├── .idea ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── misc.xml ├── modules.xml └── Invertible-Image-Rescaling-master.iml ├── .style.yapf ├── figures ├── InvMIHNet_8images.png ├── overview_InvMIHNet.png └── very_large_capacity.png ├── experiments └── pretrained_models │ └── link ├── .flake8 ├── codes ├── models │ ├── __init__.py │ ├── modules │ │ ├── Quantization.py │ │ ├── Subnet_constructor.py │ │ ├── loss.py │ │ ├── discriminator_vgg_arch.py │ │ ├── module_util.py │ │ ├── Inv_arch.py │ │ └── Unet_common.py │ ├── model.py │ ├── networks.py │ ├── rrdb_denselayer.py │ ├── invblock.py │ ├── IIH_module.py │ ├── bicubic.py │ ├── base_model.py │ ├── lr_scheduler.py │ └── InvMIHNet_model.py ├── options │ ├── test │ │ ├── test_InvMIHNet_16images.yml │ │ ├── test_InvMIHNet_8images.yml │ │ ├── test_InvMIHNet_9images.yml │ │ ├── test_InvMIHNet_6images.yml │ │ └── test_InvMIHNet_4images.yml │ ├── train │ │ ├── train_InvMIHNet_8images.yml │ │ ├── train_InvMIHNet_4images.yml │ │ ├── train_InvMIHNet_9images.yml │ │ ├── train_InvMIHNet_6images.yml │ │ └── train_InvMIHNet_16images.yml │ └── options.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── Steg_dataset.py │ └── util.py ├── test.py ├── utils │ └── util.py └── train.py ├── .gitignore ├── README.md └── LICENSE /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE = pep8 3 | COLUMN_LIMIT = 100 4 | SPLIT_BEFORE_NAMED_ASSIGNS = false -------------------------------------------------------------------------------- /figures/InvMIHNet_8images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brittany-Chen/InvMIHNet/HEAD/figures/InvMIHNet_8images.png -------------------------------------------------------------------------------- /experiments/pretrained_models/link: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/17GRiwaJN8yqmLtiAO-bJcgpG8aWHysz3/view?usp=sharing 2 | -------------------------------------------------------------------------------- /figures/overview_InvMIHNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brittany-Chen/InvMIHNet/HEAD/figures/overview_InvMIHNet.png -------------------------------------------------------------------------------- /figures/very_large_capacity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Brittany-Chen/InvMIHNet/HEAD/figures/very_large_capacity.png -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # Too many leading '#' for block comment (E266) 4 | E266 5 | 6 | max-line-length=100 -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | model = opt['model'] 7 | 8 | if model == 'InvMIHNet': 9 | from .InvMIHNet_model import InvMIHNet as M 10 | else: 11 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 12 | m = M(opt) 13 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 14 | return m 15 | -------------------------------------------------------------------------------- /.idea/Invertible-Image-Rescaling-master.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /codes/models/modules/Quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Quant(torch.autograd.Function): 5 | 6 | @staticmethod 7 | def forward(ctx, input): 8 | input = torch.clamp(input, 0, 1) 9 | output = (input * 255.).round() / 255. 10 | return output 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | return grad_output 15 | 16 | class Quantization(nn.Module): 17 | def __init__(self): 18 | super(Quantization, self).__init__() 19 | 20 | def forward(self, input): 21 | return Quant.apply(input) 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # folder 2 | .vscode 3 | 4 | experiments/* 5 | !experiments/pretrained_models 6 | experiments/pretrained_models/* 7 | # !experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth 8 | !experiments/pretrained_models/README.md 9 | 10 | results/* 11 | tb_logger/* 12 | 13 | # file type 14 | *.svg 15 | *.pyc 16 | *.t7 17 | *.caffemodel 18 | *.mat 19 | *.npy 20 | 21 | # latex 22 | *.aux 23 | *.bbl 24 | *.blg 25 | *.log 26 | *.out 27 | *.synctex.gz 28 | 29 | # TODO 30 | data_samples/samples_byteimg 31 | data_samples/samples_colorimg 32 | data_samples/samples_segprob 33 | data_samples/samples_result -------------------------------------------------------------------------------- /codes/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | import torch.nn as nn 3 | from models.IIH_module import IIH_module 4 | 5 | 6 | class Model(nn.Module): 7 | def __init__(self): 8 | super(Model, self).__init__() 9 | 10 | self.model = IIH_module() 11 | 12 | def forward(self, x, rev=False): 13 | 14 | if not rev: 15 | out = self.model(x) 16 | 17 | else: 18 | out = self.model(x, rev=True) 19 | 20 | return out 21 | 22 | 23 | def init_model(mod): 24 | for key, param in mod.named_parameters(): 25 | split = key.split('.') 26 | if param.requires_grad: 27 | param.data = 0.01 * torch.randn(param.data.shape).cuda() 28 | if split[-2] == 'conv5': 29 | param.data.fill_(0.) 30 | -------------------------------------------------------------------------------- /codes/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import models.modules.discriminator_vgg_arch as SRGAN_arch 4 | from models.modules.Inv_arch import * 5 | from models.modules.Subnet_constructor import subnet 6 | import math 7 | logger = logging.getLogger('base') 8 | 9 | 10 | #################### 11 | # define network 12 | #################### 13 | def define_G(opt): 14 | opt_net = opt['network_G'] 15 | which_model = opt_net['which_model_G'] 16 | subnet_type = which_model['subnet_type'] 17 | if opt_net['init']: 18 | init = opt_net['init'] 19 | else: 20 | init = 'xavier' 21 | 22 | down_num = len(opt_net['block_num']) 23 | # down_num = int(math.log(opt_net['scale_W'], 2)) 24 | 25 | use_ConvDownsampling = False 26 | 27 | if which_model['use_ConvDownsampling']: 28 | use_ConvDownsampling = True 29 | 30 | netG = InvRescaleNet(opt_net['in_nc'], opt_net['out_nc'], subnet(subnet_type, init), opt_net['block_num'], down_num, use_ConvDownsampling=use_ConvDownsampling, down_scale_W=opt_net['scale_W'], down_scale_H=opt_net['scale_H']) 31 | 32 | return netG 33 | -------------------------------------------------------------------------------- /codes/models/rrdb_denselayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import models.modules.module_util as mutil 4 | 5 | 6 | # Dense connection 7 | class ResidualDenseBlock_out(nn.Module): 8 | def __init__(self, input, output, bias=True): 9 | super(ResidualDenseBlock_out, self).__init__() 10 | self.conv1 = nn.Conv2d(input, 32, 3, 1, 1, bias=bias) 11 | self.conv2 = nn.Conv2d(input + 32, 32, 3, 1, 1, bias=bias) 12 | self.conv3 = nn.Conv2d(input + 2 * 32, 32, 3, 1, 1, bias=bias) 13 | self.conv4 = nn.Conv2d(input + 3 * 32, 32, 3, 1, 1, bias=bias) 14 | self.conv5 = nn.Conv2d(input + 4 * 32, output, 3, 1, 1, bias=bias) 15 | self.lrelu = nn.LeakyReLU(inplace=True) 16 | # initialization 17 | mutil.initialize_weights([self.conv5], 0.) 18 | 19 | def forward(self, x): 20 | x1 = self.lrelu(self.conv1(x)) 21 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 22 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 23 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 24 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 25 | return x5 26 | 27 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | -------------------------------------------------------------------------------- /codes/models/invblock.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch 3 | import torch.nn as nn 4 | from models.rrdb_denselayer import ResidualDenseBlock_out 5 | 6 | 7 | class INV_block(nn.Module): 8 | def __init__(self, subnet_constructor=ResidualDenseBlock_out, clamp=2.0, harr=True, in_1=3, in_2=3): 9 | super().__init__() 10 | if harr: 11 | self.split_len1 = in_1 * 4 12 | self.split_len2 = in_2 * 4 13 | self.clamp = clamp 14 | # ρ 15 | self.r = subnet_constructor(self.split_len1, self.split_len2) 16 | # η 17 | self.y = subnet_constructor(self.split_len1, self.split_len2) 18 | # φ 19 | self.f = subnet_constructor(self.split_len2, self.split_len1) 20 | 21 | 22 | def e(self, s): 23 | return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5)) 24 | 25 | def forward(self, x, rev=False): 26 | x1, x2 = (x.narrow(1, 0, self.split_len1), 27 | x.narrow(1, self.split_len1, self.split_len2)) 28 | 29 | if not rev: 30 | 31 | t2 = self.f(x2) 32 | y1 = x1 + t2 33 | s1, t1 = self.r(y1), self.y(y1) 34 | y2 = self.e(s1) * x2 + t1 35 | 36 | else: 37 | 38 | s1, t1 = self.r(x1), self.y(x1) 39 | y2 = (x2 - t1) / self.e(s1) 40 | t2 = self.f(y2) 41 | y1 = (x1 - t2) 42 | 43 | return torch.cat((y1, y2), 1) 44 | 45 | -------------------------------------------------------------------------------- /codes/options/test/test_InvMIHNet_16images.yml: -------------------------------------------------------------------------------- 1 | name: Hiding_16_images_test 2 | suffix: ~ # add suffix to saved images 3 | model: InvMIHNet 4 | distortion: sr 5 | scale_W: 4 6 | scale_H: 4 7 | crop_border: 0 # crop border when evaluation. If None(~), crop the scale pixels 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test_1: 12 | name: val_DIV2K 13 | mode: Steg 14 | batch_size: 17 15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024 16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 17 | # test_2: 18 | # name: val_COCO 19 | # mode: LQGT 20 | # dataroot_GT: # path to COCO testing dataset 21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 22 | # test_3: 23 | # name: val_ImageNet 24 | # mode: LQGT 25 | # dataroot_GT: # path to ImageNet testing dataset 26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 27 | 28 | 29 | #### network 30 | network_G: 31 | which_model_G: 32 | subnet_type: DBNet 33 | in_nc: 3 34 | out_nc: 3 35 | block_num: [8, 8] 36 | scale_W: 4 37 | scale_H: 4 38 | init: xavier 39 | 40 | 41 | #### path 42 | path: 43 | pretrain_model_G: ../experiments/pretrained_models/IIR_16images.pth 44 | pretrain_model_H: ../experiments/pretrained_models/IIH_16images.pth 45 | -------------------------------------------------------------------------------- /codes/options/test/test_InvMIHNet_8images.yml: -------------------------------------------------------------------------------- 1 | name: Hiding_8_images_test 2 | suffix: ~ # add suffix to saved images 3 | model: InvMIHNet 4 | distortion: sr 5 | scale_H: 4 6 | scale_W: 2 7 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test_1: 12 | name: val_DIV2K 13 | mode: Steg 14 | batch_size: 9 15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024 16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 17 | # test_2: 18 | # name: val_COCO 19 | # mode: LQGT 20 | # dataroot_GT: # path to COCO testing dataset 21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 22 | # test_3: 23 | # name: val_ImageNet 24 | # mode: LQGT 25 | # dataroot_GT: # path to ImageNet testing dataset 26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 27 | 28 | 29 | 30 | #### network 31 | network_G: 32 | which_model_G: 33 | subnet_type: DBNet 34 | in_nc: 3 35 | out_nc: 3 36 | block_num: [8, 8] 37 | scale_H: 4 38 | scale_W: 2 39 | init: xavier 40 | 41 | 42 | #### path 43 | path: 44 | pretrain_model_G: ../experiments/pretrained_models/IIR_8images.pth 45 | pretrain_model_H: ../experiments/pretrained_models/IIH_8images.pth 46 | -------------------------------------------------------------------------------- /codes/options/test/test_InvMIHNet_9images.yml: -------------------------------------------------------------------------------- 1 | name: Hiding_9_images_test 2 | suffix: ~ # add suffix to saved images 3 | model: InvMIHNet 4 | distortion: sr 5 | scale_W: 3 6 | scale_H: 3 7 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test_1: 12 | name: val_DIV2K 13 | mode: Steg 14 | batch_size: 10 15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024 16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 17 | # test_2: 18 | # name: val_COCO 19 | # mode: LQGT 20 | # dataroot_GT: # path to COCO testing dataset 21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 22 | # test_3: 23 | # name: val_ImageNet 24 | # mode: LQGT 25 | # dataroot_GT: # path to ImageNet testing dataset 26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 27 | 28 | 29 | #### network 30 | network_G: 31 | which_model_G: 32 | subnet_type: DBNet 33 | use_ConvDownsampling: True 34 | in_nc: 3 35 | out_nc: 3 36 | block_num: [12] 37 | scale_W: 3 38 | scale_H: 3 39 | init: xavier 40 | 41 | 42 | #### path 43 | path: 44 | pretrain_model_G: ../experiments/pretrained_models/IIR_9images.pth 45 | pretrain_model_H: ../experiments/pretrained_models/IIH_9images.pth -------------------------------------------------------------------------------- /codes/options/test/test_InvMIHNet_6images.yml: -------------------------------------------------------------------------------- 1 | name: Hiding_6_images_test 2 | suffix: ~ # add suffix to saved images 3 | model: InvMIHNet 4 | distortion: sr 5 | scale_W: 2 6 | scale_H: 3 7 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test_1: 12 | name: val_DIV2K 13 | mode: Steg 14 | batch_size: 7 15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024 16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 17 | # test_2: 18 | # name: val_COCO 19 | # mode: LQGT 20 | # dataroot_GT: # path to COCO testing dataset 21 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 22 | # test_3: 23 | # name: val_ImageNet 24 | # mode: LQGT 25 | # dataroot_GT: # path to ImageNet testing dataset 26 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 27 | 28 | 29 | 30 | 31 | #### network 32 | network_G: 33 | which_model_G: 34 | subnet_type: DBNet 35 | use_ConvDownsampling: True 36 | in_nc: 3 37 | out_nc: 3 38 | block_num: [12] 39 | scale_W: 2 40 | scale_H: 3 41 | init: xavier 42 | 43 | 44 | #### path 45 | path: 46 | pretrain_model_G: ../experiments/pretrained_models/IIR_6images.pth #../experiments/pretrained_models_IRN_all/100_G.pth 47 | pretrain_model_H: ../experiments/pretrained_models/IIH_6images.pth 48 | -------------------------------------------------------------------------------- /codes/options/test/test_InvMIHNet_4images.yml: -------------------------------------------------------------------------------- 1 | name: Hiding_4_images_test 2 | suffix: ~ # add suffix to saved images 3 | model: InvMIHNet 4 | distortion: sr 5 | scale_W: 2 6 | scale_H: 2 7 | crop_border: 0 # crop border when evaluation. If None(~), crop the scale pixels 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test_1: 12 | name: val_DIV2K 13 | mode: Steg 14 | batch_size: 5 15 | dataroot_GT: D:\dataset\DIV2K\DIV2K_valid_HR_1024 # path to DIV2K testing dataset 16 | dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 17 | # test_2: 18 | # name: val_COCO 19 | # mode: LQGT 20 | # batch_size: 5 21 | # dataroot_GT: # path to COCO testing dataset 22 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 23 | # test_3: 24 | # name: val_ImageNet 25 | # mode: LQGT 26 | # batch_size: 5 27 | # dataroot_GT: # path to ImageNet testing dataset 28 | # dataroot_LQ: ~ # path to test reference LR images, not necessary, if not provided, LR images will be generated in dataloader 29 | 30 | 31 | #### network 32 | network_G: 33 | which_model_G: 34 | subnet_type: DBNet 35 | in_nc: 3 36 | out_nc: 3 37 | block_num: [8] 38 | scale_W: 2 39 | scale_H: 2 40 | init: xavier 41 | 42 | network_H: 43 | lamda_reconstruction: 1 44 | lamda_guide: 10 45 | lamda_low_frequency: 1 46 | 47 | #### path 48 | path: 49 | pretrain_model_G: ../experiments/pretrained_models/IIR_4images.pth 50 | pretrain_model_H: ../experiments/pretrained_models/IIH_4images.pth -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=False) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False, num_workers=0, 25 | pin_memory=True, drop_last=True) 26 | 27 | 28 | def create_dataset(dataset_opt): 29 | mode = dataset_opt['mode'] 30 | if mode == 'Steg': 31 | from data.Steg_dataset import StegDataset as S 32 | else: 33 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 34 | dataset = S(dataset_opt) 35 | 36 | logger = logging.getLogger('base') 37 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 38 | dataset_opt['name'])) 39 | return dataset 40 | -------------------------------------------------------------------------------- /codes/models/modules/Subnet_constructor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.modules.module_util as mutil 5 | 6 | class DenseBlock(nn.Module): 7 | def __init__(self, channel_in, channel_out, init='xavier', gc=32, bias=True): 8 | super(DenseBlock, self).__init__() 9 | self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) 10 | self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) 11 | self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) 12 | self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) 13 | self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) 14 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 15 | 16 | if init == 'xavier': 17 | mutil.initialize_weights_xavier([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) 18 | else: 19 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) 20 | mutil.initialize_weights(self.conv5, 0) 21 | 22 | def forward(self, x): 23 | x1 = self.lrelu(self.conv1(x)) 24 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 25 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 26 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 27 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 28 | return x5 29 | 30 | 31 | def subnet(net_structure, init='xavier'): 32 | def constructor(channel_in, channel_out): 33 | if net_structure == 'DBNet': 34 | if init == 'xavier': 35 | return DenseBlock(channel_in, channel_out, init) 36 | else: 37 | return DenseBlock(channel_in, channel_out) 38 | else: 39 | return None 40 | 41 | return constructor -------------------------------------------------------------------------------- /codes/options/train/train_InvMIHNet_8images.yml: -------------------------------------------------------------------------------- 1 | 2 | #### general settings 3 | 4 | name: InvMIHNet_8images_train 5 | use_tb_logger: true 6 | model: InvMIHNet 7 | distortion: sr 8 | scale_H: 4 9 | scale_W: 2 10 | gpu_ids: [0] 11 | 12 | 13 | #### datasets 14 | 15 | datasets: 16 | train: 17 | name: DIV2K 18 | mode: Steg 19 | dataroot_GT: # path to training images 20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader 21 | 22 | use_shuffle: true 23 | n_workers: 0 # per GPU 24 | batch_size: 9 25 | GT_size: 144 26 | color: RGB 27 | 28 | val: 29 | name: val_DIV2K 30 | mode: Steg 31 | batch_size: 9 32 | dataroot_GT: # path to validation images 33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader 34 | 35 | 36 | #### network structures 37 | 38 | network_G: 39 | which_model_G: 40 | subnet_type: DBNet 41 | in_nc: 3 42 | out_nc: 3 43 | block_num: [8, 8] 44 | scale_H: 4 45 | scale_W: 2 46 | init: xavier 47 | 48 | 49 | #### path 50 | 51 | path: 52 | pretrain_model_G: ~ 53 | pretrain_model_H: ~ 54 | strict_load: true 55 | resume_state: ~ 56 | 57 | 58 | #### training settings: learning rate scheme, loss 59 | 60 | train: 61 | lr_G: !!float 2e-4 62 | lr_H: !!float 2e-4 63 | beta1: 0.9 64 | beta2: 0.999 65 | beta1_H: 0.5 66 | beta2_H: 0.999 67 | niter: 500000 68 | warmup_iter: -1 # no warm up 69 | 70 | lr_scheme: MultiStepLR 71 | lr_steps: [20000, 40000, 80000, 100000] 72 | lr_gamma: 0.5 73 | 74 | pixel_criterion_forw: l2 75 | pixel_criterion_back: l1 76 | 77 | manual_seed: 10 78 | 79 | val_freq: 1000 80 | lambda_fit_forw: 4 81 | lambda_rec_back: 1 82 | lambda_ce_forw: 1 83 | lamda_reconstruction: 5 84 | lamda_guide: 1 85 | lamda_low_frequency: 1 86 | 87 | weight_decay_G: !!float 1e-5 88 | weight_decay_H: !!float 1e-5 89 | gradient_clipping: 10 90 | 91 | 92 | weight_step: 1000 93 | 94 | 95 | #### logger 96 | 97 | logger: 98 | save_checkpoint_freq: 1000 -------------------------------------------------------------------------------- /codes/options/train/train_InvMIHNet_4images.yml: -------------------------------------------------------------------------------- 1 | 2 | #### general settings 3 | 4 | name: InvMIHNet_4images_train 5 | use_tb_logger: true 6 | model: InvMIHNet 7 | distortion: sr 8 | scale_W: 2 9 | scale_H: 2 10 | gpu_ids: [0] 11 | 12 | 13 | #### datasets 14 | 15 | datasets: 16 | train: 17 | name: DIV2K 18 | mode: Steg 19 | dataroot_GT: # path to training images 20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader 21 | 22 | use_shuffle: true 23 | n_workers: 0 # per GPU 24 | batch_size: 5 25 | GT_size: 144 26 | color: RGB 27 | 28 | val: 29 | name: val_DIV2K 30 | mode: Steg 31 | batch_size: 5 32 | dataroot_GT: # path to validation images 33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader 34 | 35 | 36 | #### network structures 37 | 38 | network_G: 39 | which_model_G: 40 | subnet_type: DBNet 41 | in_nc: 3 42 | out_nc: 3 43 | block_num: [8] 44 | scale_W: 2 45 | scale_H: 2 46 | init: xavier 47 | 48 | 49 | #### path 50 | 51 | path: 52 | pretrain_model_G: ~ 53 | pretrain_model_H: ~ 54 | strict_load: true 55 | resume_state: ~ 56 | 57 | 58 | #### training settings: learning rate scheme, loss 59 | train: 60 | lr_G: !!float 2e-4 61 | lr_H: !!float 2e-4 62 | beta1: 0.9 63 | beta2: 0.999 64 | beta1_H: 0.5 65 | beta2_H: 0.999 66 | niter: 500000 67 | warmup_iter: -1 # no warm up 68 | 69 | 70 | lr_scheme: MultiStepLR 71 | lr_steps: [20000, 40000, 80000, 100000] 72 | lr_gamma: 0.5 73 | 74 | pixel_criterion_forw: l2 75 | pixel_criterion_back: l1 76 | 77 | manual_seed: 10 78 | 79 | val_freq: 1000 80 | 81 | lambda_fit_forw: 4 82 | lambda_rec_back: 1 83 | lambda_ce_forw: 1 84 | lamda_reconstruction: 5 85 | lamda_guide: 1 86 | lamda_low_frequency: 1 87 | 88 | weight_decay_G: !!float 1e-5 89 | weight_decay_H: !!float 1e-5 90 | gradient_clipping: 10 91 | 92 | 93 | weight_step: 1000 94 | 95 | 96 | #### logger 97 | 98 | logger: 99 | save_checkpoint_freq: 1000 -------------------------------------------------------------------------------- /codes/options/train/train_InvMIHNet_9images.yml: -------------------------------------------------------------------------------- 1 | 2 | #### general settings 3 | 4 | name: InvMIHNet_9images_train 5 | use_tb_logger: true 6 | model: InvMIHNet 7 | distortion: sr 8 | scale_W: 3 9 | scale_H: 3 10 | gpu_ids: [0] 11 | 12 | 13 | #### datasets 14 | 15 | datasets: 16 | train: 17 | name: DIV2K 18 | mode: Steg 19 | dataroot_GT: # path to training images 20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader 21 | 22 | use_shuffle: true 23 | n_workers: 0 # per GPU 24 | batch_size: 10 25 | GT_size: 144 26 | color: RGB 27 | 28 | val: 29 | name: val_DIV2K 30 | mode: Steg 31 | batch_size: 10 32 | dataroot_GT: # path to validation images 33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader 34 | 35 | 36 | #### network structures 37 | 38 | network_G: 39 | which_model_G: 40 | subnet_type: DBNet 41 | use_ConvDownsampling: True 42 | in_nc: 3 43 | out_nc: 3 44 | block_num: [12] 45 | scale_W: 3 46 | scale_H: 3 47 | init: xavier 48 | 49 | 50 | #### path 51 | 52 | path: 53 | pretrain_model_G: ~ 54 | pretrain_model_H: ~ 55 | strict_load: true 56 | resume_state: ~ 57 | 58 | 59 | #### training settings: learning rate scheme, loss 60 | 61 | train: 62 | lr_G: !!float 2e-4 63 | lr_H: !!float 2e-4 64 | beta1: 0.9 65 | beta2: 0.999 66 | beta1_H: 0.5 67 | beta2_H: 0.999 68 | niter: 500000 69 | warmup_iter: -1 # no warm up 70 | 71 | lr_scheme: MultiStepLR 72 | lr_steps: [20000, 40000, 80000, 100000] 73 | lr_gamma: 0.5 74 | 75 | pixel_criterion_forw: l2 76 | pixel_criterion_back: l1 77 | 78 | manual_seed: 10 79 | 80 | val_freq: 1000 81 | 82 | lambda_fit_forw: 4 83 | lambda_rec_back: 1 84 | lambda_ce_forw: 1 85 | lamda_reconstruction: 5 86 | lamda_guide: 1 87 | lamda_low_frequency: 1 88 | 89 | weight_decay_G: !!float 1e-5 90 | weight_decay_H: !!float 1e-5 91 | gradient_clipping: 10 92 | 93 | 94 | weight_step: 1000 95 | 96 | 97 | #### logger 98 | 99 | logger: 100 | save_checkpoint_freq: 1000 -------------------------------------------------------------------------------- /codes/options/train/train_InvMIHNet_6images.yml: -------------------------------------------------------------------------------- 1 | 2 | #### general settings 3 | 4 | name: InvMIHNet_6images_train 5 | use_tb_logger: true 6 | model: InvMIHNet 7 | distortion: sr 8 | scale_W: 2 9 | scale_H: 3 10 | gpu_ids: [0] 11 | 12 | 13 | #### datasets 14 | 15 | datasets: 16 | train: 17 | name: DIV2K 18 | mode: Steg 19 | dataroot_GT: # path to training images 20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader 21 | 22 | use_shuffle: true 23 | n_workers: 0 # per GPU 24 | batch_size: 7 25 | GT_size: 144 26 | color: RGB 27 | 28 | val: 29 | name: val_DIV2K 30 | mode: Steg 31 | batch_size: 7 32 | dataroot_GT: # path to validation images 33 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader 34 | 35 | 36 | #### network structures 37 | 38 | network_G: 39 | which_model_G: 40 | subnet_type: DBNet 41 | use_ConvDownsampling: True 42 | in_nc: 3 43 | out_nc: 3 44 | block_num: [12] 45 | scale_W: 2 46 | scale_H: 3 47 | init: xavier 48 | 49 | 50 | #### path 51 | 52 | path: 53 | pretrain_model_G: ~ 54 | pretrain_model_H: ~ 55 | strict_load: true 56 | resume_state: ~ 57 | 58 | 59 | #### training settings: learning rate scheme, loss 60 | train: 61 | lr_G: !!float 2e-4 62 | lr_H: !!float 2e-4 63 | beta1: 0.9 64 | beta2: 0.999 65 | beta1_H: 0.5 66 | beta2_H: 0.999 67 | niter: 500000 68 | warmup_iter: -1 # no warm up 69 | 70 | 71 | lr_scheme: MultiStepLR 72 | lr_steps: [20000, 40000, 80000, 100000] 73 | lr_gamma: 0.5 74 | 75 | pixel_criterion_forw: l2 76 | pixel_criterion_back: l1 77 | 78 | manual_seed: 10 79 | 80 | val_freq: 1000 81 | 82 | lambda_fit_forw: 4 83 | lambda_rec_back: 1 84 | lambda_ce_forw: 1 85 | lamda_reconstruction: 5 86 | lamda_guide: 1 87 | lamda_low_frequency: 1 88 | 89 | weight_decay_G: !!float 1e-5 90 | weight_decay_H: !!float 1e-5 91 | gradient_clipping: 10 92 | 93 | 94 | weight_step: 1000 95 | 96 | 97 | #### logger 98 | 99 | logger: 100 | save_checkpoint_freq: 1000 101 | -------------------------------------------------------------------------------- /codes/options/train/train_InvMIHNet_16images.yml: -------------------------------------------------------------------------------- 1 | 2 | #### general settings 3 | 4 | name: InvMIHNet_16images_train 5 | use_tb_logger: true 6 | model: InvMIHNet 7 | distortion: sr 8 | scale_W: 4 9 | scale_H: 4 10 | gpu_ids: [0] 11 | 12 | 13 | #### datasets 14 | 15 | datasets: 16 | train: 17 | name: DIV2K 18 | mode: Steg 19 | dataroot_GT: # path to training images 20 | dataroot_LQ: ~ # path to training reference LR images, not necessary, if not provided, LR images will be generated in dataloader 21 | 22 | use_shuffle: true 23 | n_workers: 6 # per GPU 24 | batch_size: 17 25 | GT_size: 144 26 | use_flip: true 27 | use_rot: true 28 | color: RGB 29 | 30 | val: 31 | name: val_DIV2K 32 | mode: Steg 33 | batch_size: 17 34 | dataroot_GT: # path to validation images 35 | dataroot_LQ: ~ # path to validation reference LR images, not necessary, if not provided, LR images will be generated in dataloader 36 | 37 | 38 | #### network structures 39 | 40 | network_G: 41 | which_model_G: 42 | subnet_type: DBNet 43 | in_nc: 3 44 | out_nc: 3 45 | block_num: [8, 8] 46 | scale_W: 4 47 | scale_H: 4 48 | init: xavier 49 | 50 | 51 | #### path 52 | 53 | path: 54 | pretrain_model_G: ~ 55 | pretrain_model_H: ~ 56 | strict_load: true 57 | resume_state: ~ 58 | 59 | 60 | #### training settings: learning rate scheme, loss 61 | 62 | train: 63 | lr_G: !!float 2e-4 64 | lr_H: !!float 2e-4 65 | beta1: 0.9 66 | beta2: 0.999 67 | beta1_H: 0.5 68 | beta2_H: 0.999 69 | niter: 500000 70 | warmup_iter: -1 # no warm up 71 | 72 | 73 | lr_scheme: MultiStepLR 74 | lr_steps: [20000, 40000, 80000, 100000] 75 | lr_gamma: 0.5 76 | 77 | pixel_criterion_forw: l2 78 | pixel_criterion_back: l1 79 | 80 | manual_seed: 10 81 | 82 | val_freq: 1000 83 | 84 | lambda_fit_forw: 4 85 | lambda_rec_back: 1 86 | lambda_ce_forw: 1 87 | lamda_reconstruction: 5 88 | lamda_guide: 1 89 | lamda_low_frequency: 1 90 | 91 | weight_decay_G: !!float 1e-5 92 | weight_decay_H: !!float 1e-5 93 | gradient_clipping: 10 94 | 95 | 96 | weight_step: 1000 97 | 98 | 99 | #### logger 100 | 101 | logger: 102 | save_checkpoint_freq: 1000 -------------------------------------------------------------------------------- /codes/models/IIH_module.py: -------------------------------------------------------------------------------- 1 | from models.model import * 2 | from models.invblock import INV_block 3 | 4 | 5 | class IIH_module(nn.Module): 6 | 7 | def __init__(self): 8 | super(IIH_module, self).__init__() 9 | 10 | self.inv1 = INV_block() 11 | self.inv2 = INV_block() 12 | self.inv3 = INV_block() 13 | self.inv4 = INV_block() 14 | self.inv5 = INV_block() 15 | self.inv6 = INV_block() 16 | self.inv7 = INV_block() 17 | self.inv8 = INV_block() 18 | 19 | self.inv9 = INV_block() 20 | self.inv10 = INV_block() 21 | self.inv11 = INV_block() 22 | self.inv12 = INV_block() 23 | self.inv13 = INV_block() 24 | self.inv14 = INV_block() 25 | self.inv15 = INV_block() 26 | self.inv16 = INV_block() 27 | 28 | def forward(self, x, rev=False): 29 | 30 | if not rev: 31 | out = self.inv1(x) 32 | out = self.inv2(out) 33 | out = self.inv3(out) 34 | out = self.inv4(out) 35 | out = self.inv5(out) 36 | out = self.inv6(out) 37 | out = self.inv7(out) 38 | out = self.inv8(out) 39 | 40 | out = self.inv9(out) 41 | out = self.inv10(out) 42 | out = self.inv11(out) 43 | out = self.inv12(out) 44 | out = self.inv13(out) 45 | out = self.inv14(out) 46 | out = self.inv15(out) 47 | out = self.inv16(out) 48 | 49 | else: 50 | out = self.inv16(x, rev=True) 51 | out = self.inv15(out, rev=True) 52 | out = self.inv14(out, rev=True) 53 | out = self.inv13(out, rev=True) 54 | out = self.inv12(out, rev=True) 55 | out = self.inv11(out, rev=True) 56 | out = self.inv10(out, rev=True) 57 | out = self.inv9(out, rev=True) 58 | 59 | out = self.inv8(out, rev=True) 60 | out = self.inv7(out, rev=True) 61 | out = self.inv6(out, rev=True) 62 | out = self.inv5(out, rev=True) 63 | out = self.inv4(out, rev=True) 64 | out = self.inv3(out, rev=True) 65 | out = self.inv2(out, rev=True) 66 | out = self.inv1(out, rev=True) 67 | 68 | return out 69 | 70 | 71 | -------------------------------------------------------------------------------- /codes/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /codes/models/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class ReconstructionLoss(nn.Module): 6 | def __init__(self, losstype='l2', eps=1e-6): 7 | super(ReconstructionLoss, self).__init__() 8 | self.losstype = losstype 9 | self.eps = eps 10 | 11 | def forward(self, x, target): 12 | if self.losstype == 'l2': 13 | return torch.mean(torch.sum((x - target)**2, (1, 2, 3))) 14 | elif self.losstype == 'l1': 15 | diff = x - target 16 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) 17 | else: 18 | print("reconstruction loss type error!") 19 | return 0 20 | 21 | 22 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 23 | class GANLoss(nn.Module): 24 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 25 | super(GANLoss, self).__init__() 26 | self.gan_type = gan_type.lower() 27 | self.real_label_val = real_label_val 28 | self.fake_label_val = fake_label_val 29 | 30 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 31 | self.loss = nn.BCEWithLogitsLoss() 32 | elif self.gan_type == 'lsgan': 33 | self.loss = nn.MSELoss() 34 | elif self.gan_type == 'wgan-gp': 35 | 36 | def wgan_loss(input, target): 37 | # target is boolean 38 | return -1 * input.mean() if target else input.mean() 39 | 40 | self.loss = wgan_loss 41 | else: 42 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 43 | 44 | def get_target_label(self, input, target_is_real): 45 | if self.gan_type == 'wgan-gp': 46 | return target_is_real 47 | if target_is_real: 48 | return torch.empty_like(input).fill_(self.real_label_val) 49 | else: 50 | return torch.empty_like(input).fill_(self.fake_label_val) 51 | 52 | def forward(self, input, target_is_real): 53 | target_label = self.get_target_label(input, target_is_real) 54 | loss = self.loss(input, target_label) 55 | return loss 56 | 57 | 58 | class GradientPenaltyLoss(nn.Module): 59 | def __init__(self, device=torch.device('cpu')): 60 | super(GradientPenaltyLoss, self).__init__() 61 | self.register_buffer('grad_outputs', torch.Tensor()) 62 | self.grad_outputs = self.grad_outputs.to(device) 63 | 64 | def get_grad_outputs(self, input): 65 | if self.grad_outputs.size() != input.size(): 66 | self.grad_outputs.resize_(input.size()).fill_(1.0) 67 | return self.grad_outputs 68 | 69 | def forward(self, interp, interp_crit): 70 | grad_outputs = self.get_grad_outputs(interp_crit) 71 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 72 | grad_outputs=grad_outputs, create_graph=True, 73 | retain_graph=True, only_inputs=True)[0] 74 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 75 | grad_interp_norm = grad_interp.norm(2, dim=1) 76 | 77 | loss = ((grad_interp_norm - 1)**2).mean() 78 | return loss 79 | -------------------------------------------------------------------------------- /codes/models/bicubic.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | import sys, time 5 | 6 | # Interpolation kernel 7 | def u(s,a): 8 | if (abs(s) >=0) & (abs(s) <=1): 9 | return (a+2)*(abs(s)**3)-(a+3)*(abs(s)**2)+1 10 | elif (abs(s) > 1) & (abs(s) <= 2): 11 | return a*(abs(s)**3)-(5*a)*(abs(s)**2)+(8*a)*abs(s)-4*a 12 | return 0 13 | 14 | #Paddnig 15 | def padding(img,H,W,C): 16 | zimg = np.zeros((H+4,W+4,C)) 17 | zimg[2:H+2,2:W+2,:C] = img 18 | #Pad the first/last two col and row 19 | zimg[2:H+2,0:2,:C]=img[:,0:1,:C] 20 | zimg[H+2:H+4,2:W+2,:]=img[H-1:H,:,:] 21 | zimg[2:H+2,W+2:W+4,:]=img[:,W-1:W,:] 22 | zimg[0:2,2:W+2,:C]=img[0:1,:,:C] 23 | #Pad the missing eight points 24 | zimg[0:2,0:2,:C]=img[0,0,:C] 25 | zimg[H+2:H+4,0:2,:C]=img[H-1,0,:C] 26 | zimg[H+2:H+4,W+2:W+4,:C]=img[H-1,W-1,:C] 27 | zimg[0:2,W+2:W+4,:C]=img[0,W-1,:C] 28 | return zimg 29 | 30 | # https://github.com/yunabe/codelab/blob/master/misc/terminal_progressbar/progress.py 31 | def get_progressbar_str(progress): 32 | END = 170 33 | MAX_LEN = 30 34 | BAR_LEN = int(MAX_LEN * progress) 35 | return ('Progress:[' + '=' * BAR_LEN + 36 | ('>' if BAR_LEN < MAX_LEN else '') + 37 | ' ' * (MAX_LEN - BAR_LEN) + 38 | '] %.1f%%' % (progress * 100.)) 39 | 40 | # Bicubic operation 41 | def bicubic(img, ratio, a): 42 | #Get image size 43 | H,W,C = img.shape 44 | 45 | img = padding(img,H,W,C) 46 | #Create new image 47 | dH = math.floor(H*ratio) 48 | dW = math.floor(W*ratio) 49 | dst = np.zeros((dH, dW, 3)) 50 | 51 | h = 1/ratio 52 | 53 | print('Start bicubic interpolation') 54 | print('It will take a little while...') 55 | inc = 0 56 | for c in range(C): 57 | for j in range(dH): 58 | for i in range(dW): 59 | x, y = i * h + 2 , j * h + 2 60 | 61 | x1 = 1 + x - math.floor(x) 62 | x2 = x - math.floor(x) 63 | x3 = math.floor(x) + 1 - x 64 | x4 = math.floor(x) + 2 - x 65 | 66 | y1 = 1 + y - math.floor(y) 67 | y2 = y - math.floor(y) 68 | y3 = math.floor(y) + 1 - y 69 | y4 = math.floor(y) + 2 - y 70 | 71 | mat_l = np.matrix([[u(x1,a),u(x2,a),u(x3,a),u(x4,a)]]) 72 | mat_m = np.matrix([[img[int(y-y1),int(x-x1),c],img[int(y-y2),int(x-x1),c],img[int(y+y3),int(x-x1),c],img[int(y+y4),int(x-x1),c]], 73 | [img[int(y-y1),int(x-x2),c],img[int(y-y2),int(x-x2),c],img[int(y+y3),int(x-x2),c],img[int(y+y4),int(x-x2),c]], 74 | [img[int(y-y1),int(x+x3),c],img[int(y-y2),int(x+x3),c],img[int(y+y3),int(x+x3),c],img[int(y+y4),int(x+x3),c]], 75 | [img[int(y-y1),int(x+x4),c],img[int(y-y2),int(x+x4),c],img[int(y+y3),int(x+x4),c],img[int(y+y4),int(x+x4),c]]]) 76 | mat_r = np.matrix([[u(y1,a)],[u(y2,a)],[u(y3,a)],[u(y4,a)]]) 77 | dst[j, i, c] = np.dot(np.dot(mat_l, mat_m),mat_r) 78 | 79 | # Print progress 80 | inc = inc + 1 81 | sys.stderr.write('\r\033[K' + get_progressbar_str(inc/(C*dH*dW))) 82 | sys.stderr.flush() 83 | sys.stderr.write('\n') 84 | sys.stderr.flush() 85 | return dst 86 | 87 | -------------------------------------------------------------------------------- /codes/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import time 5 | import argparse 6 | import torchvision 7 | from collections import OrderedDict 8 | 9 | import numpy as np 10 | import options.options as option 11 | import utils.util as util 12 | from data.util import bgr2ycbcr 13 | from data import create_dataset, create_dataloader 14 | from models import create_model 15 | 16 | 17 | #### options 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 20 | opt = option.parse(parser.parse_args().opt, is_train=False) 21 | opt = option.dict_to_nonedict(opt) 22 | 23 | util.mkdirs( 24 | (path for key, path in opt['path'].items() 25 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 26 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 27 | screen=True, tofile=True) 28 | logger = logging.getLogger('base') 29 | logger.info(option.dict2str(opt)) 30 | 31 | #### Create test dataset and dataloader 32 | test_loaders = [] 33 | for phase, dataset_opt in sorted(opt['datasets'].items()): 34 | test_set = create_dataset(dataset_opt) 35 | test_loader = create_dataloader(test_set, dataset_opt) 36 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 37 | test_loaders.append(test_loader) 38 | 39 | model = create_model(opt) 40 | count=0 41 | for test_loader in test_loaders: 42 | test_set_name = test_loader.dataset.opt['name'] 43 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 44 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 45 | test_start_time = time.time() 46 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 47 | if not os.path.exists(dataset_dir): 48 | os.makedirs(dataset_dir) 49 | 50 | for data in test_loader: 51 | model.feed_data(data) 52 | 53 | model.test() 54 | 55 | visuals = model.get_current_visuals() 56 | 57 | cover = visuals['cover'] 58 | secret = visuals['secret'] 59 | secret_recover = visuals['secret_recover'] 60 | steg = visuals['steg'] 61 | 62 | cover_path = dataset_dir + "\\" + "cover" 63 | secret_path = dataset_dir + "\\" + "secret" 64 | secret_recover_path = dataset_dir + "\\" + "secret_recover" 65 | steg_path = dataset_dir + "\\" + "steg" 66 | 67 | 68 | if not os.path.exists(cover_path): 69 | os.makedirs(cover_path) 70 | if not os.path.exists(secret_path): 71 | os.makedirs(secret_path) 72 | if not os.path.exists(secret_recover_path): 73 | os.makedirs(secret_recover_path) 74 | if not os.path.exists(steg_path): 75 | os.makedirs(steg_path) 76 | 77 | 78 | save_cover_path = osp.join(cover_path, str(count)+'_clean_cover.png') 79 | torchvision.utils.save_image(cover, save_cover_path) 80 | 81 | save_steg_path = osp.join(steg_path, str(count) + '_steg.png') 82 | torchvision.utils.save_image(steg, save_steg_path) 83 | 84 | for i in range(int(secret.shape[0])): 85 | save_secret_path = osp.join(secret_path, str(count) + '_' + str(i) + '_clean_secret.png') 86 | save_secret_recover_path = osp.join(secret_recover_path, str(count) + '_' + str(i) + '_secret_recover.png') 87 | torchvision.utils.save_image(secret[i:i+1,:,:,:], save_secret_path) 88 | torchvision.utils.save_image(secret_recover[i:i + 1, :, :, :], save_secret_recover_path) 89 | 90 | count = count + 1 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Invertible Mosaic Image Hiding Network for Very Large Capacity Image Steganography 2 | [Zihan Chen](https://brittany-chen.github.io/)(chenzihan21@nudt.edu.cn), Tianrui Liu#, [Jun-Jie Huang](https://jjhuangcs.github.io/), Wentao Zhao#, Xing Bi and Meng Wang (#corressponding author) 3 | 4 | Pytorch implementation for "Invertible Mosaic Image Hiding Network for Very Large Capacity Image Steganography" (ICASSP'2024). 5 | 6 | 7 | 8 | The existing image steganography methods either sequentially conceal secret images or conceal a concatenation of multiple images. In such ways, the interference of information among multiple images will become increasingly severe when the number of secret images becomes larger, thus restrict the development of very large capacity image steganography. 9 | In this paper, we propose an Invertible Mosaic Image Hiding Network (InvMIHNet) which realizes very large capacity image steganography with high quality by concealing a single mosaic secret image. InvMIHNet consists of an Invertible Image Rescaling (IIR) module and an Invertible Image Hiding (IIH) module. 10 | The IIR module works for downscaling the single mosaic secret image form by spatially splicing the multiple secret images, and the IIH module then conceal this mosaic image under the cover image. 11 | The proposed InvMIHNet successfully conceal and reveal up to 16 secret images with a small number of parameters and memory consumption. 12 | Extensive experiments on ImageNet-1K, COCO and DIV2K show InvMIHNet outperforms state-of-the-art methods in terms of both the imperceptibility of stego image, recover accuracy of secret image and security against steganlysis methods. 13 | 14 |
15 | 16 |
17 | 18 | ## Requisites 19 | - Python >= 3.7 20 | - PyTorch >= 1.0 21 | - NVIDIA GPU + CUDA CuDNN 22 | 23 | ## Dataset Preparation 24 | - DIV2K 25 | - COCO 26 | - ImageNet 27 | 28 | 29 | ## Get Started 30 | ### Pretrained models 31 | Download and unzip [pretrained models](https://drive.google.com/file/d/17GRiwaJN8yqmLtiAO-bJcgpG8aWHysz3/view?usp=drive_link), and then copy their path to ```experiments/pretrained_models```. 32 | 33 | ### Training for image steganography 34 | First set a config file in options/train/, then run as following: 35 | 36 | python train.py -opt options/train/train_InvMIHNet_4images.yml 37 | 38 | ### Testing for image steganography 39 | First set a config file in options/test/, then run as following: 40 | 41 | python test.py -opt options/test/test_InvMIHNet_4images.yml 42 | 43 | You can choose to conceal and reveal **4, 6, 8, 9, 16 images**. 44 | 45 | ## Description of the files in this repository 46 | 47 | 1) [`data/`](./data): A data loader to provide data for training, validation and testing. 48 | 2) [`models/`](./models): Construct models for training and testing. 49 | 3) [`options/`](./options): Configure the options for data loader, network structure, model, training strategies and etc. 50 | 4) [`experiments/`](./experiments): Save the parameters of InvMHINet. 51 | 52 | 53 | ## Citation 54 | 55 | If you find this code and data useful, please consider citing the original work by authors: 56 | 57 | ``` 58 | @inproceedings{Chen2024InvMIHNet, 59 | title={Invertible Mosaic Image Hiding Network for Very Large Capacity Image Steganography}, 60 | author={Zihan Chen, Tianrui Liu, Jun-Jie Huang, Wentao Zhao, Xing Bi and Meng Wang}, 61 | booktitle={IEEE International Conference on Acoustics, Speech, and Signal Processing}, 62 | volume={}, 63 | number={}, 64 | pages={}, 65 | year={2024} 66 | } 67 | ``` 68 | 69 | ## Acknowledgement 70 | The code is based on [IRN](https://github.com/pkuxmq/Invertible-Image-Rescaling), with reference of [HiNet](https://github.com/TomTomTommi/HiNet). 71 | 72 | ## Contact 73 | If you have any questions, please contact . 74 | -------------------------------------------------------------------------------- /codes/models/modules/discriminator_vgg_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class Discriminator_VGG_128(nn.Module): 7 | def __init__(self, in_nc, nf): 8 | super(Discriminator_VGG_128, self).__init__() 9 | # [64, 128, 128] 10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 13 | # [64, 64, 64] 14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 18 | # [128, 32, 32] 19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 23 | # [256, 16, 16] 24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 28 | # [512, 8, 8] 29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 33 | 34 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 35 | self.linear2 = nn.Linear(100, 1) 36 | 37 | # activation function 38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 39 | 40 | def forward(self, x): 41 | fea = self.lrelu(self.conv0_0(x)) 42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 43 | 44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 46 | 47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 49 | 50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 52 | 53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 55 | 56 | fea = fea.view(fea.size(0), -1) 57 | fea = self.lrelu(self.linear1(fea)) 58 | out = self.linear2(fea) 59 | return out 60 | 61 | 62 | class VGGFeatureExtractor(nn.Module): 63 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, 64 | device=torch.device('cpu')): 65 | super(VGGFeatureExtractor, self).__init__() 66 | self.use_input_norm = use_input_norm 67 | if use_bn: 68 | model = torchvision.models.vgg19_bn(pretrained=True) 69 | else: 70 | model = torchvision.models.vgg19(pretrained=True) 71 | if self.use_input_norm: 72 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 73 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 75 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 76 | self.register_buffer('mean', mean) 77 | self.register_buffer('std', std) 78 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 79 | # No need to BP to variable 80 | for k, v in self.features.named_parameters(): 81 | v.requires_grad = False 82 | 83 | def forward(self, x): 84 | # Assume input range is [0, 1] 85 | if self.use_input_norm: 86 | x = (x - self.mean) / self.std 87 | output = self.features(x) 88 | return output 89 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | ''' set learning rate for warmup, 39 | lr_groups_l: list for lr_groups. each for a optimizer''' 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | # get the initial lr, which is set by the scheduler 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | #### set up warm up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | # return self.schedulers[0].get_lr()[0] 67 | return self.optimizers[0].param_groups[0]['lr'] 68 | 69 | def get_network_description(self, network): 70 | '''Get the string and total parameters of the network''' 71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 72 | network = network.module 73 | s = str(network) 74 | n = sum(map(lambda x: x.numel(), network.parameters())) 75 | return s, n 76 | 77 | def save_network(self, network, network_label, iter_label): 78 | save_filename = '{}_{}.pth'.format(network_label, iter_label) 79 | save_path = os.path.join(self.opt['path']['models'], save_filename) 80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 81 | network = network.module 82 | state_dict = network.state_dict() 83 | for key, param in state_dict.items(): 84 | state_dict[key] = param.cpu() 85 | torch.save(state_dict, save_path) 86 | 87 | def load_network(self, load_path, network, strict=True): 88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 89 | network = network.module 90 | load_net = torch.load(load_path) 91 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 92 | for k, v in load_net.items(): 93 | if k.startswith('module.'): 94 | load_net_clean[k[7:]] = v 95 | else: 96 | load_net_clean[k] = v 97 | network.load_state_dict(load_net_clean, strict=strict) 98 | 99 | def save_training_state(self, epoch, iter_step): 100 | '''Saves training state during training, which will be used for resuming''' 101 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 102 | for s in self.schedulers: 103 | state['schedulers'].append(s.state_dict()) 104 | for o in self.optimizers: 105 | state['optimizers'].append(o.state_dict()) 106 | save_filename = '{}.state'.format(iter_step) 107 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 108 | torch.save(state, save_path) 109 | 110 | def resume_training(self, resume_state): 111 | '''Resume the optimizers and schedulers for training''' 112 | resume_optimizers = resume_state['optimizers'] 113 | resume_schedulers = resume_state['schedulers'] 114 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 115 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 116 | for i, o in enumerate(resume_optimizers): 117 | self.optimizers[i].load_state_dict(o) 118 | for i, s in enumerate(resume_schedulers): 119 | self.schedulers[i].load_state_dict(s) 120 | -------------------------------------------------------------------------------- /codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale_W = opt['scale_W'] 20 | scale_H = opt['scale_H'] 21 | 22 | # datasets 23 | for phase, dataset in opt['datasets'].items(): 24 | phase = phase.split('_')[0] 25 | dataset['phase'] = phase 26 | if opt['distortion'] == 'sr': 27 | dataset['scale_W'] = scale_W 28 | dataset['scale_H'] = scale_H 29 | is_lmdb = False 30 | if dataset.get('dataroot_GT', None) is not None: 31 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 32 | if dataset['dataroot_GT'].endswith('lmdb'): 33 | is_lmdb = True 34 | # if dataset.get('dataroot_GT_bg', None) is not None: 35 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) 36 | if dataset.get('dataroot_LQ', None) is not None: 37 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 38 | if dataset['dataroot_LQ'].endswith('lmdb'): 39 | is_lmdb = True 40 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 41 | if dataset['mode'].endswith('mc'): # for memcached 42 | dataset['data_type'] = 'mc' 43 | dataset['mode'] = dataset['mode'].replace('_mc', '') 44 | 45 | # path 46 | for key, path in opt['path'].items(): 47 | if path and key in opt['path'] and key != 'strict_load': 48 | opt['path'][key] = osp.expanduser(path) 49 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 50 | if is_train: 51 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 52 | opt['path']['experiments_root'] = experiments_root 53 | opt['path']['models'] = osp.join(experiments_root, 'models') 54 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 55 | opt['path']['log'] = experiments_root 56 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 57 | 58 | # change some options for debug mode 59 | if 'debug' in opt['name']: 60 | opt['train']['val_freq'] = 8 61 | opt['logger']['print_freq'] = 1 62 | opt['logger']['save_checkpoint_freq'] = 8 63 | else: # test 64 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 65 | opt['path']['results_root'] = results_root 66 | opt['path']['log'] = results_root 67 | 68 | # network 69 | if opt['distortion'] == 'sr': 70 | opt['network_G']['scale_W'] = scale_W 71 | opt['network_G']['scale_H'] = scale_H 72 | return opt 73 | 74 | 75 | def dict2str(opt, indent_l=1): 76 | '''dict to string for logger''' 77 | msg = '' 78 | for k, v in opt.items(): 79 | if isinstance(v, dict): 80 | msg += ' ' * (indent_l * 2) + k + ':[\n' 81 | msg += dict2str(v, indent_l + 1) 82 | msg += ' ' * (indent_l * 2) + ']\n' 83 | else: 84 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 85 | return msg 86 | 87 | 88 | class NoneDict(dict): 89 | def __missing__(self, key): 90 | return None 91 | 92 | 93 | # convert to NoneDict, which return None for missing key. 94 | def dict_to_nonedict(opt): 95 | if isinstance(opt, dict): 96 | new_opt = dict() 97 | for key, sub_opt in opt.items(): 98 | new_opt[key] = dict_to_nonedict(sub_opt) 99 | return NoneDict(**new_opt) 100 | elif isinstance(opt, list): 101 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 102 | else: 103 | return opt 104 | 105 | 106 | def check_resume(opt, resume_iter): 107 | '''Check resume states and pretrain_model paths''' 108 | logger = logging.getLogger('base') 109 | if opt['path']['resume_state']: 110 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 111 | 'pretrain_model_D', None) is not None: 112 | logger.warning('pretrain_model path will be ignored when resuming training.') 113 | 114 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 115 | '{}_G.pth'.format(resume_iter)) 116 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 117 | if 'gan' in opt['model']: 118 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 119 | '{}_D.pth'.format(resume_iter)) 120 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 121 | -------------------------------------------------------------------------------- /codes/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restart_weights = weights if weights else [1] 16 | assert len(self.restarts) == len( 17 | self.restart_weights), 'restarts and their weights do not match.' 18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | if self.last_epoch in self.restarts: 22 | if self.clear_state: 23 | self.optimizer.state = defaultdict(dict) 24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 26 | if self.last_epoch not in self.milestones: 27 | return [group['lr'] for group in self.optimizer.param_groups] 28 | return [ 29 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 30 | for group in self.optimizer.param_groups 31 | ] 32 | 33 | 34 | class CosineAnnealingLR_Restart(_LRScheduler): 35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 36 | self.T_period = T_period 37 | self.T_max = self.T_period[0] # current T period 38 | self.eta_min = eta_min 39 | self.restarts = restarts if restarts else [0] 40 | self.restart_weights = weights if weights else [1] 41 | self.last_restart = 0 42 | assert len(self.restarts) == len( 43 | self.restart_weights), 'restarts and their weights do not match.' 44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | if self.last_epoch == 0: 48 | return self.base_lrs 49 | elif self.last_epoch in self.restarts: 50 | self.last_restart = self.last_epoch 51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 55 | return [ 56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 58 | ] 59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 61 | (group['lr'] - self.eta_min) + self.eta_min 62 | for group in self.optimizer.param_groups] 63 | 64 | 65 | if __name__ == "__main__": 66 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 67 | betas=(0.9, 0.99)) 68 | ############################## 69 | # MultiStepLR_Restart 70 | ############################## 71 | ## Original 72 | lr_steps = [200000, 400000, 600000, 800000] 73 | restarts = None 74 | restart_weights = None 75 | 76 | ## two 77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 78 | restarts = [500000] 79 | restart_weights = [1] 80 | 81 | ## four 82 | lr_steps = [ 83 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 84 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 85 | ] 86 | restarts = [250000, 500000, 750000] 87 | restart_weights = [1, 1, 1] 88 | 89 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 90 | clear_state=False) 91 | 92 | ############################## 93 | # Cosine Annealing Restart 94 | ############################## 95 | ## two 96 | T_period = [500000, 500000] 97 | restarts = [500000] 98 | restart_weights = [1] 99 | 100 | ## four 101 | T_period = [250000, 250000, 250000, 250000] 102 | restarts = [250000, 500000, 750000] 103 | restart_weights = [1, 1, 1] 104 | 105 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 106 | weights=restart_weights) 107 | 108 | ############################## 109 | # Draw figure 110 | ############################## 111 | N_iter = 1000000 112 | lr_l = list(range(N_iter)) 113 | for i in range(N_iter): 114 | scheduler.step() 115 | current_lr = optimizer.param_groups[0]['lr'] 116 | lr_l[i] = current_lr 117 | 118 | import matplotlib as mpl 119 | from matplotlib import pyplot as plt 120 | import matplotlib.ticker as mtick 121 | mpl.style.use('default') 122 | import seaborn 123 | seaborn.set(style='whitegrid') 124 | seaborn.set_context('paper') 125 | 126 | plt.figure(1) 127 | plt.subplot(111) 128 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 129 | plt.title('Title', fontsize=16, color='k') 130 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 131 | legend = plt.legend(loc='upper right', shadow=False) 132 | ax = plt.gca() 133 | labels = ax.get_xticks().tolist() 134 | for k, v in enumerate(labels): 135 | labels[k] = str(int(v / 1000)) + 'K' 136 | ax.set_xticklabels(labels) 137 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 138 | 139 | ax.set_ylabel('Learning rate') 140 | ax.set_xlabel('Iteration') 141 | fig = plt.gcf() 142 | plt.show() 143 | -------------------------------------------------------------------------------- /codes/data/Steg_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | 9 | 10 | class StegDataset(data.Dataset): 11 | ''' 12 | Read LQ (Low Quality, here is LR) and GT image pairs. 13 | If only GT image is provided, generate LQ image on-the-fly. 14 | The pair is ensured by 'sorted' function, so please check the name convention. 15 | ''' 16 | 17 | def __init__(self, opt): 18 | super(StegDataset, self).__init__() 19 | self.opt = opt 20 | self.data_type = self.opt['data_type'] 21 | self.paths_LQ, self.paths_GT = None, None 22 | self.sizes_LQ, self.sizes_GT = None, None 23 | self.LQ_env, self.GT_env = None, None # environment for lmdb 24 | 25 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) 26 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 27 | assert self.paths_GT, 'Error: GT path is empty.' 28 | if self.paths_LQ and self.paths_GT: 29 | assert len(self.paths_LQ) == len( 30 | self.paths_GT 31 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format( 32 | len(self.paths_LQ), len(self.paths_GT)) 33 | self.random_scale_list = [1] 34 | 35 | def _init_lmdb(self): 36 | # https://github.com/chainer/chainermn/issues/129 37 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 38 | meminit=False) 39 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 40 | meminit=False) 41 | 42 | def __getitem__(self, index): 43 | if self.data_type == 'lmdb': 44 | if (self.GT_env is None) or (self.LQ_env is None): 45 | self._init_lmdb() 46 | GT_path, LQ_path = None, None 47 | scale_W = self.opt['scale_W'] 48 | scale_H = self.opt['scale_H'] 49 | GT_size = self.opt['GT_size'] 50 | 51 | # get GT image 52 | GT_path = self.paths_GT[index] 53 | if self.data_type == 'lmdb': 54 | resolution = [int(s) for s in self.sizes_GT[index].split('_')] 55 | else: 56 | resolution = None 57 | img_GT = util.read_img(self.GT_env, GT_path, resolution) 58 | # modcrop in the validation / test phase 59 | if self.opt['phase'] != 'train': 60 | img_GT = util.modcrop(img_GT, scale_W, scale_H) 61 | # change color space if necessary 62 | if self.opt['color']: 63 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 64 | 65 | # get LQ image 66 | if self.paths_LQ: 67 | LQ_path = self.paths_LQ[index] 68 | if self.data_type == 'lmdb': 69 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')] 70 | else: 71 | resolution = None 72 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 73 | else: # down-sampling on-the-fly 74 | # randomly scale during training 75 | if self.opt['phase'] == 'train': 76 | random_scale = random.choice(self.random_scale_list) 77 | H_s, W_s, _ = img_GT.shape 78 | 79 | def _mod(n, random_scale, scale, thres): 80 | rlt = int(n * random_scale) 81 | rlt = (rlt // scale) * scale 82 | return thres if rlt < thres else rlt 83 | 84 | H_s = _mod(H_s, random_scale, scale_H, GT_size) 85 | W_s = _mod(W_s, random_scale, scale_W, GT_size) 86 | img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR) 87 | # force to 3 channels 88 | if img_GT.ndim == 2: 89 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) 90 | 91 | H, W, _ = img_GT.shape 92 | # using matlab imresize 93 | img_LQ = util.imresize_np(img_GT, 1 / scale_W, 1 / scale_H, True) 94 | if img_LQ.ndim == 2: 95 | img_LQ = np.expand_dims(img_LQ, axis=2) 96 | 97 | if self.opt['phase'] == 'train': 98 | # if the image size is too small 99 | H, W, _ = img_GT.shape 100 | if H < GT_size or W < GT_size: 101 | img_GT = cv2.resize(np.copy(img_GT), (GT_size, GT_size), 102 | interpolation=cv2.INTER_LINEAR) 103 | # using matlab imresize 104 | img_LQ = util.imresize_np(img_GT, 1 / scale_W, 1 / scale_H, True) 105 | if img_LQ.ndim == 2: 106 | img_LQ = np.expand_dims(img_LQ, axis=2) 107 | 108 | H, W, C = img_LQ.shape 109 | LQ_size_W = GT_size // scale_W 110 | LQ_size_H = GT_size // scale_H 111 | 112 | # randomly crop 113 | rnd_h = random.randint(0, max(0, H - LQ_size_H)) 114 | rnd_w = random.randint(0, max(0, W - LQ_size_W)) 115 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size_H, rnd_w:rnd_w + LQ_size_W, :] 116 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale_H), int(rnd_w * scale_W) 117 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] 118 | 119 | # augmentation - flip, rotate 120 | # img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], 121 | # self.opt['use_rot']) 122 | 123 | # change color space if necessary 124 | # if self.opt['color']: 125 | # img_LQ = util.channel_convert(C, self.opt['color'], 126 | # [img_LQ])[0] # TODO during val no definition 127 | 128 | # BGR to RGB, HWC to CHW, numpy to tensor 129 | if img_GT.shape[2] == 3: 130 | img_GT = img_GT[:, :, [2, 1, 0]] 131 | img_LQ = img_LQ[:, :, [2, 1, 0]] 132 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 133 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 134 | 135 | if LQ_path is None: 136 | LQ_path = GT_path 137 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} 138 | 139 | def __len__(self): 140 | return len(self.paths_GT) 141 | -------------------------------------------------------------------------------- /codes/models/modules/module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def initialize_weights_xavier(net_l, scale=1): 28 | if not isinstance(net_l, list): 29 | net_l = [net_l] 30 | for net in net_l: 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.xavier_normal_(m.weight) 34 | m.weight.data *= scale # for residual block 35 | if m.bias is not None: 36 | m.bias.data.zero_() 37 | elif isinstance(m, nn.Linear): 38 | init.xavier_normal_(m.weight) 39 | m.weight.data *= scale 40 | if m.bias is not None: 41 | m.bias.data.zero_() 42 | elif isinstance(m, nn.BatchNorm2d): 43 | init.constant_(m.weight, 1) 44 | init.constant_(m.bias.data, 0.0) 45 | 46 | 47 | def make_layer(block, n_layers): 48 | layers = [] 49 | for _ in range(n_layers): 50 | layers.append(block()) 51 | return nn.Sequential(*layers) 52 | 53 | 54 | class ResidualBlock_noBN(nn.Module): 55 | '''Residual block w/o BN 56 | ---Conv-ReLU-Conv-+- 57 | |________________| 58 | ''' 59 | 60 | def __init__(self, nf=64): 61 | super(ResidualBlock_noBN, self).__init__() 62 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 64 | 65 | # initialization 66 | initialize_weights([self.conv1, self.conv2], 0.1) 67 | 68 | def forward(self, x): 69 | identity = x 70 | out = F.relu(self.conv1(x), inplace=True) 71 | out = self.conv2(out) 72 | return identity + out 73 | 74 | 75 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 76 | """Warp an image or feature map with optical flow 77 | Args: 78 | x (Tensor): size (N, C, H, W) 79 | flow (Tensor): size (N, H, W, 2), normal value 80 | interp_mode (str): 'nearest' or 'bilinear' 81 | padding_mode (str): 'zeros' or 'border' or 'reflection' 82 | 83 | Returns: 84 | Tensor: warped image or feature map 85 | """ 86 | assert x.size()[-2:] == flow.size()[1:3] 87 | B, C, H, W = x.size() 88 | # mesh grid 89 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 90 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 91 | grid.requires_grad = False 92 | grid = grid.type_as(x) 93 | vgrid = grid + flow 94 | # scale grid to [-1,1] 95 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 96 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 97 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 98 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 99 | return output 100 | 101 | 102 | 103 | def initialize_weights(net_l, scale=1): 104 | if not isinstance(net_l, list): 105 | net_l = [net_l] 106 | for net in net_l: 107 | for m in net.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 110 | m.weight.data *= scale # for residual block 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.Linear): 114 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 115 | m.weight.data *= scale 116 | if m.bias is not None: 117 | m.bias.data.zero_() 118 | elif isinstance(m, nn.BatchNorm2d): 119 | init.constant_(m.weight, 1) 120 | init.constant_(m.bias.data, 0.0) 121 | 122 | 123 | def make_layer(block, n_layers): 124 | layers = [] 125 | for _ in range(n_layers): 126 | layers.append(block()) 127 | return nn.Sequential(*layers) 128 | 129 | 130 | class ResidualBlock_noBN(nn.Module): 131 | '''Residual block w/o BN 132 | ---Conv-ReLU-Conv-+- 133 | |________________| 134 | ''' 135 | 136 | def __init__(self, nf=64): 137 | super(ResidualBlock_noBN, self).__init__() 138 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 139 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 140 | 141 | # initialization 142 | initialize_weights([self.conv1, self.conv2], 0.1) 143 | 144 | def forward(self, x): 145 | identity = x 146 | out = F.relu(self.conv1(x), inplace=True) 147 | out = self.conv2(out) 148 | return identity + out 149 | 150 | 151 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 152 | """Warp an image or feature map with optical flow 153 | Args: 154 | x (Tensor): size (N, C, H, W) 155 | flow (Tensor): size (N, H, W, 2), normal value 156 | interp_mode (str): 'nearest' or 'bilinear' 157 | padding_mode (str): 'zeros' or 'border' or 'reflection' 158 | Returns: 159 | Tensor: warped image or feature map 160 | """ 161 | flow = flow.permute(0,2,3,1) 162 | assert x.size()[-2:] == flow.size()[1:3] 163 | B, C, H, W = x.size() 164 | # mesh grid 165 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 166 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 167 | grid.requires_grad = False 168 | grid = grid.type_as(x) 169 | vgrid = grid + flow 170 | # scale grid to [-1,1] 171 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 172 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 173 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 174 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 175 | return output -------------------------------------------------------------------------------- /codes/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | from datetime import datetime 6 | import random 7 | import logging 8 | from collections import OrderedDict 9 | import numpy as np 10 | import cv2 11 | import torch 12 | from torchvision.utils import make_grid 13 | from shutil import get_terminal_size 14 | 15 | import yaml 16 | try: 17 | from yaml import CLoader as Loader, CDumper as Dumper 18 | except ImportError: 19 | from yaml import Loader, Dumper 20 | 21 | 22 | def OrderedYaml(): 23 | '''yaml orderedDict support''' 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | #################### 38 | # miscellaneous 39 | #################### 40 | 41 | 42 | def get_timestamp(): 43 | return datetime.now().strftime('%y%m%d-%H%M%S') 44 | 45 | 46 | def mkdir(path): 47 | if not os.path.exists(path): 48 | os.makedirs(path) 49 | 50 | 51 | def mkdirs(paths): 52 | if isinstance(paths, str): 53 | mkdir(paths) 54 | else: 55 | for path in paths: 56 | mkdir(path) 57 | 58 | 59 | def mkdir_and_rename(path): 60 | if os.path.exists(path): 61 | new_name = path + '_archived_' + get_timestamp() 62 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 63 | logger = logging.getLogger('base') 64 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 65 | os.rename(path, new_name) 66 | os.makedirs(path) 67 | 68 | 69 | def set_random_seed(seed): 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | 75 | 76 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 77 | '''set up logger''' 78 | lg = logging.getLogger(logger_name) 79 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 80 | datefmt='%y-%m-%d %H:%M:%S') 81 | lg.setLevel(level) 82 | if tofile: 83 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 84 | fh = logging.FileHandler(log_file, mode='w') 85 | fh.setFormatter(formatter) 86 | lg.addHandler(fh) 87 | if screen: 88 | sh = logging.StreamHandler() 89 | sh.setFormatter(formatter) 90 | lg.addHandler(sh) 91 | 92 | 93 | #################### 94 | # image convert 95 | #################### 96 | 97 | 98 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 99 | ''' 100 | Converts a torch Tensor into an image Numpy array 101 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 102 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 103 | ''' 104 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 105 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 106 | n_dim = tensor.dim() 107 | if n_dim == 4: 108 | n_img = len(tensor) 109 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 110 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 111 | elif n_dim == 3: 112 | img_np = tensor.numpy() 113 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 114 | elif n_dim == 2: 115 | img_np = tensor.numpy() 116 | else: 117 | raise TypeError( 118 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 119 | if out_type == np.uint8: 120 | img_np = (img_np * 255.0).round() 121 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 122 | return img_np.astype(out_type) 123 | 124 | 125 | def save_img(img, img_path, mode='RGB'): 126 | cv2.imwrite(img_path, img) 127 | 128 | 129 | #################### 130 | # metric 131 | #################### 132 | 133 | 134 | def calculate_psnr(img1, img2): 135 | # img1 and img2 have range [0, 255] 136 | img1 = img1.astype(np.float64) 137 | img2 = img2.astype(np.float64) 138 | mse = np.mean((img1 - img2)**2) 139 | if mse == 0: 140 | return float('inf') 141 | return 20 * math.log10(255.0 / math.sqrt(mse)) 142 | 143 | 144 | def ssim(img1, img2): 145 | C1 = (0.01 * 255)**2 146 | C2 = (0.03 * 255)**2 147 | 148 | img1 = img1.astype(np.float64) 149 | img2 = img2.astype(np.float64) 150 | kernel = cv2.getGaussianKernel(11, 1.5) 151 | window = np.outer(kernel, kernel.transpose()) 152 | 153 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 154 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 155 | mu1_sq = mu1**2 156 | mu2_sq = mu2**2 157 | mu1_mu2 = mu1 * mu2 158 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 159 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 160 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 161 | 162 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 163 | (sigma1_sq + sigma2_sq + C2)) 164 | return ssim_map.mean() 165 | 166 | 167 | def calculate_ssim(img1, img2): 168 | '''calculate SSIM 169 | the same outputs as MATLAB's 170 | img1, img2: [0, 255] 171 | ''' 172 | if not img1.shape == img2.shape: 173 | raise ValueError('Input images must have the same dimensions.') 174 | if img1.ndim == 2: 175 | return ssim(img1, img2) 176 | elif img1.ndim == 3: 177 | if img1.shape[2] == 3: 178 | ssims = [] 179 | for i in range(3): 180 | ssims.append(ssim(img1, img2)) 181 | return np.array(ssims).mean() 182 | elif img1.shape[2] == 1: 183 | return ssim(np.squeeze(img1), np.squeeze(img2)) 184 | else: 185 | raise ValueError('Wrong input image dimensions.') 186 | 187 | 188 | class ProgressBar(object): 189 | '''A progress bar which can print the progress 190 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 191 | ''' 192 | 193 | def __init__(self, task_num=0, bar_width=50, start=True): 194 | self.task_num = task_num 195 | max_bar_width = self._get_max_bar_width() 196 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 197 | self.completed = 0 198 | if start: 199 | self.start() 200 | 201 | def _get_max_bar_width(self): 202 | terminal_width, _ = get_terminal_size() 203 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 204 | if max_bar_width < 10: 205 | print('terminal width is too small ({}), please consider widen the terminal for better ' 206 | 'progressbar visualization'.format(terminal_width)) 207 | max_bar_width = 10 208 | return max_bar_width 209 | 210 | def start(self): 211 | if self.task_num > 0: 212 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 213 | ' ' * self.bar_width, self.task_num, 'Start...')) 214 | else: 215 | sys.stdout.write('completed: 0, elapsed: 0s') 216 | sys.stdout.flush() 217 | self.start_time = time.time() 218 | 219 | def update(self, msg='In progress...'): 220 | self.completed += 1 221 | elapsed = time.time() - self.start_time 222 | fps = self.completed / elapsed 223 | if self.task_num > 0: 224 | percentage = self.completed / float(self.task_num) 225 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 226 | mark_width = int(self.bar_width * percentage) 227 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 228 | sys.stdout.write('\033[2F') # cursor up 2 lines 229 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 230 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 231 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 232 | else: 233 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 234 | self.completed, int(elapsed + 0.5), fps)) 235 | sys.stdout.flush() 236 | -------------------------------------------------------------------------------- /codes/models/modules/Inv_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import torchvision 7 | 8 | 9 | class InvBlockExp(nn.Module): 10 | def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=1.): 11 | super(InvBlockExp, self).__init__() 12 | 13 | self.split_len1 = channel_split_num 14 | self.split_len2 = channel_num - channel_split_num 15 | 16 | self.clamp = clamp 17 | 18 | self.F = subnet_constructor(self.split_len2, self.split_len1) 19 | self.G = subnet_constructor(self.split_len1, self.split_len2) 20 | self.H = subnet_constructor(self.split_len1, self.split_len2) 21 | 22 | def forward(self, x, rev=False): 23 | x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) 24 | 25 | if not rev: 26 | y1 = x1 + self.F(x2) 27 | self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) 28 | y2 = x2.mul(torch.exp(self.s)) + self.G(y1) 29 | else: 30 | self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1) 31 | y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) 32 | y1 = x1 - self.F(y2) 33 | 34 | return torch.cat((y1, y2), 1) 35 | 36 | def jacobian(self, x, rev=False): 37 | if not rev: 38 | jac = torch.sum(self.s) 39 | else: 40 | jac = -torch.sum(self.s) 41 | 42 | return jac / x.shape[0] 43 | 44 | 45 | class HaarDownsampling(nn.Module): 46 | def __init__(self, channel_in): 47 | super(HaarDownsampling, self).__init__() 48 | self.channel_in = channel_in 49 | 50 | self.haar_weights = torch.ones(4, 1, 2, 2) 51 | 52 | self.haar_weights[1, 0, 0, 1] = -1 53 | self.haar_weights[1, 0, 1, 1] = -1 54 | 55 | self.haar_weights[2, 0, 1, 0] = -1 56 | self.haar_weights[2, 0, 1, 1] = -1 57 | 58 | self.haar_weights[3, 0, 1, 0] = -1 59 | self.haar_weights[3, 0, 0, 1] = -1 60 | 61 | self.haar_weights = torch.cat([self.haar_weights] * self.channel_in, 0) 62 | self.haar_weights = nn.Parameter(self.haar_weights) 63 | self.haar_weights.requires_grad = False 64 | 65 | def forward(self, x, rev=False): 66 | if not rev: 67 | self.elements = x.shape[1] * x.shape[2] * x.shape[3] 68 | self.last_jac = self.elements / 4 * np.log(1/16.) 69 | 70 | out = F.conv2d(x, self.haar_weights, bias=None, stride=2, groups=self.channel_in) / 4.0 71 | out = out.reshape([x.shape[0], self.channel_in, 4, x.shape[2] // 2, x.shape[3] // 2]) 72 | out = torch.transpose(out, 1, 2) 73 | out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2] // 2, x.shape[3] // 2]) 74 | return out 75 | else: 76 | self.elements = x.shape[1] * x.shape[2] * x.shape[3] 77 | self.last_jac = self.elements / 4 * np.log(16.) 78 | out = x.reshape([x.shape[0], 4, self.channel_in, x.shape[2], x.shape[3]]) 79 | out = torch.transpose(out, 1, 2) 80 | out = out.reshape([x.shape[0], self.channel_in * 4, x.shape[2], x.shape[3]]) 81 | return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=2, groups = self.channel_in) 82 | 83 | def jacobian(self, x, rev=False): 84 | return self.last_jac 85 | 86 | class ConvDownsampling_downsampleH(nn.Module): 87 | def __init__(self, channel_in): 88 | super(ConvDownsampling_downsampleH, self).__init__() 89 | 90 | self.channel_in = channel_in 91 | 92 | self.haar_weights = torch.ones(2, 1, 2, 1) 93 | self.haar_weights[0, 0, 0, 0] = 1/2 94 | self.haar_weights[0, 0, 1, 0] = 1/2 95 | 96 | self.haar_weights[1, 0, 0, 0] = -1/2 97 | self.haar_weights[1, 0, 1, 0] = 1/2 98 | 99 | self.haar_weights = torch.cat([self.haar_weights] * self.channel_in, 0) 100 | self.haar_weights = nn.Parameter(self.haar_weights) 101 | self.haar_weights.requires_grad = True 102 | 103 | def forward(self, x, rev=False): 104 | if not rev: 105 | self.elements = x.shape[1] * x.shape[2] * x.shape[3] 106 | self.last_jac = self.elements / 4 * np.log(1/16.) 107 | 108 | out = F.conv2d(x, self.haar_weights, bias=None, stride=(2, 1), groups=self.channel_in) 109 | out = out.reshape([x.shape[0], self.channel_in, 2, x.shape[2] // 2, x.shape[3]]) 110 | out = torch.transpose(out, 1, 2) 111 | out = out.reshape([x.shape[0], self.channel_in * 2, x.shape[2] // 2, x.shape[3]]) 112 | return out 113 | else: 114 | r = 2 115 | in_batch, in_channel, in_height, in_width = x.size() 116 | out_batch, out_channel, out_height, out_width = in_batch, int( 117 | in_channel / r), r * in_height, in_width 118 | 119 | self.elements = x.shape[1] * x.shape[2] * x.shape[3] 120 | self.last_jac = self.elements / 4 * np.log(16.) 121 | 122 | out = x.reshape([x.shape[0], 2, self.channel_in, x.shape[2], x.shape[3]]) 123 | out = torch.transpose(out, 1, 2) 124 | out = out.reshape([x.shape[0], self.channel_in * 2, x.shape[2], x.shape[3]]) 125 | return F.conv_transpose2d(out, self.haar_weights, bias=None, stride=(2, 1), groups=self.channel_in) * 2.0 126 | 127 | 128 | class ConvDownsampling(nn.Module): 129 | def __init__(self, scale_W, scale_H): 130 | super(ConvDownsampling, self).__init__() 131 | self.scale_W = scale_W 132 | self.scale_H = scale_H 133 | self.scale2 = self.scale_H * self.scale_W 134 | 135 | self.conv_weights = torch.eye(self.scale2) 136 | 137 | if self.scale_W == self.scale_H == 2: # haar init 138 | self.conv_weights[0] = torch.Tensor([1./4, 1./4, 1./4, 1./4]) 139 | self.conv_weights[1] = torch.Tensor([1./4, -1./4, 1./4, -1./4]) 140 | self.conv_weights[2] = torch.Tensor([1./4, 1./4, -1./4, -1./4]) 141 | self.conv_weights[3] = torch.Tensor([1./4, -1./4, -1./4, 1./4]) 142 | else: 143 | self.conv_weights[0] = torch.Tensor([1./(self.scale2)] * (self.scale2)) 144 | 145 | self.conv_weights = nn.Parameter(self.conv_weights) 146 | 147 | def forward(self, x, rev=False): 148 | if not rev: 149 | h = x.shape[2] 150 | w = x.shape[3] 151 | 152 | [B, C, H, W] = list(x.size()) 153 | x = x.reshape(B, C, H // self.scale_H, self.scale_H, W // self.scale_W, self.scale_W) 154 | x = x.permute(0, 1, 3, 5, 2, 4) 155 | x = x.reshape(B, C * self.scale2, H // self.scale_H, W // self.scale_W) 156 | 157 | conv_weights = self.conv_weights.reshape(self.scale2, self.scale2, 1, 1) 158 | conv_weights = conv_weights.repeat(C, 1, 1, 1) 159 | 160 | out = F.conv2d(x, conv_weights, bias=None, stride=1, groups=C) 161 | 162 | out = out.reshape(B, C, self.scale2, H // self.scale_H, W // self.scale_W) 163 | out = torch.transpose(out, 1, 2) 164 | out = out.reshape(B, C * self.scale2, H // self.scale_H, W // self.scale_W) 165 | 166 | 167 | return out 168 | else: 169 | inv_weights = torch.inverse(self.conv_weights) 170 | inv_weights = inv_weights.reshape(self.scale2, self.scale2, 1, 1) 171 | 172 | [B, C_, H_, W_] = list(x.size()) 173 | C = C_ // self.scale2 174 | H = H_ * self.scale_H 175 | W = W_ * self.scale_W 176 | 177 | inv_weights = inv_weights.repeat(C, 1, 1, 1) 178 | 179 | x = x.reshape(B, self.scale2, C, H_, W_) 180 | x = torch.transpose(x, 1, 2) 181 | x = x.reshape(B, C_, H_, W_) 182 | 183 | out = F.conv2d(x, inv_weights, bias=None, stride=1, groups=C) 184 | 185 | out = out.reshape(B, C, self.scale_H, self.scale_W, H_, W_) 186 | out = out.permute(0, 1, 4, 2, 5, 3) 187 | out = out.reshape(B, C, H, W) 188 | 189 | return out 190 | 191 | class InvRescaleNet(nn.Module): 192 | def __init__(self, channel_in=3, channel_out=3, subnet_constructor=None, block_num=[], down_num=2, use_ConvDownsampling=False, down_scale_W=2, down_scale_H=2): 193 | super(InvRescaleNet, self).__init__() 194 | 195 | operations = [] 196 | 197 | if use_ConvDownsampling: 198 | down_num = 1 199 | 200 | current_channel = channel_in 201 | for i in range(down_num): 202 | if use_ConvDownsampling: 203 | b = ConvDownsampling(down_scale_W, down_scale_H) 204 | current_channel *= down_scale_W * down_scale_H 205 | elif down_scale_W==2 and down_scale_H==4 and i==1: 206 | b = ConvDownsampling_downsampleH(current_channel) 207 | current_channel *= 2 208 | else: 209 | b = HaarDownsampling(current_channel) 210 | current_channel *= 4 211 | operations.append(b) 212 | for j in range(block_num[i]): 213 | b = InvBlockExp(subnet_constructor, current_channel, channel_out) 214 | operations.append(b) 215 | 216 | self.operations = nn.ModuleList(operations) 217 | 218 | def forward(self, x, rev=False, cal_jacobian=False): 219 | out = x 220 | jacobian = 0 221 | 222 | if not rev: 223 | for op in self.operations: 224 | out = op.forward(out, rev) 225 | if cal_jacobian: 226 | jacobian += op.jacobian(out, rev) 227 | else: 228 | for op in reversed(self.operations): 229 | out = op.forward(out, rev) 230 | if cal_jacobian: 231 | jacobian += op.jacobian(out, rev) 232 | 233 | if cal_jacobian: 234 | return out, jacobian 235 | else: 236 | return out 237 | 238 | -------------------------------------------------------------------------------- /codes/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import logging 6 | import warnings 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from data.data_sampler import DistIterSampler 12 | 13 | import options.options as option 14 | from utils import util 15 | from data import create_dataloader, create_dataset 16 | from models import create_model 17 | 18 | import torchvision 19 | import os.path as osp 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | def init_dist(backend='nccl', **kwargs): 24 | ''' initialization for distributed training''' 25 | # if mp.get_start_method(allow_none=True) is None: 26 | if mp.get_start_method(allow_none=True) != 'spawn': 27 | mp.set_start_method('spawn') 28 | rank = int(os.environ['RANK']) 29 | num_gpus = torch.cuda.device_count() 30 | torch.cuda.set_device(rank % num_gpus) 31 | dist.init_process_group(backend=backend, **kwargs) 32 | 33 | def load(name, net): 34 | state_dicts = torch.load(name) 35 | network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k} 36 | net.load_state_dict(network_state_dict) 37 | 38 | 39 | def main(): 40 | #### options 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') 43 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 44 | help='job launcher') 45 | parser.add_argument('--local_rank', type=int, default=0) 46 | args = parser.parse_args() 47 | opt = option.parse(args.opt, is_train=True) 48 | 49 | 50 | #### distributed training settings 51 | if args.launcher == 'none': # disabled distributed training 52 | opt['dist'] = False 53 | rank = -1 54 | print('Disabled distributed training.') 55 | else: 56 | opt['dist'] = True 57 | init_dist() 58 | world_size = torch.distributed.get_world_size() 59 | rank = torch.distributed.get_rank() 60 | 61 | #### loading resume state if exists 62 | if opt['path'].get('resume_state', None): 63 | # distributed resuming: all load into default GPU 64 | device_id = torch.cuda.current_device() 65 | resume_state = torch.load(opt['path']['resume_state'], 66 | map_location=lambda storage, loc: storage.cuda(device_id)) 67 | option.check_resume(opt, resume_state['iter']) # check resume options 68 | else: 69 | resume_state = None 70 | 71 | #### mkdir and loggers 72 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 73 | if resume_state is None: 74 | util.mkdir_and_rename( 75 | opt['path']['experiments_root']) # rename experiment folder if exists 76 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 77 | and 'pretrain_model' not in key and 'resume' not in key)) 78 | 79 | # config loggers. Before it, the log will not work 80 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 81 | screen=True, tofile=True) 82 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 83 | screen=True, tofile=True) 84 | logger = logging.getLogger('base') 85 | logger.info(option.dict2str(opt)) 86 | # tensorboard logger 87 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 88 | version = float(torch.__version__[0:3]) 89 | if version >= 1.1: # PyTorch 1.1 90 | from torch.utils.tensorboard import SummaryWriter 91 | else: 92 | logger.info( 93 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 94 | from tensorboardX import SummaryWriter 95 | tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) 96 | else: 97 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 98 | logger = logging.getLogger('base') 99 | 100 | # convert to NoneDict, which returns None for missing keys 101 | opt = option.dict_to_nonedict(opt) 102 | 103 | #### random seed 104 | seed = opt['train']['manual_seed'] 105 | if seed is None: 106 | seed = random.randint(1, 10000) 107 | if rank <= 0: 108 | logger.info('Random seed: {}'.format(seed)) 109 | util.set_random_seed(seed) 110 | 111 | torch.backends.cudnn.benchmark = True 112 | # torch.backends.cudnn.deterministic = True 113 | 114 | #### create train and val dataloader 115 | dataset_ratio = 1 # enlarge the size of each epoch 116 | for phase, dataset_opt in opt['datasets'].items(): 117 | if phase == 'train': 118 | train_set = create_dataset(dataset_opt) 119 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) 120 | total_iters = int(opt['train']['niter']) 121 | total_epochs = int(math.ceil(total_iters / train_size)) 122 | if opt['dist']: 123 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) 124 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) 125 | else: 126 | train_sampler = None 127 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) 128 | if rank <= 0: 129 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 130 | len(train_set), train_size)) 131 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 132 | total_epochs, total_iters)) 133 | elif phase == 'val': 134 | val_set = create_dataset(dataset_opt) 135 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 136 | if rank <= 0: 137 | logger.info('Number of val images in [{:s}]: {:d}'.format( 138 | dataset_opt['name'], len(val_set))) 139 | else: 140 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 141 | assert train_loader is not None 142 | 143 | #### create model 144 | model = create_model(opt) 145 | 146 | 147 | #### resume training 148 | if resume_state: 149 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 150 | resume_state['epoch'], resume_state['iter'])) 151 | 152 | start_epoch = resume_state['epoch'] 153 | current_step = resume_state['iter'] 154 | model.resume_training(resume_state) # handle optimizers and schedulers 155 | else: 156 | current_step = 0 157 | start_epoch = 0 158 | 159 | #### training 160 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 161 | for epoch in range(start_epoch, total_epochs + 1): 162 | if opt['dist']: 163 | train_sampler.set_epoch(epoch) 164 | for _, train_data in enumerate(train_loader): 165 | current_step += 1 166 | if current_step > total_iters: 167 | break 168 | #### training 169 | model.feed_data(train_data) 170 | model.optimize_parameters(current_step) 171 | 172 | 173 | #### update learning rate 174 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) 175 | 176 | # validation 177 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0: 178 | avg_psnr_secret = 0.0 179 | avg_psnr_cover = 0.0 180 | idx = 0 181 | count = 0 182 | for val_data in val_loader: 183 | idx += 1 184 | # img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] 185 | img_dir = os.path.join(opt['path']['val_images'], str(current_step)) 186 | util.mkdir(img_dir) 187 | 188 | model.feed_data(val_data) 189 | model.test() 190 | count = count + 1 191 | 192 | visuals = model.get_current_visuals() 193 | 194 | cover = visuals['cover'] 195 | secret = visuals['secret'] 196 | secret_recover = visuals['secret_recover'] 197 | steg = visuals['steg'] 198 | 199 | cover_path = img_dir + "\\" + "cover" 200 | secret_path = img_dir + "\\" + "secret" 201 | secret_recover_path = img_dir + "\\" + "secret_recover" 202 | steg_path = img_dir + "\\" + "steg" 203 | 204 | if not os.path.exists(cover_path): 205 | os.makedirs(cover_path) 206 | if not os.path.exists(secret_path): 207 | os.makedirs(secret_path) 208 | if not os.path.exists(secret_recover_path): 209 | os.makedirs(secret_recover_path) 210 | if not os.path.exists(steg_path): 211 | os.makedirs(steg_path) 212 | 213 | save_cover_path = osp.join(cover_path, str(count) + '_clean_cover.png') 214 | torchvision.utils.save_image(cover, save_cover_path) 215 | 216 | save_steg_path = osp.join(steg_path, str(count) + '_steg.png') 217 | torchvision.utils.save_image(steg, save_steg_path) 218 | 219 | for i in range(int(secret.shape[0])): 220 | save_secret_path = osp.join(secret_path, str(count) + '_' + str(i) + '_clean_secret.png') 221 | save_secret_recover_path = osp.join(secret_recover_path, 222 | str(count) + '_' + str(i) + '_secret_recover.png') 223 | torchvision.utils.save_image(secret[i:i + 1, :, :, :], save_secret_path) 224 | torchvision.utils.save_image(secret_recover[i:i + 1, :, :, :], save_secret_recover_path) 225 | 226 | 227 | #### save models and training states 228 | if current_step % opt['logger']['save_checkpoint_freq'] == 0: 229 | if rank <= 0: 230 | logger.info('Saving models and training states.') 231 | model.save(current_step) 232 | model.save_training_state(epoch, current_step) 233 | 234 | if rank <= 0: 235 | logger.info('Saving the final model.') 236 | model.save('latest') 237 | logger.info('End of training.') 238 | 239 | 240 | if __name__ == '__main__': 241 | main() 242 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /codes/data/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import pickle 4 | import random 5 | import numpy as np 6 | import torch 7 | import cv2 8 | from PIL import Image 9 | 10 | #################### 11 | # Files & IO 12 | #################### 13 | 14 | ###################### get image path list ###################### 15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 16 | 17 | 18 | def is_image_file(filename): 19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 20 | 21 | 22 | def _get_paths_from_images(path): 23 | '''get image path list from image folder''' 24 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 25 | images = [] 26 | for dirpath, _, fnames in sorted(os.walk(path)): 27 | for fname in sorted(fnames): 28 | if is_image_file(fname): 29 | img_path = os.path.join(dirpath, fname) 30 | images.append(img_path) 31 | assert images, '{:s} has no valid image file'.format(path) 32 | return images 33 | 34 | 35 | def _get_paths_from_lmdb(dataroot): 36 | '''get image path list from lmdb meta info''' 37 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) 38 | paths = meta_info['keys'] 39 | sizes = meta_info['resolution'] 40 | if len(sizes) == 1: 41 | sizes = sizes * len(paths) 42 | return paths, sizes 43 | 44 | 45 | def get_image_paths(data_type, dataroot): 46 | '''get image path list 47 | support lmdb or image files''' 48 | paths, sizes = None, None 49 | if dataroot is not None: 50 | if data_type == 'lmdb': 51 | paths, sizes = _get_paths_from_lmdb(dataroot) 52 | elif data_type == 'img': 53 | paths = sorted(_get_paths_from_images(dataroot)) 54 | else: 55 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) 56 | return paths, sizes 57 | 58 | 59 | ###################### read images ###################### 60 | def _read_img_lmdb(env, key, size): 61 | '''read image from lmdb with key (w/ and w/o fixed size) 62 | size: (C, H, W) tuple''' 63 | with env.begin(write=False) as txn: 64 | buf = txn.get(key.encode('ascii')) 65 | img_flat = np.frombuffer(buf, dtype=np.uint8) 66 | C, H, W = size 67 | img = img_flat.reshape(H, W, C) 68 | return img 69 | 70 | 71 | def read_img(env, path, size=None): 72 | '''read image by cv2 or from lmdb 73 | return: Numpy float32, HWC, BGR, [0,1]''' 74 | if env is None: # img 75 | #img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 76 | img = cv2.imread(path, cv2.IMREAD_COLOR) 77 | # img = Image.open(path,'r') 78 | else: 79 | img = _read_img_lmdb(env, path, size) 80 | img = img.astype(np.float32) / 255. 81 | if img.ndim == 2: 82 | img = np.expand_dims(img, axis=2) 83 | # some images have 4 channels 84 | if img.shape[2] > 3: 85 | img = img[:, :, :3] 86 | return img 87 | 88 | 89 | #################### 90 | # image processing 91 | # process on numpy image 92 | #################### 93 | 94 | 95 | def augment(img_list, hflip=True, rot=True): 96 | # horizontal flip OR rotate 97 | hflip = hflip and random.random() < 0.5 98 | vflip = rot and random.random() < 0.5 99 | rot90 = rot and random.random() < 0.5 100 | 101 | def _augment(img): 102 | if hflip: 103 | img = img[:, ::-1, :] 104 | if vflip: 105 | img = img[::-1, :, :] 106 | if rot90: 107 | img = img.transpose(1, 0, 2) 108 | return img 109 | 110 | return [_augment(img) for img in img_list] 111 | 112 | 113 | def augment_flow(img_list, flow_list, hflip=True, rot=True): 114 | # horizontal flip OR rotate 115 | hflip = hflip and random.random() < 0.5 116 | vflip = rot and random.random() < 0.5 117 | rot90 = rot and random.random() < 0.5 118 | 119 | def _augment(img): 120 | if hflip: 121 | img = img[:, ::-1, :] 122 | if vflip: 123 | img = img[::-1, :, :] 124 | if rot90: 125 | img = img.transpose(1, 0, 2) 126 | return img 127 | 128 | def _augment_flow(flow): 129 | if hflip: 130 | flow = flow[:, ::-1, :] 131 | flow[:, :, 0] *= -1 132 | if vflip: 133 | flow = flow[::-1, :, :] 134 | flow[:, :, 1] *= -1 135 | if rot90: 136 | flow = flow.transpose(1, 0, 2) 137 | flow = flow[:, :, [1, 0]] 138 | return flow 139 | 140 | rlt_img_list = [_augment(img) for img in img_list] 141 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list] 142 | 143 | return rlt_img_list, rlt_flow_list 144 | 145 | 146 | def channel_convert(in_c, tar_type, img_list): 147 | # conversion among BGR, gray and y 148 | if in_c == 3 and tar_type == 'gray': # BGR to gray 149 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] 150 | return [np.expand_dims(img, axis=2) for img in gray_list] 151 | elif in_c == 3 and tar_type == 'y': # BGR to y 152 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] 153 | return [np.expand_dims(img, axis=2) for img in y_list] 154 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR 155 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] 156 | else: 157 | return img_list 158 | 159 | 160 | def rgb2ycbcr(img, only_y=True): 161 | '''same as matlab rgb2ycbcr 162 | only_y: only return Y channel 163 | Input: 164 | uint8, [0, 255] 165 | float, [0, 1] 166 | ''' 167 | in_img_type = img.dtype 168 | img.astype(np.float32) 169 | if in_img_type != np.uint8: 170 | img *= 255. 171 | # convert 172 | if only_y: 173 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 174 | else: 175 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 176 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 177 | if in_img_type == np.uint8: 178 | rlt = rlt.round() 179 | else: 180 | rlt /= 255. 181 | return rlt.astype(in_img_type) 182 | 183 | 184 | def bgr2ycbcr(img, only_y=True): 185 | '''bgr version of rgb2ycbcr 186 | only_y: only return Y channel 187 | Input: 188 | uint8, [0, 255] 189 | float, [0, 1] 190 | ''' 191 | in_img_type = img.dtype 192 | img.astype(np.float32) 193 | if in_img_type != np.uint8: 194 | img *= 255. 195 | # convert 196 | if only_y: 197 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 198 | else: 199 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 200 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 201 | if in_img_type == np.uint8: 202 | rlt = rlt.round() 203 | else: 204 | rlt /= 255. 205 | return rlt.astype(in_img_type) 206 | 207 | 208 | def ycbcr2rgb(img): 209 | '''same as matlab ycbcr2rgb 210 | Input: 211 | uint8, [0, 255] 212 | float, [0, 1] 213 | ''' 214 | in_img_type = img.dtype 215 | img.astype(np.float32) 216 | if in_img_type != np.uint8: 217 | img *= 255. 218 | # convert 219 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 220 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 221 | if in_img_type == np.uint8: 222 | rlt = rlt.round() 223 | else: 224 | rlt /= 255. 225 | return rlt.astype(in_img_type) 226 | 227 | 228 | def modcrop(img_in, scale_W, scale_H): 229 | # img_in: Numpy, HWC or HW 230 | img = np.copy(img_in) 231 | if img.ndim == 2: 232 | H, W = img.shape 233 | H_r, W_r = H % scale_H, W % scale_W 234 | img = img[:H - H_r, :W - W_r] 235 | elif img.ndim == 3: 236 | H, W, C = img.shape 237 | H_r, W_r = H % (scale_H * 4), W % (scale_W * 4) 238 | img = img[:H - H_r, :W - W_r, :] 239 | else: 240 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) 241 | return img 242 | 243 | 244 | #################### 245 | # Functions 246 | #################### 247 | 248 | 249 | # matlab 'imresize' function, now only support 'bicubic' 250 | def cubic(x): 251 | absx = torch.abs(x) 252 | absx2 = absx**2 253 | absx3 = absx**3 254 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ( 255 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (( 256 | (absx > 1) * (absx <= 2)).type_as(absx)) 257 | 258 | 259 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): 260 | if (scale < 1) and (antialiasing): 261 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width 262 | kernel_width = kernel_width / scale 263 | 264 | # Output-space coordinates 265 | x = torch.linspace(1, out_length, out_length) 266 | 267 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 268 | # in output space maps to 0.5 in input space, and 0.5+scale in output 269 | # space maps to 1.5 in input space. 270 | u = x / scale + 0.5 * (1 - 1 / scale) 271 | 272 | # What is the left-most pixel that can be involved in the computation? 273 | left = torch.floor(u - kernel_width / 2) 274 | 275 | # What is the maximum number of pixels that can be involved in the 276 | # computation? Note: it's OK to use an extra pixel here; if the 277 | # corresponding weights are all zero, it will be eliminated at the end 278 | # of this function. 279 | P = math.ceil(kernel_width) + 2 280 | 281 | # The indices of the input pixels involved in computing the k-th output 282 | # pixel are in row k of the indices matrix. 283 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( 284 | 1, P).expand(out_length, P) 285 | 286 | # The weights used to compute the k-th output pixel are in row k of the 287 | # weights matrix. 288 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices 289 | # apply cubic kernel 290 | if (scale < 1) and (antialiasing): 291 | weights = scale * cubic(distance_to_center * scale) 292 | else: 293 | weights = cubic(distance_to_center) 294 | # Normalize the weights matrix so that each row sums to 1. 295 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 296 | weights = weights / weights_sum.expand(out_length, P) 297 | 298 | # If a column in weights is all zero, get rid of it. only consider the first and last column. 299 | weights_zero_tmp = torch.sum((weights == 0), 0) 300 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 301 | indices = indices.narrow(1, 1, P - 2) 302 | weights = weights.narrow(1, 1, P - 2) 303 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 304 | indices = indices.narrow(1, 0, P - 2) 305 | weights = weights.narrow(1, 0, P - 2) 306 | weights = weights.contiguous() 307 | indices = indices.contiguous() 308 | sym_len_s = -indices.min() + 1 309 | sym_len_e = indices.max() - in_length 310 | indices = indices + sym_len_s - 1 311 | return weights, indices, int(sym_len_s), int(sym_len_e) 312 | 313 | 314 | def imresize(img, scale, antialiasing=True): 315 | # Now the scale should be the same for H and W 316 | # input: img: CHW RGB [0,1] 317 | # output: CHW RGB [0,1] w/o round 318 | 319 | in_C, in_H, in_W = img.size() 320 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 321 | kernel_width = 4 322 | kernel = 'cubic' 323 | 324 | # Return the desired dimension order for performing the resize. The 325 | # strategy is to perform the resize first along the dimension with the 326 | # smallest scale factor. 327 | # Now we do not support this. 328 | 329 | # get weights and indices 330 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 331 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 332 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 333 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 334 | # process H dimension 335 | # symmetric copying 336 | img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) 337 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) 338 | 339 | sym_patch = img[:, :sym_len_Hs, :] 340 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 341 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 342 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) 343 | 344 | sym_patch = img[:, -sym_len_He:, :] 345 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 346 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 347 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 348 | 349 | out_1 = torch.FloatTensor(in_C, out_H, in_W) 350 | kernel_width = weights_H.size(1) 351 | for i in range(out_H): 352 | idx = int(indices_H[i][0]) 353 | out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 354 | out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 355 | out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 356 | 357 | # process W dimension 358 | # symmetric copying 359 | out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) 360 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) 361 | 362 | sym_patch = out_1[:, :, :sym_len_Ws] 363 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 364 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 365 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) 366 | 367 | sym_patch = out_1[:, :, -sym_len_We:] 368 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 369 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 370 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 371 | 372 | out_2 = torch.FloatTensor(in_C, out_H, out_W) 373 | kernel_width = weights_W.size(1) 374 | for i in range(out_W): 375 | idx = int(indices_W[i][0]) 376 | out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) 377 | out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) 378 | out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) 379 | 380 | return out_2 381 | 382 | 383 | def imresize_np(img, scale_W, scale_H, antialiasing=True): 384 | # Now the scale should be the same for H and W 385 | # input: img: Numpy, HWC BGR [0,1] 386 | # output: HWC BGR [0,1] w/o round 387 | img = torch.from_numpy(img) 388 | 389 | in_H, in_W, in_C = img.size() 390 | _, out_H, out_W = in_C, math.ceil(in_H * scale_H), math.ceil(in_W * scale_W) 391 | kernel_width = 4 392 | kernel = 'cubic' 393 | 394 | # Return the desired dimension order for performing the resize. The 395 | # strategy is to perform the resize first along the dimension with the 396 | # smallest scale factor. 397 | # Now we do not support this. 398 | 399 | # get weights and indices 400 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 401 | in_H, out_H, scale_H, kernel, kernel_width, antialiasing) 402 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 403 | in_W, out_W, scale_W, kernel, kernel_width, antialiasing) 404 | # process H dimension 405 | # symmetric copying 406 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) 407 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) 408 | 409 | sym_patch = img[:sym_len_Hs, :, :] 410 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 411 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 412 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) 413 | 414 | sym_patch = img[-sym_len_He:, :, :] 415 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 416 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 417 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 418 | 419 | out_1 = torch.FloatTensor(out_H, in_W, in_C) 420 | kernel_width = weights_H.size(1) 421 | for i in range(out_H): 422 | idx = int(indices_H[i][0]) 423 | out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) 424 | out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) 425 | out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) 426 | 427 | # process W dimension 428 | # symmetric copying 429 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) 430 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) 431 | 432 | sym_patch = out_1[:, :sym_len_Ws, :] 433 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 434 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 435 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) 436 | 437 | sym_patch = out_1[:, -sym_len_We:, :] 438 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 439 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 440 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 441 | 442 | out_2 = torch.FloatTensor(out_H, out_W, in_C) 443 | kernel_width = weights_W.size(1) 444 | for i in range(out_W): 445 | idx = int(indices_W[i][0]) 446 | out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) 447 | out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) 448 | out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) 449 | 450 | return out_2.numpy() 451 | 452 | 453 | if __name__ == '__main__': 454 | # test imresize function 455 | # read images 456 | img = cv2.imread('test.png') 457 | img = img * 1.0 / 255 458 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() 459 | # imresize 460 | scale = 1 / 4 461 | import time 462 | total_time = 0 463 | for i in range(10): 464 | start_time = time.time() 465 | rlt = imresize(img, scale, antialiasing=True) 466 | use_time = time.time() - start_time 467 | total_time += use_time 468 | print('average time: {}'.format(total_time / 10)) 469 | 470 | import torchvision.utils 471 | torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, 472 | normalize=False) 473 | -------------------------------------------------------------------------------- /codes/models/InvMIHNet_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import OrderedDict 4 | 5 | import torch 6 | from torch.nn.parallel import DistributedDataParallel 7 | import models.networks as networks 8 | import models.lr_scheduler as lr_scheduler 9 | from .base_model import BaseModel 10 | from models.modules.loss import ReconstructionLoss 11 | from models.modules.Quantization import Quantization 12 | import numpy as np 13 | import models.modules.Unet_common as common 14 | from models.model import * 15 | from .model import * 16 | from PIL import Image 17 | 18 | logger = logging.getLogger('base') 19 | 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | def gauss_noise(shape): 23 | noise = torch.zeros(shape).cuda() 24 | for i in range(noise.shape[0]): 25 | noise[i] = torch.randn(noise[i].shape).cuda() 26 | 27 | return noise 28 | 29 | 30 | def guide_loss(output, bicubic_image): 31 | loss_fn = torch.nn.MSELoss(reduce=True, size_average=False) 32 | loss = loss_fn(output, bicubic_image) 33 | return loss.to(device) 34 | 35 | 36 | def reconstruction_loss(rev_input, input): 37 | loss_fn = torch.nn.MSELoss(reduce=True, size_average=False) 38 | loss = loss_fn(rev_input, input) 39 | return loss.to(device) 40 | 41 | 42 | def low_frequency_loss(ll_input, gt_input): 43 | loss_fn = torch.nn.MSELoss(reduce=True, size_average=False) 44 | loss = loss_fn(ll_input, gt_input) 45 | return loss.to(device) 46 | 47 | 48 | 49 | def image_save(img,save_dir,img_name): 50 | if not os.path.exists(save_dir): 51 | os.makedirs(save_dir) 52 | img_path = os.path.join(save_dir, img_name) 53 | at_images_np = img.detach().cpu().numpy() 54 | adv_img = at_images_np[0] 55 | adv_img = np.moveaxis(adv_img, 0, 2) 56 | img_pil = Image.fromarray(adv_img.astype(np.uint8)) 57 | img_pil.save(img_path) 58 | 59 | 60 | dwt = common.DWT() 61 | iwt = common.IWT() 62 | 63 | 64 | 65 | class InvMIHNet(BaseModel): 66 | def __init__(self, opt): 67 | super(InvMIHNet, self).__init__(opt) 68 | 69 | if opt['dist']: 70 | self.rank = torch.distributed.get_rank() 71 | else: 72 | self.rank = -1 # non dist training 73 | train_opt = opt['train'] 74 | test_opt = opt['test'] 75 | self.train_opt = train_opt 76 | self.test_opt = test_opt 77 | 78 | self.netG = networks.define_G(opt).to(self.device) 79 | self.netH = Model().to(self.device) 80 | init_model(self.netH) 81 | 82 | # print network 83 | self.print_network() 84 | self.load() 85 | self.load_H(self.netH) 86 | 87 | self.Quantization = Quantization() 88 | 89 | if self.is_train: 90 | self.netG.train() 91 | self.netH.train() 92 | 93 | # loss 94 | self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw']) 95 | self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back']) 96 | 97 | 98 | # optimizers 99 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 100 | optim_params_G = [] 101 | for k, v in self.netG.named_parameters(): 102 | if v.requires_grad: 103 | optim_params_G.append(v) 104 | else: 105 | if self.rank <= 0: 106 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 107 | self.optimizer_G = torch.optim.Adam(optim_params_G, lr=train_opt['lr_G'], 108 | weight_decay=wd_G, 109 | betas=(train_opt['beta1'], train_opt['beta2'])) 110 | self.optimizers.append(self.optimizer_G) 111 | 112 | optim_params_H = [] 113 | for k, v in self.netH.named_parameters(): 114 | if v.requires_grad: 115 | optim_params_H.append(v) 116 | else: 117 | if self.rank <= 0: 118 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 119 | self.optimizer_H = torch.optim.Adam(optim_params_H, lr=train_opt['lr_H'], betas=(train_opt['beta1_H'], train_opt['beta2_H']), eps=1e-6, weight_decay=train_opt['weight_decay_H']) 120 | 121 | # schedulers 122 | if train_opt['lr_scheme'] == 'MultiStepLR': 123 | for optimizer in self.optimizers: 124 | self.schedulers.append( 125 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 126 | restarts=train_opt['restarts'], 127 | weights=train_opt['restart_weights'], 128 | gamma=train_opt['lr_gamma'], 129 | clear_state=train_opt['clear_state'])) 130 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 131 | for optimizer in self.optimizers: 132 | self.schedulers.append( 133 | lr_scheduler.CosineAnnealingLR_Restart( 134 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 135 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 136 | else: 137 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 138 | 139 | self.schedulers_H = torch.optim.lr_scheduler.StepLR(self.optimizer_H, train_opt['weight_step'], gamma=train_opt['lr_gamma']) 140 | 141 | self.log_dict = OrderedDict() 142 | 143 | def feed_data(self, data): 144 | self.ref_L = data['LQ'].to(self.device) # LQ 145 | self.real_H = data['GT'].to(self.device) # GT 146 | 147 | def gaussian_batch(self, dims): 148 | return torch.randn(tuple(dims)).to(self.device) 149 | 150 | def loss_forward(self, out, y, z): 151 | l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y) 152 | 153 | z = z.reshape([out.shape[0], -1]) 154 | l_forw_ce = self.train_opt['lambda_ce_forw'] * torch.sum(z**2) / z.shape[0] 155 | 156 | return l_forw_fit, l_forw_ce 157 | 158 | def loss_backward(self, x, y): 159 | x_samples = self.netG(x=y, rev=True) 160 | x_samples_image = x_samples[:, :3, :, :] 161 | l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image) 162 | 163 | return l_back_rec 164 | 165 | def get_parameter_number(net): 166 | total_num = sum(p.numel() for p in net.parameters()) 167 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 168 | return {'Total': total_num, 'Trainable': trainable_num} 169 | 170 | def optimize_parameters(self, step): 171 | # torch.autograd.set_detect_anomaly(True) 172 | 173 | # downscaling 174 | self.input = self.real_H[1: ,: , :, :] 175 | self.output = self.netG(x=self.input) 176 | 177 | zshape = self.output[:, 3:, :, :].shape 178 | LR_ref = self.ref_L[1: ,: , :, :].detach() 179 | 180 | secret = self.output[:, :3, :, :] 181 | new_H = int(self.input.shape[2] / self.opt['scale_H']) 182 | new_W = int(self.input.shape[3] / self.opt['scale_W']) 183 | if self.opt['scale_H'] * self.opt['scale_W'] == 4: 184 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:], 185 | self.real_H[:1, :3, new_H:, :new_W], self.real_H[:1, :3, new_H:, new_W:]), dim=0) 186 | elif self.opt['scale_H'] * self.opt['scale_W'] == 6: 187 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2], 188 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], 189 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2], 190 | self.real_H[:1, :3, new_H * 2:new_H * 4, :new_W], 191 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2],), dim=0) 192 | elif self.opt['scale_H'] * self.opt['scale_W'] == 8: 193 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2], 194 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], 195 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2], 196 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], 197 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2], 198 | self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W], 199 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2]), dim=0) 200 | elif self.opt['scale_H'] * self.opt['scale_W'] == 9: 201 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_H * 2], 202 | self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3], 203 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], 204 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_H * 2], 205 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3], 206 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], 207 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_H * 2], 208 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3],), dim=0) 209 | elif self.opt['scale_H'] * self.opt['scale_W'] == 16: 210 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2], 211 | self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3], 212 | self.real_H[:1, :3, :new_H, new_W * 3:new_W * 4], 213 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], 214 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2], 215 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3], 216 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 3:new_W * 4], 217 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], 218 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2], 219 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3], 220 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 3:new_W * 4], 221 | self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W], 222 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2], 223 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 2:new_W * 3], 224 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 3:new_W * 4], 225 | ), dim=0) 226 | 227 | cover_input = dwt(cover) 228 | secret_input = dwt(secret) 229 | input_img = torch.cat((cover_input, secret_input), 1) 230 | 231 | # hiding 232 | output = self.netH(input_img) 233 | 234 | channel_in = self.opt['network_G']['in_nc'] 235 | 236 | output_steg = output.narrow(1, 0, 4 * channel_in) 237 | output_z = output.narrow(1, 4 * channel_in, output.shape[1] - 4 * channel_in) 238 | steg_img = iwt(output_steg) 239 | steg_img = self.Quantization(steg_img) 240 | 241 | # concealing 242 | output_z_guass = gauss_noise(output_z.shape) 243 | 244 | output_rev = torch.cat((output_steg, output_z_guass), 1) 245 | output_image = self.netH(output_rev, rev=True) 246 | 247 | secret_rev = output_image.narrow(1, 4 * channel_in, output_image.shape[1] - 4 * channel_in) 248 | secret_rev_1 = iwt(secret_rev) 249 | 250 | 251 | # loss functions 252 | g_loss = guide_loss(steg_img.cuda(), cover.cuda()) 253 | r_loss = reconstruction_loss(secret_rev_1, secret[:,:3,:,:]) 254 | steg_low = output_steg.narrow(1, 0, channel_in) 255 | cover_low = cover_input.narrow(1, 0, channel_in) 256 | l_loss = low_frequency_loss(steg_low, cover_low) 257 | 258 | total_loss = self.train_opt['lamda_reconstruction'] * r_loss + self.train_opt['lamda_guide'] * g_loss + self.train_opt['lamda_low_frequency'] * l_loss 259 | 260 | l_forw_fit, l_forw_ce = self.loss_forward(secret_rev_1, LR_ref, self.output[:, 3:, :, :]) 261 | 262 | # upscaling 263 | LR = self.Quantization(secret_rev_1) 264 | gaussian_scale = self.train_opt['gaussian_scale'] if self.train_opt['gaussian_scale'] != None else 1 265 | y_ = torch.cat((LR, gaussian_scale * self.gaussian_batch(zshape)), dim=1) 266 | 267 | l_back_rec = self.loss_backward(self.real_H[1:,:,:,:], y_) 268 | 269 | total_loss.backward(retain_graph=True) 270 | 271 | loss = l_forw_fit + l_back_rec + l_forw_ce 272 | print("step", step, "_loss:", loss) 273 | loss.backward() 274 | 275 | # gradient clipping 276 | if self.train_opt['gradient_clipping']: 277 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) 278 | nn.utils.clip_grad_norm_(self.netH.parameters(), self.train_opt['gradient_clipping']) 279 | 280 | self.optimizer_G.step() 281 | self.optimizer_H.step() 282 | self.optimizer_H.zero_grad() 283 | self.optimizer_G.zero_grad() 284 | 285 | if step % self.opt['logger']['save_checkpoint_freq'] == 0: 286 | # save_path = os.path.join(self.opt['path']['models'], save_filename) 287 | torch.save({'net': self.netH.state_dict()}, os.path.join(self.opt['path']['models'], 'IIH_' + str(step)+'.pth')) 288 | 289 | 290 | def test(self): 291 | Lshape = self.ref_L[1:, :, :, :].shape 292 | 293 | input_dim = Lshape[1] 294 | self.input = self.real_H[1:] 295 | 296 | zshape = [Lshape[0], input_dim * (self.opt['scale_W'] * self.opt['scale_H']) - Lshape[1], Lshape[2], Lshape[3]] 297 | 298 | gaussian_scale = 1 299 | if self.test_opt and self.test_opt['gaussian_scale'] != None: 300 | gaussian_scale = self.test_opt['gaussian_scale'] 301 | 302 | self.netG.eval() 303 | self.netH.eval() 304 | new_H = int(self.input.shape[2] / self.opt['scale_H']) 305 | new_W = int(self.input.shape[3] / self.opt['scale_W']) 306 | if self.opt['scale_H']*self.opt['scale_W'] == 4: 307 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:], self.real_H[:1, :3, new_H:, :new_W], self.real_H[:1, :3, new_H:, new_W:]), dim=0) 308 | elif self.opt['scale_H']*self.opt['scale_W'] == 6: 309 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2], 310 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2], 311 | self.real_H[:1, :3, new_H*2:new_H*4, :new_W], self.real_H[:1, :3, new_H*2:new_H*3, new_W:new_W * 2],), dim=0) 312 | elif self.opt['scale_H']*self.opt['scale_W'] == 8: 313 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2], self.real_H[:1, :3, new_H:new_H * 2, :new_W], self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W*2], 314 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2], self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W], self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2]), dim=0) 315 | elif self.opt['scale_H']*self.opt['scale_W'] == 9: 316 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_H * 2], self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3], 317 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], self.real_H[:1, :3, new_H:new_H * 2, new_W:new_H * 2], self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3], 318 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_H * 2], self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3],), dim=0) 319 | elif self.opt['scale_H']*self.opt['scale_W'] == 16: 320 | cover = torch.cat((self.real_H[:1, :3, :new_H, :new_W], self.real_H[:1, :3, :new_H, new_W:new_W * 2], 321 | self.real_H[:1, :3, :new_H, new_W * 2:new_W * 3], 322 | self.real_H[:1, :3, :new_H, new_W * 3:new_W * 4], 323 | self.real_H[:1, :3, new_H:new_H * 2, :new_W], 324 | self.real_H[:1, :3, new_H:new_H * 2, new_W:new_W * 2], 325 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 2:new_W * 3], 326 | self.real_H[:1, :3, new_H:new_H * 2, new_W * 3:new_W * 4], 327 | self.real_H[:1, :3, new_H * 2:new_H * 3, :new_W], 328 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W:new_W * 2], 329 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 2:new_W * 3], 330 | self.real_H[:1, :3, new_H * 2:new_H * 3, new_W * 3:new_W * 4], 331 | self.real_H[:1, :3, new_H * 3:new_H * 4, :new_W], 332 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W:new_W * 2], 333 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 2:new_W * 3], 334 | self.real_H[:1, :3, new_H * 3:new_H * 4, new_W * 3:new_W * 4], 335 | ), dim=0) 336 | 337 | self.cover = self.real_H[:1, :3, :, :] 338 | self.secret = self.real_H[1:, :3, :, :] 339 | 340 | with torch.no_grad(): 341 | output = self.netG(x=self.input)[:, :3, :, :] 342 | cover_input = dwt(cover) 343 | secret_input = dwt(output) 344 | input_img = torch.cat((cover_input, secret_input), 1) 345 | channel_in = self.opt['network_G']['in_nc'] 346 | 347 | output_inn = self.netH(input_img) 348 | output_steg = output_inn.narrow(1, 0, 4 * channel_in) 349 | output_z = output_inn.narrow(1, 4 * channel_in, output_inn.shape[1] - 4 * channel_in) 350 | steg_img = iwt(output_steg) 351 | 352 | output_z_guass = gauss_noise(output_z.shape) 353 | 354 | output_rev = torch.cat((output_steg, output_z_guass), 1) 355 | output_image = self.netH(output_rev, rev=True) 356 | 357 | secret_rev = output_image.narrow(1, 4 * channel_in, output_image.shape[1] - 4 * channel_in) 358 | secret_rev = iwt(secret_rev) 359 | 360 | self.forw_L = secret_rev 361 | self.forw_L = self.Quantization(self.forw_L).cuda() 362 | y_forw = torch.cat((self.forw_L, gaussian_scale * self.gaussian_batch(zshape)), dim=1) 363 | self.fake_H = self.netG(x=y_forw, rev=True)[:, :3, :, :] 364 | 365 | self.secret_recover = self.fake_H 366 | 367 | if self.opt['scale_H'] * self.opt['scale_W'] == 4: 368 | steg_1 = torch.cat((steg_img[0], steg_img[1]), 2) 369 | steg_2 = torch.cat((steg_img[2], steg_img[3]), 2) 370 | steg = torch.cat((steg_1, steg_2), 1) 371 | elif self.opt['scale_H'] * self.opt['scale_W'] == 6: 372 | steg_1 = torch.cat((steg_img[0], steg_img[1]), 2) 373 | steg_2 = torch.cat((steg_img[2], steg_img[3]), 2) 374 | steg_3 = torch.cat((steg_img[4], steg_img[5]), 2) 375 | steg = torch.cat((steg_1, steg_2, steg_3), 1) 376 | elif self.opt['scale_H'] * self.opt['scale_W'] == 8: 377 | steg_1 = torch.cat((steg_img[0], steg_img[1]), 2) 378 | steg_2 = torch.cat((steg_img[2], steg_img[3]), 2) 379 | steg_3 = torch.cat((steg_img[4], steg_img[5]), 2) 380 | steg_4 = torch.cat((steg_img[6], steg_img[7]), 2) 381 | steg = torch.cat((steg_1, steg_2, steg_3, steg_4), 1) 382 | elif self.opt['scale_H'] * self.opt['scale_W'] == 9: 383 | steg_1 = torch.cat((steg_img[0], steg_img[1], steg_img[2]), 2) 384 | steg_2 = torch.cat((steg_img[3], steg_img[4], steg_img[5]), 2) 385 | steg_3 = torch.cat((steg_img[6], steg_img[7], steg_img[8]), 2) 386 | steg = torch.cat((steg_1, steg_2, steg_3), 1) 387 | elif self.opt['scale_H'] * self.opt['scale_W'] == 16: 388 | steg_1 = torch.cat((steg_img[0], steg_img[1], steg_img[2], steg_img[3]), 2) 389 | steg_2 = torch.cat((steg_img[4], steg_img[5], steg_img[6], steg_img[7]), 2) 390 | steg_3 = torch.cat((steg_img[8], steg_img[9], steg_img[10], steg_img[11]), 2) 391 | steg_4 = torch.cat((steg_img[12], steg_img[13], steg_img[14], steg_img[15]), 2) 392 | steg = torch.cat((steg_1, steg_2, steg_3, steg_4), 1) 393 | self.steg = self.Quantization(steg) 394 | 395 | def get_current_log(self): 396 | return self.log_dict 397 | 398 | def get_current_visuals(self): 399 | out_dict = OrderedDict() 400 | out_dict['cover'] = self.cover.detach().float().cpu() 401 | out_dict['secret'] = self.secret.detach().float().cpu() 402 | out_dict['secret_recover'] = self.secret_recover.detach().float().cpu() 403 | out_dict['steg'] = self.steg.detach().float().cpu() 404 | return out_dict 405 | 406 | def print_network(self): 407 | s, n = self.get_network_description(self.netG) 408 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 409 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 410 | self.netG.module.__class__.__name__) 411 | else: 412 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 413 | if self.rank <= 0: 414 | logger.info('Network IIR structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 415 | logger.info(s) 416 | 417 | def load(self): 418 | load_path_G = self.opt['path']['pretrain_model_G'] 419 | if load_path_G is not None: 420 | logger.info('Loading model for IIR [{:s}] ...'.format(load_path_G)) 421 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 422 | 423 | def load_H(self, net): 424 | load_path_H = self.opt['path']['pretrain_model_H'] 425 | if load_path_H is not None: 426 | logger.info('Loading model for IIH [{:s}] ...'.format(load_path_H)) 427 | state_dicts = torch.load(load_path_H) 428 | network_state_dict = {k.replace("module.", ""): v for k, v in state_dicts['net'].items() if 429 | 'tmp_var' not in k} 430 | network_state_dict = {k: v for k, v in network_state_dict.items() if 'rect' not in k} 431 | net.load_state_dict(network_state_dict) 432 | 433 | def save(self, iter_label): 434 | self.save_network(self.netG, 'IIR', iter_label) 435 | -------------------------------------------------------------------------------- /codes/models/modules/Unet_common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.modules import module_util as mutil 7 | import functools 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1, use_snorm=False): 10 | if use_snorm: 11 | return nn.utils.spectral_norm(nn.Conv2d( 12 | in_channels, out_channels, kernel_size, 13 | padding=(kernel_size//2)+dilation-1, bias=bias, dilation=dilation)) 14 | else: 15 | return nn.Conv2d( 16 | in_channels, out_channels, kernel_size, 17 | padding=(kernel_size//2)+dilation-1, bias=bias, dilation=dilation) 18 | 19 | 20 | def default_conv1(in_channels, out_channels, kernel_size, bias=True, groups=3, use_snorm=False): 21 | if use_snorm: 22 | return nn.utils.spectral_norm(nn.Conv2d( 23 | in_channels,out_channels, kernel_size, 24 | padding=(kernel_size//2), bias=bias, groups=groups)) 25 | else: 26 | return nn.Conv2d( 27 | in_channels,out_channels, kernel_size, 28 | padding=(kernel_size//2), bias=bias, groups=groups) 29 | 30 | def default_conv3d(in_channels, out_channels, kernel_size, t_kernel=3, bias=True, dilation=1, groups=1, use_snorm=False): 31 | if use_snorm: 32 | return nn.utils.spectral_norm(nn.Conv3d( 33 | in_channels,out_channels, (t_kernel, kernel_size, kernel_size), stride=1, 34 | padding=(0,kernel_size//2,kernel_size//2), bias=bias, dilation=dilation, groups=groups)) 35 | else: 36 | return nn.Conv3d( 37 | in_channels,out_channels, (t_kernel, kernel_size, kernel_size), stride=1, 38 | padding=(0,kernel_size//2,kernel_size//2), bias=bias, dilation=dilation, groups=groups) 39 | 40 | #def shuffle_channel() 41 | 42 | def channel_shuffle(x, groups): 43 | batchsize, num_channels, height, width = x.size() 44 | 45 | channels_per_group = num_channels // groups 46 | 47 | # reshape 48 | x = x.view(batchsize, groups, 49 | channels_per_group, height, width) 50 | 51 | x = torch.transpose(x, 1, 2).contiguous() 52 | 53 | # flatten 54 | x = x.view(batchsize, -1, height, width) 55 | 56 | return x 57 | 58 | def pixel_down_shuffle(x, downsacale_factor): 59 | batchsize, num_channels, height, width = x.size() 60 | 61 | out_height = height // downsacale_factor 62 | out_width = width // downsacale_factor 63 | input_view = x.contiguous().view(batchsize, num_channels, out_height, downsacale_factor, out_width, 64 | downsacale_factor) 65 | 66 | num_channels *= downsacale_factor ** 2 67 | unshuffle_out = input_view.permute(0,1,3,5,2,4).contiguous() 68 | 69 | return unshuffle_out.view(batchsize, num_channels, out_height, out_width) 70 | 71 | 72 | 73 | def sp_init(x): 74 | 75 | x01 = x[:, :, 0::2, :] 76 | x02 = x[:, :, 1::2, :] 77 | x_LL = x01[:, :, :, 0::2] 78 | x_HL = x02[:, :, :, 0::2] 79 | x_LH = x01[:, :, :, 1::2] 80 | x_HH = x02[:, :, :, 1::2] 81 | 82 | 83 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) 84 | 85 | def dwt_init3d(x): 86 | 87 | x01 = x[:, :, :, 0::2, :] / 2 88 | x02 = x[:, :, :, 1::2, :] / 2 89 | x1 = x01[:, :, :, :, 0::2] 90 | x2 = x02[:, :, :, :, 0::2] 91 | x3 = x01[:, :, :, :, 1::2] 92 | x4 = x02[:, :, :, :, 1::2] 93 | x_LL = x1 + x2 + x3 + x4 94 | x_HL = -x1 - x2 + x3 + x4 95 | x_LH = -x1 + x2 - x3 + x4 96 | x_HH = x1 - x2 - x3 + x4 97 | 98 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) 99 | 100 | def dwt_init(x): 101 | 102 | x01 = x[:, :, 0::2, :] / 2 103 | x02 = x[:, :, 1::2, :] / 2 104 | x1 = x01[:, :, :, 0::2] 105 | x2 = x02[:, :, :, 0::2] 106 | x3 = x01[:, :, :, 1::2] 107 | x4 = x02[:, :, :, 1::2] 108 | x_LL = x1 + x2 + x3 + x4 109 | x_HL = -x1 - x2 + x3 + x4 110 | x_LH = -x1 + x2 - x3 + x4 111 | x_HH = x1 - x2 - x3 + x4 112 | 113 | return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) 114 | 115 | def iwt_init(x): 116 | r = 2 117 | in_batch, in_channel, in_height, in_width = x.size() 118 | #print([in_batch, in_channel, in_height, in_width]) 119 | out_batch, out_channel, out_height, out_width = in_batch, int( 120 | in_channel / (r ** 2)), r * in_height, r * in_width 121 | x1 = x[:, 0:out_channel, :, :] / 2 122 | x2 = x[:, out_channel:out_channel * 2, :, :] / 2 123 | x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 124 | x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 125 | 126 | 127 | h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda() 128 | 129 | h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 130 | h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 131 | h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 132 | h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 133 | 134 | return h 135 | 136 | class ResidualDenseBlock(nn.Module): 137 | def __init__(self, nf=64, gc=32, kernel_size = 3, bias=True, use_snorm=False): 138 | super(ResidualDenseBlock, self).__init__() 139 | # gc: growth channel, i.e. intermediate channels 140 | if use_snorm: 141 | self.conv1 = nn.utils.spectral_norm(nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)) 142 | self.conv2 = nn.utils.spectral_norm(nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)) 143 | self.conv3 = nn.utils.spectral_norm(nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)) 144 | self.conv4 = nn.utils.spectral_norm(nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)) 145 | self.conv5 = nn.utils.spectral_norm(nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)) 146 | else: 147 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 148 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 149 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 150 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 151 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 152 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 153 | 154 | # initialization 155 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 156 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) 157 | 158 | def forward(self, x): 159 | x1 = self.lrelu(self.conv1(x)) 160 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 161 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 162 | # x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) 163 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 164 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 165 | return x5 * 0.2 + x 166 | 167 | class RRDB(nn.Module): 168 | '''Residual in Residual Dense Block''' 169 | 170 | def __init__(self, nf, gc=32, use_snorm=False): 171 | super(RRDB, self).__init__() 172 | self.RDB1 = ResidualDenseBlock(nf, gc, use_snorm) 173 | # self.RDB2 = ResidualDenseBlock(nf, gc) 174 | # self.RDB3 = ResidualDenseBlock(nf, gc) 175 | 176 | def forward(self, x): 177 | out = self.RDB1(x) 178 | # out = self.RDB2(out) 179 | # out = self.RDB3(out) 180 | return out * 0.2 + x 181 | 182 | class RRDBblock(nn.Module): 183 | '''Residual in Residual Dense Block''' 184 | 185 | def __init__(self, nf, gc=32, nb=23, use_snorm=False): 186 | super(RRDBblock, self).__init__() 187 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc, use_snorm=use_snorm) 188 | 189 | self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) 190 | if use_snorm: 191 | self.trunk_conv = nn.utils.spectral_norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) 192 | else: 193 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 194 | 195 | def forward(self, x): 196 | 197 | return self.trunk_conv(self.RRDB_trunk(x)) 198 | 199 | class Channel_Shuffle(nn.Module): 200 | def __init__(self, conv_groups): 201 | super(Channel_Shuffle, self).__init__() 202 | self.conv_groups = conv_groups 203 | self.requires_grad = False 204 | 205 | def forward(self, x): 206 | return channel_shuffle(x, self.conv_groups) 207 | 208 | class SP(nn.Module): 209 | def __init__(self): 210 | super(SP, self).__init__() 211 | self.requires_grad = False 212 | 213 | def forward(self, x): 214 | return sp_init(x) 215 | 216 | class Pixel_Down_Shuffle(nn.Module): 217 | def __init__(self): 218 | super(Pixel_Down_Shuffle, self).__init__() 219 | self.requires_grad = False 220 | 221 | def forward(self, x): 222 | return pixel_down_shuffle(x, 2) 223 | 224 | class DWT(nn.Module): 225 | def __init__(self): 226 | super(DWT, self).__init__() 227 | self.requires_grad = False 228 | 229 | def forward(self, x): 230 | return dwt_init(x) 231 | 232 | class DWT3d(nn.Module): 233 | def __init__(self): 234 | super(DWT3d, self).__init__() 235 | self.requires_grad = False 236 | 237 | def forward(self, x): 238 | return dwt_init3d(x) 239 | 240 | class IWT(nn.Module): 241 | def __init__(self): 242 | super(IWT, self).__init__() 243 | self.requires_grad = False 244 | 245 | def forward(self, x): 246 | return iwt_init(x) 247 | 248 | 249 | class MeanShift(nn.Conv2d): 250 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 251 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 252 | std = torch.Tensor(rgb_std) 253 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 254 | self.weight.data.div_(std.view(3, 1, 1, 1)) 255 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 256 | self.bias.data.div_(std) 257 | self.requires_grad = False 258 | if sign==-1: 259 | self.create_graph = False 260 | self.volatile = True 261 | class MeanShift2(nn.Conv2d): 262 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 263 | super(MeanShift2, self).__init__(4, 4, kernel_size=1) 264 | std = torch.Tensor(rgb_std) 265 | self.weight.data = torch.eye(4).view(4, 4, 1, 1) 266 | self.weight.data.div_(std.view(4, 1, 1, 1)) 267 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 268 | self.bias.data.div_(std) 269 | self.requires_grad = False 270 | if sign==-1: 271 | self.volatile = True 272 | 273 | class BasicBlock(nn.Sequential): 274 | def __init__( 275 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 276 | bn=False, act=nn.LeakyReLU(True), use_snorm=False): 277 | 278 | if use_snorm: 279 | m = [nn.utils.spectral_norm(nn.Conv2d( 280 | in_channels, out_channels, kernel_size, 281 | padding=(kernel_size//2), stride=stride, bias=bias)) 282 | ] 283 | else: 284 | m = [nn.Conv2d( 285 | in_channels, out_channels, kernel_size, 286 | padding=(kernel_size//2), stride=stride, bias=bias) 287 | ] 288 | if bn: m.append(nn.BatchNorm2d(out_channels)) 289 | if act is not None: m.append(act) 290 | super(BasicBlock, self).__init__(*m) 291 | 292 | class Block3d(nn.Sequential): 293 | def __init__( 294 | self, in_channels, out_channels, kernel_size, t_kernel=3, 295 | bias=True, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 296 | 297 | super(Block3d, self).__init__() 298 | m = [] 299 | 300 | m.append(default_conv3d(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm)) 301 | m.append(act) 302 | m.append(default_conv3d(out_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm)) 303 | m.append(act) 304 | 305 | 306 | self.body = nn.Sequential(*m) 307 | self.res_scale = res_scale 308 | 309 | def forward(self, x): 310 | x = self.body(x) 311 | return x 312 | 313 | class BBlock(nn.Module): 314 | def __init__( 315 | self, conv, in_channels, out_channels, kernel_size, 316 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 317 | 318 | super(BBlock, self).__init__() 319 | m = [] 320 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm)) 321 | 322 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 323 | m.append(act) 324 | 325 | 326 | self.body = nn.Sequential(*m) 327 | self.res_scale = res_scale 328 | 329 | def forward(self, x): 330 | x = self.body(x) 331 | return x 332 | 333 | class DBlock_com(nn.Module): 334 | def __init__( 335 | self, conv, in_channels, out_channels, kernel_size, 336 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 337 | 338 | super(DBlock_com, self).__init__() 339 | m = [] 340 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 341 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 342 | m.append(act) 343 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3, use_snorm=use_snorm)) 344 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 345 | m.append(act) 346 | 347 | 348 | self.body = nn.Sequential(*m) 349 | self.res_scale = res_scale 350 | 351 | def forward(self, x): 352 | x = self.body(x) 353 | return x 354 | 355 | class DBlock_inv(nn.Module): 356 | def __init__( 357 | self, conv, in_channels, out_channels, kernel_size, 358 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 359 | 360 | super(DBlock_inv, self).__init__() 361 | m = [] 362 | 363 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=3, use_snorm=use_snorm)) 364 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 365 | m.append(act) 366 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 367 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 368 | m.append(act) 369 | 370 | 371 | self.body = nn.Sequential(*m) 372 | self.res_scale = res_scale 373 | 374 | def forward(self, x): 375 | x = self.body(x) 376 | return x 377 | 378 | class DBlock_com1(nn.Module): 379 | def __init__( 380 | self, conv, in_channels, out_channels, kernel_size, 381 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 382 | 383 | super(DBlock_com1, self).__init__() 384 | m = [] 385 | 386 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 387 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 388 | m.append(act) 389 | m.append(conv(out_channels, out_channels, kernel_size, bias=bias, dilation=1, use_snorm=use_snorm)) 390 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 391 | m.append(act) 392 | 393 | 394 | self.body = nn.Sequential(*m) 395 | self.res_scale = res_scale 396 | 397 | def forward(self, x): 398 | x = self.body(x) 399 | return x 400 | 401 | class DBlock_inv1(nn.Module): 402 | def __init__( 403 | self, conv, in_channels, out_channels, kernel_size, 404 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 405 | 406 | super(DBlock_inv1, self).__init__() 407 | m = [] 408 | 409 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 410 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 411 | m.append(act) 412 | m.append(conv(out_channels, out_channels, kernel_size, bias=bias, dilation=1, use_snorm=use_snorm)) 413 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 414 | m.append(act) 415 | 416 | 417 | self.body = nn.Sequential(*m) 418 | self.res_scale = res_scale 419 | 420 | def forward(self, x): 421 | x = self.body(x) 422 | return x 423 | 424 | class DBlock_com2(nn.Module): 425 | def __init__( 426 | self, conv, in_channels, out_channels, kernel_size, 427 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 428 | 429 | super(DBlock_com2, self).__init__() 430 | m = [] 431 | 432 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 433 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 434 | m.append(act) 435 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 436 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 437 | m.append(act) 438 | 439 | 440 | self.body = nn.Sequential(*m) 441 | self.res_scale = res_scale 442 | 443 | def forward(self, x): 444 | x = self.body(x) 445 | return x 446 | 447 | class DBlock_inv2(nn.Module): 448 | def __init__( 449 | self, conv, in_channels, out_channels, kernel_size, 450 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 451 | 452 | super(DBlock_inv2, self).__init__() 453 | m = [] 454 | 455 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 456 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 457 | m.append(act) 458 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, dilation=2, use_snorm=use_snorm)) 459 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 460 | m.append(act) 461 | 462 | 463 | self.body = nn.Sequential(*m) 464 | self.res_scale = res_scale 465 | 466 | def forward(self, x): 467 | x = self.body(x) 468 | return x 469 | 470 | class ShuffleBlock(nn.Module): 471 | def __init__( 472 | self, conv, in_channels, out_channels, kernel_size, 473 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1,conv_groups=1, use_snorm=False): 474 | 475 | super(ShuffleBlock, self).__init__() 476 | m = [] 477 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm)) 478 | m.append(Channel_Shuffle(conv_groups)) 479 | if bn: m.append(nn.BatchNorm2d(out_channels)) 480 | m.append(act) 481 | 482 | 483 | self.body = nn.Sequential(*m) 484 | self.res_scale = res_scale 485 | 486 | def forward(self, x): 487 | x = self.body(x).mul(self.res_scale) 488 | return x 489 | 490 | 491 | class DWBlock(nn.Module): 492 | def __init__( 493 | self, conv, conv1, in_channels, out_channels, kernel_size, 494 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 495 | 496 | super(DWBlock, self).__init__() 497 | m = [] 498 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias, use_snorm=use_snorm)) 499 | if bn: m.append(nn.BatchNorm2d(out_channels)) 500 | m.append(act) 501 | 502 | m.append(conv1(in_channels, out_channels, 1, bias=bias, use_snorm=use_snorm)) 503 | if bn: m.append(nn.BatchNorm2d(out_channels)) 504 | m.append(act) 505 | 506 | 507 | self.body = nn.Sequential(*m) 508 | self.res_scale = res_scale 509 | 510 | def forward(self, x): 511 | x = self.body(x).mul(self.res_scale) 512 | return x 513 | 514 | class ResBlock(nn.Module): 515 | def __init__( 516 | self, conv, n_feat, kernel_size, 517 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 518 | 519 | super(ResBlock, self).__init__() 520 | m = [] 521 | for i in range(2): 522 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias, use_snorm=use_snorm)) 523 | if bn: m.append(nn.BatchNorm2d(n_feat)) 524 | if i == 0: m.append(act) 525 | 526 | self.body = nn.Sequential(*m) 527 | self.res_scale = res_scale 528 | 529 | def forward(self, x): 530 | res = self.body(x).mul(self.res_scale) 531 | res += x 532 | 533 | return res 534 | 535 | class Block(nn.Module): 536 | def __init__( 537 | self, conv, n_feat, kernel_size, 538 | bias=True, bn=False, act=nn.LeakyReLU(True), res_scale=1, use_snorm=False): 539 | 540 | super(Block, self).__init__() 541 | m = [] 542 | for i in range(4): 543 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias, use_snorm=use_snorm)) 544 | if bn: m.append(nn.BatchNorm2d(n_feat)) 545 | if i == 0: m.append(act) 546 | 547 | self.body = nn.Sequential(*m) 548 | self.res_scale = res_scale 549 | 550 | def forward(self, x): 551 | res = self.body(x).mul(self.res_scale) 552 | # res += x 553 | 554 | return res 555 | 556 | class Upsampler(nn.Sequential): 557 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True, use_snorm=False): 558 | 559 | m = [] 560 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 561 | for _ in range(int(math.log(scale, 2))): 562 | m.append(conv(n_feat, 4 * n_feat, 3, bias, use_snorm=use_snorm)) 563 | m.append(nn.PixelShuffle(2)) 564 | if bn: m.append(nn.BatchNorm2d(n_feat)) 565 | if act: m.append(act()) 566 | elif scale == 3: 567 | m.append(conv(n_feat, 9 * n_feat, 3, bias, use_snorm=use_snorm)) 568 | m.append(nn.PixelShuffle(3)) 569 | if bn: m.append(nn.BatchNorm2d(n_feat)) 570 | if act: m.append(act()) 571 | else: 572 | raise NotImplementedError 573 | 574 | super(Upsampler, self).__init__(*m) 575 | 576 | class VGG_conv0(nn.Module): 577 | def __init__(self, in_nc, nf): 578 | 579 | super(VGG_conv0, self).__init__() 580 | # [64, 128, 128] 581 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 582 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 583 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 584 | # [64, 64, 64] 585 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 586 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 587 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 588 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 589 | # [128, 32, 32] 590 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 591 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 592 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 593 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 594 | # [256, 16, 16] 595 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 596 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 597 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 598 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 599 | # [512, 8, 8] 600 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 601 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 602 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 603 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 604 | 605 | # self.avg_pool = nn.AvgPool2d(3, stride=2, padding=0, ceil_mode=True) # /2 606 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 607 | 608 | def forward(self, x): 609 | fea = self.lrelu(self.conv0_0(x)) 610 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 611 | 612 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 613 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 614 | 615 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 616 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 617 | 618 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 619 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 620 | 621 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 622 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 623 | # fea = self.avg_pool(fea) 624 | 625 | return fea 626 | 627 | class VGG_conv1(nn.Module): 628 | def __init__(self, in_nc, nf): 629 | 630 | super(VGG_conv1, self).__init__() 631 | # [64, 128, 128] 632 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 633 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 634 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 635 | # [64, 64, 64] 636 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 637 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 638 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 639 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 640 | # [128, 32, 32] 641 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 642 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 643 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 644 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 645 | # [256, 16, 16] 646 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 647 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 648 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 649 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 650 | # [512, 8, 8] 651 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 652 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 653 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 654 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 655 | 656 | # self.avg_pool = nn.AvgPool2d(2, stride=1, padding=0, ceil_mode=True) # /2 657 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 658 | 659 | def forward(self, x): 660 | fea = self.lrelu(self.conv0_0(x)) 661 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 662 | 663 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 664 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 665 | 666 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 667 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 668 | 669 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 670 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 671 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 672 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 673 | # fea = self.avg_pool(fea) 674 | 675 | return fea 676 | 677 | class VGG_conv2(nn.Module): 678 | def __init__(self, in_nc, nf): 679 | 680 | super(VGG_conv2, self).__init__() 681 | # [64, 128, 128] 682 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 683 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 684 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 685 | # [64, 64, 64] 686 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 687 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 688 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 689 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 690 | # [128, 32, 32] 691 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 692 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 693 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 694 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 695 | # [256, 16, 16] 696 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 697 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 698 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 699 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 700 | # [512, 8, 8] 701 | # self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 702 | # self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 703 | # self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 704 | # self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 705 | 706 | # self.avg_pool = nn.AvgPool2d(3, stride=2, padding=0, ceil_mode=True) # /2 707 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 708 | 709 | def forward(self, x): 710 | fea = self.lrelu(self.conv0_0(x)) 711 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 712 | 713 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 714 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 715 | 716 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 717 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 718 | 719 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 720 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 721 | 722 | # fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 723 | # fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 724 | # fea = self.avg_pool(fea) 725 | 726 | return fea --------------------------------------------------------------------------------