├── code ├── utils │ ├── __init__.py │ └── util.py ├── options │ ├── __init__.py │ ├── train │ │ ├── train_LF-VSN_3video.yml │ │ ├── train_LF-VSN_4video.yml │ │ ├── train_LF-VSN_5video.yml │ │ ├── train_LF-VSN_6video.yml │ │ ├── train_LF-VSN_7video.yml │ │ ├── train_LF-VSN_1video.yml │ │ └── train_LF-VSN_2video.yml │ └── options.py ├── models │ ├── modules │ │ ├── __init__.py │ │ ├── Quantization.py │ │ ├── common.py │ │ ├── loss.py │ │ ├── Subnet_constructor.py │ │ ├── module_util.py │ │ └── Inv_arch.py │ ├── __init__.py │ ├── networks.py │ ├── base_model.py │ ├── lr_scheduler.py │ ├── discrim.py │ └── LFVSN.py ├── data │ ├── video_test_dataset.py │ ├── __init__.py │ ├── data_sampler.py │ ├── Vimeo90K_dataset.py │ └── util.py ├── test.py └── train.py ├── assets ├── overview.PNG └── performance.PNG └── README.md /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/overview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/LF-VSN/HEAD/assets/overview.PNG -------------------------------------------------------------------------------- /assets/performance.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MC-E/LF-VSN/HEAD/assets/performance.PNG -------------------------------------------------------------------------------- /code/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | def create_model(opt): 5 | model = opt['model'] 6 | from .LFVSN import Model_VSN as M 7 | 8 | m = M(opt) 9 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 10 | return m -------------------------------------------------------------------------------- /code/models/modules/Quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Quant(torch.autograd.Function): 5 | 6 | @staticmethod 7 | def forward(ctx, input): 8 | input = torch.clamp(input, 0, 1) 9 | output = (input * 255.).round() / 255. 10 | return output 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | return grad_output 15 | 16 | class Quantization(nn.Module): 17 | def __init__(self): 18 | super(Quantization, self).__init__() 19 | 20 | def forward(self, input): 21 | return Quant.apply(input) 22 | -------------------------------------------------------------------------------- /code/models/networks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | from models.modules.Inv_arch import * 5 | from models.modules.Subnet_constructor import subnet 6 | 7 | logger = logging.getLogger('base') 8 | 9 | #################### 10 | # define network 11 | #################### 12 | def define_G_v2(opt): 13 | opt_net = opt['network_G'] 14 | which_model = opt_net['which_model_G'] 15 | subnet_type = which_model['subnet_type'] 16 | opt_datasets = opt['datasets'] 17 | down_num = int(math.log(opt_net['scale'], 2)) 18 | if opt['num_video'] == 1: 19 | netG = VSN(opt, subnet(subnet_type, 'xavier'), subnet(subnet_type, 'xavier'), down_num) 20 | else: 21 | netG = VSN(opt, subnet(subnet_type, 'xavier'), subnet(subnet_type, 'xavier_v2'), down_num) 22 | 23 | return netG 24 | -------------------------------------------------------------------------------- /code/data/video_test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | import torch.utils.data as data 5 | import data.util as util 6 | 7 | class VideoTestDataset(data.Dataset): 8 | """ 9 | A video test dataset. Support: 10 | Vid4 11 | REDS4 12 | Vimeo90K-Test 13 | 14 | no need to prepare LMDB files 15 | """ 16 | 17 | def __init__(self, opt): 18 | super(VideoTestDataset, self).__init__() 19 | self.opt = opt 20 | self.half_N_frames = opt['N_frames'] // 2 21 | self.data_path = opt['data_path'] 22 | self.txt_path = self.opt['txt_path'] 23 | self.num_video = self.opt['num_video'] 24 | with open(self.txt_path) as f: 25 | self.list_video = f.readlines() 26 | self.list_video = [line.strip('\n') for line in self.list_video] 27 | self.list_video.sort() 28 | self.list_video = self.list_video[:200] 29 | l = len(self.list_video) // (self.num_video + 1) 30 | self.video_list_gt = self.list_video[:l] 31 | self.video_list_lq = self.list_video[l:l * (self.num_video + 1)] 32 | 33 | def __getitem__(self, index): 34 | path_GT = self.video_list_gt[index] 35 | 36 | img_GT = util.read_img_seq(os.path.join(self.data_path, path_GT)) 37 | list_h = [] 38 | for i in range(self.num_video): 39 | path_LQ = self.video_list_lq[index*self.num_video+i] 40 | imgs_LQ = util.read_img_seq(os.path.join(self.data_path, path_LQ)) 41 | list_h.append(imgs_LQ) 42 | list_h = torch.stack(list_h, dim=0) 43 | return { 44 | 'LQ': list_h, 45 | 'GT': img_GT 46 | } 47 | 48 | def __len__(self): 49 | return len(self.video_list_gt) 50 | -------------------------------------------------------------------------------- /code/data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 7 | phase = dataset_opt['phase'] 8 | if phase == 'train': 9 | if opt['dist']: 10 | world_size = torch.distributed.get_world_size() 11 | num_workers = dataset_opt['n_workers'] 12 | assert dataset_opt['batch_size'] % world_size == 0 13 | batch_size = dataset_opt['batch_size'] // world_size 14 | shuffle = False 15 | else: 16 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 17 | batch_size = dataset_opt['batch_size'] 18 | shuffle = True 19 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 20 | num_workers=num_workers, sampler=sampler, drop_last=True, 21 | pin_memory=False) 22 | else: 23 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 24 | pin_memory=True) 25 | 26 | 27 | def create_dataset(dataset_opt): 28 | mode = dataset_opt['mode'] 29 | if mode == 'test': 30 | from data.video_test_dataset import VideoTestDataset as D 31 | elif mode == 'train': 32 | from data.Vimeo90K_dataset import Vimeo90KDataset as D 33 | else: 34 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 35 | print(mode) 36 | dataset = D(dataset_opt) 37 | 38 | logger = logging.getLogger('base') 39 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 40 | dataset_opt['name'])) 41 | return dataset 42 | -------------------------------------------------------------------------------- /code/options/train/train_LF-VSN_3video.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: train_LF-VSN_3video 4 | use_tb_logger: true 5 | model: MIMO-VRN-h 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0, 1] 9 | gop: 3 10 | num_video: 3 11 | 12 | #### datasets 13 | 14 | datasets: 15 | train: 16 | name: Vimeo90K 17 | mode: train 18 | interval_list: [1] 19 | random_reverse: false 20 | border_mode: false 21 | data_path: vimeo90k/sequences 22 | txt_path: vimeo90k/sep_trainlist.txt 23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb 24 | cache_keys: Vimeo90K_train_keys.pkl 25 | num_video: 3 26 | 27 | N_frames: 7 28 | use_shuffle: true 29 | n_workers: 24 # per GPU 30 | batch_size: 8 31 | GT_size: 144 32 | LQ_size: 36 33 | use_flip: true 34 | use_rot: true 35 | color: RGB 36 | 37 | val: 38 | num_video: 3 39 | name: Vid4 40 | mode: test 41 | data_path: vimeo90k/sequences 42 | txt_path: vimeo90k/sep_testlist.txt 43 | N_frames: 1 44 | padding: 'new_info' 45 | pred_interval: -1 46 | 47 | 48 | #### network structures 49 | 50 | network_G: 51 | which_model_G: 52 | subnet_type: DBNet 53 | in_nc: 12 54 | out_nc: 12 55 | block_num: [8, 8] 56 | scale: 2 57 | init: xavier_group 58 | block_num_rbm: 8 59 | 60 | 61 | #### path 62 | 63 | path: 64 | pretrain_model_G: 65 | models: ckp/base 66 | strict_load: true 67 | resume_state: ~ 68 | 69 | 70 | #### training settings: learning rate scheme, loss 71 | 72 | train: 73 | 74 | lr_G: !!float 1e-4 75 | beta1: 0.9 76 | beta2: 0.5 77 | niter: 250000 78 | warmup_iter: -1 # no warm up 79 | 80 | lr_scheme: MultiStepLR 81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] 82 | lr_gamma: 0.5 83 | 84 | pixel_criterion_forw: l2 85 | pixel_criterion_back: l1 86 | 87 | manual_seed: 10 88 | 89 | val_freq: !!float 1000 #!!float 5e3 90 | 91 | lambda_fit_forw: 64. 92 | lambda_rec_back: 1 93 | lambda_center: 0 94 | 95 | weight_decay_G: !!float 1e-12 96 | gradient_clipping: 10 97 | 98 | 99 | #### logger 100 | 101 | logger: 102 | print_freq: 100 103 | save_checkpoint_freq: !!float 5e3 104 | -------------------------------------------------------------------------------- /code/options/train/train_LF-VSN_4video.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: train_LF-VSN_4video 4 | use_tb_logger: true 5 | model: MIMO-VRN-h 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0, 1] 9 | gop: 3 10 | num_video: 4 11 | 12 | #### datasets 13 | 14 | datasets: 15 | train: 16 | name: Vimeo90K 17 | mode: train 18 | interval_list: [1] 19 | random_reverse: false 20 | border_mode: false 21 | data_path: vimeo90k/sequences 22 | txt_path: vimeo90k/sep_trainlist.txt 23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb 24 | cache_keys: Vimeo90K_train_keys.pkl 25 | num_video: 4 26 | 27 | N_frames: 7 28 | use_shuffle: true 29 | n_workers: 24 # per GPU 30 | batch_size: 8 31 | GT_size: 144 32 | LQ_size: 36 33 | use_flip: true 34 | use_rot: true 35 | color: RGB 36 | 37 | val: 38 | num_video: 4 39 | name: Vid4 40 | mode: test 41 | data_path: vimeo90k/sequences 42 | txt_path: vimeo90k/sep_testlist.txt 43 | N_frames: 1 44 | padding: 'new_info' 45 | pred_interval: -1 46 | 47 | 48 | #### network structures 49 | 50 | network_G: 51 | which_model_G: 52 | subnet_type: DBNet 53 | in_nc: 12 54 | out_nc: 12 55 | block_num: [8, 8] 56 | scale: 2 57 | init: xavier_group 58 | block_num_rbm: 8 59 | 60 | 61 | #### path 62 | 63 | path: 64 | pretrain_model_G: 65 | models: ckp/base 66 | strict_load: true 67 | resume_state: ~ 68 | 69 | 70 | #### training settings: learning rate scheme, loss 71 | 72 | train: 73 | 74 | lr_G: !!float 1e-4 75 | beta1: 0.9 76 | beta2: 0.5 77 | niter: 250000 78 | warmup_iter: -1 # no warm up 79 | 80 | lr_scheme: MultiStepLR 81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] 82 | lr_gamma: 0.5 83 | 84 | pixel_criterion_forw: l2 85 | pixel_criterion_back: l1 86 | 87 | manual_seed: 10 88 | 89 | val_freq: !!float 1000 #!!float 5e3 90 | 91 | lambda_fit_forw: 64. 92 | lambda_rec_back: 1 93 | lambda_center: 0 94 | 95 | weight_decay_G: !!float 1e-12 96 | gradient_clipping: 10 97 | 98 | 99 | #### logger 100 | 101 | logger: 102 | print_freq: 100 103 | save_checkpoint_freq: !!float 5e3 104 | -------------------------------------------------------------------------------- /code/options/train/train_LF-VSN_5video.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: train_LF-VSN_5video 4 | use_tb_logger: true 5 | model: MIMO-VRN-h 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0, 1] 9 | gop: 3 10 | num_video: 5 11 | 12 | #### datasets 13 | 14 | datasets: 15 | train: 16 | name: Vimeo90K 17 | mode: train 18 | interval_list: [1] 19 | random_reverse: false 20 | border_mode: false 21 | data_path: vimeo90k/sequences 22 | txt_path: vimeo90k/sep_trainlist.txt 23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb 24 | cache_keys: Vimeo90K_train_keys.pkl 25 | num_video: 5 26 | 27 | N_frames: 7 28 | use_shuffle: true 29 | n_workers: 24 # per GPU 30 | batch_size: 8 31 | GT_size: 144 32 | LQ_size: 36 33 | use_flip: true 34 | use_rot: true 35 | color: RGB 36 | 37 | val: 38 | num_video: 5 39 | name: Vid4 40 | mode: test 41 | data_path: vimeo90k/sequences 42 | txt_path: vimeo90k/sep_testlist.txt 43 | N_frames: 1 44 | padding: 'new_info' 45 | pred_interval: -1 46 | 47 | 48 | #### network structures 49 | 50 | network_G: 51 | which_model_G: 52 | subnet_type: DBNet 53 | in_nc: 12 54 | out_nc: 12 55 | block_num: [8, 8] 56 | scale: 2 57 | init: xavier_group 58 | block_num_rbm: 8 59 | 60 | 61 | #### path 62 | 63 | path: 64 | pretrain_model_G: 65 | models: ckp/base 66 | strict_load: true 67 | resume_state: ~ 68 | 69 | 70 | #### training settings: learning rate scheme, loss 71 | 72 | train: 73 | 74 | lr_G: !!float 1e-4 75 | beta1: 0.9 76 | beta2: 0.5 77 | niter: 250000 78 | warmup_iter: -1 # no warm up 79 | 80 | lr_scheme: MultiStepLR 81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] 82 | lr_gamma: 0.5 83 | 84 | pixel_criterion_forw: l2 85 | pixel_criterion_back: l1 86 | 87 | manual_seed: 10 88 | 89 | val_freq: !!float 1000 #!!float 5e3 90 | 91 | lambda_fit_forw: 64. 92 | lambda_rec_back: 1 93 | lambda_center: 0 94 | 95 | weight_decay_G: !!float 1e-12 96 | gradient_clipping: 10 97 | 98 | 99 | #### logger 100 | 101 | logger: 102 | print_freq: 100 103 | save_checkpoint_freq: !!float 5e3 104 | -------------------------------------------------------------------------------- /code/options/train/train_LF-VSN_6video.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: train_LF-VSN_6video 4 | use_tb_logger: true 5 | model: MIMO-VRN-h 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0, 1] 9 | gop: 3 10 | num_video: 6 11 | 12 | #### datasets 13 | 14 | datasets: 15 | train: 16 | name: Vimeo90K 17 | mode: train 18 | interval_list: [1] 19 | random_reverse: false 20 | border_mode: false 21 | data_path: vimeo90k/sequences 22 | txt_path: vimeo90k/sep_trainlist.txt 23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb 24 | cache_keys: Vimeo90K_train_keys.pkl 25 | num_video: 6 26 | 27 | N_frames: 7 28 | use_shuffle: true 29 | n_workers: 24 # per GPU 30 | batch_size: 8 31 | GT_size: 144 32 | LQ_size: 36 33 | use_flip: true 34 | use_rot: true 35 | color: RGB 36 | 37 | val: 38 | num_video: 6 39 | name: Vid4 40 | mode: test 41 | data_path: vimeo90k/sequences 42 | txt_path: vimeo90k/sep_testlist.txt 43 | N_frames: 1 44 | padding: 'new_info' 45 | pred_interval: -1 46 | 47 | 48 | #### network structures 49 | 50 | network_G: 51 | which_model_G: 52 | subnet_type: DBNet 53 | in_nc: 12 54 | out_nc: 12 55 | block_num: [8, 8] 56 | scale: 2 57 | init: xavier_group 58 | block_num_rbm: 8 59 | 60 | 61 | #### path 62 | 63 | path: 64 | pretrain_model_G: 65 | models: ckp/base 66 | strict_load: true 67 | resume_state: ~ 68 | 69 | 70 | #### training settings: learning rate scheme, loss 71 | 72 | train: 73 | 74 | lr_G: !!float 1e-4 75 | beta1: 0.9 76 | beta2: 0.5 77 | niter: 250000 78 | warmup_iter: -1 # no warm up 79 | 80 | lr_scheme: MultiStepLR 81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] 82 | lr_gamma: 0.5 83 | 84 | pixel_criterion_forw: l2 85 | pixel_criterion_back: l1 86 | 87 | manual_seed: 10 88 | 89 | val_freq: !!float 1000 #!!float 5e3 90 | 91 | lambda_fit_forw: 64. 92 | lambda_rec_back: 1 93 | lambda_center: 0 94 | 95 | weight_decay_G: !!float 1e-12 96 | gradient_clipping: 10 97 | 98 | 99 | #### logger 100 | 101 | logger: 102 | print_freq: 100 103 | save_checkpoint_freq: !!float 5e3 104 | -------------------------------------------------------------------------------- /code/options/train/train_LF-VSN_7video.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: train_LF-VSN_7video 4 | use_tb_logger: true 5 | model: MIMO-VRN-h 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0, 1] 9 | gop: 3 10 | num_video: 7 11 | 12 | #### datasets 13 | 14 | datasets: 15 | train: 16 | name: Vimeo90K 17 | mode: train 18 | interval_list: [1] 19 | random_reverse: false 20 | border_mode: false 21 | data_path: vimeo90k/sequences 22 | txt_path: vimeo90k/sep_trainlist.txt 23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb 24 | cache_keys: Vimeo90K_train_keys.pkl 25 | num_video: 7 26 | 27 | N_frames: 7 28 | use_shuffle: true 29 | n_workers: 24 # per GPU 30 | batch_size: 8 31 | GT_size: 144 32 | LQ_size: 36 33 | use_flip: true 34 | use_rot: true 35 | color: RGB 36 | 37 | val: 38 | num_video: 7 39 | name: Vid4 40 | mode: test 41 | data_path: vimeo90k/sequences 42 | txt_path: vimeo90k/sep_testlist.txt 43 | N_frames: 1 44 | padding: 'new_info' 45 | pred_interval: -1 46 | 47 | 48 | #### network structures 49 | 50 | network_G: 51 | which_model_G: 52 | subnet_type: DBNet 53 | in_nc: 12 54 | out_nc: 12 55 | block_num: [8, 8] 56 | scale: 2 57 | init: xavier_group 58 | block_num_rbm: 8 59 | 60 | 61 | #### path 62 | 63 | path: 64 | pretrain_model_G: 65 | models: ckp/base 66 | strict_load: true 67 | resume_state: ~ 68 | 69 | 70 | #### training settings: learning rate scheme, loss 71 | 72 | train: 73 | 74 | lr_G: !!float 1e-4 75 | beta1: 0.9 76 | beta2: 0.5 77 | niter: 250000 78 | warmup_iter: -1 # no warm up 79 | 80 | lr_scheme: MultiStepLR 81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] 82 | lr_gamma: 0.5 83 | 84 | pixel_criterion_forw: l2 85 | pixel_criterion_back: l1 86 | 87 | manual_seed: 10 88 | 89 | val_freq: !!float 1000 #!!float 5e3 90 | 91 | lambda_fit_forw: 64. 92 | lambda_rec_back: 1 93 | lambda_center: 0 94 | 95 | weight_decay_G: !!float 1e-12 96 | gradient_clipping: 10 97 | 98 | 99 | #### logger 100 | 101 | logger: 102 | print_freq: 100 103 | save_checkpoint_freq: !!float 5e3 104 | -------------------------------------------------------------------------------- /code/options/train/train_LF-VSN_1video.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: train_LF-VSN_1video 4 | use_tb_logger: true 5 | model: MIMO-VRN-h 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0, 1] 9 | gop: 3 10 | num_video: 1 11 | 12 | #### datasets 13 | 14 | datasets: 15 | train: 16 | name: Vimeo90K 17 | mode: train 18 | interval_list: [1] 19 | random_reverse: false 20 | border_mode: false 21 | data_path: vimeo90k/sequences 22 | txt_path: vimeo90k/sep_trainlist.txt 23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb 24 | cache_keys: Vimeo90K_train_keys.pkl 25 | num_video: 1 26 | 27 | N_frames: 7 28 | use_shuffle: true 29 | n_workers: 24 # per GPU 30 | batch_size: 8 31 | GT_size: 144 32 | LQ_size: 36 33 | use_flip: true 34 | use_rot: true 35 | color: RGB 36 | 37 | val: 38 | num_video: 1 39 | name: Vid4 40 | mode: test 41 | data_path: vimeo90k/sequences 42 | txt_path: vimeo90k/sep_testlist.txt 43 | 44 | N_frames: 1 45 | padding: 'new_info' 46 | pred_interval: -1 47 | 48 | 49 | #### network structures 50 | 51 | network_G: 52 | which_model_G: 53 | subnet_type: DBNet 54 | in_nc: 12 55 | out_nc: 12 56 | block_num: [8, 8] 57 | scale: 2 58 | init: xavier_group 59 | block_num_rbm: 8 60 | 61 | 62 | #### path 63 | 64 | path: 65 | pretrain_model_G: 66 | models: ckp/base 67 | strict_load: true 68 | resume_state: ~ 69 | 70 | 71 | #### training settings: learning rate scheme, loss 72 | 73 | train: 74 | 75 | lr_G: !!float 1e-4 76 | beta1: 0.9 77 | beta2: 0.5 78 | niter: 250000 79 | warmup_iter: -1 # no warm up 80 | 81 | lr_scheme: MultiStepLR 82 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] 83 | lr_gamma: 0.5 84 | 85 | pixel_criterion_forw: l2 86 | pixel_criterion_back: l1 87 | 88 | manual_seed: 10 89 | 90 | val_freq: !!float 1000 #!!float 5e3 91 | 92 | lambda_fit_forw: 64. 93 | lambda_rec_back: 1 94 | lambda_center: 0 95 | 96 | weight_decay_G: !!float 1e-12 97 | gradient_clipping: 10 98 | 99 | 100 | #### logger 101 | 102 | logger: 103 | print_freq: 100 104 | save_checkpoint_freq: !!float 5e3 105 | -------------------------------------------------------------------------------- /code/options/train/train_LF-VSN_2video.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: train_LF-VSN_2video 4 | use_tb_logger: true 5 | model: MIMO-VRN-h 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0, 1] 9 | gop: 3 10 | num_video: 2 11 | 12 | #### datasets 13 | 14 | datasets: 15 | train: 16 | name: Vimeo90K 17 | mode: train 18 | interval_list: [1] 19 | random_reverse: false 20 | border_mode: false 21 | data_path: vimeo90k/sequences 22 | txt_path: vimeo90k/sep_trainlist.txt 23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb 24 | cache_keys: Vimeo90K_train_keys.pkl 25 | num_video: 2 26 | 27 | N_frames: 7 28 | use_shuffle: true 29 | n_workers: 24 # per GPU 30 | batch_size: 8 31 | GT_size: 144 32 | LQ_size: 36 33 | use_flip: true 34 | use_rot: true 35 | color: RGB 36 | 37 | val: 38 | num_video: 2 39 | name: Vid4 40 | mode: test 41 | data_path: vimeo90k/sequences 42 | txt_path: vimeo90k/sep_testlist.txt 43 | 44 | N_frames: 1 45 | padding: 'new_info' 46 | pred_interval: -1 47 | 48 | 49 | #### network structures 50 | 51 | network_G: 52 | which_model_G: 53 | subnet_type: DBNet 54 | in_nc: 12 55 | out_nc: 12 56 | block_num: [8, 8] 57 | scale: 2 58 | init: xavier_group 59 | block_num_rbm: 8 60 | 61 | 62 | #### path 63 | 64 | path: 65 | pretrain_model_G: 66 | models: ckp/base 67 | strict_load: true 68 | resume_state: ~ 69 | 70 | 71 | #### training settings: learning rate scheme, loss 72 | 73 | train: 74 | 75 | lr_G: !!float 1e-4 76 | beta1: 0.9 77 | beta2: 0.5 78 | niter: 250000 79 | warmup_iter: -1 # no warm up 80 | 81 | lr_scheme: MultiStepLR 82 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000] 83 | lr_gamma: 0.5 84 | 85 | pixel_criterion_forw: l2 86 | pixel_criterion_back: l1 87 | 88 | manual_seed: 10 89 | 90 | val_freq: !!float 1000 #!!float 5e3 91 | 92 | lambda_fit_forw: 64. 93 | lambda_rec_back: 1 94 | lambda_center: 0 95 | 96 | weight_decay_G: !!float 1e-12 97 | gradient_clipping: 10 98 | 99 | 100 | #### logger 101 | 102 | logger: 103 | print_freq: 100 104 | save_checkpoint_freq: !!float 5e3 105 | -------------------------------------------------------------------------------- /code/models/modules/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def dwt_init3d(x): 7 | 8 | x01 = x[:, :, :, 0::2, :] / 2 9 | x02 = x[:, :, :, 1::2, :] / 2 10 | x1 = x01[:, :, :, :, 0::2] 11 | x2 = x02[:, :, :, :, 0::2] 12 | x3 = x01[:, :, :, :, 1::2] 13 | x4 = x02[:, :, :, :, 1::2] 14 | x_LL = x1 + x2 + x3 + x4 15 | x_HL = -x1 - x2 + x3 + x4 16 | x_LH = -x1 + x2 - x3 + x4 17 | x_HH = x1 - x2 - x3 + x4 18 | 19 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) 20 | 21 | def dwt_init(x): 22 | 23 | x01 = x[:, :, 0::2, :] / 2 24 | x02 = x[:, :, 1::2, :] / 2 25 | x1 = x01[:, :, :, 0::2] 26 | x2 = x02[:, :, :, 0::2] 27 | x3 = x01[:, :, :, 1::2] 28 | x4 = x02[:, :, :, 1::2] 29 | x_LL = x1 + x2 + x3 + x4 30 | x_HL = -x1 - x2 + x3 + x4 31 | x_LH = -x1 + x2 - x3 + x4 32 | x_HH = x1 - x2 - x3 + x4 33 | 34 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) 35 | 36 | def iwt_init(x): 37 | r = 2 38 | in_batch, in_channel, in_height, in_width = x.size() 39 | #print([in_batch, in_channel, in_height, in_width]) 40 | out_batch, out_channel, out_height, out_width = in_batch, int( 41 | in_channel / (r ** 2)), r * in_height, r * in_width 42 | x1 = x[:, 0:out_channel, :, :] / 2 43 | x2 = x[:, out_channel:out_channel * 2, :, :] / 2 44 | x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 45 | x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 46 | 47 | 48 | h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda() 49 | 50 | h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 51 | h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 52 | h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 53 | h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 54 | 55 | return h 56 | 57 | class DWT(nn.Module): 58 | def __init__(self): 59 | super(DWT, self).__init__() 60 | self.requires_grad = False 61 | 62 | def forward(self, x): 63 | return dwt_init(x) 64 | 65 | class DWT3d(nn.Module): 66 | def __init__(self): 67 | super(DWT3d, self).__init__() 68 | self.requires_grad = False 69 | 70 | def forward(self, x): 71 | return dwt_init3d(x) 72 | 73 | class IWT(nn.Module): 74 | def __init__(self): 75 | super(IWT, self).__init__() 76 | self.requires_grad = False 77 | 78 | def forward(self, x): 79 | return iwt_init(x) -------------------------------------------------------------------------------- /code/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /code/models/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ReconstructionLoss(nn.Module): 7 | def __init__(self, losstype='l2', eps=1e-6): 8 | super(ReconstructionLoss, self).__init__() 9 | self.losstype = losstype 10 | self.eps = eps 11 | 12 | def forward(self, x, target): 13 | if self.losstype == 'l2': 14 | return torch.mean(torch.sum((x - target) ** 2, (1, 2, 3))) 15 | elif self.losstype == 'l1': 16 | diff = x - target 17 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) 18 | elif self.losstype == 'center': 19 | return torch.sum((x - target) ** 2, (1, 2, 3)) 20 | 21 | else: 22 | print("reconstruction loss type error!") 23 | return 0 24 | 25 | 26 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 27 | class GANLoss(nn.Module): 28 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 29 | super(GANLoss, self).__init__() 30 | self.gan_type = gan_type.lower() 31 | self.real_label_val = real_label_val 32 | self.fake_label_val = fake_label_val 33 | 34 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 35 | self.loss = nn.BCEWithLogitsLoss() 36 | elif self.gan_type == 'lsgan': 37 | self.loss = nn.MSELoss() 38 | elif self.gan_type == 'wgan-gp': 39 | 40 | def wgan_loss(input, target): 41 | # target is boolean 42 | return -1 * input.mean() if target else input.mean() 43 | 44 | self.loss = wgan_loss 45 | else: 46 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 47 | 48 | def get_target_label(self, input, target_is_real): 49 | if self.gan_type == 'wgan-gp': 50 | return target_is_real 51 | if target_is_real: 52 | return torch.empty_like(input).fill_(self.real_label_val) 53 | else: 54 | return torch.empty_like(input).fill_(self.fake_label_val) 55 | 56 | def forward(self, input, target_is_real): 57 | target_label = self.get_target_label(input, target_is_real) 58 | loss = self.loss(input, target_label) 59 | return loss 60 | 61 | 62 | class GradientPenaltyLoss(nn.Module): 63 | def __init__(self, device=torch.device('cpu')): 64 | super(GradientPenaltyLoss, self).__init__() 65 | self.register_buffer('grad_outputs', torch.Tensor()) 66 | self.grad_outputs = self.grad_outputs.to(device) 67 | 68 | def get_grad_outputs(self, input): 69 | if self.grad_outputs.size() != input.size(): 70 | self.grad_outputs.resize_(input.size()).fill_(1.0) 71 | return self.grad_outputs 72 | 73 | def forward(self, interp, interp_crit): 74 | grad_outputs = self.get_grad_outputs(interp_crit) 75 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 76 | grad_outputs=grad_outputs, create_graph=True, 77 | retain_graph=True, only_inputs=True)[0] 78 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 79 | grad_interp_norm = grad_interp.norm(2, dim=1) 80 | 81 | loss = ((grad_interp_norm - 1) ** 2).mean() 82 | return loss 83 | -------------------------------------------------------------------------------- /code/models/modules/Subnet_constructor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.modules.module_util as mutil 5 | from basicsr.archs.arch_util import flow_warp, ResidualBlockNoBN 6 | from models.modules.module_util import initialize_weights_xavier 7 | 8 | class DenseBlock(nn.Module): 9 | def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True): 10 | super(DenseBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) 12 | self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) 13 | self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) 14 | self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) 15 | self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) 16 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 17 | self.H = None 18 | 19 | if init == 'xavier': 20 | mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) 21 | else: 22 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) 23 | mutil.initialize_weights(self.conv5, 0) 24 | 25 | def forward(self, x): 26 | if isinstance(x, list): 27 | x = x[0] 28 | x1 = self.lrelu(self.conv1(x)) 29 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 30 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 31 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 32 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 33 | 34 | return x5 35 | 36 | class DenseBlock_v2(nn.Module): 37 | def __init__(self, channel_in, channel_out, groups, init='xavier', gc=32, bias=True): 38 | super(DenseBlock_v2, self).__init__() 39 | self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) 40 | self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) 41 | self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) 42 | self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) 43 | self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) 44 | self.conv_final = nn.Conv2d(channel_out*groups, channel_out, 3, 1, 1, bias=bias) 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 46 | 47 | if init == 'xavier': 48 | mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 49 | else: 50 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 51 | mutil.initialize_weights(self.conv_final, 0) 52 | 53 | def forward(self, x): 54 | res = [] 55 | for xi in x: 56 | x1 = self.lrelu(self.conv1(xi)) 57 | x2 = self.lrelu(self.conv2(torch.cat((xi, x1), 1))) 58 | x3 = self.lrelu(self.conv3(torch.cat((xi, x1, x2), 1))) 59 | x4 = self.lrelu(self.conv4(torch.cat((xi, x1, x2, x3), 1))) 60 | x5 = self.lrelu(self.conv5(torch.cat((xi, x1, x2, x3, x4), 1))) 61 | res.append(x5) 62 | res = torch.cat(res, dim=1) 63 | res = self.conv_final(res) 64 | 65 | return res 66 | 67 | def subnet(net_structure, init='xavier'): 68 | def constructor(channel_in, channel_out, groups=None): 69 | if net_structure == 'DBNet': 70 | if init == 'xavier': 71 | return DenseBlock(channel_in, channel_out, init) 72 | elif init == 'xavier_v2': 73 | return DenseBlock_v2(channel_in, channel_out, groups, 'xavier') 74 | else: 75 | return DenseBlock(channel_in, channel_out) 76 | else: 77 | return None 78 | 79 | return constructor 80 | -------------------------------------------------------------------------------- /code/models/modules/module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def initialize_weights_xavier(net_l, scale=1): 28 | if not isinstance(net_l, list): 29 | net_l = [net_l] 30 | for net in net_l: 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.xavier_normal_(m.weight) 34 | m.weight.data *= scale # for residual block 35 | if m.bias is not None: 36 | m.bias.data.zero_() 37 | elif isinstance(m, nn.Linear): 38 | init.xavier_normal_(m.weight) 39 | m.weight.data *= scale 40 | if m.bias is not None: 41 | m.bias.data.zero_() 42 | elif isinstance(m, nn.BatchNorm2d): 43 | init.constant_(m.weight, 1) 44 | init.constant_(m.bias.data, 0.0) 45 | 46 | 47 | def make_layer(block, n_layers): 48 | layers = [] 49 | for _ in range(n_layers): 50 | layers.append(block()) 51 | return nn.Sequential(*layers) 52 | 53 | 54 | class ResidualBlock_noBN(nn.Module): 55 | '''Residual block w/o BN 56 | ---Conv-ReLU-Conv-+- 57 | |________________| 58 | ''' 59 | 60 | def __init__(self, nf=64): 61 | super(ResidualBlock_noBN, self).__init__() 62 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 64 | 65 | # initialization 66 | initialize_weights([self.conv1, self.conv2], 0.1) 67 | 68 | def forward(self, x): 69 | identity = x 70 | out = F.relu(self.conv1(x), inplace=True) 71 | out = self.conv2(out) 72 | return identity + out 73 | 74 | 75 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 76 | """Warp an image or feature map with optical flow 77 | Args: 78 | x (Tensor): size (N, C, H, W) 79 | flow (Tensor): size (N, H, W, 2), normal value 80 | interp_mode (str): 'nearest' or 'bilinear' 81 | padding_mode (str): 'zeros' or 'border' or 'reflection' 82 | 83 | Returns: 84 | Tensor: warped image or feature map 85 | """ 86 | assert x.size()[-2:] == flow.size()[1:3] 87 | B, C, H, W = x.size() 88 | # mesh grid 89 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 90 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 91 | grid.requires_grad = False 92 | grid = grid.type_as(x) 93 | vgrid = grid + flow 94 | # scale grid to [-1,1] 95 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 96 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 97 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 98 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 99 | return output 100 | -------------------------------------------------------------------------------- /code/data/Vimeo90K_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Vimeo90K dataset 3 | support reading images from lmdb, image folder and memcached 4 | ''' 5 | import logging 6 | import os 7 | import os.path as osp 8 | import pickle 9 | import random 10 | 11 | import cv2 12 | import lmdb 13 | import numpy as np 14 | import torch 15 | import torch.utils.data as data 16 | 17 | import data.util as util 18 | 19 | try: 20 | import mc # import memcached 21 | except ImportError: 22 | pass 23 | logger = logging.getLogger('base') 24 | 25 | class Vimeo90KDataset(data.Dataset): 26 | ''' 27 | Reading the training Vimeo90K dataset 28 | key example: 00001_0001 (_1, ..., _7) 29 | GT (Ground-Truth): 4th frame; 30 | LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame 31 | ''' 32 | 33 | def __init__(self, opt): 34 | super(Vimeo90KDataset, self).__init__() 35 | self.opt = opt 36 | # get train indexes 37 | self.data_path = self.opt['data_path'] 38 | self.txt_path = self.opt['txt_path'] 39 | with open(self.txt_path) as f: 40 | self.list_video = f.readlines() 41 | self.list_video = [line.strip('\n') for line in self.list_video] 42 | # temporal augmentation 43 | self.interval_list = opt['interval_list'] 44 | self.random_reverse = opt['random_reverse'] 45 | logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( 46 | ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) 47 | self.data_type = self.opt['data_type'] 48 | random.shuffle(self.list_video) 49 | self.LR_input = True 50 | self.num_video = self.opt['num_video'] 51 | 52 | def _ensure_memcached(self): 53 | if self.mclient is None: 54 | # specify the config files 55 | server_list_config_file = None 56 | client_config_file = None 57 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, 58 | client_config_file) 59 | 60 | def __getitem__(self, index): 61 | GT_size = self.opt['GT_size'] 62 | video_name = self.list_video[index] 63 | path_frame = os.path.join(self.data_path, video_name) 64 | frames = [] 65 | for im_name in os.listdir(path_frame): 66 | if im_name.endswith('.png'): 67 | frames.append(util.read_img(None, osp.join(path_frame, im_name))) 68 | list_index_h = [] 69 | index_h = random.randint(0, len(self.list_video) - 1) 70 | list_index_h.append(index_h) 71 | for _ in range(self.num_video-1): 72 | index_h_i = random.randint(0, len(self.list_video) - 1) 73 | while index_h_i == index or index_h_i in list_index_h: 74 | index_h_i = random.randint(0, len(self.list_video) - 1) 75 | list_index_h.append(index_h_i) 76 | 77 | # random crop 78 | H, W, C = frames[0].shape 79 | rnd_h = random.randint(0, max(0, H - GT_size)) 80 | rnd_w = random.randint(0, max(0, W - GT_size)) 81 | frames = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in frames] 82 | # stack HQ images to NHWC, N is the frame number 83 | img_frames = np.stack(frames, axis=0) 84 | # BGR to RGB, HWC to CHW, numpy to tensor 85 | img_frames = img_frames[:, :, :, [2, 1, 0]] 86 | img_frames = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames, (0, 3, 1, 2)))).float() 87 | # process h_list 88 | list_h = [] 89 | for index_h_i in list_index_h: 90 | video_name_h = self.list_video[index_h_i] 91 | path_frame_h = os.path.join(self.data_path, video_name_h) 92 | frames_h = [] 93 | for im_name in os.listdir(path_frame_h): 94 | if im_name.endswith('.png'): 95 | frames_h.append(util.read_img(None, osp.join(path_frame_h, im_name))) 96 | frames_h = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in frames_h] 97 | img_frames_h = np.stack(frames_h, axis=0) 98 | img_frames_h = img_frames_h[:, :, :, [2, 1, 0]] 99 | img_frames_h = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames_h, (0, 3, 1, 2)))).float() 100 | list_h.append(img_frames_h.clone()) 101 | list_h = torch.stack(list_h, dim=0) 102 | 103 | return {'GT': img_frames, 'LQ': list_h} 104 | 105 | def __len__(self): 106 | return len(self.list_video) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large-capacity and Flexible Video Steganography via Invertible Neural Network (CVPR 2023) 2 | [Chong Mou](https://scholar.google.com.hk/citations?user=SYQoDk0AAAAJ&hl=en), Youmin Xu, Jiechong Song, [Chen Zhao](https://scholar.google.com/citations?hl=zh-CN&user=dUWdX5EAAAAJ), [Bernard Ghanem](https://www.bernardghanem.com/), [Jian Zhang](https://jianzhang.tech/) 3 | 4 | Official implementation of **[Large-capacity and Flexible Video Steganography via Invertible Neural Network](https://arxiv.org/abs/2304.12300)**. 5 | 6 | ## **Introduction** 7 |

