diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 084b4f68..30923b8a 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -8,6 +8,7 @@ import logging from collections import OrderedDict +from contextlib import nullcontext from typing import Callable, Iterable import torch @@ -126,6 +127,8 @@ def __init__( module: Callable, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwarg_keys, ) -> None: super().__init__() @@ -134,18 +137,24 @@ def __init__( self.arg_keys = _call_with_args_ self.kwarg_keys = kwarg_keys self.return_keys = _return_as_dict_ + self.train_mode = _train_mode_ + self.inference_mode = _inference_mode_ def __call__( self, *args, _args_: Iterable[str] | None = None, _return_keys_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwargs, ): module_name = self.module.__class__.__name__ arg_keys = _args_ or self.arg_keys kwarg_keys = self.kwarg_keys + _train_mode_ = _train_mode_ or self.train_mode + _inference_mode_ = _inference_mode_ or self.inference_mode # Change and replace args and kwargs that we call module with if arg_keys is not None or len(kwarg_keys) > 0: @@ -183,8 +192,24 @@ def __call__( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex - # FIXME: Add better error message - ret = self.module(*args, **kwargs) + # Apply train mode and inference mode, if necessary, and call module with args and kwargs + context = nullcontext() + if isinstance(self.module, torch.nn.Module): + old_train_mode = self.module.training + + if _train_mode_ is not None: + self.module.train(_train_mode_) + + if _inference_mode_ is not None: + context = torch.inference_mode(mode=_inference_mode_) + + with context: + # FIXME: Add better error message + ret = self.module(*args, **kwargs) + + if isinstance(self.module, torch.nn.Module): + if _train_mode_ is not None: + self.module.train(old_train_mode) # Change returned values into dictionary, if necessary return_keys = _return_keys_ or self.return_keys