diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh new file mode 100644 index 000000000..3d6b13b53 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_METAX_CUH__ +#define __ADD_RMS_NORM_METAX_CUH__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca new file mode 100644 index 000000000..8339ec5aa --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca @@ -0,0 +1,167 @@ +#include "../../../devices/metax/metax_common.h" +#include "add_rms_norm_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +// Kernel function template for add_rms_norm on Metax platform +template +INFINIOP_METAX_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::metax { + +// Internal opaque structure for Metax device handle +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor +Descriptor::~Descriptor() { + delete _opaque; +} + +// Create descriptor for add_rms_norm operator +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// Launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + hcStream_t stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + // Handle different data type combinations following Metax pattern + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__hpcc_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __hpcc_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__hpcc_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +// Main calculation function +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *a, const void *b, const void *weight, + void *residual_out, void *stream_) const { + + // Check workspace size + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + // Extract tensor strides and dimensions + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto stream = reinterpret_cast(stream_); + + // Launch kernel with appropriate block size based on device capability + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::metax diff --git a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h new file mode 100644 index 000000000..9d3f810f2 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.h @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_MOORE_H__ +#define __ADD_RMS_NORM_MOORE_H__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(moore) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu new file mode 100644 index 000000000..fe7a49765 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu @@ -0,0 +1,183 @@ +#include "../../../devices/moore/moore_common.h" +#include "add_rms_norm_moore.h" + +#include "../../../devices/moore/moore_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +// Kernel function template for add_rms_norm on Moore platform +template +INFINIOP_MOORE_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::moore { + +// Internal opaque structure for Moore device handle +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor +Descriptor::~Descriptor() { + delete _opaque; +} + +// Create descriptor for add_rms_norm operator +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// Launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + musaStream_t musa_stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + // Handle different data type combinations + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __mt_bfloat16, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__mt_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__mt_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +// Main calculation function +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *a, const void *b, const void *weight, + void *residual_out, void *stream) const { + + // Check workspace size + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + // Extract tensor strides and dimensions + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto musa_stream = reinterpret_cast(stream); + + // Launch kernel with appropriate block size based on device capability + if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, musa_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::moore diff --git a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu index 03601205f..6fc9175bb 100644 --- a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu +++ b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu @@ -143,7 +143,15 @@ infiniStatus_t Descriptor::calculate( auto cuda_stream = reinterpret_cast(stream); // launch kernel with different block sizes - if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel( batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, @@ -151,8 +159,8 @@ infiniStatus_t Descriptor::calculate( a, stride_a_batch, stride_a_nhead, b, stride_b_batch, stride_b_nhead, weight, _info.wtype, _info.epsilon, cuda_stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { - CHECK_STATUS(launchKernel( + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, residual_out, stride_residual_out_batch, stride_residual_out_nhead, diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc index a856e5447..1151c6340 100644 --- a/src/infiniop/ops/add_rms_norm/operator.cc +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -17,12 +17,10 @@ // #include "bang/add_rms_norm_bang.h" #endif #ifdef ENABLE_METAX_API -// TODO: Add Metax implementation -// #include "metax/add_rms_norm_metax.cuh" +#include "metax/add_rms_norm_metax.cuh" #endif #ifdef ENABLE_MOORE_API -// TODO: Add Moore implementation -// #include "moore/add_rms_norm_moore.h" +#include "moore/add_rms_norm_moore.h" #endif #ifdef ENABLE_KUNLUN_API // TODO: Add Kunlun implementation @@ -61,6 +59,12 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -94,6 +98,12 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -138,6 +148,12 @@ __C infiniStatus_t infiniopAddRMSNorm( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -173,6 +189,12 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip #ifdef ENABLE_ILUVATAR_API DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif #ifdef ENABLE_QY_API DESTROY(INFINI_DEVICE_QY, nvidia); #endif