diff --git a/checkpoint/orbax/checkpoint/checkpoint_args.py b/checkpoint/orbax/checkpoint/checkpoint_args.py index e0660835c..12274535c 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_args.py +++ b/checkpoint/orbax/checkpoint/checkpoint_args.py @@ -16,7 +16,7 @@ import dataclasses import inspect -from typing import Tuple, Type, Union +from typing import Tuple, Type, TypeVar, Union from orbax.checkpoint._src.handlers import checkpoint_handler from orbax.checkpoint._src.handlers import handler_type_registry @@ -77,6 +77,8 @@ class MyCheckpointRestore(ocp.args.CheckpointArgs): {} ) +_CheckpointArgsType = TypeVar('_CheckpointArgsType', bound=CheckpointArgs) + def register_with_handler( handler_cls: Type[CheckpointHandler], @@ -104,7 +106,9 @@ def register_with_handler( if not for_save and not for_restore: raise ValueError('`for_save` and `for_restore` cannot both be False.') - def decorator(cls: Type[CheckpointArgs]): + def decorator( + cls: Type[_CheckpointArgsType], + ) -> Type[_CheckpointArgsType]: if not issubclass(cls, CheckpointArgs): raise TypeError( f'{cls} must subclass `CheckpointArgs` in order to be registered.' diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index a4c4f0744..bcb3e32c2 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -536,7 +536,7 @@ class ErrorSaveArgs(pytree_checkpoint_handler.PyTreeSaveArgs): @checkpoint_args.register_with_handler(ErrorCheckpointHandler, for_restore=True) -class ErrorRestoreArgs(pytree_checkpoint_handler.PyTreeSaveArgs): +class ErrorRestoreArgs(pytree_checkpoint_handler.PyTreeRestoreArgs): pass