diff --git a/mart/attack/composer.py b/mart/attack/composer.py index cd6300f7..7747f958 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -11,6 +11,12 @@ from typing import TYPE_CHECKING, Any, Iterable import torch +import torchvision +import torchvision.transforms.functional as F + +from mart.utils import pylogger + +logger = pylogger.get_pylogger(__name__) if TYPE_CHECKING: from .perturber import Perturber @@ -104,33 +110,170 @@ def forward(self, perturbation, input, target): return perturbation, input, target -class Mask(Function): +# TODO: We may decompose Overlay into: perturbation-mask, input-re-mask, additive. +class Overlay(Function): + """We assume an adversary overlays a patch to the input.""" + def __init__(self, *args, key="perturbable_mask", **kwargs): super().__init__(*args, **kwargs) self.key = key def forward(self, perturbation, input, target): + # True is mutable, False is immutable. mask = target[self.key] + + # Convert mask to a Tensor with same torch.dtype and torch.device as input, + # because some data modules (e.g. Armory) gives binary mask. + mask = mask.to(input) + perturbation = perturbation * mask + + input = input * (1 - mask) + perturbation return perturbation, input, target -class Overlay(Function): - """We assume an adversary overlays a patch to the input.""" +class InputFakeClamp(Function): + """A Clamp operation that preserves gradients.""" + + def __init__(self, *args, min_val, max_val, **kwargs): + super().__init__(*args, **kwargs) + self.min_val = min_val + self.max_val = max_val + + @staticmethod + def fake_clamp(x, *, min_val, max_val): + with torch.no_grad(): + x_clamped = x.clamp(min_val, max_val) + diff = x_clamped - x + return x + diff + + def forward(self, perturbation, input, target): + input = self.fake_clamp(input, min_val=self.min_val, max_val=self.max_val) + return perturbation, input, target + +class PerturbationMask(Function): def __init__(self, *args, key="perturbable_mask", **kwargs): super().__init__(*args, **kwargs) self.key = key def forward(self, perturbation, input, target): - # True is mutable, False is immutable. mask = target[self.key] + perturbation = perturbation * mask + return perturbation, input, target - # Convert mask to a Tensor with same torch.dtype and torch.device as input, - # because some data modules (e.g. Armory) gives binary mask. - mask = mask.to(input) - perturbation = perturbation * mask +class PerturbationRectangleCrop(Function): + def __init__(self, *args, coords_key="patch_coords", **kwargs): + super().__init__(*args, **kwargs) + self.coords_key = coords_key + + def get_smallest_rectangle_shape(self, input, patch_coords): + """Get a smallest rectangle that covers the whole patch.""" + coords = patch_coords + leading_dims = list(input.shape[:-2]) + width = coords[:, 0].max() - coords[:, 0].min() + height = coords[:, 1].max() - coords[:, 1].min() + shape = list(leading_dims) + [height, width] + return shape + + def slice_rectangle(self, perturbation, height_patch, width_patch): + """Slice a rectangle from top-left of the perturbation.""" + height_patch_index = torch.tensor(range(height_patch), device=perturbation.device) + width_patch_index = torch.tensor(range(width_patch), device=perturbation.device) + perturbation_patch = perturbation.index_select(-2, height_patch_index).index_select( + -1, width_patch_index + ) + return perturbation_patch - input = input * (1 - mask) + perturbation + def forward(self, perturbation, input, target): + coords = target[self.coords_key] + # TODO: Make composers stackable to reuse some Composer. + # The perturbation variable has the same shape as input. + # We slice a small rectangle from top-left of the perturbation variable to compose the patch. + rectangle_shape = self.get_smallest_rectangle_shape(input, coords) + # Assume perturbation is in shape of [N]CHW + height_patch, width_patch = rectangle_shape[-2:] + rectangle_patch = self.slice_rectangle(perturbation, height_patch, width_patch) + return rectangle_patch, input, target + + +class PerturbationRectanglePad(Function): + def __init__(self, *args, coords_key="patch_coords", rect_coords_key="rect_coords", **kwargs): + super().__init__(*args, **kwargs) + self.coords_key = coords_key + self.rect_coords_key = rect_coords_key + + def forward(self, perturbation_patch, input, target): + coords = target[self.coords_key] + height, width = input.shape[-2:] + # Pad rectangle to the same size of input, so that it is almost aligned with the patch. + height_patch, width_patch = perturbation_patch.shape[-2:] + pad_left = min(coords[0, 0], coords[3, 0]) + pad_top = min(coords[0, 1], coords[1, 1]) + pad_right = width - width_patch - pad_left + pad_bottom = height - height_patch - pad_top + + perturbation_padded = F.pad( + img=perturbation_patch, + padding=[pad_left, pad_top, pad_right, pad_bottom], + fill=0, + padding_mode="constant", + ) + + # Save coords of four corners of the rectangle for later transform. + top_left = [pad_left, pad_top] + top_right = [width - pad_right, pad_top] + bottom_right = [width - pad_right, height - pad_bottom] + bottom_left = [pad_left, height - pad_bottom] + target[self.rect_coords_key] = [top_left, top_right, bottom_right, bottom_left] + + return perturbation_padded, input, target + + +class PerturbationRectanglePerspectiveTransform(Function): + def __init__(self, *args, coords_key="patch_coords", rect_coords_key="rect_coords", **kwargs): + super().__init__(*args, **kwargs) + self.coords_key = coords_key + self.rect_coords_key = rect_coords_key + + def forward(self, perturbation_rect, input, target): + coords = target[self.coords_key] + # Perspective transformation: rectangle -> coords. + # Fetch four corners of the rectangle. + startpoints = target[self.rect_coords_key] + endpoints = coords + # TODO: Make interpolation configurable. + perturbation_coords = F.perspective( + img=perturbation_rect, + startpoints=startpoints, + endpoints=endpoints, + interpolation=F.InterpolationMode.BILINEAR, + fill=0, + ) + return perturbation_coords, input, target + + +class PerturbationImageAdditive(Function): + """Add an image to perturbation if specified.""" + + def __init__(self, *args, path: str | None = None, scale: int = 1, **kwargs): + super().__init__(*args, **kwargs) + + self.image = None + if path is not None: + # This is uint8 [0,255]. + self.image = torchvision.io.read_image(path, torchvision.io.ImageReadMode.RGB) + # We shouldn't need scale as we use canonical input format. + self.image = self.image / scale + + def forward(self, perturbation, input, target): + if self.image is not None: + image = self.image + + if image.shape != perturbation.shape: + logger.info(f"Resizing image from {image.shape} to {perturbation.shape}...") + image = F.resize(image, perturbation.shape[1:]) + + perturbation = perturbation + image return perturbation, input, target diff --git a/mart/configs/attack/composer/functions/input_fake_clamp.yaml b/mart/configs/attack/composer/functions/input_fake_clamp.yaml new file mode 100644 index 00000000..764f08d4 --- /dev/null +++ b/mart/configs/attack/composer/functions/input_fake_clamp.yaml @@ -0,0 +1,5 @@ +input_fake_clamp: + _target_: mart.attack.composer.InputFakeClamp + order: 0 + min_val: 0 + max_val: 255 diff --git a/mart/configs/attack/composer/functions/mask.yaml b/mart/configs/attack/composer/functions/mask.yaml deleted file mode 100644 index 04cefaf8..00000000 --- a/mart/configs/attack/composer/functions/mask.yaml +++ /dev/null @@ -1,4 +0,0 @@ -mask: - _target_: mart.attack.composer.Mask - key: perturbable_mask - order: 0 diff --git a/mart/configs/attack/composer/functions/pert_image_additive.yaml b/mart/configs/attack/composer/functions/pert_image_additive.yaml new file mode 100644 index 00000000..c6e5d9a0 --- /dev/null +++ b/mart/configs/attack/composer/functions/pert_image_additive.yaml @@ -0,0 +1,4 @@ +pert_image_additive: + _target_: mart.attack.composer.PerturbationImageAdditive + path: null + order: 0 diff --git a/mart/configs/attack/composer/functions/pert_mask.yaml b/mart/configs/attack/composer/functions/pert_mask.yaml new file mode 100644 index 00000000..c3adb784 --- /dev/null +++ b/mart/configs/attack/composer/functions/pert_mask.yaml @@ -0,0 +1,4 @@ +pert_mask: + _target_: mart.attack.composer.PerturbationMask + key: perturbable_mask + order: 0 diff --git a/mart/configs/attack/composer/functions/pert_rect_crop.yaml b/mart/configs/attack/composer/functions/pert_rect_crop.yaml new file mode 100644 index 00000000..73d4f4b1 --- /dev/null +++ b/mart/configs/attack/composer/functions/pert_rect_crop.yaml @@ -0,0 +1,4 @@ +pert_rect_crop: + _target_: mart.attack.composer.PerturbationRectangleCrop + coords_key: patch_coords + order: 0 diff --git a/mart/configs/attack/composer/functions/pert_rect_pad.yaml b/mart/configs/attack/composer/functions/pert_rect_pad.yaml new file mode 100644 index 00000000..ce45ec19 --- /dev/null +++ b/mart/configs/attack/composer/functions/pert_rect_pad.yaml @@ -0,0 +1,5 @@ +pert_rect_pad: + _target_: mart.attack.composer.PerturbationRectanglePad + coords_key: patch_coords + rect_coords_key: rect_coords + order: 0 diff --git a/mart/configs/attack/composer/functions/pert_rect_perspective_transform.yaml b/mart/configs/attack/composer/functions/pert_rect_perspective_transform.yaml new file mode 100644 index 00000000..0be6f331 --- /dev/null +++ b/mart/configs/attack/composer/functions/pert_rect_perspective_transform.yaml @@ -0,0 +1,5 @@ +pert_rect_perspective_transform: + _target_: mart.attack.composer.PerturbationRectanglePerspectiveTransform + order: 0 + coords_key: patch_coords + rect_coords_key: rect_coords diff --git a/mart/configs/attack/composer/rect_patch_additive.yaml b/mart/configs/attack/composer/rect_patch_additive.yaml new file mode 100644 index 00000000..ba5a2218 --- /dev/null +++ b/mart/configs/attack/composer/rect_patch_additive.yaml @@ -0,0 +1,25 @@ +defaults: + - default + - functions: + [ + pert_rect_crop, + pert_rect_pad, + pert_rect_perspective_transform, + pert_mask, + additive, + input_fake_clamp, + ] + +functions: + pert_rect_crop: + order: 0 + pert_rect_pad: + order: 1 + pert_rect_perspective_transform: + order: 2 + pert_mask: + order: 3 + additive: + order: 4 + input_fake_clamp: + order: 5 diff --git a/mart/configs/attack/composer/rect_patch_overlay.yaml b/mart/configs/attack/composer/rect_patch_overlay.yaml new file mode 100644 index 00000000..c773d747 --- /dev/null +++ b/mart/configs/attack/composer/rect_patch_overlay.yaml @@ -0,0 +1,25 @@ +defaults: + - default + - functions: + [ + pert_rect_crop, + pert_image_additive, + pert_rect_pad, + pert_rect_perspective_transform, + overlay, + input_fake_clamp, + ] + +functions: + pert_rect_crop: + order: 0 + pert_image_additive: + order: 1 + pert_rect_pad: + order: 2 + pert_rect_perspective_transform: + order: 3 + overlay: + order: 4 + input_fake_clamp: + order: 5 diff --git a/tests/test_composer.py b/tests/test_composer.py index 3da44a27..3aba543c 100644 --- a/tests/test_composer.py +++ b/tests/test_composer.py @@ -8,7 +8,15 @@ import torch -from mart.attack.composer import Additive, Composer, Mask, Overlay +from mart.attack.composer import ( + Additive, + Composer, + Overlay, + PerturbationMask, + PerturbationRectangleCrop, + PerturbationRectanglePad, + PerturbationRectanglePerspectiveTransform, +) def test_additive_composer_forward(input_data, target_data, perturbation): @@ -33,15 +41,87 @@ def test_overlay_composer_forward(input_data, target_data, perturbation): torch.testing.assert_close(output, expected_output, equal_nan=True) -def test_mask_additive_composer_forward(): +def test_pert_mask_additive_composer_forward(): input = torch.zeros((2, 2)) perturbation = torch.ones((2, 2)) target = {"perturbable_mask": torch.eye(2)} expected_output = torch.eye(2) perturber = Mock(return_value=perturbation) - functions = {"mask": Mask(order=0), "additive": Additive(order=1)} + functions = {"pert_mask": PerturbationMask(order=0), "additive": Additive(order=1)} composer = Composer(perturber=perturber, functions=functions) output = composer(input=input, target=target) torch.testing.assert_close(output, expected_output, equal_nan=True) + + +def test_pert_rect_crop(): + key = "patch_coords" + input = torch.zeros((3, 10, 10)) + perturbation = torch.ones_like(input) + fn = PerturbationRectangleCrop(coords_key=key) + + # FIXME: four corner points (width, height) of a patch in the order of top-left, top-right, bottom-right, bottom-left. + # A simple square patch. + patch_coords = torch.tensor(((0, 0), (5, 0), (5, 5), (5, 0))) + target = {key: patch_coords} + + rect_patch, _input, _target = fn(perturbation, input, target) + assert torch.equal(input, _input) + assert target == _target + assert rect_patch.shape == (3, 5, 5) + + # A skew patch. + patch_coords = torch.tensor(((1, 1), (5, 2), (7, 8), (3, 9))) + target = {key: patch_coords} + + rect_patch, _input, _target = fn(perturbation, input, target) + assert torch.equal(input, _input) + assert target == _target + assert rect_patch.shape == (3, 8, 6) + + +def test_pert_rect_pad(): + coords_key = "patch_coords" + rect_coords_key = "rect_coords" + + rect_patch = torch.ones(3, 5, 5) + patch_coords = torch.tensor(((0, 0), (5, 0), (5, 5), (5, 0))) + + input = torch.zeros((3, 10, 10)) + target = {coords_key: patch_coords} + + fn = PerturbationRectanglePad(coords_key=coords_key, rect_coords_key=rect_coords_key) + pert_padded, _input, _target = fn(rect_patch, input, target) + + pert_padded_expected = torch.zeros_like(input) + pert_padded_expected[:, :5, :5] = 1 + + assert torch.equal(pert_padded_expected, pert_padded) + + rect_coords_expected = [[0, 0], [5, 0], [5, 5], [0, 5]] + assert _target[rect_coords_key] == rect_coords_expected + + +def test_pert_rect_perspective_transform(): + coords_key = "patch_coords" + rect_coords_key = "rect_coords" + + rect_coords = [[0, 0], [5, 0], [5, 5], [0, 5]] + # Move from top left to bottom right. + patch_coords = torch.tensor(((5, 5), (10, 5), (10, 10), (5, 10))) + target = {coords_key: patch_coords, rect_coords_key: rect_coords} + + input = torch.zeros((3, 10, 10)) + + pert_padded = torch.zeros_like(input) + pert_padded[:, :5, :5] = 1 + + fn = PerturbationRectanglePerspectiveTransform( + coords_key=coords_key, rect_coords_key=rect_coords_key + ) + pert_coords, _input, _target = fn(pert_padded, input, target) + pert_coords_expected = torch.zeros_like(input) + pert_coords_expected[:, 5:, 5:] = 1 + # rounding numeric error from the perspective transformation. + assert torch.equal(pert_coords.round(), pert_coords_expected)