8 | 9 |

10 | 11 | 12 | Video steganography is the art of unobtrusively concealing secret data in a cover video and then recovering the secret data through a decoding protocol at the receiver end. Although several attempts have been made, most of them are limited to low-capacity and fixed steganography. To rectify these weaknesses, we propose a **L**arge-capacity and **F**lexible **V**ideo **S**teganography **N**etwork (**LF-VSN**) in this paper. For large-capacity, we present a reversible pipeline to perform multiple videos hiding and recovering through a single invertible neural network (INN). Our method can **hide/recover 7 secret videos in/from 1 cover video** with promising performance. For flexibility, we propose a key-controllable scheme, enabling different receivers to recover particular secret videos from the same cover video through specific keys. Moreover, we further improve the flexibility by proposing a scalable strategy in multiple videos hiding, which can hide variable numbers of secret videos in a cover video with a single model and a single training session. Extensive experiments demonstrate that with the significant improvement of the video steganography performance, our proposed LF-VSN has high security, large hiding capacity, and flexibility. 13 | 14 | --- 15 | 16 | 20 | 21 | ## 🔧 **Dependencies and Installation** 22 | - Python 3.6 23 | - PyTorch >= 1.4.0 24 | - numpy 25 | - skimage 26 | - cv2 27 | 28 | ## ⏬ **Download Models** 29 | 30 | The pre-trained models are available at: 31 | 32 | 33 | | Mode | Download link | 34 | | :------------------- | :--------------------------------------------: | 35 | | One video hiding | [Google Drive](https://drive.google.com/file/d/1aEMZaigkMd2NUNXnOu2r0oa5IuLPCtTh/view?usp=share_link) | 36 | | Two video hiding | [Google Drive](https://drive.google.com/file/d/1Yd7tK9Y-J4fkXoL-5u8VifEVsW7OmZN0/view?usp=share_link) | 37 | | Three video hiding | [Google Drive](https://drive.google.com/file/d/1oeDDzkYMZ6tKpPnIUwSI2v_Rbn7vLQJo/view?usp=share_link) | 38 | | Four video hiding | [Google Drive](https://drive.google.com/file/d/1kyMKdfAG_gq6ArWChv6ZMLBsqT-QpS9j/view?usp=share_link) | 39 | | Five video hiding | [Google Drive](https://drive.google.com/file/d/1OlTL6_ZgsThPeYfxbpGrGvNoaisqThq2/view?usp=share_link) | 40 | | Six video hiding | [Google Drive](https://drive.google.com/file/d/1dr-ZIL-VP0ol4fRO7bGZYQoxRetA-GXW/view?usp=share_link) | 41 | | Seven video hiding | [Google Drive](https://drive.google.com/file/d/178cqpz_vS-mPlYwLuZP2qFc7pV7vrXrr/view?usp=share_link) | 42 | 43 | ## **Data Preparing** 44 | Please download the training and evaluation dataset from [Vimeo-90K](http://toflow.csail.mit.edu/). 45 | 46 | ## **Train** 47 | 48 | Training the desired model by changing the config file. 49 | 50 | ```bash 51 | python train.py -opt options/train/train_LF-VSN_1video.yml 52 | ``` 53 | 54 | ## **Test** 55 | 56 | Testing the desired model by changing the config file. 57 | 58 | ```bash 59 | python test.py -opt options/train/train_LF-VSN_1video.yml 60 | ``` 61 | 62 | ## **Qualitative Results** 63 |

