Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define __INFINIOP_API_H__

#include "infiniop/handle.h"
#include "infiniop/ops/2dmrope.h"
#include "infiniop/ops/3dmrope.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h"
Expand Down
32 changes: 32 additions & 0 deletions include/infiniop/ops/2dmrope.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef __INFINIOP_2DMROPE_API_H__
#define __INFINIOP_2DMROPE_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopMRoPE2DDescriptor_t;

__C __export infiniStatus_t infiniopCreateMRoPE2DDescriptor(
infiniopHandle_t handle,
infiniopMRoPE2DDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table);

__C __export infiniStatus_t infiniopGetMRoPE2DWorkspaceSize(infiniopMRoPE2DDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopMRoPE2D(
infiniopMRoPE2DDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void *stream);

__C __export infiniStatus_t infiniopDestroyMRoPE2DDescriptor(infiniopMRoPE2DDescriptor_t desc);

#endif
34 changes: 34 additions & 0 deletions include/infiniop/ops/3dmrope.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __INFINIOP_3DMROPE_API_H__
#define __INFINIOP_3DMROPE_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopMRoPE3DDescriptor_t;

__C __export infiniStatus_t infiniopCreateMRoPE3DDescriptor(
infiniopHandle_t handle,
infiniopMRoPE3DDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table,
infiniopTensorDescriptor_t rope_section);

__C __export infiniStatus_t infiniopGetMRoPE3DWorkspaceSize(infiniopMRoPE3DDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopMRoPE3D(
infiniopMRoPE3DDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void const *rope_section,
void *stream);

__C __export infiniStatus_t infiniopDestroyMRoPE3DDescriptor(infiniopMRoPE3DDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions scripts/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def run_tests(args):
"topkrouter.py",
"topksoftmax.py",
"zeros.py",
"2dmrope.py",
"3dmrope.py",
]:
result = subprocess.run(
f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True
Expand Down
140 changes: 140 additions & 0 deletions src/infiniop/ops/2dmrope/2dmrope.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#ifndef __2DMROPE_H__
#define __2DMROPE_H__

#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <cstdio>

#define DESCRIPTOR(NAMESPACE) \
\
namespace op::mrope2d::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
MRoPE2DInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
MRoPE2DInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t pos_desc, \
infiniopTensorDescriptor_t sin_desc, \
infiniopTensorDescriptor_t cos_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *y, \
const void *x, \
const void *pos_ids, \
const void *sin_table, \
const void *cos_table, \
void *stream) const; \
}; \
}

class MRoPE2DInfo {
private:
MRoPE2DInfo() = default;

public:
infiniDtype_t data_type, pos_type;
size_t seqlen, nhead, dhead, table_len, table_dim;
ptrdiff_t
y_stride_seqlen,
y_stride_nhead,
x_stride_seqlen,
x_stride_nhead;

static utils::Result<MRoPE2DInfo> createMRoPE2DInfo(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
CHECK_OR_RETURN(
y_desc != nullptr && x_desc != nullptr && pos_desc != nullptr && sin_desc != nullptr && cos_desc != nullptr,
INFINI_STATUS_NULL_POINTER);

const infiniDtype_t data_type = y_desc->dtype();
const infiniDtype_t pos_type = pos_desc->dtype();
CHECK_OR_RETURN(data_type == x_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
// // sin_table and cos_table should be float32 for precision
// CHECK_OR_RETURN(sin_desc->dtype() == INFINI_DTYPE_F32 && cos_desc->dtype() == INFINI_DTYPE_F32,
// INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(data_type == x_desc->dtype() && data_type == sin_desc->dtype() && data_type == cos_desc->dtype(),
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_DTYPE(data_type, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE_ANY_INT(pos_type);

CHECK_OR_RETURN(y_desc->ndim() == 3
&& x_desc->ndim() == 3
&& pos_desc->ndim() == 2
&& sin_desc->ndim() == 2
&& cos_desc->ndim() == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);

const auto nhead = y_desc->dim(0),
seqlen = y_desc->dim(1),
dhead = y_desc->dim(2),
table_len = sin_desc->dim(0),
table_dim = sin_desc->dim(1);
printf("y_desc->dim(0): %zu, y_desc->dim(1): %zu, y_desc->dim(2): %zu\n", y_desc->dim(0), y_desc->dim(1), y_desc->dim(2));
printf("x_desc->dim(0): %zu, x_desc->dim(1): %zu, x_desc->dim(2): %zu\n", x_desc->dim(0), x_desc->dim(1), x_desc->dim(2));
printf("pos_desc->dim(0): %zu, pos_desc->dim(1): %zu\n", pos_desc->dim(0), pos_desc->dim(1));
printf("sin_desc->dim(0): %zu, sin_desc->dim(1): %zu\n", sin_desc->dim(0), sin_desc->dim(1));
printf("cos_desc->dim(0): %zu, cos_desc->dim(1): %zu\n", cos_desc->dim(0), cos_desc->dim(1));
printf("nhead: %zu, seqlen: %zu, dhead: %zu, table_len: %zu, table_dim: %zu\n", nhead, seqlen, dhead, table_len, table_dim);

CHECK_OR_RETURN(nhead == x_desc->dim(0)
&& seqlen == x_desc->dim(1) && seqlen == pos_desc->dim(0)
&& dhead == x_desc->dim(2)
&& table_len == cos_desc->dim(0) && table_dim == cos_desc->dim(1)
&& pos_desc->dim(1) == 2,
INFINI_STATUS_BAD_TENSOR_SHAPE);

CHECK_OR_RETURN(dhead == table_dim * 4, INFINI_STATUS_BAD_TENSOR_SHAPE); // 2D MRoPE: dhead = table_dim * 4
// Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);
// pos_ids must be contiguous
CHECK_OR_RETURN(pos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);

return utils::Result<MRoPE2DInfo>(MRoPE2DInfo{
data_type,
pos_type,
seqlen,
nhead,
dhead,
table_len,
table_dim,
y_desc->stride(1),
y_desc->stride(0),
x_desc->stride(1),
x_desc->stride(0),
});
}
};

#endif
42 changes: 42 additions & 0 deletions src/infiniop/ops/2dmrope/cuda/mrope.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
template <class Tp, class Ta>
static __device__ void padding(
Ta *__restrict__ y_,
int const stride_token_y,
int const stride_head_y,
Ta const *__restrict__ x_,
int const stride_token_x,
int const stride_head_x,
Tp const *__restrict__ pos_,
float const *__restrict__ sin_table,
float const *__restrict__ cos_table) {

// n = gridDim.y
// nh_h = gridDim.x
int nh_l = blockDim.y,
dh_div_2 = blockDim.x,
it = blockIdx.y,
ih_h = blockIdx.x,
ih_l = threadIdx.y,
ih = ih_h * nh_l + ih_l,
i = threadIdx.x;

// 计算 x 和 y 的位置, 每相距 d_div_2 的两个为一组
auto x1 = x_ + it * stride_token_x + ih * stride_head_x + i;
auto x2 = x_ + it * stride_token_x + ih * stride_head_x + i + dh_div_2;
auto y1 = y_ + it * stride_token_y + ih * stride_head_y + i;
auto y2 = y_ + it * stride_token_y + ih * stride_head_y + i + dh_div_2;

// 获取位置索引
// 2 维 mrope 的 w, h 维度均分 d_div_2,每个分到 d_div_2 / 2
int id_h = i / (dh_div_2 / 2); // w, h 的维度索引
int id_l = i % (dh_div_2 / 2); // w, h 维度内索引
auto pos = pos_[it * 2 + id_h]; // 2 维 pos 的 shape: [it, 2], strides: [2, 1]
float sin = sin_table[pos * (dh_div_2 / 2) + id_l],
cos = cos_table[pos * (dh_div_2 / 2) + id_l],
a = x1[0],
b = x2[0];

// 应用旋转并写入 y
y1[0] = Ta(a * cos - b * sin);
y2[0] = Ta(a * sin + b * cos);
}
Loading