Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/infiniop/ops/swiglu/ninetoothed/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import ninetoothed
import swiglu

import infiniop.ninetoothed.build


def build():
MAX_NDIM = 5

ndim_values = range(1, MAX_NDIM + 1)
dtype_values = (
ninetoothed.float16,
ninetoothed.bfloat16,
ninetoothed.float32,
)

constexpr_param_grid = {
"ndim": ndim_values,
"dtype": dtype_values,
"block_size": (1024,),
}

infiniop.ninetoothed.build.build(
swiglu.premake,
constexpr_param_grid,
caller="cuda",
op_name="swiglu",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
82 changes: 82 additions & 0 deletions src/infiniop/ops/swiglu/ninetoothed/swiglu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#ifndef SWIGLU_H
#define SWIGLU_H

#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"

#include "../../../../../build/ninetoothed/swiglu.h"
#include "../../../ninetoothed/utils.h"

namespace op::swiglu::ninetoothed {
class Descriptor final : public InfiniopDescriptor {

public:
Descriptor(
infiniopHandle_t handle,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id},
out_shape_{out_desc->shape()},
out_strides_{out_desc->strides()},
up_shape_{input_desc_vec[0]->shape()},
up_strides_{input_desc_vec[0]->strides()},
gate_shape_{input_desc_vec[1]->shape()},
gate_strides_{input_desc_vec[1]->strides()},
dtype_{out_desc->dtype()} {}

~Descriptor() = default;

size_t workspaceSize() const {
return 0;
}

static infiniStatus_t create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
*desc_ptr = new Descriptor(handle, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)};
auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)};
auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)};

if (launch_swiglu(stream,
out_nt,
up_nt,
gate_nt,
out_shape_.size(),
dtype_,
1024)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}

return INFINI_STATUS_SUCCESS;
}

private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;

std::vector<Size> out_shape_;
std::vector<Stride> out_strides_;

std::vector<Size> up_shape_;
std::vector<Stride> up_strides_;

std::vector<Size> gate_shape_;
std::vector<Stride> gate_strides_;

infiniDtype_t dtype_;
};
} // namespace op::swiglu::ninetoothed

#endif // SWIGLU_H
22 changes: 22 additions & 0 deletions src/infiniop/ops/swiglu/ninetoothed/swiglu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(output, up, gate):
output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
56 changes: 56 additions & 0 deletions src/infiniop/ops/swiglu/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,22 @@
#include "cpu/swiglu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NINETOOTHED)
#include "ninetoothed/swiglu.h"
#else
#include "nvidia/swiglu_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
#ifdef ENABLE_METAX_API
#if defined(ENABLE_NINETOOTHED)
#include "ninetoothed/swiglu.h"
#else
#include "metax/swiglu_metax.h"
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/swiglu_bang.h"
#endif
Expand Down Expand Up @@ -46,11 +54,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
Expand All @@ -61,8 +77,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_METAX, ninetoothed);
#else
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
Expand Down Expand Up @@ -92,11 +112,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
Expand All @@ -107,8 +135,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_METAX, ninetoothed);
#else
GET(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
Expand Down Expand Up @@ -145,11 +177,19 @@ __C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
Expand All @@ -160,8 +200,12 @@ __C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
Expand Down Expand Up @@ -193,11 +237,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
Expand All @@ -208,8 +260,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_METAX, ninetoothed);
#else
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
Expand Down