From 263adf3d099e94fb89148c0b965de5100bf2b96c Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 12 Jun 2023 15:56:55 -0700 Subject: [PATCH 01/62] Don't require output module with SequentialDict --- mart/nn/nn.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 93b0f07f..754e8657 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -49,10 +49,6 @@ class SequentialDict(torch.nn.ModuleDict): """ def __init__(self, modules, sequences=None): - - if "output" not in modules: - raise ValueError("Modules must have an module named 'output'") - super().__init__(modules) self._sequences = { @@ -121,7 +117,8 @@ def forward(self, step=None, sequence=None, **kwargs): # Pop the executed module to proceed with the sequence sequence.popitem(last=False) - return kwargs["output"] + # return kwargs as DotDict + return DotDict(kwargs) class ReturnKwargs(torch.nn.Module): From cd5d7897bea4e34385cd566695f72abf2d400fa0 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 11:58:03 -0700 Subject: [PATCH 02/62] CallWith passes non-str arguments directly to module --- mart/nn/nn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 754e8657..a5b52ce7 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -108,6 +108,7 @@ def forward(self, step=None, sequence=None, **kwargs): # Don't pop the first element yet, because it may be used to re-evaluate the model. key, module = next(iter(sequence.items())) + # FIXME: Add better error message output = module(step=step, sequence=sequence, **kwargs) if key in kwargs: @@ -149,14 +150,16 @@ def forward(self, *args, **kwargs): # as it and assume these consume the first len(args) of arg_keys. remaining_arg_keys = arg_keys[len(args) :] - for key in remaining_arg_keys + list(kwarg_keys.values()): + # Check to make sure each str exists within kwargs + str_kwarg_keys = filter(lambda k: isinstance(k, str), kwarg_keys.values()) + for key in remaining_arg_keys + list(str_kwarg_keys): if key not in kwargs: raise Exception( f"Module {orig_class} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." ) - selected_args = [kwargs[key] for key in arg_keys[len(args) :]] - selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} + selected_args = [kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :]] + selected_kwargs = {key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items()} # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) From 169582a9e8e6552176829192f364c46f802b4c1b Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 12 Jun 2023 16:05:47 -0700 Subject: [PATCH 03/62] Make *_step_log dicts where the key is the logging name and value is the output key --- mart/models/modular.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index eb5dd934..1fe4be99 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -70,13 +70,13 @@ def __init__( self.lr_scheduler = lr_scheduler - self.training_step_log = training_step_log or ["loss"] + self.training_step_log = training_step_log or {} self.training_metrics = training_metrics - self.validation_step_log = validation_step_log or [] + self.validation_step_log = validation_step_log or {} self.validation_metrics = validation_metrics - self.test_step_log = test_step_log or [] + self.test_step_log = test_step_log or {} self.test_metrics = test_metrics # Load state dict for specified modules. We flatten it because Hydra @@ -115,8 +115,8 @@ def training_step(self, batch, batch_idx): input, target = batch output = self(input=input, target=target, model=self.model, step="training") - for name in self.training_step_log: - self.log(f"training/{name}", output[name]) + for log_name, output_key in self.training_step_log.items(): + self.log(f"training/{log_name}", output[output_key], sync_dist=True) assert "loss" in output return output @@ -149,8 +149,8 @@ def validation_step(self, batch, batch_idx): input, target = batch output = self(input=input, target=target, model=self.model, step="validation") - for name in self.validation_step_log: - self.log(f"validation/{name}", output[name]) + for log_name, output_key in self.validation_step_log.items(): + self.log(f"validation/{log_name}", output[output_key], sync_dist=True) return output @@ -175,8 +175,8 @@ def test_step(self, batch, batch_idx): input, target = batch output = self(input=input, target=target, model=self.model, step="test") - for name in self.test_step_log: - self.log(f"test/{name}", output[name]) + for log_name, output_key in self.test_step_log.items(): + self.log(f"test/{log_name}", output[output_key], sync_dist=True) return output From c37a03a19bab4cbcc5b465695ef8971a82095013 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 12:00:40 -0700 Subject: [PATCH 04/62] Add better example of dual model torchvision detectors --- .../attack/gain/rcnn_training_loss.yaml | 8 +- ...CarlaOverObjDet_TorchvisionFasterRCNN.yaml | 7 +- .../COCO_TorchvisionFasterRCNN.yaml | 8 +- .../experiment/COCO_TorchvisionRetinaNet.yaml | 11 +- .../model/torchvision_faster_rcnn.yaml | 100 ++++++++---------- .../model/torchvision_object_detection.yaml | 15 +-- mart/configs/model/torchvision_retinanet.yaml | 100 +++++++++++------- mart/models/dual_mode.py | 32 +----- mart/nn/__init__.py | 1 + mart/nn/module.py | 39 +++++++ 10 files changed, 170 insertions(+), 151 deletions(-) create mode 100644 mart/nn/module.py diff --git a/mart/configs/attack/gain/rcnn_training_loss.yaml b/mart/configs/attack/gain/rcnn_training_loss.yaml index eb7abb9c..59d1f4ee 100644 --- a/mart/configs/attack/gain/rcnn_training_loss.yaml +++ b/mart/configs/attack/gain/rcnn_training_loss.yaml @@ -2,8 +2,8 @@ _target_: mart.nn.CallWith module: _target_: mart.nn.Sum arg_keys: - - rpn_loss.loss_objectness - - rpn_loss.loss_rpn_box_reg - - box_loss.loss_classifier - - box_loss.loss_box_reg + - losses.loss_objectness + - losses.loss_rpn_box_reg + - losses.loss_classifier + - losses.loss_box_reg kwarg_keys: null diff --git a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml index 991a87e7..c5d5f75d 100644 --- a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml @@ -12,10 +12,9 @@ optimized_metric: "test_metrics/map" model: modules: - losses_and_detections: - model: - num_classes: 3 - weights: null + detector: + num_classes: 3 + weights: null optimizer: lr: 0.0125 diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml index 9259de2c..aa2e9b42 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml @@ -29,11 +29,9 @@ datamodule: model: modules: - losses_and_detections: - model: - # Inferred by torchvision. - num_classes: null - weights: COCO_V1 + detector: + num_classes: null # inferred by torchvision + weights: COCO_V1 optimizer: lr: 0.0125 diff --git a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml index dbd4541f..c336555b 100644 --- a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml +++ b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml @@ -20,7 +20,8 @@ callbacks: trainer: # 117,266 training images, 6 epochs, batch_size=2, 351798 max_steps: 351798 - precision: 16 + # FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms(). + precision: 32 datamodule: ims_per_batch: 2 @@ -28,11 +29,9 @@ datamodule: model: modules: - losses_and_detections: - model: - # Inferred by torchvision. - num_classes: null - weights: COCO_V1 + detector: + num_classes: null # inferred by torchvision + weights: COCO_V1 optimizer: lr: 0.0125 diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index c5237184..731f9743 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -2,87 +2,81 @@ defaults: - torchvision_object_detection -# log all losses separately in training. +# log all losses separately in losses. training_step_log: - [ - "rpn_loss.loss_objectness", - "rpn_loss.loss_rpn_box_reg", - "box_loss.loss_classifier", - "box_loss.loss_box_reg", - "loss", - ] + loss_objectness: "losses.loss_objectness" + loss_rpn_box_reg: "losses.loss_rpn_box_reg" + loss_classifier: "losses.loss_classifier" + loss_box_reg: "losses.loss_box_reg" training_sequence: seq010: preprocessor: ["input"] seq020: - losses_and_detections: ["preprocessor", "target"] + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True seq030: + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True + + seq040: loss: # Sum up the losses. [ - "losses_and_detections.training.loss_objectness", - "losses_and_detections.training.loss_rpn_box_reg", - "losses_and_detections.training.loss_classifier", - "losses_and_detections.training.loss_box_reg", + "losses.loss_objectness", + "losses.loss_rpn_box_reg", + "losses.loss_classifier", + "losses.loss_box_reg", ] - seq040: - output: - # Output all losses for logging, defined in model.training_step_log - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss": "loss", - "rpn_loss.loss_objectness": "losses_and_detections.training.loss_objectness", - "rpn_loss.loss_rpn_box_reg": "losses_and_detections.training.loss_rpn_box_reg", - "box_loss.loss_classifier": "losses_and_detections.training.loss_classifier", - "box_loss.loss_box_reg": "losses_and_detections.training.loss_box_reg", - } - validation_sequence: seq010: preprocessor: ["input"] seq020: - losses_and_detections: ["preprocessor", "target"] + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True seq030: - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "rpn_loss.loss_objectness": "losses_and_detections.training.loss_objectness", - "rpn_loss.loss_rpn_box_reg": "losses_and_detections.training.loss_rpn_box_reg", - "box_loss.loss_classifier": "losses_and_detections.training.loss_classifier", - "box_loss.loss_box_reg": "losses_and_detections.training.loss_box_reg", - } + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True test_sequence: seq010: preprocessor: ["input"] seq020: - losses_and_detections: ["preprocessor", "target"] + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True seq030: - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "rpn_loss.loss_objectness": "losses_and_detections.training.loss_objectness", - "rpn_loss.loss_rpn_box_reg": "losses_and_detections.training.loss_rpn_box_reg", - "box_loss.loss_classifier": "losses_and_detections.training.loss_classifier", - "box_loss.loss_box_reg": "losses_and_detections.training.loss_box_reg", - } + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True modules: - losses_and_detections: - # 17s: DualModeGeneralizedRCNN - # 23s: DualMode - _target_: mart.models.DualModeGeneralizedRCNN - model: - _target_: torchvision.models.detection.fasterrcnn_resnet50_fpn - num_classes: ??? + detector: + path: torchvision.models.detection.fasterrcnn_resnet50_fpn + num_classes: ??? diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index a1495dad..6268dd83 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -3,22 +3,17 @@ defaults: - modular - /model/modules@modules.preprocessor: tuple_normalizer -training_step_log: ??? +training_step_log: + loss: "loss" training_sequence: ??? - validation_sequence: ??? - test_sequence: ??? modules: - losses_and_detections: - # Return losses in the training mode and predictions in the eval mode in one pass. - _target_: mart.models.DualMode - model: ??? + detector: + _target_: mart.nn.Module + path: ??? loss: _target_: mart.nn.Sum - - output: - _target_: mart.nn.ReturnKwargs diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 4c45917c..1a978ed1 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -3,52 +3,76 @@ defaults: - torchvision_object_detection # log all losses separately in training. -training_step_log: ["loss_classifier", "loss_box_reg"] +training_step_log: + loss_classifier: "losses.classification" + loss_box_reg: "losses.bbox_regression" training_sequence: - - preprocessor: ["input"] - - losses_and_detections: ["preprocessor", "target"] - - loss: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True + + seq030: + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True + + seq040: + loss: # Sum up the losses. [ - "losses_and_detections.training.classification", - "losses_and_detections.training.bbox_regression", + "losses.classification", + "losses.bbox_regression", ] - - output: - # Output all losses for logging, defined in model.training_step_log - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss": "loss", - "loss_classifier": "losses_and_detections.training.classification", - "loss_box_reg": "losses_and_detections.training.bbox_regression", - } validation_sequence: - - preprocessor: ["input"] - - losses_and_detections: ["preprocessor", "target"] - - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss_classifier": "losses_and_detections.training.classification", - "loss_box_reg": "losses_and_detections.training.bbox_regression", - } + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True + + seq030: + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True test_sequence: - - preprocessor: ["input"] - - losses_and_detections: ["preprocessor", "target"] - - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss_classifier": "losses_and_detections.training.classification", - "loss_box_reg": "losses_and_detections.training.bbox_regression", - } + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True + + seq030: + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True modules: - losses_and_detections: - # _target_: mart.models.DualMode - model: - _target_: torchvision.models.detection.retinanet_resnet50_fpn - num_classes: ??? + detector: + path: torchvision.models.detection.retinanet_resnet50_fpn + num_classes: ??? diff --git a/mart/models/dual_mode.py b/mart/models/dual_mode.py index 5cc780a2..616e4177 100644 --- a/mart/models/dual_mode.py +++ b/mart/models/dual_mode.py @@ -13,37 +13,7 @@ from mart.utils.monkey_patch import MonkeyPatch -__all__ = ["DualMode", "DualModeGeneralizedRCNN"] - - -class DualMode(torch.nn.Module): - """Run model.forward() in both the training mode and the eval mode, then aggregate results in a - dictionary {"training": ..., "eval": ...}. - - Some object detection models are implemented to return losses in the training mode and - predictions in the eval mode, but we want both the losses and the predictions when attacking a - model in the test mode. - """ - - def __init__(self, model): - super().__init__() - - self.model = model - - def forward(self, *args, **kwargs): - original_training_status = self.model.training - ret = {} - - # TODO: Reuse the feature map in dual mode to improve efficiency - self.model.train(True) - ret["training"] = self.model(*args, **kwargs) - - self.model.train(False) - with torch.no_grad(): - ret["eval"] = self.model(*args, **kwargs) - - self.model.train(original_training_status) - return ret +__all__ = ["DualModeGeneralizedRCNN"] class DualModeGeneralizedRCNN(torch.nn.Module): diff --git a/mart/nn/__init__.py b/mart/nn/__init__.py index e39c0d57..c257de69 100644 --- a/mart/nn/__init__.py +++ b/mart/nn/__init__.py @@ -1 +1,2 @@ from .nn import * # noqa: F403 +from .module import * # noqa: F403 diff --git a/mart/nn/module.py b/mart/nn/module.py new file mode 100644 index 00000000..41287d17 --- /dev/null +++ b/mart/nn/module.py @@ -0,0 +1,39 @@ +# +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +import torch + +from hydra.utils import instantiate + +__all__ = ["Module"] + + +class Module(torch.nn.Module): + def __init__(self, path: str, *args, **kwargs): + super().__init__() + + # TODO: Add _load_state_dict_ + # TODO: Add _freeze_ + + cfg = {"_target_": path} + self.module = instantiate(cfg, *args, **kwargs) + + def forward( + self, + *args, + train_mode: bool = True, + inference_mode: bool = False, + **kwargs, + ): + old_train_mode = self.module.training + + # FIXME: Would be nice if this was a context... + self.module.train(train_mode) + with torch.inference_mode(mode=inference_mode): + ret = self.module(*args, **kwargs) + self.module.train(old_train_mode) + + return ret From 48e3be991b263c29d9934caeed9e899cf8286bde Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 12 Jun 2023 16:05:47 -0700 Subject: [PATCH 05/62] Make *_step_log dicts where the key is the logging name and value is the output key --- mart/models/modular.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index eb5dd934..1fe4be99 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -70,13 +70,13 @@ def __init__( self.lr_scheduler = lr_scheduler - self.training_step_log = training_step_log or ["loss"] + self.training_step_log = training_step_log or {} self.training_metrics = training_metrics - self.validation_step_log = validation_step_log or [] + self.validation_step_log = validation_step_log or {} self.validation_metrics = validation_metrics - self.test_step_log = test_step_log or [] + self.test_step_log = test_step_log or {} self.test_metrics = test_metrics # Load state dict for specified modules. We flatten it because Hydra @@ -115,8 +115,8 @@ def training_step(self, batch, batch_idx): input, target = batch output = self(input=input, target=target, model=self.model, step="training") - for name in self.training_step_log: - self.log(f"training/{name}", output[name]) + for log_name, output_key in self.training_step_log.items(): + self.log(f"training/{log_name}", output[output_key], sync_dist=True) assert "loss" in output return output @@ -149,8 +149,8 @@ def validation_step(self, batch, batch_idx): input, target = batch output = self(input=input, target=target, model=self.model, step="validation") - for name in self.validation_step_log: - self.log(f"validation/{name}", output[name]) + for log_name, output_key in self.validation_step_log.items(): + self.log(f"validation/{log_name}", output[output_key], sync_dist=True) return output @@ -175,8 +175,8 @@ def test_step(self, batch, batch_idx): input, target = batch output = self(input=input, target=target, model=self.model, step="test") - for name in self.test_step_log: - self.log(f"test/{name}", output[name]) + for log_name, output_key in self.test_step_log.items(): + self.log(f"test/{log_name}", output[output_key], sync_dist=True) return output From 01a20664b2011c20727d1158b1b7476dc9d44779 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 13:57:49 -0700 Subject: [PATCH 06/62] Fix configs --- mart/configs/model/torchvision_faster_rcnn.yaml | 11 ++++------- mart/configs/model/torchvision_object_detection.yaml | 3 ++- mart/configs/model/torchvision_retinanet.yaml | 4 +++- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index c5237184..bc0ce228 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -4,13 +4,10 @@ defaults: # log all losses separately in training. training_step_log: - [ - "rpn_loss.loss_objectness", - "rpn_loss.loss_rpn_box_reg", - "box_loss.loss_classifier", - "box_loss.loss_box_reg", - "loss", - ] + rpn_loss_objectness: "rpn_loss.loss_objectness" + rpn_loss_rpn_box_reg: "rpn_loss.loss_rpn_box_reg" + box_loss_classifier: "box_loss.loss_classifier" + box_loss_box_reg: "box_loss.loss_box_reg" training_sequence: seq010: diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index a1495dad..c81930a8 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -3,7 +3,8 @@ defaults: - modular - /model/modules@modules.preprocessor: tuple_normalizer -training_step_log: ??? +training_step_log: + loss: "loss" training_sequence: ??? diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 4c45917c..695263a2 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -3,7 +3,9 @@ defaults: - torchvision_object_detection # log all losses separately in training. -training_step_log: ["loss_classifier", "loss_box_reg"] +training_step_log: + loss_classifier: "loss_classifier" + loss_box_reg: "loss_box_reg" training_sequence: - preprocessor: ["input"] From df1d0b266483ce62fc076573a94449574d082a6c Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 13:57:58 -0700 Subject: [PATCH 07/62] remove sync_dist --- mart/models/modular.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index 1fe4be99..e09f2c9d 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -116,7 +116,7 @@ def training_step(self, batch, batch_idx): output = self(input=input, target=target, model=self.model, step="training") for log_name, output_key in self.training_step_log.items(): - self.log(f"training/{log_name}", output[output_key], sync_dist=True) + self.log(f"training/{log_name}", output[output_key]) assert "loss" in output return output @@ -150,7 +150,7 @@ def validation_step(self, batch, batch_idx): output = self(input=input, target=target, model=self.model, step="validation") for log_name, output_key in self.validation_step_log.items(): - self.log(f"validation/{log_name}", output[output_key], sync_dist=True) + self.log(f"validation/{log_name}", output[output_key]) return output @@ -176,7 +176,7 @@ def test_step(self, batch, batch_idx): output = self(input=input, target=target, model=self.model, step="test") for log_name, output_key in self.test_step_log.items(): - self.log(f"test/{log_name}", output[output_key], sync_dist=True) + self.log(f"test/{log_name}", output[output_key]) return output From 14f4d1fa9f437849655e820f287885864939bf07 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 14:05:53 -0700 Subject: [PATCH 08/62] backwards compatibility --- mart/models/modular.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mart/models/modular.py b/mart/models/modular.py index e09f2c9d..b24ce6ae 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -70,12 +70,21 @@ def __init__( self.lr_scheduler = lr_scheduler + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(training_step_log, (list, tuple)): + training_step_log = { item: item for item in training_step_log} self.training_step_log = training_step_log or {} self.training_metrics = training_metrics + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(validation_step_log, (list, tuple)): + validation_step_log = { item: item for item in validation_step_log} self.validation_step_log = validation_step_log or {} self.validation_metrics = validation_metrics + # Be backwards compatible by turning list into dict where each item is its own key-value + if isinstance(test_step_log, (list, tuple)): + test_step_log = { item: item for item in test_step_log} self.test_step_log = test_step_log or {} self.test_metrics = test_metrics From 2e30587274d99afe8f8dc255da44af68a30a05ef Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 14:06:03 -0700 Subject: [PATCH 09/62] Revert "Fix configs" This reverts commit 01a20664b2011c20727d1158b1b7476dc9d44779. --- mart/configs/model/torchvision_faster_rcnn.yaml | 11 +++++++---- mart/configs/model/torchvision_object_detection.yaml | 3 +-- mart/configs/model/torchvision_retinanet.yaml | 4 +--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index bc0ce228..c5237184 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -4,10 +4,13 @@ defaults: # log all losses separately in training. training_step_log: - rpn_loss_objectness: "rpn_loss.loss_objectness" - rpn_loss_rpn_box_reg: "rpn_loss.loss_rpn_box_reg" - box_loss_classifier: "box_loss.loss_classifier" - box_loss_box_reg: "box_loss.loss_box_reg" + [ + "rpn_loss.loss_objectness", + "rpn_loss.loss_rpn_box_reg", + "box_loss.loss_classifier", + "box_loss.loss_box_reg", + "loss", + ] training_sequence: seq010: diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index c81930a8..a1495dad 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -3,8 +3,7 @@ defaults: - modular - /model/modules@modules.preprocessor: tuple_normalizer -training_step_log: - loss: "loss" +training_step_log: ??? training_sequence: ??? diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 695263a2..4c45917c 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -3,9 +3,7 @@ defaults: - torchvision_object_detection # log all losses separately in training. -training_step_log: - loss_classifier: "loss_classifier" - loss_box_reg: "loss_box_reg" +training_step_log: ["loss_classifier", "loss_box_reg"] training_sequence: - preprocessor: ["input"] From 6fef148cb94379eeb4cef983080b90740e5f0bc1 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 14:08:48 -0700 Subject: [PATCH 10/62] style --- mart/models/modular.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index b24ce6ae..d1d2752c 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -72,19 +72,19 @@ def __init__( # Be backwards compatible by turning list into dict where each item is its own key-value if isinstance(training_step_log, (list, tuple)): - training_step_log = { item: item for item in training_step_log} + training_step_log = {item: item for item in training_step_log} self.training_step_log = training_step_log or {} self.training_metrics = training_metrics # Be backwards compatible by turning list into dict where each item is its own key-value if isinstance(validation_step_log, (list, tuple)): - validation_step_log = { item: item for item in validation_step_log} + validation_step_log = {item: item for item in validation_step_log} self.validation_step_log = validation_step_log or {} self.validation_metrics = validation_metrics # Be backwards compatible by turning list into dict where each item is its own key-value if isinstance(test_step_log, (list, tuple)): - test_step_log = { item: item for item in test_step_log} + test_step_log = {item: item for item in test_step_log} self.test_step_log = test_step_log or {} self.test_metrics = test_metrics From c4e0d78813a85fc7f6176ce67def781635b7b3de Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 12 Jun 2023 16:04:53 -0700 Subject: [PATCH 11/62] Make metric logging keys configurable --- mart/models/modular.py | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index eb5dd934..4fcfe783 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -36,6 +36,9 @@ def __init__( test_step_log=None, test_metrics=None, load_state_dict=None, + output_loss_key="loss", + output_preds_key="preds", + output_target_key="target", ): super().__init__() @@ -88,6 +91,10 @@ def __init__( logger.info(f"Loading state_dict {path} for {module.__class__.__name__}...") module.load_state_dict(torch.load(path, map_location="cpu")) + self.output_loss_key = output_loss_key + self.output_preds_key = output_preds_key + self.output_target_key = output_target_key + def configure_optimizers(self): config = {} config["optimizer"] = self.optimizer_fn(self.model) @@ -118,19 +125,15 @@ def training_step(self, batch, batch_idx): for name in self.training_step_log: self.log(f"training/{name}", output[name]) - assert "loss" in output - return output - - def training_step_end(self, output): if self.training_metrics is not None: # Some models only return loss in the training mode. - if "preds" not in output or "target" not in output: + if self.output_preds_key not in output or self.output_target_key not in output: raise ValueError( - "You have specified training_metrics, but the model does not return preds and target during training. You can either nullify training_metrics or configure the model to return preds and target in the training output." + f"You have specified training_metrics, but the model does not return {self.output_preds_key} or {self.output_target_key} during training. You can either nullify training_metrics or configure the model to return {self.output_preds_key} and {self.output_target_key} in the training output." ) - self.training_metrics(output["preds"], output["target"]) - loss = output.pop("loss") - return loss + self.training_metrics(output[self.output_preds_key], output[self.output_target_key]) + + return output[self.output_loss_key] def training_epoch_end(self, outputs): if self.training_metrics is not None: @@ -152,13 +155,9 @@ def validation_step(self, batch, batch_idx): for name in self.validation_step_log: self.log(f"validation/{name}", output[name]) - return output + self.validation_metrics(output[self.output_preds_key], output[self.output_target_key]) - def validation_step_end(self, output): - self.validation_metrics(output["preds"], output["target"]) - - # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) - output.clear() + return None def validation_epoch_end(self, outputs): metrics = self.validation_metrics.compute() @@ -178,13 +177,9 @@ def test_step(self, batch, batch_idx): for name in self.test_step_log: self.log(f"test/{name}", output[name]) - return output - - def test_step_end(self, output): - self.test_metrics(output["preds"], output["target"]) + self.test_metrics(output[self.output_preds_key], output[self.output_target_key]) - # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) - output.clear() + return None def test_epoch_end(self, outputs): metrics = self.test_metrics.compute() From 508798ca12d98d2ce757bcb18779b7c3fe474cdd Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 14:15:36 -0700 Subject: [PATCH 12/62] cleanup --- mart/models/modular.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index 4fcfe783..d663dda3 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -125,6 +125,10 @@ def training_step(self, batch, batch_idx): for name in self.training_step_log: self.log(f"training/{name}", output[name]) + assert "loss" in output + return output + + def training_step_end(self, output): if self.training_metrics is not None: # Some models only return loss in the training mode. if self.output_preds_key not in output or self.output_target_key not in output: @@ -132,8 +136,8 @@ def training_step(self, batch, batch_idx): f"You have specified training_metrics, but the model does not return {self.output_preds_key} or {self.output_target_key} during training. You can either nullify training_metrics or configure the model to return {self.output_preds_key} and {self.output_target_key} in the training output." ) self.training_metrics(output[self.output_preds_key], output[self.output_target_key]) - - return output[self.output_loss_key] + loss = output.pop(self.output_loss_key) + return loss def training_epoch_end(self, outputs): if self.training_metrics is not None: @@ -155,9 +159,13 @@ def validation_step(self, batch, batch_idx): for name in self.validation_step_log: self.log(f"validation/{name}", output[name]) + return output + + def validation_step_end(self, output): self.validation_metrics(output[self.output_preds_key], output[self.output_target_key]) - return None + # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) + output.clear() def validation_epoch_end(self, outputs): metrics = self.validation_metrics.compute() @@ -177,9 +185,13 @@ def test_step(self, batch, batch_idx): for name in self.test_step_log: self.log(f"test/{name}", output[name]) + return output + + def test_step_end(self, output): self.test_metrics(output[self.output_preds_key], output[self.output_target_key]) - return None + # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) + output.clear() def test_epoch_end(self, outputs): metrics = self.test_metrics.compute() From fc770e81d7783edf7e0ef7bf04b4b12a25a2eaa0 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 17:12:51 -0700 Subject: [PATCH 13/62] Remove *_step_end --- mart/models/modular.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index d663dda3..4fcfe783 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -125,10 +125,6 @@ def training_step(self, batch, batch_idx): for name in self.training_step_log: self.log(f"training/{name}", output[name]) - assert "loss" in output - return output - - def training_step_end(self, output): if self.training_metrics is not None: # Some models only return loss in the training mode. if self.output_preds_key not in output or self.output_target_key not in output: @@ -136,8 +132,8 @@ def training_step_end(self, output): f"You have specified training_metrics, but the model does not return {self.output_preds_key} or {self.output_target_key} during training. You can either nullify training_metrics or configure the model to return {self.output_preds_key} and {self.output_target_key} in the training output." ) self.training_metrics(output[self.output_preds_key], output[self.output_target_key]) - loss = output.pop(self.output_loss_key) - return loss + + return output[self.output_loss_key] def training_epoch_end(self, outputs): if self.training_metrics is not None: @@ -159,13 +155,9 @@ def validation_step(self, batch, batch_idx): for name in self.validation_step_log: self.log(f"validation/{name}", output[name]) - return output - - def validation_step_end(self, output): self.validation_metrics(output[self.output_preds_key], output[self.output_target_key]) - # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) - output.clear() + return None def validation_epoch_end(self, outputs): metrics = self.validation_metrics.compute() @@ -185,13 +177,9 @@ def test_step(self, batch, batch_idx): for name in self.test_step_log: self.log(f"test/{name}", output[name]) - return output - - def test_step_end(self, output): self.test_metrics(output[self.output_preds_key], output[self.output_target_key]) - # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) - output.clear() + return None def test_epoch_end(self, outputs): metrics = self.test_metrics.compute() From c31f4deddb65f6008f8c570bafd3ed2bf1c33d77 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Mon, 12 Jun 2023 15:56:55 -0700 Subject: [PATCH 14/62] Don't require output module with SequentialDict --- mart/nn/nn.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 93b0f07f..754e8657 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -49,10 +49,6 @@ class SequentialDict(torch.nn.ModuleDict): """ def __init__(self, modules, sequences=None): - - if "output" not in modules: - raise ValueError("Modules must have an module named 'output'") - super().__init__(modules) self._sequences = { @@ -121,7 +117,8 @@ def forward(self, step=None, sequence=None, **kwargs): # Pop the executed module to proceed with the sequence sequence.popitem(last=False) - return kwargs["output"] + # return kwargs as DotDict + return DotDict(kwargs) class ReturnKwargs(torch.nn.Module): From 549f705c3a71e58cb014fd3ffe564f717580578a Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 17:42:07 -0700 Subject: [PATCH 15/62] fix configs and tests --- .../attack/gain/rcnn_training_loss.yaml | 8 ++-- mart/configs/model/classifier.yaml | 17 ------- .../model/torchvision_faster_rcnn.yaml | 46 ++----------------- .../model/torchvision_object_detection.yaml | 10 ++-- mart/configs/model/torchvision_retinanet.yaml | 27 ++--------- tests/test_experiments.py | 6 +-- 6 files changed, 18 insertions(+), 96 deletions(-) diff --git a/mart/configs/attack/gain/rcnn_training_loss.yaml b/mart/configs/attack/gain/rcnn_training_loss.yaml index eb7abb9c..9ed8671b 100644 --- a/mart/configs/attack/gain/rcnn_training_loss.yaml +++ b/mart/configs/attack/gain/rcnn_training_loss.yaml @@ -2,8 +2,8 @@ _target_: mart.nn.CallWith module: _target_: mart.nn.Sum arg_keys: - - rpn_loss.loss_objectness - - rpn_loss.loss_rpn_box_reg - - box_loss.loss_classifier - - box_loss.loss_box_reg + - "losses_and_detections.training.loss_objectness" + - "losses_and_detections.training.loss_rpn_box_reg" + - "losses_and_detections.training.loss_classifier" + - "losses_and_detections.training.loss_box_reg" kwarg_keys: null diff --git a/mart/configs/model/classifier.yaml b/mart/configs/model/classifier.yaml index ad664989..df1a9c5b 100644 --- a/mart/configs/model/classifier.yaml +++ b/mart/configs/model/classifier.yaml @@ -17,14 +17,6 @@ training_sequence: seq040: preds: _call_with_args_: ["logits"] - seq050: - output: - { - "preds": "preds", - "target": "target", - "logits": "logits", - "loss": "loss", - } # The kwargs-centric version. # We may use *args as **kwargs to avoid the lengthy _call_with_args_. @@ -36,10 +28,6 @@ validation_sequence: - logits: ["preprocessor"] - preds: input: logits - - output: - preds: preds - target: target - logits: logits # The simplified version. # We treat a list as the `_call_with_args_` parameter. @@ -50,8 +38,6 @@ test_sequence: logits: ["preprocessor"] seq030: preds: ["logits"] - seq040: - output: { preds: preds, target: target, logits: logits } modules: preprocessor: ??? @@ -64,6 +50,3 @@ modules: preds: _target_: torch.nn.Softmax dim: 1 - - output: - _target_: mart.nn.ReturnKwargs diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index c5237184..65200579 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -4,13 +4,10 @@ defaults: # log all losses separately in training. training_step_log: - [ - "rpn_loss.loss_objectness", - "rpn_loss.loss_rpn_box_reg", - "box_loss.loss_classifier", - "box_loss.loss_box_reg", - "loss", - ] + loss_objectness: "losses_and_detections.training.loss_objectness" + loss_rpn_box_reg: "losses_and_detections.training.loss_rpn_box_reg" + loss_classifier: "losses_and_detections.training.loss_classifier" + loss_box_reg: "losses_and_detections.training.loss_box_reg" training_sequence: seq010: @@ -29,19 +26,6 @@ training_sequence: "losses_and_detections.training.loss_box_reg", ] - seq040: - output: - # Output all losses for logging, defined in model.training_step_log - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss": "loss", - "rpn_loss.loss_objectness": "losses_and_detections.training.loss_objectness", - "rpn_loss.loss_rpn_box_reg": "losses_and_detections.training.loss_rpn_box_reg", - "box_loss.loss_classifier": "losses_and_detections.training.loss_classifier", - "box_loss.loss_box_reg": "losses_and_detections.training.loss_box_reg", - } - validation_sequence: seq010: preprocessor: ["input"] @@ -49,17 +33,6 @@ validation_sequence: seq020: losses_and_detections: ["preprocessor", "target"] - seq030: - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "rpn_loss.loss_objectness": "losses_and_detections.training.loss_objectness", - "rpn_loss.loss_rpn_box_reg": "losses_and_detections.training.loss_rpn_box_reg", - "box_loss.loss_classifier": "losses_and_detections.training.loss_classifier", - "box_loss.loss_box_reg": "losses_and_detections.training.loss_box_reg", - } - test_sequence: seq010: preprocessor: ["input"] @@ -67,17 +40,6 @@ test_sequence: seq020: losses_and_detections: ["preprocessor", "target"] - seq030: - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "rpn_loss.loss_objectness": "losses_and_detections.training.loss_objectness", - "rpn_loss.loss_rpn_box_reg": "losses_and_detections.training.loss_rpn_box_reg", - "box_loss.loss_classifier": "losses_and_detections.training.loss_classifier", - "box_loss.loss_box_reg": "losses_and_detections.training.loss_box_reg", - } - modules: losses_and_detections: # 17s: DualModeGeneralizedRCNN diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index a1495dad..1bbd678c 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -3,14 +3,15 @@ defaults: - modular - /model/modules@modules.preprocessor: tuple_normalizer -training_step_log: ??? +training_step_log: + loss: "loss" training_sequence: ??? - validation_sequence: ??? - test_sequence: ??? +output_preds_key: "losses_and_detections.eval" + modules: losses_and_detections: # Return losses in the training mode and predictions in the eval mode in one pass. @@ -19,6 +20,3 @@ modules: loss: _target_: mart.nn.Sum - - output: - _target_: mart.nn.ReturnKwargs diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 4c45917c..34b66945 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -3,7 +3,9 @@ defaults: - torchvision_object_detection # log all losses separately in training. -training_step_log: ["loss_classifier", "loss_box_reg"] +training_step_log: + loss_classifier: "losses_and_detections.training.classification" + loss_box_reg: "losses_and_detections.training.bbox_regression" training_sequence: - preprocessor: ["input"] @@ -14,37 +16,14 @@ training_sequence: "losses_and_detections.training.classification", "losses_and_detections.training.bbox_regression", ] - - output: - # Output all losses for logging, defined in model.training_step_log - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss": "loss", - "loss_classifier": "losses_and_detections.training.classification", - "loss_box_reg": "losses_and_detections.training.bbox_regression", - } validation_sequence: - preprocessor: ["input"] - losses_and_detections: ["preprocessor", "target"] - - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss_classifier": "losses_and_detections.training.classification", - "loss_box_reg": "losses_and_detections.training.bbox_regression", - } test_sequence: - preprocessor: ["input"] - losses_and_detections: ["preprocessor", "target"] - - output: - { - "preds": "losses_and_detections.eval", - "target": "target", - "loss_classifier": "losses_and_detections.training.classification", - "loss_box_reg": "losses_and_detections.training.bbox_regression", - } modules: losses_and_detections: diff --git a/tests/test_experiments.py b/tests/test_experiments.py index d128c1df..cf4ffea7 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -209,7 +209,7 @@ def test_coco_fasterrcnn_experiment(coco_cfg, tmp_path): "-m", "experiment=COCO_TorchvisionFasterRCNN", "hydra.sweep.dir=" + str(tmp_path), - "optimized_metric=training/rpn_loss.loss_objectness", + "optimized_metric=training/loss_objectness", ] + overrides run_sh_command(command) @@ -224,7 +224,7 @@ def test_coco_fasterrcnn_adv_experiment(coco_cfg, tmp_path): "-m", "experiment=COCO_TorchvisionFasterRCNN_Adv", "hydra.sweep.dir=" + str(tmp_path), - "optimized_metric=training/rpn_loss.loss_objectness", + "optimized_metric=training/loss_objectness", ] + overrides run_sh_command(command) @@ -256,7 +256,7 @@ def test_armory_carla_fasterrcnn_experiment(carla_cfg, tmp_path): "experiment=ArmoryCarlaOverObjDet_TorchvisionFasterRCNN", "+attack@model.modules.input_adv_test=object_detection_mask_adversary", "hydra.sweep.dir=" + str(tmp_path), - "optimized_metric=training/rpn_loss.loss_objectness", + "optimized_metric=training/loss_objectness", ] + overrides run_sh_command(command) From 5e7381743d018ab27d3ed61fb033244192b2ebbf Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 17:55:33 -0700 Subject: [PATCH 16/62] Generalize attack objectives --- mart/configs/attack/objective/misclassification.yaml | 4 ++-- mart/configs/attack/objective/object_detection_missed.yaml | 2 +- mart/configs/attack/objective/zero_ap.yaml | 4 ++-- mart/configs/model/modular.yaml | 3 +++ 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mart/configs/attack/objective/misclassification.yaml b/mart/configs/attack/objective/misclassification.yaml index e2e9b819..82e055cd 100644 --- a/mart/configs/attack/objective/misclassification.yaml +++ b/mart/configs/attack/objective/misclassification.yaml @@ -2,6 +2,6 @@ _target_: mart.nn.CallWith module: _target_: mart.attack.objective.Mispredict arg_keys: - - preds - - target + - ${model.output_preds_key} + - ${model.output_target_key} kwarg_keys: null diff --git a/mart/configs/attack/objective/object_detection_missed.yaml b/mart/configs/attack/objective/object_detection_missed.yaml index dec2410c..7ebb1dc3 100644 --- a/mart/configs/attack/objective/object_detection_missed.yaml +++ b/mart/configs/attack/objective/object_detection_missed.yaml @@ -3,5 +3,5 @@ module: _target_: mart.attack.objective.Missed confidence_threshold: 0.0 arg_keys: - - preds + - ${model.output_preds_key} kwarg_keys: null diff --git a/mart/configs/attack/objective/zero_ap.yaml b/mart/configs/attack/objective/zero_ap.yaml index 6a43f77d..91dc5b96 100644 --- a/mart/configs/attack/objective/zero_ap.yaml +++ b/mart/configs/attack/objective/zero_ap.yaml @@ -4,6 +4,6 @@ module: iou_threshold: 0.5 confidence_threshold: 0.0 arg_keys: - - preds - - target + - ${model.output_preds_key} + - ${model.output_target_key} kwarg_keys: null diff --git a/mart/configs/model/modular.yaml b/mart/configs/model/modular.yaml index f4a6976f..6c137a53 100644 --- a/mart/configs/model/modular.yaml +++ b/mart/configs/model/modular.yaml @@ -1,6 +1,9 @@ _target_: mart.models.LitModular _convert_: all +output_preds_key: "preds" +output_target_key: "target" + modules: ??? optimizer: ??? From 54882e71e5de1439435496dff8b204d26174a643 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 07:57:02 -0700 Subject: [PATCH 17/62] style --- mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml | 2 +- mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml | 2 +- mart/configs/model/torchvision_retinanet.yaml | 5 +---- mart/nn/__init__.py | 2 +- mart/nn/module.py | 1 - mart/nn/nn.py | 8 ++++++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml index aa2e9b42..5a52b36e 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml @@ -30,7 +30,7 @@ datamodule: model: modules: detector: - num_classes: null # inferred by torchvision + num_classes: null # inferred by torchvision weights: COCO_V1 optimizer: diff --git a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml index c336555b..3b94a04a 100644 --- a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml +++ b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml @@ -30,7 +30,7 @@ datamodule: model: modules: detector: - num_classes: null # inferred by torchvision + num_classes: null # inferred by torchvision weights: COCO_V1 optimizer: diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 1a978ed1..e0fceeea 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -29,10 +29,7 @@ training_sequence: seq040: loss: # Sum up the losses. - [ - "losses.classification", - "losses.bbox_regression", - ] + ["losses.classification", "losses.bbox_regression"] validation_sequence: seq010: diff --git a/mart/nn/__init__.py b/mart/nn/__init__.py index c257de69..6333b3da 100644 --- a/mart/nn/__init__.py +++ b/mart/nn/__init__.py @@ -1,2 +1,2 @@ -from .nn import * # noqa: F403 from .module import * # noqa: F403 +from .nn import * # noqa: F403 diff --git a/mart/nn/module.py b/mart/nn/module.py index 41287d17..756a3689 100644 --- a/mart/nn/module.py +++ b/mart/nn/module.py @@ -5,7 +5,6 @@ # import torch - from hydra.utils import instantiate __all__ = ["Module"] diff --git a/mart/nn/nn.py b/mart/nn/nn.py index a5b52ce7..ad437454 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -158,8 +158,12 @@ def forward(self, *args, **kwargs): f"Module {orig_class} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." ) - selected_args = [kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :]] - selected_kwargs = {key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items()} + selected_args = [ + kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :] + ] + selected_kwargs = { + key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items() + } # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) From ed49144f8a4bab4597b8263fe63eca41be69542d Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 08:01:09 -0700 Subject: [PATCH 18/62] bugfix --- mart/configs/model/torchvision_object_detection.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index be60cbc0..6268dd83 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -10,8 +10,6 @@ training_sequence: ??? validation_sequence: ??? test_sequence: ??? -output_preds_key: "losses_and_detections.eval" - modules: detector: _target_: mart.nn.Module From c2eeebf058091598931070673dacca148f3819eb Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 13:29:18 -0700 Subject: [PATCH 19/62] _return_as_dict -> _return_as_dict_ --- mart/nn/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index ad437454..b9fa0fb7 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -27,7 +27,7 @@ class SequentialDict(torch.nn.ModuleDict): : _name_: _call_with_args_: - _return_as_dict: + _return_as_dict_: **kwargs All intermediate output from each module are stored in the dictionary `kwargs` in `forward()` @@ -82,7 +82,7 @@ def parse_sequence(self, sequence): # The module would be called with these *args. arg_keys = module_cfg.pop("_call_with_args_", None) # The module would return a dictionary with these keys instead of a tuple. - return_keys = module_cfg.pop("_return_as_dict", None) + return_keys = module_cfg.pop("_return_as_dict_", None) # The module would be called with these **kwargs. kwarg_keys = module_cfg From 5be1688588334fe3c25dc3f4112d4271b4eccb67 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 13:31:48 -0700 Subject: [PATCH 20/62] path -> _path_ --- mart/nn/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/nn/module.py b/mart/nn/module.py index 756a3689..fc94d3e7 100644 --- a/mart/nn/module.py +++ b/mart/nn/module.py @@ -11,13 +11,13 @@ class Module(torch.nn.Module): - def __init__(self, path: str, *args, **kwargs): + def __init__(self, _path_: str, *args, **kwargs): super().__init__() # TODO: Add _load_state_dict_ # TODO: Add _freeze_ - cfg = {"_target_": path} + cfg = {"_target_": _path_} self.module = instantiate(cfg, *args, **kwargs) def forward( From 04a249aa37c63def98549dff42d2332b65d0760f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 14:39:14 -0700 Subject: [PATCH 21/62] Better Magic Module --- .../model/torchvision_faster_rcnn.yaml | 2 +- .../model/torchvision_object_detection.yaml | 2 +- mart/configs/model/torchvision_retinanet.yaml | 2 +- mart/nn/module.py | 27 ++++++++++++------- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index 731f9743..fd8ee78b 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -78,5 +78,5 @@ test_sequence: modules: detector: - path: torchvision.models.detection.fasterrcnn_resnet50_fpn + _path_: torchvision.models.detection.fasterrcnn_resnet50_fpn num_classes: ??? diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index 6268dd83..1ea42852 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -13,7 +13,7 @@ test_sequence: ??? modules: detector: _target_: mart.nn.Module - path: ??? + _path_: ??? loss: _target_: mart.nn.Sum diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index e0fceeea..8d0b72fa 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -71,5 +71,5 @@ test_sequence: modules: detector: - path: torchvision.models.detection.retinanet_resnet50_fpn + _path_: torchvision.models.detection.retinanet_resnet50_fpn num_classes: ??? diff --git a/mart/nn/module.py b/mart/nn/module.py index fc94d3e7..1cbef0da 100644 --- a/mart/nn/module.py +++ b/mart/nn/module.py @@ -4,6 +4,9 @@ # SPDX-License-Identifier: BSD-3-Clause # +import types +from collections import OrderedDict + import torch from hydra.utils import instantiate @@ -11,15 +14,21 @@ class Module(torch.nn.Module): - def __init__(self, _path_: str, *args, **kwargs): - super().__init__() + """A magic Module that can override forward.""" - # TODO: Add _load_state_dict_ - # TODO: Add _freeze_ + def __new__(cls, *args, _path_: str, **kwargs): + # TODO: Add support for _load_state_dict_ + # TODO: Add support for _freeze_ cfg = {"_target_": _path_} - self.module = instantiate(cfg, *args, **kwargs) + module = instantiate(cfg, *args, **kwargs) + + module._forward = module.forward + module.forward = types.MethodType(Module.forward, module) + + return module + @staticmethod def forward( self, *args, @@ -27,12 +36,12 @@ def forward( inference_mode: bool = False, **kwargs, ): - old_train_mode = self.module.training + old_train_mode = self.training # FIXME: Would be nice if this was a context... - self.module.train(train_mode) + self.train(train_mode) with torch.inference_mode(mode=inference_mode): - ret = self.module(*args, **kwargs) - self.module.train(old_train_mode) + ret = self._forward(*args, **kwargs) + self.train(old_train_mode) return ret From d6fad0dc4764ab2c8e801e36942880f4b9973605 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 14:46:54 -0700 Subject: [PATCH 22/62] Interpolate rcnn_training_loss from training sequence --- mart/configs/attack/gain/rcnn_training_loss.yaml | 6 +----- mart/configs/model/torchvision_faster_rcnn.yaml | 2 +- mart/configs/model/torchvision_object_detection.yaml | 4 +++- mart/configs/model/torchvision_retinanet.yaml | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mart/configs/attack/gain/rcnn_training_loss.yaml b/mart/configs/attack/gain/rcnn_training_loss.yaml index bcde9696..2fe3900e 100644 --- a/mart/configs/attack/gain/rcnn_training_loss.yaml +++ b/mart/configs/attack/gain/rcnn_training_loss.yaml @@ -1,9 +1,5 @@ _target_: mart.nn.CallWith module: _target_: mart.nn.Sum -arg_keys: - - "losses.loss_objectness" - - "losses.loss_rpn_box_reg" - - "losses.loss_classifier" - - "losses.loss_box_reg" +arg_keys: ${model.training_sequence.seq100.loss} kwarg_keys: null diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index fd8ee78b..096660e9 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -28,7 +28,7 @@ training_sequence: train_mode: False inference_mode: True - seq040: + seq100: loss: # Sum up the losses. [ diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index 1ea42852..11da28a0 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -6,7 +6,9 @@ defaults: training_step_log: loss: "loss" -training_sequence: ??? +training_sequence: + seq100: + loss: ??? validation_sequence: ??? test_sequence: ??? diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 8d0b72fa..da2411a8 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -26,7 +26,7 @@ training_sequence: train_mode: False inference_mode: True - seq040: + seq100: loss: # Sum up the losses. ["losses.classification", "losses.bbox_regression"] From 0eb7cef7f9f2097eb6de509f6824fc2205790aea Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 14:56:06 -0700 Subject: [PATCH 23/62] Dont hardcode detector --- mart/configs/model/torchvision_faster_rcnn.yaml | 1 + mart/configs/model/torchvision_object_detection.yaml | 4 ---- mart/configs/model/torchvision_retinanet.yaml | 1 + 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index 096660e9..d7f9657e 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -78,5 +78,6 @@ test_sequence: modules: detector: + _target_: mart.nn.Module _path_: torchvision.models.detection.fasterrcnn_resnet50_fpn num_classes: ??? diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index 11da28a0..c4b54c98 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -13,9 +13,5 @@ validation_sequence: ??? test_sequence: ??? modules: - detector: - _target_: mart.nn.Module - _path_: ??? - loss: _target_: mart.nn.Sum diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index da2411a8..4753ea6b 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -71,5 +71,6 @@ test_sequence: modules: detector: + _target_: mart.nn.Module _path_: torchvision.models.detection.retinanet_resnet50_fpn num_classes: ??? From b508ef64a6f400b2e3ab2304d745ef85499ba887 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 14:58:55 -0700 Subject: [PATCH 24/62] Remove DualModeGeneralizedRCNN in favor of using configuration --- .../COCO_TorchvisionFasterRCNN.yaml | 5 - .../model/torchvision_dual_faster_rcnn.yaml | 83 +++++++ .../model/torchvision_faster_rcnn.yaml | 216 ++++++++++++++---- mart/models/dual_mode.py | 123 ---------- 4 files changed, 257 insertions(+), 170 deletions(-) create mode 100644 mart/configs/model/torchvision_dual_faster_rcnn.yaml delete mode 100644 mart/models/dual_mode.py diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml index 5a52b36e..0937774b 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml @@ -28,11 +28,6 @@ datamodule: world_size: 1 model: - modules: - detector: - num_classes: null # inferred by torchvision - weights: COCO_V1 - optimizer: lr: 0.0125 momentum: 0.9 diff --git a/mart/configs/model/torchvision_dual_faster_rcnn.yaml b/mart/configs/model/torchvision_dual_faster_rcnn.yaml new file mode 100644 index 00000000..d7f9657e --- /dev/null +++ b/mart/configs/model/torchvision_dual_faster_rcnn.yaml @@ -0,0 +1,83 @@ +# We simply wrap a torchvision object detection model for validation. +defaults: + - torchvision_object_detection + +# log all losses separately in losses. +training_step_log: + loss_objectness: "losses.loss_objectness" + loss_rpn_box_reg: "losses.loss_rpn_box_reg" + loss_classifier: "losses.loss_classifier" + loss_box_reg: "losses.loss_box_reg" + +training_sequence: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True + + seq030: + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True + + seq100: + loss: + # Sum up the losses. + [ + "losses.loss_objectness", + "losses.loss_rpn_box_reg", + "losses.loss_classifier", + "losses.loss_box_reg", + ] + +validation_sequence: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True + + seq030: + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True + +test_sequence: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + images: preprocessor + targets: target + train_mode: True + + seq030: + detector: + _name_: preds + images: preprocessor + targets: target + train_mode: False + inference_mode: True + +modules: + detector: + _target_: mart.nn.Module + _path_: torchvision.models.detection.fasterrcnn_resnet50_fpn + num_classes: ??? diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index d7f9657e..3b1feea6 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -2,40 +2,52 @@ defaults: - torchvision_object_detection +output_preds_key: "roi_heads_eval.detections" + # log all losses separately in losses. training_step_log: - loss_objectness: "losses.loss_objectness" - loss_rpn_box_reg: "losses.loss_rpn_box_reg" - loss_classifier: "losses.loss_classifier" - loss_box_reg: "losses.loss_box_reg" + loss_objectness: "rpn.losses.loss_objectness" + loss_rpn_box_reg: "rpn.losses.loss_rpn_box_reg" + loss_classifier: "roi_heads.losses.loss_classifier" + loss_box_reg: "roi_heads.losses.loss_box_reg" training_sequence: seq010: preprocessor: ["input"] seq020: - detector: - _name_: losses - images: preprocessor - targets: target - train_mode: True + transform: + images: "preprocessor" + targets: "target" + _return_as_dict_: ["images", "targets"] seq030: - detector: - _name_: preds - images: preprocessor - targets: target - train_mode: False - inference_mode: True + backbone: + x: "transform.images.tensors" + + seq040: + rpn: + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq050: + roi_heads: + features: "backbone" + proposals: "rpn.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] seq100: loss: # Sum up the losses. [ - "losses.loss_objectness", - "losses.loss_rpn_box_reg", - "losses.loss_classifier", - "losses.loss_box_reg", + "rpn.losses.loss_objectness", + "rpn.losses.loss_rpn_box_reg", + "roi_heads.losses.loss_classifier", + "roi_heads.losses.loss_box_reg", ] validation_sequence: @@ -43,41 +55,161 @@ validation_sequence: preprocessor: ["input"] seq020: - detector: - _name_: losses - images: preprocessor - targets: target - train_mode: True + transform: + images: "preprocessor" + targets: "target" + _return_as_dict_: ["images", "targets"] seq030: - detector: - _name_: preds - images: preprocessor - targets: target - train_mode: False - inference_mode: True + backbone: + x: "transform.images.tensors" + + seq040: + rpn: + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq050: + roi_heads: + features: "backbone" + proposals: "rpn.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] test_sequence: seq010: preprocessor: ["input"] seq020: - detector: - _name_: losses - images: preprocessor - targets: target - train_mode: True + transform: + images: "preprocessor" + targets: "target" + _return_as_dict_: ["images", "targets"] seq030: - detector: - _name_: preds - images: preprocessor - targets: target + backbone: + x: "transform.images.tensors" + + seq040: + rpn: + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq050: + roi_heads: + features: "backbone" + proposals: "rpn.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + + seq060: + rpn: + _name_: "rpn_eval" + train_mode: False + inference_mode: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq070: + roi_heads: + _name_: "roi_heads_eval" train_mode: False inference_mode: True + features: "backbone" + proposals: "rpn_eval.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] modules: - detector: + transform: + _target_: torchvision.models.detection.transform.GeneralizedRCNNTransform + min_size: 800 + max_size: 1333 + image_mean: [0.485, 0.456, 0.406] + image_std: [0.229, 0.224, 0.225] + + backbone: + _target_: torchvision.models.detection.backbone_utils.BackboneWithFPN + backbone: + _target_: torchvision.models.resnet.ResNet + block: + _target_: hydra.utils.get_method + path: torchvision.models.resnet.Bottleneck + layers: [3, 4, 6, 3] + num_classes: 1000 + zero_init_residual: False + groups: 1 + width_per_group: 64 + replace_stride_with_dilation: null + norm_layer: + _target_: hydra.utils.get_method + path: torchvision.ops.misc.FrozenBatchNorm2d + return_layers: + { "layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3" } + in_channels_list: [256, 512, 1024, 2048] + out_channels: 256 + + rpn: + _target_: mart.nn.Module + _path_: torchvision.models.detection.rpn.RegionProposalNetwork + #_target_: torchvision.models.detection.rpn.RegionProposalNetwork + anchor_generator: + _target_: torchvision.models.detection.anchor_utils.AnchorGenerator + sizes: + - [32] + - [64] + - [128] + - [256] + - [512] + aspect_ratios: + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + head: + _target_: torchvision.models.detection.rpn.RPNHead + in_channels: 256 # backbone.out_channels + num_anchors: 3 + fg_iou_thresh: 0.7 + bg_iou_thresh: 0.3 + batch_size_per_image: 256 + positive_fraction: 0.5 + pre_nms_top_n: { "training": 2000, "testing": 1000 } + post_nms_top_n: { "training": 2000, "testing": 1000 } + nms_thresh: 0.7 + score_thresh: 0.0 + + roi_heads: _target_: mart.nn.Module - _path_: torchvision.models.detection.fasterrcnn_resnet50_fpn - num_classes: ??? + _path_: torchvision.models.detection.roi_heads.RoIHeads + box_roi_pool: + _target_: torchvision.ops.MultiScaleRoIAlign + featmap_names: ["0", "1", "2", "3"] # backbone.return_layers + output_size: 7 + sampling_ratio: 2 + box_head: + _target_: torchvision.models.detection.faster_rcnn.TwoMLPHead + in_channels: 12544 # backbone.out_channels * (box_roi_pool.output_size ** 2) + representation_size: 1024 + box_predictor: + _target_: torchvision.models.detection.faster_rcnn.FastRCNNPredictor + in_channels: 1024 # box_head.representation_size + num_classes: 91 # coco classes + background + fg_iou_thresh: 0.5 + bg_iou_thresh: 0.5 + batch_size_per_image: 512 + positive_fraction: 0.25 + bbox_reg_weights: null + score_thresh: 0.05 + nms_thresh: 0.5 + detections_per_img: 100 diff --git a/mart/models/dual_mode.py b/mart/models/dual_mode.py deleted file mode 100644 index 616e4177..00000000 --- a/mart/models/dual_mode.py +++ /dev/null @@ -1,123 +0,0 @@ -# -# Copyright (C) 2022 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -from collections import OrderedDict -from typing import Dict, List, Optional, Tuple - -import torch -from torch import Tensor -from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN - -from mart.utils.monkey_patch import MonkeyPatch - -__all__ = ["DualModeGeneralizedRCNN"] - - -class DualModeGeneralizedRCNN(torch.nn.Module): - """Efficient dual mode for GeneralizedRCNN from torchvision, by reusing feature maps from - backbone.""" - - def __init__(self, model): - super().__init__() - - self.model = model - - def forward(self, *args, **kwargs): - bound_method = self.forward_dual_mode.__get__(self.model, self.model.__class__) - with MonkeyPatch(self.model, "forward", bound_method): - ret = self.model(*args, **kwargs) - return ret - - # Adapted from: https://github.com/pytorch/vision/blob/32757a260dfedebf71eb470bd0a072ed20beddc3/torchvision/models/detection/generalized_rcnn.py#L46 - @staticmethod - def forward_dual_mode(self, images, targets=None): - # type: (GeneralizedRCNN, List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] - """ - Args: - self (GeneralizedRCNN): the model. - images (list[Tensor]): images to be processed - targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) - - Returns: - result (list[BoxList] or dict[Tensor]): the output from the model. - During training, it returns a dict[Tensor] which contains the losses. - During testing, it returns list[BoxList] contains additional fields - like `scores`, `labels` and `mask` (for Mask R-CNN models). - - """ - # Validate targets in both training and eval mode. - if targets is not None: - for target in targets: - boxes = target["boxes"] - if isinstance(boxes, torch.Tensor): - torch._assert( - len(boxes.shape) == 2 and boxes.shape[-1] == 4, - f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", - ) - else: - torch._assert( - False, - f"Expected target boxes to be of type Tensor, got {type(boxes)}.", - ) - - original_image_sizes: List[Tuple[int, int]] = [] - for img in images: - val = img.shape[-2:] - torch._assert( - len(val) == 2, - f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", - ) - original_image_sizes.append((val[0], val[1])) - - images, targets = self.transform(images, targets) - - # Check for degenerate boxes - # TODO: Move this to a function - if targets is not None: - for target_idx, target in enumerate(targets): - boxes = target["boxes"] - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - # print the first degenerate box - bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] - degen_bb: List[float] = boxes[bb_idx].tolist() - torch._assert( - False, - "All bounding boxes should have positive height and width." - f" Found invalid box {degen_bb} for target at index {target_idx}.", - ) - - features = self.backbone(images.tensors) - if isinstance(features, torch.Tensor): - features = OrderedDict([("0", features)]) - - original_training_status = self.training - ret = {} - - # Training mode. - self.train(True) - proposals, proposal_losses = self.rpn(images, features, targets) - detections, detector_losses = self.roi_heads( - features, proposals, images.image_sizes, targets - ) - losses = {} - losses.update(detector_losses) - losses.update(proposal_losses) - ret["training"] = losses - - # Eval mode. - self.train(False) - with torch.no_grad(): - proposals, proposal_losses = self.rpn(images, features, targets) - detections, detector_losses = self.roi_heads( - features, proposals, images.image_sizes, targets - ) - detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator] - ret["eval"] = detections - - self.train(original_training_status) - - return ret From 54878b064a449fd575725d3adb7cd6b56ab817bc Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 15:20:35 -0700 Subject: [PATCH 25/62] Remove DualModeGeneralizedRCNN in favor of using configuration --- mart/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mart/models/__init__.py b/mart/models/__init__.py index a6e8d57f..96544d96 100644 --- a/mart/models/__init__.py +++ b/mart/models/__init__.py @@ -4,5 +4,4 @@ # SPDX-License-Identifier: BSD-3-Clause # -from .dual_mode import * # noqa: F403 from .modular import * # noqa: F403 From 7b943119c6c91a228c99a1ac39e8d993374474d7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 15:22:00 -0700 Subject: [PATCH 26/62] Remove DualModeGeneralizedRCNN in favor of using configuration --- .../model/torchvision_faster_rcnn.yaml | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index 3b1feea6..c678bb5e 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -27,6 +27,8 @@ training_sequence: seq040: rpn: + train_mode: True + inference_mode: False images: "transform.images" features: "backbone" targets: "transform.targets" @@ -34,12 +36,35 @@ training_sequence: seq050: roi_heads: + train_mode: True + inference_mode: False features: "backbone" proposals: "rpn.proposals" image_shapes: "transform.images.image_sizes" targets: "transform.targets" _return_as_dict_: ["detections", "losses"] + seq060: + rpn: + _name_: "rpn_eval" + train_mode: False + inference_mode: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq070: + roi_heads: + _name_: "roi_heads_eval" + train_mode: False + inference_mode: True + features: "backbone" + proposals: "rpn_eval.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + seq100: loss: # Sum up the losses. @@ -66,6 +91,8 @@ validation_sequence: seq040: rpn: + train_mode: True + inference_mode: False images: "transform.images" features: "backbone" targets: "transform.targets" @@ -73,12 +100,35 @@ validation_sequence: seq050: roi_heads: + train_mode: True + inference_mode: False features: "backbone" proposals: "rpn.proposals" image_shapes: "transform.images.image_sizes" targets: "transform.targets" _return_as_dict_: ["detections", "losses"] + seq060: + rpn: + _name_: "rpn_eval" + train_mode: False + inference_mode: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq070: + roi_heads: + _name_: "roi_heads_eval" + train_mode: False + inference_mode: True + features: "backbone" + proposals: "rpn_eval.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + test_sequence: seq010: preprocessor: ["input"] @@ -95,6 +145,8 @@ test_sequence: seq040: rpn: + train_mode: True + inference_mode: False images: "transform.images" features: "backbone" targets: "transform.targets" @@ -102,6 +154,8 @@ test_sequence: seq050: roi_heads: + train_mode: True + inference_mode: False features: "backbone" proposals: "rpn.proposals" image_shapes: "transform.images.image_sizes" From b6f8ca306897401e60a05b168b987a5dd9cae8df Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Tue, 13 Jun 2023 11:58:03 -0700 Subject: [PATCH 27/62] CallWith passes non-str arguments directly to module --- mart/nn/nn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 93b0f07f..345d3c78 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -112,6 +112,7 @@ def forward(self, step=None, sequence=None, **kwargs): # Don't pop the first element yet, because it may be used to re-evaluate the model. key, module = next(iter(sequence.items())) + # FIXME: Add better error message output = module(step=step, sequence=sequence, **kwargs) if key in kwargs: @@ -152,14 +153,16 @@ def forward(self, *args, **kwargs): # as it and assume these consume the first len(args) of arg_keys. remaining_arg_keys = arg_keys[len(args) :] - for key in remaining_arg_keys + list(kwarg_keys.values()): + # Check to make sure each str exists within kwargs + str_kwarg_keys = filter(lambda k: isinstance(k, str), kwarg_keys.values()) + for key in remaining_arg_keys + list(str_kwarg_keys): if key not in kwargs: raise Exception( f"Module {orig_class} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." ) - selected_args = [kwargs[key] for key in arg_keys[len(args) :]] - selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} + selected_args = [kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :]] + selected_kwargs = {key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items()} # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) From 52f050199c9c226d22a2a04361cca4abd7470a87 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 13:29:18 -0700 Subject: [PATCH 28/62] _return_as_dict -> _return_as_dict_ --- mart/nn/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 345d3c78..92c3bc70 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -27,7 +27,7 @@ class SequentialDict(torch.nn.ModuleDict): : _name_: _call_with_args_: - _return_as_dict: + _return_as_dict_: **kwargs All intermediate output from each module are stored in the dictionary `kwargs` in `forward()` @@ -86,7 +86,7 @@ def parse_sequence(self, sequence): # The module would be called with these *args. arg_keys = module_cfg.pop("_call_with_args_", None) # The module would return a dictionary with these keys instead of a tuple. - return_keys = module_cfg.pop("_return_as_dict", None) + return_keys = module_cfg.pop("_return_as_dict_", None) # The module would be called with these **kwargs. kwarg_keys = module_cfg From a0761ee16634ed17543a7dc3ab28f3d91b83b55c Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 16:18:12 -0700 Subject: [PATCH 29/62] style --- mart/nn/nn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 92c3bc70..90bfa3a6 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -161,8 +161,12 @@ def forward(self, *args, **kwargs): f"Module {orig_class} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." ) - selected_args = [kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :]] - selected_kwargs = {key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items()} + selected_args = [ + kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :] + ] + selected_kwargs = { + key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items() + } # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) From 2063b973cb3cc951fad74fca64d35fb397e3b7a7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 16:35:53 -0700 Subject: [PATCH 30/62] bugfix --- .../ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml index c5d5f75d..09489f48 100644 --- a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml @@ -12,9 +12,9 @@ optimized_metric: "test_metrics/map" model: modules: - detector: - num_classes: 3 - weights: null + roi_heads: + box_predictor: + num_classes: 3 optimizer: lr: 0.0125 From 4f41893e6d6c64d79af73d1b091f15873b457f19 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 17:12:40 -0700 Subject: [PATCH 31/62] Improve load_state_dict --- mart/models/modular.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/mart/models/modular.py b/mart/models/modular.py index a27c6867..cf502a5b 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -95,10 +95,24 @@ def __init__( # commandlines converts dotted paths to nested dictionaries. load_state_dict = flatten_dict(load_state_dict or {}) - for name, path in load_state_dict.items(): - module = attrgetter(name)(self.model) - logger.info(f"Loading state_dict {path} for {module.__class__.__name__}...") - module.load_state_dict(torch.load(path, map_location="cpu")) + for name, state_dict in load_state_dict.items(): + if name == "_model_": + module = self.model + else: + module = attrgetter(name)(self.model) + + if isinstance(state_dict, str): + logger.info(f"Loading {state_dict} for {module.__class__.__name__}...") + state_dict = torch.load(state_dict, map_location="cpu") + + elif hasattr(state_dict, "get_state_dict"): + logger.info(f"Loading {state_dict.__class__.__name__} for {module.__class__.__name__}...") + state_dict = state_dict.get_state_dict(progress=True) + + else: + raise ValueError(f"Unsupported state_dict: {state_dict}") + + module.load_state_dict(state_dict, strict=True) self.output_loss_key = output_loss_key self.output_preds_key = output_preds_key From dcf222ec66596ad4a66312a6d7d889db43c172b5 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 17:12:49 -0700 Subject: [PATCH 32/62] Load COCO_V1 weights --- mart/configs/model/torchvision_faster_rcnn.yaml | 5 +++++ mart/configs/model/torchvision_retinanet.yaml | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index c678bb5e..3f7cd1ff 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -2,6 +2,11 @@ defaults: - torchvision_object_detection +load_state_dict: + _model_: + _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available + path: "torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1" + output_preds_key: "roi_heads_eval.detections" # log all losses separately in losses. diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 4753ea6b..d09bde00 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -2,6 +2,11 @@ defaults: - torchvision_object_detection +load_state_dict: + detector: + _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available + path: "torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights.COCO_V1" + # log all losses separately in training. training_step_log: loss_classifier: "losses.classification" From d4b60ce53b60f2baaba94a58288ba51452f04ad3 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 17:18:11 -0700 Subject: [PATCH 33/62] Load COCO_V1 weights --- mart/configs/model/torchvision_dual_faster_rcnn.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/configs/model/torchvision_dual_faster_rcnn.yaml b/mart/configs/model/torchvision_dual_faster_rcnn.yaml index d7f9657e..37960667 100644 --- a/mart/configs/model/torchvision_dual_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_dual_faster_rcnn.yaml @@ -80,4 +80,4 @@ modules: detector: _target_: mart.nn.Module _path_: torchvision.models.detection.fasterrcnn_resnet50_fpn - num_classes: ??? + weights: "COCO_V1" From b159543a731b772c41ba666f5d7ee4fe38feed62 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 17:55:38 -0700 Subject: [PATCH 34/62] Add GeneralizedRCNNPostProcessor --- .../model/torchvision_faster_rcnn.yaml | 23 +++++++- mart/models/detection.py | 55 +++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 mart/models/detection.py diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index 3f7cd1ff..ad0ec78c 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -7,8 +7,6 @@ load_state_dict: _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available path: "torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1" -output_preds_key: "roi_heads_eval.detections" - # log all losses separately in losses. training_step_log: loss_objectness: "rpn.losses.loss_objectness" @@ -70,6 +68,12 @@ training_sequence: targets: "transform.targets" _return_as_dict_: ["detections", "losses"] + seq080: + preds: + result: "roi_heads_eval.detections" + image_shapes: "transform.images.image_sizes" + original_images: "preprocessor" + seq100: loss: # Sum up the losses. @@ -134,6 +138,12 @@ validation_sequence: targets: "transform.targets" _return_as_dict_: ["detections", "losses"] + seq080: + preds: + result: "roi_heads_eval.detections" + image_shapes: "transform.images.image_sizes" + original_images: "preprocessor" + test_sequence: seq010: preprocessor: ["input"] @@ -188,6 +198,12 @@ test_sequence: targets: "transform.targets" _return_as_dict_: ["detections", "losses"] + seq080: + preds: + result: "roi_heads_eval.detections" + image_shapes: "transform.images.image_sizes" + original_images: "preprocessor" + modules: transform: _target_: torchvision.models.detection.transform.GeneralizedRCNNTransform @@ -272,3 +288,6 @@ modules: score_thresh: 0.05 nms_thresh: 0.5 detections_per_img: 100 + + preds: + _target_: mart.models.detection.GeneralizedRCNNPostProcessor diff --git a/mart/models/detection.py b/mart/models/detection.py new file mode 100644 index 00000000..f895faf8 --- /dev/null +++ b/mart/models/detection.py @@ -0,0 +1,55 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import torch +from torch import Tensor, nn +from torchvision.models.detection.transform import ( + paste_masks_in_image, + resize_boxes, + resize_keypoints, +) + +__all__ = ["GeneralizedRCNNPostProcessor"] + + +class GeneralizedRCNNPostProcessor(nn.Module): + def forward( + self, + result: list[dict[str, Tensor]], + image_shapes: list[tuple[int, int]], + original_images: list[Tensor], + ) -> list[dict[str, Tensor]]: + original_image_sizes = get_image_sizes(original_images) + + for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): + boxes = pred["boxes"] + boxes = resize_boxes(boxes, im_s, o_im_s) + result[i]["boxes"] = boxes + if "masks" in pred: + masks = pred["masks"] + masks = paste_masks_in_image(masks, boxes, o_im_s) + result[i]["masks"] = masks + if "keypoints" in pred: + keypoints = pred["keypoints"] + keypoints = resize_keypoints(keypoints, im_s, o_im_s) + result[i]["keypoints"] = keypoints + return result + + +def get_image_sizes(images: list[Tensor]): + image_sizes: list[tuple[int, int]] = [] + + for img in images: + val = img.shape[-2:] + torch._assert( + len(val) == 2, + f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", + ) + image_sizes.append((val[0], val[1])) + + return image_sizes From 0b6ce3d1ae81133fc182e4020ecbd2f7e3b1466c Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 14 Jun 2023 17:56:00 -0700 Subject: [PATCH 35/62] style --- mart/configs/model/torchvision_faster_rcnn.yaml | 2 +- mart/configs/model/torchvision_retinanet.yaml | 2 +- mart/models/modular.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index ad0ec78c..bea1013f 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -4,7 +4,7 @@ defaults: load_state_dict: _model_: - _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available + _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available path: "torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1" # log all losses separately in losses. diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index d09bde00..81bb166a 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -4,7 +4,7 @@ defaults: load_state_dict: detector: - _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available + _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available path: "torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights.COCO_V1" # log all losses separately in training. diff --git a/mart/models/modular.py b/mart/models/modular.py index cf502a5b..f8355f60 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -106,7 +106,9 @@ def __init__( state_dict = torch.load(state_dict, map_location="cpu") elif hasattr(state_dict, "get_state_dict"): - logger.info(f"Loading {state_dict.__class__.__name__} for {module.__class__.__name__}...") + logger.info( + f"Loading {state_dict.__class__.__name__} for {module.__class__.__name__}..." + ) state_dict = state_dict.get_state_dict(progress=True) else: From 3faa5b669f57cbfcffe2b8eba57dd15e62221347 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:02:42 -0700 Subject: [PATCH 36/62] Surround parameters by underscores --- .../model/torchvision_dual_faster_rcnn.yaml | 18 ++++---- .../model/torchvision_faster_rcnn.yaml | 42 ++++++++----------- mart/configs/model/torchvision_retinanet.yaml | 18 ++++---- mart/nn/module.py | 24 +++++++---- 4 files changed, 52 insertions(+), 50 deletions(-) diff --git a/mart/configs/model/torchvision_dual_faster_rcnn.yaml b/mart/configs/model/torchvision_dual_faster_rcnn.yaml index 37960667..fa078001 100644 --- a/mart/configs/model/torchvision_dual_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_dual_faster_rcnn.yaml @@ -16,17 +16,17 @@ training_sequence: seq020: detector: _name_: losses + _train_mode_: True images: preprocessor targets: target - train_mode: True seq030: detector: _name_: preds + _train_mode_: False + _inference_mode_: True images: preprocessor targets: target - train_mode: False - inference_mode: True seq100: loss: @@ -45,17 +45,17 @@ validation_sequence: seq020: detector: _name_: losses + _train_mode_: True images: preprocessor targets: target - train_mode: True seq030: detector: _name_: preds + _train_mode_: False + _inference_mode_: True images: preprocessor targets: target - train_mode: False - inference_mode: True test_sequence: seq010: @@ -64,17 +64,17 @@ test_sequence: seq020: detector: _name_: losses + _train_mode_: True images: preprocessor targets: target - train_mode: True seq030: detector: _name_: preds + _train_mode_: False + _inference_mode_: True images: preprocessor targets: target - train_mode: False - inference_mode: True modules: detector: diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index bea1013f..1d79c0a4 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -30,8 +30,7 @@ training_sequence: seq040: rpn: - train_mode: True - inference_mode: False + _train_mode_: True images: "transform.images" features: "backbone" targets: "transform.targets" @@ -39,8 +38,7 @@ training_sequence: seq050: roi_heads: - train_mode: True - inference_mode: False + _train_mode_: True features: "backbone" proposals: "rpn.proposals" image_shapes: "transform.images.image_sizes" @@ -50,8 +48,8 @@ training_sequence: seq060: rpn: _name_: "rpn_eval" - train_mode: False - inference_mode: True + _train_mode_: False + _inference_mode_: True images: "transform.images" features: "backbone" targets: "transform.targets" @@ -60,8 +58,8 @@ training_sequence: seq070: roi_heads: _name_: "roi_heads_eval" - train_mode: False - inference_mode: True + _train_mode_: False + _inference_mode_: True features: "backbone" proposals: "rpn_eval.proposals" image_shapes: "transform.images.image_sizes" @@ -100,8 +98,7 @@ validation_sequence: seq040: rpn: - train_mode: True - inference_mode: False + _train_mode_: True images: "transform.images" features: "backbone" targets: "transform.targets" @@ -109,8 +106,7 @@ validation_sequence: seq050: roi_heads: - train_mode: True - inference_mode: False + _train_mode_: True features: "backbone" proposals: "rpn.proposals" image_shapes: "transform.images.image_sizes" @@ -120,8 +116,8 @@ validation_sequence: seq060: rpn: _name_: "rpn_eval" - train_mode: False - inference_mode: True + _train_mode_: False + _inference_mode_: True images: "transform.images" features: "backbone" targets: "transform.targets" @@ -130,8 +126,8 @@ validation_sequence: seq070: roi_heads: _name_: "roi_heads_eval" - train_mode: False - inference_mode: True + _train_mode_: False + _inference_mode_: True features: "backbone" proposals: "rpn_eval.proposals" image_shapes: "transform.images.image_sizes" @@ -160,8 +156,7 @@ test_sequence: seq040: rpn: - train_mode: True - inference_mode: False + _train_mode_: True images: "transform.images" features: "backbone" targets: "transform.targets" @@ -169,8 +164,7 @@ test_sequence: seq050: roi_heads: - train_mode: True - inference_mode: False + _train_mode_: True features: "backbone" proposals: "rpn.proposals" image_shapes: "transform.images.image_sizes" @@ -180,8 +174,8 @@ test_sequence: seq060: rpn: _name_: "rpn_eval" - train_mode: False - inference_mode: True + _train_mode_: False + _inference_mode_: True images: "transform.images" features: "backbone" targets: "transform.targets" @@ -190,8 +184,8 @@ test_sequence: seq070: roi_heads: _name_: "roi_heads_eval" - train_mode: False - inference_mode: True + _train_mode_: False + _inference_mode_: True features: "backbone" proposals: "rpn_eval.proposals" image_shapes: "transform.images.image_sizes" diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 81bb166a..c0318898 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -19,17 +19,17 @@ training_sequence: seq020: detector: _name_: losses + _train_mode_: True images: preprocessor targets: target - train_mode: True seq030: detector: _name_: preds + _train_mode_: False + _inference_mode_: True images: preprocessor targets: target - train_mode: False - inference_mode: True seq100: loss: @@ -43,17 +43,17 @@ validation_sequence: seq020: detector: _name_: losses + _train_mode_: True images: preprocessor targets: target - train_mode: True seq030: detector: _name_: preds + _train_mode_: False + _inference_mode_: True images: preprocessor targets: target - train_mode: False - inference_mode: True test_sequence: seq010: @@ -62,17 +62,17 @@ test_sequence: seq020: detector: _name_: losses + _train_mode_: True images: preprocessor targets: target - train_mode: True seq030: detector: _name_: preds + _train_mode_: False + _inference_mode_: True images: preprocessor targets: target - train_mode: False - inference_mode: True modules: detector: diff --git a/mart/nn/module.py b/mart/nn/module.py index 1cbef0da..29d5d4d7 100644 --- a/mart/nn/module.py +++ b/mart/nn/module.py @@ -6,6 +6,7 @@ import types from collections import OrderedDict +from contextlib import nullcontext import torch from hydra.utils import instantiate @@ -23,7 +24,7 @@ def __new__(cls, *args, _path_: str, **kwargs): cfg = {"_target_": _path_} module = instantiate(cfg, *args, **kwargs) - module._forward = module.forward + module._forward_ = module.forward module.forward = types.MethodType(Module.forward, module) return module @@ -32,16 +33,23 @@ def __new__(cls, *args, _path_: str, **kwargs): def forward( self, *args, - train_mode: bool = True, - inference_mode: bool = False, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwargs, ): old_train_mode = self.training - # FIXME: Would be nice if this was a context... - self.train(train_mode) - with torch.inference_mode(mode=inference_mode): - ret = self._forward(*args, **kwargs) - self.train(old_train_mode) + if _train_mode_ is not None: + self.train(_train_mode_) + + inference_mode = nullcontext() + if _inference_mode_ is not None: + inference_mode = torch.inference_mode(mode=_inference_mode_) + + with inference_mode: + ret = self._forward_(*args, **kwargs) + + if _train_mode_ is not None: + self.train(old_train_mode) return ret From cc719342904c48337a7c238298f7777111509913 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:23:10 -0700 Subject: [PATCH 37/62] bugfix --- mart/nn/module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mart/nn/module.py b/mart/nn/module.py index 29d5d4d7..2590965a 100644 --- a/mart/nn/module.py +++ b/mart/nn/module.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import types from collections import OrderedDict from contextlib import nullcontext From b2009a87c42ab74765d4cb75e6c8e2ccfe7c7257 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:27:55 -0700 Subject: [PATCH 38/62] Revert "_return_as_dict -> _return_as_dict_" This reverts commit 52f050199c9c226d22a2a04361cca4abd7470a87. --- mart/nn/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 90bfa3a6..17ba0efc 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -27,7 +27,7 @@ class SequentialDict(torch.nn.ModuleDict): : _name_: _call_with_args_: - _return_as_dict_: + _return_as_dict: **kwargs All intermediate output from each module are stored in the dictionary `kwargs` in `forward()` @@ -86,7 +86,7 @@ def parse_sequence(self, sequence): # The module would be called with these *args. arg_keys = module_cfg.pop("_call_with_args_", None) # The module would return a dictionary with these keys instead of a tuple. - return_keys = module_cfg.pop("_return_as_dict_", None) + return_keys = module_cfg.pop("_return_as_dict", None) # The module would be called with these **kwargs. kwarg_keys = module_cfg From cef11f38cf5d3b207beb30fb4f2507d5106b1d63 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:31:10 -0700 Subject: [PATCH 39/62] fix imports --- mart/nn/nn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 93b0f07f..8ad7c194 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -5,15 +5,14 @@ # import logging +from typing import OrderedDict -logger = logging.getLogger(__name__) - -from typing import OrderedDict # noqa: E402 - -import torch # noqa: E402 +import torch __all__ = ["GroupNorm32", "SequentialDict", "ReturnKwargs", "CallWith", "Sum"] +logger = logging.getLogger(__name__) + class SequentialDict(torch.nn.ModuleDict): """A special Sequential container where we can rewire the input and output of each module in From 9522075536bd0df387cc957367b614444cb66608 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:33:54 -0700 Subject: [PATCH 40/62] Move _call_with_args_ and _return_as_dict_ functionality into CallWith --- mart/nn/nn.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 8ad7c194..7f403b9f 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -4,8 +4,10 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import logging -from typing import OrderedDict +from typing import Callable, Iterable, OrderedDict import torch @@ -82,20 +84,9 @@ def parse_sequence(self, sequence): # The return name could be different from module_name when a module is used more than once. return_name = module_cfg.pop("_name_", module_name) - # The module would be called with these *args. - arg_keys = module_cfg.pop("_call_with_args_", None) - # The module would return a dictionary with these keys instead of a tuple. - return_keys = module_cfg.pop("_return_as_dict", None) - # The module would be called with these **kwargs. - kwarg_keys = module_cfg - - module = self[module_name] - - # Add CallWith to module if we have enough parameters - if arg_keys is not None or len(kwarg_keys) > 0 or return_keys is not None: - module = CallWith(module, arg_keys, kwarg_keys, return_keys) - + module = CallWith(self[module_name], **module_cfg) module_dict[return_name] = module + return module_dict def forward(self, step=None, sequence=None, **kwargs): @@ -132,13 +123,19 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): - def __init__(self, module, arg_keys, kwarg_keys, return_keys=None) -> None: + def __init__( + self, + module: Callable, + _call_with_args_: Iterable[str] | None = None, + _return_as_dict_: Iterable[str] | None = None, + **kwarg_keys, + ) -> None: super().__init__() self.module = module - self.arg_keys = arg_keys or [] + self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} - self.return_keys = return_keys + self.return_keys = _return_as_dict_ def forward(self, *args, **kwargs): orig_class = self.module.__class__ From ba582644bed9a793d90f5688a7780d1179d18e66 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 08:37:17 -0700 Subject: [PATCH 41/62] Allow overwriting _call_with_args_ and _return_as_dict_ in CallWith.forward --- mart/nn/nn.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 7f403b9f..fead5225 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -137,10 +137,17 @@ def __init__( self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ - def forward(self, *args, **kwargs): + def forward( + self, + *args, + _call_with_args_: Iterable[str] | None = None, + _return_as_dict_: Iterable[str] | None = None, + **kwargs, + ): orig_class = self.module.__class__ - arg_keys = self.arg_keys + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys + return_keys = _return_as_dict_ or self.return_keys kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -160,16 +167,16 @@ def forward(self, *args, **kwargs): # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) - if self.return_keys: + if return_keys: if not isinstance(ret, tuple): raise Exception( f"Module {orig_class} does not return multiple unnamed variables, so we can not dictionarize the return." ) - if len(self.return_keys) != len(ret): + if len(return_keys) != len(ret): raise Exception( - f"Module {orig_class} returns {len(ret)} items, but {len(self.return_keys)} return_keys were specified." + f"Module {orig_class} returns {len(ret)} items, but {len(return_keys)} return_keys were specified." ) - ret = dict(zip(self.return_keys, ret)) + ret = dict(zip(return_keys, ret)) return ret From 60fc2ad0d2a372e03cc64133a9068496c51d8ad8 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:04:07 -0700 Subject: [PATCH 42/62] Add _train_mode_ and _inference_mode_ to CallWith --- mart/nn/nn.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index fead5225..6c2ba4bc 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,7 +7,9 @@ from __future__ import annotations import logging -from typing import Callable, Iterable, OrderedDict +from collections import OrderedDict +from contextlib import nullcontext +from typing import Iterable import torch @@ -125,9 +127,11 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: Callable, + module: torch.nn.Module, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwarg_keys, ) -> None: super().__init__() @@ -136,18 +140,26 @@ def __init__( self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ + self.train_mode = _train_mode_ + self.inference_mode = _inference_mode_ def forward( self, *args, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwargs, ): orig_class = self.module.__class__ + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys return_keys = _return_as_dict_ or self.return_keys + _train_mode_ = _train_mode_ or self.train_mode + _inference_mode_ = _inference_mode_ or self.inference_mode + kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -164,8 +176,21 @@ def forward( selected_args = [kwargs[key] for key in arg_keys[len(args) :]] selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} - # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) + old_train_mode = self.module.training + + if _train_mode_ is not None: + self.module.train(_train_mode_) + + context = nullcontext() + if _inference_mode_ is not None: + context = torch.inference_mode(mode=_inference_mode_) + + with context: + # FIXME: Add better error message + ret = self.module(*args, *selected_args, **selected_kwargs) + + if _train_mode_ is not None: + self.module.train(old_train_mode) if return_keys: if not isinstance(ret, tuple): From 2e51c662409a7aeb0e0913696330395326772fa7 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:11:25 -0700 Subject: [PATCH 43/62] Revert "Add _train_mode_ and _inference_mode_ to CallWith" This reverts commit 60fc2ad0d2a372e03cc64133a9068496c51d8ad8. --- mart/nn/nn.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 6c2ba4bc..fead5225 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,9 +7,7 @@ from __future__ import annotations import logging -from collections import OrderedDict -from contextlib import nullcontext -from typing import Iterable +from typing import Callable, Iterable, OrderedDict import torch @@ -127,11 +125,9 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: torch.nn.Module, + module: Callable, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, - _train_mode_: bool | None = None, - _inference_mode_: bool | None = None, **kwarg_keys, ) -> None: super().__init__() @@ -140,26 +136,18 @@ def __init__( self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ - self.train_mode = _train_mode_ - self.inference_mode = _inference_mode_ def forward( self, *args, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, - _train_mode_: bool | None = None, - _inference_mode_: bool | None = None, **kwargs, ): orig_class = self.module.__class__ - arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys return_keys = _return_as_dict_ or self.return_keys - _train_mode_ = _train_mode_ or self.train_mode - _inference_mode_ = _inference_mode_ or self.inference_mode - kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -176,21 +164,8 @@ def forward( selected_args = [kwargs[key] for key in arg_keys[len(args) :]] selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} - old_train_mode = self.module.training - - if _train_mode_ is not None: - self.module.train(_train_mode_) - - context = nullcontext() - if _inference_mode_ is not None: - context = torch.inference_mode(mode=_inference_mode_) - - with context: - # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) - - if _train_mode_ is not None: - self.module.train(old_train_mode) + # FIXME: Add better error message + ret = self.module(*args, *selected_args, **selected_kwargs) if return_keys: if not isinstance(ret, tuple): From 1671755ec6d3b9a63bae7a1b1b2d4662faec86ec Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:12:24 -0700 Subject: [PATCH 44/62] Revert "Revert "Add _train_mode_ and _inference_mode_ to CallWith"" This reverts commit 2e51c662409a7aeb0e0913696330395326772fa7. --- mart/nn/nn.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index fead5225..6c2ba4bc 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,7 +7,9 @@ from __future__ import annotations import logging -from typing import Callable, Iterable, OrderedDict +from collections import OrderedDict +from contextlib import nullcontext +from typing import Iterable import torch @@ -125,9 +127,11 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: Callable, + module: torch.nn.Module, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwarg_keys, ) -> None: super().__init__() @@ -136,18 +140,26 @@ def __init__( self.arg_keys = _call_with_args_ or [] self.kwarg_keys = kwarg_keys or {} self.return_keys = _return_as_dict_ + self.train_mode = _train_mode_ + self.inference_mode = _inference_mode_ def forward( self, *args, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, + _train_mode_: bool | None = None, + _inference_mode_: bool | None = None, **kwargs, ): orig_class = self.module.__class__ + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys return_keys = _return_as_dict_ or self.return_keys + _train_mode_ = _train_mode_ or self.train_mode + _inference_mode_ = _inference_mode_ or self.inference_mode + kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -164,8 +176,21 @@ def forward( selected_args = [kwargs[key] for key in arg_keys[len(args) :]] selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} - # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) + old_train_mode = self.module.training + + if _train_mode_ is not None: + self.module.train(_train_mode_) + + context = nullcontext() + if _inference_mode_ is not None: + context = torch.inference_mode(mode=_inference_mode_) + + with context: + # FIXME: Add better error message + ret = self.module(*args, *selected_args, **selected_kwargs) + + if _train_mode_ is not None: + self.module.train(old_train_mode) if return_keys: if not isinstance(ret, tuple): From 731b23daa5abfc13db8205e9e39e08cb85ff9c6f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:14:18 -0700 Subject: [PATCH 45/62] cleanup --- mart/nn/nn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index fead5225..170c7274 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -7,7 +7,8 @@ from __future__ import annotations import logging -from typing import Callable, Iterable, OrderedDict +from collections import OrderedDict +from typing import Iterable import torch @@ -125,7 +126,7 @@ def __call__(self, **kwargs): class CallWith(torch.nn.Module): def __init__( self, - module: Callable, + module: torch.nn.Module, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, **kwarg_keys, From 8e664f1927690f20dc3e7131be94738434db3119 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 10:27:27 -0700 Subject: [PATCH 46/62] cleanup --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 170c7274..ed734a95 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -148,7 +148,6 @@ def forward( orig_class = self.module.__class__ arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys - return_keys = _return_as_dict_ or self.return_keys kwargs = DotDict(kwargs) # Sometimes we receive positional arguments because some modules use nn.Sequential @@ -168,6 +167,7 @@ def forward( # FIXME: Add better error message ret = self.module(*args, *selected_args, **selected_kwargs) + return_keys = _return_as_dict_ or self.return_keys if return_keys: if not isinstance(ret, tuple): raise Exception( From ae1b8369bb35a4645e68fb1c2c677058e7e7094f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 11:46:30 -0700 Subject: [PATCH 47/62] Fix configs --- mart/configs/attack/gain/modular.yaml | 3 +-- mart/configs/attack/gain/rcnn_training_loss.yaml | 3 +-- mart/configs/attack/objective/misclassification.yaml | 3 +-- mart/configs/attack/objective/object_detection_missed.yaml | 3 +-- mart/configs/attack/objective/zero_ap.yaml | 3 +-- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/mart/configs/attack/gain/modular.yaml b/mart/configs/attack/gain/modular.yaml index 9b06b307..9c76b2c4 100644 --- a/mart/configs/attack/gain/modular.yaml +++ b/mart/configs/attack/gain/modular.yaml @@ -1,7 +1,6 @@ _target_: mart.nn.CallWith module: _target_: ??? -arg_keys: +_call_with_args_: - logits - target -kwarg_keys: null diff --git a/mart/configs/attack/gain/rcnn_training_loss.yaml b/mart/configs/attack/gain/rcnn_training_loss.yaml index eb7abb9c..19b9355e 100644 --- a/mart/configs/attack/gain/rcnn_training_loss.yaml +++ b/mart/configs/attack/gain/rcnn_training_loss.yaml @@ -1,9 +1,8 @@ _target_: mart.nn.CallWith module: _target_: mart.nn.Sum -arg_keys: +_call_with_args_: - rpn_loss.loss_objectness - rpn_loss.loss_rpn_box_reg - box_loss.loss_classifier - box_loss.loss_box_reg -kwarg_keys: null diff --git a/mart/configs/attack/objective/misclassification.yaml b/mart/configs/attack/objective/misclassification.yaml index e2e9b819..a2e6260e 100644 --- a/mart/configs/attack/objective/misclassification.yaml +++ b/mart/configs/attack/objective/misclassification.yaml @@ -1,7 +1,6 @@ _target_: mart.nn.CallWith module: _target_: mart.attack.objective.Mispredict -arg_keys: +_call_with_args_: - preds - target -kwarg_keys: null diff --git a/mart/configs/attack/objective/object_detection_missed.yaml b/mart/configs/attack/objective/object_detection_missed.yaml index dec2410c..efc93078 100644 --- a/mart/configs/attack/objective/object_detection_missed.yaml +++ b/mart/configs/attack/objective/object_detection_missed.yaml @@ -2,6 +2,5 @@ _target_: mart.nn.CallWith module: _target_: mart.attack.objective.Missed confidence_threshold: 0.0 -arg_keys: +_call_with_args_: - preds -kwarg_keys: null diff --git a/mart/configs/attack/objective/zero_ap.yaml b/mart/configs/attack/objective/zero_ap.yaml index 6a43f77d..e11105e6 100644 --- a/mart/configs/attack/objective/zero_ap.yaml +++ b/mart/configs/attack/objective/zero_ap.yaml @@ -3,7 +3,6 @@ module: _target_: mart.attack.objective.ZeroAP iou_threshold: 0.5 confidence_threshold: 0.0 -arg_keys: +_call_with_args_: - preds - target -kwarg_keys: null From 39f9aaa04efc5bf6e7a9040c3a4068c5e8b032ed Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 11:48:48 -0700 Subject: [PATCH 48/62] cleanup --- mart/nn/nn.py | 47 +++++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index ed734a95..9710f487 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -134,8 +134,8 @@ def __init__( super().__init__() self.module = module - self.arg_keys = _call_with_args_ or [] - self.kwarg_keys = kwarg_keys or {} + self.arg_keys = _call_with_args_ + self.kwarg_keys = kwarg_keys self.return_keys = _return_as_dict_ def forward( @@ -145,37 +145,52 @@ def forward( _return_as_dict_: Iterable[str] | None = None, **kwargs, ): - orig_class = self.module.__class__ + module_name = self.module.__class__.__name__ + arg_keys = _call_with_args_ or self.arg_keys kwarg_keys = self.kwarg_keys + + args = list(args) kwargs = DotDict(kwargs) - # Sometimes we receive positional arguments because some modules use nn.Sequential - # which has a __call__ function that passes positional args. So we pass along args - # as it and assume these consume the first len(args) of arg_keys. - remaining_arg_keys = arg_keys[len(args) :] + # Change and replaces args and kwargs that we call module with + if arg_keys is not None or len(kwarg_keys) > 0: + arg_keys = arg_keys or [] + + # Sometimes we receive positional arguments because some modules use nn.Sequential + # which has a __call__ function that passes positional args. So we pass along args + # as it and assume these consume the first len(args) of arg_keys. + arg_keys = arg_keys[len(args) :] - for key in remaining_arg_keys + list(kwarg_keys.values()): - if key not in kwargs: + # Append kwargs to args using arg_keys + try: + [args.append(kwargs[kwargs_key]) for kwargs_key in arg_keys] + except KeyError as ex: raise Exception( - f"Module {orig_class} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." - ) + f"{module_name} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + ) from ex - selected_args = [kwargs[key] for key in arg_keys[len(args) :]] - selected_kwargs = {key: kwargs[val] for key, val in kwarg_keys.items()} + # Replace kwargs with selected kwargs + try: + kwargs = {name: kwargs[kwargs_key] for name, kwargs_key in kwarg_keys.items()} + except KeyError as ex: + raise Exception( + f"{module_name} wants kwarg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + ) from ex # FIXME: Add better error message - ret = self.module(*args, *selected_args, **selected_kwargs) + ret = self.module(*args, **kwargs) + # Change returned values into dictionary, if necessary return_keys = _return_as_dict_ or self.return_keys if return_keys: if not isinstance(ret, tuple): raise Exception( - f"Module {orig_class} does not return multiple unnamed variables, so we can not dictionarize the return." + f"{module_name} does not return multiple unnamed variables, so we can not dictionarize the return." ) if len(return_keys) != len(ret): raise Exception( - f"Module {orig_class} returns {len(ret)} items, but {len(return_keys)} return_keys were specified." + f"Module {module_name} returns {len(ret)} items, but {len(return_keys)} return_keys were specified." ) ret = dict(zip(return_keys, ret)) From 097393384109e845a554063dc1786313d735f871 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 11:50:08 -0700 Subject: [PATCH 49/62] bugfix --- mart/nn/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 9710f487..c932bc55 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -167,7 +167,7 @@ def forward( [args.append(kwargs[kwargs_key]) for kwargs_key in arg_keys] except KeyError as ex: raise Exception( - f"{module_name} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex # Replace kwargs with selected kwargs @@ -175,7 +175,7 @@ def forward( kwargs = {name: kwargs[kwargs_key] for name, kwargs_key in kwarg_keys.items()} except KeyError as ex: raise Exception( - f"{module_name} wants kwarg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." + f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex # FIXME: Add better error message From 2ec4e4984757cffb8bdf841e4765875b379e2351 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 12:30:15 -0700 Subject: [PATCH 50/62] Only set train mode and inference mode on Modules --- mart/nn/nn.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 56ced9b6..b365a8d5 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -188,21 +188,23 @@ def forward( ) from ex # Apply train mode and inference mode, if necessary, and call module with args and kwargs - old_train_mode = self.module.training + context = nullcontext() + if isinstance(self.module, torch.nn.Module) + old_train_mode = self.module.training - if _train_mode_ is not None: - self.module.train(_train_mode_) + if _train_mode_ is not None: + self.module.train(_train_mode_) - context = nullcontext() - if _inference_mode_ is not None: - context = torch.inference_mode(mode=_inference_mode_) + if _inference_mode_ is not None: + context = torch.inference_mode(mode=_inference_mode_) with context: # FIXME: Add better error message ret = self.module(*args, **kwargs) - if _train_mode_ is not None: - self.module.train(old_train_mode) + if isinstance(self.module, torch.nn.Module): + if _train_mode_ is not None: + self.module.train(old_train_mode) # Change returned values into dictionary, if necessary return_keys = _return_as_dict_ or self.return_keys From 741d282d2fb21b4c994781d64a8afab8b30c8f1f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 12:31:07 -0700 Subject: [PATCH 51/62] bugfix --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index b365a8d5..d2053fa8 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -189,7 +189,7 @@ def forward( # Apply train mode and inference mode, if necessary, and call module with args and kwargs context = nullcontext() - if isinstance(self.module, torch.nn.Module) + if isinstance(self.module, torch.nn.Module): old_train_mode = self.module.training if _train_mode_ is not None: From 15c5a5f77f7ff0b4a1a29a8ab377378e2e7dc77d Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 12:41:08 -0700 Subject: [PATCH 52/62] CallWith is not a Module --- mart/nn/nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index c932bc55..1bb4a269 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -8,7 +8,7 @@ import logging from collections import OrderedDict -from typing import Iterable +from typing import Callable, Iterable import torch @@ -123,10 +123,10 @@ def __call__(self, **kwargs): return kwargs -class CallWith(torch.nn.Module): +class CallWith: def __init__( self, - module: torch.nn.Module, + module: Callable, _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, **kwarg_keys, @@ -138,7 +138,7 @@ def __init__( self.kwarg_keys = kwarg_keys self.return_keys = _return_as_dict_ - def forward( + def __call__( self, *args, _call_with_args_: Iterable[str] | None = None, From 27380abb415ecaa4e520d3946f46b5e08058495f Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 14:26:28 -0700 Subject: [PATCH 53/62] Remove Module in favor of changes to SequentialDict --- .../model/torchvision_dual_faster_rcnn.yaml | 3 +- .../model/torchvision_faster_rcnn.yaml | 7 +-- mart/configs/model/torchvision_retinanet.yaml | 3 +- mart/nn/__init__.py | 1 - mart/nn/module.py | 57 ------------------- 5 files changed, 4 insertions(+), 67 deletions(-) delete mode 100644 mart/nn/module.py diff --git a/mart/configs/model/torchvision_dual_faster_rcnn.yaml b/mart/configs/model/torchvision_dual_faster_rcnn.yaml index fa078001..61a252a6 100644 --- a/mart/configs/model/torchvision_dual_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_dual_faster_rcnn.yaml @@ -78,6 +78,5 @@ test_sequence: modules: detector: - _target_: mart.nn.Module - _path_: torchvision.models.detection.fasterrcnn_resnet50_fpn + _target_: torchvision.models.detection.fasterrcnn_resnet50_fpn weights: "COCO_V1" diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index 1d79c0a4..b5b741ed 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -228,9 +228,7 @@ modules: out_channels: 256 rpn: - _target_: mart.nn.Module - _path_: torchvision.models.detection.rpn.RegionProposalNetwork - #_target_: torchvision.models.detection.rpn.RegionProposalNetwork + _target_: torchvision.models.detection.rpn.RegionProposalNetwork anchor_generator: _target_: torchvision.models.detection.anchor_utils.AnchorGenerator sizes: @@ -259,8 +257,7 @@ modules: score_thresh: 0.0 roi_heads: - _target_: mart.nn.Module - _path_: torchvision.models.detection.roi_heads.RoIHeads + _target_: torchvision.models.detection.roi_heads.RoIHeads box_roi_pool: _target_: torchvision.ops.MultiScaleRoIAlign featmap_names: ["0", "1", "2", "3"] # backbone.return_layers diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index c0318898..67523685 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -76,6 +76,5 @@ test_sequence: modules: detector: - _target_: mart.nn.Module - _path_: torchvision.models.detection.retinanet_resnet50_fpn + _target_: torchvision.models.detection.retinanet_resnet50_fpn num_classes: ??? diff --git a/mart/nn/__init__.py b/mart/nn/__init__.py index 6333b3da..e39c0d57 100644 --- a/mart/nn/__init__.py +++ b/mart/nn/__init__.py @@ -1,2 +1 @@ -from .module import * # noqa: F403 from .nn import * # noqa: F403 diff --git a/mart/nn/module.py b/mart/nn/module.py deleted file mode 100644 index 2590965a..00000000 --- a/mart/nn/module.py +++ /dev/null @@ -1,57 +0,0 @@ -# -# Copyright (C) 2023 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -from __future__ import annotations - -import types -from collections import OrderedDict -from contextlib import nullcontext - -import torch -from hydra.utils import instantiate - -__all__ = ["Module"] - - -class Module(torch.nn.Module): - """A magic Module that can override forward.""" - - def __new__(cls, *args, _path_: str, **kwargs): - # TODO: Add support for _load_state_dict_ - # TODO: Add support for _freeze_ - - cfg = {"_target_": _path_} - module = instantiate(cfg, *args, **kwargs) - - module._forward_ = module.forward - module.forward = types.MethodType(Module.forward, module) - - return module - - @staticmethod - def forward( - self, - *args, - _train_mode_: bool | None = None, - _inference_mode_: bool | None = None, - **kwargs, - ): - old_train_mode = self.training - - if _train_mode_ is not None: - self.train(_train_mode_) - - inference_mode = nullcontext() - if _inference_mode_ is not None: - inference_mode = torch.inference_mode(mode=_inference_mode_) - - with inference_mode: - ret = self._forward_(*args, **kwargs) - - if _train_mode_ is not None: - self.train(old_train_mode) - - return ret From 96bbbd54f17ec48a02fa9daf8d903746a21cd860 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 15 Jun 2023 14:28:56 -0700 Subject: [PATCH 54/62] style --- mart/nn/nn.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 9bf925d5..288747e5 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -171,7 +171,10 @@ def __call__( # Append kwargs to args using arg_keys try: - [args.append(kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key) for kwargs_key in arg_keys] + [ + args.append(kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key) + for kwargs_key in arg_keys + ] except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." @@ -179,7 +182,10 @@ def __call__( # Replace kwargs with selected kwargs try: - kwargs = {name: kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key for name, kwargs_key in kwarg_keys.items()} + kwargs = { + name: kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key + for name, kwargs_key in kwarg_keys.items() + } except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." From 40dd262fcec6fdbbd6eacbd5f1306d83f25745fa Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 11:29:41 -0700 Subject: [PATCH 55/62] cleanup --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 1bb4a269..9b87d475 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -29,7 +29,7 @@ class SequentialDict(torch.nn.ModuleDict): : _name_: _call_with_args_: - _return_as_dict: + _return_as_dict_: **kwargs All intermediate output from each module are stored in the dictionary `kwargs` in `forward()` From ceaa7443e96384acbd35c6ae4fb512d786acc04b Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 13:00:32 -0700 Subject: [PATCH 56/62] bugfix --- mart/nn/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index e593d6b6..d8f5164d 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -236,7 +236,7 @@ def __getitem__(self, key): elif isinstance(value, dict) and subkey in value: value = value[subkey] else: - raise KeyError("No {subkey} in " + ".".join([key, *subkeys])) + raise KeyError(f"No {subkey} in " + ".".join([key, *subkeys])) return value From 51c2f73c89bbf7ee5777f1df78a221952ebb4fd5 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 14:17:58 -0700 Subject: [PATCH 57/62] comments --- mart/nn/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index ad437454..5459615f 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -158,6 +158,8 @@ def forward(self, *args, **kwargs): f"Module {orig_class} wants arg named '{key}' but only received kwargs: {', '.join(kwargs.keys())}." ) + # For each specified args/kwargs key, lookup its corresponding value in kwargs only if the key is a string. + # Otherwise, we just treat the key as a value. selected_args = [ kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :] ] From b8d473b06d98685745ff33dd4e998b6b75d32867 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 16:36:33 -0700 Subject: [PATCH 58/62] fix merge error --- mart/nn/nn.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 1908b9f3..902e13be 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -155,22 +155,26 @@ def __call__( if arg_keys is not None or len(kwarg_keys) > 0: arg_keys = arg_keys or [] - # Check to make sure each str exists within kwargs - str_kwarg_keys = filter(lambda k: isinstance(k, str), kwarg_keys.values()) - for key in remaining_arg_keys + list(str_kwarg_keys): - if key not in kwargs: + # Sometimes we receive positional arguments because some modules use nn.Sequential + # which has a __call__ function that passes positional args. So we pass along args + # as it and assume these consume the first len(args) of arg_keys. + arg_keys = arg_keys[len(args) :] + + # Append kwargs to args using arg_keys + try: + [args.append(kwargs[kwargs_key]) for kwargs_key in arg_keys] + except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex - # For each specified args/kwargs key, lookup its corresponding value in kwargs only if the key is a string. - # Otherwise, we just treat the key as a value. - selected_args = [ - kwargs[key] if isinstance(key, str) else key for key in arg_keys[len(args) :] - ] - selected_kwargs = { - key: kwargs[val] if isinstance(val, str) else val for key, val in kwarg_keys.items() - } + # Replace kwargs with selected kwargs + try: + kwargs = {name: kwargs[kwargs_key] for name, kwargs_key in kwarg_keys.items()} + except KeyError as ex: + raise Exception( + f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." + ) from ex # FIXME: Add better error message ret = self.module(*args, **kwargs) From e9cf67b20fb226001c8b6e0133bd0bd6629d4eb4 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Thu, 22 Jun 2023 16:43:18 -0700 Subject: [PATCH 59/62] Change call special arg names --- mart/nn/nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index 902e13be..04a0c8e1 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -139,13 +139,13 @@ def __init__( def __call__( self, *args, - _call_with_args_: Iterable[str] | None = None, - _return_as_dict_: Iterable[str] | None = None, + _args_: Iterable[str] | None = None, + _return_keys_: Iterable[str] | None = None, **kwargs, ): module_name = self.module.__class__.__name__ - arg_keys = _call_with_args_ or self.arg_keys + arg_keys = _args_ or self.arg_keys kwarg_keys = self.kwarg_keys args = list(args) @@ -180,7 +180,7 @@ def __call__( ret = self.module(*args, **kwargs) # Change returned values into dictionary, if necessary - return_keys = _return_as_dict_ or self.return_keys + return_keys = _return_keys_ or self.return_keys if return_keys: if not isinstance(ret, tuple): raise Exception( From 6987ef95f05dffe5ee72cece66ea56d8c6c99b2e Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Fri, 23 Jun 2023 08:16:14 -0700 Subject: [PATCH 60/62] bugfix --- .../experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml index 09489f48..3648bf4a 100644 --- a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml @@ -11,6 +11,8 @@ tags: ["regular_training"] optimized_metric: "test_metrics/map" model: + load_state_dict: null + modules: roi_heads: box_predictor: From b8a692a3d70a6a8f2496c1ca6b63e197dcd995d4 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 28 Jun 2023 08:29:02 -0700 Subject: [PATCH 61/62] Fix merge error --- mart/nn/nn.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/mart/nn/nn.py b/mart/nn/nn.py index dce22ec4..02113899 100644 --- a/mart/nn/nn.py +++ b/mart/nn/nn.py @@ -80,11 +80,16 @@ def parse_sequence(self, sequence): # We can omit the key of _call_with_args_ if it is the only config. module_cfg = {"_call_with_args_": module_cfg} + # Add support for calling different functions using dot-syntax + if "." not in module_name: + module_name = f"{module_name}.__call__" + module_name, _call_ = module_name.split(".", 1) + module_cfg["_call_"] = _call_ + # The return name could be different from module_name when a module is used more than once. return_name = module_cfg.pop("_name_", module_name) module = CallWith(self[module_name], **module_cfg) module_dict[return_name] = module - return module_dict def forward(self, step=None, sequence=None, **kwargs): @@ -125,7 +130,8 @@ def __call__(self, **kwargs): class CallWith: def __init__( self, - module: Callable, + module: object, + _call_: str | None = "__call__", _call_with_args_: Iterable[str] | None = None, _return_as_dict_: Iterable[str] | None = None, _train_mode_: bool | None = None, @@ -135,6 +141,7 @@ def __init__( super().__init__() self.module = module + self.call_attr = _call_ self.arg_keys = _call_with_args_ self.kwarg_keys = kwarg_keys self.return_keys = _return_as_dict_ @@ -157,30 +164,32 @@ def __call__( _train_mode_ = _train_mode_ or self.train_mode _inference_mode_ = _inference_mode_ or self.inference_mode - args = list(args) - kwargs = DotDict(kwargs) - - # Change and replaces args and kwargs that we call module with + # Change and replace args and kwargs that we call module with if arg_keys is not None or len(kwarg_keys) > 0: arg_keys = arg_keys or [] + kwargs = DotDict(kwargs) # we need to lookup values using dot strings + args = list(args) # tuple -> list + # Sometimes we receive positional arguments because some modules use nn.Sequential # which has a __call__ function that passes positional args. So we pass along args # as it and assume these consume the first len(args) of arg_keys. arg_keys = arg_keys[len(args) :] - # Append kwargs to args using arg_keys + # Extend args with selected kwargs using arg_keys try: - [ - args.append(kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key) - for kwargs_key in arg_keys - ] + args.extend( + [ + kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key + for kwargs_key in arg_keys + ] + ) except KeyError as ex: raise Exception( f"{module_name} only received kwargs: {', '.join(kwargs.keys())}." ) from ex - # Replace kwargs with selected kwargs + # Replace kwargs with selected kwargs using kwarg_keys try: kwargs = { name: kwargs[kwargs_key] if isinstance(kwargs_key, str) else kwargs_key @@ -204,7 +213,8 @@ def __call__( with context: # FIXME: Add better error message - ret = self.module(*args, **kwargs) + func = getattr(self.module, self.call_attr) + ret = func(*args, **kwargs) if isinstance(self.module, torch.nn.Module): if _train_mode_ is not None: From 7e48a21e3a0e6c0eb75d0214c507439e05c9b8a9 Mon Sep 17 00:00:00 2001 From: Cory Cornelius Date: Wed, 28 Jun 2023 08:40:59 -0700 Subject: [PATCH 62/62] bugfix --- mart/configs/model/torchvision_object_detection.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index ba2feff4..c4b54c98 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -12,8 +12,6 @@ training_sequence: validation_sequence: ??? test_sequence: ??? -output_preds_key: "losses_and_detections.eval" - modules: loss: _target_: mart.nn.Sum