From 48b62bb9e5f7860c0f31c6b9858e9f21d51009f5 Mon Sep 17 00:00:00 2001 From: hongchao Date: Fri, 30 Jan 2026 05:22:35 +0000 Subject: [PATCH] fix ps alloc err & avoid mem fragmentation --- checkpoint_engine/worker.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 3979cda..8278e82 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -177,7 +177,7 @@ def _load_weights(weights: _WEIGHTS_TYPE): ): self.model_runner.drafter.model.load_weights(weights=weights) - def _post_hook(): + def _process_weight_after_loading(): process_weights_after_loading(self.model_runner.model, self.model_config, self.device) # Also trigger drafter model's post processing if MTP is enabled if ( @@ -188,10 +188,15 @@ def _post_hook(): self.model_runner.drafter.model, self.model_config, self.device ) + torch.cuda.empty_cache() + update_weights_from_ipc( self._zmq_ctx, zmq_handles[self._device_uuid], device_id=self.device.index, run=_load_weights, - post_hook=_post_hook, + post_hook=_process_weight_after_loading, ) + + if getattr(self, "_sampler_warmup", None) is not None: + self._sampler_warmup()