-
Notifications
You must be signed in to change notification settings - Fork 77
Open
Description
假设有4个进程,rank0-3进程各自生成部分数据并注册至同名的ckpt_name,如何使用p2p传输能将rank0-2各自部分的数据,传输给rank3?
我仿照test/test_update.py实现了以下代码,
- 两进程下rank0往rank1传输使用ps.update(ranks=[0,1])是可以的,
- 两进程下rank0往rank1传输使用ps.update(ranks=[1])会socket time out
- 四进程下,即使ps.update(ranks=[0,1,2,3])也会socket time out
import os
import torch
import zmq
import time
import argparse
from checkpoint_engine.ps import ParameterServer, _get_physical_gpu_id
from checkpoint_engine.worker import update_weights_from_ipc
from torch.multiprocessing import Queue, get_context
from checkpoint_engine.device_utils import DeviceManager
from types import SimpleNamespace
try:
device_manager = DeviceManager()
except TypeError:
device_manager = SimpleNamespace(device_module=SimpleNamespace(device_count=lambda: 0))
def req_func(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
print(f"[Rank {rank}] req_func")
device_manager.device_module.set_device(rank)
named_tensors = {
name: tensor.to(device_manager.device_type) for name, tensor in named_tensors.items()
}
_zmq_ctx = zmq.Context()
def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor]]):
if len(names_to_check):
return
print(f"[Rank {rank}] names_to_check={names_to_check}")
print(f"[Rank {rank}] weights={weights}")
for name, weight in weights:
if name not in named_tensors:
continue
assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!"
names_to_check[name] = True
def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]):
print(f"[Rank {rank}] req_func check_weights")
socket_paths = dict(socket_paths)
update_weights_from_ipc(
_zmq_ctx,
socket_paths[device_uuid],
device_id=rank,
run=lambda weights: check(names_to_check, weights),
post_hook=lambda: device_manager.device_module.synchronize(),
)
device_manager.device_module.synchronize()
device_manager.device_module.empty_cache()
assert all(names_to_check.values())
while True:
socket_paths: list[tuple[str, str]] = queue.get()
if socket_paths is None:
break
names_to_check = dict.fromkeys(named_tensors.keys(), False)
check_weights(names_to_check, socket_paths)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Update weights example")
parser.add_argument("--num-trainer", type=int, default=1)
parser.add_argument("--num-rollout", type=int, default=1)
args = parser.parse_args()
num_trainer = args.num_trainer
num_rollout = args.num_rollout
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
print(f"\n{'='*60}")
print(f"[Rank {rank}] Starting one-way P2P test (0->1)")
print(f"{'='*60}")
# 设置设备
torch.npu.set_device(local_rank)
# 初始化PS
print(f"[Rank {rank}] Initializing ParameterServer...")
ps = ParameterServer(auto_pg=True)
# 训练侧分片权重
if rank < num_trainer:
tensors = []
for i in range(3):
name = f"rank{rank}.tensor{i}"
tensor = torch.ones(64, 128, device=f"npu:{local_rank}").to(torch.bfloat16) * (rank * 10 + i + 1)
tensors.append((name, tensor))
named_tensors = dict(tensors)
# 推理侧空权重
else:
named_tensors = dict()
checkpoint_name = "p2p_test"
ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
ps.gather_metas(checkpoint_name)
print(f"metas={ps.get_metas()}")
device_uuid = _get_physical_gpu_id(ps.device_manager, rank)
ctx = get_context("spawn")
queue = ctx.Queue()
proc = ctx.Process(target=req_func, args=(rank, device_uuid, named_tensors, queue))
proc.start()
ps.update(checkpoint_name, queue.put, ranks=[0,3])
time.sleep(2)
print(f"\n[Rank {rank}] {'='*40}")
if rank == 0:
print(f"[Rank 0] ✅ Sender finished")
elif rank == 1:
print(f"[Rank 1] ✅ Receiver finished")
print(f"[Rank {rank}] Test completed!")
print(f"{'='*60}")
# 清理
if rank == 0:
ps.unregister_checkpoint(checkpoint_name)
queue.put(None)
proc.join()
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels