Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mart/attack/perturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,7 @@ def forward(self, **batch):
)
)

return {"input_adv": input_adv, "total_variation": total_variation}
targets = batch["target"]["target"]
targets = torch.zeros_like(targets)

return {"input_adv": input_adv, "total_variation": total_variation, "targets": targets}
5 changes: 3 additions & 2 deletions mart/callbacks/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def on_train_end(self, trainer, model):


class PerturbationVisualizer(Callback):
def __init__(self, frequency: int = 100):
def __init__(self, frequency: int = 100, pixel_scale: float = 1.0):
self.frequency = frequency
self.pixel_scale = pixel_scale

def log_perturbation(self, trainer, pl_module):
# FIXME: Generalize this by using DotDict?
perturbation = pl_module.model.perturber.perturbation
perturbation = pl_module.model.perturber.perturbation / self.pixel_scale

# Add image to each logger
for logger in trainer.loggers:
Expand Down
36 changes: 36 additions & 0 deletions mart/configs/datamodule/coco_yolox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defaults:
- default.yaml

num_workers: 2

train_dataset:
_target_: yolox.data.datasets.coco.COCODataset
data_dir: ${paths.data_dir}/coco
name: train2017
json_file: instances_train2017.json
img_size: ???
preproc:
_target_: yolox.data.TrainTransform
max_labels: 50
flip_prob: 0.5
hsv_prob: 1.0

val_dataset:
_target_: yolox.data.datasets.coco.COCODataset
data_dir: ${paths.data_dir}/coco
name: val2017
json_file: instances_val2017.json
img_size: ${..train_dataset.img_size}
preproc:
# Use TrainTransform instead of ValTransform since it supplies targets
# we can use to measure.
_target_: yolox.data.TrainTransform
max_labels: 100
flip_prob: 0.
hsv_prob: 0.

test_dataset: ${.val_dataset}

collate_fn:
_target_: hydra.utils.get_method
path: mart.datamodules.coco.collate_yolox_fn
34 changes: 34 additions & 0 deletions mart/configs/experiment/COCO_YOLOX.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# @package _global_

defaults:
- override /datamodule: coco_yolox
- override /model: yolox
- override /metric: average_precision
- override /optimization: super_convergence

task_name: "COCO_YOLOX"
tags: ["evaluation"]

optimized_metric: "test_metrics/map"

trainer:
# 117,266 training images, 6 epochs, batch_size=16, 43,974.75
max_steps: 43975
# FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms().
precision: 32

datamodule:
num_workers: 32
ims_per_batch: 16

train_dataset:
img_size: [416, 416]

model:
# YOLOX model does not produce preds/targets in training sequence
training_metrics: null

optimizer:
lr: 0.001
momentum: 0.9
weight_decay: 0.0005
118 changes: 118 additions & 0 deletions mart/configs/experiment/COCO_YOLOX_ShapeShifter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# @package _global_

defaults:
- /attack/perturber@model.modules.perturber: default
- /attack/perturber/initializer@model.modules.perturber.initializer: uniform
- /attack/perturber/composer@model.modules.perturber.composer: color_jitter_warp_overlay
- /attack/perturber/projector@model.modules.perturber.projector: range
- /attack/gradient_modifier@model.gradient_modifier: lp_normalizer
- override /optimization: super_convergence
- override /datamodule: coco_yolox
- override /model: yolox
- override /metric: average_precision
- override /callbacks: [perturbation_visualizer, lr_monitor, gradient_monitor]

task_name: "COCO_YOLOX_ShapeShifter"
tags: ["adv"]

optimized_metric: "test_metrics/map"

trainer:
# 118287 training images, batch_size=16, FLOOR(118287/16) = 7392
max_steps: 73920 # 10 epochs
# mAP can be slow to compute so limit number of images
limit_val_batches: 100
limit_test_batches: 100
precision: 32

callbacks:
perturbation_visualizer:
frequency: 500
pixel_scale: 255

datamodule:
num_workers: 32
ims_per_batch: 16

train_dataset:
img_size: [416, 416]

model:
modules:
perturber:
size: [3, 416, 234]

initializer:
min: 127
max: 129

composer:
warp:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.RandomErasing
p: 0.75
scale: [0.2, 0.7]
ratio: [0.3, 3.3]
- _target_: torchvision.transforms.RandomAffine
degrees: [-5, 5]
scale: [0.3, 0.5]
shear: [-3, 3, -3, 3]
interpolation: 2 # BILINEAR
clamp: [0, 255]
brightness: [0.5, 1.5]
contrast: [0.5, 1.5]
saturation: [0.5, 1.5]
hue: [-0.05, 0.05]
pixel_scale: 255
loss:
weights: [1, 0.00001] # minimize total_loss and total variation

freeze: "losses_or_predictions"

optimizer:
lr: 25.5
momentum: 0.9

gradient_modifier: null

training_metrics: null

