diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index eb63d82..1d8c5cf 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -1,4 +1,5 @@ import ctypes +import gc import os import threading from collections import defaultdict @@ -852,6 +853,29 @@ def _update_per_bucket( gidx += 1 socket.recv() + device_mem = self.device_manager.device_module.mem_get_info() + logger.info( + f"[rank{self._rank}] weights broadcast done, device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, allocated memory: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, reserved memory: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB" + ) + # Notify worker to release handle + socket.send_pyobj(None) + socket.recv() + # Set to None in correct order (views first, then base tensors) + del buffer_b, h2d_buffer, buffer, handle + self.device_manager.device_module.synchronize() + gc.collect() + self.device_manager.device_module.ipc_collect() + self.device_manager.device_module.empty_cache() + self.device_manager.device_module.synchronize() + + # Log actual memory usage + device_mem = self.device_manager.device_module.mem_get_info() + logger.info( + f"[rank{self._rank}] post-release: device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, " + f"allocated: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, " + f"reserved: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB" + ) + # Notify worker to call post_hook socket.send_pyobj(None) socket.recv() finally: diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 3979cda..ea170ca 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -70,15 +70,35 @@ def update_weights_from_ipc( socket.send_string(msg) socket.recv() # wait for ack raise + # State machine: + # + receive tensor_metadata -> update_weights + # + receive Exception -> raise and stop + # + receive None first time -> release resources + # + receive None second time -> call post_hook and stop try: + released = False while True: payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj() - if payload is None: # done signal + if released: + assert payload is None, "Should not receive any payload after released" if post_hook is not None: post_hook() device_manager.device_module.synchronize() socket.send(b"") break + if payload is None: # done signal + # TODO: wrap all messages to an object instead of None and Exception + device_manager.device_module.synchronize() + released = True + buffer = None + del ipc_handle + + gc.collect() + device_manager.device_module.ipc_collect() + device_manager.device_module.empty_cache() + device_manager.device_module.synchronize() + socket.send(b"") + continue if isinstance(payload, list): # still updating weights try: run(_extract_weights(payload, buffer)) diff --git a/tests/test_update.py b/tests/test_update.py index e44354f..89ddcfa 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -91,6 +91,9 @@ def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Ten name: tensor.to(device_manager.device_type) for name, tensor in named_tensors.items() } _zmq_ctx = zmq.Context() + mem_info = device_manager.device_module.mem_get_info() + memory_usage = mem_info[1] - mem_info[0] + memory_history: list[int] = [memory_usage] def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor]]): for name, weight in weights: @@ -108,6 +111,11 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, run=lambda weights: check(names_to_check, weights), post_hook=lambda: device_manager.device_module.synchronize(), ) + device_manager.device_module.synchronize() + device_manager.device_module.empty_cache() + mem_info = device_manager.device_module.mem_get_info() + memory_usage = mem_info[1] - mem_info[0] + memory_history.append(memory_usage) assert all(names_to_check.values()) while True: @@ -117,6 +125,12 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, names_to_check = dict.fromkeys(named_tensors.keys(), False) check_weights(names_to_check, socket_paths) + mem_info = device_manager.device_module.mem_get_info() + memory_usage = mem_info[1] - mem_info[0] + memory_history.append(memory_usage) + for memory in memory_history[1:]: + print(f"[rank{rank}] Memory change: {memory - memory_history[0]}") + def run( checker_func: callable, @@ -318,6 +332,8 @@ def test_update_with_files(test_name: str = "test_with_files"): rank_list = json.loads(sys.argv[2]) if test_type == "test_no_error": run(checker_proc, rank_list, need_error=False) + mem_info = device_manager.device_module.mem_get_info() + print(f"Memory usage: {mem_info[1] - mem_info[0]}") elif test_type == "test_with_remote_error": run( checker_proc_with_error,