diff --git a/checkpoint_engine/distributed/__init__.py b/checkpoint_engine/distributed/__init__.py new file mode 100644 index 0000000..fee44a9 --- /dev/null +++ b/checkpoint_engine/distributed/__init__.py @@ -0,0 +1,28 @@ +from .base import ( + Distributed, + DistributedProcessGroup, + all_gather_object, + all_reduce, + barrier, + broadcast, + destroy_process_group, + init_process_group, + is_initialized, + new_group, + use_backend, +) + + +__all__ = [ + "Distributed", + "DistributedProcessGroup", + "all_gather_object", + "all_reduce", + "barrier", + "broadcast", + "destroy_process_group", + "init_process_group", + "is_initialized", + "new_group", + "use_backend", +] diff --git a/checkpoint_engine/distributed/base.py b/checkpoint_engine/distributed/base.py new file mode 100644 index 0000000..4299394 --- /dev/null +++ b/checkpoint_engine/distributed/base.py @@ -0,0 +1,288 @@ +import importlib +import io +import pickle +from abc import ABC, abstractmethod +from datetime import timedelta +from typing import Any, Protocol + +import torch +import torch.distributed as torch_dist + + +class CommunicatorProtocol(Protocol): + def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ... + + +class CommGroup: + def __init__(self, comm_handle: int, ranks: list[int]): + self._comm = comm_handle + self._ranks = ranks + + @property + def handle(self) -> int: + return self._comm + + @property + def ranks(self) -> list[int]: + return self._ranks + + +DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup + + +class Distributed(ABC): + @abstractmethod + def init_process_group( + self, + rank: int, + world_size: int, + store: torch_dist.TCPStore, + **kwargs, + ): + raise NotImplementedError + + @abstractmethod + def destroy_process_group( + self, + group: DistributedProcessGroup | None = None, + ): + raise NotImplementedError + + @abstractmethod + def is_initialized(self) -> bool: + raise NotImplementedError + + @abstractmethod + def all_gather_object( + self, + object_list: list[Any], + obj: Any, + group: DistributedProcessGroup | None = None, + ): + raise NotImplementedError + + @abstractmethod + def all_reduce( + self, + tensor: torch.Tensor, + op: torch_dist.ReduceOp.RedOpType, + group: DistributedProcessGroup | None = None, + **kwargs, + ): + raise NotImplementedError + + @abstractmethod + def broadcast( + self, + tensor: torch.Tensor, + src: int, + group: DistributedProcessGroup | None = None, + **kwargs, + ): + raise NotImplementedError + + @abstractmethod + def barrier( + self, + group: DistributedProcessGroup | None = None, + **kwargs, + ): + raise NotImplementedError + + @abstractmethod + def new_group( + self, + ranks: list[int], + **kwargs, + ): + raise NotImplementedError + + +class TorchBackend(Distributed): + def init_process_group( + self, + rank: int, + world_size: int, + store: torch_dist.TCPStore, + **kwargs, + ): + backend = kwargs.get("backend", "nccl") + timeout = kwargs.get("timeout", timedelta(minutes=10)) + + torch_dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + timeout=timeout, + store=store, + ) + + def destroy_process_group(self, group: DistributedProcessGroup | None = None): + torch_dist.destroy_process_group(group) + + def is_initialized(self) -> bool: + return torch_dist.is_initialized() + + def all_gather_object( + self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None + ): + torch_dist.all_gather_object(object_list, obj, group) + + def all_reduce( + self, + tensor: torch.Tensor, + op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM, + group: DistributedProcessGroup | None = None, + **kwargs, + ): + torch_dist.all_reduce(tensor, op, group, **kwargs) + + def broadcast( + self, + tensor: torch.Tensor, + src: int = 0, + group: DistributedProcessGroup | None = None, + **kwargs, + ): + torch_dist.broadcast(tensor, src, group, **kwargs) + + def barrier(self, group: DistributedProcessGroup | None = None, **kwargs): + torch_dist.barrier(group, **kwargs) + + def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None: + return torch_dist.new_group(ranks, **kwargs) + + +# specific device instance +_BACKEND_INSTANCE: Distributed = TorchBackend() + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) + byte_tensor = torch.ByteTensor(byte_storage).to(device) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any: + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +def _flatten_for_scatter_gather( + tensor_list: list[torch.Tensor], copy: bool = False +) -> torch.Tensor: + if not tensor_list: + raise RuntimeError("Received an empty list.") + t = tensor_list[0] + buffer_shape = [len(tensor_list)] + list(t.shape) + + buffer = torch.empty(tuple(buffer_shape), dtype=t.dtype, device=t.device) + if copy: + for i, tensor in enumerate(tensor_list): + buffer[i].copy_(tensor) + return buffer + + +def _common_all_gather_object( + comm: CommunicatorProtocol, + device: torch.device, + world_size: int, + object_list: list[Any], + object: Any, +): + input_tensor, local_size = _object_to_tensor(object, device) + object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device) + comm.all_gather(object_sizes_tensor, local_size) + object_size_list = [object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size)] + max_object_size = int(max(object_size_list).item()) + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * world_size, dtype=torch.uint8, device=device + ) + + comm.all_gather(coalesced_output_tensor, input_tensor) + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(world_size) + ] + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) + + +def use_backend(backend: str | None): + global _BACKEND_INSTANCE + + if not backend: + return + + mapping = { + "vllm_nccl": ".vllm_nccl.DistributedNccl", + "vllm_hccl": ".vllm_hccl.DistributedHccl", + } + if backend not in mapping: + raise ValueError(f"Unsupported custom backend: {backend}") + + module_path, class_name = mapping[backend].rsplit(".", 1) + module = importlib.import_module(module_path, "checkpoint_engine.distributed") + backend_class = getattr(module, class_name) + _BACKEND_INSTANCE = backend_class() + + +def init_process_group( + rank: int, + world_size: int, + store: torch_dist.TCPStore, + **kwargs, +): + _BACKEND_INSTANCE.init_process_group(rank, world_size, store, **kwargs) + + +def destroy_process_group(group: DistributedProcessGroup | None = None): + _BACKEND_INSTANCE.destroy_process_group(group) + + +def is_initialized() -> bool: + return _BACKEND_INSTANCE.is_initialized() + + +def all_gather_object( + object_list: list[Any], + obj: Any, + group: DistributedProcessGroup | None = None, +): + _BACKEND_INSTANCE.all_gather_object(object_list, obj, group) + + +def all_reduce( + tensor: torch.Tensor, + op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM, + group: DistributedProcessGroup | None = None, + **kwargs, +): + _BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs) + + +def broadcast( + tensor: torch.Tensor, + src: int = 0, + group: DistributedProcessGroup | None = None, + **kwargs, +): + _BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs) + + +def barrier(group: DistributedProcessGroup | None = None, **kwargs): + _BACKEND_INSTANCE.barrier(group, **kwargs) + + +def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None: + return _BACKEND_INSTANCE.new_group(ranks, **kwargs) diff --git a/checkpoint_engine/distributed/vllm_hccl.py b/checkpoint_engine/distributed/vllm_hccl.py new file mode 100644 index 0000000..fbdab0c --- /dev/null +++ b/checkpoint_engine/distributed/vllm_hccl.py @@ -0,0 +1,323 @@ +import ctypes +from contextlib import contextmanager +from typing import Any, ClassVar + +import torch +from torch.distributed import ReduceOp +from vllm.distributed.utils import StatelessProcessGroup +from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator +from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( + Function, + HCCLLibrary, + aclrtStream_t, + buffer_type, + hcclComm_t, + hcclDataType_t, + hcclDataTypeEnum, + hcclResult_t, +) +from vllm_ascend.utils import current_stream + +from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object + + +class HcclCommConfig(ctypes.Structure): + _fields_: ClassVar[list[tuple[str, Any]]] = [ + ("size", ctypes.c_size_t), + ("magic_word", ctypes.c_uint32), + ("version", ctypes.c_uint32), + ("reserved", ctypes.c_uint64), + ("hccl_buffer_size", ctypes.c_uint32), + ("hccl_deterministic", ctypes.c_uint32), + ("hccl_comm_name", ctypes.c_char * 128), + ("hccl_udi", ctypes.c_char * 128), + ("hccl_op_expansion_mode", ctypes.c_uint32), + ("hccl_rdma_traffic_class", ctypes.c_uint32), + ("hccl_rdma_service_level", ctypes.c_uint32), + ("hcll_world_rank_id", ctypes.c_uint32), + ("hccl_job_id", ctypes.c_uint64), + ("comm_engine", ctypes.c_int32), + ("thread_num", ctypes.c_uint32), + ("notify_num_per_thread", ctypes.c_uint32), + ("acl_graph_zero_copy_enable", ctypes.c_uint8), + ] + + +orig_exported_functions = HCCLLibrary.exported_functions +extended_functions = [ + # HcclResult HcclAllGather( + # void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, + # HcclComm comm, alcrtStream stream + # ) + Function( + "HcclAllGather", + hcclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_uint64, + hcclDataType_t, + hcclComm_t, + aclrtStream_t, + ], + ), + # HcclResult HcclCreateSubCommConfig( + # HcclComm *comm, uin32_t rankNum, uint32_t *rankIds, uint64_t subCommId, + # uint32_t subCommRankId, HcclCommConfig *config, HcclComm *subComm + # ) + Function( + "HcclCreateSubCommConfig", + hcclResult_t, + [ + ctypes.POINTER(hcclComm_t), + ctypes.c_uint32, + ctypes.POINTER(ctypes.c_uint32), + ctypes.c_uint64, + ctypes.c_uint32, + ctypes.POINTER(HcclCommConfig), + ctypes.POINTER(hcclComm_t), + ], + ), +] + + +def hccl_all_gather( + self, # noqa: ANN001 + send_buf: buffer_type, + recv_buf: buffer_type, + count: ctypes.c_uint64, + data_type: hcclDataType_t, + comm: hcclComm_t, + stream: aclrtStream_t, +): + self.HCCL_CHECK( + self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream) + ) + + +def hccl_create_subcomm_config( + self, # noqa: ANN001 + comm: hcclComm_t, + ranks_size: ctypes.c_uint32, + c_rank_ids: ctypes.POINTER(ctypes.c_uint32), + subcomm_id: ctypes.c_uint64, + subcomm_rank: ctypes.c_uint64, + comm_config: HcclCommConfig, +) -> hcclComm_t: + subcomm = hcclComm_t() + self.HCCL_CHECK( + self._funcs["HcclCreateSubCommConfig"]( + ctypes.byref(comm), + ranks_size, + c_rank_ids, + subcomm_id, + subcomm_rank, + ctypes.byref(comm_config), + ctypes.byref(subcomm), + ) + ) + return subcomm + + +# extend HCCLLibrary +HCCLLibrary.exported_functions = orig_exported_functions + extended_functions +HCCLLibrary.hcclAllGather = hccl_all_gather +HCCLLibrary.hcclCreateSubCommConfig = hccl_create_subcomm_config + + +class PyHcclCommunicatorEx(PyHcclCommunicator): + def __init__(self, group: StatelessProcessGroup, device: torch.device): + super().__init__(group, device) + self.subcomm_id = 1 + + def destroy_comm(self, comm: hcclComm_t = None): + if comm: + self.hccl.hcclCommDestroy(comm) + else: + self.hccl.hcclCommDestroy(self.comm) + + def all_gather( + self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream: torch.npu.Stream = None + ) -> torch.Tensor: + if self.disabled: + return + assert in_tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, " + f"but the input tensor in on {in_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.hccl.hcclAllGather( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + hcclDataTypeEnum.from_torch(in_tensor.dtype), + self.comm, # todo + aclrtStream_t(stream.npu_stream), + ) + return out_tensor + + def create_subcomm(self, ranks: list[int]) -> hcclComm_t: + comm_config = HcclCommConfig( + size=312, + magic_word=0xF0F0F0F0, + version=6, + reserved=0, + hccl_buffer_size=0xFFFFFFFF, + hccl_deterministic=0xFFFFFFFF, + hccl_comm_name=b"\0", + hccl_udi=b"\0", + hccl_op_expansize_mode=0, + hccl_rdma_traffic_class=0xFFFFFFFF, + hccl_rdma_service_level=0xFFFFFFFF, + hccl_world_rank_id=0, + hccl_job_id=0, + comm_engine=-1, + thread_num=0xFFFFFFFF, + notify_num_per_thread=0xFFFFFFFF, + acl_graph_zero_copy_enable=0, + ) + uint32_array = ctypes.c_uint32 * len(ranks) + c_rank_ids = uint32_array(*ranks) + subcomm_rank = ranks.index(self.rank) + ranks_size = len(ranks) + subcomm_id = self.subcomm_id + + subcomm = self.hccl.hcclCreateSubCommConfig( + self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config + ) + self.subcomm_id += 1 + return subcomm + + +class DistributedHccl(Distributed): + def __init__(self): + self.pg: StatelessProcessGroup = None + self.pyhccl: PyHcclCommunicatorEx = None + self.sub_groups: dict[int, CommGroup] = {} + self.comm: hcclComm_t = None + + self.host: str = None + self.port: int = None + self.rank: int = None + self.world_size: int = None + self.device: torch.device = None + + self.initialized: bool = False + + @contextmanager + def _use_group(self, group: CommGroup | None, src: int | None = None): + active_src = src + if group: + assert group.handle in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group.handle) + self.pyhccl.comm = newcomm + + if src is not None: + assert src in group.ranks, "src rank not in group" + # convert src rank id in default world to newcomm + active_src = group.ranks.index(src) + self.pyhccl.rank = group.ranks.index(self.rank) + + try: + yield active_src + finally: + if group: + self.pyhccl.comm = self.comm + if src is not None: + self.pyhccl.rank = self.rank + + def init_process_group( + self, + rank: int, + world_size: int, + store: torch.distributed.TCPStore, + **kwargs, + ): + assert not self.initialized, "already initialized" + + self.rank = rank + self.world_size = world_size + self.device = torch.device("npu", torch.npu.current_device()) + + self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None) + self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device) + self.comm = self.pyhccl.comm + self.initialized = True + + def destroy_process_group( + self, + group: CommGroup | None = None, + ): + assert self.initialized, "not initialized" + + if group and group.handle in self.sub_groups: + subcomm = ctypes.c_void_p(group.handle) + self.pyhccl.destroy_comm(subcomm) + del self.sub_groups[group.handle] + return + + self.pyhccl.destroy_comm() + self.pyhccl = None + self.pg = None + self.initialized = False + + def is_initialized(self) -> bool: + return self.initialized + + def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None): + assert self.initialized, "not initialized" + + with self._use_group(group): + _common_all_gather_object(self.pyhccl, self.device, self.world_size, object_list, obj) + current_stream().synchronize() + + def all_reduce( + self, + tensor: torch.Tensor, + op: ReduceOp.RedOpType = ReduceOp.SUM, + group: CommGroup | None = None, + **kwargs, + ): + assert self.initialized, "not initialized" + + with self._use_group(group): + out_tensor = self.pyhccl.all_reduce(tensor, op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + def broadcast( + self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs + ): + assert self.initialized, "not initialized" + + with self._use_group(group, src) as local_rank: + self.pyhccl.broadcast(tensor, local_rank) + current_stream().synchronize() + + def barrier(self, group: CommGroup | None = None, **kwargs): + assert self.initialized, "not initialized" + + with self._use_group(group): + data = torch.zeros(1, device=self.device) + self.pyhccl.all_reduce(data) + current_stream().synchronize() + + def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None: + assert self.initialized, "not initialized" + + # ranks is None or [] + if not ranks: + ranks = list(range(self.world_size)) + else: + ranks.sort() + + group: CommGroup = None + if self.rank not in ranks: + return group + + subcomm = self.pyhccl.create_subcomm(ranks) + if subcomm: + group = CommGroup(subcomm.value, ranks) + self.sub_groups[subcomm.value] = group + return group diff --git a/checkpoint_engine/distributed/vllm_nccl.py b/checkpoint_engine/distributed/vllm_nccl.py new file mode 100644 index 0000000..2ffe253 --- /dev/null +++ b/checkpoint_engine/distributed/vllm_nccl.py @@ -0,0 +1,223 @@ +import ctypes +from contextlib import contextmanager +from typing import Any, ClassVar + +import torch +from torch.distributed import ReduceOp +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import ( + Function, + NCCLLibrary, + ncclComm_t, + ncclResult_t, +) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import current_stream + +from checkpoint_engine.distributed.base import CommGroup, Distributed, _common_all_gather_object + + +class NcclConfigT(ctypes.Structure): + _fields_: ClassVar[list[tuple[str, Any]]] = [ + ("size", ctypes.c_size_t), + ("magic", ctypes.c_uint), + ("version", ctypes.c_uint), + ("blocking", ctypes.c_int), + ("cgaClusterSize", ctypes.c_int), + ("minCTAs", ctypes.c_int), + ("maxCTAs", ctypes.c_int), + ("netName", ctypes.c_char_p), + ("splitShare", ctypes.c_int), + ("trafficClass", ctypes.c_int), + ("commName", ctypes.c_char_p), + ("collnetEnable", ctypes.c_int), + ("CTAPolicy", ctypes.c_int), + ("shrinkShare", ctypes.c_int), + ("nvlsCTAs", ctypes.c_int), + ("nChannelsPerNetPeer", ctypes.c_int), + ("nvlinkCentricSched", ctypes.c_int), + ("graphUsageMode", ctypes.c_int), + ("numRmaCtx", ctypes.c_int), + ] + + +nccl_orig_exported_functions = NCCLLibrary.exported_functions +nccl_extended_functions = [ + # ncclResult_t ncclCommSplit( + # ncclComm_t comm, int color, int key, ncclComm_t *newcomm, NcclConfigT *config + # ) + Function( + "ncclCommSplit", + ncclResult_t, + [ + ncclComm_t, + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ncclComm_t), + ctypes.POINTER(NcclConfigT), + ], + ), +] + + +def nccl_comm_split( + self, # noqa: ANN001 + comm: ncclComm_t, + color: int, + key: int, +) -> ncclComm_t: + newcomm = ncclComm_t() + + self.NCCL_CHECK(self._funcs["ncclCommSplit"](comm, color, key, ctypes.byref(newcomm), None)) + return newcomm + + +# extend NCCLLibrary +NCCLLibrary.exported_functions = nccl_orig_exported_functions + nccl_extended_functions +NCCLLibrary.ncclCommSplit = nccl_comm_split + + +class PyNcclCommunicatorEx(PyNcclCommunicator): + def destroy_comm(self, comm: ncclComm_t = None): + if comm: + self.nccl.ncclCommDestroy(comm) + else: + self.nccl.ncclCommDestroy(self.comm) + + def create_newcomm(self, ranks: list[int]) -> ncclComm_t: + if self.rank in ranks: + color = 0 + else: + color = -1 # NCCL_SPLIT_NOCOLOR + newcomm = self.nccl.ncclCommSplit(self.comm, color, self.rank) + return newcomm + + +class DistributedNccl(Distributed): + def __init__(self): + self.pg: StatelessProcessGroup = None + self.pynccl: PyNcclCommunicatorEx = None + self.sub_groups: dict[int, list[int]] = {} + self.comm: ncclComm_t = None + + self.host: str = None + self.port: int = None + self.rank: int = None + self.world_size: int = None + self.device: torch.device = None + + self.initialized: bool = False + + @contextmanager + def _use_group(self, group: CommGroup | None, src: int | None = None): + active_src = src + if group: + assert group.handle in self.sub_groups, "invalid sub_group" + newcomm = ctypes.c_void_p(group.handle) + self.pynccl.comm = newcomm + + if src is not None: + assert src in group.ranks, "src rank not in group" + # convert src rank id in default world to newcomm + active_src = group.ranks.index(src) + self.pynccl.rank = group.ranks.index(self.rank) + + try: + yield active_src + finally: + if group: + self.pynccl.comm = self.comm + if src is not None: + self.pynccl.rank = self.rank + + def init_process_group( + self, + rank: int, + world_size: int, + store: torch.distributed.TCPStore, + **kwargs, + ): + assert not self.initialized, "already initialized" + + self.rank = rank + self.world_size = world_size + self.device = torch.device("cuda", torch.cuda.current_device()) + + self.pg = StatelessProcessGroup(rank=rank, world_size=world_size, store=store, socket=None) + self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self.device) + self.comm = self.pynccl.comm + self.initialized = True + + def destroy_process_group( + self, + group: CommGroup | None = None, + ): + assert self.initialized, "not initialized" + + if group and group.handle in self.sub_groups: + newcomm = ctypes.c_void_p(group.handle) + self.pynccl.destroy_comm(newcomm) + del self.sub_groups[group.handle] + return + + self.pynccl.destroy_comm() + self.pynccl = None + self.pg = None + self.initialized = False + + def is_initialized(self) -> bool: + return self.initialized + + def all_gather_object(self, object_list: list[Any], obj: Any, group: CommGroup | None = None): + assert self.initialized, "not initialized" + + with self._use_group(group): + _common_all_gather_object(self.pynccl, self.device, self.world_size, object_list, obj) + current_stream().synchronize() + + def all_reduce( + self, + tensor: torch.Tensor, + op: ReduceOp.RedOpType = ReduceOp.SUM, + group: CommGroup | None = None, + **kwargs, + ): + assert self.initialized, "not initialized" + + with self._use_group(group): + out_tensor = self.pynccl.all_reduce(in_tensor=tensor, op=op) + current_stream().synchronize() + tensor.copy_(out_tensor) + + def broadcast( + self, tensor: torch.Tensor, src: int | None = None, group: CommGroup | None = None, **kwargs + ): + assert self.initialized, "not initialized" + + with self._use_group(group, src) as local_src: + self.pynccl.broadcast(tensor, local_src) + current_stream().synchronize() + + def barrier(self, group: CommGroup | None = None, **kwargs): + assert self.initialized, "not initialized" + + with self._use_group(group): + data = torch.zeros(1, device=self.device) + self.pynccl.all_reduce(data) + current_stream().synchronize() + + def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None: + assert self.initialized, "not initialized" + + # ranks is None or [] + if not ranks: + ranks = list(range(self.world_size)) + else: + ranks.sort() + + group: CommGroup = None + newcomm = self.pynccl.create_newcomm(ranks) + if newcomm: + group = CommGroup(newcomm.value, ranks) + self.sub_groups[newcomm.value] = group + return group diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index d25b750..eb63d82 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -7,11 +7,12 @@ from typing import TYPE_CHECKING import torch -import torch.distributed as dist +import torch.distributed import zmq from loguru import logger from torch.multiprocessing.reductions import reduce_tensor +import checkpoint_engine.distributed as dist from checkpoint_engine.data_types import ( BucketRange, DataToGather, @@ -531,7 +532,7 @@ def store_based_barrier(self, timeout: timedelta = timedelta(minutes=5)) -> None Args: store: The TCPStore instance to use for synchronization. """ - dist.distributed_c10d._store_based_barrier( + torch.distributed.distributed_c10d._store_based_barrier( rank=self._rank, store=self._store, group_name="parameter_server_barrier", @@ -600,7 +601,10 @@ def zmq_handle(device_uuid: str) -> str: return socket, socket_paths def _detect_bucket_size( - self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False + self, + ranks_group: dist.DistributedProcessGroup | None, + *, + disable_h2d_buffer: bool = False, ) -> tuple[int, bool]: GiB = 1 << 30 # noqa: N806 # auto detect bucket size @@ -617,7 +621,7 @@ def _detect_bucket_size( dtype=torch.int64, device=self.device_manager.device_type, ) - dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group) + dist.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN, group=ranks_group) tensor = tensor.cpu() free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item() max_tensor_bytes = 0 @@ -719,7 +723,7 @@ def _update_per_bucket( self, checkpoint_name: str, req_func: Callable[[list[tuple[str, str]]], None], - ranks_group: dist.ProcessGroup | None, + ranks_group: dist.DistributedProcessGroup | None, ranks: list[int] | None = None, ): assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty" @@ -838,7 +842,7 @@ def _update_per_bucket( f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}" ) ret_code.fill_(1) - dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group) + dist.all_reduce(ret_code, op=torch.distributed.ReduceOp.SUM, group=ranks_group) self.device_manager.device_module.synchronize() if ret_code.item() != 0: # quit early if any rank failed diff --git a/examples/update.py b/examples/update.py index f5605cf..d331a12 100644 --- a/examples/update.py +++ b/examples/update.py @@ -10,10 +10,10 @@ import httpx import torch -import torch.distributed as dist from loguru import logger from safetensors import safe_open +import checkpoint_engine.distributed as dist from checkpoint_engine import request_inference_to_update from checkpoint_engine.ps import ParameterServer @@ -159,10 +159,13 @@ def join( parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0") parser.add_argument("--update-method", type=str, default="broadcast") parser.add_argument("--uds", type=str, default=None) + parser.add_argument("--custom-dist", type=str, default=None) args = parser.parse_args() rank = int(os.getenv("RANK")) world_size = int(os.getenv("WORLD_SIZE")) + req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds) + dist.use_backend(args.custom_dist) ps = ParameterServer(auto_pg=True) if args.load_metas_file: join(