training_sequence:
seq005: perturber
seq010:
losses_or_predictions:
x: perturber.input_adv
targets: perturber.targets
seq020:
loss:
- losses_or_predictions.total_loss
- perturber.total_variation

seq030:
output:
total_variation: perturber.total_variation

training_step_log:
- loss
- total_loss
- iou_loss
- l1_loss
- conf_loss
- cls_loss
- num_fg
- total_variation

validation_sequence:
seq005: perturber
seq010:
losses_or_predictions:
x: perturber.input_adv
targets: perturber.targets

test_sequence:
seq005: perturber
seq010:
losses_or_predictions:
x: perturber.input_adv
targets: perturber.targets
84 changes: 84 additions & 0 deletions mart/configs/model/yolox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
defaults:
- modular

modules:
losses_or_predictions:
_target_: yolox.models.build.create_yolox_model
name: "yolov3"
pretrained: True
num_classes: 80
device: "cpu" # PL handles this

loss:
_target_: mart.nn.Sum

detections:
_target_: mart.models.yolox.Detections
conf_thre: 0.1
nms_thre: 0.4

output:
_target_: mart.nn.ReturnKwargs

training_metrics: null

training_sequence:
seq010:
losses_or_predictions:
x: input
targets: target

seq020:
loss:
- losses_or_predictions.total_loss

seq030:
output:
loss: loss
total_loss: losses_or_predictions.total_loss
iou_loss: losses_or_predictions.iou_loss
l1_loss: losses_or_predictions.l1_loss
conf_loss: losses_or_predictions.conf_loss
cls_loss: losses_or_predictions.cls_loss
num_fg: losses_or_predictions.num_fg

validation_sequence:
seq010:
losses_or_predictions:
x: input
targets: target.target

seq020:
detections:
predictions: losses_or_predictions
targets: target.target

seq030:
output:
preds: detections.preds
target: detections.target

test_sequence:
seq010:
losses_or_predictions:
x: input
targets: target.target

seq020:
detections:
predictions: losses_or_predictions
targets: target.target

seq030:
output:
preds: detections.preds
target: detections.target

training_step_log:
- loss
- total_loss
- iou_loss
- l1_loss
- conf_loss
- cls_loss
- num_fg
7 changes: 7 additions & 0 deletions mart/datamodules/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, List, Optional

import numpy as np
from torch.utils.data import default_collate
from torchvision.datasets.coco import CocoDetection as CocoDetection_
from torchvision.datasets.folder import default_loader

Expand Down Expand Up @@ -89,3 +90,9 @@ def __getitem__(self, index: int):
# Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203
def collate_fn(batch):
return tuple(zip(*batch))


def collate_yolox_fn(batch):
batch = default_collate(batch)
image, target, *_ = batch
return image, {"target": target}
70 changes: 70 additions & 0 deletions mart/models/yolox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: BSD-3-Clause
#

import torch
from yolox.utils import postprocess


class Detections(torch.nn.Module):
def __init__(self, num_classes=80, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
super().__init__()

self.num_classes = num_classes
self.conf_thre = conf_thre
self.nms_thre = nms_thre
self.class_agnostic = class_agnostic

@staticmethod
def cxcywh2xyxy(bboxes):
bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
return bboxes

@staticmethod
def tensor_to_dict(detection):
if detection is None:
# Handle images with no detections
boxes = torch.empty((0, 4), device="cuda") # HACK
labels = torch.empty((0,), device="cuda") # HACK
scores = torch.empty((0,), device="cuda") # HACK

elif detection.shape[1] > 5:
# (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
boxes = detection[:, 0:4]
labels = detection[:, 6].to(int)
scores = detection[:, 4] * detection[:, 5]

else: # targets have no scores
# [class, xc, yc, w, h]
boxes = detection[:, 1:5]
boxes = Detections.cxcywh2xyxy(boxes)
labels = detection[:, 0].to(int)
scores = torch.ones_like(labels)

length = (labels > 0).sum()

boxes = boxes[:length]
labels = labels[:length]
scores = scores[:length]

return {"boxes": boxes, "labels": labels, "scores": scores}

def forward(self, predictions, targets):
detections = postprocess(
predictions,
self.num_classes,
conf_thre=self.conf_thre,
nms_thre=self.nms_thre,
class_agnostic=self.class_agnostic,
)

# Convert preds and targets to format acceptable to torchmetrics
preds = [Detections.tensor_to_dict(det) for det in detections]
targets = [Detections.tensor_to_dict(tar) for tar in targets]

return {"preds": preds, "target": targets}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [

# ----- object detection----- #
"pycocotools ~= 2.0.5",
"yolox @ git+https://github.com/Megvii-BaseDetection/YOLOX.git@0.3.0",

# -------- Adversary ---------#
"robustbench @ git+https://github.com/RobustBench/robustbench.git@9a590683b7daecf963244dea402529f0d728c727",
Expand Down