From 277e0edb6db94f760bb36e5d2867dabee764ab11 Mon Sep 17 00:00:00 2001 From: purewind7 <757419920@qq.com> Date: Wed, 15 Mar 2023 10:50:33 +0800 Subject: [PATCH 1/2] Update vgg_arch.py Solve the UserWarning: "The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead." --- basicsr/archs/vgg_arch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py index 05200334e..9e822099d 100644 --- a/basicsr/archs/vgg_arch.py +++ b/basicsr/archs/vgg_arch.py @@ -3,7 +3,7 @@ from collections import OrderedDict from torch import nn as nn from torchvision.models import vgg as vgg - +from torchvision.models import VGG19_Weights from basicsr.utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' @@ -105,7 +105,7 @@ def __init__(self, state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) vgg_net.load_state_dict(state_dict) else: - vgg_net = getattr(vgg, vgg_type)(pretrained=True) + vgg_net = getattr(vgg, vgg_type)(weights=VGG19_Weights.DEFAULT) features = vgg_net.features[:max_idx + 1] From cf89d567cdff2675f8c54a260674c7d3185dd7ff Mon Sep 17 00:00:00 2001 From: purewind7 <757419920@qq.com> Date: Sun, 23 Apr 2023 19:04:16 +0800 Subject: [PATCH 2/2] Update paired_image_dataset.py fix normalize in dataset --- basicsr/data/paired_image_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index 9f5c8c6ad..2f4235acc 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -96,7 +96,7 @@ def __getitem__(self, index): # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize - if self.mean is not None or self.std is not None: + if self.mean is not None and self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) normalize(img_gt, self.mean, self.std, inplace=True)