From f736704c2bb216ad39058f6c483146491edde43d Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 14:14:26 +0800 Subject: [PATCH 01/15] =?UTF-8?q?Device=20=E5=A2=9E=E5=8A=A0=E6=91=A9?= =?UTF-8?q?=E5=B0=94=E7=BA=BF=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/data_type.h | 8 +++++ src/devices/handle.cc | 14 +++++++++ src/devices/musa/common_musa.h | 38 +++++++++++++++++++++++ src/devices/musa/musa_handle.cc | 30 ++++++++++++++++++ src/devices/musa/musa_handle.h | 37 ++++++++++++++++++++++ src/devices/musa/pool.h | 50 ++++++++++++++++++++++++++++++ src/devices/musa/utils.cc | 17 ++++++++++ xmake.lua | 55 +++++++++++++++++++++++++++++++++ 8 files changed, 249 insertions(+) create mode 100644 src/devices/musa/common_musa.h create mode 100644 src/devices/musa/musa_handle.cc create mode 100644 src/devices/musa/musa_handle.h create mode 100644 src/devices/musa/pool.h create mode 100644 src/devices/musa/utils.cc diff --git a/include/data_type.h b/include/data_type.h index e2f24c4f..954a42ea 100644 --- a/include/data_type.h +++ b/include/data_type.h @@ -46,4 +46,12 @@ const static struct DataLayout F64 = {1, 1, 8, 52, 11}; // clang-format on +DT get_F16(); + +DT get_U32(); + +DT get_F32(); + +DT get_U64(); + #endif// __DATA_TYPE_H__ diff --git a/src/devices/handle.cc b/src/devices/handle.cc index 45779776..d00278e5 100644 --- a/src/devices/handle.cc +++ b/src/devices/handle.cc @@ -14,6 +14,9 @@ #ifdef ENABLE_METAX_GPU #include "./maca/maca_handle.h" #endif +#ifdef ENABLE_MT_GPU +#include "./musa/musa_handle.h" +#endif __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device device, int device_id) { @@ -48,6 +51,11 @@ __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device d case DevMetaxGpu: { return createMacaHandle((MacaHandle_t *) handle_ptr, device_id); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return createMusaHandle((MusaHandle_t *) handle_ptr, device_id); + } #endif } return STATUS_BAD_DEVICE; @@ -81,6 +89,12 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { case DevMetaxGpu: { return deleteMacaHandle((MacaHandle_t) handle); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + deleteMusaHandle((MusaHandle_t) handle); + return STATUS_SUCCESS; + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/devices/musa/common_musa.h b/src/devices/musa/common_musa.h new file mode 100644 index 00000000..bfed9900 --- /dev/null +++ b/src/devices/musa/common_musa.h @@ -0,0 +1,38 @@ +#ifndef __COMMON_MUSA_H__ +#define __COMMON_MUSA_H__ + +enum class Type { + QINT4, + QINT8, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + HALF, + BFLOAT16, + FLOAT, + DOUBLE, + BOOL, +}; + +enum class Format { + UNKNOWN, + NCW, + NWC, + NCHW, + NHWC, + HWCN, + NCDHW, + NDHWC, + DHWCN, +}; + +#define MAX_THREADS_PER_BLOCK 1024 +#define MAX_WARP_PER_BLOCK 32 +#define WARP_SIZE 32 + +#endif // __COMMON_MUSA_H__ \ No newline at end of file diff --git a/src/devices/musa/musa_handle.cc b/src/devices/musa/musa_handle.cc new file mode 100644 index 00000000..00f43e9d --- /dev/null +++ b/src/devices/musa/musa_handle.cc @@ -0,0 +1,30 @@ +#include "musa_handle.h" +#include + +infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) { + int device_count; + musaGetDeviceCount(&device_count); + if (device_id >= device_count) { + return STATUS_BAD_DEVICE; + } + + // if (musaSetDevice(device_id) != musaSuccess){ + // return STATUS_BAD_DEVICE; + // } + + auto mublas_pool = std::make_shared>(); + mublasHandle_t *mublas_handle = new mublasHandle_t; + mublasCreate(mublas_handle); + mublas_pool->push(mublas_handle); + + *handle_ptr = new MusaContext{DevMtGpu, device_id, std::move(mublas_pool)}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t deleteMusaHandle(MusaHandle_t handle_ptr) { + handle_ptr->mublas_handles_t = nullptr; + delete handle_ptr; + + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/devices/musa/musa_handle.h b/src/devices/musa/musa_handle.h new file mode 100644 index 00000000..f91caba8 --- /dev/null +++ b/src/devices/musa/musa_handle.h @@ -0,0 +1,37 @@ +#ifndef __MUSA_HANDLE_H__ +#define __MUSA_HANDLE_H__ + +#include "pool.h" +#include "device.h" +#include "status.h" +#include "ops/matmul/matmul.h" +#include +#include +#include +#include + +struct MusaContext { + Device device; + int device_id; + std::shared_ptr> mublas_handles_t; +}; +typedef struct MusaContext *MusaHandle_t; + +infiniopStatus_t createMusaHandle(MusaHandle_t *handle_ptr, int device_id); + +infiniopStatus_t deleteMusaHandle(MusaHandle_t handle_ptr); + +template +void use_mublas(std::shared_ptr> mublas_handles_t, int device_id, MUstream stream, T const &f) { + mublasHandle_t *handle = mublas_handles_t->pop(); + if (!handle) { + // musaSetDevice(device_id); + mublasHandle_t *handle = new mublasHandle_t; + mublasCreate(handle); + } + mublasSetStream(*handle, (MUstream) stream); + f(*handle); + mublas_handles_t->push(handle); +} + +#endif // __MUSA_HANDLE_H__ \ No newline at end of file diff --git a/src/devices/musa/pool.h b/src/devices/musa/pool.h new file mode 100644 index 00000000..9c6a107b --- /dev/null +++ b/src/devices/musa/pool.h @@ -0,0 +1,50 @@ +#ifndef __POOL_MUSA_H__ +#define __POOL_MUSA_H__ + +#include +#include +#include + +template +class Pool { +public: + Pool() : _head(nullptr) {} + + Pool(const Pool &) = delete; + + Pool(Pool &&pool) noexcept : _head(pool._head.exchange(nullptr)) {} + + ~Pool() { + while (this->pop()) {} + } + + void push(T *val) const { + Node *new_node = new Node(val); + new_node->next = _head.load(); + while (!_head.compare_exchange_weak(new_node->next, new_node)); + } + + T* pop() const { + Node *top = _head.load(); + Node *new_head = nullptr; + do { + if (!top) { + return nullptr; + } + new_head = top->next; + } while (!_head.compare_exchange_weak(top, new_head)); + return top->data; + } + +private: + template + struct Node { + U *data; + Node *next; + Node(U *data) : data(data), next(nullptr) {} + }; + + mutable std::atomic *> _head; +}; + +#endif // __POOL_MUSA_H__ \ No newline at end of file diff --git a/src/devices/musa/utils.cc b/src/devices/musa/utils.cc new file mode 100644 index 00000000..466fcf7d --- /dev/null +++ b/src/devices/musa/utils.cc @@ -0,0 +1,17 @@ +#include "data_type.h" + +DT get_F16() { + return F16; +} + +DT get_F32() { + return F32; +} + +DT get_U32() { + return U32; +} + +DT get_U64() { + return U64; +} \ No newline at end of file diff --git a/xmake.lua b/xmake.lua index ce8f065a..4f3adfdb 100644 --- a/xmake.lua +++ b/xmake.lua @@ -48,6 +48,13 @@ option("metax-gpu") option_end() +option("mthreads-gpu") + set_default(false) + set_showmenu(true) + set_description("Enable or disable MThreads GPU kernel") + add_defines("ENABLE_MT_GPU") +option_end() + option("sugon-dcu") set_default(false) set_showmenu(true) @@ -172,6 +179,51 @@ if has_config("cambricon-mlu") then end +if has_config("mthreads-gpu") then + + add_defines("ENABLE_MT_GPU") + local musa_home = os.getenv("MUSA_INSTALL_PATH") + -- Add include dirs + add_includedirs(musa_home .. "/include") + -- Add shared lib + add_linkdirs(musa_home .. "/lib") + add_links("libmusa.so") + add_links("libmusart.so") + add_links("libmudnn.so") + add_links("libmublas.so") + + rule("mu") + set_extensions(".mu") + on_load(function (target) + target:add("includedirs", "include") + end) + + on_build_file(function (target, sourcefile) + local objectfile = target:objectfile(sourcefile) + os.mkdir(path.directory(objectfile)) + + local mcc = "/usr/local/musa/bin/mcc" + local includedirs = table.concat(target:get("includedirs"), " ") + local args = {"-c", sourcefile, "-o", objectfile, "-I/usr/local/musa/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"} + for _, includedir in ipairs(target:get("includedirs")) do + table.insert(args, "-I" .. includedir) + end + + os.execv(mcc, args) + table.insert(target:objectfiles(), objectfile) + end) + rule_end() + + target("mthreads-gpu") + set_kind("static") + set_languages("cxx17") + add_files("src/devices/musa/*.cc", "src/ops/*/musa/*.cc") + add_files("src/ops/*/musa/*.mu", {rule = "mu"}) + add_cxflags("-lstdc++ -Wall -fPIC") + target_end() + +end + if has_config("ascend-npu") then add_defines("ENABLE_ASCEND_NPU") @@ -315,6 +367,9 @@ target("infiniop") if has_config("metax-gpu") then add_deps("metax-gpu") end + if has_config("mthreads-gpu") then + add_deps("mthreads-gpu") + end set_languages("cxx17") add_files("src/devices/handle.cc") add_files("src/ops/*/operator.cc") From 49ee1e3379a90e1ee086b7daf29b4d66980f9e11 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 14:58:09 +0800 Subject: [PATCH 02/15] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=91=A9=E5=B0=94?= =?UTF-8?q?=E7=BA=BF=E7=A8=8B=20MatMul=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/matmul.py | 33 +++++++++++++ operatorspy/tests/test_utils.py | 5 ++ src/ops/matmul/musa/matmul_musa.cc | 48 +++++++++++++++++++ src/ops/matmul/musa/matmul_musa.h | 45 +++++++++++++++++ src/ops/matmul/musa/matmul_musa.mu | 77 ++++++++++++++++++++++++++++++ src/ops/matmul/operator.cc | 23 +++++++++ 6 files changed, 231 insertions(+) create mode 100644 src/ops/matmul/musa/matmul_musa.cc create mode 100644 src/ops/matmul/musa/matmul_musa.h create mode 100644 src/ops/matmul/musa/matmul_musa.mu diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index ba590447..46469222 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -325,6 +325,37 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for ( + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride, + b_stride, + c_stride, + dtype, + ) in test_cases: + test( + lib, + handle, + "musa", + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride, + b_stride, + c_stride, + dtype, + ) + if __name__ == "__main__": test_cases = [ # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype @@ -387,6 +418,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) + if args.musa: + test_musa(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/test_utils.py b/operatorspy/tests/test_utils.py index 68b71bc4..6e4960d5 100644 --- a/operatorspy/tests/test_utils.py +++ b/operatorspy/tests/test_utils.py @@ -32,6 +32,11 @@ def get_args(): action="store_true", help="Run ASCEND NPU test", ) + parser.add_argument( + "--musa", + action="store_true", + help="Run MUSA test", + ) return parser.parse_args() diff --git a/src/ops/matmul/musa/matmul_musa.cc b/src/ops/matmul/musa/matmul_musa.cc new file mode 100644 index 00000000..8a090291 --- /dev/null +++ b/src/ops/matmul/musa/matmul_musa.cc @@ -0,0 +1,48 @@ +#include "matmul_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include +#include + +#include + +infiniopStatus_t musaCreateMatmulDescriptor(MusaHandle_t handle, + MatmulMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + float alpha, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + float beta) { + DT dtype = c_desc->dt; + + if (dtype != F16 && dtype != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + + infiniopStatus_t *status = new infiniopStatus_t{STATUS_EXECUTION_FAILED}; + auto info = MatmulInfo(c_desc, a_desc, b_desc, status); + if (*status != STATUS_SUCCESS) { + return *status; + } + + *desc_ptr = new MatmulMusaDescriptor{ + DevMtGpu, + dtype, + handle->device_id, + info, + alpha, + beta, + handle->mublas_handles_t}; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetMatmulWorkspaceSize(MatmulMusaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyMatmulDescriptor(MatmulMusaDescriptor_t desc) { + desc->mublas_handles_t = nullptr; + delete desc; + return STATUS_SUCCESS; +} \ No newline at end of file diff --git a/src/ops/matmul/musa/matmul_musa.h b/src/ops/matmul/musa/matmul_musa.h new file mode 100644 index 00000000..617a8318 --- /dev/null +++ b/src/ops/matmul/musa/matmul_musa.h @@ -0,0 +1,45 @@ +#ifndef __MUSA_MATMUL_H__ +#define __MUSA_MATMUL_H__ + +#include +#include +#include +#include +#include +#include "../blas.h" +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +typedef struct MatmulMusaDescriptor { + Device device; + DT dtype; + int device_id; + MatmulInfo info; + float alpha; + float beta; + std::shared_ptr> mublas_handles_t; +} MatmulMusaDescriptor; + +typedef struct MatmulMusaDescriptor *MatmulMusaDescriptor_t; + +infiniopStatus_t musaCreateMatmulDescriptor(MusaHandle_t handle, + MatmulMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + float alpha, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + float beta); + +infiniopStatus_t musaGetMatmulWorkspaceSize(MatmulMusaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t musaMatmul(MatmulMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream); + +infiniopStatus_t musaDestroyMatmulDescriptor(MatmulMusaDescriptor_t desc); + +#endif // __MUSA_MATMUL_H__ \ No newline at end of file diff --git a/src/ops/matmul/musa/matmul_musa.mu b/src/ops/matmul/musa/matmul_musa.mu new file mode 100644 index 00000000..4685beb8 --- /dev/null +++ b/src/ops/matmul/musa/matmul_musa.mu @@ -0,0 +1,77 @@ +#include "../../../devices/musa/musa_handle.h" +#include "../../utils.h" +#include "../blas.h" +#include "matmul_musa.h" +#include +#include + +template +infiniopStatus_t matmul_musa(MatmulMusaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) { + auto info = desc->info; + + if (info.is_transed) { + std::swap(a, b); + } + + Tdata alpha_, beta_; + musaDataType_t a_type, b_type, c_type; + mublasComputeType_t compute_type; + + if constexpr (std::is_same::value) { + alpha_ = __float2half(alpha); + beta_ = __float2half(beta); + a_type = b_type = c_type = MUSA_R_16F; + compute_type = MUBLAS_COMPUTE_16F; + } else { + alpha_ = alpha; + beta_ = beta; + a_type = b_type = c_type = MUSA_R_32F; + compute_type = MUBLAS_COMPUTE_32F_FAST_TF32; + } + + auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T; + auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T; + + use_mublas(desc->mublas_handles_t, desc->device_id, (MUstream) stream, + [&](mublasHandle_t handle) { mublasGemmStridedBatchedEx( + handle, + op_a, + op_b, + info.m, + info.n, + info.k, + &alpha_, + a, + a_type, + info.a_matrix.ld(), + info.a_matrix.stride, + b, + b_type, + info.b_matrix.ld(), + info.b_matrix.stride, + &beta_, + c, + c_type, + info.c_matrix.ld(), + info.c_matrix.stride, + info.batch, + compute_type, + MUBLAS_GEMM_DEFAULT);}); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaMatmul(MatmulMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream) { + if (desc->dtype == F16) { + return matmul_musa(desc, c, desc->beta, a, b, desc->alpha, stream); + } + if (desc->dtype == F32) { + return matmul_musa(desc, c, desc->beta, a, b, desc->alpha, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} \ No newline at end of file diff --git a/src/ops/matmul/operator.cc b/src/ops/matmul/operator.cc index 14748b99..5dd880a4 100644 --- a/src/ops/matmul/operator.cc +++ b/src/ops/matmul/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/matmul_maca.h" #endif +#ifdef ENABLE_MT_GPU +#include "musa/matmul_musa.h" +#endif __C infiniopStatus_t infiniopCreateMatmulDescriptor(infiniopHandle_t handle, infiniopMatmulDescriptor_t *desc_ptr, @@ -56,6 +59,11 @@ __C infiniopStatus_t infiniopCreateMatmulDescriptor(infiniopHandle_t handle, case DevMetaxGpu: { return macaCreateMatmulDescriptor((MacaHandle_t) handle, (MatmulMacaDescriptor_t *) desc_ptr, c_desc, alpha, a_desc, b_desc, beta); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateMatmulDescriptor((MusaHandle_t) handle, (MatmulMusaDescriptor_t *) desc_ptr, c_desc, alpha, a_desc, b_desc, beta); + } #endif } return STATUS_BAD_DEVICE; @@ -88,6 +96,11 @@ __C infiniopStatus_t infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t d case DevMetaxGpu: { return macaGetMatmulWorkspaceSize((MatmulMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaGetMatmulWorkspaceSize((MatmulMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -122,6 +135,11 @@ __C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc, void *works case DevMetaxGpu: { return macaMatmul((MatmulMacaDescriptor_t) desc, workspace, workspace_size, c, a, b, stream); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaMatmul((MatmulMusaDescriptor_t) desc, workspace, workspace_size, c, a, b, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -153,6 +171,11 @@ __C infiniopStatus_t infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t case DevMetaxGpu: { return macaDestroyMatmulDescriptor((MatmulMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaDestroyMatmulDescriptor((MatmulMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; From b14587b0e55d4849934efadf16274401c27b0267 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 15:02:10 +0800 Subject: [PATCH 03/15] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=91=A9=E5=B0=94?= =?UTF-8?q?=E7=BA=BF=E7=A8=8B=20Causal=5Fsoftmax=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/causal_softmax.py | 12 + .../musa/causal_softmax_musa.cc | 55 ++++ .../causal_softmax/musa/causal_softmax_musa.h | 36 +++ .../musa/causal_softmax_musa.mu | 258 ++++++++++++++++++ src/ops/causal_softmax/operator.cc | 23 ++ 5 files changed, 384 insertions(+) create mode 100644 src/ops/causal_softmax/musa/causal_softmax_musa.cc create mode 100644 src/ops/causal_softmax/musa/causal_softmax_musa.h create mode 100644 src/ops/causal_softmax/musa/causal_softmax_musa.mu diff --git a/operatorspy/tests/causal_softmax.py b/operatorspy/tests/causal_softmax.py index 623c0fac..762b0707 100644 --- a/operatorspy/tests/causal_softmax.py +++ b/operatorspy/tests/causal_softmax.py @@ -119,6 +119,16 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + + handle = create_handle(lib, device) + for x_shape, x_stride in test_cases: + test(lib, handle, "musa", x_shape, x_stride) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ # x_shape, x_stride @@ -161,6 +171,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) + if args.musa: + test_musa(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.cc b/src/ops/causal_softmax/musa/causal_softmax_musa.cc new file mode 100644 index 00000000..ae138efd --- /dev/null +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.cc @@ -0,0 +1,55 @@ +#include "causal_softmax_musa.h" +#include "../../utils.h" +#include "../../../devices/musa/common_musa.h" + +infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, + CausalSoftmaxMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y) { + unsigned long int ndim = y->ndim; + // TODO: only support 2d or 3d tensor + if (ndim != 2 && ndim != 3) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(y->dt, F16)) { + return STATUS_BAD_TENSOR_DTYPE; + } + unsigned long int total_seq_len = y->shape[ndim - 1]; + unsigned long int seq_len = y->shape[ndim - 2]; + unsigned long int batch_size = 1; + unsigned long int stride_b = 0; + unsigned long int stride_i = y->strides[ndim - 2]; + unsigned long int stride_j = y->strides[ndim - 1]; + if (stride_j != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + for (uint64_t i = 0; i < ndim - 2; i++) { + batch_size *= y->shape[i]; + } + if (ndim == 3) + stride_b = y->strides[ndim - 3]; + unsigned int max_items_per_thread = ROUND_UP_DIV(total_seq_len, MAX_THREADS_PER_BLOCK); + + *desc_ptr = new CausalSoftmaxMusaDescriptor{ + handle->device, + handle->device_id, + y->dt, + batch_size, + stride_b, + seq_len, + stride_i, + total_seq_len, + stride_j, + max_items_per_thread}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyCausalSoftmaxDescriptor(CausalSoftmaxMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.h b/src/ops/causal_softmax/musa/causal_softmax_musa.h new file mode 100644 index 00000000..90d588f0 --- /dev/null +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.h @@ -0,0 +1,36 @@ +#ifndef __MUSA_CAUSAL_SOFTMAX_H__ +#define __MUSA_CAUSAL_SOFTMAX_H__ + +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +struct CausalSoftmaxMusaDescriptor { + Device device; + int device_id; + DT dtype; + unsigned long int batch_size; + unsigned long int stride_b; + unsigned long int seq_len; + unsigned long int stride_i; + unsigned long int total_seq_len; + unsigned long int stride_j; + unsigned int max_items_per_thread; +}; + +typedef struct CausalSoftmaxMusaDescriptor *CausalSoftmaxMusaDescriptor_t; + +infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, + CausalSoftmaxMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc); + +infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t musaCausalSoftmax(CausalSoftmaxMusaDescriptor_t desc, + void *workspace, + unsigned long int workspace_size, + void *data, + void *stream); + +infiniopStatus_t musaDestroyCausalSoftmaxDescriptor(CausalSoftmaxMusaDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.mu b/src/ops/causal_softmax/musa/causal_softmax_musa.mu new file mode 100644 index 00000000..3bb92ad4 --- /dev/null +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.mu @@ -0,0 +1,258 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "causal_softmax_musa.h" +#include + +struct AttentionCausualMask { + __forceinline__ __device__ bool + operator()(int tok_id, int seq_len, + int pos_id, int total_seq_len) { + // tok_id ↓ |<-total_seq_len->| + // 0 | * * * ... * | + // 1 | * * * ... * * | + // 2 | * * * ... * * * | + // seq_len: 3 pos_id-> + return total_seq_len + tok_id >= pos_id + seq_len; + } +}; + +template +static __device__ void block_padding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const token_idx, + unsigned int const seq_len) { + auto att_idx = threadIdx.x, total_seq_len = blockDim.x; + auto thread_data = mask(token_idx, seq_len, att_idx, total_seq_len) + ? float(att[att_idx]) + : -__FLT_MAX__; + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + __shared__ float max; + { + auto acc = block_op.Reduce(thread_data, cub::Max(), total_seq_len); + if (threadIdx.x == 0) { max = acc; } + } + __syncthreads(); + + __shared__ float mean; + { + auto acc = block_op.Sum(thread_data = expf(thread_data - max), total_seq_len); + if (threadIdx.x == 0) { mean = fdividef(1, acc); } + } + __syncthreads(); + + att[att_idx] = Tdata(thread_data * mean); +} + +template +static __device__ void block_folding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const token_idx, + unsigned int const seq_len, + unsigned int const total_seq_len) { + + auto local = (total_seq_len + blockDim.x - 1) / blockDim.x; + + auto thread_offset = threadIdx.x * local; + att += thread_offset; + + float thread_data[ITEMS_PER_THREAD], thread_max = -__FLT_MAX__, thread_sum = 0; + for (unsigned int i = 0; i < local; ++i) { + auto att_idx = thread_offset + i; + thread_data[i] = att_idx < total_seq_len && mask(token_idx, seq_len, att_idx, total_seq_len) + ? float(att[i]) + : -__FLT_MAX__; + thread_max = cub::Max()(thread_max, thread_data[i]); + } + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + __shared__ float max; + { + auto acc = block_op.Reduce(thread_max, cub::Max()); + if (threadIdx.x == 0) { max = acc; } + } + __syncthreads(); + + __shared__ float mean; + { + for (unsigned int i = 0; i < local; ++i) { + thread_data[i] = expf(thread_data[i] - max); + thread_sum += thread_data[i]; + } + auto acc = block_op.Sum(thread_sum); + if (threadIdx.x == 0) { mean = fdividef(1, acc); } + } + __syncthreads(); + + for (unsigned int i = 0; i < local; ++i) { + if (auto att_idx = thread_offset + i; att_idx < total_seq_len) { + att[i] = Tdata(thread_data[i] * mean); + } + } +} + +// assert BLOCK_SIZE >= blockDim.x +template +static __forceinline__ __device__ void padding( + Tdata *__restrict__ att, + Tmask mask, + int const stride_x, + int const stride_y, + int const stride_z) { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + block_padding( + att + offset, mask, token_idx, seq_len); +} + +template +static __forceinline__ __device__ void folding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const total_seq_len, + int const stride_x, + int const stride_y, + int const stride_z) { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + block_folding( + att + offset, mask, token_idx, seq_len, total_seq_len); +} + +template +__global__ void fused_softmax_padding( + Tdata *__restrict__ att, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z) { + + padding(att, AttentionCausualMask(), stride_x, stride_y, stride_z); +} + +template +__global__ void fused_softmax_folding( + Tdata *__restrict__ att, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z, + unsigned int const total_seq_len) { + { + folding(att, AttentionCausualMask(), total_seq_len, stride_x, stride_y, stride_z); + } +} + +template +__global__ void fused_softmax_standard( + Tdata *__restrict__ att_, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z, + unsigned int const total_seq_len) { + { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + + auto att = att_ + offset; + auto att_idx = threadIdx.x; + + float partial; + __shared__ float max_; + __shared__ float sum_; + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + // Partial max + partial = -__FLT_MAX__; + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + partial = max(partial, float(att[i])); + } + } + __syncthreads(); + // Block reduce max + { + auto acc = block_op.Reduce(partial, cub::Max()); + if (threadIdx.x == 0) { max_ = acc; } + } + __syncthreads(); + + // Partial sum + partial = 0.; + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + float e = expf(float(att[i]) - max_); + partial += e; + } + } + __syncthreads(); + + // Block reduce sum + { + auto acc = block_op.Reduce(partial, cub::Sum()); + if (threadIdx.x == 0) { sum_ = acc; } + } + __syncthreads(); + + // Softmax + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + float e = expf(float(att[i]) - max_); + att[i] = e / sum_; + } else { + att[i] = half(0); + } + } + } +} + + +void causal_softmax_mt_gpu_f16(CausalSoftmaxMusaDescriptor_t desc, void* y, void *stream) { + unsigned long int total_seq_len = desc->total_seq_len; + unsigned long int seq_len = desc->seq_len; + unsigned long int batch_size = desc->batch_size; + unsigned long int stride_x = desc->stride_b; + unsigned long int stride_y = desc->stride_i; + unsigned long int stride_z = desc->stride_j;// covert byte strides to element strides + unsigned int max_items_per_thread = desc->max_items_per_thread; + + dim3 grid(batch_size, seq_len); + + if (max_items_per_thread == 1) { + fused_softmax_padding + <<>>((half *) (y), stride_x, stride_y, stride_z); + } else if (max_items_per_thread <= 16) { + fused_softmax_folding + <<>>((half *) (y), stride_x, stride_y, stride_z, total_seq_len); + } else { + fused_softmax_standard + <<>>((half *) (y), stride_x, stride_y, stride_z, total_seq_len); + } +} + +infiniopStatus_t musaCausalSoftmax(CausalSoftmaxMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *data, + void *stream) { +// if(musaSetDevice(desc->device_id) != musaSuccess){ +// return STATUS_BAD_DEVICE; +// } + if (dtype_eq(desc->dtype, F16)) { + causal_softmax_mt_gpu_f16(desc, data, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/causal_softmax/operator.cc b/src/ops/causal_softmax/operator.cc index c9d87dda..841eb75a 100644 --- a/src/ops/causal_softmax/operator.cc +++ b/src/ops/causal_softmax/operator.cc @@ -21,6 +21,10 @@ #ifdef ENABLE_METAX_GPU #include "maca/causal_softmax_maca.h" #endif +#ifdef ENABLE_MT_GPU +#include "musa/causal_softmax_musa.h" +#include "../../devices/musa/common_musa.h" +#endif __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( infiniopHandle_t handle, @@ -52,6 +56,11 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( case DevMetaxGpu: { return macaCreateCausalSoftmaxDescriptor((MacaHandle_t) handle, (CausalSoftmaxMacaDescriptor_t *) desc_ptr, y_desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateCausalSoftmaxDescriptor((MusaHandle_t) handle, (CausalSoftmaxMusaDescriptor_t *) desc_ptr, y_desc); + } #endif } return STATUS_BAD_DEVICE; @@ -85,6 +94,11 @@ __C infiniopStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmax case DevMetaxGpu: { return macaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -117,6 +131,11 @@ __C infiniopStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t des case DevMetaxGpu: { return macaCausalSoftmax((CausalSoftmaxMacaDescriptor_t) desc, workspace, workspace_size, data, stream); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCausalSoftmax((CausalSoftmaxMusaDescriptor_t) desc, workspace, workspace_size, data, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -149,6 +168,10 @@ __C infiniopStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftma case DevMetaxGpu: { return macaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: + return musaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMusaDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; From 6e84da6adc4f46d5ec1ee95f44836511157a67dd Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 15:06:37 +0800 Subject: [PATCH 04/15] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=91=A9=E5=B0=94?= =?UTF-8?q?=E7=BA=BF=E7=A8=8B=20rearrange=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/rearrange.py | 22 +++++++ src/ops/rearrange/musa/rearrange_musa.cc | 77 ++++++++++++++++++++++++ src/ops/rearrange/musa/rearrange_musa.h | 33 ++++++++++ src/ops/rearrange/musa/rearrange_musa.mu | 69 +++++++++++++++++++++ src/ops/rearrange/operator.cc | 18 ++++++ 5 files changed, 219 insertions(+) create mode 100644 src/ops/rearrange/musa/rearrange_musa.cc create mode 100644 src/ops/rearrange/musa/rearrange_musa.h create mode 100644 src/ops/rearrange/musa/rearrange_musa.mu diff --git a/operatorspy/tests/rearrange.py b/operatorspy/tests/rearrange.py index 124fe552..9709e6b3 100644 --- a/operatorspy/tests/rearrange.py +++ b/operatorspy/tests/rearrange.py @@ -117,6 +117,26 @@ def test_maca(lib, test_cases): test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for test_case in test_cases: + x_shape, x_stride = test_case[0] + y_shape, y_stride = test_case[1] + test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride) + destroy_handle(lib, handle) + +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for test_case in test_cases: + x_shape, x_stride = test_case[0] + y_shape, y_stride = test_case[1] + test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride) + destroy_handle(lib, handle) + if __name__ == "__main__": args = get_args() test_cases = [ @@ -156,4 +176,6 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) + if args.musa: + test_musa(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/rearrange/musa/rearrange_musa.cc b/src/ops/rearrange/musa/rearrange_musa.cc new file mode 100644 index 00000000..29f2b6b5 --- /dev/null +++ b/src/ops/rearrange/musa/rearrange_musa.cc @@ -0,0 +1,77 @@ +#include "rearrange_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include + +infiniopStatus_t musaCreateRearrangeDescriptor(MusaHandle_t handle, + RearrangeMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src) { + if (!dtype_eq(dst->dt, src->dt)) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (dst->ndim != src->ndim || dst->ndim < 2) { + return STATUS_BAD_TENSOR_SHAPE; + } + auto ndim = dst->ndim; + for (uint64_t i = 0; i < ndim; ++i) { + if (dst->shape[i] != src->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + unsigned int r = 0, c = 0, b = 0; + unsigned int rsa = 0, csa = 0, rsb = 0, csb = 0; + if (ndim == 2) { + c = dst->shape[0]; + b = dst->shape[1]; + csa = dst->strides[0]; + csb = src->strides[0]; + } else if (ndim == 3) { + r = dst->shape[0]; + c = dst->shape[1]; + b = dst->shape[2]; + csa = dst->strides[1]; + csb = src->strides[1]; + rsa = dst->strides[0]; + rsb = src->strides[0]; + } else { + for (uint64_t i = ndim - 3; i >= 1; --i) { + if ((int64_t) dst->shape[i] * dst->strides[i] != dst->strides[i - 1] || (int64_t) src->shape[i] * src->strides[i] != src->strides[i - 1]) { + return STATUS_BAD_TENSOR_STRIDES; + } + } + r = std::accumulate(dst->shape, dst->shape + ndim - 2, 1, std::multiplies()); + c = dst->shape[ndim - 2]; + b = dst->shape[ndim - 1]; + csa = dst->strides[ndim - 2]; + csb = src->strides[ndim - 2]; + rsa = dst->strides[ndim - 3]; + rsb = src->strides[ndim - 3]; + } + auto contiguous_bytes = b * dst->dt.size; + if (contiguous_bytes % WARP_SIZE != 0) { + return STATUS_BAD_PARAM; + } + auto bytes_per_thread = contiguous_bytes / WARP_SIZE ; + if (bytes_per_thread <= 0 || bytes_per_thread > 32 || (bytes_per_thread & (bytes_per_thread - 1)) != 0) { + return STATUS_BAD_PARAM; + } + *desc_ptr = new RearrangeMusaDescriptor{ + handle->device, + handle->device_id, + rsa, + rsb, + csa, + csb, + r, c, b, + bytes_per_thread}; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyRearrangeDescriptor(RearrangeMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rearrange/musa/rearrange_musa.h b/src/ops/rearrange/musa/rearrange_musa.h new file mode 100644 index 00000000..7ebdb4e5 --- /dev/null +++ b/src/ops/rearrange/musa/rearrange_musa.h @@ -0,0 +1,33 @@ +#ifndef __MUSA_REARRANGE_H__ +#define __MUSA_REARRANGE_H__ + +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +struct RearrangeMusaDescriptor { + Device device; + int device_id; + unsigned long int rsa; + unsigned long int rsb; + unsigned long int csa; + unsigned long int csb; + unsigned long int r, c, b; + unsigned long int bytes_per_thread; +}; + +typedef struct RearrangeMusaDescriptor *RearrangeMusaDescriptor_t; + +infiniopStatus_t musaCreateRearrangeDescriptor(MusaHandle_t handle, + RearrangeMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src); + +infiniopStatus_t musaRearrange(RearrangeMusaDescriptor_t desc, + void *dst, + void const *src, + void *stream); + +infiniopStatus_t musaDestroyRearrangeDescriptor(RearrangeMusaDescriptor_t desc); + +void rearrange_mt_gpu(RearrangeMusaDescriptor *, void *y, void const *x, void *stream); +#endif // __MUSA_REARRANGE_H__ diff --git a/src/ops/rearrange/musa/rearrange_musa.mu b/src/ops/rearrange/musa/rearrange_musa.mu new file mode 100644 index 00000000..ee094869 --- /dev/null +++ b/src/ops/rearrange/musa/rearrange_musa.mu @@ -0,0 +1,69 @@ +#include "../../../devices/musa/common_musa.h" +#include "rearrange_musa.h" + +template +static __global__ void rearrange( + void *__restrict__ dst, + unsigned int const rsa, + unsigned int const csa, + void const *__restrict__ src, + unsigned int const rsb, + unsigned int const csb, + unsigned int const ncols) { + + auto row = blockIdx.y, + col = blockIdx.x * blockDim.y + threadIdx.y; + if (col >= ncols) return; + + auto thread = threadIdx.x, + warp_size = blockDim.x; + auto i = (row * rsa + col * csa) * warp_size + thread; + auto j = (row * rsb + col * csb) * warp_size + thread; + + reinterpret_cast(dst)[i] = reinterpret_cast(src)[j]; +} + + +void rearrange_mt_gpu(RearrangeMusaDescriptor_t desc, void *y, void const *x, void *stream) { + unsigned long int rsa = desc->rsa, csa = desc->csa, rsb = desc->rsb, csb = desc->csb; + unsigned int r = desc->r, c = desc->c, b = desc->b, bytes_per_thread = desc->bytes_per_thread; + auto dst_ptr = static_cast(reinterpret_cast(y)); + rsa /= b; + csa /= b; + auto src_ptr = static_cast(reinterpret_cast(x)); + rsb /= b; + csb /= b; + auto musa_stream = reinterpret_cast(stream); + dim3 grid_dims = dim3((c + MAX_WARP_PER_BLOCK - 1) / MAX_WARP_PER_BLOCK, r); + dim3 block_dims = dim3(WARP_SIZE, (c + grid_dims.x - 1) / grid_dims.x); + switch (bytes_per_thread) { + case 1: + rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + break; + case 2: + rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + break; + case 4: + rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + break; + case 8: + rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + break; + case 16: + rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + break; + case 32: + rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + break; + default: + break; + } +} +infiniopStatus_t musaRearrange(RearrangeMusaDescriptor_t desc, + void *dst, void const *src, void *stream) { +// if(musaSetDevice(desc->device_id) != musaSuccess){ +// return STATUS_BAD_DEVICE; +// } + rearrange_mt_gpu(desc, dst, src, stream); + return STATUS_SUCCESS; +} diff --git a/src/ops/rearrange/operator.cc b/src/ops/rearrange/operator.cc index 752211e5..d3da887c 100644 --- a/src/ops/rearrange/operator.cc +++ b/src/ops/rearrange/operator.cc @@ -20,6 +20,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/rearrange_maca.h" #endif +#ifdef ENABLE_MT_GPU +#include "musa/rearrange_musa.h" +#endif __C infiniopStatus_t infiniopCreateRearrangeDescriptor( infiniopHandle_t handle, @@ -54,6 +57,11 @@ __C infiniopStatus_t infiniopCreateRearrangeDescriptor( case DevMetaxGpu: { return macaCreateRearrangeDescriptor((MacaHandle_t) handle, (RearrangeMacaDescriptor_t *) desc_ptr, dst, src); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateRearrangeDescriptor((MusaHandle_t)handle, (RearrangeMusaDescriptor_t *) desc_ptr, dst, src); + } #endif } return STATUS_BAD_DEVICE; @@ -88,6 +96,11 @@ __C infiniopStatus_t infiniopRearrange(infiniopRearrangeDescriptor_t desc, void case DevMetaxGpu: { return macaRearrange((RearrangeMacaDescriptor_t) desc, dst, src, stream); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaRearrange((RearrangeMusaDescriptor_t) desc, dst, src, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -119,6 +132,11 @@ __C infiniopStatus_t infiniopDestroyRearrangeDescriptor(infiniopRearrangeDescrip case DevMetaxGpu: { return macaDestroyRearrangeDescriptor((RearrangeMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaDestroyRearrangeDescriptor((RearrangeMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; From c5bc2819a94023f189a7bb94c053ed029d36f489 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 15:12:18 +0800 Subject: [PATCH 05/15] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=91=A9=E5=B0=94?= =?UTF-8?q?=E7=BA=BF=E7=A8=8B=20rms=5Fnorm=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/rms_norm.py | 10 ++ src/ops/rms_norm/musa/rms_norm_musa.cc | 46 +++++++ src/ops/rms_norm/musa/rms_norm_musa.h | 40 ++++++ src/ops/rms_norm/musa/rms_norm_musa.mu | 173 +++++++++++++++++++++++++ src/ops/rms_norm/operator.cc | 24 +++- 5 files changed, 292 insertions(+), 1 deletion(-) create mode 100644 src/ops/rms_norm/musa/rms_norm_musa.cc create mode 100644 src/ops/rms_norm/musa/rms_norm_musa.h create mode 100644 src/ops/rms_norm/musa/rms_norm_musa.mu diff --git a/operatorspy/tests/rms_norm.py b/operatorspy/tests/rms_norm.py index 8176af64..a11b794f 100644 --- a/operatorspy/tests/rms_norm.py +++ b/operatorspy/tests/rms_norm.py @@ -125,6 +125,14 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: + test(lib, handle, "musa", y_shape, x_shape, w_shape, dtype, w_dtype) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ # y_shape, x_shape, w_shape, dtype, w_dtype @@ -174,6 +182,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) + if args.musa: + test_musa(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/rms_norm/musa/rms_norm_musa.cc b/src/ops/rms_norm/musa/rms_norm_musa.cc new file mode 100644 index 00000000..5b053e73 --- /dev/null +++ b/src/ops/rms_norm/musa/rms_norm_musa.cc @@ -0,0 +1,46 @@ +#include "rms_norm_musa.h" +#include "../../utils.h" +#include "../../../devices/musa/common_musa.h" + +infiniopStatus_t musaCreateRMSNormDescriptor(MusaHandle_t handle, RMSNormMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + float epsilon) { + if (y_desc->ndim != 2 || x_desc->ndim != 2 || w_desc->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + + auto n = y_desc->shape[0], + d = y_desc->shape[1]; + + if (x_desc->shape[0] != n || x_desc->shape[1] != d || w_desc->shape[0] != d) { + return STATUS_BAD_TENSOR_SHAPE; + } + + unsigned long int stride_y = y_desc->strides[0]; + unsigned long int stride_x = x_desc->strides[0]; + auto w_datatype = w_desc->dt; + *desc_ptr = new RMSNormMusaDescriptor{ + handle->device, + handle->device_id, + y_desc->dt, + n, + d, + stride_y, + stride_x, + w_datatype, + epsilon}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyRMSNormDescriptor(RMSNormMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rms_norm/musa/rms_norm_musa.h b/src/ops/rms_norm/musa/rms_norm_musa.h new file mode 100644 index 00000000..292d5212 --- /dev/null +++ b/src/ops/rms_norm/musa/rms_norm_musa.h @@ -0,0 +1,40 @@ +#ifndef __MUSA_RMS_NORM_H__ +#define __MUSA_RMS_NORM_H__ + +#include "operators.h" +#include "../../../devices/musa/musa_handle.h" + +struct RMSNormMusaDescriptor { + Device device; + int device_id; + DT dtype; + unsigned long int n; + unsigned long int d; + unsigned long int stride_y; + unsigned long int stride_x; + DT w_datatype; + float epsilon; +}; + +typedef struct RMSNormMusaDescriptor *RMSNormMusaDescriptor_t; + +infiniopStatus_t musaCreateRMSNormDescriptor(MusaHandle_t handle, + RMSNormMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + float epsilon); + +infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t musaRMSNorm(RMSNormMusaDescriptor_t desc, + void *workspace, + unsigned long int workspace_size, + void *y, void const *x, void const *w, + void *stream); + +infiniopStatus_t musaDestroyRMSNormDescriptor(RMSNormMusaDescriptor_t desc); + +void rms_norm_mt_gpu_f16(RMSNormMusaDescriptor_t desc, void *y, void const *x, void const *w, float epsilon, void *stream); + +#endif// __MT_GPU_RMS_NORM_H__ diff --git a/src/ops/rms_norm/musa/rms_norm_musa.mu b/src/ops/rms_norm/musa/rms_norm_musa.mu new file mode 100644 index 00000000..c023b8b7 --- /dev/null +++ b/src/ops/rms_norm/musa/rms_norm_musa.mu @@ -0,0 +1,173 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "rms_norm_musa.h" +#include +#include + +// assert BLOCK_SIZE >= blockDim.x +template +static __global__ void rms_norm_padding( + Tdata *__restrict__ o_, + unsigned int const stride_y, + Tdata const *__restrict__ x_, + unsigned int const stride_x, + Wdata const *__restrict__ w_, + float const epsilon) { + auto y = o_ + blockIdx.x * stride_y + threadIdx.x; + auto x = x_[blockIdx.x * stride_x + threadIdx.x]; + auto w = w_[threadIdx.x]; + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto acc = BlockOp(temp_storage).Reduce(x * x, cub::Sum()); + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + rms = Tdata(rsqrtf(acc / float(blockDim.x) + epsilon)); + } + __syncthreads(); + + *y = rms * x * (Tdata)w; +} + +template +static __global__ void rms_norm_folding( + Tdata *__restrict__ y, + unsigned int const stride_y, + Tdata const *__restrict__ x, + unsigned int const stride_x, + Wdata const *__restrict__ w, + float const epsilon, + unsigned int const items_size) { + y += blockIdx.x * stride_y; + x += blockIdx.x * stride_x; + + float thread_data[ITEMS_PER_THREAD]; + { + using BlockOp = cub::BlockLoad; + __shared__ typename BlockOp::TempStorage temp_storage; + BlockOp(temp_storage).Load(x, thread_data, items_size, 0.f); + } + + float squared[ITEMS_PER_THREAD]; +#pragma unroll + for (unsigned int i = 0; i < ITEMS_PER_THREAD; ++i) { + squared[i] = thread_data[i] * thread_data[i]; + } + + float acc; + { + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + acc = BlockOp(temp_storage).Reduce(squared, cub::Sum()); + } + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + rms = Tdata(rsqrtf(acc / float(items_size) + epsilon)); + } + __syncthreads(); + +#pragma unroll + for (unsigned int i = 0; i < ITEMS_PER_THREAD; ++i) { + if (auto j = i + threadIdx.x * ITEMS_PER_THREAD; j < items_size) { + y[j] = Tdata(float(rms) * float(thread_data[i]) * float(w[j])); + } + } +} + +template +static __global__ void rms_norm_standard( + Tdata *__restrict__ y_, + unsigned int const stride_y, + Tdata const *__restrict__ x_, + unsigned int const stride_x, + Wdata const *__restrict__ w, + float const epsilon, + unsigned int const d) { + auto y = y_ + blockIdx.x * stride_y; + auto x = x_ + blockIdx.x * stride_x; + + __shared__ float partial_sum[BLOCK_SIZE]; + + float sum = 0.0f; + for (int i = threadIdx.x; i < d; i += BLOCK_SIZE) { + sum += float(x[i]) * float(x[i]); + } + + partial_sum[threadIdx.x] = sum; + __syncthreads(); + for (int stride = BLOCK_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + partial_sum[threadIdx.x] += partial_sum[threadIdx.x + stride]; + } + __syncthreads(); + } + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + float row_sum = partial_sum[0]; + rms = Tdata(rsqrtf(row_sum / float(d) + epsilon)); + } + __syncthreads(); + + for (int i = threadIdx.x; i < d; i += BLOCK_SIZE) { + y[i] = rms * x[i] * (Tdata)w[i]; + } +} + +void rms_norm_mt_gpu_f16(RMSNormMusaDescriptor_t desc, void *y, void const *x, void const *w, void *stream) { + auto n = desc->n, d = desc->d; + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto epsilon = desc->epsilon; + + // Get strides in terms of elements + auto stride_y = desc->stride_y; + auto stride_x = desc->stride_x; + + auto musa_stream = reinterpret_cast(stream); + unsigned int items_per_thread = ROUND_UP_DIV(d, MAX_THREADS_PER_BLOCK); + auto w_datatype = desc->w_datatype; + if (dtype_eq(w_datatype, F16)) { + auto w_ = reinterpret_cast(w); + if (items_per_thread == 1) { + rms_norm_padding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon); + } else if (items_per_thread <= 16) { + rms_norm_folding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } else { + rms_norm_standard + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } + } else { + auto w_ = reinterpret_cast(w); + if (items_per_thread == 1) { + rms_norm_padding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon); + } else if (items_per_thread <= 16) { + rms_norm_folding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } else { + rms_norm_standard + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } + } +} + +infiniopStatus_t musaRMSNorm(RMSNormMusaDescriptor_t desc, + void *workspace, + unsigned long int workspace_size, + void *y, void const *x, void const *w, + void *stream){ +// if(musaSetDevice(desc->device_id) != musaSuccess){ +// return STATUS_BAD_DEVICE; +// } + if (dtype_eq(desc->dtype, F16)){ + rms_norm_mt_gpu_f16(desc, y, x, w, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/rms_norm/operator.cc b/src/ops/rms_norm/operator.cc index dff9573b..b90adef7 100644 --- a/src/ops/rms_norm/operator.cc +++ b/src/ops/rms_norm/operator.cc @@ -20,6 +20,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/rms_norm_maca.h" #endif +#ifdef ENABLE_MT_GPU +#include "musa/rms_norm_musa.h" +#endif __C infiniopStatus_t infiniopCreateRMSNormDescriptor( infiniopHandle_t handle, @@ -57,6 +60,11 @@ __C infiniopStatus_t infiniopCreateRMSNormDescriptor( case DevMetaxGpu: { return macaCreateRMSNormDescriptor((MacaHandle_t) handle, (RMSNormMacaDescriptor_t *) desc_ptr, y_desc, x_desc, w_desc, epsilon); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateRMSNormDescriptor((MusaHandle_t) handle, (RMSNormMusaDescriptor_t *) desc_ptr, y_desc, x_desc, w_desc, epsilon); + } #endif } return STATUS_BAD_DEVICE; @@ -89,6 +97,11 @@ __C infiniopStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t case DevMetaxGpu: { return macaGetRMSNormWorkspaceSize((RMSNormMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaGetRMSNormWorkspaceSize((RMSNormMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -127,6 +140,11 @@ __C infiniopStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *wor case DevMetaxGpu: { return macaRMSNorm((RMSNormMacaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaRMSNorm((RMSNormMusaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -153,12 +171,16 @@ __C infiniopStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_ case DevAscendNpu: { return aclnnDestroyRMSNormDescriptor((RMSNormAclnnDescriptor_t) desc); } - #endif #ifdef ENABLE_METAX_GPU case DevMetaxGpu: { return macaDestroyRMSNormDescriptor((RMSNormMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaDestroyRMSNormDescriptor((RMSNormMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; From 8bd132fb34f3319e945bfedc064186914370d490 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 15:19:43 +0800 Subject: [PATCH 06/15] =?UTF-8?q?=E6=91=A9=E5=B0=94=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20Rope=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/rotary_embedding.py | 19 ++++- .../musa/rotary_embedding_musa.cc | 76 +++++++++++++++++++ .../musa/rotary_embedding_musa.h | 40 ++++++++++ .../musa/rotary_embedding_musa.mu | 68 +++++++++++++++++ src/ops/rotary_embedding/operator.cc | 23 ++++++ 5 files changed, 222 insertions(+), 4 deletions(-) create mode 100644 src/ops/rotary_embedding/musa/rotary_embedding_musa.cc create mode 100644 src/ops/rotary_embedding/musa/rotary_embedding_musa.h create mode 100644 src/ops/rotary_embedding/musa/rotary_embedding_musa.mu diff --git a/operatorspy/tests/rotary_embedding.py b/operatorspy/tests/rotary_embedding.py index b7123052..de5b471a 100644 --- a/operatorspy/tests/rotary_embedding.py +++ b/operatorspy/tests/rotary_embedding.py @@ -77,7 +77,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): pos[2 * i] = posTmp[i] pos[2 * i + 1] = 0 theta = 1e4 - if torch_device == 'mlu' or torch_device == 'npu': + if torch_device == 'mlu' or torch_device == 'npu' or torch_device == 'musa': ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device) pos = pos.to(torch_device) t = t.to(torch_device) @@ -94,8 +94,9 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): # 2x table length for test sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) t_tensor = to_tensor(t, lib) - pos_tensor = to_tensor(pos[: t.shape[0]], lib) - pos_tensor.descriptor.contents.dt = U64 + pos_tensor = to_tensor(pos, lib) + if(torch_device == 'mlu' or torch_device == 'musa'): + pos_tensor.descriptor.contents.dt = U64 sin_table_tensor = to_tensor(sin_table, lib) cos_table_tensor = to_tensor(cos_table, lib) @@ -181,6 +182,14 @@ def test_maca(lib, test_cases) : test(lib, handle, "maca", shape, strides, dtype) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for shape, strides, dtype in test_cases: + test(lib, handle, "musa", shape, strides, dtype) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ ((1, 32, 128), None, torch.float16), @@ -233,6 +242,8 @@ def test_maca(lib, test_cases) : test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc b/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc new file mode 100644 index 00000000..b5bdf33a --- /dev/null +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc @@ -0,0 +1,76 @@ +#include "rotary_embedding_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateRoPEDescriptor(MusaHandle_t handle, + RoPEMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table) { + if (desc_ptr == nullptr) + return STATUS_MEMORY_NOT_ALLOCATED; + + if (t->ndim != 3 || + pos_ids->ndim != 1 || + sin_table->ndim != 2 || + cos_table->ndim != 2) + return STATUS_BAD_TENSOR_SHAPE; + + auto seq_len = t->shape[0]; + auto nhead = t->shape[1]; + auto dim = t->shape[2]; + auto total_seq_len = sin_table->shape[0]; + + if (dim % 2 != 0) + return STATUS_BAD_TENSOR_SHAPE; + + if (pos_ids->shape[0] != seq_len || + sin_table->shape[1] != dim || + cos_table->shape[1] != dim || + sin_table->shape[0] != cos_table->shape[0]) + return STATUS_BAD_TENSOR_SHAPE; + + // TODO: support larger dim in the future + if (dim / 2 > MAX_THREADS_PER_BLOCK) { + return STATUS_BAD_TENSOR_SHAPE; + } + + if (t->strides[2] != 1 || + pos_ids->strides[0] != 1 || + sin_table->strides[1] != 1 || + cos_table->strides[1] != 1) + return STATUS_BAD_TENSOR_STRIDES; + + if (!dtype_eq(t->dt, F16)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(pos_ids->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + + *desc_ptr = new RoPEMusaDescriptor{ + handle->device, + handle->device_id, + t->dt, + seq_len, + nhead, + dim, + total_seq_len, + {t->strides[0], t->strides[1]}}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} + + +infiniopStatus_t musaDestroyRoPEDescriptor(RoPEMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.h b/src/ops/rotary_embedding/musa/rotary_embedding_musa.h new file mode 100644 index 00000000..7124a76f --- /dev/null +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.h @@ -0,0 +1,40 @@ +#ifndef __MUSA_ROTARY_EMBEDDING_H__ +#define __MUSA_ROTARY_EMBEDDING_H__ + +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" + +struct RoPEMusaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t seq_len; + uint64_t nhead; + uint64_t dim; + uint64_t total_seq_len; + int64_t strides[2]; +}; + +typedef struct RoPEMusaDescriptor *RoPEMusaDescriptor_t; + +infiniopStatus_t musaCreateRoPEDescriptor(MusaHandle_t handle, + RoPEMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table); + +infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t musaRoPE(RoPEMusaDescriptor_t desc, + void *workspace, + unsigned long int workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream); + +infiniopStatus_t musaDestroyRoPEDescriptor(RoPEMusaDescriptor_t desc); + +#endif// __MT_GPU_ROTARY_EMBEDDING_H__ diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu b/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu new file mode 100644 index 00000000..56875482 --- /dev/null +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu @@ -0,0 +1,68 @@ +#include "../../utils.h" +#include "rotary_embedding_musa.h" +#include + +static __global__ void padding_f16( + half *__restrict__ x_, + unsigned long const *__restrict__ pos_, + float const *__restrict__ sin_, + float const *__restrict__ cos_, + long const stride0, + long const stride1) { + auto dk = blockDim.x; + auto k = threadIdx.x; + auto offset = blockIdx.x * stride0 + blockIdx.y * stride1 + k * 2; + auto &x = reinterpret_cast(x_[offset]); + auto pos = pos_[blockIdx.x]; + auto sincos_offset = pos * dk * 2 + k * 2; + + float sin0 = sin_[sincos_offset], cos0 = cos_[sincos_offset], + sin1 = sin_[sincos_offset + 1], cos1 = cos_[sincos_offset + 1]; + float x0 = __half2float(x.x) * cos0 - __half2float(x.y) * sin0; + float x1 = __half2float(x.y) * cos1 + __half2float(x.x) * sin1; + x = half2(x0, x1); +} + + +void rotary_embedding_mt_gpu_f16( + RoPEMusaDescriptor_t desc, + half *t, + unsigned long const *pos, + float const *sin_, float const *cos_, + void *stream) { + auto nt = desc->seq_len, + nh = desc->nhead, + dh = desc->dim; + + // batching 2 half together + auto stride0 = desc->strides[0], + stride1 = desc->strides[1]; + + auto musa_stream = reinterpret_cast(stream); + padding_f16<<>>(t, pos, sin_, cos_, stride0, stride1); +} + +infiniopStatus_t musaRoPE(RoPEMusaDescriptor_t desc, + void *workspace, + unsigned long int workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream) { + if (t == nullptr || pos_ids == nullptr || sin_table == nullptr || cos_table == nullptr) + return STATUS_BAD_PARAM; + + if (dtype_eq(desc->dtype, F16)) { + rotary_embedding_mt_gpu_f16(desc, + reinterpret_cast(t), + reinterpret_cast(pos_ids), + reinterpret_cast(sin_table), + reinterpret_cast(cos_table), + stream); + } else { + return STATUS_BAD_TENSOR_DTYPE; + } + + return STATUS_SUCCESS; +} diff --git a/src/ops/rotary_embedding/operator.cc b/src/ops/rotary_embedding/operator.cc index 5c1d4aec..8f3707b2 100644 --- a/src/ops/rotary_embedding/operator.cc +++ b/src/ops/rotary_embedding/operator.cc @@ -18,6 +18,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/rotary_embedding_maca.h" #endif +#ifdef ENABLE_MT_GPU +#include "musa/rotary_embedding_musa.h" +#endif struct RoPEDescriptor { Device device; @@ -65,6 +68,11 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(infiniopHandle_t handle, sin_table, cos_table); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateRoPEDescriptor((MusaHandle_t) handle, (RoPEMusaDescriptor_t *) desc_ptr, t, pos_ids, sin_table, cos_table); + } #endif } return STATUS_BAD_DEVICE; @@ -98,6 +106,11 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, return macaGetRoPEWorkspaceSize((RoPEMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaGetRoPEWorkspaceSize((RoPEMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -150,6 +163,11 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc, cos_table, stream); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaRoPE((RoPEMusaDescriptor_t) desc, workspace, workspace_size, t, pos_ids, sin_table, cos_table, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -181,6 +199,11 @@ __C infiniopStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc case DevMetaxGpu: { return macaDestroyRoPEDescriptor((RoPEMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaDestroyRoPEDescriptor((RoPEMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; From 4522473618305ee89dbfa484aa9800637de5e3a8 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 15:25:34 +0800 Subject: [PATCH 07/15] =?UTF-8?q?=E6=91=A9=E5=B0=94=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20swiglu=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/swiglu.py | 16 +++++++ src/ops/swiglu/musa/swiglu.mu | 68 ++++++++++++++++++++++++++++++ src/ops/swiglu/musa/swiglu_musa.cc | 50 ++++++++++++++++++++++ src/ops/swiglu/musa/swiglu_musa.h | 34 +++++++++++++++ src/ops/swiglu/operator.cc | 15 +++++++ 5 files changed, 183 insertions(+) create mode 100644 src/ops/swiglu/musa/swiglu.mu create mode 100644 src/ops/swiglu/musa/swiglu_musa.cc create mode 100644 src/ops/swiglu/musa/swiglu_musa.h diff --git a/operatorspy/tests/swiglu.py b/operatorspy/tests/swiglu.py index fcd044f1..9ca07c14 100644 --- a/operatorspy/tests/swiglu.py +++ b/operatorspy/tests/swiglu.py @@ -262,6 +262,20 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + + for shape, a_stride, b_stride, c_stride, dtype in test_cases: + test_out_of_place( + lib, handle, "musa", shape, a_stride, b_stride, c_stride, dtype + ) + test_in_place1(lib, handle, "musa", shape, a_stride, b_stride, dtype) + test_in_place2(lib, handle, "musa", shape, a_stride, b_stride, dtype) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -307,4 +321,6 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) + if args.musa: + test_musa(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/swiglu/musa/swiglu.mu b/src/ops/swiglu/musa/swiglu.mu new file mode 100644 index 00000000..259e5c6f --- /dev/null +++ b/src/ops/swiglu/musa/swiglu.mu @@ -0,0 +1,68 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "swiglu_musa.h" +#include + +static __forceinline__ __device__ float silu(float x) { + return x * fdividef(1, 1 + expf(-x)); +} + +inline int gcd(int a, int b) { + while (b != 0) { + int rem = a % b; + a = b; + b = rem; + } + return a; +} + +template +static __global__ void swiglu( + Tdata *__restrict__ c, + int const stride_c, + Tdata const *__restrict__ a, + int const stride_a, + Tdata const *__restrict__ b, + int const stride_b) { + auto i = blockIdx.y * stride_b + blockIdx.x * blockDim.x + threadIdx.x, + j = blockIdx.y * stride_a + blockIdx.x * blockDim.x + threadIdx.x, + k = blockIdx.y * stride_c + blockIdx.x * blockDim.x + threadIdx.x; + auto x = float(b[i]), + y = float(a[j]); + c[k] = Tdata(silu(x) * y); +} + +void swiglu_mt_gpu_f16(SwiGLUMusaDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { + + auto seq_len = desc->seq_len, + di = desc->di; + + auto stride_a = desc->stride_a, + stride_b = desc->stride_b, + stride_c = desc->stride_c; + + dim3 block_dims = gcd(MAX_THREADS_PER_BLOCK, di); + dim3 grid_dims = dim3(di / block_dims.x, seq_len); + + auto a_ptr = reinterpret_cast(a); + auto b_ptr = reinterpret_cast(b); + auto c_ptr = reinterpret_cast(c); + + auto musa_stream = reinterpret_cast(stream); + + swiglu<<>>( + c_ptr, stride_c, a_ptr, stride_a, b_ptr, stride_b); +} + +infiniopStatus_t musaSwiGLU(SwiGLUMusaDescriptor_t desc, + void *c, + void const *a, + void const *b, + void *stream) { + if (dtype_eq(desc->dtype, F16)) { + swiglu_mt_gpu_f16(desc, c, a, b, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/swiglu/musa/swiglu_musa.cc b/src/ops/swiglu/musa/swiglu_musa.cc new file mode 100644 index 00000000..88169be3 --- /dev/null +++ b/src/ops/swiglu/musa/swiglu_musa.cc @@ -0,0 +1,50 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "swiglu_musa.h" + +infiniopStatus_t musaCreateSwiGLUDescriptor(infiniopHandle_t handle, + SwiGLUMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + if (c_desc->ndim != 2 || a_desc->ndim != 2 || b_desc->ndim != 2) { + return STATUS_BAD_TENSOR_SHAPE; + } + + DT dtype = c_desc->dt; + + if (!dtype_eq(dtype, F16)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + if (a_desc->strides[1] != 1 || b_desc->strides[1] != 1 || c_desc->strides[1] != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + + uint64_t seq_len = c_desc->shape[0], + di = c_desc->shape[1]; + + uint64_t stride_a = a_desc->strides[0], + stride_b = b_desc->strides[0], + stride_c = c_desc->strides[0]; + + + if (a_desc->shape[0] != seq_len || a_desc->shape[1] != di || !dtype_eq(a_desc->dt, dtype) || + b_desc->shape[0] != seq_len || b_desc->shape[1] != di || !dtype_eq(b_desc->dt, dtype)) { + return STATUS_BAD_PARAM; + } + + *desc_ptr = new SwiGLUMusaDescriptor{DevMtGpu, + dtype, + seq_len, + di, + stride_a, + stride_b, + stride_c}; + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroySwiGLUDescriptor(SwiGLUMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/swiglu/musa/swiglu_musa.h b/src/ops/swiglu/musa/swiglu_musa.h new file mode 100644 index 00000000..00ae1155 --- /dev/null +++ b/src/ops/swiglu/musa/swiglu_musa.h @@ -0,0 +1,34 @@ +#ifndef __MUSA_SWIGLU_H__ +#define __MUSA_SWIGLU_H__ + +#include "operators.h" + +struct SwiGLUMusaDescriptor { + Device device; + DT dtype; + uint64_t seq_len; + uint64_t di; + uint64_t stride_a; + uint64_t stride_b; + uint64_t stride_c; +}; + +typedef struct SwiGLUMusaDescriptor *SwiGLUMusaDescriptor_t; + +infiniopStatus_t musaCreateSwiGLUDescriptor(infiniopHandle_t handle, + SwiGLUMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_dec, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + +infiniopStatus_t musaSwiGLU(SwiGLUMusaDescriptor_t desc, + void *c, + void const *a, + void const *b, + void *stream); + +infiniopStatus_t musaDestroySwiGLUDescriptor(SwiGLUMusaDescriptor_t desc); + +void swiglu_mt_gpu_f16(SwiGLUMusaDescriptor_t desc, void *c, void const *a, void const *b, void *stream); + +#endif// __MT_GPU_SWIGLU_H__ diff --git a/src/ops/swiglu/operator.cc b/src/ops/swiglu/operator.cc index 3eb68a97..06699b0d 100644 --- a/src/ops/swiglu/operator.cc +++ b/src/ops/swiglu/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/swiglu_maca.h" #endif +#ifdef ENABLE_MT_GPU +#include "musa/swiglu_musa.h" +#endif __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, infiniopSwiGLUDescriptor_t *desc_ptr, @@ -57,6 +60,10 @@ __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, a_desc, b_desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: + return musaCreateSwiGLUDescriptor(handle, (SwiGLUMusaDescriptor_t *) desc_ptr, c_desc, a_desc, b_desc); #endif } return STATUS_BAD_DEVICE; @@ -88,6 +95,10 @@ __C infiniopStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, #ifdef ENABLE_METAX_GPU case DevMetaxGpu: return macaSwiGLU((SwiGLUMacaDescriptor_t) desc, c, a, b, stream); +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: + return musaSwiGLU((SwiGLUMusaDescriptor_t) desc, c, a, b, stream); #endif } return STATUS_BAD_DEVICE; @@ -115,6 +126,10 @@ __C infiniopStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t #ifdef ENABLE_METAX_GPU case DevMetaxGpu: return macaDestroySwiGLUDescriptor((SwiGLUMacaDescriptor_t) desc); +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: + return musaDestroySwiGLUDescriptor((SwiGLUMusaDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; From 329ca2152a274873b20c8ef195ca413d23a40774 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 15:37:56 +0800 Subject: [PATCH 08/15] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=91=A9=E5=B0=94?= =?UTF-8?q?=E7=BA=BF=E7=A8=8B=20random=20sample=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/random_sample.py | 13 +- .../random_sample/musa/random_sample_musa.cc | 37 ++++ .../random_sample/musa/random_sample_musa.h | 38 ++++ .../random_sample/musa/random_sample_musa.mu | 180 ++++++++++++++++++ src/ops/random_sample/operator.cc | 20 ++ 5 files changed, 286 insertions(+), 2 deletions(-) create mode 100644 src/ops/random_sample/musa/random_sample_musa.cc create mode 100644 src/ops/random_sample/musa/random_sample_musa.h create mode 100644 src/ops/random_sample/musa/random_sample_musa.mu diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index 4b0c2a10..2c464522 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -170,7 +170,7 @@ def test_ascend(lib, test_cases): for (voc, random_val, topp, topk, temperature) in test_cases: test(lib, handle, "npu", voc, random_val, topp, topk, temperature) destroy_handle(lib, handle) - + def test_maca(lib, test_cases): device = DeviceEnum.DEVICE_MACA handle = create_handle(lib, device) @@ -179,6 +179,13 @@ def test_maca(lib, test_cases): destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for (voc, random_val, topp, topk, temperature) in test_cases: + test(lib, handle, "musa", voc, random_val, topp, topk, temperature) + destroy_handle(lib, handle) if __name__ == "__main__": test_cases = [ @@ -236,6 +243,8 @@ def test_maca(lib, test_cases): test_ascend(lib, test_cases) if args.maca: test_maca(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/random_sample/musa/random_sample_musa.cc b/src/ops/random_sample/musa/random_sample_musa.cc new file mode 100644 index 00000000..29f676f9 --- /dev/null +++ b/src/ops/random_sample/musa/random_sample_musa.cc @@ -0,0 +1,37 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "random_sample_musa.h" + +infiniopStatus_t musaCreateRandomSampleDescriptor(MusaHandle_t handle, + RandomSampleMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs) { + if (probs->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(result->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + int voc = probs->shape[0]; + int rLength = result->shape[0]; + if (result->ndim != 1 && rLength != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + *desc_ptr = new RandomSampleMusaDescriptor{ + handle->device, + handle->device_id, + probs->dt, + voc, + result->dt, + rLength}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, unsigned long int *size) { + *size = desc->voc * (2 * sizeof(uint64_t) + sizeof(desc->dtype)); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyRandomSampleDescriptor(RandomSampleMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/random_sample/musa/random_sample_musa.h b/src/ops/random_sample/musa/random_sample_musa.h new file mode 100644 index 00000000..493cd3f4 --- /dev/null +++ b/src/ops/random_sample/musa/random_sample_musa.h @@ -0,0 +1,38 @@ +#ifndef __MUSA_RANDOM_SAMPLE_H__ +#define __MUSA_RANDOM_SAMPLE_H__ + +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" + +struct RandomSampleMusaDescriptor { + Device device; + int device_id; + DT dtype; + int voc; + DT rDtype; + int rLength; +}; + +typedef struct RandomSampleMusaDescriptor *RandomSampleMusaDescriptor_t; + +infiniopStatus_t musaCreateRandomSampleDescriptor(MusaHandle_t handle, + RandomSampleMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs); + +infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t musaRandomSample(RandomSampleMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream); + +infiniopStatus_t musaDestroyRandomSampleDescriptor(RandomSampleMusaDescriptor_t desc); + + +#endif diff --git a/src/ops/random_sample/musa/random_sample_musa.mu b/src/ops/random_sample/musa/random_sample_musa.mu new file mode 100644 index 00000000..c8000098 --- /dev/null +++ b/src/ops/random_sample/musa/random_sample_musa.mu @@ -0,0 +1,180 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "random_sample_musa.h" +#include +#include + +template +__global__ void softmax( + T *val_out, + int topk, + float temperature, int voc) { + float sum_s = 0.0f; + for (int i = threadIdx.x; i < topk; i += BLOCK_DIM) { + sum_s += __expf(static_cast(val_out[i] - val_out[0]) / temperature); + } + __shared__ float sum_inverse_total; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float block_sum = BlockReduce(temp_storage).Reduce(sum_s, cub::Sum()); + if (threadIdx.x == 0) { + sum_inverse_total = __fdividef(1.0F, block_sum);//高精度除法 + } + + __syncthreads(); + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < topk) { + val_out[tid] = static_cast(__expf(static_cast(val_out[tid] - val_out[0]) / temperature) * sum_inverse_total); + } +} + +__global__ void index(uint64_t *key_in, int voc) { + int ind = threadIdx.x + blockIdx.x * blockDim.x; + if (ind < voc) { + key_in[ind] = static_cast(ind); + } +} +template +__global__ void random_sample_kernel(uint64_t *result, + T *val_out, + float random_val, + float topp, + int topk, + uint64_t *key_out) { + int end = 0; + for (end = 0; end < topk; end++) { + if (val_out[end] >= static_cast(topp)) { + break; + } + } + if (end < topk - 1) { + end += 1; + } else { + end = topk; + } + + random_val *= static_cast(val_out[end - 1]); + for (int i = 0; i < end; i++) { + if (random_val < static_cast(val_out[i])) { + result[0] = key_out[i]; + break; + } + } +} +template +void sort_pairs_descending( + void *workspace, size_t &size_radix_sort, + T const *val_in, T *val_out, + I *key_in, I *key_out, + int voc, musaStream_t stream) { + cub::DeviceRadixSort::SortPairsDescending( + workspace, size_radix_sort, + val_in, val_out, + key_in, key_out, + voc, 0, sizeof(T) * 8, stream); +} +template +void inclusive_sum( + void *workspace, size_t &size_scan, + T *data, int voc, + musaStream_t stream) { + cub::DeviceScan::InclusiveSum( + workspace, size_scan, + data, data, voc, + stream); +} +template +void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan, + int voc, musaStream_t stream) { + + + sort_pairs_descending(nullptr, size_radix_sort, + nullptr, nullptr, + nullptr, nullptr, + voc, stream); + + inclusive_sum( + nullptr, size_scan, + nullptr, voc, + stream); +} +__global__ void random_sample_kernel(uint64_t *result, + uint64_t *key_out) { + result[0] = key_out[0]; +} +void random_sample_nv_gpu_f16(RandomSampleMusaDescriptor_t desc, void *workspace, void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) { + int voc = desc->voc; + //下面这段代码在排序 + char *origin = reinterpret_cast(workspace); + char *keyTmp = origin + voc * sizeof(half); + half *val_out = (half *) origin; + + uint64_t *key_in = (uint64_t *) keyTmp; + uint64_t *key_out = key_in + voc; + + index<<<(voc + 1023) / 1024, 1024, 0, (musaStream_t) stream>>>(key_in, voc); + //下面开始计算workspace空间 + size_t size_radix_sort; + size_t size_scan; + random_sample_workspace(size_radix_sort, size_scan, + voc, (musaStream_t) stream); + void *workspace_extra; + musaMalloc(&workspace_extra, size_radix_sort + size_scan); + sort_pairs_descending( + workspace_extra, size_radix_sort, + (half *) probs, val_out, + key_in, key_out, + voc, (musaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上 + //排序结束,然后开始做softmax变换 + if (topp > 0 && topk > 1) { + int BLOCK_DIM = 1024; + int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM; + softmax<<>>(val_out, topk, + temperature, voc); + + + inclusive_sum( + workspace_extra, size_scan, + val_out, voc, + (musaStream_t) stream);//该函数会实现scan功能不断累加结果 + random_sample_kernel<<<1, 1, 0, (musaStream_t) stream>>>((uint64_t *) result, + val_out, + random_val, + topp, + topk, + key_out); + + } else { + random_sample_kernel<<<1, 1, 0, (musaStream_t) stream>>>((uint64_t *) result, + key_out); + } + musaFree(workspace_extra); +} + +infiniopStatus_t musaRandomSample(RandomSampleMusaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) { +// if (musaSetDevice(desc->device_id) != musaSuccess) { +// return STATUS_BAD_DEVICE; +// } + if (dtype_eq(desc->dtype, F16)) { + random_sample_nv_gpu_f16(desc, workspace, result, probs, random_val, topp, topk, temperature, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/random_sample/operator.cc b/src/ops/random_sample/operator.cc index b9cf3ded..f335b14f 100644 --- a/src/ops/random_sample/operator.cc +++ b/src/ops/random_sample/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_METAX_GPU #include "maca/random_sample_maca.h" #endif +#ifdef ENABLE_MT_GPU +#include "musa/random_sample_musa.h" +#endif __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handle, infiniopRandomSampleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs) { switch (handle->device) { @@ -47,6 +50,10 @@ __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handl (RandomSampleMacaDescriptor_t *) desc_ptr, result, probs); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: + return musaCreateRandomSampleDescriptor((MusaHandle_t) handle, (RandomSampleMusaDescriptor_t *) desc_ptr, result, probs); #endif } return STATUS_BAD_DEVICE; @@ -79,6 +86,11 @@ __C infiniopStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDe case DevMetaxGpu: { return macaGetRandomSampleWorkspaceSize((RandomSampleMacaDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaGetRandomSampleWorkspaceSize((RandomSampleMusaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -117,6 +129,10 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc, case DevMetaxGpu: { return macaRandomSample((RandomSampleMacaDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: + return musaRandomSample((RandomSampleMusaDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); #endif } return STATUS_BAD_DEVICE; @@ -146,6 +162,10 @@ __C infiniopStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleD case DevMetaxGpu: { return macaDestroyRandomSampleDescriptor((RandomSampleMacaDescriptor_t) desc); } +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: + return musaDestroyRandomSampleDescriptor((RandomSampleMusaDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; From 19565fbdfc54423594e7fa974d896494b3787a52 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 29 Nov 2024 17:15:08 +0800 Subject: [PATCH 09/15] =?UTF-8?q?=E6=91=A9=E5=B0=94=EF=BC=9Asetdevice?= =?UTF-8?q?=E4=B9=8B=E5=89=8D=E8=BF=9B=E8=A1=8C=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/rms_norm.py | 2 +- src/devices/musa/musa_handle.cc | 10 +++++++--- src/devices/musa/musa_handle.h | 7 ++++++- src/ops/causal_softmax/musa/causal_softmax_musa.mu | 10 +++++++--- src/ops/random_sample/musa/random_sample_musa.mu | 10 +++++++--- src/ops/rearrange/musa/rearrange_musa.mu | 10 +++++++--- src/ops/rms_norm/musa/rms_norm_musa.mu | 10 +++++++--- 7 files changed, 42 insertions(+), 17 deletions(-) diff --git a/operatorspy/tests/rms_norm.py b/operatorspy/tests/rms_norm.py index a11b794f..46b1d0f3 100644 --- a/operatorspy/tests/rms_norm.py +++ b/operatorspy/tests/rms_norm.py @@ -184,6 +184,6 @@ def test_musa(lib, test_cases): test_maca(lib, test_cases) if args.musa: test_musa(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/devices/musa/musa_handle.cc b/src/devices/musa/musa_handle.cc index 00f43e9d..bc40560a 100644 --- a/src/devices/musa/musa_handle.cc +++ b/src/devices/musa/musa_handle.cc @@ -8,9 +8,13 @@ infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) { return STATUS_BAD_DEVICE; } - // if (musaSetDevice(device_id) != musaSuccess){ - // return STATUS_BAD_DEVICE; - // } + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != device_id && musaSetDevice(device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } auto mublas_pool = std::make_shared>(); mublasHandle_t *mublas_handle = new mublasHandle_t; diff --git a/src/devices/musa/musa_handle.h b/src/devices/musa/musa_handle.h index f91caba8..9c1842ee 100644 --- a/src/devices/musa/musa_handle.h +++ b/src/devices/musa/musa_handle.h @@ -7,6 +7,7 @@ #include "ops/matmul/matmul.h" #include #include +#include #include #include @@ -25,7 +26,11 @@ template void use_mublas(std::shared_ptr> mublas_handles_t, int device_id, MUstream stream, T const &f) { mublasHandle_t *handle = mublas_handles_t->pop(); if (!handle) { - // musaSetDevice(device_id); + int current_device; + musaGetDevice(¤t_device); + if (current_device != device_id) { + musaSetDevice(device_id); + } mublasHandle_t *handle = new mublasHandle_t; mublasCreate(handle); } diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.mu b/src/ops/causal_softmax/musa/causal_softmax_musa.mu index 3bb92ad4..8957134b 100644 --- a/src/ops/causal_softmax/musa/causal_softmax_musa.mu +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.mu @@ -246,9 +246,13 @@ infiniopStatus_t musaCausalSoftmax(CausalSoftmaxMusaDescriptor_t desc, uint64_t workspace_size, void *data, void *stream) { -// if(musaSetDevice(desc->device_id) != musaSuccess){ -// return STATUS_BAD_DEVICE; -// } + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } if (dtype_eq(desc->dtype, F16)) { causal_softmax_mt_gpu_f16(desc, data, stream); return STATUS_SUCCESS; diff --git a/src/ops/random_sample/musa/random_sample_musa.mu b/src/ops/random_sample/musa/random_sample_musa.mu index c8000098..55dbdd0a 100644 --- a/src/ops/random_sample/musa/random_sample_musa.mu +++ b/src/ops/random_sample/musa/random_sample_musa.mu @@ -168,9 +168,13 @@ infiniopStatus_t musaRandomSample(RandomSampleMusaDescriptor_t desc, int topk, float temperature, void *stream) { -// if (musaSetDevice(desc->device_id) != musaSuccess) { -// return STATUS_BAD_DEVICE; -// } + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } if (dtype_eq(desc->dtype, F16)) { random_sample_nv_gpu_f16(desc, workspace, result, probs, random_val, topp, topk, temperature, stream); return STATUS_SUCCESS; diff --git a/src/ops/rearrange/musa/rearrange_musa.mu b/src/ops/rearrange/musa/rearrange_musa.mu index ee094869..77489add 100644 --- a/src/ops/rearrange/musa/rearrange_musa.mu +++ b/src/ops/rearrange/musa/rearrange_musa.mu @@ -61,9 +61,13 @@ void rearrange_mt_gpu(RearrangeMusaDescriptor_t desc, void *y, void const *x, vo } infiniopStatus_t musaRearrange(RearrangeMusaDescriptor_t desc, void *dst, void const *src, void *stream) { -// if(musaSetDevice(desc->device_id) != musaSuccess){ -// return STATUS_BAD_DEVICE; -// } + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } rearrange_mt_gpu(desc, dst, src, stream); return STATUS_SUCCESS; } diff --git a/src/ops/rms_norm/musa/rms_norm_musa.mu b/src/ops/rms_norm/musa/rms_norm_musa.mu index c023b8b7..0b1837ad 100644 --- a/src/ops/rms_norm/musa/rms_norm_musa.mu +++ b/src/ops/rms_norm/musa/rms_norm_musa.mu @@ -161,9 +161,13 @@ infiniopStatus_t musaRMSNorm(RMSNormMusaDescriptor_t desc, unsigned long int workspace_size, void *y, void const *x, void const *w, void *stream){ -// if(musaSetDevice(desc->device_id) != musaSuccess){ -// return STATUS_BAD_DEVICE; -// } + int current_device; + if (musaGetDevice(¤t_device) != musaSuccess) { + return STATUS_BAD_DEVICE; + } + if (current_device != desc->device_id && musaSetDevice(desc->device_id) != musaSuccess) { + return STATUS_BAD_DEVICE; + } if (dtype_eq(desc->dtype, F16)){ rms_norm_mt_gpu_f16(desc, y, x, w, stream); return STATUS_SUCCESS; From 37c4f545b7ee33d922ce36a0a0857aee789087d2 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Thu, 5 Dec 2024 15:41:51 +0800 Subject: [PATCH 10/15] =?UTF-8?q?=E6=91=A9=E5=B0=94=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20Add=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/add.py | 14 +++- src/devices/musa/common_musa.h | 45 ++++++++++++- src/devices/musa/utils.cc | 17 ----- src/ops/add/musa/add_musa.cc | 81 +++++++++++++++++++++++ src/ops/add/musa/add_musa.h | 37 +++++++++++ src/ops/add/musa/add_musa.mu | 116 +++++++++++++++++++++++++++++++++ src/ops/add/operator.cc | 18 +++++ 7 files changed, 307 insertions(+), 21 deletions(-) delete mode 100644 src/devices/musa/utils.cc create mode 100644 src/ops/add/musa/add_musa.cc create mode 100644 src/ops/add/musa/add_musa.h create mode 100644 src/ops/add/musa/add_musa.mu diff --git a/operatorspy/tests/add.py b/operatorspy/tests/add.py index 455014cc..da9c58c9 100644 --- a/operatorspy/tests/add.py +++ b/operatorspy/tests/add.py @@ -115,6 +115,16 @@ def test_bang(lib, test_cases): test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for c_shape, a_shape, b_shape, inplace in test_cases: + test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace) + test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -163,6 +173,8 @@ def test_bang(lib, test_cases): test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) - if not (args.cpu or args.cuda or args.bang): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/devices/musa/common_musa.h b/src/devices/musa/common_musa.h index bfed9900..02d97330 100644 --- a/src/devices/musa/common_musa.h +++ b/src/devices/musa/common_musa.h @@ -1,6 +1,16 @@ #ifndef __COMMON_MUSA_H__ #define __COMMON_MUSA_H__ +#define MAX_THREADS_PER_BLOCK 1024 +#define MAX_WARP_PER_BLOCK 32 +#define WARP_SIZE 32 + +#include +#include "data_type.h" +#include +#include +#include + enum class Type { QINT4, QINT8, @@ -31,8 +41,37 @@ enum class Format { DHWCN, }; -#define MAX_THREADS_PER_BLOCK 1024 -#define MAX_WARP_PER_BLOCK 32 -#define WARP_SIZE 32 +#define checkMusaErrorWithCode(call, errorCode) \ + do { \ + if (auto status = call; status != musaSuccess) { \ + std::cerr << "MUSA error: " << musaGetErrorString(status) \ + << " in file " << __FILE__ \ + << ", function " << __func__ \ + << ", line " << __LINE__ << std::endl; \ + return errorCode; \ + } \ + } while (0) + +#define checkMusaError(call) checkMusaErrorWithCode(call, STATUS_BAD_DEVICE) + +// get the corresponding offset in the destination given the flat index of the source (for element mapping in shape broadcast) +inline __device__ uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { + uint64_t res = 0; + for (uint64_t i = 0; i < ndim; ++i) { + res += flat_index / src_strides[i] * dst_strides[i]; + flat_index %= src_strides[i]; + } + return res; +} + +// get the memory offset of the given element in a tensor given its flat index +inline __device__ uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { + uint64_t res = 0; + for (long i = ndim - 1; i >= 0; --i) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} #endif // __COMMON_MUSA_H__ \ No newline at end of file diff --git a/src/devices/musa/utils.cc b/src/devices/musa/utils.cc deleted file mode 100644 index 466fcf7d..00000000 --- a/src/devices/musa/utils.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "data_type.h" - -DT get_F16() { - return F16; -} - -DT get_F32() { - return F32; -} - -DT get_U32() { - return U32; -} - -DT get_U64() { - return U64; -} \ No newline at end of file diff --git a/src/ops/add/musa/add_musa.cc b/src/ops/add/musa/add_musa.cc new file mode 100644 index 00000000..21fbbdd1 --- /dev/null +++ b/src/ops/add/musa/add_musa.cc @@ -0,0 +1,81 @@ +#include "add_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateAddDescriptor(MusaHandle_t handle, + AddMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t a, + infiniopTensorDescriptor_t b) { + uint64_t ndim = c->ndim; + if (!isValidBroadcastShape(a, b, c)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!is_contiguous(a) || !is_contiguous(b) || !is_contiguous(c)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (c->dt != F16 && c->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (c->dt != a->dt || c->dt != b->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + bool broadcasted = false; + if (ndim != a->ndim || ndim != b->ndim) { + broadcasted = true; + } else { + for (uint64_t i = 0; i < ndim; ++i) { + if (c->shape[i] != a->shape[i] || c->shape[i] != b->shape[i]) { + broadcasted = true; + break; + } + } + } + + uint64_t c_data_size = std::accumulate(c->shape, c->shape + c->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for a and b + int64_t *a_strides = new int64_t[ndim]; + int64_t *b_strides = new int64_t[ndim]; + for (size_t i = 0; i < ndim; ++i) { + a_strides[i] = (i < ndim - a->ndim || c->shape[i] != a->shape[i + a->ndim - ndim]) ? 0 : a->strides[i + a->ndim - ndim]; + b_strides[i] = (i < ndim - b->ndim || c->shape[i] != b->shape[i + b->ndim - ndim]) ? 0 : b->strides[i + b->ndim - ndim]; + } + + musaDeviceProp prop; + musaGetDeviceProperties(&prop, handle->device_id); + + int64_t *a_strides_d, *b_strides_d, *c_strides_d; + checkMusaErrorWithCode(musaMalloc(&a_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMalloc(&b_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMalloc(&c_strides_d, ndim * sizeof(int64_t)), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMemcpy(a_strides_d, a_strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(b_strides_d, b_strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(c_strides_d, c->strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new AddMusaDescriptor{ + DevMtGpu, + c->dt, + handle->device_id, + ndim, + c_data_size, + static_cast(prop.maxGridSize[0]), + a_strides_d, + b_strides_d, + c_strides_d, + broadcasted, + }; + + delete[] a_strides; + delete[] b_strides; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyAddDescriptor(AddMusaDescriptor_t desc) { + checkMusaErrorWithCode(musaFree((void *) desc->a_strides), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaFree((void *) desc->b_strides), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaFree((void *) desc->c_strides), STATUS_EXECUTION_FAILED); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/add/musa/add_musa.h b/src/ops/add/musa/add_musa.h new file mode 100644 index 00000000..c492c45c --- /dev/null +++ b/src/ops/add/musa/add_musa.h @@ -0,0 +1,37 @@ +#ifndef __MUSA_ADD_H__ +#define __MUSA_ADD_H__ + +#include "../../../devices/musa/common_musa.h" +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" +#include +#include + +struct AddMusaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t c_data_size; + uint64_t max_grid_size; + int64_t const *a_strides; + int64_t const *b_strides; + int64_t const *c_strides; + bool broadcasted; +}; + +typedef struct AddMusaDescriptor *AddMusaDescriptor_t; + +infiniopStatus_t musaCreateAddDescriptor(MusaHandle_t, + AddMusaDescriptor_t *, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t a, + infiniopTensorDescriptor_t b); + +infiniopStatus_t musaAdd(AddMusaDescriptor_t desc, + void *c, void const *a, void const *b, + void *stream); + +infiniopStatus_t musaDestroyAddDescriptor(AddMusaDescriptor_t desc); + +#endif diff --git a/src/ops/add/musa/add_musa.mu b/src/ops/add/musa/add_musa.mu new file mode 100644 index 00000000..0766aa7c --- /dev/null +++ b/src/ops/add/musa/add_musa.mu @@ -0,0 +1,116 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "add_musa.h" + +/** + * @brief A templated vector struct that supports element-wise addition on arrays. + * + * @tparam T - The access data type for elements in the vector. + * @tparam TComp - The computation data type used for arithmetic operations. + * @tparam N - The number of elements of type T in the vector for a single access. + */ +template +struct vecN { + T data[N]; + + __device__ __forceinline__ vecN operator+(const vecN &other) const { + vecN result; + + for (int i = 0; i < N; ++i) { + if constexpr (std::is_same::value) { + result.data[i] = data[i] + other.data[i]; + } else { + constexpr static size_t pack_size = sizeof(T) / sizeof(TComp); + auto data_ = reinterpret_cast *>(result.data); + data_[i] = std::move(reinterpret_cast const *>(data)[i] + + reinterpret_cast const *>(other.data)[i]); + } + } + + return result; + } + + __device__ __forceinline__ const T &operator[](size_t i) const { + return data[i]; + } +}; + +template +__global__ void add( + Tdata *c, + const Tdata *a, + const Tdata *b, + const int64_t *a_strides, + const int64_t *b_strides, + const int64_t *c_strides, + uint64_t data_size, + uint64_t ndim, + uint64_t offset, + bool broadcasted, + unsigned pack_size) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < data_size) { + if (broadcasted) { + idx *= pack_size; + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); +#pragma unroll + for (size_t i = 0; i < pack_size; ++i) { + auto a_idx = getDstOffset(idx + i, ndim, c_strides, a_strides); + auto b_idx = getDstOffset(idx + i, ndim, c_strides, b_strides); + c_[idx + i] = a_[a_idx] + b_[b_idx]; + } + return; + } + c[idx] = a[idx] + b[idx]; + } +} + +template +void _add_nv_gpu(AddMusaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) { + if (data_size == 0) { + return; + } + dim3 blockDims = dim3(std::min(static_cast(256), data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + musaStream_t musa_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < data_size; i += step) { + add<<>>( + c, a, b, desc->a_strides, desc->b_strides, desc->c_strides, offset + data_size, desc->ndim, offset + i, desc->broadcasted, pack_size); + } +} + +template +infiniopStatus_t add_mt_gpu(AddMusaDescriptor_t desc, void *c, void const *a, void const *b, void *stream, uint64_t pack_size) { + const auto data_size = desc->c_data_size / pack_size; + const auto a_vec = reinterpret_cast(a); + const auto b_vec = reinterpret_cast(b); + const auto c_vec = reinterpret_cast(c); + _add_nv_gpu(desc, c_vec, a_vec, b_vec, data_size, pack_size, 0, stream); + + const auto remainder = desc->c_data_size % pack_size; + const auto a_ = reinterpret_cast(a); + const auto b_ = reinterpret_cast(b); + const auto c_ = reinterpret_cast(c); + _add_nv_gpu(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaAdd(AddMusaDescriptor_t desc, + void *c, void const *a, void const *b, + void *stream) { + checkMusaError(musaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return add_mt_gpu, half>(desc, c, a, b, stream, 8); + } + if (desc->dtype == F32) { + return add_mt_gpu, float>(desc, c, a, b, stream, 4); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/add/operator.cc b/src/ops/add/operator.cc index c2a30ea8..9d090243 100644 --- a/src/ops/add/operator.cc +++ b/src/ops/add/operator.cc @@ -9,6 +9,9 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/add.cuh" #endif +#ifdef ENABLE_MT_GPU +#include "musa/add_musa.h" +#endif __C infiniopStatus_t infiniopCreateAddDescriptor( infiniopHandle_t handle, @@ -29,6 +32,11 @@ __C infiniopStatus_t infiniopCreateAddDescriptor( #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateAddDescriptor((MusaHandle_t) handle, (AddMusaDescriptor_t *) desc_ptr, c, a, b); + } #endif } return STATUS_BAD_DEVICE; @@ -48,6 +56,11 @@ __C infiniopStatus_t infiniopAdd(infiniopAddDescriptor_t desc, void *c, void con #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaAdd((AddMusaDescriptor_t) desc, c, a, b, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -67,6 +80,11 @@ __C infiniopStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaDestroyAddDescriptor((AddMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; From 251ca48597e39805e5f7bd913aaec80c28d1f7e8 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Thu, 5 Dec 2024 16:05:10 +0800 Subject: [PATCH 11/15] =?UTF-8?q?=E6=91=A9=E5=B0=94=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20expand=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/expand.py | 14 +++++++- src/devices/musa/musa_handle.cc | 7 +++- src/devices/musa/musa_handle.h | 1 + src/ops/expand/musa/expand_musa.cc | 51 ++++++++++++++++++++++++++ src/ops/expand/musa/expand_musa.h | 33 +++++++++++++++++ src/ops/expand/musa/expand_musa.mu | 58 ++++++++++++++++++++++++++++++ src/ops/expand/operator.cc | 19 ++++++++++ 7 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 src/ops/expand/musa/expand_musa.cc create mode 100644 src/ops/expand/musa/expand_musa.h create mode 100644 src/ops/expand/musa/expand_musa.mu diff --git a/operatorspy/tests/expand.py b/operatorspy/tests/expand.py index e060ad73..87365c05 100644 --- a/operatorspy/tests/expand.py +++ b/operatorspy/tests/expand.py @@ -133,6 +133,16 @@ def test_bang(lib, test_cases): test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for y_shape, x_shape, y_stride, x_stride in test_cases: + test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16) + test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -174,6 +184,8 @@ def test_bang(lib, test_cases): test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) - if not (args.cpu or args.cuda or args.bang): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/devices/musa/musa_handle.cc b/src/devices/musa/musa_handle.cc index bc40560a..cd242114 100644 --- a/src/devices/musa/musa_handle.cc +++ b/src/devices/musa/musa_handle.cc @@ -16,12 +16,17 @@ infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) { return STATUS_BAD_DEVICE; } + // set CUDA device property + musaDeviceProp prop; + musaGetDeviceProperties(&prop, device_id); + + auto mublas_pool = std::make_shared>(); mublasHandle_t *mublas_handle = new mublasHandle_t; mublasCreate(mublas_handle); mublas_pool->push(mublas_handle); - *handle_ptr = new MusaContext{DevMtGpu, device_id, std::move(mublas_pool)}; + *handle_ptr = new MusaContext{DevMtGpu, device_id, std::move(mublas_pool), std::move(prop)}; return STATUS_SUCCESS; } diff --git a/src/devices/musa/musa_handle.h b/src/devices/musa/musa_handle.h index 9c1842ee..fed050d8 100644 --- a/src/devices/musa/musa_handle.h +++ b/src/devices/musa/musa_handle.h @@ -15,6 +15,7 @@ struct MusaContext { Device device; int device_id; std::shared_ptr> mublas_handles_t; + musaDeviceProp prop; }; typedef struct MusaContext *MusaHandle_t; diff --git a/src/ops/expand/musa/expand_musa.cc b/src/ops/expand/musa/expand_musa.cc new file mode 100644 index 00000000..02980d71 --- /dev/null +++ b/src/ops/expand/musa/expand_musa.cc @@ -0,0 +1,51 @@ +#include "expand_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateExpandDescriptor(MusaHandle_t handle, + ExpandMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (!isValidBroadcastShape(y, x)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t y_data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + // get the adjusted strides for x in terms of y + int64_t *x_strides = new int64_t[ndim]; + for (size_t i = 0; i < ndim; ++i) { + x_strides[i] = (i < ndim - x->ndim || y->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; + } + + int64_t *x_strides_d, *y_strides_d; + char *strides_and_shape_d; + checkMusaErrorWithCode(musaMalloc(&strides_and_shape_d, ndim * (2 * sizeof(int64_t) + sizeof(uint64_t))), STATUS_MEMORY_NOT_ALLOCATED); + checkMusaErrorWithCode(musaMemcpy(strides_and_shape_d, x_strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(strides_and_shape_d + ndim * sizeof(int64_t), y->strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + checkMusaErrorWithCode(musaMemcpy(strides_and_shape_d + 2 * ndim * sizeof(int64_t), y->shape, ndim * sizeof(uint64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); + + *desc_ptr = new ExpandMusaDescriptor{ + DevMtGpu, + y->dt, + handle->device_id, + ndim, + y_data_size, + static_cast(handle->prop.maxGridSize[0]), + strides_and_shape_d, + }; + + delete[] x_strides; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyExpandDescriptor(ExpandMusaDescriptor_t desc) { + checkMusaErrorWithCode(musaFree((void *) desc->strides_and_shape_d), STATUS_EXECUTION_FAILED); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/expand/musa/expand_musa.h b/src/ops/expand/musa/expand_musa.h new file mode 100644 index 00000000..8e4651e1 --- /dev/null +++ b/src/ops/expand/musa/expand_musa.h @@ -0,0 +1,33 @@ +#ifndef __MUSA_EXPAND_H__ +#define __MUSA_EXPAND_H__ + +#include "../../../devices/musa/common_musa.h" +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" +#include +#include + +struct ExpandMusaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t y_data_size; + uint64_t max_grid_size; + char const *strides_and_shape_d; +}; + +typedef struct ExpandMusaDescriptor *ExpandMusaDescriptor_t; + +infiniopStatus_t musaCreateExpandDescriptor(MusaHandle_t, + ExpandMusaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t musaExpand(ExpandMusaDescriptor_t desc, + void *y, void const *x, + void *stream); + +infiniopStatus_t musaDestroyExpandDescriptor(ExpandMusaDescriptor_t desc); + +#endif diff --git a/src/ops/expand/musa/expand_musa.mu b/src/ops/expand/musa/expand_musa.mu new file mode 100644 index 00000000..4b549541 --- /dev/null +++ b/src/ops/expand/musa/expand_musa.mu @@ -0,0 +1,58 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "expand_musa.h" + +template +__global__ void expand( + Tdata *y, + const Tdata *x, + const int64_t *y_strides, + const int64_t *x_strides, + const uint64_t *y_shape, + uint64_t y_data_size, + uint64_t ndim, + uint64_t offset) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < y_data_size) { + uint64_t y_idx = getOffset(idx, ndim, y_shape, y_strides); + y[y_idx] = x[getDstOffset(y_idx, ndim, y_strides, x_strides)]; + } +} + +template +infiniopStatus_t expand_mt_gpu(ExpandMusaDescriptor_t desc, void *y, void const *x, void *stream) { + if (desc->y_data_size == 0) { + return STATUS_SUCCESS; + } + dim3 blockDims = dim3(std::min(static_cast(256), desc->y_data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(desc->y_data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + const auto x_ = reinterpret_cast(x); + const auto y_ = reinterpret_cast(y); + const auto x_strides = reinterpret_cast(desc->strides_and_shape_d); + const auto y_strides = reinterpret_cast(desc->strides_and_shape_d + desc->ndim * sizeof(int64_t)); + const auto y_shape = reinterpret_cast(desc->strides_and_shape_d + 2 * desc->ndim * sizeof(int64_t)); + musaStream_t musa_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < desc->y_data_size; i += step) { + expand<<>>( + y_, x_, y_strides, x_strides, y_shape, i + desc->y_data_size, desc->ndim, i); + } + return STATUS_SUCCESS; +} + +infiniopStatus_t musaExpand(ExpandMusaDescriptor_t desc, + void *y, void const *x, + void *stream) { + checkMusaError(musaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return expand_mt_gpu(desc, y, x, stream); + } + if (desc->dtype == F32) { + return expand_mt_gpu(desc, y, x, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/expand/operator.cc b/src/ops/expand/operator.cc index 0572acd0..f5852e46 100644 --- a/src/ops/expand/operator.cc +++ b/src/ops/expand/operator.cc @@ -9,6 +9,10 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/expand.cuh" #endif +#ifdef ENABLE_MT_GPU +#include "musa/expand_musa.h" +#endif + __C infiniopStatus_t infiniopCreateExpandDescriptor( infiniopHandle_t handle, @@ -28,6 +32,11 @@ __C infiniopStatus_t infiniopCreateExpandDescriptor( #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateExpandDescriptor((MusaHandle_t) handle, (ExpandMusaDescriptor_t *) desc_ptr, y, x); + } #endif } return STATUS_BAD_DEVICE; @@ -47,6 +56,11 @@ __C infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc, void *y, vo #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaExpand((ExpandMusaDescriptor_t) desc, y, x, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -66,6 +80,11 @@ __C infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaDestroyExpandDescriptor((ExpandMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; From 1c142a8b31af35b2990ca46361bf708796d6420c Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Thu, 5 Dec 2024 16:30:56 +0800 Subject: [PATCH 12/15] =?UTF-8?q?=E6=91=A9=E5=B0=94=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20relu=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- operatorspy/tests/relu.py | 14 ++++- src/ops/relu/musa/relu_musa.cc | 45 +++++++++++++ src/ops/relu/musa/relu_musa.h | 32 ++++++++++ src/ops/relu/musa/relu_musa.mu | 111 +++++++++++++++++++++++++++++++++ src/ops/relu/operator.cc | 19 ++++++ 5 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 src/ops/relu/musa/relu_musa.cc create mode 100644 src/ops/relu/musa/relu_musa.h create mode 100644 src/ops/relu/musa/relu_musa.mu diff --git a/operatorspy/tests/relu.py b/operatorspy/tests/relu.py index b7f76627..b99706ff 100644 --- a/operatorspy/tests/relu.py +++ b/operatorspy/tests/relu.py @@ -132,6 +132,16 @@ def test_bang(lib, test_cases): test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) destroy_handle(lib, handle) +def test_musa(lib, test_cases): + import torch_musa + + device = DeviceEnum.DEVICE_MUSA + handle = create_handle(lib, device) + for tensor_shape, inplace in test_cases: + test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float16, inplace=inplace) + test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float32, inplace=inplace) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -172,6 +182,8 @@ def test_bang(lib, test_cases): test_cuda(lib, test_cases) if args.bang: test_bang(lib, test_cases) - if not (args.cpu or args.cuda or args.bang): + if args.musa: + test_musa(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/src/ops/relu/musa/relu_musa.cc b/src/ops/relu/musa/relu_musa.cc new file mode 100644 index 00000000..3e3c35fe --- /dev/null +++ b/src/ops/relu/musa/relu_musa.cc @@ -0,0 +1,45 @@ +#include "relu_musa.h" +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" + +infiniopStatus_t musaCreateReluDescriptor(MusaHandle_t handle, + ReluMusaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x) { + uint64_t ndim = y->ndim; + if (ndim != x->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (size_t i = 0; i < ndim; ++i) { + if (y->shape[i] != x->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!is_contiguous(y) || !is_contiguous(x)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); + + *desc_ptr = new ReluMusaDescriptor{ + DevMtGpu, + y->dt, + handle->device_id, + ndim, + data_size, + static_cast(handle->prop.maxGridSize[0]), + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t musaDestroyReluDescriptor(ReluMusaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/relu/musa/relu_musa.h b/src/ops/relu/musa/relu_musa.h new file mode 100644 index 00000000..84276369 --- /dev/null +++ b/src/ops/relu/musa/relu_musa.h @@ -0,0 +1,32 @@ +#ifndef __MUSA_RELU_H__ +#define __MUSA_RELU_H__ + +#include "../../../devices/musa/common_musa.h" +#include "../../../devices/musa/musa_handle.h" +#include "operators.h" +#include +#include + +struct ReluMusaDescriptor { + Device device; + DT dtype; + int device_id; + uint64_t ndim; + uint64_t data_size; + uint64_t max_grid_size; +}; + +typedef struct ReluMusaDescriptor *ReluMusaDescriptor_t; + +infiniopStatus_t musaCreateReluDescriptor(MusaHandle_t, + ReluMusaDescriptor_t *, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +infiniopStatus_t musaRelu(ReluMusaDescriptor_t desc, + void *y, void const *x, + void *stream); + +infiniopStatus_t musaDestroyReluDescriptor(ReluMusaDescriptor_t desc); + +#endif diff --git a/src/ops/relu/musa/relu_musa.mu b/src/ops/relu/musa/relu_musa.mu new file mode 100644 index 00000000..3d91b4e2 --- /dev/null +++ b/src/ops/relu/musa/relu_musa.mu @@ -0,0 +1,111 @@ +#include "../../../devices/musa/common_musa.h" +#include "../../utils.h" +#include "relu_musa.h" + +/** + * @brief A templated vector struct that supports applying relu on arrays. + * + * @tparam T - The access data type for elements in the vector. + * @tparam TComp - The computation data type used for arithmetic operations. sizeof(T) should + * be >= sizeof(TComp) + * @tparam N - The number of elements of type T in the vector for a single access. + */ +template +struct vecN { + T data[N]; + constexpr static size_t pack_size = sizeof(T) / sizeof(TComp); + + // Constructor that initializes the data array with type TComp + __device__ __forceinline__ constexpr vecN(const TComp &val) { + const auto data_ = reinterpret_cast(data); + const auto size = N * pack_size; +#pragma unroll + for (size_t i = 0; i < size; ++i) { + data_[i] = 0; + } + } + + // Assignment operator with relu assignment logic + __device__ __forceinline__ vecN &operator=(const vecN &other) { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < N; ++i) { + data[i] = other.data[i] < TComp(0) ? TComp(0) : other.data[i]; + } + } else { + auto *data_this = reinterpret_cast *>(data); + auto *data_other = reinterpret_cast *>(other.data); +#pragma unroll + for (int i = 0; i < N; ++i) { + data_this[i] = data_other[i]; + } + } + return *this; + } + + // Always returns false since the actual relu logic is in the assignment process + __device__ __forceinline__ bool operator<(const vecN &other) const { + return false; + } + + __device__ __forceinline__ const T &operator[](size_t i) const { + return data[i]; + } +}; + +template +__global__ void relu( + Tdata *y, + const Tdata *x, + uint64_t data_size, + uint64_t offset) { + uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + + if (idx < data_size) { + y[idx] = x[idx] < Tdata(0) ? Tdata(0) : x[idx]; + } +} + +template +void relu_mt_gpu(ReluMusaDescriptor_t desc, Tdata *y, Tdata const *x, uint64_t data_size, uint64_t offset, void *stream) { + if (data_size == 0) { + return; + } + dim3 blockDims = dim3(std::min(static_cast(256), data_size)); + dim3 gridDims = dim3(std::min(ROUND_UP_DIV(data_size, blockDims.x), desc->max_grid_size)); + uint64_t step = gridDims.x * blockDims.x; + + musaStream_t musa_stream = reinterpret_cast(stream); + +#pragma unroll + for (uint64_t i = 0; i < data_size; i += step) { + relu<<>>(y, x, offset + data_size, offset + i); + } +} + +template +infiniopStatus_t relu_mt_gpu(ReluMusaDescriptor_t desc, void *y, void const *x, void *stream, uint64_t pack_size) { + const auto data_size = desc->data_size / pack_size; + const auto x_vec = reinterpret_cast(x); + const auto y_vec = reinterpret_cast(y); + relu_mt_gpu(desc, y_vec, x_vec, data_size, 0, stream); + + const auto remainder = desc->data_size % pack_size; + const auto x_ = reinterpret_cast(x); + const auto y_ = reinterpret_cast(y); + relu_mt_gpu(desc, y_, x_, remainder, data_size * pack_size, stream); + return STATUS_SUCCESS; +} + +infiniopStatus_t musaRelu(ReluMusaDescriptor_t desc, + void *y, void const *x, + void *stream) { + checkMusaError(musaSetDevice(desc->device_id)); + if (desc->dtype == F16) { + return relu_mt_gpu, half>(desc, y, x, stream, 4); + } + if (desc->dtype == F32) { + return relu_mt_gpu, float>(desc, y, x, stream, 4); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/relu/operator.cc b/src/ops/relu/operator.cc index 89122915..16e1d583 100644 --- a/src/ops/relu/operator.cc +++ b/src/ops/relu/operator.cc @@ -9,6 +9,10 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/relu.cuh" #endif +#ifdef ENABLE_MT_GPU +#include "musa/relu_musa.h" +#endif + __C infiniopStatus_t infiniopCreateReluDescriptor( infiniopHandle_t handle, @@ -28,6 +32,11 @@ __C infiniopStatus_t infiniopCreateReluDescriptor( #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaCreateReluDescriptor((MusaHandle_t) handle, (ReluMusaDescriptor_t *) desc_ptr, y, x); + } #endif } return STATUS_BAD_DEVICE; @@ -47,6 +56,11 @@ __C infiniopStatus_t infiniopRelu(infiniopReluDescriptor_t desc, void *y, void c #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaRelu((ReluMusaDescriptor_t) desc, y, x, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -66,6 +80,11 @@ __C infiniopStatus_t infiniopDestroyReluDescriptor(infiniopReluDescriptor_t desc #endif #ifdef ENABLE_CAMBRICON_MLU // TODO +#endif +#ifdef ENABLE_MT_GPU + case DevMtGpu: { + return musaDestroyReluDescriptor((ReluMusaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; From 862e4f271d69d6aa81ad3ea2b4dd153ab1ac5e7e Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Thu, 9 Jan 2025 15:08:39 +0800 Subject: [PATCH 13/15] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9mudnn=E7=9A=84?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/musa/musa_handle.cc | 22 ++++++++- src/devices/musa/musa_handle.h | 21 +++++++++ src/devices/musa/tensor_desc.cc | 81 +++++++++++++++++++++++++++++++++ src/devices/musa/tensor_desc.h | 42 +++++++++++++++++ 4 files changed, 164 insertions(+), 2 deletions(-) create mode 100644 src/devices/musa/tensor_desc.cc create mode 100644 src/devices/musa/tensor_desc.h diff --git a/src/devices/musa/musa_handle.cc b/src/devices/musa/musa_handle.cc index cd242114..e8d9be5b 100644 --- a/src/devices/musa/musa_handle.cc +++ b/src/devices/musa/musa_handle.cc @@ -20,19 +20,37 @@ infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) { musaDeviceProp prop; musaGetDeviceProperties(&prop, device_id); - + // create a mublas handle pool auto mublas_pool = std::make_shared>(); mublasHandle_t *mublas_handle = new mublasHandle_t; mublasCreate(mublas_handle); mublas_pool->push(mublas_handle); - *handle_ptr = new MusaContext{DevMtGpu, device_id, std::move(mublas_pool), std::move(prop)}; + // create a mudnn handle pool + auto mudnn_pool = std::make_shared>(); + musa::dnn::Handle *mudnn_handle = new musa::dnn::Handle; + mudnn_pool->push(mudnn_handle); + + int capability_major; + int capability_minor; + musaDeviceGetAttribute(&capability_major, musaDevAttrComputeCapabilityMajor, device_id); + musaDeviceGetAttribute(&capability_minor, musaDevAttrComputeCapabilityMinor, device_id); + + *handle_ptr = new MusaContext{ + DevMtGpu, + device_id, + std::move(mublas_pool), + std::move(mudnn_pool), + std::move(prop), + capability_major, + capability_minor,}; return STATUS_SUCCESS; } infiniopStatus_t deleteMusaHandle(MusaHandle_t handle_ptr) { handle_ptr->mublas_handles_t = nullptr; + handle_ptr->mudnn_handles_t = nullptr; delete handle_ptr; return STATUS_SUCCESS; diff --git a/src/devices/musa/musa_handle.h b/src/devices/musa/musa_handle.h index fed050d8..0c715b83 100644 --- a/src/devices/musa/musa_handle.h +++ b/src/devices/musa/musa_handle.h @@ -15,7 +15,10 @@ struct MusaContext { Device device; int device_id; std::shared_ptr> mublas_handles_t; + std::shared_ptr> mudnn_handles_t; musaDeviceProp prop; + int compute_capability_major; + int compute_capability_minor; }; typedef struct MusaContext *MusaHandle_t; @@ -40,4 +43,22 @@ void use_mublas(std::shared_ptr> mublas_handles_t, int devi mublas_handles_t->push(handle); } +template +void use_mudnn(std::shared_ptr> mudnn_handles_t, int device_id, musaStream_t stream, T const &f) { + musa::dnn::Handle* handle = mudnn_handles_t->pop(); + if (!handle) { + int current_device; + musaGetDevice(¤t_device); + if (current_device != device_id) { + musaSetDevice(device_id); + } + handle = new musa::dnn::Handle(device_id); + // mudnnCreate(handle); + } + // mudnnSetStream(*handle, (MUstream) stream); + handle->SetStream(stream); + f(handle); + mudnn_handles_t->push(handle); +} + #endif // __MUSA_HANDLE_H__ \ No newline at end of file diff --git a/src/devices/musa/tensor_desc.cc b/src/devices/musa/tensor_desc.cc new file mode 100644 index 00000000..e706a8c6 --- /dev/null +++ b/src/devices/musa/tensor_desc.cc @@ -0,0 +1,81 @@ + +#include "tensor_desc.h" +#include +#include + +// void mudnnSqueezeTensorDim(mudnnTensorDesc_t &ldesc, mudnnTensorDesc_t &rdesc, mudnnTensorDesc_t &outdesc) { +// if (outdesc->ndims > 2) { +// if (ldesc->ndims > 2 && *ldesc->dim == 1) { +// ldesc->ndims -= 1; +// ldesc->dim = ldesc->dim+1; +// } +// if (rdesc->ndims > 2 && *rdesc->dim == 1) { +// rdesc->ndims -= 1; +// rdesc->dim = rdesc->dim+1; +// } +// } +// } + +// void mudnnCreateTensorDescriptor(mudnnTensorDesc_t *desc) { +// *desc = new mudnnTensorDesc; +// (*desc)->type = Type::FLOAT; +// (*desc)->format = Format::UNKNOWN; +// (*desc)->ndims = 0; +// (*desc)->dim = nullptr; +// (*desc)->stride = nullptr; +// (*desc)->scales = nullptr; +// (*desc)->addr = nullptr; +// } + + +// void mudnnSetTensorDescriptor(mudnnTensorDesc_t &desc, int64_t *shape, int64_t *stride, int64_t ndim, +// int64_t offset, Type type, Format format) { +// desc->type = type; +// desc->format = format; +// desc->ndims = ndim; +// desc->dim = shape; +// if (stride) { +// desc->stride = stride; +// } else { +// std::vector stride_v(ndim, 1); +// for (int64_t i = ndim - 2; i >= 0; i--) { +// stride_v[i] = shape[i + 1] * stride_v[i + 1]; +// } +// desc->stride = stride_v.data(); +// } +// } + +// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout) { +// auto dims = new int64_t(layout->ndim); +// for (uint64_t i = 0; i < layout->ndim; i++) { +// dims[i] = static_cast(layout->shape[i]); +// } +// // Cast bytes stride to element stride +// auto strides = new int64_t(layout->ndim); +// for (uint64_t i = 0; i < layout->ndim; i++) { +// strides[i] = layout->strides[i] / (layout->dt).size; +// } + +// Type type = Type::HALF; +// Format format = Format::NCHW; + +// mudnnSetTensorDescriptor(desc, dims, strides, layout->ndim, 0, type, format); +// } + +// void mudnnDestroyTensorDescriptor(mudnnTensorDesc_t &desc) { +// if (desc) { +// delete desc; +// desc = nullptr; +// } +// } + +// int mudnnCreateTensor(TensorDescriptor desc, void *data, musa::dnn::Tensor **tensor) { +// *tensor = new musa::dnn::Tensor(); + +// (*tensor)->SetAddr(data); +// // (*tensor)->SetType(musa::dnn::Tensor::Type(desc->type)); +// (*tensor)->SetFormat(musa::dnn::Tensor::Format(desc->format)); +// // (*tensor)->SetNdInfo(desc->ndims, desc->dim, desc->stride); +// (*tensor)->SetNdInfo(desc->ndims, desc->dim); +// return 0; +// } \ No newline at end of file diff --git a/src/devices/musa/tensor_desc.h b/src/devices/musa/tensor_desc.h new file mode 100644 index 00000000..9b896f18 --- /dev/null +++ b/src/devices/musa/tensor_desc.h @@ -0,0 +1,42 @@ +#ifndef __TENSOR_DESC_H__ +#define __TENSOR_DESC_H__ + +#include "tensor.h" +#include "common_musa.h" +#include +#include +#include +#include + +// using namespace musa::dnn; + +// struct mudnnTensorDesc { +// Type type; +// Format format; +// int64_t ndims; +// int64_t *dim; +// int64_t *stride; +// int64_t *scales; +// int64_t *addr; +// }; + +// typedef mudnnTensorDesc *mudnnTensorDesc_t; + +// void mudnnCreateTensorDescriptor(mudnnTensorDesc_t *desc); + +// void mudnnSetTensorDescriptor(mudnnTensorDesc_t &desc, int64_t *shape, +// int64_t *stride, int64_t ndim, int64_t offset, +// Type type, Format format); + +// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout); + +// void mudnnDestroyTensorDescriptor(mudnnTensorDesc_t &desc); + +int mudnnCreateTensor(TensorDescriptor desc, void *data, musa::dnn::Tensor **tensor); + +// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout); + +// void mudnnSqueezeTensorDim(mudnnTensorDesc_t &ldesc, mudnnTensorDesc_t &rdesc, mudnnTensorDesc_t &outdesc); + + +#endif // __TENSOR_DESC_H__ \ No newline at end of file From bac08e9687cfe35c16a297b83d0fab81e83e3db6 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Fri, 7 Feb 2025 10:09:06 +0800 Subject: [PATCH 14/15] rebase dev --- operatorspy/tests/causal_softmax.py | 2 +- operatorspy/tests/matmul.py | 2 +- operatorspy/tests/rotary_embedding.py | 9 +- src/devices/handle.cc | 10 +- src/devices/musa/musa_handle.cc | 2 +- src/ops/add/musa/add_musa.cc | 2 +- src/ops/add/operator.cc | 14 +-- .../musa/causal_softmax_musa.cc | 16 ++-- .../causal_softmax/musa/causal_softmax_musa.h | 18 ++-- .../musa/causal_softmax_musa.mu | 12 +-- src/ops/causal_softmax/operator.cc | 18 ++-- src/ops/expand/musa/expand_musa.cc | 2 +- src/ops/expand/operator.cc | 14 +-- src/ops/matmul/musa/matmul_musa.cc | 2 +- src/ops/matmul/operator.cc | 18 ++-- .../random_sample/musa/random_sample_musa.cc | 2 +- .../random_sample/musa/random_sample_musa.h | 2 +- src/ops/random_sample/operator.cc | 18 ++-- src/ops/rearrange/musa/rearrange_musa.cc | 93 +++++++++---------- src/ops/rearrange/musa/rearrange_musa.h | 8 +- src/ops/rearrange/musa/rearrange_musa.mu | 50 +++++----- src/ops/rearrange/operator.cc | 14 +-- src/ops/relu/musa/relu_musa.cc | 2 +- src/ops/relu/operator.cc | 14 +-- src/ops/rms_norm/musa/rms_norm_musa.cc | 6 +- src/ops/rms_norm/musa/rms_norm_musa.h | 12 +-- src/ops/rms_norm/musa/rms_norm_musa.mu | 2 +- src/ops/rms_norm/operator.cc | 18 ++-- .../musa/rotary_embedding_musa.cc | 2 +- .../musa/rotary_embedding_musa.h | 4 +- .../musa/rotary_embedding_musa.mu | 8 +- src/ops/rotary_embedding/operator.cc | 18 ++-- src/ops/swiglu/musa/swiglu_musa.cc | 2 +- src/ops/swiglu/operator.cc | 14 +-- xmake.lua | 4 +- 35 files changed, 215 insertions(+), 219 deletions(-) diff --git a/operatorspy/tests/causal_softmax.py b/operatorspy/tests/causal_softmax.py index 762b0707..b7cabc4a 100644 --- a/operatorspy/tests/causal_softmax.py +++ b/operatorspy/tests/causal_softmax.py @@ -173,6 +173,6 @@ def test_musa(lib, test_cases): test_maca(lib, test_cases) if args.musa: test_musa(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index 46469222..31076fb5 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -420,6 +420,6 @@ def test_musa(lib, test_cases): test_maca(lib, test_cases) if args.musa: test_musa(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/rotary_embedding.py b/operatorspy/tests/rotary_embedding.py index de5b471a..3064e0ac 100644 --- a/operatorspy/tests/rotary_embedding.py +++ b/operatorspy/tests/rotary_embedding.py @@ -94,9 +94,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): # 2x table length for test sin_table, cos_table = sin_cos_table(t.shape[0] * 2, t.shape[2], t.device, theta) t_tensor = to_tensor(t, lib) - pos_tensor = to_tensor(pos, lib) - if(torch_device == 'mlu' or torch_device == 'musa'): - pos_tensor.descriptor.contents.dt = U64 + pos_tensor = to_tensor(pos[: t.shape[0]], lib) + pos_tensor.descriptor.contents.dt = U64 sin_table_tensor = to_tensor(sin_table, lib) cos_table_tensor = to_tensor(cos_table, lib) @@ -182,7 +181,7 @@ def test_maca(lib, test_cases) : test(lib, handle, "maca", shape, strides, dtype) destroy_handle(lib, handle) -def test_musa(lib, test_cases): +def test_musa(lib, test_cases) : import torch_musa device = DeviceEnum.DEVICE_MUSA handle = create_handle(lib, device) @@ -246,4 +245,4 @@ def test_musa(lib, test_cases): test_musa(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) - print("\033[92mTest passed!\033[0m") + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/src/devices/handle.cc b/src/devices/handle.cc index d00278e5..6b7f54a8 100644 --- a/src/devices/handle.cc +++ b/src/devices/handle.cc @@ -14,7 +14,7 @@ #ifdef ENABLE_METAX_GPU #include "./maca/maca_handle.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "./musa/musa_handle.h" #endif @@ -52,8 +52,8 @@ __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device d return createMacaHandle((MacaHandle_t *) handle_ptr, device_id); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return createMusaHandle((MusaHandle_t *) handle_ptr, device_id); } #endif @@ -90,8 +90,8 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { return deleteMacaHandle((MacaHandle_t) handle); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { deleteMusaHandle((MusaHandle_t) handle); return STATUS_SUCCESS; } diff --git a/src/devices/musa/musa_handle.cc b/src/devices/musa/musa_handle.cc index e8d9be5b..ab6c88ce 100644 --- a/src/devices/musa/musa_handle.cc +++ b/src/devices/musa/musa_handle.cc @@ -37,7 +37,7 @@ infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) { musaDeviceGetAttribute(&capability_minor, musaDevAttrComputeCapabilityMinor, device_id); *handle_ptr = new MusaContext{ - DevMtGpu, + DevMthreadsGpu, device_id, std::move(mublas_pool), std::move(mudnn_pool), diff --git a/src/ops/add/musa/add_musa.cc b/src/ops/add/musa/add_musa.cc index 21fbbdd1..8c4475fe 100644 --- a/src/ops/add/musa/add_musa.cc +++ b/src/ops/add/musa/add_musa.cc @@ -54,7 +54,7 @@ infiniopStatus_t musaCreateAddDescriptor(MusaHandle_t handle, checkMusaErrorWithCode(musaMemcpy(c_strides_d, c->strides, ndim * sizeof(int64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); *desc_ptr = new AddMusaDescriptor{ - DevMtGpu, + DevMthreadsGpu, c->dt, handle->device_id, ndim, diff --git a/src/ops/add/operator.cc b/src/ops/add/operator.cc index 9d090243..de97dc94 100644 --- a/src/ops/add/operator.cc +++ b/src/ops/add/operator.cc @@ -9,7 +9,7 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/add.cuh" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/add_musa.h" #endif @@ -33,8 +33,8 @@ __C infiniopStatus_t infiniopCreateAddDescriptor( #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateAddDescriptor((MusaHandle_t) handle, (AddMusaDescriptor_t *) desc_ptr, c, a, b); } #endif @@ -57,8 +57,8 @@ __C infiniopStatus_t infiniopAdd(infiniopAddDescriptor_t desc, void *c, void con #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaAdd((AddMusaDescriptor_t) desc, c, a, b, stream); } #endif @@ -81,8 +81,8 @@ __C infiniopStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaDestroyAddDescriptor((AddMusaDescriptor_t) desc); } #endif diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.cc b/src/ops/causal_softmax/musa/causal_softmax_musa.cc index ae138efd..6ff55d65 100644 --- a/src/ops/causal_softmax/musa/causal_softmax_musa.cc +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.cc @@ -5,7 +5,7 @@ infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, CausalSoftmaxMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y) { - unsigned long int ndim = y->ndim; + uint64_t ndim = y->ndim; // TODO: only support 2d or 3d tensor if (ndim != 2 && ndim != 3) { return STATUS_BAD_TENSOR_SHAPE; @@ -13,12 +13,12 @@ infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, if (!dtype_eq(y->dt, F16)) { return STATUS_BAD_TENSOR_DTYPE; } - unsigned long int total_seq_len = y->shape[ndim - 1]; - unsigned long int seq_len = y->shape[ndim - 2]; - unsigned long int batch_size = 1; - unsigned long int stride_b = 0; - unsigned long int stride_i = y->strides[ndim - 2]; - unsigned long int stride_j = y->strides[ndim - 1]; + uint64_t total_seq_len = y->shape[ndim - 1]; + uint64_t seq_len = y->shape[ndim - 2]; + uint64_t batch_size = 1; + uint64_t stride_b = 0; + uint64_t stride_i = y->strides[ndim - 2]; + uint64_t stride_j = y->strides[ndim - 1]; if (stride_j != 1) { return STATUS_BAD_TENSOR_STRIDES; } @@ -44,7 +44,7 @@ infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, unsigned long int *size) { +infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, uint64_t *size) { *size = 0; return STATUS_SUCCESS; } diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.h b/src/ops/causal_softmax/musa/causal_softmax_musa.h index 90d588f0..65d88423 100644 --- a/src/ops/causal_softmax/musa/causal_softmax_musa.h +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.h @@ -8,13 +8,13 @@ struct CausalSoftmaxMusaDescriptor { Device device; int device_id; DT dtype; - unsigned long int batch_size; - unsigned long int stride_b; - unsigned long int seq_len; - unsigned long int stride_i; - unsigned long int total_seq_len; - unsigned long int stride_j; - unsigned int max_items_per_thread; + uint64_t batch_size; + uint64_t stride_b; + uint64_t seq_len; + uint64_t stride_i; + uint64_t total_seq_len; + uint64_t stride_j; + uint64_t max_items_per_thread; }; typedef struct CausalSoftmaxMusaDescriptor *CausalSoftmaxMusaDescriptor_t; @@ -23,11 +23,11 @@ infiniopStatus_t musaCreateCausalSoftmaxDescriptor(MusaHandle_t handle, CausalSoftmaxMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y_desc); -infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, unsigned long int *size); +infiniopStatus_t musaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMusaDescriptor_t desc, uint64_t *size); infiniopStatus_t musaCausalSoftmax(CausalSoftmaxMusaDescriptor_t desc, void *workspace, - unsigned long int workspace_size, + uint64_t workspace_size, void *data, void *stream); diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.mu b/src/ops/causal_softmax/musa/causal_softmax_musa.mu index 8957134b..5eb5c8d9 100644 --- a/src/ops/causal_softmax/musa/causal_softmax_musa.mu +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.mu @@ -219,12 +219,12 @@ __global__ void fused_softmax_standard( void causal_softmax_mt_gpu_f16(CausalSoftmaxMusaDescriptor_t desc, void* y, void *stream) { - unsigned long int total_seq_len = desc->total_seq_len; - unsigned long int seq_len = desc->seq_len; - unsigned long int batch_size = desc->batch_size; - unsigned long int stride_x = desc->stride_b; - unsigned long int stride_y = desc->stride_i; - unsigned long int stride_z = desc->stride_j;// covert byte strides to element strides + uint64_t total_seq_len = desc->total_seq_len; + uint64_t seq_len = desc->seq_len; + uint64_t batch_size = desc->batch_size; + uint64_t stride_x = desc->stride_b; + uint64_t stride_y = desc->stride_i; + uint64_t stride_z = desc->stride_j;// covert byte strides to element strides unsigned int max_items_per_thread = desc->max_items_per_thread; dim3 grid(batch_size, seq_len); diff --git a/src/ops/causal_softmax/operator.cc b/src/ops/causal_softmax/operator.cc index 841eb75a..92498dca 100644 --- a/src/ops/causal_softmax/operator.cc +++ b/src/ops/causal_softmax/operator.cc @@ -21,7 +21,7 @@ #ifdef ENABLE_METAX_GPU #include "maca/causal_softmax_maca.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/causal_softmax_musa.h" #include "../../devices/musa/common_musa.h" #endif @@ -57,8 +57,8 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( return macaCreateCausalSoftmaxDescriptor((MacaHandle_t) handle, (CausalSoftmaxMacaDescriptor_t *) desc_ptr, y_desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateCausalSoftmaxDescriptor((MusaHandle_t) handle, (CausalSoftmaxMusaDescriptor_t *) desc_ptr, y_desc); } #endif @@ -95,8 +95,8 @@ __C infiniopStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmax return macaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMacaDescriptor_t) desc, size); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMusaDescriptor_t) desc, size); } #endif @@ -132,8 +132,8 @@ __C infiniopStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t des return macaCausalSoftmax((CausalSoftmaxMacaDescriptor_t) desc, workspace, workspace_size, data, stream); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCausalSoftmax((CausalSoftmaxMusaDescriptor_t) desc, workspace, workspace_size, data, stream); } #endif @@ -169,8 +169,8 @@ __C infiniopStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftma return macaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMacaDescriptor_t) desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: return musaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMusaDescriptor_t) desc); #endif } diff --git a/src/ops/expand/musa/expand_musa.cc b/src/ops/expand/musa/expand_musa.cc index 02980d71..0e2e4581 100644 --- a/src/ops/expand/musa/expand_musa.cc +++ b/src/ops/expand/musa/expand_musa.cc @@ -30,7 +30,7 @@ infiniopStatus_t musaCreateExpandDescriptor(MusaHandle_t handle, checkMusaErrorWithCode(musaMemcpy(strides_and_shape_d + 2 * ndim * sizeof(int64_t), y->shape, ndim * sizeof(uint64_t), musaMemcpyHostToDevice), STATUS_EXECUTION_FAILED); *desc_ptr = new ExpandMusaDescriptor{ - DevMtGpu, + DevMthreadsGpu, y->dt, handle->device_id, ndim, diff --git a/src/ops/expand/operator.cc b/src/ops/expand/operator.cc index f5852e46..b0374645 100644 --- a/src/ops/expand/operator.cc +++ b/src/ops/expand/operator.cc @@ -9,7 +9,7 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/expand.cuh" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/expand_musa.h" #endif @@ -33,8 +33,8 @@ __C infiniopStatus_t infiniopCreateExpandDescriptor( #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateExpandDescriptor((MusaHandle_t) handle, (ExpandMusaDescriptor_t *) desc_ptr, y, x); } #endif @@ -57,8 +57,8 @@ __C infiniopStatus_t infiniopExpand(infiniopExpandDescriptor_t desc, void *y, vo #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaExpand((ExpandMusaDescriptor_t) desc, y, x, stream); } #endif @@ -81,8 +81,8 @@ __C infiniopStatus_t infiniopDestroyExpandDescriptor(infiniopExpandDescriptor_t #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaDestroyExpandDescriptor((ExpandMusaDescriptor_t) desc); } #endif diff --git a/src/ops/matmul/musa/matmul_musa.cc b/src/ops/matmul/musa/matmul_musa.cc index 8a090291..1b5f98fc 100644 --- a/src/ops/matmul/musa/matmul_musa.cc +++ b/src/ops/matmul/musa/matmul_musa.cc @@ -26,7 +26,7 @@ infiniopStatus_t musaCreateMatmulDescriptor(MusaHandle_t handle, } *desc_ptr = new MatmulMusaDescriptor{ - DevMtGpu, + DevMthreadsGpu, dtype, handle->device_id, info, diff --git a/src/ops/matmul/operator.cc b/src/ops/matmul/operator.cc index 5dd880a4..5fa766eb 100644 --- a/src/ops/matmul/operator.cc +++ b/src/ops/matmul/operator.cc @@ -17,7 +17,7 @@ #ifdef ENABLE_METAX_GPU #include "maca/matmul_maca.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/matmul_musa.h" #endif @@ -60,8 +60,8 @@ __C infiniopStatus_t infiniopCreateMatmulDescriptor(infiniopHandle_t handle, return macaCreateMatmulDescriptor((MacaHandle_t) handle, (MatmulMacaDescriptor_t *) desc_ptr, c_desc, alpha, a_desc, b_desc, beta); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateMatmulDescriptor((MusaHandle_t) handle, (MatmulMusaDescriptor_t *) desc_ptr, c_desc, alpha, a_desc, b_desc, beta); } #endif @@ -97,8 +97,8 @@ __C infiniopStatus_t infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t d return macaGetMatmulWorkspaceSize((MatmulMacaDescriptor_t) desc, size); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaGetMatmulWorkspaceSize((MatmulMusaDescriptor_t) desc, size); } #endif @@ -136,8 +136,8 @@ __C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc, void *works return macaMatmul((MatmulMacaDescriptor_t) desc, workspace, workspace_size, c, a, b, stream); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaMatmul((MatmulMusaDescriptor_t) desc, workspace, workspace_size, c, a, b, stream); } #endif @@ -172,8 +172,8 @@ __C infiniopStatus_t infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t return macaDestroyMatmulDescriptor((MatmulMacaDescriptor_t) desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaDestroyMatmulDescriptor((MatmulMusaDescriptor_t) desc); } #endif diff --git a/src/ops/random_sample/musa/random_sample_musa.cc b/src/ops/random_sample/musa/random_sample_musa.cc index 29f676f9..70ff941c 100644 --- a/src/ops/random_sample/musa/random_sample_musa.cc +++ b/src/ops/random_sample/musa/random_sample_musa.cc @@ -26,7 +26,7 @@ infiniopStatus_t musaCreateRandomSampleDescriptor(MusaHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, unsigned long int *size) { +infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, uint64_t *size) { *size = desc->voc * (2 * sizeof(uint64_t) + sizeof(desc->dtype)); return STATUS_SUCCESS; } diff --git a/src/ops/random_sample/musa/random_sample_musa.h b/src/ops/random_sample/musa/random_sample_musa.h index 493cd3f4..d8839ff1 100644 --- a/src/ops/random_sample/musa/random_sample_musa.h +++ b/src/ops/random_sample/musa/random_sample_musa.h @@ -19,7 +19,7 @@ infiniopStatus_t musaCreateRandomSampleDescriptor(MusaHandle_t handle, RandomSampleMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs); -infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, unsigned long int *size); +infiniopStatus_t musaGetRandomSampleWorkspaceSize(RandomSampleMusaDescriptor_t desc, uint64_t *size); infiniopStatus_t musaRandomSample(RandomSampleMusaDescriptor_t desc, void *workspace, diff --git a/src/ops/random_sample/operator.cc b/src/ops/random_sample/operator.cc index f335b14f..40a8ec03 100644 --- a/src/ops/random_sample/operator.cc +++ b/src/ops/random_sample/operator.cc @@ -17,7 +17,7 @@ #ifdef ENABLE_METAX_GPU #include "maca/random_sample_maca.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/random_sample_musa.h" #endif @@ -51,8 +51,8 @@ __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handl probs); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: return musaCreateRandomSampleDescriptor((MusaHandle_t) handle, (RandomSampleMusaDescriptor_t *) desc_ptr, result, probs); #endif } @@ -87,8 +87,8 @@ __C infiniopStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDe return macaGetRandomSampleWorkspaceSize((RandomSampleMacaDescriptor_t) desc, size); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaGetRandomSampleWorkspaceSize((RandomSampleMusaDescriptor_t) desc, size); } #endif @@ -130,8 +130,8 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc, return macaRandomSample((RandomSampleMacaDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: return musaRandomSample((RandomSampleMusaDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); #endif } @@ -163,8 +163,8 @@ __C infiniopStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleD return macaDestroyRandomSampleDescriptor((RandomSampleMacaDescriptor_t) desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: return musaDestroyRandomSampleDescriptor((RandomSampleMusaDescriptor_t) desc); #endif } diff --git a/src/ops/rearrange/musa/rearrange_musa.cc b/src/ops/rearrange/musa/rearrange_musa.cc index 29f2b6b5..5fa2e768 100644 --- a/src/ops/rearrange/musa/rearrange_musa.cc +++ b/src/ops/rearrange/musa/rearrange_musa.cc @@ -7,14 +7,16 @@ infiniopStatus_t musaCreateRearrangeDescriptor(MusaHandle_t handle, RearrangeMusaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t dst, infiniopTensorDescriptor_t src) { - if (!dtype_eq(dst->dt, src->dt)) { + auto dt = dst->dt; + if (!dtype_eq(src->dt, dt)) { return STATUS_BAD_TENSOR_DTYPE; } - if (dst->ndim != src->ndim || dst->ndim < 2) { + + auto ndim = dst->ndim; + if (src->ndim != ndim || ndim == 0) { return STATUS_BAD_TENSOR_SHAPE; } - auto ndim = dst->ndim; - for (uint64_t i = 0; i < ndim; ++i) { + for (int i = 0; i < ndim; ++i) { if (dst->shape[i] != src->shape[i]) { return STATUS_BAD_TENSOR_SHAPE; } @@ -22,55 +24,46 @@ infiniopStatus_t musaCreateRearrangeDescriptor(MusaHandle_t handle, if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) { return STATUS_BAD_TENSOR_STRIDES; } - unsigned int r = 0, c = 0, b = 0; - unsigned int rsa = 0, csa = 0, rsb = 0, csb = 0; - if (ndim == 2) { - c = dst->shape[0]; - b = dst->shape[1]; - csa = dst->strides[0]; - csb = src->strides[0]; - } else if (ndim == 3) { - r = dst->shape[0]; - c = dst->shape[1]; - b = dst->shape[2]; - csa = dst->strides[1]; - csb = src->strides[1]; - rsa = dst->strides[0]; - rsb = src->strides[0]; - } else { - for (uint64_t i = ndim - 3; i >= 1; --i) { - if ((int64_t) dst->shape[i] * dst->strides[i] != dst->strides[i - 1] || (int64_t) src->shape[i] * src->strides[i] != src->strides[i - 1]) { - return STATUS_BAD_TENSOR_STRIDES; - } - } - r = std::accumulate(dst->shape, dst->shape + ndim - 2, 1, std::multiplies()); - c = dst->shape[ndim - 2]; - b = dst->shape[ndim - 1]; - csa = dst->strides[ndim - 2]; - csb = src->strides[ndim - 2]; - rsa = dst->strides[ndim - 3]; - rsb = src->strides[ndim - 3]; - } - auto contiguous_bytes = b * dst->dt.size; - if (contiguous_bytes % WARP_SIZE != 0) { - return STATUS_BAD_PARAM; - } - auto bytes_per_thread = contiguous_bytes / WARP_SIZE ; - if (bytes_per_thread <= 0 || bytes_per_thread > 32 || (bytes_per_thread & (bytes_per_thread - 1)) != 0) { - return STATUS_BAD_PARAM; + + switch (ndim) { + case 1: + *desc_ptr = new RearrangeMusaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[0], + 1, 1, + 0, 0, + 0, 0}; + break; + case 2: + *desc_ptr = new RearrangeMusaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[1], + 1, dst->shape[0], + 0, dst->strides[0], + 0, src->strides[0]}; + break; + case 3: + *desc_ptr = new RearrangeMusaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[2], + dst->shape[0], dst->shape[1], + dst->strides[0], dst->strides[1], + src->strides[0], src->strides[1]}; + break; + default: + return STATUS_BAD_TENSOR_SHAPE; } - *desc_ptr = new RearrangeMusaDescriptor{ - handle->device, - handle->device_id, - rsa, - rsb, - csa, - csb, - r, c, b, - bytes_per_thread}; + + (*desc_ptr)->dst_rs *= dt.size; + (*desc_ptr)->dst_cs *= dt.size; + (*desc_ptr)->src_rs *= dt.size; + (*desc_ptr)->src_cs *= dt.size; + return STATUS_SUCCESS; } - infiniopStatus_t musaDestroyRearrangeDescriptor(RearrangeMusaDescriptor_t desc) { delete desc; return STATUS_SUCCESS; diff --git a/src/ops/rearrange/musa/rearrange_musa.h b/src/ops/rearrange/musa/rearrange_musa.h index 7ebdb4e5..cb33209a 100644 --- a/src/ops/rearrange/musa/rearrange_musa.h +++ b/src/ops/rearrange/musa/rearrange_musa.h @@ -7,12 +7,8 @@ struct RearrangeMusaDescriptor { Device device; int device_id; - unsigned long int rsa; - unsigned long int rsb; - unsigned long int csa; - unsigned long int csb; - unsigned long int r, c, b; - unsigned long int bytes_per_thread; + uint64_t unit, r, c; + int64_t dst_rs, dst_cs, src_rs, src_cs; }; typedef struct RearrangeMusaDescriptor *RearrangeMusaDescriptor_t; diff --git a/src/ops/rearrange/musa/rearrange_musa.mu b/src/ops/rearrange/musa/rearrange_musa.mu index 77489add..887923b3 100644 --- a/src/ops/rearrange/musa/rearrange_musa.mu +++ b/src/ops/rearrange/musa/rearrange_musa.mu @@ -4,11 +4,11 @@ template static __global__ void rearrange( void *__restrict__ dst, - unsigned int const rsa, - unsigned int const csa, + int const rsa, + int const csa, void const *__restrict__ src, - unsigned int const rsb, - unsigned int const csb, + int const rsb, + int const csb, unsigned int const ncols) { auto row = blockIdx.y, @@ -25,35 +25,43 @@ static __global__ void rearrange( void rearrange_mt_gpu(RearrangeMusaDescriptor_t desc, void *y, void const *x, void *stream) { - unsigned long int rsa = desc->rsa, csa = desc->csa, rsb = desc->rsb, csb = desc->csb; - unsigned int r = desc->r, c = desc->c, b = desc->b, bytes_per_thread = desc->bytes_per_thread; - auto dst_ptr = static_cast(reinterpret_cast(y)); - rsa /= b; - csa /= b; - auto src_ptr = static_cast(reinterpret_cast(x)); - rsb /= b; - csb /= b; auto musa_stream = reinterpret_cast(stream); - dim3 grid_dims = dim3((c + MAX_WARP_PER_BLOCK - 1) / MAX_WARP_PER_BLOCK, r); - dim3 block_dims = dim3(WARP_SIZE, (c + grid_dims.x - 1) / grid_dims.x); - switch (bytes_per_thread) { + auto unit = desc->unit, + r = desc->r, c = desc->c; + auto dst_rs = desc->dst_rs, dst_cs = desc->dst_cs, + src_rs = desc->src_rs, src_cs = desc->src_cs; + + if (r == 1 && c == 1) { + musaMemcpyAsync(y, x, unit, musaMemcpyDeviceToDevice, musa_stream); + return; + } + + auto warps = 1024 / WARP_SIZE; + auto grid = dim3((c + warps - 1) / warps, r); + auto block = dim3(WARP_SIZE, (c + grid.x - 1) / grid.x); + dst_rs /= unit; + dst_cs /= unit; + src_rs /= unit; + src_cs /= unit; + + switch (unit / WARP_SIZE) { case 1: - rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); break; case 2: - rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); break; case 4: - rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); break; case 8: - rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); break; case 16: - rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); break; case 32: - rearrange<<>>(dst_ptr, rsa, csa, src_ptr, rsb, csb, c); + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); break; default: break; diff --git a/src/ops/rearrange/operator.cc b/src/ops/rearrange/operator.cc index d3da887c..4a922dc7 100644 --- a/src/ops/rearrange/operator.cc +++ b/src/ops/rearrange/operator.cc @@ -20,7 +20,7 @@ #ifdef ENABLE_METAX_GPU #include "maca/rearrange_maca.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/rearrange_musa.h" #endif @@ -58,8 +58,8 @@ __C infiniopStatus_t infiniopCreateRearrangeDescriptor( return macaCreateRearrangeDescriptor((MacaHandle_t) handle, (RearrangeMacaDescriptor_t *) desc_ptr, dst, src); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateRearrangeDescriptor((MusaHandle_t)handle, (RearrangeMusaDescriptor_t *) desc_ptr, dst, src); } #endif @@ -97,8 +97,8 @@ __C infiniopStatus_t infiniopRearrange(infiniopRearrangeDescriptor_t desc, void return macaRearrange((RearrangeMacaDescriptor_t) desc, dst, src, stream); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaRearrange((RearrangeMusaDescriptor_t) desc, dst, src, stream); } #endif @@ -133,8 +133,8 @@ __C infiniopStatus_t infiniopDestroyRearrangeDescriptor(infiniopRearrangeDescrip return macaDestroyRearrangeDescriptor((RearrangeMacaDescriptor_t) desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaDestroyRearrangeDescriptor((RearrangeMusaDescriptor_t) desc); } #endif diff --git a/src/ops/relu/musa/relu_musa.cc b/src/ops/relu/musa/relu_musa.cc index 3e3c35fe..6baaef18 100644 --- a/src/ops/relu/musa/relu_musa.cc +++ b/src/ops/relu/musa/relu_musa.cc @@ -28,7 +28,7 @@ infiniopStatus_t musaCreateReluDescriptor(MusaHandle_t handle, uint64_t data_size = std::accumulate(y->shape, y->shape + y->ndim, 1ULL, std::multiplies()); *desc_ptr = new ReluMusaDescriptor{ - DevMtGpu, + DevMthreadsGpu, y->dt, handle->device_id, ndim, diff --git a/src/ops/relu/operator.cc b/src/ops/relu/operator.cc index 16e1d583..7a3a2e2f 100644 --- a/src/ops/relu/operator.cc +++ b/src/ops/relu/operator.cc @@ -9,7 +9,7 @@ #include "../../devices/cuda/cuda_handle.h" #include "cuda/relu.cuh" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/relu_musa.h" #endif @@ -33,8 +33,8 @@ __C infiniopStatus_t infiniopCreateReluDescriptor( #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateReluDescriptor((MusaHandle_t) handle, (ReluMusaDescriptor_t *) desc_ptr, y, x); } #endif @@ -57,8 +57,8 @@ __C infiniopStatus_t infiniopRelu(infiniopReluDescriptor_t desc, void *y, void c #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaRelu((ReluMusaDescriptor_t) desc, y, x, stream); } #endif @@ -81,8 +81,8 @@ __C infiniopStatus_t infiniopDestroyReluDescriptor(infiniopReluDescriptor_t desc #ifdef ENABLE_CAMBRICON_MLU // TODO #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaDestroyReluDescriptor((ReluMusaDescriptor_t) desc); } #endif diff --git a/src/ops/rms_norm/musa/rms_norm_musa.cc b/src/ops/rms_norm/musa/rms_norm_musa.cc index 5b053e73..99c22c6e 100644 --- a/src/ops/rms_norm/musa/rms_norm_musa.cc +++ b/src/ops/rms_norm/musa/rms_norm_musa.cc @@ -18,8 +18,8 @@ infiniopStatus_t musaCreateRMSNormDescriptor(MusaHandle_t handle, RMSNormMusaDes return STATUS_BAD_TENSOR_SHAPE; } - unsigned long int stride_y = y_desc->strides[0]; - unsigned long int stride_x = x_desc->strides[0]; + uint64_t stride_y = y_desc->strides[0]; + uint64_t stride_x = x_desc->strides[0]; auto w_datatype = w_desc->dt; *desc_ptr = new RMSNormMusaDescriptor{ handle->device, @@ -35,7 +35,7 @@ infiniopStatus_t musaCreateRMSNormDescriptor(MusaHandle_t handle, RMSNormMusaDes return STATUS_SUCCESS; } -infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, unsigned long int *size) { +infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, uint64_t *size) { *size = 0; return STATUS_SUCCESS; } diff --git a/src/ops/rms_norm/musa/rms_norm_musa.h b/src/ops/rms_norm/musa/rms_norm_musa.h index 292d5212..ee8dfb72 100644 --- a/src/ops/rms_norm/musa/rms_norm_musa.h +++ b/src/ops/rms_norm/musa/rms_norm_musa.h @@ -8,10 +8,10 @@ struct RMSNormMusaDescriptor { Device device; int device_id; DT dtype; - unsigned long int n; - unsigned long int d; - unsigned long int stride_y; - unsigned long int stride_x; + uint64_t n; + uint64_t d; + uint64_t stride_y; + uint64_t stride_x; DT w_datatype; float epsilon; }; @@ -25,11 +25,11 @@ infiniopStatus_t musaCreateRMSNormDescriptor(MusaHandle_t handle, infiniopTensorDescriptor_t w_desc, float epsilon); -infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, unsigned long int *size); +infiniopStatus_t musaGetRMSNormWorkspaceSize(RMSNormMusaDescriptor_t desc, uint64_t *size); infiniopStatus_t musaRMSNorm(RMSNormMusaDescriptor_t desc, void *workspace, - unsigned long int workspace_size, + uint64_t workspace_size, void *y, void const *x, void const *w, void *stream); diff --git a/src/ops/rms_norm/musa/rms_norm_musa.mu b/src/ops/rms_norm/musa/rms_norm_musa.mu index 0b1837ad..d80bdac9 100644 --- a/src/ops/rms_norm/musa/rms_norm_musa.mu +++ b/src/ops/rms_norm/musa/rms_norm_musa.mu @@ -158,7 +158,7 @@ void rms_norm_mt_gpu_f16(RMSNormMusaDescriptor_t desc, void *y, void const *x, v infiniopStatus_t musaRMSNorm(RMSNormMusaDescriptor_t desc, void *workspace, - unsigned long int workspace_size, + uint64_t workspace_size, void *y, void const *x, void const *w, void *stream){ int current_device; diff --git a/src/ops/rms_norm/operator.cc b/src/ops/rms_norm/operator.cc index b90adef7..317e7ef2 100644 --- a/src/ops/rms_norm/operator.cc +++ b/src/ops/rms_norm/operator.cc @@ -20,7 +20,7 @@ #ifdef ENABLE_METAX_GPU #include "maca/rms_norm_maca.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/rms_norm_musa.h" #endif @@ -61,8 +61,8 @@ __C infiniopStatus_t infiniopCreateRMSNormDescriptor( return macaCreateRMSNormDescriptor((MacaHandle_t) handle, (RMSNormMacaDescriptor_t *) desc_ptr, y_desc, x_desc, w_desc, epsilon); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateRMSNormDescriptor((MusaHandle_t) handle, (RMSNormMusaDescriptor_t *) desc_ptr, y_desc, x_desc, w_desc, epsilon); } #endif @@ -98,8 +98,8 @@ __C infiniopStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t return macaGetRMSNormWorkspaceSize((RMSNormMacaDescriptor_t) desc, size); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaGetRMSNormWorkspaceSize((RMSNormMusaDescriptor_t) desc, size); } #endif @@ -141,8 +141,8 @@ __C infiniopStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *wor return macaRMSNorm((RMSNormMacaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaRMSNorm((RMSNormMusaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); } #endif @@ -177,8 +177,8 @@ __C infiniopStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_ return macaDestroyRMSNormDescriptor((RMSNormMacaDescriptor_t) desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaDestroyRMSNormDescriptor((RMSNormMusaDescriptor_t) desc); } #endif diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc b/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc index b5bdf33a..9ba0547d 100644 --- a/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.cc @@ -64,7 +64,7 @@ infiniopStatus_t musaCreateRoPEDescriptor(MusaHandle_t handle, return STATUS_SUCCESS; } -infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, unsigned long int *size) { +infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, uint64_t *size) { *size = 0; return STATUS_SUCCESS; } diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.h b/src/ops/rotary_embedding/musa/rotary_embedding_musa.h index 7124a76f..7a14daea 100644 --- a/src/ops/rotary_embedding/musa/rotary_embedding_musa.h +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.h @@ -24,11 +24,11 @@ infiniopStatus_t musaCreateRoPEDescriptor(MusaHandle_t handle, infiniopTensorDescriptor_t sin_table, infiniopTensorDescriptor_t cos_table); -infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, unsigned long int *size); +infiniopStatus_t musaGetRoPEWorkspaceSize(RoPEMusaDescriptor_t desc, uint64_t *size); infiniopStatus_t musaRoPE(RoPEMusaDescriptor_t desc, void *workspace, - unsigned long int workspace_size, + uint64_t workspace_size, void *t, void const *pos_ids, void const *sin_table, diff --git a/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu b/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu index 56875482..bac7ad47 100644 --- a/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu +++ b/src/ops/rotary_embedding/musa/rotary_embedding_musa.mu @@ -4,7 +4,7 @@ static __global__ void padding_f16( half *__restrict__ x_, - unsigned long const *__restrict__ pos_, + uint64_t const *__restrict__ pos_, float const *__restrict__ sin_, float const *__restrict__ cos_, long const stride0, @@ -27,7 +27,7 @@ static __global__ void padding_f16( void rotary_embedding_mt_gpu_f16( RoPEMusaDescriptor_t desc, half *t, - unsigned long const *pos, + uint64_t const *pos, float const *sin_, float const *cos_, void *stream) { auto nt = desc->seq_len, @@ -44,7 +44,7 @@ void rotary_embedding_mt_gpu_f16( infiniopStatus_t musaRoPE(RoPEMusaDescriptor_t desc, void *workspace, - unsigned long int workspace_size, + uint64_t workspace_size, void *t, void const *pos_ids, void const *sin_table, @@ -56,7 +56,7 @@ infiniopStatus_t musaRoPE(RoPEMusaDescriptor_t desc, if (dtype_eq(desc->dtype, F16)) { rotary_embedding_mt_gpu_f16(desc, reinterpret_cast(t), - reinterpret_cast(pos_ids), + reinterpret_cast(pos_ids), reinterpret_cast(sin_table), reinterpret_cast(cos_table), stream); diff --git a/src/ops/rotary_embedding/operator.cc b/src/ops/rotary_embedding/operator.cc index 8f3707b2..bc2dbc09 100644 --- a/src/ops/rotary_embedding/operator.cc +++ b/src/ops/rotary_embedding/operator.cc @@ -18,7 +18,7 @@ #ifdef ENABLE_METAX_GPU #include "maca/rotary_embedding_maca.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/rotary_embedding_musa.h" #endif @@ -69,8 +69,8 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(infiniopHandle_t handle, cos_table); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaCreateRoPEDescriptor((MusaHandle_t) handle, (RoPEMusaDescriptor_t *) desc_ptr, t, pos_ids, sin_table, cos_table); } #endif @@ -107,8 +107,8 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, size); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaGetRoPEWorkspaceSize((RoPEMusaDescriptor_t) desc, size); } #endif @@ -164,8 +164,8 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc, stream); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaRoPE((RoPEMusaDescriptor_t) desc, workspace, workspace_size, t, pos_ids, sin_table, cos_table, stream); } #endif @@ -200,8 +200,8 @@ __C infiniopStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc return macaDestroyRoPEDescriptor((RoPEMacaDescriptor_t) desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: { +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: { return musaDestroyRoPEDescriptor((RoPEMusaDescriptor_t) desc); } #endif diff --git a/src/ops/swiglu/musa/swiglu_musa.cc b/src/ops/swiglu/musa/swiglu_musa.cc index 88169be3..a1d5719b 100644 --- a/src/ops/swiglu/musa/swiglu_musa.cc +++ b/src/ops/swiglu/musa/swiglu_musa.cc @@ -34,7 +34,7 @@ infiniopStatus_t musaCreateSwiGLUDescriptor(infiniopHandle_t handle, return STATUS_BAD_PARAM; } - *desc_ptr = new SwiGLUMusaDescriptor{DevMtGpu, + *desc_ptr = new SwiGLUMusaDescriptor{DevMthreadsGpu, dtype, seq_len, di, diff --git a/src/ops/swiglu/operator.cc b/src/ops/swiglu/operator.cc index 06699b0d..3ea0bedc 100644 --- a/src/ops/swiglu/operator.cc +++ b/src/ops/swiglu/operator.cc @@ -17,7 +17,7 @@ #ifdef ENABLE_METAX_GPU #include "maca/swiglu_maca.h" #endif -#ifdef ENABLE_MT_GPU +#ifdef ENABLE_MTHREADS_GPU #include "musa/swiglu_musa.h" #endif @@ -61,8 +61,8 @@ __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, b_desc); } #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: return musaCreateSwiGLUDescriptor(handle, (SwiGLUMusaDescriptor_t *) desc_ptr, c_desc, a_desc, b_desc); #endif } @@ -96,8 +96,8 @@ __C infiniopStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, case DevMetaxGpu: return macaSwiGLU((SwiGLUMacaDescriptor_t) desc, c, a, b, stream); #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: return musaSwiGLU((SwiGLUMusaDescriptor_t) desc, c, a, b, stream); #endif } @@ -127,8 +127,8 @@ __C infiniopStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t case DevMetaxGpu: return macaDestroySwiGLUDescriptor((SwiGLUMacaDescriptor_t) desc); #endif -#ifdef ENABLE_MT_GPU - case DevMtGpu: +#ifdef ENABLE_MTHREADS_GPU + case DevMthreadsGpu: return musaDestroySwiGLUDescriptor((SwiGLUMusaDescriptor_t) desc); #endif } diff --git a/xmake.lua b/xmake.lua index 4f3adfdb..f9e6f3dc 100644 --- a/xmake.lua +++ b/xmake.lua @@ -52,7 +52,7 @@ option("mthreads-gpu") set_default(false) set_showmenu(true) set_description("Enable or disable MThreads GPU kernel") - add_defines("ENABLE_MT_GPU") + add_defines("ENABLE_MTHREADS_GPU") option_end() option("sugon-dcu") @@ -181,7 +181,7 @@ end if has_config("mthreads-gpu") then - add_defines("ENABLE_MT_GPU") + add_defines("ENABLE_MTHREADS_GPU") local musa_home = os.getenv("MUSA_INSTALL_PATH") -- Add include dirs add_includedirs(musa_home .. "/include") From c9ade4dc51d03e8994ae2c9ae1e8adaba6e89157 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Mon, 10 Feb 2025 14:59:52 +0800 Subject: [PATCH 15/15] fix format and rebase dev --- include/data_type.h | 8 -- operatorspy/tests/random_sample.py | 2 +- operatorspy/tests/rotary_embedding.py | 2 +- src/devices/musa/common_musa.h | 2 +- src/devices/musa/musa_handle.cc | 4 +- src/devices/musa/musa_handle.h | 2 +- src/devices/musa/pool.h | 2 +- src/devices/musa/tensor_desc.cc | 81 ------------------- src/devices/musa/tensor_desc.h | 42 ---------- .../causal_softmax/musa/causal_softmax_musa.h | 3 +- src/ops/matmul/musa/matmul_musa.cc | 2 +- src/ops/matmul/musa/matmul_musa.h | 2 +- src/ops/matmul/musa/matmul_musa.mu | 2 +- src/ops/rearrange/musa/rearrange_musa.h | 1 + 14 files changed, 12 insertions(+), 143 deletions(-) delete mode 100644 src/devices/musa/tensor_desc.cc delete mode 100644 src/devices/musa/tensor_desc.h diff --git a/include/data_type.h b/include/data_type.h index 954a42ea..e2f24c4f 100644 --- a/include/data_type.h +++ b/include/data_type.h @@ -46,12 +46,4 @@ const static struct DataLayout F64 = {1, 1, 8, 52, 11}; // clang-format on -DT get_F16(); - -DT get_U32(); - -DT get_F32(); - -DT get_U64(); - #endif// __DATA_TYPE_H__ diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index 2c464522..85a3c681 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -94,7 +94,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ if(torch_device == 'maca'): indices = torch.zeros([1], dtype = torch.int64).to('cuda') else: - indices = torch.zeros([1], dtype = torch.uint64).to(torch_device) + indices = torch.zeros([1], dtype = torch.int64).to(torch_device) x_tensor = to_tensor(data, lib) indices_tensor = to_tensor(indices, lib) indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 diff --git a/operatorspy/tests/rotary_embedding.py b/operatorspy/tests/rotary_embedding.py index 3064e0ac..1c1122a6 100644 --- a/operatorspy/tests/rotary_embedding.py +++ b/operatorspy/tests/rotary_embedding.py @@ -245,4 +245,4 @@ def test_musa(lib, test_cases) : test_musa(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa): test_cpu(lib, test_cases) - print("\033[92mTest passed!\033[0m") \ No newline at end of file + print("\033[92mTest passed!\033[0m") diff --git a/src/devices/musa/common_musa.h b/src/devices/musa/common_musa.h index 02d97330..c42b5197 100644 --- a/src/devices/musa/common_musa.h +++ b/src/devices/musa/common_musa.h @@ -74,4 +74,4 @@ inline __device__ uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_ return res; } -#endif // __COMMON_MUSA_H__ \ No newline at end of file +#endif // __COMMON_MUSA_H__ diff --git a/src/devices/musa/musa_handle.cc b/src/devices/musa/musa_handle.cc index ab6c88ce..3a7f8174 100644 --- a/src/devices/musa/musa_handle.cc +++ b/src/devices/musa/musa_handle.cc @@ -16,7 +16,7 @@ infiniopStatus_t createMusaHandle(MusaHandle_t* handle_ptr, int device_id) { return STATUS_BAD_DEVICE; } - // set CUDA device property + // set MUSA device property musaDeviceProp prop; musaGetDeviceProperties(&prop, device_id); @@ -54,4 +54,4 @@ infiniopStatus_t deleteMusaHandle(MusaHandle_t handle_ptr) { delete handle_ptr; return STATUS_SUCCESS; -} \ No newline at end of file +} diff --git a/src/devices/musa/musa_handle.h b/src/devices/musa/musa_handle.h index 0c715b83..6de2c2d3 100644 --- a/src/devices/musa/musa_handle.h +++ b/src/devices/musa/musa_handle.h @@ -61,4 +61,4 @@ void use_mudnn(std::shared_ptr> mudnn_handles_t, int dev mudnn_handles_t->push(handle); } -#endif // __MUSA_HANDLE_H__ \ No newline at end of file +#endif // __MUSA_HANDLE_H__ diff --git a/src/devices/musa/pool.h b/src/devices/musa/pool.h index 9c6a107b..2cfb5e32 100644 --- a/src/devices/musa/pool.h +++ b/src/devices/musa/pool.h @@ -47,4 +47,4 @@ class Pool { mutable std::atomic *> _head; }; -#endif // __POOL_MUSA_H__ \ No newline at end of file +#endif // __POOL_MUSA_H__ diff --git a/src/devices/musa/tensor_desc.cc b/src/devices/musa/tensor_desc.cc deleted file mode 100644 index e706a8c6..00000000 --- a/src/devices/musa/tensor_desc.cc +++ /dev/null @@ -1,81 +0,0 @@ - -#include "tensor_desc.h" -#include -#include - -// void mudnnSqueezeTensorDim(mudnnTensorDesc_t &ldesc, mudnnTensorDesc_t &rdesc, mudnnTensorDesc_t &outdesc) { -// if (outdesc->ndims > 2) { -// if (ldesc->ndims > 2 && *ldesc->dim == 1) { -// ldesc->ndims -= 1; -// ldesc->dim = ldesc->dim+1; -// } -// if (rdesc->ndims > 2 && *rdesc->dim == 1) { -// rdesc->ndims -= 1; -// rdesc->dim = rdesc->dim+1; -// } -// } -// } - -// void mudnnCreateTensorDescriptor(mudnnTensorDesc_t *desc) { -// *desc = new mudnnTensorDesc; -// (*desc)->type = Type::FLOAT; -// (*desc)->format = Format::UNKNOWN; -// (*desc)->ndims = 0; -// (*desc)->dim = nullptr; -// (*desc)->stride = nullptr; -// (*desc)->scales = nullptr; -// (*desc)->addr = nullptr; -// } - - -// void mudnnSetTensorDescriptor(mudnnTensorDesc_t &desc, int64_t *shape, int64_t *stride, int64_t ndim, -// int64_t offset, Type type, Format format) { -// desc->type = type; -// desc->format = format; -// desc->ndims = ndim; -// desc->dim = shape; -// if (stride) { -// desc->stride = stride; -// } else { -// std::vector stride_v(ndim, 1); -// for (int64_t i = ndim - 2; i >= 0; i--) { -// stride_v[i] = shape[i + 1] * stride_v[i + 1]; -// } -// desc->stride = stride_v.data(); -// } -// } - -// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout) { -// auto dims = new int64_t(layout->ndim); -// for (uint64_t i = 0; i < layout->ndim; i++) { -// dims[i] = static_cast(layout->shape[i]); -// } -// // Cast bytes stride to element stride -// auto strides = new int64_t(layout->ndim); -// for (uint64_t i = 0; i < layout->ndim; i++) { -// strides[i] = layout->strides[i] / (layout->dt).size; -// } - -// Type type = Type::HALF; -// Format format = Format::NCHW; - -// mudnnSetTensorDescriptor(desc, dims, strides, layout->ndim, 0, type, format); -// } - -// void mudnnDestroyTensorDescriptor(mudnnTensorDesc_t &desc) { -// if (desc) { -// delete desc; -// desc = nullptr; -// } -// } - -// int mudnnCreateTensor(TensorDescriptor desc, void *data, musa::dnn::Tensor **tensor) { -// *tensor = new musa::dnn::Tensor(); - -// (*tensor)->SetAddr(data); -// // (*tensor)->SetType(musa::dnn::Tensor::Type(desc->type)); -// (*tensor)->SetFormat(musa::dnn::Tensor::Format(desc->format)); -// // (*tensor)->SetNdInfo(desc->ndims, desc->dim, desc->stride); -// (*tensor)->SetNdInfo(desc->ndims, desc->dim); -// return 0; -// } \ No newline at end of file diff --git a/src/devices/musa/tensor_desc.h b/src/devices/musa/tensor_desc.h deleted file mode 100644 index 9b896f18..00000000 --- a/src/devices/musa/tensor_desc.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef __TENSOR_DESC_H__ -#define __TENSOR_DESC_H__ - -#include "tensor.h" -#include "common_musa.h" -#include -#include -#include -#include - -// using namespace musa::dnn; - -// struct mudnnTensorDesc { -// Type type; -// Format format; -// int64_t ndims; -// int64_t *dim; -// int64_t *stride; -// int64_t *scales; -// int64_t *addr; -// }; - -// typedef mudnnTensorDesc *mudnnTensorDesc_t; - -// void mudnnCreateTensorDescriptor(mudnnTensorDesc_t *desc); - -// void mudnnSetTensorDescriptor(mudnnTensorDesc_t &desc, int64_t *shape, -// int64_t *stride, int64_t ndim, int64_t offset, -// Type type, Format format); - -// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout); - -// void mudnnDestroyTensorDescriptor(mudnnTensorDesc_t &desc); - -int mudnnCreateTensor(TensorDescriptor desc, void *data, musa::dnn::Tensor **tensor); - -// void mudnnSetTensorDescriptorFromTensorLayout(mudnnTensorDesc_t &desc, const TensorLayout *layout); - -// void mudnnSqueezeTensorDim(mudnnTensorDesc_t &ldesc, mudnnTensorDesc_t &rdesc, mudnnTensorDesc_t &outdesc); - - -#endif // __TENSOR_DESC_H__ \ No newline at end of file diff --git a/src/ops/causal_softmax/musa/causal_softmax_musa.h b/src/ops/causal_softmax/musa/causal_softmax_musa.h index 65d88423..c6f81afc 100644 --- a/src/ops/causal_softmax/musa/causal_softmax_musa.h +++ b/src/ops/causal_softmax/musa/causal_softmax_musa.h @@ -32,5 +32,4 @@ infiniopStatus_t musaCausalSoftmax(CausalSoftmaxMusaDescriptor_t desc, void *stream); infiniopStatus_t musaDestroyCausalSoftmaxDescriptor(CausalSoftmaxMusaDescriptor_t desc); - -#endif \ No newline at end of file +#endif diff --git a/src/ops/matmul/musa/matmul_musa.cc b/src/ops/matmul/musa/matmul_musa.cc index 1b5f98fc..3256dca6 100644 --- a/src/ops/matmul/musa/matmul_musa.cc +++ b/src/ops/matmul/musa/matmul_musa.cc @@ -45,4 +45,4 @@ infiniopStatus_t musaDestroyMatmulDescriptor(MatmulMusaDescriptor_t desc) { desc->mublas_handles_t = nullptr; delete desc; return STATUS_SUCCESS; -} \ No newline at end of file +} diff --git a/src/ops/matmul/musa/matmul_musa.h b/src/ops/matmul/musa/matmul_musa.h index 617a8318..b086a494 100644 --- a/src/ops/matmul/musa/matmul_musa.h +++ b/src/ops/matmul/musa/matmul_musa.h @@ -42,4 +42,4 @@ infiniopStatus_t musaMatmul(MatmulMusaDescriptor_t desc, infiniopStatus_t musaDestroyMatmulDescriptor(MatmulMusaDescriptor_t desc); -#endif // __MUSA_MATMUL_H__ \ No newline at end of file +#endif // __MUSA_MATMUL_H__ diff --git a/src/ops/matmul/musa/matmul_musa.mu b/src/ops/matmul/musa/matmul_musa.mu index 4685beb8..b445a7b3 100644 --- a/src/ops/matmul/musa/matmul_musa.mu +++ b/src/ops/matmul/musa/matmul_musa.mu @@ -74,4 +74,4 @@ infiniopStatus_t musaMatmul(MatmulMusaDescriptor_t desc, return matmul_musa(desc, c, desc->beta, a, b, desc->alpha, stream); } return STATUS_BAD_TENSOR_DTYPE; -} \ No newline at end of file +} diff --git a/src/ops/rearrange/musa/rearrange_musa.h b/src/ops/rearrange/musa/rearrange_musa.h index cb33209a..df6ade12 100644 --- a/src/ops/rearrange/musa/rearrange_musa.h +++ b/src/ops/rearrange/musa/rearrange_musa.h @@ -27,3 +27,4 @@ infiniopStatus_t musaDestroyRearrangeDescriptor(RearrangeMusaDescriptor_t desc); void rearrange_mt_gpu(RearrangeMusaDescriptor *, void *y, void const *x, void *stream); #endif // __MUSA_REARRANGE_H__ +