├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── Invertible-Image-Rescaling-master.iml
├── .style.yapf
├── figures
├── InvMIHNet_8images.png
├── overview_InvMIHNet.png
└── very_large_capacity.png
├── experiments
└── pretrained_models
│ └── link
├── .flake8
├── codes
├── models
│ ├── __init__.py
│ ├── modules
│ │ ├── Quantization.py
│ │ ├── Subnet_constructor.py
│ │ ├── loss.py
│ │ ├── discriminator_vgg_arch.py
│ │ ├── module_util.py
│ │ ├── Inv_arch.py
│ │ └── Unet_common.py
│ ├── model.py
│ ├── networks.py
│ ├── rrdb_denselayer.py
│ ├── invblock.py
│ ├── IIH_module.py
│ ├── bicubic.py
│ ├── base_model.py
│ ├── lr_scheduler.py
│ └── InvMIHNet_model.py
├── options
│ ├── test
│ │ ├── test_InvMIHNet_16images.yml
│ │ ├── test_InvMIHNet_8images.yml
│ │ ├── test_InvMIHNet_9images.yml
│ │ ├── test_InvMIHNet_6images.yml
│ │ └── test_InvMIHNet_4images.yml
│ ├── train
│ │ ├── train_InvMIHNet_8images.yml
│ │ ├── train_InvMIHNet_4images.yml
│ │ ├── train_InvMIHNet_9images.yml
│ │ ├── train_InvMIHNet_6images.yml
│ │ └── train_InvMIHNet_16images.yml
│ └── options.py
├── data
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── Steg_dataset.py
│ └── util.py
├── test.py
├── utils
│ └── util.py
└── train.py
├── .gitignore
├── README.md
└── LICENSE
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | BASED_ON_STYLE = pep8
3 | COLUMN_LIMIT = 100
4 | SPLIT_BEFORE_NAMED_ASSIGNS = false
--------------------------------------------------------------------------------
/figures/InvMIHNet_8images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Brittany-Chen/InvMIHNet/HEAD/figures/InvMIHNet_8images.png
--------------------------------------------------------------------------------
/experiments/pretrained_models/link:
--------------------------------------------------------------------------------
1 | https://drive.google.com/file/d/17GRiwaJN8yqmLtiAO-bJcgpG8aWHysz3/view?usp=sharing
2 |
--------------------------------------------------------------------------------
/figures/overview_InvMIHNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Brittany-Chen/InvMIHNet/HEAD/figures/overview_InvMIHNet.png
--------------------------------------------------------------------------------
/figures/very_large_capacity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Brittany-Chen/InvMIHNet/HEAD/figures/very_large_capacity.png
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # Too many leading '#' for block comment (E266)
4 | E266
5 |
6 | max-line-length=100
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/codes/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | model = opt['model']
7 |
8 | if model == 'InvMIHNet':
9 | from .InvMIHNet_model import InvMIHNet as M
10 | else:
11 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
12 | m = M(opt)
13 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
14 | return m
15 |
--------------------------------------------------------------------------------
/.idea/Invertible-Image-Rescaling-master.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/codes/models/modules/Quantization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Quant(torch.autograd.Function):
5 |
6 | @staticmethod
7 | def forward(ctx, input):
8 | input = torch.clamp(input, 0, 1)
9 | output = (input * 255.).round() / 255.
10 | return output
11 |
12 | @staticmethod
13 | def backward(ctx, grad_output):
14 | return grad_output
15 |
16 | class Quantization(nn.Module):
17 | def __init__(self):
18 | super(Quantization, self).__init__()
19 |
20 | def forward(self, input):
21 | return Quant.apply(input)
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # folder
2 | .vscode
3 |
4 | experiments/*
5 | !experiments/pretrained_models
6 | experiments/pretrained_models/*
7 | # !experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth
8 | !experiments/pretrained_models/README.md
9 |
10 | results/*
11 | tb_logger/*
12 |
13 | # file type
14 | *.svg
15 | *.pyc
16 | *.t7
17 | *.caffemodel
18 | *.mat
19 | *.npy
20 |
21 | # latex
22 | *.aux
23 | *.bbl
24 | *.blg
25 | *.log
26 | *.out
27 | *.synctex.gz
28 |
29 | # TODO
30 | data_samples/samples_byteimg
31 | data_samples/samples_colorimg
32 | data_samples/samples_segprob
33 | data_samples/samples_result
--------------------------------------------------------------------------------
/codes/models/model.py:
--------------------------------------------------------------------------------
1 | import torch.optim
2 | import torch.nn as nn
3 | from models.IIH_module import IIH_module
4 |
5 |
6 | class Model(nn.Module):
7 | def __init__(self):
8 | super(Model, self).__init__()
9 |
10 | self.model = IIH_module()
11 |
12 | def forward(self, x, rev=False):
13 |
14 | if not rev:
15 | out = self.model(x)
16 |
17 | else:
18 | out = self.model(x, rev=True)
19 |
20 | return out
21 |
22 |
23 | def init_model(mod):
24 | for key, param in mod.named_parameters():
25 | split = key.split('.')
26 | if param.requires_grad:
27 | param.data = 0.01 * torch.randn(param.data.shape).cuda()
28 | if split[-2] == 'conv5':
29 | param.data.fill_(0.)
30 |
--------------------------------------------------------------------------------
/codes/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import models.modules.discriminator_vgg_arch as SRGAN_arch
4 | from models.modules.Inv_arch import *
5 | from models.modules.Subnet_constructor import subnet
6 | import math
7 | logger = logging.getLogger('base')
8 |
9 |
10 | ####################
11 | # define network
12 | ####################
13 | def define_G(opt):
14 | opt_net = opt['network_G']
15 | which_model = opt_net['which_model_G']
16 | subnet_type = which_model['subnet_type']
17 | if opt_net['init']:
18 | init = opt_net['init']
19 | else:
20 | init = 'xavier'
21 |
22 | down_num = len(opt_net['block_num'])
23 | # down_num = int(math.log(opt_net['scale_W'], 2))
24 |
25 | use_ConvDownsampling = False
26 |
27 | if which_model['use_ConvDownsampling']:
28 | use_ConvDownsampling = True
29 |
30 | netG = InvRescaleNet(opt_net['in_nc'], opt_net['out_nc'], subnet(subnet_type, init), opt_net['block_num'], down_num, use_ConvDownsampling=use_ConvDownsampling, down_scale_W=opt_net['scale_W'], down_scale_H=opt_net['scale_H'])
31 |
32 | return netG
33 |
--------------------------------------------------------------------------------
/codes/models/rrdb_denselayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import models.modules.module_util as mutil
4 |
5 |
6 | # Dense connection
7 | class ResidualDenseBlock_out(nn.Module):
8 | def __init__(self, input, output, bias=True):
9 | super(ResidualDenseBlock_out, self).__init__()
10 | self.conv1 = nn.Conv2d(input, 32, 3, 1, 1, bias=bias)
11 | self.conv2 = nn.Conv2d(input + 32, 32, 3, 1, 1, bias=bias)
12 | self.conv3 = nn.Conv2d(input + 2 * 32, 32, 3, 1, 1, bias=bias)
13 | self.conv4 = nn.Conv2d(input + 3 * 32, 32, 3, 1, 1, bias=bias)
14 | self.conv5 = nn.Conv2d(input + 4 * 32, output, 3, 1, 1, bias=bias)
15 | self.lrelu = nn.LeakyReLU(inplace=True)
16 | # initialization
17 | mutil.initialize_weights([self.conv5], 0.)
18 |
19 | def forward(self, x):
20 | x1 = self.lrelu(self.conv1(x))
21 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
22 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
23 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
24 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
25 | return x5
26 |
27 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/codes/models/invblock.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 | import torch
3 | import torch.nn as nn
4 | from models.rrdb_denselayer import ResidualDenseBlock_out
5 |
6 |
7 | class INV_block(nn.Module):
8 | def __init__(self, subnet_constructor=ResidualDenseBlock_out, clamp=2.0, harr=True, in_1=3, in_2=3):
9 | super().__init__()
10 | if harr:
11 | self.split_len1 = in_1 * 4
12 | self.split_len2 = in_2 * 4
13 | self.clamp = clamp
14 | # ρ
15 | self.r = subnet_constructor(self.split_len1, self.split_len2)
16 | # η
17 | self.y = subnet_constructor(self.split_len1, self.split_len2)
18 | # φ
19 | self.f = subnet_constructor(self.split_len2, self.split_len1)
20 |
21 |
22 | def e(self, s):
23 | return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))
24 |
25 | def forward(self, x, rev=False):
26 | x1, x2 = (x.narrow(1, 0, self.split_len1),
27 | x.narrow(1, self.split_len1, self.split_len2))
28 |
29 | if not rev:
30 |
31 | t2 = self.f(x2)
32 | y1 = x1 + t2
33 | s1, t1 = self.r(y1), self.y(y1)
34 | y2 = self.e(s1) * x2 + t1
35 |
36 | else:
37 |
38 | s1, t1 = self.r(x1), self.y(x1)
39 | y2 = (x2 - t1) / self.e(s1)
40 | t2 = self.f(y2)
41 | y1 = (x1 - t2)
42 |
43 | return torch.cat((y1, y2), 1)
44 |
45 |
--------------------------------------------------------------------------------
/codes/options/test/test_InvMIHNet_16images.yml:
--------------------------------------------------------------------------------
1 | name: Hiding_16_images_test
2 | suffix: ~ # add suffix to saved images
3 | model: InvMIHNet
4 | distortion: sr
5 | scale_W: 4
6 | scale_H: 4
7 | crop_border: 0 # crop border when evaluation. If None(~), crop the scale pixels
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test_1:
12 | name: val_DIV2K
13 | mode: Steg
14 | batch_size: 17
15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024
16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
17 | # test_2:
18 | # name: val_COCO
19 | # mode: LQGT
20 | # dataroot_GT: # path to COCO testing dataset
21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
22 | # test_3:
23 | # name: val_ImageNet
24 | # mode: LQGT
25 | # dataroot_GT: # path to ImageNet testing dataset
26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
27 |
28 |
29 | #### network
30 | network_G:
31 | which_model_G:
32 | subnet_type: DBNet
33 | in_nc: 3
34 | out_nc: 3
35 | block_num: [8, 8]
36 | scale_W: 4
37 | scale_H: 4
38 | init: xavier
39 |
40 |
41 | #### path
42 | path:
43 | pretrain_model_G: ../experiments/pretrained_models/IIR_16images.pth
44 | pretrain_model_H: ../experiments/pretrained_models/IIH_16images.pth
45 |
--------------------------------------------------------------------------------
/codes/options/test/test_InvMIHNet_8images.yml:
--------------------------------------------------------------------------------
1 | name: Hiding_8_images_test
2 | suffix: ~ # add suffix to saved images
3 | model: InvMIHNet
4 | distortion: sr
5 | scale_H: 4
6 | scale_W: 2
7 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test_1:
12 | name: val_DIV2K
13 | mode: Steg
14 | batch_size: 9
15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024
16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
17 | # test_2:
18 | # name: val_COCO
19 | # mode: LQGT
20 | # dataroot_GT: # path to COCO testing dataset
21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
22 | # test_3:
23 | # name: val_ImageNet
24 | # mode: LQGT
25 | # dataroot_GT: # path to ImageNet testing dataset
26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
27 |
28 |
29 |
30 | #### network
31 | network_G:
32 | which_model_G:
33 | subnet_type: DBNet
34 | in_nc: 3
35 | out_nc: 3
36 | block_num: [8, 8]
37 | scale_H: 4
38 | scale_W: 2
39 | init: xavier
40 |
41 |
42 | #### path
43 | path:
44 | pretrain_model_G: ../experiments/pretrained_models/IIR_8images.pth
45 | pretrain_model_H: ../experiments/pretrained_models/IIH_8images.pth
46 |
--------------------------------------------------------------------------------
/codes/options/test/test_InvMIHNet_9images.yml:
--------------------------------------------------------------------------------
1 | name: Hiding_9_images_test
2 | suffix: ~ # add suffix to saved images
3 | model: InvMIHNet
4 | distortion: sr
5 | scale_W: 3
6 | scale_H: 3
7 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test_1:
12 | name: val_DIV2K
13 | mode: Steg
14 | batch_size: 10
15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024
16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
17 | # test_2:
18 | # name: val_COCO
19 | # mode: LQGT
20 | # dataroot_GT: # path to COCO testing dataset
21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
22 | # test_3:
23 | # name: val_ImageNet
24 | # mode: LQGT
25 | # dataroot_GT: # path to ImageNet testing dataset
26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
27 |
28 |
29 | #### network
30 | network_G:
31 | which_model_G:
32 | subnet_type: DBNet
33 | use_ConvDownsampling: True
34 | in_nc: 3
35 | out_nc: 3
36 | block_num: [12]
37 | scale_W: 3
38 | scale_H: 3
39 | init: xavier
40 |
41 |
42 | #### path
43 | path:
44 | pretrain_model_G: ../experiments/pretrained_models/IIR_9images.pth
45 | pretrain_model_H: ../experiments/pretrained_models/IIH_9images.pth
--------------------------------------------------------------------------------
/codes/options/test/test_InvMIHNet_6images.yml:
--------------------------------------------------------------------------------
1 | name: Hiding_6_images_test
2 | suffix: ~ # add suffix to saved images
3 | model: InvMIHNet
4 | distortion: sr
5 | scale_W: 2
6 | scale_H: 3
7 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test_1:
12 | name: val_DIV2K
13 | mode: Steg
14 | batch_size: 7
15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024
16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
17 | # test_2:
18 | # name: val_COCO
19 | # mode: LQGT
20 | # dataroot_GT: # path to COCO testing dataset
21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
22 | # test_3:
23 | # name: val_ImageNet
24 | # mode: LQGT
25 | # dataroot_GT: # path to ImageNet testing dataset
26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
27 |
28 |
29 |
30 |
31 | #### network
32 | network_G:
33 | which_model_G:
34 | subnet_type: DBNet
35 | use_ConvDownsampling: True
36 | in_nc: 3
37 | out_nc: 3
38 | block_num: [12]
39 | scale_W: 2
40 | scale_H: 3
41 | init: xavier
42 |
43 |
44 | #### path
45 | path:
46 | pretrain_model_G: ../experiments/pretrained_models/IIR_6images.pth #../experiments/pretrained_models_IRN_all/100_G.pth
47 | pretrain_model_H: ../experiments/pretrained_models/IIH_6images.pth
48 |
--------------------------------------------------------------------------------
/codes/options/test/test_InvMIHNet_4images.yml:
--------------------------------------------------------------------------------
1 | name: Hiding_4_images_test
2 | suffix: ~ # add suffix to saved images
3 | model: InvMIHNet
4 | distortion: sr
5 | scale_W: 2
6 | scale_H: 2
7 | crop_border: 0 # crop border when evaluation. If None(~), crop the scale pixels
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test_1:
12 | name: val_DIV2K
13 | mode: Steg
14 | batch_size: 5
15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024 # path to DIV2K testing dataset
16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
17 | # test_2:
18 | # name: val_COCO
19 | # mode: LQGT
20 | # batch_size: 5
21 | # dataroot_GT: # path to COCO testing dataset
22 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
23 | # test_3:
24 | # name: val_ImageNet
25 | # mode: LQGT
26 | # batch_size: 5
27 | # dataroot_GT: # path to ImageNet testing dataset
28 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader
29 |
30 |
31 | #### network
32 | network_G:
33 | which_model_G:
34 | subnet_type: DBNet
35 | in_nc: 3
36 | out_nc: 3
37 | block_num: [8]
38 | scale_W: 2
39 | scale_H: 2
40 | init: xavier
41 |
42 | network_H:
43 | lamda_reconstruction: 1
44 | lamda_guide: 10
45 | lamda_low_frequency: 1
46 |
47 | #### path
48 | path:
49 | pretrain_model_G: ../experiments/pretrained_models/IIR_4images.pth
50 | pretrain_model_H: ../experiments/pretrained_models/IIH_4images.pth
--------------------------------------------------------------------------------
/codes/data/__init__.py:
--------------------------------------------------------------------------------
1 | '''create dataset and dataloader'''
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 |
7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
8 | phase = dataset_opt['phase']
9 | if phase == 'train':
10 | if opt['dist']:
11 | world_size = torch.distributed.get_world_size()
12 | num_workers = dataset_opt['n_workers']
13 | assert dataset_opt['batch_size'] % world_size == 0
14 | batch_size = dataset_opt['batch_size'] // world_size
15 | shuffle = False
16 | else:
17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
18 | batch_size = dataset_opt['batch_size']
19 | shuffle = True
20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
21 | num_workers=num_workers, sampler=sampler, drop_last=True,
22 | pin_memory=False)
23 | else:
24 | return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False, num_workers=0,
25 | pin_memory=True, drop_last=True)
26 |
27 |
28 | def create_dataset(dataset_opt):
29 | mode = dataset_opt['mode']
30 | if mode == 'Steg':
31 | from data.Steg_dataset import StegDataset as S
32 | else:
33 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
34 | dataset = S(dataset_opt)
35 |
36 | logger = logging.getLogger('base')
37 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
38 | dataset_opt['name']))
39 | return dataset
40 |
--------------------------------------------------------------------------------
/codes/models/modules/Subnet_constructor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import models.modules.module_util as mutil
5 |
6 | class DenseBlock(nn.Module):
7 | def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True):
8 | super(DenseBlock, self).__init__()
9 | self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias)
10 | self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias)
11 | self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias)
12 | self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias)
13 | self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias)
14 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
15 |
16 | if init == 'xavier':
17 | mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
18 | else:
19 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
20 | mutil.initialize_weights(self.conv5, 0)
21 |
22 | def forward(self, x):
23 | x1 = self.lrelu(self.conv1(x))
24 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
25 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
26 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
27 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
28 | return x5
29 |
30 |
31 | def subnet(net_structure, init='xavier'):
32 | def constructor(channel_in, channel_out):
33 | if net_structure == 'DBNet':
34 | if init == 'xavier':
35 | return DenseBlock(channel_in, channel_out, init)
36 | else:
37 | return DenseBlock(channel_in, channel_out)
38 | else:
39 | return None
40 |
41 | return constructor
--------------------------------------------------------------------------------
/codes/options/train/train_InvMIHNet_8images.yml:
--------------------------------------------------------------------------------
1 |
2 | #### general settings
3 |
4 | name: InvMIHNet_8images_train
5 | use_tb_logger: true
6 | model: InvMIHNet
7 | distortion: sr
8 | scale_H: 4
9 | scale_W: 2
10 | gpu_ids: [0]
11 |
12 |
13 | #### datasets
14 |
15 | datasets:
16 | train:
17 | name: DIV2K
18 | mode: Steg
19 | dataroot_GT: # path to training images
20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader
21 |
22 | use_shuffle: true
23 | n_workers: 0 # per GPU
24 | batch_size: 9
25 | GT_size: 144
26 | color: RGB
27 |
28 | val:
29 | name: val_DIV2K
30 | mode: Steg
31 | batch_size: 9
32 | dataroot_GT: # path to validation images
33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader
34 |
35 |
36 | #### network structures
37 |
38 | network_G:
39 | which_model_G:
40 | subnet_type: DBNet
41 | in_nc: 3
42 | out_nc: 3
43 | block_num: [8, 8]
44 | scale_H: 4
45 | scale_W: 2
46 | init: xavier
47 |
48 |
49 | #### path
50 |
51 | path:
52 | pretrain_model_G: ~
53 | pretrain_model_H: ~
54 | strict_load: true
55 | resume_state: ~
56 |
57 |
58 | #### training settings: learning rate scheme, loss
59 |
60 | train:
61 | lr_G: !!float 2e-4
62 | lr_H: !!float 2e-4
63 | beta1: 0.9
64 | beta2: 0.999
65 | beta1_H: 0.5
66 | beta2_H: 0.999
67 | niter: 500000
68 | warmup_iter: -1 # no warm up
69 |
70 | lr_scheme: MultiStepLR
71 | lr_steps: [20000, 40000, 80000, 100000]
72 | lr_gamma: 0.5
73 |
74 | pixel_criterion_forw: l2
75 | pixel_criterion_back: l1
76 |
77 | manual_seed: 10
78 |
79 | val_freq: 1000
80 | lambda_fit_forw: 4
81 | lambda_rec_back: 1
82 | lambda_ce_forw: 1
83 | lamda_reconstruction: 5
84 | lamda_guide: 1
85 | lamda_low_frequency: 1
86 |
87 | weight_decay_G: !!float 1e-5
88 | weight_decay_H: !!float 1e-5
89 | gradient_clipping: 10
90 |
91 |
92 | weight_step: 1000
93 |
94 |
95 | #### logger
96 |
97 | logger:
98 | save_checkpoint_freq: 1000
--------------------------------------------------------------------------------
/codes/options/train/train_InvMIHNet_4images.yml:
--------------------------------------------------------------------------------
1 |
2 | #### general settings
3 |
4 | name: InvMIHNet_4images_train
5 | use_tb_logger: true
6 | model: InvMIHNet
7 | distortion: sr
8 | scale_W: 2
9 | scale_H: 2
10 | gpu_ids: [0]
11 |
12 |
13 | #### datasets
14 |
15 | datasets:
16 | train:
17 | name: DIV2K
18 | mode: Steg
19 | dataroot_GT: # path to training images
20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader
21 |
22 | use_shuffle: true
23 | n_workers: 0 # per GPU
24 | batch_size: 5
25 | GT_size: 144
26 | color: RGB
27 |
28 | val:
29 | name: val_DIV2K
30 | mode: Steg
31 | batch_size: 5
32 | dataroot_GT: # path to validation images
33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader
34 |
35 |
36 | #### network structures
37 |
38 | network_G:
39 | which_model_G:
40 | subnet_type: DBNet
41 | in_nc: 3
42 | out_nc: 3
43 | block_num: [8]
44 | scale_W: 2
45 | scale_H: 2
46 | init: xavier
47 |
48 |
49 | #### path
50 |
51 | path:
52 | pretrain_model_G: ~
53 | pretrain_model_H: ~
54 | strict_load: true
55 | resume_state: ~
56 |
57 |
58 | #### training settings: learning rate scheme, loss
59 | train:
60 | lr_G: !!float 2e-4
61 | lr_H: !!float 2e-4
62 | beta1: 0.9
63 | beta2: 0.999
64 | beta1_H: 0.5
65 | beta2_H: 0.999
66 | niter: 500000
67 | warmup_iter: -1 # no warm up
68 |
69 |
70 | lr_scheme: MultiStepLR
71 | lr_steps: [20000, 40000, 80000, 100000]
72 | lr_gamma: 0.5
73 |
74 | pixel_criterion_forw: l2
75 | pixel_criterion_back: l1
76 |
77 | manual_seed: 10
78 |
79 | val_freq: 1000
80 |
81 | lambda_fit_forw: 4
82 | lambda_rec_back: 1
83 | lambda_ce_forw: 1
84 | lamda_reconstruction: 5
85 | lamda_guide: 1
86 | lamda_low_frequency: 1
87 |
88 | weight_decay_G: !!float 1e-5
89 | weight_decay_H: !!float 1e-5
90 | gradient_clipping: 10
91 |
92 |
93 | weight_step: 1000
94 |
95 |
96 | #### logger
97 |
98 | logger:
99 | save_checkpoint_freq: 1000
--------------------------------------------------------------------------------
/codes/options/train/train_InvMIHNet_9images.yml:
--------------------------------------------------------------------------------
1 |
2 | #### general settings
3 |
4 | name: InvMIHNet_9images_train
5 | use_tb_logger: true
6 | model: InvMIHNet
7 | distortion: sr
8 | scale_W: 3
9 | scale_H: 3
10 | gpu_ids: [0]
11 |
12 |
13 | #### datasets
14 |
15 | datasets:
16 | train:
17 | name: DIV2K
18 | mode: Steg
19 | dataroot_GT: # path to training images
20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader
21 |
22 | use_shuffle: true
23 | n_workers: 0 # per GPU
24 | batch_size: 10
25 | GT_size: 144
26 | color: RGB
27 |
28 | val:
29 | name: val_DIV2K
30 | mode: Steg
31 | batch_size: 10
32 | dataroot_GT: # path to validation images
33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader
34 |
35 |
36 | #### network structures
37 |
38 | network_G:
39 | which_model_G:
40 | subnet_type: DBNet
41 | use_ConvDownsampling: True
42 | in_nc: 3
43 | out_nc: 3
44 | block_num: [12]
45 | scale_W: 3
46 | scale_H: 3
47 | init: xavier
48 |
49 |
50 | #### path
51 |
52 | path:
53 | pretrain_model_G: ~
54 | pretrain_model_H: ~
55 | strict_load: true
56 | resume_state: ~
57 |
58 |
59 | #### training settings: learning rate scheme, loss
60 |
61 | train:
62 | lr_G: !!float 2e-4
63 | lr_H: !!float 2e-4
64 | beta1: 0.9
65 | beta2: 0.999
66 | beta1_H: 0.5
67 | beta2_H: 0.999
68 | niter: 500000
69 | warmup_iter: -1 # no warm up
70 |
71 | lr_scheme: MultiStepLR
72 | lr_steps: [20000, 40000, 80000, 100000]
73 | lr_gamma: 0.5
74 |
75 | pixel_criterion_forw: l2
76 | pixel_criterion_back: l1
77 |
78 | manual_seed: 10
79 |
80 | val_freq: 1000
81 |
82 | lambda_fit_forw: 4
83 | lambda_rec_back: 1
84 | lambda_ce_forw: 1
85 | lamda_reconstruction: 5
86 | lamda_guide: 1
87 | lamda_low_frequency: 1
88 |
89 | weight_decay_G: !!float 1e-5
90 | weight_decay_H: !!float 1e-5
91 | gradient_clipping: 10
92 |
93 |
94 | weight_step: 1000
95 |
96 |
97 | #### logger
98 |
99 | logger:
100 | save_checkpoint_freq: 1000
--------------------------------------------------------------------------------
/codes/options/train/train_InvMIHNet_6images.yml:
--------------------------------------------------------------------------------
1 |
2 | #### general settings
3 |
4 | name: InvMIHNet_6images_train
5 | use_tb_logger: true
6 | model: InvMIHNet
7 | distortion: sr
8 | scale_W: 2
9 | scale_H: 3
10 | gpu_ids: [0]
11 |
12 |
13 | #### datasets
14 |
15 | datasets:
16 | train:
17 | name: DIV2K
18 | mode: Steg
19 | dataroot_GT: # path to training images
20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader
21 |
22 | use_shuffle: true
23 | n_workers: 0 # per GPU
24 | batch_size: 7
25 | GT_size: 144
26 | color: RGB
27 |
28 | val:
29 | name: val_DIV2K
30 | mode: Steg
31 | batch_size: 7
32 | dataroot_GT: # path to validation images
33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader
34 |
35 |
36 | #### network structures
37 |
38 | network_G:
39 | which_model_G:
40 | subnet_type: DBNet
41 | use_ConvDownsampling: True
42 | in_nc: 3
43 | out_nc: 3
44 | block_num: [12]
45 | scale_W: 2
46 | scale_H: 3
47 | init: xavier
48 |
49 |
50 | #### path
51 |
52 | path:
53 | pretrain_model_G: ~
54 | pretrain_model_H: ~
55 | strict_load: true
56 | resume_state: ~
57 |
58 |
59 | #### training settings: learning rate scheme, loss
60 | train:
61 | lr_G: !!float 2e-4
62 | lr_H: !!float 2e-4
63 | beta1: 0.9
64 | beta2: 0.999
65 | beta1_H: 0.5
66 | beta2_H: 0.999
67 | niter: 500000
68 | warmup_iter: -1 # no warm up
69 |
70 |
71 | lr_scheme: MultiStepLR
72 | lr_steps: [20000, 40000, 80000, 100000]
73 | lr_gamma: 0.5
74 |
75 | pixel_criterion_forw: l2
76 | pixel_criterion_back: l1
77 |
78 | manual_seed: 10
79 |
80 | val_freq: 1000
81 |
82 | lambda_fit_forw: 4
83 | lambda_rec_back: 1
84 | lambda_ce_forw: 1
85 | lamda_reconstruction: 5
86 | lamda_guide: 1
87 | lamda_low_frequency: 1
88 |
89 | weight_decay_G: !!float 1e-5
90 | weight_decay_H: !!float 1e-5
91 | gradient_clipping: 10
92 |
93 |
94 | weight_step: 1000
95 |
96 |
97 | #### logger
98 |
99 | logger:
100 | save_checkpoint_freq: 1000
101 |
--------------------------------------------------------------------------------
/codes/options/train/train_InvMIHNet_16images.yml:
--------------------------------------------------------------------------------
1 |
2 | #### general settings
3 |
4 | name: InvMIHNet_16images_train
5 | use_tb_logger: true
6 | model: InvMIHNet
7 | distortion: sr
8 | scale_W: 4
9 | scale_H: 4
10 | gpu_ids: [0]
11 |
12 |
13 | #### datasets
14 |
15 | datasets:
16 | train:
17 | name: DIV2K
18 | mode: Steg
19 | dataroot_GT: # path to training images
20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader
21 |
22 | use_shuffle: true
23 | n_workers: 6 # per GPU
24 | batch_size: 17
25 | GT_size: 144
26 | use_flip: true
27 | use_rot: true
28 | color: RGB
29 |
30 | val:
31 | name: val_DIV2K
32 | mode: Steg
33 | batch_size: 17
34 | dataroot_GT: # path to validation images
35 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader
36 |
37 |
38 | #### network structures
39 |
40 | network_G:
41 | which_model_G:
42 | subnet_type: DBNet
43 | in_nc: 3
44 | out_nc: 3
45 | block_num: [8, 8]
46 | scale_W: 4
47 | scale_H: 4
48 | init: xavier
49 |
50 |
51 | #### path
52 |
53 | path:
54 | pretrain_model_G: ~
55 | pretrain_model_H: ~
56 | strict_load: true
57 | resume_state: ~
58 |
59 |
60 | #### training settings: learning rate scheme, loss
61 |
62 | train:
63 | lr_G: !!float 2e-4
64 | lr_H: !!float 2e-4
65 | beta1: 0.9
66 | beta2: 0.999
67 | beta1_H: 0.5
68 | beta2_H: 0.999
69 | niter: 500000
70 | warmup_iter: -1 # no warm up
71 |
72 |
73 | lr_scheme: MultiStepLR
74 | lr_steps: [20000, 40000, 80000, 100000]
75 | lr_gamma: 0.5
76 |
77 | pixel_criterion_forw: l2
78 | pixel_criterion_back: l1
79 |
80 | manual_seed: 10
81 |
82 | val_freq: 1000
83 |
84 | lambda_fit_forw: 4
85 | lambda_rec_back: 1
86 | lambda_ce_forw: 1
87 | lamda_reconstruction: 5
88 | lamda_guide: 1
89 | lamda_low_frequency: 1
90 |
91 | weight_decay_G: !!float 1e-5
92 | weight_decay_H: !!float 1e-5
93 | gradient_clipping: 10
94 |
95 |
96 | weight_step: 1000
97 |
98 |
99 | #### logger
100 |
101 | logger:
102 | save_checkpoint_freq: 1000
--------------------------------------------------------------------------------
/codes/models/IIH_module.py:
--------------------------------------------------------------------------------
1 | from models.model import *
2 | from models.invblock import INV_block
3 |
4 |
5 | class IIH_module(nn.Module):
6 |
7 | def __init__(self):
8 | super(IIH_module, self).__init__()
9 |
10 | self.inv1 = INV_block()
11 | self.inv2 = INV_block()
12 | self.inv3 = INV_block()
13 | self.inv4 = INV_block()
14 | self.inv5 = INV_block()
15 | self.inv6 = INV_block()
16 | self.inv7 = INV_block()
17 | self.inv8 = INV_block()
18 |
19 | self.inv9 = INV_block()
20 | self.inv10 = INV_block()
21 | self.inv11 = INV_block()
22 | self.inv12 = INV_block()
23 | self.inv13 = INV_block()
24 | self.inv14 = INV_block()
25 | self.inv15 = INV_block()
26 | self.inv16 = INV_block()
27 |
28 | def forward(self, x, rev=False):
29 |
30 | if not rev:
31 | out = self.inv1(x)
32 | out = self.inv2(out)
33 | out = self.inv3(out)
34 | out = self.inv4(out)
35 | out = self.inv5(out)
36 | out = self.inv6(out)
37 | out = self.inv7(out)
38 | out = self.inv8(out)
39 |
40 | out = self.inv9(out)
41 | out = self.inv10(out)
42 | out = self.inv11(out)
43 | out = self.inv12(out)
44 | out = self.inv13(out)
45 | out = self.inv14(out)
46 | out = self.inv15(out)
47 | out = self.inv16(out)
48 |
49 | else:
50 | out = self.inv16(x, rev=True)
51 | out = self.inv15(out, rev=True)
52 | out = self.inv14(out, rev=True)
53 | out = self.inv13(out, rev=True)
54 | out = self.inv12(out, rev=True)
55 | out = self.inv11(out, rev=True)
56 | out = self.inv10(out, rev=True)
57 | out = self.inv9(out, rev=True)
58 |
59 | out = self.inv8(out, rev=True)
60 | out = self.inv7(out, rev=True)
61 | out = self.inv6(out, rev=True)
62 | out = self.inv5(out, rev=True)
63 | out = self.inv4(out, rev=True)
64 | out = self.inv3(out, rev=True)
65 | out = self.inv2(out, rev=True)
66 | out = self.inv1(out, rev=True)
67 |
68 | return out
69 |
70 |
71 |
--------------------------------------------------------------------------------
/codes/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from torch.utils.data.distributed.DistributedSampler
3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
4 | dataloader after each epoch
5 | """
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | import torch.distributed as dist
10 |
11 |
12 | class DistIterSampler(Sampler):
13 | """Sampler that restricts data loading to a subset of the dataset.
14 |
15 | It is especially useful in conjunction with
16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17 | process can pass a DistributedSampler instance as a DataLoader sampler,
18 | and load a subset of the original dataset that is exclusive to it.
19 |
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 |
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (optional): Number of processes participating in
26 | distributed training.
27 | rank (optional): Rank of the current process within num_replicas.
28 | """
29 |
30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31 | if num_replicas is None:
32 | if not dist.is_available():
33 | raise RuntimeError("Requires distributed package to be available")
34 | num_replicas = dist.get_world_size()
35 | if rank is None:
36 | if not dist.is_available():
37 | raise RuntimeError("Requires distributed package to be available")
38 | rank = dist.get_rank()
39 | self.dataset = dataset
40 | self.num_replicas = num_replicas
41 | self.rank = rank
42 | self.epoch = 0
43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
44 | self.total_size = self.num_samples * self.num_replicas
45 |
46 | def __iter__(self):
47 | # deterministically shuffle based on epoch
48 | g = torch.Generator()
49 | g.manual_seed(self.epoch)
50 | indices = torch.randperm(self.total_size, generator=g).tolist()
51 |
52 | dsize = len(self.dataset)
53 | indices = [v % dsize for v in indices]
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | assert len(indices) == self.num_samples
58 |
59 | return iter(indices)
60 |
61 | def __len__(self):
62 | return self.num_samples
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
--------------------------------------------------------------------------------
/codes/models/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | class ReconstructionLoss(nn.Module):
6 | def __init__(self, losstype='l2', eps=1e-6):
7 | super(ReconstructionLoss, self).__init__()
8 | self.losstype = losstype
9 | self.eps = eps
10 |
11 | def forward(self, x, target):
12 | if self.losstype == 'l2':
13 | return torch.mean(torch.sum((x - target)**2, (1, 2, 3)))
14 | elif self.losstype == 'l1':
15 | diff = x - target
16 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3)))
17 | else:
18 | print("reconstruction loss type error!")
19 | return 0
20 |
21 |
22 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
23 | class GANLoss(nn.Module):
24 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
25 | super(GANLoss, self).__init__()
26 | self.gan_type = gan_type.lower()
27 | self.real_label_val = real_label_val
28 | self.fake_label_val = fake_label_val
29 |
30 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
31 | self.loss = nn.BCEWithLogitsLoss()
32 | elif self.gan_type == 'lsgan':
33 | self.loss = nn.MSELoss()
34 | elif self.gan_type == 'wgan-gp':
35 |
36 | def wgan_loss(input, target):
37 | # target is boolean
38 | return -1 * input.mean() if target else input.mean()
39 |
40 | self.loss = wgan_loss
41 | else:
42 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
43 |
44 | def get_target_label(self, input, target_is_real):
45 | if self.gan_type == 'wgan-gp':
46 | return target_is_real
47 | if target_is_real:
48 | return torch.empty_like(input).fill_(self.real_label_val)
49 | else:
50 | return torch.empty_like(input).fill_(self.fake_label_val)
51 |
52 | def forward(self, input, target_is_real):
53 | target_label = self.get_target_label(input, target_is_real)
54 | loss = self.loss(input, target_label)
55 | return loss
56 |
57 |
58 | class GradientPenaltyLoss(nn.Module):
59 | def __init__(self, device=torch.device('cpu')):
60 | super(GradientPenaltyLoss, self).__init__()
61 | self.register_buffer('grad_outputs', torch.Tensor())
62 | self.grad_outputs = self.grad_outputs.to(device)
63 |
64 | def get_grad_outputs(self, input):
65 | if self.grad_outputs.size() != input.size():
66 | self.grad_outputs.resize_(input.size()).fill_(1.0)
67 | return self.grad_outputs
68 |
69 | def forward(self, interp, interp_crit):
70 | grad_outputs = self.get_grad_outputs(interp_crit)
71 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
72 | grad_outputs=grad_outputs, create_graph=True,
73 | retain_graph=True, only_inputs=True)[0]
74 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
75 | grad_interp_norm = grad_interp.norm(2, dim=1)
76 |
77 | loss = ((grad_interp_norm - 1)**2).mean()
78 | return loss
79 |
--------------------------------------------------------------------------------
/codes/models/bicubic.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import math
4 | import sys, time
5 |
6 | # Interpolation kernel
7 | def u(s,a):
8 | if (abs(s) >=0) & (abs(s) <=1):
9 | return (a+2)*(abs(s)**3)-(a+3)*(abs(s)**2)+1
10 | elif (abs(s) > 1) & (abs(s) <= 2):
11 | return a*(abs(s)**3)-(5*a)*(abs(s)**2)+(8*a)*abs(s)-4*a
12 | return 0
13 |
14 | #Paddnig
15 | def padding(img,H,W,C):
16 | zimg = np.zeros((H+4,W+4,C))
17 | zimg[2:H+2,2:W+2,:C] = img
18 | #Pad the first/last two col and row
19 | zimg[2:H+2,0:2,:C]=img[:,0:1,:C]
20 | zimg[H+2:H+4,2:W+2,:]=img[H-1:H,:,:]
21 | zimg[2:H+2,W+2:W+4,:]=img[:,W-1:W,:]
22 | zimg[0:2,2:W+2,:C]=img[0:1,:,:C]
23 | #Pad the missing eight points
24 | zimg[0:2,0:2,:C]=img[0,0,:C]
25 | zimg[H+2:H+4,0:2,:C]=img[H-1,0,:C]
26 | zimg[H+2:H+4,W+2:W+4,:C]=img[H-1,W-1,:C]
27 | zimg[0:2,W+2:W+4,:C]=img[0,W-1,:C]
28 | return zimg
29 |
30 | # https://github.com/yunabe/codelab/blob/master/misc/terminal_progressbar/progress.py
31 | def get_progressbar_str(progress):
32 | END = 170
33 | MAX_LEN = 30
34 | BAR_LEN = int(MAX_LEN * progress)
35 | return ('Progress:[' + '=' * BAR_LEN +
36 | ('>' if BAR_LEN < MAX_LEN else '') +
37 | ' ' * (MAX_LEN - BAR_LEN) +
38 | '] %.1f%%' % (progress * 100.))
39 |
40 | # Bicubic operation
41 | def bicubic(img, ratio, a):
42 | #Get image size
43 | H,W,C = img.shape
44 |
45 | img = padding(img,H,W,C)
46 | #Create new image
47 | dH = math.floor(H*ratio)
48 | dW = math.floor(W*ratio)
49 | dst = np.zeros((dH, dW, 3))
50 |
51 | h = 1/ratio
52 |
53 | print('Start bicubic interpolation')
54 | print('It will take a little while...')
55 | inc = 0
56 | for c in range(C):
57 | for j in range(dH):
58 | for i in range(dW):
59 | x, y = i * h + 2 , j * h + 2
60 |
61 | x1 = 1 + x - math.floor(x)
62 | x2 = x - math.floor(x)
63 | x3 = math.floor(x) + 1 - x
64 | x4 = math.floor(x) + 2 - x
65 |
66 | y1 = 1 + y - math.floor(y)
67 | y2 = y - math.floor(y)
68 | y3 = math.floor(y) + 1 - y
69 | y4 = math.floor(y) + 2 - y
70 |
71 | mat_l = np.matrix([[u(x1,a),u(x2,a),u(x3,a),u(x4,a)]])
72 | mat_m = np.matrix([[img[int(y-y1),int(x-x1),c],img[int(y-y2),int(x-x1),c],img[int(y+y3),int(x-x1),c],img[int(y+y4),int(x-x1),c]],
73 | [img[int(y-y1),int(x-x2),c],img[int(y-y2),int(x-x2),c],img[int(y+y3),int(x-x2),c],img[int(y+y4),int(x-x2),c]],
74 | [img[int(y-y1),int(x+x3),c],img[int(y-y2),int(x+x3),c],img[int(y+y3),int(x+x3),c],img[int(y+y4),int(x+x3),c]],
75 | [img[int(y-y1),int(x+x4),c],img[int(y-y2),int(x+x4),c],img[int(y+y3),int(x+x4),c],img[int(y+y4),int(x+x4),c]]])
76 | mat_r = np.matrix([[u(y1,a)],[u(y2,a)],[u(y3,a)],[u(y4,a)]])
77 | dst[j, i, c] = np.dot(np.dot(mat_l, mat_m),mat_r)
78 |
79 | # Print progress
80 | inc = inc + 1
81 | sys.stderr.write('\r\033[K' + get_progressbar_str(inc/(C*dH*dW)))
82 | sys.stderr.flush()
83 | sys.stderr.write('\n')
84 | sys.stderr.flush()
85 | return dst
86 |
87 |
--------------------------------------------------------------------------------
/codes/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import time
5 | import argparse
6 | import torchvision
7 | from collections import OrderedDict
8 |
9 | import numpy as np
10 | import options.options as option
11 | import utils.util as util
12 | from data.util import bgr2ycbcr
13 | from data import create_dataset, create_dataloader
14 | from models import create_model
15 |
16 |
17 | #### options
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
20 | opt = option.parse(parser.parse_args().opt, is_train=False)
21 | opt = option.dict_to_nonedict(opt)
22 |
23 | util.mkdirs(
24 | (path for key, path in opt['path'].items()
25 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
26 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
27 | screen=True, tofile=True)
28 | logger = logging.getLogger('base')
29 | logger.info(option.dict2str(opt))
30 |
31 | #### Create test dataset and dataloader
32 | test_loaders = []
33 | for phase, dataset_opt in sorted(opt['datasets'].items()):
34 | test_set = create_dataset(dataset_opt)
35 | test_loader = create_dataloader(test_set, dataset_opt)
36 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
37 | test_loaders.append(test_loader)
38 |
39 | model = create_model(opt)
40 | count=0
41 | for test_loader in test_loaders:
42 | test_set_name = test_loader.dataset.opt['name']
43 | logger.info('\nTesting [{:s}]...'.format(test_set_name))
44 | logger.info('\nTesting [{:s}]...'.format(test_set_name))
45 | test_start_time = time.time()
46 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
47 | if not os.path.exists(dataset_dir):
48 | os.makedirs(dataset_dir)
49 |
50 | for data in test_loader:
51 | model.feed_data(data)
52 |
53 | model.test()
54 |
55 | visuals = model.get_current_visuals()
56 |
57 | cover = visuals['cover']
58 | secret = visuals['secret']
59 | secret_recover = visuals['secret_recover']
60 | steg = visuals['steg']
61 |
62 | cover_path = dataset_dir + "\\" + "cover"
63 | secret_path = dataset_dir + "\\" + "secret"
64 | secret_recover_path = dataset_dir + "\\" + "secret_recover"
65 | steg_path = dataset_dir + "\\" + "steg"
66 |
67 |
68 | if not os.path.exists(cover_path):
69 | os.makedirs(cover_path)
70 | if not os.path.exists(secret_path):
71 | os.makedirs(secret_path)
72 | if not os.path.exists(secret_recover_path):
73 | os.makedirs(secret_recover_path)
74 | if not os.path.exists(steg_path):
75 | os.makedirs(steg_path)
76 |
77 |
78 | save_cover_path = osp.join(cover_path, str(count)+'_clean_cover.png')
79 | torchvision.utils.save_image(cover, save_cover_path)
80 |
81 | save_steg_path = osp.join(steg_path, str(count) + '_steg.png')
82 | torchvision.utils.save_image(steg, save_steg_path)
83 |
84 | for i in range(int(secret.shape[0])):
85 | save_secret_path = osp.join(secret_path, str(count) + '_' + str(i) + '_clean_secret.png')
86 | save_secret_recover_path = osp.join(secret_recover_path, str(count) + '_' + str(i) + '_secret_recover.png')
87 | torchvision.utils.save_image(secret[i:i+1,:,:,:], save_secret_path)
88 | torchvision.utils.save_image(secret_recover[i:i + 1, :, :, :], save_secret_recover_path)
89 |
90 | count = count + 1
91 |
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Invertible Mosaic Image Hiding Network for Very Large Capacity Image Steganography
2 | [Zihan Chen](https://brittany-chen.github.io/)(chenzihan21@nudt.edu.cn), Tianrui Liu#, [Jun-Jie Huang](https://jjhuangcs.github.io/), Wentao Zhao#, Xing Bi and Meng Wang (#corressponding author)
3 |
4 | Pytorch implementation for "Invertible Mosaic Image Hiding Network for Very Large Capacity Image Steganography" (ICASSP'2024).
5 |
6 |
7 |
8 | The existing image steganography methods either sequentially conceal secret images or conceal a concatenation of multiple images. In such ways, the interference of information among multiple images will become increasingly severe when the number of secret images becomes larger, thus restrict the development of very large capacity image steganography.
9 | In this paper, we propose an Invertible Mosaic Image Hiding Network (InvMIHNet) which realizes very large capacity image steganography with high quality by concealing a single mosaic secret image. InvMIHNet consists of an Invertible Image Rescaling (IIR) module and an Invertible Image Hiding (IIH) module.
10 | The IIR module works for downscaling the single mosaic secret image form by spatially splicing the multiple secret images, and the IIH module then conceal this mosaic image under the cover image.
11 | The proposed InvMIHNet successfully conceal and reveal up to 16 secret images with a small number of parameters and memory consumption.
12 | Extensive experiments on ImageNet-1K, COCO and DIV2K show InvMIHNet outperforms state-of-the-art methods in terms of both the imperceptibility of stego image, recover accuracy of secret image and security against steganlysis methods.
13 |
14 |
15 |
16 |
17 |
18 | ## Requisites
19 | - Python >= 3.7
20 | - PyTorch >= 1.0
21 | - NVIDIA GPU + CUDA CuDNN
22 |
23 | ## Dataset Preparation
24 | - DIV2K
25 | - COCO
26 | - ImageNet
27 |
28 |
29 | ## Get Started
30 | ### Pretrained models
31 | Download and unzip [pretrained models](https://drive.google.com/file/d/17GRiwaJN8yqmLtiAO-bJcgpG8aWHysz3/view?usp=drive_link), and then copy their path to ```experiments/pretrained_models```.
32 |
33 | ### Training for image steganography
34 | First set a config file in options/train/, then run as following:
35 |
36 | python train.py -opt options/train/train_InvMIHNet_4images.yml
37 |
38 | ### Testing for image steganography
39 | First set a config file in options/test/, then run as following:
40 |
41 | python test.py -opt options/test/test_InvMIHNet_4images.yml
42 |
43 | You can choose to conceal and reveal **4, 6, 8, 9, 16 images**.
44 |
45 | ## Description of the files in this repository
46 |
47 | 1) [`data/`](./data): A data loader to provide data for training, validation and testing.
48 | 2) [`models/`](./models): Construct models for training and testing.
49 | 3) [`options/`](./options): Configure the options for data loader, network structure, model, training strategies and etc.
50 | 4) [`experiments/`](./experiments): Save the parameters of InvMHINet.
51 |
52 |
53 | ## Citation
54 |
55 | If you find this code and data useful, please consider citing the original work by authors:
56 |
57 | ```
58 | @inproceedings{Chen2024InvMIHNet,
59 | title={Invertible Mosaic Image Hiding Network for Very Large Capacity Image Steganography},
60 | author={Zihan Chen, Tianrui Liu, Jun-Jie Huang, Wentao Zhao, Xing Bi and Meng Wang},
61 | booktitle={IEEE International Conference on Acoustics, Speech, and Signal Processing},
62 | volume={},
63 | number={},
64 | pages={},
65 | year={2024}
66 | }
67 | ```
68 |
69 | ## Acknowledgement
70 | The code is based on [IRN](https://github.com/pkuxmq/Invertible-Image-Rescaling), with reference of [HiNet](https://github.com/TomTomTommi/HiNet).
71 |
72 | ## Contact
73 | If you have any questions, please contact .
74 |
--------------------------------------------------------------------------------
/codes/models/modules/discriminator_vgg_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | class Discriminator_VGG_128(nn.Module):
7 | def __init__(self, in_nc, nf):
8 | super(Discriminator_VGG_128, self).__init__()
9 | # [64, 128, 128]
10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
13 | # [64, 64, 64]
14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
18 | # [128, 32, 32]
19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
23 | # [256, 16, 16]
24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
28 | # [512, 8, 8]
29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
33 |
34 | self.linear1 = nn.Linear(512 * 4 * 4, 100)
35 | self.linear2 = nn.Linear(100, 1)
36 |
37 | # activation function
38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
39 |
40 | def forward(self, x):
41 | fea = self.lrelu(self.conv0_0(x))
42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
43 |
44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
46 |
47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
49 |
50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
52 |
53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
55 |
56 | fea = fea.view(fea.size(0), -1)
57 | fea = self.lrelu(self.linear1(fea))
58 | out = self.linear2(fea)
59 | return out
60 |
61 |
62 | class VGGFeatureExtractor(nn.Module):
63 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
64 | device=torch.device('cpu')):
65 | super(VGGFeatureExtractor, self).__init__()
66 | self.use_input_norm = use_input_norm
67 | if use_bn:
68 | model = torchvision.models.vgg19_bn(pretrained=True)
69 | else:
70 | model = torchvision.models.vgg19(pretrained=True)
71 | if self.use_input_norm:
72 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
73 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
75 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
76 | self.register_buffer('mean', mean)
77 | self.register_buffer('std', std)
78 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
79 | # No need to BP to variable
80 | for k, v in self.features.named_parameters():
81 | v.requires_grad = False
82 |
83 | def forward(self, x):
84 | # Assume input range is [0, 1]
85 | if self.use_input_norm:
86 | x = (x - self.mean) / self.std
87 | output = self.features(x)
88 | return output
89 |
--------------------------------------------------------------------------------
/codes/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
12 | self.is_train = opt['is_train']
13 | self.schedulers = []
14 | self.optimizers = []
15 |
16 | def feed_data(self, data):
17 | pass
18 |
19 | def optimize_parameters(self):
20 | pass
21 |
22 | def get_current_visuals(self):
23 | pass
24 |
25 | def get_current_losses(self):
26 | pass
27 |
28 | def print_network(self):
29 | pass
30 |
31 | def save(self, label):
32 | pass
33 |
34 | def load(self):
35 | pass
36 |
37 | def _set_lr(self, lr_groups_l):
38 | ''' set learning rate for warmup,
39 | lr_groups_l: list for lr_groups. each for a optimizer'''
40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
41 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
42 | param_group['lr'] = lr
43 |
44 | def _get_init_lr(self):
45 | # get the initial lr, which is set by the scheduler
46 | init_lr_groups_l = []
47 | for optimizer in self.optimizers:
48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
49 | return init_lr_groups_l
50 |
51 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
52 | for scheduler in self.schedulers:
53 | scheduler.step()
54 | #### set up warm up learning rate
55 | if cur_iter < warmup_iter:
56 | # get initial lr for each group
57 | init_lr_g_l = self._get_init_lr()
58 | # modify warming-up learning rates
59 | warm_up_lr_l = []
60 | for init_lr_g in init_lr_g_l:
61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
62 | # set learning rate
63 | self._set_lr(warm_up_lr_l)
64 |
65 | def get_current_learning_rate(self):
66 | # return self.schedulers[0].get_lr()[0]
67 | return self.optimizers[0].param_groups[0]['lr']
68 |
69 | def get_network_description(self, network):
70 | '''Get the string and total parameters of the network'''
71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
72 | network = network.module
73 | s = str(network)
74 | n = sum(map(lambda x: x.numel(), network.parameters()))
75 | return s, n
76 |
77 | def save_network(self, network, network_label, iter_label):
78 | save_filename = '{}_{}.pth'.format(network_label, iter_label)
79 | save_path = os.path.join(self.opt['path']['models'], save_filename)
80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
81 | network = network.module
82 | state_dict = network.state_dict()
83 | for key, param in state_dict.items():
84 | state_dict[key] = param.cpu()
85 | torch.save(state_dict, save_path)
86 |
87 | def load_network(self, load_path, network, strict=True):
88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
89 | network = network.module
90 | load_net = torch.load(load_path)
91 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
92 | for k, v in load_net.items():
93 | if k.startswith('module.'):
94 | load_net_clean[k[7:]] = v
95 | else:
96 | load_net_clean[k] = v
97 | network.load_state_dict(load_net_clean, strict=strict)
98 |
99 | def save_training_state(self, epoch, iter_step):
100 | '''Saves training state during training, which will be used for resuming'''
101 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
102 | for s in self.schedulers:
103 | state['schedulers'].append(s.state_dict())
104 | for o in self.optimizers:
105 | state['optimizers'].append(o.state_dict())
106 | save_filename = '{}.state'.format(iter_step)
107 | save_path = os.path.join(self.opt['path']['training_state'], save_filename)
108 | torch.save(state, save_path)
109 |
110 | def resume_training(self, resume_state):
111 | '''Resume the optimizers and schedulers for training'''
112 | resume_optimizers = resume_state['optimizers']
113 | resume_schedulers = resume_state['schedulers']
114 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
115 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
116 | for i, o in enumerate(resume_optimizers):
117 | self.optimizers[i].load_state_dict(o)
118 | for i, s in enumerate(resume_schedulers):
119 | self.schedulers[i].load_state_dict(s)
120 |
--------------------------------------------------------------------------------
/codes/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 | Loader, Dumper = OrderedYaml()
7 |
8 |
9 | def parse(opt_path, is_train=True):
10 | with open(opt_path, mode='r') as f:
11 | opt = yaml.load(f, Loader=Loader)
12 | # export CUDA_VISIBLE_DEVICES
13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
16 |
17 | opt['is_train'] = is_train
18 | if opt['distortion'] == 'sr':
19 | scale_W = opt['scale_W']
20 | scale_H = opt['scale_H']
21 |
22 | # datasets
23 | for phase, dataset in opt['datasets'].items():
24 | phase = phase.split('_')[0]
25 | dataset['phase'] = phase
26 | if opt['distortion'] == 'sr':
27 | dataset['scale_W'] = scale_W
28 | dataset['scale_H'] = scale_H
29 | is_lmdb = False
30 | if dataset.get('dataroot_GT', None) is not None:
31 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
32 | if dataset['dataroot_GT'].endswith('lmdb'):
33 | is_lmdb = True
34 | # if dataset.get('dataroot_GT_bg', None) is not None:
35 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg'])
36 | if dataset.get('dataroot_LQ', None) is not None:
37 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
38 | if dataset['dataroot_LQ'].endswith('lmdb'):
39 | is_lmdb = True
40 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
41 | if dataset['mode'].endswith('mc'): # for memcached
42 | dataset['data_type'] = 'mc'
43 | dataset['mode'] = dataset['mode'].replace('_mc', '')
44 |
45 | # path
46 | for key, path in opt['path'].items():
47 | if path and key in opt['path'] and key != 'strict_load':
48 | opt['path'][key] = osp.expanduser(path)
49 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
50 | if is_train:
51 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
52 | opt['path']['experiments_root'] = experiments_root
53 | opt['path']['models'] = osp.join(experiments_root, 'models')
54 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
55 | opt['path']['log'] = experiments_root
56 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
57 |
58 | # change some options for debug mode
59 | if 'debug' in opt['name']:
60 | opt['train']['val_freq'] = 8
61 | opt['logger']['print_freq'] = 1
62 | opt['logger']['save_checkpoint_freq'] = 8
63 | else: # test
64 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
65 | opt['path']['results_root'] = results_root
66 | opt['path']['log'] = results_root
67 |
68 | # network
69 | if opt['distortion'] == 'sr':
70 | opt['network_G']['scale_W'] = scale_W
71 | opt['network_G']['scale_H'] = scale_H
72 | return opt
73 |
74 |
75 | def dict2str(opt, indent_l=1):
76 | '''dict to string for logger'''
77 | msg = ''
78 | for k, v in opt.items():
79 | if isinstance(v, dict):
80 | msg += ' ' * (indent_l * 2) + k + ':[\n'
81 | msg += dict2str(v, indent_l + 1)
82 | msg += ' ' * (indent_l * 2) + ']\n'
83 | else:
84 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
85 | return msg
86 |
87 |
88 | class NoneDict(dict):
89 | def __missing__(self, key):
90 | return None
91 |
92 |
93 | # convert to NoneDict, which return None for missing key.
94 | def dict_to_nonedict(opt):
95 | if isinstance(opt, dict):
96 | new_opt = dict()
97 | for key, sub_opt in opt.items():
98 | new_opt[key] = dict_to_nonedict(sub_opt)
99 | return NoneDict(**new_opt)
100 | elif isinstance(opt, list):
101 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
102 | else:
103 | return opt
104 |
105 |
106 | def check_resume(opt, resume_iter):
107 | '''Check resume states and pretrain_model paths'''
108 | logger = logging.getLogger('base')
109 | if opt['path']['resume_state']:
110 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
111 | 'pretrain_model_D', None) is not None:
112 | logger.warning('pretrain_model path will be ignored when resuming training.')
113 |
114 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
115 | '{}_G.pth'.format(resume_iter))
116 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
117 | if 'gan' in opt['model']:
118 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
119 | '{}_D.pth'.format(resume_iter))
120 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
121 |
--------------------------------------------------------------------------------
/codes/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restart_weights = weights if weights else [1]
16 | assert len(self.restarts) == len(
17 | self.restart_weights), 'restarts and their weights do not match.'
18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
19 |
20 | def get_lr(self):
21 | if self.last_epoch in self.restarts:
22 | if self.clear_state:
23 | self.optimizer.state = defaultdict(dict)
24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
26 | if self.last_epoch not in self.milestones:
27 | return [group['lr'] for group in self.optimizer.param_groups]
28 | return [
29 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
30 | for group in self.optimizer.param_groups
31 | ]
32 |
33 |
34 | class CosineAnnealingLR_Restart(_LRScheduler):
35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
36 | self.T_period = T_period
37 | self.T_max = self.T_period[0] # current T period
38 | self.eta_min = eta_min
39 | self.restarts = restarts if restarts else [0]
40 | self.restart_weights = weights if weights else [1]
41 | self.last_restart = 0
42 | assert len(self.restarts) == len(
43 | self.restart_weights), 'restarts and their weights do not match.'
44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
45 |
46 | def get_lr(self):
47 | if self.last_epoch == 0:
48 | return self.base_lrs
49 | elif self.last_epoch in self.restarts:
50 | self.last_restart = self.last_epoch
51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
55 | return [
56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
58 | ]
59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
61 | (group['lr'] - self.eta_min) + self.eta_min
62 | for group in self.optimizer.param_groups]
63 |
64 |
65 | if __name__ == "__main__":
66 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
67 | betas=(0.9, 0.99))
68 | ##############################
69 | # MultiStepLR_Restart
70 | ##############################
71 | ## Original
72 | lr_steps = [200000, 400000, 600000, 800000]
73 | restarts = None
74 | restart_weights = None
75 |
76 | ## two
77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
78 | restarts = [500000]
79 | restart_weights = [1]
80 |
81 | ## four
82 | lr_steps = [
83 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
84 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
85 | ]
86 | restarts = [250000, 500000, 750000]
87 | restart_weights = [1, 1, 1]
88 |
89 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
90 | clear_state=False)
91 |
92 | ##############################
93 | # Cosine Annealing Restart
94 | ##############################
95 | ## two
96 | T_period = [500000, 500000]
97 | restarts = [500000]
98 | restart_weights = [1]
99 |
100 | ## four
101 | T_period = [250000, 250000, 250000, 250000]
102 | restarts = [250000, 500000, 750000]
103 | restart_weights = [1, 1, 1]
104 |
105 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
106 | weights=restart_weights)
107 |
108 | ##############################
109 | # Draw figure
110 | ##############################
111 | N_iter = 1000000
112 | lr_l = list(range(N_iter))
113 | for i in range(N_iter):
114 | scheduler.step()
115 | current_lr = optimizer.param_groups[0]['lr']
116 | lr_l[i] = current_lr
117 |
118 | import matplotlib as mpl
119 | from matplotlib import pyplot as plt
120 | import matplotlib.ticker as mtick
121 | mpl.style.use('default')
122 | import seaborn
123 | seaborn.set(style='whitegrid')
124 | seaborn.set_context('paper')
125 |
126 | plt.figure(1)
127 | plt.subplot(111)
128 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
129 | plt.title('Title', fontsize=16, color='k')
130 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
131 | legend = plt.legend(loc='upper right', shadow=False)
132 | ax = plt.gca()
133 | labels = ax.get_xticks().tolist()
134 | for k, v in enumerate(labels):
135 | labels[k] = str(int(v / 1000)) + 'K'
136 | ax.set_xticklabels(labels)
137 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
138 |
139 | ax.set_ylabel('Learning rate')
140 | ax.set_xlabel('Iteration')
141 | fig = plt.gcf()
142 | plt.show()
143 |
--------------------------------------------------------------------------------
/codes/data/Steg_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 | import lmdb
5 | import torch
6 | import torch.utils.data as data
7 | import data.util as util
8 |
9 |
10 | class StegDataset(data.Dataset):
11 | '''
12 | Read LQ (Low Quality, here is LR) and GT image pairs.
13 | If only GT image is provided, generate LQ image on-the-fly.
14 | The pair is ensured by 'sorted' function, so please check the name convention.
15 | '''
16 |
17 | def __init__(self, opt):
18 | super(StegDataset, self).__init__()
19 | self.opt = opt
20 | self.data_type = self.opt['data_type']
21 | self.paths_LQ, self.paths_GT = None, None
22 | self.sizes_LQ, self.sizes_GT = None, None
23 | self.LQ_env, self.GT_env = None, None # environment for lmdb
24 |
25 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
26 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
27 | assert self.paths_GT, 'Error: GT path is empty.'
28 | if self.paths_LQ and self.paths_GT:
29 | assert len(self.paths_LQ) == len(
30 | self.paths_GT
31 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
32 | len(self.paths_LQ), len(self.paths_GT))
33 | self.random_scale_list = [1]
34 |
35 | def _init_lmdb(self):
36 | # https://github.com/chainer/chainermn/issues/129
37 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
38 | meminit=False)
39 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
40 | meminit=False)
41 |
42 | def __getitem__(self, index):
43 | if self.data_type == 'lmdb':
44 | if (self.GT_env is None) or (self.LQ_env is None):
45 | self._init_lmdb()
46 | GT_path, LQ_path = None, None
47 | scale_W = self.opt['scale_W']
48 | scale_H = self.opt['scale_H']
49 | GT_size = self.opt['GT_size']
50 |
51 | # get GT image
52 | GT_path = self.paths_GT[index]
53 | if self.data_type == 'lmdb':
54 | resolution = [int(s) for s in self.sizes_GT[index].split('_')]
55 | else:
56 | resolution = None
57 | img_GT = util.read_img(self.GT_env, GT_path, resolution)
58 | # modcrop in the validation / test phase
59 | if self.opt['phase'] != 'train':
60 | img_GT = util.modcrop(img_GT, scale_W, scale_H)
61 | # change color space if necessary
62 | if self.opt['color']:
63 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
64 |
65 | # get LQ image
66 | if self.paths_LQ:
67 | LQ_path = self.paths_LQ[index]
68 | if self.data_type == 'lmdb':
69 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')]
70 | else:
71 | resolution = None
72 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
73 | else: # down-sampling on-the-fly
74 | # randomly scale during training
75 | if self.opt['phase'] == 'train':
76 | random_scale = random.choice(self.random_scale_list)
77 | H_s, W_s, _ = img_GT.shape
78 |
79 | def _mod(n, random_scale, scale, thres):
80 | rlt = int(n * random_scale)
81 | rlt = (rlt // scale) * scale
82 | return thres if rlt < thres else rlt
83 |
84 | H_s = _mod(H_s, random_scale, scale_H, GT_size)
85 | W_s = _mod(W_s, random_scale, scale_W, GT_size)
86 | img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
87 | # force to 3 channels
88 | if img_GT.ndim == 2:
89 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
90 |
91 | H, W, _ = img_GT.shape
92 | # using matlab imresize
93 | img_LQ = util.imresize_np(img_GT, 1 / scale_W, 1 / scale_H, True)
94 | if img_LQ.ndim == 2:
95 | img_LQ = np.expand_dims(img_LQ, axis=2)
96 |
97 | if self.opt['phase'] == 'train':
98 | # if the image size is too small
99 | H, W, _ = img_GT.shape
100 | if H < GT_size or W < GT_size:
101 | img_GT = cv2.resize(np.copy(img_GT), (GT_size, GT_size),
102 | interpolation=cv2.INTER_LINEAR)
103 | # using matlab imresize
104 | img_LQ = util.imresize_np(img_GT, 1 / scale_W, 1 / scale_H, True)
105 | if img_LQ.ndim == 2:
106 | img_LQ = np.expand_dims(img_LQ, axis=2)
107 |
108 | H, W, C = img_LQ.shape
109 | LQ_size_W = GT_size // scale_W
110 | LQ_size_H = GT_size // scale_H
111 |
112 | # randomly crop
113 | rnd_h = random.randint(0, max(0, H - LQ_size_H))
114 | rnd_w = random.randint(0, max(0, W - LQ_size_W))
115 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size_H, rnd_w:rnd_w + LQ_size_W, :]
116 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale_H), int(rnd_w * scale_W)
117 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
118 |
119 | # augmentation - flip, rotate
120 | # img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
121 | # self.opt['use_rot'])
122 |
123 | # change color space if necessary
124 | # if self.opt['color']:
125 | # img_LQ = util.channel_convert(C, self.opt['color'],
126 | # [img_LQ])[0] # TODO during val no definition
127 |
128 | # BGR to RGB, HWC to CHW, numpy to tensor
129 | if img_GT.shape[2] == 3:
130 | img_GT = img_GT[:, :, [2, 1, 0]]
131 | img_LQ = img_LQ[:, :, [2, 1, 0]]
132 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
133 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
134 |
135 | if LQ_path is None:
136 | LQ_path = GT_path
137 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
138 |
139 | def __len__(self):
140 | return len(self.paths_GT)
141 |
--------------------------------------------------------------------------------
/codes/models/modules/module_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 |
27 | def initialize_weights_xavier(net_l, scale=1):
28 | if not isinstance(net_l, list):
29 | net_l = [net_l]
30 | for net in net_l:
31 | for m in net.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.xavier_normal_(m.weight)
34 | m.weight.data *= scale # for residual block
35 | if m.bias is not None:
36 | m.bias.data.zero_()
37 | elif isinstance(m, nn.Linear):
38 | init.xavier_normal_(m.weight)
39 | m.weight.data *= scale
40 | if m.bias is not None:
41 | m.bias.data.zero_()
42 | elif isinstance(m, nn.BatchNorm2d):
43 | init.constant_(m.weight, 1)
44 | init.constant_(m.bias.data, 0.0)
45 |
46 |
47 | def make_layer(block, n_layers):
48 | layers = []
49 | for _ in range(n_layers):
50 | layers.append(block())
51 | return nn.Sequential(*layers)
52 |
53 |
54 | class ResidualBlock_noBN(nn.Module):
55 | '''Residual block w/o BN
56 | ---Conv-ReLU-Conv-+-
57 | |________________|
58 | '''
59 |
60 | def __init__(self, nf=64):
61 | super(ResidualBlock_noBN, self).__init__()
62 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
64 |
65 | # initialization
66 | initialize_weights([self.conv1, self.conv2], 0.1)
67 |
68 | def forward(self, x):
69 | identity = x
70 | out = F.relu(self.conv1(x), inplace=True)
71 | out = self.conv2(out)
72 | return identity + out
73 |
74 |
75 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
76 | """Warp an image or feature map with optical flow
77 | Args:
78 | x (Tensor): size (N, C, H, W)
79 | flow (Tensor): size (N, H, W, 2), normal value
80 | interp_mode (str): 'nearest' or 'bilinear'
81 | padding_mode (str): 'zeros' or 'border' or 'reflection'
82 |
83 | Returns:
84 | Tensor: warped image or feature map
85 | """
86 | assert x.size()[-2:] == flow.size()[1:3]
87 | B, C, H, W = x.size()
88 | # mesh grid
89 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
90 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
91 | grid.requires_grad = False
92 | grid = grid.type_as(x)
93 | vgrid = grid + flow
94 | # scale grid to [-1,1]
95 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
96 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
97 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
98 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
99 | return output
100 |
101 |
102 |
103 | def initialize_weights(net_l, scale=1):
104 | if not isinstance(net_l, list):
105 | net_l = [net_l]
106 | for net in net_l:
107 | for m in net.modules():
108 | if isinstance(m, nn.Conv2d):
109 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
110 | m.weight.data *= scale # for residual block
111 | if m.bias is not None:
112 | m.bias.data.zero_()
113 | elif isinstance(m, nn.Linear):
114 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
115 | m.weight.data *= scale
116 | if m.bias is not None:
117 | m.bias.data.zero_()
118 | elif isinstance(m, nn.BatchNorm2d):
119 | init.constant_(m.weight, 1)
120 | init.constant_(m.bias.data, 0.0)
121 |
122 |
123 | def make_layer(block, n_layers):
124 | layers = []
125 | for _ in range(n_layers):
126 | layers.append(block())
127 | return nn.Sequential(*layers)
128 |
129 |
130 | class ResidualBlock_noBN(nn.Module):
131 | '''Residual block w/o BN
132 | ---Conv-ReLU-Conv-+-
133 | |________________|
134 | '''
135 |
136 | def __init__(self, nf=64):
137 | super(ResidualBlock_noBN, self).__init__()
138 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
139 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
140 |
141 | # initialization
142 | initialize_weights([self.conv1, self.conv2], 0.1)
143 |
144 | def forward(self, x):
145 | identity = x
146 | out = F.relu(self.conv1(x), inplace=True)
147 | out = self.conv2(out)
148 | return identity + out
149 |
150 |
151 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
152 | """Warp an image or feature map with optical flow
153 | Args:
154 | x (Tensor): size (N, C, H, W)
155 | flow (Tensor): size (N, H, W, 2), normal value
156 | interp_mode (str): 'nearest' or 'bilinear'
157 | padding_mode (str): 'zeros' or 'border' or 'reflection'
158 | Returns:
159 | Tensor: warped image or feature map
160 | """
161 | flow = flow.permute(0,2,3,1)
162 | assert x.size()[-2:] == flow.size()[1:3]
163 | B, C, H, W = x.size()
164 | # mesh grid
165 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
166 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
167 | grid.requires_grad = False
168 | grid = grid.type_as(x)
169 | vgrid = grid + flow
170 | # scale grid to [-1,1]
171 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
172 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
173 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
174 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
175 | return output
--------------------------------------------------------------------------------
/codes/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | from datetime import datetime
6 | import random
7 | import logging
8 | from collections import OrderedDict
9 | import numpy as np
10 | import cv2
11 | import torch
12 | from torchvision.utils import make_grid
13 | from shutil import get_terminal_size
14 |
15 | import yaml
16 | try:
17 | from yaml import CLoader as Loader, CDumper as Dumper
18 | except ImportError:
19 | from yaml import Loader, Dumper
20 |
21 |
22 | def OrderedYaml():
23 | '''yaml orderedDict support'''
24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
25 |
26 | def dict_representer(dumper, data):
27 | return dumper.represent_dict(data.items())
28 |
29 | def dict_constructor(loader, node):
30 | return OrderedDict(loader.construct_pairs(node))
31 |
32 | Dumper.add_representer(OrderedDict, dict_representer)
33 | Loader.add_constructor(_mapping_tag, dict_constructor)
34 | return Loader, Dumper
35 |
36 |
37 | ####################
38 | # miscellaneous
39 | ####################
40 |
41 |
42 | def get_timestamp():
43 | return datetime.now().strftime('%y%m%d-%H%M%S')
44 |
45 |
46 | def mkdir(path):
47 | if not os.path.exists(path):
48 | os.makedirs(path)
49 |
50 |
51 | def mkdirs(paths):
52 | if isinstance(paths, str):
53 | mkdir(paths)
54 | else:
55 | for path in paths:
56 | mkdir(path)
57 |
58 |
59 | def mkdir_and_rename(path):
60 | if os.path.exists(path):
61 | new_name = path + '_archived_' + get_timestamp()
62 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
63 | logger = logging.getLogger('base')
64 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
65 | os.rename(path, new_name)
66 | os.makedirs(path)
67 |
68 |
69 | def set_random_seed(seed):
70 | random.seed(seed)
71 | np.random.seed(seed)
72 | torch.manual_seed(seed)
73 | torch.cuda.manual_seed_all(seed)
74 |
75 |
76 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
77 | '''set up logger'''
78 | lg = logging.getLogger(logger_name)
79 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
80 | datefmt='%y-%m-%d %H:%M:%S')
81 | lg.setLevel(level)
82 | if tofile:
83 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
84 | fh = logging.FileHandler(log_file, mode='w')
85 | fh.setFormatter(formatter)
86 | lg.addHandler(fh)
87 | if screen:
88 | sh = logging.StreamHandler()
89 | sh.setFormatter(formatter)
90 | lg.addHandler(sh)
91 |
92 |
93 | ####################
94 | # image convert
95 | ####################
96 |
97 |
98 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
99 | '''
100 | Converts a torch Tensor into an image Numpy array
101 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
102 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
103 | '''
104 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
105 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
106 | n_dim = tensor.dim()
107 | if n_dim == 4:
108 | n_img = len(tensor)
109 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
110 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
111 | elif n_dim == 3:
112 | img_np = tensor.numpy()
113 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
114 | elif n_dim == 2:
115 | img_np = tensor.numpy()
116 | else:
117 | raise TypeError(
118 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
119 | if out_type == np.uint8:
120 | img_np = (img_np * 255.0).round()
121 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
122 | return img_np.astype(out_type)
123 |
124 |
125 | def save_img(img, img_path, mode='RGB'):
126 | cv2.imwrite(img_path, img)
127 |
128 |
129 | ####################
130 | # metric
131 | ####################
132 |
133 |
134 | def calculate_psnr(img1, img2):
135 | # img1 and img2 have range [0, 255]
136 | img1 = img1.astype(np.float64)
137 | img2 = img2.astype(np.float64)
138 | mse = np.mean((img1 - img2)**2)
139 | if mse == 0:
140 | return float('inf')
141 | return 20 * math.log10(255.0 / math.sqrt(mse))
142 |
143 |
144 | def ssim(img1, img2):
145 | C1 = (0.01 * 255)**2
146 | C2 = (0.03 * 255)**2
147 |
148 | img1 = img1.astype(np.float64)
149 | img2 = img2.astype(np.float64)
150 | kernel = cv2.getGaussianKernel(11, 1.5)
151 | window = np.outer(kernel, kernel.transpose())
152 |
153 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
154 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
155 | mu1_sq = mu1**2
156 | mu2_sq = mu2**2
157 | mu1_mu2 = mu1 * mu2
158 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
159 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
160 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
161 |
162 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
163 | (sigma1_sq + sigma2_sq + C2))
164 | return ssim_map.mean()
165 |
166 |
167 | def calculate_ssim(img1, img2):
168 | '''calculate SSIM
169 | the same outputs as MATLAB's
170 | img1, img2: [0, 255]
171 | '''
172 | if not img1.shape == img2.shape:
173 | raise ValueError('Input images must have the same dimensions.')
174 | if img1.ndim == 2:
175 | return ssim(img1, img2)
176 | elif img1.ndim == 3:
177 | if img1.shape[2] == 3:
178 | ssims = []
179 | for i in range(3):
180 | ssims.append(ssim(img1, img2))
181 | return np.array(ssims).mean()
182 | elif img1.shape[2] == 1:
183 | return ssim(np.squeeze(img1), np.squeeze(img2))
184 | else:
185 | raise ValueError('Wrong input image dimensions.')
186 |
187 |
188 | class ProgressBar(object):
189 | '''A progress bar which can print the progress
190 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
191 | '''
192 |
193 | def __init__(self, task_num=0, bar_width=50, start=True):
194 | self.task_num = task_num
195 | max_bar_width = self._get_max_bar_width()
196 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
197 | self.completed = 0
198 | if start:
199 | self.start()
200 |
201 | def _get_max_bar_width(self):
202 | terminal_width, _ = get_terminal_size()
203 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
204 | if max_bar_width < 10:
205 | print('terminal width is too small ({}), please consider widen the terminal for better '
206 | 'progressbar visualization'.format(terminal_width))
207 | max_bar_width = 10
208 | return max_bar_width
209 |
210 | def start(self):
211 | if self.task_num > 0:
212 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
213 | ' ' * self.bar_width, self.task_num, 'Start...'))
214 | else:
215 | sys.stdout.write('completed: 0, elapsed: 0s')
216 | sys.stdout.flush()
217 | self.start_time = time.time()
218 |
219 | def update(self, msg='In progress...'):
220 | self.completed += 1
221 | elapsed = time.time() - self.start_time
222 | fps = self.completed / elapsed
223 | if self.task_num > 0:
224 | percentage = self.completed / float(self.task_num)
225 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
226 | mark_width = int(self.bar_width * percentage)
227 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
228 | sys.stdout.write('\033[2F') # cursor up 2 lines
229 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
230 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
231 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
232 | else:
233 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
234 | self.completed, int(elapsed + 0.5), fps))
235 | sys.stdout.flush()
236 |
--------------------------------------------------------------------------------
/codes/models/modules/Inv_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import torchvision
7 |
8 |
9 | class InvBlockExp(nn.Module):
10 | def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=1.):
11 | super(InvBlockExp, self).__init__()
12 |
13 | self.split_len1 = channel_split_num
14 | self.split_len2 = channel_num - channel_split_num
15 |
16 | self.clamp = clamp
17 |
18 | self.F = subnet_constructor(self.split_len2, self.split_len1)
19 | self.G = subnet_constructor(self.split_len1, self.split_len2)
20 | self.H = subnet_constructor(self.split_len1, self.split_len2)
21 |
22 | def forward(self, x, rev=False):
23 | x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
24 |
25 | if not rev:
26 | y1 = x1 + self.F(x2)
27 | self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1)
28 | y2 = x2.mul(torch.exp(self.s)) + self.G(y1)
29 | else:
30 | self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1)
31 | y2 = (x2 - self.G(x1)).div(torch.exp(self.s))
32 | y1 = x1 - self.F(y2)
33 |
34 | return torch.cat((y1, y2), 1)
35 |
36 | def jacobian(self, x, rev=False):
37 | if not rev:
38 | jac = torch.sum(self.s)
39 | else:
40 | jac = -torch.sum(self.s)
41 |
42 | return jac / x.shape[0]
43 |
44 |
45 | class HaarDownsampling(nn.Module):
46 | def __init__(self, channel_in):
47 | super(HaarDownsampling, self).__init__()
48 | self.channel_in = channel_in
49 |
50 | self.haar_weights = torch.ones(4, 1, 2, 2)
51 |
52 | self.haar_weights[1, 0, 0, 1] = -1
53 | self.haar_weights[1, 0, 1, 1] = -1
54 |
55 | self.haar_weights[2, 0, 1, 0] = -1
56 | self.haar_weights[2, 0, 1, 1] = -1
57 |
58 | self.haar_weights[3, 0, 1, 0] = -1
59 | self.haar_weights[3, 0, 0, 1] = -1
60 |
61 | self.haar_weights = torch.cat([self.haar_weights] * self.channel_in, 0)
62 | self.haar_weights = nn.Parameter(self.haar_weights)
63 | self.haar_weights.requires_grad = False
64 |
65 | def forward(self, x, rev=False):
66 | if not rev:
67 | self.elements = x.shape[1] * x.shape[2] * x.shape[3]
68 | self.last_jac = self.elements / 4 * np.log(1/16.)
69 |
70 | out = F.conv2d(x, self.haar_weights, bias=None, stride=2, groups=self.channel_in) / 4.0
71 | out = out.reshape([x.shape[0], self.channel_in, 4, x.shape[2] // 2, x.shape[3] // 2])
72 | out = torch.transpose(out, 1, 2)
73 | out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2] // 2, x.shape[3] // 2])
74 | return out
75 | else:
76 | self.elements = x.shape[1] * x.shape[2] * x.shape[3]
77 | self.last_jac = self.elements / 4 * np.log(16.)
78 | out = x.reshape([x.shape[0], 4, self.channel_in, x.shape[2], x.shape[3]])
79 | out = torch.transpose(out, 1, 2)
80 | out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2], x.shape[3]])
81 | return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=2, groups = self.channel_in)
82 |
83 | def jacobian(self, x, rev=False):
84 | return self.last_jac
85 |
86 | class ConvDownsampling_downsampleH(nn.Module):
87 | def __init__(self, channel_in):
88 | super(ConvDownsampling_downsampleH, self).__init__()
89 |
90 | self.channel_in = channel_in
91 |
92 | self.haar_weights = torch.ones(2, 1, 2, 1)
93 | self.haar_weights[0, 0, 0, 0] = 1/2
94 | self.haar_weights[0, 0, 1, 0] = 1/2
95 |
96 | self.haar_weights[1, 0, 0, 0] = -1/2
97 | self.haar_weights[1, 0, 1, 0] = 1/2
98 |
99 | self.haar_weights = torch.cat([self.haar_weights] * self.channel_in, 0)
100 | self.haar_weights = nn.Parameter(self.haar_weights)
101 | self.haar_weights.requires_grad = True
102 |
103 | def forward(self, x, rev=False):
104 | if not rev:
105 | self.elements = x.shape[1] * x.shape[2] * x.shape[3]
106 | self.last_jac = self.elements / 4 * np.log(1/16.)
107 |
108 | out = F.conv2d(x, self.haar_weights, bias=None, stride=(2, 1), groups=self.channel_in)
109 | out = out.reshape([x.shape[0], self.channel_in, 2, x.shape[2] // 2, x.shape[3]])
110 | out = torch.transpose(out, 1, 2)
111 | out = out.reshape([x.shape[0], self.channel_in * 2, x.shape[2] // 2, x.shape[3]])
112 | return out
113 | else:
114 | r = 2
115 | in_batch, in_channel, in_height, in_width = x.size()
116 | out_batch, out_channel, out_height, out_width = in_batch, int(
117 | in_channel / r), r * in_height, in_width
118 |
119 | self.elements = x.shape[1] * x.shape[2] * x.shape[3]
120 | self.last_jac = self.elements / 4 * np.log(16.)
121 |
122 | out = x.reshape([x.shape[0], 2, self.channel_in, x.shape[2], x.shape[3]])
123 | out = torch.transpose(out, 1, 2)
124 | out = out.reshape([x.shape[0], self.channel_in * 2, x.shape[2], x.shape[3]])
125 | return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=(2, 1), groups=self.channel_in) * 2.0
126 |
127 |
128 | class ConvDownsampling(nn.Module):
129 | def __init__(self, scale_W, scale_H):
130 | super(ConvDownsampling, self).__init__()
131 | self.scale_W = scale_W
132 | self.scale_H = scale_H
133 | self.scale2 = self.scale_H * self.scale_W
134 |
135 | self.conv_weights = torch.eye(self.scale2)
136 |
137 | if self.scale_W == self.scale_H == 2: # haar init
138 | self.conv_weights[0] = torch.Tensor([1./4, 1./4, 1./4, 1./4])
139 | self.conv_weights[1] = torch.Tensor([1./4, -1./4, 1./4, -1./4])
140 | self.conv_weights[2] = torch.Tensor([1./4, 1./4, -1./4, -1./4])
141 | self.conv_weights[3] = torch.Tensor([1./4, -1./4, -1./4, 1./4])
142 | else:
143 | self.conv_weights[0] = torch.Tensor([1./(self.scale2)] * (self.scale2))
144 |
145 | self.conv_weights = nn.Parameter(self.conv_weights)
146 |
147 | def forward(self, x, rev=False):
148 | if not rev:
149 | h = x.shape[2]
150 | w = x.shape[3]
151 |
152 | [B, C, H, W] = list(x.size())
153 | x = x.reshape(B, C, H // self.scale_H, self.scale_H, W // self.scale_W, self.scale_W)
154 | x = x.permute(0, 1, 3, 5, 2, 4)
155 | x = x.reshape(B, C * self.scale2, H // self.scale_H, W // self.scale_W)
156 |
157 | conv_weights = self.conv_weights.reshape(self.scale2, self.scale2, 1, 1)
158 | conv_weights = conv_weights.repeat(C, 1, 1, 1)
159 |
160 | out = F.conv2d(x, conv_weights, bias=None, stride=1, groups=C)
161 |
162 | out = out.reshape(B, C, self.scale2, H // self.scale_H, W // self.scale_W)
163 | out = torch.transpose(out, 1, 2)
164 | out = out.reshape(B, C * self.scale2, H // self.scale_H, W // self.scale_W)
165 |
166 |
167 | return out
168 | else:
169 | inv_weights = torch.inverse(self.conv_weights)
170 | inv_weights = inv_weights.reshape(self.scale2, self.scale2, 1, 1)
171 |
172 | [B, C_, H_, W_] = list(x.size())
173 | C = C_ // self.scale2
174 | H = H_ * self.scale_H
175 | W = W_ * self.scale_W
176 |
177 | inv_weights = inv_weights.repeat(C, 1, 1, 1)
178 |
179 | x = x.reshape(B, self.scale2, C, H_, W_)
180 | x = torch.transpose(x, 1, 2)
181 | x = x.reshape(B, C_, H_, W_)
182 |
183 | out = F.conv2d(x, inv_weights, bias=None, stride=1, groups=C)
184 |
185 | out = out.reshape(B, C, self.scale_H, self.scale_W, H_, W_)
186 | out = out.permute(0, 1, 4, 2, 5, 3)
187 | out = out.reshape(B, C, H, W)
188 |
189 | return out
190 |
191 | class InvRescaleNet(nn.Module):
192 | def __init__(self, channel_in=3, channel_out=3, subnet_constructor=None, block_num=[], down_num=2, use_ConvDownsampling=False, down_scale_W=2, down_scale_H=2):
193 | super(InvRescaleNet, self).__init__()
194 |
195 | operations = []
196 |
197 | if use_ConvDownsampling:
198 | down_num = 1
199 |
200 | current_channel = channel_in
201 | for i in range(down_num):
202 | if use_ConvDownsampling:
203 | b = ConvDownsampling(down_scale_W, down_scale_H)
204 | current_channel *= down_scale_W * down_scale_H
205 | elif down_scale_W==2 and down_scale_H==4 and i==1:
206 | b = ConvDownsampling_downsampleH(current_channel)
207 | current_channel *= 2
208 | else:
209 | b = HaarDownsampling(current_channel)
210 | current_channel *= 4
211 | operations.append(b)
212 | for j in range(block_num[i]):
213 | b = InvBlockExp(subnet_constructor, current_channel, channel_out)
214 | operations.append(b)
215 |
216 | self.operations = nn.ModuleList(operations)
217 |
218 | def forward(self, x, rev=False, cal_jacobian=False):
219 | out = x
220 | jacobian = 0
221 |
222 | if not rev:
223 | for op in self.operations:
224 | out = op.forward(out, rev)
225 | if cal_jacobian:
226 | jacobian += op.jacobian(out, rev)
227 | else:
228 | for op in reversed(self.operations):
229 | out = op.forward(out, rev)
230 | if cal_jacobian:
231 | jacobian += op.jacobian(out, rev)
232 |
233 | if cal_jacobian:
234 | return out, jacobian
235 | else:
236 | return out
237 |
238 |
--------------------------------------------------------------------------------
/codes/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import random
5 | import logging
6 | import warnings
7 |
8 | import torch
9 | import torch.distributed as dist
10 | import torch.multiprocessing as mp
11 | from data.data_sampler import DistIterSampler
12 |
13 | import options.options as option
14 | from utils import util
15 | from data import create_dataloader, create_dataset
16 | from models import create_model
17 |
18 | import torchvision
19 | import os.path as osp
20 |
21 | warnings.filterwarnings("ignore")
22 |
23 | def init_dist(backend='nccl', **kwargs):
24 | ''' initialization for distributed training'''
25 | # if mp.get_start_method(allow_none=True) is None:
26 | if mp.get_start_method(allow_none=True) != 'spawn':
27 | mp.set_start_method('spawn')
28 | rank = int(os.environ['RANK'])
29 | num_gpus = torch.cuda.device_count()
30 | torch.cuda.set_device(rank % num_gpus)
31 | dist.init_process_group(backend=backend, **kwargs)
32 |
33 | def load(name, net):
34 | state_dicts = torch.load(name)
35 | network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
36 | net.load_state_dict(network_state_dict)
37 |
38 |
39 | def main():
40 | #### options
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
43 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
44 | help='job launcher')
45 | parser.add_argument('--local_rank', type=int, default=0)
46 | args = parser.parse_args()
47 | opt = option.parse(args.opt, is_train=True)
48 |
49 |
50 | #### distributed training settings
51 | if args.launcher == 'none': # disabled distributed training
52 | opt['dist'] = False
53 | rank = -1
54 | print('Disabled distributed training.')
55 | else:
56 | opt['dist'] = True
57 | init_dist()
58 | world_size = torch.distributed.get_world_size()
59 | rank = torch.distributed.get_rank()
60 |
61 | #### loading resume state if exists
62 | if opt['path'].get('resume_state', None):
63 | # distributed resuming: all load into default GPU
64 | device_id = torch.cuda.current_device()
65 | resume_state = torch.load(opt['path']['resume_state'],
66 | map_location=lambda storage, loc: storage.cuda(device_id))
67 | option.check_resume(opt, resume_state['iter']) # check resume options
68 | else:
69 | resume_state = None
70 |
71 | #### mkdir and loggers
72 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
73 | if resume_state is None:
74 | util.mkdir_and_rename(
75 | opt['path']['experiments_root']) # rename experiment folder if exists
76 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
77 | and 'pretrain_model' not in key and 'resume' not in key))
78 |
79 | # config loggers. Before it, the log will not work
80 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
81 | screen=True, tofile=True)
82 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
83 | screen=True, tofile=True)
84 | logger = logging.getLogger('base')
85 | logger.info(option.dict2str(opt))
86 | # tensorboard logger
87 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
88 | version = float(torch.__version__[0:3])
89 | if version >= 1.1: # PyTorch 1.1
90 | from torch.utils.tensorboard import SummaryWriter
91 | else:
92 | logger.info(
93 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
94 | from tensorboardX import SummaryWriter
95 | tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
96 | else:
97 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
98 | logger = logging.getLogger('base')
99 |
100 | # convert to NoneDict, which returns None for missing keys
101 | opt = option.dict_to_nonedict(opt)
102 |
103 | #### random seed
104 | seed = opt['train']['manual_seed']
105 | if seed is None:
106 | seed = random.randint(1, 10000)
107 | if rank <= 0:
108 | logger.info('Random seed: {}'.format(seed))
109 | util.set_random_seed(seed)
110 |
111 | torch.backends.cudnn.benchmark = True
112 | # torch.backends.cudnn.deterministic = True
113 |
114 | #### create train and val dataloader
115 | dataset_ratio = 1 # enlarge the size of each epoch
116 | for phase, dataset_opt in opt['datasets'].items():
117 | if phase == 'train':
118 | train_set = create_dataset(dataset_opt)
119 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
120 | total_iters = int(opt['train']['niter'])
121 | total_epochs = int(math.ceil(total_iters / train_size))
122 | if opt['dist']:
123 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
124 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
125 | else:
126 | train_sampler = None
127 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
128 | if rank <= 0:
129 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
130 | len(train_set), train_size))
131 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
132 | total_epochs, total_iters))
133 | elif phase == 'val':
134 | val_set = create_dataset(dataset_opt)
135 | val_loader = create_dataloader(val_set, dataset_opt, opt, None)
136 | if rank <= 0:
137 | logger.info('Number of val images in [{:s}]: {:d}'.format(
138 | dataset_opt['name'], len(val_set)))
139 | else:
140 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
141 | assert train_loader is not None
142 |
143 | #### create model
144 | model = create_model(opt)
145 |
146 |
147 | #### resume training
148 | if resume_state:
149 | logger.info('Resuming training from epoch: {}, iter: {}.'.format(
150 | resume_state['epoch'], resume_state['iter']))
151 |
152 | start_epoch = resume_state['epoch']
153 | current_step = resume_state['iter']
154 | model.resume_training(resume_state) # handle optimizers and schedulers
155 | else:
156 | current_step = 0
157 | start_epoch = 0
158 |
159 | #### training
160 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
161 | for epoch in range(start_epoch, total_epochs + 1):
162 | if opt['dist']:
163 | train_sampler.set_epoch(epoch)
164 | for _, train_data in enumerate(train_loader):
165 | current_step += 1
166 | if current_step > total_iters:
167 | break
168 | #### training
169 | model.feed_data(train_data)
170 | model.optimize_parameters(current_step)
171 |
172 |
173 | #### update learning rate
174 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
175 |
176 | # validation
177 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
178 | avg_psnr_secret = 0.0
179 | avg_psnr_cover = 0.0
180 | idx = 0
181 | count = 0
182 | for val_data in val_loader:
183 | idx += 1
184 | # img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
185 | img_dir = os.path.join(opt['path']['val_images'], str(current_step))
186 | util.mkdir(img_dir)
187 |
188 | model.feed_data(val_data)
189 | model.test()
190 | count = count + 1
191 |
192 | visuals = model.get_current_visuals()
193 |
194 | cover = visuals['cover']
195 | secret = visuals['secret']
196 | secret_recover = visuals['secret_recover']
197 | steg = visuals['steg']
198 |
199 | cover_path = img_dir + "\\" + "cover"
200 | secret_path = img_dir + "\\" + "secret"
201 | secret_recover_path = img_dir + "\\" + "secret_recover"
202 | steg_path = img_dir + "\\" + "steg"
203 |
204 | if not os.path.exists(cover_path):
205 | os.makedirs(cover_path)
206 | if not os.path.exists(secret_path):
207 | os.makedirs(secret_path)
208 | if not os.path.exists(secret_recover_path):
209 | os.makedirs(secret_recover_path)
210 | if not os.path.exists(steg_path):
211 | os.makedirs(steg_path)
212 |
213 | save_cover_path = osp.join(cover_path, str(count) + '_clean_cover.png')
214 | torchvision.utils.save_image(cover, save_cover_path)
215 |
216 | save_steg_path = osp.join(steg_path, str(count) + '_steg.png')
217 | torchvision.utils.save_image(steg, save_steg_path)
218 |
219 | for i in range(int(secret.shape[0])):
220 | save_secret_path = osp.join(secret_path, str(count) + '_' + str(i) + '_clean_secret.png')
221 | save_secret_recover_path = osp.join(secret_recover_path,
222 | str(count) + '_' + str(i) + '_secret_recover.png')
223 | torchvision.utils.save_image(secret[i:i + 1, :, :, :], save_secret_path)
224 | torchvision.utils.save_image(secret_recover[i:i + 1, :, :, :], save_secret_recover_path)
225 |
226 |
227 | #### save models and training states
228 | if current_step % opt['logger']['save_checkpoint_freq'] == 0:
229 | if rank <= 0:
230 | logger.info('Saving models and training states.')
231 | model.save(current_step)
232 | model.save_training_state(epoch, current_step)
233 |
234 | if rank <= 0:
235 | logger.info('Saving the final model.')
236 | model.save('latest')
237 | logger.info('End of training.')
238 |
239 |
240 | if __name__ == '__main__':
241 | main()
242 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/codes/data/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import pickle
4 | import random
5 | import numpy as np
6 | import torch
7 | import cv2
8 | from PIL import Image
9 |
10 | ####################
11 | # Files & IO
12 | ####################
13 |
14 | ###################### get image path list ######################
15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
16 |
17 |
18 | def is_image_file(filename):
19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20 |
21 |
22 | def _get_paths_from_images(path):
23 | '''get image path list from image folder'''
24 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
25 | images = []
26 | for dirpath, _, fnames in sorted(os.walk(path)):
27 | for fname in sorted(fnames):
28 | if is_image_file(fname):
29 | img_path = os.path.join(dirpath, fname)
30 | images.append(img_path)
31 | assert images, '{:s} has no valid image file'.format(path)
32 | return images
33 |
34 |
35 | def _get_paths_from_lmdb(dataroot):
36 | '''get image path list from lmdb meta info'''
37 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
38 | paths = meta_info['keys']
39 | sizes = meta_info['resolution']
40 | if len(sizes) == 1:
41 | sizes = sizes * len(paths)
42 | return paths, sizes
43 |
44 |
45 | def get_image_paths(data_type, dataroot):
46 | '''get image path list
47 | support lmdb or image files'''
48 | paths, sizes = None, None
49 | if dataroot is not None:
50 | if data_type == 'lmdb':
51 | paths, sizes = _get_paths_from_lmdb(dataroot)
52 | elif data_type == 'img':
53 | paths = sorted(_get_paths_from_images(dataroot))
54 | else:
55 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
56 | return paths, sizes
57 |
58 |
59 | ###################### read images ######################
60 | def _read_img_lmdb(env, key, size):
61 | '''read image from lmdb with key (w/ and w/o fixed size)
62 | size: (C, H, W) tuple'''
63 | with env.begin(write=False) as txn:
64 | buf = txn.get(key.encode('ascii'))
65 | img_flat = np.frombuffer(buf, dtype=np.uint8)
66 | C, H, W = size
67 | img = img_flat.reshape(H, W, C)
68 | return img
69 |
70 |
71 | def read_img(env, path, size=None):
72 | '''read image by cv2 or from lmdb
73 | return: Numpy float32, HWC, BGR, [0,1]'''
74 | if env is None: # img
75 | #img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
76 | img = cv2.imread(path, cv2.IMREAD_COLOR)
77 | # img = Image.open(path,'r')
78 | else:
79 | img = _read_img_lmdb(env, path, size)
80 | img = img.astype(np.float32) / 255.
81 | if img.ndim == 2:
82 | img = np.expand_dims(img, axis=2)
83 | # some images have 4 channels
84 | if img.shape[2] > 3:
85 | img = img[:, :, :3]
86 | return img
87 |
88 |
89 | ####################
90 | # image processing
91 | # process on numpy image
92 | ####################
93 |
94 |
95 | def augment(img_list, hflip=True, rot=True):
96 | # horizontal flip OR rotate
97 | hflip = hflip and random.random() < 0.5
98 | vflip = rot and random.random() < 0.5
99 | rot90 = rot and random.random() < 0.5
100 |
101 | def _augment(img):
102 | if hflip:
103 | img = img[:, ::-1, :]
104 | if vflip:
105 | img = img[::-1, :, :]
106 | if rot90:
107 | img = img.transpose(1, 0, 2)
108 | return img
109 |
110 | return [_augment(img) for img in img_list]
111 |
112 |
113 | def augment_flow(img_list, flow_list, hflip=True, rot=True):
114 | # horizontal flip OR rotate
115 | hflip = hflip and random.random() < 0.5
116 | vflip = rot and random.random() < 0.5
117 | rot90 = rot and random.random() < 0.5
118 |
119 | def _augment(img):
120 | if hflip:
121 | img = img[:, ::-1, :]
122 | if vflip:
123 | img = img[::-1, :, :]
124 | if rot90:
125 | img = img.transpose(1, 0, 2)
126 | return img
127 |
128 | def _augment_flow(flow):
129 | if hflip:
130 | flow = flow[:, ::-1, :]
131 | flow[:, :, 0] *= -1
132 | if vflip:
133 | flow = flow[::-1, :, :]
134 | flow[:, :, 1] *= -1
135 | if rot90:
136 | flow = flow.transpose(1, 0, 2)
137 | flow = flow[:, :, [1, 0]]
138 | return flow
139 |
140 | rlt_img_list = [_augment(img) for img in img_list]
141 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
142 |
143 | return rlt_img_list, rlt_flow_list
144 |
145 |
146 | def channel_convert(in_c, tar_type, img_list):
147 | # conversion among BGR, gray and y
148 | if in_c == 3 and tar_type == 'gray': # BGR to gray
149 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
150 | return [np.expand_dims(img, axis=2) for img in gray_list]
151 | elif in_c == 3 and tar_type == 'y': # BGR to y
152 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
153 | return [np.expand_dims(img, axis=2) for img in y_list]
154 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
155 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
156 | else:
157 | return img_list
158 |
159 |
160 | def rgb2ycbcr(img, only_y=True):
161 | '''same as matlab rgb2ycbcr
162 | only_y: only return Y channel
163 | Input:
164 | uint8, [0, 255]
165 | float, [0, 1]
166 | '''
167 | in_img_type = img.dtype
168 | img.astype(np.float32)
169 | if in_img_type != np.uint8:
170 | img *= 255.
171 | # convert
172 | if only_y:
173 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
174 | else:
175 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
176 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
177 | if in_img_type == np.uint8:
178 | rlt = rlt.round()
179 | else:
180 | rlt /= 255.
181 | return rlt.astype(in_img_type)
182 |
183 |
184 | def bgr2ycbcr(img, only_y=True):
185 | '''bgr version of rgb2ycbcr
186 | only_y: only return Y channel
187 | Input:
188 | uint8, [0, 255]
189 | float, [0, 1]
190 | '''
191 | in_img_type = img.dtype
192 | img.astype(np.float32)
193 | if in_img_type != np.uint8:
194 | img *= 255.
195 | # convert
196 | if only_y:
197 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
198 | else:
199 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
200 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
201 | if in_img_type == np.uint8:
202 | rlt = rlt.round()
203 | else:
204 | rlt /= 255.
205 | return rlt.astype(in_img_type)
206 |
207 |
208 | def ycbcr2rgb(img):
209 | '''same as matlab ycbcr2rgb
210 | Input:
211 | uint8, [0, 255]
212 | float, [0, 1]
213 | '''
214 | in_img_type = img.dtype
215 | img.astype(np.float32)
216 | if in_img_type != np.uint8:
217 | img *= 255.
218 | # convert
219 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
220 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
221 | if in_img_type == np.uint8:
222 | rlt = rlt.round()
223 | else:
224 | rlt /= 255.
225 | return rlt.astype(in_img_type)
226 |
227 |
228 | def modcrop(img_in, scale_W, scale_H):
229 | # img_in: Numpy, HWC or HW
230 | img = np.copy(img_in)
231 | if img.ndim == 2:
232 | H, W = img.shape
233 | H_r, W_r = H % scale_H, W % scale_W
234 | img = img[:H - H_r, :W - W_r]
235 | elif img.ndim == 3:
236 | H, W, C = img.shape
237 | H_r, W_r = H % (scale_H * 4), W % (scale_W * 4)
238 | img = img[:H - H_r, :W - W_r, :]
239 | else:
240 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
241 | return img
242 |
243 |
244 | ####################
245 | # Functions
246 | ####################
247 |
248 |
249 | # matlab 'imresize' function, now only support 'bicubic'
250 | def cubic(x):
251 | absx = torch.abs(x)
252 | absx2 = absx**2
253 | absx3 = absx**3
254 | return (1.5 * absx3 - 2.5 * absx2 + 1) * (
255 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
256 | (absx > 1) * (absx <= 2)).type_as(absx))
257 |
258 |
259 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
260 | if (scale < 1) and (antialiasing):
261 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
262 | kernel_width = kernel_width / scale
263 |
264 | # Output-space coordinates
265 | x = torch.linspace(1, out_length, out_length)
266 |
267 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
268 | # in output space maps to 0.5 in input space, and 0.5+scale in output
269 | # space maps to 1.5 in input space.
270 | u = x / scale + 0.5 * (1 - 1 / scale)
271 |
272 | # What is the left-most pixel that can be involved in the computation?
273 | left = torch.floor(u - kernel_width / 2)
274 |
275 | # What is the maximum number of pixels that can be involved in the
276 | # computation? Note: it's OK to use an extra pixel here; if the
277 | # corresponding weights are all zero, it will be eliminated at the end
278 | # of this function.
279 | P = math.ceil(kernel_width) + 2
280 |
281 | # The indices of the input pixels involved in computing the k-th output
282 | # pixel are in row k of the indices matrix.
283 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
284 | 1, P).expand(out_length, P)
285 |
286 | # The weights used to compute the k-th output pixel are in row k of the
287 | # weights matrix.
288 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
289 | # apply cubic kernel
290 | if (scale < 1) and (antialiasing):
291 | weights = scale * cubic(distance_to_center * scale)
292 | else:
293 | weights = cubic(distance_to_center)
294 | # Normalize the weights matrix so that each row sums to 1.
295 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
296 | weights = weights / weights_sum.expand(out_length, P)
297 |
298 | # If a column in weights is all zero, get rid of it. only consider the first and last column.
299 | weights_zero_tmp = torch.sum((weights == 0), 0)
300 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
301 | indices = indices.narrow(1, 1, P - 2)
302 | weights = weights.narrow(1, 1, P - 2)
303 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
304 | indices = indices.narrow(1, 0, P - 2)
305 | weights = weights.narrow(1, 0, P - 2)
306 | weights = weights.contiguous()
307 | indices = indices.contiguous()
308 | sym_len_s = -indices.min() + 1
309 | sym_len_e = indices.max() - in_length
310 | indices = indices + sym_len_s - 1
311 | return weights, indices, int(sym_len_s), int(sym_len_e)
312 |
313 |
314 | def imresize(img, scale, antialiasing=True):
315 | # Now the scale should be the same for H and W
316 | # input: img: CHW RGB [0,1]
317 | # output: CHW RGB [0,1] w/o round
318 |
319 | in_C, in_H, in_W = img.size()
320 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
321 | kernel_width = 4
322 | kernel = 'cubic'
323 |
324 | # Return the desired dimension order for performing the resize. The
325 | # strategy is to perform the resize first along the dimension with the
326 | # smallest scale factor.
327 | # Now we do not support this.
328 |
329 | # get weights and indices
330 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
331 | in_H, out_H, scale, kernel, kernel_width, antialiasing)
332 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
333 | in_W, out_W, scale, kernel, kernel_width, antialiasing)
334 | # process H dimension
335 | # symmetric copying
336 | img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
337 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
338 |
339 | sym_patch = img[:, :sym_len_Hs, :]
340 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
341 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
342 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
343 |
344 | sym_patch = img[:, -sym_len_He:, :]
345 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
346 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
347 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
348 |
349 | out_1 = torch.FloatTensor(in_C, out_H, in_W)
350 | kernel_width = weights_H.size(1)
351 | for i in range(out_H):
352 | idx = int(indices_H[i][0])
353 | out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
354 | out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
355 | out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
356 |
357 | # process W dimension
358 | # symmetric copying
359 | out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
360 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
361 |
362 | sym_patch = out_1[:, :, :sym_len_Ws]
363 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
364 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
365 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
366 |
367 | sym_patch = out_1[:, :, -sym_len_We:]
368 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
369 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
370 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
371 |
372 | out_2 = torch.FloatTensor(in_C, out_H, out_W)
373 | kernel_width = weights_W.size(1)
374 | for i in range(out_W):
375 | idx = int(indices_W[i][0])
376 | out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
377 | out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
378 | out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
379 |
380 | return out_2
381 |
382 |
383 | def imresize_np(img, scale_W, scale_H, antialiasing=True):
384 | # Now the scale should be the same for H and W
385 | # input: img: Numpy, HWC BGR [0,1]
386 | # output: HWC BGR [0,1] w/o round
387 | img = torch.from_numpy(img)
388 |
389 | in_H, in_W, in_C = img.size()
390 | _, out_H, out_W = in_C, math.ceil(in_H * scale_H), math.ceil(in_W * scale_W)
391 | kernel_width = 4
392 | kernel = 'cubic'
393 |
394 | # Return the desired dimension order for performing the resize. The
395 | # strategy is to perform the resize first along the dimension with the
396 | # smallest scale factor.
397 | # Now we do not support this.
398 |
399 | # get weights and indices
400 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
401 | in_H, out_H, scale_H, kernel, kernel_width, antialiasing)
402 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
403 | in_W, out_W, scale_W, kernel, kernel_width, antialiasing)
404 | # process H dimension
405 | # symmetric copying
406 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
407 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
408 |
409 | sym_patch = img[:sym_len_Hs, :, :]
410 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
411 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
412 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
413 |
414 | sym_patch = img[-sym_len_He:, :, :]
415 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
416 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
417 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
418 |
419 | out_1 = torch.FloatTensor(out_H, in_W, in_C)
420 | kernel_width = weights_H.size(1)
421 | for i in range(out_H):
422 | idx = int(indices_H[i][0])
423 | out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
424 | out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
425 | out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
426 |
427 | # process W dimension
428 | # symmetric copying
429 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
430 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
431 |
432 | sym_patch = out_1[:, :sym_len_Ws, :]
433 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
434 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
435 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
436 |
437 | sym_patch = out_1[:, -sym_len_We:, :]
438 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
439 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
440 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
441 |
442 | out_2 = torch.FloatTensor(out_H, out_W, in_C)
443 | kernel_width = weights_W.size(1)
444 | for i in range(out_W):
445 | idx = int(indices_W[i][0])
446 | out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
447 | out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
448 | out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
449 |
450 | return out_2.numpy()
451 |
452 |
453 | if __name__ == '__main__':
454 | # test imresize function
455 | # read images
456 | img = cv2.imread('test.png')
457 | img = img * 1.0 / 255
458 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
459 | # imresize
460 | scale = 1 / 4
461 | import time
462 | total_time = 0
463 | for i in range(10):
464 | start_time = time.time()
465 | rlt = imresize(img, scale, antialiasing=True)
466 | use_time = time.time() - start_time
467 | total_time += use_time
468 | print('average time: {}'.format(total_time / 10))
469 |
470 | import torchvision.utils
471 | torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
472 | normalize=False)
473 |
--------------------------------------------------------------------------------
/codes/models/InvMIHNet_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from collections import OrderedDict
4 |
5 | import torch
6 | from torch.nn.parallel import DistributedDataParallel
7 | import models.networks as networks
8 | import models.lr_scheduler as lr_scheduler
9 | from .base_model import BaseModel
10 | from models.modules.loss import ReconstructionLoss
11 | from models.modules.Quantization import Quantization
12 | import numpy as np
13 | import models.modules.Unet_common as common
14 | from models.model import *
15 | from .model import *
16 | from PIL import Image
17 |
18 | logger = logging.getLogger('base')
19 |
20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 |
22 | def gauss_noise(shape):
23 | noise = torch.zeros(shape).cuda()
24 | for i in range(noise.shape[0]):
25 | noise[i] = torch.randn(noise[i].shape).cuda()
26 |
27 | return noise
28 |
29 |
30 | def guide_loss(output, bicubic_image):
31 | loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
32 | loss = loss_fn(output, bicubic_image)
33 | return loss.to(device)
34 |
35 |
36 | def reconstruction_loss(rev_input, input):
37 | loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
38 | loss = loss_fn(rev_input, input)
39 | return loss.to(device)
40 |
41 |
42 | def low_frequency_loss(ll_input, gt_input):
43 | loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
44 | loss = loss_fn(ll_input, gt_input)
45 | return loss.to(device)
46 |
47 |
48 |
49 | def image_save(img,save_dir,img_name):
50 | if not os.path.exists(save_dir):
51 | os.makedirs(save_dir)
52 | img_path = os.path.join(save_dir, img_name)
53 | at_images_np = img.detach().cpu().numpy()
54 | adv_img = at_images_np[0]
55 | adv_img = np.moveaxis(adv_img, 0, 2)
56 | img_pil = Image.fromarray(adv_img.astype(np.uint8))
57 | img_pil.save(img_path)
58 |
59 |
60 | dwt = common.DWT()
61 | iwt = common.IWT()
62 |
63 |
64 |
65 | class InvMIHNet(BaseModel):
66 | def __init__(self, opt):
67 | super(InvMIHNet, self).__init__(opt)
68 |
69 | if opt['dist']:
70 | self.rank = torch.distributed.get_rank()
71 | else:
72 | self.rank = -1 # non dist training
73 | train_opt = opt['train']
74 | test_opt = opt['test']
75 | self.train_opt = train_opt
76 | self.test_opt = test_opt
77 |
78 | self.netG = networks.define_G(opt).to(self.device)
79 | self.netH = Model().to(self.device)
80 | init_model(self.netH)
81 |
82 | # print network
83 | self.print_network()
84 | self.load()
85 | self.load_H(self.netH)
86 |
87 | self.Quantization = Quantization()
88 |
89 | if self.is_train:
90 | self.netG.train()
91 | self.netH.train()
92 |
93 | # loss
94 | self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw'])
95 | self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back'])
96 |
97 |
98 | # optimizers
99 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
100 | optim_params_G = []
101 | for k, v in self.netG.named_parameters():
102 | if v.requires_grad:
103 | optim_params_G.append(v)
104 | else:
105 | if self.rank <= 0:
106 | logger.warning('Params [{:s}] will not optimize.'.format(k))
107 | self.optimizer_G = torch.optim.Adam(optim_params_G, lr=train_opt['lr_G'],
108 | weight_decay=wd_G,
109 | betas=(train_opt['beta1'], train_opt['beta2']))
110 | self.optimizers.append(self.optimizer_G)
111 |
112 | optim_params_H = []
113 | for k, v in self.netH.named_parameters():
114 | if v.requires_grad:
115 | optim_params_H.append(v)
116 | else:
117 | if self.rank <= 0:
118 | logger.warning('Params [{:s}] will not optimize.'.format(k))
119 | self.optimizer_H = torch.optim.Adam(optim_params_H, lr=train_opt['lr_H'], betas=(train_opt['beta1_H'], train_opt['beta2_H']), eps=1e-6, weight_decay=train_opt['weight_decay_H'])
120 |
121 | # schedulers
122 | if train_opt['lr_scheme'] == 'MultiStepLR':
123 | for optimizer in self.optimizers:
124 | self.schedulers.append(
125 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
126 | restarts=train_opt['restarts'],
127 | weights=train_opt['restart_weights'],
128 | gamma=train_opt['lr_gamma'],
129 | clear_state=train_opt['clear_state']))
130 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
131 | for optimizer in self.optimizers:
132 | self.schedulers.append(
133 | lr_scheduler.CosineAnnealingLR_Restart(
134 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
135 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
136 | else:
137 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
138 |
139 | self.schedulers_H = torch.optim.lr_scheduler.StepLR(self.optimizer_H, train_opt['weight_step'], gamma=train_opt['lr_gamma'])
140 |
141 | self.log_dict = OrderedDict()
142 |
143 | def feed_data(self, data):
144 | self.ref_L = data['LQ'].to(self.device) # LQ
145 | self.real_H = data['GT'].to(self.device) # GT
146 |
147 | def gaussian_batch(self, dims):
148 | return torch.randn(tuple(dims)).to(self.device)
149 |
150 | def loss_forward(self, out, y, z):
151 | l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y)
152 |
153 | z = z.reshape([out.shape[0], -1])
154 | l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0]
155 |
156 | return l_forw_fit, l_forw_ce
157 |
158 | def loss_backward(self, x, y):
159 | x_samples = self.netG(x=y, rev=True)
160 | x_samples_image = x_samples[:, :3, :, :]
161 | l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)
162 |
163 | return l_back_rec
164 |
165 | def get_parameter_number(net):
166 | total_num = sum(p.numel() for p in net.parameters())
167 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
168 | return {'Total': total_num, 'Trainable': trainable_num}
169 |
170 | def optimize_parameters(self, step):
171 | # torch.autograd.set_detect_anomaly(True)
172 |
173 | # downscaling
174 | self.input = self.real_H[1: ,: , :, :]
175 | self.output = self.netG(x=self.input)
176 |
177 | zshape = self.output[:, 3:, :, :].shape
178 | LR_ref = self.ref_L[1: ,: , :, :].detach()
179 |
180 | secret = self.output[:, :3, :, :]
181 | new_H = int(self.input.shape[2] / self.opt['scale_H'])
182 | new_W = int(self.input.shape[3] / self.opt['scale_W'])
183 | if self.opt['scale_H'] * self.opt['scale_W'] == 4:
184 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:],
185 | self.real_H[:1, :3, new_H:, :new_W], self.real_H[:1, :3, new_H:, new_W:]), dim=0)
186 | elif self.opt['scale_H'] * self.opt['scale_W'] == 6:
187 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2],
188 | self.real_H[:1, :3, new_H:new_H * 2, :new_W],
189 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2],
190 | self.real_H[:1, :3, new_H * 2:new_H * 4, :new_W],
191 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2],), dim=0)
192 | elif self.opt['scale_H'] * self.opt['scale_W'] == 8:
193 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2],
194 | self.real_H[:1, :3, new_H:new_H * 2, :new_W],
195 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2],
196 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W],
197 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2],
198 | self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W],
199 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2]), dim=0)
200 | elif self.opt['scale_H'] * self.opt['scale_W'] == 9:
201 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_H * 2],
202 | self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3],
203 | self.real_H[:1, :3, new_H:new_H * 2, :new_W],
204 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_H * 2],
205 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3],
206 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W],
207 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_H * 2],
208 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3],), dim=0)
209 | elif self.opt['scale_H'] * self.opt['scale_W'] == 16:
210 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2],
211 | self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3],
212 | self.real_H[:1, :3, :new_H, new_W * 3:new_W * 4],
213 | self.real_H[:1, :3, new_H:new_H * 2, :new_W],
214 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2],
215 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3],
216 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 3:new_W * 4],
217 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W],
218 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2],
219 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3],
220 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 3:new_W * 4],
221 | self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W],
222 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2],
223 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 2:new_W * 3],
224 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 3:new_W * 4],
225 | ), dim=0)
226 |
227 | cover_input = dwt(cover)
228 | secret_input = dwt(secret)
229 | input_img = torch.cat((cover_input, secret_input), 1)
230 |
231 | # hiding
232 | output = self.netH(input_img)
233 |
234 | channel_in = self.opt['network_G']['in_nc']
235 |
236 | output_steg = output.narrow(1, 0, 4 * channel_in)
237 | output_z = output.narrow(1, 4 * channel_in, output.shape[1] - 4 * channel_in)
238 | steg_img = iwt(output_steg)
239 | steg_img = self.Quantization(steg_img)
240 |
241 | # concealing
242 | output_z_guass = gauss_noise(output_z.shape)
243 |
244 | output_rev = torch.cat((output_steg, output_z_guass), 1)
245 | output_image = self.netH(output_rev, rev=True)
246 |
247 | secret_rev = output_image.narrow(1, 4 * channel_in, output_image.shape[1] - 4 * channel_in)
248 | secret_rev_1 = iwt(secret_rev)
249 |
250 |
251 | # loss functions
252 | g_loss = guide_loss(steg_img.cuda(), cover.cuda())
253 | r_loss = reconstruction_loss(secret_rev_1, secret[:,:3,:,:])
254 | steg_low = output_steg.narrow(1, 0, channel_in)
255 | cover_low = cover_input.narrow(1, 0, channel_in)
256 | l_loss = low_frequency_loss(steg_low, cover_low)
257 |
258 | total_loss = self.train_opt['lamda_reconstruction'] * r_loss + self.train_opt['lamda_guide'] * g_loss + self.train_opt['lamda_low_frequency'] * l_loss
259 |
260 | l_forw_fit, l_forw_ce = self.loss_forward(secret_rev_1, LR_ref, self.output[:, 3:, :, :])
261 |
262 | # upscaling
263 | LR = self.Quantization(secret_rev_1)
264 | gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt['gaussian_scale'] != None else 1
265 | y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1)
266 |
267 | l_back_rec = self.loss_backward(self.real_H[1:,:,:,:], y_)
268 |
269 | total_loss.backward(retain_graph=True)
270 |
271 | loss = l_forw_fit + l_back_rec + l_forw_ce
272 | print("step", step, "_loss:", loss)
273 | loss.backward()
274 |
275 | # gradient clipping
276 | if self.train_opt['gradient_clipping']:
277 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
278 | nn.utils.clip_grad_norm_(self.netH.parameters(), self.train_opt['gradient_clipping'])
279 |
280 | self.optimizer_G.step()
281 | self.optimizer_H.step()
282 | self.optimizer_H.zero_grad()
283 | self.optimizer_G.zero_grad()
284 |
285 | if step % self.opt['logger']['save_checkpoint_freq'] == 0:
286 | # save_path = os.path.join(self.opt['path']['models'], save_filename)
287 | torch.save({'net': self.netH.state_dict()}, os.path.join(self.opt['path']['models'], 'IIH_' + str(step)+'.pth'))
288 |
289 |
290 | def test(self):
291 | Lshape = self.ref_L[1:, :, :, :].shape
292 |
293 | input_dim = Lshape[1]
294 | self.input = self.real_H[1:]
295 |
296 | zshape = [Lshape[0], input_dim * (self.opt['scale_W'] * self.opt['scale_H']) - Lshape[1], Lshape[2], Lshape[3]]
297 |
298 | gaussian_scale = 1
299 | if self.test_opt and self.test_opt['gaussian_scale'] != None:
300 | gaussian_scale = self.test_opt['gaussian_scale']
301 |
302 | self.netG.eval()
303 | self.netH.eval()
304 | new_H = int(self.input.shape[2] / self.opt['scale_H'])
305 | new_W = int(self.input.shape[3] / self.opt['scale_W'])
306 | if self.opt['scale_H']*self.opt['scale_W'] == 4:
307 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:], self.real_H[:1, :3, new_H:, :new_W], self.real_H[:1, :3, new_H:, new_W:]), dim=0)
308 | elif self.opt['scale_H']*self.opt['scale_W'] == 6:
309 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2],
310 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2],
311 | self.real_H[:1, :3, new_H*2:new_H*4, :new_W], self.real_H[:1, :3, new_H*2:new_H*3, new_W:new_W * 2],), dim=0)
312 | elif self.opt['scale_H']*self.opt['scale_W'] == 8:
313 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2], self.real_H[:1, :3, new_H:new_H * 2, :new_W], self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W*2],
314 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2], self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W], self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2]), dim=0)
315 | elif self.opt['scale_H']*self.opt['scale_W'] == 9:
316 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_H * 2], self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3],
317 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], self.real_H[:1, :3, new_H:new_H * 2, new_W:new_H * 2], self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3],
318 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_H * 2], self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3],), dim=0)
319 | elif self.opt['scale_H']*self.opt['scale_W'] == 16:
320 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2],
321 | self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3],
322 | self.real_H[:1, :3, :new_H, new_W * 3:new_W * 4],
323 | self.real_H[:1, :3, new_H:new_H * 2, :new_W],
324 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2],
325 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3],
326 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 3:new_W * 4],
327 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W],
328 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2],
329 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3],
330 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 3:new_W * 4],
331 | self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W],
332 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2],
333 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 2:new_W * 3],
334 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 3:new_W * 4],
335 | ), dim=0)
336 |
337 | self.cover = self.real_H[:1, :3, :, :]
338 | self.secret = self.real_H[1:, :3, :, :]
339 |
340 | with torch.no_grad():
341 | output = self.netG(x=self.input)[:, :3, :, :]
342 | cover_input = dwt(cover)
343 | secret_input = dwt(output)
344 | input_img = torch.cat((cover_input, secret_input), 1)
345 | channel_in = self.opt['network_G']['in_nc']
346 |
347 | output_inn = self.netH(input_img)
348 | output_steg = output_inn.narrow(1, 0, 4 * channel_in)
349 | output_z = output_inn.narrow(1, 4 * channel_in, output_inn.shape[1] - 4 * channel_in)
350 | steg_img = iwt(output_steg)
351 |
352 | output_z_guass = gauss_noise(output_z.shape)
353 |
354 | output_rev = torch.cat((output_steg, output_z_guass), 1)
355 | output_image = self.netH(output_rev, rev=True)
356 |
357 | secret_rev = output_image.narrow(1, 4 * channel_in, output_image.shape[1] - 4 * channel_in)
358 | secret_rev = iwt(secret_rev)
359 |
360 | self.forw_L = secret_rev
361 | self.forw_L = self.Quantization(self.forw_L).cuda()
362 | y_forw = torch.cat((self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1)
363 | self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :]
364 |
365 | self.secret_recover = self.fake_H
366 |
367 | if self.opt['scale_H'] * self.opt['scale_W'] == 4:
368 | steg_1 = torch.cat((steg_img[0], steg_img[1]), 2)
369 | steg_2 = torch.cat((steg_img[2], steg_img[3]), 2)
370 | steg = torch.cat((steg_1, steg_2), 1)
371 | elif self.opt['scale_H'] * self.opt['scale_W'] == 6:
372 | steg_1 = torch.cat((steg_img[0], steg_img[1]), 2)
373 | steg_2 = torch.cat((steg_img[2], steg_img[3]), 2)
374 | steg_3 = torch.cat((steg_img[4], steg_img[5]), 2)
375 | steg = torch.cat((steg_1, steg_2, steg_3), 1)
376 | elif self.opt['scale_H'] * self.opt['scale_W'] == 8:
377 | steg_1 = torch.cat((steg_img[0], steg_img[1]), 2)
378 | steg_2 = torch.cat((steg_img[2], steg_img[3]), 2)
379 | steg_3 = torch.cat((steg_img[4], steg_img[5]), 2)
380 | steg_4 = torch.cat((steg_img[6], steg_img[7]), 2)
381 | steg = torch.cat((steg_1, steg_2, steg_3, steg_4), 1)
382 | elif self.opt['scale_H'] * self.opt['scale_W'] == 9:
383 | steg_1 = torch.cat((steg_img[0], steg_img[1], steg_img[2]), 2)
384 | steg_2 = torch.cat((steg_img[3], steg_img[4], steg_img[5]), 2)
385 | steg_3 = torch.cat((steg_img[6], steg_img[7], steg_img[8]), 2)
386 | steg = torch.cat((steg_1, steg_2, steg_3), 1)
387 | elif self.opt['scale_H'] * self.opt['scale_W'] == 16:
388 | steg_1 = torch.cat((steg_img[0], steg_img[1], steg_img[2], steg_img[3]), 2)
389 | steg_2 = torch.cat((steg_img[4], steg_img[5], steg_img[6], steg_img[7]), 2)
390 | steg_3 = torch.cat((steg_img[8], steg_img[9], steg_img[10], steg_img[11]), 2)
391 | steg_4 = torch.cat((steg_img[12], steg_img[13], steg_img[14], steg_img[15]), 2)
392 | steg = torch.cat((steg_1, steg_2, steg_3, steg_4), 1)
393 | self.steg = self.Quantization(steg)
394 |
395 | def get_current_log(self):
396 | return self.log_dict
397 |
398 | def get_current_visuals(self):
399 | out_dict = OrderedDict()
400 | out_dict['cover'] = self.cover.detach().float().cpu()
401 | out_dict['secret'] = self.secret.detach().float().cpu()
402 | out_dict['secret_recover'] = self.secret_recover.detach().float().cpu()
403 | out_dict['steg'] = self.steg.detach().float().cpu()
404 | return out_dict
405 |
406 | def print_network(self):
407 | s, n = self.get_network_description(self.netG)
408 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
409 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
410 | self.netG.module.__class__.__name__)
411 | else:
412 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
413 | if self.rank <= 0:
414 | logger.info('Network IIR structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
415 | logger.info(s)
416 |
417 | def load(self):
418 | load_path_G = self.opt['path']['pretrain_model_G']
419 | if load_path_G is not None:
420 | logger.info('Loading model for IIR [{:s}] ...'.format(load_path_G))
421 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
422 |
423 | def load_H(self, net):
424 | load_path_H = self.opt['path']['pretrain_model_H']
425 | if load_path_H is not None:
426 | logger.info('Loading model for IIH [{:s}] ...'.format(load_path_H))
427 | state_dicts = torch.load(load_path_H)
428 | network_state_dict = {k.replace("module.", ""): v for k, v in state_dicts['net'].items() if
429 | 'tmp_var' not in k}
430 | network_state_dict = {k: v for k, v in network_state_dict.items() if 'rect' not in k}
431 | net.load_state_dict(network_state_dict)
432 |
433 | def save(self, iter_label):
434 | self.save_network(self.netG, 'IIR', iter_label)
435 |
--------------------------------------------------------------------------------
/codes/models/modules/Unet_common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from models.modules import module_util as mutil
7 | import functools
8 |
9 | def default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1, use_snorm=False):
10 | if use_snorm:
11 | return nn.utils.spectral_norm(nn.Conv2d(
12 | in_channels, out_channels, kernel_size,
13 | padding=(kernel_size//2)+dilation-1, bias=bias, dilation=dilation))
14 | else:
15 | return nn.Conv2d(
16 | in_channels, out_channels, kernel_size,
17 | padding=(kernel_size//2)+dilation-1, bias=bias, dilation=dilation)
18 |
19 |
20 | def default_conv1(in_channels, out_channels, kernel_size, bias=True, groups=3, use_snorm=False):
21 | if use_snorm:
22 | return nn.utils.spectral_norm(nn.Conv2d(
23 | in_channels,out_channels, kernel_size,
24 | padding=(kernel_size//2), bias=bias, groups=groups))
25 | else:
26 | return nn.Conv2d(
27 | in_channels,out_channels, kernel_size,
28 | padding=(kernel_size//2), bias=bias, groups=groups)
29 |
30 | def default_conv3d(in_channels, out_channels, kernel_size, t_kernel=3, bias=True, dilation=1, groups=1, use_snorm=False):
31 | if use_snorm:
32 | return nn.utils.spectral_norm(nn.Conv3d(
33 | in_channels,out_channels, (t_kernel, kernel_size, kernel_size), stride=1,
34 | padding=(0,kernel_size//2,kernel_size//2), bias=bias, dilation=dilation, groups=groups))
35 | else:
36 | return nn.Conv3d(
37 | in_channels,out_channels, (t_kernel, kernel_size, kernel_size), stride=1,
38 | padding=(0,kernel_size//2,kernel_size//2), bias=bias, dilation=dilation, groups=groups)
39 |
40 | #def shuffle_channel()
41 |
42 | def channel_shuffle(x, groups):
43 | batchsize, num_channels, height, width = x.size()
44 |
45 | channels_per_group = num_channels // groups
46 |
47 | # reshape
48 | x = x.view(batchsize, groups,
49 | channels_per_group, height, width)
50 |
51 | x = torch.transpose(x, 1, 2).contiguous()
52 |
53 | # flatten
54 | x = x.view(batchsize, -1, height, width)
55 |
56 | return x
57 |
58 | def pixel_down_shuffle(x, downsacale_factor):
59 | batchsize, num_channels, height, width = x.size()
60 |
61 | out_height = height // downsacale_factor
62 | out_width = width // downsacale_factor
63 | input_view = x.contiguous().view(batchsize, num_channels, out_height, downsacale_factor, out_width,
64 | downsacale_factor)
65 |
66 | num_channels *= downsacale_factor ** 2
67 | unshuffle_out = input_view.permute(0,1,3,5,2,4).contiguous()
68 |
69 | return unshuffle_out.view(batchsize, num_channels, out_height, out_width)
70 |
71 |
72 |
73 | def sp_init(x):
74 |
75 | x01 = x[:, :, 0::2, :]
76 | x02 = x[:, :, 1::2, :]
77 | x_LL = x01[:, :, :, 0::2]
78 | x_HL = x02[:, :, :, 0::2]
79 | x_LH = x01[:, :, :, 1::2]
80 | x_HH = x02[:, :, :, 1::2]
81 |
82 |
83 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
84 |
85 | def dwt_init3d(x):
86 |
87 | x01 = x[:, :, :, 0::2, :] / 2
88 | x02 = x[:, :, :, 1::2, :] / 2
89 | x1 = x01[:, :, :, :, 0::2]
90 | x2 = x02[:, :, :, :, 0::2]
91 | x3 = x01[:, :, :, :, 1::2]
92 | x4 = x02[:, :, :, :, 1::2]
93 | x_LL = x1 + x2 + x3 + x4
94 | x_HL = -x1 - x2 + x3 + x4
95 | x_LH = -x1 + x2 - x3 + x4
96 | x_HH = x1 - x2 - x3 + x4
97 |
98 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
99 |
100 | def dwt_init(x):
101 |
102 | x01 = x[:, :, 0::2, :] / 2
103 | x02 = x[:, :, 1::2, :] / 2
104 | x1 = x01[:, :, :, 0::2]
105 | x2 = x02[:, :, :, 0::2]
106 | x3 = x01[:, :, :, 1::2]
107 | x4 = x02[:, :, :, 1::2]
108 | x_LL = x1 + x2 + x3 + x4
109 | x_HL = -x1 - x2 + x3 + x4
110 | x_LH = -x1 + x2 - x3 + x4
111 | x_HH = x1 - x2 - x3 + x4
112 |
113 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
114 |
115 | def iwt_init(x):
116 | r = 2
117 | in_batch, in_channel, in_height, in_width = x.size()
118 | #print([in_batch, in_channel, in_height, in_width])
119 | out_batch, out_channel, out_height, out_width = in_batch, int(
120 | in_channel / (r ** 2)), r * in_height, r * in_width
121 | x1 = x[:, 0:out_channel, :, :] / 2
122 | x2 = x[:, out_channel:out_channel * 2, :, :] / 2
123 | x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
124 | x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
125 |
126 |
127 | h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
128 |
129 | h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
130 | h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
131 | h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
132 | h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
133 |
134 | return h
135 |
136 | class ResidualDenseBlock(nn.Module):
137 | def __init__(self, nf=64, gc=32, kernel_size = 3, bias=True, use_snorm=False):
138 | super(ResidualDenseBlock, self).__init__()
139 | # gc: growth channel, i.e. intermediate channels
140 | if use_snorm:
141 | self.conv1 = nn.utils.spectral_norm(nn.Conv2d(nf, gc, 3, 1, 1, bias=bias))
142 | self.conv2 = nn.utils.spectral_norm(nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias))
143 | self.conv3 = nn.utils.spectral_norm(nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias))
144 | self.conv4 = nn.utils.spectral_norm(nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias))
145 | self.conv5 = nn.utils.spectral_norm(nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias))
146 | else:
147 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
148 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
149 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
150 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
151 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
152 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
153 |
154 | # initialization
155 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
156 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1)
157 |
158 | def forward(self, x):
159 | x1 = self.lrelu(self.conv1(x))
160 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
161 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
162 | # x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
163 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
164 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
165 | return x5 * 0.2 + x
166 |
167 | class RRDB(nn.Module):
168 | '''Residual in Residual Dense Block'''
169 |
170 | def __init__(self, nf, gc=32, use_snorm=False):
171 | super(RRDB, self).__init__()
172 | self.RDB1 = ResidualDenseBlock(nf, gc, use_snorm)
173 | # self.RDB2 = ResidualDenseBlock(nf, gc)
174 | # self.RDB3 = ResidualDenseBlock(nf, gc)
175 |
176 | def forward(self, x):
177 | out = self.RDB1(x)
178 | # out = self.RDB2(out)
179 | # out = self.RDB3(out)
180 | return out * 0.2 + x
181 |
182 | class RRDBblock(nn.Module):
183 | '''Residual in Residual Dense Block'''
184 |
185 | def __init__(self, nf, gc=32, nb=23, use_snorm=False):
186 | super(RRDBblock, self).__init__()
187 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc, use_snorm=use_snorm)
188 |
189 | self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
190 | if use_snorm:
191 | self.trunk_conv = nn.utils.spectral_norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))
192 | else:
193 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
194 |
195 | def forward(self, x):
196 |
197 | return self.trunk_conv(self.RRDB_trunk(x))
198 |
199 | class Channel_Shuffle(nn.Module):
200 | def __init__(self, conv_groups):
201 | super(Channel_Shuffle, self).__init__()
202 | self.conv_groups = conv_groups
203 | self.requires_grad = False
204 |
205 | def forward(self, x):
206 | return channel_shuffle(x, self.conv_groups)
207 |
208 | class SP(nn.Module):
209 | def __init__(self):
210 | super(SP, self).__init__()
211 | self.requires_grad = False
212 |
213 | def forward(self, x):
214 | return sp_init(x)
215 |
216 | class Pixel_Down_Shuffle(nn.Module):
217 | def __init__(self):
218 | super(Pixel_Down_Shuffle, self).__init__()
219 | self.requires_grad = False
220 |
221 | def forward(self, x):
222 | return pixel_down_shuffle(x, 2)
223 |
224 | class DWT(nn.Module):
225 | def __init__(self):
226 | super(DWT, self).__init__()
227 | self.requires_grad = False
228 |
229 | def forward(self, x):
230 | return dwt_init(x)
231 |
232 | class DWT3d(nn.Module):
233 | def __init__(self):
234 | super(DWT3d, self).__init__()
235 | self.requires_grad = False
236 |
237 | def forward(self, x):
238 | return dwt_init3d(x)
239 |
240 | class IWT(nn.Module):
241 | def __init__(self):
242 | super(IWT, self).__init__()
243 | self.requires_grad = False
244 |
245 | def forward(self, x):
246 | return iwt_init(x)
247 |
248 |
249 | class MeanShift(nn.Conv2d):
250 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
251 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
252 | std = torch.Tensor(rgb_std)
253 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
254 | self.weight.data.div_(std.view(3, 1, 1, 1))
255 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
256 | self.bias.data.div_(std)
257 | self.requires_grad = False
258 | if sign==-1:
259 | self.create_graph = False
260 | self.volatile = True
261 | class MeanShift2(nn.Conv2d):
262 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
263 | super(MeanShift2, self).__init__(4, 4, kernel_size=1)
264 | std = torch.Tensor(rgb_std)
265 | self.weight.data = torch.eye(4).view(4, 4, 1, 1)
266 | self.weight.data.div_(std.view(4, 1, 1, 1))
267 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
268 | self.bias.data.div_(std)
269 | self.requires_grad = False
270 | if sign==-1:
271 | self.volatile = True
272 |
273 | class BasicBlock(nn.Sequential):
274 | def __init__(
275 | self, in_channels, out_channels, kernel_size, stride=1, bias=False,
276 | bn=False, act=nn.LeakyReLU(True), use_snorm=False):
277 |
278 | if use_snorm:
279 | m = [nn.utils.spectral_norm(nn.Conv2d(
280 | in_channels, out_channels, kernel_size,
281 | padding=(kernel_size//2), stride=stride, bias=bias))
282 | ]
283 | else:
284 | m = [nn.Conv2d(
285 | in_channels, out_channels, kernel_size,
286 | padding=(kernel_size//2), stride=stride, bias=bias)
287 | ]
288 | if bn: m.append(nn.BatchNorm2d(out_channels))
289 | if act is not None: m.append(act)
290 | super(BasicBlock, self).__init__(*m)
291 |
292 | class Block3d(nn.Sequential):
293 | def __init__(
294 | self, in_channels, out_channels, kernel_size, t_kernel=3,
295 | bias=True, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
296 |
297 | super(Block3d, self).__init__()
298 | m = []
299 |
300 | m.append(default_conv3d(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm))
301 | m.append(act)
302 | m.append(default_conv3d(out_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm))
303 | m.append(act)
304 |
305 |
306 | self.body = nn.Sequential(*m)
307 | self.res_scale = res_scale
308 |
309 | def forward(self, x):
310 | x = self.body(x)
311 | return x
312 |
313 | class BBlock(nn.Module):
314 | def __init__(
315 | self, conv, in_channels, out_channels, kernel_size,
316 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
317 |
318 | super(BBlock, self).__init__()
319 | m = []
320 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm))
321 |
322 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
323 | m.append(act)
324 |
325 |
326 | self.body = nn.Sequential(*m)
327 | self.res_scale = res_scale
328 |
329 | def forward(self, x):
330 | x = self.body(x)
331 | return x
332 |
333 | class DBlock_com(nn.Module):
334 | def __init__(
335 | self, conv, in_channels, out_channels, kernel_size,
336 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
337 |
338 | super(DBlock_com, self).__init__()
339 | m = []
340 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
341 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
342 | m.append(act)
343 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3, use_snorm=use_snorm))
344 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
345 | m.append(act)
346 |
347 |
348 | self.body = nn.Sequential(*m)
349 | self.res_scale = res_scale
350 |
351 | def forward(self, x):
352 | x = self.body(x)
353 | return x
354 |
355 | class DBlock_inv(nn.Module):
356 | def __init__(
357 | self, conv, in_channels, out_channels, kernel_size,
358 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
359 |
360 | super(DBlock_inv, self).__init__()
361 | m = []
362 |
363 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3, use_snorm=use_snorm))
364 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
365 | m.append(act)
366 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
367 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
368 | m.append(act)
369 |
370 |
371 | self.body = nn.Sequential(*m)
372 | self.res_scale = res_scale
373 |
374 | def forward(self, x):
375 | x = self.body(x)
376 | return x
377 |
378 | class DBlock_com1(nn.Module):
379 | def __init__(
380 | self, conv, in_channels, out_channels, kernel_size,
381 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
382 |
383 | super(DBlock_com1, self).__init__()
384 | m = []
385 |
386 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
387 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
388 | m.append(act)
389 | m.append(conv(out_channels, out_channels, kernel_size, bias=bias, dilation=1, use_snorm=use_snorm))
390 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
391 | m.append(act)
392 |
393 |
394 | self.body = nn.Sequential(*m)
395 | self.res_scale = res_scale
396 |
397 | def forward(self, x):
398 | x = self.body(x)
399 | return x
400 |
401 | class DBlock_inv1(nn.Module):
402 | def __init__(
403 | self, conv, in_channels, out_channels, kernel_size,
404 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
405 |
406 | super(DBlock_inv1, self).__init__()
407 | m = []
408 |
409 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
410 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
411 | m.append(act)
412 | m.append(conv(out_channels, out_channels, kernel_size, bias=bias, dilation=1, use_snorm=use_snorm))
413 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
414 | m.append(act)
415 |
416 |
417 | self.body = nn.Sequential(*m)
418 | self.res_scale = res_scale
419 |
420 | def forward(self, x):
421 | x = self.body(x)
422 | return x
423 |
424 | class DBlock_com2(nn.Module):
425 | def __init__(
426 | self, conv, in_channels, out_channels, kernel_size,
427 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
428 |
429 | super(DBlock_com2, self).__init__()
430 | m = []
431 |
432 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
433 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
434 | m.append(act)
435 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
436 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
437 | m.append(act)
438 |
439 |
440 | self.body = nn.Sequential(*m)
441 | self.res_scale = res_scale
442 |
443 | def forward(self, x):
444 | x = self.body(x)
445 | return x
446 |
447 | class DBlock_inv2(nn.Module):
448 | def __init__(
449 | self, conv, in_channels, out_channels, kernel_size,
450 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
451 |
452 | super(DBlock_inv2, self).__init__()
453 | m = []
454 |
455 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
456 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
457 | m.append(act)
458 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm))
459 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95))
460 | m.append(act)
461 |
462 |
463 | self.body = nn.Sequential(*m)
464 | self.res_scale = res_scale
465 |
466 | def forward(self, x):
467 | x = self.body(x)
468 | return x
469 |
470 | class ShuffleBlock(nn.Module):
471 | def __init__(
472 | self, conv, in_channels, out_channels, kernel_size,
473 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1,conv_groups=1, use_snorm=False):
474 |
475 | super(ShuffleBlock, self).__init__()
476 | m = []
477 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm))
478 | m.append(Channel_Shuffle(conv_groups))
479 | if bn: m.append(nn.BatchNorm2d(out_channels))
480 | m.append(act)
481 |
482 |
483 | self.body = nn.Sequential(*m)
484 | self.res_scale = res_scale
485 |
486 | def forward(self, x):
487 | x = self.body(x).mul(self.res_scale)
488 | return x
489 |
490 |
491 | class DWBlock(nn.Module):
492 | def __init__(
493 | self, conv, conv1, in_channels, out_channels, kernel_size,
494 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
495 |
496 | super(DWBlock, self).__init__()
497 | m = []
498 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm))
499 | if bn: m.append(nn.BatchNorm2d(out_channels))
500 | m.append(act)
501 |
502 | m.append(conv1(in_channels, out_channels, 1, bias=bias, use_snorm=use_snorm))
503 | if bn: m.append(nn.BatchNorm2d(out_channels))
504 | m.append(act)
505 |
506 |
507 | self.body = nn.Sequential(*m)
508 | self.res_scale = res_scale
509 |
510 | def forward(self, x):
511 | x = self.body(x).mul(self.res_scale)
512 | return x
513 |
514 | class ResBlock(nn.Module):
515 | def __init__(
516 | self, conv, n_feat, kernel_size,
517 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
518 |
519 | super(ResBlock, self).__init__()
520 | m = []
521 | for i in range(2):
522 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias, use_snorm=use_snorm))
523 | if bn: m.append(nn.BatchNorm2d(n_feat))
524 | if i == 0: m.append(act)
525 |
526 | self.body = nn.Sequential(*m)
527 | self.res_scale = res_scale
528 |
529 | def forward(self, x):
530 | res = self.body(x).mul(self.res_scale)
531 | res += x
532 |
533 | return res
534 |
535 | class Block(nn.Module):
536 | def __init__(
537 | self, conv, n_feat, kernel_size,
538 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False):
539 |
540 | super(Block, self).__init__()
541 | m = []
542 | for i in range(4):
543 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias, use_snorm=use_snorm))
544 | if bn: m.append(nn.BatchNorm2d(n_feat))
545 | if i == 0: m.append(act)
546 |
547 | self.body = nn.Sequential(*m)
548 | self.res_scale = res_scale
549 |
550 | def forward(self, x):
551 | res = self.body(x).mul(self.res_scale)
552 | # res += x
553 |
554 | return res
555 |
556 | class Upsampler(nn.Sequential):
557 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True, use_snorm=False):
558 |
559 | m = []
560 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
561 | for _ in range(int(math.log(scale, 2))):
562 | m.append(conv(n_feat, 4 * n_feat, 3, bias, use_snorm=use_snorm))
563 | m.append(nn.PixelShuffle(2))
564 | if bn: m.append(nn.BatchNorm2d(n_feat))
565 | if act: m.append(act())
566 | elif scale == 3:
567 | m.append(conv(n_feat, 9 * n_feat, 3, bias, use_snorm=use_snorm))
568 | m.append(nn.PixelShuffle(3))
569 | if bn: m.append(nn.BatchNorm2d(n_feat))
570 | if act: m.append(act())
571 | else:
572 | raise NotImplementedError
573 |
574 | super(Upsampler, self).__init__(*m)
575 |
576 | class VGG_conv0(nn.Module):
577 | def __init__(self, in_nc, nf):
578 |
579 | super(VGG_conv0, self).__init__()
580 | # [64, 128, 128]
581 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
582 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
583 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
584 | # [64, 64, 64]
585 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
586 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
587 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
588 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
589 | # [128, 32, 32]
590 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
591 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
592 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
593 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
594 | # [256, 16, 16]
595 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
596 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
597 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
598 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
599 | # [512, 8, 8]
600 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
601 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
602 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
603 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
604 |
605 | # self.avg_pool = nn.AvgPool2d(3, stride=2, padding=0, ceil_mode=True) # /2
606 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
607 |
608 | def forward(self, x):
609 | fea = self.lrelu(self.conv0_0(x))
610 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
611 |
612 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
613 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
614 |
615 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
616 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
617 |
618 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
619 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
620 |
621 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
622 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
623 | # fea = self.avg_pool(fea)
624 |
625 | return fea
626 |
627 | class VGG_conv1(nn.Module):
628 | def __init__(self, in_nc, nf):
629 |
630 | super(VGG_conv1, self).__init__()
631 | # [64, 128, 128]
632 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
633 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
634 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
635 | # [64, 64, 64]
636 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
637 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
638 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
639 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
640 | # [128, 32, 32]
641 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
642 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
643 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
644 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
645 | # [256, 16, 16]
646 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
647 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
648 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
649 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
650 | # [512, 8, 8]
651 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
652 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
653 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
654 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
655 |
656 | # self.avg_pool = nn.AvgPool2d(2, stride=1, padding=0, ceil_mode=True) # /2
657 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
658 |
659 | def forward(self, x):
660 | fea = self.lrelu(self.conv0_0(x))
661 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
662 |
663 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
664 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
665 |
666 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
667 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
668 |
669 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
670 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
671 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
672 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
673 | # fea = self.avg_pool(fea)
674 |
675 | return fea
676 |
677 | class VGG_conv2(nn.Module):
678 | def __init__(self, in_nc, nf):
679 |
680 | super(VGG_conv2, self).__init__()
681 | # [64, 128, 128]
682 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
683 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
684 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
685 | # [64, 64, 64]
686 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
687 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
688 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
689 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
690 | # [128, 32, 32]
691 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
692 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
693 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
694 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
695 | # [256, 16, 16]
696 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
697 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
698 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
699 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
700 | # [512, 8, 8]
701 | # self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
702 | # self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
703 | # self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
704 | # self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
705 |
706 | # self.avg_pool = nn.AvgPool2d(3, stride=2, padding=0, ceil_mode=True) # /2
707 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
708 |
709 | def forward(self, x):
710 | fea = self.lrelu(self.conv0_0(x))
711 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
712 |
713 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
714 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
715 |
716 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
717 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
718 |
719 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
720 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
721 |
722 | # fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
723 | # fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
724 | # fea = self.avg_pool(fea)
725 |
726 | return fea
--------------------------------------------------------------------------------