Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions python/infinicore/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import infinicore

from ...device import device as InfiniCoreDevice
from ...tensor import Tensor
from ..parameter import InfiniCoreParameter as Parameter

Expand Down Expand Up @@ -481,15 +482,14 @@ def _load_from_state_dict(
f"While copying the parameter named {key}, expected Tensor from checkpoint but received {type(input_param)}"
)

if (
(param.shape == input_param.shape)
and (param.dtype == input_param.dtype)
and (param.device == input_param.device)
if (param.shape == input_param.shape) and (
param.dtype == input_param.dtype
):
param.copy_(input_param)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为,infinicore的两个tensor之间的copy操作,是支持从 cpu直接拷贝到gpu的。

现在的权重加载判断是: 模型weight和权重文件,二者shape和dtype同一样时,可以拷贝数据,否则报错。

else:
print(f"param '{name}' don't match input_param '{key}'")
setattr(self, name, input_param)
raise KeyError(
f"param '{name}' don't match input_param '{key}' with shape or dtype"
)

elif strict:
missing_keys.append(key)
Expand Down Expand Up @@ -848,10 +848,29 @@ def eval(self: T) -> T:
Returns:
Module: self
"""
pass
raise KeyError("not support")

def _apply(self, fn, recurse=True):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to函数的部分参考了torch的写法

raise KeyError("not support")
if recurse:
for module in self.children():
module._apply(fn)

def to(self, *args, **kwargs):
raise KeyError("not support")
for key, param in self._parameters.items():
if param is not None:
setattr(self, key, fn(param))
Copy link
Collaborator Author

@pengcheng888 pengcheng888 Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用=符号,赋值不成功,不知为何。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最后用了setattr(self, key, fn(param))


for key, buf in self._buffers.items():
if buf is not None:
setattr(self, key, fn(buf))

return self

def to(self, device: str | InfiniCoreDevice):
if device is None:
raise ValueError("device cannot be None")
device = InfiniCoreDevice(device)

def convert(t):
return t.to(device)

return self._apply(convert)
90 changes: 88 additions & 2 deletions test/infinicore/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self):
def forward(self):
return infinicore.add(self.a, self.b)


infinicore_model_infer = InfiniCoreNet()
# ============================================================
# 2. 加载权重
Expand Down Expand Up @@ -75,6 +76,91 @@ def forward(self):


# ============================================================
# 5. to测试,buffer测试
# 5. to测试 - 测试模型在不同设备间的转换
# ============================================================
# 等待添加
print("\n" + "=" * 60)
print("5. to测试 - 设备转换测试")
print("=" * 60)


def print_model_state(model, title="状态"):
"""打印模型的参数状态"""
print(f"\n{title}:")
print("-" * 40)
print("Parameters:")
for name, param in model.named_parameters():
print(
f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}"
)


def verify_device_conversion(model, target_device, use_type_check=False):
"""验证模型参数的设备转换"""
print("转换后的Parameters:")
for name, param in model.named_parameters():
print(
f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}"
)
if use_type_check:
# 当使用字符串参数时,只检查设备类型
expected_type = (
target_device if isinstance(target_device, str) else target_device.type
)
assert param.device.type == expected_type, (
f"参数 {name} 的设备转换失败: 期望类型 {expected_type}, 实际 {param.device.type}"
)
else:
# 使用device对象时,进行完整比较
assert param.device == target_device, (
f"参数 {name} 的设备转换失败: 期望 {target_device}, 实际 {param.device}"
)


# 5.1 打印初始状态
print_model_state(infinicore_model_infer, "5.1 初始状态")

# 定义设备转换测试用例列表
device_conversion_cases = [
{
"name": "5.2 转换到CUDA设备",
"description": "使用 infinicore.device('cuda', 0)",
"target": infinicore.device("cuda", 0),
"use_type_check": False,
"success_msg": "✓ CUDA设备转换验证通过",
},
{
"name": "5.3 转换到CPU设备",
"description": "使用 infinicore.device('cpu', 0)",
"target": infinicore.device("cpu", 0),
"use_type_check": False,
"success_msg": "✓ CPU设备转换验证通过",
},
{
"name": "5.4 转换到CUDA设备",
"description": "使用字符串 'cuda'",
"target": "cuda",
"use_type_check": True,
"success_msg": "✓ 字符串参数设备转换验证通过",
},
]

# 循环测试每个设备转换用例
for case in device_conversion_cases:
print(f"\n{case['name']} ({case['description']}):")
print("-" * 40)
infinicore_model_infer.to(case["target"])
verify_device_conversion(
infinicore_model_infer, case["target"], use_type_check=case["use_type_check"]
)
print(case["success_msg"])

# 5.5 验证to方法返回self(链式调用支持)
print("\n5.5 测试to方法的返回值(链式调用):")
print("-" * 40)
result = infinicore_model_infer.to(infinicore.device("cpu", 0))
assert result is infinicore_model_infer, "to方法应该返回self以支持链式调用"
print("✓ to方法返回值验证通过")

print("\n" + "=" * 60)
print("所有to测试通过!")
print("=" * 60 + "\n")