From 56014e65d85ca88550ed90e8b7e6f2fa39609df7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:25:36 -0700 Subject: [PATCH 1/5] Replace tuple with Iterable[torch.Tensor] --- mart/attack/adversary_in_art.py | 12 ++++--- mart/attack/adversary_wrapper.py | 11 +++--- mart/attack/callbacks/base.py | 26 +++++++------- mart/attack/composer.py | 27 +++++++++------ mart/attack/enforcer.py | 59 ++++++++++++++------------------ mart/attack/initializer.py | 41 ++++++++++++---------- mart/attack/projector.py | 56 +++++++++++++++++++----------- 7 files changed, 129 insertions(+), 103 deletions(-) diff --git a/mart/attack/adversary_in_art.py b/mart/attack/adversary_in_art.py index 2a993349..d48f669c 100644 --- a/mart/attack/adversary_in_art.py +++ b/mart/attack/adversary_in_art.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional import hydra import numpy @@ -82,17 +82,18 @@ def convert_input_art_to_mart(self, x: numpy.ndarray): x (np.ndarray): NHWC, [0, 1] Returns: - tuple: a tuple of tensors in CHW, [0, 255]. + Iterable[torch.Tensor]: an Iterable of tensors in CHW, [0, 255]. """ input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255 + # FIXME: replace tuple with whatever input's type is input = tuple(inp_ for inp_ in input) return input - def convert_input_mart_to_art(self, input: tuple): + def convert_input_mart_to_art(self, input: Iterable[torch.Tensor]): """Convert MART input to the ART's format. Args: - input (tuple): a tuple of tensors in CHW, [0, 255]. + input (Iterable[torch.Tensor]): an Iterable of tensors in CHW, [0, 255]. Returns: np.ndarray: NHWC, [0, 1] @@ -112,7 +113,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): y_patch_metadata (_type_): _description_ Returns: - tuple: a tuple of target dictionaies. + Iterable[dict[str, Any]]: an Iterable of target dictionaies. """ # Copy y to target, and convert ndarray to pytorch tensors accordingly. target = [] @@ -132,6 +133,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): target_i["file_name"] = f"{yi['image_id'][0]}.jpg" target.append(target_i) + # FIXME: replace tuple with input type? target = tuple(target) return target diff --git a/mart/attack/adversary_wrapper.py b/mart/attack/adversary_wrapper.py index c4b02953..a40ee644 100644 --- a/mart/attack/adversary_wrapper.py +++ b/mart/attack/adversary_wrapper.py @@ -6,10 +6,13 @@ from __future__ import annotations -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Iterable import torch +if TYPE_CHECKING: + from .enforcer import Enforcer + __all__ = ["NormalizedAdversaryAdapter"] @@ -22,7 +25,7 @@ class NormalizedAdversaryAdapter(torch.nn.Module): def __init__( self, adversary: Callable[[Callable], Callable], - enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None], + enforcer: Enforcer, ): """ @@ -37,8 +40,8 @@ def __init__( def forward( self, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module | None = None, **kwargs, ): diff --git a/mart/attack/callbacks/base.py b/mart/attack/callbacks/base.py index 97541ecb..a982aa8e 100644 --- a/mart/attack/callbacks/base.py +++ b/mart/attack/callbacks/base.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable import torch @@ -24,8 +24,8 @@ def on_run_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -35,8 +35,8 @@ def on_examine_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -46,8 +46,8 @@ def on_examine_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -57,8 +57,8 @@ def on_advance_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -68,8 +68,8 @@ def on_advance_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -79,8 +79,8 @@ def on_run_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], model: torch.nn.Module, **kwargs, ): diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 5bc4edb7..ef8f3417 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, Iterable import torch @@ -15,21 +15,28 @@ class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], **kwargs, - ) -> torch.Tensor | tuple: - if isinstance(perturbation, tuple): - input_adv = tuple( + ) -> torch.Tensor | Iterable[torch.Tensor]: + if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): + return self.compose(perturbation, input=input, target=target) + + elif ( + isinstance(perturbation, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + # FIXME: replace tuple with whatever input's type is + return tuple( self.compose(perturbation_i, input=input_i, target=target_i) for perturbation_i, input_i, target_i in zip(perturbation, input, target) ) - else: - input_adv = self.compose(perturbation, input=input, target=target) - return input_adv + else: + raise NotImplementedError @abc.abstractmethod def compose( diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index babc44e6..95e6716b 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, Iterable import torch @@ -95,45 +95,38 @@ def verify(self, input_adv, *, input, target): class Enforcer: - def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: - self.modality_constraints = modality_constraints + def __init__(self, constraints: dict[str, Constraint]) -> None: + self.constraints = list(constraints.values()) # intentionally ignore keys @torch.no_grad() - def _enforce( + def __call__( self, - input_adv: torch.Tensor, + input_adv: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], - modality: str, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + **kwargs, ): - for constraint in self.modality_constraints[modality].values(): - constraint(input_adv, input=input, target=target) + if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor): + self.enforce(input_adv, input=input, target=target) + + elif ( + isinstance(input_adv, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + [ + self.enforce(input_adv_i, input=input_i, target=target_i) + for input_adv_i, input_i, target_i in zip(input_adv, input, target) + ] - def __call__( + @torch.no_grad() + def enforce( self, - input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + input_adv: torch.Tensor, *, - input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + input: torch.Tensor, target: torch.Tensor | dict[str, Any], - modality: str = "constraints", - **kwargs, ): - assert type(input_adv) == type(input) - - if isinstance(input_adv, torch.Tensor): - # Finally we can verify constraints on tensor, per its modality. - # Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities. - self._enforce(input_adv, input=input, target=target, modality=modality) - elif isinstance(input_adv, dict): - # The dict input has modalities specified in keys, passing them recursively. - for modality in input_adv: - self(input_adv[modality], input=input[modality], target=target, modality=modality) - elif isinstance(input_adv, (list, tuple)): - # We assume a modality-dictionary only contains tensors, but not list/tuple. - assert modality == "constraints" - # The list or tuple input is a collection of sub-input and sub-target. - for input_adv_i, input_i, target_i in zip(input_adv, input, target): - self(input_adv_i, input=input_i, target=target_i, modality=modality) - else: - raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.") + for constraint in self.constraints: + constraint(input_adv, input=input, target=target) diff --git a/mart/attack/initializer.py b/mart/attack/initializer.py index cd05c6c6..d66bcf9f 100644 --- a/mart/attack/initializer.py +++ b/mart/attack/initializer.py @@ -4,52 +4,57 @@ # SPDX-License-Identifier: BSD-3-Clause # -import abc -from typing import Optional, Union +from __future__ import annotations -import torch +from typing import Iterable -__all__ = ["Initializer"] +import torch -class Initializer(abc.ABC): +class Initializer: """Initializer base class.""" @torch.no_grad() - @abc.abstractmethod - def __call__(self, perturbation: torch.Tensor) -> None: + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + [self.initialize_(parameter) for parameter in parameters] + + @torch.no_grad() + def initialize_(self, parameter: torch.Tensor) -> None: pass class Constant(Initializer): - def __init__(self, constant: Optional[Union[int, float]] = 0): + def __init__(self, constant: int | float = 0): self.constant = constant @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.constant_(perturbation, self.constant) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.constant_(parameter, self.constant) class Uniform(Initializer): - def __init__(self, min: Union[int, float], max: Union[int, float]): + def __init__(self, min: int | float, max: int | float): self.min = min self.max = max @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.uniform_(perturbation, self.min, self.max) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.uniform_(parameter, self.min, self.max) class UniformLp(Initializer): - def __init__(self, eps: Union[int, float], p: Optional[Union[int, float]] = torch.inf): + def __init__(self, eps: int | float, p: int | float = torch.inf): self.eps = eps self.p = p @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.uniform_(perturbation, -self.eps, self.eps) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.uniform_(parameter, -self.eps, self.eps) # TODO: make sure the first dim is the batch dim. if self.p is not torch.inf: # We don't do tensor.renorm_() because the first dim is not the batch dim. - pert_norm = perturbation.norm(p=self.p) - perturbation.mul_(self.eps / pert_norm) + pert_norm = parameter.norm(p=self.p) + parameter.mul_(self.eps / pert_norm) diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 92391c67..095d2601 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Iterable import torch @@ -17,24 +17,35 @@ class Projector: @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.tensor | dict[str, Any]], **kwargs, ) -> None: - if isinstance(perturbation, tuple): - for perturbation_i, input_i, target_i in zip(perturbation, input, target): - self.project(perturbation_i, input=input_i, target=target_i) + if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): + self.project_(perturbation, input=input, target=target) + + elif ( + isinstance(perturbation, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + [ + self.project_(perturbation_i, input=input_i, target=target_i) + for perturbation_i, input_i, target_i in zip(perturbation, input, target) + ] + else: - self.project(perturbation, input=input, target=target) + raise NotImplementedError - def project( + @torch.no_grad() + def project_( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], ) -> None: pass @@ -48,10 +59,10 @@ def __init__(self, projectors: list[Projector]): @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], **kwargs, ) -> None: for projector in self.projectors: @@ -70,7 +81,8 @@ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | floa self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): if self.quantize: perturbation.round_() perturbation.clamp_(self.min, self.max) @@ -92,7 +104,8 @@ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | floa self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): if self.quantize: perturbation.round_() perturbation.clamp_(self.min - input, self.max - input) @@ -117,7 +130,8 @@ def __init__(self, eps: int | float, p: int | float = torch.inf): self.p = p self.eps = eps - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): pert_norm = perturbation.norm(p=self.p) if pert_norm > self.eps: # We only upper-bound the norm. @@ -133,7 +147,8 @@ def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 25 self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): eps_min = (input - self.eps).clamp(self.min, self.max) - input eps_max = (input + self.eps).clamp(self.min, self.max) - input @@ -141,7 +156,8 @@ def project(self, perturbation, *, input, target): class Mask(Projector): - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): perturbation.mul_(target["perturbable_mask"]) def __repr__(self): From 1c47cc049a7130802521080ee800f4c7ead55dc2 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:41:15 -0700 Subject: [PATCH 2/5] Fix tests --- tests/test_enforcer.py | 54 ++++++++++++++++++++--------------------- tests/test_projector.py | 2 +- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/test_enforcer.py b/tests/test_enforcer.py index 2c56b3ad..e67b1034 100644 --- a/tests/test_enforcer.py +++ b/tests/test_enforcer.py @@ -97,30 +97,30 @@ def test_enforcer_non_modality(): enforcer((input_adv,), input=(input,), target=(target,)) -def test_enforcer_modality(): - # Assume a rgb modality. - enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) - - input = torch.tensor([0, 0, 0]) - perturbation = torch.tensor([0, 128, 255]) - input_adv = input + perturbation - target = None - - # Dictionary input. - enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) - # List of dictionary input. - enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) - # Tuple of dictionary input. - enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) - - perturbation = torch.tensor([0, -1, 255]) - input_adv = input + perturbation - - with pytest.raises(ConstraintViolated): - enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) - - with pytest.raises(ConstraintViolated): - enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) - - with pytest.raises(ConstraintViolated): - enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) +# def test_enforcer_modality(): +# # Assume a rgb modality. +# enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) +# +# input = torch.tensor([0, 0, 0]) +# perturbation = torch.tensor([0, 128, 255]) +# input_adv = input + perturbation +# target = None +# +# # Dictionary input. +# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) +# # List of dictionary input. +# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) +# # Tuple of dictionary input. +# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) +# +# perturbation = torch.tensor([0, -1, 255]) +# input_adv = input + perturbation +# +# with pytest.raises(ConstraintViolated): +# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) +# +# with pytest.raises(ConstraintViolated): +# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) +# +# with pytest.raises(ConstraintViolated): +# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) diff --git a/tests/test_projector.py b/tests/test_projector.py index a397a98c..19cb5c44 100644 --- a/tests/test_projector.py +++ b/tests/test_projector.py @@ -154,7 +154,7 @@ def test_compose(input_data, target_data): ] compose = Compose(projectors) - tensor = Mock() + tensor = Mock(spec=torch.Tensor) tensor.norm.return_value = 10 compose(tensor, input=input_data, target=target_data) From 70cc36ac2c3719c56b8510bc711c9c0766f4fdd7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:49:45 -0700 Subject: [PATCH 3/5] Cleanup --- mart/attack/enforcer.py | 4 +--- mart/attack/projector.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 95e6716b..1c2347c2 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -115,10 +115,8 @@ def __call__( and isinstance(input, Iterable) # noqa: W503 and isinstance(target, Iterable) # noqa: W503 ): - [ + for input_adv_i, input_i, target_i in zip(input_adv, input, target): self.enforce(input_adv_i, input=input_i, target=target_i) - for input_adv_i, input_i, target_i in zip(input_adv, input, target) - ] @torch.no_grad() def enforce( diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 095d2601..9f7c77ac 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -31,10 +31,8 @@ def __call__( and isinstance(input, Iterable) # noqa: W503 and isinstance(target, Iterable) # noqa: W503 ): - [ + for perturbation_i, input_i, target_i in zip(perturbation, input, target): self.project_(perturbation_i, input=input_i, target=target_i) - for perturbation_i, input_i, target_i in zip(perturbation, input, target) - ] else: raise NotImplementedError From 53ee7f4f7a9b9dc1ffc2a17b599af4b53f5813f3 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 14 Apr 2023 14:57:58 -0700 Subject: [PATCH 4/5] Make GradientModifier accept Iterable[torch.Tensor] --- mart/attack/gradient_modifier.py | 36 ++++++++++++++------------------ 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/mart/attack/gradient_modifier.py b/mart/attack/gradient_modifier.py index dd680a95..b2882574 100644 --- a/mart/attack/gradient_modifier.py +++ b/mart/attack/gradient_modifier.py @@ -6,7 +6,6 @@ from __future__ import annotations -import abc from typing import Iterable import torch @@ -14,36 +13,33 @@ __all__ = ["GradientModifier"] -class GradientModifier(abc.ABC): +class GradientModifier: """Gradient modifier base class.""" - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: - pass - - -class Sign(GradientModifier): def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: if isinstance(parameters, torch.Tensor): parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] + [self.modify_(parameter) for parameter in parameters] + + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + pass + - for p in parameters: - p.grad.detach().sign_() +class Sign(GradientModifier): + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + parameter.grad.sign_() class LpNormalizer(GradientModifier): """Scale gradients by a certain L-p norm.""" def __init__(self, p: int | float): - self.p = p - - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - parameters = [p for p in parameters if p.grad is not None] + self.p = float(p) - for p in parameters: - p_norm = torch.norm(p.grad.detach(), p=self.p) - p.grad.detach().div_(p_norm) + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + p_norm = torch.norm(parameter.grad.detach(), p=self.p) + parameter.grad.detach().div_(p_norm) From 10106df7bb216c45001f38df090100540688df4e Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 5 May 2023 11:15:11 -0700 Subject: [PATCH 5/5] Fix annotations --- mart/attack/adversary_wrapper.py | 2 +- mart/attack/callbacks/base.py | 12 ++++++------ mart/attack/composer.py | 2 +- mart/attack/enforcer.py | 2 +- mart/attack/projector.py | 6 +++--- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mart/attack/adversary_wrapper.py b/mart/attack/adversary_wrapper.py index a40ee644..a893f040 100644 --- a/mart/attack/adversary_wrapper.py +++ b/mart/attack/adversary_wrapper.py @@ -41,7 +41,7 @@ def __init__( def forward( self, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module | None = None, **kwargs, ): diff --git a/mart/attack/callbacks/base.py b/mart/attack/callbacks/base.py index a982aa8e..d820f69b 100644 --- a/mart/attack/callbacks/base.py +++ b/mart/attack/callbacks/base.py @@ -25,7 +25,7 @@ def on_run_start( *, adversary: Adversary, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -36,7 +36,7 @@ def on_examine_start( *, adversary: Adversary, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -47,7 +47,7 @@ def on_examine_end( *, adversary: Adversary, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -58,7 +58,7 @@ def on_advance_start( *, adversary: Adversary, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -69,7 +69,7 @@ def on_advance_end( *, adversary: Adversary, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -80,7 +80,7 @@ def on_run_end( *, adversary: Adversary, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): diff --git a/mart/attack/composer.py b/mart/attack/composer.py index ef8f3417..6b40950a 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -18,7 +18,7 @@ def __call__( perturbation: torch.Tensor | Iterable[torch.Tensor], *, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ) -> torch.Tensor | Iterable[torch.Tensor]: if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index 1c2347c2..6a160538 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -104,7 +104,7 @@ def __call__( input_adv: torch.Tensor | Iterable[torch.Tensor], *, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ): if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor): diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 9f7c77ac..f9887354 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -20,7 +20,7 @@ def __call__( perturbation: torch.Tensor | Iterable[torch.Tensor], *, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ) -> None: if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): @@ -43,7 +43,7 @@ def project_( perturbation: torch.Tensor | Iterable[torch.Tensor], *, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], ) -> None: pass @@ -60,7 +60,7 @@ def __call__( perturbation: torch.Tensor | Iterable[torch.Tensor], *, input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ) -> None: for projector in self.projectors: