From 3116c97dc56fa7326625187f47d1046e636ab7f3 Mon Sep 17 00:00:00 2001 From: Shutong Li Date: Tue, 23 Dec 2025 13:33:11 -0800 Subject: [PATCH] Fix a type incompatibility issues. PiperOrigin-RevId: 848282522 --- checkpoint/orbax/checkpoint/checkpoint_args.py | 8 ++++++-- checkpoint/orbax/checkpoint/test_utils.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/checkpoint/orbax/checkpoint/checkpoint_args.py b/checkpoint/orbax/checkpoint/checkpoint_args.py index e0660835c..12274535c 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_args.py +++ b/checkpoint/orbax/checkpoint/checkpoint_args.py @@ -16,7 +16,7 @@ import dataclasses import inspect -from typing import Tuple, Type, Union +from typing import Tuple, Type, TypeVar, Union from orbax.checkpoint._src.handlers import checkpoint_handler from orbax.checkpoint._src.handlers import handler_type_registry @@ -77,6 +77,8 @@ class MyCheckpointRestore(ocp.args.CheckpointArgs): {} ) +_CheckpointArgsType = TypeVar('_CheckpointArgsType', bound=CheckpointArgs) + def register_with_handler( handler_cls: Type[CheckpointHandler], @@ -104,7 +106,9 @@ def register_with_handler( if not for_save and not for_restore: raise ValueError('`for_save` and `for_restore` cannot both be False.') - def decorator(cls: Type[CheckpointArgs]): + def decorator( + cls: Type[_CheckpointArgsType], + ) -> Type[_CheckpointArgsType]: if not issubclass(cls, CheckpointArgs): raise TypeError( f'{cls} must subclass `CheckpointArgs` in order to be registered.' diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index a4c4f0744..bcb3e32c2 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -536,7 +536,7 @@ class ErrorSaveArgs(pytree_checkpoint_handler.PyTreeSaveArgs): @checkpoint_args.register_with_handler(ErrorCheckpointHandler, for_restore=True) -class ErrorRestoreArgs(pytree_checkpoint_handler.PyTreeSaveArgs): +class ErrorRestoreArgs(pytree_checkpoint_handler.PyTreeRestoreArgs): pass