From 82683f8577049c1d1fc666608888396ee4b10b27 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 15 Jan 2026 15:38:53 +0800 Subject: [PATCH] issue/931 - ninetoothed swiglu --- src/infiniop/ops/swiglu/ninetoothed/build.py | 29 +++++++ src/infiniop/ops/swiglu/ninetoothed/swiglu.h | 82 +++++++++++++++++++ src/infiniop/ops/swiglu/ninetoothed/swiglu.py | 22 +++++ src/infiniop/ops/swiglu/operator.cc | 56 +++++++++++++ 4 files changed, 189 insertions(+) create mode 100644 src/infiniop/ops/swiglu/ninetoothed/build.py create mode 100644 src/infiniop/ops/swiglu/ninetoothed/swiglu.h create mode 100644 src/infiniop/ops/swiglu/ninetoothed/swiglu.py diff --git a/src/infiniop/ops/swiglu/ninetoothed/build.py b/src/infiniop/ops/swiglu/ninetoothed/build.py new file mode 100644 index 000000000..899bef8ec --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/build.py @@ -0,0 +1,29 @@ +import ninetoothed +import swiglu + +import infiniop.ninetoothed.build + + +def build(): + MAX_NDIM = 5 + + ndim_values = range(1, MAX_NDIM + 1) + dtype_values = ( + ninetoothed.float16, + ninetoothed.bfloat16, + ninetoothed.float32, + ) + + constexpr_param_grid = { + "ndim": ndim_values, + "dtype": dtype_values, + "block_size": (1024,), + } + + infiniop.ninetoothed.build.build( + swiglu.premake, + constexpr_param_grid, + caller="cuda", + op_name="swiglu", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.h b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h new file mode 100644 index 000000000..4aa2fa70e --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h @@ -0,0 +1,82 @@ +#ifndef SWIGLU_H +#define SWIGLU_H + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/swiglu.h" +#include "../../../ninetoothed/utils.h" + +namespace op::swiglu::ninetoothed { +class Descriptor final : public InfiniopDescriptor { + +public: + Descriptor( + infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id}, + out_shape_{out_desc->shape()}, + out_strides_{out_desc->strides()}, + up_shape_{input_desc_vec[0]->shape()}, + up_strides_{input_desc_vec[0]->strides()}, + gate_shape_{input_desc_vec[1]->shape()}, + gate_strides_{input_desc_vec[1]->strides()}, + dtype_{out_desc->dtype()} {} + + ~Descriptor() = default; + + size_t workspaceSize() const { + return 0; + } + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + *desc_ptr = new Descriptor(handle, out_desc, input_desc_vec); + return INFINI_STATUS_SUCCESS; + } + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)}; + auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)}; + auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)}; + + if (launch_swiglu(stream, + out_nt, + up_nt, + gate_nt, + out_shape_.size(), + dtype_, + 1024)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector out_shape_; + std::vector out_strides_; + + std::vector up_shape_; + std::vector up_strides_; + + std::vector gate_shape_; + std::vector gate_strides_; + + infiniDtype_t dtype_; +}; +} // namespace op::swiglu::ninetoothed + +#endif // SWIGLU_H diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.py b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py new file mode 100644 index 000000000..62074a84b --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(output, up, gate): + output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index 9d8e6406a..b3fabba32 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -6,14 +6,22 @@ #include "cpu/swiglu_cpu.h" #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "nvidia/swiglu_nvidia.cuh" #endif +#endif #ifdef ENABLE_KUNLUN_API #include "kunlun/swiglu_kunlun.h" #endif #ifdef ENABLE_METAX_API +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "metax/swiglu_metax.h" #endif +#endif #ifdef ENABLE_CAMBRICON_API #include "bang/swiglu_bang.h" #endif @@ -46,11 +54,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -61,8 +77,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_METAX, ninetoothed); +#else CREATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CREATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -92,11 +112,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_NVIDIA, ninetoothed); +#else GET(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -107,8 +135,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_METAX, ninetoothed); +#else GET(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API GET(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -145,11 +177,19 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -160,8 +200,12 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_METAX, ninetoothed); +#else CALCULATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CALCULATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -193,11 +237,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else DELETE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API DELETE(INFINI_DEVICE_QY, nvidia); #endif @@ -208,8 +260,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_METAX, ninetoothed); +#else DELETE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API DELETE(INFINI_DEVICE_CAMBRICON, bang); #endif