diff --git a/mart/attack/adversary_in_art.py b/mart/attack/adversary_in_art.py index 2a993349..d48f669c 100644 --- a/mart/attack/adversary_in_art.py +++ b/mart/attack/adversary_in_art.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional import hydra import numpy @@ -82,17 +82,18 @@ def convert_input_art_to_mart(self, x: numpy.ndarray): x (np.ndarray): NHWC, [0, 1] Returns: - tuple: a tuple of tensors in CHW, [0, 255]. + Iterable[torch.Tensor]: an Iterable of tensors in CHW, [0, 255]. """ input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255 + # FIXME: replace tuple with whatever input's type is input = tuple(inp_ for inp_ in input) return input - def convert_input_mart_to_art(self, input: tuple): + def convert_input_mart_to_art(self, input: Iterable[torch.Tensor]): """Convert MART input to the ART's format. Args: - input (tuple): a tuple of tensors in CHW, [0, 255]. + input (Iterable[torch.Tensor]): an Iterable of tensors in CHW, [0, 255]. Returns: np.ndarray: NHWC, [0, 1] @@ -112,7 +113,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): y_patch_metadata (_type_): _description_ Returns: - tuple: a tuple of target dictionaies. + Iterable[dict[str, Any]]: an Iterable of target dictionaies. """ # Copy y to target, and convert ndarray to pytorch tensors accordingly. target = [] @@ -132,6 +133,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List): target_i["file_name"] = f"{yi['image_id'][0]}.jpg" target.append(target_i) + # FIXME: replace tuple with input type? target = tuple(target) return target diff --git a/mart/attack/adversary_wrapper.py b/mart/attack/adversary_wrapper.py index c4b02953..a893f040 100644 --- a/mart/attack/adversary_wrapper.py +++ b/mart/attack/adversary_wrapper.py @@ -6,10 +6,13 @@ from __future__ import annotations -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Iterable import torch +if TYPE_CHECKING: + from .enforcer import Enforcer + __all__ = ["NormalizedAdversaryAdapter"] @@ -22,7 +25,7 @@ class NormalizedAdversaryAdapter(torch.nn.Module): def __init__( self, adversary: Callable[[Callable], Callable], - enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None], + enforcer: Enforcer, ): """ @@ -37,8 +40,8 @@ def __init__( def forward( self, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module | None = None, **kwargs, ): diff --git a/mart/attack/callbacks/base.py b/mart/attack/callbacks/base.py index 97541ecb..d820f69b 100644 --- a/mart/attack/callbacks/base.py +++ b/mart/attack/callbacks/base.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable import torch @@ -24,8 +24,8 @@ def on_run_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -35,8 +35,8 @@ def on_examine_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -46,8 +46,8 @@ def on_examine_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -57,8 +57,8 @@ def on_advance_start( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -68,8 +68,8 @@ def on_advance_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): @@ -79,8 +79,8 @@ def on_run_end( self, *, adversary: Adversary, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], model: torch.nn.Module, **kwargs, ): diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 5bc4edb7..6b40950a 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, Iterable import torch @@ -15,21 +15,28 @@ class Composer(abc.ABC): def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, - ) -> torch.Tensor | tuple: - if isinstance(perturbation, tuple): - input_adv = tuple( + ) -> torch.Tensor | Iterable[torch.Tensor]: + if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): + return self.compose(perturbation, input=input, target=target) + + elif ( + isinstance(perturbation, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + # FIXME: replace tuple with whatever input's type is + return tuple( self.compose(perturbation_i, input=input_i, target=target_i) for perturbation_i, input_i, target_i in zip(perturbation, input, target) ) - else: - input_adv = self.compose(perturbation, input=input, target=target) - return input_adv + else: + raise NotImplementedError @abc.abstractmethod def compose( diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index babc44e6..6a160538 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -7,7 +7,7 @@ from __future__ import annotations import abc -from typing import Any +from typing import Any, Iterable import torch @@ -95,45 +95,36 @@ def verify(self, input_adv, *, input, target): class Enforcer: - def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: - self.modality_constraints = modality_constraints + def __init__(self, constraints: dict[str, Constraint]) -> None: + self.constraints = list(constraints.values()) # intentionally ignore keys @torch.no_grad() - def _enforce( + def __call__( self, - input_adv: torch.Tensor, + input_adv: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], - modality: str, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + **kwargs, ): - for constraint in self.modality_constraints[modality].values(): - constraint(input_adv, input=input, target=target) + if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor): + self.enforce(input_adv, input=input, target=target) + + elif ( + isinstance(input_adv, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + for input_adv_i, input_i, target_i in zip(input_adv, input, target): + self.enforce(input_adv_i, input=input_i, target=target_i) - def __call__( + @torch.no_grad() + def enforce( self, - input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + input_adv: torch.Tensor, *, - input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + input: torch.Tensor, target: torch.Tensor | dict[str, Any], - modality: str = "constraints", - **kwargs, ): - assert type(input_adv) == type(input) - - if isinstance(input_adv, torch.Tensor): - # Finally we can verify constraints on tensor, per its modality. - # Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities. - self._enforce(input_adv, input=input, target=target, modality=modality) - elif isinstance(input_adv, dict): - # The dict input has modalities specified in keys, passing them recursively. - for modality in input_adv: - self(input_adv[modality], input=input[modality], target=target, modality=modality) - elif isinstance(input_adv, (list, tuple)): - # We assume a modality-dictionary only contains tensors, but not list/tuple. - assert modality == "constraints" - # The list or tuple input is a collection of sub-input and sub-target. - for input_adv_i, input_i, target_i in zip(input_adv, input, target): - self(input_adv_i, input=input_i, target=target_i, modality=modality) - else: - raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.") + for constraint in self.constraints: + constraint(input_adv, input=input, target=target) diff --git a/mart/attack/gradient_modifier.py b/mart/attack/gradient_modifier.py index dd680a95..b2882574 100644 --- a/mart/attack/gradient_modifier.py +++ b/mart/attack/gradient_modifier.py @@ -6,7 +6,6 @@ from __future__ import annotations -import abc from typing import Iterable import torch @@ -14,36 +13,33 @@ __all__ = ["GradientModifier"] -class GradientModifier(abc.ABC): +class GradientModifier: """Gradient modifier base class.""" - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: - pass - - -class Sign(GradientModifier): def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: if isinstance(parameters, torch.Tensor): parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] + [self.modify_(parameter) for parameter in parameters] + + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + pass + - for p in parameters: - p.grad.detach().sign_() +class Sign(GradientModifier): + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + parameter.grad.sign_() class LpNormalizer(GradientModifier): """Scale gradients by a certain L-p norm.""" def __init__(self, p: int | float): - self.p = p - - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - parameters = [p for p in parameters if p.grad is not None] + self.p = float(p) - for p in parameters: - p_norm = torch.norm(p.grad.detach(), p=self.p) - p.grad.detach().div_(p_norm) + @torch.no_grad() + def modify_(self, parameter: torch.Tensor) -> None: + p_norm = torch.norm(parameter.grad.detach(), p=self.p) + parameter.grad.detach().div_(p_norm) diff --git a/mart/attack/initializer.py b/mart/attack/initializer.py index cd05c6c6..d66bcf9f 100644 --- a/mart/attack/initializer.py +++ b/mart/attack/initializer.py @@ -4,52 +4,57 @@ # SPDX-License-Identifier: BSD-3-Clause # -import abc -from typing import Optional, Union +from __future__ import annotations -import torch +from typing import Iterable -__all__ = ["Initializer"] +import torch -class Initializer(abc.ABC): +class Initializer: """Initializer base class.""" @torch.no_grad() - @abc.abstractmethod - def __call__(self, perturbation: torch.Tensor) -> None: + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + [self.initialize_(parameter) for parameter in parameters] + + @torch.no_grad() + def initialize_(self, parameter: torch.Tensor) -> None: pass class Constant(Initializer): - def __init__(self, constant: Optional[Union[int, float]] = 0): + def __init__(self, constant: int | float = 0): self.constant = constant @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.constant_(perturbation, self.constant) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.constant_(parameter, self.constant) class Uniform(Initializer): - def __init__(self, min: Union[int, float], max: Union[int, float]): + def __init__(self, min: int | float, max: int | float): self.min = min self.max = max @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.uniform_(perturbation, self.min, self.max) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.uniform_(parameter, self.min, self.max) class UniformLp(Initializer): - def __init__(self, eps: Union[int, float], p: Optional[Union[int, float]] = torch.inf): + def __init__(self, eps: int | float, p: int | float = torch.inf): self.eps = eps self.p = p @torch.no_grad() - def __call__(self, perturbation: torch.Tensor) -> None: - torch.nn.init.uniform_(perturbation, -self.eps, self.eps) + def initialize_(self, parameter: torch.Tensor) -> None: + torch.nn.init.uniform_(parameter, -self.eps, self.eps) # TODO: make sure the first dim is the batch dim. if self.p is not torch.inf: # We don't do tensor.renorm_() because the first dim is not the batch dim. - pert_norm = perturbation.norm(p=self.p) - perturbation.mul_(self.eps / pert_norm) + pert_norm = parameter.norm(p=self.p) + parameter.mul_(self.eps / pert_norm) diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 92391c67..f9887354 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Iterable import torch @@ -17,24 +17,33 @@ class Projector: @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ) -> None: - if isinstance(perturbation, tuple): + if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): + self.project_(perturbation, input=input, target=target) + + elif ( + isinstance(perturbation, Iterable) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): for perturbation_i, input_i, target_i in zip(perturbation, input, target): - self.project(perturbation_i, input=input_i, target=target_i) + self.project_(perturbation_i, input=input_i, target=target_i) + else: - self.project(perturbation, input=input, target=target) + raise NotImplementedError - def project( + @torch.no_grad() + def project_( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], ) -> None: pass @@ -48,10 +57,10 @@ def __init__(self, projectors: list[Projector]): @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | tuple, + perturbation: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ) -> None: for projector in self.projectors: @@ -70,7 +79,8 @@ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | floa self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): if self.quantize: perturbation.round_() perturbation.clamp_(self.min, self.max) @@ -92,7 +102,8 @@ def __init__(self, quantize: bool = False, min: int | float = 0, max: int | floa self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): if self.quantize: perturbation.round_() perturbation.clamp_(self.min - input, self.max - input) @@ -117,7 +128,8 @@ def __init__(self, eps: int | float, p: int | float = torch.inf): self.p = p self.eps = eps - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): pert_norm = perturbation.norm(p=self.p) if pert_norm > self.eps: # We only upper-bound the norm. @@ -133,7 +145,8 @@ def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 25 self.min = min self.max = max - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): eps_min = (input - self.eps).clamp(self.min, self.max) - input eps_max = (input + self.eps).clamp(self.min, self.max) - input @@ -141,7 +154,8 @@ def project(self, perturbation, *, input, target): class Mask(Projector): - def project(self, perturbation, *, input, target): + @torch.no_grad() + def project_(self, perturbation, *, input, target): perturbation.mul_(target["perturbable_mask"]) def __repr__(self): diff --git a/tests/test_enforcer.py b/tests/test_enforcer.py index 2c56b3ad..e67b1034 100644 --- a/tests/test_enforcer.py +++ b/tests/test_enforcer.py @@ -97,30 +97,30 @@ def test_enforcer_non_modality(): enforcer((input_adv,), input=(input,), target=(target,)) -def test_enforcer_modality(): - # Assume a rgb modality. - enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) - - input = torch.tensor([0, 0, 0]) - perturbation = torch.tensor([0, 128, 255]) - input_adv = input + perturbation - target = None - - # Dictionary input. - enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) - # List of dictionary input. - enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) - # Tuple of dictionary input. - enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) - - perturbation = torch.tensor([0, -1, 255]) - input_adv = input + perturbation - - with pytest.raises(ConstraintViolated): - enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) - - with pytest.raises(ConstraintViolated): - enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) - - with pytest.raises(ConstraintViolated): - enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) +# def test_enforcer_modality(): +# # Assume a rgb modality. +# enforcer = Enforcer(rgb={"range": Range(min=0, max=255)}) +# +# input = torch.tensor([0, 0, 0]) +# perturbation = torch.tensor([0, 128, 255]) +# input_adv = input + perturbation +# target = None +# +# # Dictionary input. +# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) +# # List of dictionary input. +# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) +# # Tuple of dictionary input. +# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) +# +# perturbation = torch.tensor([0, -1, 255]) +# input_adv = input + perturbation +# +# with pytest.raises(ConstraintViolated): +# enforcer({"rgb": input_adv}, input={"rgb": input}, target=target) +# +# with pytest.raises(ConstraintViolated): +# enforcer([{"rgb": input_adv}], input=[{"rgb": input}], target=[target]) +# +# with pytest.raises(ConstraintViolated): +# enforcer(({"rgb": input_adv},), input=({"rgb": input},), target=(target,)) diff --git a/tests/test_projector.py b/tests/test_projector.py index a397a98c..19cb5c44 100644 --- a/tests/test_projector.py +++ b/tests/test_projector.py @@ -154,7 +154,7 @@ def test_compose(input_data, target_data): ] compose = Compose(projectors) - tensor = Mock() + tensor = Mock(spec=torch.Tensor) tensor.norm.return_value = 10 compose(tensor, input=input_data, target=target_data)