From cef11f38cf5d3b207beb30fb4f2507d5106b1d63 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:31:10 -0700 Subject: [PATCH 01/23] fix imports --- mart/nn/nn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 93b0f07f..8ad7c194 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -5,15 +5,14 @@ # import logging +from typing import OrderedDict -logger = logging.getLogger(__name__) - -from typing import OrderedDict # noqa: E402 - -import torch # noqa: E402 +import torch __all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum"] +logger = logging.getLogger(__name__) + class SequentialDict(torch.nn.ModuleDict): """A special Sequential container where we can rewire the input and output of each module in From 9522075536bd0df387cc957367b614444cb66608 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:33:54 -0700 Subject: [PATCH 02/23] Move _call_with_args_ and _return_as_dict_ functionality into CallWith --- mart/nn/nn.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 8ad7c194..7f403b9f 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -4,8 +4,10 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import logging -from typing import OrderedDict +from typing import Callable, Iterable, OrderedDict import torch @@ -82,20 +84,9 @@ def parse_sequence(self, sequence): # The return name could be different from module_name when a module is used more than once. return_name = module_cfg.pop("_name_", module_name) - # The module would be called with these *args. - arg_keys = module_cfg.pop("_call_with_args_", None) - # The module would return a dictionary with these keys instead of a tuple. - return_keys = module_cfg.pop("_return_as_dict", None) - # The module would be called with these **kwargs. - kwarg_keys = module_cfg - - module = self[module_name] - - # Add CallWith to module if we have enough parameters - if arg_keys is not None or len(kwarg_keys) > 0 or return_keys is not None: - module = CallWith(module, arg_keys, kwarg_keys, return_keys) - + module = CallWith(self[module_name], **module_cfg) module_dict[return_name] = module + return module_dict def forward(self, step=None, sequence=None, **kwargs): @@ -132,13 +123,19 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): - def __init__(self, module, arg_keys, kwarg_keys, return_keys=None) -> None: + def __init__( + self, + module: Callable, + _call_with_args_: Iterable[str] | None = None, + _return_as_dict_: Iterable[str] | None = None, + **kwarg_keys, + ) -> None: super().__init__() self.module = module - self.arg_keys = arg_keys or [] + self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} - self.return_keys = return_keys + self.return_keys = _return_as_dict_ def forward(self, *args, **kwargs): orig_class = self.module.__class__ From ba582644bed9a793d90f5688a7780d1179d18e66 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:37:17 -0700 Subject: [PATCH 03/23] Allow overwriting _call_with_args_ and _return_as_dict_ in CallWith.forward --- mart/nn/nn.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 7f403b9f..fead5225 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -137,10 +137,17 @@ def __init__( self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ - def forward(self, *args, **kwargs): + def forward( + self, + *args, + _call_with_args_: Iterable[str] | None = None, + _return_as_dict_: Iterable[str] | None = None, + **kwargs, + ): orig_class = self.module.__class__ - arg_keys = self.arg_keys + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys + return_keys = _return_as_dict_ or self.return_keys kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -160,16 +167,16 @@ def forward(self, *args, **kwargs): # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) - if self.return_keys: + if return_keys: if not isinstance(ret, tuple): raise Exception( f"Module {orig_class} does not return multiple unnamed variables, so we can not dictionarize the return." ) - if len(self.return_keys) != len(ret): + if len(return_keys) != len(ret): raise Exception( - f"Module {orig_class} returns {len(ret)} items, but {len(self.return_keys)} return_keys were specified." + f"Module {orig_class} returns {len(ret)} items, but {len(return_keys)} return_keys were specified." ) - ret = dict(zip(self.return_keys, ret)) + ret = dict(zip(return_keys, ret)) return ret From 60fc2ad0d2a372e03cc64133a9068496c51d8ad8 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:04:07 -0700 Subject: [PATCH 04/23] Add _train_mode_ and _inference_mode_ to CallWith --- mart/nn/nn.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index fead5225..6c2ba4bc 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,7 +7,9 @@ from __future__ import annotations import logging -from typing import Callable, Iterable, OrderedDict +from collections import OrderedDict +from contextlib import nullcontext +from typing import Iterable import torch @@ -125,9 +127,11 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: Callable, + module: torch.nn.Module, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwarg_keys, ) -> None: super().__init__() @@ -136,18 +140,26 @@ def __init__( self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ + self.train_mode = _train_mode_ + self.inference_mode = _inference_mode_ def forward( self, *args, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwargs, ): orig_class = self.module.__class__ + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys return_keys = _return_as_dict_ or self.return_keys + _train_mode_ = _train_mode_ or self.train_mode + _inference_mode_ = _inference_mode_ or self.inference_mode + kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -164,8 +176,21 @@ def forward( selected_args = [kwargs[key] for key in arg_keys[len(args) :]] selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} - # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) + old_train_mode = self.module.training + + if _train_mode_ is not None: + self.module.train(_train_mode_) + + context = nullcontext() + if _inference_mode_ is not None: + context = torch.inference_mode(mode=_inference_mode_) + + with context: + # FIXME: Add better error message + ret = self.module(*args, *selected_args, **selected_kwargs) + + if _train_mode_ is not None: + self.module.train(old_train_mode) if return_keys: if not isinstance(ret, tuple): From 2e51c662409a7aeb0e0913696330395326772fa7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:11:25 -0700 Subject: [PATCH 05/23] Revert "Add _train_mode_ and _inference_mode_ to CallWith" This reverts commit 60fc2ad0d2a372e03cc64133a9068496c51d8ad8. --- mart/nn/nn.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 6c2ba4bc..fead5225 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,9 +7,7 @@ from __future__ import annotations import logging -from collections import OrderedDict -from contextlib import nullcontext -from typing import Iterable +from typing import Callable, Iterable, OrderedDict import torch @@ -127,11 +125,9 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: torch.nn.Module, + module: Callable, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, - _train_mode_: bool | None = None, - _inference_mode_: bool | None = None, **kwarg_keys, ) -> None: super().__init__() @@ -140,26 +136,18 @@ def __init__( self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ - self.train_mode = _train_mode_ - self.inference_mode = _inference_mode_ def forward( self, *args, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, - _train_mode_: bool | None = None, - _inference_mode_: bool | None = None, **kwargs, ): orig_class = self.module.__class__ - arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys return_keys = _return_as_dict_ or self.return_keys - _train_mode_ = _train_mode_ or self.train_mode - _inference_mode_ = _inference_mode_ or self.inference_mode - kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -176,21 +164,8 @@ def forward( selected_args = [kwargs[key] for key in arg_keys[len(args) :]] selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} - old_train_mode = self.module.training - - if _train_mode_ is not None: - self.module.train(_train_mode_) - - context = nullcontext() - if _inference_mode_ is not None: - context = torch.inference_mode(mode=_inference_mode_) - - with context: - # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) - - if _train_mode_ is not None: - self.module.train(old_train_mode) + # FIXME: Add better error message + ret = self.module(*args, *selected_args, **selected_kwargs) if return_keys: if not isinstance(ret, tuple): From 1671755ec6d3b9a63bae7a1b1b2d4662faec86ec Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:12:24 -0700 Subject: [PATCH 06/23] Revert "Revert "Add _train_mode_ and _inference_mode_ to CallWith"" This reverts commit 2e51c662409a7aeb0e0913696330395326772fa7. --- mart/nn/nn.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index fead5225..6c2ba4bc 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,7 +7,9 @@ from __future__ import annotations import logging -from typing import Callable, Iterable, OrderedDict +from collections import OrderedDict +from contextlib import nullcontext +from typing import Iterable import torch @@ -125,9 +127,11 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: Callable, + module: torch.nn.Module, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwarg_keys, ) -> None: super().__init__() @@ -136,18 +140,26 @@ def __init__( self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ + self.train_mode = _train_mode_ + self.inference_mode = _inference_mode_ def forward( self, *args, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwargs, ): orig_class = self.module.__class__ + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys return_keys = _return_as_dict_ or self.return_keys + _train_mode_ = _train_mode_ or self.train_mode + _inference_mode_ = _inference_mode_ or self.inference_mode + kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -164,8 +176,21 @@ def forward( selected_args = [kwargs[key] for key in arg_keys[len(args) :]] selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} - # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) + old_train_mode = self.module.training + + if _train_mode_ is not None: + self.module.train(_train_mode_) + + context = nullcontext() + if _inference_mode_ is not None: + context = torch.inference_mode(mode=_inference_mode_) + + with context: + # FIXME: Add better error message + ret = self.module(*args, *selected_args, **selected_kwargs) + + if _train_mode_ is not None: + self.module.train(old_train_mode) if return_keys: if not isinstance(ret, tuple): From 731b23daa5abfc13db8205e9e39e08cb85ff9c6f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:14:18 -0700 Subject: [PATCH 07/23] cleanup --- mart/nn/nn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index fead5225..170c7274 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,7 +7,8 @@ from __future__ import annotations import logging -from typing import Callable, Iterable, OrderedDict +from collections import OrderedDict +from typing import Iterable import torch @@ -125,7 +126,7 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: Callable, + module: torch.nn.Module, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, **kwarg_keys, From 8e664f1927690f20dc3e7131be94738434db3119 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:27:27 -0700 Subject: [PATCH 08/23] cleanup --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 170c7274..ed734a95 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -148,7 +148,6 @@ def forward( orig_class = self.module.__class__ arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys - return_keys = _return_as_dict_ or self.return_keys kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -168,6 +167,7 @@ def forward( # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) + return_keys = _return_as_dict_ or self.return_keys if return_keys: if not isinstance(ret, tuple): raise Exception( From ae1b8369bb35a4645e68fb1c2c677058e7e7094f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 11:46:30 -0700 Subject: [PATCH 09/23] Fix configs --- mart/configs/attack/gain/modular.yaml | 3 +-- mart/configs/attack/gain/rcnn_training_loss.yaml | 3 +-- mart/configs/attack/objective/misclassification.yaml | 3 +-- mart/configs/attack/objective/object_detection_missed.yaml | 3 +-- mart/configs/attack/objective/zero_ap.yaml | 3 +-- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/mart/configs/attack/gain/modular.yaml b/mart/configs/attack/gain/modular.yaml index 9b06b307..9c76b2c4 100644 --- a/mart/configs/attack/gain/modular.yaml +++ b/mart/configs/attack/gain/modular.yaml @@ -1,7 +1,6 @@ _target_: mart.nn.CallWith module: _target_: ??? -arg_keys: +_call_with_args_: - logits - target -kwarg_keys: null diff --git a/mart/configs/attack/gain/rcnn_training_loss.yaml b/mart/configs/attack/gain/rcnn_training_loss.yaml index eb7abb9c..19b9355e 100644 --- a/mart/configs/attack/gain/rcnn_training_loss.yaml +++ b/mart/configs/attack/gain/rcnn_training_loss.yaml @@ -1,9 +1,8 @@ _target_: mart.nn.CallWith module: _target_: mart.nn.Sum -arg_keys: +_call_with_args_: - rpn_loss.loss_objectness - rpn_loss.loss_rpn_box_reg - box_loss.loss_classifier - box_loss.loss_box_reg -kwarg_keys: null diff --git a/mart/configs/attack/objective/misclassification.yaml b/mart/configs/attack/objective/misclassification.yaml index e2e9b819..a2e6260e 100644 --- a/mart/configs/attack/objective/misclassification.yaml +++ b/mart/configs/attack/objective/misclassification.yaml @@ -1,7 +1,6 @@ _target_: mart.nn.CallWith module: _target_: mart.attack.objective.Mispredict -arg_keys: +_call_with_args_: - preds - target -kwarg_keys: null diff --git a/mart/configs/attack/objective/object_detection_missed.yaml b/mart/configs/attack/objective/object_detection_missed.yaml index dec2410c..efc93078 100644 --- a/mart/configs/attack/objective/object_detection_missed.yaml +++ b/mart/configs/attack/objective/object_detection_missed.yaml @@ -2,6 +2,5 @@ _target_: mart.nn.CallWith module: _target_: mart.attack.objective.Missed confidence_threshold: 0.0 -arg_keys: +_call_with_args_: - preds -kwarg_keys: null diff --git a/mart/configs/attack/objective/zero_ap.yaml b/mart/configs/attack/objective/zero_ap.yaml index 6a43f77d..e11105e6 100644 --- a/mart/configs/attack/objective/zero_ap.yaml +++ b/mart/configs/attack/objective/zero_ap.yaml @@ -3,7 +3,6 @@ module: _target_: mart.attack.objective.ZeroAP iou_threshold: 0.5 confidence_threshold: 0.0 -arg_keys: +_call_with_args_: - preds - target -kwarg_keys: null From 39f9aaa04efc5bf6e7a9040c3a4068c5e8b032ed Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 11:48:48 -0700 Subject: [PATCH 10/23] cleanup --- mart/nn/nn.py | 47 +++++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index ed734a95..9710f487 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -134,8 +134,8 @@ def __init__( super().__init__() self.module = module - self.arg_keys = _call_with_args_ or [] - self.kwarg_keys = kwarg_keys or {} + self.arg_keys = _call_with_args_ + self.kwarg_keys = kwarg_keys self.return_keys = _return_as_dict_ def forward( @@ -145,37 +145,52 @@ def forward( _return_as_dict_: Iterable[str] | None = None, **kwargs, ): - orig_class = self.module.__class__ + module_name = self.module.__class__.__name__ + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys + + args = list(args) kwargs = DotDict(kwargs) - # Sometimes we receive positional arguments because some modules use nn.Sequential - # which has a __call__ function that passes positional args. So we pass along args - # as it and assume these consume the first len(args) of arg_keys. - remaining_arg_keys = arg_keys[len(args) :] + # Change and replaces args and kwargs that we call module with + if arg_keys is not None or len(kwarg_keys) > 0: + arg_keys = arg_keys or [] + + # Sometimes we receive positional arguments because some modules use nn.Sequential + # which has a __call__ function that passes positional args. So we pass along args + # as it and assume these consume the first len(args) of arg_keys. + arg_keys = arg_keys[len(args) :] - for key in remaining_arg_keys + list(kwarg_keys.values()): - if key not in kwargs: + # Append kwargs to args using arg_keys + try: + [args.append(kwargs[kwargs_key]) for kwargs_key in arg_keys] + except KeyError as ex: raise Exception( - f"Module {orig_class} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." - ) + f"{module_name} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + ) from ex - selected_args = [kwargs[key] for key in arg_keys[len(args) :]] - selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} + # Replace kwargs with selected kwargs + try: + kwargs = {name: kwargs[kwargs_key] for name, kwargs_key in kwarg_keys.items()} + except KeyError as ex: + raise Exception( + f"{module_name} wants kwarg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + ) from ex # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) + ret = self.module(*args, **kwargs) + # Change returned values into dictionary, if necessary return_keys = _return_as_dict_ or self.return_keys if return_keys: if not isinstance(ret, tuple): raise Exception( - f"Module {orig_class} does not return multiple unnamed variables, so we can not dictionarize the return." + f"{module_name} does not return multiple unnamed variables, so we can not dictionarize the return." ) if len(return_keys) != len(ret): raise Exception( - f"Module {orig_class} returns {len(ret)} items, but {len(return_keys)} return_keys were specified." + f"Module {module_name} returns {len(ret)} items, but {len(return_keys)} return_keys were specified." ) ret = dict(zip(return_keys, ret)) From 097393384109e845a554063dc1786313d735f871 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 11:50:08 -0700 Subject: [PATCH 11/23] bugfix --- mart/nn/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 9710f487..c932bc55 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -167,7 +167,7 @@ def forward( [args.append(kwargs[kwargs_key]) for kwargs_key in arg_keys] except KeyError as ex: raise Exception( - f"{module_name} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex # Replace kwargs with selected kwargs @@ -175,7 +175,7 @@ def forward( kwargs = {name: kwargs[kwargs_key] for name, kwargs_key in kwarg_keys.items()} except KeyError as ex: raise Exception( - f"{module_name} wants kwarg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex # FIXME: Add better error message From 2ec4e4984757cffb8bdf841e4765875b379e2351 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 12:30:15 -0700 Subject: [PATCH 12/23] Only set train mode and inference mode on Modules --- mart/nn/nn.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 56ced9b6..b365a8d5 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -188,21 +188,23 @@ def forward( ) from ex # Apply train mode and inference mode, if necessary, and call module with args and kwargs - old_train_mode = self.module.training + context = nullcontext() + if isinstance(self.module, torch.nn.Module) + old_train_mode = self.module.training - if _train_mode_ is not None: - self.module.train(_train_mode_) + if _train_mode_ is not None: + self.module.train(_train_mode_) - context = nullcontext() - if _inference_mode_ is not None: - context = torch.inference_mode(mode=_inference_mode_) + if _inference_mode_ is not None: + context = torch.inference_mode(mode=_inference_mode_) with context: # FIXME: Add better error message ret = self.module(*args, **kwargs) - if _train_mode_ is not None: - self.module.train(old_train_mode) + if isinstance(self.module, torch.nn.Module): + if _train_mode_ is not None: + self.module.train(old_train_mode) # Change returned values into dictionary, if necessary return_keys = _return_as_dict_ or self.return_keys From 741d282d2fb21b4c994781d64a8afab8b30c8f1f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 12:31:07 -0700 Subject: [PATCH 13/23] bugfix --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index b365a8d5..d2053fa8 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -189,7 +189,7 @@ def forward( # Apply train mode and inference mode, if necessary, and call module with args and kwargs context = nullcontext() - if isinstance(self.module, torch.nn.Module) + if isinstance(self.module, torch.nn.Module): old_train_mode = self.module.training if _train_mode_ is not None: From 15c5a5f77f7ff0b4a1a29a8ab377378e2e7dc77d Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 12:41:08 -0700 Subject: [PATCH 14/23] CallWith is not a Module --- mart/nn/nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index c932bc55..1bb4a269 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -8,7 +8,7 @@ import logging from collections import OrderedDict -from typing import Iterable +from typing import Callable, Iterable import torch @@ -123,10 +123,10 @@ def __call__(self, **kwargs): return kwargs -class CallWith(torch.nn.Module): +class CallWith: def __init__( self, - module: torch.nn.Module, + module: Callable, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, **kwarg_keys, @@ -138,7 +138,7 @@ def __init__( self.kwarg_keys = kwarg_keys self.return_keys = _return_as_dict_ - def forward( + def __call__( self, *args, _call_with_args_: Iterable[str] | None = None, From 40dd262fcec6fdbbd6eacbd5f1306d83f25745fa Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 11:29:41 -0700 Subject: [PATCH 15/23] cleanup --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 1bb4a269..9b87d475 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -29,7 +29,7 @@ class SequentialDict(torch.nn.ModuleDict): : _name_: _call_with_args_: - _return_as_dict: + _return_as_dict_: **kwargs All intermediate output from each module are stored in the dictionary `kwargs` in `forward()` From ceaa7443e96384acbd35c6ae4fb512d786acc04b Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 13:00:32 -0700 Subject: [PATCH 16/23] bugfix --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index e593d6b6..d8f5164d 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -236,7 +236,7 @@ def __getitem__(self, key): elif isinstance(value, dict) and subkey in value: value = value[subkey] else: - raise KeyError("No {subkey} in " + ".".join([key, *subkeys])) + raise KeyError(f"No {subkey} in " + ".".join([key, *subkeys])) return value From b8d473b06d98685745ff33dd4e998b6b75d32867 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 16:36:33 -0700 Subject: [PATCH 17/23] fix merge error --- mart/nn/nn.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 1908b9f3..902e13be 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -155,22 +155,26 @@ def __call__( if arg_keys is not None or len(kwarg_keys) > 0: arg_keys = arg_keys or [] - # Check to make sure each str exists within kwargs - str_kwarg_keys = filter(lambda k: isinstance(k, str), kwarg_keys.values()) - for key in remaining_arg_keys + list(str_kwarg_keys): - if key not in kwargs: + # Sometimes we receive positional arguments because some modules use nn.Sequential + # which has a __call__ function that passes positional args. So we pass along args + # as it and assume these consume the first len(args) of arg_keys. + arg_keys = arg_keys[len(args) :] + + # Append kwargs to args using arg_keys + try: + [args.append(kwargs[kwargs_key]) for kwargs_key in arg_keys] + except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex - # For each specified args/kwargs key, lookup its corresponding value in kwargs only if the key is a string. - # Otherwise, we just treat the key as a value. - selected_args = [ - kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :] - ] - selected_kwargs = { - key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items() - } + # Replace kwargs with selected kwargs + try: + kwargs = {name: kwargs[kwargs_key] for name, kwargs_key in kwarg_keys.items()} + except KeyError as ex: + raise Exception( + f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." + ) from ex # FIXME: Add better error message ret = self.module(*args, **kwargs) From e9cf67b20fb226001c8b6e0133bd0bd6629d4eb4 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 16:43:18 -0700 Subject: [PATCH 18/23] Change call special arg names --- mart/nn/nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 902e13be..04a0c8e1 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -139,13 +139,13 @@ def __init__( def __call__( self, *args, - _call_with_args_: Iterable[str] | None = None, - _return_as_dict_: Iterable[str] | None = None, + _args_: Iterable[str] | None = None, + _return_keys_: Iterable[str] | None = None, **kwargs, ): module_name = self.module.__class__.__name__ - arg_keys = _call_with_args_ or self.arg_keys + arg_keys = _args_ or self.arg_keys kwarg_keys = self.kwarg_keys args = list(args) @@ -180,7 +180,7 @@ def __call__( ret = self.module(*args, **kwargs) # Change returned values into dictionary, if necessary - return_keys = _return_as_dict_ or self.return_keys + return_keys = _return_keys_ or self.return_keys if return_keys: if not isinstance(ret, tuple): raise Exception( From 9c7b05e76ad11e185edbea8a9f51ebead5f9742b Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 23 Jun 2023 08:02:13 -0700 Subject: [PATCH 19/23] bugfix --- mart/configs/attack/gain/rcnn_class_background.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mart/configs/attack/gain/rcnn_class_background.yaml b/mart/configs/attack/gain/rcnn_class_background.yaml index 3f668abf..ddeae13d 100644 --- a/mart/configs/attack/gain/rcnn_class_background.yaml +++ b/mart/configs/attack/gain/rcnn_class_background.yaml @@ -6,7 +6,6 @@ module: # Try to classify as background. class_index: 0 targeted: true -arg_keys: +_call_with_args_: - box_head.class_logits - rpn_predictor.boxes -kwarg_keys: null From fab97633e2c6c83f36061f32ea016cea0fbbb6ec Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 23 Jun 2023 09:49:23 -0700 Subject: [PATCH 20/23] fix merge error --- mart/nn/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 04a0c8e1..c39e39f6 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -162,7 +162,7 @@ def __call__( # Append kwargs to args using arg_keys try: - [args.append(kwargs[kwargs_key]) for kwargs_key in arg_keys] + [args.append(kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key) for kwargs_key in arg_keys] except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." @@ -170,7 +170,7 @@ def __call__( # Replace kwargs with selected kwargs try: - kwargs = {name: kwargs[kwargs_key] for name, kwargs_key in kwarg_keys.items()} + kwargs = {name: kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key for name, kwargs_key in kwarg_keys.items()} except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." From f9b43eb53a3ab978fc4f0c35c364aad0d7a193b2 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 23 Jun 2023 09:56:24 -0700 Subject: [PATCH 21/23] cleanup --- mart/nn/nn.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index c39e39f6..4ede1fa6 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -148,27 +148,27 @@ def __call__( arg_keys = _args_ or self.arg_keys kwarg_keys = self.kwarg_keys - args = list(args) - kwargs = DotDict(kwargs) - - # Change and replaces args and kwargs that we call module with + # Change and replace args and kwargs that we call module with if arg_keys is not None or len(kwarg_keys) > 0: arg_keys = arg_keys or [] + kwargs = DotDict(kwargs) # we need to lookup values using dot strings + args = list(args) # tuple -> list + # Sometimes we receive positional arguments because some modules use nn.Sequential # which has a __call__ function that passes positional args. So we pass along args # as it and assume these consume the first len(args) of arg_keys. arg_keys = arg_keys[len(args) :] - # Append kwargs to args using arg_keys + # Extend args with selected kwargs using arg_keys try: - [args.append(kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key) for kwargs_key in arg_keys] + args.extend([kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key for kwargs_key in arg_keys]) except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex - # Replace kwargs with selected kwargs + # Replace kwargs with selected kwargs using kwarg_keys try: kwargs = {name: kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key for name, kwargs_key in kwarg_keys.items()} except KeyError as ex: From 471757d6a0d583c0e26a2e37489a2141f1c62430 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 23 Jun 2023 09:57:51 -0700 Subject: [PATCH 22/23] cleanup --- mart/nn/nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 4ede1fa6..a4dd8947 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -83,7 +83,6 @@ def parse_sequence(self, sequence): return_name = module_cfg.pop("_name_", module_name) module = CallWith(self[module_name], **module_cfg) module_dict[return_name] = module - return module_dict def forward(self, step=None, sequence=None, **kwargs): From 368b87e49c6bebffeaa52aacec0fee2830b9db80 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 23 Jun 2023 14:19:35 -0700 Subject: [PATCH 23/23] style --- mart/nn/nn.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index a4dd8947..084b4f68 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -161,7 +161,12 @@ def __call__( # Extend args with selected kwargs using arg_keys try: - args.extend([kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key for kwargs_key in arg_keys]) + args.extend( + [ + kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key + for kwargs_key in arg_keys + ] + ) except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." @@ -169,7 +174,10 @@ def __call__( # Replace kwargs with selected kwargs using kwarg_keys try: - kwargs = {name: kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key for name, kwargs_key in kwarg_keys.items()} + kwargs = { + name: kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key + for name, kwargs_key in kwarg_keys.items() + } except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}."