diff --git a/mart/attack/__init__.py b/mart/attack/__init__.py index f86d5ad8..19870873 100644 --- a/mart/attack/__init__.py +++ b/mart/attack/__init__.py @@ -1,12 +1,11 @@ from .adversary import * from .adversary_in_art import * from .adversary_wrapper import * -from .callbacks import Callback from .composer import * from .enforcer import * from .gain import * from .gradient_modifier import * from .initializer import * -from .objective import Objective +from .objective import * from .perturber import * from .projector import * diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 55c32310..d158f6af 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -6,338 +6,116 @@ from __future__ import annotations -from collections import OrderedDict -from typing import Any +from functools import partial +from itertools import cycle +from typing import TYPE_CHECKING, Any +import pytorch_lightning as pl import torch -from .callbacks import Callback -from .composer import Composer -from .enforcer import Enforcer -from .gain import Gain -from .objective import Objective -from .perturber import BatchPerturber, Perturber +from mart.utils import silent -__all__ = ["Adversary", "Attacker"] +from .perturber import Perturber +if TYPE_CHECKING: + from .enforcer import Enforcer -class AttackerCallbackHookMixin(Callback): - """Define event hooks in the Adversary Loop for callbacks.""" +__all__ = ["Adversary"] - callbacks = {} - def on_run_start(self, **kwargs) -> None: - """Prepare the attack loop state.""" - for _name, callback in self.callbacks.items(): - # FIXME: Skip incomplete callback instance. - # Give access of self to callbacks by `adversary=self`. - callback.on_run_start(**kwargs) - - def on_examine_start(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_examine_start(**kwargs) - - def on_examine_end(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_examine_end(**kwargs) - - def on_advance_start(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_advance_start(**kwargs) - - def on_advance_end(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_advance_end(**kwargs) - - def on_run_end(self, **kwargs) -> None: - for _name, callback in self.callbacks.items(): - callback.on_run_end(**kwargs) - - -class Attacker(AttackerCallbackHookMixin, torch.nn.Module): - """The attack optimization loop. - - This class implements the following loop structure: - - .. code-block:: python - - on_run_start() - - while true: - on_examine_start() - examine() - on_examine_end() - - if not done: - on_advance_start() - advance() - on_advance_end() - else: - break - - on_run_end() - """ +class Adversary(torch.nn.Module): + """An adversary module which generates and applies perturbation to input.""" def __init__( self, *, - perturber: BatchPerturber | Perturber, - composer: Composer, - optimizer: torch.optim.Optimizer, - max_iters: int, - gain: Gain, - objective: Objective | None = None, - callbacks: dict[str, Callback] | None = None, + enforcer: Enforcer, + perturber: Perturber | None = None, + attacker: pl.Trainer | None = None, + **kwargs, ): """_summary_ Args: - perturber (BatchPerturber | Perturber): A module that stores perturbations. - composer (Composer): A module which composes adversarial examples from input and perturbation. - optimizer (torch.optim.Optimizer): A PyTorch optimizer. - max_iters (int): The max number of attack iterations. - gain (Gain): An adversarial gain function, which is a differentiable estimate of adversarial objective. - objective (Objective | None): A function for computing adversarial objective, which returns True or False. Optional. - callbacks (dict[str, Callback] | None): A dictionary of callback objects. Optional. + enforcer (Enforcer): A Callable that enforce constraints on the adversarial input. + perturber (Perturber): A Perturber that manages perturbations. + attacker (Trainer): A PyTorch-Lightning Trainer object used to fit the perturber. """ super().__init__() - self.perturber = perturber - self.composer = composer - self.optimizer_fn = optimizer - - self.max_iters = max_iters - self.callbacks = OrderedDict() - - # Register perturber as callback if it implements Callback interface - if isinstance(self.perturber, Callback): - # FIXME: Use self.perturber.__class__.__name__ as key? - self.callbacks["_perturber"] = self.perturber - - if callbacks is not None: - self.callbacks.update(callbacks) - - self.objective_fn = objective - # self.gain is a tensor. - self.gain_fn = gain - - @property - def done(self) -> bool: - # Reach the max iteration; - if self.cur_iter >= self.max_iters: - return True - - # All adv. examples are found; - if hasattr(self, "found") and bool(self.found.all()) is True: - return True - - # Compatible with models which return None gain when objective is reached. - # TODO: Remove gain==None stopping criteria in all models, - # because the BestPerturbation callback relies on gain to determine which pert is the best. - if self.gain is None: - return True - - return False - - def on_run_start( - self, - *, - adversary: torch.nn.Module, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - super().on_run_start( - adversary=adversary, input=input, target=target, model=model, **kwargs - ) - - # FIXME: We should probably just register IterativeAdversary as a callback. - # Set up the optimizer. - self.cur_iter = 0 - - # param_groups with learning rate and other optim params. - param_groups = self.perturber.parameter_groups() - - self.opt = self.optimizer_fn(param_groups) - - def on_run_end( - self, - *, - adversary: torch.nn.Module, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - super().on_run_end(adversary=adversary, input=input, target=target, model=model, **kwargs) - - # Release optimization resources - del self.opt - - # Disable mixed-precision optimization for attacks, - # since we haven't implemented it yet. - @torch.autocast("cuda", enabled=False) - @torch.autocast("cpu", enabled=False) - def fit( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - - self.on_run_start(adversary=self, input=input, target=target, model=model, **kwargs) - - while True: - try: - self.on_examine_start( - adversary=self, input=input, target=target, model=model, **kwargs - ) - self.examine(input=input, target=target, model=model, **kwargs) - self.on_examine_end( - adversary=self, input=input, target=target, model=model, **kwargs - ) - - # Check the done condition here, so that every update of perturbation is examined. - if not self.done: - self.on_advance_start( - adversary=self, - input=input, - target=target, - model=model, - **kwargs, - ) - self.advance( - input=input, - target=target, - model=model, - **kwargs, - ) - self.on_advance_end( - adversary=self, - input=input, - target=target, - model=model, - **kwargs, - ) - # Update cur_iter at the end so that all hooks get the correct cur_iter. - self.cur_iter += 1 - else: - break - except StopIteration: - break - - self.on_run_end(adversary=self, input=input, target=target, model=model, **kwargs) - - # Make sure we can do autograd. - # Earlier Pytorch Lightning uses no_grad(), but later PL uses inference_mode(): - # https://github.com/Lightning-AI/lightning/pull/12715 - @torch.enable_grad() - @torch.inference_mode(False) - def examine( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - """Examine current perturbation, update self.gain and self.found.""" - - # Clone tensors for autograd, in case it was created in the inference mode. - # FIXME: object detection uses non-pure-tensor data, but it may have cloned somewhere else implicitly? - if isinstance(input, torch.Tensor): - input = input.clone() - if isinstance(target, torch.Tensor): - target = target.clone() + self.attacker = attacker + + if self.attacker is None: + # Enable attack to be late bound in forward + self.attacker = partial( + pl.Trainer, + num_sanity_val_steps=0, + logger=False, + max_epochs=0, + limit_train_batches=kwargs.pop("max_iters", 10), + callbacks=list(kwargs.pop("callbacks", {}).values()), # dict to list of values + enable_model_summary=False, + enable_checkpointing=False, + enable_progress_bar=False, + ) - # Set model as None, because no need to update perturbation. - # Save everything to self.outputs so that callbacks have access to them. - self.outputs = model(input=input, target=target, model=None, **kwargs) - - # Use CallWith to dispatch **outputs. - self.gain = self.gain_fn(**self.outputs) - - # objective_fn is optional, because adversaries may never reach their objective. - if self.objective_fn is not None: - self.found = self.objective_fn(**self.outputs) - if self.gain.shape == torch.Size([]): - # A reduced gain value, not an input-wise gain vector. - self.total_gain = self.gain - else: - # No need to calculate new gradients if adversarial examples are already found. - self.total_gain = self.gain[~self.found].sum() else: - self.total_gain = self.gain.sum() - - # Make sure we can do autograd. - @torch.enable_grad() - @torch.inference_mode(False) - def advance( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module, - **kwargs, - ): - """Run one attack iteration.""" - - self.opt.zero_grad() - - # Do not flip the gain value, because we set maximize=True in optimizer. - self.total_gain.backward() - - self.opt.step() - - def forward( - self, - *, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - **kwargs, - ): - perturbation = self.perturber(input, target) - output = self.composer(perturbation, input=input, target=target) - - return output - - -class Adversary(torch.nn.Module): - """An adversary module which generates and applies perturbation to input.""" - - def __init__(self, *, enforcer: Enforcer, attacker: Attacker | None = None, **kwargs): - """_summary_ - - Args: - enforcer (Enforcer): A module which checks if adversarial examples satisfy constraints. - attacker (Attacker): A trainer-like object that computes attacks. - """ - super().__init__() + # We feed the same batch to the attack every time so we treat each step as an + # attack iteration. As such, attackers must only run for 1 epoch and must limit + # the number of attack steps via limit_train_batches. + assert self.attacker.max_epochs == 0 + assert self.attacker.limit_train_batches > 0 + self.perturber = perturber or Perturber(**kwargs) self.enforcer = enforcer - self.attacker = attacker or Attacker(**kwargs) - def forward( - self, - input: torch.Tensor | tuple, - target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module | None = None, - **kwargs, - ): - # Generate a perturbation only if we have a model. This will update - # the parameters of self.perturber. - if model is not None: - self.attacker.fit(input=input, target=target, model=model, **kwargs) - - # Get perturbation and apply threat model - # The mask projector in perturber may require information from target. - output = self.attacker(input=input, target=target) + @silent() + def forward(self, **batch): + # Adversary lives within a sequence of model. To signal the adversary should attack, one + # must pass a model to attack when calling the adversary. Since we do not know where the + # Adversary lives inside the model, we also need the remaining sequence to be able to + # get a loss. + if "model" in batch and batch["model"] is not None and "sequence" in batch: + self._attack(**batch) + + # Always use perturb the current input. + input_adv = self.perturber(**batch) + + # Enforce constraints after the attack optimization ends. + if "model" in batch and batch["model"] is not None and "sequence" in batch: + self.enforcer(input_adv, **batch) + + return input_adv + + def _attack(self, input, **batch): + batch = {"input": input, **batch} + + # Configure and reset perturber to use batch inputs + self.perturber.configure_perturbation(input) + + # Attack, aka fit a perturbation, for one epoch by cycling over the same input batch. + # We use Trainer.limit_train_batches to control the number of attack iterations. + attacker = self._get_attacker(input) + attacker.fit_loop.max_epochs += 1 + attacker.fit(self.perturber, train_dataloaders=cycle([batch])) + + def _get_attacker(self, input): + if not isinstance(self.attacker, partial): + return self.attacker + + # Convert torch.device to PL accelerator + device = self.perturber.device + + if device.type == "cuda": + accelerator = "gpu" + devices = [device.index] + elif device.type == "cpu": + accelerator = "cpu" + devices = None + else: + accelerator = device.type + devices = [device.index] - if model is not None: - # We only enforce constraints after the attack optimization ends. - self.enforcer(output, input=input, target=target) + self.attacker = self.attacker(accelerator=accelerator, devices=devices) - return output + return self.attacker diff --git a/mart/attack/callbacks/__init__.py b/mart/attack/callbacks/__init__.py index 736f7dd1..7ce8b2cf 100644 --- a/mart/attack/callbacks/__init__.py +++ b/mart/attack/callbacks/__init__.py @@ -1,4 +1,3 @@ -from .base import * from .eval_mode import * from .no_grad_mode import * from .progress_bar import * diff --git a/mart/attack/callbacks/eval_mode.py b/mart/attack/callbacks/eval_mode.py index de5eef75..be3b6397 100644 --- a/mart/attack/callbacks/eval_mode.py +++ b/mart/attack/callbacks/eval_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from .base import Callback +from pytorch_lightning.callbacks import Callback __all__ = ["AttackInEvalMode"] @@ -15,11 +15,11 @@ class AttackInEvalMode(Callback): def __init__(self): self.training_mode_status = None - def on_run_start(self, *, model, **kwargs): + def on_train_start(self, trainer, model): self.training_mode_status = model.training model.train(False) - def on_run_end(self, *, model, **kwargs): + def on_train_end(self, trainer, model): assert self.training_mode_status is not None # Resume the previous training status of the model. diff --git a/mart/attack/callbacks/no_grad_mode.py b/mart/attack/callbacks/no_grad_mode.py index bca4d971..cfb90ead 100644 --- a/mart/attack/callbacks/no_grad_mode.py +++ b/mart/attack/callbacks/no_grad_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from .base import Callback +from pytorch_lightning.callbacks import Callback __all__ = ["ModelParamsNoGrad"] @@ -15,10 +15,10 @@ class ModelParamsNoGrad(Callback): This callback should not change the result. Don't use unless an attack runs faster. """ - def on_run_start(self, *, model, **kwargs): + def on_train_start(self, trainer, model): for param in model.parameters(): param.requires_grad_(False) - def on_run_end(self, *, model, **kwargs): + def on_train_end(self, trainer, model): for param in model.parameters(): param.requires_grad_(True) diff --git a/mart/attack/callbacks/progress_bar.py b/mart/attack/callbacks/progress_bar.py index d175aa5d..564f311c 100644 --- a/mart/attack/callbacks/progress_bar.py +++ b/mart/attack/callbacks/progress_bar.py @@ -5,29 +5,18 @@ # import tqdm - -from .base import Callback +from pytorch_lightning.callbacks import TQDMProgressBar __all__ = ["ProgressBar"] -class ProgressBar(Callback): +class ProgressBar(TQDMProgressBar): """Display progress bar of attack iterations with the gain value.""" - def on_run_start(self, *, adversary, **kwargs): - self.pbar = tqdm.tqdm(total=adversary.max_iters, leave=False, desc="Attack", unit="iter") - - def on_examine_end(self, *, input, adversary, **kwargs): - msg = "" - if hasattr(adversary, "found"): - # there is no adversary.found if adversary.objective_fn() is not defined. - msg += f"found={int(sum(adversary.found))}/{len(input)}, " - - msg += f"avg_gain={float(adversary.gain.mean()):.2f}, " - - self.pbar.set_description(msg) - self.pbar.update(1) + def init_train_tqdm(self): + bar = super().init_train_tqdm() + bar.leave = False + bar.set_description("Attack") + bar.unit = "iter" - def on_run_end(self, **kwargs): - self.pbar.close() - del self.pbar + return bar diff --git a/mart/attack/callbacks/visualizer.py b/mart/attack/callbacks/visualizer.py index d0eb0c58..c1f0e35b 100644 --- a/mart/attack/callbacks/visualizer.py +++ b/mart/attack/callbacks/visualizer.py @@ -6,29 +6,39 @@ import os +from pytorch_lightning.callbacks import Callback from torchvision.transforms import ToPILImage -from .base import Callback - __all__ = ["PerturbedImageVisualizer"] class PerturbedImageVisualizer(Callback): """Save adversarial images as files.""" - def __init__(self, folder): + def __init__(self, folder, modality="rgb"): super().__init__() + # FIXME: This should use the Trainer's logging directory. self.folder = folder + self.modality = modality self.convert = ToPILImage() if not os.path.isdir(self.folder): os.makedirs(self.folder) - def on_run_end(self, *, adversary, input, target, model, **kwargs): - adv_input = adversary(input=input, target=target, model=None, **kwargs) + def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): + # Save input and target for on_train_end + self.input = batch["input"] + self.target = batch["target"] + + def on_train_end(self, trainer, model): + # FIXME: We should really just save this to outputs instead of recomputing adv_input + adv_input = model(input=self.input, target=self.target) - for img, tgt in zip(adv_input, target): + for img, tgt in zip(adv_input, self.target): + # Modality aware. + if isinstance(img, dict): + img = img[self.modality] fname = tgt["file_name"] fpath = os.path.join(self.folder, fname) im = self.convert(img / 255) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index babc44e6..427d4891 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -10,6 +10,9 @@ from typing import Any import torch +from torch import Tensor + +from mart.utils import modality_dispatch __all__ = ["Enforcer"] @@ -21,20 +24,20 @@ class ConstraintViolated(Exception): class Constraint(abc.ABC): def __call__( self, - input_adv: torch.Tensor, + input_adv: Tensor, *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], + input: Tensor, + target: Tensor | dict[str, Any], ) -> None: self.verify(input_adv, input=input, target=target) @abc.abstractmethod def verify( self, - input_adv: torch.Tensor, + input_adv: Tensor, *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], + input: Tensor, + target: Tensor | dict[str, Any], ) -> None: raise NotImplementedError @@ -101,10 +104,10 @@ def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> @torch.no_grad() def _enforce( self, - input_adv: torch.Tensor, + input_adv: Tensor, *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], + input: Tensor, + target: Tensor | dict[str, Any], modality: str, ): for constraint in self.modality_constraints[modality].values(): @@ -112,28 +115,12 @@ def _enforce( def __call__( self, - input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], + input_adv: Tensor | list[Tensor] | list[dict[str, Tensor]], *, - input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor], - target: torch.Tensor | dict[str, Any], - modality: str = "constraints", + input: Tensor | list[Tensor] | list[dict[str, Tensor]], + target: Tensor | dict[str, Any], **kwargs, ): - assert type(input_adv) == type(input) - - if isinstance(input_adv, torch.Tensor): - # Finally we can verify constraints on tensor, per its modality. - # Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities. - self._enforce(input_adv, input=input, target=target, modality=modality) - elif isinstance(input_adv, dict): - # The dict input has modalities specified in keys, passing them recursively. - for modality in input_adv: - self(input_adv[modality], input=input[modality], target=target, modality=modality) - elif isinstance(input_adv, (list, tuple)): - # We assume a modality-dictionary only contains tensors, but not list/tuple. - assert modality == "constraints" - # The list or tuple input is a collection of sub-input and sub-target. - for input_adv_i, input_i, target_i in zip(input_adv, input, target): - self(input_adv_i, input=input_i, target=target_i, modality=modality) - else: - raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.") + modality_dispatch( + self._enforce, input_adv, input=input, target=target, modality="constraints" + ) diff --git a/mart/attack/gradient_modifier.py b/mart/attack/gradient_modifier.py index fcb9b0db..dd680a95 100644 --- a/mart/attack/gradient_modifier.py +++ b/mart/attack/gradient_modifier.py @@ -4,8 +4,10 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import abc -from typing import Union +from typing import Iterable import torch @@ -15,25 +17,33 @@ class GradientModifier(abc.ABC): """Gradient modifier base class.""" - @abc.abstractmethod - def __call__(self, grad: torch.Tensor) -> torch.Tensor: + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: pass class Sign(GradientModifier): - def __call__(self, grad: torch.Tensor) -> torch.Tensor: - return grad.sign() + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + parameters = [p for p in parameters if p.grad is not None] + + for p in parameters: + p.grad.detach().sign_() class LpNormalizer(GradientModifier): """Scale gradients by a certain L-p norm.""" - def __init__(self, p: Union[int, float]): - super().__init__ - + def __init__(self, p: int | float): self.p = p - def __call__(self, grad: torch.Tensor) -> torch.Tensor: - grad_norm = grad.norm(p=self.p) - grad_normalized = grad / grad_norm - return grad_normalized + def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + + parameters = [p for p in parameters if p.grad is not None] + + for p in parameters: + p_norm = torch.norm(p.grad.detach(), p=self.p) + p.grad.detach().div_(p_norm) diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py new file mode 100644 index 00000000..5ba079b8 --- /dev/null +++ b/mart/attack/perturber.py @@ -0,0 +1,205 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from mart.utils import modality_dispatch + +from .gradient_modifier import GradientModifier +from .projector import Projector + +if TYPE_CHECKING: + from .composer import Composer + from .gain import Gain + from .initializer import Initializer + from .objective import Objective + +__all__ = ["Perturber"] + + +class Perturber(pl.LightningModule): + """Peturbation optimization module.""" + + MODALITY_DEFAULT = "default" + + def __init__( + self, + *, + optimizer: Callable, + gain: Gain, + composer: Composer | dict[str, Composer], + initializer: Initializer | dict[str, Initializer], + gradient_modifier: GradientModifier | dict[str, GradientModifier] | None = None, + projector: Projector | dict[str, Projector] | None = None, + objective: Objective | None = None, + optim_params: dict[str, dict[str, Any]] | None = None, + ): + """_summary_ + + Args: + optimizer: A partial of PyTorch optimizer that awaits parameters to optimize. + gain: An adversarial gain function, which is a differentiable estimate of adversarial objective. + composer: A module which composes adversarial input from input and perturbation. Modality-aware. + initializer: To initialize the perturbation. Modality-aware. + gradient_modifier: To modify the gradient of perturbation. Modality-aware. + projector: To project the perturbation into some space. Modality-aware. + objective: A function for computing adversarial objective, which returns True or False. Optional. + optim_params: A dictionary of optimization hyper-parameters. E.g. {"rgb": {"lr": 0.1}}. + """ + super().__init__() + + # Modality-neutral objects. + self.optimizer_fn = optimizer + self.gain_fn = gain + self.objective_fn = objective + + # Replace None with nop(). + gradient_modifier = gradient_modifier or GradientModifier() + projector = projector or Projector() + + # Modality-specific objects. + # Backward compatibility, in case modality is unknown, and not given in input. + if not isinstance(initializer, dict): + initializer = {self.MODALITY_DEFAULT: initializer} + if not isinstance(gradient_modifier, dict): + gradient_modifier = {self.MODALITY_DEFAULT: gradient_modifier} + if not isinstance(projector, dict): + projector = {self.MODALITY_DEFAULT: projector} + if not isinstance(composer, dict): + composer = {self.MODALITY_DEFAULT: composer} + + # Backward compatibility, in case optimization parameters are not given. + if optim_params is None: + optim_params = {modality: {} for modality in initializer.keys()} + + # Modality-specific objects. + self.initializer = initializer + self.gradient_modifier = gradient_modifier + self.projector = projector + self.composer = composer + self.optim_params = optim_params + + self.perturbation = None + + def configure_perturbation(self, input: torch.Tensor | tuple | tuple[dict[str, torch.Tensor]]): + def create_and_initialize(data, *, input, target, modality): + # Though data and target are not used, they are required placeholders for modality_dispatch(). + # TODO: we don't want an integer tensor, but make sure it does not affect mixed precision training. + pert = torch.empty_like(input, dtype=torch.float, requires_grad=True) + self.initializer[modality](pert) + return pert + + # Recursively configure perturbation in tensor. + # Though only input=input is used, we have to fill the placeholders of data and target. + self.perturbation = modality_dispatch( + create_and_initialize, input, input=input, target=None, modality=self.MODALITY_DEFAULT + ) + + def parameter_groups(self): + """Extract parameter groups for optimization from perturbation tensor(s).""" + param_groups = self._parameter_groups(self.perturbation, modality=self.MODALITY_DEFAULT) + return param_groups + + def _parameter_groups(self, pert, *, modality): + """Recursively return parameter groups as a list of dictionaries.""" + + if isinstance(pert, torch.Tensor): + # Return a list of dictionary instead of a dictionary, easier to extend later. + # Add the modality notation so that we can perform gradient modification later. + return [{"params": pert, "modality": modality} | self.optim_params[modality]] + elif isinstance(pert, dict): + param_list = [] + for modality, pert_i in pert.items(): + ret_modality = self._parameter_groups(pert_i, modality=modality) + param_list.extend(ret_modality) + return param_list + elif isinstance(pert, (list, tuple)): + param_list = [] + for pert_i in pert: + ret_i = self._parameter_groups(pert_i, modality=modality) + param_list.extend(ret_i) + return param_list + else: + raise ValueError(f"Unsupported data type of input: {type(pert)}.") + + def project_(self, perturbation, *, input, target, **kwargs): + """In-place projection.""" + modality_dispatch( + self.projector, + perturbation, + input=input, + target=target, + modality=self.MODALITY_DEFAULT, + ) + + def compose(self, perturbation, *, input, target, **kwargs): + return modality_dispatch( + self.composer, perturbation, input=input, target=target, modality=self.MODALITY_DEFAULT + ) + + def configure_optimizers(self): + # parameter_groups is generated from perturbation. + if self.perturbation is None: + raise MisconfigurationException( + "You need to call the configure_perturbation before fit." + ) + return self.optimizer_fn(self.parameter_groups()) + + def training_step(self, batch, batch_idx): + # copy batch since we modify it and it is used internally + batch = batch.copy() + + # We need to evaluate the perturbation against the whole model, so call it normally to get a gain. + model = batch.pop("model") + # When an Adversary takes input from another module in the sequence, we would have to specify kwargs of Adversary, and model would be a required kwarg. + outputs = model(**batch, model=None) + + # FIXME: This should really be just `return outputs`. But this might require a new sequence? + # FIXME: Everything below here should live in the model as modules. + # Use CallWith to dispatch **outputs. + gain = self.gain_fn(**outputs) + + # objective_fn is optional, because adversaries may never reach their objective. + if self.objective_fn is not None: + found = self.objective_fn(**outputs) + + # No need to calculate new gradients if adversarial examples are already found. + if len(gain.shape) > 0: + gain = gain[~found] + + if len(gain.shape) > 0: + gain = gain.sum() + + return gain + + def configure_gradient_clipping( + self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None + ): + # Configuring gradient clipping in pl.Trainer is still useful, so use it. + super().configure_gradient_clipping( + optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm + ) + + for group in optimizer.param_groups: + modality = self.MODALITY_DEFAULT if "modality" not in group else group["modality"] + self.gradient_modifier[modality](group["params"]) + + def forward(self, **batch): + if self.perturbation is None: + raise MisconfigurationException( + "You need to call the configure_perturbation before forward." + ) + + self.project_(self.perturbation, **batch) + input_adv = self.compose(self.perturbation, **batch) + + return input_adv diff --git a/mart/attack/perturber/__init__.py b/mart/attack/perturber/__init__.py deleted file mode 100644 index 60b2b5f6..00000000 --- a/mart/attack/perturber/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .batch import * -from .perturber import * diff --git a/mart/attack/perturber/batch.py b/mart/attack/perturber/batch.py deleted file mode 100644 index 83b306c2..00000000 --- a/mart/attack/perturber/batch.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Copyright (C) 2022 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -from typing import Any, Callable, Dict, Union - -import torch -from hydra.utils import instantiate - -from mart.attack.callbacks import Callback - -from ..gradient_modifier import GradientModifier -from ..initializer import Initializer -from ..projector import Projector -from .perturber import Perturber - -__all__ = ["BatchPerturber"] - - -class BatchPerturber(Callback, torch.nn.Module): - """The batch input could be a list or a NCHW tensor. - - We split input into individual examples and run different perturbers accordingly. - """ - - def __init__( - self, - perturber_factory: Callable[[Initializer, GradientModifier, Projector], Perturber], - *perturber_args, - **perturber_kwargs, - ): - super().__init__() - - self.perturber_factory = perturber_factory - self.perturber_args = perturber_args - self.perturber_kwargs = perturber_kwargs - - # Try to create a perturber using factory and kwargs - assert self.perturber_factory(*self.perturber_args, **self.perturber_kwargs) is not None - - self.perturbers = torch.nn.ModuleDict() - - def parameter_groups(self): - """Return parameters along with optim parameters.""" - params = [] - for perturber in self.perturbers.values(): - params += perturber.parameter_groups() - return params - - def on_run_start(self, adversary, input, target, model, **kwargs): - # Remove old perturbers - # FIXME: Can we do this in on_run_end instead? - self.perturbers.clear() - - # Create new perturber for each item in the batch - for i in range(len(input)): - perturber = self.perturber_factory(*self.perturber_args, **self.perturber_kwargs) - self.perturbers[f"input_{i}_perturber"] = perturber - - # Trigger callback - for i, (input_i, target_i) in enumerate(zip(input, target)): - perturber = self.perturbers[f"input_{i}_perturber"] - if isinstance(perturber, Callback): - perturber.on_run_start( - adversary=adversary, input=input_i, target=target_i, model=model, **kwargs - ) - - def forward(self, input: torch.Tensor, target: Union[torch.Tensor, Dict[str, Any]]) -> None: - output = [] - for i, (input_i, target_i) in enumerate(zip(input, target)): - perturber = self.perturbers[f"input_{i}_perturber"] - ret_i = perturber(input_i, target_i) - output.append(ret_i) - - if isinstance(input, torch.Tensor): - output = torch.stack(output) - else: - output = tuple(output) - - return output diff --git a/mart/attack/perturber/perturber.py b/mart/attack/perturber/perturber.py deleted file mode 100644 index d7eed81f..00000000 --- a/mart/attack/perturber/perturber.py +++ /dev/null @@ -1,101 +0,0 @@ -# -# Copyright (C) 2022 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -from typing import Any, Dict, Optional, Union - -import torch - -from mart.attack.callbacks import Callback - -from ..gradient_modifier import GradientModifier -from ..initializer import Initializer -from ..projector import Projector - -__all__ = ["Perturber"] - - -class Perturber(Callback, torch.nn.Module): - """The base class of perturbers. - - A perturber wraps a nn.Parameter and returns this parameter when called. It also enables one to - specify an initialization for this parameter, how to modify gradients computed on this - parameter, and how to project the values of the parameter. - """ - - def __init__( - self, - initializer: Initializer, - gradient_modifier: Optional[GradientModifier] = None, - projector: Optional[Projector] = None, - **optim_params, - ): - """_summary_ - - Args: - initializer (object): To initialize the perturbation. - gradient_modifier (object): To modify the gradient of perturbation. - projector (object): To project the perturbation into some space. - optim_params Optional[dict]: Optimization parameters such learning rate and momentum for perturbation. - """ - super().__init__() - - self.initializer = initializer - self.gradient_modifier = gradient_modifier - self.projector = projector - self.optim_params = optim_params - - # Pre-occupy the name of the buffer, so that extra_repr() always gets perturbation. - self.register_buffer("perturbation", torch.nn.UninitializedBuffer(), persistent=False) - - def projector_wrapper(perturber_module, args): - if isinstance(perturber_module.perturbation, torch.nn.UninitializedBuffer): - raise ValueError("Perturbation must be initialized") - - input, target = args - return projector(perturber_module.perturbation, input, target) - - # Will be called before forward() is called. - if projector is not None: - self.register_forward_pre_hook(projector_wrapper) - - def on_run_start(self, *, adversary, input, target, model, **kwargs): - # Initialize perturbation. - perturbation = torch.zeros_like(input, requires_grad=True) - - # Register perturbation as a non-persistent buffer even though we will optimize it. This is because it is not - # a parameter of the underlying model but a parameter of the adversary. - self.register_buffer("perturbation", perturbation, persistent=False) - - # A backward hook that will be called when a gradient w.r.t the Tensor is computed. - if self.gradient_modifier is not None: - self.perturbation.register_hook(self.gradient_modifier) - - self.initializer(self.perturbation) - - def parameter_groups(self): - """Return parameters along with the pre-defined optimization parameters. - - Example: `[{"params": perturbation, "lr":0.1, "momentum": 0.9}]` - """ - if "params" in self.optim_params: - raise ValueError( - 'Optimization parameters should not include "params" which will override the actual parameters to be optimized. ' - ) - - return [{"params": self.perturbation} | self.optim_params] - - def forward( - self, input: torch.Tensor, target: Union[torch.Tensor, Dict[str, Any]] - ) -> torch.Tensor: - return self.perturbation - - def extra_repr(self): - perturbation = self.perturbation - - return ( - f"{repr(perturbation)}, initializer={self.initializer}," - f"gradient_modifier={self.gradient_modifier}, projector={self.projector}" - ) diff --git a/mart/attack/projector.py b/mart/attack/projector.py index 4c360688..92391c67 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -4,23 +4,37 @@ # SPDX-License-Identifier: BSD-3-Clause # -import abc -from typing import Any, Dict, List, Optional, Union +from __future__ import annotations -import torch +from typing import Any -__all__ = ["Projector"] +import torch -class Projector(abc.ABC): +class Projector: """A projector modifies nn.Parameter's data.""" @torch.no_grad() def __call__( self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], + perturbation: torch.Tensor | tuple, + *, + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, + **kwargs, + ) -> None: + if isinstance(perturbation, tuple): + for perturbation_i, input_i, target_i in zip(perturbation, input, target): + self.project(perturbation_i, input=input_i, target=target_i) + else: + self.project(perturbation, input=input, target=target) + + def project( + self, + perturbation: torch.Tensor | tuple, + *, + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, ) -> None: pass @@ -28,18 +42,20 @@ def __call__( class Compose(Projector): """Apply a list of perturbation modifier.""" - def __init__(self, projectors: List[Projector]): + def __init__(self, projectors: list[Projector]): self.projectors = projectors @torch.no_grad() def __call__( self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], + perturbation: torch.Tensor | tuple, + *, + input: torch.Tensor | tuple, + target: torch.Tensor | dict[str, Any] | tuple, + **kwargs, ) -> None: for projector in self.projectors: - projector(tensor, input, target) + projector(perturbation, input=input, target=target) def __repr__(self): projector_names = [repr(p) for p in self.projectors] @@ -49,26 +65,15 @@ def __repr__(self): class Range(Projector): """Clamp the perturbation so that the output is range-constrained.""" - def __init__( - self, - quantize: Optional[bool] = False, - min: Optional[Union[int, float]] = 0, - max: Optional[Union[int, float]] = 255, - ): + def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255): self.quantize = quantize self.min = min self.max = max - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: + def project(self, perturbation, *, input, target): if self.quantize: - tensor.round_() - tensor.clamp_(self.min, self.max) + perturbation.round_() + perturbation.clamp_(self.min, self.max) def __repr__(self): return ( @@ -82,26 +87,15 @@ class RangeAdditive(Projector): The projector assumes an additive perturbation threat model. """ - def __init__( - self, - quantize: Optional[bool] = False, - min: Optional[Union[int, float]] = 0, - max: Optional[Union[int, float]] = 255, - ): + def __init__(self, quantize: bool = False, min: int | float = 0, max: int | float = 255): self.quantize = quantize self.min = min self.max = max - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: + def project(self, perturbation, *, input, target): if self.quantize: - tensor.round_() - tensor.clamp_(self.min - input, self.max - input) + perturbation.round_() + perturbation.clamp_(self.min - input, self.max - input) def __repr__(self): return ( @@ -112,7 +106,7 @@ def __repr__(self): class Lp(Projector): """Project perturbations to Lp norm, only if the Lp norm is larger than eps.""" - def __init__(self, eps: float, p: Optional[Union[int, float]] = torch.inf): + def __init__(self, eps: int | float, p: int | float = torch.inf): """_summary_ Args: @@ -123,55 +117,32 @@ def __init__(self, eps: float, p: Optional[Union[int, float]] = torch.inf): self.p = p self.eps = eps - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: - pert_norm = tensor.norm(p=self.p) + def project(self, perturbation, *, input, target): + pert_norm = perturbation.norm(p=self.p) if pert_norm > self.eps: # We only upper-bound the norm. - tensor.mul_(self.eps / pert_norm) + perturbation.mul_(self.eps / pert_norm) class LinfAdditiveRange(Projector): """Make sure the perturbation is within the Linf norm ball, and "input + perturbation" is within the [min, max] range.""" - def __init__( - self, - eps: float, - min: Optional[Union[int, float]] = 0, - max: Optional[Union[int, float]] = 255, - ): + def __init__(self, eps: int | float, min: int | float = 0, max: int | float = 255): self.eps = eps self.min = min self.max = max - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: + def project(self, perturbation, *, input, target): eps_min = (input - self.eps).clamp(self.min, self.max) - input eps_max = (input + self.eps).clamp(self.min, self.max) - input - tensor.clamp_(eps_min, eps_max) + perturbation.clamp_(eps_min, eps_max) class Mask(Projector): - @torch.no_grad() - def __call__( - self, - tensor: torch.Tensor, - input: torch.Tensor, - target: Union[torch.Tensor, Dict[str, Any]], - ) -> None: - tensor.mul_(target["perturbable_mask"]) + def project(self, perturbation, *, input, target): + perturbation.mul_(target["perturbable_mask"]) def __repr__(self): return f"{self.__class__.__name__}()" diff --git a/mart/configs/attack/iterative.yaml b/mart/configs/attack/adversary.yaml similarity index 69% rename from mart/configs/attack/iterative.yaml rename to mart/configs/attack/adversary.yaml index 40b57864..3cbf0aff 100644 --- a/mart/configs/attack/iterative.yaml +++ b/mart/configs/attack/adversary.yaml @@ -1,12 +1,10 @@ _target_: mart.attack.Adversary -# Composition -perturber: ??? -composer: ??? - # Optimization +initializer: ??? +gradient_modifier: ??? +projector: ??? optimizer: ??? max_iters: ??? -callbacks: ??? gain: ??? # Threat model diff --git a/mart/configs/attack/callbacks/progress_bar.yaml b/mart/configs/attack/callbacks/progress_bar.yaml index 21d4c477..e528c714 100644 --- a/mart/configs/attack/callbacks/progress_bar.yaml +++ b/mart/configs/attack/callbacks/progress_bar.yaml @@ -1,2 +1,3 @@ progress_bar: _target_: mart.attack.callbacks.ProgressBar + process_position: 1 diff --git a/mart/configs/attack/classification_eps1.75_fgsm.yaml b/mart/configs/attack/classification_eps1.75_fgsm.yaml index 21420c17..3009a725 100644 --- a/mart/configs/attack/classification_eps1.75_fgsm.yaml +++ b/mart/configs/attack/classification_eps1.75_fgsm.yaml @@ -1,9 +1,9 @@ defaults: - - iterative_sgd - - perturber: default - - perturber/initializer: constant - - perturber/gradient_modifier: sign - - perturber/projector: linf_additive_range + - adversary + - optimizer: sgd + - initializer: constant + - gradient_modifier: sign + - projector: linf_additive_range - objective: misclassification - gain: cross_entropy - composer: additive @@ -20,9 +20,8 @@ optimizer: max_iters: 1 -perturber: - initializer: - constant: 0 +initializer: + constant: 0 - projector: - eps: 1.75 +projector: + eps: 1.75 diff --git a/mart/configs/attack/classification_eps2_pgd10_step1.yaml b/mart/configs/attack/classification_eps2_pgd10_step1.yaml index 37566620..9c6be04b 100644 --- a/mart/configs/attack/classification_eps2_pgd10_step1.yaml +++ b/mart/configs/attack/classification_eps2_pgd10_step1.yaml @@ -1,9 +1,9 @@ defaults: - - iterative_sgd - - perturber: default - - perturber/initializer: uniform_lp - - perturber/gradient_modifier: sign - - perturber/projector: linf_additive_range + - adversary + - optimizer: sgd + - initializer: uniform_lp + - gradient_modifier: sign + - projector: linf_additive_range - objective: misclassification - gain: cross_entropy - composer: additive @@ -20,9 +20,8 @@ optimizer: max_iters: 10 -perturber: - initializer: - eps: 2 +initializer: + eps: 2 - projector: - eps: 2 +projector: + eps: 2 diff --git a/mart/configs/attack/classification_eps8_pgd10_step1.yaml b/mart/configs/attack/classification_eps8_pgd10_step1.yaml index 9eb0dbd1..ab5ff843 100644 --- a/mart/configs/attack/classification_eps8_pgd10_step1.yaml +++ b/mart/configs/attack/classification_eps8_pgd10_step1.yaml @@ -1,9 +1,9 @@ defaults: - - iterative_sgd - - perturber: default - - perturber/initializer: uniform_lp - - perturber/gradient_modifier: sign - - perturber/projector: linf_additive_range + - adversary + - optimizer: sgd + - initializer: uniform_lp + - gradient_modifier: sign + - projector: linf_additive_range - objective: misclassification - gain: cross_entropy - composer: additive @@ -20,9 +20,8 @@ optimizer: max_iters: 10 -perturber: - initializer: - eps: 8 +initializer: + eps: 8 - projector: - eps: 8 +projector: + eps: 8 diff --git a/mart/configs/attack/perturber/gradient_modifier/lp_normalizer.yaml b/mart/configs/attack/gradient_modifier/lp_normalizer.yaml similarity index 100% rename from mart/configs/attack/perturber/gradient_modifier/lp_normalizer.yaml rename to mart/configs/attack/gradient_modifier/lp_normalizer.yaml diff --git a/mart/configs/attack/perturber/gradient_modifier/sign.yaml b/mart/configs/attack/gradient_modifier/sign.yaml similarity index 100% rename from mart/configs/attack/perturber/gradient_modifier/sign.yaml rename to mart/configs/attack/gradient_modifier/sign.yaml diff --git a/mart/configs/attack/perturber/initializer/constant.yaml b/mart/configs/attack/initializer/constant.yaml similarity index 100% rename from mart/configs/attack/perturber/initializer/constant.yaml rename to mart/configs/attack/initializer/constant.yaml diff --git a/mart/configs/attack/perturber/initializer/uniform.yaml b/mart/configs/attack/initializer/uniform.yaml similarity index 100% rename from mart/configs/attack/perturber/initializer/uniform.yaml rename to mart/configs/attack/initializer/uniform.yaml diff --git a/mart/configs/attack/perturber/initializer/uniform_lp.yaml b/mart/configs/attack/initializer/uniform_lp.yaml similarity index 100% rename from mart/configs/attack/perturber/initializer/uniform_lp.yaml rename to mart/configs/attack/initializer/uniform_lp.yaml diff --git a/mart/configs/attack/iterative_sgd.yaml b/mart/configs/attack/iterative_sgd.yaml deleted file mode 100644 index 5ec86235..00000000 --- a/mart/configs/attack/iterative_sgd.yaml +++ /dev/null @@ -1,4 +0,0 @@ -defaults: - - iterative - - optimizer: sgd - - callbacks: [progress_bar] diff --git a/mart/configs/attack/object_detection_mask_adversary.yaml b/mart/configs/attack/object_detection_mask_adversary.yaml index 04a66ebc..e069b45c 100644 --- a/mart/configs/attack/object_detection_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_mask_adversary.yaml @@ -1,10 +1,10 @@ defaults: - - iterative_sgd - - perturber: batch - - perturber/initializer: constant - - perturber/gradient_modifier: sign - - perturber/projector: mask_range - - callbacks: [progress_bar, image_visualizer] + - adversary + - optimizer: sgd + - initializer: constant + - gradient_modifier: sign + - projector: mask_range + - callbacks: [image_visualizer] - objective: zero_ap - gain: rcnn_training_loss - composer: overlay @@ -17,6 +17,5 @@ optimizer: max_iters: 5 -perturber: - initializer: - constant: 127 +initializer: + constant: 127 diff --git a/mart/configs/attack/object_detection_mask_adversary_missed.yaml b/mart/configs/attack/object_detection_mask_adversary_missed.yaml index e44b5342..9cf8657f 100644 --- a/mart/configs/attack/object_detection_mask_adversary_missed.yaml +++ b/mart/configs/attack/object_detection_mask_adversary_missed.yaml @@ -8,6 +8,5 @@ optimizer: max_iters: 100 -perturber: - initializer: - constant: 127 +initializer: + constant: 127 diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml index a2bb039e..1fc21948 100644 --- a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -1,13 +1,13 @@ defaults: - - iterative_sgd - - perturber: batch - - perturber/initializer: constant - - perturber/gradient_modifier: sign - - perturber/projector: mask_range - - callbacks: [progress_bar, image_visualizer] + - adversary + - optimizer: sgd + - initializer@initializer.rgb: constant + - gradient_modifier@gradient_modifier.rgb: sign + - projector@projector.rgb: mask_range + - callbacks: [image_visualizer] - objective: zero_ap - gain: rcnn_training_loss - - composer: overlay + - composer@composer.rgb: overlay - enforcer: default - enforcer/constraints@enforcer.rgb: [mask, pixel_range] @@ -17,6 +17,6 @@ optimizer: max_iters: 5 -perturber: - initializer: +initializer: + rgb: constant: 127 diff --git a/mart/configs/attack/perturber/batch.yaml b/mart/configs/attack/perturber/batch.yaml deleted file mode 100644 index b3ed0634..00000000 --- a/mart/configs/attack/perturber/batch.yaml +++ /dev/null @@ -1,7 +0,0 @@ -_target_: mart.attack.BatchPerturber -perturber_factory: - _target_: mart.attack.Perturber - _partial_: true -initializer: ??? -gradient_modifier: ??? -projector: ??? diff --git a/mart/configs/attack/perturber/default.yaml b/mart/configs/attack/perturber/default.yaml deleted file mode 100644 index 8025bfd5..00000000 --- a/mart/configs/attack/perturber/default.yaml +++ /dev/null @@ -1,4 +0,0 @@ -_target_: mart.attack.Perturber -initializer: ??? -gradient_modifier: ??? -projector: ??? diff --git a/mart/configs/attack/perturber/projector/linf_additive_range.yaml b/mart/configs/attack/projector/linf_additive_range.yaml similarity index 100% rename from mart/configs/attack/perturber/projector/linf_additive_range.yaml rename to mart/configs/attack/projector/linf_additive_range.yaml diff --git a/mart/configs/attack/perturber/projector/lp_additive_range.yaml b/mart/configs/attack/projector/lp_additive_range.yaml similarity index 100% rename from mart/configs/attack/perturber/projector/lp_additive_range.yaml rename to mart/configs/attack/projector/lp_additive_range.yaml diff --git a/mart/configs/attack/perturber/projector/mask_range.yaml b/mart/configs/attack/projector/mask_range.yaml similarity index 100% rename from mart/configs/attack/perturber/projector/mask_range.yaml rename to mart/configs/attack/projector/mask_range.yaml diff --git a/mart/configs/attack/perturber/projector/range.yaml b/mart/configs/attack/projector/range.yaml similarity index 100% rename from mart/configs/attack/perturber/projector/range.yaml rename to mart/configs/attack/projector/range.yaml diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index 0187b99f..8a1dded6 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -68,6 +68,16 @@ test_sequence: losses_and_detections: ["preprocessor", "target"] seq030: + loss: + # Sum up the losses. + [ + "losses_and_detections.training.loss_objectness", + "losses_and_detections.training.loss_rpn_box_reg", + "losses_and_detections.training.loss_classifier", + "losses_and_detections.training.loss_box_reg", + ] + + seq040: output: { "preds": "losses_and_detections.eval", @@ -76,6 +86,7 @@ test_sequence: "rpn_loss.loss_rpn_box_reg": "losses_and_detections.training.loss_rpn_box_reg", "box_loss.loss_classifier": "losses_and_detections.training.loss_classifier", "box_loss.loss_box_reg": "losses_and_detections.training.loss_box_reg", + "loss": "loss", } modules: diff --git a/mart/utils/__init__.py b/mart/utils/__init__.py index 91c84339..4c1ae708 100644 --- a/mart/utils/__init__.py +++ b/mart/utils/__init__.py @@ -1,6 +1,8 @@ from .adapters import * from .export import * +from .modality_dispatch import * from .monkey_patch import * from .pylogger import * from .rich_utils import * +from .silent import * from .utils import * diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py new file mode 100644 index 00000000..6ea6206e --- /dev/null +++ b/mart/utils/modality_dispatch.py @@ -0,0 +1,67 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +from itertools import cycle +from typing import Any, Callable + +import torch +from torch import Tensor + +__all__ = ["modality_dispatch"] + + +def modality_dispatch( + modality_func: Callable | dict[str, Callable], + data: Tensor | tuple | list[Tensor] | dict[str, Tensor], + *, + input: Tensor | tuple | list[Tensor] | dict[str, Tensor], + target: torch.Tensor | dict[str, Any] | list[dict[str, Any]] | None, + modality: str = "default", +): + """Recursively dispatch data and input/target to functions of the same modality. + + The function returns an object that is homomorphic to input and data. + """ + + assert type(data) == type(input) + if target is None: + # Make target zips well with input. + target = cycle([None]) + + if isinstance(input, torch.Tensor): + if isinstance(modality_func, dict): + # A dictionary of Callable indexed by modality. + return modality_func[modality](data, input=input, target=target) + else: + # A Callable with modality=? as a keyword argument. + return modality_func(data, input=input, target=target, modality=modality) + elif isinstance(input, dict): + # The dict input has modalities specified in keys, passing them recursively. + output = {} + for modality in input.keys(): + output[modality] = modality_dispatch( + modality_func, + data[modality], + input=input[modality], + target=target, + modality=modality, + ) + return output + elif isinstance(input, (list, tuple)): + # The list or tuple input is a collection of sub-input and sub-target. + output = [] + for data_i, input_i, target_i in zip(data, input, target): + output_i = modality_dispatch( + modality_func, data_i, input=input_i, target=target_i, modality=modality + ) + output.append(output_i) + if isinstance(input, tuple): + output = tuple(output) + return output + else: + raise ValueError(f"Unsupported data type of input: {type(input)}.") diff --git a/mart/utils/silent.py b/mart/utils/silent.py new file mode 100644 index 00000000..b9cbd1c3 --- /dev/null +++ b/mart/utils/silent.py @@ -0,0 +1,30 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +import logging +from contextlib import ContextDecorator + +__all__ = ["silent"] + + +class silent(ContextDecorator): + """Suppress logging.""" + + DEFAULT_NAMES = ["pytorch_lightning.utilities.rank_zero", "pytorch_lightning.accelerators.gpu"] + + def __init__(self, names=None): + if names is None: + names = silent.DEFAULT_NAMES + + self.loggers = [logging.getLogger(name) for name in names] + + def __enter__(self): + for logger in self.loggers: + logger.propagate = False + + def __exit__(self, exc_type, exc_value, traceback): + for logger in self.loggers: + logger.propagate = False diff --git a/tests/test_adversary.py b/tests/test_adversary.py index c7438458..c14e1569 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -7,70 +7,53 @@ from functools import partial from unittest.mock import Mock +import pytorch_lightning as pl import torch from torch.optim import SGD import mart -from mart.attack import Adversary -from mart.attack.perturber import Perturber +from mart.attack import Adversary, Perturber def test_adversary(input_data, target_data, perturbation): - composer = mart.attack.composer.Additive() enforcer = Mock() - perturber = Mock(return_value=perturbation) - optimizer = Mock() - max_iters = 3 - gain = Mock() + perturber = Mock(return_value=perturbation + input_data) + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) adversary = Adversary( - composer=composer, enforcer=enforcer, perturber=perturber, - optimizer=optimizer, - max_iters=max_iters, - gain=gain, + attacker=attacker, ) - output_data = adversary(input_data, target_data) + output_data = adversary(input=input_data, target=target_data) - optimizer.assert_not_called() - gain.assert_not_called() - perturber.assert_called_once() - # The enforcer is only called when model is not None. + # The enforcer and attacker should only be called when model is not None. enforcer.assert_not_called() + attacker.fit.assert_not_called() + assert attacker.fit_loop.max_epochs == 0 + + perturber.assert_called_once() + torch.testing.assert_close(output_data, input_data + perturbation) def test_adversary_with_model(input_data, target_data, perturbation): - composer = mart.attack.composer.Additive() enforcer = Mock() - initializer = Mock() - parameter_groups = Mock(return_value=[]) - perturber = Mock(return_value=perturbation, parameter_groups=parameter_groups) - optimizer = Mock() - max_iters = 3 - model = Mock(return_value={}) - gain = Mock(return_value=torch.tensor(0.0, requires_grad=True)) + perturber = Mock(return_value=input_data + perturbation) + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) adversary = Adversary( - composer=composer, enforcer=enforcer, perturber=perturber, - optimizer=optimizer, - max_iters=3, - gain=gain, + attacker=attacker, ) - output_data = adversary(input_data, target_data, model=model) + output_data = adversary(input=input_data, target=target_data, model=Mock(), sequence=None) - parameter_groups.assert_called_once() - optimizer.assert_called_once() # The enforcer is only called when model is not None. enforcer.assert_called_once() - # max_iters+1 because Adversary examines one last time - assert gain.call_count == max_iters + 1 - assert model.call_count == max_iters + 1 + attacker.fit.assert_called_once() # Once with model=None to get perturbation. # When model=model, perturber.initialize_parameters() is called. @@ -79,69 +62,47 @@ def test_adversary_with_model(input_data, target_data, perturbation): torch.testing.assert_close(output_data, input_data + perturbation) -def test_adversary_perturber_hidden_params(input_data, target_data): - initializer = Mock() - perturber = Perturber(initializer) - - composer = mart.attack.composer.Additive() +def test_adversary_hidden_params(input_data, target_data, perturbation): enforcer = Mock() - optimizer = Mock() - gain = Mock(return_value=torch.tensor(0.0, requires_grad=True)) - model = Mock(return_value={}) + perturber = Mock(return_value=input_data + perturbation) + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) adversary = Adversary( - composer=composer, enforcer=enforcer, perturber=perturber, - optimizer=optimizer, - max_iters=1, - gain=gain, + attacker=attacker, ) - output_data = adversary(input_data, target_data, model=model) + + output_data = adversary(input=input_data, target=target_data, model=None, sequence=None) # Adversarial perturbation should not be updated by a regular training optimizer. params = [p for p in adversary.parameters()] assert len(params) == 0 - # Adversarial perturbation should not be saved to the model checkpoint. + # Adversarial perturbation should not have any state dict items state_dict = adversary.state_dict() - assert "perturber.perturbation" not in state_dict + assert len(state_dict) == 0 -def test_adversary_perturbation(input_data, target_data): - composer = mart.attack.composer.Additive() +def test_adversary_perturbation(input_data, target_data, perturbation): enforcer = Mock() - optimizer = partial(SGD, lr=1.0, maximize=True) - - def gain(logits): - return logits.mean() - - # Perturbation initialized as zero. - def initializer(x): - torch.nn.init.constant_(x, 0) - - perturber = Perturber(initializer) + perturber = Mock(return_value=input_data + perturbation) + attacker = Mock(max_epochs=0, limit_train_batches=1, fit_loop=Mock(max_epochs=0)) adversary = Adversary( - composer=composer, enforcer=enforcer, perturber=perturber, - optimizer=optimizer, - max_iters=1, - gain=gain, + attacker=attacker, ) - def model(input, target, model=None, **kwargs): - return {"logits": adversary(input, target)} + _ = adversary(input=input_data, target=target_data, model=Mock(), sequence=None) + output_data = adversary(input=input_data, target=target_data) - output1 = adversary(input_data.requires_grad_(), target_data, model=model) - pert1 = perturber.perturbation.clone() - output2 = adversary(input_data.requires_grad_(), target_data, model=model) - pert2 = perturber.perturbation.clone() + # The enforcer is only called when model is not None. + enforcer.assert_called_once() + attacker.fit.assert_called_once() - # The perturbation from multiple runs should be the same. - torch.testing.assert_close(pert1, pert2) + # Once with model and sequence and once without + assert perturber.call_count == 2 - # Simulate a new batch of data of different size. - new_input_data = torch.cat([input_data, input_data]) - output3 = adversary(new_input_data, target_data, model=model) + torch.testing.assert_close(output_data, input_data + perturbation) diff --git a/tests/test_batch.py b/tests/test_batch.py deleted file mode 100644 index d259fb82..00000000 --- a/tests/test_batch.py +++ /dev/null @@ -1,91 +0,0 @@ -# -# Copyright (C) 2022 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -from unittest.mock import Mock, patch - -import pytest -import torch - -from mart.attack.perturber import BatchPerturber, Perturber - - -@pytest.fixture(scope="function") -def perturber_batch(): - # function to mock perturbation - def perturbation(input, target): - return input + torch.ones(*input.shape) - - # setup batch mock - perturber = Mock(name="perturber_mock", spec=Perturber, side_effect=perturbation) - perturber_factory = Mock(return_value=perturber) - - batch = BatchPerturber(perturber_factory) - - return batch - - -@pytest.fixture(scope="function") -def input_data_batch(): - batch_size = 2 - image_size = (3, 32, 32) - - input_data = {} - input_data["image_batch"] = torch.zeros(batch_size, *image_size) - input_data["image_batch_list"] = [torch.zeros(*image_size) for _ in range(batch_size)] - input_data["target"] = {"perturbable_mask": torch.ones(*image_size)} - - return input_data - - -def test_batch_run_start(perturber_batch, input_data_batch): - assert isinstance(perturber_batch, BatchPerturber) - - # start perturber batch - adversary = Mock() - model = Mock() - perturber_batch.on_run_start( - adversary, input_data_batch["image_batch"], input_data_batch["target"], model - ) - - batch_size, _, _, _ = input_data_batch["image_batch"].shape - assert len(perturber_batch.perturbers) == batch_size - - -def test_batch_forward(perturber_batch, input_data_batch): - assert isinstance(perturber_batch, BatchPerturber) - - # start perturber batch - adversary = Mock() - model = Mock() - perturber_batch.on_run_start( - adversary, input_data_batch["image_batch"], input_data_batch["target"], model - ) - - perturbed_images = perturber_batch(input_data_batch["image_batch"], input_data_batch["target"]) - expected = torch.ones(*perturbed_images.shape) - torch.testing.assert_close(perturbed_images, expected) - - -def test_tuple_batch_forward(perturber_batch, input_data_batch): - assert isinstance(perturber_batch, BatchPerturber) - - # start perturber batch - adversary = Mock() - model = Mock() - perturber_batch.on_run_start( - adversary, input_data_batch["image_batch_list"], input_data_batch["target"], model - ) - - perturbed_images = perturber_batch( - input_data_batch["image_batch_list"], input_data_batch["target"] - ) - expected = [ - torch.ones(*input_data_batch["image_batch_list"][0].shape) - for _ in range(len(input_data_batch["image_batch_list"])) - ] - - for output, expected_output in zip(expected, perturbed_images): - torch.testing.assert_close(output, expected_output) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index e71023d4..a4ad49ee 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -11,15 +11,23 @@ def test_gradient_sign(input_data): - gradient = Sign() - output = gradient(input_data) - expected_output = input_data.sign() - torch.testing.assert_close(output, expected_output) + # Don't share input_data with other tests, because the gradient would be changed. + input_data = torch.tensor([1.0, 2.0, 3.0]) + input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) + grad_modifier = Sign() + grad_modifier(input_data) + expected_grad = torch.tensor([-1.0, 1.0, 0.0]) + torch.testing.assert_close(input_data.grad, expected_grad) + + +def test_gradient_lp_normalizer(): + # Don't share input_data with other tests, because the gradient would be changed. + input_data = torch.tensor([1.0, 2.0, 3.0]) + input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) -def test_gradient_lp_normalizer(input_data): p = 1 - gradient = LpNormalizer(p) - output = gradient(input_data) - expected_output = input_data / input_data.norm(p=p) - torch.testing.assert_close(output, expected_output) + grad_modifier = LpNormalizer(p) + grad_modifier(input_data) + expected_grad = torch.tensor([-0.25, 0.75, 0.0]) + torch.testing.assert_close(input_data.grad, expected_grad) diff --git a/tests/test_perturber.py b/tests/test_perturber.py index a3d2d196..96e8954b 100644 --- a/tests/test_perturber.py +++ b/tests/test_perturber.py @@ -4,42 +4,293 @@ # SPDX-License-Identifier: BSD-3-Clause # -import importlib -from unittest.mock import Mock, patch +from functools import partial +from typing import Iterable +from unittest.mock import Mock import pytest import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException +import mart +from mart.attack.initializer import Constant from mart.attack.perturber import Perturber -def test_perturber_repr(input_data, target_data): +def test_configure_perturbation(input_data): initializer = Mock() - gradient_modifier = Mock() projector = Mock() - perturber = Perturber(initializer, gradient_modifier, projector) + composer = Mock() + gain = Mock() - # get additive perturber representation - perturbation = torch.nn.UninitializedBuffer() - expected_repr = ( - f"{repr(perturbation)}, initializer={initializer}," - f"gradient_modifier={gradient_modifier}, projector={projector}" + perturber = Perturber( + initializer=initializer, optimizer=None, composer=composer, projector=projector, gain=gain ) - representation = perturber.extra_repr() - assert expected_repr == representation - # generate again the perturber with an initialized - # perturbation - perturber.on_run_start(adversary=None, input=input_data, target=target_data, model=None) - representation = perturber.extra_repr() - assert expected_repr != representation + perturber.configure_perturbation(input_data) + initializer.assert_called_once() + projector.assert_not_called() + composer.assert_not_called() + gain.assert_not_called() -def test_perturber_forward(input_data, target_data): - initializer = Mock() - perturber = Perturber(initializer) - perturber.on_run_start(adversary=None, input=input_data, target=target_data, model=None) - output = perturber(input_data, target_data) - expected_output = perturber.perturbation - torch.testing.assert_close(output, expected_output, equal_nan=True) +def test_forward(input_data, target_data): + initializer = mart.attack.initializer.Constant(1337) + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock() + + perturber = Perturber( + initializer=initializer, optimizer=None, composer=composer, projector=projector, gain=gain + ) + + perturber.configure_perturbation(input_data) + + for _ in range(2): + output_data = perturber(input=input_data, target=target_data) + + torch.testing.assert_close(output_data, input_data + 1337) + + # perturber needs to project and compose perturbation on every call + assert projector.call_count == 2 + gain.assert_not_called() + + +def test_forward_fails(input_data, target_data): + initializer = mart.attack.initializer.Constant(1337) + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock() + + perturber = Perturber( + initializer=initializer, optimizer=None, composer=composer, projector=projector, gain=gain + ) + + with pytest.raises(MisconfigurationException): + output_data = perturber(input=input_data, target=target_data) + + +def test_configure_optimizers(input_data, target_data): + initializer = mart.attack.initializer.Constant(1337) + optimizer = Mock() + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock() + + perturber = Perturber( + initializer=initializer, + optimizer=optimizer, + composer=composer, + projector=projector, + gain=gain, + ) + + perturber.configure_perturbation(input_data) + + for _ in range(2): + perturber.configure_optimizers() + perturber(input=input_data, target=target_data) + + assert optimizer.call_count == 2 + assert projector.call_count == 2 + gain.assert_not_called() + + +def test_configure_optimizers_fails(): + initializer = mart.attack.initializer.Constant(1337) + optimizer = Mock() + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock() + + perturber = Perturber( + initializer=initializer, + optimizer=optimizer, + composer=composer, + projector=projector, + gain=gain, + ) + + with pytest.raises(MisconfigurationException): + perturber.configure_optimizers() + + +def test_optimizer_parameters_with_gradient(input_data, target_data): + initializer = mart.attack.initializer.Constant(1337) + optimizer = partial(torch.optim.SGD, lr=0) + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock() + + perturber = Perturber( + initializer=initializer, + optimizer=optimizer, + composer=composer, + projector=projector, + gain=gain, + ) + + perturber.configure_perturbation(input_data) + opt = perturber.configure_optimizers() + + # Make sure each parameter in optimizer requires a gradient + for param_group in opt.param_groups: + for param in param_group["params"]: + assert param.requires_grad + + +def test_training_step(input_data, target_data): + initializer = mart.attack.initializer.Constant(1337) + optimizer = Mock() + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock(return_value=torch.tensor(1337)) + model = Mock(return_value={}) + + perturber = Perturber( + initializer=initializer, + optimizer=optimizer, + composer=composer, + projector=projector, + gain=gain, + ) + + output = perturber.training_step( + {"input": input_data, "target": target_data, "model": model}, 0 + ) + + gain.assert_called_once() + assert output == 1337 + + +def test_training_step_with_many_gain(input_data, target_data): + initializer = mart.attack.initializer.Constant(1337) + optimizer = Mock() + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock(return_value=torch.tensor([1234, 5678])) + model = Mock(return_value={}) + + perturber = Perturber( + initializer=initializer, + optimizer=optimizer, + composer=composer, + projector=projector, + gain=gain, + ) + + output = perturber.training_step( + {"input": input_data, "target": target_data, "model": model}, 0 + ) + + assert output == 1234 + 5678 + + +def test_training_step_with_objective(input_data, target_data): + initializer = mart.attack.initializer.Constant(1337) + optimizer = Mock() + projector = Mock() + composer = mart.attack.composer.Additive() + gain = Mock(return_value=torch.tensor([1234, 5678])) + model = Mock(return_value={}) + objective = Mock(return_value=torch.tensor([True, False], dtype=torch.bool)) + + perturber = Perturber( + initializer=initializer, + optimizer=optimizer, + composer=composer, + projector=projector, + objective=objective, + gain=gain, + ) + + output = perturber.training_step( + {"input": input_data, "target": target_data, "model": model}, 0 + ) + + assert output == 5678 + + objective.assert_called_once() + + +def test_configure_gradient_clipping(): + initializer = mart.attack.initializer.Constant(1337) + projector = Mock() + composer = mart.attack.composer.Additive() + optimizer = Mock(param_groups=[{"params": Mock()}, {"params": Mock()}]) + gradient_modifier = Mock() + gain = Mock() + + perturber = Perturber( + optimizer=optimizer, + gradient_modifier=gradient_modifier, + initializer=None, + composer=None, + projector=None, + gain=gain, + ) + # We need to mock a trainer since LightningModule does some checks + perturber.trainer = Mock(gradient_clip_val=1.0, gradient_clip_algorithm="norm") + + perturber.configure_gradient_clipping(optimizer, 0) + + # Once for each parameter in the optimizer + assert gradient_modifier.call_count == 2 + + +def test_perturbation_tensor_to_param_groups(): + input_data = torch.tensor([1.0, 2.0]) + initializer = Constant(constant=0) + + perturber = Perturber(initializer=initializer, optimizer=Mock(), composer=Mock(), gain=Mock()) + + perturber.configure_perturbation(input_data) + pert = perturber.perturbation + assert isinstance(pert, torch.Tensor) + assert pert.shape == pert.shape + assert (pert == 0).all() + + param_groups = perturber.parameter_groups() + assert isinstance(param_groups, Iterable) + assert param_groups[0]["params"].requires_grad + + +def test_perturbation_dict_to_param_groups(): + input_data = {"rgb": torch.tensor([1.0, 2.0]), "depth": torch.tensor([1.0, 2.0])} + initializer = {"rgb": Constant(constant=0), "depth": Constant(constant=1)} + perturber = Perturber(initializer=initializer, optimizer=Mock(), composer=Mock(), gain=Mock()) + + perturber.configure_perturbation(input_data) + pert = perturber.perturbation + assert isinstance(pert, dict) + assert (pert["rgb"] == 0).all() + assert (pert["depth"] == 1).all() + + param_groups = perturber.parameter_groups() + assert len(param_groups) == 2 + param_groups = list(param_groups) + assert param_groups[0]["params"].requires_grad + # assert (param_groups[0]["params"] == 0).all() + + +def test_perturbation_tuple_dict_to_param_groups(): + input_data = ( + {"rgb": torch.tensor([1.0, 2.0]), "depth": torch.tensor([3.0, 4.0])}, + {"rgb": torch.tensor([-1.0, -2.0]), "depth": torch.tensor([-3.0, -4.0])}, + ) + initializer = {"rgb": Constant(constant=0), "depth": Constant(constant=1)} + perturber = Perturber(initializer=initializer, optimizer=Mock(), composer=Mock(), gain=Mock()) + + perturber.configure_perturbation(input_data) + pert = perturber.perturbation + assert isinstance(pert, tuple) + assert (pert[0]["rgb"] == 0).all() + assert (pert[0]["depth"] == 1).all() + assert (pert[1]["rgb"] == 0).all() + assert (pert[1]["depth"] == 1).all() + + param_groups = perturber.parameter_groups() + assert len(param_groups) == 4 + param_groups = list(param_groups) + assert param_groups[0]["params"].requires_grad diff --git a/tests/test_projector.py b/tests/test_projector.py index 9983fa4a..a397a98c 100644 --- a/tests/test_projector.py +++ b/tests/test_projector.py @@ -36,7 +36,7 @@ def test_range_projector_repr(): @pytest.mark.parametrize("max", [10, 100, 110]) def test_range_projector(quantize, min, max, input_data, target_data, perturbation): projector = Range(quantize, min, max) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) assert torch.max(perturbation) <= max assert torch.min(perturbation) >= min @@ -61,7 +61,7 @@ def test_range_additive_projector(quantize, min, max, input_data, target_data, p expected_perturbation = torch.clone(perturbation) projector = RangeAdditive(quantize, min, max) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # modify expected_perturbation if quantize: @@ -78,7 +78,7 @@ def test_lp_projector(eps, p, input_data, target_data, perturbation): expected_perturbation = torch.clone(perturbation) projector = Lp(eps, p) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # modify expected_perturbation pert_norm = expected_perturbation.norm(p=p) @@ -95,7 +95,7 @@ def test_linf_additive_range_projector(min, max, eps, input_data, target_data, p expected_perturbation = torch.clone(perturbation) projector = LinfAdditiveRange(eps, min, max) - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # get expected result eps_min = (input_data - eps).clamp(min, max) - input_data @@ -117,7 +117,7 @@ def test_mask_projector(input_data, target_data, perturbation): expected_perturbation = torch.clone(perturbation) projector = Mask() - projector(perturbation, input_data, target_data) + projector(perturbation, input=input_data, target=target_data) # get expected output expected_perturbation.mul_(target_data["perturbable_mask"]) @@ -156,7 +156,7 @@ def test_compose(input_data, target_data): compose = Compose(projectors) tensor = Mock() tensor.norm.return_value = 10 - compose(tensor, input_data, target_data) + compose(tensor, input=input_data, target=target_data) # RangeProjector, RangeAdditiveProjector, and LinfAdditiveRangeProjector calls `clamp_` assert tensor.clamp_.call_count == 3 diff --git a/tests/test_visualizer.py b/tests/test_visualizer.py index 5a269db2..c4abb4dd 100644 --- a/tests/test_visualizer.py +++ b/tests/test_visualizer.py @@ -19,15 +19,19 @@ def test_visualizer_run_end(input_data, target_data, perturbation, tmp_path): target_list = [target_data] # simulate an addition perturbation - def perturb(input, target, model): + def perturb(input): result = [sample + perturbation for sample in input] return result - model = Mock() + trainer = Mock() + model = Mock(return_value=perturb(input_list)) + outputs = Mock() + batch = {"input": input_list, "target": target_list} adversary = Mock(spec=Adversary, side_effect=perturb) visualizer = PerturbedImageVisualizer(folder) - visualizer.on_run_end(adversary=adversary, input=input_list, target=target_list, model=model) + visualizer.on_train_batch_end(trainer, model, outputs, batch, 0) + visualizer.on_train_end(trainer, model) # verify that the visualizer created the JPG file expected_output_path = folder / target_data["file_name"]