diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..3780a80e7 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -9,6 +9,7 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/dequantize_gptq.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/layer_norm.h" diff --git a/include/infiniop/ops/dequantize_gptq.h b/include/infiniop/ops/dequantize_gptq.h new file mode 100644 index 000000000..5b0a83351 --- /dev/null +++ b/include/infiniop/ops/dequantize_gptq.h @@ -0,0 +1,30 @@ +#ifndef __INFINIOP_DEQUANTIZE_GPTQ_API_H__ +#define __INFINIOP_DEQUANTIZE_GPTQ_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopDequantizeGPTQDescriptor_t; + +__C __export infiniStatus_t infiniopCreateDequantizeGPTQDescriptor(infiniopHandle_t handle, + infiniopDequantizeGPTQDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t qweight_desc, + infiniopTensorDescriptor_t scales_desc, + infiniopTensorDescriptor_t zeros_desc, + infiniopTensorDescriptor_t g_idx_desc); // add g_idx + +__C __export infiniStatus_t infiniopGetDequantizeGPTQWorkspaceSize(infiniopDequantizeGPTQDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopDequantizeGPTQ(infiniopDequantizeGPTQDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *qweight, + const void *scales, + const void *zeros, + const void *g_idx, // add g_idx + void *stream); + +__C __export infiniStatus_t infiniopDestroyDequantizeGPTQDescriptor(infiniopDequantizeGPTQDescriptor_t desc); + +#endif diff --git a/scripts/python_test.py b/scripts/python_test.py index 06af369ef..23ecdf046 100644 --- a/scripts/python_test.py +++ b/scripts/python_test.py @@ -17,7 +17,8 @@ def run_tests(args): "causal_softmax.py", "clip.py", "conv.py", - #"dequantize_awq.py", + "dequantize_awq.py", + "dequantize_gptq.py", "gelu.py", "gemm.py", #"layer_norm.py", diff --git a/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_iluvatar.cu b/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_iluvatar.cu index d873ca49a..d0f7bc73a 100644 --- a/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_iluvatar.cu +++ b/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_iluvatar.cu @@ -7,7 +7,7 @@ #include __global__ void __launch_bounds__(64) - dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors, + dequantize_weights_awq(int *__restrict__ B, half *__restrict__ scaling_factors, int *__restrict__ zeros, half *__restrict__ C, int G) { // static constexpr uint32_t ZERO = 0x0; half B_shared[32 * (128 + 8)]; @@ -29,11 +29,11 @@ __global__ void __launch_bounds__(64) half *scaling_factors_ptr2 = scaling_factors + index4; uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2_awq(zeros_loaded); uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2); uint32_t B_loaded = *(uint32_t *)B_ptr2; - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2_awq(B_loaded); // Reinterpret uint4 components as __half2 __half2 *B_loaded_fp16_h2 = reinterpret_cast<__half2 *>(&B_loaded_fp16); @@ -119,7 +119,7 @@ Descriptor::calculate( half *scales_ = const_cast(reinterpret_cast(scales)); int *zeros_ = const_cast(reinterpret_cast(zeros)); - dequantize_weights<<(stream)>>>( + dequantize_weights_awq<<(stream)>>>( qweight_, scales_, zeros_, out_, group_size); return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_kernel.cuh b/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_kernel.cuh index 7ef31c5f9..7acdfe122 100644 --- a/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_kernel.cuh +++ b/src/infiniop/ops/dequantize_awq/iluvatar/dequantize_w42f16_kernel.cuh @@ -11,7 +11,7 @@ * @param source 输入的32位无符号整数,它打包了8个4-bit的数据。 * @return 一个 uint4 变量,其中包含8个反量化后的 half 值。 */ -__device__ __forceinline__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) { +__device__ __forceinline__ uint4 dequantize_s4_to_fp16x2_awq(uint32_t const &source) { // 步骤 1: 从一个 32-bit 源数据中解包出 8 个 4-bit 无符号整数。 // 源数据的内存布局被假定为 [v7, v6, v5, v4, v3, v2, v1, v0], // 其中每个 'v' 都是一个 4-bit 的半字节 (nibble)。 diff --git a/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_kernel.h b/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_kernel.h index 3bc6c2d6c..5b926e892 100644 --- a/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_kernel.h +++ b/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_kernel.h @@ -11,7 +11,7 @@ * @param source 输入的32位无符号整数,它打包了8个4-bit的数据。 * @return 一个 uint4 变量,其中包含8个反量化后的 half 值。 */ -__device__ __forceinline__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) { +__device__ __forceinline__ uint4 dequantize_s4_to_fp16x2_awq(uint32_t const &source) { // 步骤 1: 从一个 32-bit 源数据中解包出 8 个 4-bit 无符号整数。 // 源数据的内存布局被假定为 [v7, v6, v5, v4, v3, v2, v1, v0], // 其中每个 'v' 都是一个 4-bit 的半字节 (nibble)。 diff --git a/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.mu b/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.mu index e1b1f4fc3..970a58cd4 100644 --- a/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.mu +++ b/src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.mu @@ -7,7 +7,7 @@ #include __global__ void __launch_bounds__(64) - dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors, + dequantize_weights_awq(int *__restrict__ B, half *__restrict__ scaling_factors, int *__restrict__ zeros, half *__restrict__ C, int G) { // static constexpr uint32_t ZERO = 0x0; half B_shared[32 * (128 + 8)]; @@ -29,11 +29,11 @@ __global__ void __launch_bounds__(64) half *scaling_factors_ptr2 = scaling_factors + index4; uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2_awq(zeros_loaded); uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2); uint32_t B_loaded = *(uint32_t *)B_ptr2; - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2_awq(B_loaded); // Reinterpret uint4 components as __half2 __half2 *B_loaded_fp16_h2 = reinterpret_cast<__half2 *>(&B_loaded_fp16); @@ -119,7 +119,7 @@ Descriptor::calculate( half *scales_ = const_cast(reinterpret_cast(scales)); int *zeros_ = const_cast(reinterpret_cast(zeros)); - dequantize_weights<<(stream)>>>( + dequantize_weights_awq<<(stream)>>>( qweight_, scales_, zeros_, out_, group_size); return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_kernel.cuh b/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_kernel.cuh index d1dcc0f44..133e0dcb5 100644 --- a/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_kernel.cuh +++ b/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_kernel.cuh @@ -1,6 +1,6 @@ #pragma once -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const &source) { +__device__ uint4 dequantize_s4_to_fp16x2_awq(uint32_t const &source) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 // 步骤 1: 从一个 32-bit 源数据中解包出 8 个 4-bit 无符号整数。 // 源数据的内存布局被假定为 [v7, v6, v5, v4, v3, v2, v1, v0], diff --git a/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_nvidia.cu b/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_nvidia.cu index d83e94c5c..e2bf64ad3 100644 --- a/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_nvidia.cu +++ b/src/infiniop/ops/dequantize_awq/nvidia/dequantize_w42f16_nvidia.cu @@ -10,33 +10,36 @@ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 750) __global__ void __launch_bounds__(64) - dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors, - int *__restrict__ zeros, half *__restrict__ C, int G) { + dequantize_weights_awq(int *__restrict__ B, half *__restrict__ scaling_factors, + int *__restrict__ zeros, half *__restrict__ C, int G, + int out_features, int in_features) { // static constexpr uint32_t ZERO = 0x0; - half B_shared[32 * (128 + 8)]; - half *B_shared_ptr2 = B_shared; - - int N = blockDim.x * gridDim.x; // 2 int col = (blockIdx.x * blockDim.x + threadIdx.x); int row = (blockIdx.y * blockDim.y + threadIdx.y); - int index1 = 8 * col + 8 * row * N; + + // 边界检查,防止越界访问 + if (col >= out_features || row >= in_features) return; + + // 每个元素在输出中的起始地址:行主序,连续 8 个 half + int index1 = 8 * col + 8 * row * out_features; half *C_ptr2 = C + index1; - int index2 = col + row * N; + int index2 = col + row * out_features; int *B_ptr2 = B + index2; - int index3 = col + (int)(row / G) * N; + int index3 = col + (int)(row / G) * out_features; int *zeros_ptr2 = zeros + index3; - int index4 = 8 * col + (int)(row / G) * N * 8; + + int index4 = 8 * col + (int)(row / G) * out_features * 8; half *scaling_factors_ptr2 = scaling_factors + index4; uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2_awq(zeros_loaded); uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2); uint32_t B_loaded = *(uint32_t *)B_ptr2; - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2_awq(B_loaded); // Reinterpret uint4 components as __half2 __half2 *B_loaded_fp16_h2 = reinterpret_cast<__half2 *>(&B_loaded_fp16); @@ -55,42 +58,43 @@ __global__ void __launch_bounds__(64) B_loaded_fp16_h2[2] = __hfma2(B_loaded_fp16_h2[2], B_loaded_scale_h2[2], __float2half2_rn(0.0f)); B_loaded_fp16_h2[3] = __hfma2(B_loaded_fp16_h2[3], B_loaded_scale_h2[3], __float2half2_rn(0.0f)); - // Store back to shared memory - *(uint4 *)B_shared_ptr2 = B_loaded_fp16; - + // 直接写回全局内存输出 + half *out_vec = reinterpret_cast(&B_loaded_fp16); + #pragma unroll for (int i = 0; i < 8; ++i) { - *(C_ptr2 + i) = B_shared[i]; + C_ptr2[i] = out_vec[i]; } } #else __global__ void __launch_bounds__(64) - dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors, - int *__restrict__ zeros, half *__restrict__ C, int group_size) { + dequantize_weights_awq(int *__restrict__ B, half *__restrict__ scaling_factors, + int *__restrict__ zeros, half *__restrict__ C, int group_size, + int out_features, int in_features) { static constexpr uint32_t ZERO = 0x0; - half B_shared[32 * (128 + 8)]; - half *B_shared_ptr2 = B_shared; - - int N = blockDim.x * gridDim.x; // 2 int col = (blockIdx.x * blockDim.x + threadIdx.x); int row = blockIdx.y * blockDim.y + threadIdx.y; - int index1 = 8 * col + 8 * row * N; + + // 边界检查,防止越界访问 + if (col >= out_features || row >= in_features) return; + + int index1 = 8 * col + 8 * row * out_features; half *C_ptr2 = C + index1; - int index2 = col + row * N; + int index2 = col + row * out_features; int *B_ptr2 = B + index2; - int index3 = col + (int)(row / group_size) * N; + int index3 = col + (int)(row / group_size) * out_features; int *zeros_ptr2 = zeros + index3; - int index4 = 8 * col + (int)(row / group_size) * N * 8; + int index4 = 8 * col + (int)(row / group_size) * out_features * 8; half *scaling_factors_ptr2 = scaling_factors + index4; uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2_awq(zeros_loaded); uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2); uint32_t B_loaded = *(uint32_t *)B_ptr2; - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2_awq(B_loaded); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); @@ -116,10 +120,11 @@ __global__ void __launch_bounds__(64) : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - *(uint4 *)B_shared_ptr2 = B_loaded_fp16; - + // 直接写回全局内存输出 + half *out_vec = reinterpret_cast(&B_loaded_fp16); + #pragma unroll for (int i = 0; i < 8; ++i) { - *(C_ptr2 + i) = B_shared[i]; + C_ptr2[i] = out_vec[i]; } } #endif @@ -183,8 +188,8 @@ Descriptor::calculate( half *scales_ = const_cast(reinterpret_cast(scales)); int *zeros_ = const_cast(reinterpret_cast(zeros)); - dequantize_weights<<(stream)>>>( - qweight_, scales_, zeros_, out_, group_size); + dequantize_weights_awq<<(stream)>>>( + qweight_, scales_, zeros_, out_, group_size, out_features, in_features); return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/dequantize_gptq/dequantize_gptq.h b/src/infiniop/ops/dequantize_gptq/dequantize_gptq.h new file mode 100644 index 000000000..82898fbaf --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/dequantize_gptq.h @@ -0,0 +1,55 @@ +#ifndef __DEQUANTIZE_GPTQ_H__ +#define __DEQUANTIZE_GPTQ_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::dequantize_gptq::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + DequantizeGPTQInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + size_t workspace_size_, \ + Opaque *opaque, \ + DequantizeGPTQInfo info, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t qweight_desc, \ + infiniopTensorDescriptor_t scales_desc, \ + infiniopTensorDescriptor_t zeros_desc, \ + infiniopTensorDescriptor_t g_idx_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *out, \ + const void *qweight, \ + const void *scales, \ + const void *zeros, \ + const void *g_idx, \ + void *stream) const; \ + }; \ + } + +#endif //__DEQUANTIZE_GPTQ_H__ diff --git a/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_iluvatar.cu b/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_iluvatar.cu new file mode 100644 index 000000000..a0b79d0ca --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_iluvatar.cu @@ -0,0 +1,125 @@ +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "dequantize_w42f16_iluvatar.cuh" +#include "dequantize_w42f16_kernel.cuh" + +#include "../dequantize_gptq.h" +#include +#include + +namespace op::dequantize_gptq::iluvatar { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +// 对齐 nvidia 版:支持 g_idx +// qweight: [in_packed, out_features] packing 8 input channels per word +// zeros: [num_groups, out_packed] packing 8 output channels per word +// scales: [num_groups, out_features], g_idx: [in_features] +__global__ void __launch_bounds__(128) +dequantize_weights_gptq(const uint32_t *__restrict__ qweight, + const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, + const int *__restrict__ g_idx, + half *__restrict__ out, + int in_features, + int out_features, + int out_packed, // ceil(out_features / 8) + int num_groups) { + const int col_pack = blockIdx.x * blockDim.x + threadIdx.x; // packed output column + const int row = blockIdx.y * blockDim.y + threadIdx.y; // real input row + if (col_pack >= out_packed || row >= in_features) return; + + const int gid_raw = g_idx ? g_idx[row] : 0; + const int gid = ((gid_raw % num_groups) + num_groups) % num_groups; + + const int pack_row = row >> 3; // packed input row (8 rows per pack) + const int q_shift = (row & 7) * 4; // nibble shift within uint32 + + const uint32_t zeros_loaded = zeros[gid * out_packed + col_pack]; + + const int col_base = col_pack << 3; // 8 real cols per pack + const int scale_base = gid * out_features + col_base; + + #pragma unroll + for (int j = 0; j < 8; ++j) { + const int col = col_base + j; + if (col >= out_features) break; + + const uint32_t q_loaded = qweight[pack_row * out_features + col]; + const int q_nib = (q_loaded >> q_shift) & 0xF; + + const int z_nib = (zeros_loaded >> (j * 4)) & 0xF; + const half scale = scales[scale_base + j]; + + // aligned with nvidia: (q - (z + 1)) * s + const float v = float(q_nib - (z_nib + 1)) * __half2float(scale); + out[row * out_features + col] = __float2half(v); + } +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t qweight_desc, + infiniopTensorDescriptor_t scales_desc, + infiniopTensorDescriptor_t zeros_desc, + infiniopTensorDescriptor_t g_idx_desc) { + + auto handle = reinterpret_cast(handle_); + auto result = DequantizeGPTQInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc, g_idx_desc); + + *desc_ptr = new Descriptor( + 0, + new Opaque{handle->internal()}, + result.take(), + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *qweight, + const void *scales, + const void *zeros, + const void *g_idx, + void *stream) const { + + (void)workspace; + (void)workspace_size; + + const int in_features = _info.in_features(); + const int out_features = _info.out_features(); + const int out_packed = _info.out_packed(); + const int in_packed = _info.in_packed(); + const int num_groups = _info.num_groups(); + + if (num_groups <= 0 || in_features <= 0 || out_features <= 0 || out_packed <= 0 || in_packed <= 0) + return INFINI_STATUS_BAD_PARAM; + + constexpr int BLOCK_X = 16; // packed columns + constexpr int BLOCK_Y = 4; // rows + dim3 threads(BLOCK_X, BLOCK_Y); + dim3 blocks((out_packed + BLOCK_X - 1) / BLOCK_X, + (in_features + BLOCK_Y - 1) / BLOCK_Y); + + dequantize_weights_gptq<<(stream)>>>( + reinterpret_cast(qweight), + reinterpret_cast(scales), + reinterpret_cast(zeros), + reinterpret_cast(g_idx), + reinterpret_cast(out), + in_features, out_features, out_packed, num_groups); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::dequantize_gptq::iluvatar diff --git a/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_iluvatar.cuh b/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_iluvatar.cuh new file mode 100644 index 000000000..2916b0961 --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_iluvatar.cuh @@ -0,0 +1,8 @@ +#ifndef __DEQUANTIZE_GPTQ_ILUVATAR_CUH__ +#define __DEQUANTIZE_GPTQ_ILUVATAR_CUH__ + +#include "../dequantize_gptq.h" + +DESCRIPTOR(iluvatar) + +#endif // __DEQUANTIZE_GPTQ_ILUVATAR_CUH__ diff --git a/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_kernel.cuh b/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_kernel.cuh new file mode 100644 index 000000000..0cbf67743 --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/iluvatar/dequantize_w42f16_kernel.cuh @@ -0,0 +1,41 @@ +#pragma once +#include +#include + +/** + * @brief Unpack 8x 4-bit unsigned integers (0..15) from a uint32_t into 8 half values. + * + * GPTQ dequant is applied outside this helper (aligned with nvidia impl): + * out = (q - (z + 1)) * s + * + * Output order matches interleaved half2 packing: + * (v0,v4), (v1,v5), (v2,v6), (v3,v7) + */ +__device__ __forceinline__ uint4 dequantize_s4_to_fp16x2_gptq(uint32_t const &source) { + const unsigned int v0 = (source >> 0) & 0x0F; + const unsigned int v1 = (source >> 4) & 0x0F; + const unsigned int v2 = (source >> 8) & 0x0F; + const unsigned int v3 = (source >> 12) & 0x0F; + const unsigned int v4 = (source >> 16) & 0x0F; + const unsigned int v5 = (source >> 20) & 0x0F; + const unsigned int v6 = (source >> 24) & 0x0F; + const unsigned int v7 = (source >> 28) & 0x0F; + + // NOTE: GPTQ uses unsigned q/z in [0,15]. No "-8" signed mapping here. + const half hv0 = half(v0); + const half hv1 = half(v1); + const half hv2 = half(v2); + const half hv3 = half(v3); + const half hv4 = half(v4); + const half hv5 = half(v5); + const half hv6 = half(v6); + const half hv7 = half(v7); + + uint4 result; + __half2 *p = reinterpret_cast<__half2 *>(&result); + p[0] = __halves2half2(hv0, hv4); + p[1] = __halves2half2(hv1, hv5); + p[2] = __halves2half2(hv2, hv6); + p[3] = __halves2half2(hv3, hv7); + return result; +} diff --git a/src/infiniop/ops/dequantize_gptq/info.h b/src/infiniop/ops/dequantize_gptq/info.h new file mode 100644 index 000000000..9c12237c2 --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/info.h @@ -0,0 +1,51 @@ +#ifndef __DEQUANTIZE_GPTQ_INFO_H__ +#define __DEQUANTIZE_GPTQ_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +#include + +namespace op::dequantize_gptq { + +class DequantizeGPTQInfo { + DequantizeGPTQInfo() = default; + +public: + int _in_features, _out_features, _num_groups, _out_packed, _in_packed; + + int in_features() const { return _in_features; } + int out_features() const { return _out_features; } + int num_groups() const { return _num_groups; } + int out_packed() const { return _out_packed; } + int in_packed() const { return _in_packed; } + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t qweight_desc, + infiniopTensorDescriptor_t scales_desc, + infiniopTensorDescriptor_t zeros_desc, + infiniopTensorDescriptor_t g_idx_desc) { + + const int _in_features = g_idx_desc->dim(0); // real input channels + const int _in_packed = qweight_desc->dim(0); // ceil(in_features / 8) + const int _out_features = qweight_desc->dim(1); // real output channels + const int _num_groups = scales_desc->dim(0); // should be in_features / group_size + const int _out_packed = zeros_desc->dim(1); // ceil(out_features / 8) + + assert(out_desc->dim(0) == _in_features); + assert(out_desc->dim(1) == _out_features); + assert(_in_packed == (_in_features + 7) / 8); + assert(scales_desc->dim(1) == _out_features); + assert(_num_groups == zeros_desc->dim(0)); + assert(_out_packed == (_out_features + 7) / 8); + + return utils::Result( + DequantizeGPTQInfo{_in_features, _out_features, _num_groups, _out_packed, _in_packed}); + } +}; + +} // namespace op::dequantize_gptq + +#endif // __DEQUANTIZE_GPTQ_INFO_H__ diff --git a/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_kernel.h b/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_kernel.h new file mode 100644 index 000000000..681404b54 --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_kernel.h @@ -0,0 +1,41 @@ +#pragma once +#include // __half / __half2 + +/** + * @brief Unpack 8x 4-bit unsigned integers (0..15) from a uint32_t into 8 half values. + * + * GPTQ dequant is applied outside this helper (aligned with nvidia impl): + * out = (q - (z + 1)) * s + * + * Output order matches the interleaved half2 packing: + * (v0,v4), (v1,v5), (v2,v6), (v3,v7) + */ +__device__ __forceinline__ uint4 dequantize_s4_to_fp16x2_gptq(uint32_t const &source) { + // unpack 8 nibbles: v0..v7 in [0, 15] + const unsigned int v0 = (source >> 0) & 0x0F; + const unsigned int v1 = (source >> 4) & 0x0F; + const unsigned int v2 = (source >> 8) & 0x0F; + const unsigned int v3 = (source >> 12) & 0x0F; + const unsigned int v4 = (source >> 16) & 0x0F; + const unsigned int v5 = (source >> 20) & 0x0F; + const unsigned int v6 = (source >> 24) & 0x0F; + const unsigned int v7 = (source >> 28) & 0x0F; + + // NOTE: no "-8" offset here (unlike signed s4). GPTQ uses unsigned q/z. + const __half hv0 = __half(v0); + const __half hv1 = __half(v1); + const __half hv2 = __half(v2); + const __half hv3 = __half(v3); + const __half hv4 = __half(v4); + const __half hv5 = __half(v5); + const __half hv6 = __half(v6); + const __half hv7 = __half(v7); + + uint4 result; + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + result_ptr[0] = __halves2half2(hv0, hv4); + result_ptr[1] = __halves2half2(hv1, hv5); + result_ptr[2] = __halves2half2(hv2, hv6); + result_ptr[3] = __halves2half2(hv3, hv7); + return result; +} \ No newline at end of file diff --git a/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_moore.h b/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_moore.h new file mode 100644 index 000000000..689657e0a --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_moore.h @@ -0,0 +1,8 @@ +#ifndef __DEQUANTIZE_GPTQ_MOORE_H__ +#define __DEQUANTIZE_GPTQ_MOORE_H__ + +#include "../dequantize_gptq.h" + +DESCRIPTOR(moore) + +#endif // __DEQUANTIZE_GPTQ_MOORE_H__ diff --git a/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_moore.mu b/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_moore.mu new file mode 100644 index 000000000..290d68c8b --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/moore/dequantize_w42f16_moore.mu @@ -0,0 +1,125 @@ +#include "../../../devices/moore/moore_handle.h" +#include "../../../devices/moore/moore_kernel_common.h" +#include "dequantize_w42f16_moore.h" +// #include "dequantize_w42f16_kernel.h" // 不再需要(保留也无妨) + +#include "../dequantize_gptq.h" +#include +#include + +// 对齐 nvidia 版:支持 g_idx +// qweight: [in_packed, out_features],每个 uint32 打包 8 个输入通道的 4-bit +// zeros: [num_groups, out_packed],每个 uint32 打包 8 个输出通道的 4-bit +// scales: [num_groups, out_features] +// g_idx: [in_features] +__global__ void __launch_bounds__(128) +dequantize_weights_gptq(const uint32_t *__restrict__ qweight, + const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, + const int *__restrict__ g_idx, + half *__restrict__ out, + int in_features, + int out_features, + int out_packed, // ceil(out_features / 8) + int num_groups) { + const int col_pack = blockIdx.x * blockDim.x + threadIdx.x; // packed output column + const int row = blockIdx.y * blockDim.y + threadIdx.y; // real input row + if (col_pack >= out_packed || row >= in_features) return; + + // clamp gid to [0, num_groups) + const int gid_raw = g_idx ? g_idx[row] : 0; + const int gid = ((gid_raw % num_groups) + num_groups) % num_groups; + + const int pack_row = row >> 3; // packed input row (8 rows per pack) + const int q_shift = (row & 7) * 4; // nibble shift within uint32 + + const int zero_idx = gid * out_packed + col_pack; + const uint32_t zeros_loaded = zeros[zero_idx]; + + const int col_base = col_pack << 3; // 8 real cols per pack + const int scale_base = gid * out_features + col_base; + + #pragma unroll + for (int j = 0; j < 8; ++j) { + const int col = col_base + j; + if (col >= out_features) break; + + const uint32_t q_loaded = qweight[pack_row * out_features + col]; + const int q_nib = (q_loaded >> q_shift) & 0xF; + + const int z_nib = (zeros_loaded >> (j * 4)) & 0xF; + const half scale = scales[scale_base + j]; + + // 与 nvidia 版一致: (q - (z + 1)) * s + const float v = float(q_nib - (z_nib + 1)) * __half2float(scale); + out[row * out_features + col] = __float2half(v); + } +} + +namespace op::dequantize_gptq::moore { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t qweight_desc, + infiniopTensorDescriptor_t scales_desc, + infiniopTensorDescriptor_t zeros_desc, + infiniopTensorDescriptor_t g_idx_desc) { + + auto handle = reinterpret_cast(handle_); + auto result = DequantizeGPTQInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc, g_idx_desc); + + *desc_ptr = new Descriptor( + 0, + new Opaque{handle->internal()}, + result.take(), + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *qweight, + const void *scales, + const void *zeros, + const void *g_idx, + void *stream) const { + + const int in_features = _info.in_features(); + const int out_features = _info.out_features(); + const int out_packed = _info.out_packed(); + const int in_packed = _info.in_packed(); + const int num_groups = _info.num_groups(); + + if (num_groups <= 0 || in_features <= 0 || out_features <= 0 || out_packed <= 0 || in_packed <= 0) + return INFINI_STATUS_BAD_PARAM; + + constexpr int BLOCK_X = 16; // packed columns + constexpr int BLOCK_Y = 4; // rows + dim3 threads(BLOCK_X, BLOCK_Y); + dim3 blocks((out_packed + BLOCK_X - 1) / BLOCK_X, + (in_features + BLOCK_Y - 1) / BLOCK_Y); + + dequantize_weights_gptq<<(stream)>>>( + reinterpret_cast(qweight), + reinterpret_cast(scales), + reinterpret_cast(zeros), + reinterpret_cast(g_idx), + reinterpret_cast(out), + in_features, out_features, out_packed, num_groups); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::dequantize_gptq::moore diff --git a/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_kernel.cuh b/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_kernel.cuh new file mode 100644 index 000000000..d77d888ee --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_kernel.cuh @@ -0,0 +1,124 @@ +#pragma once + +__device__ uint4 dequantize_s4_to_fp16x2_gptq(uint32_t const &source) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + // 步骤 1: 从一个 32-bit 源数据中解包出 8 个 4-bit 无符号整数。 + // 源数据的内存布局被假定为 [v7, v6, v5, v4, v3, v2, v1, v0], + // 其中每个 'v' 都是一个 4-bit 的半字节 (nibble)。 + const unsigned int v0 = (source >> 0) & 0x0F; + const unsigned int v1 = (source >> 4) & 0x0F; + const unsigned int v2 = (source >> 8) & 0x0F; + const unsigned int v3 = (source >> 12) & 0x0F; + const unsigned int v4 = (source >> 16) & 0x0F; + const unsigned int v5 = (source >> 20) & 0x0F; + const unsigned int v6 = (source >> 24) & 0x0F; + const unsigned int v7 = (source >> 28) & 0x0F; + + // 步骤 2: GPTQ 是 (Q - Z) * S。 + // Q 和 Z 都是无符号数 [0, 15]。 + // 这里不需要 - offset + + __half hv0 = __half(v0); + __half hv1 = __half(v1); + __half hv2 = __half(v2); + __half hv3 = __half(v3); + __half hv4 = __half(v4); + __half hv5 = __half(v5); + __half hv6 = __half(v6); + __half hv7 = __half(v7); + + // 步骤 3: 将 half 值按 PTX 交错顺序打包成 __half2 并存入 result 中。 + // 顺序:result_ptr[0]: low=hv0, high=hv4 + // result_ptr[1]: low=hv1, high=hv5 + // result_ptr[2]: low=hv2, high=hv6 + // result_ptr[3]: low=hv3, high=hv7 + // __halves2half2 函数:low 为第一个参数,high 为第二个参数。 + uint4 result; + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + + result_ptr[0] = __halves2half2(hv0, hv4); + result_ptr[1] = __halves2half2(hv1, hv5); + result_ptr[2] = __halves2half2(hv2, hv6); + result_ptr[3] = __halves2half2(hv3, hv7); + + return result; +#else + uint4 result; + + uint32_t *h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + return result; +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} \ No newline at end of file diff --git a/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_nvidia.cu b/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_nvidia.cu new file mode 100644 index 000000000..b77446190 --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_nvidia.cu @@ -0,0 +1,126 @@ +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) + +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "dequantize_w42f16_kernel.cuh" +#include "dequantize_w42f16_nvidia.cuh" +#include "../dequantize_gptq.h" +#include + +namespace op::dequantize_gptq::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t qweight_desc, + infiniopTensorDescriptor_t scales_desc, + infiniopTensorDescriptor_t zeros_desc, + infiniopTensorDescriptor_t g_idx_desc) { + + auto handle = reinterpret_cast(handle_); + auto result = DequantizeGPTQInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc, g_idx_desc); + + *desc_ptr = new Descriptor( + 0, + new Opaque{handle->internal()}, + result.take(), + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// 仅保留这一版内核(支持 g_idx) +// qweight: [in_packed, out_features] packing 8 input channels per word +// zeros: [num_groups, out_packed] packing 8 output channels per word +// scales: [num_groups, out_features], g_idx: [in_features] +__global__ void __launch_bounds__(128) +dequantize_weights_gptq(const uint32_t *__restrict__ qweight, + const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, + const int *__restrict__ g_idx, + half *__restrict__ out, + int in_features, + int out_features, + int out_packed, // ceil(out_features / 8) + int num_groups) { + // Each thread handles one packed output column (8 real output cols). + const int col_pack = blockIdx.x * blockDim.x + threadIdx.x; // packed output column + const int row = blockIdx.y * blockDim.y + threadIdx.y; // real input row + if (col_pack >= out_packed || row >= in_features) return; + + // Clamp gid to valid range + const int gid_raw = g_idx ? g_idx[row] : 0; + const int gid = ((gid_raw % num_groups) + num_groups) % num_groups; + + const int pack_row = row >> 3; // packed input row + + const int zero_idx = gid * out_packed + col_pack; // zeros layout: [num_groups, out_packed] + const uint32_t zeros_loaded = zeros[zero_idx]; + + const int q_shift = (row & 7) * 4; // qweight packs 8 input rows + const int col_base = col_pack << 3; // 8 real cols per pack + const int scale_base = gid * out_features + col_base; + + #pragma unroll + for (int j = 0; j < 8; ++j) { + const int col = col_base + j; + if (col >= out_features) break; + + const uint32_t q_loaded = qweight[pack_row * out_features + col]; + const int q_nib = (q_loaded >> q_shift) & 0xF; + + const int z_nib = (zeros_loaded >> (j * 4)) & 0xF; + const half scale = scales[scale_base + j]; + + // GPTQ quirk: The stored zero point is usually offset by 1. + // Standard formula: (q - (z + 1)) * s + const float v = float(q_nib - (z_nib + 1)) * __half2float(scale); + out[row * out_features + col] = __float2half(v); + } +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, + const void *qweight, + const void *scales, + const void *zeros, + const void *g_idx, + void *stream) const { + + const int in_features = _info.in_features(); + const int out_features = _info.out_features(); + const int out_packed = _info.out_packed(); + const int in_packed = _info.in_packed(); + const int num_groups = _info.num_groups(); + + if (num_groups <= 0 || in_features <= 0 || out_features <= 0 || out_packed <= 0 || in_packed <= 0) + return INFINI_STATUS_BAD_PARAM; + + constexpr int BLOCK_X = 16; // packed columns + constexpr int BLOCK_Y = 4; // rows + dim3 threads(BLOCK_X, BLOCK_Y); + dim3 blocks((out_packed + BLOCK_X - 1) / BLOCK_X, + (in_features + BLOCK_Y - 1) / BLOCK_Y); + + dequantize_weights_gptq<<(stream)>>>( + reinterpret_cast(qweight), + reinterpret_cast(scales), + reinterpret_cast(zeros), + reinterpret_cast(g_idx), + reinterpret_cast(out), + in_features, out_features, out_packed, num_groups); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::dequantize_gptq::nvidia + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_nvidia.cuh b/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_nvidia.cuh new file mode 100644 index 000000000..ccaa0e429 --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/nvidia/dequantize_w42f16_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __DEQUANTIZE_GPTQ_CUDA_CUH__ +#define __DEQUANTIZE_GPTQ_CUDA_CUH__ + +#include "../dequantize_gptq.h" + +DESCRIPTOR(nvidia) + +#endif // __DEQUANTIZE_GPTQ_CUDA_CUH__ diff --git a/src/infiniop/ops/dequantize_gptq/operator.cc b/src/infiniop/ops/dequantize_gptq/operator.cc new file mode 100644 index 000000000..a0790612d --- /dev/null +++ b/src/infiniop/ops/dequantize_gptq/operator.cc @@ -0,0 +1,143 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/dequantize_gptq.h" + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#include "nvidia/dequantize_w42f16_nvidia.cuh" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/dequantize_w42f16_moore.h" +#endif +#ifdef ENABLE_ILUVATAR_API +#include "iluvatar/dequantize_w42f16_iluvatar.cuh" +#endif + +__C infiniStatus_t infiniopCreateDequantizeGPTQDescriptor( + infiniopHandle_t handle, + infiniopDequantizeGPTQDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t qweight_desc, + infiniopTensorDescriptor_t scales_desc, + infiniopTensorDescriptor_t zeros_desc, + infiniopTensorDescriptor_t g_idx_desc) { // add g_idx +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::dequantize_gptq::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + qweight_desc, \ + scales_desc, \ + zeros_desc, g_idx_desc) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, iluvatar); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetDequantizeGPTQWorkspaceSize(infiniopDequantizeGPTQDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, iluvatar); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__C infiniStatus_t infiniopDequantizeGPTQ( + infiniopDequantizeGPTQDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *qweight, + const void *scales, + const void *zeros, + const void *g_idx, // add g_idx + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, qweight, scales, zeros, g_idx, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, iluvatar); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroyDequantizeGPTQDescriptor(infiniopDequantizeGPTQDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_MOORE_API + DELETE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_ILUVATAR_API + DELETE(INFINI_DEVICE_ILUVATAR, iluvatar); +#endif +#ifdef ENABLE_QY_API + DELETE(INFINI_DEVICE_QY, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif \ No newline at end of file diff --git a/test/infiniop/dequantize_gptq.py b/test/infiniop/dequantize_gptq.py new file mode 100644 index 000000000..a9d1a0404 --- /dev/null +++ b/test/infiniop/dequantize_gptq.py @@ -0,0 +1,276 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + TestWorkspace, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +# ============================================================================== +# Configuration +# ============================================================================== +# Each case: (in_features, out_features, group_size) +_TEST_CASES = [ + (128, 256, 32), + (512, 2048, 128), + (1024, 1024, 128), + # Non-multiple-of-8 edge case for both dims + (513, 257, 32), +] + +_TENSOR_DTYPES = [InfiniDtype.F16] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 0.0, "rtol": 1e-3}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +# ============================================================================== +# Reference Implementation (matches CUDA kernel) +# ============================================================================== + +def _unpack_qweight_int4_packed_by_rows(qweight_packed: torch.Tensor, in_features: int) -> torch.Tensor: + """qweight_packed: [in_packed, out_features] int32/uint32. + + Packs 8 input rows per 32-bit word; nibble shift = (row_in_pack * 4). + Returns: [in_features, out_features] int32 in [0, 15]. + """ + assert qweight_packed.dim() == 2 + shifts = torch.arange(0, 32, 4, device=qweight_packed.device, dtype=torch.int32) + # [in_packed, 8, out_features] + expanded = torch.bitwise_right_shift(qweight_packed[:, None, :], shifts[None, :, None]) + vals = torch.bitwise_and(expanded, 0xF).to(torch.int32) + out_features = qweight_packed.shape[1] + return vals.reshape(-1, out_features)[:in_features, :] + + +def _unpack_zeros_int4_packed_by_cols(zeros_packed: torch.Tensor, out_features: int) -> torch.Tensor: + """zeros_packed: [num_groups, out_packed] int32/uint32. + + Packs 8 output cols per 32-bit word; nibble shift = (col_in_pack * 4). + Returns: [num_groups, out_features] int32 in [0, 15]. + """ + assert zeros_packed.dim() == 2 + shifts = torch.arange(0, 32, 4, device=zeros_packed.device, dtype=torch.int32) + # [num_groups, out_packed, 8] + expanded = torch.bitwise_right_shift(zeros_packed[:, :, None], shifts[None, None, :]) + vals = torch.bitwise_and(expanded, 0xF).to(torch.int32) + return vals.reshape(zeros_packed.shape[0], -1)[:, :out_features] + + +def dequantize_gptq_ref( + qweight_packed: torch.Tensor, + scales: torch.Tensor, + zeros_packed: torch.Tensor, + g_idx: torch.Tensor, + group_size: int, +) -> torch.Tensor: + """Reference matches kernel in dequantize_w42f16_nvidia.cu. + + out[row, col] = (q - (z + 1)) * scale + where q is int4 packed by input rows and z is int4 packed by output cols. + gid = g_idx[row] (clamped modulo num_groups). + """ + in_features = g_idx.numel() + out_features = qweight_packed.shape[1] + + num_groups = scales.shape[0] + assert scales.shape == (num_groups, out_features) + + q = _unpack_qweight_int4_packed_by_rows(qweight_packed, in_features) # [in_features, out_features] + z_full = _unpack_zeros_int4_packed_by_cols(zeros_packed, out_features) # [num_groups, out_features] + + gid_raw = g_idx.to(torch.int32) + gid = ((gid_raw % num_groups) + num_groups) % num_groups + + z = z_full[gid] # [in_features, out_features] + s = scales[gid] # [in_features, out_features] + + out = (q - (z + 1)).to(torch.float32) * s.to(torch.float32) + return out.to(torch.float16) + + +def _make_packed_inputs(in_features: int, out_features: int, group_size: int, torch_device: str): + # num_groups is implicit (same convention as GPTQ: group per input channel) + num_groups = (in_features + group_size - 1) // group_size + in_packed = (in_features + 7) // 8 + out_packed = (out_features + 7) // 8 + + # Deterministic group mapping + g_idx = (torch.arange(in_features, device=torch_device, dtype=torch.int32) // group_size) + + # Random nibble-level values to avoid relying on sign/shift corner cases + q_nib = torch.randint(0, 16, (in_features, out_features), device=torch_device, dtype=torch.int32) + z_nib = torch.randint(0, 16, (num_groups, out_features), device=torch_device, dtype=torch.int32) + + # scales in fp16 + scales = (torch.rand((num_groups, out_features), device=torch_device, dtype=torch.float16) * 0.5 + 0.01) + + # Pack qweight: [in_packed, out_features] + qweight_packed = torch.zeros((in_packed, out_features), device=torch_device, dtype=torch.int32) + for i in range(8): + rows = torch.arange(i, in_features, 8, device=torch_device) + if rows.numel() == 0: + continue + pack_rows = (rows // 8).to(torch.int64) + qweight_packed[pack_rows, :] |= (q_nib[rows, :] & 0xF) << (i * 4) + + # Pack zeros: [num_groups, out_packed] + zeros_packed = torch.zeros((num_groups, out_packed), device=torch_device, dtype=torch.int32) + for j in range(8): + cols = torch.arange(j, out_features, 8, device=torch_device) + if cols.numel() == 0: + continue + pack_cols = (cols // 8).to(torch.int64) + zeros_packed[:, pack_cols] |= (z_nib[:, cols] & 0xF) << (j * 4) + + return qweight_packed, scales, zeros_packed, g_idx + + +# ============================================================================== +# Test Entrypoint +# ============================================================================== + +def test( + handle, + device, + in_features, + out_features, + group_size, + dtype=None, + sync=None, +): + print( + f"Testing Dequantize GPTQ on {InfiniDeviceNames[device]} with in_features:{in_features}, out_features:{out_features}, group_size:{group_size}" + ) + + # Infer torch device from a probe tensor created by TestTensor. + probe = TestTensor((1,), None, InfiniDtype.U8, device, mode="ones") + torch_device = probe.actual_tensor().device + + qweight_packed, scales, zeros_packed, g_idx = _make_packed_inputs( + in_features, out_features, group_size, torch_device + ) + + # Reference + ans = dequantize_gptq_ref(qweight_packed, scales, zeros_packed, g_idx, group_size) + + # Wrap into TestTensor so we get descriptors + raw pointers + qweight = TestTensor.from_torch(qweight_packed, InfiniDtype.I32, device) + zeros = TestTensor.from_torch(zeros_packed, InfiniDtype.I32, device) + qscales = TestTensor.from_torch(scales, InfiniDtype.F16, device) + g_idx_t = TestTensor.from_torch(g_idx, InfiniDtype.I32, device) + + out = TestTensor((in_features, out_features), None, InfiniDtype.F16, device, mode="zeros") + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateDequantizeGPTQDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + qweight.descriptor, + qscales.descriptor, + zeros.descriptor, + g_idx_t.descriptor, + ) + ) + + # Invalidate descriptors (same pattern as other tests) + for tensor in [qweight, zeros, qscales, g_idx_t, out]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetDequantizeGPTQWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_dequantize_gptq(): + check_error( + LIBINFINIOP.infiniopDequantizeGPTQ( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + qweight.data(), + qscales.data(), + zeros.data(), + g_idx_t.data(), + None, + ) + ) + + lib_dequantize_gptq() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation("PyTorch", lambda: dequantize_gptq_ref(qweight_packed, scales, zeros_packed, g_idx, group_size), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_dequantize_gptq(), device, NUM_PRERUN, NUM_ITERATIONS) + + check_error(LIBINFINIOP.infiniopDestroyDequantizeGPTQDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # This operator is intended to run on NVIDIA in our current environment. + # Make `srun ... python test/infiniop/dequantize_gptq.py` work without extra flags. + if not any( + getattr(args, name) + for name in [ + "cpu", + "nvidia", + "iluvatar", + "qy", + "cambricon", + "ascend", + "metax", + "moore", + "kunlun", + "hygon", + ] + ): + args.nvidia = True + + devices = [d for d in get_test_devices(args) if InfiniDeviceNames[d] == "NVIDIA"] + if not devices: + raise RuntimeError("No NVIDIA device selected; run with --nvidia under srun.") + + for device in devices: + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 283bdb1cd..270b91908 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -723,6 +723,42 @@ def dequantize_(lib): ] + +@OpRegister.operator +def dequantize_gptq_(lib): + lib.infiniopCreateDequantizeGPTQDescriptor.restype = c_int32 + lib.infiniopCreateDequantizeGPTQDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetDequantizeGPTQWorkspaceSize.restype = c_int32 + lib.infiniopGetDequantizeGPTQWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopDequantizeGPTQ.restype = c_int32 + lib.infiniopDequantizeGPTQ.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyDequantizeGPTQDescriptor.restype = c_int32 + lib.infiniopDestroyDequantizeGPTQDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32