diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 3979cda..8278e82 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -177,7 +177,7 @@ def _load_weights(weights: _WEIGHTS_TYPE): ): self.model_runner.drafter.model.load_weights(weights=weights) - def _post_hook(): + def _process_weight_after_loading(): process_weights_after_loading(self.model_runner.model, self.model_config, self.device) # Also trigger drafter model's post processing if MTP is enabled if ( @@ -188,10 +188,15 @@ def _post_hook(): self.model_runner.drafter.model, self.model_config, self.device ) + torch.cuda.empty_cache() + update_weights_from_ipc( self._zmq_ctx, zmq_handles[self._device_uuid], device_id=self.device.index, run=_load_weights, - post_hook=_post_hook, + post_hook=_process_weight_after_loading, ) + + if getattr(self, "_sampler_warmup", None) is not None: + self._sampler_warmup()