Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ctypes
import gc
import os
import threading
from collections import defaultdict
Expand Down Expand Up @@ -852,6 +853,29 @@ def _update_per_bucket(
gidx += 1

socket.recv()
device_mem = self.device_manager.device_module.mem_get_info()
logger.info(
f"[rank{self._rank}] weights broadcast done, device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, allocated memory: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, reserved memory: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
)
# Notify worker to release handle
socket.send_pyobj(None)
socket.recv()
# Set to None in correct order (views first, then base tensors)
del buffer_b, h2d_buffer, buffer, handle
self.device_manager.device_module.synchronize()
gc.collect()
self.device_manager.device_module.ipc_collect()
self.device_manager.device_module.empty_cache()
self.device_manager.device_module.synchronize()

# Log actual memory usage
device_mem = self.device_manager.device_module.mem_get_info()
logger.info(
f"[rank{self._rank}] post-release: device mem usage: {(device_mem[1] - device_mem[0]) / 1024 / 1024:.2f} MB, "
f"allocated: {self.device_manager.device_module.memory_allocated() / 1024 / 1024:.2f} MB, "
f"reserved: {self.device_manager.device_module.memory_reserved() / 1024 / 1024:.2f} MB"
)
# Notify worker to call post_hook
socket.send_pyobj(None)
socket.recv()
finally:
Expand Down
22 changes: 21 additions & 1 deletion checkpoint_engine/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,35 @@ def update_weights_from_ipc(
socket.send_string(msg)
socket.recv() # wait for ack
raise
# State machine:
# + receive tensor_metadata -> update_weights
# + receive Exception -> raise and stop
# + receive None first time -> release resources
# + receive None second time -> call post_hook and stop
try:
released = False
while True:
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
if payload is None: # done signal
if released:
assert payload is None, "Should not receive any payload after released"
if post_hook is not None:
post_hook()
device_manager.device_module.synchronize()
socket.send(b"")
break
if payload is None: # done signal
# TODO: wrap all messages to an object instead of None and Exception
device_manager.device_module.synchronize()
released = True
buffer = None
del ipc_handle

gc.collect()
device_manager.device_module.ipc_collect()
device_manager.device_module.empty_cache()
device_manager.device_module.synchronize()
socket.send(b"")
continue
if isinstance(payload, list): # still updating weights
try:
run(_extract_weights(payload, buffer))
Expand Down
16 changes: 16 additions & 0 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Ten
name: tensor.to(device_manager.device_type) for name, tensor in named_tensors.items()
}
_zmq_ctx = zmq.Context()
mem_info = device_manager.device_module.mem_get_info()
memory_usage = mem_info[1] - mem_info[0]
memory_history: list[int] = [memory_usage]

def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor]]):
for name, weight in weights:
Expand All @@ -108,6 +111,11 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
run=lambda weights: check(names_to_check, weights),
post_hook=lambda: device_manager.device_module.synchronize(),
)
device_manager.device_module.synchronize()
device_manager.device_module.empty_cache()
mem_info = device_manager.device_module.mem_get_info()
memory_usage = mem_info[1] - mem_info[0]
memory_history.append(memory_usage)
assert all(names_to_check.values())

while True:
Expand All @@ -117,6 +125,12 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
names_to_check = dict.fromkeys(named_tensors.keys(), False)
check_weights(names_to_check, socket_paths)

mem_info = device_manager.device_module.mem_get_info()
memory_usage = mem_info[1] - mem_info[0]
memory_history.append(memory_usage)
for memory in memory_history[1:]:
print(f"[rank{rank}] Memory change: {memory - memory_history[0]}")


def run(
checker_func: callable,
Expand Down Expand Up @@ -318,6 +332,8 @@ def test_update_with_files(test_name: str = "test_with_files"):
rank_list = json.loads(sys.argv[2])
if test_type == "test_no_error":
run(checker_proc, rank_list, need_error=False)
mem_info = device_manager.device_module.mem_get_info()
print(f"Memory usage: {mem_info[1] - mem_info[0]}")
elif test_type == "test_with_remote_error":
run(
checker_proc_with_error,
Expand Down