Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions checkpoint/orbax/checkpoint/checkpoint_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import dataclasses
import inspect
from typing import Tuple, Type, Union
from typing import Tuple, Type, TypeVar, Union

from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import handler_type_registry
Expand Down Expand Up @@ -77,6 +77,8 @@ class MyCheckpointRestore(ocp.args.CheckpointArgs):
{}
)

_CheckpointArgsType = TypeVar('_CheckpointArgsType', bound=CheckpointArgs)


def register_with_handler(
handler_cls: Type[CheckpointHandler],
Expand Down Expand Up @@ -104,7 +106,9 @@ def register_with_handler(
if not for_save and not for_restore:
raise ValueError('`for_save` and `for_restore` cannot both be False.')

def decorator(cls: Type[CheckpointArgs]):
def decorator(
cls: Type[_CheckpointArgsType],
) -> Type[_CheckpointArgsType]:
if not issubclass(cls, CheckpointArgs):
raise TypeError(
f'{cls} must subclass `CheckpointArgs` in order to be registered.'
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ class ErrorSaveArgs(pytree_checkpoint_handler.PyTreeSaveArgs):


@checkpoint_args.register_with_handler(ErrorCheckpointHandler, for_restore=True)
class ErrorRestoreArgs(pytree_checkpoint_handler.PyTreeSaveArgs):
class ErrorRestoreArgs(pytree_checkpoint_handler.PyTreeRestoreArgs):
pass


Expand Down
Loading