From ec17a86b2123d88404002b00fe13329c467053e1 Mon Sep 17 00:00:00 2001 From: pengcheng888 Date: Tue, 13 Jan 2026 14:43:17 +0800 Subject: [PATCH] =?UTF-8?q?issue/890=20-=20=E4=B8=BApython=E7=AB=AF?= =?UTF-8?q?=E7=9A=84nn.module=E6=B7=BB=E5=8A=A0to=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/infinicore/nn/modules/module.py | 39 ++++++++--- test/infinicore/nn/module.py | 90 +++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 12 deletions(-) diff --git a/python/infinicore/nn/modules/module.py b/python/infinicore/nn/modules/module.py index d21223903..013f8331d 100644 --- a/python/infinicore/nn/modules/module.py +++ b/python/infinicore/nn/modules/module.py @@ -32,6 +32,7 @@ import infinicore +from ...device import device as InfiniCoreDevice from ...tensor import Tensor from ..parameter import InfiniCoreParameter as Parameter @@ -481,15 +482,14 @@ def _load_from_state_dict( f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}" ) - if ( - (param.shape == input_param.shape) - and (param.dtype == input_param.dtype) - and (param.device == input_param.device) + if (param.shape == input_param.shape) and ( + param.dtype == input_param.dtype ): param.copy_(input_param) else: - print(f"param '{name}' don't match input_param '{key}'") - setattr(self, name, input_param) + raise KeyError( + f"param '{name}' don't match input_param '{key}' with shape or dtype" + ) elif strict: missing_keys.append(key) @@ -848,10 +848,29 @@ def eval(self: T) -> T: Returns: Module: self """ - pass + raise KeyError("not support") def _apply(self, fn, recurse=True): - raise KeyError("not support") + if recurse: + for module in self.children(): + module._apply(fn) - def to(self, *args, **kwargs): - raise KeyError("not support") + for key, param in self._parameters.items(): + if param is not None: + setattr(self, key, fn(param)) + + for key, buf in self._buffers.items(): + if buf is not None: + setattr(self, key, fn(buf)) + + return self + + def to(self, device: str | InfiniCoreDevice): + if device is None: + raise ValueError("device cannot be None") + device = InfiniCoreDevice(device) + + def convert(t): + return t.to(device) + + return self._apply(convert) diff --git a/test/infinicore/nn/module.py b/test/infinicore/nn/module.py index 69e341fa2..2dafacee6 100644 --- a/test/infinicore/nn/module.py +++ b/test/infinicore/nn/module.py @@ -44,6 +44,7 @@ def __init__(self): def forward(self): return infinicore.add(self.a, self.b) + infinicore_model_infer = InfiniCoreNet() # ============================================================ # 2. 加载权重 @@ -75,6 +76,91 @@ def forward(self): # ============================================================ -# 5. to测试,buffer测试 +# 5. to测试 - 测试模型在不同设备间的转换 # ============================================================ -# 等待添加 +print("\n" + "=" * 60) +print("5. to测试 - 设备转换测试") +print("=" * 60) + + +def print_model_state(model, title="状态"): + """打印模型的参数状态""" + print(f"\n{title}:") + print("-" * 40) + print("Parameters:") + for name, param in model.named_parameters(): + print( + f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}" + ) + + +def verify_device_conversion(model, target_device, use_type_check=False): + """验证模型参数的设备转换""" + print("转换后的Parameters:") + for name, param in model.named_parameters(): + print( + f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}" + ) + if use_type_check: + # 当使用字符串参数时,只检查设备类型 + expected_type = ( + target_device if isinstance(target_device, str) else target_device.type + ) + assert param.device.type == expected_type, ( + f"参数 {name} 的设备转换失败: 期望类型 {expected_type}, 实际 {param.device.type}" + ) + else: + # 使用device对象时,进行完整比较 + assert param.device == target_device, ( + f"参数 {name} 的设备转换失败: 期望 {target_device}, 实际 {param.device}" + ) + + +# 5.1 打印初始状态 +print_model_state(infinicore_model_infer, "5.1 初始状态") + +# 定义设备转换测试用例列表 +device_conversion_cases = [ + { + "name": "5.2 转换到CUDA设备", + "description": "使用 infinicore.device('cuda', 0)", + "target": infinicore.device("cuda", 0), + "use_type_check": False, + "success_msg": "✓ CUDA设备转换验证通过", + }, + { + "name": "5.3 转换到CPU设备", + "description": "使用 infinicore.device('cpu', 0)", + "target": infinicore.device("cpu", 0), + "use_type_check": False, + "success_msg": "✓ CPU设备转换验证通过", + }, + { + "name": "5.4 转换到CUDA设备", + "description": "使用字符串 'cuda'", + "target": "cuda", + "use_type_check": True, + "success_msg": "✓ 字符串参数设备转换验证通过", + }, +] + +# 循环测试每个设备转换用例 +for case in device_conversion_cases: + print(f"\n{case['name']} ({case['description']}):") + print("-" * 40) + infinicore_model_infer.to(case["target"]) + verify_device_conversion( + infinicore_model_infer, case["target"], use_type_check=case["use_type_check"] + ) + print(case["success_msg"]) + +# 5.5 验证to方法返回self(链式调用支持) +print("\n5.5 测试to方法的返回值(链式调用):") +print("-" * 40) +result = infinicore_model_infer.to(infinicore.device("cpu", 0)) +assert result is infinicore_model_infer, "to方法应该返回self以支持链式调用" +print("✓ to方法返回值验证通过") + +print("\n" + "=" * 60) +print("所有to测试通过!") +print("=" * 60 + "\n")