Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 152 additions & 9 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from typing import TYPE_CHECKING, Any, Iterable

import torch
import torchvision
import torchvision.transforms.functional as F

from mart.utils import pylogger

logger = pylogger.get_pylogger(__name__)

if TYPE_CHECKING:
from .perturber import Perturber
Expand Down Expand Up @@ -104,33 +110,170 @@ def forward(self, perturbation, input, target):
return perturbation, input, target


class Mask(Function):
# TODO: We may decompose Overlay into: perturbation-mask, input-re-mask, additive.
class Overlay(Function):
"""We assume an adversary overlays a patch to the input."""

def __init__(self, *args, key="perturbable_mask", **kwargs):
super().__init__(*args, **kwargs)
self.key = key

def forward(self, perturbation, input, target):
# True is mutable, False is immutable.
mask = target[self.key]

# Convert mask to a Tensor with same torch.dtype and torch.device as input,
# because some data modules (e.g. Armory) gives binary mask.
mask = mask.to(input)

perturbation = perturbation * mask

input = input * (1 - mask) + perturbation
return perturbation, input, target


class Overlay(Function):
"""We assume an adversary overlays a patch to the input."""
class InputFakeClamp(Function):
"""A Clamp operation that preserves gradients."""

def __init__(self, *args, min_val, max_val, **kwargs):
super().__init__(*args, **kwargs)
self.min_val = min_val
self.max_val = max_val

@staticmethod
def fake_clamp(x, *, min_val, max_val):
with torch.no_grad():
x_clamped = x.clamp(min_val, max_val)
diff = x_clamped - x
return x + diff

def forward(self, perturbation, input, target):
input = self.fake_clamp(input, min_val=self.min_val, max_val=self.max_val)
return perturbation, input, target


class PerturbationMask(Function):
def __init__(self, *args, key="perturbable_mask", **kwargs):
super().__init__(*args, **kwargs)
self.key = key

def forward(self, perturbation, input, target):
# True is mutable, False is immutable.
mask = target[self.key]
perturbation = perturbation * mask
return perturbation, input, target

# Convert mask to a Tensor with same torch.dtype and torch.device as input,
# because some data modules (e.g. Armory) gives binary mask.
mask = mask.to(input)

perturbation = perturbation * mask
class PerturbationRectangleCrop(Function):
def __init__(self, *args, coords_key="patch_coords", **kwargs):
super().__init__(*args, **kwargs)
self.coords_key = coords_key

def get_smallest_rectangle_shape(self, input, patch_coords):
"""Get a smallest rectangle that covers the whole patch."""
coords = patch_coords
leading_dims = list(input.shape[:-2])
width = coords[:, 0].max() - coords[:, 0].min()
height = coords[:, 1].max() - coords[:, 1].min()
shape = list(leading_dims) + [height, width]
return shape

def slice_rectangle(self, perturbation, height_patch, width_patch):
"""Slice a rectangle from top-left of the perturbation."""
height_patch_index = torch.tensor(range(height_patch), device=perturbation.device)
width_patch_index = torch.tensor(range(width_patch), device=perturbation.device)
perturbation_patch = perturbation.index_select(-2, height_patch_index).index_select(
-1, width_patch_index
)
return perturbation_patch

input = input * (1 - mask) + perturbation
def forward(self, perturbation, input, target):
coords = target[self.coords_key]
# TODO: Make composers stackable to reuse some Composer.
# The perturbation variable has the same shape as input.
# We slice a small rectangle from top-left of the perturbation variable to compose the patch.
rectangle_shape = self.get_smallest_rectangle_shape(input, coords)
# Assume perturbation is in shape of [N]CHW
height_patch, width_patch = rectangle_shape[-2:]
rectangle_patch = self.slice_rectangle(perturbation, height_patch, width_patch)
return rectangle_patch, input, target


class PerturbationRectanglePad(Function):
def __init__(self, *args, coords_key="patch_coords", rect_coords_key="rect_coords", **kwargs):
super().__init__(*args, **kwargs)
self.coords_key = coords_key
self.rect_coords_key = rect_coords_key

