diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index b5512c3..d25b750 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -175,6 +175,8 @@ def __init__( auto_pg: bool = True, gpu_count: int | None = None, mem_fraction: float | None = None, + master_addr: str | None = None, + master_port: int | None = None, ): """ Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set. @@ -228,6 +230,17 @@ def __init__( self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) self._rdma_device = None if self._p2p_store is None else self._p2p_store.device + master_addr = master_addr or os.getenv("MASTER_ADDR") + assert master_addr, "master_addr is required" + self._store = torch.distributed.TCPStore( + master_addr, + _get_master_port(master_port), + self._world_size, + timeout=timedelta(minutes=10), + is_master=self._rank == 0, + ) + self._store_counter = 0 + def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]: if checkpoint_name == self._current_shared_memory_pool_user: assert self._memory_pool[self.shared_memory_pool_name], ( @@ -487,8 +500,6 @@ def gather_metas(self, checkpoint_name: str): def init_process_group( self, *, - master_addr: str | None = None, - master_port: int | None = None, timeout: timedelta = timedelta(minutes=10), ): """ @@ -498,27 +509,18 @@ def init_process_group( master_port: The specified port of the master node. If not set, will use _get_master_port to get the port. timeout: The timeout of the process group. """ - master_addr = master_addr or os.getenv("MASTER_ADDR") - assert master_addr, "master_addr is required" - store = dist.TCPStore( - master_addr, - _get_master_port(master_port), - self._world_size, - timeout=timeout, - is_master=self._rank == 0, - ) + self._store_counter += 1 + sub_store = torch.distributed.PrefixStore(f"prefix-{self._store_counter}", self._store) dist.init_process_group( backend=self.device_manager.backend, world_size=self._world_size, rank=self._rank, timeout=timeout, - store=store, + store=sub_store, ) logger.info(f"[rank{self._rank}] init process group successfully.") - def store_based_barrier( - self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5) - ) -> None: + def store_based_barrier(self, timeout: timedelta = timedelta(minutes=5)) -> None: """ Perform a store-based barrier synchronization across all ranks. @@ -531,7 +533,7 @@ def store_based_barrier( """ dist.distributed_c10d._store_based_barrier( rank=self._rank, - store=store, + store=self._store, group_name="parameter_server_barrier", rendezvous_count=self._world_size, timeout=timeout, @@ -544,8 +546,6 @@ def update( *, timeout: timedelta = timedelta(minutes=10), ranks: list[int] | None = None, - master_addr: str | None = None, - master_port: int | None = None, ) -> None: """ Update the checkpoint to inference engine. This function should be called after gather_metas. @@ -566,28 +566,12 @@ def update( assert req_func is not None, "req_func is required" ranks_group = None try: - master_addr = os.getenv("MASTER_ADDR") or master_addr - assert master_addr, "master_addr is required" - if self._auto_pg: - if not dist.is_initialized(): - self.init_process_group( - timeout=timeout, master_addr=master_addr, master_port=master_port - ) - manager_store = dist.distributed_c10d._get_default_store() - else: - # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1 - # If master_port is provided, use master_port+1 for barrier store - manager_store = dist.TCPStore( - master_addr, - _get_master_port(master_port) + 1, - self._world_size, - timeout=timeout, - is_master=self._rank == 0, - ) + if self._auto_pg and not dist.is_initialized(): + self.init_process_group(timeout=timeout) # if ranks is None or [], it will use fully broadcast to update to all ranks ranks_group = dist.new_group(ranks) if ranks else None self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks) - self.store_based_barrier(manager_store) + self.store_based_barrier() except Exception as e: logger.exception( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}" diff --git a/tests/test_reuse_pin_memory.py b/tests/test_reuse_pin_memory.py index bb698b7..c2fb325 100644 --- a/tests/test_reuse_pin_memory.py +++ b/tests/test_reuse_pin_memory.py @@ -23,6 +23,8 @@ def generate_dummy_checkpoint() -> dict[str, torch.Tensor]: def test_register_pin_memory(): os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "25400" ps = ParameterServer() checkpoint1 = generate_dummy_checkpoint() checkpoint_shared1 = generate_dummy_checkpoint()