diff --git a/inference/inference_basicvsrpp.py b/inference/inference_basicvsrpp.py index b44aaa482..c18155477 100644 --- a/inference/inference_basicvsrpp.py +++ b/inference/inference_basicvsrpp.py @@ -7,7 +7,7 @@ from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus from basicsr.data.data_util import read_img_seq -from basicsr.utils.img_util import tensor2img +from basicsr.utils import tensor2img def inference(imgs, imgnames, model, save_path): @@ -23,7 +23,8 @@ def inference(imgs, imgnames, model, save_path): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSRPP_REDS4.pth') + parser.add_argument( + '--model_path', type=str, default='experiments/pretrained_models/BasicVSRPP_x4_SR_REDS_official.pth') parser.add_argument( '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder') parser.add_argument('--save_path', type=str, default='results/BasicVSRPP/000', help='save image path') @@ -34,7 +35,7 @@ def main(): # set up model model = BasicVSRPlusPlus(mid_channels=64, num_blocks=7) - model.load_state_dict(torch.load(args.model_path)['params'], strict=True) + model.load_state_dict(torch.load(args.model_path), strict=True) model.eval() model = model.to(device) diff --git a/scripts/model_conversion/convert_models.py b/scripts/model_conversion/convert_models.py index 46bb085f9..930616edd 100644 --- a/scripts/model_conversion/convert_models.py +++ b/scripts/model_conversion/convert_models.py @@ -357,6 +357,50 @@ def convert_duf_model(): torch.save(crt_net, 'experiments/pretrained_models/DUF_x2_16L_official.pth') +def convert_basicvsrpp_model(): + from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus + basicvsrpp = BasicVSRPlusPlus(mid_channels=64, num_blocks=7) + crt_net = basicvsrpp.state_dict() + # for k, v in crt_net.items(): + # print(k) + + # print('=================') + + ori_net = torch.load( + 'experiments/pretrained_models/BasicVSRPP/basicvsr_plusplus_c64n7_8x1_300k_vimeo90k_bi_20210305-4ef437e2.pth') + + # for k, v in ori_net['state_dict'].items(): + # print(k) + + for ort_k, _ in ori_net['state_dict'].items(): + if 'generator' in ort_k: + # delete 'generator' + crt_k = ort_k[10:] + + # spynet module + if 'spynet.basic_module' in ort_k: + if 'weight' in ort_k: + number = int(crt_k[-13]) + crt_k = crt_k[:-13] + f'{number * 2}.weight' + elif 'bias' in ort_k: + number = int(crt_k[-11]) + crt_k = crt_k[:-11] + f'{number * 2}.bias' + + # upsample module + if 'upsample1.upsample_conv.weight' in ort_k: + crt_k = 'upconv1.weight' + elif 'upsample1.upsample_conv.bias' in ort_k: + crt_k = 'upconv1.bias' + elif 'upsample2.upsample_conv.weight' in ort_k: + crt_k = 'upconv2.weight' + elif 'upsample2.upsample_conv.bias' in ort_k: + crt_k = 'upconv2.bias' + + crt_net[crt_k] = ori_net['state_dict'][ort_k] + + torch.save(crt_net, 'experiments/pretrained_models/Converted-BasicVSRPP/BasicVSRPP_x4_SR_Vimeo90K_BI_official.pth') + + if __name__ == '__main__': # convert EDSR models # ori_net_path = 'path to original model' @@ -364,4 +408,4 @@ def convert_duf_model(): # save_path = 'save path' # convert_edsr(ori_net_path, crt_net_path, save_path, num_block=32) - convert_duf_model() + convert_basicvsrpp_model()