diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py index 05200334e..9e822099d 100644 --- a/basicsr/archs/vgg_arch.py +++ b/basicsr/archs/vgg_arch.py @@ -3,7 +3,7 @@ from collections import OrderedDict from torch import nn as nn from torchvision.models import vgg as vgg - +from torchvision.models import VGG19_Weights from basicsr.utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' @@ -105,7 +105,7 @@ def __init__(self, state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) vgg_net.load_state_dict(state_dict) else: - vgg_net = getattr(vgg, vgg_type)(pretrained=True) + vgg_net = getattr(vgg, vgg_type)(weights=VGG19_Weights.DEFAULT) features = vgg_net.features[:max_idx + 1] diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py index 9f5c8c6ad..2f4235acc 100644 --- a/basicsr/data/paired_image_dataset.py +++ b/basicsr/data/paired_image_dataset.py @@ -96,7 +96,7 @@ def __getitem__(self, index): # BGR to RGB, HWC to CHW, numpy to tensor img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) # normalize - if self.mean is not None or self.std is not None: + if self.mean is not None and self.std is not None: normalize(img_lq, self.mean, self.std, inplace=True) normalize(img_gt, self.mean, self.std, inplace=True)