def forward(self, perturbation_patch, input, target):
coords = target[self.coords_key]
height, width = input.shape[-2:]
# Pad rectangle to the same size of input, so that it is almost aligned with the patch.
height_patch, width_patch = perturbation_patch.shape[-2:]
pad_left = min(coords[0, 0], coords[3, 0])
pad_top = min(coords[0, 1], coords[1, 1])
pad_right = width - width_patch - pad_left
pad_bottom = height - height_patch - pad_top

perturbation_padded = F.pad(
img=perturbation_patch,
padding=[pad_left, pad_top, pad_right, pad_bottom],
fill=0,
padding_mode="constant",
)

# Save coords of four corners of the rectangle for later transform.
top_left = [pad_left, pad_top]
top_right = [width - pad_right, pad_top]
bottom_right = [width - pad_right, height - pad_bottom]
bottom_left = [pad_left, height - pad_bottom]
target[self.rect_coords_key] = [top_left, top_right, bottom_right, bottom_left]

return perturbation_padded, input, target


class PerturbationRectanglePerspectiveTransform(Function):
def __init__(self, *args, coords_key="patch_coords", rect_coords_key="rect_coords", **kwargs):
super().__init__(*args, **kwargs)
self.coords_key = coords_key
self.rect_coords_key = rect_coords_key

def forward(self, perturbation_rect, input, target):
coords = target[self.coords_key]
# Perspective transformation: rectangle -> coords.
# Fetch four corners of the rectangle.
startpoints = target[self.rect_coords_key]
endpoints = coords
# TODO: Make interpolation configurable.
perturbation_coords = F.perspective(
img=perturbation_rect,
startpoints=startpoints,
endpoints=endpoints,
interpolation=F.InterpolationMode.BILINEAR,
fill=0,
)
return perturbation_coords, input, target


class PerturbationImageAdditive(Function):
"""Add an image to perturbation if specified."""

def __init__(self, *args, path: str | None = None, scale: int = 1, **kwargs):
super().__init__(*args, **kwargs)

self.image = None
if path is not None:
# This is uint8 [0,255].
self.image = torchvision.io.read_image(path, torchvision.io.ImageReadMode.RGB)
# We shouldn't need scale as we use canonical input format.
self.image = self.image / scale

def forward(self, perturbation, input, target):
if self.image is not None:
image = self.image

if image.shape != perturbation.shape:
logger.info(f"Resizing image from {image.shape} to {perturbation.shape}...")
image = F.resize(image, perturbation.shape[1:])

perturbation = perturbation + image
return perturbation, input, target
5 changes: 5 additions & 0 deletions mart/configs/attack/composer/functions/input_fake_clamp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
input_fake_clamp:
_target_: mart.attack.composer.InputFakeClamp
order: 0
min_val: 0
max_val: 255
4 changes: 0 additions & 4 deletions mart/configs/attack/composer/functions/mask.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pert_image_additive:
_target_: mart.attack.composer.PerturbationImageAdditive
path: null
order: 0
4 changes: 4 additions & 0 deletions mart/configs/attack/composer/functions/pert_mask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pert_mask:
_target_: mart.attack.composer.PerturbationMask
key: perturbable_mask
order: 0
4 changes: 4 additions & 0 deletions mart/configs/attack/composer/functions/pert_rect_crop.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pert_rect_crop:
_target_: mart.attack.composer.PerturbationRectangleCrop
coords_key: patch_coords
order: 0
5 changes: 5 additions & 0 deletions mart/configs/attack/composer/functions/pert_rect_pad.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pert_rect_pad:
_target_: mart.attack.composer.PerturbationRectanglePad
coords_key: patch_coords
rect_coords_key: rect_coords
order: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pert_rect_perspective_transform:
_target_: mart.attack.composer.PerturbationRectanglePerspectiveTransform
order: 0
coords_key: patch_coords
rect_coords_key: rect_coords
25 changes: 25 additions & 0 deletions mart/configs/attack/composer/rect_patch_additive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- default
- functions:
[
pert_rect_crop,
pert_rect_pad,
pert_rect_perspective_transform,
pert_mask,
additive,
input_fake_clamp,
]

functions:
pert_rect_crop:
order: 0
pert_rect_pad:
order: 1
pert_rect_perspective_transform:
order: 2
pert_mask:
order: 3
additive:
order: 4
input_fake_clamp:
order: 5
25 changes: 25 additions & 0 deletions mart/configs/attack/composer/rect_patch_overlay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- default
- functions:
[
pert_rect_crop,
pert_image_additive,
pert_rect_pad,
pert_rect_perspective_transform,
overlay,
input_fake_clamp,
]

