diff --git a/include/device.h b/include/device.h index 701b6632..bdeb1dc9 100644 --- a/include/device.h +++ b/include/device.h @@ -2,10 +2,12 @@ #define __DEVICE_H__ enum DeviceEnum { - DevCpu, - DevNvGpu, - DevCambriconMlu, - DevAscendNpu, + DevCpu = 0, + DevNvGpu = 1, + DevCambriconMlu = 2, + DevAscendNpu = 3, + DevMetaxGpu = 4, + DevMthreadsGpu = 5, }; typedef enum DeviceEnum Device; diff --git a/operatorspy/devices.py b/operatorspy/devices.py index 4984502a..23bd2a5c 100644 --- a/operatorspy/devices.py +++ b/operatorspy/devices.py @@ -3,3 +3,5 @@ class DeviceEnum: DEVICE_CUDA = 1 DEVICE_BANG = 2 DEVICE_ASCEND = 3 + DEVICE_MACA = 4 + DEVICE_MUSA = 5 diff --git a/operatorspy/tests/causal_softmax.py b/operatorspy/tests/causal_softmax.py index 1ad304b2..623c0fac 100644 --- a/operatorspy/tests/causal_softmax.py +++ b/operatorspy/tests/causal_softmax.py @@ -111,6 +111,14 @@ def test_ascend(lib, test_cases): destroy_handle(lib, handle) +def test_maca(lib, test_cases): + device = DeviceEnum.DEVICE_MACA + handle = create_handle(lib, device) + for x_shape, x_stride in test_cases: + test(lib, handle, "cuda", x_shape, x_stride) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ # x_shape, x_stride @@ -151,6 +159,8 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): + if args.maca: + test_maca(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index ac4b0f7f..ba590447 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -293,6 +293,38 @@ def test_ascend(lib, test_cases): destroy_handle(lib, handle) +def test_maca(lib, test_cases): + device = DeviceEnum.DEVICE_MACA + handle = create_handle(lib, device) + + for ( + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride, + b_stride, + c_stride, + dtype, + ) in test_cases: + test( + lib, + handle, + "cuda", + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride, + b_stride, + c_stride, + dtype, + ) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype @@ -353,6 +385,8 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): + if args.maca: + test_maca(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index 98a8dceb..4b0c2a10 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -83,12 +83,18 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ) data = torch.arange(voc).float() * 0.0001 _perm = torch.randperm(voc) - data = data[_perm].to(x_dtype).to(torch_device) + if (torch_device == 'maca'): + data = data[_perm].to(x_dtype).to('cuda') + else: + data = data[_perm].to(x_dtype).to(torch_device) if(topp > 0 and topk > 1): ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu") else: ans = random_sample_0(data) - indices = torch.zeros([1], dtype=torch.int64).to(torch_device) + if(torch_device == 'maca'): + indices = torch.zeros([1], dtype = torch.int64).to('cuda') + else: + indices = torch.zeros([1], dtype = torch.uint64).to(torch_device) x_tensor = to_tensor(data, lib) indices_tensor = to_tensor(indices, lib) indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64 @@ -163,7 +169,15 @@ def test_ascend(lib, test_cases): handle = create_handle(lib, device) for (voc, random_val, topp, topk, temperature) in test_cases: test(lib, handle, "npu", voc, random_val, topp, topk, temperature) - destroy_handle(lib, handle) + destroy_handle(lib, handle) + +def test_maca(lib, test_cases): + device = DeviceEnum.DEVICE_MACA + handle = create_handle(lib, device) + for (voc, random_val, topp, topk, temperature) in test_cases: + test(lib, handle, "maca", voc, random_val, topp, topk, temperature) + destroy_handle(lib, handle) + if __name__ == "__main__": @@ -220,6 +234,8 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): + if args.maca: + test_maca(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/rearrange.py b/operatorspy/tests/rearrange.py index e9cc81b9..124fe552 100644 --- a/operatorspy/tests/rearrange.py +++ b/operatorspy/tests/rearrange.py @@ -108,6 +108,15 @@ def test_ascend(lib, test_cases): test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride) destroy_handle(lib, handle) +def test_maca(lib, test_cases): + device = DeviceEnum.DEVICE_MACA + handle = create_handle(lib, device) + for test_case in test_cases: + x_shape, x_stride = test_case[0] + y_shape, y_stride = test_case[1] + test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride) + destroy_handle(lib, handle) + if __name__ == "__main__": args = get_args() test_cases = [ @@ -145,4 +154,6 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) + if args.maca: + test_maca(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/rms_norm.py b/operatorspy/tests/rms_norm.py index 13cf1ccf..8176af64 100644 --- a/operatorspy/tests/rms_norm.py +++ b/operatorspy/tests/rms_norm.py @@ -117,6 +117,14 @@ def test_ascend(lib, test_cases): destroy_handle(lib, handle) +def test_maca(lib, test_cases): + device = DeviceEnum.DEVICE_MACA + handle = create_handle(lib, device) + for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases: + test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ # y_shape, x_shape, w_shape, dtype, w_dtype @@ -164,6 +172,8 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): + if args.maca: + test_maca(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/rotary_embedding.py b/operatorspy/tests/rotary_embedding.py index 081d2f91..b7123052 100644 --- a/operatorspy/tests/rotary_embedding.py +++ b/operatorspy/tests/rotary_embedding.py @@ -45,7 +45,6 @@ def rotary_embedding(t, pos, theta, torch_device): ) freqs = torch.outer(pos, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, t_) t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype) @@ -82,6 +81,10 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device) pos = pos.to(torch_device) t = t.to(torch_device) + elif torch_device == 'maca': + ans = rotary_embedding(t, posTmp, theta, "cpu").to('cuda') + pos = pos.to('cuda') + t = t.to('cuda') else: t = t.to(torch_device) pos = pos.to(torch_device) @@ -133,7 +136,6 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): None, ) ) - assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2) check_error(lib.infiniopDestroyRoPEDescriptor(descriptor)) @@ -172,6 +174,13 @@ def test_ascend(lib, test_cases) : test(lib, handle, "npu", shape, strides, dtype) destroy_handle(lib, handle) +def test_maca(lib, test_cases) : + device = DeviceEnum.DEVICE_MACA + handle = create_handle(lib, device) + for shape, strides, dtype in test_cases: + test(lib, handle, "maca", shape, strides, dtype) + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ ((1, 32, 128), None, torch.float16), @@ -222,6 +231,8 @@ def test_ascend(lib, test_cases) : test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): + if args.maca: + test_maca(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca): test_cpu(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/swiglu.py b/operatorspy/tests/swiglu.py index 7fb447a1..fcd044f1 100644 --- a/operatorspy/tests/swiglu.py +++ b/operatorspy/tests/swiglu.py @@ -250,6 +250,18 @@ def test_ascend(lib, test_cases): destroy_handle(lib, handle) +def test_maca(lib, test_cases): + device = DeviceEnum.DEVICE_MACA + handle = create_handle(lib, device) + + for shape, a_stride, b_stride, c_stride, dtype in test_cases: + test_out_of_place( + lib, handle, "cuda", shape, a_stride, b_stride, c_stride, dtype) + test_in_place1(lib, handle, "cuda", shape, a_stride, b_stride, dtype) + test_in_place2(lib, handle, "cuda", shape, a_stride, b_stride, dtype) + + destroy_handle(lib, handle) + if __name__ == "__main__": test_cases = [ @@ -293,4 +305,6 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) + if args.maca: + test_maca(lib, test_cases) print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/test_utils.py b/operatorspy/tests/test_utils.py index 47635b6e..68b71bc4 100644 --- a/operatorspy/tests/test_utils.py +++ b/operatorspy/tests/test_utils.py @@ -27,6 +27,11 @@ def get_args(): action="store_true", help="Run ASCEND NPU test", ) + parser.add_argument( + "--maca", + action="store_true", + help="Run ASCEND NPU test", + ) return parser.parse_args() diff --git a/operatorspy/utils.py b/operatorspy/utils.py index b079d871..bb095658 100644 --- a/operatorspy/utils.py +++ b/operatorspy/utils.py @@ -50,6 +50,8 @@ def create_workspace(size, torch_device): if size == 0: return None import torch + if (torch_device == 'maca'): + return torch.zeros(size=(size,), dtype=torch.uint8, device='cuda') return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device) def create_handle(lib, device, id=0): diff --git a/src/devices/handle.cc b/src/devices/handle.cc index 97126a9d..45779776 100644 --- a/src/devices/handle.cc +++ b/src/devices/handle.cc @@ -11,6 +11,9 @@ #ifdef ENABLE_ASCEND_NPU #include "./ascend/ascend_handle.h" #endif +#ifdef ENABLE_METAX_GPU +#include "./maca/maca_handle.h" +#endif __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device device, int device_id) { @@ -40,6 +43,11 @@ __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device d case DevAscendNpu: { return createAscendHandle((AscendHandle_t *) handle_ptr, device_id); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return createMacaHandle((MacaHandle_t *) handle_ptr, device_id); + } #endif } return STATUS_BAD_DEVICE; @@ -68,6 +76,11 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { case DevAscendNpu: { return deleteAscendHandle((AscendHandle_t) handle); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return deleteMacaHandle((MacaHandle_t) handle); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/devices/maca/common_maca.h b/src/devices/maca/common_maca.h new file mode 100644 index 00000000..9fa82e78 --- /dev/null +++ b/src/devices/maca/common_maca.h @@ -0,0 +1,87 @@ +#ifndef __COMMON_MACA_H__ +#define __COMMON_MACA_H__ + +#define MAX_THREADS_PER_BLOCK 1024 +#define MAX_WARP_PER_BLOCK 32 +#define WARP_SIZE 32 + +#include + +#define checkMacaErrorWithCode(call, errorCode) \ + do { \ + if (auto status = call; status != hcSuccess) { \ + std::cerr << "MACA error: " << hcGetErrorString(status) \ + << " in file " << __FILE__ \ + << ", function " << __func__ \ + << ", line " << __LINE__ << std::endl; \ + return errorCode; \ + } \ + } while (0) + +#define checkMacaError(call) checkMacaErrorWithCode(call, STATUS_BAD_DEVICE) + +#define checkMcdnnError(call) \ + do { \ + if (auto status = call; status != HCDNN_STATUS_SUCCESS) { \ + std::cerr << "MCDNN error: " << hcdnnGetErrorString(status) \ + << " in file " << __FILE__ \ + << ", function " << __func__ \ + << ", line " << __LINE__ << std::endl; \ + return STATUS_EXECUTION_FAILED; \ + } \ + } while (0) + +#include "data_type.h" +#include + +typedef struct DTMcdnnMapping { + DT layout; + hcdnnDataType_t hcdnn_type; +} DTMcdnnMapping; + +// DT mcdnnDataType_t mapping table +const DTMcdnnMapping dtMappings[] = { + {F16, HCDNN_DATA_HALF}, + {F32, HCDNN_DATA_FLOAT}, + {F64, HCDNN_DATA_DOUBLE}, + {BF16, HCDNN_DATA_BFLOAT16}, + {I8, HCDNN_DATA_INT8}, + {I32, HCDNN_DATA_INT32}, + {I64, HCDNN_DATA_INT64}, + {U8, HCDNN_DATA_UINT8}, +}; + +typedef struct DataLayoutMap { + int operator[](const DataLayout &layout) const { + for (const auto &mapping : dtMappings) { + if (mapping.layout == layout) { + return mapping.hcdnn_type; + } + } + return -1; + } +} DTMap; + +constexpr DTMap dataTypeMap; + +// get the corresponding offset in the destination given the flat index of the source (for element mapping in shape broadcast) +inline __device__ uint64_t getDstOffset(uint64_t flat_index, uint64_t ndim, int64_t const *src_strides, int64_t const *dst_strides) { + uint64_t res = 0; + for (uint64_t i = 0; i < ndim; ++i) { + res += flat_index / src_strides[i] * dst_strides[i]; + flat_index %= src_strides[i]; + } + return res; +} + +// get the memory offset of the given element in a tensor given its flat index +inline __device__ uint64_t getOffset(uint64_t flat_index, uint64_t ndim, uint64_t const *shape, int64_t const *strides) { + uint64_t res = 0; + for (long i = ndim - 1; i >= 0; --i) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +#endif// __COMMON_MACA_H__ diff --git a/src/devices/maca/maca_handle.cc b/src/devices/maca/maca_handle.cc new file mode 100644 index 00000000..9b1b52b8 --- /dev/null +++ b/src/devices/maca/maca_handle.cc @@ -0,0 +1,55 @@ +#include "maca_handle.h" + +infiniopStatus_t createMacaHandle(MacaHandle_t *handle_ptr, int device_id) { + // Check if device_id is valid + int device_count; + hcGetDeviceCount(&device_count); + if (device_id >= device_count) { + return STATUS_BAD_DEVICE; + } + + // Create a new mcblas handle pool + auto pool = std::make_shared>(); + if (hcSetDevice(device_id) != hcSuccess) { + return STATUS_BAD_DEVICE; + } + hcblasHandle_t handle; + hcblasCreate(&handle); + pool->push(std::move(handle)); + + // create a mcdnn handle pool + auto mcdnn_pool = std::make_shared>(); + hcdnnHandle_t mcdnn_handle; + checkMcdnnError(hcdnnCreate(&mcdnn_handle)); + mcdnn_pool->push(std::move(mcdnn_handle)); + + // set MACA device property + hcDeviceProp_t prop; + hcGetDeviceProperties(&prop, device_id); + + // set device compute capability numbers + int capability_major; + int capability_minor; + hcDeviceGetAttribute(&capability_major, hcDeviceAttributeComputeCapabilityMajor, device_id); + hcDeviceGetAttribute(&capability_minor, hcDeviceAttributeComputeCapabilityMinor, device_id); + + *handle_ptr = new MacaContext{ + DevMetaxGpu, + device_id, + std::move(pool), + std::move(mcdnn_pool), + std::move(prop), + capability_major, + capability_minor, + }; + + return STATUS_SUCCESS; +} + +infiniopStatus_t deleteMacaHandle(MacaHandle_t handle_ptr) { + handle_ptr->mcblas_handles_t = nullptr; + handle_ptr->mcdnn_handles_t = nullptr; + delete handle_ptr; + + return STATUS_SUCCESS; +} diff --git a/src/devices/maca/maca_handle.h b/src/devices/maca/maca_handle.h new file mode 100644 index 00000000..41485099 --- /dev/null +++ b/src/devices/maca/maca_handle.h @@ -0,0 +1,52 @@ +#ifndef MACA_HANDLE_H +#define MACA_HANDLE_H + +#include "../pool.h" +#include "common_maca.h" +#include "device.h" +#include "status.h" +#include +#include +#include + +struct MacaContext { + Device device; + int device_id; + std::shared_ptr> mcblas_handles_t; + std::shared_ptr> mcdnn_handles_t; + hcDeviceProp_t prop; + int compute_capability_major; + int compute_capability_minor; +}; +typedef struct MacaContext *MacaHandle_t; + +infiniopStatus_t createMacaHandle(MacaHandle_t *handle_ptr, int device_id); + +infiniopStatus_t deleteMacaHandle(MacaHandle_t handle_ptr); + +template +void use_mcblas(std::shared_ptr> mcblas_handles_t, int device_id, hcStream_t stream, T const &f) { + auto handle = mcblas_handles_t->pop(); + if (!handle) { + hcSetDevice(device_id); + hcblasCreate(&(*handle)); + } + hcblasSetStream(*handle, (hcStream_t) stream); + f(*handle); + mcblas_handles_t->push(std::move(*handle)); +} + +template +hcdnnStatus_t use_mcdnn(std::shared_ptr> mcdnn_handles_t, int device_id, hcStream_t stream, T const &f) { + auto handle = mcdnn_handles_t->pop(); + if (!handle) { + hcSetDevice(device_id); + hcdnnCreate(&(*handle)); + } + hcdnnSetStream(*handle, stream); + hcdnnStatus_t status = f(*handle); + mcdnn_handles_t->push(std::move(*handle)); + return status; +} + +#endif diff --git a/src/ops/causal_softmax/maca/causal_softmax_maca.cc b/src/ops/causal_softmax/maca/causal_softmax_maca.cc new file mode 100644 index 00000000..5a3803e7 --- /dev/null +++ b/src/ops/causal_softmax/maca/causal_softmax_maca.cc @@ -0,0 +1,55 @@ +#include "causal_softmax_maca.h" +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" + +infiniopStatus_t macaCreateCausalSoftmaxDescriptor(MacaHandle_t handle, + CausalSoftmaxMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y) { + uint64_t ndim = y->ndim; + // TODO: only support 2d or 3d tensor + if (ndim != 2 && ndim != 3) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(y->dt, F16)) { + return STATUS_BAD_TENSOR_DTYPE; + } + uint64_t total_seq_len = y->shape[ndim - 1]; + uint64_t seq_len = y->shape[ndim - 2]; + uint64_t batch_size = 1; + uint64_t stride_b = 0; + uint64_t stride_i = y->strides[ndim - 2]; + uint64_t stride_j = y->strides[ndim - 1]; + if (stride_j != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + for (int i = 0; i < ndim - 2; i++) { + batch_size *= y->shape[i]; + } + if (ndim == 3) + stride_b = y->strides[ndim - 3]; + unsigned int max_items_per_thread = ROUND_UP_DIV(total_seq_len, MAX_THREADS_PER_BLOCK); + + *desc_ptr = new CausalSoftmaxMacaDescriptor{ + handle->device, + handle->device_id, + y->dt, + batch_size, + stride_b, + seq_len, + stride_i, + total_seq_len, + stride_j, + max_items_per_thread}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t macaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMacaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t macaDestroyCausalSoftmaxDescriptor(CausalSoftmaxMacaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/causal_softmax/maca/causal_softmax_maca.h b/src/ops/causal_softmax/maca/causal_softmax_maca.h new file mode 100644 index 00000000..daa198b7 --- /dev/null +++ b/src/ops/causal_softmax/maca/causal_softmax_maca.h @@ -0,0 +1,36 @@ +#ifndef __MACA_CAUSAL_SOFTMAX_H__ +#define __MACA_CAUSAL_SOFTMAX_H__ + +#include "../../../devices/maca/maca_handle.h" +#include "operators.h" + +struct CausalSoftmaxMacaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t batch_size; + uint64_t stride_b; + uint64_t seq_len; + uint64_t stride_i; + uint64_t total_seq_len; + uint64_t stride_j; + unsigned int max_items_per_thread; +}; + +typedef struct CausalSoftmaxMacaDescriptor *CausalSoftmaxMacaDescriptor_t; + +infiniopStatus_t macaCreateCausalSoftmaxDescriptor(MacaHandle_t handle, + CausalSoftmaxMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc); + +infiniopStatus_t macaGetCausalSoftmaxWorkspaceSize(CausalSoftmaxMacaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t macaCausalSoftmax(CausalSoftmaxMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *data, + void *stream); + +infiniopStatus_t macaDestroyCausalSoftmaxDescriptor(CausalSoftmaxMacaDescriptor_t desc); + +#endif diff --git a/src/ops/causal_softmax/maca/causal_softmax_maca.maca b/src/ops/causal_softmax/maca/causal_softmax_maca.maca new file mode 100644 index 00000000..94b884e8 --- /dev/null +++ b/src/ops/causal_softmax/maca/causal_softmax_maca.maca @@ -0,0 +1,259 @@ +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" +#include "causal_softmax_maca.h" +#include + +struct AttentionCausualMask { + __forceinline__ __device__ bool + operator()(int tok_id, int seq_len, + int pos_id, int total_seq_len) { + // tok_id ↓ |<-total_seq_len->| + // 0 | * * * ... * | + // 1 | * * * ... * * | + // 2 | * * * ... * * * | + // seq_len: 3 pos_id-> + return total_seq_len + tok_id >= pos_id + seq_len; + } +}; + +template +static __device__ void block_padding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const token_idx, + unsigned int const seq_len) { + auto att_idx = threadIdx.x; + auto total_seq_len = blockDim.x; + auto thread_data = mask(token_idx, seq_len, att_idx, total_seq_len) + ? float(att[att_idx]) + : -__FLT_MAX__; + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + __shared__ float max; + { + auto acc = block_op.Reduce(thread_data, cub::Max(), total_seq_len); + if (threadIdx.x == 0) { max = acc; } + } + __syncthreads(); + + __shared__ float mean; + { + auto acc = block_op.Sum(thread_data = expf(thread_data - max), total_seq_len); + if (threadIdx.x == 0) { mean = fdividef(1, acc); } + } + __syncthreads(); + + att[att_idx] = Tdata(thread_data * mean); +} + +template +static __device__ void block_folding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const token_idx, + unsigned int const seq_len, + unsigned int const total_seq_len) { + + auto local = (total_seq_len + blockDim.x - 1) / blockDim.x; + + auto thread_offset = threadIdx.x * local; + att += thread_offset; + + float thread_data[ITEMS_PER_THREAD], thread_max = -__FLT_MAX__, thread_sum = 0; + for (unsigned int i = 0; i < local; ++i) { + auto att_idx = thread_offset + i; + thread_data[i] = att_idx < total_seq_len && mask(token_idx, seq_len, att_idx, total_seq_len) + ? float(att[i]) + : -__FLT_MAX__; + thread_max = cub::Max()(thread_max, thread_data[i]); + } + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + __shared__ float max; + { + auto acc = block_op.Reduce(thread_max, cub::Max()); + if (threadIdx.x == 0) { max = acc; } + } + __syncthreads(); + + __shared__ float mean; + { + for (unsigned int i = 0; i < local; ++i) { + thread_data[i] = expf(thread_data[i] - max); + thread_sum += thread_data[i]; + } + auto acc = block_op.Sum(thread_sum); + if (threadIdx.x == 0) { mean = fdividef(1, acc); } + } + __syncthreads(); + + for (unsigned int i = 0; i < local; ++i) { + if (auto att_idx = thread_offset + i; att_idx < total_seq_len) { + att[i] = Tdata(thread_data[i] * mean); + } + } +} + +// assert BLOCK_SIZE >= blockDim.x +template +static __forceinline__ __device__ void padding( + Tdata *__restrict__ att, + Tmask mask, + int const stride_x, + int const stride_y, + int const stride_z) { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + block_padding( + att + offset, mask, token_idx, seq_len); +} + +template +static __forceinline__ __device__ void folding( + Tdata *__restrict__ att, + Tmask mask, + unsigned int const total_seq_len, + int const stride_x, + int const stride_y, + int const stride_z) { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + block_folding( + att + offset, mask, token_idx, seq_len, total_seq_len); +} + +template +__global__ void fused_softmax_padding( + Tdata *__restrict__ att, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z) { + + padding(att, AttentionCausualMask(), stride_x, stride_y, stride_z); +} + +template +__global__ void fused_softmax_folding( + Tdata *__restrict__ att, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z, + unsigned int const total_seq_len) { + { + folding(att, AttentionCausualMask(), total_seq_len, stride_x, stride_y, stride_z); + } +} + +template +__global__ void fused_softmax_standard( + Tdata *__restrict__ att_, + unsigned int const stride_x, + unsigned int const stride_y, + unsigned int const stride_z, + unsigned int const total_seq_len) { + { + auto offset = blockIdx.x * stride_x + blockIdx.y * stride_y, + token_idx = blockIdx.y, + seq_len = gridDim.y; + + auto att = att_ + offset; + auto att_idx = threadIdx.x; + + float partial; + __shared__ float max_; + __shared__ float sum_; + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto block_op = BlockOp(temp_storage); + + // Partial max + partial = -__FLT_MAX__; + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + partial = max(partial, float(att[i])); + } + } + __syncthreads(); + // Block reduce max + { + auto acc = block_op.Reduce(partial, cub::Max()); + if (threadIdx.x == 0) { max_ = acc; } + } + __syncthreads(); + + // Partial sum + partial = 0.; + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + float e = expf(float(att[i]) - max_); + partial += e; + } + } + __syncthreads(); + + // Block reduce sum + { + auto acc = block_op.Reduce(partial, cub::Sum()); + if (threadIdx.x == 0) { sum_ = acc; } + } + __syncthreads(); + + // Softmax + for (unsigned int i = att_idx; i < total_seq_len; i += BLOCK_SIZE) { + if (i <= total_seq_len - seq_len + token_idx) { + float e = expf(float(att[i]) - max_); + att[i] = e / sum_; + } else { + att[i] = half(0); + } + } + } +} + + +void causal_softmax_nv_gpu_f16(CausalSoftmaxMacaDescriptor_t desc, void *y, void *stream) { + uint64_t total_seq_len = desc->total_seq_len; + uint64_t seq_len = desc->seq_len; + uint64_t batch_size = desc->batch_size; + uint64_t stride_x = desc->stride_b; + uint64_t stride_y = desc->stride_i; + uint64_t stride_z = desc->stride_j;// covert byte strides to element strides + unsigned int max_items_per_thread = desc->max_items_per_thread; + + dim3 grid(batch_size, seq_len); + + if (max_items_per_thread == 1) { + fused_softmax_padding + <<>>((half *) (y), stride_x, stride_y, stride_z); + } else if (max_items_per_thread <= 16) { + fused_softmax_folding + <<>>((half *) (y), stride_x, stride_y, stride_z, total_seq_len); + } else { + fused_softmax_standard + <<>>((half *) (y), stride_x, stride_y, stride_z, total_seq_len); + } +} + +infiniopStatus_t macaCausalSoftmax(CausalSoftmaxMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *data, + void *stream) { + if (hcSetDevice(desc->device_id) != hcSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)) { + causal_softmax_nv_gpu_f16(desc, data, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/causal_softmax/operator.cc b/src/ops/causal_softmax/operator.cc index ef10919f..c9d87dda 100644 --- a/src/ops/causal_softmax/operator.cc +++ b/src/ops/causal_softmax/operator.cc @@ -18,6 +18,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/causal_softmax_aclnn.h" #endif +#ifdef ENABLE_METAX_GPU +#include "maca/causal_softmax_maca.h" +#endif __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( infiniopHandle_t handle, @@ -44,6 +47,11 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( case DevAscendNpu: { return aclnnCreateCausalSoftmaxDescriptor((AscendHandle_t) handle, (CausalSoftmaxAclnnDescriptor_t *) desc_ptr, y_desc); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCreateCausalSoftmaxDescriptor((MacaHandle_t) handle, (CausalSoftmaxMacaDescriptor_t *) desc_ptr, y_desc); + } #endif } return STATUS_BAD_DEVICE; @@ -72,6 +80,11 @@ __C infiniopStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmax case DevAscendNpu: { return aclnnGetCausalSoftmaxWorkspaceSize((CausalSoftmaxAclnnDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaGetCausalSoftmaxWorkspaceSize((CausalSoftmaxMacaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -99,6 +112,11 @@ __C infiniopStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t des case DevAscendNpu: { return aclnnCausalSoftmax((CausalSoftmaxAclnnDescriptor_t) desc, workspace, workspace_size, data, stream); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCausalSoftmax((CausalSoftmaxMacaDescriptor_t) desc, workspace, workspace_size, data, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -126,6 +144,11 @@ __C infiniopStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftma case DevAscendNpu: { return aclnnDestroyCausalSoftmaxDescriptor((CausalSoftmaxAclnnDescriptor_t) desc); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaDestroyCausalSoftmaxDescriptor((CausalSoftmaxMacaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/matmul/maca/matmul_maca.cc b/src/ops/matmul/maca/matmul_maca.cc new file mode 100644 index 00000000..2d6658f7 --- /dev/null +++ b/src/ops/matmul/maca/matmul_maca.cc @@ -0,0 +1,44 @@ +#include "matmul_maca.h" +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" + +infiniopStatus_t macaCreateMatmulDescriptor(MacaHandle_t handle, + MatmulMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + float alpha, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + float beta) { + DT dtype = c_desc->dt; + + if (dtype != F16 && dtype != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + + infiniopStatus_t *status = new infiniopStatus_t{STATUS_EXECUTION_FAILED}; + auto info = MatmulInfo(c_desc, a_desc, b_desc, status); + if (*status != STATUS_SUCCESS) { + return *status; + } + + *desc_ptr = new MatmulMacaDescriptor{ + DevMetaxGpu, + dtype, + handle->device_id, + info, + alpha, + beta, + handle->mcblas_handles_t}; + return STATUS_SUCCESS; +} + +infiniopStatus_t macaGetMatmulWorkspaceSize(MatmulMacaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t macaDestroyMatmulDescriptor(MatmulMacaDescriptor_t desc) { + desc->mcblas_handles_t = nullptr; + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/matmul/maca/matmul_maca.h b/src/ops/matmul/maca/matmul_maca.h new file mode 100644 index 00000000..2264cdc4 --- /dev/null +++ b/src/ops/matmul/maca/matmul_maca.h @@ -0,0 +1,41 @@ +#ifndef __MACA_MATMUL_H__ +#define __MACA_MATMUL_H__ + +#include "../../../devices/maca/maca_handle.h" +#include "../blas.h" +#include "operators.h" +#include + +typedef struct MatmulMacaDescriptor { + Device device; + DT dtype; + int device_id; + MatmulInfo info; + float alpha; + float beta; + std::shared_ptr> mcblas_handles_t; +} MatmulMacaDescriptor; + +typedef struct MatmulMacaDescriptor *MatmulMacaDescriptor_t; + +infiniopStatus_t macaCreateMatmulDescriptor(MacaHandle_t handle, + MatmulMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + float alpha, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + float beta); + +infiniopStatus_t macaGetMatmulWorkspaceSize(MatmulMacaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t macaMatmul(MatmulMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream); + +infiniopStatus_t macaDestroyMatmulDescriptor(MatmulMacaDescriptor_t desc); + +#endif// __MACA_MATMUL_H__ diff --git a/src/ops/matmul/maca/matmul_maca.maca b/src/ops/matmul/maca/matmul_maca.maca new file mode 100644 index 00000000..d944c85a --- /dev/null +++ b/src/ops/matmul/maca/matmul_maca.maca @@ -0,0 +1,77 @@ +#include "../../../devices/maca/maca_handle.h" +#include "../../utils.h" +#include "../blas.h" +#include "matmul_maca.h" +#include +#include + +template +infiniopStatus_t matmul_maca(MatmulMacaDescriptor_t desc, void *c, float beta, void const *a, void const *b, float alpha, void *stream) { + auto info = desc->info; + + if (info.is_transed) { + std::swap(a, b); + } + + Tdata alpha_, beta_; + hpccDataType a_type, b_type, c_type; + hcblasComputeType_t compute_type; + + if constexpr (std::is_same::value) { + alpha_ = __float2half(alpha); + beta_ = __float2half(beta); + a_type = b_type = c_type = HPCC_R_16F; + compute_type = HCBLAS_COMPUTE_16F; + } else { + alpha_ = alpha; + beta_ = beta; + a_type = b_type = c_type = HPCC_R_32F; + compute_type = HCBLAS_COMPUTE_32F_FAST_TF32; + } + + auto op_a = info.a_matrix.row_stride == 1 ? HCBLAS_OP_N : HCBLAS_OP_T; + auto op_b = info.b_matrix.row_stride == 1 ? HCBLAS_OP_N : HCBLAS_OP_T; + + use_mcblas(desc->mcblas_handles_t, desc->device_id, (hcStream_t) stream, + [&](hcblasHandle_t handle) { hcblasGemmStridedBatchedEx( + handle, + op_a, + op_b, + info.m, + info.n, + info.k, + &alpha_, + a, + a_type, + info.a_matrix.ld(), + info.a_matrix.stride, + b, + b_type, + info.b_matrix.ld(), + info.b_matrix.stride, + &beta_, + c, + c_type, + info.c_matrix.ld(), + info.c_matrix.stride, + info.batch, + compute_type, + HCBLAS_GEMM_DEFAULT_TENSOR_OP); }); + return STATUS_SUCCESS; +} + +infiniopStatus_t macaMatmul(MatmulMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *c, + void const *a, + void const *b, + void *stream) { + if (desc->dtype == F16) { + return matmul_maca(desc, c, desc->beta, a, b, desc->alpha, stream); + } + if (desc->dtype == F32) { + return matmul_maca(desc, c, desc->beta, a, b, desc->alpha, stream); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/matmul/operator.cc b/src/ops/matmul/operator.cc index 444168b6..14748b99 100644 --- a/src/ops/matmul/operator.cc +++ b/src/ops/matmul/operator.cc @@ -14,6 +14,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/matmul_aclnn.h" #endif +#ifdef ENABLE_METAX_GPU +#include "maca/matmul_maca.h" +#endif __C infiniopStatus_t infiniopCreateMatmulDescriptor(infiniopHandle_t handle, infiniopMatmulDescriptor_t *desc_ptr, @@ -48,6 +51,11 @@ __C infiniopStatus_t infiniopCreateMatmulDescriptor(infiniopHandle_t handle, beta, 1); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCreateMatmulDescriptor((MacaHandle_t) handle, (MatmulMacaDescriptor_t *) desc_ptr, c_desc, alpha, a_desc, b_desc, beta); + } #endif } return STATUS_BAD_DEVICE; @@ -75,6 +83,11 @@ __C infiniopStatus_t infiniopGetMatmulWorkspaceSize(infiniopMatmulDescriptor_t d return aclnnGetMatmulWorkspaceSize((MatmulAclnnDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaGetMatmulWorkspaceSize((MatmulMacaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -104,6 +117,11 @@ __C infiniopStatus_t infiniopMatmul(infiniopMatmulDescriptor_t desc, void *works a, b, stream); +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaMatmul((MatmulMacaDescriptor_t) desc, workspace, workspace_size, c, a, b, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -130,6 +148,11 @@ __C infiniopStatus_t infiniopDestroyMatmulDescriptor(infiniopMatmulDescriptor_t case DevAscendNpu: { return aclnnDestroyMatmulDescriptor((MatmulAclnnDescriptor_t) desc); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaDestroyMatmulDescriptor((MatmulMacaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/random_sample/maca/random_sample_maca.cc b/src/ops/random_sample/maca/random_sample_maca.cc new file mode 100644 index 00000000..1cb0fe74 --- /dev/null +++ b/src/ops/random_sample/maca/random_sample_maca.cc @@ -0,0 +1,37 @@ +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" +#include "random_sample_maca.h" + +infiniopStatus_t macaCreateRandomSampleDescriptor(MacaHandle_t handle, + RandomSampleMacaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs) { + if (probs->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!dtype_eq(result->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + int voc = probs->shape[0]; + int rLength = result->shape[0]; + if (result->ndim != 1 && rLength != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + *desc_ptr = new RandomSampleMacaDescriptor{ + handle->device, + handle->device_id, + probs->dt, + voc, + result->dt, + rLength}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t macaGetRandomSampleWorkspaceSize(RandomSampleMacaDescriptor_t desc, uint64_t *size) { + *size = desc->voc * (2 * sizeof(uint64_t) + sizeof(desc->dtype)); + return STATUS_SUCCESS; +} + +infiniopStatus_t macaDestroyRandomSampleDescriptor(RandomSampleMacaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/random_sample/maca/random_sample_maca.h b/src/ops/random_sample/maca/random_sample_maca.h new file mode 100644 index 00000000..3cf1ab59 --- /dev/null +++ b/src/ops/random_sample/maca/random_sample_maca.h @@ -0,0 +1,38 @@ +#ifndef __MACA_RANDOM_SAMPLE_H__ +#define __MACA_RANDOM_SAMPLE_H__ + +#include "../../../devices/maca/maca_handle.h" +#include "operators.h" + +struct RandomSampleMacaDescriptor { + Device device; + int device_id; + DT dtype; + int voc; + DT rDtype; + int rLength; +}; + +typedef struct RandomSampleMacaDescriptor *RandomSampleMacaDescriptor_t; + +infiniopStatus_t macaCreateRandomSampleDescriptor(MacaHandle_t handle, + RandomSampleMacaDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs); + +infiniopStatus_t macaGetRandomSampleWorkspaceSize(RandomSampleMacaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t macaRandomSample(RandomSampleMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream); + +infiniopStatus_t macaDestroyRandomSampleDescriptor(RandomSampleMacaDescriptor_t desc); + + +#endif diff --git a/src/ops/random_sample/maca/random_sample_maca.maca b/src/ops/random_sample/maca/random_sample_maca.maca new file mode 100644 index 00000000..310343fb --- /dev/null +++ b/src/ops/random_sample/maca/random_sample_maca.maca @@ -0,0 +1,180 @@ +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" +#include "random_sample_maca.h" +#include +#include + +template +__global__ void softmax( + T *val_out, + int topk, + float temperature, int voc) { + float sum_s = 0.0f; + for (int i = threadIdx.x; i < topk; i += BLOCK_DIM) { + sum_s += __expf(static_cast(val_out[i] - val_out[0]) / temperature); + } + __shared__ float sum_inverse_total; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float block_sum = BlockReduce(temp_storage).Reduce(sum_s, cub::Sum()); + if (threadIdx.x == 0) { + sum_inverse_total = __fdividef(1.0F, block_sum);//高精度除法 + } + + __syncthreads(); + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < topk) { + val_out[tid] = static_cast(__expf(static_cast(val_out[tid] - val_out[0]) / temperature) * sum_inverse_total); + } +} + +__global__ void index(uint64_t *key_in, int voc) { + int ind = threadIdx.x + blockIdx.x * blockDim.x; + if (ind < voc) { + key_in[ind] = static_cast(ind); + } +} +template +__global__ void random_sample_kernel(uint64_t *result, + T *val_out, + float random_val, + float topp, + int topk, + uint64_t *key_out) { + int end = 0; + for (end = 0; end < topk; end++) { + if (val_out[end] >= static_cast(topp)) { + break; + } + } + if (end < topk - 1) { + end += 1; + } else { + end = topk; + } + + random_val *= static_cast(val_out[end - 1]); + for (int i = 0; i < end; i++) { + if (random_val < static_cast(val_out[i])) { + result[0] = key_out[i]; + break; + } + } +} +template +void sort_pairs_descending( + void *workspace, size_t &size_radix_sort, + T const *val_in, T *val_out, + I *key_in, I *key_out, + int voc, hcStream_t stream) { + cub::DeviceRadixSort::SortPairsDescending( + workspace, size_radix_sort, + val_in, val_out, + key_in, key_out, + voc, 0, sizeof(T) * 8, stream); +} +template +void inclusive_sum( + void *workspace, size_t &size_scan, + T *data, int voc, + hcStream_t stream) { + cub::DeviceScan::InclusiveSum( + workspace, size_scan, + data, data, voc, + stream); +} +template +void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan, + int voc, hcStream_t stream) { + + + sort_pairs_descending(nullptr, size_radix_sort, + nullptr, nullptr, + nullptr, nullptr, + voc, stream); + + inclusive_sum( + nullptr, size_scan, + nullptr, voc, + stream); +} +__global__ void random_sample_kernel(uint64_t *result, + uint64_t *key_out) { + result[0] = key_out[0]; +} +void random_sample_nv_gpu_f16(RandomSampleMacaDescriptor_t desc, void *workspace, void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) { + int voc = desc->voc; + //下面这段代码在排序 + char *origin = reinterpret_cast(workspace); + char *keyTmp = origin + voc * sizeof(half); + half *val_out = (half *) origin; + + uint64_t *key_in = (uint64_t *) keyTmp; + uint64_t *key_out = key_in + voc; + + index<<<(voc + 1023) / 1024, 1024, 0, (hcStream_t) stream>>>(key_in, voc); + //下面开始计算workspace空间 + size_t size_radix_sort; + size_t size_scan; + random_sample_workspace(size_radix_sort, size_scan, + voc, (hcStream_t) stream); + void *workspace_extra; + hcMalloc(&workspace_extra, size_radix_sort + size_scan); + sort_pairs_descending( + workspace_extra, size_radix_sort, + (half *) probs, val_out, + key_in, key_out, + voc, (hcStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上 + //排序结束,然后开始做softmax变换 + if (topp > 0 && topk > 1) { + int BLOCK_DIM = 1024; + int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM; + softmax<<>>(val_out, topk, + temperature, voc); + + + inclusive_sum( + workspace_extra, size_scan, + val_out, voc, + (hcStream_t) stream);//该函数会实现scan功能不断累加结果 + random_sample_kernel<<<1, 1, 0, (hcStream_t) stream>>>((uint64_t *) result, + val_out, + random_val, + topp, + topk, + key_out); + + } else { + random_sample_kernel<<<1, 1, 0, (hcStream_t) stream>>>((uint64_t *) result, + key_out); + } + hcFree(workspace_extra); +} + +infiniopStatus_t macaRandomSample(RandomSampleMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *result, + void const *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) { + if (hcSetDevice(desc->device_id) != hcSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)) { + random_sample_nv_gpu_f16(desc, workspace, result, probs, random_val, topp, topk, temperature, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/random_sample/operator.cc b/src/ops/random_sample/operator.cc index ff241e77..b9cf3ded 100644 --- a/src/ops/random_sample/operator.cc +++ b/src/ops/random_sample/operator.cc @@ -14,6 +14,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/random_sample.h" #endif +#ifdef ENABLE_METAX_GPU +#include "maca/random_sample_maca.h" +#endif __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handle, infiniopRandomSampleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t result, infiniopTensorDescriptor_t probs) { switch (handle->device) { @@ -37,6 +40,13 @@ __C infiniopStatus_t infiniopCreateRandomSampleDescriptor(infiniopHandle_t handl return ascendCreateRandomSampleDescriptor((AscendHandle_t) handle, (RandomSampleAscendDescriptor_t *) desc_ptr, result, probs); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCreateRandomSampleDescriptor((MacaHandle_t) handle, + (RandomSampleMacaDescriptor_t *) desc_ptr, result, + probs); + } #endif } return STATUS_BAD_DEVICE; @@ -64,6 +74,11 @@ __C infiniopStatus_t infiniopGetRandomSampleWorkspaceSize(infiniopRandomSampleDe case DevAscendNpu: { return ascendGetRandomSampleWorkspaceSize((RandomSampleAscendDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaGetRandomSampleWorkspaceSize((RandomSampleMacaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -97,6 +112,11 @@ __C infiniopStatus_t infiniopRandomSample(infiniopRandomSampleDescriptor_t desc, case DevAscendNpu: { return ascendRandomSample((RandomSampleAscendDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaRandomSample((RandomSampleMacaDescriptor_t) desc, workspace, workspace_size, result, probs, random_val, topp, topk, temperature, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -121,6 +141,11 @@ __C infiniopStatus_t infiniopDestroyRandomSampleDescriptor(infiniopRandomSampleD case DevAscendNpu: { return ascendDestroyRandomSampleDescriptor((RandomSampleAscendDescriptor_t) desc); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaDestroyRandomSampleDescriptor((RandomSampleMacaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/rearrange/maca/rearrange_maca.cc b/src/ops/rearrange/maca/rearrange_maca.cc new file mode 100644 index 00000000..ac33fe06 --- /dev/null +++ b/src/ops/rearrange/maca/rearrange_maca.cc @@ -0,0 +1,70 @@ +#include "rearrange_maca.h" +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" +#include + +infiniopStatus_t macaCreateRearrangeDescriptor(MacaHandle_t handle, + RearrangeMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src) { + auto dt = dst->dt; + if (!dtype_eq(src->dt, dt)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + auto ndim = dst->ndim; + if (src->ndim != ndim || ndim == 0) { + return STATUS_BAD_TENSOR_SHAPE; + } + for (int i = 0; i < ndim; ++i) { + if (dst->shape[i] != src->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (dst->strides[ndim - 1] != 1 || src->strides[ndim - 1] != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + + switch (ndim) { + case 1: + *desc_ptr = new RearrangeMacaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[0], + 1, 1, + 0, 0, + 0, 0}; + break; + case 2: + *desc_ptr = new RearrangeMacaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[1], + 1, dst->shape[0], + 0, dst->strides[0], + 0, src->strides[0]}; + break; + case 3: + *desc_ptr = new RearrangeMacaDescriptor{ + handle->device, + handle->device_id, + dt.size * dst->shape[2], + dst->shape[0], dst->shape[1], + dst->strides[0], dst->strides[1], + src->strides[0], src->strides[1]}; + break; + default: + return STATUS_BAD_TENSOR_SHAPE; + } + + (*desc_ptr)->dst_rs *= dt.size; + (*desc_ptr)->dst_cs *= dt.size; + (*desc_ptr)->src_rs *= dt.size; + (*desc_ptr)->src_cs *= dt.size; + + return STATUS_SUCCESS; +} +infiniopStatus_t macaDestroyRearrangeDescriptor(RearrangeMacaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rearrange/maca/rearrange_maca.h b/src/ops/rearrange/maca/rearrange_maca.h new file mode 100644 index 00000000..701f55bb --- /dev/null +++ b/src/ops/rearrange/maca/rearrange_maca.h @@ -0,0 +1,29 @@ +#ifndef __MACA_REARRANGE_H__ +#define __MACA_REARRANGE_H__ + +#include "../../../devices/maca/maca_handle.h" +#include "operators.h" + +struct RearrangeMacaDescriptor { + Device device; + int device_id; + uint64_t unit, r, c; + int64_t dst_rs, dst_cs, src_rs, src_cs; +}; + +typedef struct RearrangeMacaDescriptor *RearrangeMacaDescriptor_t; + +infiniopStatus_t macaCreateRearrangeDescriptor(MacaHandle_t handle, + RearrangeMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t dst, + infiniopTensorDescriptor_t src); + +infiniopStatus_t macaRearrange(RearrangeMacaDescriptor_t desc, + void *dst, + void const *src, + void *stream); + +infiniopStatus_t macaDestroyRearrangeDescriptor(RearrangeMacaDescriptor_t desc); + +void rearrange_mc_gpu(RearrangeMacaDescriptor_t, void *y, void const *x, void *stream); +#endif// __MACA_REARRANGE_H__ diff --git a/src/ops/rearrange/maca/rearrange_maca.maca b/src/ops/rearrange/maca/rearrange_maca.maca new file mode 100644 index 00000000..b5152c15 --- /dev/null +++ b/src/ops/rearrange/maca/rearrange_maca.maca @@ -0,0 +1,76 @@ +#include "../../../devices/maca/common_maca.h" +#include "rearrange_maca.h" + +template +static __global__ void rearrange( + void *__restrict__ dst, + int const rsa, + int const csa, + void const *__restrict__ src, + int const rsb, + int const csb, + unsigned int const ncols) { + + auto row = blockIdx.y, + col = blockIdx.x * blockDim.y + threadIdx.y; + if (col >= ncols) return; + + auto thread = threadIdx.x; + auto warp_size = blockDim.x; + auto i = (row * rsa + col * csa) * warp_size + thread; + auto j = (row * rsb + col * csb) * warp_size + thread; + + reinterpret_cast(dst)[i] = reinterpret_cast(src)[j]; +} + +void rearrange_mc_gpu(RearrangeMacaDescriptor_t desc, void *y, void const *x, void *stream) { + auto maca_stream = reinterpret_cast(stream); + auto unit = desc->unit, + r = desc->r, c = desc->c; + auto dst_rs = desc->dst_rs, dst_cs = desc->dst_cs, + src_rs = desc->src_rs, src_cs = desc->src_cs; + + if (r == 1 && c == 1) { + hcMemcpyAsync(y, x, unit, hcMemcpyDeviceToDevice, maca_stream); + return; + } + + auto warps = 1024 / WARP_SIZE; + auto grid = dim3((c + warps - 1) / warps, r); + auto block = dim3(WARP_SIZE, (c + grid.x - 1) / grid.x); + dst_rs /= unit; + dst_cs /= unit; + src_rs /= unit; + src_cs /= unit; + + switch (unit / WARP_SIZE) { + case 1: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 2: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 4: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 8: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 16: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + case 32: + rearrange<<>>(y, dst_rs, dst_cs, x, src_rs, src_cs, c); + break; + default: + break; + } +} +infiniopStatus_t macaRearrange(RearrangeMacaDescriptor_t desc, + void *dst, void const *src, void *stream) { + if (hcSetDevice(desc->device_id) != hcSuccess) { + return STATUS_BAD_DEVICE; + } + rearrange_mc_gpu(desc, dst, src, stream); + return STATUS_SUCCESS; +} diff --git a/src/ops/rearrange/operator.cc b/src/ops/rearrange/operator.cc index a1084d48..752211e5 100644 --- a/src/ops/rearrange/operator.cc +++ b/src/ops/rearrange/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/rearrange_aclnn.h" #endif +#ifdef ENABLE_METAX_GPU +#include "maca/rearrange_maca.h" +#endif __C infiniopStatus_t infiniopCreateRearrangeDescriptor( infiniopHandle_t handle, @@ -46,6 +49,11 @@ __C infiniopStatus_t infiniopCreateRearrangeDescriptor( dst, src); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCreateRearrangeDescriptor((MacaHandle_t) handle, (RearrangeMacaDescriptor_t *) desc_ptr, dst, src); + } #endif } return STATUS_BAD_DEVICE; @@ -75,6 +83,11 @@ __C infiniopStatus_t infiniopRearrange(infiniopRearrangeDescriptor_t desc, void src, stream); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaRearrange((RearrangeMacaDescriptor_t) desc, dst, src, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -101,6 +114,11 @@ __C infiniopStatus_t infiniopDestroyRearrangeDescriptor(infiniopRearrangeDescrip case DevAscendNpu: { return aclnnDestroyRearrangeDescriptor((RearrangeAclnnDescriptor_t) desc); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaDestroyRearrangeDescriptor((RearrangeMacaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/rms_norm/maca/rms_norm_maca.cc b/src/ops/rms_norm/maca/rms_norm_maca.cc new file mode 100644 index 00000000..054be969 --- /dev/null +++ b/src/ops/rms_norm/maca/rms_norm_maca.cc @@ -0,0 +1,46 @@ +#include "rms_norm_maca.h" +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" + +infiniopStatus_t macaCreateRMSNormDescriptor(MacaHandle_t handle, RMSNormMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + float epsilon) { + if (y_desc->ndim != 2 || x_desc->ndim != 2 || w_desc->ndim != 1) { + return STATUS_BAD_TENSOR_SHAPE; + } + + auto n = y_desc->shape[0], + d = y_desc->shape[1]; + + if (x_desc->shape[0] != n || x_desc->shape[1] != d || w_desc->shape[0] != d) { + return STATUS_BAD_TENSOR_SHAPE; + } + + int64_t stride_y = y_desc->strides[0]; + int64_t stride_x = x_desc->strides[0]; + auto w_datatype = w_desc->dt; + *desc_ptr = new RMSNormMacaDescriptor{ + handle->device, + handle->device_id, + y_desc->dt, + n, + d, + stride_y, + stride_x, + w_datatype, + epsilon}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t macaGetRMSNormWorkspaceSize(RMSNormMacaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t macaDestroyRMSNormDescriptor(RMSNormMacaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rms_norm/maca/rms_norm_maca.h b/src/ops/rms_norm/maca/rms_norm_maca.h new file mode 100644 index 00000000..f244ce97 --- /dev/null +++ b/src/ops/rms_norm/maca/rms_norm_maca.h @@ -0,0 +1,40 @@ +#ifndef __MACA_RMS_NORM_H__ +#define __MACA_RMS_NORM_H__ + +#include "../../../devices/maca/maca_handle.h" +#include "operators.h" + +struct RMSNormMacaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t n; + uint64_t d; + int64_t stride_y; + int64_t stride_x; + DT w_datatype; + float epsilon; +}; + +typedef struct RMSNormMacaDescriptor *RMSNormMacaDescriptor_t; + +infiniopStatus_t macaCreateRMSNormDescriptor(MacaHandle_t handle, + RMSNormMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + float epsilon); + +infiniopStatus_t macaGetRMSNormWorkspaceSize(RMSNormMacaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t macaRMSNorm(RMSNormMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream); + +infiniopStatus_t macaDestroyRMSNormDescriptor(RMSNormMacaDescriptor_t desc); + +void rms_norm_mc_gpu_f16(RMSNormMacaDescriptor_t desc, void *y, void const *x, void const *w, float epsilon, void *stream); + +#endif// __MACA_RMS_NORM_H__ diff --git a/src/ops/rms_norm/maca/rms_norm_maca.maca b/src/ops/rms_norm/maca/rms_norm_maca.maca new file mode 100644 index 00000000..3becfab6 --- /dev/null +++ b/src/ops/rms_norm/maca/rms_norm_maca.maca @@ -0,0 +1,173 @@ +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" +#include "rms_norm_maca.h" +#include +#include + +// assert BLOCK_SIZE >= blockDim.x +template +static __global__ void rms_norm_padding( + Tdata *__restrict__ o_, + unsigned int const stride_y, + Tdata const *__restrict__ x_, + unsigned int const stride_x, + Wdata const *__restrict__ w_, + float const epsilon) { + auto y = o_ + blockIdx.x * stride_y + threadIdx.x; + auto x = x_[blockIdx.x * stride_x + threadIdx.x]; + auto w = w_[threadIdx.x]; + + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + auto acc = BlockOp(temp_storage).Reduce(x * x, cub::Sum()); + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + rms = Tdata(rsqrtf(acc / float(blockDim.x) + epsilon)); + } + __syncthreads(); + + *y = rms * x * (Tdata) w; +} + +template +static __global__ void rms_norm_folding( + Tdata *__restrict__ y, + unsigned int const stride_y, + Tdata const *__restrict__ x, + unsigned int const stride_x, + Wdata const *__restrict__ w, + float const epsilon, + unsigned int const items_size) { + y += blockIdx.x * stride_y; + x += blockIdx.x * stride_x; + + float thread_data[ITEMS_PER_THREAD]; + { + using BlockOp = cub::BlockLoad; + __shared__ typename BlockOp::TempStorage temp_storage; + BlockOp(temp_storage).Load(x, thread_data, items_size, 0.f); + } + + float squared[ITEMS_PER_THREAD]; +#pragma unroll + for (unsigned int i = 0; i < ITEMS_PER_THREAD; ++i) { + squared[i] = thread_data[i] * thread_data[i]; + } + + float acc; + { + using BlockOp = cub::BlockReduce; + __shared__ typename BlockOp::TempStorage temp_storage; + acc = BlockOp(temp_storage).Reduce(squared, cub::Sum()); + } + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + rms = Tdata(rsqrtf(acc / float(items_size) + epsilon)); + } + __syncthreads(); + +#pragma unroll + for (unsigned int i = 0; i < ITEMS_PER_THREAD; ++i) { + if (auto j = i + threadIdx.x * ITEMS_PER_THREAD; j < items_size) { + y[j] = Tdata(float(rms) * float(thread_data[i]) * float(w[j])); + } + } +} + +template +static __global__ void rms_norm_standard( + Tdata *__restrict__ y_, + unsigned int const stride_y, + Tdata const *__restrict__ x_, + unsigned int const stride_x, + Wdata const *__restrict__ w, + float const epsilon, + unsigned int const d) { + auto y = y_ + blockIdx.x * stride_y; + auto x = x_ + blockIdx.x * stride_x; + + __shared__ float partial_sum[BLOCK_SIZE]; + + float sum = 0.0f; + for (int i = threadIdx.x; i < d; i += BLOCK_SIZE) { + sum += float(x[i]) * float(x[i]); + } + + partial_sum[threadIdx.x] = sum; + __syncthreads(); + for (int stride = BLOCK_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + partial_sum[threadIdx.x] += partial_sum[threadIdx.x + stride]; + } + __syncthreads(); + } + + __shared__ Tdata rms; + if (threadIdx.x == 0) { + float row_sum = partial_sum[0]; + rms = Tdata(rsqrtf(row_sum / float(d) + epsilon)); + } + __syncthreads(); + + for (int i = threadIdx.x; i < d; i += BLOCK_SIZE) { + y[i] = rms * x[i] * (Tdata) w[i]; + } +} + +void rms_norm_mc_gpu_f16(RMSNormMacaDescriptor_t desc, void *y, void const *x, void const *w, void *stream) { + auto n = desc->n, d = desc->d; + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto epsilon = desc->epsilon; + + // Get strides in terms of elements + auto stride_y = desc->stride_y; + auto stride_x = desc->stride_x; + + auto maca_stream = reinterpret_cast(stream); + unsigned int items_per_thread = ROUND_UP_DIV(d, MAX_THREADS_PER_BLOCK); + auto w_datatype = desc->w_datatype; + if (dtype_eq(w_datatype, F16)) { + auto w_ = reinterpret_cast(w); + if (items_per_thread == 1) { + rms_norm_padding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon); + } else if (items_per_thread <= 16) { + rms_norm_folding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } else { + rms_norm_standard + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } + } else { + auto w_ = reinterpret_cast(w); + if (items_per_thread == 1) { + rms_norm_padding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon); + } else if (items_per_thread <= 16) { + rms_norm_folding + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } else { + rms_norm_standard + <<>>(y_, stride_y, x_, stride_x, w_, epsilon, d); + } + } +} + +infiniopStatus_t macaRMSNorm(RMSNormMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *y, void const *x, void const *w, + void *stream) { + if (hcSetDevice(desc->device_id) != hcSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)) { + rms_norm_mc_gpu_f16(desc, y, x, w, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/rms_norm/operator.cc b/src/ops/rms_norm/operator.cc index 9aa4b206..dff9573b 100644 --- a/src/ops/rms_norm/operator.cc +++ b/src/ops/rms_norm/operator.cc @@ -17,6 +17,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/rms_norm_aclnn.h" #endif +#ifdef ENABLE_METAX_GPU +#include "maca/rms_norm_maca.h" +#endif __C infiniopStatus_t infiniopCreateRMSNormDescriptor( infiniopHandle_t handle, @@ -49,6 +52,11 @@ __C infiniopStatus_t infiniopCreateRMSNormDescriptor( w_desc, epsilon); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCreateRMSNormDescriptor((MacaHandle_t) handle, (RMSNormMacaDescriptor_t *) desc_ptr, y_desc, x_desc, w_desc, epsilon); + } #endif } return STATUS_BAD_DEVICE; @@ -76,6 +84,11 @@ __C infiniopStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t return aclnnGetRMSNormWorkspaceSize((RMSNormAclnnDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaGetRMSNormWorkspaceSize((RMSNormMacaDescriptor_t) desc, size); + } #endif } return STATUS_BAD_DEVICE; @@ -109,6 +122,11 @@ __C infiniopStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *wor w, stream); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaRMSNorm((RMSNormMacaDescriptor_t) desc, workspace, workspace_size, y, x, w, stream); + } #endif } return STATUS_BAD_DEVICE; @@ -136,6 +154,11 @@ __C infiniopStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_ return aclnnDestroyRMSNormDescriptor((RMSNormAclnnDescriptor_t) desc); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaDestroyRMSNormDescriptor((RMSNormMacaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/rotary_embedding/maca/rotary_embedding_maca.cc b/src/ops/rotary_embedding/maca/rotary_embedding_maca.cc new file mode 100644 index 00000000..171f1c57 --- /dev/null +++ b/src/ops/rotary_embedding/maca/rotary_embedding_maca.cc @@ -0,0 +1,76 @@ +#include "rotary_embedding_maca.h" +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" + +infiniopStatus_t macaCreateRoPEDescriptor(MacaHandle_t handle, + RoPEMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table) { + if (desc_ptr == nullptr) + return STATUS_MEMORY_NOT_ALLOCATED; + + if (t->ndim != 3 || + pos_ids->ndim != 1 || + sin_table->ndim != 2 || + cos_table->ndim != 2) + return STATUS_BAD_TENSOR_SHAPE; + + auto seq_len = t->shape[0]; + auto nhead = t->shape[1]; + auto dim = t->shape[2]; + auto total_seq_len = sin_table->shape[0]; + + if (dim % 2 != 0) + return STATUS_BAD_TENSOR_SHAPE; + + if (pos_ids->shape[0] != seq_len || + sin_table->shape[1] != dim || + cos_table->shape[1] != dim || + sin_table->shape[0] != cos_table->shape[0]) + return STATUS_BAD_TENSOR_SHAPE; + + // TODO: support larger dim in the future + if (dim / 2 > MAX_THREADS_PER_BLOCK) { + return STATUS_BAD_TENSOR_SHAPE; + } + + if (t->strides[2] != 1 || + pos_ids->strides[0] != 1 || + sin_table->strides[1] != 1 || + cos_table->strides[1] != 1) + return STATUS_BAD_TENSOR_STRIDES; + + if (!dtype_eq(t->dt, F16)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32)) + return STATUS_BAD_TENSOR_DTYPE; + + if (!dtype_eq(pos_ids->dt, U64)) + return STATUS_BAD_TENSOR_DTYPE; + + *desc_ptr = new RoPEMacaDescriptor{ + handle->device, + handle->device_id, + t->dt, + seq_len, + nhead, + dim, + total_seq_len, + {t->strides[0], t->strides[1]}}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t macaGetRoPEWorkspaceSize(RoPEMacaDescriptor_t desc, uint64_t *size) { + *size = 0; + return STATUS_SUCCESS; +} + + +infiniopStatus_t macaDestroyRoPEDescriptor(RoPEMacaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/rotary_embedding/maca/rotary_embedding_maca.h b/src/ops/rotary_embedding/maca/rotary_embedding_maca.h new file mode 100644 index 00000000..f5de3b14 --- /dev/null +++ b/src/ops/rotary_embedding/maca/rotary_embedding_maca.h @@ -0,0 +1,40 @@ +#ifndef __METAX_GPU_ROTARY_EMBEDDING_H__ +#define __METAX_GPU_ROTARY_EMBEDDING_H__ + +#include "../../../devices/maca/maca_handle.h" +#include "operators.h" + +struct RoPEMacaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t seq_len; + uint64_t nhead; + uint64_t dim; + uint64_t total_seq_len; + int64_t strides[2]; +}; + +typedef struct RoPEMacaDescriptor *RoPEMacaDescriptor_t; + +infiniopStatus_t macaCreateRoPEDescriptor(MacaHandle_t handle, + RoPEMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t t, + infiniopTensorDescriptor_t pos_ids, + infiniopTensorDescriptor_t sin_table, + infiniopTensorDescriptor_t cos_table); + +infiniopStatus_t macaGetRoPEWorkspaceSize(RoPEMacaDescriptor_t desc, uint64_t *size); + +infiniopStatus_t macaRoPE(RoPEMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream); + +infiniopStatus_t macaDestroyRoPEDescriptor(RoPEMacaDescriptor_t desc); + +#endif// __METAX_GPU_ROTARY_EMBEDDING_H__ diff --git a/src/ops/rotary_embedding/maca/rotary_embedding_maca.maca b/src/ops/rotary_embedding/maca/rotary_embedding_maca.maca new file mode 100644 index 00000000..aaa52250 --- /dev/null +++ b/src/ops/rotary_embedding/maca/rotary_embedding_maca.maca @@ -0,0 +1,70 @@ +#include "../../utils.h" +#include "rotary_embedding_maca.h" +#include + +static __global__ void padding_f16( + half *__restrict__ x_, + uint64_t const *__restrict__ pos_, + float const *__restrict__ sin_, + float const *__restrict__ cos_, + long const stride0, + long const stride1) { + auto dk = blockDim.x; + auto k = threadIdx.x; + auto offset = blockIdx.x * stride0 + blockIdx.y * stride1 + k * 2; + auto &x = reinterpret_cast(x_[offset]); + auto pos = pos_[blockIdx.x]; + auto sincos_offset = pos * dk * 2 + k * 2; + + float sin0 = sin_[sincos_offset], cos0 = cos_[sincos_offset], + sin1 = sin_[sincos_offset + 1], cos1 = cos_[sincos_offset + 1]; + float x0 = __half2float(x.x) * cos0 - __half2float(x.y) * sin0; + float x1 = __half2float(x.y) * cos1 + __half2float(x.x) * sin1; + x = half2(x0, x1); +} + + +void rotary_embedding_mc_gpu_f16( + RoPEMacaDescriptor_t desc, + half *t, + uint64_t const *pos, + float const *sin_, float const *cos_, + void *stream) { + auto nt = desc->seq_len, + nh = desc->nhead, + dh = desc->dim; + + // batching 2 half together + auto stride0 = desc->strides[0], + stride1 = desc->strides[1]; + + auto maca_stream = reinterpret_cast(stream); + padding_f16<<>>(t, pos, sin_, cos_, stride0, stride1); +} + +infiniopStatus_t macaRoPE(RoPEMacaDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *t, + void const *pos_ids, + void const *sin_table, + void const *cos_table, + void *stream) { + if (t == nullptr || pos_ids == nullptr || sin_table == nullptr || cos_table == nullptr) + return STATUS_BAD_PARAM; + + checkMacaError(hcSetDevice(desc->device_id)); + + if (dtype_eq(desc->dtype, F16)) { + rotary_embedding_mc_gpu_f16(desc, + reinterpret_cast(t), + reinterpret_cast(pos_ids), + reinterpret_cast(sin_table), + reinterpret_cast(cos_table), + stream); + } else { + return STATUS_BAD_TENSOR_DTYPE; + } + + return STATUS_SUCCESS; +} diff --git a/src/ops/rotary_embedding/operator.cc b/src/ops/rotary_embedding/operator.cc index 33ac8ad3..5c1d4aec 100644 --- a/src/ops/rotary_embedding/operator.cc +++ b/src/ops/rotary_embedding/operator.cc @@ -15,6 +15,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/rotary_embedding.h" #endif +#ifdef ENABLE_METAX_GPU +#include "maca/rotary_embedding_maca.h" +#endif struct RoPEDescriptor { Device device; @@ -52,6 +55,16 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(infiniopHandle_t handle, sin_table, cos_table); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCreateRoPEDescriptor((MacaHandle_t) handle, + (RoPEMacaDescriptor_t *) desc_ptr, + t, + pos_ids, + sin_table, + cos_table); + } #endif } return STATUS_BAD_DEVICE; @@ -79,6 +92,12 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, return ascendGetRoPEWorkspaceSize((RoPEAscendDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaGetRoPEWorkspaceSize((RoPEMacaDescriptor_t) desc, + size); + } #endif } return STATUS_BAD_DEVICE; @@ -119,6 +138,18 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc, cos_table, stream); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaRoPE((RoPEMacaDescriptor_t) desc, + workspace, + workspace_size, + t, + pos_ids, + sin_table, + cos_table, + stream); + } #endif } return STATUS_BAD_DEVICE; @@ -145,6 +176,11 @@ __C infiniopStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc case DevAscendNpu: { return ascendDestroyRoPEDescriptor((RoPEAscendDescriptor_t) desc); } +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaDestroyRoPEDescriptor((RoPEMacaDescriptor_t) desc); + } #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/swiglu/maca/swiglu_maca.cc b/src/ops/swiglu/maca/swiglu_maca.cc new file mode 100644 index 00000000..71c2af70 --- /dev/null +++ b/src/ops/swiglu/maca/swiglu_maca.cc @@ -0,0 +1,51 @@ +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" +#include "swiglu_maca.h" + +infiniopStatus_t macaCreateSwiGLUDescriptor(MacaHandle_t handle, + SwiGLUMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + if (c_desc->ndim != 2 || a_desc->ndim != 2 || b_desc->ndim != 2) { + return STATUS_BAD_TENSOR_SHAPE; + } + + DT dtype = c_desc->dt; + + if (!dtype_eq(dtype, F16)) { + return STATUS_BAD_TENSOR_DTYPE; + } + + if (a_desc->strides[1] != 1 || b_desc->strides[1] != 1 || c_desc->strides[1] != 1) { + return STATUS_BAD_TENSOR_STRIDES; + } + + uint64_t seq_len = c_desc->shape[0], + di = c_desc->shape[1]; + + uint64_t stride_a = a_desc->strides[0], + stride_b = b_desc->strides[0], + stride_c = c_desc->strides[0]; + + + if (a_desc->shape[0] != seq_len || a_desc->shape[1] != di || !dtype_eq(a_desc->dt, dtype) || + b_desc->shape[0] != seq_len || b_desc->shape[1] != di || !dtype_eq(b_desc->dt, dtype)) { + return STATUS_BAD_PARAM; + } + + *desc_ptr = new SwiGLUMacaDescriptor{DevMetaxGpu, + handle->device_id, + dtype, + seq_len, + di, + stride_a, + stride_b, + stride_c}; + return STATUS_SUCCESS; +} + +infiniopStatus_t macaDestroySwiGLUDescriptor(SwiGLUMacaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/swiglu/maca/swiglu_maca.h b/src/ops/swiglu/maca/swiglu_maca.h new file mode 100644 index 00000000..3ea7c661 --- /dev/null +++ b/src/ops/swiglu/maca/swiglu_maca.h @@ -0,0 +1,36 @@ +#ifndef __MACA_SWIGLU_H__ +#define __MACA_SWIGLU_H__ +#include "../../../devices/maca/maca_handle.h" +#include "../../utils.h" +#include "operators.h" + +struct SwiGLUMacaDescriptor { + Device device; + int device_id; + DT dtype; + uint64_t seq_len; + uint64_t di; + uint64_t stride_a; + uint64_t stride_b; + uint64_t stride_c; +}; + +typedef struct SwiGLUMacaDescriptor *SwiGLUMacaDescriptor_t; + +infiniopStatus_t macaCreateSwiGLUDescriptor(MacaHandle_t handle, + SwiGLUMacaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_dec, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + +infiniopStatus_t macaSwiGLU(SwiGLUMacaDescriptor_t desc, + void *c, + void const *a, + void const *b, + void *stream); + +infiniopStatus_t macaDestroySwiGLUDescriptor(SwiGLUMacaDescriptor_t desc); + +void swiglu_mc_gpu_f16(SwiGLUMacaDescriptor_t desc, void *c, void const *a, void const *b, void *stream); + +#endif// __MC_GPU_SWIGLU_H__ diff --git a/src/ops/swiglu/maca/swiglu_maca.maca b/src/ops/swiglu/maca/swiglu_maca.maca new file mode 100644 index 00000000..68692c04 --- /dev/null +++ b/src/ops/swiglu/maca/swiglu_maca.maca @@ -0,0 +1,70 @@ +#include "../../../devices/maca/common_maca.h" +#include "../../utils.h" +#include "swiglu_maca.h" +#include + +static __forceinline__ __device__ float silu(float x) { + return x * fdividef(1, 1 + expf(-x)); +} + +inline int gcd(int a, int b) { + while (b != 0) { + int rem = a % b; + a = b; + b = rem; + } + return a; +} + +template +static __global__ void swiglu( + Tdata *__restrict__ c, + int const stride_c, + Tdata const *__restrict__ a, + int const stride_a, + Tdata const *__restrict__ b, + int const stride_b) { + auto i = blockIdx.y * stride_b + blockIdx.x * blockDim.x + threadIdx.x, + j = blockIdx.y * stride_a + blockIdx.x * blockDim.x + threadIdx.x, + k = blockIdx.y * stride_c + blockIdx.x * blockDim.x + threadIdx.x; + auto x = float(b[i]), + y = float(a[j]); + c[k] = Tdata(silu(x) * y); +} + +void swiglu_mc_gpu_f16(SwiGLUMacaDescriptor_t desc, void *c, void const *a, void const *b, void *stream) { + + auto seq_len = desc->seq_len, + di = desc->di; + + auto stride_a = desc->stride_a, + stride_b = desc->stride_b, + stride_c = desc->stride_c; + + dim3 block_dims = gcd(MAX_THREADS_PER_BLOCK, di); + dim3 grid_dims = dim3(di / block_dims.x, seq_len); + + auto a_ptr = reinterpret_cast(a); + auto b_ptr = reinterpret_cast(b); + auto c_ptr = reinterpret_cast(c); + + auto maca_stream = reinterpret_cast(stream); + + swiglu<<>>( + c_ptr, stride_c, a_ptr, stride_a, b_ptr, stride_b); +} + +infiniopStatus_t macaSwiGLU(SwiGLUMacaDescriptor_t desc, + void *c, + void const *a, + void const *b, + void *stream) { + checkMacaError(hcSetDevice(desc->device_id)); + + if (dtype_eq(desc->dtype, F16)) { + swiglu_mc_gpu_f16(desc, c, a, b, stream); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/swiglu/operator.cc b/src/ops/swiglu/operator.cc index b0bcb35c..3eb68a97 100644 --- a/src/ops/swiglu/operator.cc +++ b/src/ops/swiglu/operator.cc @@ -14,6 +14,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/swiglu.h" #endif +#ifdef ENABLE_METAX_GPU +#include "maca/swiglu_maca.h" +#endif __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, infiniopSwiGLUDescriptor_t *desc_ptr, @@ -45,6 +48,15 @@ __C infiniopStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t handle, c_desc, a_desc, b_desc); +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: { + return macaCreateSwiGLUDescriptor((MacaHandle_t) handle, + (SwiGLUMacaDescriptor_t *) desc_ptr, + c_desc, + a_desc, + b_desc); + } #endif } return STATUS_BAD_DEVICE; @@ -72,6 +84,10 @@ __C infiniopStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, #ifdef ENABLE_ASCEND_NPU case DevAscendNpu: return ascendSwiGLU((SwiGLUAscendDescriptor_t) desc, c, a, b, stream); +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: + return macaSwiGLU((SwiGLUMacaDescriptor_t) desc, c, a, b, stream); #endif } return STATUS_BAD_DEVICE; @@ -95,6 +111,10 @@ __C infiniopStatus_t infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t #ifdef ENABLE_ASCEND_NPU case DevAscendNpu: return ascendDestroySwiGLUDescriptor((SwiGLUAscendDescriptor_t) desc); +#endif +#ifdef ENABLE_METAX_GPU + case DevMetaxGpu: + return macaDestroySwiGLUDescriptor((SwiGLUMacaDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; diff --git a/xmake.lua b/xmake.lua index 327e91ef..dcb14715 100644 --- a/xmake.lua +++ b/xmake.lua @@ -40,6 +40,14 @@ option("ascend-npu") add_defines("ENABLE_ASCEND_NPU") option_end() +option("metax-gpu") + set_default(false) + set_showmenu(true) + set_description("Enable or disable Metax GPU kernel") + add_defines("ENABLE_METAX_GPU") +option_end() + + if is_mode("debug") then add_cxflags("-g -O0") add_defines("DEBUG_MODE") @@ -212,6 +220,53 @@ if has_config("ascend-npu") then target_end() end +if has_config("metax-gpu") then + + add_defines("ENABLE_METAX_GPU") + local MACA_ROOT = os.getenv("MACA_PATH") or os.getenv("MACA_HOME") or os.getenv("MACA_ROOT") + + add_includedirs(MACA_ROOT .. "/include") + add_linkdirs(MACA_ROOT .. "/lib") + -- add_linkdirs(MACA_ROOT .. "htgpu_llvm/lib") + add_links("libhcdnn.so") + add_links("libhcblas.so") + add_links("libhcruntime.so") + + rule("maca") + set_extensions(".maca") + + on_load(function (target) + target:add("includedirs", "include") + end) + + on_build_file(function (target, sourcefile) + local objectfile = target:objectfile(sourcefile) + os.mkdir(path.directory(objectfile)) + local htcc = "/opt/hpcc/htgpu_llvm/bin/htcc" + + local includedirs = table.concat(target:get("includedirs"), " ") + local args = { "-x", "hpcc", "-c", sourcefile, "-o", objectfile, "-I/opt/hpcc/include", "-O3", "-fPIC", "-Werror", "-std=c++17"} + + for _, includedir in ipairs(target:get("includedirs")) do + table.insert(args, "-I" .. includedir) + end + + os.execv(htcc, args) + table.insert(target:objectfiles(), objectfile) + end) + rule_end() + + target("metax-gpu") + set_kind("static") + on_install(function (target) end) + set_languages("cxx17") + add_files("src/devices/maca/*.cc", "src/ops/*/maca/*.cc") + add_files("src/ops/*/maca/*.maca", {rule = "maca"}) + add_cxflags("-lstdc++ -Werror -fPIC") + target_end() + +end + target("infiniop") set_kind("shared") @@ -227,6 +282,9 @@ target("infiniop") if has_config("ascend-npu") then add_deps("ascend-npu") end + if has_config("metax-gpu") then + add_deps("metax-gpu") + end set_languages("cxx17") add_files("src/devices/handle.cc") add_files("src/ops/*/operator.cc")