Skip to content

请问可以提供训推分离场景,p2p方式将分片数据从训练进程传输到推理进程的用例吗? #85

@big-ZWZ

Description

@big-ZWZ

假设有4个进程,rank0-3进程各自生成部分数据并注册至同名的ckpt_name,如何使用p2p传输能将rank0-2各自部分的数据,传输给rank3?
我仿照test/test_update.py实现了以下代码,

  1. 两进程下rank0往rank1传输使用ps.update(ranks=[0,1])是可以的,
  2. 两进程下rank0往rank1传输使用ps.update(ranks=[1])会socket time out
  3. 四进程下,即使ps.update(ranks=[0,1,2,3])也会socket time out
import os
import torch
import zmq
import time
import argparse
from checkpoint_engine.ps import ParameterServer, _get_physical_gpu_id
from checkpoint_engine.worker import update_weights_from_ipc
from torch.multiprocessing import Queue, get_context
from checkpoint_engine.device_utils import DeviceManager
from types import SimpleNamespace

try:
    device_manager = DeviceManager()
except TypeError:
    device_manager = SimpleNamespace(device_module=SimpleNamespace(device_count=lambda: 0))


def req_func(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
    print(f"[Rank {rank}] req_func")
    device_manager.device_module.set_device(rank)
    named_tensors = {
        name: tensor.to(device_manager.device_type) for name, tensor in named_tensors.items()
    }
    _zmq_ctx = zmq.Context()

    def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor]]):
        if len(names_to_check):
            return
        print(f"[Rank {rank}] names_to_check={names_to_check}")
        print(f"[Rank {rank}] weights={weights}")
        for name, weight in weights:
            if name not in named_tensors:
                continue
            assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!"
            names_to_check[name] = True

    def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]):
        print(f"[Rank {rank}] req_func check_weights")
        socket_paths = dict(socket_paths)
        update_weights_from_ipc(
            _zmq_ctx,
            socket_paths[device_uuid],
            device_id=rank,
            run=lambda weights: check(names_to_check, weights),
            post_hook=lambda: device_manager.device_module.synchronize(),
        )
        device_manager.device_module.synchronize()
        device_manager.device_module.empty_cache()
        assert all(names_to_check.values())

    while True:
        socket_paths: list[tuple[str, str]] = queue.get()
        if socket_paths is None:
            break
        names_to_check = dict.fromkeys(named_tensors.keys(), False)
        check_weights(names_to_check, socket_paths)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Update weights example")
    parser.add_argument("--num-trainer", type=int, default=1)
    parser.add_argument("--num-rollout", type=int, default=1)
    args = parser.parse_args()

    num_trainer = args.num_trainer
    num_rollout = args.num_rollout

    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    
    print(f"\n{'='*60}")
    print(f"[Rank {rank}] Starting one-way P2P test (0->1)")
    print(f"{'='*60}")
    
    # 设置设备
    torch.npu.set_device(local_rank)
    
    # 初始化PS
    print(f"[Rank {rank}] Initializing ParameterServer...")
    ps = ParameterServer(auto_pg=True)
    
    # 训练侧分片权重
    if rank < num_trainer:
        tensors = []
        for i in range(3):
            name = f"rank{rank}.tensor{i}"
            tensor = torch.ones(64, 128, device=f"npu:{local_rank}").to(torch.bfloat16) * (rank * 10 + i + 1)
            tensors.append((name, tensor))
        named_tensors = dict(tensors)
    # 推理侧空权重
    else:
        named_tensors = dict()
    
    checkpoint_name = "p2p_test"
    ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
    ps.gather_metas(checkpoint_name)
    print(f"metas={ps.get_metas()}")
    
    device_uuid = _get_physical_gpu_id(ps.device_manager, rank)

    ctx = get_context("spawn")
    queue = ctx.Queue()
    proc = ctx.Process(target=req_func, args=(rank, device_uuid, named_tensors, queue))
    proc.start()
    
    ps.update(checkpoint_name, queue.put, ranks=[0,3])
    
    time.sleep(2)
    
    print(f"\n[Rank {rank}] {'='*40}")
    if rank == 0:
        print(f"[Rank 0] ✅ Sender finished")
    elif rank == 1:
        print(f"[Rank 1] ✅ Receiver finished")
    print(f"[Rank {rank}] Test completed!")
    print(f"{'='*60}")
    
    # 清理
    if rank == 0:
        ps.unregister_checkpoint(checkpoint_name)
    queue.put(None)
    proc.join()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions