Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,47 @@
#define __INFINIOP_API_H__

#include "infiniop/handle.h"
#include "infiniop/ops/abs.h"
#include "infiniop/ops/acos.h"
#include "infiniop/ops/acosh.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/asin.h"
#include "infiniop/ops/asinh.h"
#include "infiniop/ops/atan.h"
#include "infiniop/ops/atanh.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/ceil.h"
#include "infiniop/ops/cos.h"
#include "infiniop/ops/cosh.h"
#include "infiniop/ops/erf.h"
#include "infiniop/ops/floor.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/div.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/log.h"
#include "infiniop/ops/logsoftmax.h"
#include "infiniop/ops/lp_norm.h"
#include "infiniop/ops/max.h"
#include "infiniop/ops/min.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/neg.h"
#include "infiniop/ops/ones.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/reciprocal.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/round.h"
#include "infiniop/ops/sign.h"
#include "infiniop/ops/sinh.h"
#include "infiniop/ops/sqrt.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rope.h"
Expand All @@ -30,6 +52,7 @@
#include "infiniop/ops/softplus.h"
#include "infiniop/ops/sub.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/ops/tan.h"
#include "infiniop/ops/tanh.h"
#include "infiniop/ops/topkrouter.h"
#include "infiniop/ops/topksoftmax.h"
Expand Down
8 changes: 8 additions & 0 deletions include/infiniop/ops/abs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ABS_API_H__
#define __INFINIOP_ABS_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(abs, Abs)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/acos.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ACOS_API_H__
#define __INFINIOP_ACOS_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(acos, Acos)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/acosh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ACOSH_API_H__
#define __INFINIOP_ACOSH_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(acosh, Acosh)

#endif
22 changes: 2 additions & 20 deletions include/infiniop/ops/add.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
#ifndef __INFINIOP_ADD_API_H__
#define __INFINIOP_ADD_API_H__

#include "../operator_descriptor.h"
#include "binary_op_api.h"

typedef struct InfiniopDescriptor *infiniopAddDescriptor_t;

__C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
infiniopAddDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);

__C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream);

__C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc);
BINARY_OP_API_DECLARE(add, Add)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/asin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ASIN_API_H__
#define __INFINIOP_ASIN_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(asin, Asin)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/asinh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ASINH_API_H__
#define __INFINIOP_ASINH_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(asinh, Asinh)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/atan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ATAN_API_H__
#define __INFINIOP_ATAN_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(atan, Atan)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/atanh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ATANH_API_H__
#define __INFINIOP_ATANH_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(atanh, Atanh)

#endif
50 changes: 50 additions & 0 deletions include/infiniop/ops/binary_op_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef __INFINIOP_BINARY_OP_API_H__
#define __INFINIOP_BINARY_OP_API_H__

#include "../operator_descriptor.h"

/**
* @brief Macro to generate the C API header for a binary operator.
*
* This macro generates all the necessary declarations for a binary operator:
* - Descriptor type definition
* - Create descriptor function
* - Get workspace size function
* - Execute operator function
* - Destroy descriptor function
*
* Usage:
* BINARY_OP_API_DECLARE(div, Div)
* BINARY_OP_API_DECLARE(pow, Pow)
*
* @param OP_NAME Lowercase operator name (e.g., div, pow, mod)
* @param OP_NAME_UPPER Uppercase operator name (e.g., Div, Pow, Mod)
*/
#define BINARY_OP_API_DECLARE(OP_NAME, OP_NAME_UPPER) \
\
typedef struct InfiniopDescriptor *infiniop##OP_NAME_UPPER##Descriptor_t; \
\
__C __export infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \
infiniopHandle_t handle, \
infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \
infiniopTensorDescriptor_t c, \
infiniopTensorDescriptor_t a, \
infiniopTensorDescriptor_t b); \
\
__C __export infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
size_t *size); \
\
__C __export infiniStatus_t infiniop##OP_NAME_UPPER( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
void *workspace, \
size_t workspace_size, \
void *c, \
const void *a, \
const void *b, \
void *stream); \
\
__C __export infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \
infiniop##OP_NAME_UPPER##Descriptor_t desc);

#endif // __INFINIOP_BINARY_OP_API_H__
8 changes: 8 additions & 0 deletions include/infiniop/ops/ceil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_CEIL_API_H__
#define __INFINIOP_CEIL_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(ceil, Ceil)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/cos.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_COS_API_H__
#define __INFINIOP_COS_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(cos, Cos)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/cosh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_COSH_API_H__
#define __INFINIOP_COSH_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(cosh, Cosh)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/div.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_DIV_API_H__
#define __INFINIOP_DIV_API_H__

#include "binary_op_api.h"

BINARY_OP_API_DECLARE(div, Div)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/erf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ERF_API_H__
#define __INFINIOP_ERF_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(erf, Erf)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/floor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_FLOOR_API_H__
#define __INFINIOP_FLOOR_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(floor, Floor)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/log.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_LOG_API_H__
#define __INFINIOP_LOG_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(log, Log)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/max.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_MAX_API_H__
#define __INFINIOP_MAX_API_H__

#include "binary_op_api.h"

BINARY_OP_API_DECLARE(max, Max)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/min.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_MIN_API_H__
#define __INFINIOP_MIN_API_H__

#include "binary_op_api.h"

BINARY_OP_API_DECLARE(min, Min)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/mod.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_MOD_API_H__
#define __INFINIOP_MOD_API_H__

#include "binary_op_api.h"

BINARY_OP_API_DECLARE(mod, Mod)

#endif
22 changes: 2 additions & 20 deletions include/infiniop/ops/mul.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
#ifndef __INFINIOP_MUL_API_H__
#define __INFINIOP_MUL_API_H__

#include "../operator_descriptor.h"
#include "binary_op_api.h"

typedef struct InfiniopDescriptor *infiniopMulDescriptor_t;

__C __export infiniStatus_t infiniopCreateMulDescriptor(infiniopHandle_t handle,
infiniopMulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);

__C __export infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopMul(infiniopMulDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream);

__C __export infiniStatus_t infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc);
BINARY_OP_API_DECLARE(mul, Mul)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/neg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_NEG_API_H__
#define __INFINIOP_NEG_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(neg, Neg)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/pow.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_POW_API_H__
#define __INFINIOP_POW_API_H__

#include "binary_op_api.h"

BINARY_OP_API_DECLARE(pow, Pow)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/reciprocal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_RECIPROCAL_API_H__
#define __INFINIOP_RECIPROCAL_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(reciprocal, Reciprocal)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/round.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_ROUND_API_H__
#define __INFINIOP_ROUND_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(round, Round)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/sign.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_SIGN_API_H__
#define __INFINIOP_SIGN_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(sign, Sign)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/sinh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_SINH_API_H__
#define __INFINIOP_SINH_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(sinh, Sinh)

#endif
8 changes: 8 additions & 0 deletions include/infiniop/ops/sqrt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __INFINIOP_SQRT_API_H__
#define __INFINIOP_SQRT_API_H__

#include "unary_op_api.h"

UNARY_OP_API_DECLARE(sqrt, Sqrt)

#endif
Loading