From 6186729656062bb3f1757cc38d9a021b41d61d55 Mon Sep 17 00:00:00 2001 From: Hubert Zhang Date: Tue, 20 Jan 2026 03:41:58 +0800 Subject: [PATCH] feat: make use_inplace_pin_memory configurable via env and request --- checkpoint_engine/api.py | 7 ++++++- checkpoint_engine/ps.py | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/checkpoint_engine/api.py b/checkpoint_engine/api.py index e61b41d..8e44e0d 100644 --- a/checkpoint_engine/api.py +++ b/checkpoint_engine/api.py @@ -47,6 +47,7 @@ def _init_api(ps: ParameterServer) -> Any: class RegisterRequest(BaseModel): files: list[str] + use_inplace_pin_memory: bool | None = None class UpdateRequest(BaseModel): ranks: list[int] = [] @@ -65,7 +66,11 @@ def wrap_exception(func: Callable[[], None]) -> Response: @app.post("/v1/checkpoints/{checkpoint_name}/files") async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response: - return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files)) + return wrap_exception( + lambda: ps.register_checkpoint( + checkpoint_name, files=req.files, use_inplace_pin_memory=req.use_inplace_pin_memory + ) + ) @app.delete("/v1/checkpoints/{checkpoint_name}") async def unregister_checkpoint(checkpoint_name: str) -> Response: diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 20f5be6..2e049be 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -263,7 +263,7 @@ def register_checkpoint( files: list[str] | None = None, named_tensors: dict[str, torch.Tensor] | None = None, use_shared_memory_pool: bool = False, - use_inplace_pin_memory: bool = True, + use_inplace_pin_memory: bool | None = None, ) -> None: """ Register a checkpoint to the parameter server. Both files and named_tensors will be registered together. @@ -282,6 +282,12 @@ def register_checkpoint( use_inplace_pin_memory: If True (default), allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True. """ + if use_inplace_pin_memory is None: + env_str = os.getenv("PS_USE_INPLACE_PIN_MEMORY", "true") + use_inplace_pin_memory = env_str.lower() in ["true", "1", "yes", "y"] + logger.info( + f"[rank{self._rank}] use_inplace_pin_memory set to {use_inplace_pin_memory} by environment variable PS_USE_INPLACE_PIN_MEMORY={env_str}" + ) if self.device_manager.device_type != "cuda" and use_inplace_pin_memory: logger.warning( f"[rank{self._rank}] Only cuda devices support in-place pin memory, set use_inplace_pin_memory to False"