├── codes ├── utils │ └── __init__.py ├── options │ ├── __init__.py │ ├── train │ │ ├── train_FSTRN_RealVSR_YCbCr_Combine.yml │ │ ├── train_FSTRN_Vimeo90K.yml │ │ ├── train_TOF_RealVSR_YCbCr_Combine.yml │ │ ├── train_TDAN_RealVSR_YCbCr_Combine.yml │ │ ├── train_TOF_Vimeo90K.yml │ │ ├── train_TDAN_Vimeo90K.yml │ │ ├── train_FSTRN_RealVSR_YCbCr_Split.yml │ │ ├── train_TOF_RealVSR_YCbCr_Split.yml │ │ ├── train_TDAN_RealVSR_YCbCr_Split.yml │ │ ├── train_RCAN_RealVSR_YCbCr_Combine.yml │ │ ├── train_EDVR_woTSA_RealVSR_YCbCr_Combine.yml │ │ ├── train_RCAN_Vimeo90K.yml │ │ ├── train_EDVR_woTSA_Vimeo90K.yml │ │ ├── train_RCAN_RealVSR_YCbCr_Split.yml │ │ ├── train_EDVR_woTSA_RealVSR_YCbCr_Split.yml │ │ ├── train_TOF-GAN_RealVSR_YCbCr_Split.yml │ │ └── train_EDVR-GAN_woTSA_RealVSR_YCbCr_Split.yml │ └── options.py ├── models │ ├── archs │ │ ├── __init__.py │ │ ├── dcn │ │ │ ├── __init__.py │ │ │ └── setup.py │ │ ├── SRResNet_arch.py │ │ ├── FSTRN_arch.py │ │ ├── TDAN_arch.py │ │ ├── arch_util.py │ │ ├── RCAN_arch.py │ │ ├── TOF_arch.py │ │ └── VGG_arch.py │ ├── __init__.py │ ├── base_model.py │ ├── lr_scheduler.py │ ├── VideoSR_archs.py │ ├── loss.py │ ├── VideoSR_AllPair_model_YCbCr_Split.py │ └── VideoSR_AllPair_model_YCbCr_Combine.py ├── metrics │ ├── models │ │ ├── niqe_model_realvsr_all.mat │ │ └── fit_niqe_model.m │ ├── evaluate_realvsr_no_reference_metrics.m │ ├── evaluate_niqe_brisque.m │ └── evaluate_realvsr_full_reference_metrics.py ├── scripts │ ├── generate_LR_BI_Vimeo90K.m │ └── prepare_data.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── augments_video_allpair.py │ ├── VideoTestDataset.py │ └── Vimeo90K_dataset.py ├── test_RealVSR_wo_GT.py └── test_RealVSR_wi_GT.py ├── keys ├── remove_seqs.pkl └── realvsr_keys.pkl ├── imgs └── dataset_samples.png ├── .style.yapf ├── .flake8 ├── requirements.txt ├── .gitignore ├── README.md └── LICENSE /codes/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /codes/options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /codes/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keys/remove_seqs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanYeung/RealVSR/HEAD/keys/remove_seqs.pkl -------------------------------------------------------------------------------- /keys/realvsr_keys.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanYeung/RealVSR/HEAD/keys/realvsr_keys.pkl -------------------------------------------------------------------------------- /imgs/dataset_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanYeung/RealVSR/HEAD/imgs/dataset_samples.png -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE = pep8 3 | COLUMN_LIMIT = 100 4 | SPLIT_BEFORE_NAMED_ASSIGNS = false -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # Too many leading '#' for block comment (E266) 4 | E266 5 | 6 | max-line-length=100 -------------------------------------------------------------------------------- /codes/metrics/models/niqe_model_realvsr_all.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IanYeung/RealVSR/HEAD/codes/metrics/models/niqe_model_realvsr_all.mat -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | lmdb 4 | pyyaml 5 | tensorboard 6 | future 7 | matplotlib 8 | ffmpeg-python 9 | IQA_pytorch 10 | ipython 11 | matplotlib 12 | scikit-image 13 | torch 14 | torchvision 15 | tqdm 16 | kornia -------------------------------------------------------------------------------- /codes/models/archs/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, 2 | deform_conv, modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /codes/metrics/evaluate_realvsr_no_reference_metrics.m: -------------------------------------------------------------------------------- 1 | root = '/home/xiyang/Datasets/RealVSR/'; 2 | expn = 'LQ_test'; 3 | result_path = '/home/xiyang/Results/RealVSR/degredation_no_reference_metrics_inp_all.txt'; 4 | if_niqe = true; 5 | if_brisque = true; 6 | niqe_model_path = './models/niqe_model_realvsr_all.mat'; 7 | evaluate_no_reference_metrics(root, expn, if_niqe, if_brisque, result_path, niqe_model_path) -------------------------------------------------------------------------------- /codes/metrics/models/fit_niqe_model.m: -------------------------------------------------------------------------------- 1 | function fit_niqe_model(realvsr_gt_dir) 2 | filepaths = dir(fullfile(realvsr_gt_dir, '*', '*.png')); 3 | img_cell_1 = {}; 4 | for i = 1 : length(filepaths) 5 | [~,imname,ext] = fileparts(filepaths(i).name); 6 | folder_path = filepaths(i).folder; 7 | img_cell_1{i} = fullfile(folder_path, [imname, ext]); 8 | end 9 | img_cell = [img_cell_1]; 10 | imds = imageDatastore(img_cell); 11 | niqe_model = fitniqe(imds); % default block size is 96x96 12 | save('./niqe_model_realvsr_all.mat', 'niqe_model'); 13 | end -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | model = opt['model'] 7 | if model == 'VideoSR_AllPair_YCbCr_Combine': 8 | from .VideoSR_AllPair_model_YCbCr_Combine import VideoSRModel as M 9 | elif model == 'VideoSR_AllPair_YCbCr_Split': 10 | from .VideoSR_AllPair_model_YCbCr_Split import VideoSRModel as M 11 | elif model == 'VideoSRGAN_AllPair_YCbCr_Split': 12 | from .VideoSRGAN_AllPair_model_YCbCr_Split import VideoSRGANModel as M 13 | else: 14 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 15 | m = M(opt) 16 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 17 | return m 18 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | def make_cuda_ext(name, sources): 6 | 7 | return CUDAExtension( 8 | name='{}'.format(name), sources=[p for p in sources], extra_compile_args={ 9 | 'cxx': [], 10 | 'nvcc': [ 11 | '-D__CUDA_NO_HALF_OPERATORS__', 12 | '-D__CUDA_NO_HALF_CONVERSIONS__', 13 | '-D__CUDA_NO_HALF2_OPERATORS__', 14 | ] 15 | }) 16 | 17 | 18 | setup( 19 | name='deform_conv', ext_modules=[ 20 | make_cuda_ext(name='deform_conv_cuda', 21 | sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']) 22 | ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) 23 | -------------------------------------------------------------------------------- /codes/scripts/generate_LR_BI_Vimeo90K.m: -------------------------------------------------------------------------------- 1 | function generate_LR_BI_Vimeo90K() 2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset 3 | up_scale = 2; 4 | mod_scale = 2; 5 | idx = 0; 6 | filepaths = dir('/home/xiyang/Datasets/Vimeo90k/vimeo_septuplet/sequences/*/*/*.png'); 7 | for i = 1 : length(filepaths) 8 | [~,imname,ext] = fileparts(filepaths(i).name); 9 | folder_path = filepaths(i).folder; 10 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_BIx2_same'); 11 | if ~exist(save_LR_folder, 'dir') 12 | mkdir(save_LR_folder); 13 | end 14 | if isempty(imname) 15 | disp('Ignore . folder.'); 16 | elseif strcmp(imname, '.') 17 | disp('Ignore .. folder.'); 18 | else 19 | idx = idx + 1; 20 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 21 | fprintf(str_rlt); 22 | % read image 23 | img = imread(fullfile(folder_path, [imname, ext])); 24 | img = im2double(img); 25 | % modcrop 26 | img = modcrop(img, mod_scale); 27 | % LR 28 | im_LR = imresize(imresize(img, 1/up_scale, 'bicubic'), up_scale, 'bicubic'); 29 | if exist('save_LR_folder', 'var') 30 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 31 | end 32 | end 33 | end 34 | end 35 | 36 | function img = modcrop(img, modulo) 37 | %% modcrop 38 | if size(img,3) == 1 39 | sz = size(img); 40 | sz = sz - mod(sz, modulo); 41 | img = img(1:sz(1), 1:sz(2)); 42 | else 43 | tmpsz = size(img); 44 | sz = tmpsz(1:2); 45 | sz = sz - mod(sz, modulo); 46 | img = img(1:sz(1), 1:sz(2),:); 47 | end 48 | end -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/* 2 | results/* 3 | tb_logger/* 4 | .vscode 5 | .idea 6 | 7 | *.html 8 | *.png 9 | *.jpg 10 | *.gif 11 | *.pth 12 | *.pytorch 13 | 14 | *.zip 15 | 16 | # template 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | """create dataset and dataloader""" 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=False) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 25 | pin_memory=True) 26 | 27 | 28 | def create_dataset(dataset_opt): 29 | 30 | mode = dataset_opt['mode'] 31 | if mode == 'VideoTest': 32 | from data.VideoTestDataset import VideoTestDataset as D 33 | elif mode == 'Vimeo90k': 34 | from data.Vimeo90K_dataset import Vimeo90kDataset as D 35 | elif mode == 'Vimeo90k_AllPair': 36 | from data.Vimeo90K_dataset import Vimeo90kAllPairDataset as D 37 | elif mode == 'RealVSR': 38 | from data.RealVSR_dataset import RealVSRDataset as D 39 | elif mode == 'RealVSR_AllPair': 40 | from data.RealVSR_dataset import RealVSRAllPairDataset as D 41 | else: 42 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 43 | 44 | dataset = D(dataset_opt) 45 | 46 | logger = logging.getLogger('base') 47 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, dataset_opt['name'])) 48 | 49 | return dataset 50 | -------------------------------------------------------------------------------- /codes/options/train/train_FSTRN_RealVSR_YCbCr_Combine.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_FSTRN_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Combine 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: FSTRN 43 | k: 3 44 | nf: 64 45 | nframes: 3 46 | 47 | #### path 48 | path: 49 | pretrain_model_G: ~ 50 | strict_load: true 51 | resume_state: ~ 52 | 53 | #### training settings: learning rate scheme, loss 54 | train: 55 | lr_G: !!float 1e-4 56 | lr_scheme: CosineAnnealingLR_Restart 57 | beta1: 0.9 58 | beta2: 0.99 59 | niter: 150000 60 | warmup_iter: -1 # -1: no warm up 61 | T_period: [150000, 150000, 150000, 150000] 62 | restarts: [150000, 300000, 450000] 63 | restart_weights: [1, 1, 1] 64 | eta_min: !!float 1e-7 65 | 66 | pixel_criterion: cb 67 | pixel_weight: 1.0 68 | val_freq: !!float 1e4 69 | 70 | manual_seed: 0 71 | 72 | #### logger 73 | logger: 74 | print_freq: 100 75 | save_checkpoint_freq: !!float 1e4 76 | 77 | #### augment 78 | augment: 79 | augs: ["none", "cutblur"] 80 | probs: [1.0, 1.0] 81 | mix_p: [0.95, 0.05] 82 | alphas: [1.0, 0.7] 83 | -------------------------------------------------------------------------------- /codes/options/train/train_FSTRN_Vimeo90K.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_FSTRN_scratch_lr1e-4_150k_Vimeo90k_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: Vimeo90k_Train 13 | mode: Vimeo90k_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_ycbcr/sequences 18 | dataroot_LQ: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_BIx2_same_ycbcr/sequences 19 | cache_keys: /home/yangxi/datasets/Vimeo90k/keys/vimeo90k_train_keys.pkl 20 | N_frames: 3 21 | use_shuffle: true 22 | n_workers: 8 # per GPU 23 | batch_size: 32 24 | GT_size: 192 25 | LQ_size: 192 26 | use_flip: true 27 | use_rot: true 28 | color: ycbcr 29 | val: 30 | name: RealVSR_Test 31 | mode: VideoTest 32 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 33 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 34 | cache_data: true 35 | N_frames: 3 36 | padding: new_info 37 | color: ycbcr 38 | 39 | #### network structures 40 | network_G: 41 | which_model_G: FSTRN 42 | k: 3 43 | nf: 64 44 | nframes: 3 45 | 46 | #### path 47 | path: 48 | pretrain_model_G: ~ 49 | strict_load: true 50 | resume_state: ~ 51 | 52 | #### training settings: learning rate scheme, loss 53 | train: 54 | lr_G: !!float 1e-4 55 | lr_scheme: CosineAnnealingLR_Restart 56 | beta1: 0.9 57 | beta2: 0.99 58 | niter: 150000 59 | warmup_iter: -1 # -1: no warm up 60 | T_period: [150000, 150000, 150000, 150000] 61 | restarts: [150000, 300000, 450000] 62 | restart_weights: [1, 1, 1] 63 | eta_min: !!float 1e-7 64 | 65 | pixel_criterion: cb 66 | pixel_weight: 1.0 67 | val_freq: !!float 1e4 68 | 69 | manual_seed: 0 70 | 71 | #### logger 72 | logger: 73 | print_freq: 100 74 | save_checkpoint_freq: !!float 1e4 75 | 76 | #### augment 77 | augment: 78 | augs: ["none", "cutblur"] 79 | probs: [1.0, 1.0] 80 | mix_p: [0.95, 0.05] 81 | alphas: [1.0, 0.7] -------------------------------------------------------------------------------- /codes/options/train/train_TOF_RealVSR_YCbCr_Combine.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_TOF_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Combine 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: TOF 43 | nframes: 3 44 | K: 3 45 | nc: 3 46 | nf: 64 47 | nb: 10 48 | 49 | #### path 50 | path: 51 | pretrain_model_G: ~ 52 | strict_load: true 53 | resume_state: ~ 54 | 55 | #### training settings: learning rate scheme, loss 56 | train: 57 | lr_G: !!float 1e-4 58 | lr_scheme: CosineAnnealingLR_Restart 59 | beta1: 0.9 60 | beta2: 0.99 61 | niter: 150000 62 | warmup_iter: -1 # -1: no warm up 63 | T_period: [150000, 150000, 150000, 150000] 64 | restarts: [150000, 300000, 450000] 65 | restart_weights: [1, 1, 1] 66 | eta_min: !!float 1e-7 67 | 68 | pixel_criterion: cb 69 | pixel_weight: 1.0 70 | val_freq: !!float 1e4 71 | 72 | manual_seed: 0 73 | 74 | #### logger 75 | logger: 76 | print_freq: 100 77 | save_checkpoint_freq: !!float 1e4 78 | 79 | #### augment 80 | augment: 81 | augs: ["none", "cutblur"] 82 | probs: [1.0, 1.0] 83 | mix_p: [0.95, 0.05] 84 | alphas: [1.0, 0.7] 85 | -------------------------------------------------------------------------------- /codes/options/train/train_TDAN_RealVSR_YCbCr_Combine.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_TDAN_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Combine 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: TDAN 43 | nf: 64 44 | nc: 3 45 | nframes: 3 46 | nb_f: 5 47 | nb_b: 10 48 | groups: 8 49 | 50 | #### path 51 | path: 52 | pretrain_model_G: ~ 53 | strict_load: true 54 | resume_state: ~ 55 | 56 | #### training settings: learning rate scheme, loss 57 | train: 58 | lr_G: !!float 1e-4 59 | lr_scheme: CosineAnnealingLR_Restart 60 | beta1: 0.9 61 | beta2: 0.99 62 | niter: 150000 63 | warmup_iter: -1 # -1: no warm up 64 | T_period: [150000, 150000, 150000, 150000] 65 | restarts: [150000, 300000, 450000] 66 | restart_weights: [1, 1, 1] 67 | eta_min: !!float 1e-7 68 | 69 | pixel_criterion: cb 70 | pixel_weight: 1.0 71 | val_freq: !!float 1e4 72 | 73 | manual_seed: 0 74 | 75 | #### logger 76 | logger: 77 | print_freq: 100 78 | save_checkpoint_freq: !!float 1e4 79 | 80 | #### augment 81 | augment: 82 | augs: ["none", "cutblur"] 83 | probs: [1.0, 1.0] 84 | mix_p: [0.95, 0.05] 85 | alphas: [1.0, 0.7] 86 | -------------------------------------------------------------------------------- /codes/options/train/train_TOF_Vimeo90K.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_TOF_scratch_lr1e-4_150k_Vimeo90k_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: Vimeo90k_Train 13 | mode: Vimeo90k_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_ycbcr/sequences 18 | dataroot_LQ: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_BIx2_same_ycbcr/sequences 19 | cache_keys: /home/yangxi/datasets/Vimeo90k/keys/vimeo90k_train_keys.pkl 20 | N_frames: 3 21 | use_shuffle: true 22 | n_workers: 8 # per GPU 23 | batch_size: 32 24 | GT_size: 192 25 | LQ_size: 192 26 | use_flip: true 27 | use_rot: true 28 | color: ycbcr 29 | val: 30 | name: RealVSR_Test 31 | mode: VideoTest 32 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 33 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 34 | cache_data: true 35 | N_frames: 3 36 | padding: new_info 37 | color: ycbcr 38 | 39 | #### network structures 40 | network_G: 41 | which_model_G: TOF 42 | nframes: 3 43 | K: 3 44 | nc: 3 45 | nf: 64 46 | nb: 10 47 | 48 | #### path 49 | path: 50 | pretrain_model_G: ~ 51 | strict_load: true 52 | resume_state: ~ 53 | 54 | #### training settings: learning rate scheme, loss 55 | train: 56 | lr_G: !!float 1e-4 57 | lr_scheme: CosineAnnealingLR_Restart 58 | beta1: 0.9 59 | beta2: 0.99 60 | niter: 150000 61 | warmup_iter: -1 # -1: no warm up 62 | T_period: [150000, 150000, 150000, 150000] 63 | restarts: [150000, 300000, 450000] 64 | restart_weights: [1, 1, 1] 65 | eta_min: !!float 1e-7 66 | 67 | pixel_criterion: cb 68 | pixel_weight: 1.0 69 | val_freq: !!float 1e4 70 | 71 | manual_seed: 0 72 | 73 | #### logger 74 | logger: 75 | print_freq: 100 76 | save_checkpoint_freq: !!float 1e4 77 | 78 | #### augment 79 | augment: 80 | augs: ["none", "cutblur"] 81 | probs: [1.0, 1.0] 82 | mix_p: [0.95, 0.05] 83 | alphas: [1.0, 0.7] 84 | 85 | -------------------------------------------------------------------------------- /codes/options/train/train_TDAN_Vimeo90K.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_TDAN_scratch_lr1e-4_150k_Vimeo90k_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: Vimeo90k_Train 13 | mode: Vimeo90k_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_ycbcr/sequences 18 | dataroot_LQ: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_BIx2_same_ycbcr/sequences 19 | cache_keys: /home/yangxi/datasets/Vimeo90k/keys/vimeo90k_train_keys.pkl 20 | N_frames: 3 21 | use_shuffle: true 22 | n_workers: 8 # per GPU 23 | batch_size: 32 24 | GT_size: 192 25 | LQ_size: 192 26 | use_flip: true 27 | use_rot: true 28 | color: ycbcr 29 | val: 30 | name: RealVSR_Test 31 | mode: VideoTest 32 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 33 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 34 | cache_data: true 35 | N_frames: 3 36 | padding: new_info 37 | color: ycbcr 38 | 39 | #### network structures 40 | network_G: 41 | which_model_G: TDAN 42 | nf: 64 43 | nc: 3 44 | nframes: 3 45 | nb_f: 5 46 | nb_b: 10 47 | groups: 8 48 | 49 | #### path 50 | path: 51 | pretrain_model_G: ~ 52 | strict_load: true 53 | resume_state: ~ 54 | 55 | #### training settings: learning rate scheme, loss 56 | train: 57 | lr_G: !!float 1e-4 58 | lr_scheme: CosineAnnealingLR_Restart 59 | beta1: 0.9 60 | beta2: 0.99 61 | niter: 150000 62 | warmup_iter: -1 # -1: no warm up 63 | T_period: [150000, 150000, 150000, 150000] 64 | restarts: [150000, 300000, 450000] 65 | restart_weights: [1, 1, 1] 66 | eta_min: !!float 1e-7 67 | 68 | pixel_criterion: cb 69 | pixel_weight: 1.0 70 | val_freq: !!float 1e4 71 | 72 | manual_seed: 0 73 | 74 | #### logger 75 | logger: 76 | print_freq: 100 77 | save_checkpoint_freq: !!float 1e4 78 | 79 | #### augment 80 | augment: 81 | augs: ["none", "cutblur"] 82 | probs: [1.0, 1.0] 83 | mix_p: [0.95, 0.05] 84 | alphas: [1.0, 0.7] -------------------------------------------------------------------------------- /codes/options/train/train_FSTRN_RealVSR_YCbCr_Split.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_FSTRN_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Split 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: FSTRN 43 | k: 3 44 | nf: 64 45 | nframes: 3 46 | 47 | #### path 48 | path: 49 | pretrain_model_G: ~ 50 | strict_load: true 51 | resume_state: ~ 52 | 53 | #### training settings: learning rate scheme, loss 54 | train: 55 | lr_G: !!float 1e-4 56 | lr_scheme: CosineAnnealingLR_Restart 57 | beta1: 0.9 58 | beta2: 0.99 59 | niter: 150000 60 | warmup_iter: -1 # -1: no warm up 61 | T_period: [150000, 150000, 150000, 150000] 62 | restarts: [150000, 300000, 450000] 63 | restart_weights: [1, 1, 1] 64 | eta_min: !!float 1e-7 65 | 66 | pixel_criterion_y: lappyr 67 | pixel_weight_y: 1.0 68 | pixel_criterion_c: gw 69 | pixel_weight_c: 1.0 70 | val_freq: !!float 1e4 71 | 72 | manual_seed: 0 73 | 74 | #### logger 75 | logger: 76 | print_freq: 100 77 | save_checkpoint_freq: !!float 1e4 78 | 79 | #### augment 80 | augment: 81 | augs: ["none", "cutblur"] 82 | probs: [1.0, 1.0] 83 | mix_p: [0.95, 0.05] 84 | alphas: [1.0, 0.7] 85 | -------------------------------------------------------------------------------- /codes/options/train/train_TOF_RealVSR_YCbCr_Split.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_TOF_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Split 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: TOF 43 | nframes: 3 44 | K: 3 45 | nc: 3 46 | nf: 64 47 | nb: 10 48 | 49 | #### path 50 | path: 51 | pretrain_model_G: ~ 52 | strict_load: true 53 | resume_state: ~ 54 | 55 | #### training settings: learning rate scheme, loss 56 | train: 57 | lr_G: !!float 1e-4 58 | lr_scheme: CosineAnnealingLR_Restart 59 | beta1: 0.9 60 | beta2: 0.99 61 | niter: 150000 62 | warmup_iter: -1 # -1: no warm up 63 | T_period: [150000, 150000, 150000, 150000] 64 | restarts: [150000, 300000, 450000] 65 | restart_weights: [1, 1, 1] 66 | eta_min: !!float 1e-7 67 | 68 | pixel_criterion_y: lappyr 69 | pixel_weight_y: 1.0 70 | pixel_criterion_c: gw 71 | pixel_weight_c: 1.0 72 | val_freq: !!float 1e4 73 | 74 | manual_seed: 0 75 | 76 | #### logger 77 | logger: 78 | print_freq: 100 79 | save_checkpoint_freq: !!float 1e4 80 | 81 | #### augment 82 | augment: 83 | augs: ["none", "cutblur"] 84 | probs: [1.0, 1.0] 85 | mix_p: [0.95, 0.05] 86 | alphas: [1.0, 0.7] 87 | -------------------------------------------------------------------------------- /codes/options/train/train_TDAN_RealVSR_YCbCr_Split.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_TDAN_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: TDAN 43 | nf: 64 44 | nc: 3 45 | nframes: 3 46 | nb_f: 5 47 | nb_b: 10 48 | groups: 8 49 | 50 | #### path 51 | path: 52 | pretrain_model_G: ~ 53 | strict_load: true 54 | resume_state: ~ 55 | 56 | #### training settings: learning rate scheme, loss 57 | train: 58 | lr_G: !!float 1e-4 59 | lr_scheme: CosineAnnealingLR_Restart 60 | beta1: 0.9 61 | beta2: 0.99 62 | niter: 150000 63 | warmup_iter: -1 # -1: no warm up 64 | T_period: [150000, 150000, 150000, 150000] 65 | restarts: [150000, 300000, 450000] 66 | restart_weights: [1, 1, 1] 67 | eta_min: !!float 1e-7 68 | 69 | pixel_criterion_y: lappyr 70 | pixel_weight_y: 1.0 71 | pixel_criterion_c: gw 72 | pixel_weight_c: 1.0 73 | val_freq: !!float 1e4 74 | 75 | manual_seed: 0 76 | 77 | #### logger 78 | logger: 79 | print_freq: 100 80 | save_checkpoint_freq: !!float 1e4 81 | 82 | #### augment 83 | augment: 84 | augs: ["none", "cutblur"] 85 | probs: [1.0, 1.0] 86 | mix_p: [0.95, 0.05] 87 | alphas: [1.0, 0.7] 88 | -------------------------------------------------------------------------------- /codes/options/train/train_RCAN_RealVSR_YCbCr_Combine.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_RCAN_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Combine 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: RCAN 43 | num_in_ch: 3 44 | num_out_ch: 3 45 | num_frames: 3 46 | num_feat: 64 47 | num_group: 5 48 | num_block: 2 49 | squeeze_factor: 16 50 | res_scale: 1 51 | 52 | #### path 53 | path: 54 | pretrain_model_G: ~ 55 | strict_load: true 56 | resume_state: ~ 57 | 58 | #### training settings: learning rate scheme, loss 59 | train: 60 | lr_G: !!float 1e-4 61 | lr_scheme: CosineAnnealingLR_Restart 62 | beta1: 0.9 63 | beta2: 0.99 64 | niter: 150000 65 | warmup_iter: -1 # -1: no warm up 66 | T_period: [150000, 150000, 150000, 150000] 67 | restarts: [150000, 300000, 450000] 68 | restart_weights: [1, 1, 1] 69 | eta_min: !!float 1e-7 70 | 71 | pixel_criterion: cb 72 | pixel_weight: 1.0 73 | val_freq: !!float 1e4 74 | 75 | manual_seed: 0 76 | 77 | #### logger 78 | logger: 79 | print_freq: 100 80 | save_checkpoint_freq: !!float 1e4 81 | 82 | #### augment 83 | augment: 84 | augs: ["none", "cutblur"] 85 | probs: [1.0, 1.0] 86 | mix_p: [0.95, 0.05] 87 | alphas: [1.0, 0.7] 88 | -------------------------------------------------------------------------------- /codes/options/train/train_EDVR_woTSA_RealVSR_YCbCr_Combine.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_EDVR_NoUp_woTSA_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Combine 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: EDVR_NoUp 43 | nf: 64 44 | nc: 3 45 | nframes: 3 46 | groups: 8 47 | front_RBs: 5 48 | back_RBs: 10 49 | predeblur: false 50 | HR_in: false 51 | w_TSA: false 52 | 53 | #### path 54 | path: 55 | pretrain_model_G: ~ 56 | strict_load: true 57 | resume_state: ~ 58 | 59 | #### training settings: learning rate scheme, loss 60 | train: 61 | lr_G: !!float 1e-4 62 | lr_scheme: CosineAnnealingLR_Restart 63 | beta1: 0.9 64 | beta2: 0.99 65 | niter: 150000 66 | warmup_iter: -1 # -1: no warm up 67 | T_period: [150000, 150000, 150000, 150000] 68 | restarts: [150000, 300000, 450000] 69 | restart_weights: [1, 1, 1] 70 | eta_min: !!float 1e-7 71 | 72 | pixel_criterion: cb 73 | pixel_weight: 1.0 74 | val_freq: !!float 1e4 75 | 76 | manual_seed: 0 77 | 78 | #### logger 79 | logger: 80 | print_freq: 100 81 | save_checkpoint_freq: !!float 1e4 82 | 83 | #### augment 84 | augment: 85 | augs: ["none", "cutblur"] 86 | probs: [1.0, 1.0] 87 | mix_p: [0.95, 0.05] 88 | alphas: [1.0, 0.7] 89 | -------------------------------------------------------------------------------- /codes/options/train/train_RCAN_Vimeo90K.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_RCAN_scratch_lr1e-4_150k_Vimeo90k_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: Vimeo90k_Train 13 | mode: Vimeo90k_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_ycbcr/sequences 18 | dataroot_LQ: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_BIx2_same_ycbcr/sequences 19 | cache_keys: /home/yangxi/datasets/Vimeo90k/keys/vimeo90k_train_keys.pkl 20 | N_frames: 3 21 | use_shuffle: true 22 | n_workers: 8 # per GPU 23 | batch_size: 32 24 | GT_size: 192 25 | LQ_size: 192 26 | use_flip: true 27 | use_rot: true 28 | color: ycbcr 29 | val: 30 | name: RealVSR_Test 31 | mode: VideoTest 32 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 33 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 34 | cache_data: true 35 | N_frames: 3 36 | padding: new_info 37 | color: ycbcr 38 | 39 | #### network structures 40 | network_G: 41 | which_model_G: RCAN 42 | num_in_ch: 3 43 | num_out_ch: 3 44 | num_frames: 3 45 | num_feat: 64 46 | num_group: 5 47 | num_block: 2 48 | squeeze_factor: 16 49 | res_scale: 1 50 | 51 | #### path 52 | path: 53 | pretrain_model_G: ~ 54 | strict_load: true 55 | resume_state: ~ 56 | 57 | #### training settings: learning rate scheme, loss 58 | train: 59 | lr_G: !!float 1e-4 60 | lr_scheme: CosineAnnealingLR_Restart 61 | beta1: 0.9 62 | beta2: 0.99 63 | niter: 150000 64 | warmup_iter: -1 # -1: no warm up 65 | T_period: [150000, 150000, 150000, 150000] 66 | restarts: [150000, 300000, 450000] 67 | restart_weights: [1, 1, 1] 68 | eta_min: !!float 1e-7 69 | 70 | pixel_criterion: cb 71 | pixel_weight: 1.0 72 | val_freq: !!float 1e4 73 | 74 | manual_seed: 0 75 | 76 | #### logger 77 | logger: 78 | print_freq: 100 79 | save_checkpoint_freq: !!float 1e4 80 | 81 | #### augment 82 | augment: 83 | augs: ["none", "cutblur"] 84 | probs: [1.0, 1.0] 85 | mix_p: [0.95, 0.05] 86 | alphas: [1.0, 0.7] 87 | 88 | -------------------------------------------------------------------------------- /codes/options/train/train_EDVR_woTSA_Vimeo90K.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_EDVR_NoUp_woTSA_scratch_lr1e-4_150k_Vimeo90k_3frame_WiCutBlur_YCbCr_CB 3 | use_tb_logger: true 4 | model: VideoSR_AllPair 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: Vimeo90k_Train 13 | mode: Vimeo90k_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_ycbcr/sequences 18 | dataroot_LQ: /home/yangxi/datasets/Vimeo90k/vimeo_septuplet_BIx2_same_ycbcr/sequences 19 | cache_keys: /home/yangxi/datasets/Vimeo90k/keys/vimeo90k_train_keys.pkl 20 | N_frames: 3 21 | use_shuffle: true 22 | n_workers: 8 # per GPU 23 | batch_size: 32 24 | GT_size: 192 25 | LQ_size: 192 26 | use_flip: true 27 | use_rot: true 28 | color: ycbcr 29 | val: 30 | name: RealVSR_Test 31 | mode: VideoTest 32 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 33 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 34 | cache_data: true 35 | N_frames: 3 36 | padding: new_info 37 | color: ycbcr 38 | 39 | #### network structures 40 | network_G: 41 | which_model_G: EDVR_NoUp 42 | nf: 64 43 | nc: 3 44 | nframes: 3 45 | groups: 8 46 | front_RBs: 5 47 | back_RBs: 10 48 | predeblur: false 49 | HR_in: false 50 | w_TSA: false 51 | 52 | #### path 53 | path: 54 | pretrain_model_G: ~ 55 | strict_load: true 56 | resume_state: ~ 57 | 58 | #### training settings: learning rate scheme, loss 59 | train: 60 | lr_G: !!float 1e-4 61 | lr_scheme: CosineAnnealingLR_Restart 62 | beta1: 0.9 63 | beta2: 0.99 64 | niter: 150000 65 | warmup_iter: -1 # -1: no warm up 66 | T_period: [150000, 150000, 150000, 150000] 67 | restarts: [150000, 300000, 450000] 68 | restart_weights: [1, 1, 1] 69 | eta_min: !!float 1e-7 70 | 71 | pixel_criterion: cb 72 | pixel_weight: 1.0 73 | val_freq: !!float 1e4 74 | 75 | manual_seed: 0 76 | 77 | #### logger 78 | logger: 79 | print_freq: 100 80 | save_checkpoint_freq: !!float 1e4 81 | 82 | #### augment 83 | augment: 84 | augs: ["none", "cutblur"] 85 | probs: [1.0, 1.0] 86 | mix_p: [0.95, 0.05] 87 | alphas: [1.0, 0.7] 88 | -------------------------------------------------------------------------------- /codes/options/train/train_RCAN_RealVSR_YCbCr_Split.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_RCAN_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Split 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: RCAN 43 | num_in_ch: 3 44 | num_out_ch: 3 45 | num_frames: 3 46 | num_feat: 64 47 | num_group: 5 48 | num_block: 2 49 | squeeze_factor: 16 50 | res_scale: 1 51 | 52 | #### path 53 | path: 54 | pretrain_model_G: ~ 55 | strict_load: true 56 | resume_state: ~ 57 | 58 | #### training settings: learning rate scheme, loss 59 | train: 60 | lr_G: !!float 1e-4 61 | lr_scheme: CosineAnnealingLR_Restart 62 | beta1: 0.9 63 | beta2: 0.99 64 | niter: 150000 65 | warmup_iter: -1 # -1: no warm up 66 | T_period: [150000, 150000, 150000, 150000] 67 | restarts: [150000, 300000, 450000] 68 | restart_weights: [1, 1, 1] 69 | eta_min: !!float 1e-7 70 | 71 | pixel_criterion_y: lappyr 72 | pixel_weight_y: 1.0 73 | pixel_criterion_c: gw 74 | pixel_weight_c: 1.0 75 | val_freq: !!float 1e4 76 | 77 | manual_seed: 0 78 | 79 | #### logger 80 | logger: 81 | print_freq: 100 82 | save_checkpoint_freq: !!float 1e4 83 | 84 | #### augment 85 | augment: 86 | augs: ["none", "cutblur"] 87 | probs: [1.0, 1.0] 88 | mix_p: [0.95, 0.05] 89 | alphas: [1.0, 0.7] 90 | -------------------------------------------------------------------------------- /codes/options/train/train_EDVR_woTSA_RealVSR_YCbCr_Split.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_EDVR_NoUp_woTSA_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW 3 | use_tb_logger: true 4 | model: VideoSR_AllPair_YCbCr_Split 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: EDVR_NoUp 43 | nf: 64 44 | nc: 3 45 | nframes: 3 46 | groups: 8 47 | front_RBs: 5 48 | back_RBs: 10 49 | predeblur: false 50 | HR_in: false 51 | w_TSA: false 52 | 53 | #### path 54 | path: 55 | pretrain_model_G: ~ 56 | strict_load: true 57 | resume_state: ~ 58 | 59 | #### training settings: learning rate scheme, loss 60 | train: 61 | lr_G: !!float 1e-4 62 | lr_scheme: CosineAnnealingLR_Restart 63 | beta1: 0.9 64 | beta2: 0.99 65 | niter: 150000 66 | warmup_iter: -1 # -1: no warm up 67 | T_period: [150000, 150000, 150000, 150000] 68 | restarts: [150000, 300000, 450000] 69 | restart_weights: [1, 1, 1] 70 | eta_min: !!float 1e-7 71 | 72 | pixel_criterion_y: lappyr 73 | pixel_weight_y: 1.0 74 | pixel_criterion_c: gw 75 | pixel_weight_c: 1.0 76 | val_freq: !!float 1e4 77 | 78 | manual_seed: 0 79 | 80 | #### logger 81 | logger: 82 | print_freq: 100 83 | save_checkpoint_freq: !!float 1e4 84 | 85 | #### augment 86 | augment: 87 | augs: ["none", "cutblur"] 88 | probs: [1.0, 1.0] 89 | mix_p: [0.95, 0.05] 90 | alphas: [1.0, 0.7] 91 | -------------------------------------------------------------------------------- /codes/models/archs/SRResNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import models.archs.arch_util as arch_util 6 | 7 | 8 | class MSRResNet(nn.Module): 9 | """ modified SRResNet""" 10 | 11 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 12 | super(MSRResNet, self).__init__() 13 | self.upscale = upscale 14 | 15 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 16 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 17 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 18 | 19 | # upsampling 20 | if self.upscale == 2: 21 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 22 | self.pixel_shuffle = nn.PixelShuffle(2) 23 | elif self.upscale == 3: 24 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 25 | self.pixel_shuffle = nn.PixelShuffle(3) 26 | elif self.upscale == 4: 27 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 28 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 29 | self.pixel_shuffle = nn.PixelShuffle(2) 30 | 31 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 32 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 33 | 34 | # activation function 35 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 36 | 37 | # initialization 38 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 39 | 0.1) 40 | if self.upscale == 4: 41 | arch_util.initialize_weights(self.upconv2, 0.1) 42 | 43 | def forward(self, x): 44 | fea = self.lrelu(self.conv_first(x)) 45 | out = self.recon_trunk(fea) 46 | 47 | if self.upscale == 4: 48 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 49 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 50 | elif self.upscale == 3 or self.upscale == 2: 51 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 52 | 53 | out = self.conv_last(self.lrelu(self.HRconv(out))) 54 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 55 | out += base 56 | return out 57 | 58 | -------------------------------------------------------------------------------- /codes/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /codes/data/augments_video_allpair.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def apply_augment( 7 | im1, im2, 8 | augs, probs, alphas, mix_p=None 9 | ): 10 | idx = np.random.choice(len(augs), p=mix_p) 11 | aug = augs[idx] 12 | prob = float(probs[idx]) 13 | alpha = float(alphas[idx]) 14 | 15 | if aug == "none": 16 | im1_aug, im2_aug = im1.clone(), im2.clone() 17 | elif aug == "blend": 18 | im1_aug, im2_aug = blend( 19 | im1.clone(), im2.clone(), 20 | prob=prob, alpha=alpha 21 | ) 22 | elif aug == "cutblur": 23 | im1_aug, im2_aug = cutblur( 24 | im1.clone(), im2.clone(), 25 | prob=prob, alpha=alpha 26 | ) 27 | elif aug == "rgb": 28 | im1_aug, im2_aug = rgb( 29 | im1.clone(), im2.clone(), 30 | prob=prob 31 | ) 32 | else: 33 | raise ValueError("{} is not invalid.".format(aug)) 34 | 35 | return im1_aug, im2_aug 36 | 37 | 38 | def blend(im1, im2, prob=1.0, alpha=0.6): 39 | if alpha <= 0 or np.random.rand(1) >= prob: 40 | return im1, im2 41 | 42 | c = torch.empty((im2.size(0), im2.size(1), 3, 1, 1), device=im2.device).uniform_(0, 1) 43 | rim2 = c.repeat((1, 1, 1, im2.size(3), im2.size(4))) 44 | rim1 = c.repeat((1, 1, 1, im1.size(3), im1.size(4))) 45 | 46 | v = np.random.uniform(alpha, 1) 47 | im1 = v * im1 + (1-v) * rim1 48 | im2 = v * im2 + (1-v) * rim2 49 | 50 | return im1, im2 51 | 52 | 53 | def cutblur(im1, im2, prob=1.0, alpha=1.0): 54 | if im1.size() != im2.size(): 55 | raise ValueError("im1 and im2 have to be the same resolution.") 56 | 57 | if alpha <= 0 or np.random.rand(1) >= prob: 58 | return im1, im2 59 | 60 | cut_ratio = np.random.randn() * 0.01 + alpha 61 | 62 | h, w = im2.size(2), im2.size(3) 63 | ch, cw = int(h*cut_ratio), int(w*cut_ratio) 64 | cy = np.random.randint(0, h-ch+1) 65 | cx = np.random.randint(0, w-cw+1) 66 | 67 | # apply CutBlur to inside or outside 68 | if np.random.random() > 0.5: 69 | im2[..., cy:cy+ch, cx:cx+cw] = im1[..., cy:cy+ch, cx:cx+cw] 70 | else: 71 | im2_aug = im1.clone() 72 | im2_aug[..., cy:cy+ch, cx:cx+cw] = im2[..., cy:cy+ch, cx:cx+cw] 73 | im2 = im2_aug 74 | 75 | return im1, im2 76 | 77 | 78 | def rgb(im1, im2, prob=1.0): 79 | if np.random.rand(1) >= prob: 80 | return im1, im2 81 | 82 | perm = np.random.permutation(3) 83 | im1 = im1[:, :, perm, :, :] 84 | im2 = im2[:, :, perm, :, :] 85 | 86 | return im1, im2 87 | -------------------------------------------------------------------------------- /codes/scripts/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import cv2 5 | import pickle 6 | import random 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from shutil import copy, copytree 13 | 14 | try: 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 16 | import data.util as data_util 17 | import utils.util as util 18 | except ImportError: 19 | pass 20 | 21 | 22 | def rgb2ycbcr(src_root, dst_root, only_y=True): 23 | util.mkdir(dst_root) 24 | src_img_paths = sorted(glob.glob(os.path.join(src_root, '*.png'))) 25 | 26 | for src_img_path in src_img_paths: 27 | print(src_img_path) 28 | src_img = cv2.imread(src_img_path) 29 | dst_img = data_util.bgr2ycbcr(src_img, only_y=only_y) 30 | cv2.imwrite(os.path.join(dst_root, '{}'.format(os.path.basename(src_img_path))), dst_img[:, :, [2, 1, 0]]) 31 | 32 | 33 | def realvsr(src_root, dst_root, only_y): 34 | seq_paths = sorted(glob.glob(os.path.join(src_root, '*'))) 35 | seqs = [os.path.basename(seq_path) for seq_path in seq_paths] 36 | 37 | for seq in seqs: 38 | print('Processing {}'.format(seq)) 39 | src_img_paths = sorted(glob.glob(os.path.join(src_root, seq, '*.png'))) 40 | 41 | for src_img_path in src_img_paths: 42 | src_img = cv2.imread(src_img_path) 43 | dst_img = data_util.bgr2ycbcr(src_img, only_y=only_y) 44 | util.mkdir(os.path.join(dst_root, seq)) 45 | cv2.imwrite(os.path.join(dst_root, seq, '{}'.format(os.path.basename(src_img_path))), dst_img[:, :, [2, 1, 0]]) 46 | 47 | 48 | def vimeo90k(src_root, dst_root): 49 | seq_paths = sorted(glob.glob(os.path.join(src_root, '*', '*', '*.png'))) 50 | 51 | for src_img_path in seq_paths: 52 | print(src_img_path) 53 | tmp_list = src_img_path.split('/') 54 | name_a, name_b, img_name = tmp_list[-3], tmp_list[-2], tmp_list[-1] 55 | src_img = cv2.imread(src_img_path) 56 | dst_img = data_util.bgr2ycbcr(src_img, only_y=False) 57 | util.mkdir(os.path.join(dst_root, name_a, name_b)) 58 | cv2.imwrite(os.path.join(dst_root, name_a, name_b, img_name), dst_img[:, :, [2, 1, 0]]) 59 | 60 | 61 | def save_keys_realvsr(save_path): 62 | key_list = [] 63 | for seq_idx in range(500): 64 | for img_idx in range(50): 65 | key_list.append('{:03d}_{:05d}'.format(seq_idx, img_idx)) 66 | with open(save_path, 'wb') as f: 67 | pickle.dump({'keys': key_list}, f) 68 | 69 | 70 | if __name__ == '__main__': 71 | pass 72 | -------------------------------------------------------------------------------- /codes/metrics/evaluate_niqe_brisque.m: -------------------------------------------------------------------------------- 1 | function evaluate_niqe_brisque(root, expn, if_niqe, if_brisque, result_path, niqe_model_path) 2 | 3 | fileID = fopen(result_path,'a'); 4 | 5 | root_dir = strcat(root, expn); 6 | 7 | seq_struct = dir(fullfile(root_dir, '*')); 8 | seq_struct = seq_struct([seq_struct.isdir]); 9 | 10 | if niqe_model_path ~= -1 11 | load(niqe_model_path, 'niqe_model'); 12 | fprintf('load custem model\n'); 13 | else 14 | niqe_model = niqeModel; 15 | fprintf('load default model\n'); 16 | end 17 | 18 | % filter away '.' and '..' 19 | seq_cell = regexpi({seq_struct.name}, '[0-9]{3}', 'match', 'once'); 20 | seq_cell = seq_cell(~cellfun('isempty', seq_cell)); 21 | 22 | if if_niqe 23 | niqe_results = cell(1, length(seq_cell)); 24 | end 25 | if if_brisque 26 | brisque_results = cell(1, length(seq_cell)); 27 | end 28 | % loop for sequence 29 | for i = 1:length(seq_cell) 30 | 31 | frm_struct = dir(fullfile(root_dir, seq_cell{i}, '*.png')); 32 | frm_cell = {frm_struct.name}; 33 | 34 | niqe_folder_sum = 0; 35 | brisque_folder_sum = 0; 36 | 37 | % loop for frame 38 | for j = 1:length(frm_cell) 39 | frm_path = fullfile(root_dir, seq_cell{i}, frm_cell{j}); 40 | img = imread(frm_path); 41 | if if_niqe 42 | niqe_score = niqe(img, niqe_model); 43 | niqe_folder_sum = niqe_folder_sum + niqe_score; 44 | end 45 | if if_brisque 46 | brisque_score = brisque(img); 47 | brisque_folder_sum = brisque_folder_sum + brisque_score; 48 | end 49 | end 50 | 51 | if if_niqe 52 | fprintf('%s NIQE: %.4f\n', seq_cell{i}, niqe_folder_sum / length(frm_cell)); 53 | fprintf(fileID, '%s NIQE: %.4f\n', seq_cell{i}, niqe_folder_sum / length(frm_cell)); 54 | niqe_results{i} = niqe_folder_sum / length(frm_cell); 55 | end 56 | if if_brisque 57 | fprintf('%s BRISQUE: %.4f\n', seq_cell{i}, brisque_folder_sum / length(frm_cell)); 58 | fprintf(fileID, '%s BRISQUE: %.4f\n', seq_cell{i}, brisque_folder_sum / length(frm_cell)); 59 | brisque_results{i} = brisque_folder_sum / length(frm_cell); 60 | end 61 | end 62 | 63 | if if_niqe 64 | niqe_mean = mean(cell2mat(niqe_results)); 65 | fprintf(fileID, '%s, NIQE: %.4f\n', expn, niqe_mean); 66 | end 67 | if if_brisque 68 | brisque_mean = mean(cell2mat(brisque_results)); 69 | fprintf(fileID, '%s, BRISQUE: %.4f\n', expn, brisque_mean); 70 | end 71 | fclose(fileID); 72 | end 73 | 74 | -------------------------------------------------------------------------------- /codes/options/train/train_TOF-GAN_RealVSR_YCbCr_Split.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_TOF-GAN_scratch_lr5e-5_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW+EdgeGAN 3 | use_tb_logger: true 4 | model: VideoSRGAN_AllPair_YCbCr_Split 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: TOF 43 | nframes: 3 44 | K: 3 45 | nc: 3 46 | nf: 64 47 | nb: 10 48 | 49 | network_D: 50 | which_model_D: MultiscaleDiscriminator_v4 51 | in_nc: 1 52 | nf: 64 53 | num_D: 2 54 | gan_type: patch 55 | 56 | #### path 57 | path: 58 | pretrain_model_G: ../experiments/pretrained_models/001_VSEDSR_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW.pth 59 | strict_load: true 60 | resume_state: ~ 61 | 62 | #### training settings: learning rate scheme, loss 63 | train: 64 | lr_G: !!float 5e-5 65 | weight_decay_G: 0 66 | beta1_G: 0.9 67 | beta2_G: 0.99 68 | lr_D: !!float 5e-5 69 | weight_decay_D: 0 70 | beta1_D: 0.9 71 | beta2_D: 0.99 72 | lr_scheme: CosineAnnealingLR_Restart 73 | niter: 150000 74 | warmup_iter: -1 # -1: no warm up 75 | T_period: [150000, 150000, 150000, 150000] 76 | restarts: [150000, 300000, 450000] 77 | restart_weights: [1, 1, 1] 78 | eta_min: !!float 1e-7 79 | 80 | pixel_criterion_s: ssim 81 | pixel_weight_s: 1.0 82 | pixel_criterion_d: cb 83 | pixel_weight_d: 1.0 84 | pixel_criterion_c: gw 85 | pixel_weight_c: 1.0 86 | feature_criterion: cb 87 | feature_weight: 0.0 88 | gan_type: ragan 89 | gan_weight: !!float 1e-4 90 | val_freq: !!float 1e4 91 | 92 | manual_seed: 0 93 | 94 | #### logger 95 | logger: 96 | print_freq: 100 97 | save_checkpoint_freq: !!float 1e4 98 | 99 | #### augment 100 | augment: 101 | augs: ["none", "cutblur"] 102 | probs: [1.0, 1.0] 103 | mix_p: [0.95, 0.05] 104 | alphas: [1.0, 0.7] 105 | -------------------------------------------------------------------------------- /codes/options/train/train_EDVR-GAN_woTSA_RealVSR_YCbCr_Split.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_EDVR_NoUp_woTSA_scratch_lr5e-5_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW+EdgeGAN 3 | use_tb_logger: true 4 | model: VideoSRGAN_AllPair_YCbCr_Split 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0,1,2,3] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: RealVSR_Train 13 | mode: RealVSR_AllPair 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr 18 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr 19 | cache_keys: ../keys/realvsr_keys.pkl 20 | remove_list: ../keys/remove_seqs.pkl 21 | N_frames: 3 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 192 26 | LQ_size: 192 27 | use_flip: true 28 | use_rot: true 29 | color: ycbcr 30 | val: 31 | name: RealVSR_Test 32 | mode: VideoTest 33 | dataroot_GT: /home/yangxi/datasets/RealVSR/GT_YCbCr_test_10 34 | dataroot_LQ: /home/yangxi/datasets/RealVSR/LQ_YCbCr_test_10 35 | cache_data: true 36 | N_frames: 3 37 | padding: new_info 38 | color: ycbcr 39 | 40 | #### network structures 41 | network_G: 42 | which_model_G: EDVR_NoUp 43 | nf: 64 44 | nc: 3 45 | nframes: 3 46 | groups: 8 47 | front_RBs: 5 48 | back_RBs: 10 49 | predeblur: false 50 | HR_in: false 51 | w_TSA: false 52 | 53 | network_D: 54 | which_model_D: MultiscaleDiscriminator_v4 55 | in_nc: 1 56 | nf: 64 57 | num_D: 2 58 | gan_type: patch 59 | 60 | #### path 61 | path: 62 | pretrain_model_G: ../experiments/pretrained_models/001_EDVR_NoUp_woTSA_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW.pth 63 | strict_load: true 64 | resume_state: ~ 65 | 66 | #### training settings: learning rate scheme, loss 67 | train: 68 | lr_G: !!float 5e-5 69 | weight_decay_G: 0 70 | beta1_G: 0.9 71 | beta2_G: 0.99 72 | lr_D: !!float 5e-5 73 | weight_decay_D: 0 74 | beta1_D: 0.9 75 | beta2_D: 0.99 76 | lr_scheme: CosineAnnealingLR_Restart 77 | niter: 150000 78 | warmup_iter: -1 # -1: no warm up 79 | T_period: [150000, 150000, 150000, 150000] 80 | restarts: [150000, 300000, 450000] 81 | restart_weights: [1, 1, 1] 82 | eta_min: !!float 1e-7 83 | 84 | pixel_criterion_s: ssim 85 | pixel_weight_s: 1.0 86 | pixel_criterion_d: cb 87 | pixel_weight_d: 1.0 88 | pixel_criterion_c: gw 89 | pixel_weight_c: 1.0 90 | feature_criterion: cb 91 | feature_weight: 0.0 92 | gan_type: ragan 93 | gan_weight: !!float 1e-4 94 | val_freq: !!float 1e4 95 | 96 | manual_seed: 0 97 | 98 | #### logger 99 | logger: 100 | print_freq: 100 101 | save_checkpoint_freq: !!float 1e4 102 | 103 | #### augment 104 | augment: 105 | augs: ["none", "cutblur"] 106 | probs: [1.0, 1.0] 107 | mix_p: [0.95, 0.05] 108 | alphas: [1.0, 0.7] 109 | -------------------------------------------------------------------------------- /codes/models/archs/FSTRN_arch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network architecture for FSTRN: 3 | Fast Spatio-Temporal Residual Network for Video Super-Resolution (CVPR 2019) 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class FRB(nn.Module): 12 | """Fast spatial-temporal residual block""" 13 | def __init__(self, k=3, nf=64): 14 | super(FRB, self).__init__() 15 | self.prelu = nn.PReLU() 16 | self.conv3d_1 = nn.Conv3d(nf, nf, (1, k, k), stride=(1, 1, 1), padding=(0, 1, 1), bias=True) 17 | self.conv3d_2 = nn.Conv3d(nf, nf, (k, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=True) 18 | 19 | def forward(self, x): 20 | res = x 21 | out = self.conv3d_2(self.conv3d_1(self.prelu(x))) 22 | return res + out 23 | 24 | 25 | class FSTRN(nn.Module): 26 | """Fast spatial-temporal residual network""" 27 | def __init__(self, k=3, nf=64, scale=4, nframes=5): 28 | super(FSTRN, self).__init__() 29 | self.k = k 30 | self.nf = nf 31 | self.scale = scale 32 | self.center = nframes // 2 33 | #### LFENet 34 | self.conv3d_fe = nn.Conv3d(3, nf, (k, k, k), stride=(1, 1, 1), padding=(1, 1, 1), bias=True) 35 | #### FRBs 36 | self.frb_1 = FRB(k=k, nf=nf) 37 | self.frb_2 = FRB(k=k, nf=nf) 38 | self.frb_3 = FRB(k=k, nf=nf) 39 | self.frb_4 = FRB(k=k, nf=nf) 40 | self.frb_5 = FRB(k=k, nf=nf) 41 | #### LSRNet 42 | self.prelu = nn.PReLU() 43 | self.dropout = nn.Dropout(p=0.3, inplace=False) 44 | self.conv3d_1 = nn.Conv3d(nf, nf, (k, k, k), stride=(1, 1, 1), padding=(1, 1, 1), bias=True) 45 | self.upsample = nn.ConvTranspose3d(nf, nf, (1, self.scale, self.scale), 46 | stride=(1, self.scale, self.scale), bias=True) 47 | self.conv3d_2 = nn.Conv3d(nf, 3, (k, k, k), stride=(1, 1, 1), padding=(1, 1, 1), bias=True) 48 | 49 | def forward(self, x): 50 | """ 51 | x: [B, T, C, H, W], reshape to [B, C, T, H, W] for Conv3D 52 | """ 53 | x = x.permute(0, 2, 1, 3, 4) 54 | #### LFENet 55 | cs_res = x 56 | out = self.conv3d_fe(x) 57 | #### FRBs (with LR residual connection) 58 | lr_res = out 59 | out = self.frb_5(self.frb_4(self.frb_3(self.frb_2(self.frb_1(out))))) 60 | out = lr_res + out 61 | #### LSRNet 62 | out = self.dropout(self.prelu(out)) 63 | out = self.conv3d_1(out) 64 | out = self.upsample(out) 65 | out = self.conv3d_2(out) 66 | #### Cross-space residual connection 67 | cs_out = F.interpolate(cs_res, scale_factor=(1, self.scale, self.scale), mode='trilinear', align_corners=False) 68 | out = cs_out + out 69 | return out[:, :, self.center, :, :] 70 | 71 | 72 | if __name__ == '__main__': 73 | device = torch.device('cuda') 74 | model = FSTRN(k=3, nf=64, nframes=5, scale=1).to(device) 75 | x = torch.randn(32, 5, 3, 48, 48).to(device) 76 | out = model(x) 77 | print(out.shape) 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dataset and Code for RealVSR 2 | 3 | >[Real-world Video Super-resolution: A Benchmark Dataset and A Decomposition based Learning Scheme](https://openaccess.thecvf.com/content/ICCV2021/papers/Yang_Real-World_Video_Super-Resolution_A_Benchmark_Dataset_and_a_Decomposition_Based_ICCV_2021_paper.pdf) \ 4 | >Xi Yang, Wangmeng Xiang, Hui Zeng and Lei Zhang \ 5 | >International Conference on Computer Vision, 2021. 6 | 7 | ## Dataset 8 | 9 | The dataset is hosted on [Google Drive](https://drive.google.com/drive/folders/1-8MvMEYMOeOE713DjI7TJKyRE-LnrM3Y?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1rBIGo5xrY2VtpoUF2gf_HA) (code: 43ph). Some example scenes are shown below. 10 | 11 | ![dataset_samples](./imgs/dataset_samples.png) 12 | 13 | The structure of the dataset is illustrated below. 14 | 15 | | File | Description | 16 | | ------------------------ |:-------------------------------------------:| 17 | | GT.zip | All ground truth sequences in RGB format | 18 | | LQ.zip | All low quality sequences in RGB format | 19 | | GT_YCbCr.zip | All ground truth sequences in YCbCr format | 20 | | LQ_YCbCr.zip | All low quality sequences in YCbCr format | 21 | | GT_test.zip | Ground truth test sequences in RGB format | 22 | | LQ_test.zip | Low Quality test sequences in RGB format | 23 | | GT_YCbCr_test.zip | Ground truth test sequences in YCbCr format | 24 | | LQ_YCbCr_test.zip | Low Quality test sequences in YCbCr format | 25 | | videos.zip | Original videos (> 500 LR-HR pairs here) | 26 | 27 | ## Code 28 | 29 | ### Dependencies 30 | * Linux *(tested on Ubuntu 18.04)* 31 | * Python 3 *(tested on python 3.7)* 32 | * NVIDIA GPU + CUDA *(tested on CUDA 10.2 and 11.1)* 33 | 34 | ### Installation 35 | ``` 36 | # Create a new anaconda python environment (realvsr) 37 | conda create -n realvsr python=3.7 -y 38 | 39 | # Activate the created environment 40 | conda activate realvsr 41 | 42 | # Install dependencies 43 | pip install -r requirements.txt 44 | 45 | # Bulid the DCN module 46 | cd codes/models/archs/dcn 47 | python setup.py develop 48 | ``` 49 | 50 | ### Training 51 | 52 | Modify the configuration files accordingly in codes/options/train folder and run the following command (current we did not implement distributed training): 53 | ``` 54 | python train.py -opt xxxxx.yml 55 | ``` 56 | 57 | ### Testing 58 | 59 | #### Test on RealVSR testing set sequences: 60 | Modify the configuration in test_RealVSR_wi_GT.py and run the following command: 61 | ``` 62 | python test_RealVSR_wi_GT.py 63 | ``` 64 | 65 | #### Test on real-world captured sequences: 66 | Modify the configuration in test_RealVSR_wo_GT.py and run the following command: 67 | ``` 68 | python test_RealVSR_wo_GT.py 69 | ``` 70 | 71 | #### Pre-trained Models 72 | Some pretrained models could be found on [Google Drive](https://drive.google.com/drive/folders/1nMXhsNbTrRUBUX8EEzeD_gqmqoHbcmDz?usp=sharing) and [Baidu Drive](https://pan.baidu.com/s/1zYupxTDBRAyxbzc5fQThwQ) (code: n1n0). 73 | 74 | ## License 75 | 76 | This project is released under the Apache 2.0 license. 77 | 78 | ## Citation 79 | 80 | If you find this code useful in your research, please consider citing: 81 | ``` latex 82 | @article{yang2021real, 83 | title={Real-world Video Super-resolution: A Benchmark Dataset and A Decomposition based Learning Scheme}, 84 | author={YANG, Xi and Xiang, Wangmeng and Zeng, Hui and Zhang, Lei}, 85 | journal=ICCV, 86 | year={2021} 87 | } 88 | ``` 89 | 90 | ## Acknowledgement 91 | 92 | This implementation largely depends on [EDVR](https://github.com/xinntao/EDVR). Thanks for the excellent codebase! You may also consider migrating it to [BasicSR](https://github.com/xinntao/BasicSR). 93 | -------------------------------------------------------------------------------- /codes/data/VideoTestDataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | import torch.utils.data as data 4 | import data.util as util 5 | 6 | 7 | class VideoTestDataset(data.Dataset): 8 | """ 9 | A video test dataset. Support: 10 | Vid4 11 | REDS4 12 | Vimeo90K-Test 13 | 14 | no need to prepare LMDB files 15 | """ 16 | 17 | def __init__(self, opt): 18 | super(VideoTestDataset, self).__init__() 19 | self.opt = opt 20 | self.cache_data = opt['cache_data'] 21 | self.half_N_frames = opt['N_frames'] // 2 22 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] 23 | self.data_type = self.opt['data_type'] 24 | self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []} 25 | if self.data_type == 'lmdb': 26 | raise ValueError('No need to use LMDB during validation/test.') 27 | #### Generate data info and cache data 28 | self.imgs_LQ, self.imgs_GT = {}, {} 29 | if opt['name'].lower() in ['vid4', 'reds4', 'realvsr_test']: 30 | subfolders_LQ = util.glob_file_list(self.LQ_root) 31 | subfolders_GT = util.glob_file_list(self.GT_root) 32 | for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT): 33 | subfolder_name = osp.basename(subfolder_GT) 34 | img_paths_LQ = util.glob_file_list(subfolder_LQ) 35 | img_paths_GT = util.glob_file_list(subfolder_GT) 36 | max_idx = len(img_paths_LQ) 37 | assert max_idx == len(img_paths_GT), 'Different number of images in LQ and GT folders' 38 | self.data_info['path_LQ'].extend(img_paths_LQ) 39 | self.data_info['path_GT'].extend(img_paths_GT) 40 | self.data_info['folder'].extend([subfolder_name] * max_idx) 41 | for i in range(max_idx): 42 | self.data_info['idx'].append('{}/{}'.format(i, max_idx)) 43 | border_l = [0] * max_idx 44 | for i in range(self.half_N_frames): 45 | border_l[i] = 1 46 | border_l[max_idx - i - 1] = 1 47 | self.data_info['border'].extend(border_l) 48 | 49 | if self.cache_data: 50 | self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ, color=opt['color']) 51 | self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT, color=opt['color']) 52 | elif opt['name'].lower() in ['vimeo90k-test']: 53 | pass # TODO 54 | else: 55 | raise ValueError( 56 | 'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.' 57 | ) 58 | 59 | def __getitem__(self, index): 60 | # path_LQ = self.data_info['path_LQ'][index] 61 | # path_GT = self.data_info['path_GT'][index] 62 | folder = self.data_info['folder'][index] 63 | idx, max_idx = self.data_info['idx'][index].split('/') 64 | idx, max_idx = int(idx), int(max_idx) 65 | border = self.data_info['border'][index] 66 | 67 | if self.cache_data: 68 | select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'], 69 | padding=self.opt['padding']) 70 | imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx)) 71 | img_GT = self.imgs_GT[folder][idx] 72 | else: 73 | pass # TODO 74 | 75 | return { 76 | 'LQs': imgs_LQ, 77 | 'GT': img_GT, 78 | 'folder': folder, 79 | 'idx': self.data_info['idx'][index], 80 | 'border': border 81 | } 82 | 83 | def __len__(self): 84 | return len(self.data_info['path_GT']) 85 | -------------------------------------------------------------------------------- /codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale = opt['scale'] 20 | 21 | # datasets 22 | for phase, dataset in opt['datasets'].items(): 23 | phase = phase.split('_')[0] 24 | dataset['phase'] = phase 25 | if opt['distortion'] == 'sr': 26 | dataset['scale'] = scale 27 | is_lmdb = False 28 | if dataset.get('dataroot_GT', None) is not None: 29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 30 | if dataset['dataroot_GT'].endswith('lmdb'): 31 | is_lmdb = True 32 | if dataset.get('dataroot_LQ', None) is not None: 33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 34 | if dataset['dataroot_LQ'].endswith('lmdb'): 35 | is_lmdb = True 36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 37 | if dataset['mode'].endswith('mc'): # for memcached 38 | dataset['data_type'] = 'mc' 39 | dataset['mode'] = dataset['mode'].replace('_mc', '') 40 | 41 | # path 42 | for key, path in opt['path'].items(): 43 | if path and key in opt['path'] and key != 'strict_load': 44 | opt['path'][key] = osp.expanduser(path) 45 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 46 | if is_train: 47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 48 | opt['path']['experiments_root'] = experiments_root 49 | opt['path']['models'] = osp.join(experiments_root, 'models') 50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 51 | opt['path']['log'] = experiments_root 52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 53 | 54 | # change some options for debug mode 55 | if 'debug' in opt['name']: 56 | opt['train']['val_freq'] = 8 57 | opt['logger']['print_freq'] = 1 58 | opt['logger']['save_checkpoint_freq'] = 8 59 | else: # test 60 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 61 | opt['path']['results_root'] = results_root 62 | opt['path']['log'] = results_root 63 | 64 | # network 65 | if opt['distortion'] == 'sr': 66 | opt['network_G']['scale'] = scale 67 | 68 | return opt 69 | 70 | 71 | def dict2str(opt, indent_l=1): 72 | """dict to string for logger""" 73 | msg = '' 74 | for k, v in opt.items(): 75 | if isinstance(v, dict): 76 | msg += ' ' * (indent_l * 2) + k + ':[\n' 77 | msg += dict2str(v, indent_l + 1) 78 | msg += ' ' * (indent_l * 2) + ']\n' 79 | else: 80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 81 | return msg 82 | 83 | 84 | def dict_to_nonedict(opt): 85 | """convert to NoneDict, which return None for missing key.""" 86 | if isinstance(opt, dict): 87 | new_opt = dict() 88 | for key, sub_opt in opt.items(): 89 | new_opt[key] = dict_to_nonedict(sub_opt) 90 | return NoneDict(**new_opt) 91 | elif isinstance(opt, list): 92 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 93 | else: 94 | return opt 95 | 96 | 97 | def check_resume(opt, resume_iter): 98 | """Check resume states and pretrain_model paths""" 99 | logger = logging.getLogger('base') 100 | if opt['path']['resume_state']: 101 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 102 | 'pretrain_model_D', None) is not None: 103 | logger.warning('pretrain_model path will be ignored when resuming training.') 104 | 105 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 106 | '{}_G.pth'.format(resume_iter)) 107 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 108 | if 'gan' in opt['model']: 109 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 110 | '{}_D.pth'.format(resume_iter)) 111 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 112 | 113 | 114 | class NoneDict(dict): 115 | def __missing__(self, key): 116 | return None 117 | -------------------------------------------------------------------------------- /codes/models/archs/TDAN_arch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Network architecture for TDAN: 3 | TDAN: Temporally Deformable Alignment Network for Video Super-Resolution 4 | ''' 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import models.archs.arch_util as arch_util 11 | try: 12 | from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN 13 | except ImportError: 14 | raise ImportError('Failed to import DCNv2 module.') 15 | 16 | 17 | class Align(nn.Module): 18 | 19 | def __init__(self, channel=1, nf=64, nb=5, groups=8): 20 | super(Align, self).__init__() 21 | 22 | self.initial_conv = nn.Conv2d(channel, nf, 3, padding=1, bias=True) 23 | self.residual_layers = arch_util.make_layer(arch_util.ResidualBlock_noBN, nb) 24 | 25 | self.bottle_neck = nn.Conv2d(nf * 2, nf, 3, padding=1, bias=True) 26 | 27 | self.offset_conv_1 = nn.Conv2d(nf, nf, 3, padding=1, bias=True) 28 | self.deform_conv_1 = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 29 | extra_offset_mask=True) 30 | self.offset_conv_2 = nn.Conv2d(nf, nf, 3, padding=1, bias=True) 31 | self.deform_conv_2 = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 32 | extra_offset_mask=True) 33 | self.offset_conv_3 = nn.Conv2d(nf, nf, 3, padding=1, bias=True) 34 | self.deform_conv_3 = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 35 | extra_offset_mask=True) 36 | 37 | self.offset_conv = nn.Conv2d(nf, nf, 3, padding=1, bias=True) 38 | self.deform_conv = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 39 | extra_offset_mask=True) 40 | self.reconstruction = nn.Conv2d(nf, channel, 3, padding=1, bias=True) 41 | 42 | def forward(self, x): 43 | B, N, C, W, H = x.size() 44 | 45 | # extract features 46 | y = x.view(-1, C, W, H) 47 | out = F.relu(self.initial_conv(y), inplace=True) 48 | out = self.residual_layers(out) 49 | out = out.view(B, N, -1, W, H) 50 | 51 | # reference frame 52 | ref_index = N // 2 53 | ref_frame = out[:, ref_index, :, :, :].clone().contiguous() 54 | # neighbor frames 55 | y = [] 56 | for i in range(N): 57 | nei_frame = out[:, i, :, :, :].contiguous() 58 | fea = torch.cat([ref_frame, nei_frame], dim=1) 59 | fea = self.bottle_neck(fea) 60 | # feature transformation 61 | offset1 = self.offset_conv_1(fea) 62 | fea = self.deform_conv_1([fea, offset1]) 63 | offset2 = self.offset_conv_2(fea) 64 | fea = self.deform_conv_2([fea, offset2]) 65 | offset3 = self.offset_conv_3(fea) 66 | fea = self.deform_conv_3([nei_frame, offset3]) 67 | offset = self.offset_conv(fea) 68 | aligned_fea = (self.deform_conv([fea, offset])) 69 | im = self.reconstruction(aligned_fea) 70 | y.append(im) 71 | y = torch.cat(y, dim=1) 72 | return y 73 | 74 | 75 | class Trunk(nn.Module): 76 | 77 | def __init__(self, channel=1, nframes=5, scale=4, nf=64, nb=10): 78 | super(Trunk, self).__init__() 79 | self.feature_extractor = nn.Sequential(nn.Conv2d(nframes * channel, 64, 3, padding=1, bias=True), 80 | nn.ReLU(inplace=True)) 81 | self.residual_layers = arch_util.make_layer(arch_util.ResidualBlock_noBN, nb) 82 | self.upsampler = nn.Sequential(arch_util.Upsampler(arch_util.default_conv, scale, 64, act=False), 83 | nn.Conv2d(64, 3, 3, padding=1, bias=False)) 84 | 85 | def forward(self, x): 86 | ''' 87 | :param x: (B, C*T, H, W) 88 | :return: (B, C, s*H, s*W) 89 | ''' 90 | out = self.feature_extractor(x) 91 | out = self.residual_layers(out) 92 | out = self.upsampler(out) 93 | return out 94 | 95 | 96 | class TDAN(nn.Module): 97 | '''Temporally Deformable Alignment Network''' 98 | def __init__(self, channel=1, nframes=5, scale=4, nf=64, nb_f=5, nb_b=10, groups=8): 99 | super(TDAN, self).__init__() 100 | 101 | self.align = Align(channel=channel, nf=nf, nb=nb_f, groups=groups) 102 | self.trunk = Trunk(channel=channel, nframes=nframes, scale=scale, nf=nf, nb=nb_b) 103 | 104 | def forward(self, x): 105 | ''' 106 | :param x: (B, T, C, H, W) 107 | :return: (B, C, s*H, s*W) 108 | ''' 109 | out = self.align(x) 110 | out = self.trunk(out) 111 | return out 112 | 113 | 114 | if __name__ == '__main__': 115 | B, N, C, W, H = 1, 7, 3, 64, 64 116 | model = TDAN(channel=C, nf=64, nframes=N, groups=8, scale=1).to(device=torch.device('cuda')) 117 | x = torch.randn(B, N, C, W, H).to(device=torch.device('cuda')) 118 | out = model(x) 119 | print(out.shape) 120 | -------------------------------------------------------------------------------- /codes/test_RealVSR_wo_GT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import logging 5 | import numpy as np 6 | import cv2 7 | import torch 8 | 9 | import utils.util as util 10 | import data.util as data_util 11 | import models.archs.TOF_arch as TOF_arch 12 | import models.archs.RCAN_arch as RCAN_arch 13 | import models.archs.EDVR_arch as EDVR_arch 14 | import models.archs.TDAN_arch as TDAN_arch 15 | import models.archs.FSTRN_arch as FSTRN_arch 16 | 17 | 18 | def center_crop(img, tar_h, tar_w): 19 | N, C, inp_h, inp_w = img.shape # LQ size 20 | # center crop 21 | start_h = int((inp_h - tar_h) / 2) 22 | start_w = int((inp_w - tar_w) / 2) 23 | img = img[:, :, start_h:start_h + tar_h, start_w:start_w + tar_w] 24 | return img 25 | 26 | 27 | def main(): 28 | ################# 29 | # configurations 30 | ################# 31 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 32 | data_mode = 'RealSeq' 33 | 34 | # TODO: Modify the configurations here 35 | # model 36 | N_ch = 3 37 | N_in = 3 38 | model = 'EDVR' 39 | model_name = '001_EDVR_NoUp_woTSA_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW' 40 | model_path = '../experiments/pretrained_models/{}.pth'.format(model_name) 41 | # dataset 42 | read_folder = '/home/yangxi/datasets/RealVSR/test_frames_YCbCr' 43 | save_folder = '/home/yangxi/datasets/RealVSR/results/{}/{}'.format(data_mode, model_name) 44 | # color mode 45 | color = 'YCbCr' 46 | # device 47 | device = torch.device('cuda') 48 | 49 | if model == 'RCAN': 50 | model = RCAN_arch.RCAN(num_in_ch=N_ch, num_out_ch=N_ch, num_frames=N_in, num_feat=64, 51 | num_group=5, num_block=2, squeeze_factor=16, upscale=1, res_scale=1) 52 | elif model == 'FSTRN': 53 | model = FSTRN_arch.FSTRN(k=3, nf=64, scale=1, nframes=N_in) 54 | elif model == 'TOF': 55 | model = TOF_arch.TOF(nframes=3, K=3, in_nc=N_ch, out_nc=N_ch, nf=64, nb=10, upscale=1) 56 | elif model == 'TDAN': 57 | model = TDAN_arch.TDAN(channel=N_ch, nf=64, nframes=N_in, groups=8, scale=1) 58 | elif model == 'EDVR': 59 | model = EDVR_arch.EDVR_NoUp(nf=64, nc=N_ch, nframes=N_in, groups=8, front_RBs=5, back_RBs=10, 60 | predeblur=False, HR_in=False, w_TSA=False) 61 | else: 62 | raise ValueError() 63 | 64 | # temporal padding mode 65 | padding = 'new_info' # different from the official setting 66 | save_imgs = True 67 | 68 | util.mkdirs(save_folder) 69 | util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) 70 | logger = logging.getLogger('base') 71 | 72 | #### log info 73 | logger.info('Data: {} - {}'.format(data_mode, read_folder)) 74 | logger.info('Padding mode: {}'.format(padding)) 75 | logger.info('Model path: {}'.format(model_path)) 76 | logger.info('Save images: {}'.format(save_imgs)) 77 | 78 | subfolder_l = sorted(glob.glob(os.path.join(read_folder, '*'))) 79 | 80 | #### set up the models 81 | model.load_state_dict(torch.load(model_path), strict=True) 82 | model.eval() 83 | model = model.to(device) 84 | 85 | subfolder_name_l = [] 86 | 87 | # for each sub-folder 88 | for subfolder in subfolder_l: 89 | subfolder_name = subfolder.split('/')[-1] 90 | subfolder_name_l.append(subfolder_name) 91 | save_subfolder = os.path.join(save_folder, subfolder_name) 92 | 93 | img_path_l = sorted(glob.glob(os.path.join(subfolder, '*'))) 94 | max_idx = len(img_path_l) 95 | 96 | if save_imgs: 97 | util.mkdirs(save_subfolder) 98 | logger.info('Folder {} '.format(subfolder_name)) 99 | time_list = [] 100 | 101 | # process each image 102 | for img_idx, img_path in enumerate(img_path_l): 103 | img_name = os.path.splitext(os.path.basename(img_path))[0] 104 | select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) 105 | img_l = [] 106 | for idx in select_idx: 107 | inp_path = os.path.join(subfolder, '{:05d}.png'.format(idx + 1)) 108 | img = data_util.read_img(None, inp_path)[:, :, [2, 1, 0]] 109 | img_l.append(img) 110 | imgs = np.stack(img_l, axis=0) 111 | imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() 112 | imgs = center_crop(imgs, tar_h=imgs.shape[2], tar_w=imgs.shape[3]) 113 | imgs = imgs.unsqueeze(0).to(device) 114 | t_1 = time.time() 115 | output = util.single_forward(model, imgs) 116 | t_2 = time.time() 117 | logger.info('Processing: {}, Time: {:.4f} s'.format(img_name, t_2 - t_1)) 118 | time_list.append(t_2 - t_1) 119 | output = output.squeeze(0).cpu().numpy() 120 | output = np.transpose(output, (1, 2, 0)) 121 | output = data_util.ycbcr2bgr(output) 122 | output = (np.clip(output, 0, 1) * 255.).round().astype(np.uint8) 123 | # save imgs 124 | if save_imgs: 125 | cv2.imwrite(os.path.join(save_subfolder, '{}.png'.format(img_name)), output) 126 | logger.info('Average inference time: {:.4f} s'.format(sum(time_list) / len(time_list))) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /codes/models/archs/arch_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | def initialize_weights(net_l, scale=1): 9 | if not isinstance(net_l, list): 10 | net_l = [net_l] 11 | for net in net_l: 12 | for m in net.modules(): 13 | if isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 15 | m.weight.data *= scale # for residual block 16 | if m.bias is not None: 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 20 | m.weight.data *= scale 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif isinstance(m, nn.BatchNorm2d): 24 | init.constant_(m.weight, 1) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | 28 | def make_layer(basic_block, num_basic_block, **kwarg): 29 | """Make layers by stacking the same blocks. 30 | Args: 31 | basic_block (nn.module): nn.module class for basic block. 32 | num_basic_block (int): number of blocks. 33 | Returns: 34 | nn.Sequential: Stacked blocks in nn.Sequential. 35 | """ 36 | layers = [] 37 | for _ in range(num_basic_block): 38 | layers.append(basic_block(**kwarg)) 39 | return nn.Sequential(*layers) 40 | 41 | 42 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 43 | return nn.Conv2d(in_channels, out_channels, kernel_size, 44 | padding=(kernel_size // 2), bias=bias) 45 | 46 | 47 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 48 | """Warp an image or feature map with optical flow. 49 | Args: 50 | x (Tensor): Tensor with size (n, c, h, w). 51 | flow (Tensor): Tensor with size (n, h, w, 2), normal value. 52 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. 53 | padding_mode (str): 'zeros' or 'border' or 'reflection'. 54 | Default: 'zeros'. 55 | Returns: 56 | Tensor: Warped image or feature map. 57 | """ 58 | assert x.size()[-2:] == flow.size()[1:3] 59 | _, _, h, w = x.size() 60 | # create mesh grid 61 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) 62 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 63 | grid.requires_grad = False 64 | grid = grid.type_as(x) 65 | 66 | vgrid = grid + flow 67 | # scale grid to [-1,1] 68 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 69 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 70 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 71 | output = F.grid_sample( 72 | x, 73 | vgrid_scaled, 74 | mode=interp_mode, 75 | padding_mode=padding_mode, 76 | align_corners=True) 77 | # Before pytorch 1.3, the default value is align_corners=True 78 | # After pytorch 1.3, the default value is align_corners=False 79 | # TODO, what if align_corners=False 80 | return output 81 | 82 | 83 | class BasicBlock(nn.Sequential): 84 | 85 | def __init__(self, conv, in_channels, out_channels, kernel_size, stride=1, 86 | bias=False, bn=True, act=nn.ReLU(True)): 87 | 88 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 89 | if bn: 90 | m.append(nn.BatchNorm2d(out_channels)) 91 | if act is not None: 92 | m.append(act) 93 | 94 | super(BasicBlock, self).__init__(*m) 95 | 96 | 97 | class ResBlock(nn.Module): 98 | 99 | def __init__(self, conv, n_feats, kernel_size, 100 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 101 | 102 | super(ResBlock, self).__init__() 103 | m = [] 104 | for i in range(2): 105 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 106 | if bn: 107 | m.append(nn.BatchNorm2d(n_feats)) 108 | if i == 0: 109 | m.append(act) 110 | 111 | self.body = nn.Sequential(*m) 112 | self.res_scale = res_scale 113 | 114 | def forward(self, x): 115 | res = self.body(x).mul(self.res_scale) 116 | res += x 117 | 118 | return res 119 | 120 | 121 | class ResidualBlock_noBN(nn.Module): 122 | """Residual block w/o BN 123 | ---Conv-ReLU-Conv-+- 124 | |________________| 125 | """ 126 | 127 | def __init__(self, nf=64): 128 | super(ResidualBlock_noBN, self).__init__() 129 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 130 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 131 | 132 | # initialization 133 | initialize_weights([self.conv1, self.conv2], 0.1) 134 | 135 | def forward(self, x): 136 | identity = x 137 | out = F.relu(self.conv1(x), inplace=True) 138 | out = self.conv2(out) 139 | return identity + out 140 | 141 | 142 | class Upsampler(nn.Sequential): 143 | """Upsampler""" 144 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 145 | modules = [] 146 | 147 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 148 | for _ in range(int(math.log(scale, 2))): 149 | modules.append(conv(n_feat, 4 * n_feat, 3, bias)) 150 | modules.append(nn.PixelShuffle(2)) 151 | if bn: 152 | modules.append(nn.BatchNorm2d(n_feat)) 153 | if act: 154 | modules.append(act()) 155 | elif scale == 3: 156 | modules.append(conv(n_feat, 9 * n_feat, 3, bias)) 157 | modules.append(nn.PixelShuffle(3)) 158 | if bn: 159 | modules.append(nn.BatchNorm2d(n_feat)) 160 | if act: 161 | modules.append(act()) 162 | else: 163 | raise NotImplementedError() 164 | 165 | super(Upsampler, self).__init__(*modules) 166 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | 10 | def __init__(self, opt): 11 | self.opt = opt 12 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 13 | self.is_train = opt['is_train'] 14 | self.schedulers = [] 15 | self.optimizers = [] 16 | 17 | def feed_data(self, data): 18 | pass 19 | 20 | def optimize_parameters(self, step): 21 | pass 22 | 23 | def get_current_visuals(self): 24 | pass 25 | 26 | def get_current_losses(self): 27 | pass 28 | 29 | def print_network(self): 30 | pass 31 | 32 | def save(self, label): 33 | pass 34 | 35 | def load(self): 36 | pass 37 | 38 | def _set_lr(self, lr_groups_l): 39 | ''' set learning rate for warmup, 40 | lr_groups_l: list for lr_groups. each for a optimizer''' 41 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 42 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 43 | param_group['lr'] = lr 44 | 45 | def _get_init_lr(self): 46 | # get the initial lr, which is set by the scheduler 47 | init_lr_groups_l = [] 48 | for optimizer in self.optimizers: 49 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 50 | return init_lr_groups_l 51 | 52 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 53 | for scheduler in self.schedulers: 54 | scheduler.step() 55 | #### set up warm up learning rate 56 | if cur_iter < warmup_iter: 57 | # get initial lr for each group 58 | init_lr_g_l = self._get_init_lr() 59 | # modify warming-up learning rates 60 | warm_up_lr_l = [] 61 | for init_lr_g in init_lr_g_l: 62 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 63 | # set learning rate 64 | self._set_lr(warm_up_lr_l) 65 | 66 | def get_current_learning_rate(self): 67 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups] 68 | 69 | def get_network_description(self, network): 70 | '''Get the string and total parameters of the network''' 71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 72 | network = network.module 73 | s = str(network) 74 | n = sum(map(lambda x: x.numel(), network.parameters())) 75 | return s, n 76 | 77 | def save_network(self, network, network_label, iter_label): 78 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 79 | save_path = os.path.join(self.opt['path']['models'], save_filename) 80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 81 | network = network.module 82 | state_dict = network.state_dict() 83 | for key, param in state_dict.items(): 84 | state_dict[key] = param.cpu() 85 | torch.save(state_dict, save_path) 86 | 87 | def load_network(self, load_path, network, strict=True): 88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 89 | network = network.module 90 | load_net = torch.load(load_path) 91 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 92 | for k, v in load_net.items(): 93 | if k.startswith('module.'): 94 | load_net_clean[k[7:]] = v 95 | else: 96 | load_net_clean[k] = v 97 | network.load_state_dict(load_net_clean, strict=strict) 98 | 99 | def load_network_separately(self, load_path_a, load_path_b, name_a, name_b, network, strict=False): 100 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 101 | network = network.module 102 | 103 | load_net_a = torch.load(load_path_a) 104 | load_net_a_clean = OrderedDict() # remove unnecessary 'module.' 105 | for k, v in load_net_a.items(): 106 | if k.startswith('module.'): 107 | load_net_a_clean['{}{}'.format(name_a, k[7:])] = v 108 | else: 109 | load_net_a_clean['{}{}'.format(name_a, k)] = v 110 | network.load_state_dict(load_net_a_clean, strict=strict) 111 | 112 | load_net_b = torch.load(load_path_b) 113 | load_net_b_clean = OrderedDict() # remove unnecessary 'module.' 114 | for k, v in load_net_b.items(): 115 | if k.startswith('module.'): 116 | load_net_b_clean['{}{}'.format(name_b, k[7:])] = v 117 | else: 118 | load_net_b_clean['{}{}'.format(name_b, k)] = v 119 | network.load_state_dict(load_net_b_clean, strict=strict) 120 | 121 | def save_training_state(self, epoch, iter_step): 122 | '''Saves training state during training, which will be used for resuming''' 123 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 124 | for s in self.schedulers: 125 | state['schedulers'].append(s.state_dict()) 126 | for o in self.optimizers: 127 | state['optimizers'].append(o.state_dict()) 128 | save_filename = '{}.state'.format(iter_step) 129 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 130 | torch.save(state, save_path) 131 | 132 | def resume_training(self, resume_state): 133 | '''Resume the optimizers and schedulers for training''' 134 | resume_optimizers = resume_state['optimizers'] 135 | resume_schedulers = resume_state['schedulers'] 136 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 137 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 138 | for i, o in enumerate(resume_optimizers): 139 | self.optimizers[i].load_state_dict(o) 140 | for i, s in enumerate(resume_schedulers): 141 | self.schedulers[i].load_state_dict(s) 142 | -------------------------------------------------------------------------------- /codes/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restarts = [v + 1 for v in self.restarts] 16 | self.restart_weights = weights if weights else [1] 17 | assert len(self.restarts) == len( 18 | self.restart_weights), 'restarts and their weights do not match.' 19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch in self.restarts: 23 | if self.clear_state: 24 | self.optimizer.state = defaultdict(dict) 25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 27 | if self.last_epoch not in self.milestones: 28 | return [group['lr'] for group in self.optimizer.param_groups] 29 | return [ 30 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 31 | for group in self.optimizer.param_groups 32 | ] 33 | 34 | 35 | class CosineAnnealingLR_Restart(_LRScheduler): 36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 37 | self.T_period = T_period 38 | self.T_max = self.T_period[0] # current T period 39 | self.eta_min = eta_min 40 | self.restarts = restarts if restarts else [0] 41 | self.restarts = [v + 1 for v in self.restarts] 42 | self.restart_weights = weights if weights else [1] 43 | self.last_restart = 0 44 | assert len(self.restarts) == len( 45 | self.restart_weights), 'restarts and their weights do not match.' 46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | if self.last_epoch == 0: 50 | return self.base_lrs 51 | elif self.last_epoch in self.restarts: 52 | self.last_restart = self.last_epoch 53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 57 | return [ 58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 60 | ] 61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 63 | (group['lr'] - self.eta_min) + self.eta_min 64 | for group in self.optimizer.param_groups] 65 | 66 | 67 | if __name__ == "__main__": 68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 69 | betas=(0.9, 0.99)) 70 | ############################## 71 | # MultiStepLR_Restart 72 | ############################## 73 | ## Original 74 | lr_steps = [200000, 400000, 600000, 800000] 75 | restarts = None 76 | restart_weights = None 77 | 78 | ## two 79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 80 | restarts = [500000] 81 | restart_weights = [1] 82 | 83 | ## four 84 | lr_steps = [ 85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 87 | ] 88 | restarts = [250000, 500000, 750000] 89 | restart_weights = [1, 1, 1] 90 | 91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 92 | clear_state=False) 93 | 94 | ############################## 95 | # Cosine Annealing Restart 96 | ############################## 97 | ## two 98 | T_period = [500000, 500000] 99 | restarts = [500000] 100 | restart_weights = [1] 101 | 102 | ## four 103 | T_period = [250000, 250000, 250000, 250000] 104 | restarts = [250000, 500000, 750000] 105 | restart_weights = [1, 1, 1] 106 | 107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 108 | weights=restart_weights) 109 | 110 | ############################## 111 | # Draw figure 112 | ############################## 113 | N_iter = 1000000 114 | lr_l = list(range(N_iter)) 115 | for i in range(N_iter): 116 | scheduler.step() 117 | current_lr = optimizer.param_groups[0]['lr'] 118 | lr_l[i] = current_lr 119 | 120 | import matplotlib as mpl 121 | from matplotlib import pyplot as plt 122 | import matplotlib.ticker as mtick 123 | mpl.style.use('default') 124 | import seaborn 125 | seaborn.set(style='whitegrid') 126 | seaborn.set_context('paper') 127 | 128 | plt.figure(1) 129 | plt.subplot(111) 130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 131 | plt.title('Title', fontsize=16, color='k') 132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 133 | legend = plt.legend(loc='upper right', shadow=False) 134 | ax = plt.gca() 135 | labels = ax.get_xticks().tolist() 136 | for k, v in enumerate(labels): 137 | labels[k] = str(int(v / 1000)) + 'K' 138 | ax.set_xticklabels(labels) 139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 140 | 141 | ax.set_ylabel('Learning rate') 142 | ax.set_xlabel('Iteration') 143 | fig = plt.gcf() 144 | plt.show() 145 | -------------------------------------------------------------------------------- /codes/models/archs/RCAN_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | 5 | from models.archs.arch_util import make_layer 6 | 7 | 8 | class Upsample(nn.Sequential): 9 | """Upsample module. 10 | Args: 11 | scale (int): Scale factor. Supported scales: 2^n and 3. 12 | num_feat (int): Channel number of intermediate features. 13 | """ 14 | 15 | def __init__(self, scale, num_feat): 16 | m = [] 17 | if (scale & (scale - 1)) == 0: # scale = 2^n 18 | for _ in range(int(math.log(scale, 2))): 19 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 20 | m.append(nn.PixelShuffle(2)) 21 | elif scale == 3: 22 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 23 | m.append(nn.PixelShuffle(3)) 24 | else: 25 | raise ValueError(f'scale {scale} is not supported. ' 26 | 'Supported scales: 2^n and 3.') 27 | super(Upsample, self).__init__(*m) 28 | 29 | 30 | class ChannelAttention(nn.Module): 31 | """Channel attention used in RCAN. 32 | Args: 33 | num_feat (int): Channel number of intermediate features. 34 | squeeze_factor (int): Channel squeeze factor. Default: 16. 35 | """ 36 | 37 | def __init__(self, num_feat, squeeze_factor=16): 38 | super(ChannelAttention, self).__init__() 39 | self.attention = nn.Sequential( 40 | nn.AdaptiveAvgPool2d(1), 41 | nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), 44 | nn.Sigmoid()) 45 | 46 | def forward(self, x): 47 | y = self.attention(x) 48 | return x * y 49 | 50 | 51 | class RCAB(nn.Module): 52 | """Residual Channel Attention Block (RCAB) used in RCAN. 53 | Args: 54 | num_feat (int): Channel number of intermediate features. 55 | squeeze_factor (int): Channel squeeze factor. Default: 16. 56 | res_scale (float): Scale the residual. Default: 1. 57 | """ 58 | 59 | def __init__(self, num_feat, squeeze_factor=16, res_scale=1): 60 | super(RCAB, self).__init__() 61 | self.res_scale = res_scale 62 | 63 | self.rcab = nn.Sequential( 64 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), 65 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), 66 | ChannelAttention(num_feat, squeeze_factor)) 67 | 68 | def forward(self, x): 69 | res = self.rcab(x) * self.res_scale 70 | return res + x 71 | 72 | 73 | class ResidualGroup(nn.Module): 74 | """Residual Group of RCAB. 75 | Args: 76 | num_feat (int): Channel number of intermediate features. 77 | num_block (int): Block number in the body network. 78 | squeeze_factor (int): Channel squeeze factor. Default: 16. 79 | res_scale (float): Scale the residual. Default: 1. 80 | """ 81 | 82 | def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): 83 | super(ResidualGroup, self).__init__() 84 | 85 | self.residual_group = make_layer( 86 | RCAB, 87 | num_block, 88 | num_feat=num_feat, 89 | squeeze_factor=squeeze_factor, 90 | res_scale=res_scale) 91 | self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 92 | 93 | def forward(self, x): 94 | res = self.conv(self.residual_group(x)) 95 | return res + x 96 | 97 | 98 | class RCAN(nn.Module): 99 | """Residual Channel Attention Networks. 100 | Paper: Image Super-Resolution Using Very Deep Residual Channel Attention 101 | Networks 102 | Ref git repo: https://github.com/yulunzhang/RCAN. 103 | Args: 104 | num_in_ch (int): Channel number of inputs. 105 | num_out_ch (int): Channel number of outputs. 106 | num_feat (int): Channel number of intermediate features. 107 | Default: 64. 108 | num_group (int): Number of ResidualGroup. Default: 10. 109 | num_block (int): Number of RCAB in ResidualGroup. Default: 16. 110 | squeeze_factor (int): Channel squeeze factor. Default: 16. 111 | upscale (int): Upsampling factor. Support 2^n and 3. 112 | Default: 4. 113 | res_scale (float): Used to scale the residual in residual block. 114 | Default: 1. 115 | img_range (float): Image range. Default: 255. 116 | rgb_mean (tuple[float]): Image mean in RGB orders. 117 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 118 | """ 119 | 120 | def __init__(self, 121 | num_in_ch, 122 | num_out_ch, 123 | num_frames, 124 | num_feat=64, 125 | num_group=10, 126 | num_block=16, 127 | squeeze_factor=16, 128 | upscale=4, 129 | res_scale=1, 130 | img_range=255., 131 | rgb_mean=(0.4488, 0.4371, 0.4040)): 132 | super(RCAN, self).__init__() 133 | 134 | # self.img_range = img_range 135 | # self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 136 | 137 | self.conv_first = nn.Conv2d(num_in_ch * num_frames, num_feat, 3, 1, 1) 138 | self.body = make_layer( 139 | ResidualGroup, 140 | num_group, 141 | num_feat=num_feat, 142 | num_block=num_block, 143 | squeeze_factor=squeeze_factor, 144 | res_scale=res_scale) 145 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 146 | self.upsample = Upsample(upscale, num_feat) 147 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 148 | 149 | def forward(self, x): 150 | # self.mean = self.mean.type_as(x) 151 | # x = (x - self.mean) * self.img_range 152 | if x.dim() == 5: 153 | B, N, C, H, W = x.shape 154 | x = x.view(B, N*C, H, W) 155 | 156 | x = self.conv_first(x) 157 | res = self.conv_after_body(self.body(x)) 158 | res += x 159 | 160 | x = self.conv_last(self.upsample(res)) 161 | # x = x / self.img_range + self.mean 162 | 163 | return x 164 | 165 | 166 | if __name__ == '__main__': 167 | x = torch.randn(4, 3, 3, 64, 64) 168 | model = RCAN(num_in_ch=3, 169 | num_out_ch=3, 170 | num_frames=3, 171 | num_feat=64, 172 | num_group=5, 173 | num_block=2, 174 | squeeze_factor=16, 175 | upscale=1, 176 | res_scale=1) 177 | out = model(x) 178 | print(out.shape) 179 | -------------------------------------------------------------------------------- /codes/models/archs/TOF_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import models.archs.arch_util as arch_util 6 | 7 | from models.archs.SRResNet_arch import MSRResNet 8 | 9 | 10 | class SpyNet_Block(nn.Module): 11 | """ 12 | A submodule of SpyNet. 13 | """ 14 | 15 | def __init__(self, ic=8): 16 | super(SpyNet_Block, self).__init__() 17 | 18 | self.block = nn.Sequential( 19 | nn.Conv2d(in_channels=ic, out_channels=32, kernel_size=7, stride=1, padding=3), 20 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 21 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), 22 | nn.BatchNorm2d(64), nn.ReLU(inplace=True), 23 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), 24 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 25 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), 26 | nn.BatchNorm2d(16), nn.ReLU(inplace=True), 27 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 28 | 29 | # initialization 30 | arch_util.initialize_weights(self.block, 0.1) 31 | 32 | def forward(self, x): 33 | """ 34 | input: x: [ref im, nbr im, initial flow] - (B, 8, H, W) 35 | output: estimated flow - (B, 2, H, W) 36 | """ 37 | return self.block(x) 38 | 39 | 40 | class SpyNet(nn.Module): 41 | """ 42 | SpyNet for estimating optical flow 43 | Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016 44 | """ 45 | 46 | def __init__(self, K=3): 47 | super(SpyNet, self).__init__() 48 | 49 | self.K = K 50 | ## modify input block 51 | self.block0 = SpyNet_Block(ic=6) 52 | self.blocks = nn.ModuleList([SpyNet_Block(ic=8) for _ in range(K)]) 53 | 54 | def forward(self, ref, nbr): 55 | """Estimating optical flow in coarse level, upsample, and estimate in fine level 56 | Note: the size of input should be divisible by 8, if not, pad them before input 57 | input: ref: reference image - [B, 3, H, W] 58 | nbr: the neighboring image to be warped - [B, 3, H, W] 59 | output: warpped nbr by estimated optical flow - [B, 3, H, W] 60 | flow: estimated optical flow (absolute displacement) - [B, 2, H, W] 61 | """ 62 | B, C, H, W = ref.size() 63 | ref = [ref] 64 | nbr = [nbr] 65 | 66 | for _ in range(self.K): 67 | ref.insert( 68 | 0, 69 | nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2, 70 | count_include_pad=False) 71 | ) 72 | nbr.insert( 73 | 0, 74 | nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2, 75 | count_include_pad=False) 76 | ) 77 | 78 | flow = self.block0(torch.cat([ref[0], nbr[0]], 1)) # [H//2^K, W//2^K] 79 | 80 | for i in range(self.K): 81 | flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear', 82 | align_corners=True) * 2.0 83 | flow = flow_up + self.blocks[i]( 84 | torch.cat([ref[i+1], arch_util.flow_warp(nbr[i+1], flow_up.permute(0, 2, 3, 1)), flow_up], 1) 85 | ) 86 | 87 | output = arch_util.flow_warp(nbr[-1], flow.permute(0, 2, 3, 1)) 88 | return output, flow 89 | 90 | 91 | class MSRResNet(nn.Module): 92 | 93 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 94 | super(MSRResNet, self).__init__() 95 | self.upscale = upscale 96 | 97 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 98 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 99 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 100 | 101 | # upsampling 102 | if self.upscale == 2: 103 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 104 | self.pixel_shuffle = nn.PixelShuffle(2) 105 | elif self.upscale == 3: 106 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 107 | self.pixel_shuffle = nn.PixelShuffle(3) 108 | elif self.upscale == 4: 109 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 110 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 111 | self.pixel_shuffle = nn.PixelShuffle(2) 112 | 113 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 114 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 115 | 116 | # activation function 117 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 118 | 119 | # initialization 120 | if self.upscale == 2 or self.upscale == 3 or self.upscale == 4: 121 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1) 122 | if self.upscale == 4: 123 | arch_util.initialize_weights(self.upconv2, 0.1) 124 | 125 | def forward(self, x): 126 | C = x.size(1) 127 | if C > 3: 128 | ## for video sr with multi-frames input 129 | x_base = x[:, C//2-1:C//2+2, :, :] 130 | else: 131 | x_base = x 132 | 133 | fea = self.lrelu(self.conv_first(x)) 134 | out = self.recon_trunk(fea) 135 | 136 | if self.upscale == 4: 137 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 138 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 139 | elif self.upscale == 3 or self.upscale == 2: 140 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 141 | 142 | out = self.conv_last(self.lrelu(self.HRconv(out))) 143 | base = F.interpolate(x_base, scale_factor=self.upscale, mode='bilinear', align_corners=False) 144 | out += base 145 | return out 146 | 147 | 148 | class TOF(nn.Module): 149 | """ 150 | Video sr based on SpyNet and MEDSR 151 | Args [in_nc] is number of input channels of a single frame! 152 | """ 153 | def __init__(self, nframes=3, K=3, in_nc=3, out_nc=3, nf=32, nb=12, upscale=2): 154 | super(TOF, self).__init__() 155 | 156 | self.nframes = nframes 157 | 158 | self.align_arch = SpyNet(K=K) 159 | self.sr_arch = MSRResNet(in_nc=nframes * in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=upscale) 160 | 161 | def forward(self, x): 162 | """ 163 | x: [B, T, C, H, W], T = nframes. 164 | """ 165 | B, T, C, H, W = x.size() 166 | assert T == self.nframes 167 | 168 | ## warp neighbour frames to reference frame 169 | # reference frame 170 | ref_index = T // 2 171 | ref_frame = x[:, ref_index, :, :, :] 172 | # neighbour frames 173 | y = [] 174 | nbrs = [] 175 | flows = [] 176 | for i in range(T): 177 | if i == ref_index: 178 | y.append(ref_frame) 179 | else: 180 | warp_nbr_frame, flow = self.align_arch(ref_frame, x[:, i, :, :, :]) 181 | y.append(warp_nbr_frame) 182 | nbrs.append(warp_nbr_frame) 183 | flows.append(flow) 184 | 185 | ## cat frames as input of sr module 186 | y = torch.cat(y, dim=1) 187 | out = self.sr_arch(y) 188 | 189 | return out 190 | 191 | 192 | if __name__ == '__main__': 193 | device = torch.device('cuda') 194 | model = TOF(nframes=1, K=3, in_nc=3, out_nc=3, nf=32, nb=10, upscale=1).to(device) 195 | x = torch.randn(4, 1, 3, 128, 128).to(device) 196 | out = model(x) 197 | print(out[0].shape) 198 | 199 | -------------------------------------------------------------------------------- /codes/models/VideoSR_archs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.archs.VGG_arch as VGG_arch 5 | import models.archs.TOF_arch as TOF_arch 6 | import models.archs.RCAN_arch as RCAN_arch 7 | import models.archs.EDVR_arch as EDVR_arch 8 | import models.archs.TDAN_arch as TDAN_arch 9 | import models.archs.FSTRN_arch as FSTRN_arch 10 | import models.archs.discriminator_arch as discriminator_arch 11 | 12 | 13 | #################### 14 | # define network 15 | #################### 16 | 17 | 18 | def define_G(opt): 19 | """Discriminator""" 20 | opt_net = opt['network_G'] 21 | which_model = opt_net['which_model_G'] 22 | 23 | if which_model == 'EDVR': 24 | netG = EDVR_arch.EDVR(nf=opt_net['nf'], nc=opt_net['nc'], nframes=opt_net['nframes'], 25 | groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], 26 | back_RBs=opt_net['back_RBs'], center=opt_net['center'], 27 | predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], 28 | w_TSA=opt_net['w_TSA']) 29 | 30 | elif which_model == 'EDVR_NoUp': 31 | netG = EDVR_arch.EDVR_NoUp(nf=opt_net['nf'], nc=opt_net['nc'], nframes=opt_net['nframes'], 32 | groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], 33 | back_RBs=opt_net['back_RBs'], center=opt_net['center'], 34 | predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], 35 | w_TSA=opt_net['w_TSA']) 36 | 37 | elif which_model == 'TDAN': 38 | netG = TDAN_arch.TDAN(nf=opt_net['nf'], channel=opt_net['nc'], nframes=opt_net['nframes'], 39 | nb_f=opt_net['nb_f'], nb_b=opt_net['nb_b'], groups=opt_net['groups'], 40 | scale=opt['scale']) 41 | 42 | elif which_model == 'TOF': 43 | netG = TOF_arch.TOF(nframes=opt_net['nframes'], K=opt_net['K'], in_nc=opt_net['nc'], 44 | out_nc=opt_net['nc'], nf=opt_net['nf'], nb=opt_net['nb'], 45 | upscale=opt['scale']) 46 | 47 | elif which_model == 'FSTRN': 48 | netG = FSTRN_arch.FSTRN(k=opt_net['k'], nf=opt_net['nf'], scale=opt['scale'], 49 | nframes=opt_net['nframes']) 50 | 51 | elif which_model == 'RCAN': 52 | netG = RCAN_arch.RCAN(num_in_ch=opt_net['num_in_ch'], num_out_ch=opt_net['num_out_ch'], 53 | num_frames=opt_net['num_frames'], num_feat=opt_net['num_feat'], 54 | num_group=opt_net['num_group'], num_block=opt_net['num_block'], 55 | squeeze_factor=opt_net['squeeze_factor'], upscale=opt['scale'], 56 | res_scale=opt_net['res_scale']) 57 | else: 58 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) 59 | 60 | return netG 61 | 62 | 63 | # Discriminator 64 | def define_D(opt): 65 | opt_net = opt['network_D'] 66 | which_model = opt_net['which_model_D'] 67 | 68 | if which_model == 'discriminator_vgg_192': 69 | netD = VGG_arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 70 | 71 | elif which_model == 'PatchDiscriminator': 72 | netD = discriminator_arch.PatchDiscriminator(input_nc=opt_net['in_nc'], ndf=opt_net['nf'], 73 | norm_layer=nn.BatchNorm2d) 74 | elif which_model == 'PixelDiscriminator': 75 | netD = discriminator_arch.PixelDiscriminator(input_nc=opt_net['in_nc'], ndf=opt_net['nf'], 76 | norm_layer=nn.BatchNorm2d) 77 | elif which_model == 'UNetDiscriminator': 78 | netD = discriminator_arch.UNetDiscriminator(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 79 | 80 | elif which_model == 'MultiscaleDiscriminator_v1': 81 | netD = discriminator_arch.MultiscaleDiscriminator_v1(input_nc=opt_net['in_nc'], 82 | ndf=opt_net['nf'], 83 | num_D=opt_net['num_D'], 84 | norm_layer=nn.BatchNorm2d, 85 | gan_type=opt_net['gan_type']) 86 | elif which_model == 'MultiscaleDiscriminator_v2': 87 | netD = discriminator_arch.MultiscaleDiscriminator_v2(input_nc=opt_net['in_nc'], 88 | ndf=opt_net['nf'], 89 | num_D=opt_net['num_D'], 90 | norm_layer=nn.BatchNorm2d, 91 | gan_type=opt_net['gan_type']) 92 | elif which_model == 'MultiscaleDiscriminator_v3': 93 | netD = discriminator_arch.MultiscaleDiscriminator_v3(input_nc=opt_net['in_nc'], 94 | ndf=opt_net['nf'], 95 | num_D=opt_net['num_D'], 96 | norm_layer=nn.BatchNorm2d, 97 | gan_type=opt_net['gan_type']) 98 | elif which_model == 'MultiscaleDiscriminator_v4': 99 | netD = discriminator_arch.MultiscaleDiscriminator_v4(input_nc=opt_net['in_nc'], 100 | ndf=opt_net['nf'], 101 | num_D=opt_net['num_D'], 102 | norm_layer=nn.BatchNorm2d, 103 | gan_type=opt_net['gan_type']) 104 | elif which_model == 'LaplacePyramidDiscriminator': 105 | netD = discriminator_arch.LaplacePyramidDiscriminator(input_nc=opt_net['in_nc'], 106 | ndf=opt_net['nf'], 107 | num_D=opt_net['num_D'], 108 | norm_layer=nn.BatchNorm2d, 109 | gan_type=opt_net['gan_type']) 110 | elif which_model == 'GaussianPyramidDiscriminator': 111 | netD = discriminator_arch.GaussianPyramidDiscriminator(input_nc=opt_net['in_nc'], 112 | ndf=opt_net['nf'], 113 | num_D=opt_net['num_D'], 114 | norm_layer=nn.BatchNorm2d, 115 | gan_type=opt_net['gan_type']) 116 | elif which_model == 'ImageGradientPyramidDiscriminator_v1': 117 | netD = discriminator_arch.ImageGradientPyramidDiscriminator_v1(input_nc=opt_net['in_nc'], 118 | ndf=opt_net['nf'], 119 | num_D=opt_net['num_D'], 120 | norm_layer=nn.BatchNorm2d, 121 | gan_type=opt_net['gan_type']) 122 | elif which_model == 'ImageGradientPyramidDiscriminator_v2': 123 | netD = discriminator_arch.ImageGradientPyramidDiscriminator_v2(input_nc=opt_net['in_nc'], 124 | ndf=opt_net['nf'], 125 | num_D=opt_net['num_D'], 126 | norm_layer=nn.BatchNorm2d, 127 | gan_type=opt_net['gan_type']) 128 | else: 129 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 130 | return netD 131 | 132 | 133 | def define_F(opt, use_bn=False): 134 | """Network for Perceptual Loss""" 135 | gpu_ids = opt['gpu_ids'] 136 | device = torch.device('cuda' if gpu_ids else 'cpu') 137 | # PyTorch pretrained VGG19-54, before ReLU. 138 | if use_bn: 139 | feature_layer = 49 140 | else: 141 | feature_layer = 34 142 | netF = VGG_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, use_input_norm=True, device=device) 143 | netF.eval() # No need to train 144 | 145 | return netF -------------------------------------------------------------------------------- /codes/models/archs/VGG_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | import torchvision 6 | 7 | 8 | class Discriminator_VGG_128(nn.Module): 9 | def __init__(self, in_nc, nf): 10 | super(Discriminator_VGG_128, self).__init__() 11 | # [64, 128, 128] 12 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 13 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 14 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 15 | # [64, 64, 64] 16 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 17 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 18 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 19 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 20 | # [128, 32, 32] 21 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 22 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 23 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 24 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 25 | # [256, 16, 16] 26 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 27 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 28 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 29 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 30 | # [512, 8, 8] 31 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 32 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 33 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 34 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 35 | 36 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 37 | self.linear2 = nn.Linear(100, 1) 38 | 39 | # activation function 40 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 41 | 42 | def forward(self, x): 43 | fea = self.lrelu(self.conv0_0(x)) 44 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 45 | 46 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 47 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 48 | 49 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 50 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 51 | 52 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 53 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 54 | 55 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 56 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 57 | 58 | fea = fea.view(fea.size(0), -1) 59 | fea = self.lrelu(self.linear1(fea)) 60 | out = self.linear2(fea) 61 | return out 62 | 63 | 64 | class Discriminator_VGG_192(nn.Module): 65 | def __init__(self, in_nc, nf): 66 | super(Discriminator_VGG_192, self).__init__() 67 | # [64, 128, 128] 68 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 69 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 70 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 71 | # [64, 64, 64] 72 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 73 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 74 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 75 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 76 | # [128, 32, 32] 77 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 78 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 79 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 80 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 81 | # [256, 16, 16] 82 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 83 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 84 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 85 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 86 | # [512, 8, 8] 87 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 88 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 89 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 90 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 91 | 92 | self.linear1 = nn.Linear(512 * 6 * 6, 100) 93 | self.linear2 = nn.Linear(100, 1) 94 | 95 | # activation function 96 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 97 | 98 | def forward(self, x): 99 | fea = self.lrelu(self.conv0_0(x)) 100 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 101 | 102 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 103 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 104 | 105 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 106 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 107 | 108 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 109 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 110 | 111 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 112 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 113 | 114 | fea = fea.view(fea.size(0), -1) 115 | fea = self.lrelu(self.linear1(fea)) 116 | out = self.linear2(fea) 117 | return out 118 | 119 | 120 | class VGGFeatureExtractor(nn.Module): 121 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, 122 | device=torch.device('cpu')): 123 | super(VGGFeatureExtractor, self).__init__() 124 | self.use_input_norm = use_input_norm 125 | if use_bn: 126 | model = torchvision.models.vgg19_bn(pretrained=True) 127 | else: 128 | model = torchvision.models.vgg19(pretrained=True) 129 | if self.use_input_norm: 130 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 131 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 132 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 133 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 134 | self.register_buffer('mean', mean) 135 | self.register_buffer('std', std) 136 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 137 | # No need to BP to variable 138 | for k, v in self.features.named_parameters(): 139 | v.requires_grad = False 140 | 141 | def forward(self, x): 142 | # Assume input range is [0, 1] 143 | if self.use_input_norm: 144 | x = (x - self.mean) / self.std 145 | output = self.features(x) 146 | return output 147 | 148 | 149 | class Vgg19(nn.Module): 150 | 151 | def __init__(self, use_input_norm=True, device=torch.device('cpu')): 152 | super(Vgg19, self).__init__() 153 | self.use_input_norm = use_input_norm 154 | self.slice1 = torch.nn.Sequential() 155 | self.slice2 = torch.nn.Sequential() 156 | self.slice3 = torch.nn.Sequential() 157 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 158 | if self.use_input_norm: 159 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 160 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 161 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 162 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 163 | self.register_buffer('mean', mean) 164 | self.register_buffer('std', std) 165 | for x in range(4): 166 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 167 | for x in range(4, 9): 168 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 169 | for x in range(9, 14): 170 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 171 | for param in self.parameters(): 172 | param.requires_grad = False 173 | 174 | def forward(self, x): 175 | # Assume input range is [0, 1] 176 | if self.use_input_norm: 177 | x = (x - self.mean) / self.std 178 | h = self.slice1(x) 179 | h_relu1_2 = h 180 | h = self.slice2(h) 181 | h_relu2_2 = h 182 | h = self.slice3(h) 183 | h_relu3_2 = h 184 | return h_relu1_2, h_relu2_2, h_relu3_2 185 | 186 | 187 | if __name__ == '__main__': 188 | # x = torch.randn(4, 3, 224, 224) 189 | # vgg = torchvision.models.vgg19(pretrained=True) 190 | # out = vgg(x) 191 | 192 | x = torch.randn(4, 3, 128, 128) 193 | vgg128 = Discriminator_VGG_128(in_nc=3, nf=64) 194 | out = vgg128(x) 195 | print(out.shape) 196 | 197 | x = torch.randn(4, 3, 192, 192) 198 | vgg192 = Discriminator_VGG_192(in_nc=3, nf=64) 199 | out = vgg192(x) 200 | print(out.shape) 201 | -------------------------------------------------------------------------------- /codes/metrics/evaluate_realvsr_full_reference_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import glob 4 | import logging 5 | import numpy as np 6 | import cv2 7 | import torch 8 | from PIL import Image 9 | from torchvision import transforms 10 | 11 | import utils.util as util 12 | import data.util as data_util 13 | from metrics import calculate_psnr, calculate_ssim, calculate_niqe 14 | from IQA_pytorch import LPIPSvgg, DISTS 15 | 16 | 17 | def setup_logger(logger_name, log_file, level=logging.INFO, screen=False, tofile=False): 18 | """set up logger""" 19 | lg = logging.getLogger(logger_name) 20 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 21 | datefmt='%y-%m-%d %H:%M:%S') 22 | lg.setLevel(level) 23 | if tofile: 24 | fh = logging.FileHandler(log_file, mode='w') 25 | fh.setFormatter(formatter) 26 | lg.addHandler(fh) 27 | if screen: 28 | sh = logging.StreamHandler() 29 | sh.setFormatter(formatter) 30 | lg.addHandler(sh) 31 | 32 | 33 | def prepare_image(image, resize=False, repeatNum=1): 34 | if resize and min(image.size) > 256: 35 | image = transforms.functional.resize(image, 256) 36 | image = transforms.ToTensor()(image) 37 | return image.unsqueeze(0).repeat(repeatNum, 1, 1, 1) 38 | 39 | 40 | def evaluate_psnr(model_name, GT_folder, log_file=None, color='y'): 41 | 42 | if log_file: 43 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 44 | logger = logging.getLogger('base') 45 | else: 46 | log_file = '/home/xiyang/Results/RealVSR/PSNR_{}.log'.format(model_name) 47 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 48 | logger = logging.getLogger('base') 49 | 50 | subfolder_l = sorted(glob.glob(GT_folder)) 51 | avg_psnr_l = [] 52 | 53 | # for each sub-folder 54 | for subfolder in subfolder_l: 55 | subfolder_name = subfolder.split('/')[-1] 56 | logger.info(subfolder_name) 57 | avg_psnr = 0 58 | for img_idx, img_GT_path in enumerate(sorted(glob.glob(osp.join(subfolder, '[0-9]*')))): 59 | if color == 'y': 60 | img_GT = data_util.read_img(None, img_GT_path) 61 | img_GT = data_util.bgr2ycbcr(img_GT, only_y=True) 62 | else: 63 | img_GT = data_util.read_img(None, img_GT_path) 64 | # TODO: modify accordingly 65 | img_LQ_path = img_GT_path.replace('/GT_test/', '/test_results/{}/'.format(model_name)) 66 | assert img_LQ_path != img_GT_path 67 | if color == 'y': 68 | img_LQ = data_util.read_img(None, img_LQ_path) 69 | img_LQ = data_util.bgr2ycbcr(img_LQ, only_y=True) 70 | else: 71 | img_LQ = data_util.read_img(None, img_LQ_path) 72 | 73 | if color == 'y': 74 | psnr = util.calculate_psnr(img_LQ * 255, img_GT * 255) 75 | else: 76 | psnr = util.calculate_psnr(img_LQ * 255, img_GT * 255) 77 | logger.info('{:3d} - {:25} \tPSNR: {:.2f} dB'.format(img_idx + 1, os.path.basename(img_LQ_path), psnr)) 78 | avg_psnr += psnr 79 | 80 | avg_psnr = avg_psnr / len(subfolder_l) 81 | avg_psnr_l.append(avg_psnr) 82 | logger.info(model_name) 83 | logger.info('PSNR: {:.2f} dB'.format(sum(avg_psnr_l) / len(avg_psnr_l))) 84 | 85 | 86 | def evaluate_ssim(model_name, GT_folder, log_file=None, color='y'): 87 | 88 | if log_file: 89 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 90 | logger = logging.getLogger('base') 91 | else: 92 | log_file = '/home/xiyang/Results/RealVSR/SSIM_{}.log'.format(model_name) 93 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 94 | logger = logging.getLogger('base') 95 | 96 | subfolder_l = sorted(glob.glob(GT_folder)) 97 | avg_ssim_l = [] 98 | 99 | # for each sub-folder 100 | for subfolder in subfolder_l: 101 | subfolder_name = subfolder.split('/')[-1] 102 | logger.info(subfolder_name) 103 | avg_ssim = 0 104 | for img_idx, img_GT_path in enumerate(sorted(glob.glob(osp.join(subfolder, '[0-9]*')))): 105 | if color == 'y': 106 | img_GT = data_util.read_img(None, img_GT_path) 107 | img_GT = data_util.bgr2ycbcr(img_GT, only_y=True) 108 | else: 109 | img_GT = data_util.read_img(None, img_GT_path) 110 | # TODO: modify accordingly 111 | img_LQ_path = img_GT_path.replace('/GT_test/', '/test_results/{}/'.format(model_name)) 112 | assert img_LQ_path != img_GT_path 113 | if color == 'y': 114 | img_LQ = data_util.read_img(None, img_LQ_path) 115 | img_LQ = data_util.bgr2ycbcr(img_LQ, only_y=True) 116 | else: 117 | img_LQ = data_util.read_img(None, img_LQ_path) 118 | 119 | if color == 'y': 120 | ssim = util.calculate_ssim(img_LQ * 255, img_GT * 255) 121 | else: 122 | ssim = util.calculate_ssim(img_LQ * 255, img_GT * 255) 123 | logger.info('{:3d} - {:25} \tSSIM: {:.4f}'.format(img_idx + 1, os.path.basename(img_LQ_path), ssim)) 124 | avg_ssim += ssim 125 | 126 | avg_ssim = avg_ssim / len(subfolder_l) 127 | avg_ssim_l.append(avg_ssim) 128 | logger.info(model_name) 129 | logger.info('SSIM: {:.4f}'.format(sum(avg_ssim_l) / len(avg_ssim_l))) 130 | 131 | 132 | def evaluate_lpips(model_name, GT_folder, device='cuda:0', log_file=None): 133 | 134 | if log_file: 135 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 136 | logger = logging.getLogger('base') 137 | else: 138 | log_file = '/home/xiyang/Results/RealVSR/LPIPS_{}.log'.format(model_name) 139 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 140 | logger = logging.getLogger('base') 141 | 142 | avg_lpips_l = [] 143 | subfolder_l = sorted(glob.glob(GT_folder)) 144 | 145 | # for each sub-folder 146 | for subfolder in subfolder_l: 147 | subfolder_name = subfolder.split('/')[-1] 148 | logger.info(subfolder_name) 149 | avg_lpips = 0 150 | for img_idx, img_GT_path in enumerate(sorted(glob.glob(osp.join(subfolder, '[0-9]*')))): 151 | img_GT = Image.open(img_GT_path).convert("RGB") 152 | # TODO: modify accordingly 153 | img_LQ_path = img_GT_path.replace('/GT_test/', '/test_results/{}/'.format(model_name)) 154 | assert img_LQ_path != img_GT_path 155 | img_LQ = Image.open(img_LQ_path).convert("RGB") 156 | 157 | lq = prepare_image(img_LQ, resize=False).to(device) 158 | gt = prepare_image(img_GT, resize=False).to(device) 159 | 160 | img_name = os.path.basename(img_LQ_path) 161 | metric = LPIPSvgg().to(device) 162 | score = metric(lq, gt, as_loss=False) 163 | logger.info('{:3d} - {:25} \t LPIPS: {:.4f}'.format(img_idx + 1, img_name, score.item())) 164 | avg_lpips += score.item() 165 | 166 | avg_lpips = avg_lpips / 50 167 | avg_lpips_l.append(avg_lpips) 168 | logger.info('{}'.format(model_name)) 169 | logger.info('LPIPS: {:.4f}'.format(sum(avg_lpips_l) / len(avg_lpips_l))) 170 | 171 | 172 | def evaluate_dists(model_name, GT_folder, device='cuda:0', log_file=None): 173 | 174 | if log_file: 175 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 176 | logger = logging.getLogger('base') 177 | else: 178 | log_file = '/home/xiyang/Results/RealVSR/DISTS_{}.log'.format(model_name) 179 | setup_logger('base', log_file, level=logging.INFO, screen=True, tofile=True) 180 | logger = logging.getLogger('base') 181 | 182 | avg_dists_l = [] 183 | subfolder_l = sorted(glob.glob(GT_folder)) 184 | 185 | # for each sub-folder 186 | for subfolder in subfolder_l: 187 | subfolder_name = subfolder.split('/')[-1] 188 | logger.info(subfolder_name) 189 | avg_dists = 0 190 | for img_idx, img_GT_path in enumerate(sorted(glob.glob(osp.join(subfolder, '[0-9]*')))): 191 | img_GT = Image.open(img_GT_path).convert("RGB") 192 | # TODO: modify accordingly 193 | img_LQ_path = img_GT_path.replace('/GT_test/', '/test_results/{}/'.format(model_name)) 194 | assert img_LQ_path != img_GT_path 195 | img_LQ = Image.open(img_LQ_path).convert("RGB") 196 | 197 | lq = prepare_image(img_LQ, resize=False).to(device) 198 | gt = prepare_image(img_GT, resize=False).to(device) 199 | 200 | img_name = os.path.basename(img_LQ_path) 201 | metric = DISTS().to(device) 202 | score = metric(lq, gt, as_loss=False) 203 | logger.info('{:3d} - {:25} \t DISTS: {:.4f}'.format(img_idx + 1, img_name, score.item())) 204 | avg_dists += score.item() 205 | 206 | avg_dists = avg_dists / 50 207 | avg_dists_l.append(avg_dists) 208 | logger.info('{}'.format(model_name)) 209 | logger.info('DISTS: {:.4f}'.format(sum(avg_dists_l) / len(avg_dists_l))) 210 | 211 | 212 | if __name__ == '__main__': 213 | pass 214 | -------------------------------------------------------------------------------- /codes/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import data.util as data_util 5 | import utils.util as util 6 | 7 | from IQA_pytorch import SSIM, MS_SSIM 8 | 9 | 10 | class CharbonnierLoss(nn.Module): 11 | """Charbonnier Loss (L1)""" 12 | def __init__(self, eps=1e-6, reduction='mean'): 13 | super(CharbonnierLoss, self).__init__() 14 | self.eps = eps 15 | self.reduction = reduction 16 | 17 | def forward(self, x, y): 18 | diff = x - y 19 | if self.reduction == 'mean': 20 | loss = torch.mean(torch.sqrt(diff * diff + self.eps)) 21 | else: 22 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 23 | return loss 24 | 25 | 26 | class HuberLoss(nn.Module): 27 | """Huber Loss (L1)""" 28 | def __init__(self, delta=1e-2, reduction='mean'): 29 | super(HuberLoss, self).__init__() 30 | self.delta = delta 31 | self.reduction = reduction 32 | 33 | def forward(self, x, y): 34 | abs_diff = torch.abs(x - y) 35 | q_term = torch.min(abs_diff, torch.full_like(abs_diff, self.delta)) 36 | l_term = abs_diff - q_term 37 | if self.reduction == 'mean': 38 | loss = torch.mean(0.5 * q_term ** 2 + self.delta * l_term) 39 | else: 40 | loss = torch.sum(0.5 * q_term ** 2 + self.delta * l_term) 41 | return loss 42 | 43 | 44 | class TVLoss(nn.Module): 45 | """Total Variation Loss""" 46 | def __init__(self): 47 | super(TVLoss, self).__init__() 48 | 49 | def forward(self, x): 50 | return torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + \ 51 | torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) 52 | 53 | 54 | class GWLoss(nn.Module): 55 | """Gradient Weighted Loss""" 56 | def __init__(self, w=4, reduction='mean'): 57 | super(GWLoss, self).__init__() 58 | self.w = w 59 | self.reduction = reduction 60 | sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float) 61 | sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float) 62 | self.weight_x = nn.Parameter(data=sobel_x, requires_grad=False) 63 | self.weight_y = nn.Parameter(data=sobel_y, requires_grad=False) 64 | 65 | def forward(self, x1, x2): 66 | b, c, w, h = x1.shape 67 | weight_x = self.weight_x.expand(c, 1, 3, 3).type_as(x1) 68 | weight_y = self.weight_y.expand(c, 1, 3, 3).type_as(x1) 69 | Ix1 = F.conv2d(x1, weight_x, stride=1, padding=1, groups=c) 70 | Ix2 = F.conv2d(x2, weight_x, stride=1, padding=1, groups=c) 71 | Iy1 = F.conv2d(x1, weight_y, stride=1, padding=1, groups=c) 72 | Iy2 = F.conv2d(x2, weight_y, stride=1, padding=1, groups=c) 73 | dx = torch.abs(Ix1 - Ix2) 74 | dy = torch.abs(Iy1 - Iy2) 75 | # loss = torch.exp(2*(dx + dy)) * torch.abs(x1 - x2) 76 | loss = (1 + self.w * dx) * (1 + self.w * dy) * torch.abs(x1 - x2) 77 | if self.reduction == 'mean': 78 | return torch.mean(loss) 79 | else: 80 | return torch.sum(loss) 81 | 82 | 83 | class StyleLoss(nn.Module): 84 | """Style Loss""" 85 | def __init__(self): 86 | super(StyleLoss, self).__init__() 87 | 88 | @staticmethod 89 | def gram_matrix(self, x): 90 | B, C, H, W = x.size() 91 | features = x.view(B * C, H * W) 92 | G = torch.mm(features, features.t()) 93 | return G.div(B * C * H * W) 94 | 95 | def forward(self, input, target): 96 | G_i = self.gram_matrix(input) 97 | G_t = self.gram_matrix(target).detach() 98 | loss = F.mse_loss(G_i, G_t) 99 | return loss 100 | 101 | 102 | class GANLoss(nn.Module): 103 | """GAN loss (vanilla | lsgan | wgan-gp)""" 104 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 105 | super(GANLoss, self).__init__() 106 | self.gan_type = gan_type.lower() 107 | self.real_label_val = real_label_val 108 | self.fake_label_val = fake_label_val 109 | 110 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 111 | self.loss = nn.BCEWithLogitsLoss() 112 | elif self.gan_type == 'lsgan': 113 | self.loss = nn.MSELoss() 114 | elif self.gan_type == 'wgan-gp': 115 | def wgan_loss(input, target): 116 | # target is boolean 117 | return -1 * input.mean() if target else input.mean() 118 | self.loss = wgan_loss 119 | else: 120 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 121 | 122 | def get_target_label(self, input, target_is_real): 123 | if self.gan_type == 'wgan-gp': 124 | return target_is_real 125 | if target_is_real: 126 | return torch.empty_like(input).fill_(self.real_label_val) 127 | else: 128 | return torch.empty_like(input).fill_(self.fake_label_val) 129 | 130 | def forward(self, input, target_is_real): 131 | target_label = self.get_target_label(input, target_is_real) 132 | loss = self.loss(input, target_label) 133 | return loss 134 | 135 | 136 | class GradientPenaltyLoss(nn.Module): 137 | """Gradient Penalty Loss""" 138 | def __init__(self, device=torch.device('cpu')): 139 | super(GradientPenaltyLoss, self).__init__() 140 | self.register_buffer('grad_outputs', torch.Tensor()) 141 | self.grad_outputs = self.grad_outputs.to(device) 142 | 143 | def get_grad_outputs(self, input): 144 | if self.grad_outputs.size() != input.size(): 145 | self.grad_outputs.resize_(input.size()).fill_(1.0) 146 | return self.grad_outputs 147 | 148 | def forward(self, interp, interp_crit): 149 | grad_outputs = self.get_grad_outputs(interp_crit) 150 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 151 | grad_outputs=grad_outputs, create_graph=True, 152 | retain_graph=True, only_inputs=True)[0] 153 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 154 | grad_interp_norm = grad_interp.norm(2, dim=1) 155 | 156 | loss = ((grad_interp_norm - 1) ** 2).mean() 157 | return loss 158 | 159 | 160 | class PyramidLoss(nn.Module): 161 | """Pyramid Loss""" 162 | def __init__(self, num_levels=3, pyr_mode='gau', loss_mode='l1', reduction='mean'): 163 | super(PyramidLoss, self).__init__() 164 | self.num_levels = num_levels 165 | self.pyr_mode = pyr_mode 166 | self.loss_mode = loss_mode 167 | assert self.pyr_mode == 'gau' or self.pyr_mode == 'lap' 168 | if self.loss_mode == 'l1': 169 | self.loss = nn.L1Loss(reduction=reduction) 170 | elif self.loss_mode == 'l2': 171 | self.loss = nn.MSELoss(reduction=reduction) 172 | elif self.loss_mode == 'hb': 173 | self.loss = HuberLoss(reduction=reduction) 174 | elif self.loss_mode == 'cb': 175 | self.loss = CharbonnierLoss(reduction=reduction) 176 | else: 177 | raise ValueError() 178 | 179 | def forward(self, x, y): 180 | B, C, H, W = x.shape 181 | device = x.device 182 | gauss_kernel = util.gauss_kernel(size=5, device=device, channels=C) 183 | if self.pyr_mode == 'gau': 184 | pyr_x = util.gau_pyramid(img=x, kernel=gauss_kernel, max_levels=self.num_levels) 185 | pyr_y = util.gau_pyramid(img=y, kernel=gauss_kernel, max_levels=self.num_levels) 186 | else: 187 | pyr_x = util.lap_pyramid(img=x, kernel=gauss_kernel, max_levels=self.num_levels) 188 | pyr_y = util.lap_pyramid(img=y, kernel=gauss_kernel, max_levels=self.num_levels) 189 | loss = 0 190 | for i in range(self.num_levels): 191 | loss += self.loss(pyr_x[i], pyr_y[i]) 192 | return loss 193 | 194 | 195 | class LapPyrLoss(nn.Module): 196 | """Pyramid Loss""" 197 | def __init__(self, num_levels=3, lf_mode='ssim', hf_mode='cb', reduction='mean'): 198 | super(LapPyrLoss, self).__init__() 199 | self.num_levels = num_levels 200 | self.lf_mode = lf_mode 201 | self.hf_mode = hf_mode 202 | if lf_mode == 'ssim': 203 | self.lf_loss = SSIM(channels=1) 204 | elif lf_mode == 'cb': 205 | self.lf_loss = CharbonnierLoss(reduction=reduction) 206 | else: 207 | raise ValueError() 208 | if hf_mode == 'ssim': 209 | self.hf_loss = SSIM(channels=1) 210 | elif hf_mode == 'cb': 211 | self.hf_loss = CharbonnierLoss(reduction=reduction) 212 | else: 213 | raise ValueError() 214 | 215 | def forward(self, x, y): 216 | B, C, H, W = x.shape 217 | device = x.device 218 | gauss_kernel = util.gauss_kernel(size=5, device=device, channels=C) 219 | pyr_x = util.laplacian_pyramid(img=x, kernel=gauss_kernel, max_levels=self.num_levels) 220 | pyr_y = util.laplacian_pyramid(img=y, kernel=gauss_kernel, max_levels=self.num_levels) 221 | loss = self.lf_loss(pyr_x[-1], pyr_y[-1]) 222 | for i in range(self.num_levels - 1): 223 | loss += self.hf_loss(pyr_x[i], pyr_y[i]) 224 | return loss 225 | 226 | 227 | if __name__ == '__main__': 228 | device = torch.device('cuda') 229 | x1 = torch.randn(4, 3, 64, 64).to(device) 230 | x1.requires_grad = True 231 | x2 = torch.randn(4, 3, 64, 64).to(device) 232 | x2.requires_grad = True 233 | loss = GWLoss().to(device) 234 | l = loss(x1, x2) 235 | print(l) 236 | l.backward() 237 | 238 | -------------------------------------------------------------------------------- /codes/test_RealVSR_wi_GT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import logging 5 | import numpy as np 6 | import cv2 7 | import torch 8 | 9 | import utils.util as util 10 | import data.util as data_util 11 | import models.archs.TOF_arch as TOF_arch 12 | import models.archs.RCAN_arch as RCAN_arch 13 | import models.archs.EDVR_arch as EDVR_arch 14 | import models.archs.TDAN_arch as TDAN_arch 15 | import models.archs.FSTRN_arch as FSTRN_arch 16 | 17 | 18 | def main(): 19 | ################# 20 | # configurations 21 | ################# 22 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 23 | data_mode = 'RealVSR' 24 | 25 | # TODO: Modify the configurations here 26 | # model 27 | N_ch = 3 28 | N_in = 3 29 | model = 'EDVR' 30 | model_name = '001_EDVR_NoUp_woTSA_scratch_lr1e-4_150k_RealVSR_3frame_WiCutBlur_YCbCr_LapPyr+GW' 31 | model_path = '../experiments/pretrained_models/{}.pth'.format(model_name) 32 | # dataset 33 | read_folder = '/home/yangxi/datasets/RealVSR/release/LQ_YCbCr_test' 34 | save_folder = '/home/yangxi/datasets/RealVSR/results/{}/{}'.format(data_mode, model_name) 35 | # color mode 36 | color = 'YCbCr' 37 | # device 38 | device = torch.device('cuda') 39 | 40 | if model == 'RCAN': 41 | model = RCAN_arch.RCAN(num_in_ch=N_ch, num_out_ch=N_ch, num_frames=N_in, num_feat=64, 42 | num_group=5, num_block=2, squeeze_factor=16, upscale=1, res_scale=1) 43 | elif model == 'FSTRN': 44 | model = FSTRN_arch.FSTRN(k=3, nf=64, scale=1, nframes=N_in) 45 | elif model == 'TOF': 46 | model = TOF_arch.TOF(nframes=3, K=3, in_nc=N_ch, out_nc=N_ch, nf=64, nb=10, upscale=1) 47 | elif model == 'TDAN': 48 | model = TDAN_arch.TDAN(channel=N_ch, nf=64, nframes=N_in, groups=8, scale=1) 49 | elif model == 'EDVR': 50 | model = EDVR_arch.EDVR_NoUp(nf=64, nc=N_ch, nframes=N_in, groups=8, front_RBs=5, back_RBs=10, 51 | predeblur=False, HR_in=False, w_TSA=False) 52 | else: 53 | raise ValueError() 54 | 55 | #### evaluation 56 | flip_test = False 57 | crop_border = 0 58 | border_frame = N_in // 2 # border frames when evaluate 59 | 60 | # temporal padding mode 61 | padding = 'replicate' # different from the official setting 62 | save_imgs = True 63 | 64 | util.mkdirs(save_folder) 65 | util.setup_logger('base', save_folder, 'test', level=logging.INFO, screen=True, tofile=True) 66 | logger = logging.getLogger('base') 67 | 68 | #### log info 69 | logger.info('Data: {} - {}'.format(data_mode, read_folder)) 70 | logger.info('Padding mode: {}'.format(padding)) 71 | logger.info('Model path: {}'.format(model_path)) 72 | logger.info('Save images: {}'.format(save_imgs)) 73 | 74 | subfolder_l = sorted(glob.glob(os.path.join(read_folder, '*'))) 75 | 76 | #### set up the models 77 | model.load_state_dict(torch.load(model_path), strict=True) 78 | model.eval() 79 | model = model.to(device) 80 | 81 | avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], [] 82 | avg_ssim_l, avg_ssim_center_l, avg_ssim_border_l = [], [], [] 83 | subfolder_name_l = [] 84 | 85 | # for each sub-folder 86 | for subfolder in subfolder_l: 87 | subfolder_name = subfolder.split('/')[-1] 88 | subfolder_name_l.append(subfolder_name) 89 | save_subfolder = os.path.join(save_folder, subfolder_name) 90 | 91 | img_path_l = sorted(glob.glob(os.path.join(subfolder, '*'))) 92 | max_idx = len(img_path_l) 93 | 94 | if save_imgs: 95 | util.mkdirs(save_subfolder) 96 | 97 | #### read LR images 98 | imgs = data_util.read_img_seq(subfolder, color=color) 99 | #### read GT images 100 | img_GT_l = [] 101 | subfolder_GT = os.path.join(subfolder.replace('/LQ_YCbCr_test/', '/GT_YCbCr_test/'), '*') 102 | for img_GT_path in sorted(glob.glob(subfolder_GT)): 103 | if color == 'YCbCr': 104 | tmp_img = data_util.read_img(None, img_GT_path)[:, :, [2, 1, 0]] 105 | else: 106 | tmp_img = data_util.read_img(None, img_GT_path) 107 | img_GT_l.append(tmp_img) 108 | 109 | avg_psnr, avg_psnr_border, avg_psnr_center = 0, 0, 0 110 | avg_ssim, avg_ssim_border, avg_ssim_center = 0, 0, 0 111 | N_border, N_center = 0, 0 112 | 113 | # process each image 114 | for img_idx, img_path in enumerate(img_path_l): 115 | img_name = os.path.splitext(os.path.basename(img_path))[0] 116 | select_idx = data_util.index_generation(img_idx, max_idx, N_in, padding=padding) 117 | # get input images 118 | imgs_in = imgs.index_select(0, torch.LongTensor(select_idx)).unsqueeze(0).to(device) 119 | output = util.single_forward(model, imgs_in) 120 | 121 | if color == 'YCbCr': 122 | output = util.tensor2img(output.squeeze(0), out_type=np.float32, reverse_channel=False) 123 | img = (np.clip(data_util.ycbcr2bgr(output), 0, 1) * 255.).round().astype(np.uint8) 124 | # save imgs 125 | if save_imgs: 126 | cv2.imwrite(os.path.join(save_subfolder, '{}.png'.format(img_name)), img) 127 | else: 128 | output = util.tensor2img(output.squeeze(0), out_type=np.uint8, reverse_channel=True) 129 | img = output 130 | # save imgs 131 | if save_imgs: 132 | cv2.imwrite(os.path.join(save_subfolder, '{}.png'.format(img_name)), img) 133 | 134 | #### calculate PSNR and SSIM 135 | if color == 'YCbCr': 136 | output = output / 255. 137 | GT = np.copy(img_GT_l[img_idx]) 138 | GT = np.squeeze(GT) 139 | output, GT = util.crop_border([output, GT], crop_border) 140 | output = (output * 255.0).round().astype(np.uint8) 141 | GT = (GT * 255.0).round().astype(np.uint8) 142 | crt_psnr = util.calculate_psnr(output[:, :, 0], GT[:, :, 0]) 143 | crt_ssim = util.calculate_ssim(output[:, :, 0], GT[:, :, 0]) 144 | else: 145 | output = output / 255. 146 | GT = np.copy(img_GT_l[img_idx]) 147 | GT = np.squeeze(GT) 148 | output, GT = util.crop_border([output, GT], crop_border) 149 | crt_psnr = util.calculate_psnr(output * 255, GT * 255) 150 | crt_ssim = util.calculate_ssim(output * 255, GT * 255) 151 | logger.info('{:3d} - {:25} \tPSNR: {:.2f} dB \tSSIM: {:.4f}'. 152 | format(img_idx + 1, img_name, crt_psnr, crt_ssim)) 153 | 154 | if border_frame <= img_idx < max_idx - border_frame: # center frames 155 | avg_psnr_center += crt_psnr 156 | avg_ssim_center += crt_ssim 157 | N_center += 1 158 | else: # border frames 159 | avg_psnr_border += crt_psnr 160 | avg_ssim_border += crt_ssim 161 | N_border += 1 162 | 163 | avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border) 164 | avg_ssim = (avg_ssim_center + avg_ssim_border) / (N_center + N_border) 165 | avg_psnr_center = avg_psnr_center / N_center 166 | avg_ssim_center = avg_ssim_center / N_center 167 | avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border 168 | avg_ssim_border = 0 if N_border == 0 else avg_ssim_border / N_border 169 | 170 | avg_psnr_l.append(avg_psnr) 171 | avg_psnr_center_l.append(avg_psnr_center) 172 | avg_psnr_border_l.append(avg_psnr_border) 173 | avg_ssim_l.append(avg_ssim) 174 | avg_ssim_center_l.append(avg_ssim_center) 175 | avg_ssim_border_l.append(avg_ssim_border) 176 | 177 | logger.info('Folder {} - Average PSNR: {:.2f} dB for {} frames; ' 178 | 'Center PSNR: {:.2f} dB for {} frames; ' 179 | 'Border PSNR: {:.2f} dB for {} frames.'.format(subfolder_name, avg_psnr, 180 | (N_center + N_border), 181 | avg_psnr_center, N_center, 182 | avg_psnr_border, N_border)) 183 | logger.info('Folder {} - Average SSIM: {:.4f} for {} frames; ' 184 | 'Center SSIM: {:.4f} for {} frames; ' 185 | 'Border SSIM: {:.4f} for {} frames.'.format(subfolder_name, avg_ssim, 186 | (N_center + N_border), 187 | avg_ssim_center, N_center, 188 | avg_ssim_border, N_border)) 189 | 190 | logger.info('################ Tidy Outputs ################') 191 | for name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l, 192 | avg_psnr_center_l, avg_psnr_border_l): 193 | logger.info('Folder {} - Average PSNR: {:.2f} dB. ' 194 | 'Center PSNR: {:.2f} dB. ' 195 | 'Border PSNR: {:.2f} dB.'.format(name, psnr, psnr_center, psnr_border)) 196 | for name, ssim, ssim_center, ssim_border in zip(subfolder_name_l, avg_ssim_l, 197 | avg_ssim_center_l, avg_ssim_border_l): 198 | logger.info('Folder {} - Average SSIM: {:.4f}. ' 199 | 'Center SSIM: {:.4f}. ' 200 | 'Border SSIM: {:.4f}.'.format(name, ssim, ssim_center, ssim_border)) 201 | logger.info('################ Final Results ################') 202 | logger.info('Data: {} - {}'.format(data_mode, read_folder)) 203 | logger.info('Padding mode: {}'.format(padding)) 204 | logger.info('Model path: {}'.format(model_path)) 205 | logger.info('Save images: {}'.format(save_imgs)) 206 | logger.info('Flip Test: {}'.format(flip_test)) 207 | logger.info('Total Average PSNR: {:.2f} dB for {} clips. ' 208 | 'Center PSNR: {:.2f} dB. Border PSNR: {:.2f} dB.'.format( 209 | sum(avg_psnr_l) / len(avg_psnr_l), len(subfolder_l), 210 | sum(avg_psnr_center_l) / len(avg_psnr_center_l), 211 | sum(avg_psnr_border_l) / len(avg_psnr_border_l))) 212 | logger.info('Total Average SSIM: {:.4f} for {} clips. ' 213 | 'Center SSIM: {:.4f}. Border SSIM: {:.4f}.'.format( 214 | sum(avg_ssim_l) / len(avg_ssim_l), len(subfolder_l), 215 | sum(avg_ssim_center_l) / len(avg_ssim_center_l), 216 | sum(avg_ssim_border_l) / len(avg_ssim_border_l))) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | -------------------------------------------------------------------------------- /codes/models/VideoSR_AllPair_model_YCbCr_Split.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import kornia 8 | from torch.nn.parallel import DataParallel, DistributedDataParallel 9 | import models.VideoSR_archs as networks 10 | import models.lr_scheduler as lr_scheduler 11 | import data.augments_video_allpair as augments 12 | from .base_model import BaseModel 13 | from models.loss import CharbonnierLoss, HuberLoss, GWLoss, PyramidLoss, LapPyrLoss 14 | from IQA_pytorch import SSIM, MS_SSIM, DISTS 15 | 16 | 17 | logger = logging.getLogger('base') 18 | 19 | 20 | class VideoSRModel(BaseModel): 21 | 22 | def __init__(self, opt): 23 | super(VideoSRModel, self).__init__(opt) 24 | 25 | if opt['dist']: 26 | self.rank = torch.distributed.get_rank() 27 | else: 28 | self.rank = -1 # non dist training 29 | train_opt = opt['train'] 30 | 31 | # define network and load pretrained models 32 | self.netG = networks.define_G(opt).to(self.device) 33 | if opt['dist']: 34 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 35 | else: 36 | self.netG = DataParallel(self.netG) 37 | # print network 38 | self.print_network() 39 | self.load() 40 | 41 | if self.is_train: 42 | self.netG.train() 43 | 44 | #### loss 45 | loss_type = train_opt['pixel_criterion_y'] 46 | if loss_type == 'l1': 47 | self.cri_pix_y = nn.L1Loss(reduction='mean').to(self.device) 48 | elif loss_type == 'l2': 49 | self.cri_pix_y = nn.MSELoss(reduction='mean').to(self.device) 50 | elif loss_type == 'cb': 51 | self.cri_pix_y = CharbonnierLoss(reduction='mean').to(self.device) 52 | elif loss_type == 'hb': 53 | self.cri_pix_y = HuberLoss(reduction='mean').to(self.device) 54 | elif loss_type == 'gw': 55 | self.cri_pix_y = GWLoss(w=4, reduction='mean').to(self.device) 56 | elif loss_type == 'pyr': 57 | self.cri_pix_y = PyramidLoss(num_levels=3, pyr_mode='gau', loss_mode='cb', reduction='mean') 58 | elif loss_type == 'lappyr': 59 | self.cri_pix_y = LapPyrLoss(num_levels=3, lf_mode='ssim', hf_mode='cb', reduction='mean') 60 | elif loss_type == 'msssim': 61 | channels = opt['network_G']['nc'] 62 | self.cri_pix_y = MS_SSIM(channels=channels).to(self.device) 63 | else: 64 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 65 | self.l_pix_w_y = train_opt['pixel_weight_y'] 66 | 67 | loss_type = train_opt['pixel_criterion_c'] 68 | if loss_type == 'l1': 69 | self.cri_pix_c = nn.L1Loss(reduction='mean').to(self.device) 70 | elif loss_type == 'l2': 71 | self.cri_pix_c = nn.MSELoss(reduction='mean').to(self.device) 72 | elif loss_type == 'cb': 73 | self.cri_pix_c = CharbonnierLoss(reduction='mean').to(self.device) 74 | elif loss_type == 'hb': 75 | self.cri_pix_c = HuberLoss(reduction='mean').to(self.device) 76 | elif loss_type == 'gw': 77 | self.cri_pix_c = GWLoss(w=4, reduction='mean').to(self.device) 78 | elif loss_type == 'pyr': 79 | self.cri_pix_c = PyramidLoss(num_levels=3, pyr_mode='gau', loss_mode='cb', reduction='mean') 80 | elif loss_type == 'lappyr': 81 | self.cri_pix_c = LapPyrLoss(num_levels=3, lf_mode='ssim', hf_mode='cb', reduction='mean') 82 | elif loss_type == 'msssim': 83 | channels = opt['network_G']['nc'] 84 | self.cri_pix_c = MS_SSIM(channels=channels).to(self.device) 85 | else: 86 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 87 | self.l_pix_w_c = train_opt['pixel_weight_c'] 88 | 89 | #### optimizers 90 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 91 | if train_opt['ft_tsa_only']: 92 | normal_params = [] 93 | tsa_fusion_params = [] 94 | for k, v in self.netG.named_parameters(): 95 | if v.requires_grad: 96 | if 'tsa_fusion' in k: 97 | tsa_fusion_params.append(v) 98 | else: 99 | normal_params.append(v) 100 | else: 101 | if self.rank <= 0: 102 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 103 | optim_params = [ 104 | { # add normal params first 105 | 'params': normal_params, 106 | 'lr': train_opt['lr_G'] 107 | }, 108 | { 109 | 'params': tsa_fusion_params, 110 | 'lr': train_opt['lr_G'] 111 | }, 112 | ] 113 | else: 114 | optim_params = [] 115 | for k, v in self.netG.named_parameters(): 116 | if v.requires_grad: 117 | optim_params.append(v) 118 | else: 119 | if self.rank <= 0: 120 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 121 | 122 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 123 | weight_decay=wd_G, 124 | betas=(train_opt['beta1'], train_opt['beta2'])) 125 | self.optimizers.append(self.optimizer_G) 126 | 127 | #### schedulers 128 | if train_opt['lr_scheme'] == 'MultiStepLR_Restart': 129 | for optimizer in self.optimizers: 130 | self.schedulers.append( 131 | lr_scheduler.MultiStepLR_Restart( 132 | optimizer, train_opt['lr_steps'], 133 | restarts=train_opt['restarts'], 134 | weights=train_opt['restart_weights'], 135 | gamma=train_opt['lr_gamma'], 136 | clear_state=train_opt['clear_state'] 137 | ) 138 | ) 139 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 140 | for optimizer in self.optimizers: 141 | self.schedulers.append( 142 | lr_scheduler.CosineAnnealingLR_Restart( 143 | optimizer, train_opt['T_period'], 144 | eta_min=train_opt['eta_min'], 145 | restarts=train_opt['restarts'], 146 | weights=train_opt['restart_weights'] 147 | ) 148 | ) 149 | else: 150 | raise NotImplementedError() 151 | 152 | self.log_dict = OrderedDict() 153 | 154 | def feed_data(self, data, need_GT=True): 155 | self.var_L = data['LQs'].to(self.device) 156 | if need_GT: 157 | self.var_H = data['GT'].to(self.device) 158 | 159 | def set_params_lr_zero(self): 160 | # fix normal module 161 | self.optimizers[0].param_groups[0]['lr'] = 0 162 | 163 | def optimize_parameters(self, step): 164 | if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: 165 | self.set_params_lr_zero() 166 | 167 | # clear gradient and forward propagate 168 | self.optimizer_G.zero_grad() 169 | if self.opt['augment']: 170 | opt = self.opt['augment'] 171 | self.var_H, self.var_L = augments.apply_augment( 172 | self.var_H, self.var_L, 173 | opt['augs'], opt['probs'], opt['alphas'], opt['mix_p'] 174 | ) 175 | self.fake_H = self.netG(self.var_L) 176 | else: 177 | self.fake_H = self.netG(self.var_L) 178 | 179 | # calculate loss and back propagate 180 | center_idx = self.var_L.size(1) // 2 181 | # loss for luminance channel 182 | l_pix_y = self.l_pix_w_y * self.cri_pix_y(self.fake_H[:, 0:1, :, :], self.var_H[:, center_idx, 0:1, :, :].contiguous()) 183 | # loss for color channel 184 | l_pix_c = self.l_pix_w_c * self.cri_pix_c(self.fake_H[:, 1:3, :, :], self.var_H[:, center_idx, 1:3, :, :].contiguous()) 185 | l_pix = l_pix_y + l_pix_c 186 | l_pix.backward() 187 | self.optimizer_G.step() 188 | # set log 189 | self.log_dict['l_pix_y'] = l_pix_y.item() 190 | self.log_dict['l_pix_c'] = l_pix_c.item() 191 | self.log_dict['l_pix'] = l_pix.item() 192 | 193 | def test(self): 194 | self.netG.eval() 195 | with torch.no_grad(): 196 | self.fake_H = self.netG(self.var_L) 197 | self.netG.train() 198 | 199 | def get_current_log(self): 200 | return self.log_dict 201 | 202 | def get_current_visuals(self, need_GT=True): 203 | out_dict = OrderedDict() 204 | out_dict['LQs'] = self.var_L.detach()[0].float().cpu() 205 | out_dict['HQ'] = self.fake_H.detach()[0].float().cpu() 206 | if need_GT: 207 | out_dict['GT'] = self.var_H.detach()[0].float().cpu() 208 | return out_dict 209 | 210 | def print_network(self): 211 | s, n = self.get_network_description(self.netG) 212 | if isinstance(self.netG, nn.DataParallel): 213 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 214 | self.netG.module.__class__.__name__) 215 | else: 216 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 217 | if self.rank <= 0: 218 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 219 | logger.info(s) 220 | 221 | def load(self): 222 | load_path_G = self.opt['path']['pretrain_model_G'] 223 | if load_path_G is not None: 224 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 225 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 226 | 227 | def load_separately(self): 228 | load_path_G_a = self.opt['path']['pretrain_model_G_a'] 229 | load_path_G_b = self.opt['path']['pretrain_model_G_b'] 230 | name_a = self.opt['path']['name_a'] 231 | name_b = self.opt['path']['name_b'] 232 | if load_path_G_a is not None and load_path_G_b is not None: 233 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G_a)) 234 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G_b)) 235 | self.load_network_separately(load_path_G_a, load_path_G_b, name_a, name_b, 236 | self.netG, self.opt['path']['strict_load']) 237 | 238 | def save(self, iter_step): 239 | self.save_network(self.netG, 'G', iter_step) 240 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /codes/data/Vimeo90K_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Vimeo90K dataset 3 | support reading images from lmdb, image folder and memcached 4 | ''' 5 | import os.path as osp 6 | import random 7 | import pickle 8 | import logging 9 | import numpy as np 10 | import cv2 11 | import lmdb 12 | import torch 13 | import torch.utils.data as data 14 | import utils.util as util 15 | import data.util as data_util 16 | try: 17 | import mc # import memcached 18 | except ImportError: 19 | pass 20 | 21 | logger = logging.getLogger('base') 22 | 23 | 24 | class Vimeo90kDataset(data.Dataset): 25 | ''' 26 | Reading the training Vimeo90K dataset: LQ is precomputed from GT 27 | key example: 00001_0001 (_1, ..., _7) 28 | ''' 29 | 30 | def __init__(self, opt): 31 | super(Vimeo90kDataset, self).__init__() 32 | self.opt = opt 33 | # temporal augmentation 34 | self.interval_list = opt['interval_list'] 35 | self.random_reverse = opt['random_reverse'] 36 | logger.info( 37 | 'Temporal augmentation interval list: [{}], with random reverse is {}.' 38 | .format(','.join(str(x) for x in opt['interval_list']), self.random_reverse) 39 | ) 40 | 41 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] 42 | self.data_type = self.opt['data_type'] 43 | self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs 44 | 45 | #### determine the LQ frame list 46 | ''' 47 | N | frames 48 | 1 | 4 49 | 3 | 3,4,5 50 | 5 | 2,3,4,5,6 51 | 7 | 1,2,3,4,5,6,7 52 | ''' 53 | self.frame_list = [] 54 | self.center = opt['N_frames'] // 2 55 | for i in range(opt['N_frames']): 56 | self.frame_list.append(i + (9 - opt['N_frames']) // 2) 57 | 58 | if self.data_type == 'lmdb': 59 | self.paths_GT, _ = data_util.get_image_paths(self.data_type, opt['dataroot_GT']) 60 | logger.info('Using lmdb meta info for cache keys.') 61 | elif opt['cache_keys']: 62 | logger.info('Using cache keys: {}'.format(opt['cache_keys'])) 63 | self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] 64 | else: 65 | raise ValueError('Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') 66 | assert self.paths_GT, 'Error: GT path is empty.' 67 | 68 | if self.data_type == 'lmdb': 69 | self.GT_env, self.LQ_env = None, None 70 | elif self.data_type == 'img': 71 | pass 72 | else: 73 | raise ValueError('Wrong data type: {}'.format(self.data_type)) 74 | 75 | def _init_lmdb(self): 76 | # https://github.com/chainer/chainermn/issues/129 77 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, meminit=False) 78 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, meminit=False) 79 | 80 | def __getitem__(self, index): 81 | if self.data_type == 'lmdb': 82 | if (self.GT_env is None) or (self.LQ_env is None): 83 | self._init_lmdb() 84 | 85 | scale = self.opt['scale'] 86 | GT_size, LQ_size = self.opt['GT_size'], self.opt['LQ_size'] 87 | key = self.paths_GT[index] 88 | name_a, name_b = key.split('_') 89 | 90 | #### temporal augmentation: random reverse 91 | if self.random_reverse and random.random() < 0.5: 92 | self.frame_list.reverse() 93 | 94 | #### get GT image 95 | GT_size_tuple = (3, 256, 448) 96 | if self.data_type == 'lmdb': 97 | img_GT = data_util.read_img(self.GT_env, key + '_4', GT_size_tuple) 98 | else: 99 | img_GT = data_util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png')) 100 | 101 | #### get LQ images 102 | LQ_size_tuple = (3, 256 // scale, 448 // scale) if self.LR_input else (3, 256, 448) 103 | img_LQ_l = [] 104 | for v in self.frame_list: 105 | if self.data_type == 'lmdb': 106 | img_LQ = data_util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple) 107 | else: 108 | img_LQ = data_util.read_img(None, osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v))) 109 | img_LQ_l.append(img_LQ) 110 | 111 | if self.opt['phase'] == 'train': 112 | C, H, W = LQ_size_tuple # LQ size 113 | # randomly crop 114 | if self.LR_input: 115 | LQ_size = GT_size // scale 116 | rnd_h_LQ = random.randint(0, max(0, H - LQ_size)) 117 | rnd_w_LQ = random.randint(0, max(0, W - LQ_size)) 118 | img_LQ_l = [v[rnd_h_LQ:rnd_h_LQ + LQ_size, rnd_w_LQ:rnd_w_LQ + LQ_size, :] for v in img_LQ_l] 119 | rnd_h_GT = int(rnd_h_LQ * scale) 120 | rnd_w_GT = int(rnd_w_LQ * scale) 121 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] 122 | else: 123 | rnd_h_LQ = random.randint(0, max(0, H - GT_size)) 124 | rnd_w_LQ = random.randint(0, max(0, W - GT_size)) 125 | img_LQ_l = [v[rnd_h_LQ:rnd_h_LQ + GT_size, rnd_w_LQ:rnd_w_LQ + GT_size, :] for v in img_LQ_l] 126 | rnd_h_GT = rnd_h_LQ 127 | rnd_w_GT = rnd_w_LQ 128 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] 129 | 130 | # augmentation - flip, rotate 131 | img_LQ_l.append(img_GT) 132 | rlt = data_util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) 133 | img_LQ_l = rlt[0:-1] 134 | img_GT = rlt[-1] 135 | 136 | img_LQs = np.stack(img_LQ_l, axis=0) # stack LQ images to NHWC, N is the frame number 137 | 138 | #### BGR to RGB, HWC to CHW, numpy to tensor 139 | img_GT = img_GT[:, :, [2, 1, 0]] 140 | img_LQs = img_LQs[:, :, :, [2, 1, 0]] 141 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 142 | img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, (0, 3, 1, 2)))).float() 143 | 144 | return {'LQs': img_LQs, 'GT': img_GT, 'key': key} 145 | 146 | def __len__(self): 147 | return len(self.paths_GT) 148 | 149 | 150 | class Vimeo90kAllPairDataset(data.Dataset): 151 | ''' 152 | Reading the training Vimeo90K dataset: LQ is precomputed from GT 153 | key example: 00001_0001 (_1, ..., _7) 154 | ''' 155 | 156 | def __init__(self, opt): 157 | super(Vimeo90kAllPairDataset, self).__init__() 158 | self.opt = opt 159 | # temporal augmentation 160 | self.interval_list = opt['interval_list'] 161 | self.random_reverse = opt['random_reverse'] 162 | logger.info( 163 | 'Temporal augmentation interval list: [{}], with random reverse is {}.' 164 | .format(','.join(str(x) for x in opt['interval_list']), self.random_reverse) 165 | ) 166 | 167 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] 168 | self.data_type = self.opt['data_type'] 169 | self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs 170 | 171 | #### determine the LQ frame list 172 | ''' 173 | N | frames 174 | 1 | 4 175 | 3 | 3,4,5 176 | 5 | 2,3,4,5,6 177 | 7 | 1,2,3,4,5,6,7 178 | ''' 179 | self.frame_list = [] 180 | self.center = opt['N_frames'] // 2 181 | for i in range(opt['N_frames']): 182 | self.frame_list.append(i + (9 - opt['N_frames']) // 2) 183 | 184 | if self.data_type == 'lmdb': 185 | self.paths_GT, _ = data_util.get_image_paths(self.data_type, opt['dataroot_GT']) 186 | logger.info('Using lmdb meta info for cache keys.') 187 | elif opt['cache_keys']: 188 | logger.info('Using cache keys: {}'.format(opt['cache_keys'])) 189 | self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb')) 190 | else: 191 | raise ValueError('Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') 192 | assert self.paths_GT, 'Error: GT path is empty.' 193 | 194 | if self.data_type == 'lmdb': 195 | self.GT_env, self.LQ_env = None, None 196 | elif self.data_type == 'img': 197 | pass 198 | else: 199 | raise ValueError('Wrong data type: {}'.format(self.data_type)) 200 | 201 | def _init_lmdb(self): 202 | # https://github.com/chainer/chainermn/issues/129 203 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, meminit=False) 204 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, meminit=False) 205 | 206 | def __getitem__(self, index): 207 | if self.data_type == 'lmdb': 208 | if (self.GT_env is None) or (self.LQ_env is None): 209 | self._init_lmdb() 210 | 211 | scale = self.opt['scale'] 212 | GT_size, LQ_size = self.opt['GT_size'], self.opt['LQ_size'] 213 | key = self.paths_GT[index] 214 | name_a, name_b = key.split('_') 215 | 216 | #### temporal augmentation: random reverse 217 | if self.random_reverse and random.random() < 0.5: 218 | self.frame_list.reverse() 219 | 220 | #### get GT image 221 | GT_size_tuple = (3, 256, 448) 222 | img_GT_l = [] 223 | for v in self.frame_list: 224 | if self.data_type == 'lmdb': 225 | img_GT = data_util.read_img(self.GT_env, key + '_{}'.format(v), GT_size_tuple) 226 | else: 227 | img_GT = data_util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im{}.png'.format(v))) 228 | img_GT_l.append(img_GT) 229 | 230 | #### get LQ images 231 | LQ_size_tuple = (3, 256 // scale, 448 // scale) if self.LR_input else (3, 256, 448) 232 | img_LQ_l = [] 233 | for v in self.frame_list: 234 | if self.data_type == 'lmdb': 235 | img_LQ = data_util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple) 236 | else: 237 | img_LQ = data_util.read_img(None, osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v))) 238 | img_LQ_l.append(img_LQ) 239 | 240 | if self.opt['phase'] == 'train': 241 | C, H, W = LQ_size_tuple # LQ size 242 | # randomly crop 243 | if self.LR_input: 244 | LQ_size = GT_size // scale 245 | rnd_h_LQ = random.randint(0, max(0, H - LQ_size)) 246 | rnd_w_LQ = random.randint(0, max(0, W - LQ_size)) 247 | img_LQ_l = [v[rnd_h_LQ:rnd_h_LQ + LQ_size, rnd_w_LQ:rnd_w_LQ + LQ_size, :] for v in img_LQ_l] 248 | rnd_h_GT = int(rnd_h_LQ * scale) 249 | rnd_w_GT = int(rnd_w_LQ * scale) 250 | img_GT_l = [v[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] for v in img_GT_l] 251 | else: 252 | rnd_h_LQ = random.randint(0, max(0, H - GT_size)) 253 | rnd_w_LQ = random.randint(0, max(0, W - GT_size)) 254 | img_LQ_l = [v[rnd_h_LQ:rnd_h_LQ + GT_size, rnd_w_LQ:rnd_w_LQ + GT_size, :] for v in img_LQ_l] 255 | rnd_h_GT = rnd_h_LQ 256 | rnd_w_GT = rnd_w_LQ 257 | img_GT_l = [v[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] for v in img_GT_l] 258 | 259 | # augmentation - flip, rotate 260 | rlt = data_util.augment([*img_LQ_l, *img_GT_l], self.opt['use_flip'], self.opt['use_rot']) 261 | img_LQ_l = rlt[:len(self.frame_list)] 262 | img_GT_l = rlt[len(self.frame_list):] 263 | 264 | img_LQs = np.stack(img_LQ_l, axis=0) # stack LQ images to NHWC, N is the frame number 265 | img_GTs = np.stack(img_GT_l, axis=0) # stack GT images to NHWC, N is the frame number 266 | 267 | #### BGR to RGB, HWC to CHW, numpy to tensor 268 | img_GTs = img_GTs[:, :, :, [2, 1, 0]] 269 | img_LQs = img_LQs[:, :, :, [2, 1, 0]] 270 | img_GTs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GTs, (0, 3, 1, 2)))).float() 271 | img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, (0, 3, 1, 2)))).float() 272 | 273 | return {'LQs': img_LQs, 'GT': img_GTs, 'key': key} 274 | 275 | def __len__(self): 276 | return len(self.paths_GT) 277 | 278 | 279 | if __name__ == '__main__': 280 | pass -------------------------------------------------------------------------------- /codes/models/VideoSR_AllPair_model_YCbCr_Combine.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import kornia 8 | from torch.nn.parallel import DataParallel, DistributedDataParallel 9 | import models.VideoSR_archs as networks 10 | import models.lr_scheduler as lr_scheduler 11 | import data.augments_video_allpair as augments 12 | from .base_model import BaseModel 13 | from models.loss import CharbonnierLoss, HuberLoss, GWLoss, PyramidLoss, LapPyrLoss 14 | from IQA_pytorch import SSIM, MS_SSIM, DISTS 15 | 16 | 17 | logger = logging.getLogger('base') 18 | 19 | 20 | class VideoSRModel(BaseModel): 21 | 22 | def __init__(self, opt): 23 | super(VideoSRModel, self).__init__(opt) 24 | 25 | if opt['dist']: 26 | self.rank = torch.distributed.get_rank() 27 | else: 28 | self.rank = -1 # non dist training 29 | train_opt = opt['train'] 30 | 31 | # define network and load pretrained models 32 | self.netG = networks.define_G(opt).to(self.device) 33 | if opt['dist']: 34 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 35 | else: 36 | self.netG = DataParallel(self.netG) 37 | # print network 38 | self.print_network() 39 | self.load() 40 | 41 | if self.is_train: 42 | self.netG.train() 43 | 44 | #### loss 45 | loss_type = train_opt['pixel_criterion'] 46 | if loss_type == 'l1': 47 | self.cri_pix = nn.L1Loss(reduction='mean').to(self.device) 48 | elif loss_type == 'l2': 49 | self.cri_pix = nn.MSELoss(reduction='mean').to(self.device) 50 | elif loss_type == 'cb': 51 | self.cri_pix = CharbonnierLoss(reduction='mean').to(self.device) 52 | elif loss_type == 'hb': 53 | self.cri_pix = HuberLoss(reduction='mean').to(self.device) 54 | elif loss_type == 'pyr': 55 | self.cri_pix = PyramidLoss(num_levels=3, pyr_mode='gau', loss_mode='cb', reduction='mean') 56 | elif loss_type == 'lappyr': 57 | self.cri_pix = LapPyrLoss(num_levels=3, lf_mode='ssim', hf_mode='cb', reduction='mean') 58 | elif loss_type == 'msssim': 59 | channels = opt['network_G']['nc'] 60 | self.cri_pix = MS_SSIM(channels=channels).to(self.device) 61 | else: 62 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 63 | self.l_pix_w = train_opt['pixel_weight'] 64 | 65 | if train_opt.get('edge_criterion') and train_opt.get('edge_weight'): 66 | loss_type = train_opt['edge_criterion'] 67 | if loss_type == 'l1': 68 | self.cri_edg = nn.L1Loss(reduction='mean').to(self.device) 69 | elif loss_type == 'l2': 70 | self.cri_edg = nn.MSELoss(reduction='mean').to(self.device) 71 | elif loss_type == 'cb': 72 | self.cri_edg = CharbonnierLoss(reduction='mean').to(self.device) 73 | elif loss_type == 'hb': 74 | self.cri_edg = HuberLoss(reduction='mean').to(self.device) 75 | elif loss_type == 'pyr': 76 | self.cri_edg = PyramidLoss(num_levels=3, pyr_mode='lap', loss_mode='cb', reduction='mean') 77 | elif loss_type == 'lappyr': 78 | self.cri_edg = LapPyrLoss(num_levels=3, lf_mode='ssim', hf_mode='cb', reduction='mean') 79 | elif loss_type == 'msssim': 80 | channels = opt['network_G']['nc'] 81 | self.cri_edg = MS_SSIM(channels=channels).to(self.device) 82 | else: 83 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 84 | self.l_edg_w = train_opt['edge_weight'] 85 | else: 86 | logger.info('Remove edge loss.') 87 | self.cri_edg = None 88 | 89 | # G feature loss 90 | if train_opt.get('feature_criterion') and train_opt.get('feature_weight'): 91 | l_fea_type = train_opt['feature_criterion'] 92 | if l_fea_type == 'l1': 93 | self.cri_fea = nn.L1Loss().to(self.device) 94 | elif l_fea_type == 'l2': 95 | self.cri_fea = nn.MSELoss().to(self.device) 96 | elif loss_type == 'cb': 97 | self.cri_fea = CharbonnierLoss().to(self.device) 98 | elif loss_type == 'hb': 99 | self.cri_fea = HuberLoss().to(self.device) 100 | else: 101 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) 102 | self.l_fea_w = train_opt['feature_weight'] 103 | else: 104 | logger.info('Remove feature loss.') 105 | self.cri_fea = None 106 | if self.cri_fea: # load VGG perceptual loss 107 | self.netF = networks.define_F(opt, use_bn=False).to(self.device) 108 | if opt['dist']: 109 | pass # do not need to use DistributedDataParallel for netF 110 | else: 111 | self.netF = DataParallel(self.netF) 112 | 113 | #### optimizers 114 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 115 | if train_opt['ft_tsa_only']: 116 | normal_params = [] 117 | tsa_fusion_params = [] 118 | for k, v in self.netG.named_parameters(): 119 | if v.requires_grad: 120 | if 'tsa_fusion' in k: 121 | tsa_fusion_params.append(v) 122 | else: 123 | normal_params.append(v) 124 | else: 125 | if self.rank <= 0: 126 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 127 | optim_params = [ 128 | { # add normal params first 129 | 'params': normal_params, 130 | 'lr': train_opt['lr_G'] 131 | }, 132 | { 133 | 'params': tsa_fusion_params, 134 | 'lr': train_opt['lr_G'] 135 | }, 136 | ] 137 | else: 138 | optim_params = [] 139 | for k, v in self.netG.named_parameters(): 140 | if v.requires_grad: 141 | optim_params.append(v) 142 | else: 143 | if self.rank <= 0: 144 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 145 | 146 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 147 | weight_decay=wd_G, 148 | betas=(train_opt['beta1'], train_opt['beta2'])) 149 | self.optimizers.append(self.optimizer_G) 150 | 151 | #### schedulers 152 | if train_opt['lr_scheme'] == 'MultiStepLR_Restart': 153 | for optimizer in self.optimizers: 154 | self.schedulers.append( 155 | lr_scheduler.MultiStepLR_Restart( 156 | optimizer, train_opt['lr_steps'], 157 | restarts=train_opt['restarts'], 158 | weights=train_opt['restart_weights'], 159 | gamma=train_opt['lr_gamma'], 160 | clear_state=train_opt['clear_state'] 161 | ) 162 | ) 163 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 164 | for optimizer in self.optimizers: 165 | self.schedulers.append( 166 | lr_scheduler.CosineAnnealingLR_Restart( 167 | optimizer, train_opt['T_period'], 168 | eta_min=train_opt['eta_min'], 169 | restarts=train_opt['restarts'], 170 | weights=train_opt['restart_weights'] 171 | ) 172 | ) 173 | else: 174 | raise NotImplementedError() 175 | 176 | self.log_dict = OrderedDict() 177 | 178 | def feed_data(self, data, need_GT=True): 179 | self.var_L = data['LQs'].to(self.device) 180 | if need_GT: 181 | self.var_H = data['GT'].to(self.device) 182 | 183 | def set_params_lr_zero(self): 184 | # fix normal module 185 | self.optimizers[0].param_groups[0]['lr'] = 0 186 | 187 | def optimize_parameters(self, step): 188 | if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: 189 | self.set_params_lr_zero() 190 | 191 | # clear gradient and forward propagate 192 | self.optimizer_G.zero_grad() 193 | if self.opt['augment']: 194 | opt = self.opt['augment'] 195 | self.var_H, self.var_L = augments.apply_augment( 196 | self.var_H, self.var_L, 197 | opt['augs'], opt['probs'], opt['alphas'], opt['mix_p'] 198 | ) 199 | self.fake_H = self.netG(self.var_L) 200 | else: 201 | self.fake_H = self.netG(self.var_L) 202 | 203 | # calculate loss and back propagate 204 | center_idx = self.var_L.size(1) // 2 205 | l_tot = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H[:, center_idx, :, :, :].contiguous()) 206 | if self.cri_edg: 207 | l_edg = self.l_edg_w * self.cri_edg(self.fake_H, self.var_H[:, center_idx, :, :, :].contiguous()) 208 | l_tot += l_edg 209 | if self.cri_fea: 210 | real_fea = self.netF(self.var_H[:, center_idx, :, :, :]).detach() 211 | fake_fea = self.netF(self.fake_H) 212 | l_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) 213 | l_tot += l_fea 214 | l_tot.backward() 215 | self.optimizer_G.step() 216 | # set log 217 | self.log_dict['l_tot'] = l_tot.item() 218 | if self.cri_fea: 219 | self.log_dict['l_fea'] = l_fea.item() 220 | if self.cri_edg: 221 | self.log_dict['l_edg'] = l_edg.item() 222 | 223 | def test(self): 224 | self.netG.eval() 225 | with torch.no_grad(): 226 | self.fake_H = self.netG(self.var_L) 227 | self.netG.train() 228 | 229 | def get_current_log(self): 230 | return self.log_dict 231 | 232 | def get_current_visuals(self, need_GT=True): 233 | out_dict = OrderedDict() 234 | out_dict['LQs'] = self.var_L.detach()[0].float().cpu() 235 | out_dict['HQ'] = self.fake_H.detach()[0].float().cpu() 236 | if need_GT: 237 | out_dict['GT'] = self.var_H.detach()[0].float().cpu() 238 | return out_dict 239 | 240 | def print_network(self): 241 | s, n = self.get_network_description(self.netG) 242 | if isinstance(self.netG, nn.DataParallel): 243 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 244 | self.netG.module.__class__.__name__) 245 | else: 246 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 247 | if self.rank <= 0: 248 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 249 | logger.info(s) 250 | 251 | def load(self): 252 | load_path_G = self.opt['path']['pretrain_model_G'] 253 | if load_path_G is not None: 254 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 255 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 256 | 257 | def load_separately(self): 258 | load_path_G_a = self.opt['path']['pretrain_model_G_a'] 259 | load_path_G_b = self.opt['path']['pretrain_model_G_b'] 260 | name_a = self.opt['path']['name_a'] 261 | name_b = self.opt['path']['name_b'] 262 | if load_path_G_a is not None and load_path_G_b is not None: 263 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G_a)) 264 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G_b)) 265 | self.load_network_separately(load_path_G_a, load_path_G_b, name_a, name_b, 266 | self.netG, self.opt['path']['strict_load']) 267 | 268 | def save(self, iter_step): 269 | self.save_network(self.netG, 'G', iter_step) 270 | --------------------------------------------------------------------------------