diff --git a/mart/configs/attack/gain/rcnn_training_loss.yaml b/mart/configs/attack/gain/rcnn_training_loss.yaml index 61ad886d..4d19f364 100644 --- a/mart/configs/attack/gain/rcnn_training_loss.yaml +++ b/mart/configs/attack/gain/rcnn_training_loss.yaml @@ -1,8 +1,4 @@ _target_: mart.nn.CallWith module: _target_: mart.nn.Sum -_call_with_args_: - - "losses_and_detections.training.loss_objectness" - - "losses_and_detections.training.loss_rpn_box_reg" - - "losses_and_detections.training.loss_classifier" - - "losses_and_detections.training.loss_box_reg" +_call_with_args_: ${model.training_sequence.seq100.loss} diff --git a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml index 991a87e7..3648bf4a 100644 --- a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml @@ -11,11 +11,12 @@ tags: ["regular_training"] optimized_metric: "test_metrics/map" model: + load_state_dict: null + modules: - losses_and_detections: - model: + roi_heads: + box_predictor: num_classes: 3 - weights: null optimizer: lr: 0.0125 diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml index 9259de2c..0937774b 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml @@ -28,13 +28,6 @@ datamodule: world_size: 1 model: - modules: - losses_and_detections: - model: - # Inferred by torchvision. - num_classes: null - weights: COCO_V1 - optimizer: lr: 0.0125 momentum: 0.9 diff --git a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml index dbd4541f..3b94a04a 100644 --- a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml +++ b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml @@ -20,7 +20,8 @@ callbacks: trainer: # 117,266 training images, 6 epochs, batch_size=2, 351798 max_steps: 351798 - precision: 16 + # FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms(). + precision: 32 datamodule: ims_per_batch: 2 @@ -28,11 +29,9 @@ datamodule: model: modules: - losses_and_detections: - model: - # Inferred by torchvision. - num_classes: null - weights: COCO_V1 + detector: + num_classes: null # inferred by torchvision + weights: COCO_V1 optimizer: lr: 0.0125 diff --git a/mart/configs/model/torchvision_dual_faster_rcnn.yaml b/mart/configs/model/torchvision_dual_faster_rcnn.yaml new file mode 100644 index 00000000..61a252a6 --- /dev/null +++ b/mart/configs/model/torchvision_dual_faster_rcnn.yaml @@ -0,0 +1,82 @@ +# We simply wrap a torchvision object detection model for validation. +defaults: + - torchvision_object_detection + +# log all losses separately in losses. +training_step_log: + loss_objectness: "losses.loss_objectness" + loss_rpn_box_reg: "losses.loss_rpn_box_reg" + loss_classifier: "losses.loss_classifier" + loss_box_reg: "losses.loss_box_reg" + +training_sequence: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + _train_mode_: True + images: preprocessor + targets: target + + seq030: + detector: + _name_: preds + _train_mode_: False + _inference_mode_: True + images: preprocessor + targets: target + + seq100: + loss: + # Sum up the losses. + [ + "losses.loss_objectness", + "losses.loss_rpn_box_reg", + "losses.loss_classifier", + "losses.loss_box_reg", + ] + +validation_sequence: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + _train_mode_: True + images: preprocessor + targets: target + + seq030: + detector: + _name_: preds + _train_mode_: False + _inference_mode_: True + images: preprocessor + targets: target + +test_sequence: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + _train_mode_: True + images: preprocessor + targets: target + + seq030: + detector: + _name_: preds + _train_mode_: False + _inference_mode_: True + images: preprocessor + targets: target + +modules: + detector: + _target_: torchvision.models.detection.fasterrcnn_resnet50_fpn + weights: "COCO_V1" diff --git a/mart/configs/model/torchvision_faster_rcnn.yaml b/mart/configs/model/torchvision_faster_rcnn.yaml index 65200579..b5b741ed 100644 --- a/mart/configs/model/torchvision_faster_rcnn.yaml +++ b/mart/configs/model/torchvision_faster_rcnn.yaml @@ -2,28 +2,84 @@ defaults: - torchvision_object_detection -# log all losses separately in training. +load_state_dict: + _model_: + _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available + path: "torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1" + +# log all losses separately in losses. training_step_log: - loss_objectness: "losses_and_detections.training.loss_objectness" - loss_rpn_box_reg: "losses_and_detections.training.loss_rpn_box_reg" - loss_classifier: "losses_and_detections.training.loss_classifier" - loss_box_reg: "losses_and_detections.training.loss_box_reg" + loss_objectness: "rpn.losses.loss_objectness" + loss_rpn_box_reg: "rpn.losses.loss_rpn_box_reg" + loss_classifier: "roi_heads.losses.loss_classifier" + loss_box_reg: "roi_heads.losses.loss_box_reg" training_sequence: seq010: preprocessor: ["input"] seq020: - losses_and_detections: ["preprocessor", "target"] + transform: + images: "preprocessor" + targets: "target" + _return_as_dict_: ["images", "targets"] seq030: + backbone: + x: "transform.images.tensors" + + seq040: + rpn: + _train_mode_: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq050: + roi_heads: + _train_mode_: True + features: "backbone" + proposals: "rpn.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + + seq060: + rpn: + _name_: "rpn_eval" + _train_mode_: False + _inference_mode_: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq070: + roi_heads: + _name_: "roi_heads_eval" + _train_mode_: False + _inference_mode_: True + features: "backbone" + proposals: "rpn_eval.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + + seq080: + preds: + result: "roi_heads_eval.detections" + image_shapes: "transform.images.image_sizes" + original_images: "preprocessor" + + seq100: loss: # Sum up the losses. [ - "losses_and_detections.training.loss_objectness", - "losses_and_detections.training.loss_rpn_box_reg", - "losses_and_detections.training.loss_classifier", - "losses_and_detections.training.loss_box_reg", + "rpn.losses.loss_objectness", + "rpn.losses.loss_rpn_box_reg", + "roi_heads.losses.loss_classifier", + "roi_heads.losses.loss_box_reg", ] validation_sequence: @@ -31,20 +87,198 @@ validation_sequence: preprocessor: ["input"] seq020: - losses_and_detections: ["preprocessor", "target"] + transform: + images: "preprocessor" + targets: "target" + _return_as_dict_: ["images", "targets"] + + seq030: + backbone: + x: "transform.images.tensors" + + seq040: + rpn: + _train_mode_: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq050: + roi_heads: + _train_mode_: True + features: "backbone" + proposals: "rpn.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + + seq060: + rpn: + _name_: "rpn_eval" + _train_mode_: False + _inference_mode_: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq070: + roi_heads: + _name_: "roi_heads_eval" + _train_mode_: False + _inference_mode_: True + features: "backbone" + proposals: "rpn_eval.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + + seq080: + preds: + result: "roi_heads_eval.detections" + image_shapes: "transform.images.image_sizes" + original_images: "preprocessor" test_sequence: seq010: preprocessor: ["input"] seq020: - losses_and_detections: ["preprocessor", "target"] + transform: + images: "preprocessor" + targets: "target" + _return_as_dict_: ["images", "targets"] + + seq030: + backbone: + x: "transform.images.tensors" + + seq040: + rpn: + _train_mode_: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq050: + roi_heads: + _train_mode_: True + features: "backbone" + proposals: "rpn.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + + seq060: + rpn: + _name_: "rpn_eval" + _train_mode_: False + _inference_mode_: True + images: "transform.images" + features: "backbone" + targets: "transform.targets" + _return_as_dict_: ["proposals", "losses"] + + seq070: + roi_heads: + _name_: "roi_heads_eval" + _train_mode_: False + _inference_mode_: True + features: "backbone" + proposals: "rpn_eval.proposals" + image_shapes: "transform.images.image_sizes" + targets: "transform.targets" + _return_as_dict_: ["detections", "losses"] + + seq080: + preds: + result: "roi_heads_eval.detections" + image_shapes: "transform.images.image_sizes" + original_images: "preprocessor" modules: - losses_and_detections: - # 17s: DualModeGeneralizedRCNN - # 23s: DualMode - _target_: mart.models.DualModeGeneralizedRCNN - model: - _target_: torchvision.models.detection.fasterrcnn_resnet50_fpn - num_classes: ??? + transform: + _target_: torchvision.models.detection.transform.GeneralizedRCNNTransform + min_size: 800 + max_size: 1333 + image_mean: [0.485, 0.456, 0.406] + image_std: [0.229, 0.224, 0.225] + + backbone: + _target_: torchvision.models.detection.backbone_utils.BackboneWithFPN + backbone: + _target_: torchvision.models.resnet.ResNet + block: + _target_: hydra.utils.get_method + path: torchvision.models.resnet.Bottleneck + layers: [3, 4, 6, 3] + num_classes: 1000 + zero_init_residual: False + groups: 1 + width_per_group: 64 + replace_stride_with_dilation: null + norm_layer: + _target_: hydra.utils.get_method + path: torchvision.ops.misc.FrozenBatchNorm2d + return_layers: + { "layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3" } + in_channels_list: [256, 512, 1024, 2048] + out_channels: 256 + + rpn: + _target_: torchvision.models.detection.rpn.RegionProposalNetwork + anchor_generator: + _target_: torchvision.models.detection.anchor_utils.AnchorGenerator + sizes: + - [32] + - [64] + - [128] + - [256] + - [512] + aspect_ratios: + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + - [0.5, 1.0, 2.0] + head: + _target_: torchvision.models.detection.rpn.RPNHead + in_channels: 256 # backbone.out_channels + num_anchors: 3 + fg_iou_thresh: 0.7 + bg_iou_thresh: 0.3 + batch_size_per_image: 256 + positive_fraction: 0.5 + pre_nms_top_n: { "training": 2000, "testing": 1000 } + post_nms_top_n: { "training": 2000, "testing": 1000 } + nms_thresh: 0.7 + score_thresh: 0.0 + + roi_heads: + _target_: torchvision.models.detection.roi_heads.RoIHeads + box_roi_pool: + _target_: torchvision.ops.MultiScaleRoIAlign + featmap_names: ["0", "1", "2", "3"] # backbone.return_layers + output_size: 7 + sampling_ratio: 2 + box_head: + _target_: torchvision.models.detection.faster_rcnn.TwoMLPHead + in_channels: 12544 # backbone.out_channels * (box_roi_pool.output_size ** 2) + representation_size: 1024 + box_predictor: + _target_: torchvision.models.detection.faster_rcnn.FastRCNNPredictor + in_channels: 1024 # box_head.representation_size + num_classes: 91 # coco classes + background + fg_iou_thresh: 0.5 + bg_iou_thresh: 0.5 + batch_size_per_image: 512 + positive_fraction: 0.25 + bbox_reg_weights: null + score_thresh: 0.05 + nms_thresh: 0.5 + detections_per_img: 100 + + preds: + _target_: mart.models.detection.GeneralizedRCNNPostProcessor diff --git a/mart/configs/model/torchvision_object_detection.yaml b/mart/configs/model/torchvision_object_detection.yaml index 1bbd678c..c4b54c98 100644 --- a/mart/configs/model/torchvision_object_detection.yaml +++ b/mart/configs/model/torchvision_object_detection.yaml @@ -6,17 +6,12 @@ defaults: training_step_log: loss: "loss" -training_sequence: ??? +training_sequence: + seq100: + loss: ??? validation_sequence: ??? test_sequence: ??? -output_preds_key: "losses_and_detections.eval" - modules: - losses_and_detections: - # Return losses in the training mode and predictions in the eval mode in one pass. - _target_: mart.models.DualMode - model: ??? - loss: _target_: mart.nn.Sum diff --git a/mart/configs/model/torchvision_retinanet.yaml b/mart/configs/model/torchvision_retinanet.yaml index 34b66945..67523685 100644 --- a/mart/configs/model/torchvision_retinanet.yaml +++ b/mart/configs/model/torchvision_retinanet.yaml @@ -2,32 +2,79 @@ defaults: - torchvision_object_detection +load_state_dict: + detector: + _target_: hydra.utils._locate # FIXME: Use hydra.utils.get_object when available + path: "torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights.COCO_V1" + # log all losses separately in training. training_step_log: - loss_classifier: "losses_and_detections.training.classification" - loss_box_reg: "losses_and_detections.training.bbox_regression" + loss_classifier: "losses.classification" + loss_box_reg: "losses.bbox_regression" training_sequence: - - preprocessor: ["input"] - - losses_and_detections: ["preprocessor", "target"] - - loss: + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + _train_mode_: True + images: preprocessor + targets: target + + seq030: + detector: + _name_: preds + _train_mode_: False + _inference_mode_: True + images: preprocessor + targets: target + + seq100: + loss: # Sum up the losses. - [ - "losses_and_detections.training.classification", - "losses_and_detections.training.bbox_regression", - ] + ["losses.classification", "losses.bbox_regression"] validation_sequence: - - preprocessor: ["input"] - - losses_and_detections: ["preprocessor", "target"] + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + _train_mode_: True + images: preprocessor + targets: target + + seq030: + detector: + _name_: preds + _train_mode_: False + _inference_mode_: True + images: preprocessor + targets: target test_sequence: - - preprocessor: ["input"] - - losses_and_detections: ["preprocessor", "target"] + seq010: + preprocessor: ["input"] + + seq020: + detector: + _name_: losses + _train_mode_: True + images: preprocessor + targets: target + + seq030: + detector: + _name_: preds + _train_mode_: False + _inference_mode_: True + images: preprocessor + targets: target modules: - losses_and_detections: - # _target_: mart.models.DualMode - model: - _target_: torchvision.models.detection.retinanet_resnet50_fpn - num_classes: ??? + detector: + _target_: torchvision.models.detection.retinanet_resnet50_fpn + num_classes: ??? diff --git a/mart/models/__init__.py b/mart/models/__init__.py index a6e8d57f..96544d96 100644 --- a/mart/models/__init__.py +++ b/mart/models/__init__.py @@ -4,5 +4,4 @@ # SPDX-License-Identifier: BSD-3-Clause # -from .dual_mode import * # noqa: F403 from .modular import * # noqa: F403 diff --git a/mart/models/detection.py b/mart/models/detection.py new file mode 100644 index 00000000..f895faf8 --- /dev/null +++ b/mart/models/detection.py @@ -0,0 +1,55 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import torch +from torch import Tensor, nn +from torchvision.models.detection.transform import ( + paste_masks_in_image, + resize_boxes, + resize_keypoints, +) + +__all__ = ["GeneralizedRCNNPostProcessor"] + + +class GeneralizedRCNNPostProcessor(nn.Module): + def forward( + self, + result: list[dict[str, Tensor]], + image_shapes: list[tuple[int, int]], + original_images: list[Tensor], + ) -> list[dict[str, Tensor]]: + original_image_sizes = get_image_sizes(original_images) + + for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): + boxes = pred["boxes"] + boxes = resize_boxes(boxes, im_s, o_im_s) + result[i]["boxes"] = boxes + if "masks" in pred: + masks = pred["masks"] + masks = paste_masks_in_image(masks, boxes, o_im_s) + result[i]["masks"] = masks + if "keypoints" in pred: + keypoints = pred["keypoints"] + keypoints = resize_keypoints(keypoints, im_s, o_im_s) + result[i]["keypoints"] = keypoints + return result + + +def get_image_sizes(images: list[Tensor]): + image_sizes: list[tuple[int, int]] = [] + + for img in images: + val = img.shape[-2:] + torch._assert( + len(val) == 2, + f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", + ) + image_sizes.append((val[0], val[1])) + + return image_sizes diff --git a/mart/models/dual_mode.py b/mart/models/dual_mode.py deleted file mode 100644 index 5cc780a2..00000000 --- a/mart/models/dual_mode.py +++ /dev/null @@ -1,153 +0,0 @@ -# -# Copyright (C) 2022 Intel Corporation -# -# SPDX-License-Identifier: BSD-3-Clause -# - -from collections import OrderedDict -from typing import Dict, List, Optional, Tuple - -import torch -from torch import Tensor -from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN - -from mart.utils.monkey_patch import MonkeyPatch - -__all__ = ["DualMode", "DualModeGeneralizedRCNN"] - - -class DualMode(torch.nn.Module): - """Run model.forward() in both the training mode and the eval mode, then aggregate results in a - dictionary {"training": ..., "eval": ...}. - - Some object detection models are implemented to return losses in the training mode and - predictions in the eval mode, but we want both the losses and the predictions when attacking a - model in the test mode. - """ - - def __init__(self, model): - super().__init__() - - self.model = model - - def forward(self, *args, **kwargs): - original_training_status = self.model.training - ret = {} - - # TODO: Reuse the feature map in dual mode to improve efficiency - self.model.train(True) - ret["training"] = self.model(*args, **kwargs) - - self.model.train(False) - with torch.no_grad(): - ret["eval"] = self.model(*args, **kwargs) - - self.model.train(original_training_status) - return ret - - -class DualModeGeneralizedRCNN(torch.nn.Module): - """Efficient dual mode for GeneralizedRCNN from torchvision, by reusing feature maps from - backbone.""" - - def __init__(self, model): - super().__init__() - - self.model = model - - def forward(self, *args, **kwargs): - bound_method = self.forward_dual_mode.__get__(self.model, self.model.__class__) - with MonkeyPatch(self.model, "forward", bound_method): - ret = self.model(*args, **kwargs) - return ret - - # Adapted from: https://github.com/pytorch/vision/blob/32757a260dfedebf71eb470bd0a072ed20beddc3/torchvision/models/detection/generalized_rcnn.py#L46 - @staticmethod - def forward_dual_mode(self, images, targets=None): - # type: (GeneralizedRCNN, List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] - """ - Args: - self (GeneralizedRCNN): the model. - images (list[Tensor]): images to be processed - targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional) - - Returns: - result (list[BoxList] or dict[Tensor]): the output from the model. - During training, it returns a dict[Tensor] which contains the losses. - During testing, it returns list[BoxList] contains additional fields - like `scores`, `labels` and `mask` (for Mask R-CNN models). - - """ - # Validate targets in both training and eval mode. - if targets is not None: - for target in targets: - boxes = target["boxes"] - if isinstance(boxes, torch.Tensor): - torch._assert( - len(boxes.shape) == 2 and boxes.shape[-1] == 4, - f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.", - ) - else: - torch._assert( - False, - f"Expected target boxes to be of type Tensor, got {type(boxes)}.", - ) - - original_image_sizes: List[Tuple[int, int]] = [] - for img in images: - val = img.shape[-2:] - torch._assert( - len(val) == 2, - f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}", - ) - original_image_sizes.append((val[0], val[1])) - - images, targets = self.transform(images, targets) - - # Check for degenerate boxes - # TODO: Move this to a function - if targets is not None: - for target_idx, target in enumerate(targets): - boxes = target["boxes"] - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - # print the first degenerate box - bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] - degen_bb: List[float] = boxes[bb_idx].tolist() - torch._assert( - False, - "All bounding boxes should have positive height and width." - f" Found invalid box {degen_bb} for target at index {target_idx}.", - ) - - features = self.backbone(images.tensors) - if isinstance(features, torch.Tensor): - features = OrderedDict([("0", features)]) - - original_training_status = self.training - ret = {} - - # Training mode. - self.train(True) - proposals, proposal_losses = self.rpn(images, features, targets) - detections, detector_losses = self.roi_heads( - features, proposals, images.image_sizes, targets - ) - losses = {} - losses.update(detector_losses) - losses.update(proposal_losses) - ret["training"] = losses - - # Eval mode. - self.train(False) - with torch.no_grad(): - proposals, proposal_losses = self.rpn(images, features, targets) - detections, detector_losses = self.roi_heads( - features, proposals, images.image_sizes, targets - ) - detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator] - ret["eval"] = detections - - self.train(original_training_status) - - return ret diff --git a/mart/models/modular.py b/mart/models/modular.py index a27c6867..f8355f60 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -95,10 +95,26 @@ def __init__( # commandlines converts dotted paths to nested dictionaries. load_state_dict = flatten_dict(load_state_dict or {}) - for name, path in load_state_dict.items(): - module = attrgetter(name)(self.model) - logger.info(f"Loading state_dict {path} for {module.__class__.__name__}...") - module.load_state_dict(torch.load(path, map_location="cpu")) + for name, state_dict in load_state_dict.items(): + if name == "_model_": + module = self.model + else: + module = attrgetter(name)(self.model) + + if isinstance(state_dict, str): + logger.info(f"Loading {state_dict} for {module.__class__.__name__}...") + state_dict = torch.load(state_dict, map_location="cpu") + + elif hasattr(state_dict, "get_state_dict"): + logger.info( + f"Loading {state_dict.__class__.__name__} for {module.__class__.__name__}..." + ) + state_dict = state_dict.get_state_dict(progress=True) + + else: + raise ValueError(f"Unsupported state_dict: {state_dict}") + + module.load_state_dict(state_dict, strict=True) self.output_loss_key = output_loss_key self.output_preds_key = output_preds_key