├── code
├── utils
│ ├── __init__.py
│ └── util.py
├── options
│ ├── __init__.py
│ ├── train
│ │ ├── train_LF-VSN_3video.yml
│ │ ├── train_LF-VSN_4video.yml
│ │ ├── train_LF-VSN_5video.yml
│ │ ├── train_LF-VSN_6video.yml
│ │ ├── train_LF-VSN_7video.yml
│ │ ├── train_LF-VSN_1video.yml
│ │ └── train_LF-VSN_2video.yml
│ └── options.py
├── models
│ ├── modules
│ │ ├── __init__.py
│ │ ├── Quantization.py
│ │ ├── common.py
│ │ ├── loss.py
│ │ ├── Subnet_constructor.py
│ │ ├── module_util.py
│ │ └── Inv_arch.py
│ ├── __init__.py
│ ├── networks.py
│ ├── base_model.py
│ ├── lr_scheduler.py
│ ├── discrim.py
│ └── LFVSN.py
├── data
│ ├── video_test_dataset.py
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── Vimeo90K_dataset.py
│ └── util.py
├── test.py
└── train.py
├── assets
├── overview.PNG
└── performance.PNG
└── README.md
/code/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/code/options/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/code/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/overview.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MC-E/LF-VSN/HEAD/assets/overview.PNG
--------------------------------------------------------------------------------
/assets/performance.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MC-E/LF-VSN/HEAD/assets/performance.PNG
--------------------------------------------------------------------------------
/code/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 | def create_model(opt):
5 | model = opt['model']
6 | from .LFVSN import Model_VSN as M
7 |
8 | m = M(opt)
9 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
10 | return m
--------------------------------------------------------------------------------
/code/models/modules/Quantization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Quant(torch.autograd.Function):
5 |
6 | @staticmethod
7 | def forward(ctx, input):
8 | input = torch.clamp(input, 0, 1)
9 | output = (input * 255.).round() / 255.
10 | return output
11 |
12 | @staticmethod
13 | def backward(ctx, grad_output):
14 | return grad_output
15 |
16 | class Quantization(nn.Module):
17 | def __init__(self):
18 | super(Quantization, self).__init__()
19 |
20 | def forward(self, input):
21 | return Quant.apply(input)
22 |
--------------------------------------------------------------------------------
/code/models/networks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | from models.modules.Inv_arch import *
5 | from models.modules.Subnet_constructor import subnet
6 |
7 | logger = logging.getLogger('base')
8 |
9 | ####################
10 | # define network
11 | ####################
12 | def define_G_v2(opt):
13 | opt_net = opt['network_G']
14 | which_model = opt_net['which_model_G']
15 | subnet_type = which_model['subnet_type']
16 | opt_datasets = opt['datasets']
17 | down_num = int(math.log(opt_net['scale'], 2))
18 | if opt['num_video'] == 1:
19 | netG = VSN(opt, subnet(subnet_type, 'xavier'), subnet(subnet_type, 'xavier'), down_num)
20 | else:
21 | netG = VSN(opt, subnet(subnet_type, 'xavier'), subnet(subnet_type, 'xavier_v2'), down_num)
22 |
23 | return netG
24 |
--------------------------------------------------------------------------------
/code/data/video_test_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import torch
4 | import torch.utils.data as data
5 | import data.util as util
6 |
7 | class VideoTestDataset(data.Dataset):
8 | """
9 | A video test dataset. Support:
10 | Vid4
11 | REDS4
12 | Vimeo90K-Test
13 |
14 | no need to prepare LMDB files
15 | """
16 |
17 | def __init__(self, opt):
18 | super(VideoTestDataset, self).__init__()
19 | self.opt = opt
20 | self.half_N_frames = opt['N_frames'] // 2
21 | self.data_path = opt['data_path']
22 | self.txt_path = self.opt['txt_path']
23 | self.num_video = self.opt['num_video']
24 | with open(self.txt_path) as f:
25 | self.list_video = f.readlines()
26 | self.list_video = [line.strip('\n') for line in self.list_video]
27 | self.list_video.sort()
28 | self.list_video = self.list_video[:200]
29 | l = len(self.list_video) // (self.num_video + 1)
30 | self.video_list_gt = self.list_video[:l]
31 | self.video_list_lq = self.list_video[l:l * (self.num_video + 1)]
32 |
33 | def __getitem__(self, index):
34 | path_GT = self.video_list_gt[index]
35 |
36 | img_GT = util.read_img_seq(os.path.join(self.data_path, path_GT))
37 | list_h = []
38 | for i in range(self.num_video):
39 | path_LQ = self.video_list_lq[index*self.num_video+i]
40 | imgs_LQ = util.read_img_seq(os.path.join(self.data_path, path_LQ))
41 | list_h.append(imgs_LQ)
42 | list_h = torch.stack(list_h, dim=0)
43 | return {
44 | 'LQ': list_h,
45 | 'GT': img_GT
46 | }
47 |
48 | def __len__(self):
49 | return len(self.video_list_gt)
50 |
--------------------------------------------------------------------------------
/code/data/__init__.py:
--------------------------------------------------------------------------------
1 | '''create dataset and dataloader'''
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
7 | phase = dataset_opt['phase']
8 | if phase == 'train':
9 | if opt['dist']:
10 | world_size = torch.distributed.get_world_size()
11 | num_workers = dataset_opt['n_workers']
12 | assert dataset_opt['batch_size'] % world_size == 0
13 | batch_size = dataset_opt['batch_size'] // world_size
14 | shuffle = False
15 | else:
16 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
17 | batch_size = dataset_opt['batch_size']
18 | shuffle = True
19 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
20 | num_workers=num_workers, sampler=sampler, drop_last=True,
21 | pin_memory=False)
22 | else:
23 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
24 | pin_memory=True)
25 |
26 |
27 | def create_dataset(dataset_opt):
28 | mode = dataset_opt['mode']
29 | if mode == 'test':
30 | from data.video_test_dataset import VideoTestDataset as D
31 | elif mode == 'train':
32 | from data.Vimeo90K_dataset import Vimeo90KDataset as D
33 | else:
34 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
35 | print(mode)
36 | dataset = D(dataset_opt)
37 |
38 | logger = logging.getLogger('base')
39 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
40 | dataset_opt['name']))
41 | return dataset
42 |
--------------------------------------------------------------------------------
/code/options/train/train_LF-VSN_3video.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 |
3 | name: train_LF-VSN_3video
4 | use_tb_logger: true
5 | model: MIMO-VRN-h
6 | distortion: sr
7 | scale: 4
8 | gpu_ids: [0, 1]
9 | gop: 3
10 | num_video: 3
11 |
12 | #### datasets
13 |
14 | datasets:
15 | train:
16 | name: Vimeo90K
17 | mode: train
18 | interval_list: [1]
19 | random_reverse: false
20 | border_mode: false
21 | data_path: vimeo90k/sequences
22 | txt_path: vimeo90k/sep_trainlist.txt
23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb
24 | cache_keys: Vimeo90K_train_keys.pkl
25 | num_video: 3
26 |
27 | N_frames: 7
28 | use_shuffle: true
29 | n_workers: 24 # per GPU
30 | batch_size: 8
31 | GT_size: 144
32 | LQ_size: 36
33 | use_flip: true
34 | use_rot: true
35 | color: RGB
36 |
37 | val:
38 | num_video: 3
39 | name: Vid4
40 | mode: test
41 | data_path: vimeo90k/sequences
42 | txt_path: vimeo90k/sep_testlist.txt
43 | N_frames: 1
44 | padding: 'new_info'
45 | pred_interval: -1
46 |
47 |
48 | #### network structures
49 |
50 | network_G:
51 | which_model_G:
52 | subnet_type: DBNet
53 | in_nc: 12
54 | out_nc: 12
55 | block_num: [8, 8]
56 | scale: 2
57 | init: xavier_group
58 | block_num_rbm: 8
59 |
60 |
61 | #### path
62 |
63 | path:
64 | pretrain_model_G:
65 | models: ckp/base
66 | strict_load: true
67 | resume_state: ~
68 |
69 |
70 | #### training settings: learning rate scheme, loss
71 |
72 | train:
73 |
74 | lr_G: !!float 1e-4
75 | beta1: 0.9
76 | beta2: 0.5
77 | niter: 250000
78 | warmup_iter: -1 # no warm up
79 |
80 | lr_scheme: MultiStepLR
81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000]
82 | lr_gamma: 0.5
83 |
84 | pixel_criterion_forw: l2
85 | pixel_criterion_back: l1
86 |
87 | manual_seed: 10
88 |
89 | val_freq: !!float 1000 #!!float 5e3
90 |
91 | lambda_fit_forw: 64.
92 | lambda_rec_back: 1
93 | lambda_center: 0
94 |
95 | weight_decay_G: !!float 1e-12
96 | gradient_clipping: 10
97 |
98 |
99 | #### logger
100 |
101 | logger:
102 | print_freq: 100
103 | save_checkpoint_freq: !!float 5e3
104 |
--------------------------------------------------------------------------------
/code/options/train/train_LF-VSN_4video.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 |
3 | name: train_LF-VSN_4video
4 | use_tb_logger: true
5 | model: MIMO-VRN-h
6 | distortion: sr
7 | scale: 4
8 | gpu_ids: [0, 1]
9 | gop: 3
10 | num_video: 4
11 |
12 | #### datasets
13 |
14 | datasets:
15 | train:
16 | name: Vimeo90K
17 | mode: train
18 | interval_list: [1]
19 | random_reverse: false
20 | border_mode: false
21 | data_path: vimeo90k/sequences
22 | txt_path: vimeo90k/sep_trainlist.txt
23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb
24 | cache_keys: Vimeo90K_train_keys.pkl
25 | num_video: 4
26 |
27 | N_frames: 7
28 | use_shuffle: true
29 | n_workers: 24 # per GPU
30 | batch_size: 8
31 | GT_size: 144
32 | LQ_size: 36
33 | use_flip: true
34 | use_rot: true
35 | color: RGB
36 |
37 | val:
38 | num_video: 4
39 | name: Vid4
40 | mode: test
41 | data_path: vimeo90k/sequences
42 | txt_path: vimeo90k/sep_testlist.txt
43 | N_frames: 1
44 | padding: 'new_info'
45 | pred_interval: -1
46 |
47 |
48 | #### network structures
49 |
50 | network_G:
51 | which_model_G:
52 | subnet_type: DBNet
53 | in_nc: 12
54 | out_nc: 12
55 | block_num: [8, 8]
56 | scale: 2
57 | init: xavier_group
58 | block_num_rbm: 8
59 |
60 |
61 | #### path
62 |
63 | path:
64 | pretrain_model_G:
65 | models: ckp/base
66 | strict_load: true
67 | resume_state: ~
68 |
69 |
70 | #### training settings: learning rate scheme, loss
71 |
72 | train:
73 |
74 | lr_G: !!float 1e-4
75 | beta1: 0.9
76 | beta2: 0.5
77 | niter: 250000
78 | warmup_iter: -1 # no warm up
79 |
80 | lr_scheme: MultiStepLR
81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000]
82 | lr_gamma: 0.5
83 |
84 | pixel_criterion_forw: l2
85 | pixel_criterion_back: l1
86 |
87 | manual_seed: 10
88 |
89 | val_freq: !!float 1000 #!!float 5e3
90 |
91 | lambda_fit_forw: 64.
92 | lambda_rec_back: 1
93 | lambda_center: 0
94 |
95 | weight_decay_G: !!float 1e-12
96 | gradient_clipping: 10
97 |
98 |
99 | #### logger
100 |
101 | logger:
102 | print_freq: 100
103 | save_checkpoint_freq: !!float 5e3
104 |
--------------------------------------------------------------------------------
/code/options/train/train_LF-VSN_5video.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 |
3 | name: train_LF-VSN_5video
4 | use_tb_logger: true
5 | model: MIMO-VRN-h
6 | distortion: sr
7 | scale: 4
8 | gpu_ids: [0, 1]
9 | gop: 3
10 | num_video: 5
11 |
12 | #### datasets
13 |
14 | datasets:
15 | train:
16 | name: Vimeo90K
17 | mode: train
18 | interval_list: [1]
19 | random_reverse: false
20 | border_mode: false
21 | data_path: vimeo90k/sequences
22 | txt_path: vimeo90k/sep_trainlist.txt
23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb
24 | cache_keys: Vimeo90K_train_keys.pkl
25 | num_video: 5
26 |
27 | N_frames: 7
28 | use_shuffle: true
29 | n_workers: 24 # per GPU
30 | batch_size: 8
31 | GT_size: 144
32 | LQ_size: 36
33 | use_flip: true
34 | use_rot: true
35 | color: RGB
36 |
37 | val:
38 | num_video: 5
39 | name: Vid4
40 | mode: test
41 | data_path: vimeo90k/sequences
42 | txt_path: vimeo90k/sep_testlist.txt
43 | N_frames: 1
44 | padding: 'new_info'
45 | pred_interval: -1
46 |
47 |
48 | #### network structures
49 |
50 | network_G:
51 | which_model_G:
52 | subnet_type: DBNet
53 | in_nc: 12
54 | out_nc: 12
55 | block_num: [8, 8]
56 | scale: 2
57 | init: xavier_group
58 | block_num_rbm: 8
59 |
60 |
61 | #### path
62 |
63 | path:
64 | pretrain_model_G:
65 | models: ckp/base
66 | strict_load: true
67 | resume_state: ~
68 |
69 |
70 | #### training settings: learning rate scheme, loss
71 |
72 | train:
73 |
74 | lr_G: !!float 1e-4
75 | beta1: 0.9
76 | beta2: 0.5
77 | niter: 250000
78 | warmup_iter: -1 # no warm up
79 |
80 | lr_scheme: MultiStepLR
81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000]
82 | lr_gamma: 0.5
83 |
84 | pixel_criterion_forw: l2
85 | pixel_criterion_back: l1
86 |
87 | manual_seed: 10
88 |
89 | val_freq: !!float 1000 #!!float 5e3
90 |
91 | lambda_fit_forw: 64.
92 | lambda_rec_back: 1
93 | lambda_center: 0
94 |
95 | weight_decay_G: !!float 1e-12
96 | gradient_clipping: 10
97 |
98 |
99 | #### logger
100 |
101 | logger:
102 | print_freq: 100
103 | save_checkpoint_freq: !!float 5e3
104 |
--------------------------------------------------------------------------------
/code/options/train/train_LF-VSN_6video.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 |
3 | name: train_LF-VSN_6video
4 | use_tb_logger: true
5 | model: MIMO-VRN-h
6 | distortion: sr
7 | scale: 4
8 | gpu_ids: [0, 1]
9 | gop: 3
10 | num_video: 6
11 |
12 | #### datasets
13 |
14 | datasets:
15 | train:
16 | name: Vimeo90K
17 | mode: train
18 | interval_list: [1]
19 | random_reverse: false
20 | border_mode: false
21 | data_path: vimeo90k/sequences
22 | txt_path: vimeo90k/sep_trainlist.txt
23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb
24 | cache_keys: Vimeo90K_train_keys.pkl
25 | num_video: 6
26 |
27 | N_frames: 7
28 | use_shuffle: true
29 | n_workers: 24 # per GPU
30 | batch_size: 8
31 | GT_size: 144
32 | LQ_size: 36
33 | use_flip: true
34 | use_rot: true
35 | color: RGB
36 |
37 | val:
38 | num_video: 6
39 | name: Vid4
40 | mode: test
41 | data_path: vimeo90k/sequences
42 | txt_path: vimeo90k/sep_testlist.txt
43 | N_frames: 1
44 | padding: 'new_info'
45 | pred_interval: -1
46 |
47 |
48 | #### network structures
49 |
50 | network_G:
51 | which_model_G:
52 | subnet_type: DBNet
53 | in_nc: 12
54 | out_nc: 12
55 | block_num: [8, 8]
56 | scale: 2
57 | init: xavier_group
58 | block_num_rbm: 8
59 |
60 |
61 | #### path
62 |
63 | path:
64 | pretrain_model_G:
65 | models: ckp/base
66 | strict_load: true
67 | resume_state: ~
68 |
69 |
70 | #### training settings: learning rate scheme, loss
71 |
72 | train:
73 |
74 | lr_G: !!float 1e-4
75 | beta1: 0.9
76 | beta2: 0.5
77 | niter: 250000
78 | warmup_iter: -1 # no warm up
79 |
80 | lr_scheme: MultiStepLR
81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000]
82 | lr_gamma: 0.5
83 |
84 | pixel_criterion_forw: l2
85 | pixel_criterion_back: l1
86 |
87 | manual_seed: 10
88 |
89 | val_freq: !!float 1000 #!!float 5e3
90 |
91 | lambda_fit_forw: 64.
92 | lambda_rec_back: 1
93 | lambda_center: 0
94 |
95 | weight_decay_G: !!float 1e-12
96 | gradient_clipping: 10
97 |
98 |
99 | #### logger
100 |
101 | logger:
102 | print_freq: 100
103 | save_checkpoint_freq: !!float 5e3
104 |
--------------------------------------------------------------------------------
/code/options/train/train_LF-VSN_7video.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 |
3 | name: train_LF-VSN_7video
4 | use_tb_logger: true
5 | model: MIMO-VRN-h
6 | distortion: sr
7 | scale: 4
8 | gpu_ids: [0, 1]
9 | gop: 3
10 | num_video: 7
11 |
12 | #### datasets
13 |
14 | datasets:
15 | train:
16 | name: Vimeo90K
17 | mode: train
18 | interval_list: [1]
19 | random_reverse: false
20 | border_mode: false
21 | data_path: vimeo90k/sequences
22 | txt_path: vimeo90k/sep_trainlist.txt
23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb
24 | cache_keys: Vimeo90K_train_keys.pkl
25 | num_video: 7
26 |
27 | N_frames: 7
28 | use_shuffle: true
29 | n_workers: 24 # per GPU
30 | batch_size: 8
31 | GT_size: 144
32 | LQ_size: 36
33 | use_flip: true
34 | use_rot: true
35 | color: RGB
36 |
37 | val:
38 | num_video: 7
39 | name: Vid4
40 | mode: test
41 | data_path: vimeo90k/sequences
42 | txt_path: vimeo90k/sep_testlist.txt
43 | N_frames: 1
44 | padding: 'new_info'
45 | pred_interval: -1
46 |
47 |
48 | #### network structures
49 |
50 | network_G:
51 | which_model_G:
52 | subnet_type: DBNet
53 | in_nc: 12
54 | out_nc: 12
55 | block_num: [8, 8]
56 | scale: 2
57 | init: xavier_group
58 | block_num_rbm: 8
59 |
60 |
61 | #### path
62 |
63 | path:
64 | pretrain_model_G:
65 | models: ckp/base
66 | strict_load: true
67 | resume_state: ~
68 |
69 |
70 | #### training settings: learning rate scheme, loss
71 |
72 | train:
73 |
74 | lr_G: !!float 1e-4
75 | beta1: 0.9
76 | beta2: 0.5
77 | niter: 250000
78 | warmup_iter: -1 # no warm up
79 |
80 | lr_scheme: MultiStepLR
81 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000]
82 | lr_gamma: 0.5
83 |
84 | pixel_criterion_forw: l2
85 | pixel_criterion_back: l1
86 |
87 | manual_seed: 10
88 |
89 | val_freq: !!float 1000 #!!float 5e3
90 |
91 | lambda_fit_forw: 64.
92 | lambda_rec_back: 1
93 | lambda_center: 0
94 |
95 | weight_decay_G: !!float 1e-12
96 | gradient_clipping: 10
97 |
98 |
99 | #### logger
100 |
101 | logger:
102 | print_freq: 100
103 | save_checkpoint_freq: !!float 5e3
104 |
--------------------------------------------------------------------------------
/code/options/train/train_LF-VSN_1video.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 |
3 | name: train_LF-VSN_1video
4 | use_tb_logger: true
5 | model: MIMO-VRN-h
6 | distortion: sr
7 | scale: 4
8 | gpu_ids: [0, 1]
9 | gop: 3
10 | num_video: 1
11 |
12 | #### datasets
13 |
14 | datasets:
15 | train:
16 | name: Vimeo90K
17 | mode: train
18 | interval_list: [1]
19 | random_reverse: false
20 | border_mode: false
21 | data_path: vimeo90k/sequences
22 | txt_path: vimeo90k/sep_trainlist.txt
23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb
24 | cache_keys: Vimeo90K_train_keys.pkl
25 | num_video: 1
26 |
27 | N_frames: 7
28 | use_shuffle: true
29 | n_workers: 24 # per GPU
30 | batch_size: 8
31 | GT_size: 144
32 | LQ_size: 36
33 | use_flip: true
34 | use_rot: true
35 | color: RGB
36 |
37 | val:
38 | num_video: 1
39 | name: Vid4
40 | mode: test
41 | data_path: vimeo90k/sequences
42 | txt_path: vimeo90k/sep_testlist.txt
43 |
44 | N_frames: 1
45 | padding: 'new_info'
46 | pred_interval: -1
47 |
48 |
49 | #### network structures
50 |
51 | network_G:
52 | which_model_G:
53 | subnet_type: DBNet
54 | in_nc: 12
55 | out_nc: 12
56 | block_num: [8, 8]
57 | scale: 2
58 | init: xavier_group
59 | block_num_rbm: 8
60 |
61 |
62 | #### path
63 |
64 | path:
65 | pretrain_model_G:
66 | models: ckp/base
67 | strict_load: true
68 | resume_state: ~
69 |
70 |
71 | #### training settings: learning rate scheme, loss
72 |
73 | train:
74 |
75 | lr_G: !!float 1e-4
76 | beta1: 0.9
77 | beta2: 0.5
78 | niter: 250000
79 | warmup_iter: -1 # no warm up
80 |
81 | lr_scheme: MultiStepLR
82 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000]
83 | lr_gamma: 0.5
84 |
85 | pixel_criterion_forw: l2
86 | pixel_criterion_back: l1
87 |
88 | manual_seed: 10
89 |
90 | val_freq: !!float 1000 #!!float 5e3
91 |
92 | lambda_fit_forw: 64.
93 | lambda_rec_back: 1
94 | lambda_center: 0
95 |
96 | weight_decay_G: !!float 1e-12
97 | gradient_clipping: 10
98 |
99 |
100 | #### logger
101 |
102 | logger:
103 | print_freq: 100
104 | save_checkpoint_freq: !!float 5e3
105 |
--------------------------------------------------------------------------------
/code/options/train/train_LF-VSN_2video.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 |
3 | name: train_LF-VSN_2video
4 | use_tb_logger: true
5 | model: MIMO-VRN-h
6 | distortion: sr
7 | scale: 4
8 | gpu_ids: [0, 1]
9 | gop: 3
10 | num_video: 2
11 |
12 | #### datasets
13 |
14 | datasets:
15 | train:
16 | name: Vimeo90K
17 | mode: train
18 | interval_list: [1]
19 | random_reverse: false
20 | border_mode: false
21 | data_path: vimeo90k/sequences
22 | txt_path: vimeo90k/sep_trainlist.txt
23 | dataroot_LQ: ~/vimeo90k/vimeo90k_train_LR7frames.lmdb
24 | cache_keys: Vimeo90K_train_keys.pkl
25 | num_video: 2
26 |
27 | N_frames: 7
28 | use_shuffle: true
29 | n_workers: 24 # per GPU
30 | batch_size: 8
31 | GT_size: 144
32 | LQ_size: 36
33 | use_flip: true
34 | use_rot: true
35 | color: RGB
36 |
37 | val:
38 | num_video: 2
39 | name: Vid4
40 | mode: test
41 | data_path: vimeo90k/sequences
42 | txt_path: vimeo90k/sep_testlist.txt
43 |
44 | N_frames: 1
45 | padding: 'new_info'
46 | pred_interval: -1
47 |
48 |
49 | #### network structures
50 |
51 | network_G:
52 | which_model_G:
53 | subnet_type: DBNet
54 | in_nc: 12
55 | out_nc: 12
56 | block_num: [8, 8]
57 | scale: 2
58 | init: xavier_group
59 | block_num_rbm: 8
60 |
61 |
62 | #### path
63 |
64 | path:
65 | pretrain_model_G:
66 | models: ckp/base
67 | strict_load: true
68 | resume_state: ~
69 |
70 |
71 | #### training settings: learning rate scheme, loss
72 |
73 | train:
74 |
75 | lr_G: !!float 1e-4
76 | beta1: 0.9
77 | beta2: 0.5
78 | niter: 250000
79 | warmup_iter: -1 # no warm up
80 |
81 | lr_scheme: MultiStepLR
82 | lr_steps: [30000, 60000, 90000, 150000, 180000, 210000]
83 | lr_gamma: 0.5
84 |
85 | pixel_criterion_forw: l2
86 | pixel_criterion_back: l1
87 |
88 | manual_seed: 10
89 |
90 | val_freq: !!float 1000 #!!float 5e3
91 |
92 | lambda_fit_forw: 64.
93 | lambda_rec_back: 1
94 | lambda_center: 0
95 |
96 | weight_decay_G: !!float 1e-12
97 | gradient_clipping: 10
98 |
99 |
100 | #### logger
101 |
102 | logger:
103 | print_freq: 100
104 | save_checkpoint_freq: !!float 5e3
105 |
--------------------------------------------------------------------------------
/code/models/modules/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | def dwt_init3d(x):
7 |
8 | x01 = x[:, :, :, 0::2, :] / 2
9 | x02 = x[:, :, :, 1::2, :] / 2
10 | x1 = x01[:, :, :, :, 0::2]
11 | x2 = x02[:, :, :, :, 0::2]
12 | x3 = x01[:, :, :, :, 1::2]
13 | x4 = x02[:, :, :, :, 1::2]
14 | x_LL = x1 + x2 + x3 + x4
15 | x_HL = -x1 - x2 + x3 + x4
16 | x_LH = -x1 + x2 - x3 + x4
17 | x_HH = x1 - x2 - x3 + x4
18 |
19 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
20 |
21 | def dwt_init(x):
22 |
23 | x01 = x[:, :, 0::2, :] / 2
24 | x02 = x[:, :, 1::2, :] / 2
25 | x1 = x01[:, :, :, 0::2]
26 | x2 = x02[:, :, :, 0::2]
27 | x3 = x01[:, :, :, 1::2]
28 | x4 = x02[:, :, :, 1::2]
29 | x_LL = x1 + x2 + x3 + x4
30 | x_HL = -x1 - x2 + x3 + x4
31 | x_LH = -x1 + x2 - x3 + x4
32 | x_HH = x1 - x2 - x3 + x4
33 |
34 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
35 |
36 | def iwt_init(x):
37 | r = 2
38 | in_batch, in_channel, in_height, in_width = x.size()
39 | #print([in_batch, in_channel, in_height, in_width])
40 | out_batch, out_channel, out_height, out_width = in_batch, int(
41 | in_channel / (r ** 2)), r * in_height, r * in_width
42 | x1 = x[:, 0:out_channel, :, :] / 2
43 | x2 = x[:, out_channel:out_channel * 2, :, :] / 2
44 | x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
45 | x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
46 |
47 |
48 | h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
49 |
50 | h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
51 | h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
52 | h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
53 | h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
54 |
55 | return h
56 |
57 | class DWT(nn.Module):
58 | def __init__(self):
59 | super(DWT, self).__init__()
60 | self.requires_grad = False
61 |
62 | def forward(self, x):
63 | return dwt_init(x)
64 |
65 | class DWT3d(nn.Module):
66 | def __init__(self):
67 | super(DWT3d, self).__init__()
68 | self.requires_grad = False
69 |
70 | def forward(self, x):
71 | return dwt_init3d(x)
72 |
73 | class IWT(nn.Module):
74 | def __init__(self):
75 | super(IWT, self).__init__()
76 | self.requires_grad = False
77 |
78 | def forward(self, x):
79 | return iwt_init(x)
--------------------------------------------------------------------------------
/code/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from torch.utils.data.distributed.DistributedSampler
3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
4 | dataloader after each epoch
5 | """
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | import torch.distributed as dist
10 |
11 |
12 | class DistIterSampler(Sampler):
13 | """Sampler that restricts data loading to a subset of the dataset.
14 |
15 | It is especially useful in conjunction with
16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17 | process can pass a DistributedSampler instance as a DataLoader sampler,
18 | and load a subset of the original dataset that is exclusive to it.
19 |
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 |
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (optional): Number of processes participating in
26 | distributed training.
27 | rank (optional): Rank of the current process within num_replicas.
28 | """
29 |
30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31 | if num_replicas is None:
32 | if not dist.is_available():
33 | raise RuntimeError("Requires distributed package to be available")
34 | num_replicas = dist.get_world_size()
35 | if rank is None:
36 | if not dist.is_available():
37 | raise RuntimeError("Requires distributed package to be available")
38 | rank = dist.get_rank()
39 | self.dataset = dataset
40 | self.num_replicas = num_replicas
41 | self.rank = rank
42 | self.epoch = 0
43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
44 | self.total_size = self.num_samples * self.num_replicas
45 |
46 | def __iter__(self):
47 | # deterministically shuffle based on epoch
48 | g = torch.Generator()
49 | g.manual_seed(self.epoch)
50 | indices = torch.randperm(self.total_size, generator=g).tolist()
51 |
52 | dsize = len(self.dataset)
53 | indices = [v % dsize for v in indices]
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | assert len(indices) == self.num_samples
58 |
59 | return iter(indices)
60 |
61 | def __len__(self):
62 | return self.num_samples
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
--------------------------------------------------------------------------------
/code/models/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class ReconstructionLoss(nn.Module):
7 | def __init__(self, losstype='l2', eps=1e-6):
8 | super(ReconstructionLoss, self).__init__()
9 | self.losstype = losstype
10 | self.eps = eps
11 |
12 | def forward(self, x, target):
13 | if self.losstype == 'l2':
14 | return torch.mean(torch.sum((x - target) ** 2, (1, 2, 3)))
15 | elif self.losstype == 'l1':
16 | diff = x - target
17 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3)))
18 | elif self.losstype == 'center':
19 | return torch.sum((x - target) ** 2, (1, 2, 3))
20 |
21 | else:
22 | print("reconstruction loss type error!")
23 | return 0
24 |
25 |
26 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
27 | class GANLoss(nn.Module):
28 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
29 | super(GANLoss, self).__init__()
30 | self.gan_type = gan_type.lower()
31 | self.real_label_val = real_label_val
32 | self.fake_label_val = fake_label_val
33 |
34 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
35 | self.loss = nn.BCEWithLogitsLoss()
36 | elif self.gan_type == 'lsgan':
37 | self.loss = nn.MSELoss()
38 | elif self.gan_type == 'wgan-gp':
39 |
40 | def wgan_loss(input, target):
41 | # target is boolean
42 | return -1 * input.mean() if target else input.mean()
43 |
44 | self.loss = wgan_loss
45 | else:
46 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
47 |
48 | def get_target_label(self, input, target_is_real):
49 | if self.gan_type == 'wgan-gp':
50 | return target_is_real
51 | if target_is_real:
52 | return torch.empty_like(input).fill_(self.real_label_val)
53 | else:
54 | return torch.empty_like(input).fill_(self.fake_label_val)
55 |
56 | def forward(self, input, target_is_real):
57 | target_label = self.get_target_label(input, target_is_real)
58 | loss = self.loss(input, target_label)
59 | return loss
60 |
61 |
62 | class GradientPenaltyLoss(nn.Module):
63 | def __init__(self, device=torch.device('cpu')):
64 | super(GradientPenaltyLoss, self).__init__()
65 | self.register_buffer('grad_outputs', torch.Tensor())
66 | self.grad_outputs = self.grad_outputs.to(device)
67 |
68 | def get_grad_outputs(self, input):
69 | if self.grad_outputs.size() != input.size():
70 | self.grad_outputs.resize_(input.size()).fill_(1.0)
71 | return self.grad_outputs
72 |
73 | def forward(self, interp, interp_crit):
74 | grad_outputs = self.get_grad_outputs(interp_crit)
75 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
76 | grad_outputs=grad_outputs, create_graph=True,
77 | retain_graph=True, only_inputs=True)[0]
78 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
79 | grad_interp_norm = grad_interp.norm(2, dim=1)
80 |
81 | loss = ((grad_interp_norm - 1) ** 2).mean()
82 | return loss
83 |
--------------------------------------------------------------------------------
/code/models/modules/Subnet_constructor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import models.modules.module_util as mutil
5 | from basicsr.archs.arch_util import flow_warp, ResidualBlockNoBN
6 | from models.modules.module_util import initialize_weights_xavier
7 |
8 | class DenseBlock(nn.Module):
9 | def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
10 | super(DenseBlock, self).__init__()
11 | self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
12 | self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
13 | self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
14 | self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
15 | self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
16 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
17 | self.H = None
18 |
19 | if init == 'xavier':
20 | mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
21 | else:
22 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
23 | mutil.initialize_weights(self.conv5, 0)
24 |
25 | def forward(self, x):
26 | if isinstance(x, list):
27 | x = x[0]
28 | x1 = self.lrelu(self.conv1(x))
29 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
30 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
31 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
32 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
33 |
34 | return x5
35 |
36 | class DenseBlock_v2(nn.Module):
37 | def __init__(self, channel_in, channel_out, groups, init='xavier', gc=32, bias=True):
38 | super(DenseBlock_v2, self).__init__()
39 | self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
40 | self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
41 | self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
42 | self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
43 | self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
44 | self.conv_final = nn.Conv2d(channel_out*groups, channel_out, 3, 1, 1, bias=bias)
45 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
46 |
47 | if init == 'xavier':
48 | mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
49 | else:
50 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
51 | mutil.initialize_weights(self.conv_final, 0)
52 |
53 | def forward(self, x):
54 | res = []
55 | for xi in x:
56 | x1 = self.lrelu(self.conv1(xi))
57 | x2 = self.lrelu(self.conv2(torch.cat((xi, x1), 1)))
58 | x3 = self.lrelu(self.conv3(torch.cat((xi, x1, x2), 1)))
59 | x4 = self.lrelu(self.conv4(torch.cat((xi, x1, x2, x3), 1)))
60 | x5 = self.lrelu(self.conv5(torch.cat((xi, x1, x2, x3, x4), 1)))
61 | res.append(x5)
62 | res = torch.cat(res, dim=1)
63 | res = self.conv_final(res)
64 |
65 | return res
66 |
67 | def subnet(net_structure, init='xavier'):
68 | def constructor(channel_in, channel_out, groups=None):
69 | if net_structure == 'DBNet':
70 | if init == 'xavier':
71 | return DenseBlock(channel_in, channel_out, init)
72 | elif init == 'xavier_v2':
73 | return DenseBlock_v2(channel_in, channel_out, groups, 'xavier')
74 | else:
75 | return DenseBlock(channel_in, channel_out)
76 | else:
77 | return None
78 |
79 | return constructor
80 |
--------------------------------------------------------------------------------
/code/models/modules/module_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 |
27 | def initialize_weights_xavier(net_l, scale=1):
28 | if not isinstance(net_l, list):
29 | net_l = [net_l]
30 | for net in net_l:
31 | for m in net.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.xavier_normal_(m.weight)
34 | m.weight.data *= scale # for residual block
35 | if m.bias is not None:
36 | m.bias.data.zero_()
37 | elif isinstance(m, nn.Linear):
38 | init.xavier_normal_(m.weight)
39 | m.weight.data *= scale
40 | if m.bias is not None:
41 | m.bias.data.zero_()
42 | elif isinstance(m, nn.BatchNorm2d):
43 | init.constant_(m.weight, 1)
44 | init.constant_(m.bias.data, 0.0)
45 |
46 |
47 | def make_layer(block, n_layers):
48 | layers = []
49 | for _ in range(n_layers):
50 | layers.append(block())
51 | return nn.Sequential(*layers)
52 |
53 |
54 | class ResidualBlock_noBN(nn.Module):
55 | '''Residual block w/o BN
56 | ---Conv-ReLU-Conv-+-
57 | |________________|
58 | '''
59 |
60 | def __init__(self, nf=64):
61 | super(ResidualBlock_noBN, self).__init__()
62 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
64 |
65 | # initialization
66 | initialize_weights([self.conv1, self.conv2], 0.1)
67 |
68 | def forward(self, x):
69 | identity = x
70 | out = F.relu(self.conv1(x), inplace=True)
71 | out = self.conv2(out)
72 | return identity + out
73 |
74 |
75 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
76 | """Warp an image or feature map with optical flow
77 | Args:
78 | x (Tensor): size (N, C, H, W)
79 | flow (Tensor): size (N, H, W, 2), normal value
80 | interp_mode (str): 'nearest' or 'bilinear'
81 | padding_mode (str): 'zeros' or 'border' or 'reflection'
82 |
83 | Returns:
84 | Tensor: warped image or feature map
85 | """
86 | assert x.size()[-2:] == flow.size()[1:3]
87 | B, C, H, W = x.size()
88 | # mesh grid
89 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
90 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
91 | grid.requires_grad = False
92 | grid = grid.type_as(x)
93 | vgrid = grid + flow
94 | # scale grid to [-1,1]
95 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
96 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
97 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
98 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
99 | return output
100 |
--------------------------------------------------------------------------------
/code/data/Vimeo90K_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Vimeo90K dataset
3 | support reading images from lmdb, image folder and memcached
4 | '''
5 | import logging
6 | import os
7 | import os.path as osp
8 | import pickle
9 | import random
10 |
11 | import cv2
12 | import lmdb
13 | import numpy as np
14 | import torch
15 | import torch.utils.data as data
16 |
17 | import data.util as util
18 |
19 | try:
20 | import mc # import memcached
21 | except ImportError:
22 | pass
23 | logger = logging.getLogger('base')
24 |
25 | class Vimeo90KDataset(data.Dataset):
26 | '''
27 | Reading the training Vimeo90K dataset
28 | key example: 00001_0001 (_1, ..., _7)
29 | GT (Ground-Truth): 4th frame;
30 | LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
31 | '''
32 |
33 | def __init__(self, opt):
34 | super(Vimeo90KDataset, self).__init__()
35 | self.opt = opt
36 | # get train indexes
37 | self.data_path = self.opt['data_path']
38 | self.txt_path = self.opt['txt_path']
39 | with open(self.txt_path) as f:
40 | self.list_video = f.readlines()
41 | self.list_video = [line.strip('\n') for line in self.list_video]
42 | # temporal augmentation
43 | self.interval_list = opt['interval_list']
44 | self.random_reverse = opt['random_reverse']
45 | logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
46 | ','.join(str(x) for x in opt['interval_list']), self.random_reverse))
47 | self.data_type = self.opt['data_type']
48 | random.shuffle(self.list_video)
49 | self.LR_input = True
50 | self.num_video = self.opt['num_video']
51 |
52 | def _ensure_memcached(self):
53 | if self.mclient is None:
54 | # specify the config files
55 | server_list_config_file = None
56 | client_config_file = None
57 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
58 | client_config_file)
59 |
60 | def __getitem__(self, index):
61 | GT_size = self.opt['GT_size']
62 | video_name = self.list_video[index]
63 | path_frame = os.path.join(self.data_path, video_name)
64 | frames = []
65 | for im_name in os.listdir(path_frame):
66 | if im_name.endswith('.png'):
67 | frames.append(util.read_img(None, osp.join(path_frame, im_name)))
68 | list_index_h = []
69 | index_h = random.randint(0, len(self.list_video) - 1)
70 | list_index_h.append(index_h)
71 | for _ in range(self.num_video-1):
72 | index_h_i = random.randint(0, len(self.list_video) - 1)
73 | while index_h_i == index or index_h_i in list_index_h:
74 | index_h_i = random.randint(0, len(self.list_video) - 1)
75 | list_index_h.append(index_h_i)
76 |
77 | # random crop
78 | H, W, C = frames[0].shape
79 | rnd_h = random.randint(0, max(0, H - GT_size))
80 | rnd_w = random.randint(0, max(0, W - GT_size))
81 | frames = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in frames]
82 | # stack HQ images to NHWC, N is the frame number
83 | img_frames = np.stack(frames, axis=0)
84 | # BGR to RGB, HWC to CHW, numpy to tensor
85 | img_frames = img_frames[:, :, :, [2, 1, 0]]
86 | img_frames = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames, (0, 3, 1, 2)))).float()
87 | # process h_list
88 | list_h = []
89 | for index_h_i in list_index_h:
90 | video_name_h = self.list_video[index_h_i]
91 | path_frame_h = os.path.join(self.data_path, video_name_h)
92 | frames_h = []
93 | for im_name in os.listdir(path_frame_h):
94 | if im_name.endswith('.png'):
95 | frames_h.append(util.read_img(None, osp.join(path_frame_h, im_name)))
96 | frames_h = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in frames_h]
97 | img_frames_h = np.stack(frames_h, axis=0)
98 | img_frames_h = img_frames_h[:, :, :, [2, 1, 0]]
99 | img_frames_h = torch.from_numpy(np.ascontiguousarray(np.transpose(img_frames_h, (0, 3, 1, 2)))).float()
100 | list_h.append(img_frames_h.clone())
101 | list_h = torch.stack(list_h, dim=0)
102 |
103 | return {'GT': img_frames, 'LQ': list_h}
104 |
105 | def __len__(self):
106 | return len(self.list_video)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Large-capacity and Flexible Video Steganography via Invertible Neural Network (CVPR 2023)
2 | [Chong Mou](https://scholar.google.com.hk/citations?user=SYQoDk0AAAAJ&hl=en), Youmin Xu, Jiechong Song, [Chen Zhao](https://scholar.google.com/citations?hl=zh-CN&user=dUWdX5EAAAAJ), [Bernard Ghanem](https://www.bernardghanem.com/), [Jian Zhang](https://jianzhang.tech/)
3 |
4 | Official implementation of **[Large-capacity and Flexible Video Steganography via Invertible Neural Network](https://arxiv.org/abs/2304.12300)**.
5 |
6 | ## **Introduction**
7 |
8 |
9 |
10 |
11 |
12 | Video steganography is the art of unobtrusively concealing secret data in a cover video and then recovering the secret data through a decoding protocol at the receiver end. Although several attempts have been made, most of them are limited to low-capacity and fixed steganography. To rectify these weaknesses, we propose a **L**arge-capacity and **F**lexible **V**ideo **S**teganography **N**etwork (**LF-VSN**) in this paper. For large-capacity, we present a reversible pipeline to perform multiple videos hiding and recovering through a single invertible neural network (INN). Our method can **hide/recover 7 secret videos in/from 1 cover video** with promising performance. For flexibility, we propose a key-controllable scheme, enabling different receivers to recover particular secret videos from the same cover video through specific keys. Moreover, we further improve the flexibility by proposing a scalable strategy in multiple videos hiding, which can hide variable numbers of secret videos in a cover video with a single model and a single training session. Extensive experiments demonstrate that with the significant improvement of the video steganography performance, our proposed LF-VSN has high security, large hiding capacity, and flexibility.
13 |
14 | ---
15 |
16 |
20 |
21 | ## 🔧 **Dependencies and Installation**
22 | - Python 3.6
23 | - PyTorch >= 1.4.0
24 | - numpy
25 | - skimage
26 | - cv2
27 |
28 | ## ⏬ **Download Models**
29 |
30 | The pre-trained models are available at:
31 |
32 |
33 | | Mode | Download link |
34 | | :------------------- | :--------------------------------------------: |
35 | | One video hiding | [Google Drive](https://drive.google.com/file/d/1aEMZaigkMd2NUNXnOu2r0oa5IuLPCtTh/view?usp=share_link) |
36 | | Two video hiding | [Google Drive](https://drive.google.com/file/d/1Yd7tK9Y-J4fkXoL-5u8VifEVsW7OmZN0/view?usp=share_link) |
37 | | Three video hiding | [Google Drive](https://drive.google.com/file/d/1oeDDzkYMZ6tKpPnIUwSI2v_Rbn7vLQJo/view?usp=share_link) |
38 | | Four video hiding | [Google Drive](https://drive.google.com/file/d/1kyMKdfAG_gq6ArWChv6ZMLBsqT-QpS9j/view?usp=share_link) |
39 | | Five video hiding | [Google Drive](https://drive.google.com/file/d/1OlTL6_ZgsThPeYfxbpGrGvNoaisqThq2/view?usp=share_link) |
40 | | Six video hiding | [Google Drive](https://drive.google.com/file/d/1dr-ZIL-VP0ol4fRO7bGZYQoxRetA-GXW/view?usp=share_link) |
41 | | Seven video hiding | [Google Drive](https://drive.google.com/file/d/178cqpz_vS-mPlYwLuZP2qFc7pV7vrXrr/view?usp=share_link) |
42 |
43 | ## **Data Preparing**
44 | Please download the training and evaluation dataset from [Vimeo-90K](http://toflow.csail.mit.edu/).
45 |
46 | ## **Train**
47 |
48 | Training the desired model by changing the config file.
49 |
50 | ```bash
51 | python train.py -opt options/train/train_LF-VSN_1video.yml
52 | ```
53 |
54 | ## **Test**
55 |
56 | Testing the desired model by changing the config file.
57 |
58 | ```bash
59 | python test.py -opt options/train/train_LF-VSN_1video.yml
60 | ```
61 |
62 | ## **Qualitative Results**
63 |
64 |
65 |
66 |
67 | ## 🤗 **Acknowledgements**
68 | This code is built on [MIM-VRN (PyTorch)](https://github.com/ding3820/MIMO-VRN). We thank the authors for sharing their codes of MIMO-VRN.
69 |
70 | ## :e-mail: Contact
71 |
72 | If you have any question, please email `eechongm@gmail.com`.
73 |
74 | ## **Citation**
75 |
76 | If you find our work helpful in your resarch or work, please cite the following paper.
77 | ```
78 | @inproceedings{mou2023lfvsn,
79 | title={Large-capacity and Flexible Video Steganography via Invertible Neural Network},
80 | author={Chong Mou, Youmin Xu, Jiechong Song, Chen Zhao, Bernard Ghanem, Jian Zhang},
81 | booktitle={CVPR},
82 | year={2023}
83 | }
84 | ```
85 |
--------------------------------------------------------------------------------
/code/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 | Loader, Dumper = OrderedYaml()
7 |
8 |
9 | def parse(opt_path, is_train=True):
10 | with open(opt_path, mode='r') as f:
11 | opt = yaml.load(f, Loader=Loader)
12 | # export CUDA_VISIBLE_DEVICES
13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
16 |
17 | opt['is_train'] = is_train
18 | if opt['distortion'] == 'sr':
19 | scale = opt['scale']
20 |
21 | # datasets
22 | for phase, dataset in opt['datasets'].items():
23 | phase = phase.split('_')[0]
24 | dataset['phase'] = phase
25 | if opt['distortion'] == 'sr':
26 | dataset['scale'] = scale
27 | is_lmdb = False
28 | if dataset.get('dataroot_GT', None) is not None:
29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
30 | if dataset['dataroot_GT'].endswith('lmdb'):
31 | is_lmdb = True
32 | # if dataset.get('dataroot_GT_bg', None) is not None:
33 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg'])
34 | if dataset.get('dataroot_LQ', None) is not None:
35 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
36 | if dataset['dataroot_LQ'].endswith('lmdb'):
37 | is_lmdb = True
38 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
39 | if dataset['mode'].endswith('mc'): # for memcached
40 | dataset['data_type'] = 'mc'
41 | dataset['mode'] = dataset['mode'].replace('_mc', '')
42 |
43 | # path
44 | for key, path in opt['path'].items():
45 | if path and key in opt['path'] and key != 'strict_load':
46 | opt['path'][key] = osp.expanduser(path)
47 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
48 | if is_train:
49 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
50 | opt['path']['experiments_root'] = experiments_root
51 | opt['path']['models'] = osp.join(experiments_root, 'models')
52 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
53 | opt['path']['log'] = experiments_root
54 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
55 |
56 | # change some options for debug mode
57 | if 'debug' in opt['name']:
58 | opt['train']['val_freq'] = 8
59 | opt['logger']['print_freq'] = 1
60 | opt['logger']['save_checkpoint_freq'] = 8
61 | else: # test
62 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
63 | opt['path']['results_root'] = results_root
64 | opt['path']['log'] = results_root
65 |
66 | # network
67 | if opt['distortion'] == 'sr':
68 | opt['network_G']['scale'] = scale
69 |
70 | return opt
71 |
72 |
73 | def dict2str(opt, indent_l=1):
74 | '''dict to string for logger'''
75 | msg = ''
76 | for k, v in opt.items():
77 | if isinstance(v, dict):
78 | msg += ' ' * (indent_l * 2) + k + ':[\n'
79 | msg += dict2str(v, indent_l + 1)
80 | msg += ' ' * (indent_l * 2) + ']\n'
81 | else:
82 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
83 | return msg
84 |
85 |
86 | class NoneDict(dict):
87 | def __missing__(self, key):
88 | return None
89 |
90 |
91 | # convert to NoneDict, which return None for missing key.
92 | def dict_to_nonedict(opt):
93 | if isinstance(opt, dict):
94 | new_opt = dict()
95 | for key, sub_opt in opt.items():
96 | new_opt[key] = dict_to_nonedict(sub_opt)
97 | return NoneDict(**new_opt)
98 | elif isinstance(opt, list):
99 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
100 | else:
101 | return opt
102 |
103 |
104 | def check_resume(opt, resume_iter):
105 | '''Check resume states and pretrain_model paths'''
106 | logger = logging.getLogger('base')
107 | if opt['path']['resume_state']:
108 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
109 | 'pretrain_model_D', None) is not None:
110 | logger.warning('pretrain_model path will be ignored when resuming training.')
111 |
112 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
113 | '{}_G.pth'.format(resume_iter))
114 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
115 |
--------------------------------------------------------------------------------
/code/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
12 | self.is_train = opt['is_train']
13 | self.schedulers = []
14 | self.optimizers = []
15 |
16 | def feed_data(self, data):
17 | pass
18 |
19 | def optimize_parameters(self):
20 | pass
21 |
22 | def get_current_visuals(self):
23 | pass
24 |
25 | def get_current_losses(self):
26 | pass
27 |
28 | def print_network(self):
29 | pass
30 |
31 | def save(self, label):
32 | pass
33 |
34 | def load(self):
35 | pass
36 |
37 | def _set_lr(self, lr_groups_l):
38 | ''' set learning rate for warmup,
39 | lr_groups_l: list for lr_groups. each for a optimizer'''
40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
41 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
42 | param_group['lr'] = lr
43 |
44 | def _get_init_lr(self):
45 | # get the initial lr, which is set by the scheduler
46 | init_lr_groups_l = []
47 | for optimizer in self.optimizers:
48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
49 | return init_lr_groups_l
50 |
51 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
52 | for scheduler in self.schedulers:
53 | scheduler.step()
54 | #### set up warm up learning rate
55 | if cur_iter < warmup_iter:
56 | # get initial lr for each group
57 | init_lr_g_l = self._get_init_lr()
58 | # modify warming-up learning rates
59 | warm_up_lr_l = []
60 | for init_lr_g in init_lr_g_l:
61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
62 | # set learning rate
63 | self._set_lr(warm_up_lr_l)
64 |
65 | def get_current_learning_rate(self):
66 | # return self.schedulers[0].get_lr()[0]
67 | return self.optimizers[0].param_groups[0]['lr']
68 |
69 | def get_network_description(self, network):
70 | '''Get the string and total parameters of the network'''
71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
72 | network = network.module
73 | s = str(network)
74 | n = sum(map(lambda x: x.numel(), network.parameters()))
75 | return s, n
76 |
77 | def save_network(self, network, network_label, iter_label):
78 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
79 | save_path = os.path.join(self.opt['path']['models'], save_filename)
80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
81 | network = network.module
82 | state_dict = network.state_dict()
83 | for key, param in state_dict.items():
84 | state_dict[key] = param.cpu()
85 | torch.save(state_dict, save_path)
86 |
87 | def load_network(self, load_path, network, strict=True):
88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
89 | network = network.module
90 | load_net = torch.load(load_path)
91 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
92 | for k, v in load_net.items():
93 | if k.startswith('module.'):
94 | load_net_clean[k[7:]] = v
95 | else:
96 | load_net_clean[k] = v
97 | network.load_state_dict(load_net_clean, strict=strict)
98 |
99 | def save_training_state(self, epoch, iter_step):
100 | '''Saves training state during training, which will be used for resuming'''
101 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
102 | for s in self.schedulers:
103 | state['schedulers'].append(s.state_dict())
104 | for o in self.optimizers:
105 | state['optimizers'].append(o.state_dict())
106 | save_filename = '{}.state'.format(iter_step)
107 | save_path = os.path.join(self.opt['path']['training_state'], save_filename)
108 | torch.save(state, save_path)
109 |
110 | def resume_training(self, resume_state):
111 | '''Resume the optimizers and schedulers for training'''
112 | resume_optimizers = resume_state['optimizers']
113 | resume_schedulers = resume_state['schedulers']
114 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
115 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
116 | for i, o in enumerate(resume_optimizers):
117 | self.optimizers[i].load_state_dict(o)
118 | for i, s in enumerate(resume_schedulers):
119 | self.schedulers[i].load_state_dict(s)
120 |
--------------------------------------------------------------------------------
/code/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restart_weights = weights if weights else [1]
16 | assert len(self.restarts) == len(
17 | self.restart_weights), 'restarts and their weights do not match.'
18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
19 |
20 | def get_lr(self):
21 | if self.last_epoch in self.restarts:
22 | if self.clear_state:
23 | self.optimizer.state = defaultdict(dict)
24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
26 | if self.last_epoch not in self.milestones:
27 | return [group['lr'] for group in self.optimizer.param_groups]
28 | return [
29 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
30 | for group in self.optimizer.param_groups
31 | ]
32 |
33 |
34 | class CosineAnnealingLR_Restart(_LRScheduler):
35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
36 | self.T_period = T_period
37 | self.T_max = self.T_period[0] # current T period
38 | self.eta_min = eta_min
39 | self.restarts = restarts if restarts else [0]
40 | self.restart_weights = weights if weights else [1]
41 | self.last_restart = 0
42 | assert len(self.restarts) == len(
43 | self.restart_weights), 'restarts and their weights do not match.'
44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
45 |
46 | def get_lr(self):
47 | if self.last_epoch == 0:
48 | return self.base_lrs
49 | elif self.last_epoch in self.restarts:
50 | self.last_restart = self.last_epoch
51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
55 | return [
56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
58 | ]
59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
61 | (group['lr'] - self.eta_min) + self.eta_min
62 | for group in self.optimizer.param_groups]
63 |
64 |
65 | if __name__ == "__main__":
66 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
67 | betas=(0.9, 0.99))
68 | ##############################
69 | # MultiStepLR_Restart
70 | ##############################
71 | ## Original
72 | lr_steps = [200000, 400000, 600000, 800000]
73 | restarts = None
74 | restart_weights = None
75 |
76 | ## two
77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
78 | restarts = [500000]
79 | restart_weights = [1]
80 |
81 | ## four
82 | lr_steps = [
83 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
84 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
85 | ]
86 | restarts = [250000, 500000, 750000]
87 | restart_weights = [1, 1, 1]
88 |
89 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
90 | clear_state=False)
91 |
92 | ##############################
93 | # Cosine Annealing Restart
94 | ##############################
95 | ## two
96 | T_period = [500000, 500000]
97 | restarts = [500000]
98 | restart_weights = [1]
99 |
100 | ## four
101 | T_period = [250000, 250000, 250000, 250000]
102 | restarts = [250000, 500000, 750000]
103 | restart_weights = [1, 1, 1]
104 |
105 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
106 | weights=restart_weights)
107 |
108 | ##############################
109 | # Draw figure
110 | ##############################
111 | N_iter = 1000000
112 | lr_l = list(range(N_iter))
113 | for i in range(N_iter):
114 | scheduler.step()
115 | current_lr = optimizer.param_groups[0]['lr']
116 | lr_l[i] = current_lr
117 |
118 | import matplotlib as mpl
119 | from matplotlib import pyplot as plt
120 | import matplotlib.ticker as mtick
121 | mpl.style.use('default')
122 | import seaborn
123 | seaborn.set(style='whitegrid')
124 | seaborn.set_context('paper')
125 |
126 | plt.figure(1)
127 | plt.subplot(111)
128 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
129 | plt.title('Title', fontsize=16, color='k')
130 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
131 | legend = plt.legend(loc='upper right', shadow=False)
132 | ax = plt.gca()
133 | labels = ax.get_xticks().tolist()
134 | for k, v in enumerate(labels):
135 | labels[k] = str(int(v / 1000)) + 'K'
136 | ax.set_xticklabels(labels)
137 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
138 |
139 | ax.set_ylabel('Learning rate')
140 | ax.set_xlabel('Iteration')
141 | fig = plt.gcf()
142 | plt.show()
143 |
--------------------------------------------------------------------------------
/code/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import random
5 | import logging
6 |
7 | import torch
8 | import torch.distributed as dist
9 | import torch.multiprocessing as mp
10 | from data.data_sampler import DistIterSampler
11 |
12 | import options.options as option
13 | from utils import util
14 | from data import create_dataloader, create_dataset
15 | from models import create_model
16 |
17 |
18 | def init_dist(backend='nccl', **kwargs):
19 | ''' initialization for distributed training'''
20 | # if mp.get_start_method(allow_none=True) is None:
21 | if mp.get_start_method(allow_none=True) != 'spawn':
22 | mp.set_start_method('spawn')
23 | rank = int(os.environ['RANK'])
24 | num_gpus = torch.cuda.device_count()
25 | torch.cuda.set_device(rank % num_gpus)
26 | dist.init_process_group(backend=backend, **kwargs)
27 |
28 | def cal_pnsr(sr_img, gt_img):
29 | # calculate PSNR
30 | gt_img = gt_img / 255.
31 | sr_img = sr_img / 255.
32 |
33 | psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
34 |
35 | return psnr
36 |
37 | def main():
38 | # options
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件
41 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
42 | help='job launcher')
43 | parser.add_argument('--ckpt', type=str, default='/group/30042/chongmou/ft_local/LF-VSN-git/LF-VSN/ckpt/LF-VSN_2video_hiding_250k.pth', help='Path to pre-trained model.')
44 | parser.add_argument('--local_rank', type=int, default=0)
45 | args = parser.parse_args()
46 | opt = option.parse(args.opt, is_train=True)
47 |
48 | # distributed training settings
49 | if args.launcher == 'none': # disabled distributed training
50 | opt['dist'] = False
51 | rank = -1
52 | print('Disabled distributed training.')
53 | else:
54 | opt['dist'] = True
55 | init_dist()
56 | world_size = torch.distributed.get_world_size()
57 | rank = torch.distributed.get_rank()
58 |
59 | # loading resume state if exists
60 | if opt['path'].get('resume_state', None):
61 | # distributed resuming: all load into default GPU
62 | device_id = torch.cuda.current_device()
63 | resume_state = torch.load(opt['path']['resume_state'],
64 | map_location=lambda storage, loc: storage.cuda(device_id))
65 | option.check_resume(opt, resume_state['iter']) # check resume options
66 | else:
67 | resume_state = None
68 |
69 | # convert to NoneDict, which returns None for missing keys
70 | opt = option.dict_to_nonedict(opt)
71 |
72 | torch.backends.cudnn.benchmark = True
73 | # torch.backends.cudnn.deterministic = True
74 |
75 | #### create train and val dataloader
76 | dataset_ratio = 200 # enlarge the size of each epoch
77 | for phase, dataset_opt in opt['datasets'].items():
78 | if phase == 'train':
79 | train_set = create_dataset(dataset_opt)
80 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
81 | total_iters = int(opt['train']['niter'])
82 | total_epochs = int(math.ceil(total_iters / train_size))
83 | if opt['dist']:
84 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
85 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
86 | else:
87 | train_sampler = None
88 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
89 | elif phase == 'val':
90 | val_set = create_dataset(dataset_opt)
91 | val_loader = create_dataloader(val_set, dataset_opt, opt, None)
92 | else:
93 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
94 | assert train_loader is not None
95 |
96 | # create model
97 | model = create_model(opt)
98 | model.load_test(args.ckpt)
99 |
100 | # validation
101 | avg_psnr = 0.0
102 | avg_psnr_h = [0.0]*opt['num_video']
103 | avg_psnr_lr = 0.0
104 | idx = 0
105 | for video_id, val_data in enumerate(val_loader):
106 | img_dir = os.path.join('results',opt['name'])
107 | util.mkdir(img_dir)
108 |
109 | model.feed_data(val_data)
110 | model.test()
111 |
112 | visuals = model.get_current_visuals()
113 |
114 | t_step = visuals['SR'].shape[0]
115 | idx += t_step
116 | n = len(visuals['SR_h'])
117 |
118 | for i in range(t_step):
119 |
120 | sr_img = util.tensor2img(visuals['SR'][i]) # uint8
121 | sr_img_h = []
122 | for j in range(n):
123 | sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8
124 | gt_img = util.tensor2img(visuals['GT'][i]) # uint8
125 | lr_img = util.tensor2img(visuals['LR'][i])
126 | lrgt_img = []
127 | for j in range(n):
128 | lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i]))
129 |
130 | # Save SR images for reference
131 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'SR'))
132 | util.save_img(sr_img, save_img_path)
133 |
134 | for j in range(n):
135 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'SR_h'))
136 | util.save_img(sr_img_h[j], save_img_path)
137 |
138 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'GT'))
139 | util.save_img(gt_img, save_img_path)
140 |
141 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'LR'))
142 | util.save_img(lr_img, save_img_path)
143 |
144 | for j in range(n):
145 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'LRGT'))
146 | util.save_img(lrgt_img[j], save_img_path)
147 |
148 | psnr = cal_pnsr(sr_img, gt_img)
149 | psnr_h = []
150 | for j in range(n):
151 | psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j]))
152 | psnr_lr = cal_pnsr(lr_img, gt_img)
153 |
154 | avg_psnr += psnr
155 | for j in range(n):
156 | avg_psnr_h[j] += psnr_h[j]
157 | avg_psnr_lr += psnr_lr
158 |
159 | avg_psnr = avg_psnr / idx
160 | avg_psnr_h = [psnr / idx for psnr in avg_psnr_h]
161 | avg_psnr_lr = avg_psnr_lr / idx
162 | res_psnr_h = ''
163 | for p in avg_psnr_h:
164 | res_psnr_h+=('_{:.4e}'.format(p))
165 | print('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr))
166 |
167 |
168 | if __name__ == '__main__':
169 | main()
--------------------------------------------------------------------------------
/code/models/discrim.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 | from torch.nn.utils import spectral_norm
4 |
5 |
6 | class UNetDiscriminatorSN(nn.Module):
7 | """Defines a U-Net discriminator with spectral normalization (SN)
8 |
9 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
10 |
11 | Arg:
12 | num_in_ch (int): Channel number of inputs. Default: 3.
13 | num_feat (int): Channel number of base intermediate features. Default: 64.
14 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
15 | """
16 |
17 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
18 | super(UNetDiscriminatorSN, self).__init__()
19 | self.skip_connection = skip_connection
20 | norm = spectral_norm
21 | # the first convolution
22 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
23 | # downsample
24 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
25 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
26 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
27 | # upsample
28 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
29 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
30 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
31 | # extra convolutions
32 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
33 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
34 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
35 |
36 | def forward(self, x):
37 | # downsample
38 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
39 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
40 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
41 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
42 |
43 | # upsample
44 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
45 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
46 |
47 | if self.skip_connection:
48 | x4 = x4 + x2
49 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
50 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
51 |
52 | if self.skip_connection:
53 | x5 = x5 + x1
54 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
55 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
56 |
57 | if self.skip_connection:
58 | x6 = x6 + x0
59 |
60 | # extra convolutions
61 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
62 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
63 | out = self.conv9(out)
64 |
65 | return out
66 |
67 |
68 | class GANLoss(nn.Module):
69 | """Define GAN loss.
70 |
71 | Args:
72 | gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
73 | real_label_val (float): The value for real label. Default: 1.0.
74 | fake_label_val (float): The value for fake label. Default: 0.0.
75 | loss_weight (float): Loss weight. Default: 1.0.
76 | Note that loss_weight is only for generators; and it is always 1.0
77 | for discriminators.
78 | """
79 |
80 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
81 | super(GANLoss, self).__init__()
82 | self.gan_type = gan_type
83 | self.loss_weight = loss_weight
84 | self.real_label_val = real_label_val
85 | self.fake_label_val = fake_label_val
86 |
87 | if self.gan_type == 'vanilla':
88 | self.loss = nn.BCEWithLogitsLoss()
89 | elif self.gan_type == 'lsgan':
90 | self.loss = nn.MSELoss()
91 | elif self.gan_type == 'wgan':
92 | self.loss = self._wgan_loss
93 | elif self.gan_type == 'wgan_softplus':
94 | self.loss = self._wgan_softplus_loss
95 | elif self.gan_type == 'hinge':
96 | self.loss = nn.ReLU()
97 | else:
98 | raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
99 |
100 | def _wgan_loss(self, input, target):
101 | """wgan loss.
102 |
103 | Args:
104 | input (Tensor): Input tensor.
105 | target (bool): Target label.
106 |
107 | Returns:
108 | Tensor: wgan loss.
109 | """
110 | return -input.mean() if target else input.mean()
111 |
112 | def _wgan_softplus_loss(self, input, target):
113 | """wgan loss with soft plus. softplus is a smooth approximation to the
114 | ReLU function.
115 |
116 | In StyleGAN2, it is called:
117 | Logistic loss for discriminator;
118 | Non-saturating loss for generator.
119 |
120 | Args:
121 | input (Tensor): Input tensor.
122 | target (bool): Target label.
123 |
124 | Returns:
125 | Tensor: wgan loss.
126 | """
127 | return F.softplus(-input).mean() if target else F.softplus(input).mean()
128 |
129 | def get_target_label(self, input, target_is_real):
130 | """Get target label.
131 |
132 | Args:
133 | input (Tensor): Input tensor.
134 | target_is_real (bool): Whether the target is real or fake.
135 |
136 | Returns:
137 | (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
138 | return Tensor.
139 | """
140 |
141 | if self.gan_type in ['wgan', 'wgan_softplus']:
142 | return target_is_real
143 | target_val = (self.real_label_val if target_is_real else self.fake_label_val)
144 | return input.new_ones(input.size()) * target_val
145 |
146 | def forward(self, input, target_is_real, is_disc=False):
147 | """
148 | Args:
149 | input (Tensor): The input for the loss module, i.e., the network
150 | prediction.
151 | target_is_real (bool): Whether the targe is real or fake.
152 | is_disc (bool): Whether the loss for discriminators or not.
153 | Default: False.
154 |
155 | Returns:
156 | Tensor: GAN loss value.
157 | """
158 | target_label = self.get_target_label(input, target_is_real)
159 | if self.gan_type == 'hinge':
160 | if is_disc: # for discriminators in hinge-gan
161 | input = -input if target_is_real else input
162 | loss = self.loss(1 + input).mean()
163 | else: # for generators in hinge-gan
164 | loss = -input.mean()
165 | else: # other gan types
166 | loss = self.loss(input, target_label)
167 |
168 | # loss_weight is always 1.0 for discriminators
169 | return loss if is_disc else loss * self.loss_weight
--------------------------------------------------------------------------------
/code/models/modules/Inv_arch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .module_util import initialize_weights_xavier
7 | from torch.nn import init
8 | from .common import DWT,IWT
9 | import cv2
10 | from basicsr.archs.arch_util import flow_warp
11 | from models.modules.Subnet_constructor import subnet
12 | import numpy as np
13 |
14 | dwt=DWT()
15 | iwt=IWT()
16 |
17 | def thops_mean(tensor, dim=None, keepdim=False):
18 | if dim is None:
19 | # mean all dim
20 | return torch.mean(tensor)
21 | else:
22 | if isinstance(dim, int):
23 | dim = [dim]
24 | dim = sorted(dim)
25 | for d in dim:
26 | tensor = tensor.mean(dim=d, keepdim=True)
27 | if not keepdim:
28 | for i, d in enumerate(dim):
29 | tensor.squeeze_(d-i)
30 | return tensor
31 |
32 |
33 | class ResidualBlockNoBN(nn.Module):
34 | def __init__(self, nf=64, model='MIMO-VRN'):
35 | super(ResidualBlockNoBN, self).__init__()
36 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
37 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
38 | # honestly, there's no significant difference between ReLU and leaky ReLU in terms of performance here
39 | # but this is how we trained the model in the first place and what we reported in the paper
40 | if model == 'LSTM-VRN':
41 | self.relu = nn.ReLU(inplace=True)
42 | elif model == 'MIMO-VRN':
43 | self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
44 |
45 | # initialization
46 | initialize_weights_xavier([self.conv1, self.conv2], 0.1)
47 |
48 | def forward(self, x):
49 | identity = x
50 | out = self.relu(self.conv1(x))
51 | out = self.conv2(out)
52 | return identity + out
53 |
54 |
55 | class InvBlock(nn.Module):
56 | def __init__(self, subnet_constructor, subnet_constructor_v2, channel_num_ho, channel_num_hi, groups, clamp=1.):
57 | super(InvBlock, self).__init__()
58 | self.split_len1 = channel_num_ho # channel_split_num
59 | self.split_len2 = channel_num_hi # channel_num - channel_split_num
60 | self.clamp = clamp
61 |
62 | self.F = subnet_constructor_v2(self.split_len2, self.split_len1, groups=groups)
63 | if groups == 1:
64 | self.G = subnet_constructor(self.split_len1, self.split_len2, groups=groups)
65 | self.H = subnet_constructor(self.split_len1, self.split_len2, groups=groups)
66 | else:
67 | self.G = subnet_constructor(self.split_len1, self.split_len2)
68 | self.H = subnet_constructor(self.split_len1, self.split_len2)
69 |
70 | def forward(self, x1, x2, rev=False):
71 | if not rev:
72 | y1 = x1 + self.F(x2)
73 | self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1)
74 | y2 = [x2i.mul(torch.exp(self.s)) + self.G(y1) for x2i in x2]
75 | else:
76 | self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1)
77 | # print(x2[0].shape, self.G(x1).shape)
78 | y2 = [(x2i - self.G(x1)).div(torch.exp(self.s)) for x2i in x2]
79 | y1 = x1 - self.F(y2)
80 |
81 | return y1, y2 # torch.cat((y1, y2), 1)
82 |
83 | def jacobian(self, x, rev=False):
84 | if not rev:
85 | jac = torch.sum(self.s)
86 | else:
87 | jac = -torch.sum(self.s)
88 |
89 | return jac / x.shape[0]
90 |
91 | class InvNN(nn.Module):
92 | def __init__(self, channel_in_ho=3, channel_in_hi=3, subnet_constructor=None, subnet_constructor_v2=None, block_num=[], down_num=2, groups=None):
93 | super(InvNN, self).__init__()
94 | operations = []
95 | # current_channel = channel_in
96 | current_channel_ho = channel_in_ho
97 | current_channel_hi = channel_in_hi
98 | for i in range(down_num):
99 | for j in range(block_num[i]):
100 | b = InvBlock(subnet_constructor, subnet_constructor_v2, current_channel_ho, current_channel_hi, groups=groups)
101 | operations.append(b)
102 |
103 | self.operations = nn.ModuleList(operations)
104 |
105 | def forward(self, x, x_h, rev=False, cal_jacobian=False):
106 | # out = x
107 | jacobian = 0
108 |
109 | if not rev:
110 | for op in self.operations:
111 | x, x_h = op.forward(x, x_h, rev)
112 | if cal_jacobian:
113 | jacobian += op.jacobian(x, rev)
114 | else:
115 | for op in reversed(self.operations):
116 | x, x_h = op.forward(x, x_h, rev)
117 | if cal_jacobian:
118 | jacobian += op.jacobian(x, rev)
119 |
120 | if cal_jacobian:
121 | return x, x_h, jacobian
122 | else:
123 | return x, x_h
124 |
125 | class PredictiveModuleMIMO(nn.Module):
126 | def __init__(self, channel_in, nf, block_num_rbm=8):
127 | super(PredictiveModuleMIMO, self).__init__()
128 | self.conv_in = nn.Conv2d(channel_in, nf, 3, 1, 1, bias=True)
129 | residual_block = []
130 | for i in range(block_num_rbm):
131 | residual_block.append(ResidualBlockNoBN(nf))
132 | self.residual_block = nn.Sequential(*residual_block)
133 |
134 | def forward(self, x):
135 | x = self.conv_in(x)
136 | res = self.residual_block(x)
137 |
138 | return res
139 |
140 | def gauss_noise(shape):
141 | noise = torch.zeros(shape).cuda()
142 | for i in range(noise.shape[0]):
143 | noise[i] = torch.randn(noise[i].shape).cuda()
144 |
145 | return noise
146 |
147 | def gauss_noise_mul(shape):
148 | noise = torch.randn(shape).cuda()
149 |
150 | return noise
151 |
152 | class VSN(nn.Module):
153 | def __init__(self, opt, subnet_constructor=None, subnet_constructor_v2=None, down_num=2):
154 | super(VSN, self).__init__()
155 | self.model = opt['model']
156 | opt_net = opt['network_G']
157 | self.num_video = opt['num_video']
158 | self.gop = opt['gop']
159 | self.channel_in = opt_net['in_nc'] * self.gop
160 | self.channel_out = opt_net['out_nc'] * self.gop
161 | self.channel_in_hi = opt_net['in_nc'] * self.gop
162 | self.channel_in_ho = opt_net['in_nc'] * self.gop
163 |
164 | self.block_num = opt_net['block_num']
165 | self.block_num_rbm = opt_net['block_num_rbm']
166 | self.nf = self.channel_in_hi
167 | self.irn = InvNN(self.channel_in_ho, self.channel_in_hi, subnet_constructor, subnet_constructor_v2, self.block_num, down_num, groups=self.num_video)
168 | self.pm = PredictiveModuleMIMO(self.channel_in_ho, self.nf* self.num_video, block_num_rbm=self.block_num_rbm)
169 |
170 | def forward(self, x, x_h=None, rev=False, hs=[], direction='f'):
171 | if not rev:
172 | out_y, out_y_h = self.irn(x, x_h, rev)
173 | return out_y, out_y_h
174 | else:
175 | out_z = self.pm(x).unsqueeze(1)
176 | out_z_new = out_z.view(-1, self.num_video, self.channel_in, x.shape[-2], x.shape[-1])
177 | out_z_new = [out_z_new[:,i] for i in range(self.num_video)]
178 | out_x, out_x_h = self.irn(x, out_z_new, rev)
179 |
180 | return out_x, out_x_h, out_z
181 |
--------------------------------------------------------------------------------
/code/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | from datetime import datetime
6 | import random
7 | import logging
8 | from collections import OrderedDict
9 | import numpy as np
10 | import cv2
11 | import torch
12 | from torchvision.utils import make_grid
13 | from shutil import get_terminal_size
14 |
15 | import yaml
16 | try:
17 | from yaml import CLoader as Loader, CDumper as Dumper
18 | except ImportError:
19 | from yaml import Loader, Dumper
20 |
21 |
22 | def OrderedYaml():
23 | '''yaml orderedDict support'''
24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
25 |
26 | def dict_representer(dumper, data):
27 | return dumper.represent_dict(data.items())
28 |
29 | def dict_constructor(loader, node):
30 | return OrderedDict(loader.construct_pairs(node))
31 |
32 | Dumper.add_representer(OrderedDict, dict_representer)
33 | Loader.add_constructor(_mapping_tag, dict_constructor)
34 | return Loader, Dumper
35 |
36 |
37 | ####################
38 | # miscellaneous
39 | ####################
40 |
41 |
42 | def get_timestamp():
43 | return datetime.now().strftime('%y%m%d-%H%M%S')
44 |
45 |
46 | def mkdir(path):
47 | if not os.path.exists(path):
48 | os.makedirs(path)
49 |
50 |
51 | def mkdirs(paths):
52 | if isinstance(paths, str):
53 | mkdir(paths)
54 | else:
55 | for path in paths:
56 | mkdir(path)
57 |
58 |
59 | def mkdir_and_rename(path):
60 | # print(path)
61 | # exit(0)
62 | if os.path.exists(path):
63 | new_name = path + '_archived_' + get_timestamp()
64 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
65 | logger = logging.getLogger('base')
66 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
67 | # path = new_name
68 | os.rename(path, new_name)
69 | os.makedirs(path)
70 | # return path
71 |
72 |
73 | def set_random_seed(seed):
74 | random.seed(seed)
75 | np.random.seed(seed)
76 | torch.manual_seed(seed)
77 | torch.cuda.manual_seed_all(seed)
78 |
79 |
80 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
81 | '''set up logger'''
82 | lg = logging.getLogger(logger_name)
83 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
84 | datefmt='%y-%m-%d %H:%M:%S')
85 | lg.setLevel(level)
86 | if tofile:
87 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
88 | fh = logging.FileHandler(log_file, mode='w')
89 | fh.setFormatter(formatter)
90 | lg.addHandler(fh)
91 | if screen:
92 | sh = logging.StreamHandler()
93 | sh.setFormatter(formatter)
94 | lg.addHandler(sh)
95 |
96 |
97 | ####################
98 | # image convert
99 | ####################
100 |
101 |
102 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
103 | '''
104 | Converts a torch Tensor into an image Numpy array
105 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
106 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
107 | '''
108 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
109 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
110 | n_dim = tensor.dim()
111 | if n_dim == 4:
112 | n_img = len(tensor)
113 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
114 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
115 | elif n_dim == 3:
116 | img_np = tensor.numpy()
117 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
118 | elif n_dim == 2:
119 | img_np = tensor.numpy()
120 | else:
121 | raise TypeError(
122 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
123 | if out_type == np.uint8:
124 | img_np = (img_np * 255.0).round()
125 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
126 | return img_np.astype(out_type)
127 |
128 |
129 | def save_img(img, img_path, mode='RGB'):
130 | cv2.imwrite(img_path, img)
131 |
132 |
133 | ####################
134 | # metric
135 | ####################
136 |
137 |
138 | def calculate_psnr(img1, img2):
139 | # img1 and img2 have range [0, 255]
140 | img1 = img1.astype(np.float64)
141 | img2 = img2.astype(np.float64)
142 | mse = np.mean((img1 - img2)**2)
143 | if mse == 0:
144 | return float('inf')
145 | return 20 * math.log10(255.0 / math.sqrt(mse))
146 |
147 |
148 | def ssim(img1, img2):
149 | C1 = (0.01 * 255)**2
150 | C2 = (0.03 * 255)**2
151 |
152 | img1 = img1.astype(np.float64)
153 | img2 = img2.astype(np.float64)
154 | kernel = cv2.getGaussianKernel(11, 1.5)
155 | window = np.outer(kernel, kernel.transpose())
156 |
157 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
158 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
159 | mu1_sq = mu1**2
160 | mu2_sq = mu2**2
161 | mu1_mu2 = mu1 * mu2
162 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
163 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
164 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
165 |
166 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
167 | (sigma1_sq + sigma2_sq + C2))
168 | return ssim_map.mean()
169 |
170 |
171 | def calculate_ssim(img1, img2):
172 | '''calculate SSIM
173 | the same outputs as MATLAB's
174 | img1, img2: [0, 255]
175 | '''
176 | if not img1.shape == img2.shape:
177 | raise ValueError('Input images must have the same dimensions.')
178 | if img1.ndim == 2:
179 | return ssim(img1, img2)
180 | elif img1.ndim == 3:
181 | if img1.shape[2] == 3:
182 | ssims = []
183 | for i in range(3):
184 | ssims.append(ssim(img1, img2))
185 | return np.array(ssims).mean()
186 | elif img1.shape[2] == 1:
187 | return ssim(np.squeeze(img1), np.squeeze(img2))
188 | else:
189 | raise ValueError('Wrong input image dimensions.')
190 |
191 |
192 | class ProgressBar(object):
193 | '''A progress bar which can print the progress
194 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
195 | '''
196 |
197 | def __init__(self, task_num=0, bar_width=50, start=True):
198 | self.task_num = task_num
199 | max_bar_width = self._get_max_bar_width()
200 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
201 | self.completed = 0
202 | if start:
203 | self.start()
204 |
205 | def _get_max_bar_width(self):
206 | terminal_width, _ = get_terminal_size()
207 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
208 | if max_bar_width < 10:
209 | print('terminal width is too small ({}), please consider widen the terminal for better '
210 | 'progressbar visualization'.format(terminal_width))
211 | max_bar_width = 10
212 | return max_bar_width
213 |
214 | def start(self):
215 | if self.task_num > 0:
216 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
217 | ' ' * self.bar_width, self.task_num, 'Start...'))
218 | else:
219 | sys.stdout.write('completed: 0, elapsed: 0s')
220 | sys.stdout.flush()
221 | self.start_time = time.time()
222 |
223 | def update(self, msg='In progress...'):
224 | self.completed += 1
225 | elapsed = time.time() - self.start_time
226 | fps = self.completed / elapsed
227 | if self.task_num > 0:
228 | percentage = self.completed / float(self.task_num)
229 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
230 | mark_width = int(self.bar_width * percentage)
231 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
232 | sys.stdout.write('\033[2F') # cursor up 2 lines
233 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
234 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
235 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
236 | else:
237 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
238 | self.completed, int(elapsed + 0.5), fps))
239 | sys.stdout.flush()
240 |
--------------------------------------------------------------------------------
/code/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import random
5 | import logging
6 |
7 | import torch
8 | import torch.distributed as dist
9 | import torch.multiprocessing as mp
10 | from data.data_sampler import DistIterSampler
11 |
12 | import options.options as option
13 | from utils import util
14 | from data import create_dataloader, create_dataset
15 | from models import create_model
16 |
17 |
18 | def init_dist(backend='nccl', **kwargs):
19 | ''' initialization for distributed training'''
20 | # if mp.get_start_method(allow_none=True) is None:
21 | if mp.get_start_method(allow_none=True) != 'spawn':
22 | mp.set_start_method('spawn')
23 | rank = int(os.environ['RANK'])
24 | num_gpus = torch.cuda.device_count()
25 | torch.cuda.set_device(rank % num_gpus)
26 | dist.init_process_group(backend=backend, **kwargs)
27 |
28 | def cal_pnsr(sr_img, gt_img):
29 | # calculate PSNR
30 | gt_img = gt_img / 255.
31 | sr_img = sr_img / 255.
32 | psnr = util.calculate_psnr(sr_img * 255, gt_img * 255)
33 |
34 | return psnr
35 |
36 | def main():
37 | # options
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') # config 文件
40 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
41 | help='job launcher')
42 | parser.add_argument('--local_rank', type=int, default=0)
43 | args = parser.parse_args()
44 | opt = option.parse(args.opt, is_train=True)
45 |
46 | # distributed training settings
47 | if args.launcher == 'none': # disabled distributed training
48 | opt['dist'] = False
49 | rank = -1
50 | print('Disabled distributed training.')
51 | else:
52 | opt['dist'] = True
53 | init_dist()
54 | world_size = torch.distributed.get_world_size()
55 | rank = torch.distributed.get_rank()
56 |
57 | # loading resume state if exists
58 | if opt['path'].get('resume_state', None):
59 | # distributed resuming: all load into default GPU
60 | device_id = torch.cuda.current_device()
61 | resume_state = torch.load(opt['path']['resume_state'],
62 | map_location=lambda storage, loc: storage.cuda(device_id))
63 | option.check_resume(opt, resume_state['iter']) # check resume options
64 | else:
65 | resume_state = None
66 |
67 | # mkdir and loggers
68 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
69 | if resume_state is None:
70 | util.mkdir_and_rename(
71 | opt['path']['experiments_root']) # rename experiment folder if exists
72 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
73 | and 'pretrain_model' not in key and 'resume' not in key))
74 |
75 | # config loggers. Before it, the log will not work
76 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
77 | screen=True, tofile=True)
78 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
79 | screen=True, tofile=True)
80 | logger = logging.getLogger('base')
81 | logger.info(option.dict2str(opt))
82 | # tensorboard logger
83 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
84 | version = float(torch.__version__[0:3])
85 | if version >= 1.1: # PyTorch 1.1
86 | from torch.utils.tensorboard import SummaryWriter
87 | else:
88 | logger.info(
89 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
90 | from tensorboardX import SummaryWriter
91 | tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
92 | else:
93 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
94 | logger = logging.getLogger('base')
95 |
96 | # convert to NoneDict, which returns None for missing keys
97 | opt = option.dict_to_nonedict(opt)
98 |
99 | # random seed
100 | seed = opt['train']['manual_seed']
101 | if seed is None:
102 | seed = random.randint(1, 10000)
103 | if rank <= 0:
104 | logger.info('Random seed: {}'.format(seed))
105 | util.set_random_seed(seed)
106 |
107 | torch.backends.cudnn.benchmark = True
108 | # torch.backends.cudnn.deterministic = True
109 |
110 | #### create train and val dataloader
111 | dataset_ratio = 200 # enlarge the size of each epoch
112 | for phase, dataset_opt in opt['datasets'].items():
113 | if phase == 'train':
114 | train_set = create_dataset(dataset_opt)
115 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
116 | total_iters = int(opt['train']['niter'])
117 | total_epochs = int(math.ceil(total_iters / train_size))
118 | if opt['dist']:
119 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
120 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
121 | else:
122 | train_sampler = None
123 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
124 | if rank <= 0:
125 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
126 | len(train_set), train_size))
127 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
128 | total_epochs, total_iters))
129 | elif phase == 'val':
130 | val_set = create_dataset(dataset_opt)
131 | val_loader = create_dataloader(val_set, dataset_opt, opt, None)
132 | if rank <= 0:
133 | logger.info('Number of val images in [{:s}]: {:d}'.format(
134 | dataset_opt['name'], len(val_set)))
135 | else:
136 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
137 | assert train_loader is not None
138 |
139 | # create model
140 | model = create_model(opt)
141 | # resume training
142 | if resume_state:
143 | logger.info('Resuming training from epoch: {}, iter: {}.'.format(
144 | resume_state['epoch'], resume_state['iter']))
145 |
146 | start_epoch = resume_state['epoch']
147 | current_step = resume_state['iter']
148 | model.resume_training(resume_state) # handle optimizers and schedulers
149 | else:
150 | current_step = 0
151 | start_epoch = 0
152 |
153 | # training
154 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
155 | for epoch in range(start_epoch, total_epochs + 1):
156 | if opt['dist']:
157 | train_sampler.set_epoch(epoch)
158 | for _, train_data in enumerate(train_loader):
159 | current_step += 1
160 | if current_step > total_iters:
161 | break
162 | # training
163 | model.feed_data(train_data)
164 | model.optimize_parameters(current_step)
165 |
166 | # update learning rate
167 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
168 |
169 | # log
170 | if current_step % opt['logger']['print_freq'] == 0:
171 | logs = model.get_current_log()
172 | message = ' '.format(
173 | epoch, current_step, model.get_current_learning_rate())
174 | for k, v in logs.items():
175 | message += '{:s}: {:.4e} '.format(k, v)
176 | # tensorboard logger
177 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
178 | if rank <= 0:
179 | tb_logger.add_scalar(k, v, current_step)
180 | if rank <= 0:
181 | logger.info(message)
182 |
183 | # validation
184 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
185 | avg_psnr = 0.0
186 | avg_psnr_h = [0.0]*opt['num_video']
187 | avg_psnr_lr = 0.0
188 | idx = 0
189 | for video_id, val_data in enumerate(val_loader):
190 | img_dir = os.path.join(opt['path']['val_images'])
191 | util.mkdir(img_dir)
192 |
193 | model.feed_data(val_data)
194 | model.test()
195 |
196 | visuals = model.get_current_visuals()
197 |
198 | t_step = visuals['SR'].shape[0]
199 | idx += t_step
200 | n = len(visuals['SR_h'])
201 |
202 | for i in range(t_step):
203 |
204 | sr_img = util.tensor2img(visuals['SR'][i]) # uint8
205 | sr_img_h = []
206 | for j in range(n):
207 | sr_img_h.append(util.tensor2img(visuals['SR_h'][j][i])) # uint8
208 | gt_img = util.tensor2img(visuals['GT'][i]) # uint8
209 | lr_img = util.tensor2img(visuals['LR'][i])
210 | lrgt_img = []
211 | for j in range(n):
212 | lrgt_img.append(util.tensor2img(visuals['LR_ref'][j][i]))
213 |
214 | # Save SR images for reference
215 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'SR'))
216 | util.save_img(sr_img, save_img_path)
217 |
218 | for j in range(n):
219 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'SR_h'))
220 | util.save_img(sr_img_h[j], save_img_path)
221 |
222 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'GT'))
223 | util.save_img(gt_img, save_img_path)
224 |
225 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:s}.png'.format(video_id, i, 'LR'))
226 | util.save_img(lr_img, save_img_path)
227 |
228 | for j in range(n):
229 | save_img_path = os.path.join(img_dir,'{:d}_{:d}_{:d}_{:s}.png'.format(video_id, i, j, 'LRGT'))
230 | util.save_img(lrgt_img[j], save_img_path)
231 |
232 | psnr = cal_pnsr(sr_img, gt_img)
233 | psnr_h = []
234 | for j in range(n):
235 | psnr_h.append(cal_pnsr(sr_img_h[j], lrgt_img[j]))
236 | psnr_lr = cal_pnsr(lr_img, gt_img)
237 |
238 | avg_psnr += psnr
239 | for j in range(n):
240 | avg_psnr_h[j] += psnr_h[j]
241 | avg_psnr_lr += psnr_lr
242 |
243 | avg_psnr = avg_psnr / idx
244 | avg_psnr_h = [psnr / idx for psnr in avg_psnr_h]
245 | avg_psnr_lr = avg_psnr_lr / idx
246 |
247 | # log
248 | res_psnr_h = ''
249 | for p in avg_psnr_h:
250 | res_psnr_h+=('_{:.4e}'.format(p))
251 |
252 | logger.info('# Validation # PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}'.format(avg_psnr, res_psnr_h, avg_psnr_lr))
253 | logger_val = logging.getLogger('val') # validation logger
254 | logger_val.info(' PSNR_Cover: {:.4e}, PSNR_Secret: {:s}, PSNR_Stego: {:.4e}'.format(
255 | epoch, current_step, avg_psnr, res_psnr_h, avg_psnr_lr))
256 | # tensorboard logger
257 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
258 | tb_logger.add_scalar('psnr', avg_psnr, current_step)
259 |
260 | # save models and training states
261 | if current_step % opt['logger']['save_checkpoint_freq'] == 0:
262 | if rank <= 0:
263 | logger.info('Saving models and training states.')
264 | model.save(current_step)
265 | model.save_training_state(epoch, current_step)
266 |
267 | if rank <= 0:
268 | logger.info('Saving the final model.')
269 | model.save('latest')
270 | logger.info('End of training.')
271 |
272 |
273 | if __name__ == '__main__':
274 | main()
275 |
--------------------------------------------------------------------------------
/code/models/LFVSN.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 |
8 | import models.networks as networks
9 | import models.lr_scheduler as lr_scheduler
10 | from .base_model import BaseModel
11 | from models.modules.loss import ReconstructionLoss
12 | from models.modules.Quantization import Quantization
13 | from .modules.common import DWT,IWT
14 |
15 | logger = logging.getLogger('base')
16 | dwt=DWT()
17 | iwt=IWT()
18 |
19 |
20 | class Model_VSN(BaseModel):
21 | def __init__(self, opt):
22 | super(Model_VSN, self).__init__(opt)
23 |
24 | if opt['dist']:
25 | self.rank = torch.distributed.get_rank()
26 | else:
27 | self.rank = -1 # non dist training
28 |
29 | self.gop = opt['gop']
30 | train_opt = opt['train']
31 | test_opt = opt['test']
32 | self.opt = opt
33 | self.train_opt = train_opt
34 | self.test_opt = test_opt
35 | self.opt_net = opt['network_G']
36 | self.center = self.gop // 2
37 | self.num_video = opt['num_video']
38 | self.idxx = 0
39 |
40 | self.netG = networks.define_G_v2(opt).to(self.device)
41 | if opt['dist']:
42 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
43 | else:
44 | self.netG = DataParallel(self.netG)
45 | # print network
46 | self.print_network()
47 | self.load()
48 |
49 | self.Quantization = Quantization()
50 |
51 | if self.is_train:
52 | self.netG.train()
53 |
54 | # loss
55 | self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw'])
56 | self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back'])
57 | self.Reconstruction_center = ReconstructionLoss(losstype="center")
58 |
59 | # optimizers
60 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
61 | optim_params = []
62 | for k, v in self.netG.named_parameters():
63 | if v.requires_grad:
64 | optim_params.append(v)
65 | else:
66 | if self.rank <= 0:
67 | logger.warning('Params [{:s}] will not optimize.'.format(k))
68 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
69 | weight_decay=wd_G,
70 | betas=(train_opt['beta1'], train_opt['beta2']))
71 | self.optimizers.append(self.optimizer_G)
72 |
73 | # schedulers
74 | if train_opt['lr_scheme'] == 'MultiStepLR':
75 | for optimizer in self.optimizers:
76 | self.schedulers.append(
77 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
78 | restarts=train_opt['restarts'],
79 | weights=train_opt['restart_weights'],
80 | gamma=train_opt['lr_gamma'],
81 | clear_state=train_opt['clear_state']))
82 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
83 | for optimizer in self.optimizers:
84 | self.schedulers.append(
85 | lr_scheduler.CosineAnnealingLR_Restart(
86 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
87 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
88 | else:
89 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
90 |
91 | self.log_dict = OrderedDict()
92 |
93 | def feed_data(self, data):
94 | self.ref_L = data['LQ'].to(self.device)
95 | self.real_H = data['GT'].to(self.device)
96 |
97 | def init_hidden_state(self, z):
98 | b, c, h, w = z.shape
99 | h_t = []
100 | c_t = []
101 | for _ in range(self.opt_net['block_num_rbm']):
102 | h_t.append(torch.zeros([b, c, h, w]).cuda())
103 | c_t.append(torch.zeros([b, c, h, w]).cuda())
104 | memory = torch.zeros([b, c, h, w]).cuda()
105 |
106 | return h_t, c_t, memory
107 |
108 | def loss_forward(self, out, y):
109 | if self.opt['model'] == 'LSTM-VRN':
110 | l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y)
111 | return l_forw_fit
112 | elif self.opt['model'] == 'MIMO-VRN-h':
113 | l_forw_fit = 0
114 | for i in range(out.shape[1]):
115 | l_forw_fit += self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out[:, i], y[:, i])
116 | return l_forw_fit
117 |
118 | def loss_back_rec(self, out, x):
119 | if self.opt['model'] == 'LSTM-VRN':
120 | l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x)
121 | return l_back_rec
122 | elif self.opt['model'] == 'MIMO-VRN-h':
123 | l_back_rec = 0
124 | for i in range(x.shape[1]):
125 | l_back_rec += self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out[:, i], x[:, i])
126 | return l_back_rec
127 |
128 | def loss_back_rec_mul(self, out, x):
129 | out = torch.chunk(out,self.num_video,dim=1)
130 | out = [outi.squeeze(1) for outi in out]
131 | x = torch.chunk(x,self.num_video,dim=1)
132 | x = [xi.squeeze(1) for xi in x]
133 | l_back_rec = 0
134 | for i in range(len(x)):
135 | for j in range(x[i].shape[1]):
136 | l_back_rec += self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out[i][:, j], x[i][:, j])
137 | return l_back_rec
138 |
139 | def loss_center(self, out, x):
140 | # x.shape: (b, t, c, h, w)
141 | b, t = x.shape[:2]
142 | l_center = 0
143 | for i in range(b):
144 | mse_s = self.Reconstruction_center(out[i], x[i])
145 | mse_mean = torch.mean(mse_s)
146 | for j in range(t):
147 | l_center += torch.sqrt((mse_s[j] - mse_mean.detach()) ** 2 + 1e-18)
148 | l_center = self.train_opt['lambda_center'] * l_center / b
149 |
150 | return l_center
151 |
152 | def optimize_parameters(self, current_step):
153 | self.optimizer_G.zero_grad()
154 |
155 | b, n, t, c, h, w = self.ref_L.shape
156 | center = t // 2
157 | intval = self.gop // 2
158 |
159 | self.host = self.real_H[:, center - intval:center + intval + 1]
160 | self.secret = self.ref_L[:, :, center - intval:center + intval + 1]
161 | self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
162 | self.output, out_h = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret)
163 | self.output = iwt(self.output)
164 |
165 | Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach()
166 | container = self.output[:, :3 * self.gop, :, :].reshape(-1, self.gop, 3, h, w)[:,self.gop//2]
167 | l_forw_fit = self.loss_forward(container.unsqueeze(1), Gt_ref[:,self.gop//2].unsqueeze(1))
168 |
169 | y = self.Quantization(self.output[:, :3 * self.gop, :, :].view(-1, self.gop, 3, h, w)[:,self.gop//2].unsqueeze(1).repeat(1,self.gop,1,1,1).reshape(b, -1, h, w))
170 | out_x, out_x_h, out_z = self.netG(x=dwt(y), rev=True)
171 | out_x = iwt(out_x)
172 | out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
173 |
174 | l_back_rec = self.loss_back_rec(out_x.reshape(-1, self.gop, 3, h, w)[:,self.gop//2].unsqueeze(1), self.host[:,self.gop//2].unsqueeze(1))
175 | out_x_h = torch.stack(out_x_h, dim=1)
176 |
177 | l_center_x = 0
178 | for i in range(n):
179 | l_center_x += self.loss_back_rec(out_x_h.reshape(-1, n, self.gop, 3, h, w)[:, :, self.gop//2].unsqueeze(2)[:,i], self.ref_L[:, :, center - intval:center + intval + 1][:,:,self.gop//2].unsqueeze(2)[:, i])
180 |
181 | loss = l_forw_fit*2 + l_back_rec + l_center_x*4
182 | loss.backward()
183 |
184 | if self.train_opt['lambda_center'] != 0:
185 | self.log_dict['l_center_x'] = l_center_x.item()
186 |
187 | # set log
188 | self.log_dict['l_back_rec'] = l_back_rec.item()
189 | self.log_dict['l_forw_fit'] = l_forw_fit.item()
190 |
191 | self.log_dict['l_h'] = (l_center_x*10).item()
192 |
193 | # gradient clipping
194 | if self.train_opt['gradient_clipping']:
195 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
196 |
197 | self.optimizer_G.step()
198 |
199 | def test(self):
200 | self.netG.eval()
201 | with torch.no_grad():
202 | forw_L = []
203 | forw_L_h = []
204 | fake_H = []
205 | fake_H_h = []
206 | pred_z = []
207 | b, t, c, h, w = self.real_H.shape
208 | center = t // 2
209 | intval = self.gop // 2
210 | ids=[-1,0,1]
211 | b, n, t, c, h, w = self.ref_L.shape
212 | for j in range(3):
213 | id=ids[j]
214 | # forward downscaling
215 | self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
216 | self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
217 | self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
218 | self.output, out_h = self.netG(x=dwt(self.host.reshape(b, -1, h, w)),x_h=self.secret)
219 | self.output = iwt(self.output)
220 | out_lrs = self.output[:, :3 * self.gop, :, :].reshape(-1, self.gop, 3, h, w)
221 |
222 | # backward upscaling
223 | y = self.Quantization(self.output[:, :3 * self.gop, :, :].view(-1, self.gop, 3, h, w)[:,self.gop//2].unsqueeze(1).repeat(1,self.gop,1,1,1).reshape(b, -1, h, w))
224 | out_x, out_x_h, out_z = self.netG(x=dwt(y), rev=True)
225 | out_x = iwt(out_x)
226 | out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
227 | out_x = out_x.reshape(-1, self.gop, 3, h, w)
228 | out_x_h = torch.stack(out_x_h, dim=1)
229 | out_x_h = out_x_h.reshape(-1, n, self.gop, 3, h, w)
230 | forw_L.append(out_lrs[:, self.gop//2])
231 | fake_H.append(out_x[:, self.gop//2])
232 | fake_H_h.append(out_x_h[:,:, self.gop//2])
233 |
234 | self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1)
235 | self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1)
236 | self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1)
237 | self.netG.train()
238 |
239 | def get_current_log(self):
240 | return self.log_dict
241 |
242 | def get_current_visuals(self):
243 | b, n, t, c, h, w = self.ref_L.shape
244 | center = t // 2
245 | intval = 3 // 2
246 | out_dict = OrderedDict()
247 | LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu()
248 | LR_ref = torch.chunk(LR_ref, self.num_video, dim=0)
249 | out_dict['LR_ref'] = [video.squeeze(0) for video in LR_ref]
250 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
251 | SR_h = self.fake_H_h.detach()[0].float().cpu()
252 | SR_h = torch.chunk(SR_h, self.num_video, dim=0)
253 | out_dict['SR_h'] = [video.squeeze(0) for video in SR_h]
254 | out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
255 | out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu()
256 |
257 | return out_dict
258 |
259 | def print_network(self):
260 | s, n = self.get_network_description(self.netG)
261 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
262 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
263 | self.netG.module.__class__.__name__)
264 | else:
265 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
266 | if self.rank <= 0:
267 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
268 | logger.info(s)
269 |
270 | def load(self):
271 | load_path_G = self.opt['path']['pretrain_model_G']
272 | if load_path_G is not None:
273 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
274 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
275 |
276 | def load_test(self,load_path_G):
277 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
278 |
279 | def save(self, iter_label):
280 | self.save_network(self.netG, 'G', iter_label)
281 |
--------------------------------------------------------------------------------
/code/data/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import pickle
4 | import random
5 | import numpy as np
6 | import glob
7 | import torch
8 | import cv2
9 |
10 | ####################
11 | # Files & IO
12 | ####################
13 |
14 | ###################### get image path list ######################
15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
16 |
17 |
18 | def is_image_file(filename):
19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20 |
21 |
22 | def _get_paths_from_images(path):
23 | '''get image path list from image folder'''
24 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
25 | images = []
26 | for dirpath, _, fnames in sorted(os.walk(path)):
27 | for fname in sorted(fnames):
28 | if is_image_file(fname):
29 | img_path = os.path.join(dirpath, fname)
30 | images.append(img_path)
31 | assert images, '{:s} has no valid image file'.format(path)
32 | return images
33 |
34 |
35 | def _get_paths_from_lmdb(dataroot):
36 | '''get image path list from lmdb meta info'''
37 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
38 | paths = meta_info['keys']
39 | sizes = meta_info['resolution']
40 | if len(sizes) == 1:
41 | sizes = sizes * len(paths)
42 | return paths, sizes
43 |
44 |
45 | def get_image_paths(data_type, dataroot):
46 | '''get image path list
47 | support lmdb or image files'''
48 | paths, sizes = None, None
49 | if dataroot is not None:
50 | if data_type == 'lmdb':
51 | paths, sizes = _get_paths_from_lmdb(dataroot)
52 | elif data_type == 'img':
53 | paths = sorted(_get_paths_from_images(dataroot))
54 | else:
55 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
56 | return paths, sizes
57 |
58 |
59 | def glob_file_list(root):
60 | return sorted(glob.glob(os.path.join(root, '*')))
61 |
62 |
63 | ###################### read images ######################
64 | def _read_img_lmdb(env, key, size):
65 | '''read image from lmdb with key (w/ and w/o fixed size)
66 | size: (C, H, W) tuple'''
67 | with env.begin(write=False) as txn:
68 | buf = txn.get(key.encode('ascii'))
69 | img_flat = np.frombuffer(buf, dtype=np.uint8)
70 | C, H, W = size
71 | img = img_flat.reshape(H, W, C)
72 | return img
73 |
74 |
75 | def read_img(env, path, size=None):
76 | '''read image by cv2 or from lmdb
77 | return: Numpy float32, HWC, BGR, [0,1]'''
78 | if env is None: # img
79 | # print(path)
80 | #img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
81 | img = cv2.imread(path, cv2.IMREAD_COLOR)
82 | else:
83 | img = _read_img_lmdb(env, path, size)
84 | # print(img.shape)
85 | # if img is None:
86 | # print(path)
87 | # print(img.shape)
88 | img = img.astype(np.float32) / 255.
89 | if img.ndim == 2:
90 | img = np.expand_dims(img, axis=2)
91 | # some images have 4 channels
92 | if img.shape[2] > 3:
93 | img = img[:, :, :3]
94 | return img
95 |
96 |
97 | def read_img_seq(path):
98 | """Read a sequence of images from a given folder path
99 | Args:
100 | path (list/str): list of image paths/image folder path
101 |
102 | Returns:
103 | imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
104 | """
105 | if type(path) is list:
106 | img_path_l = path
107 | else:
108 | img_path_l = sorted(glob.glob(os.path.join(path, '*.png')))
109 | # print(path)
110 | # print(path,img_path_l)
111 | img_l = [read_img(None, v) for v in img_path_l]
112 | # stack to Torch tensor
113 | imgs = np.stack(img_l, axis=0)
114 | imgs = imgs[:, :, :, [2, 1, 0]]
115 | imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
116 | return imgs
117 |
118 |
119 | def index_generation(crt_i, max_n, N, padding='reflection'):
120 | """Generate an index list for reading N frames from a sequence of images
121 | Args:
122 | crt_i (int): current center index
123 | max_n (int): max number of the sequence of images (calculated from 1)
124 | N (int): reading N frames
125 | padding (str): padding mode, one of replicate | reflection | new_info | circle
126 | Example: crt_i = 0, N = 5
127 | replicate: [0, 0, 0, 1, 2]
128 | reflection: [2, 1, 0, 1, 2]
129 | new_info: [4, 3, 0, 1, 2]
130 | circle: [3, 4, 0, 1, 2]
131 |
132 | Returns:
133 | return_l (list [int]): a list of indexes
134 | """
135 | max_n = max_n - 1
136 | n_pad = N // 2
137 | return_l = []
138 |
139 | for i in range(crt_i - n_pad, crt_i + n_pad + 1):
140 | if i < 0:
141 | if padding == 'replicate':
142 | add_idx = 0
143 | elif padding == 'reflection':
144 | add_idx = -i
145 | elif padding == 'new_info':
146 | add_idx = (crt_i + n_pad) + (-i)
147 | elif padding == 'circle':
148 | add_idx = N + i
149 | else:
150 | raise ValueError('Wrong padding mode')
151 | elif i > max_n:
152 | if padding == 'replicate':
153 | add_idx = max_n
154 | elif padding == 'reflection':
155 | add_idx = max_n * 2 - i
156 | elif padding == 'new_info':
157 | add_idx = (crt_i - n_pad) - (i - max_n)
158 | elif padding == 'circle':
159 | add_idx = i - N
160 | else:
161 | raise ValueError('Wrong padding mode')
162 | else:
163 | add_idx = i
164 | return_l.append(add_idx)
165 | return return_l
166 |
167 |
168 | ####################
169 | # image processing
170 | # process on numpy image
171 | ####################
172 |
173 |
174 | def augment(img_list, hflip=True, rot=True):
175 | # horizontal flip OR rotate
176 | hflip = hflip and random.random() < 0.5
177 | vflip = rot and random.random() < 0.5
178 | rot90 = rot and random.random() < 0.5
179 |
180 | def _augment(img):
181 | if hflip:
182 | img = img[:, ::-1, :]
183 | if vflip:
184 | img = img[::-1, :, :]
185 | if rot90:
186 | img = img.transpose(1, 0, 2)
187 | return img
188 |
189 | return [_augment(img) for img in img_list]
190 |
191 |
192 | def augment_flow(img_list, flow_list, hflip=True, rot=True):
193 | # horizontal flip OR rotate
194 | hflip = hflip and random.random() < 0.5
195 | vflip = rot and random.random() < 0.5
196 | rot90 = rot and random.random() < 0.5
197 |
198 | def _augment(img):
199 | if hflip:
200 | img = img[:, ::-1, :]
201 | if vflip:
202 | img = img[::-1, :, :]
203 | if rot90:
204 | img = img.transpose(1, 0, 2)
205 | return img
206 |
207 | def _augment_flow(flow):
208 | if hflip:
209 | flow = flow[:, ::-1, :]
210 | flow[:, :, 0] *= -1
211 | if vflip:
212 | flow = flow[::-1, :, :]
213 | flow[:, :, 1] *= -1
214 | if rot90:
215 | flow = flow.transpose(1, 0, 2)
216 | flow = flow[:, :, [1, 0]]
217 | return flow
218 |
219 | rlt_img_list = [_augment(img) for img in img_list]
220 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
221 |
222 | return rlt_img_list, rlt_flow_list
223 |
224 |
225 | def channel_convert(in_c, tar_type, img_list):
226 | # conversion among BGR, gray and y
227 | if in_c == 3 and tar_type == 'gray': # BGR to gray
228 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
229 | return [np.expand_dims(img, axis=2) for img in gray_list]
230 | elif in_c == 3 and tar_type == 'y': # BGR to y
231 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
232 | return [np.expand_dims(img, axis=2) for img in y_list]
233 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
234 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
235 | else:
236 | return img_list
237 |
238 |
239 | def rgb2ycbcr(img, only_y=True):
240 | '''same as matlab rgb2ycbcr
241 | only_y: only return Y channel
242 | Input:
243 | uint8, [0, 255]
244 | float, [0, 1]
245 | '''
246 | in_img_type = img.dtype
247 | img.astype(np.float32)
248 | if in_img_type != np.uint8:
249 | img *= 255.
250 | # convert
251 | if only_y:
252 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
253 | else:
254 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
255 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
256 | if in_img_type == np.uint8:
257 | rlt = rlt.round()
258 | else:
259 | rlt /= 255.
260 | return rlt.astype(in_img_type)
261 |
262 |
263 | def bgr2ycbcr(img, only_y=True):
264 | '''bgr version of rgb2ycbcr
265 | only_y: only return Y channel
266 | Input:
267 | uint8, [0, 255]
268 | float, [0, 1]
269 | '''
270 | in_img_type = img.dtype
271 | img.astype(np.float32)
272 | if in_img_type != np.uint8:
273 | img *= 255.
274 | # convert
275 | if only_y:
276 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
277 | else:
278 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
279 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
280 | if in_img_type == np.uint8:
281 | rlt = rlt.round()
282 | else:
283 | rlt /= 255.
284 | return rlt.astype(in_img_type)
285 |
286 |
287 | def ycbcr2rgb(img):
288 | '''same as matlab ycbcr2rgb
289 | Input:
290 | uint8, [0, 255]
291 | float, [0, 1]
292 | '''
293 | in_img_type = img.dtype
294 | img.astype(np.float32)
295 | if in_img_type != np.uint8:
296 | img *= 255.
297 | # convert
298 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
299 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
300 | if in_img_type == np.uint8:
301 | rlt = rlt.round()
302 | else:
303 | rlt /= 255.
304 | return rlt.astype(in_img_type)
305 |
306 |
307 | def modcrop(img_in, scale):
308 | # img_in: Numpy, HWC or HW
309 | img = np.copy(img_in)
310 | if img.ndim == 2:
311 | H, W = img.shape
312 | H_r, W_r = H % scale, W % scale
313 | img = img[:H - H_r, :W - W_r]
314 | elif img.ndim == 3:
315 | H, W, C = img.shape
316 | H_r, W_r = H % scale, W % scale
317 | img = img[:H - H_r, :W - W_r, :]
318 | else:
319 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
320 | return img
321 |
322 |
323 | ####################
324 | # Functions
325 | ####################
326 |
327 |
328 | # matlab 'imresize' function, now only support 'bicubic'
329 | def cubic(x):
330 | absx = torch.abs(x)
331 | absx2 = absx**2
332 | absx3 = absx**3
333 | return (1.5 * absx3 - 2.5 * absx2 + 1) * (
334 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
335 | (absx > 1) * (absx <= 2)).type_as(absx))
336 |
337 |
338 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
339 | if (scale < 1) and (antialiasing):
340 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
341 | kernel_width = kernel_width / scale
342 |
343 | # Output-space coordinates
344 | x = torch.linspace(1, out_length, out_length)
345 |
346 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
347 | # in output space maps to 0.5 in input space, and 0.5+scale in output
348 | # space maps to 1.5 in input space.
349 | u = x / scale + 0.5 * (1 - 1 / scale)
350 |
351 | # What is the left-most pixel that can be involved in the computation?
352 | left = torch.floor(u - kernel_width / 2)
353 |
354 | # What is the maximum number of pixels that can be involved in the
355 | # computation? Note: it's OK to use an extra pixel here; if the
356 | # corresponding weights are all zero, it will be eliminated at the end
357 | # of this function.
358 | P = math.ceil(kernel_width) + 2
359 |
360 | # The indices of the input pixels involved in computing the k-th output
361 | # pixel are in row k of the indices matrix.
362 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
363 | 1, P).expand(out_length, P)
364 |
365 | # The weights used to compute the k-th output pixel are in row k of the
366 | # weights matrix.
367 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
368 | # apply cubic kernel
369 | if (scale < 1) and (antialiasing):
370 | weights = scale * cubic(distance_to_center * scale)
371 | else:
372 | weights = cubic(distance_to_center)
373 | # Normalize the weights matrix so that each row sums to 1.
374 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
375 | weights = weights / weights_sum.expand(out_length, P)
376 |
377 | # If a column in weights is all zero, get rid of it. only consider the first and last column.
378 | weights_zero_tmp = torch.sum((weights == 0), 0)
379 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
380 | indices = indices.narrow(1, 1, P - 2)
381 | weights = weights.narrow(1, 1, P - 2)
382 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
383 | indices = indices.narrow(1, 0, P - 2)
384 | weights = weights.narrow(1, 0, P - 2)
385 | weights = weights.contiguous()
386 | indices = indices.contiguous()
387 | sym_len_s = -indices.min() + 1
388 | sym_len_e = indices.max() - in_length
389 | indices = indices + sym_len_s - 1
390 | return weights, indices, int(sym_len_s), int(sym_len_e)
391 |
392 |
393 | def imresize(img, scale, antialiasing=True):
394 | # Now the scale should be the same for H and W
395 | # input: img: CHW RGB [0,1]
396 | # output: CHW RGB [0,1] w/o round
397 |
398 | in_C, in_H, in_W = img.size()
399 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
400 | kernel_width = 4
401 | kernel = 'cubic'
402 |
403 | # Return the desired dimension order for performing the resize. The
404 | # strategy is to perform the resize first along the dimension with the
405 | # smallest scale factor.
406 | # Now we do not support this.
407 |
408 | # get weights and indices
409 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
410 | in_H, out_H, scale, kernel, kernel_width, antialiasing)
411 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
412 | in_W, out_W, scale, kernel, kernel_width, antialiasing)
413 | # process H dimension
414 | # symmetric copying
415 | img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
416 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
417 |
418 | sym_patch = img[:, :sym_len_Hs, :]
419 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
420 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
421 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
422 |
423 | sym_patch = img[:, -sym_len_He:, :]
424 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
425 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
426 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
427 |
428 | out_1 = torch.FloatTensor(in_C, out_H, in_W)
429 | kernel_width = weights_H.size(1)
430 | for i in range(out_H):
431 | idx = int(indices_H[i][0])
432 | out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
433 | out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
434 | out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
435 |
436 | # process W dimension
437 | # symmetric copying
438 | out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
439 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
440 |
441 | sym_patch = out_1[:, :, :sym_len_Ws]
442 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
443 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
444 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
445 |
446 | sym_patch = out_1[:, :, -sym_len_We:]
447 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
448 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
449 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
450 |
451 | out_2 = torch.FloatTensor(in_C, out_H, out_W)
452 | kernel_width = weights_W.size(1)
453 | for i in range(out_W):
454 | idx = int(indices_W[i][0])
455 | out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
456 | out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
457 | out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
458 |
459 | return out_2
460 |
461 |
462 | def imresize_np(img, scale, antialiasing=True):
463 | # Now the scale should be the same for H and W
464 | # input: img: Numpy, HWC BGR [0,1]
465 | # output: HWC BGR [0,1] w/o round
466 | img = torch.from_numpy(img)
467 |
468 | in_H, in_W, in_C = img.size()
469 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
470 | kernel_width = 4
471 | kernel = 'cubic'
472 |
473 | # Return the desired dimension order for performing the resize. The
474 | # strategy is to perform the resize first along the dimension with the
475 | # smallest scale factor.
476 | # Now we do not support this.
477 |
478 | # get weights and indices
479 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
480 | in_H, out_H, scale, kernel, kernel_width, antialiasing)
481 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
482 | in_W, out_W, scale, kernel, kernel_width, antialiasing)
483 | # process H dimension
484 | # symmetric copying
485 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
486 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
487 |
488 | sym_patch = img[:sym_len_Hs, :, :]
489 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
490 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
491 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
492 |
493 | sym_patch = img[-sym_len_He:, :, :]
494 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
495 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
496 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
497 |
498 | out_1 = torch.FloatTensor(out_H, in_W, in_C)
499 | kernel_width = weights_H.size(1)
500 | for i in range(out_H):
501 | idx = int(indices_H[i][0])
502 | out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
503 | out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
504 | out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
505 |
506 | # process W dimension
507 | # symmetric copying
508 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
509 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
510 |
511 | sym_patch = out_1[:, :sym_len_Ws, :]
512 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
513 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
514 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
515 |
516 | sym_patch = out_1[:, -sym_len_We:, :]
517 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
518 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
519 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
520 |
521 | out_2 = torch.FloatTensor(out_H, out_W, in_C)
522 | kernel_width = weights_W.size(1)
523 | for i in range(out_W):
524 | idx = int(indices_W[i][0])
525 | out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
526 | out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
527 | out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
528 |
529 | return out_2.numpy()
530 |
531 |
532 | if __name__ == '__main__':
533 | # test imresize function
534 | # read images
535 | img = cv2.imread('test.png')
536 | img = img * 1.0 / 255
537 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
538 | # imresize
539 | scale = 1 / 4
540 | import time
541 | total_time = 0
542 | for i in range(10):
543 | start_time = time.time()
544 | rlt = imresize(img, scale, antialiasing=True)
545 | use_time = time.time() - start_time
546 | total_time += use_time
547 | print('average time: {}'.format(total_time / 10))
548 |
549 | import torchvision.utils
550 | torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
551 | normalize=False)
552 |
--------------------------------------------------------------------------------