diff --git a/mart/attack/__init__.py b/mart/attack/__init__.py index f86d5ad8..1cc0c32b 100644 --- a/mart/attack/__init__.py +++ b/mart/attack/__init__.py @@ -1,6 +1,6 @@ from .adversary import * from .adversary_in_art import * -from .adversary_wrapper import * +from .attacker_wrapper import * from .callbacks import Callback from .composer import * from .enforcer import * diff --git a/mart/attack/adversary_wrapper.py b/mart/attack/attacker_wrapper.py similarity index 52% rename from mart/attack/adversary_wrapper.py rename to mart/attack/attacker_wrapper.py index c4b02953..7ff0df14 100644 --- a/mart/attack/adversary_wrapper.py +++ b/mart/attack/attacker_wrapper.py @@ -10,54 +10,56 @@ import torch -__all__ = ["NormalizedAdversaryAdapter"] +__all__ = ["NormalizedAttackerAdapter"] -class NormalizedAdversaryAdapter(torch.nn.Module): +class NormalizedAttackerAdapter(torch.nn.Module): """A wrapper for running external classification adversaries in MART. - External adversaries commonly take input of NCWH-[0,1] and return input_adv in the same format. + External attack algorithms commonly take input of NCWH-[0,1] and return input_adv in the same + format. """ def __init__( self, - adversary: Callable[[Callable], Callable], - enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None], + attacker: Callable[[Callable], Callable], ): """ Args: - adversary (functools.partial): A partial of an adversary object which awaits model. - enforcer (Callable): Enforcing constraints of an adversary. + attacker (functools.partial): A partial of an attacker object which awaits a model. """ super().__init__() - self.adversary = adversary - self.enforcer = enforcer + self.attacker = attacker + self.input_adv = None def forward( self, + *, input: torch.Tensor | tuple, target: torch.Tensor | dict[str, Any] | tuple, - model: torch.nn.Module | None = None, - **kwargs, ): - - # Shortcut. Input is already updated in the attack loop. - if model is None: + # Return adversarial input if it is already updated in the attack loop. + if self.input_adv is None: return input + else: + return self.input_adv + def fit(self, *, input, target, model, **kwargs): # Input NCHW [0,1]; Output logits. def model_wrapper(x): output = model(input=x * 255, target=target, model=None, **kwargs) logits = output["logits"] return logits - attack = self.adversary(model_wrapper) + attack = self.attacker(model_wrapper) input_adv = attack(input / 255, target) # Round to integer, in case of imprecise scaling. input_adv = (input_adv * 255).round() - self.enforcer(input_adv, input=input, target=target) + + # Save to return later in forward(). + self.input_adv = input_adv return input_adv diff --git a/mart/configs/attack/classification_autoattack.yaml b/mart/configs/attack/classification_autoattack.yaml index bae6a815..e6f77c5a 100644 --- a/mart/configs/attack/classification_autoattack.yaml +++ b/mart/configs/attack/classification_autoattack.yaml @@ -2,26 +2,30 @@ defaults: - enforcer: default - enforcer/constraints: [lp, pixel_range] -_target_: mart.attack.NormalizedAdversaryAdapter -adversary: - _target_: mart.utils.adapters.PartialInstanceWrapper - partial: - _target_: autoattack.AutoAttack - _partial_: true - # AutoAttack needs to specify device for PyTorch tensors: cpu/cuda - # We can not use ${trainer.accelerator} because the vocabulary is different: cpu/gpu - # device: cpu - norm: Linf - # 8/255 - eps: 0.03137254901960784 - version: custom - attacks_to_run: - - apgd-dlr - wrapper: - _target_: mart.utils.adapters.CallableAdapter - _partial_: true - redirecting_fn: run_standard_evaluation +_target_: mart.attack.Adversary + enforcer: constraints: lp: eps: 8 + +attacker: + _target_: mart.attack.NormalizedAttackerAdapter + attacker: + _target_: mart.utils.adapters.PartialInstanceWrapper + partial: + _target_: autoattack.AutoAttack + _partial_: true + # AutoAttack needs to specify device for PyTorch tensors: cpu/cuda + # We can not use ${trainer.accelerator} because the vocabulary is different: cpu/gpu + # device: cpu + norm: Linf + # 8/255 + eps: 0.03137254901960784 + version: custom + attacks_to_run: + - apgd-dlr + wrapper: + _target_: mart.utils.adapters.CallableAdapter + _partial_: true + redirecting_fn: run_standard_evaluation diff --git a/tests/test_experiments.py b/tests/test_experiments.py index d128c1df..3451f2c3 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -158,7 +158,7 @@ def test_cifar10_cnn_autoattack_experiment(classification_cfg, tmp_path): "++datamodule.train_dataset.num_classes=10", "fit=false", "+attack@model.modules.input_adv_test=classification_autoattack", - '+model.modules.input_adv_test.adversary.partial.device="cpu"', + '+model.modules.input_adv_test.attacker.attacker.partial.device="cpu"', "+trainer.limit_test_batches=1", ] + overrides run_sh_command(command)