Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gelutanh.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
Expand All @@ -20,6 +21,7 @@
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/quickgelu.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
Expand Down
43 changes: 43 additions & 0 deletions include/infiniop/ops/gelutanh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef __INFINIOP_GELUTANH_API_H__
#define __INFINIOP_GELUTANH_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopGeluTanhDescriptor_t;

/**
* Create GELU-Tanh descriptor
*
* y = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
*/
__C __export infiniStatus_t infiniopCreateGeluTanhDescriptor(
infiniopHandle_t handle,
infiniopGeluTanhDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x);

/**
* Query workspace size
*/
__C __export infiniStatus_t infiniopGetGeluTanhWorkspaceSize(
infiniopGeluTanhDescriptor_t desc,
size_t *size);

/**
* Launch GELU-Tanh operator
*/
__C __export infiniStatus_t infiniopGeluTanh(
infiniopGeluTanhDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);

/**
* Destroy descriptor
*/
__C __export infiniStatus_t infiniopDestroyGeluTanhDescriptor(
infiniopGeluTanhDescriptor_t desc);

#endif
42 changes: 42 additions & 0 deletions include/infiniop/ops/quickgelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef __INFINIOP_QUICKGELU_API_H__
#define __INFINIOP_QUICKGELU_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopQuickGeluDescriptor_t;

/**
* Create QuickGELU descriptor
* y = x * sigmoid(1.702 * x)
*/
__C __export infiniStatus_t infiniopCreateQuickGeluDescriptor(
infiniopHandle_t handle,
infiniopQuickGeluDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x);

/**
* Query workspace size
*/
__C __export infiniStatus_t infiniopGetQuickGeluWorkspaceSize(
infiniopQuickGeluDescriptor_t desc,
size_t *size);

/**
* Launch QuickGELU operator
*/
__C __export infiniStatus_t infiniopQuickGelu(
infiniopQuickGeluDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);

/**
* Destroy descriptor
*/
__C __export infiniStatus_t infiniopDestroyQuickGeluDescriptor(
infiniopQuickGeluDescriptor_t desc);

#endif
14 changes: 13 additions & 1 deletion src/infiniop/ops/add/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/add_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
Expand Down Expand Up @@ -51,6 +51,9 @@ __C infiniStatus_t infiniopCreateAddDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -91,6 +94,9 @@ __C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, siz
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -139,6 +145,9 @@ __C infiniStatus_t infiniopAdd(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -181,6 +190,9 @@ infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down
14 changes: 13 additions & 1 deletion src/infiniop/ops/conv/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/conv_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/conv_nvidia.cuh"
#endif

Expand Down Expand Up @@ -45,6 +45,9 @@ __C __export infiniStatus_t infiniopCreateConvDescriptor(infiniopHandle_t handle
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -76,6 +79,9 @@ infiniopGetConvWorkspaceSize(
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -115,6 +121,9 @@ __C infiniStatus_t infiniopConv(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -142,6 +151,9 @@ infiniopDestroyConvDescriptor(infiniopConvDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif

default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down
14 changes: 13 additions & 1 deletion src/infiniop/ops/gelu/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/gelu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/gelu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
Expand Down Expand Up @@ -43,6 +43,9 @@ __C infiniStatus_t infiniopCreateGeluDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -77,6 +80,9 @@ __C infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t desc, s
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -118,6 +124,9 @@ __C infiniStatus_t infiniopGelu(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down Expand Up @@ -154,6 +163,9 @@ infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
Expand Down
53 changes: 53 additions & 0 deletions src/infiniop/ops/gelutanh/cpu/gelutanh_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "gelutanh_cpu.h"

namespace op::gelutanh::cpu {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &x_desc = input_desc_vec.at(0);
const auto &y_shape = out_desc->shape();
const auto &x_shape = x_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(y_shape, x_shape);

CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

(void)workspace;
(void)workspace_size;

switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<GeluTanhOp, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<GeluTanhOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<GeluTanhOp, double>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<GeluTanhOp, bf16_t>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}

} // namespace op::gelutanh::cpu

27 changes: 27 additions & 0 deletions src/infiniop/ops/gelutanh/cpu/gelutanh_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef __GELUTANH_CPU_H__
#define __GELUTANH_CPU_H__

#include "../../../elementwise/cpu/elementwise_cpu.h"

#include <cmath>

ELEMENTWISE_DESCRIPTOR(gelutanh, cpu)

namespace op::gelutanh::cpu {
typedef struct GeluTanhOp {
public:
static constexpr size_t num_inputs = 1;

template <typename T>
T operator()(const T &x) const {
// y = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
constexpr T alpha = static_cast<T>(0.7978845608); // sqrt(2/pi)
constexpr T beta = static_cast<T>(0.044715);
T inner = alpha * (x + beta * x * x * x);
return x * static_cast<T>(0.5) * (static_cast<T>(1) + std::tanh(inner));
}
} GeluTanhOp;
} // namespace op::gelutanh::cpu

#endif // __GELUTANH_CPU_H__

59 changes: 59 additions & 0 deletions src/infiniop/ops/gelutanh/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#ifndef __GELUTANH_CUDA_H__
#define __GELUTANH_CUDA_H__

#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cmath>

namespace op::gelutanh::cuda {

typedef struct GeluTanhOp {
public:
static constexpr size_t num_inputs = 1;

// GELU-Tanh constants
// static constexpr float alpha = std::sqrt(2.0 / M_PI);
// static constexpr float beta = 0.044715f;
static constexpr float alpha = 0.7978845608f; // sqrt(2/pi)
static constexpr float beta = 0.044715f;
// f32 tanh helper
__device__ __forceinline__ float tanh_f32_func(float x) const {
return tanhf(x);
}

template <typename T>
__device__ __forceinline__ T operator()(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
// half2 -> float2
float2 vf = __half22float2(x);
float inner_x0 = alpha * (vf.x + beta * vf.x * vf.x * vf.x);
float inner_x1 = alpha * (vf.y + beta * vf.y * vf.y * vf.y);
float2 vr = make_float2(tanh_f32_func(inner_x0) * 0.5f + 0.5f,
tanh_f32_func(inner_x1) * 0.5f + 0.5f);
return __hmul2(x, __float22half2_rn(vr)); // y = x * 0.5 * (1 + tanh(...))
} else if constexpr (std::is_same_v<T, half>) {
float xf = __half2float(x);
float inner = alpha * (xf + beta * xf * xf * xf);
float yf = xf * 0.5f * (1.0f + tanh_f32_func(inner));
return __float2half_rn(yf);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
float xf = __bfloat162float(x);
float inner = alpha * (xf + beta * xf * xf * xf);
float yf = xf * 0.5f * (1.0f + tanh_f32_func(inner));
return __float2bfloat16(yf);
} else if constexpr (std::is_same_v<T, float>) {
float inner = alpha * (x + beta * x * x * x);
return x * 0.5f * (1.0f + tanh_f32_func(inner));
} else { // double
double inner = alpha * (x + beta * x * x * x);
return x * 0.5 * (1.0 + std::tanh(inner));
}
}

} GeluTanhOp;

} // namespace op::gelutanh::cuda

#endif // __GELUTANH_CUDA_H__

Loading