diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 6b40950a..524ef9d5 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -10,10 +10,14 @@ from typing import Any, Iterable import torch +import torchvision +import torchvision.transforms as T +import torchvision.transforms.functional as F +from torchvision.transforms.functional import InterpolationMode -class Composer(abc.ABC): - def __call__( +class Composer(torch.nn.Module): + def forward( self, perturbation: torch.Tensor | Iterable[torch.Tensor], *, @@ -24,6 +28,17 @@ def __call__( if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): return self.compose(perturbation, input=input, target=target) + elif ( + isinstance(perturbation, torch.Tensor) + and isinstance(input, Iterable) # noqa: W503 + and isinstance(target, Iterable) # noqa: W503 + ): + # FIXME: replace tuple with whatever input's type is + return tuple( + self.compose(perturbation, input=input_i, target=target_i) + for input_i, target_i in zip(input, target) + ) + elif ( isinstance(perturbation, Iterable) and isinstance(input, Iterable) # noqa: W503 @@ -56,8 +71,13 @@ def compose(self, perturbation, *, input, target): return input + perturbation -class Overlay(Composer): - """We assume an adversary overlays a patch to the input.""" +class Composite(Composer): + """We assume an adversary underlays a patch to the input.""" + + def __init__(self, premultiplied_alpha=False): + super().__init__() + + self.premultiplied_alpha = premultiplied_alpha def compose(self, perturbation, *, input, target): # True is mutable, False is immutable. @@ -67,14 +87,7 @@ def compose(self, perturbation, *, input, target): # because some data modules (e.g. Armory) gives binary mask. mask = mask.to(input) - return input * (1 - mask) + perturbation * mask - - -class MaskAdditive(Composer): - """We assume an adversary adds masked perturbation to the input.""" - - def compose(self, perturbation, *, input, target): - mask = target["perturbable_mask"] - masked_perturbation = perturbation * mask + if not self.premultiplied_alpha: + perturbation = perturbation * mask - return input + masked_perturbation + return input * (1 - mask) + perturbation diff --git a/mart/configs/attack/composer/composite.yaml b/mart/configs/attack/composer/composite.yaml new file mode 100644 index 00000000..c75347ff --- /dev/null +++ b/mart/configs/attack/composer/composite.yaml @@ -0,0 +1 @@ +_target_: mart.attack.composer.Composite diff --git a/mart/configs/attack/composer/mask_additive.yaml b/mart/configs/attack/composer/mask_additive.yaml deleted file mode 100644 index 4bca36f8..00000000 --- a/mart/configs/attack/composer/mask_additive.yaml +++ /dev/null @@ -1 +0,0 @@ -_target_: mart.attack.composer.MaskAdditive diff --git a/mart/configs/attack/composer/overlay.yaml b/mart/configs/attack/composer/overlay.yaml deleted file mode 100644 index 469f7245..00000000 --- a/mart/configs/attack/composer/overlay.yaml +++ /dev/null @@ -1 +0,0 @@ -_target_: mart.attack.composer.Overlay diff --git a/mart/configs/attack/object_detection_mask_adversary.yaml b/mart/configs/attack/object_detection_mask_adversary.yaml index ad99dda0..580493a7 100644 --- a/mart/configs/attack/object_detection_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_mask_adversary.yaml @@ -3,7 +3,7 @@ defaults: - perturber: default - perturber/initializer: constant - perturber/projector: mask_range - - composer: overlay + - composer: composite - /optimizer@optimizer: sgd - gain: rcnn_training_loss - gradient_modifier: sign diff --git a/tests/test_composer.py b/tests/test_composer.py index 9fc15cf8..dbe26d74 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -6,7 +6,7 @@ import torch -from mart.attack.composer import Additive, MaskAdditive, Overlay +from mart.attack.composer import Additive, Composite def test_additive_composer_forward(input_data, target_data, perturbation): @@ -17,22 +17,11 @@ def test_additive_composer_forward(input_data, target_data, perturbation): torch.testing.assert_close(output, expected_output, equal_nan=True) -def test_overlay_composer_forward(input_data, target_data, perturbation): - composer = Overlay() +def test_composite_composer_forward(input_data, target_data, perturbation): + composer = Composite() output = composer(perturbation, input=input_data, target=target_data) mask = target_data["perturbable_mask"] mask = mask.to(input_data) expected_output = input_data * (1 - mask) + perturbation torch.testing.assert_close(output, expected_output, equal_nan=True) - - -def test_mask_additive_composer_forward(): - input = torch.zeros((2, 2)) - perturbation = torch.ones((2, 2)) - target = {"perturbable_mask": torch.eye(2)} - expected_output = torch.eye(2) - - composer = MaskAdditive() - output = composer(perturbation, input=input, target=target) - torch.testing.assert_close(output, expected_output, equal_nan=True)