64 | 65 |

66 | 67 | ## 🤗 **Acknowledgements** 68 | This code is built on [MIM-VRN (PyTorch)](https://github.com/ding3820/MIMO-VRN). We thank the authors for sharing their codes of MIMO-VRN. 69 | 70 | ## :e-mail: Contact 71 | 72 | If you have any question, please email `eechongm@gmail.com`. 73 | 74 | ## **Citation** 75 | 76 | If you find our work helpful in your resarch or work, please cite the following paper. 77 | ``` 78 | @inproceedings{mou2023lfvsn, 79 | title={Large-capacity and Flexible Video Steganography via Invertible Neural Network}, 80 | author={Chong Mou, Youmin Xu, Jiechong Song, Chen Zhao, Bernard Ghanem, Jian Zhang}, 81 | booktitle={CVPR}, 82 | year={2023} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /code/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale = opt['scale'] 20 | 21 | # datasets 22 | for phase, dataset in opt['datasets'].items(): 23 | phase = phase.split('_')[0] 24 | dataset['phase'] = phase 25 | if opt['distortion'] == 'sr': 26 | dataset['scale'] = scale 27 | is_lmdb = False 28 | if dataset.get('dataroot_GT', None) is not None: 29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 30 | if dataset['dataroot_GT'].endswith('lmdb'): 31 | is_lmdb = True 32 | # if dataset.get('dataroot_GT_bg', None) is not None: 33 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) 34 | if dataset.get('dataroot_LQ', None) is not None: 35 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 36 | if dataset['dataroot_LQ'].endswith('lmdb'): 37 | is_lmdb = True 38 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 39 | if dataset['mode'].endswith('mc'): # for memcached 40 | dataset['data_type'] = 'mc' 41 | dataset['mode'] = dataset['mode'].replace('_mc', '') 42 | 43 | # path 44 | for key, path in opt['path'].items(): 45 | if path and key in opt['path'] and key != 'strict_load': 46 | opt['path'][key] = osp.expanduser(path) 47 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 48 | if is_train: 49 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 50 | opt['path']['experiments_root'] = experiments_root 51 | opt['path']['models'] = osp.join(experiments_root, 'models') 52 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 53 | opt['path']['log'] = experiments_root 54 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 55 | 56 | # change some options for debug mode 57 | if 'debug' in opt['name']: 58 | opt['train']['val_freq'] = 8 59 | opt['logger']['print_freq'] = 1 60 | opt['logger']['save_checkpoint_freq'] = 8 61 | else: # test 62 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 63 | opt['path']['results_root'] = results_root 64 | opt['path']['log'] = results_root 65 | 66 | # network 67 | if opt['distortion'] == 'sr': 68 | opt['network_G']['scale'] = scale 69 | 70 | return opt 71 | 72 | 73 | def dict2str(opt, indent_l=1): 74 | '''dict to string for logger''' 75 | msg = '' 76 | for k, v in opt.items(): 77 | if isinstance(v, dict): 78 | msg += ' ' * (indent_l * 2) + k + ':[\n' 79 | msg += dict2str(v, indent_l + 1) 80 | msg += ' ' * (indent_l * 2) + ']\n' 81 | else: 82 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 83 | return msg 84 | 85 | 86 | class NoneDict(dict): 87 | def __missing__(self, key): 88 | return None 89 | 90 | 91 | # convert to NoneDict, which return None for missing key. 92 | def dict_to_nonedict(opt): 93 | if isinstance(opt, dict): 94 | new_opt = dict() 95 | for key, sub_opt in opt.items(): 96 | new_opt[key] = dict_to_nonedict(sub_opt) 97 | return NoneDict(**new_opt) 98 | elif isinstance(opt, list): 99 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 100 | else: 101 | return opt 102 | 103 | 104 | def check_resume(opt, resume_iter): 105 | '''Check resume states and pretrain_model paths''' 106 | logger = logging.getLogger('base') 107 | if opt['path']['resume_state']: 108 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 109 | 'pretrain_model_D', None) is not None: 110 | logger.warning('pretrain_model path will be ignored when resuming training.') 111 | 112 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 113 | '{}_G.pth'.format(resume_iter)) 114 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 115 | -------------------------------------------------------------------------------- /code/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | ''' set learning rate for warmup, 39 | lr_groups_l: list for lr_groups. each for a optimizer''' 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | # get the initial lr, which is set by the scheduler 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | #### set up warm up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | # return self.schedulers[0].get_lr()[0] 67 | return self.optimizers[0].param_groups[0]['lr'] 68 | 69 | def get_network_description(self, network): 70 | '''Get the string and total parameters of the network''' 71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 72 | network = network.module 73 | s = str(network) 74 | n = sum(map(lambda x: x.numel(), network.parameters())) 75 | return s, n 76 | 77 | def save_network(self, network, network_label, iter_label): 78 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 79 | save_path = os.path.join(self.opt['path']['models'], save_filename) 80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 81 | network = network.module 82 | state_dict = network.state_dict() 83 | for key, param in state_dict.items(): 84 | state_dict[key] = param.cpu() 85 | torch.save(state_dict, save_path) 86 | 87 | def load_network(self, load_path, network, strict=True): 88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 89 | network = network.module 90 | load_net = torch.load(load_path) 91 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 92 | for k, v in load_net.items(): 93 | if k.startswith('module.'): 94 | load_net_clean[k[7:]] = v 95 | else: 96 | load_net_clean[k] = v 97 | network.load_state_dict(load_net_clean, strict=strict) 98 | 99 | def save_training_state(self, epoch, iter_step): 100 | '''Saves training state during training, which will be used for resuming''' 101 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 102 | for s in self.schedulers: 103 | state['schedulers'].append(s.state_dict()) 104 | for o in self.optimizers: 105 | state['optimizers'].append(o.state_dict()) 106 | save_filename = '{}.state'.format(iter_step) 107 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 108 | torch.save(state, save_path) 109 | 110 | def resume_training(self, resume_state): 111 | '''Resume the optimizers and schedulers for training''' 112 | resume_optimizers = resume_state['optimizers'] 113 | resume_schedulers = resume_state['schedulers'] 114 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 115 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 116 | for i, o in enumerate(resume_optimizers): 117 | self.optimizers[i].load_state_dict(o) 118 | for i, s in enumerate(resume_schedulers): 119 | self.schedulers[i].load_state_dict(s) 120 | -------------------------------------------------------------------------------- /code/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restart_weights = weights if weights else [1] 16 | assert len(self.restarts) == len( 17 | self.restart_weights), 'restarts and their weights do not match.' 18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | if self.last_epoch in self.restarts: 22 | if self.clear_state: 23 | self.optimizer.state = defaultdict(dict) 24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 26 | if self.last_epoch not in self.milestones: 27 | return [group['lr'] for group in self.optimizer.param_groups] 28 | return [ 29 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 30 | for group in self.optimizer.param_groups 31 | ] 32 | 33 | 34 | class CosineAnnealingLR_Restart(_LRScheduler): 35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 36 | self.T_period = T_period 37 | self.T_max = self.T_period[0] # current T period 38 | self.eta_min = eta_min 39 | self.restarts = restarts if restarts else [0] 40 | self.restart_weights = weights if weights else [1] 41 | self.last_restart = 0 42 | assert len(self.restarts) == len( 43 | self.restart_weights), 'restarts and their weights do not match.' 44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | if self.last_epoch == 0: 48 | return self.base_lrs 49 | elif self.last_epoch in self.restarts: 50 | self.last_restart = self.last_epoch 51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 55 | return [ 56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 58 | ] 59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 61 | (group['lr'] - self.eta_min) + self.eta_min 62 | for group in self.optimizer.param_groups] 63 | 64 | 65 | if __name__ == "__main__": 66 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 67 | betas=(0.9, 0.99)) 68 | ############################## 69 | # MultiStepLR_Restart 70 | ############################## 71 | ## Original 72 | lr_steps = [200000, 400000, 600000, 800000] 73 | restarts = None 74 | restart_weights = None 75 | 76 | ## two 77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 78 | restarts = [500000] 79 | restart_weights = [1] 80 | 81 | ## four 82 | lr_steps = [ 83 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 84 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 85 | ] 86 | restarts = [250000, 500000, 750000] 87 | restart_weights = [1, 1, 1] 88 | 89 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 90 | clear_state=False) 91 | 92 | ############################## 93 | # Cosine Annealing Restart 94 | ############################## 95 | ## two 96 | T_period = [500000, 500000] 97 | restarts = [500000] 98 | restart_weights = [1] 99 | 100 | ## four 101 | T_period = [250000, 250000, 250000, 250000] 102 | restarts = [250000, 500000, 750000] 103 | restart_weights = [1, 1, 1] 104 | 105 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 106 | weights=restart_weights) 107 | 108 | ############################## 109 | # Draw figure 110 | ############################## 111 | N_iter = 1000000 112 | lr_l = list(range(N_iter)) 113 | for i in range(N_iter): 114 | scheduler.step() 115 | current_lr = optimizer.param_groups[0]['lr'] 116 | lr_l[i] = current_lr 117 | 118 | import matplotlib as mpl 119 | from matplotlib import pyplot as plt 120 | import matplotlib.ticker as mtick 121 | mpl.style.use('default') 122 | import seaborn 123 | seaborn.set(style='whitegrid') 124 | seaborn.set_context('paper') 125 | 126 | plt.figure(1) 127 | plt.subplot(111) 128 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 129 | plt.title('Title', fontsize=16, color='k') 130 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 131 | legend = plt.legend(loc='upper right', shadow=False) 132 | ax = plt.gca() 133 | labels = ax.get_xticks().tolist() 134 | for k, v in enumerate(labels): 135 | labels[k] = str(int(v / 1000)) + 'K' 136 | ax.set_xticklabels(labels) 137 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 138 | 139 | ax.set_ylabel('Learning rate') 140 | ax.set_xlabel('Iteration') 141 | fig = plt.gcf() 142 | plt.show() 143 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import logging 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | from data.data_sampler import DistIterSampler 11 | 12 | import options.options as option 13 | from utils import util 14 | from data import create_dataloader, create_dataset 15 | from models import create_model 16 | 17 | 18 | def init_dist(backend='nccl', **kwargs): 19 | ''' initialization for distributed training''' 20 | # if mp.get_start_method(allow_none=True) is None: 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn') 23 | rank = int(os.environ['RANK']) 24 | num_gpus = torch.cuda.device_count() 25 | torch.cuda.set_device(rank % num_gpus) 26 | dist.init_process_group(backend=backend, **kwargs) 27 | 28 | def cal_pnsr(sr_img, gt_img): 29 | # calculate PSNR 30 | gt_img = gt_img / 255. 31 | sr_img = sr_img / 255. 32 | 33 | psnr = util.calculate_psnr(sr_img * 255, gt_img * 255) 34 | 35 | return psnr 36 | 37 | def main(): 38 | # options 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件 41 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 42 | help='job launcher') 43 | parser.add_argument('--ckpt', type=str, default='/group/30042/chongmou/ft_local/LF-VSN-git/LF-VSN/ckpt/LF-VSN_2video_hiding_250k.pth', help='Path to pre-trained model.') 44 | parser.add_argument('--local_rank', type=int, default=0) 45 | args = parser.parse_args() 46 | opt = option.parse(args.opt, is_train=True) 47 | 48 | # distributed training settings 49 | if args.launcher == 'none': # disabled distributed training 50 | opt['dist'] = False 51 | rank = -1 52 | print('Disabled distributed training.') 53 | else: 54 | opt['dist'] = True 55 | init_dist() 56 | world_size = torch.distributed.get_world_size() 57 | rank = torch.distributed.get_rank() 58 | 59 | # loading resume state if exists 60 | if opt['path'].get('resume_state', None): 61 | # distributed resuming: all load into default GPU 62 | device_id = torch.cuda.current_device() 63 | resume_state = torch.load(opt['path']['resume_state'], 64 | map_location=lambda storage, loc: storage.cuda(device_id)) 65 | option.check_resume(opt, resume_state['iter']) # check resume options 66 | else: 67 | resume_state = None 68 | 69 | # convert to NoneDict, which returns None for missing keys 70 | opt = option.dict_to_nonedict(opt) 71 | 72 | torch.backends.cudnn.benchmark = True 73 | # torch.backends.cudnn.deterministic = True 74 | 75 | #### create train and val dataloader 76 | dataset_ratio = 200 # enlarge the size of each epoch 77 | for phase, dataset_opt in opt['datasets'].items(): 78 | if phase == 'train': 79 | train_set = create_dataset(dataset_opt) 80 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) 81 | total_iters = int(opt['train']['niter']) 82 | total_epochs = int(math.ceil(total_iters / train_size)) 83 | if opt['dist']: 84 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) 85 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) 86 | else: 87 | train_sampler = None 88 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) 89 | elif phase == 'val': 90 | val_set = create_dataset(dataset_opt) 91 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 92 | else: 93 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 94 | assert train_loader is not None 95 | 96 | # create model 97 | model = create_model(opt) 98 | model.load_test(args.ckpt) 99 | 100 | # validation 101 | avg_psnr = 0.0 102 | avg_psnr_h = [0.0]*opt['num_video'] 103 | avg_psnr_lr = 0.0 104 | idx = 0 105 | for video_id, val_data in enumerate(val_loader): 106 | img_dir = os.path.join('results',opt['name']) 107 | util.mkdir(img_dir) 108 | 109 | model.feed_data(val_data) 110 | model.test() 111 | 112 | visuals = model.get_current_visuals() 113 | 114 | t_step = visuals['SR'].shape[0] 115 | idx += t_step 116 | n = len(visuals['SR_h']) 117 | 118 | for i in range(t_step): 119 | 120 | sr_img = util.tensor2img(visuals['SR'][i]) # uint8 121 | sr_img_h = [] 122 | for j in range(n): 123 | sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8 124 | gt_img = util.tensor2img(visuals['GT'][i]) # uint8 125 | lr_img = util.tensor2img(visuals['LR'][i]) 126 | lrgt_img = [] 127 | for j in range(n): 128 | lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i])) 129 | 130 | # Save SR images for reference 131 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'SR')) 132 | util.save_img(sr_img, save_img_path) 133 | 134 | for j in range(n): 135 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'SR_h')) 136 | util.save_img(sr_img_h[j], save_img_path) 137 | 138 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'GT')) 139 | util.save_img(gt_img, save_img_path) 140 | 141 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'LR')) 142 | util.save_img(lr_img, save_img_path) 143 | 144 | for j in range(n): 145 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'LRGT')) 146 | util.save_img(lrgt_img[j], save_img_path) 147 | 148 | psnr = cal_pnsr(sr_img, gt_img) 149 | psnr_h = [] 150 | for j in range(n): 151 | psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j])) 152 | psnr_lr = cal_pnsr(lr_img, gt_img) 153 | 154 | avg_psnr += psnr 155 | for j in range(n): 156 | avg_psnr_h[j] += psnr_h[j] 157 | avg_psnr_lr += psnr_lr 158 | 159 | avg_psnr = avg_psnr / idx 160 | avg_psnr_h = [psnr / idx for psnr in avg_psnr_h] 161 | avg_psnr_lr = avg_psnr_lr / idx 162 | res_psnr_h = '' 163 | for p in avg_psnr_h: 164 | res_psnr_h+=('_{:.4e}'.format(p)) 165 | print('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr)) 166 | 167 | 168 | if __name__ == '__main__': 169 | main() -------------------------------------------------------------------------------- /code/models/discrim.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | from torch.nn.utils import spectral_norm 4 | 5 | 6 | class UNetDiscriminatorSN(nn.Module): 7 | """Defines a U-Net discriminator with spectral normalization (SN) 8 | 9 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 10 | 11 | Arg: 12 | num_in_ch (int): Channel number of inputs. Default: 3. 13 | num_feat (int): Channel number of base intermediate features. Default: 64. 14 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True. 15 | """ 16 | 17 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True): 18 | super(UNetDiscriminatorSN, self).__init__() 19 | self.skip_connection = skip_connection 20 | norm = spectral_norm 21 | # the first convolution 22 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 23 | # downsample 24 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 25 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 26 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 27 | # upsample 28 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 29 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 30 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 31 | # extra convolutions 32 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 33 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 34 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 35 | 36 | def forward(self, x): 37 | # downsample 38 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 39 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 40 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 41 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 42 | 43 | # upsample 44 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) 45 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 46 | 47 | if self.skip_connection: 48 | x4 = x4 + x2 49 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) 50 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 51 | 52 | if self.skip_connection: 53 | x5 = x5 + x1 54 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) 55 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 56 | 57 | if self.skip_connection: 58 | x6 = x6 + x0 59 | 60 | # extra convolutions 61 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 62 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 63 | out = self.conv9(out) 64 | 65 | return out 66 | 67 | 68 | class GANLoss(nn.Module): 69 | """Define GAN loss. 70 | 71 | Args: 72 | gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. 73 | real_label_val (float): The value for real label. Default: 1.0. 74 | fake_label_val (float): The value for fake label. Default: 0.0. 75 | loss_weight (float): Loss weight. Default: 1.0. 76 | Note that loss_weight is only for generators; and it is always 1.0 77 | for discriminators. 78 | """ 79 | 80 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): 81 | super(GANLoss, self).__init__() 82 | self.gan_type = gan_type 83 | self.loss_weight = loss_weight 84 | self.real_label_val = real_label_val 85 | self.fake_label_val = fake_label_val 86 | 87 | if self.gan_type == 'vanilla': 88 | self.loss = nn.BCEWithLogitsLoss() 89 | elif self.gan_type == 'lsgan': 90 | self.loss = nn.MSELoss() 91 | elif self.gan_type == 'wgan': 92 | self.loss = self._wgan_loss 93 | elif self.gan_type == 'wgan_softplus': 94 | self.loss = self._wgan_softplus_loss 95 | elif self.gan_type == 'hinge': 96 | self.loss = nn.ReLU() 97 | else: 98 | raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') 99 | 100 | def _wgan_loss(self, input, target): 101 | """wgan loss. 102 | 103 | Args: 104 | input (Tensor): Input tensor. 105 | target (bool): Target label. 106 | 107 | Returns: 108 | Tensor: wgan loss. 109 | """ 110 | return -input.mean() if target else input.mean() 111 | 112 | def _wgan_softplus_loss(self, input, target): 113 | """wgan loss with soft plus. softplus is a smooth approximation to the 114 | ReLU function. 115 | 116 | In StyleGAN2, it is called: 117 | Logistic loss for discriminator; 118 | Non-saturating loss for generator. 119 | 120 | Args: 121 | input (Tensor): Input tensor. 122 | target (bool): Target label. 123 | 124 | Returns: 125 | Tensor: wgan loss. 126 | """ 127 | return F.softplus(-input).mean() if target else F.softplus(input).mean() 128 | 129 | def get_target_label(self, input, target_is_real): 130 | """Get target label. 131 | 132 | Args: 133 | input (Tensor): Input tensor. 134 | target_is_real (bool): Whether the target is real or fake. 135 | 136 | Returns: 137 | (bool | Tensor): Target tensor. Return bool for wgan, otherwise, 138 | return Tensor. 139 | """ 140 | 141 | if self.gan_type in ['wgan', 'wgan_softplus']: 142 | return target_is_real 143 | target_val = (self.real_label_val if target_is_real else self.fake_label_val) 144 | return input.new_ones(input.size()) * target_val 145 | 146 | def forward(self, input, target_is_real, is_disc=False): 147 | """ 148 | Args: 149 | input (Tensor): The input for the loss module, i.e., the network 150 | prediction. 151 | target_is_real (bool): Whether the targe is real or fake. 152 | is_disc (bool): Whether the loss for discriminators or not. 153 | Default: False. 154 | 155 | Returns: 156 | Tensor: GAN loss value. 157 | """ 158 | target_label = self.get_target_label(input, target_is_real) 159 | if self.gan_type == 'hinge': 160 | if is_disc: # for discriminators in hinge-gan 161 | input = -input if target_is_real else input 162 | loss = self.loss(1 + input).mean() 163 | else: # for generators in hinge-gan 164 | loss = -input.mean() 165 | else: # other gan types 166 | loss = self.loss(input, target_label) 167 | 168 | # loss_weight is always 1.0 for discriminators 169 | return loss if is_disc else loss * self.loss_weight -------------------------------------------------------------------------------- /code/models/modules/Inv_arch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .module_util import initialize_weights_xavier 7 | from torch.nn import init 8 | from .common import DWT,IWT 9 | import cv2 10 | from basicsr.archs.arch_util import flow_warp 11 | from models.modules.Subnet_constructor import subnet 12 | import numpy as np 13 | 14 | dwt=DWT() 15 | iwt=IWT() 16 | 17 | def thops_mean(tensor, dim=None, keepdim=False): 18 | if dim is None: 19 | # mean all dim 20 | return torch.mean(tensor) 21 | else: 22 | if isinstance(dim, int): 23 | dim = [dim] 24 | dim = sorted(dim) 25 | for d in dim: 26 | tensor = tensor.mean(dim=d, keepdim=True) 27 | if not keepdim: 28 | for i, d in enumerate(dim): 29 | tensor.squeeze_(d-i) 30 | return tensor 31 | 32 | 33 | class ResidualBlockNoBN(nn.Module): 34 | def __init__(self, nf=64, model='MIMO-VRN'): 35 | super(ResidualBlockNoBN, self).__init__() 36 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 37 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 38 | # honestly, there's no significant difference between ReLU and leaky ReLU in terms of performance here 39 | # but this is how we trained the model in the first place and what we reported in the paper 40 | if model == 'LSTM-VRN': 41 | self.relu = nn.ReLU(inplace=True) 42 | elif model == 'MIMO-VRN': 43 | self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 44 | 45 | # initialization 46 | initialize_weights_xavier([self.conv1, self.conv2], 0.1) 47 | 48 | def forward(self, x): 49 | identity = x 50 | out = self.relu(self.conv1(x)) 51 | out = self.conv2(out) 52 | return identity + out 53 | 54 | 55 | class InvBlock(nn.Module): 56 | def __init__(self, subnet_constructor, subnet_constructor_v2, channel_num_ho, channel_num_hi, groups, clamp=1.): 57 | super(InvBlock, self).__init__() 58 | self.split_len1 = channel_num_ho # channel_split_num 59 | self.split_len2 = channel_num_hi # channel_num - channel_split_num 60 | self.clamp = clamp 61 | 62 | self.F = subnet_constructor_v2(self.split_len2, self.split_len1, groups=groups) 63 | if groups == 1: 64 | self.G = subnet_constructor(self.split_len1, self.split_len2, groups=groups) 65 | self.H = subnet_constructor(self.split_len1, self.split_len2, groups=groups) 66 | else: 67 | self.G = subnet_constructor(self.split_len1, self.split_len2) 68 | self.H = subnet_constructor(self.split_len1, self.split_len2) 69 | 70 | def forward(self, x1, x2, rev=False): 71 | if not rev: 72 | y1 = x1 + self.F(x2) 73 | self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) 74 | y2 = [x2i.mul(torch.exp(self.s)) + self.G(y1) for x2i in x2] 75 | else: 76 | self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1) 77 | # print(x2[0].shape, self.G(x1).shape) 78 | y2 = [(x2i - self.G(x1)).div(torch.exp(self.s)) for x2i in x2] 79 | y1 = x1 - self.F(y2) 80 | 81 | return y1, y2 # torch.cat((y1, y2), 1) 82 | 83 | def jacobian(self, x, rev=False): 84 | if not rev: 85 | jac = torch.sum(self.s) 86 | else: 87 | jac = -torch.sum(self.s) 88 | 89 | return jac / x.shape[0] 90 | 91 | class InvNN(nn.Module): 92 | def __init__(self, channel_in_ho=3, channel_in_hi=3, subnet_constructor=None, subnet_constructor_v2=None, block_num=[], down_num=2, groups=None): 93 | super(InvNN, self).__init__() 94 | operations = [] 95 | # current_channel = channel_in 96 | current_channel_ho = channel_in_ho 97 | current_channel_hi = channel_in_hi 98 | for i in range(down_num): 99 | for j in range(block_num[i]): 100 | b = InvBlock(subnet_constructor, subnet_constructor_v2, current_channel_ho, current_channel_hi, groups=groups) 101 | operations.append(b) 102 | 103 | self.operations = nn.ModuleList(operations) 104 | 105 | def forward(self, x, x_h, rev=False, cal_jacobian=False): 106 | # out = x 107 | jacobian = 0 108 | 109 | if not rev: 110 | for op in self.operations: 111 | x, x_h = op.forward(x, x_h, rev) 112 | if cal_jacobian: 113 | jacobian += op.jacobian(x, rev) 114 | else: 115 | for op in reversed(self.operations): 116 | x, x_h = op.forward(x, x_h, rev) 117 | if cal_jacobian: 118 | jacobian += op.jacobian(x, rev) 119 | 120 | if cal_jacobian: 121 | return x, x_h, jacobian 122 | else: 123 | return x, x_h 124 | 125 | class PredictiveModuleMIMO(nn.Module): 126 | def __init__(self, channel_in, nf, block_num_rbm=8): 127 | super(PredictiveModuleMIMO, self).__init__() 128 | self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True) 129 | residual_block = [] 130 | for i in range(block_num_rbm): 131 | residual_block.append(ResidualBlockNoBN(nf)) 132 | self.residual_block = nn.Sequential(*residual_block) 133 | 134 | def forward(self, x): 135 | x = self.conv_in(x) 136 | res = self.residual_block(x) 137 | 138 | return res 139 | 140 | def gauss_noise(shape): 141 | noise = torch.zeros(shape).cuda() 142 | for i in range(noise.shape[0]): 143 | noise[i] = torch.randn(noise[i].shape).cuda() 144 | 145 | return noise 146 | 147 | def gauss_noise_mul(shape): 148 | noise = torch.randn(shape).cuda() 149 | 150 | return noise 151 | 152 | class VSN(nn.Module): 153 | def __init__(self, opt, subnet_constructor=None, subnet_constructor_v2=None, down_num=2): 154 | super(VSN, self).__init__() 155 | self.model = opt['model'] 156 | opt_net = opt['network_G'] 157 | self.num_video = opt['num_video'] 158 | self.gop = opt['gop'] 159 | self.channel_in = opt_net['in_nc'] * self.gop 160 | self.channel_out = opt_net['out_nc'] * self.gop 161 | self.channel_in_hi = opt_net['in_nc'] * self.gop 162 | self.channel_in_ho = opt_net['in_nc'] * self.gop 163 | 164 | self.block_num = opt_net['block_num'] 165 | self.block_num_rbm = opt_net['block_num_rbm'] 166 | self.nf = self.channel_in_hi 167 | self.irn = InvNN(self.channel_in_ho, self.channel_in_hi, subnet_constructor, subnet_constructor_v2, self.block_num, down_num, groups=self.num_video) 168 | self.pm = PredictiveModuleMIMO(self.channel_in_ho, self.nf* self.num_video, block_num_rbm=self.block_num_rbm) 169 | 170 | def forward(self, x, x_h=None, rev=False, hs=[], direction='f'): 171 | if not rev: 172 | out_y, out_y_h = self.irn(x, x_h, rev) 173 | return out_y, out_y_h 174 | else: 175 | out_z = self.pm(x).unsqueeze(1) 176 | out_z_new = out_z.view(-1, self.num_video, self.channel_in, x.shape[-2], x.shape[-1]) 177 | out_z_new = [out_z_new[:,i] for i in range(self.num_video)] 178 | out_x, out_x_h = self.irn(x, out_z_new, rev) 179 | 180 | return out_x, out_x_h, out_z 181 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | from datetime import datetime 6 | import random 7 | import logging 8 | from collections import OrderedDict 9 | import numpy as np 10 | import cv2 11 | import torch 12 | from torchvision.utils import make_grid 13 | from shutil import get_terminal_size 14 | 15 | import yaml 16 | try: 17 | from yaml import CLoader as Loader, CDumper as Dumper 18 | except ImportError: 19 | from yaml import Loader, Dumper 20 | 21 | 22 | def OrderedYaml(): 23 | '''yaml orderedDict support''' 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | #################### 38 | # miscellaneous 39 | #################### 40 | 41 | 42 | def get_timestamp(): 43 | return datetime.now().strftime('%y%m%d-%H%M%S') 44 | 45 | 46 | def mkdir(path): 47 | if not os.path.exists(path): 48 | os.makedirs(path) 49 | 50 | 51 | def mkdirs(paths): 52 | if isinstance(paths, str): 53 | mkdir(paths) 54 | else: 55 | for path in paths: 56 | mkdir(path) 57 | 58 | 59 | def mkdir_and_rename(path): 60 | # print(path) 61 | # exit(0) 62 | if os.path.exists(path): 63 | new_name = path + '_archived_' + get_timestamp() 64 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 65 | logger = logging.getLogger('base') 66 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 67 | # path = new_name 68 | os.rename(path, new_name) 69 | os.makedirs(path) 70 | # return path 71 | 72 | 73 | def set_random_seed(seed): 74 | random.seed(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | 79 | 80 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 81 | '''set up logger''' 82 | lg = logging.getLogger(logger_name) 83 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 84 | datefmt='%y-%m-%d %H:%M:%S') 85 | lg.setLevel(level) 86 | if tofile: 87 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 88 | fh = logging.FileHandler(log_file, mode='w') 89 | fh.setFormatter(formatter) 90 | lg.addHandler(fh) 91 | if screen: 92 | sh = logging.StreamHandler() 93 | sh.setFormatter(formatter) 94 | lg.addHandler(sh) 95 | 96 | 97 | #################### 98 | # image convert 99 | #################### 100 | 101 | 102 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 103 | ''' 104 | Converts a torch Tensor into an image Numpy array 105 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 106 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 107 | ''' 108 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 109 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 110 | n_dim = tensor.dim() 111 | if n_dim == 4: 112 | n_img = len(tensor) 113 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 114 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 115 | elif n_dim == 3: 116 | img_np = tensor.numpy() 117 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 118 | elif n_dim == 2: 119 | img_np = tensor.numpy() 120 | else: 121 | raise TypeError( 122 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 123 | if out_type == np.uint8: 124 | img_np = (img_np * 255.0).round() 125 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 126 | return img_np.astype(out_type) 127 | 128 | 129 | def save_img(img, img_path, mode='RGB'): 130 | cv2.imwrite(img_path, img) 131 | 132 | 133 | #################### 134 | # metric 135 | #################### 136 | 137 | 138 | def calculate_psnr(img1, img2): 139 | # img1 and img2 have range [0, 255] 140 | img1 = img1.astype(np.float64) 141 | img2 = img2.astype(np.float64) 142 | mse = np.mean((img1 - img2)**2) 143 | if mse == 0: 144 | return float('inf') 145 | return 20 * math.log10(255.0 / math.sqrt(mse)) 146 | 147 | 148 | def ssim(img1, img2): 149 | C1 = (0.01 * 255)**2 150 | C2 = (0.03 * 255)**2 151 | 152 | img1 = img1.astype(np.float64) 153 | img2 = img2.astype(np.float64) 154 | kernel = cv2.getGaussianKernel(11, 1.5) 155 | window = np.outer(kernel, kernel.transpose()) 156 | 157 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 158 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 159 | mu1_sq = mu1**2 160 | mu2_sq = mu2**2 161 | mu1_mu2 = mu1 * mu2 162 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 163 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 164 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 165 | 166 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 167 | (sigma1_sq + sigma2_sq + C2)) 168 | return ssim_map.mean() 169 | 170 | 171 | def calculate_ssim(img1, img2): 172 | '''calculate SSIM 173 | the same outputs as MATLAB's 174 | img1, img2: [0, 255] 175 | ''' 176 | if not img1.shape == img2.shape: 177 | raise ValueError('Input images must have the same dimensions.') 178 | if img1.ndim == 2: 179 | return ssim(img1, img2) 180 | elif img1.ndim == 3: 181 | if img1.shape[2] == 3: 182 | ssims = [] 183 | for i in range(3): 184 | ssims.append(ssim(img1, img2)) 185 | return np.array(ssims).mean() 186 | elif img1.shape[2] == 1: 187 | return ssim(np.squeeze(img1), np.squeeze(img2)) 188 | else: 189 | raise ValueError('Wrong input image dimensions.') 190 | 191 | 192 | class ProgressBar(object): 193 | '''A progress bar which can print the progress 194 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 195 | ''' 196 | 197 | def __init__(self, task_num=0, bar_width=50, start=True): 198 | self.task_num = task_num 199 | max_bar_width = self._get_max_bar_width() 200 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 201 | self.completed = 0 202 | if start: 203 | self.start() 204 | 205 | def _get_max_bar_width(self): 206 | terminal_width, _ = get_terminal_size() 207 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 208 | if max_bar_width < 10: 209 | print('terminal width is too small ({}), please consider widen the terminal for better ' 210 | 'progressbar visualization'.format(terminal_width)) 211 | max_bar_width = 10 212 | return max_bar_width 213 | 214 | def start(self): 215 | if self.task_num > 0: 216 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 217 | ' ' * self.bar_width, self.task_num, 'Start...')) 218 | else: 219 | sys.stdout.write('completed: 0, elapsed: 0s') 220 | sys.stdout.flush() 221 | self.start_time = time.time() 222 | 223 | def update(self, msg='In progress...'): 224 | self.completed += 1 225 | elapsed = time.time() - self.start_time 226 | fps = self.completed / elapsed 227 | if self.task_num > 0: 228 | percentage = self.completed / float(self.task_num) 229 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 230 | mark_width = int(self.bar_width * percentage) 231 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 232 | sys.stdout.write('\033[2F') # cursor up 2 lines 233 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 234 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 235 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 236 | else: 237 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 238 | self.completed, int(elapsed + 0.5), fps)) 239 | sys.stdout.flush() 240 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import logging 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | from data.data_sampler import DistIterSampler 11 | 12 | import options.options as option 13 | from utils import util 14 | from data import create_dataloader, create_dataset 15 | from models import create_model 16 | 17 | 18 | def init_dist(backend='nccl', **kwargs): 19 | ''' initialization for distributed training''' 20 | # if mp.get_start_method(allow_none=True) is None: 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn') 23 | rank = int(os.environ['RANK']) 24 | num_gpus = torch.cuda.device_count() 25 | torch.cuda.set_device(rank % num_gpus) 26 | dist.init_process_group(backend=backend, **kwargs) 27 | 28 | def cal_pnsr(sr_img, gt_img): 29 | # calculate PSNR 30 | gt_img = gt_img / 255. 31 | sr_img = sr_img / 255. 32 | psnr = util.calculate_psnr(sr_img * 255, gt_img * 255) 33 | 34 | return psnr 35 | 36 | def main(): 37 | # options 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件 40 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 41 | help='job launcher') 42 | parser.add_argument('--local_rank', type=int, default=0) 43 | args = parser.parse_args() 44 | opt = option.parse(args.opt, is_train=True) 45 | 46 | # distributed training settings 47 | if args.launcher == 'none': # disabled distributed training 48 | opt['dist'] = False 49 | rank = -1 50 | print('Disabled distributed training.') 51 | else: 52 | opt['dist'] = True 53 | init_dist() 54 | world_size = torch.distributed.get_world_size() 55 | rank = torch.distributed.get_rank() 56 | 57 | # loading resume state if exists 58 | if opt['path'].get('resume_state', None): 59 | # distributed resuming: all load into default GPU 60 | device_id = torch.cuda.current_device() 61 | resume_state = torch.load(opt['path']['resume_state'], 62 | map_location=lambda storage, loc: storage.cuda(device_id)) 63 | option.check_resume(opt, resume_state['iter']) # check resume options 64 | else: 65 | resume_state = None 66 | 67 | # mkdir and loggers 68 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 69 | if resume_state is None: 70 | util.mkdir_and_rename( 71 | opt['path']['experiments_root']) # rename experiment folder if exists 72 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 73 | and 'pretrain_model' not in key and 'resume' not in key)) 74 | 75 | # config loggers. Before it, the log will not work 76 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 77 | screen=True, tofile=True) 78 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 79 | screen=True, tofile=True) 80 | logger = logging.getLogger('base') 81 | logger.info(option.dict2str(opt)) 82 | # tensorboard logger 83 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 84 | version = float(torch.__version__[0:3]) 85 | if version >= 1.1: # PyTorch 1.1 86 | from torch.utils.tensorboard import SummaryWriter 87 | else: 88 | logger.info( 89 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 90 | from tensorboardX import SummaryWriter 91 | tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) 92 | else: 93 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 94 | logger = logging.getLogger('base') 95 | 96 | # convert to NoneDict, which returns None for missing keys 97 | opt = option.dict_to_nonedict(opt) 98 | 99 | # random seed 100 | seed = opt['train']['manual_seed'] 101 | if seed is None: 102 | seed = random.randint(1, 10000) 103 | if rank <= 0: 104 | logger.info('Random seed: {}'.format(seed)) 105 | util.set_random_seed(seed) 106 | 107 | torch.backends.cudnn.benchmark = True 108 | # torch.backends.cudnn.deterministic = True 109 | 110 | #### create train and val dataloader 111 | dataset_ratio = 200 # enlarge the size of each epoch 112 | for phase, dataset_opt in opt['datasets'].items(): 113 | if phase == 'train': 114 | train_set = create_dataset(dataset_opt) 115 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) 116 | total_iters = int(opt['train']['niter']) 117 | total_epochs = int(math.ceil(total_iters / train_size)) 118 | if opt['dist']: 119 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) 120 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) 121 | else: 122 | train_sampler = None 123 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) 124 | if rank <= 0: 125 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 126 | len(train_set), train_size)) 127 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 128 | total_epochs, total_iters)) 129 | elif phase == 'val': 130 | val_set = create_dataset(dataset_opt) 131 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 132 | if rank <= 0: 133 | logger.info('Number of val images in [{:s}]: {:d}'.format( 134 | dataset_opt['name'], len(val_set))) 135 | else: 136 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 137 | assert train_loader is not None 138 | 139 | # create model 140 | model = create_model(opt) 141 | # resume training 142 | if resume_state: 143 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 144 | resume_state['epoch'], resume_state['iter'])) 145 | 146 | start_epoch = resume_state['epoch'] 147 | current_step = resume_state['iter'] 148 | model.resume_training(resume_state) # handle optimizers and schedulers 149 | else: 150 | current_step = 0 151 | start_epoch = 0 152 | 153 | # training 154 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 155 | for epoch in range(start_epoch, total_epochs + 1): 156 | if opt['dist']: 157 | train_sampler.set_epoch(epoch) 158 | for _, train_data in enumerate(train_loader): 159 | current_step += 1 160 | if current_step > total_iters: 161 | break 162 | # training 163 | model.feed_data(train_data) 164 | model.optimize_parameters(current_step) 165 | 166 | # update learning rate 167 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) 168 | 169 | # log 170 | if current_step % opt['logger']['print_freq'] == 0: 171 | logs = model.get_current_log() 172 | message = ' '.format( 173 | epoch, current_step, model.get_current_learning_rate()) 174 | for k, v in logs.items(): 175 | message += '{:s}: {:.4e} '.format(k, v) 176 | # tensorboard logger 177 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 178 | if rank <= 0: 179 | tb_logger.add_scalar(k, v, current_step) 180 | if rank <= 0: 181 | logger.info(message) 182 | 183 | # validation 184 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0: 185 | avg_psnr = 0.0 186 | avg_psnr_h = [0.0]*opt['num_video'] 187 | avg_psnr_lr = 0.0 188 | idx = 0 189 | for video_id, val_data in enumerate(val_loader): 190 | img_dir = os.path.join(opt['path']['val_images']) 191 | util.mkdir(img_dir) 192 | 193 | model.feed_data(val_data) 194 | model.test() 195 | 196 | visuals = model.get_current_visuals() 197 | 198 | t_step = visuals['SR'].shape[0] 199 | idx += t_step 200 | n = len(visuals['SR_h']) 201 | 202 | for i in range(t_step): 203 | 204 | sr_img = util.tensor2img(visuals['SR'][i]) # uint8 205 | sr_img_h = [] 206 | for j in range(n): 207 | sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8 208 | gt_img = util.tensor2img(visuals['GT'][i]) # uint8 209 | lr_img = util.tensor2img(visuals['LR'][i]) 210 | lrgt_img = [] 211 | for j in range(n): 212 | lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i])) 213 | 214 | # Save SR images for reference 215 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'SR')) 216 | util.save_img(sr_img, save_img_path) 217 | 218 | for j in range(n): 219 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'SR_h')) 220 | util.save_img(sr_img_h[j], save_img_path) 221 | 222 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'GT')) 223 | util.save_img(gt_img, save_img_path) 224 | 225 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'LR')) 226 | util.save_img(lr_img, save_img_path) 227 | 228 | for j in range(n): 229 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'LRGT')) 230 | util.save_img(lrgt_img[j], save_img_path) 231 | 232 | psnr = cal_pnsr(sr_img, gt_img) 233 | psnr_h = [] 234 | for j in range(n): 235 | psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j])) 236 | psnr_lr = cal_pnsr(lr_img, gt_img) 237 | 238 | avg_psnr += psnr 239 | for j in range(n): 240 | avg_psnr_h[j] += psnr_h[j] 241 | avg_psnr_lr += psnr_lr 242 | 243 | avg_psnr = avg_psnr / idx 244 | avg_psnr_h = [psnr / idx for psnr in avg_psnr_h] 245 | avg_psnr_lr = avg_psnr_lr / idx 246 | 247 | # log 248 | res_psnr_h = '' 249 | for p in avg_psnr_h: 250 | res_psnr_h+=('_{:.4e}'.format(p)) 251 | 252 | logger.info('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr)) 253 | logger_val = logging.getLogger('val') # validation logger 254 | logger_val.info(' PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}'.format( 255 | epoch, current_step, avg_psnr, res_psnr_h, avg_psnr_lr)) 256 | # tensorboard logger 257 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 258 | tb_logger.add_scalar('psnr', avg_psnr, current_step) 259 | 260 | # save models and training states 261 | if current_step % opt['logger']['save_checkpoint_freq'] == 0: 262 | if rank <= 0: 263 | logger.info('Saving models and training states.') 264 | model.save(current_step) 265 | model.save_training_state(epoch, current_step) 266 | 267 | if rank <= 0: 268 | logger.info('Saving the final model.') 269 | model.save('latest') 270 | logger.info('End of training.') 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /code/models/LFVSN.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | 8 | import models.networks as networks 9 | import models.lr_scheduler as lr_scheduler 10 | from .base_model import BaseModel 11 | from models.modules.loss import ReconstructionLoss 12 | from models.modules.Quantization import Quantization 13 | from .modules.common import DWT,IWT 14 | 15 | logger = logging.getLogger('base') 16 | dwt=DWT() 17 | iwt=IWT() 18 | 19 | 20 | class Model_VSN(BaseModel): 21 | def __init__(self, opt): 22 | super(Model_VSN, self).__init__(opt) 23 | 24 | if opt['dist']: 25 | self.rank = torch.distributed.get_rank() 26 | else: 27 | self.rank = -1 # non dist training 28 | 29 | self.gop = opt['gop'] 30 | train_opt = opt['train'] 31 | test_opt = opt['test'] 32 | self.opt = opt 33 | self.train_opt = train_opt 34 | self.test_opt = test_opt 35 | self.opt_net = opt['network_G'] 36 | self.center = self.gop // 2 37 | self.num_video = opt['num_video'] 38 | self.idxx = 0 39 | 40 | self.netG = networks.define_G_v2(opt).to(self.device) 41 | if opt['dist']: 42 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 43 | else: 44 | self.netG = DataParallel(self.netG) 45 | # print network 46 | self.print_network() 47 | self.load() 48 | 49 | self.Quantization = Quantization() 50 | 51 | if self.is_train: 52 | self.netG.train() 53 | 54 | # loss 55 | self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw']) 56 | self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back']) 57 | self.Reconstruction_center = ReconstructionLoss(losstype="center") 58 | 59 | # optimizers 60 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 61 | optim_params = [] 62 | for k, v in self.netG.named_parameters(): 63 | if v.requires_grad: 64 | optim_params.append(v) 65 | else: 66 | if self.rank <= 0: 67 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 68 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 69 | weight_decay=wd_G, 70 | betas=(train_opt['beta1'], train_opt['beta2'])) 71 | self.optimizers.append(self.optimizer_G) 72 | 73 | # schedulers 74 | if train_opt['lr_scheme'] == 'MultiStepLR': 75 | for optimizer in self.optimizers: 76 | self.schedulers.append( 77 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 78 | restarts=train_opt['restarts'], 79 | weights=train_opt['restart_weights'], 80 | gamma=train_opt['lr_gamma'], 81 | clear_state=train_opt['clear_state'])) 82 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 83 | for optimizer in self.optimizers: 84 | self.schedulers.append( 85 | lr_scheduler.CosineAnnealingLR_Restart( 86 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 87 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 88 | else: 89 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 90 | 91 | self.log_dict = OrderedDict() 92 | 93 | def feed_data(self, data): 94 | self.ref_L = data['LQ'].to(self.device) 95 | self.real_H = data['GT'].to(self.device) 96 | 97 | def init_hidden_state(self, z): 98 | b, c, h, w = z.shape 99 | h_t = [] 100 | c_t = [] 101 | for _ in range(self.opt_net['block_num_rbm']): 102 | h_t.append(torch.zeros([b, c, h, w]).cuda()) 103 | c_t.append(torch.zeros([b, c, h, w]).cuda()) 104 | memory = torch.zeros([b, c, h, w]).cuda() 105 | 106 | return h_t, c_t, memory 107 | 108 | def loss_forward(self, out, y): 109 | if self.opt['model'] == 'LSTM-VRN': 110 | l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y) 111 | return l_forw_fit 112 | elif self.opt['model'] == 'MIMO-VRN-h': 113 | l_forw_fit = 0 114 | for i in range(out.shape[1]): 115 | l_forw_fit += self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out[:, i], y[:, i]) 116 | return l_forw_fit 117 | 118 | def loss_back_rec(self, out, x): 119 | if self.opt['model'] == 'LSTM-VRN': 120 | l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x) 121 | return l_back_rec 122 | elif self.opt['model'] == 'MIMO-VRN-h': 123 | l_back_rec = 0 124 | for i in range(x.shape[1]): 125 | l_back_rec += self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out[:, i], x[:, i]) 126 | return l_back_rec 127 | 128 | def loss_back_rec_mul(self, out, x): 129 | out = torch.chunk(out,self.num_video,dim=1) 130 | out = [outi.squeeze(1) for outi in out] 131 | x = torch.chunk(x,self.num_video,dim=1) 132 | x = [xi.squeeze(1) for xi in x] 133 | l_back_rec = 0 134 | for i in range(len(x)): 135 | for j in range(x[i].shape[1]): 136 | l_back_rec += self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out[i][:, j], x[i][:, j]) 137 | return l_back_rec 138 | 139 | def loss_center(self, out, x): 140 | # x.shape: (b, t, c, h, w) 141 | b, t = x.shape[:2] 142 | l_center = 0 143 | for i in range(b): 144 | mse_s = self.Reconstruction_center(out[i], x[i]) 145 | mse_mean = torch.mean(mse_s) 146 | for j in range(t): 147 | l_center += torch.sqrt((mse_s[j] - mse_mean.detach()) ** 2 + 1e-18) 148 | l_center = self.train_opt['lambda_center'] * l_center / b 149 | 150 | return l_center 151 | 152 | def optimize_parameters(self, current_step): 153 | self.optimizer_G.zero_grad() 154 | 155 | b, n, t, c, h, w = self.ref_L.shape 156 | center = t // 2 157 | intval = self.gop // 2 158 | 159 | self.host = self.real_H[:, center - intval:center + intval + 1] 160 | self.secret = self.ref_L[:, :, center - intval:center + intval + 1] 161 | self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] 162 | self.output, out_h = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret) 163 | self.output = iwt(self.output) 164 | 165 | Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach() 166 | container = self.output[:, :3 * self.gop, :, :].reshape(-1, self.gop, 3, h, w)[:,self.gop//2] 167 | l_forw_fit = self.loss_forward(container.unsqueeze(1), Gt_ref[:,self.gop//2].unsqueeze(1)) 168 | 169 | y = self.Quantization(self.output[:, :3 * self.gop, :, :].view(-1, self.gop, 3, h, w)[:,self.gop//2].unsqueeze(1).repeat(1,self.gop,1,1,1).reshape(b, -1, h, w)) 170 | out_x, out_x_h, out_z = self.netG(x=dwt(y), rev=True) 171 | out_x = iwt(out_x) 172 | out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] 173 | 174 | l_back_rec = self.loss_back_rec(out_x.reshape(-1, self.gop, 3, h, w)[:,self.gop//2].unsqueeze(1), self.host[:,self.gop//2].unsqueeze(1)) 175 | out_x_h = torch.stack(out_x_h, dim=1) 176 | 177 | l_center_x = 0 178 | for i in range(n): 179 | l_center_x += self.loss_back_rec(out_x_h.reshape(-1, n, self.gop, 3, h, w)[:, :, self.gop//2].unsqueeze(2)[:,i], self.ref_L[:, :, center - intval:center + intval + 1][:,:,self.gop//2].unsqueeze(2)[:, i]) 180 | 181 | loss = l_forw_fit*2 + l_back_rec + l_center_x*4 182 | loss.backward() 183 | 184 | if self.train_opt['lambda_center'] != 0: 185 | self.log_dict['l_center_x'] = l_center_x.item() 186 | 187 | # set log 188 | self.log_dict['l_back_rec'] = l_back_rec.item() 189 | self.log_dict['l_forw_fit'] = l_forw_fit.item() 190 | 191 | self.log_dict['l_h'] = (l_center_x*10).item() 192 | 193 | # gradient clipping 194 | if self.train_opt['gradient_clipping']: 195 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) 196 | 197 | self.optimizer_G.step() 198 | 199 | def test(self): 200 | self.netG.eval() 201 | with torch.no_grad(): 202 | forw_L = [] 203 | forw_L_h = [] 204 | fake_H = [] 205 | fake_H_h = [] 206 | pred_z = [] 207 | b, t, c, h, w = self.real_H.shape 208 | center = t // 2 209 | intval = self.gop // 2 210 | ids=[-1,0,1] 211 | b, n, t, c, h, w = self.ref_L.shape 212 | for j in range(3): 213 | id=ids[j] 214 | # forward downscaling 215 | self.host = self.real_H[:, center - intval+id:center + intval + 1+id] 216 | self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id] 217 | self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)] 218 | self.output, out_h = self.netG(x=dwt(self.host.reshape(b, -1, h, w)),x_h=self.secret) 219 | self.output = iwt(self.output) 220 | out_lrs = self.output[:, :3 * self.gop, :, :].reshape(-1, self.gop, 3, h, w) 221 | 222 | # backward upscaling 223 | y = self.Quantization(self.output[:, :3 * self.gop, :, :].view(-1, self.gop, 3, h, w)[:,self.gop//2].unsqueeze(1).repeat(1,self.gop,1,1,1).reshape(b, -1, h, w)) 224 | out_x, out_x_h, out_z = self.netG(x=dwt(y), rev=True) 225 | out_x = iwt(out_x) 226 | out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h] 227 | out_x = out_x.reshape(-1, self.gop, 3, h, w) 228 | out_x_h = torch.stack(out_x_h, dim=1) 229 | out_x_h = out_x_h.reshape(-1, n, self.gop, 3, h, w) 230 | forw_L.append(out_lrs[:, self.gop//2]) 231 | fake_H.append(out_x[:, self.gop//2]) 232 | fake_H_h.append(out_x_h[:,:, self.gop//2]) 233 | 234 | self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1) 235 | self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1) 236 | self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1) 237 | self.netG.train() 238 | 239 | def get_current_log(self): 240 | return self.log_dict 241 | 242 | def get_current_visuals(self): 243 | b, n, t, c, h, w = self.ref_L.shape 244 | center = t // 2 245 | intval = 3 // 2 246 | out_dict = OrderedDict() 247 | LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu() 248 | LR_ref = torch.chunk(LR_ref, self.num_video, dim=0) 249 | out_dict['LR_ref'] = [video.squeeze(0) for video in LR_ref] 250 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu() 251 | SR_h = self.fake_H_h.detach()[0].float().cpu() 252 | SR_h = torch.chunk(SR_h, self.num_video, dim=0) 253 | out_dict['SR_h'] = [video.squeeze(0) for video in SR_h] 254 | out_dict['LR'] = self.forw_L.detach()[0].float().cpu() 255 | out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu() 256 | 257 | return out_dict 258 | 259 | def print_network(self): 260 | s, n = self.get_network_description(self.netG) 261 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 262 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 263 | self.netG.module.__class__.__name__) 264 | else: 265 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 266 | if self.rank <= 0: 267 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 268 | logger.info(s) 269 | 270 | def load(self): 271 | load_path_G = self.opt['path']['pretrain_model_G'] 272 | if load_path_G is not None: 273 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 274 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 275 | 276 | def load_test(self,load_path_G): 277 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 278 | 279 | def save(self, iter_label): 280 | self.save_network(self.netG, 'G', iter_label) 281 | -------------------------------------------------------------------------------- /code/data/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import pickle 4 | import random 5 | import numpy as np 6 | import glob 7 | import torch 8 | import cv2 9 | 10 | #################### 11 | # Files & IO 12 | #################### 13 | 14 | ###################### get image path list ###################### 15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 16 | 17 | 18 | def is_image_file(filename): 19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 20 | 21 | 22 | def _get_paths_from_images(path): 23 | '''get image path list from image folder''' 24 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 25 | images = [] 26 | for dirpath, _, fnames in sorted(os.walk(path)): 27 | for fname in sorted(fnames): 28 | if is_image_file(fname): 29 | img_path = os.path.join(dirpath, fname) 30 | images.append(img_path) 31 | assert images, '{:s} has no valid image file'.format(path) 32 | return images 33 | 34 | 35 | def _get_paths_from_lmdb(dataroot): 36 | '''get image path list from lmdb meta info''' 37 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) 38 | paths = meta_info['keys'] 39 | sizes = meta_info['resolution'] 40 | if len(sizes) == 1: 41 | sizes = sizes * len(paths) 42 | return paths, sizes 43 | 44 | 45 | def get_image_paths(data_type, dataroot): 46 | '''get image path list 47 | support lmdb or image files''' 48 | paths, sizes = None, None 49 | if dataroot is not None: 50 | if data_type == 'lmdb': 51 | paths, sizes = _get_paths_from_lmdb(dataroot) 52 | elif data_type == 'img': 53 | paths = sorted(_get_paths_from_images(dataroot)) 54 | else: 55 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) 56 | return paths, sizes 57 | 58 | 59 | def glob_file_list(root): 60 | return sorted(glob.glob(os.path.join(root, '*'))) 61 | 62 | 63 | ###################### read images ###################### 64 | def _read_img_lmdb(env, key, size): 65 | '''read image from lmdb with key (w/ and w/o fixed size) 66 | size: (C, H, W) tuple''' 67 | with env.begin(write=False) as txn: 68 | buf = txn.get(key.encode('ascii')) 69 | img_flat = np.frombuffer(buf, dtype=np.uint8) 70 | C, H, W = size 71 | img = img_flat.reshape(H, W, C) 72 | return img 73 | 74 | 75 | def read_img(env, path, size=None): 76 | '''read image by cv2 or from lmdb 77 | return: Numpy float32, HWC, BGR, [0,1]''' 78 | if env is None: # img 79 | # print(path) 80 | #img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 81 | img = cv2.imread(path, cv2.IMREAD_COLOR) 82 | else: 83 | img = _read_img_lmdb(env, path, size) 84 | # print(img.shape) 85 | # if img is None: 86 | # print(path) 87 | # print(img.shape) 88 | img = img.astype(np.float32) / 255. 89 | if img.ndim == 2: 90 | img = np.expand_dims(img, axis=2) 91 | # some images have 4 channels 92 | if img.shape[2] > 3: 93 | img = img[:, :, :3] 94 | return img 95 | 96 | 97 | def read_img_seq(path): 98 | """Read a sequence of images from a given folder path 99 | Args: 100 | path (list/str): list of image paths/image folder path 101 | 102 | Returns: 103 | imgs (Tensor): size (T, C, H, W), RGB, [0, 1] 104 | """ 105 | if type(path) is list: 106 | img_path_l = path 107 | else: 108 | img_path_l = sorted(glob.glob(os.path.join(path, '*.png'))) 109 | # print(path) 110 | # print(path,img_path_l) 111 | img_l = [read_img(None, v) for v in img_path_l] 112 | # stack to Torch tensor 113 | imgs = np.stack(img_l, axis=0) 114 | imgs = imgs[:, :, :, [2, 1, 0]] 115 | imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() 116 | return imgs 117 | 118 | 119 | def index_generation(crt_i, max_n, N, padding='reflection'): 120 | """Generate an index list for reading N frames from a sequence of images 121 | Args: 122 | crt_i (int): current center index 123 | max_n (int): max number of the sequence of images (calculated from 1) 124 | N (int): reading N frames 125 | padding (str): padding mode, one of replicate | reflection | new_info | circle 126 | Example: crt_i = 0, N = 5 127 | replicate: [0, 0, 0, 1, 2] 128 | reflection: [2, 1, 0, 1, 2] 129 | new_info: [4, 3, 0, 1, 2] 130 | circle: [3, 4, 0, 1, 2] 131 | 132 | Returns: 133 | return_l (list [int]): a list of indexes 134 | """ 135 | max_n = max_n - 1 136 | n_pad = N // 2 137 | return_l = [] 138 | 139 | for i in range(crt_i - n_pad, crt_i + n_pad + 1): 140 | if i < 0: 141 | if padding == 'replicate': 142 | add_idx = 0 143 | elif padding == 'reflection': 144 | add_idx = -i 145 | elif padding == 'new_info': 146 | add_idx = (crt_i + n_pad) + (-i) 147 | elif padding == 'circle': 148 | add_idx = N + i 149 | else: 150 | raise ValueError('Wrong padding mode') 151 | elif i > max_n: 152 | if padding == 'replicate': 153 | add_idx = max_n 154 | elif padding == 'reflection': 155 | add_idx = max_n * 2 - i 156 | elif padding == 'new_info': 157 | add_idx = (crt_i - n_pad) - (i - max_n) 158 | elif padding == 'circle': 159 | add_idx = i - N 160 | else: 161 | raise ValueError('Wrong padding mode') 162 | else: 163 | add_idx = i 164 | return_l.append(add_idx) 165 | return return_l 166 | 167 | 168 | #################### 169 | # image processing 170 | # process on numpy image 171 | #################### 172 | 173 | 174 | def augment(img_list, hflip=True, rot=True): 175 | # horizontal flip OR rotate 176 | hflip = hflip and random.random() < 0.5 177 | vflip = rot and random.random() < 0.5 178 | rot90 = rot and random.random() < 0.5 179 | 180 | def _augment(img): 181 | if hflip: 182 | img = img[:, ::-1, :] 183 | if vflip: 184 | img = img[::-1, :, :] 185 | if rot90: 186 | img = img.transpose(1, 0, 2) 187 | return img 188 | 189 | return [_augment(img) for img in img_list] 190 | 191 | 192 | def augment_flow(img_list, flow_list, hflip=True, rot=True): 193 | # horizontal flip OR rotate 194 | hflip = hflip and random.random() < 0.5 195 | vflip = rot and random.random() < 0.5 196 | rot90 = rot and random.random() < 0.5 197 | 198 | def _augment(img): 199 | if hflip: 200 | img = img[:, ::-1, :] 201 | if vflip: 202 | img = img[::-1, :, :] 203 | if rot90: 204 | img = img.transpose(1, 0, 2) 205 | return img 206 | 207 | def _augment_flow(flow): 208 | if hflip: 209 | flow = flow[:, ::-1, :] 210 | flow[:, :, 0] *= -1 211 | if vflip: 212 | flow = flow[::-1, :, :] 213 | flow[:, :, 1] *= -1 214 | if rot90: 215 | flow = flow.transpose(1, 0, 2) 216 | flow = flow[:, :, [1, 0]] 217 | return flow 218 | 219 | rlt_img_list = [_augment(img) for img in img_list] 220 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list] 221 | 222 | return rlt_img_list, rlt_flow_list 223 | 224 | 225 | def channel_convert(in_c, tar_type, img_list): 226 | # conversion among BGR, gray and y 227 | if in_c == 3 and tar_type == 'gray': # BGR to gray 228 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] 229 | return [np.expand_dims(img, axis=2) for img in gray_list] 230 | elif in_c == 3 and tar_type == 'y': # BGR to y 231 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] 232 | return [np.expand_dims(img, axis=2) for img in y_list] 233 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR 234 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] 235 | else: 236 | return img_list 237 | 238 | 239 | def rgb2ycbcr(img, only_y=True): 240 | '''same as matlab rgb2ycbcr 241 | only_y: only return Y channel 242 | Input: 243 | uint8, [0, 255] 244 | float, [0, 1] 245 | ''' 246 | in_img_type = img.dtype 247 | img.astype(np.float32) 248 | if in_img_type != np.uint8: 249 | img *= 255. 250 | # convert 251 | if only_y: 252 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 253 | else: 254 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 255 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 256 | if in_img_type == np.uint8: 257 | rlt = rlt.round() 258 | else: 259 | rlt /= 255. 260 | return rlt.astype(in_img_type) 261 | 262 | 263 | def bgr2ycbcr(img, only_y=True): 264 | '''bgr version of rgb2ycbcr 265 | only_y: only return Y channel 266 | Input: 267 | uint8, [0, 255] 268 | float, [0, 1] 269 | ''' 270 | in_img_type = img.dtype 271 | img.astype(np.float32) 272 | if in_img_type != np.uint8: 273 | img *= 255. 274 | # convert 275 | if only_y: 276 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 277 | else: 278 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 279 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 280 | if in_img_type == np.uint8: 281 | rlt = rlt.round() 282 | else: 283 | rlt /= 255. 284 | return rlt.astype(in_img_type) 285 | 286 | 287 | def ycbcr2rgb(img): 288 | '''same as matlab ycbcr2rgb 289 | Input: 290 | uint8, [0, 255] 291 | float, [0, 1] 292 | ''' 293 | in_img_type = img.dtype 294 | img.astype(np.float32) 295 | if in_img_type != np.uint8: 296 | img *= 255. 297 | # convert 298 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 299 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 300 | if in_img_type == np.uint8: 301 | rlt = rlt.round() 302 | else: 303 | rlt /= 255. 304 | return rlt.astype(in_img_type) 305 | 306 | 307 | def modcrop(img_in, scale): 308 | # img_in: Numpy, HWC or HW 309 | img = np.copy(img_in) 310 | if img.ndim == 2: 311 | H, W = img.shape 312 | H_r, W_r = H % scale, W % scale 313 | img = img[:H - H_r, :W - W_r] 314 | elif img.ndim == 3: 315 | H, W, C = img.shape 316 | H_r, W_r = H % scale, W % scale 317 | img = img[:H - H_r, :W - W_r, :] 318 | else: 319 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) 320 | return img 321 | 322 | 323 | #################### 324 | # Functions 325 | #################### 326 | 327 | 328 | # matlab 'imresize' function, now only support 'bicubic' 329 | def cubic(x): 330 | absx = torch.abs(x) 331 | absx2 = absx**2 332 | absx3 = absx**3 333 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ( 334 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (( 335 | (absx > 1) * (absx <= 2)).type_as(absx)) 336 | 337 | 338 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): 339 | if (scale < 1) and (antialiasing): 340 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width 341 | kernel_width = kernel_width / scale 342 | 343 | # Output-space coordinates 344 | x = torch.linspace(1, out_length, out_length) 345 | 346 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 347 | # in output space maps to 0.5 in input space, and 0.5+scale in output 348 | # space maps to 1.5 in input space. 349 | u = x / scale + 0.5 * (1 - 1 / scale) 350 | 351 | # What is the left-most pixel that can be involved in the computation? 352 | left = torch.floor(u - kernel_width / 2) 353 | 354 | # What is the maximum number of pixels that can be involved in the 355 | # computation? Note: it's OK to use an extra pixel here; if the 356 | # corresponding weights are all zero, it will be eliminated at the end 357 | # of this function. 358 | P = math.ceil(kernel_width) + 2 359 | 360 | # The indices of the input pixels involved in computing the k-th output 361 | # pixel are in row k of the indices matrix. 362 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( 363 | 1, P).expand(out_length, P) 364 | 365 | # The weights used to compute the k-th output pixel are in row k of the 366 | # weights matrix. 367 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices 368 | # apply cubic kernel 369 | if (scale < 1) and (antialiasing): 370 | weights = scale * cubic(distance_to_center * scale) 371 | else: 372 | weights = cubic(distance_to_center) 373 | # Normalize the weights matrix so that each row sums to 1. 374 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 375 | weights = weights / weights_sum.expand(out_length, P) 376 | 377 | # If a column in weights is all zero, get rid of it. only consider the first and last column. 378 | weights_zero_tmp = torch.sum((weights == 0), 0) 379 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 380 | indices = indices.narrow(1, 1, P - 2) 381 | weights = weights.narrow(1, 1, P - 2) 382 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 383 | indices = indices.narrow(1, 0, P - 2) 384 | weights = weights.narrow(1, 0, P - 2) 385 | weights = weights.contiguous() 386 | indices = indices.contiguous() 387 | sym_len_s = -indices.min() + 1 388 | sym_len_e = indices.max() - in_length 389 | indices = indices + sym_len_s - 1 390 | return weights, indices, int(sym_len_s), int(sym_len_e) 391 | 392 | 393 | def imresize(img, scale, antialiasing=True): 394 | # Now the scale should be the same for H and W 395 | # input: img: CHW RGB [0,1] 396 | # output: CHW RGB [0,1] w/o round 397 | 398 | in_C, in_H, in_W = img.size() 399 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 400 | kernel_width = 4 401 | kernel = 'cubic' 402 | 403 | # Return the desired dimension order for performing the resize. The 404 | # strategy is to perform the resize first along the dimension with the 405 | # smallest scale factor. 406 | # Now we do not support this. 407 | 408 | # get weights and indices 409 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 410 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 411 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 412 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 413 | # process H dimension 414 | # symmetric copying 415 | img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) 416 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) 417 | 418 | sym_patch = img[:, :sym_len_Hs, :] 419 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 420 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 421 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) 422 | 423 | sym_patch = img[:, -sym_len_He:, :] 424 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 425 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 426 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 427 | 428 | out_1 = torch.FloatTensor(in_C, out_H, in_W) 429 | kernel_width = weights_H.size(1) 430 | for i in range(out_H): 431 | idx = int(indices_H[i][0]) 432 | out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 433 | out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 434 | out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 435 | 436 | # process W dimension 437 | # symmetric copying 438 | out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) 439 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) 440 | 441 | sym_patch = out_1[:, :, :sym_len_Ws] 442 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 443 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 444 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) 445 | 446 | sym_patch = out_1[:, :, -sym_len_We:] 447 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 448 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 449 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 450 | 451 | out_2 = torch.FloatTensor(in_C, out_H, out_W) 452 | kernel_width = weights_W.size(1) 453 | for i in range(out_W): 454 | idx = int(indices_W[i][0]) 455 | out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) 456 | out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) 457 | out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) 458 | 459 | return out_2 460 | 461 | 462 | def imresize_np(img, scale, antialiasing=True): 463 | # Now the scale should be the same for H and W 464 | # input: img: Numpy, HWC BGR [0,1] 465 | # output: HWC BGR [0,1] w/o round 466 | img = torch.from_numpy(img) 467 | 468 | in_H, in_W, in_C = img.size() 469 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 470 | kernel_width = 4 471 | kernel = 'cubic' 472 | 473 | # Return the desired dimension order for performing the resize. The 474 | # strategy is to perform the resize first along the dimension with the 475 | # smallest scale factor. 476 | # Now we do not support this. 477 | 478 | # get weights and indices 479 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 480 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 481 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 482 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 483 | # process H dimension 484 | # symmetric copying 485 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) 486 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) 487 | 488 | sym_patch = img[:sym_len_Hs, :, :] 489 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 490 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 491 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) 492 | 493 | sym_patch = img[-sym_len_He:, :, :] 494 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 495 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 496 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 497 | 498 | out_1 = torch.FloatTensor(out_H, in_W, in_C) 499 | kernel_width = weights_H.size(1) 500 | for i in range(out_H): 501 | idx = int(indices_H[i][0]) 502 | out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) 503 | out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) 504 | out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) 505 | 506 | # process W dimension 507 | # symmetric copying 508 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) 509 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) 510 | 511 | sym_patch = out_1[:, :sym_len_Ws, :] 512 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 513 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 514 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) 515 | 516 | sym_patch = out_1[:, -sym_len_We:, :] 517 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 518 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 519 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 520 | 521 | out_2 = torch.FloatTensor(out_H, out_W, in_C) 522 | kernel_width = weights_W.size(1) 523 | for i in range(out_W): 524 | idx = int(indices_W[i][0]) 525 | out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) 526 | out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) 527 | out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) 528 | 529 | return out_2.numpy() 530 | 531 | 532 | if __name__ == '__main__': 533 | # test imresize function 534 | # read images 535 | img = cv2.imread('test.png') 536 | img = img * 1.0 / 255 537 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() 538 | # imresize 539 | scale = 1 / 4 540 | import time 541 | total_time = 0 542 | for i in range(10): 543 | start_time = time.time() 544 | rlt = imresize(img, scale, antialiasing=True) 545 | use_time = time.time() - start_time 546 | total_time += use_time 547 | print('average time: {}'.format(total_time / 10)) 548 | 549 | import torchvision.utils 550 | torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, 551 | normalize=False) 552 | --------------------------------------------------------------------------------