functions:
pert_rect_crop:
order: 0
pert_image_additive:
order: 1
pert_rect_pad:
order: 2
pert_rect_perspective_transform:
order: 3
overlay:
order: 4
input_fake_clamp:
order: 5
86 changes: 83 additions & 3 deletions tests/test_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@

import torch

from mart.attack.composer import Additive, Composer, Mask, Overlay
from mart.attack.composer import (
Additive,
Composer,
Overlay,
PerturbationMask,
PerturbationRectangleCrop,
PerturbationRectanglePad,
PerturbationRectanglePerspectiveTransform,
)


def test_additive_composer_forward(input_data, target_data, perturbation):
Expand All @@ -33,15 +41,87 @@ def test_overlay_composer_forward(input_data, target_data, perturbation):
torch.testing.assert_close(output, expected_output, equal_nan=True)


def test_mask_additive_composer_forward():
def test_pert_mask_additive_composer_forward():
input = torch.zeros((2, 2))
perturbation = torch.ones((2, 2))
target = {"perturbable_mask": torch.eye(2)}
expected_output = torch.eye(2)

perturber = Mock(return_value=perturbation)
functions = {"mask": Mask(order=0), "additive": Additive(order=1)}
functions = {"pert_mask": PerturbationMask(order=0), "additive": Additive(order=1)}
composer = Composer(perturber=perturber, functions=functions)

output = composer(input=input, target=target)
torch.testing.assert_close(output, expected_output, equal_nan=True)


def test_pert_rect_crop():
key = "patch_coords"
input = torch.zeros((3, 10, 10))
perturbation = torch.ones_like(input)
fn = PerturbationRectangleCrop(coords_key=key)

# FIXME: four corner points (width, height) of a patch in the order of top-left, top-right, bottom-right, bottom-left.
# A simple square patch.
patch_coords = torch.tensor(((0, 0), (5, 0), (5, 5), (5, 0)))
target = {key: patch_coords}

rect_patch, _input, _target = fn(perturbation, input, target)
assert torch.equal(input, _input)
assert target == _target
assert rect_patch.shape == (3, 5, 5)

# A skew patch.
patch_coords = torch.tensor(((1, 1), (5, 2), (7, 8), (3, 9)))
target = {key: patch_coords}

rect_patch, _input, _target = fn(perturbation, input, target)
assert torch.equal(input, _input)
assert target == _target
assert rect_patch.shape == (3, 8, 6)


def test_pert_rect_pad():
coords_key = "patch_coords"
rect_coords_key = "rect_coords"

rect_patch = torch.ones(3, 5, 5)
patch_coords = torch.tensor(((0, 0), (5, 0), (5, 5), (5, 0)))

input = torch.zeros((3, 10, 10))
target = {coords_key: patch_coords}

fn = PerturbationRectanglePad(coords_key=coords_key, rect_coords_key=rect_coords_key)
pert_padded, _input, _target = fn(rect_patch, input, target)

pert_padded_expected = torch.zeros_like(input)
pert_padded_expected[:, :5, :5] = 1

assert torch.equal(pert_padded_expected, pert_padded)

rect_coords_expected = [[0, 0], [5, 0], [5, 5], [0, 5]]
assert _target[rect_coords_key] == rect_coords_expected


def test_pert_rect_perspective_transform():
coords_key = "patch_coords"
rect_coords_key = "rect_coords"

rect_coords = [[0, 0], [5, 0], [5, 5], [0, 5]]
# Move from top left to bottom right.
patch_coords = torch.tensor(((5, 5), (10, 5), (10, 10), (5, 10)))
target = {coords_key: patch_coords, rect_coords_key: rect_coords}

input = torch.zeros((3, 10, 10))

pert_padded = torch.zeros_like(input)
pert_padded[:, :5, :5] = 1

fn = PerturbationRectanglePerspectiveTransform(
coords_key=coords_key, rect_coords_key=rect_coords_key
)
pert_coords, _input, _target = fn(pert_padded, input, target)
pert_coords_expected = torch.zeros_like(input)
pert_coords_expected[:, 5:, 5:] = 1
# rounding numeric error from the perspective transformation.
assert torch.equal(pert_coords.round(), pert_coords_expected)