From 7942e605f0aa05e1187f8fb6ef6c9c5e59d68f27 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Thu, 26 Mar 2026 11:02:49 +0800 Subject: [PATCH 1/4] issue/1102: qy gptq_gemm --- include/infiniop/ops/gptq_gemm.h | 38 + src/infiniop/ops/gptq_gemm/cuda/compat.cuh | 70 + src/infiniop/ops/gptq_gemm/cuda/kernel.cuh | 202 ++ .../ops/gptq_gemm/cuda/matrix_view.cuh | 295 +++ .../ops/gptq_gemm/cuda/my_operator.cpp | 14 + src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh | 76 + src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh | 149 ++ src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh | 126 ++ src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh | 30 + src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh | 62 + src/infiniop/ops/gptq_gemm/gptq_gemm.h | 52 + src/infiniop/ops/gptq_gemm/info.h | 77 + .../ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu | 1749 +++++++++++++++++ .../ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh | 7 + src/infiniop/ops/gptq_gemm/operator.cc | 97 + test/infiniop/gptq_gemm.py | 305 +++ test/infiniop/libinfiniop/op_register.py | 42 + xmake/qy.lua | 36 + 18 files changed, 3427 insertions(+) create mode 100644 include/infiniop/ops/gptq_gemm.h create mode 100644 src/infiniop/ops/gptq_gemm/cuda/compat.cuh create mode 100644 src/infiniop/ops/gptq_gemm/cuda/kernel.cuh create mode 100644 src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh create mode 100644 src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp create mode 100644 src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh create mode 100644 src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh create mode 100644 src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh create mode 100644 src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh create mode 100644 src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh create mode 100644 src/infiniop/ops/gptq_gemm/gptq_gemm.h create mode 100644 src/infiniop/ops/gptq_gemm/info.h create mode 100644 src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu create mode 100644 src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh create mode 100644 src/infiniop/ops/gptq_gemm/operator.cc create mode 100644 test/infiniop/gptq_gemm.py diff --git a/include/infiniop/ops/gptq_gemm.h b/include/infiniop/ops/gptq_gemm.h new file mode 100644 index 000000000..d1f25a799 --- /dev/null +++ b/include/infiniop/ops/gptq_gemm.h @@ -0,0 +1,38 @@ +#ifndef __INFINIOP_GPTQ_GEMM_API_H__ +#define __INFINIOP_GPTQ_GEMM_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopGptqGemmDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateGptqGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit); + +__INFINI_C __export infiniStatus_t infiniopGetGptqGemmWorkspaceSize( + infiniopGptqGemmDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopGptqGemm( + infiniopGptqGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + const void *b_scale, + const void *b_zero, + const void *b_g_idx, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyGptqGemmDescriptor( + infiniopGptqGemmDescriptor_t desc); +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/compat.cuh b/src/infiniop/ops/gptq_gemm/cuda/compat.cuh new file mode 100644 index 000000000..beed57c7d --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/compat.cuh @@ -0,0 +1,70 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _compat_cuh +#define _compat_cuh + +// 1. 包含CUDA核心运行时头文件(必加,提供CUDA基础类型定义) +#include + +// 2. 包含CUDA半精度浮点类型定义头文件(核心,定义half/half2/__half_raw) +#include + +namespace vllm { +namespace gptq { +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) + #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} + + #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} + #endif + + #endif +#endif + +} // namespace gptq +} // namespace vllm +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh new file mode 100644 index 000000000..73724f142 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh @@ -0,0 +1,202 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa +*/ + +#include +#include + +#include +#include +#include +#include +#include + +#include "compat.cuh" +#include "matrix_view.cuh" +#include "qdq_2.cuh" +#include "qdq_3.cuh" +#include "qdq_4.cuh" +#include "qdq_8.cuh" + +namespace vllm { +namespace gptq { + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_ROWS_8BIT 24 +#define MAX_ALT_GEMM_ROWS 8 +#define THREADS_X 32 +#define THREADS_Y 32 +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#if defined(USE_ROCM) +#include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half *alpha, + const half *AP, int lda, const half *BP, int ldb, const half *beta, + half *CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half *a_ptr, + const half2 g_result) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half *a_ptr) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half *a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half *a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half *a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half *a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half *a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half *a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half *a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half *a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half *a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2 *a2_ptr = (const half2 *)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) { + result = __hfma2(dq[i], *a2_ptr++, result); + } + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +} // namespace gptq +} // namespace vllm diff --git a/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh b/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh new file mode 100644 index 000000000..2b6719fbd --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh @@ -0,0 +1,295 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama +*/ + +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +class MatrixView_q4_column { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } +}; + +class MatrixView_q2_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } +}; + +class MatrixView_q3_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; + } + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); + } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } +}; + +class MatrixView_q8_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } +}; + +} // namespace gptq +} // namespace vllm +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp b/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp new file mode 100644 index 000000000..e2e3c0499 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp @@ -0,0 +1,14 @@ +#include +#include + +// 声明 CUDA 函数 +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int64_t bit); + + +// 绑定到 Python +PYBIND11_MODULE(vllm_gptq, m) { + m.def("gptq_gemm", &gptq_gemm, "GPTQ GEMM (CUDA)"); +} diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh new file mode 100644 index 000000000..ca0f81060 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh @@ -0,0 +1,76 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh new file mode 100644 index 000000000..0d5c2adf5 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh @@ -0,0 +1,149 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; + qa >>= 6; + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; + qb >>= 6; + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; + qc >>= 6; + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh new file mode 100644 index 000000000..7f65d2d28 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh @@ -0,0 +1,126 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | + c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | + c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | + c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | + c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh new file mode 100644 index 000000000..feb5d2204 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh @@ -0,0 +1,30 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride, + const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh new file mode 100644 index 000000000..b65238b3b --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh @@ -0,0 +1,62 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +// 1. 包含CUDA核心运行时头文件(必加,提供CUDA基础类型定义) +#include + +// 2. 包含CUDA半精度浮点类型定义头文件(核心,定义half/half2/__half_raw) +#include + +namespace vllm { +namespace gptq { + +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +} // namespace gptq +} // namespace vllm +#endif diff --git a/src/infiniop/ops/gptq_gemm/gptq_gemm.h b/src/infiniop/ops/gptq_gemm/gptq_gemm.h new file mode 100644 index 000000000..9498b5b6d --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/gptq_gemm.h @@ -0,0 +1,52 @@ +#ifndef GPTQ_GEMM_H +#define GPTQ_GEMM_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::gptq_gemm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + GptqGemmInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + GptqGemmInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t b_scales_desc, \ + infiniopTensorDescriptor_t b_zeros_desc, \ + infiniopTensorDescriptor_t b_g_idx_desc, \ + bool use_exllama, \ + int quant_bit); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, \ + const void *a, const void *b, const void *b_scale, const void *b_zero, const void *b_g_idx, \ + void *stream) const; \ + }; \ + } + +#endif // GPTQ_GEMM_H diff --git a/src/infiniop/ops/gptq_gemm/info.h b/src/infiniop/ops/gptq_gemm/info.h new file mode 100644 index 000000000..66805e6ad --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/info.h @@ -0,0 +1,77 @@ +#ifndef __GPTQ_GEMM_INFO_H__ +#define __GPTQ_GEMM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include + +namespace op::gptq_gemm { + +class GptqGemmInfo { + GptqGemmInfo() = default; + +public: + // --- Data Type --- + infiniDtype_t dtype; + + // --- Shape Dimensions --- + size_t M, K, N, b_size_0; + int block_size, num_groups; + bool use_exllama; + int64_t quant_bit; + + static utils::Result createGptqGemmInfo( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit) { + + auto dtype = out_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16); + if (b_scales_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (b_zeros_desc->dtype() != INFINI_DTYPE_I32 || b_g_idx_desc->dtype() != INFINI_DTYPE_I32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + size_t M = out_desc->shape()[0]; + size_t N = out_desc->shape()[1]; + size_t K = a_desc->shape()[1]; + size_t b_size_0 = b_desc->shape()[0]; + int block_size = 128; + int num_groups = K / block_size; + if (quant_bit != 4) { + throw std::runtime_error( + "quant_bit must be 4, but got " + std::to_string(quant_bit)); + } + + auto ndim = out_desc->ndim(); + CHECK_OR_RETURN(ndim == 2 + && a_desc->ndim() == ndim + && b_desc->ndim() == ndim + && b_scales_desc->ndim() == ndim + && b_zeros_desc->ndim() == ndim, + INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(b_scales_desc->shape()[1] == N + && b_scales_desc->shape()[0] == num_groups + && b_zeros_desc->shape()[1] == N + && b_zeros_desc->shape()[0] == num_groups, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result(GptqGemmInfo{ + dtype, + M, K, N, b_size_0, + block_size, num_groups, + use_exllama, static_cast(quant_bit)}); + } +}; + +} // namespace op::gptq_gemm + +#endif // __GPTQ_GEMM_INFO_H__ diff --git a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu new file mode 100644 index 000000000..ba57a93c2 --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu @@ -0,0 +1,1749 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "gptq_gemm_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../../../reduce/cuda/reduce.cuh" +#include + +#include "../cuda/kernel.cuh" +namespace vllm { +namespace gptq { + +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half *, const uint32_t *, + const uint32_t *, const half *, + half *, const int, const int, + const int, const int, + const int *); + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_4bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], + block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], + block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], + block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], + block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), + __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), + __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_2bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; + } + + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_3bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +template +INFINIOP_CUDA_KERNEL gemm_half_q_half_gptq_8bit_kernel( + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int *__restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto t = threadIdx.x; + + // Block + auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + auto offset_m = blockIdx.y * m_count; + auto offset_k = blockIdx.z * BLOCK_KN_SIZE; + + [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) { + a0 = a_ptr[b_q_perm[offset_k + t]]; + } else { + a0 = a_ptr[offset_k + t]; + } + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) { + return; + } + + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) { + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; + } + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; + } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( + bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) \ + return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) \ + return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) \ + return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) \ + return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; +} + +void gemm_half_q_half_cuda_part(const half *a, const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_q_perm, + half *c, int size_m, int size_n, int size_k, + int m_count, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(a, b_q_weight, b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, + size_k, groups, b_q_perm); +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_8bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 8); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_4bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_3bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / 32 * 3; + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4 *)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4 *)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_exllama_2bit_kernel( + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; + auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + auto t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) { + perm[t] = b_q_perm[offset_k + t]; + } + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) { + return; + } + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 2); + + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 2; p++) { + const int4 *b_ptr4 = (int4 *)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) { + dq[v][j] = __hmul2(scales[v], dq[v][j]); + } + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +void reconstruct_exllama(const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_q_perm, + half *out, int height, int width, int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, + out); +} + +INFINIOP_CUDA_KERNEL gemm_half_q_half_alt_4bit_kernel( + const half2 *__restrict__ vec, const uint32_t *__restrict__ mat, + half *__restrict__ mul, const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, const int *__restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + auto b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + auto h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + auto val = threadIdx.x / 8; + auto off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) { + mul[(b + m) * width + w] = __int2half_rn(0); + } + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + // res2 = {}; + res2 = __halves2half2(__float2half(0.0f), __float2half(0.0f)); +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2( + __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), + blockvec[m][k + 2], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), + blockvec[m][k + 3], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +INFINIOP_CUDA_KERNEL gemm_half_q_half_alt_8bit_kernel( + const half2 *__restrict__ vec, const uint32_t *__restrict__ mat, + half *__restrict__ mul, const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, const int *__restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + auto b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + auto h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + threadIdx.x]; + } + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) { + mul[(b + m) * width + w] = __int2half_rn(0); + } + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + // res2 = {}; + res2 = __halves2half2(__float2half(0.0f), __float2half(0.0f)); +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), + __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), + __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +void gemm_half_q_half_alt(const half *a, const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_g_idx, + half *c, int size_m, int size_n, int size_k, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + (const half2 *)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); +} + +template +INFINIOP_CUDA_KERNEL reconstruct_gptq_kernel(const uint32_t *__restrict__ w, + const half *__restrict__ w_scales, + const uint32_t *__restrict__ w_zeros, + const int *__restrict__ g_idx, + const int height, const int width, + const int group, + half *__restrict__ out) { + // Start of block + + auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + auto row = blockIdx.y * 32 / bit; + if (column >= width) { + return; + } + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half *out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), + w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } +} + +INFINIOP_CUDA_KERNEL reconstruct_gptq_3bit_kernel( + const uint32_t *__restrict__ w, const half *__restrict__ w_scales, + const uint32_t *__restrict__ w_zeros, const int *__restrict__ g_idx, + const int height, const int width, const int group, + half *__restrict__ out) { + // Start of block + auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + auto row = blockIdx.y * 32; + if (column >= width) { + return; + } + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half *out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); + } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } +} + +void reconstruct_gptq(const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_g_idx, half *out, + int height, int width, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(b_q_weight, b_gptq_scales, + b_gptq_qzeros, b_g_idx, height, + width, groups, out); +} + +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half *a, + const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_g_idx, + half *c, half *temp_dq, int size_m, int size_n, + int size_k, int groups, bool use_exllama, int bit) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); + } else { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); + } + + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else if (use_exllama) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c, last_chunk, size_n, size_k, + BLOCK_M_SIZE_MAX, groups, bit); + } + + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, + b_gptq_qzeros, b_gptq_scales, b_g_idx, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, groups, bit); + } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k, bit); + } +} + +INFINIOP_CUDA_KERNEL shuffle_4bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } +} + +INFINIOP_CUDA_KERNEL shuffle_8bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } +} + +INFINIOP_CUDA_KERNEL shuffle_2bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } +} + +INFINIOP_CUDA_KERNEL shuffle_3bit_kernel(uint32_t *__restrict__ b_q_weight, + const int size_k, const int size_n) { + auto n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) { + return; + } + int k = 0; + uint32_t *b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } +} + +INFINIOP_CUDA_KERNEL make_sequential_4bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) { + return; + } + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +INFINIOP_CUDA_KERNEL make_sequential_2bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) { + return; + } + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +INFINIOP_CUDA_KERNEL make_sequential_3bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + auto w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) { + return; + } + auto w_new_row = blockIdx.y * 3; + auto q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); + } + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; +} + +INFINIOP_CUDA_KERNEL make_sequential_8bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, + const int w_width) { + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; + int w2_stride = w_width >> 1; + auto w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) { + return; + } + auto w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +} // namespace gptq +} // namespace vllm + +infiniStatus_t GptqGemmKernel(void *c, const void *a, const void *b, + const void *b_scales, const void *b_zeros, const void *b_g_idx, + int M, int K, int N, int num_groups, + bool use_exllama, int64_t bit, cublasHandle_t cublas_handle, void *workspace) { + + char *workspace_ptr = reinterpret_cast(workspace); + half *temp_dq = reinterpret_cast(workspace_ptr); // shape ? + + vllm::gptq::gemm_half_q_half_cuda( + cublas_handle, (const half *)a, + (const uint32_t *)b, + (const uint32_t *)b_zeros, + (const half *)b_scales, + (const int *)b_g_idx, + (half *)c, temp_dq, + M, + N, + K, + num_groups, + use_exllama, bit); + return INFINI_STATUS_SUCCESS; +} + +namespace op::gptq_gemm::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit) { + + auto info = GptqGemmInfo::createGptqGemmInfo(out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc, b_g_idx_desc, use_exllama, quant_bit); + + CHECK_RESULT(info); + + size_t workspace_size = b_desc->shape()[0] * 32 / static_cast(quant_bit) * b_desc->shape()[1]; + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + const void *b_scales, + const void *b_zeros, + const void *b_g_idx, + void *stream) const { + + int M = _info.M; + int K = _info.K; + int N = _info.N; + int num_groups = _info.num_groups; + bool use_exllama = _info.use_exllama; + int64_t quant_bit = _info.quant_bit; + + if (_info.dtype == INFINI_DTYPE_F16) { + CHECK_STATUS(_opaque->internal->useCublas( + (cudaStream_t)stream, + [&](cublasHandle_t handle) { + CHECK_CUBLAS( + GptqGemmKernel(out, a, b, + b_scales, b_zeros, + b_g_idx, M, K, N, num_groups, + use_exllama, quant_bit, handle, workspace)); + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gptq_gemm::nvidia diff --git a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh new file mode 100644 index 000000000..e718932ea --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cuh @@ -0,0 +1,7 @@ +#ifndef __GPTQ_GEMM_NVIDIA_API_H__ +#define __GPTQ_GEMM_NVIDIA_API_H__ +#include "../gptq_gemm.h" + +DESCRIPTOR(nvidia) + +#endif // __GPTQ_GEMM_NVIDIA_API_H__ diff --git a/src/infiniop/ops/gptq_gemm/operator.cc b/src/infiniop/ops/gptq_gemm/operator.cc new file mode 100644 index 000000000..8c477c23a --- /dev/null +++ b/src/infiniop/ops/gptq_gemm/operator.cc @@ -0,0 +1,97 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gptq_gemm.h" + +#if defined(ENABLE_QY_API) +#include "nvidia/gptq_gemm_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateGptqGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t b_g_idx_desc, + bool use_exllama, + int quant_bit) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gptq_gemm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc, b_g_idx_desc, use_exllama, quant_bit); + + switch (handle->device) { +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGetGptqGemmWorkspaceSize( + infiniopGptqGemmDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGptqGemm( + infiniopGptqGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + const void *b_scale, + const void *b_zero, + const void *b_g_idx, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, out, a, b, b_scale, b_zero, b_g_idx, stream); + + switch (desc->device_type) { +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopDestroyGptqGemmDescriptor( + infiniopGptqGemmDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_QY_API + DESTROY(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/test/infiniop/gptq_gemm.py b/test/infiniop/gptq_gemm.py new file mode 100644 index 000000000..0b63fbed1 --- /dev/null +++ b/test/infiniop/gptq_gemm.py @@ -0,0 +1,305 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + # M, K, N, use_exllama, quant_bit, group_size + (128, 256, 32, False, 4, 128), + (512, 2048, 128, True, 4, 128), + (1024, 1024, 128, False, 8, 128), + (1024, 1024, 128, True, 8, 128), +] + + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 5e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + +def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N): + assert bit == 4, "Reference dequantization only supports 4-bit" + group_size = K // scales.shape[0] + pack_factor = 32 // bit + + # unpack q_weight: (K//pack_factor, N) -> (K, N) + unpacked_q_weight = torch.empty( + q_weight.shape[0] * pack_factor, + q_weight.shape[1], + dtype=torch.uint8, + device=q_weight.device, + ) + for i in range(pack_factor): + unpacked_q_weight[i::pack_factor, :] = (q_weight >> (i * 4)) & 0x0F + + # unpack q_zeros: (num_groups, N//pack_factor) -> (num_groups, N) + unpacked_q_zeros = torch.empty( + q_zeros.shape[0], + q_zeros.shape[1] * pack_factor, + dtype=torch.uint8, + device=q_zeros.device, + ) + for i in range(pack_factor): + unpacked_q_zeros[:, i::pack_factor] = (q_zeros >> (i * 4)) & 0x0F + + unpacked_q_zeros += 1 + unpacked_q_zeros = unpacked_q_zeros.to(scales.dtype) + + scale_zeros = unpacked_q_zeros * scales # (num_groups, N) + + current_g_idx = torch.tensor( + [i // group_size for i in range(K)], dtype=torch.int32, device=q_weight.device + ) + + scale_mat = scales[current_g_idx] # (K, N) + scale_zeros_mat = scale_zeros[current_g_idx] # (K, N) + + # dequant: weight * scale - scale_zeros + dequantized_b = unpacked_q_weight.to(scales.dtype) * scale_mat - scale_zeros_mat + + return dequantized_b.reshape(K, N) + + +def torch_gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit +): + K, N = a.shape[1], b_q_weight.shape[1] + + b_dequant = torch_dequantize( + b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit, K, N + ) + c = torch.matmul(a, b_dequant) + return c + + +def test( + handle, + device, + M, + K, + N, + use_exllama, + quant_bit, + group_size, + dtype=InfiniDtype.F16, + sync=None, +): + + print( + f"Testing Gptq Gemm on {InfiniDeviceNames[device]} with M-K-N:{M, K, N}, use_exllama:{use_exllama}, quant_bit:{quant_bit}, group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + b_fp = TestTensor((K, N), None, dtype, device) + + assert K % group_size == 0, "K must be divisible by group_size" + num_groups = K // group_size + use_shuffle = use_exllama + + if use_shuffle: + print(f"not support use_shuffle:{use_shuffle}") + return + else: + g_idx = torch.tensor( + [i // group_size for i in range(K)], dtype=torch.int32, device=device + ) + b_shuffled = b_fp.torch_tensor()[g_idx] + + b_grouped = b_shuffled.reshape(num_groups, group_size, N) + + b_max = torch.max(b_grouped, dim=1, keepdim=True)[0] + b_min = torch.min(b_grouped, dim=1, keepdim=True)[0] + + scales = (b_max - b_min) / (2**quant_bit - 1) + scales = scales.clamp(min=1e-6) + + zeros_float = (-b_min / scales).round() + + q_b = ( + (b_grouped / scales + zeros_float).round().clamp(0, 2**quant_bit - 1).to(torch.uint8) + ) + + q_zeros_unpacked = zeros_float.to(torch.uint8) - 1 + + b_q_weight = pack_rows(q_b.reshape(K, N), quant_bit, K, N) + + q_zeros_unpacked = q_zeros_unpacked.reshape(num_groups, N) + b_gptq_qzeros = pack_cols(q_zeros_unpacked, quant_bit, num_groups, N) + b_gptq_scales = scales.squeeze(1) + + A = TestTensor((M, K), None, dtype, device) + C = TestTensor((M, N), None, dtype, device) + + + B = TestTensor(b_q_weight.shape, b_q_weight.stride(), infiniDtype.I32, device, mode="manual", set_tensor=b_q_weight) + b_scales = TestTensor(b_gptq_scales.shape, b_gptq_scales.stride(), dtype, device, mode="manual", set_tensor=b_gptq_scales) + b_zeros = TestTensor(b_gptq_qzeros.shape, b_gptq_qzeros.stride(), infiniDtype.I32, device, mode="manual", set_tensor=b_gptq_qzeros) + b_g_idx = TestTensor((K, ), g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx) + + if sync is not None: + sync() + + ans = torch_gptq_gemm( + A.torch_tensor(), B.torch_tensor(), b_zeros.torch_tensor(), b_scales.torch_tensor(), b_g_idx.torch_tensor(), use_shuffle, quant_bit + ) + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateGptqGemmDescriptor( + handle, + ctypes.byref(descriptor), + C.descriptor, + A.descriptor, + B.descriptor, + b_scales.descriptor, + b_zeros.descriptor, + b_g_idx.descriptor, + use_exllama, + quant_bit, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + + for tensor in [C, A, B, b_scales, b_zeros, b_g_idx]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetGptqGemmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_gptq_gemm(): + check_error( + LIBINFINIOP.infiniopGptqGemm( + descriptor, + workspace.data(), + workspace_size.value, + C.data(), + A.data(), + B.data(), + b_scales.data(), + b_zeros.data(), + b_g_idx.data(), + None, + ) + ) + + lib_gptq_gemm() + + if sync is not None: + sync() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(C.actual_tensor(), ans, atol=atol, rtol=rtol) + + + assert torch.allclose(C.actual_tensor(), ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: gptq_gemm_torch(A.torch_tensor(), B.torch_tensor(), b_scales.torch_tensor(), b_zeros.torch_tensor(), b_g_idx.torch_tensor(), group_size, quant_bit), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_gptq_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + check_error(LIBINFINIOP.infiniopDestroyGptqGemmDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 1c90feb22..546b96ff3 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1131,6 +1131,48 @@ def per_tensor_dequant_int8_(lib): ] +@OpRegister.operator +def gptq_gemm_(lib): + lib.infiniopCreateGptqGemmDescriptor.restype = c_int32 + lib.infiniopCreateGptqGemmDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_bool, + c_int32, + ] + + lib.infiniopGetGptqGemmWorkspaceSize.restype = c_int32 + lib.infiniopGetGptqGemmWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopGptqGemm.restype = c_int32 + lib.infiniopGptqGemm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyGptqGemmDescriptor.restype = c_int32 + lib.infiniopDestroyGptqGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32 diff --git a/xmake/qy.lua b/xmake/qy.lua index 810f88c2f..1e1ae1be6 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -1,3 +1,33 @@ +local TORCH_DIR = os.getenv("TORCH_DIR") + +if not TORCH_DIR then + raise("TORCH_DIR is not set! please export it first") +end + +print("TORCH_DIR =", TORCH_DIR) + +if TORCH_DIR and os.isdir(TORCH_DIR) then + local TORCH_INCLUDE = TORCH_DIR .. "/include" + local TORCH_LIB = TORCH_DIR .. "/lib" + + print("✅ 自动找到 PyTorch 路径: " .. TORCH_DIR) + print("✅ PyTorch 头文件: " .. TORCH_INCLUDE) + print("✅ PyTorch 库路径: " .. TORCH_LIB) + + -- 添加 PyTorch 头文件 + add_includedirs(TORCH_INCLUDE) + add_includedirs(TORCH_INCLUDE .. "/torch/csrc/api/include") + + -- 添加 PyTorch 库路径 + add_linkdirs(TORCH_LIB) + + -- 链接 PyTorch 核心库(解决 undefined symbol) + add_links("torch", "torch_cpu", "torch_cuda", "c10", "c10_cuda", "torch_python") +else + print("⚠️ 未检测到 PyTorch,将跳过 PyTorch 依赖") +end + + local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH") if CUDNN_ROOT ~= nil then add_includedirs(CUDNN_ROOT .. "/include") @@ -21,6 +51,12 @@ rule("qy.cuda") on_load(function (target) target:add("includedirs", "/usr/local/denglin/sdk/include") + + -- 把 PyTorch 头文件也加入自定义规则 + if TORCH_DIR then + target:add("includedirs", TORCH_DIR .. "/include") + target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include") + end end) after_load(function (target) From 4b2b96eaebfef993cdca05d444fba315075bb3b3 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Thu, 26 Mar 2026 15:01:47 +0800 Subject: [PATCH 2/4] issue/1102: success nvidia gptq --- src/infiniop/ops/gptq_gemm/cuda/compat.cuh | 69 ++- src/infiniop/ops/gptq_gemm/cuda/kernel.cuh | 2 +- .../ops/gptq_gemm/cuda/matrix_view.cuh | 498 +++++++++--------- .../ops/gptq_gemm/cuda/my_operator.cpp | 3 +- src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh | 92 ++-- src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh | 250 ++++----- src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh | 150 +++--- src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh | 23 +- src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh | 38 +- src/infiniop/ops/gptq_gemm/info.h | 4 +- .../ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu | 6 +- src/infiniop/ops/gptq_gemm/operator.cc | 14 +- test/infiniop/gptq_gemm.py | 72 ++- xmake/nvidia.lua | 26 + 14 files changed, 659 insertions(+), 588 deletions(-) diff --git a/src/infiniop/ops/gptq_gemm/cuda/compat.cuh b/src/infiniop/ops/gptq_gemm/cuda/compat.cuh index beed57c7d..1da7ebd14 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/compat.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/compat.cuh @@ -15,56 +15,55 @@ namespace vllm { namespace gptq { // atomicAdd for half types, to support CC < 7.x -__device__ __forceinline__ void atomicAdd_half(half* address, half val) { - unsigned int* address_as_ui = - (unsigned int*)((char*)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; +__device__ __forceinline__ void atomicAdd_half(half *address, half val) { + unsigned int *address_as_ui = (unsigned int *)((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; - do { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) - : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } while (assumed != old); + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); } // atomicAdd for half2 types -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } while (assumed != old); +__device__ __forceinline__ void atomicAdd_half2(half2 *address, half2 val) { + unsigned int *address_as_ui = (unsigned int *)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2 *)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int *)&new_val)); + } while (assumed != old); } // #if defined(__CUDA_ARCH__) || defined(USE_ROCM) - #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half* address, half val) { - atomicAdd_half(address, val); +__device__ __forceinline__ void atomicAdd(half *address, half val) { + atomicAdd_half(address, val); } - #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { - atomicAdd_half2(address, val); +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2 *address, half2 val) { + atomicAdd_half2(address, val); } - #endif +#endif - #endif +#endif #endif -} // namespace gptq -} // namespace vllm +} // namespace gptq +} // namespace vllm #endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh index 73724f142..3d354d375 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh @@ -30,7 +30,7 @@ namespace gptq { #define MAX_ALT_GEMM_ROWS 8 #define THREADS_X 32 #define THREADS_Y 32 -#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) +#define DIVIDE(x, size) (((x) + (size)-1) / (size)) #if defined(USE_ROCM) #include diff --git a/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh b/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh index 2b6719fbd..5b7e0bd88 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/matrix_view.cuh @@ -6,8 +6,8 @@ https://github.com/turboderp/exllama #ifndef _matrix_view_cuh #define _matrix_view_cuh -#include #include +#include #include "qdq_util.cuh" @@ -15,281 +15,277 @@ namespace vllm { namespace gptq { class MatrixView_half { - public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, - const int width) - : data(data), height(height), width(width) {} - - __device__ __forceinline__ half item(int row, int column) const { - return data[row * width + column]; - } - __device__ __forceinline__ half2 item_half2(int row, int column) const { - return ((half2*)data)[(row * width + column) / 2]; - } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { - return __half2half2(data[row * width + column]); - } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { - return &data[row * width + column]; - } - - __device__ __forceinline__ void item4(half (&items)[4], int row, - int column) const { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __low2half(i01); - items[1] = __high2half(i01); - items[2] = __low2half(i23); - items[3] = __high2half(i23); - } - __device__ __forceinline__ void item4_f(float (&items)[4], int row, +public: + const half *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half *data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2 *)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half *item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2float(__low2half(i01)); - items[1] = __half2float(__high2half(i01)); - items[2] = __half2float(__low2half(i23)); - items[3] = __half2float(__high2half(i23)); - } - - __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, - int column) const { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2half2(__low2half(i01)); - items[1] = __half2half2(__high2half(i01)); - items[2] = __half2half2(__low2half(i23)); - items[3] = __half2half2(__high2half(i23)); - } + half2 *ptr = (half2 *)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2 *ptr = (half2 *)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2 *ptr = (half2 *)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } }; class MatrixView_half_rw { - public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, - const int width) - : data(data), height(height), width(width) {} - - __device__ __forceinline__ half item(int row, int column) const { - return data[row * width + column]; - } - __device__ __forceinline__ half2 item_half2(int row, int column) const { - return ((half2*)data)[(row * width + column) / 2]; - } - __device__ __forceinline__ half* item_ptr(int row, int column) { - return &data[row * width + column]; - } - __device__ __forceinline__ void set(int row, int column, half value) { - data[row * width + column] = value; - } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { - ((half2*)data)[(row * width + column) / 2] = value; - } - - __device__ __forceinline__ void set4(int row, int column, half v0, half v1, - half v2, half v3) { - half2 v01 = __halves2half2(v0, v1); - half2 v23 = __halves2half2(v2, v3); - half2* ptr = (half2*)item_ptr(row, column); - ptr[0] = v01; - ptr[1] = v23; - } +public: + half *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half *data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2 *)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half *item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2 *)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2 *ptr = (half2 *)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } }; class MatrixView_q4_row { - public: - const uint32_t* data; - const int height; - const int width; +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, - const int height, - const int width) - : data(data), height(height), width(width) {} - - __device__ __forceinline__ int item(int row, int column) const { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, - int column) const { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, - int column) const { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - items[2] = (d >> 8) & 0x0f; - items[3] = (d >> 12) & 0x0f; - } + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } }; class MatrixView_q4_column { - public: - const uint32_t* data; - const int height; - const int width; +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, - const int height, - const int width) - : data(data), height(height), width(width) {} - - __device__ __forceinline__ int item(int row, int column) const { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { - return data[row / 8 * width + column]; - } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, - int column) { - return &data[row / 8 * width + column]; - } + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t *item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } }; class MatrixView_q2_row { - public: - const uint32_t* data; - const int height; - const int width; +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } - __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, - const int height, - const int width) - : data(data), height(height), width(width) {} - - __device__ __forceinline__ int item(int row, int column) const { - int shift = (column & 0x0f) * 2; - return (data[row * width / 16 + column / 16] >> shift) & 0x03; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, - int column) const { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, - int column) const { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - items[2] = (d >> 4) & 0x03; - items[3] = (d >> 6) & 0x03; - } + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } }; class MatrixView_q3_row { - public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, - const int height, - const int width) - : data(data), height(height), width(width) {} - - __device__ __forceinline__ int item(int row, int column) const { - int z_w = column * 3 / 32; - int z_mod = column & 0x1f; - - if (z_mod == 10) { - return (data[row * width * 3 / 32 + z_w] >> 30) | - ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); - } else if (z_mod == 21) { - return (data[row * width * 3 / 32 + z_w] >> 31) | - ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); - } else if (z_mod < 10) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; - } else if (z_mod < 21) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; - } else { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; + } } - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, - int column) const { - int shift = (column & 0x1f); - uint32_t d; - if (shift <= 4) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); - } else if (shift == 8) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | - ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); - } else if (shift <= 16) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); - } else if (shift == 20) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | - ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); - } else { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); + } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; } - items[0] = d & 0x07; - items[1] = (d >> 3) & 0x07; - items[2] = (d >> 6) & 0x07; - items[3] = (d >> 9) & 0x07; - } }; class MatrixView_q8_row { - public: - const uint32_t* data; - const int height; - const int width; +public: + const uint32_t *data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t *data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } - __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, - const int height, - const int width) - : data(data), height(height), width(width) {} - - __device__ __forceinline__ int item(int row, int column) const { - int shift = (column & 0x03) * 8; - return (data[row * width / 4 + column / 4] >> shift) & 0xff; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, - int column) const { - int shift = (column & 0x03) * 8; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, - int column) const { - int shift = (column & 0x03) * 2; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - items[2] = (d >> 16) & 0xff; - items[3] = (d >> 24) & 0xff; - } + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } }; -} // namespace gptq -} // namespace vllm +} // namespace gptq +} // namespace vllm #endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp b/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp index e2e3c0499..ff4b779e8 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp +++ b/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp @@ -1,5 +1,5 @@ -#include #include +#include // 声明 CUDA 函数 torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, @@ -7,7 +7,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama, int64_t bit); - // 绑定到 Python PYBIND11_MODULE(vllm_gptq, m) { m.def("gptq_gemm", &gptq_gemm, "GPTQ GEMM (CUDA)"); diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh index ca0f81060..c9bb50efc 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_2.cuh @@ -14,63 +14,63 @@ namespace gptq { // // ffddbb99 77553311 eeccaa88 66442200 -__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { - uint32_t qa = q[0]; - uint32_t qb = 0; +__forceinline__ __device__ void shuffle_2bit_16(uint32_t *q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; #pragma unroll - for (int i = 0; i < 8; i++) { - uint32_t qa0 = qa & 0x03; - uint32_t qa1 = (qa & 0x0c) >> 2; - qa >>= 4; - qb |= (qa1 << (i * 2 + 16)); - qb |= (qa0 << (i * 2)); - } - q[0] = qb; + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; } __forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, half2 (&dq)[8], int stride, const uint32_t zero) { - const uint32_t c0 = 0x64006400; - const half y4_ = __float2half_rn(1.0f / 4.0f); - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y4 = __halves2half2(y4_, y4_); - const half2 y16 = __halves2half2(y16_, y16_); - const half2 y64 = __halves2half2(y64_, y64_); + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z4 = __half2half2(z4_); - const half2 z16 = __half2half2(z16_); - const half2 z64 = __half2half2(z64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 - half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 - half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 - qa >>= 8; - half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 - half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 - half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 - half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y4, z4); - dq[2] = __hfma2(q2.as_half2, y16, z16); - dq[3] = __hfma2(q3.as_half2, y64, z64); - dq[4] = __hadd2(q4.as_half2, z1); - dq[5] = __hfma2(q5.as_half2, y4, z4); - dq[6] = __hfma2(q6.as_half2, y16, z16); - dq[7] = __hfma2(q7.as_half2, y64, z64); + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); } -} // namespace gptq -} // namespace vllm +} // namespace gptq +} // namespace vllm #endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh index 0d5c2adf5..eefef3151 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_3.cuh @@ -11,71 +11,71 @@ namespace gptq { // vjjjhhhf ffdddbbb uiiiggge eecccaaa // vtttrrrp ppnnnlll usssqqqo oommmkkk -__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { - uint32_t qa = q[0 * stride]; - uint32_t qb = q[1 * stride]; - uint32_t qc = q[2 * stride]; - - // qa: aa999888 77766655 54443332 22111000 - // qb: lkkkjjji iihhhggg fffeeedd dcccbbba - // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll - - uint32_t qd = qc >> 26; - qc <<= 4; - qc |= qb >> 28; - qb <<= 2; - qb |= qa >> 30; - - // qa: ..999888 77766655 54443332 22111000 - // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa - // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk - // qd: vvvuuu - - uint32_t za = 0; - uint32_t zb = 0; - uint32_t zc = 0; - - for (int i = 0; i < 5; i++) { - uint32_t t0 = qa & 0x07; - uint32_t t1 = (qa & 0x38) >> 3; - qa >>= 6; - za |= (t0 << (i * 3)); - za |= (t1 << (i * 3 + 16)); - } - for (int i = 0; i < 5; i++) { - uint32_t t0 = qb & 0x07; - uint32_t t1 = (qb & 0x38) >> 3; - qb >>= 6; - zb |= (t0 << (i * 3)); - zb |= (t1 << (i * 3 + 16)); - } - for (int i = 0; i < 5; i++) { - uint32_t t0 = qc & 0x07; - uint32_t t1 = (qc & 0x38) >> 3; - qc >>= 6; - zc |= (t0 << (i * 3)); - zc |= (t1 << (i * 3 + 16)); - } - - // za: 9997775 55333111 8886664 44222000 - // zb: jjjhhhf ffdddbbb iiiggge eecccaaa - // zc: tttrrrp ppnnnlll sssqqqo oommmkkk - // qd: vvvuuu - - za |= ((qd & 0x01) >> 0) << 15; - zb |= ((qd & 0x02) >> 1) << 15; - zc |= ((qd & 0x04) >> 2) << 15; - za |= ((qd & 0x08) >> 3) << 31; - zb |= ((qd & 0x10) >> 4) << 31; - zc |= ((qd & 0x20) >> 5) << 31; - - // za: v9997775 55333111 u8886664 44222000 (u, v lsb) - // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa - // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk - - q[0 * stride] = za; - q[1 * stride] = zb; - q[2 * stride] = zc; +__forceinline__ __device__ void shuffle_3bit_32(uint32_t *q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; + qa >>= 6; + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; + qb >>= 6; + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; + qc >>= 6; + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; } __forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, @@ -83,67 +83,67 @@ __forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, const uint32_t q_2, half2 (&dq)[16], int stride, const uint32_t zero) { - const uint32_t c0 = 0x64006400; - const half y8_ = __float2half_rn(1.0f / 8.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y8 = __halves2half2(y8_, y8_); - const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); - const half2 z8 = __halves2half2(z8_, z8_); - const half2 z64 = __halves2half2(z64_, z64_); - - uint32_t qa = q_0; - uint32_t qb = q_1; - uint32_t qc = q_2; - - half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 - qa >>= 6; - half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 - half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 - qa >>= 9; - qa &= 0x00010001; - half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 - half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 - qb >>= 6; - half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 - half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 - half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 - qb >>= 8; - qb &= 0x00020002; - half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 - half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 - qc >>= 6; - half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 - half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 - half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 - qc >>= 7; - qc &= 0x00040004; - half2_uint32 q15((qa | qb | qc) | c0); - - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y8, z8); - dq[2] = __hadd2(q2.as_half2, z1); - dq[3] = __hfma2(q3.as_half2, y8, z8); - dq[4] = __hfma2(q4.as_half2, y64, z64); - dq[5] = __hadd2(q5.as_half2, z1); - dq[6] = __hfma2(q6.as_half2, y8, z8); - dq[7] = __hadd2(q7.as_half2, z1); - dq[8] = __hfma2(q8.as_half2, y8, z8); - dq[9] = __hfma2(q9.as_half2, y64, z64); - dq[10] = __hadd2(q10.as_half2, z1); - dq[11] = __hfma2(q11.as_half2, y8, z8); - dq[12] = __hadd2(q12.as_half2, z1); - dq[13] = __hfma2(q13.as_half2, y8, z8); - dq[14] = __hfma2(q14.as_half2, y64, z64); - dq[15] = __hadd2(q15.as_half2, z1); + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); } -} // namespace gptq -} // namespace vllm +} // namespace gptq +} // namespace vllm #endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh index 7f65d2d28..29ca387ce 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_4.cuh @@ -13,77 +13,77 @@ namespace gptq { // // 77775555 33331111 66664444 22220000 -__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { - uint32_t qa = q[0]; - uint32_t qb = 0; +__forceinline__ __device__ void shuffle_4bit_8(uint32_t *q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t qa0 = qa & 0x0f; - uint32_t qa1 = (qa & 0xf0) >> 4; - qa >>= 8; - qb |= (qa1 << (i * 4 + 16)); - qb |= (qa0 << (i * 4)); - } - q[0] = qb; + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; } __forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, half2 (&dq)[4], int stride, const uint32_t zero) { - const uint32_t c0 = 0x64006400; - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half2 y16 = __halves2half2(y16_, y16_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z16 = __half2half2(z16_); - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 - - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y16, z16); - dq[2] = __hadd2(q2.as_half2, z1); - dq[3] = __hfma2(q3.as_half2, y16, z16); + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); } __forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( const uint32_t zero, const half scale, half2 (&z1z16)[2], half2 (&y1y16)[2]) { - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - half2 scale2 = __half2half2(scale); + half2 scale2 = __half2half2(scale); - z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); - z1z16[1] = __hmul2(scale2, __half2half2(z16)); + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __hmul2(scale2, __half2half2(y1)); - y1y16[1] = __hmul2(scale2, __half2half2(y16)); + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); } __forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, half2 (&z1z16)[2], half2 (&y1y16)[2]) { - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - z1z16[0] = __half2half2(z1.as_half); - z1z16[1] = __half2half2(z16); + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __half2half2(y1); - y1y16[1] = __half2half2(y16); + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); } __forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, @@ -91,36 +91,32 @@ __forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, half2 (&z1z16)[2], half2 (&y1y16)[2], int stride, bool scaled) { - const uint32_t c0 = 0x64006400; - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | - c0); // half2( q[0] + 1024, q[1] + 1024 ) - half2_uint32 q1((qa & 0x00f000f0) | - c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | - c0); // half2( q[4] + 1024, q[5] + 1024 ) - half2_uint32 q3((qa & 0x00f000f0) | - c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) - - if (scaled) { - dq[0] = __hfma2(q0.as_half2, y1y16[0], - z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) - dq[1] = __hfma2(q1.as_half2, y1y16[1], - z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) - dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); - } else { - dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) - dq[1] = __hfma2(q1.as_half2, y1y16[1], - z1z16[1]); // half2( q[2] - z, q[3] - z ) - dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) - dq[3] = __hfma2(q3.as_half2, y1y16[1], - z1z16[1]); // half2( q[6] - z, q[7] - z ) - } + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } } -} // namespace gptq -} // namespace vllm +} // namespace gptq +} // namespace vllm #endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh index feb5d2204..fb9b8fc47 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_8.cuh @@ -10,21 +10,26 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} +__forceinline__ __device__ void shuffle_8bit_4(uint32_t *q, int stride) {} __forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, const uint32_t q_1, half2 (&dq)[4], int stride, const uint32_t zero) { - half dqh[8]; - for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); - for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); - - for (int i = 0; i < 4; i++) - dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); + half dqh[8]; + for (int i = 0; i < 4; i++) { + dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + } + for (int i = 0; i < 4; i++) { + dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + } + + for (int i = 0; i < 4; i++) { + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); + } } -} // namespace gptq -} // namespace vllm +} // namespace gptq +} // namespace vllm #endif diff --git a/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh b/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh index b65238b3b..acac37c81 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/qdq_util.cuh @@ -15,48 +15,48 @@ namespace vllm { namespace gptq { union half2_uint32 { - uint32_t as_uint32; - half2 as_half2; - __device__ half2_uint32(uint32_t val) : as_uint32(val) {} - __device__ half2_uint32(half2 val) : as_half2(val) {} + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} }; union half_uint16 { - uint16_t as_uint16; - half as_half; - __device__ half_uint16(uint16_t val) : as_uint16(val) {} - __device__ half_uint16(half val) : as_half(val) {} + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} }; // Max_scale premultiplied by 1/256 __forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { - int qs_i = qs + 1; - half qs_h = __int2half_rn(qs_i * qs_i); - qs_h = __hmul(qs_h, max_scale); - return qs_h; + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; } __forceinline__ __device__ half dq(const int q, const int qzero, const half scale) { - return __hmul(__int2half_rn(q - qzero), scale); + return __hmul(__int2half_rn(q - qzero), scale); } __forceinline__ __device__ half dq_ns(const int q, const int qzero) { - // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); - return __int2half_rn(q - qzero); + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); } __forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) { - return (int)((q >> shift) & mask); + return (int)((q >> shift) & mask); } __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) { - return (int)(__funnelshift_rc(q0, q1, shift) & mask); + return (int)(__funnelshift_rc(q0, q1, shift) & mask); } -} // namespace gptq -} // namespace vllm +} // namespace gptq +} // namespace vllm #endif diff --git a/src/infiniop/ops/gptq_gemm/info.h b/src/infiniop/ops/gptq_gemm/info.h index 66805e6ad..11c82d38d 100644 --- a/src/infiniop/ops/gptq_gemm/info.h +++ b/src/infiniop/ops/gptq_gemm/info.h @@ -59,9 +59,7 @@ class GptqGemmInfo { && b_zeros_desc->ndim() == ndim, INFINI_STATUS_BAD_TENSOR_SHAPE); CHECK_OR_RETURN(b_scales_desc->shape()[1] == N - && b_scales_desc->shape()[0] == num_groups - && b_zeros_desc->shape()[1] == N - && b_zeros_desc->shape()[0] == num_groups, + && static_cast(b_scales_desc->shape()[0]) == num_groups, INFINI_STATUS_BAD_TENSOR_SHAPE); return utils::Result(GptqGemmInfo{ diff --git a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu index ba57a93c2..7f336c2f5 100644 --- a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu +++ b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu @@ -1654,7 +1654,7 @@ INFINIOP_CUDA_KERNEL make_sequential_8bit_kernel(const uint32_t *__restrict__ w, } // namespace gptq } // namespace vllm -infiniStatus_t GptqGemmKernel(void *c, const void *a, const void *b, +cublasStatus_t GptqGemmKernel(void *c, const void *a, const void *b, const void *b_scales, const void *b_zeros, const void *b_g_idx, int M, int K, int N, int num_groups, bool use_exllama, int64_t bit, cublasHandle_t cublas_handle, void *workspace) { @@ -1674,7 +1674,7 @@ infiniStatus_t GptqGemmKernel(void *c, const void *a, const void *b, K, num_groups, use_exllama, bit); - return INFINI_STATUS_SUCCESS; + return CUBLAS_STATUS_SUCCESS; } namespace op::gptq_gemm::nvidia { @@ -1702,7 +1702,7 @@ infiniStatus_t Descriptor::create( CHECK_RESULT(info); - size_t workspace_size = b_desc->shape()[0] * 32 / static_cast(quant_bit) * b_desc->shape()[1]; + size_t workspace_size = b_desc->shape()[0] * 32 / static_cast(quant_bit) * b_desc->shape()[1] * infiniSizeOf(a_desc->dtype()); *desc_ptr = new Descriptor( new Opaque{reinterpret_cast(handle)->internal()}, info.take(), workspace_size, handle->device, handle->device_id); diff --git a/src/infiniop/ops/gptq_gemm/operator.cc b/src/infiniop/ops/gptq_gemm/operator.cc index 8c477c23a..220dab07e 100644 --- a/src/infiniop/ops/gptq_gemm/operator.cc +++ b/src/infiniop/ops/gptq_gemm/operator.cc @@ -2,7 +2,7 @@ #include "../../handle.h" #include "infiniop/ops/gptq_gemm.h" -#if defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #include "nvidia/gptq_gemm_nvidia.cuh" #endif @@ -26,6 +26,9 @@ __INFINI_C infiniStatus_t infiniopCreateGptqGemmDescriptor( out_desc, a_desc, b_desc, b_scales_desc, b_zeros_desc, b_g_idx_desc, use_exllama, quant_bit); switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia) #endif @@ -45,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopGetGptqGemmWorkspaceSize( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia) #endif @@ -71,6 +77,9 @@ __INFINI_C infiniStatus_t infiniopGptqGemm( workspace, workspace_size, out, a, b, b_scale, b_zero, b_g_idx, stream); switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia) #endif @@ -88,6 +97,9 @@ __INFINI_C infiniStatus_t infiniopDestroyGptqGemmDescriptor( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia) #endif diff --git a/test/infiniop/gptq_gemm.py b/test/infiniop/gptq_gemm.py index 0b63fbed1..dcb2cbfbf 100644 --- a/test/infiniop/gptq_gemm.py +++ b/test/infiniop/gptq_gemm.py @@ -1,4 +1,5 @@ import torch +import numpy import ctypes from ctypes import c_uint64 from libinfiniop import ( @@ -25,10 +26,15 @@ # These are not meant to be imported from other modules _TEST_CASES = [ # M, K, N, use_exllama, quant_bit, group_size - (128, 256, 32, False, 4, 128), - (512, 2048, 128, True, 4, 128), - (1024, 1024, 128, False, 8, 128), - (1024, 1024, 128, True, 8, 128), + (1, 2048, 2048, True, 4, 128), + (1, 2048, 4096, False, 4, 128), + (1, 4096, 2048, False, 4, 128), + (8, 2048, 2048, False, 4, 128), + (8, 2048, 4096, False, 4, 128), + (8, 4096, 2048, False, 4, 128), + (128, 2048, 2048, False, 4, 128), + (128, 2048, 4096, False, 4, 128), + (128, 4096, 2048, False, 4, 128), ] @@ -50,6 +56,7 @@ def get_pack_factor(num_bits): assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits + def pack_cols( q_w: torch.Tensor, num_bits: int, @@ -99,6 +106,7 @@ def pack_rows( q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) return q_res + def torch_dequantize(q_weight, q_zeros, scales, g_idx, use_shuffle, bit, K, N): assert bit == 4, "Reference dequantization only supports 4-bit" group_size = K // scales.shape[0] @@ -160,8 +168,8 @@ def test( M, K, N, - use_exllama, - quant_bit, + use_exllama, + quant_bit, group_size, dtype=InfiniDtype.F16, sync=None, @@ -181,7 +189,9 @@ def test( return else: g_idx = torch.tensor( - [i // group_size for i in range(K)], dtype=torch.int32, device=device + [i // group_size for i in range(K)], + dtype=torch.int32, + device=b_fp.torch_tensor().device, ) b_shuffled = b_fp.torch_tensor()[g_idx] @@ -196,7 +206,10 @@ def test( zeros_float = (-b_min / scales).round() q_b = ( - (b_grouped / scales + zeros_float).round().clamp(0, 2**quant_bit - 1).to(torch.uint8) + (b_grouped / scales + zeros_float) + .round() + .clamp(0, 2**quant_bit - 1) + .to(torch.uint8) ) q_zeros_unpacked = zeros_float.to(torch.uint8) - 1 @@ -210,17 +223,45 @@ def test( A = TestTensor((M, K), None, dtype, device) C = TestTensor((M, N), None, dtype, device) - - B = TestTensor(b_q_weight.shape, b_q_weight.stride(), infiniDtype.I32, device, mode="manual", set_tensor=b_q_weight) - b_scales = TestTensor(b_gptq_scales.shape, b_gptq_scales.stride(), dtype, device, mode="manual", set_tensor=b_gptq_scales) - b_zeros = TestTensor(b_gptq_qzeros.shape, b_gptq_qzeros.stride(), infiniDtype.I32, device, mode="manual", set_tensor=b_gptq_qzeros) - b_g_idx = TestTensor((K, ), g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx) + B = TestTensor( + b_q_weight.shape, + b_q_weight.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=b_q_weight, + ) + b_scales = TestTensor( + b_gptq_scales.shape, + b_gptq_scales.stride(), + dtype, + device, + mode="manual", + set_tensor=b_gptq_scales, + ) + b_zeros = TestTensor( + b_gptq_qzeros.shape, + b_gptq_qzeros.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=b_gptq_qzeros, + ) + b_g_idx = TestTensor( + (K,), g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx + ) if sync is not None: sync() ans = torch_gptq_gemm( - A.torch_tensor(), B.torch_tensor(), b_zeros.torch_tensor(), b_scales.torch_tensor(), b_g_idx.torch_tensor(), use_shuffle, quant_bit + A.torch_tensor(), + B.torch_tensor(), + b_zeros.torch_tensor(), + b_scales.torch_tensor(), + b_g_idx.torch_tensor(), + use_shuffle, + quant_bit, ) descriptor = infiniopOperatorDescriptor_t() @@ -250,7 +291,7 @@ def test( descriptor, ctypes.byref(workspace_size) ) ) - workspace = TestWorkspace(workspace_size.value, x.device) + workspace = TestWorkspace(workspace_size.value, A.device) def lib_gptq_gemm(): check_error( @@ -276,7 +317,6 @@ def lib_gptq_gemm(): atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug(C.actual_tensor(), ans, atol=atol, rtol=rtol) - assert torch.allclose(C.actual_tensor(), ans, atol=atol, rtol=rtol) diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 602fb190d..d4f96ed29 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -25,6 +25,32 @@ target("infiniop-nvidia") add_links("cudnn") end + before_build(function (target) + local TORCH_DIR = os.iorun("python -c 'import torch; import os; print(os.path.dirname(torch.__file__))' 2>/dev/null"):trim() + local PYTHON_INCLUDE = os.iorun("python -c 'import sysconfig; print(sysconfig.get_paths()[\"include\"])' 2>/dev/null"):trim() + local PYTHON_LIB_DIR = os.iorun("python -c 'import sysconfig; print(sysconfig.get_config_var(\"LIBDIR\"))' 2>/dev/null"):trim() + local LIB_PYTHON = os.iorun("python -c 'import glob,sysconfig,os; print(glob.glob(os.path.join(sysconfig.get_config_var(\"LIBDIR\"),\"libpython*.so\"))[0])' 2>/dev/null"):trim() + + target:add("includedirs", + TORCH_DIR .. "/include/torch/csrc/api/include", + TORCH_DIR .. "/include", + PYTHON_INCLUDE, + { public = true } + ) + + target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, { public = true }) + target:add("links", + "torch", + "torch_cuda", + "torch_cpu", + "c10", + "c10_cuda", + "torch_python", + { public = true } + ) + target:add("links", LIB_PYTHON, { public = true }) + end) + on_load(function (target) import("lib.detect.find_tool") local nvcc = find_tool("nvcc") From 3c57498dae296620254d5dcc91b43c27fa55d02e Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 27 Mar 2026 17:23:15 +0800 Subject: [PATCH 3/4] issue/1102: success qy gptq --- src/infiniop/ops/gptq_gemm/cuda/kernel.cuh | 5 +-- .../ops/gptq_gemm/cuda/my_operator.cpp | 13 ------- .../ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu | 30 +++++++--------- xmake/nvidia.lua | 26 -------------- xmake/qy.lua | 36 ------------------- 5 files changed, 14 insertions(+), 96 deletions(-) delete mode 100644 src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp diff --git a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh index 3d354d375..352ff8c79 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh @@ -6,11 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include -#include #include #include -#include #include "compat.cuh" #include "matrix_view.cuh" @@ -30,7 +27,7 @@ namespace gptq { #define MAX_ALT_GEMM_ROWS 8 #define THREADS_X 32 #define THREADS_Y 32 -#define DIVIDE(x, size) (((x) + (size)-1) / (size)) +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) #include diff --git a/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp b/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp deleted file mode 100644 index ff4b779e8..000000000 --- a/src/infiniop/ops/gptq_gemm/cuda/my_operator.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include -#include - -// 声明 CUDA 函数 -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit); - -// 绑定到 Python -PYBIND11_MODULE(vllm_gptq, m) { - m.def("gptq_gemm", &gptq_gemm, "GPTQ GEMM (CUDA)"); -} diff --git a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu index 7f336c2f5..dd7b78dc1 100644 --- a/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu +++ b/src/infiniop/ops/gptq_gemm/nvidia/gptq_gemm_nvidia.cu @@ -574,7 +574,7 @@ void gemm_half_q_half_cuda_part(const half *a, const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, const half *b_gptq_scales, const int *b_q_perm, half *c, int size_m, int size_n, int size_k, - int m_count, int groups, int bit) { + int m_count, int groups, int bit, cudaStream_t stream) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -585,7 +585,6 @@ void gemm_half_q_half_cuda_part(const half *a, const uint32_t *b_q_weight, fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k, groups, b_q_perm); @@ -1018,7 +1017,7 @@ void reconstruct_exllama(const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, const half *b_gptq_scales, const int *b_q_perm, half *out, int height, int width, int groups, - int bit) { + int bit, cudaStream_t stream) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1034,7 +1033,6 @@ void reconstruct_exllama(const uint32_t *b_q_weight, reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); reconstruct_exllama_kernel<<>>( b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, out); @@ -1230,7 +1228,7 @@ void gemm_half_q_half_alt(const half *a, const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, const half *b_gptq_scales, const int *b_g_idx, half *c, int size_m, int size_n, int size_k, - int bit) { + int bit, cudaStream_t stream) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1244,7 +1242,6 @@ void gemm_half_q_half_alt(const half *a, const uint32_t *b_q_weight, kernel = gemm_half_q_half_alt_8bit_kernel; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>( (const half2 *)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, size_m, size_k / 32 * bit, size_n); @@ -1334,7 +1331,7 @@ INFINIOP_CUDA_KERNEL reconstruct_gptq_3bit_kernel( void reconstruct_gptq(const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, const half *b_gptq_scales, const int *b_g_idx, half *out, - int height, int width, int groups, int bit) { + int height, int width, int groups, int bit, cudaStream_t stream) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1351,7 +1348,6 @@ void reconstruct_gptq(const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, gridDim.y = DIVIDE(height, 32); } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>(b_q_weight, b_gptq_scales, b_gptq_qzeros, b_g_idx, height, width, groups, out); @@ -1362,7 +1358,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half *a, const uint32_t *b_gptq_qzeros, const half *b_gptq_scales, const int *b_g_idx, half *c, half *temp_dq, int size_m, int size_n, - int size_k, int groups, bool use_exllama, int bit) { + int size_k, int groups, bool use_exllama, int bit, cudaStream_t stream) { bool use_reconstruct; if (use_exllama) { use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); @@ -1375,10 +1371,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half *a, // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); + temp_dq, size_k, size_n, groups, bit, stream); } else { reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); + temp_dq, size_k, size_n, groups, bit, stream); } const half alpha = __float2half(1.0f); @@ -1394,18 +1390,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half *a, if (max_chunks) { gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, last_chunk, size_n, size_k, - BLOCK_M_SIZE_MAX, groups, bit); + BLOCK_M_SIZE_MAX, groups, bit, stream); } if (last_chunk_size) { gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c + last_chunk * size_n, last_chunk_size, - size_n, size_k, last_chunk_size, groups, bit); + size_n, size_k, last_chunk_size, groups, bit, stream); } } else { gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k, bit); + c, size_m, size_n, size_k, bit, stream); } } @@ -1657,7 +1653,7 @@ INFINIOP_CUDA_KERNEL make_sequential_8bit_kernel(const uint32_t *__restrict__ w, cublasStatus_t GptqGemmKernel(void *c, const void *a, const void *b, const void *b_scales, const void *b_zeros, const void *b_g_idx, int M, int K, int N, int num_groups, - bool use_exllama, int64_t bit, cublasHandle_t cublas_handle, void *workspace) { + bool use_exllama, int64_t bit, cublasHandle_t cublas_handle, void *workspace, cudaStream_t stream) { char *workspace_ptr = reinterpret_cast(workspace); half *temp_dq = reinterpret_cast(workspace_ptr); // shape ? @@ -1673,7 +1669,7 @@ cublasStatus_t GptqGemmKernel(void *c, const void *a, const void *b, N, K, num_groups, - use_exllama, bit); + use_exllama, bit, stream); return CUBLAS_STATUS_SUCCESS; } @@ -1735,7 +1731,7 @@ infiniStatus_t Descriptor::calculate(void *workspace, GptqGemmKernel(out, a, b, b_scales, b_zeros, b_g_idx, M, K, N, num_groups, - use_exllama, quant_bit, handle, workspace)); + use_exllama, quant_bit, handle, workspace, (cudaStream_t)stream)); return INFINI_STATUS_SUCCESS; })); return INFINI_STATUS_SUCCESS; diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index d4f96ed29..602fb190d 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -25,32 +25,6 @@ target("infiniop-nvidia") add_links("cudnn") end - before_build(function (target) - local TORCH_DIR = os.iorun("python -c 'import torch; import os; print(os.path.dirname(torch.__file__))' 2>/dev/null"):trim() - local PYTHON_INCLUDE = os.iorun("python -c 'import sysconfig; print(sysconfig.get_paths()[\"include\"])' 2>/dev/null"):trim() - local PYTHON_LIB_DIR = os.iorun("python -c 'import sysconfig; print(sysconfig.get_config_var(\"LIBDIR\"))' 2>/dev/null"):trim() - local LIB_PYTHON = os.iorun("python -c 'import glob,sysconfig,os; print(glob.glob(os.path.join(sysconfig.get_config_var(\"LIBDIR\"),\"libpython*.so\"))[0])' 2>/dev/null"):trim() - - target:add("includedirs", - TORCH_DIR .. "/include/torch/csrc/api/include", - TORCH_DIR .. "/include", - PYTHON_INCLUDE, - { public = true } - ) - - target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, { public = true }) - target:add("links", - "torch", - "torch_cuda", - "torch_cpu", - "c10", - "c10_cuda", - "torch_python", - { public = true } - ) - target:add("links", LIB_PYTHON, { public = true }) - end) - on_load(function (target) import("lib.detect.find_tool") local nvcc = find_tool("nvcc") diff --git a/xmake/qy.lua b/xmake/qy.lua index 1e1ae1be6..810f88c2f 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -1,33 +1,3 @@ -local TORCH_DIR = os.getenv("TORCH_DIR") - -if not TORCH_DIR then - raise("TORCH_DIR is not set! please export it first") -end - -print("TORCH_DIR =", TORCH_DIR) - -if TORCH_DIR and os.isdir(TORCH_DIR) then - local TORCH_INCLUDE = TORCH_DIR .. "/include" - local TORCH_LIB = TORCH_DIR .. "/lib" - - print("✅ 自动找到 PyTorch 路径: " .. TORCH_DIR) - print("✅ PyTorch 头文件: " .. TORCH_INCLUDE) - print("✅ PyTorch 库路径: " .. TORCH_LIB) - - -- 添加 PyTorch 头文件 - add_includedirs(TORCH_INCLUDE) - add_includedirs(TORCH_INCLUDE .. "/torch/csrc/api/include") - - -- 添加 PyTorch 库路径 - add_linkdirs(TORCH_LIB) - - -- 链接 PyTorch 核心库(解决 undefined symbol) - add_links("torch", "torch_cpu", "torch_cuda", "c10", "c10_cuda", "torch_python") -else - print("⚠️ 未检测到 PyTorch,将跳过 PyTorch 依赖") -end - - local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH") if CUDNN_ROOT ~= nil then add_includedirs(CUDNN_ROOT .. "/include") @@ -51,12 +21,6 @@ rule("qy.cuda") on_load(function (target) target:add("includedirs", "/usr/local/denglin/sdk/include") - - -- 把 PyTorch 头文件也加入自定义规则 - if TORCH_DIR then - target:add("includedirs", TORCH_DIR .. "/include") - target:add("includedirs", TORCH_DIR .. "/include/torch/csrc/api/include") - end end) after_load(function (target) From 2459e9bf6cb27c70f4cb1c8fcf7872f4bd453a22 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Mon, 30 Mar 2026 09:27:58 +0800 Subject: [PATCH 4/4] issue/1102: modified format --- src/infiniop/ops/gptq_gemm/cuda/kernel.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh index 352ff8c79..b93715198 100644 --- a/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh +++ b/src/infiniop/ops/gptq_gemm/cuda/kernel.cuh @@ -27,7 +27,7 @@ namespace gptq { #define MAX_ALT_GEMM_ROWS 8 #define THREADS_X 32 #define THREADS_Y 32 -#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) +#define DIVIDE(x, size) (((x) + (size)-1) / (size)) #if defined(USE_ROCM) #include