From dbdb29ad5694fd9951289c313739d40a7bd544cc Mon Sep 17 00:00:00 2001 From: CalcuLuUus <1046352980@qq.com> Date: Wed, 14 Jan 2026 16:48:53 +0800 Subject: [PATCH] feat: add qwen3moe support --- .../infinilm/models/qwen3moe/MOE_REFACTOR.md | 23 + .../models/qwen3moe/MOE_REFACTOR_CN.md | 33 ++ python/infinilm/models/qwen3moe/__init__.py | 35 ++ python/infinilm/models/qwen3moe/qwen3moe.py | 328 +++++++++++++++ .../models/qwen3moe/qwen3moe_torch.py | 72 ++++ test/models/qwen3_moe/moe_test.py | 3 + test/models/qwen3_moe/ourmoe_test.py | 395 ++++++++++++++++++ test/models/qwen3_moe/test.sh | 4 + 8 files changed, 893 insertions(+) create mode 100644 python/infinilm/models/qwen3moe/MOE_REFACTOR.md create mode 100644 python/infinilm/models/qwen3moe/MOE_REFACTOR_CN.md create mode 100644 python/infinilm/models/qwen3moe/__init__.py create mode 100644 python/infinilm/models/qwen3moe/qwen3moe.py create mode 100644 python/infinilm/models/qwen3moe/qwen3moe_torch.py create mode 100644 test/models/qwen3_moe/ourmoe_test.py create mode 100755 test/models/qwen3_moe/test.sh diff --git a/python/infinilm/models/qwen3moe/MOE_REFACTOR.md b/python/infinilm/models/qwen3moe/MOE_REFACTOR.md new file mode 100644 index 00000000..6599b11a --- /dev/null +++ b/python/infinilm/models/qwen3moe/MOE_REFACTOR.md @@ -0,0 +1,23 @@ +# Qwen3 MoE refactor (infinicore) + +This refactor introduces an infinicore version of the `Qwen3MoeSparseMoeBlock` at `python/infinilm/models/qwen3moe/qwen3moe.py`. The block now subclasses `infinicore.nn.Module`, stores weights as `infinicore.nn.Parameter`, and exposes the same class names as the Torch reference for drop-in usage. + +## What changed +- Router uses `infinicore.nn.functional.linear` for the projection and a Python softmax+top-k shim to select experts, returning both scores and indices. +- Experts run gate/up/down projections with the stored weights but rely on NumPy to emulate routing utilities (one-hot masks, scatter add) that are not present in infinicore yet. +- The sparse MoE block returns `(hidden_states, routing_weights)` shaped back to `(batch, seq, hidden_dim)` and `(batch, seq, top_k)` respectively to surface the gate decisions. + +## Missing operators and temporary shims +The following Torch ops are not available in infinicore today and are emulated in pure Python/NumPy inside `qwen3moe.py`: +- Softmax over the expert dimension. +- `topk` selection of experts. +- One-hot expansion of expert indices. +- Scatter/add (`index_add_`) to accumulate expert outputs. +- Boolean masking utilities (`where`/`nonzero`) used for routing. + +All shims use `_tensor_to_numpy` to bridge an infinicore tensor to NumPy and `_from_numpy_like` to move results back while keeping device/dtype. Replace these with native infinicore kernels once they land to regain performance. + +## Notes and next steps +- Activation currently supports `silu`/`swish`, `gelu`, and `relu`. Extend `_activation_fn` if the config uses other functions. +- Weight initialization mirrors the Torch reference (`empty` for expert matrices, `zeros` for router weights); hook up a proper initializer if required. +- When infinicore adds native softmax/top-k/one-hot/scatter, the Python shims can be deleted and the routing path can stay entirely on-device. diff --git a/python/infinilm/models/qwen3moe/MOE_REFACTOR_CN.md b/python/infinilm/models/qwen3moe/MOE_REFACTOR_CN.md new file mode 100644 index 00000000..06909778 --- /dev/null +++ b/python/infinilm/models/qwen3moe/MOE_REFACTOR_CN.md @@ -0,0 +1,33 @@ +# Qwen3 MoE 重构说明(infinicore 版) + +本文记录如何将 `Qwen3MoeSparseMoeBlock` 从 torch 迁移到 infinicore 框架、缺失的算子列表,以及针对缺失算子的临时 Python/NumPy 实现方式。 + +## 重构思路 +- **接口保持一致**:在 `python/infinilm/models/qwen3moe/qwen3moe.py` 中实现 `Qwen3MoeExperts`、`Qwen3MoeTopKRouter`、`Qwen3MoeSparseMoeBlock`,类名和调用方式与 torch 版本一致,便于替换。 +- **参数类型迁移**:专家权重、路由权重使用 `infinicore.nn.Parameter` 存储,并通过 `infinicore.empty/zeros` 创建,保持设备与 dtype 可配置。 +- **算子优先用 infinicore**:线性层调用 `infinicore.nn.functional.linear`,其余缺失的路由相关算子用 Python/NumPy 暂存。 +- **返回值保持形状**:MoE block 输出 `(batch, seq, hidden_dim)` 的混合结果,以及 `(batch, seq, top_k)` 的路由得分,方便对齐原有行为。 + +## 缺失算子与临时实现 +当前 infinicore 不具备以下 torch 常用算子,均在 `qwen3moe.py` 内用纯 Python/NumPy 模拟: + +| 功能 | torch 对应 | 现状 | 临时方案 | +| --- | --- | --- | --- | +| Softmax | `torch.softmax` | 缺失 | `_softmax_np`:转 NumPy,按最后一维计算 softmax | +| Top-K | `torch.topk` | 缺失 | `_topk_np`:`argpartition` 找前 k,再排序 | +| One-Hot | `torch.nn.functional.one_hot` | 缺失 | `_one_hot_np`:`np.eye` 生成 | +| Scatter/Add | `index_add_` | 缺失 | `np.add.at` 在 token 维度累加 | +| Mask/筛选 | `where/nonzero` | 部分缺失 | 使用 `np.nonzero`/`np.where` 组合 | + +辅助函数 `_tensor_to_numpy`、`_from_numpy_like` 负责在 infinicore Tensor 与 NumPy 之间桥接,保持 dtype/device 一致;若底层增加直接转换接口,可移除这些桥接。 + +## 关键模块说明 +- **Qwen3MoeTopKRouter**:对输入做一次线性投影(infinicore),随后用 NumPy softmax + top-k 得到路由得分与专家索引,可选归一化。 +- **Qwen3MoeExperts**:对命中专家的 token 做 gate/up 投影、激活(支持 `silu/swish`、`gelu`、`relu`),再 down 投影,并用 `np.add.at` 进行按 token 维度的累加。 +- **Qwen3MoeSparseMoeBlock**:展平 batch/seq 维喂入 router,拿到 `routing_weights` 和 `selected_experts` 后调用 experts 聚合,最后 reshape 回原始形状并返回路由得分。 + +## 已知限制与后续优化 +- 路由路径使用 NumPy,暂时会有 CPU 往返与性能损失;待 infinicore 提供 softmax/top-k/one-hot/scatter 等算子后可彻底移除这些 Python 分支。 +- 激活函数目前覆盖 `silu/swish`、`gelu`、`relu`,若配置中包含其他激活需在 `_activation_fn` 扩展。 +- 权重初始化沿用 `empty/zeros`,如需与原模型严格对齐,可在加载或构建阶段补充初始化逻辑。 +- 建议后续补充单测,对比 torch 参考实现的输出形状与数值(在可用时)以确保兼容性。 diff --git a/python/infinilm/models/qwen3moe/__init__.py b/python/infinilm/models/qwen3moe/__init__.py new file mode 100644 index 00000000..8955c28a --- /dev/null +++ b/python/infinilm/models/qwen3moe/__init__.py @@ -0,0 +1,35 @@ +import os +from typing import Optional, Union +import infinicore + +__all__ = ["AutoQwen3MOEModel"] + + +class AutoQwen3MOEModel: + @classmethod + def from_pretrained( + cls, + model_path: Optional[Union[str, os.PathLike]], + device: infinicore.device, + dtype=infinicore.dtype, + backend="python", + ): + if backend == "python": + from . import modeling_qwen3moe + + return modeling_qwen3moe.Qwen3MOE.from_pretrained( + model_path, + device=device, + dtype=dtype, + ) + + # elif backend == "cpp": + # from .backends import cpp + + # return cpp.LlamaForCausalLM.from_pretrained( + # model_path, + # device=device, + # dtype=dtype, + # ) + + raise KeyError("invalid backend") diff --git a/python/infinilm/models/qwen3moe/qwen3moe.py b/python/infinilm/models/qwen3moe/qwen3moe.py new file mode 100644 index 00000000..60ffe5f1 --- /dev/null +++ b/python/infinilm/models/qwen3moe/qwen3moe.py @@ -0,0 +1,328 @@ +import math +import os +import ctypes +from typing import Callable, Tuple + +import numpy as np + +import infinicore +import infinicore as ic +from infinicore.nn import Module, Parameter +from infinicore.nn import functional as F + +def _tensor_to_numpy(tensor: infinicore.Tensor) -> np.ndarray: + if isinstance(tensor, np.ndarray): + return tensor + + try: + cpu_dev = infinicore.device("cpu", 0) + if str(tensor.device) != "cpu:0": + cpu_tensor = tensor.to(cpu_dev) + else: + cpu_tensor = tensor + + dtype_str = str(cpu_tensor.dtype) + if "float32" in dtype_str: + CType = ctypes.c_float + np_dtype = np.float32 + elif "bfloat16" in dtype_str or "half" in dtype_str or "float16" in dtype_str: + CType = ctypes.c_uint16 + np_dtype = np.uint16 + elif "int64" in dtype_str: + CType = ctypes.c_longlong + np_dtype = np.int64 + elif "int32" in dtype_str: + CType = ctypes.c_int + np_dtype = np.int32 + else: + CType = ctypes.c_float + np_dtype = np.float32 + + ptr = cpu_tensor.data_ptr() + size = cpu_tensor.numel() + + ArrayType = CType * size + c_array = ArrayType.from_address(ptr) + np_arr = np.ctypeslib.as_array(c_array).copy().reshape(cpu_tensor.shape) + + if np_dtype == np.uint16: + np_arr = np_arr.astype(np.float32) + + return np_arr + except Exception as e: + raise TypeError( + f"Could not convert {type(tensor)} to numpy array. Direct access failed: {e}" + ) + + +def _from_numpy( + array: np.ndarray, + *, + dtype: infinicore.dtype | None = None, + device: infinicore.device | None = None, +) -> infinicore.Tensor: + if not array.flags.c_contiguous: + array = np.ascontiguousarray(array) + + # HACK: Direct Memory Injection (Write Hack) + # infinicore.from_numpy() is broken for some versions/dtypes. + # We create a container tensor and write data directly to its memory. + + try: + # 1. Prepare source data (ensure float32 for safety) + src_data = array.astype(np.float32) + + + tensor = infinicore.zeros( + list(src_data.shape), + dtype=infinicore.float32, + device=infinicore.device("cpu", 0), + ) + + # 3. Get pointer and inject data + ptr = tensor.data_ptr() + size = tensor.numel() + + DstType = ctypes.c_float * size + dst_array = DstType.from_address(ptr) + + # Flatten and copy + src_flat = src_data.flatten() + # Using ctypes memmove might be faster, but loop is explicit + # dst_array[:] = src_flat # This doesn't work directly on ctypes array + + # Copy data (using memmove for speed) + ctypes.memmove(dst_array, src_flat.ctypes.data, src_flat.nbytes) + + # 4. Move to target device + if device is not None and str(device) != "cpu:0": + tensor = tensor.to(device) + + # 5. Cast to target dtype + if dtype is not None and tensor.dtype != dtype: + tensor = tensor.to(dtype=dtype) + + return tensor + + except Exception as e: + raise TypeError(f"Could not create infinicore.Tensor from numpy array via data_ptr hack: {e}") + + +def _from_numpy_like(array: np.ndarray, like: infinicore.Tensor) -> infinicore.Tensor: + return _from_numpy(array, dtype=like.dtype, device=like.device) + + +def _softmax_np(x: np.ndarray) -> np.ndarray: + x_max = np.max(x, axis=-1, keepdims=True) + x_exp = np.exp(x - x_max) + denom = np.sum(x_exp, axis=-1, keepdims=True) + return x_exp / np.maximum(denom, 1e-9) + + +def _topk_np(prob: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: + if k <= 0: + raise ValueError("top-k must be positive.") + # argpartition gives us the largest k elements in arbitrary order + partition_idx = np.argpartition(prob, -k, axis=-1)[..., -k:] + topk_values = np.take_along_axis(prob, partition_idx, axis=-1) + # sort descending within the top-k slice + sort_order = np.argsort(-topk_values, axis=-1) + topk_indices = np.take_along_axis(partition_idx, sort_order, axis=-1) + topk_values = np.take_along_axis(topk_values, sort_order, axis=-1) + return topk_values, topk_indices + + +def _one_hot_np(indices: np.ndarray, num_classes: int) -> np.ndarray: + eye = np.eye(num_classes, dtype=np.int64) + return eye[indices] + + +def _activation_fn(name: str) -> Callable[[np.ndarray], np.ndarray]: + lower = name.lower() + if lower in ("silu", "swish"): + return lambda x: x / (1.0 + np.exp(-x)) + if lower == "gelu": + return lambda x: 0.5 * x * (1.0 + np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x**3))) + if lower == "relu": + return lambda x: np.maximum(x, 0.0) + raise KeyError(f"Unsupported activation '{name}' for Qwen3 MoE block.") + + +class Qwen3MoeExperts(Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__( + self, + config, + *, + device: infinicore.device | None = None, + dtype: infinicore.dtype | None = None, + ): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + factory_kwargs = { + "device": device if device is not None else infinicore.device("cpu", 0), + "dtype": dtype if dtype is not None else infinicore.float32, + } + self.gate_up_proj = Parameter( + infinicore.empty( + [self.num_experts, 2 * self.intermediate_dim, self.hidden_dim], + **factory_kwargs, + ) + ) + self.down_proj = Parameter( + infinicore.empty( + [self.num_experts, self.hidden_dim, self.intermediate_dim], + **factory_kwargs, + ) + ) + self.act_fn = _activation_fn(getattr(config, "hidden_act", "silu")) + + def forward( + self, + hidden_states: infinicore.Tensor, + top_k_index: infinicore.Tensor, + top_k_weights: infinicore.Tensor, + ) -> infinicore.Tensor: + hidden_np = _tensor_to_numpy(hidden_states) + topk_idx_np = _tensor_to_numpy(top_k_index).astype(np.int64, copy=False) + topk_w_np = _tensor_to_numpy(top_k_weights) + + final_hidden_np = np.zeros_like(hidden_np) + expert_mask = _one_hot_np(topk_idx_np, self.num_experts) # [tokens, top_k, num_experts] + expert_mask = np.transpose(expert_mask, (2, 1, 0)) # [num_experts, top_k, tokens] + expert_hit = np.nonzero(np.sum(expert_mask, axis=(1, 2)) > 0)[0] + + gate_up_np = _tensor_to_numpy(self.gate_up_proj) + down_proj_np = _tensor_to_numpy(self.down_proj) + + for expert_idx in expert_hit: + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = np.nonzero(expert_mask[expert_idx]) + if token_idx.size == 0: + continue + + current_state = hidden_np[token_idx] + # Gate + Up projection + proj_out = current_state @ gate_up_np[expert_idx].T + gate_part = proj_out[:, : self.intermediate_dim] + up_part = proj_out[:, self.intermediate_dim :] + + activated = self.act_fn(gate_part) * up_part + + # Down projection + current_hidden = activated @ down_proj_np[expert_idx].T + current_hidden = current_hidden * topk_w_np[token_idx, top_k_pos][:, None] + + # Scatter-add back to the output buffer + np.add.at(final_hidden_np, token_idx, current_hidden.astype(final_hidden_np.dtype, copy=False)) + + return _from_numpy_like(final_hidden_np, hidden_states) + + +class Qwen3MoeTopKRouter(Module): + def __init__( + self, + config, + *, + device: infinicore.device | None = None, + dtype: infinicore.dtype | None = None, + ): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = getattr(config, "norm_topk_prob", False) + self.hidden_dim = config.hidden_size + factory_kwargs = { + "device": device if device is not None else infinicore.device("cpu", 0), + "dtype": dtype if dtype is not None else infinicore.float32, + } + self.weight = Parameter(infinicore.zeros([self.num_experts, self.hidden_dim], **factory_kwargs)) + + def forward(self, hidden_states: infinicore.Tensor): + # Avoid negative-dim inference; some backends reject view with -1. + tokens = int(hidden_states.shape[0]) + hidden_states = hidden_states.view((tokens, self.hidden_dim)) + router_logits = F.linear(hidden_states, self.weight) + + softmax_impl = os.environ.get("INFINILM_QWEN3MOE_ROUTER_SOFTMAX", "auto").lower() + if softmax_impl not in ("auto", "ic", "numpy"): + softmax_impl = "auto" + + router_prob = None + router_prob_np = None + can_use_ic = ( + infinicore.use_ntops + and router_logits.device.type in ("cuda", "musa") + and softmax_impl in ("auto", "ic") + ) + +########################################### + softmax_debug = os.environ.get("INFINILM_QWEN3MOE_ROUTER_SOFTMAX_DEBUG", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + if softmax_debug and not getattr(self, "_softmax_debug_printed", False): + selected = "ic" if can_use_ic else "numpy" + print( + "[Qwen3MoeTopKRouter] " + f"softmax_impl={softmax_impl} selected={selected} " + f"use_ntops={infinicore.use_ntops} device={router_logits.device}" + ) + setattr(self, "_softmax_debug_printed", True) +########################################### + + if can_use_ic: + router_prob = ic.softmax(router_logits, dim=-1) + router_prob_np = _tensor_to_numpy(router_prob) + else: + if softmax_impl == "ic": + raise RuntimeError( + "ic.softmax is only available with ntops on CUDA/MUSA devices; " + f"use_ntops={infinicore.use_ntops} device={router_logits.device}" + ) + router_prob_np = _softmax_np(_tensor_to_numpy(router_logits)) + router_prob = _from_numpy(router_prob_np, dtype=router_logits.dtype, device=router_logits.device) + + router_top_value, router_indices = _topk_np(router_prob_np, self.top_k) + if self.norm_topk_prob: + denom = np.sum(router_top_value, axis=-1, keepdims=True) + router_top_value = np.divide(router_top_value, denom, where=denom != 0) + + router_scores = _from_numpy( + router_top_value.astype(router_prob_np.dtype, copy=False), + dtype=router_prob.dtype, + device=router_prob.device, + ) + router_indices_tensor = _from_numpy(router_indices.astype(np.int64), device=router_prob.device) + return router_prob, router_scores, router_indices_tensor + + +class Qwen3MoeSparseMoeBlock(Module): + def __init__( + self, + config, + *, + device: infinicore.device | None = None, + dtype: infinicore.dtype | None = None, + ): + super().__init__() + self.experts = Qwen3MoeExperts(config, device=device, dtype=dtype) + self.gate = Qwen3MoeTopKRouter(config, device=device, dtype=dtype) + + def forward(self, hidden_states: infinicore.Tensor): + batch_size, sequence_length, hidden_dim = hidden_states.shape + tokens = int(batch_size * sequence_length) + hidden_states_reshaped = hidden_states.view((tokens, hidden_dim)) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + top_k = int(routing_weights.shape[1]) + return ( + final_hidden_states.view((batch_size, sequence_length, hidden_dim)), + routing_weights.view((batch_size, sequence_length, top_k)), + ) \ No newline at end of file diff --git a/python/infinilm/models/qwen3moe/qwen3moe_torch.py b/python/infinilm/models/qwen3moe/qwen3moe_torch.py new file mode 100644 index 00000000..3b84bb7f --- /dev/null +++ b/python/infinilm/models/qwen3moe/qwen3moe_torch.py @@ -0,0 +1,72 @@ +class Qwen3MoeExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class Qwen3MoeTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + if self.norm_topk_prob: + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = router_top_value + return router_logits, router_scores, router_indices + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config: Qwen3MoeConfig): + super().__init__() + self.experts = Qwen3MoeExperts(config) + self.gate = Qwen3MoeTopKRouter(config) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_reshaped = hidden_states.view(-1, hidden_dim) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) + final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) + return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) diff --git a/test/models/qwen3_moe/moe_test.py b/test/models/qwen3_moe/moe_test.py index 4e0adaf4..c0edc6b2 100644 --- a/test/models/qwen3_moe/moe_test.py +++ b/test/models/qwen3_moe/moe_test.py @@ -122,6 +122,9 @@ def benchmark_moe_torch(moe, testcase, device, dtype): print( f"\t WARMUPS={WARMUPS} RUNS={RUNS}, MoE Torch average latency: {round(total_time * 1000 / RUNS, 2)} ms throughput: {round(total_tokens / total_time, 2)} tok/s" ) + # # Correctness check: print some output stats + # print(f"\t Output stats - Sum: {output_host.sum().item():.4f}, Mean: {output_host.mean().item():.4f}") + # print(f"\t First 5 values: {output_host.flatten()[:5].tolist()}") return output_host diff --git a/test/models/qwen3_moe/ourmoe_test.py b/test/models/qwen3_moe/ourmoe_test.py new file mode 100644 index 00000000..d211960e --- /dev/null +++ b/test/models/qwen3_moe/ourmoe_test.py @@ -0,0 +1,395 @@ +import argparse +import re +import numpy as np + +import time +import torch +import os +from transformers import AutoConfig +from transformers.models import qwen3_moe # 对拍相关 +import sys + + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../python")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../QIYUAN_GROUP-InfiniCore/python")) + +import infinicore +from infinilm.modeling_utils import load_state_dict +from infinilm.models.qwen3moe.qwen3moe import Qwen3MoeSparseMoeBlock +from infinilm.generation.utils import infini_to_numpy # 对拍 +# WARMUPS = 10 +# RUNS = 100 +WARMUPS = 0 +RUNS = 1 + +PREFILL_TESTCASES = {"seqlens": [64, 128, 256, 256], "pastlens": [512, 0, 0, 256]} +DECODE_TESTCASES = { + "seqlens": [1 for _ in range(16)], + "pastlens": [50 for _ in range(4)] + + [100 for _ in range(4)] + + [200 for _ in range(4)] + + [400 for _ in range(4)], +} + + +def get_args(): + parser = argparse.ArgumentParser(description="Test Qwen3 MoE block with InfiniLM") + parser.add_argument( + "--cpu", + action="store_true", + help="Run cpu test", + ) + parser.add_argument( + "--nvidia", + action="store_true", + help="Run nvidia test", + ) + parser.add_argument( + "--metax", + action="store_true", + help="Run metax test", + ) + parser.add_argument( + "--moore", + action="store_true", + help="Run moore test", + ) + parser.add_argument( + "--iluvatar", + action="store_true", + help="Run iluvatar test", + ) + parser.add_argument( +#================ 对拍 ========================# + "--check", + action="store_true", + help="Compare against a Torch reference implementation", + ) + parser.add_argument( + "--check_device", + type=str, + default="cpu", + choices=("cpu", "cuda"), + help="Device used for the Torch reference when --check is enabled", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Seed used to generate deterministic inputs", + ) + parser.add_argument( +#================ 对拍 ========================# + "--model_path", + type=str, + required=True, + help="Path to the Qwen3-MoE checkpoint directory", + ) + parser.add_argument( + "--layer_idx", + type=int, + default=0, + help="Which decoder layer's MoE block to load weights from", + ) + return parser.parse_args() + + +def resolve_device(args) -> str: + if args.cpu: + return "cpu" + if args.nvidia: + return "cuda" + if args.metax: + return "cuda" + if args.moore: + return "musa" + if args.iluvatar: + return "cuda" + raise ValueError( + "Usage: python test/models/qwen3_moe/ourmoe_test.py " + "[--cpu | --nvidia | --metax | --moore | --iluvatar] " + "--model_path=" + ) + + +def to_torch_dtype(infini_dtype: infinicore.dtype): + utils = getattr(infinicore, "utils", None) + if utils is not None: + mapper = getattr(utils, "to_torch_dtype", None) + if callable(mapper): + return mapper(infini_dtype) +#================ 对拍 ========================# + if infini_dtype == infinicore.float32: + return torch.float32 +#================ 对拍 ========================# + return torch.bfloat16 + + +#================ 对拍 ========================# +def load_moe_state_dict_torch(model_path: str, layer_idx: int): + prefix = f"model.layers.{layer_idx}.mlp." + tensors = {} + for fname in sorted(os.listdir(model_path)): + if not fname.endswith(".safetensors"): + continue + checkpoint = load_state_dict(os.path.join(model_path, fname)) + for full_name, tensor in checkpoint.items(): + if full_name.startswith(prefix): + tensors[full_name[len(prefix) :]] = tensor + if not tensors: + raise FileNotFoundError(f"Cannot find MoE weights with prefix '{prefix}' under {model_path}") + return tensors + + +def create_moe_torch(config, model_path: str, device_str: str, dtype, layer_idx: int): + moe = qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock(config).to(device=device_str, dtype=dtype) + moe.load_state_dict(load_moe_state_dict_torch(model_path, layer_idx), strict=True) + moe.eval() + return moe + +#================ 对拍 ========================# +def load_moe_weights(model_path: str, device_str: str, dtype, layer_idx: int, config): + prefix = f"model.layers.{layer_idx}.mlp." + torch_dtype = to_torch_dtype(dtype) + + gate_weight = None + expert_parts: dict[int, dict[str, torch.Tensor]] = {} + + for fname in sorted(os.listdir(model_path)): + if not fname.endswith(".safetensors"): + continue + + checkpoint = load_state_dict(os.path.join(model_path, fname)) + for full_name, tensor in checkpoint.items(): + if not full_name.startswith(prefix): + continue + + local_name = full_name[len(prefix) :] + if local_name == "gate.weight": + gate_weight = tensor + continue + + match = re.match(r"experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight", local_name) + if match: + idx = int(match.group(1)) + expert_parts.setdefault(idx, {})[match.group(2)] = tensor + + if gate_weight is None or not expert_parts: + raise FileNotFoundError( + f"Cannot find MoE weights with prefix '{prefix}' under {model_path}" + ) + + num_experts = config.num_experts + hidden_dim = config.hidden_size + inter_dim = config.moe_intermediate_size + + gate_up = torch.empty((num_experts, 2 * inter_dim, hidden_dim)) + down = torch.empty((num_experts, hidden_dim, inter_dim)) + + for expert_idx in range(num_experts): + parts = expert_parts.get(expert_idx) + if parts is None or any(k not in parts for k in ("gate_proj", "up_proj", "down_proj")): + raise KeyError(f"Missing weights for expert {expert_idx} in {model_path}") + + gate_proj = parts["gate_proj"] + up_proj = parts["up_proj"] + down_proj = parts["down_proj"] + gate_up[expert_idx] = torch.cat([gate_proj, up_proj], dim=0) + down[expert_idx] = down_proj + + gate_weight = gate_weight.to(device=device_str, dtype=torch_dtype) + gate_up = gate_up.to(device=device_str, dtype=torch_dtype) + down = down.to(device=device_str, dtype=torch_dtype) + + return { + "gate.weight": infinicore.from_torch(gate_weight), + "experts.gate_up_proj": infinicore.from_torch(gate_up), + "experts.down_proj": infinicore.from_torch(down), + } + + +def create_moe(model_path: str, device: infinicore.device, dtype, layer_idx: int): + config = AutoConfig.from_pretrained(model_path) + moe = Qwen3MoeSparseMoeBlock(config, device=device, dtype=dtype) + + moe_state = load_moe_weights(model_path, device.type, dtype, layer_idx, config) + moe.load_state_dict(moe_state) + if hasattr(moe, "eval"): + moe.eval() + return moe, config + + +# def generate_moe_input(testcase, hidden_size: int, device, dtype): +def generate_moe_input(testcase, hidden_size: int, device, dtype, *, seed: int): # 对拍 + total_seqlen = sum(testcase["seqlens"]) + # host = np.random.default_rng().standard_normal( + # (1, total_seqlen, hidden_size) + # ).astype(np.float32) + # return infinicore.from_numpy(host, dtype=dtype, device=device) +#================ 对拍 ========================# + host = np.random.default_rng(seed).standard_normal((1, total_seqlen, hidden_size)).astype(np.float32) + return host, infinicore.from_numpy(host, dtype=dtype, device=device) +#================ 对拍 ========================# + + +def _sync_device(device: infinicore.device): + for name in ("synchronize", "device_synchronize"): + fn = getattr(infinicore, name, None) + if callable(fn): + try: + fn(device) + except TypeError: + fn() + break + + +# def benchmark_moe(moe, testcase, hidden_size: int, device, dtype): +# input_tensor = generate_moe_input(testcase, hidden_size, device, dtype) +#================ 对拍 ========================# +def benchmark_moe( + moe, + testcase, + hidden_size: int, + device, + dtype, + *, + seed: int, + check: bool, + torch_moe=None, + torch_device_str: str | None = None, + torch_dtype=None, +): + host, input_tensor = generate_moe_input(testcase, hidden_size, device, dtype, seed=seed) +#================ 对拍 ========================# + hidden_out, routing_out = moe(input_tensor) + + print( + f"\tOutput hidden shape: {getattr(hidden_out, 'shape', '?')}, routing shape: {getattr(routing_out, 'shape', '?')}" + ) + +#================ 对拍 ========================# + if check: + _sync_device(device) + + hidden_cpu = hidden_out + if hidden_cpu.device.type != "cpu": + hidden_cpu = hidden_cpu.to(infinicore.device("cpu", 0)) + if hasattr(hidden_cpu, "is_contiguous") and callable(hidden_cpu.is_contiguous): + if not hidden_cpu.is_contiguous(): + hidden_cpu = hidden_cpu.contiguous() + + out_inf = infini_to_numpy(hidden_cpu).astype(np.float32, copy=False) + + torch_inp = torch.from_numpy(host).to(device=torch_device_str, dtype=torch_dtype) + with torch.no_grad(): + out_torch, _ = torch_moe(torch_inp) + out_torch = out_torch.detach().to("cpu").to(dtype=torch.float32).numpy() + + inf_nan = int(np.isnan(out_inf).sum()) + inf_inf = int(np.isinf(out_inf).sum()) + torch_nan = int(np.isnan(out_torch).sum()) + torch_inf = int(np.isinf(out_torch).sum()) + + finite = np.isfinite(out_inf) & np.isfinite(out_torch) + if finite.any(): + diff = out_torch[finite] - out_inf[finite] + diff_abs_max = float(np.max(np.abs(diff))) + diff_abs_mean = float(np.mean(np.abs(diff))) + else: + diff_abs_max = float("nan") + diff_abs_mean = float("nan") + + print(f"\t Output stats (torch) - Sum: {out_torch.sum():.4f}, Mean: {out_torch.mean():.4f}") + print(f"\t Output stats (infini) - Sum: {out_inf.sum():.4f}, Mean: {out_inf.mean():.4f}") + print(f"\t NaN/Inf count (torch): {torch_nan}/{torch_inf} (infini): {inf_nan}/{inf_inf}") + print(f"\t First 5 values (torch): {out_torch.reshape(-1)[:5].tolist()}") + print(f"\t First 5 values (infini): {out_inf.reshape(-1)[:5].tolist()}") + print(f"\t Diff abs max: {diff_abs_max:.6f}, mean: {diff_abs_mean:.6f}") +#================ 对拍 ========================# + + for _ in range(WARMUPS): + moe(input_tensor) + _sync_device(device) + + t0 = time.time() + for _ in range(RUNS): + moe(input_tensor) + _sync_device(device) + t1 = time.time() + + total_time = t1 - t0 + total_tokens = sum(testcase["seqlens"]) * RUNS + print( + f"\tWARMUPS={WARMUPS} RUNS={RUNS}, latency: {round(total_time * 1000 / RUNS, 2)} ms throughput: {round(total_tokens / total_time, 2)} tok/s" + ) + + return hidden_out + + +if __name__ == "__main__": + args = get_args() + print(args) + + device_str = resolve_device(args) + if device_str == "musa": + try: + import torch_musa # noqa: F401 + except ImportError: + print("torch_musa is required for MUSA devices, falling back to CPU") + device_str = "cpu" + + infini_device = infinicore.device(device_str, 0) + # infini_dtype = infinicore.bfloat16 + # Switch to float32 to bypass infinicore BF16 conversion issues + infini_dtype = infinicore.float32 + + moe, config = create_moe( + args.model_path, device=infini_device, dtype=infini_dtype, layer_idx=args.layer_idx + ) + hidden_size = config.hidden_size + + print("*" * 130) + print("Test Qwen3 MoE (InfiniLM)") + print("*" * 130) + + print(f"Test Case PREFILL_TESTCASES : {PREFILL_TESTCASES}") + # benchmark_moe(moe, PREFILL_TESTCASES, hidden_size, device=infini_device, dtype=infini_dtype) +#================ 对拍 ========================# + torch_dtype = to_torch_dtype(infini_dtype) + torch_moe = None + torch_check_device = args.check_device + if args.check: + torch_moe = create_moe_torch(config, args.model_path, torch_check_device, torch_dtype, args.layer_idx) + + benchmark_moe( + moe, + PREFILL_TESTCASES, + hidden_size, + device=infini_device, + dtype=infini_dtype, + seed=args.seed, + check=args.check, + torch_moe=torch_moe, + torch_device_str=torch_check_device, + torch_dtype=torch_dtype, + ) +#================ 对拍 ========================# + + print("\n" + "-" * 130) + print(f"\nTest DECODE_TESTCASES: {DECODE_TESTCASES}") + # benchmark_moe(moe, DECODE_TESTCASES, hidden_size, device=infini_device, dtype=infini_dtype) +#================ 对拍 ========================# + benchmark_moe( + moe, + DECODE_TESTCASES, + hidden_size, + device=infini_device, + dtype=infini_dtype, + seed=args.seed + 1, + check=args.check, + torch_moe=torch_moe, + torch_device_str=torch_check_device, + torch_dtype=torch_dtype, + ) +#================ 对拍 ========================# diff --git a/test/models/qwen3_moe/test.sh b/test/models/qwen3_moe/test.sh new file mode 100755 index 00000000..2897a23d --- /dev/null +++ b/test/models/qwen3_moe/test.sh @@ -0,0 +1,4 @@ +# speed compare +srun --gres=gpu:nvidia:2 --cpus-per-task=16 --mem=256G python ourmoe_test.py --nvidia --model_path=/data/users/whitecity/models/Qwen3-30B-A3B-Instruct-2507-Layer-0 +# result compare +srun --gres=gpu:nvidia:2 --cpus-per-task=16 --mem=256G python ourmoe_test.py --cpu --check --check_device cpu --seed 0 --model_path=/data/users/whitecity/models/Qwen3-30B-A3B-Instruct-2507-Layer-0 --layer_idx=0 \ No newline at end of file