├── data ├── __init__.py ├── Matlab │ ├── Demo.m │ ├── generate_structure_images.m │ ├── tsmooth.m │ └── dirPlus.m └── dataprocess.py ├── models ├── __init__.py ├── model.py ├── FAM │ ├── FeatureAlignment.py │ ├── DeformableBlock.py │ ├── Dynamic_offset_estimator.py │ ├── Model_utils.py │ └── non_local_embedded_gaussian.py ├── InnerCos.py ├── Discriminator.py ├── Decoder.py ├── base_model.py ├── networks.py ├── Encoder.py ├── loss.py ├── RGTSI.py └── PCconv.py ├── util ├── __init__.py ├── se_module.py ├── Selfpatch.py └── util.py ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── imgs └── pipeline.png ├── train.py ├── requirements.txt ├── test.py ├── README.md └── LICENSE /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cameltr/RGTSI/HEAD/imgs/pipeline.png -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from models.RGTSI import RGTSI 2 | import torch 3 | 4 | 5 | def create_model(opt): 6 | model = RGTSI(opt) 7 | #model = torch.nn.DataParallel(model.to(opt.device), device_ids=opt.gpu_ids, output_device=opt.gpu_ids[0]) 8 | print("model [%s] was created" % (model.name())) 9 | return model 10 | 11 | -------------------------------------------------------------------------------- /models/FAM/FeatureAlignment.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from models.FAM.DeformableBlock import DeformableConvBlock 5 | from util.util import showpatch 6 | 7 | class FAM(nn.Module): 8 | def __init__(self,in_channels): 9 | super(FAM, self).__init__() 10 | self.deformblock = DeformableConvBlock(input_channels = in_channels*2) 11 | 12 | def forward(self,ist_feature, rst_feature): 13 | 14 | st_out = self.deformblock(ist_feature, rst_feature) #输出aligned feature 15 | out = torch.add(ist_feature,st_out) 16 | 17 | return out -------------------------------------------------------------------------------- /data/Matlab/Demo.m: -------------------------------------------------------------------------------- 1 | % Demo script 2 | % Uncomment each case to see the results 3 | 4 | I = (imread('imgs/Bishapur_zan.jpg')); 5 | S = tsmooth(I,0.015,3); 6 | figure, imshow(I), figure, imshow(S); 7 | 8 | % I = (imread('imgs/graffiti.jpg')); 9 | % S = tsmooth(I,0.015,3); 10 | % figure, imshow(I), figure, imshow(S); 11 | 12 | % I = (imread('imgs/crossstitch.jpg')); 13 | % S = tsmooth(I,0.015,3); 14 | % figure, imshow(I), figure, imshow(S); 15 | 16 | 17 | % I = (imread('imgs/mosaicfloor.jpg')); 18 | % S = tsmooth(I, 0.01, 3, 0.02, 5); 19 | % figure, imshow(I), figure, imshow(S); 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /data/Matlab/generate_structure_images.m: -------------------------------------------------------------------------------- 1 | dataset_path='/project/liutaorong/ICIP/data/DPED10K/train/image/' 2 | output_path='/project/liutaorong/ICIP/data/DPED10K/train/structure/' 3 | 4 | image_list = dirPlus(dataset_path, 'FileFilter', '\.(jpg|png|tif)$'); 5 | num_image = numel(image_list); 6 | for i=1:num_image 7 | image_name = image_list{i}; 8 | image = im2double(imread(image_name)); 9 | S = tsmooth(image, 0.015, 3, 0.001, 3); 10 | write_name = strrep(image_name, dataset_path, output_path); 11 | [filepath,~,~] = fileparts(write_name); 12 | if ~exist(filepath, 'dir') 13 | mkdir(filepath); 14 | end 15 | imwrite(S, write_name); 16 | 17 | if mod(i,100)==0 18 | fprintf('total: %d; output: %d; completed: %f%% \n',num_image, i, (i/num_image)*100) ; 19 | end 20 | end 21 | 22 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | parser.add_argument('--which_epoch', type=str, default='16', help='which epoch to load? set to latest to use latest cached model') 12 | parser.add_argument('--how_many', type=int, default=200, help='how many test images to run') 13 | self.isTrain = False 14 | 15 | return parser 16 | -------------------------------------------------------------------------------- /util/se_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class SELayer(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super(SELayer, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Conv2d(channel, channel // reduction, kernel_size=1,stride=1, padding=0), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(channel // reduction, channel, kernel_size=1, stride=1, padding=0), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c,1,1) 19 | y = self.fc(y) 20 | return x * y.expand_as(x) 21 | # b, c, _, _ = x.size() 22 | # latter=x.clone() 23 | # y = self.avg_pool(x).view(b, c,1,1) 24 | # 25 | # y = self.fc(y) 26 | # top,ind=torch.topk(y,int(c/2),1) 27 | # ind=ind.view(-1) 28 | # 29 | # x=torch.index_select(x,1,ind) 30 | # 31 | # return torch.cat([x, latter], 1) 32 | -------------------------------------------------------------------------------- /models/FAM/DeformableBlock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.FAM.Dynamic_offset_estimator import Dynamic_offset_estimator 5 | from mmcv.ops.deform_conv import DeformConv2d 6 | from util.util import saveoffset, showpatch 7 | 8 | class DeformableConvBlock(nn.Module): 9 | def __init__(self, input_channels): 10 | super(DeformableConvBlock, self).__init__() 11 | 12 | self.offset_estimator = Dynamic_offset_estimator(input_channelsize=input_channels) 13 | self.offset_conv = nn.Conv2d(in_channels=input_channels, out_channels=1 * 2 * 9, kernel_size=3, padding=1,bias=False) 14 | 15 | self.deformconv = DeformConv2d(in_channels=768, out_channels=768, kernel_size=3, 16 | padding=1, bias=False) 17 | 18 | def forward(self, input_features, reference_features): 19 | 20 | input_offset = torch.cat((input_features, reference_features), dim=1) 21 | estimated_offset = self.offset_estimator(input_offset) 22 | estimated_offset = self.offset_conv(estimated_offset) 23 | output = self.deformconv(x=reference_features, offset=estimated_offset) 24 | 25 | 26 | return output 27 | # 返回aligned feature -------------------------------------------------------------------------------- /models/InnerCos.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import util.util as util 6 | class InnerCos(nn.Module): 7 | def __init__(self): 8 | super(InnerCos, self).__init__() 9 | self.criterion = nn.L1Loss() 10 | self.target = None 11 | self.down_model = nn.Sequential( 12 | nn.Conv2d(256, 3, kernel_size=1,stride=1, padding=0), 13 | nn.Tanh() 14 | ) 15 | 16 | def set_target(self, targetde, targetst): 17 | self.targetst = F.interpolate(targetst, size=(32, 32), mode='bilinear') 18 | self.targetde = F.interpolate(targetde, size=(32, 32), mode='bilinear') 19 | 20 | def get_target(self): 21 | return self.target 22 | 23 | def forward(self, in_data): 24 | self.ST = self.down_model(in_data[6]) 25 | self.DE = self.down_model(in_data[7]) 26 | self.loss = self.criterion(self.ST, self.targetst)+self.criterion(self.DE, self.targetde) 27 | self.output = [in_data[0],in_data[1],in_data[2],in_data[3],in_data[4],in_data[5]] 28 | return self.output 29 | 30 | def backward(self, retain_graph=True): 31 | 32 | self.loss.backward(retain_graph=retain_graph) 33 | return self.loss 34 | 35 | def __repr__(self): 36 | 37 | return self.__class__.__name__ -------------------------------------------------------------------------------- /models/Discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import functools 3 | def spectral_norm(module, mode=True): 4 | if mode: 5 | return nn.utils.spectral_norm(module) 6 | 7 | return module 8 | 9 | class NLayerDiscriminator(nn.Module): 10 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 11 | super(NLayerDiscriminator, self).__init__() 12 | if type(norm_layer) == functools.partial: 13 | use_bias = norm_layer.func == nn.InstanceNorm2d 14 | else: 15 | use_bias = norm_layer == nn.InstanceNorm2d 16 | 17 | kw = 4 18 | padw = 1 19 | sequence = [ 20 | spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),True), 21 | nn.LeakyReLU(0.2, True) 22 | ] 23 | 24 | nf_mult = 1 25 | nf_mult_prev = 1 26 | for n in range(1, n_layers): 27 | nf_mult_prev = nf_mult 28 | nf_mult = min(2**n, 8) 29 | sequence += [ 30 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 31 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),True), 32 | 33 | nn.LeakyReLU(0.2, True), 34 | ] 35 | 36 | nf_mult_prev = nf_mult 37 | nf_mult = min(2**n_layers, 8) 38 | sequence += [ 39 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 40 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), True), 41 | 42 | nn.LeakyReLU(0.2, True) 43 | ] 44 | 45 | sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, 46 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),True)] 47 | 48 | if use_sigmoid: 49 | sequence += [nn.Sigmoid()] 50 | 51 | self.model = nn.Sequential(*sequence) 52 | 53 | def forward(self, input): 54 | return self.model(input) -------------------------------------------------------------------------------- /data/dataprocess.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.utils.data 4 | from PIL import Image 5 | from glob import glob 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | 9 | class DataProcess(torch.utils.data.Dataset): 10 | def __init__(self, de_root, st_root, input_mask_root,ref_root,opt, train=True): 11 | super(DataProcess, self).__init__() 12 | self.img_transform = transforms.Compose([ 13 | transforms.Resize((opt.fineSize,opt.fineSize)), 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) 16 | ]) 17 | # mask should not normalize, is just have 0 or 1 18 | self.mask_transform = transforms.Compose([ 19 | transforms.Resize((opt.fineSize,opt.fineSize)), 20 | transforms.ToTensor() 21 | ]) 22 | self.Train = False 23 | self.opt = opt 24 | if train: 25 | self.de_paths = sorted(glob('{:s}/*'.format(de_root), recursive=True)) 26 | self.st_paths = sorted(glob('{:s}/*'.format(st_root), recursive=True)) 27 | self.mask_paths = sorted(glob('{:s}/*'.format(input_mask_root), recursive=True)) 28 | 29 | self.ref_paths = sorted(glob('{:s}/*'.format(ref_root), recursive=True)) 30 | self.Train=True 31 | 32 | self.N_mask = len(self.mask_paths) 33 | print(self.N_mask) 34 | 35 | def __getitem__(self, index): 36 | 37 | de_img = Image.open(self.de_paths[index]) 38 | st_img = Image.open(self.st_paths[index]) 39 | ref_img = Image.open(self.ref_paths[index]) 40 | mask_img = Image.open(self.mask_paths[random.randint(0, self.N_mask - 1)]) 41 | 42 | de_img = self.img_transform(de_img.convert('RGB')) 43 | st_img = self.img_transform(st_img.convert('RGB')) 44 | ref_img = self.img_transform(ref_img.convert('RGB')) 45 | mask_img = self.mask_transform(mask_img.convert('RGB')) 46 | 47 | return de_img, st_img,mask_img,ref_img 48 | 49 | def __len__(self): 50 | return len(self.de_paths) 51 | -------------------------------------------------------------------------------- /models/FAM/Dynamic_offset_estimator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from models.FAM.non_local_embedded_gaussian import NONLocalBlock2D 5 | from models.FAM.Model_utils import DOE_downsample_block, DOE_upsample_block 6 | 7 | class Dynamic_offset_estimator(nn.Module): 8 | def __init__(self,input_channelsize): 9 | super(Dynamic_offset_estimator, self).__init__() 10 | self.downblock1 = DOE_downsample_block(input_channelsize) 11 | self.downblock2 = DOE_downsample_block(64) 12 | self.downblock3 = DOE_downsample_block(64) 13 | 14 | self.attentionblock1 = NONLocalBlock2D(in_channels=64) 15 | self.attentionblock2 = NONLocalBlock2D(in_channels=64) 16 | self.attentionblock3 = NONLocalBlock2D(in_channels=64) 17 | 18 | self.upblock1 = DOE_upsample_block(in_channels=64,out_channels=64) 19 | self.upblock2 = DOE_upsample_block(in_channels=64,out_channels=64) 20 | self.upblock3 = DOE_upsample_block(in_channels=64,out_channels=64) 21 | 22 | self.channelscaling_block = nn.Conv2d(in_channels= 64, out_channels=input_channelsize, kernel_size=3, padding=1, bias=True) 23 | 24 | def forward(self,x): 25 | halfscale_feature = self.downblock1(x)#1/2 26 | quarterscale_feature = self.downblock2(halfscale_feature)#1/4 27 | octascale_feature = self.downblock3(quarterscale_feature)#1/8 28 | 29 | 30 | octascale_NLout = self.attentionblock1(octascale_feature) 31 | octascale_NLout = torch.add(octascale_NLout, octascale_feature) 32 | octascale_upsampled = self.upblock1(octascale_NLout) 33 | 34 | quarterscale_NLout = self.attentionblock2(octascale_upsampled) 35 | quarterscale_NLout = torch.add(quarterscale_NLout, quarterscale_feature) 36 | quarterscale_upsampled = self.upblock2(quarterscale_NLout) 37 | 38 | halfscale_NLout = self.attentionblock3(quarterscale_upsampled) 39 | halfscale_NLout = torch.add(halfscale_NLout,halfscale_feature) 40 | halfscale_upsampled = self.upblock3(halfscale_NLout) 41 | 42 | out = self.channelscaling_block(halfscale_upsampled) 43 | 44 | return out 45 | 46 | 47 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | # Here is the options especially for training 4 | 5 | class TrainOptions(BaseOptions): 6 | def initialize(self, parser): 7 | parser = BaseOptions.initialize(self, parser) 8 | parser.add_argument('--log_dir', type=str, default='./logs', help='the path to record log') 9 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 10 | parser.add_argument('--print_freq', type=int, default=400, help='frequency of showing training results on console') 11 | parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 12 | parser.add_argument('--save_latest_freq', type=int, default=100, help='frequency of saving the latest results') 13 | parser.add_argument('--save_epoch_freq', type=int, default=2, help='frequency of saving checkpoints at the end of epochs') 14 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 15 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 16 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 17 | parser.add_argument('--which_epoch', type=str, default='70', help='which epoch to load? set to latest to use latest cached model') 18 | parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate') 19 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 20 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 21 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 22 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 23 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 24 | self.isTrain = True 25 | return parser 26 | -------------------------------------------------------------------------------- /util/Selfpatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Selfpatch(object): 6 | def buildAutoencoder(self, target_img, target_img_2, target_img_3, patch_size=1, stride=1): 7 | nDim = 3 8 | assert target_img.dim() == nDim, 'target image must be of dimension 3.' 9 | C = target_img.size(0) 10 | 11 | self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor 12 | 13 | patches_features = self._extract_patches(target_img, patch_size, stride) 14 | patches_features_f = self._extract_patches(target_img_3, patch_size, stride) 15 | 16 | patches_on = self._extract_patches(target_img_2, 1, stride) 17 | 18 | return patches_features_f, patches_features, patches_on 19 | 20 | def build(self, target_img, patch_size=5, stride=1): 21 | nDim = 3 22 | assert target_img.dim() == nDim, 'target image must be of dimension 3.' 23 | C = target_img.size(0) 24 | 25 | self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor 26 | 27 | patches_features = self._extract_patches(target_img, patch_size, stride) 28 | 29 | return patches_features 30 | 31 | def _build(self, patch_size, stride, C, target_patches, npatches, normalize, interpolate, type): 32 | # for each patch, divide by its L2 norm. 33 | if type == 1: 34 | enc_patches = target_patches.clone() 35 | for i in range(npatches): 36 | enc_patches[i] = enc_patches[i]*(1/(enc_patches[i].norm(2)+1e-8)) 37 | 38 | conv_enc = nn.Conv2d(npatches, npatches, kernel_size=1, stride=stride, bias=False, groups=npatches) 39 | conv_enc.weight.data = enc_patches 40 | return conv_enc 41 | 42 | # normalize is not needed, it doesn't change the result! 43 | if normalize: 44 | raise NotImplementedError 45 | 46 | if interpolate: 47 | raise NotImplementedError 48 | else: 49 | 50 | conv_dec = nn.ConvTranspose2d(npatches, C, kernel_size=patch_size, stride=stride, bias=False) 51 | conv_dec.weight.data = target_patches 52 | return conv_dec 53 | 54 | def _extract_patches(self, img, patch_size, stride): 55 | n_dim = 3 56 | assert img.dim() == n_dim, 'image must be of dimension 3.' 57 | kH, kW = patch_size, patch_size 58 | dH, dW = stride, stride 59 | input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW) 60 | i_1, i_2, i_3, i_4, i_5 = input_windows.size(0), input_windows.size(1), input_windows.size(2), input_windows.size(3), input_windows.size(4) 61 | input_windows = input_windows.permute(1,2,0,3,4).contiguous().view(i_2*i_3, i_1, i_4, i_5) 62 | patches_all = input_windows 63 | return patches_all 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /models/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from models import model 5 | 6 | class UnetSkipConnectionDBlock(nn.Module): 7 | def __init__(self, inner_nc, outer_nc, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, 8 | use_dropout=False): 9 | super(UnetSkipConnectionDBlock, self).__init__() 10 | uprelu = nn.ReLU(True) 11 | upnorm = norm_layer(outer_nc, affine=True) 12 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 13 | kernel_size=4, stride=2, 14 | padding=1) 15 | up = [uprelu, upconv, upnorm] 16 | 17 | if outermost: 18 | up = [uprelu, upconv, nn.Tanh()] 19 | model = up 20 | elif innermost: 21 | up = [uprelu, upconv, upnorm] 22 | model = up 23 | else: 24 | up = [uprelu, upconv, upnorm] 25 | model = up 26 | 27 | self.model = nn.Sequential(*model) 28 | 29 | def forward(self, x): 30 | x=x.clone() 31 | x=self.model(x) 32 | return x 33 | 34 | 35 | 36 | class Decoder(nn.Module): 37 | def __init__(self, input_nc, output_nc, ngf=64, 38 | norm_layer=nn.BatchNorm2d, use_dropout=False): 39 | super(Decoder, self).__init__() 40 | 41 | # construct unet structure 42 | Decoder_1 = UnetSkipConnectionDBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout, 43 | innermost=True) 44 | Decoder_2 = UnetSkipConnectionDBlock(ngf * 16, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout) 45 | Decoder_3 = UnetSkipConnectionDBlock(ngf * 16, ngf * 4, norm_layer=norm_layer, use_dropout=use_dropout) 46 | Decoder_4 = UnetSkipConnectionDBlock(ngf * 8, ngf * 2, norm_layer=norm_layer, use_dropout=use_dropout) 47 | Decoder_5 = UnetSkipConnectionDBlock(ngf * 4, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 48 | Decoder_6 = UnetSkipConnectionDBlock(ngf * 2, output_nc, norm_layer=norm_layer, use_dropout=use_dropout, outermost=True) 49 | 50 | self.Decoder_1 = Decoder_1 51 | self.Decoder_2 = Decoder_2 52 | self.Decoder_3 = Decoder_3 53 | self.Decoder_4 = Decoder_4 54 | self.Decoder_5 = Decoder_5 55 | self.Decoder_6 = Decoder_6 56 | 57 | def forward(self, input_1, input_2, input_3, input_4, input_5, input_6): 58 | y_1 = self.Decoder_1(input_6) 59 | #最小的那个 60 | y_2 = self.Decoder_2(torch.cat([y_1, input_5], 1)) 61 | y_3 = self.Decoder_3(torch.cat([y_2, input_4], 1)) 62 | y_4 = self.Decoder_4(torch.cat([y_3, input_3], 1)) 63 | y_5 = self.Decoder_5(torch.cat([y_4, input_2], 1)) 64 | y_6 = self.Decoder_6(torch.cat([y_5, input_1], 1)) 65 | out = y_6 66 | 67 | return out 68 | 69 | #cat 两个张量 按一维拼接在一起 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.dataprocess import DataProcess 4 | from models.model import create_model 5 | import torchvision 6 | from torch.utils import data 7 | from torch.utils.tensorboard import SummaryWriter 8 | import os 9 | import torch 10 | if __name__ == "__main__": 11 | 12 | opt = TrainOptions().parse() 13 | # define the dataset 14 | dataset = DataProcess(opt.de_root, opt.st_root, opt.input_mask_root, opt.ref_root, opt, opt.isTrain) 15 | #dataset = DataProcess(opt.de_root, opt.st_root, opt.input_mask_root, opt.ref_root, opt, opt.isTrain) 16 | iterator_train = (data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=False, num_workers=opt.num_workers, drop_last=False, pin_memory=True)) 17 | # Create model 18 | model = create_model(opt) 19 | total_steps=0 20 | # Create the logs 21 | dir = os.path.join(opt.log_dir, opt.name).replace('\\', '/') 22 | if not os.path.exists(dir): 23 | os.mkdir(dir) 24 | writer = SummaryWriter(log_dir=dir, comment=opt.name) 25 | # Start Training 26 | for epoch in range (opt.epoch_count, opt.niter + opt.niter_decay + 1): 27 | epoch_start_time = time.time() 28 | epoch_iter = 0 29 | for detail, structure,mask,reference in iterator_train: 30 | iter_start_time = time.time() 31 | total_steps += opt.batchSize 32 | epoch_iter += opt.batchSize 33 | model.set_input(detail,structure,mask,reference) 34 | model.optimize_parameters() 35 | # display the training processing 36 | if total_steps % opt.display_freq == 0: 37 | input,reference,output, GT = model.get_current_visuals() 38 | image_out = torch.cat([reference,input,output,GT], 0) 39 | grid = torchvision.utils.make_grid(image_out) 40 | writer.add_image('Epoch_(%d)_(%d)' % (epoch, total_steps + 1), grid, total_steps + 1) 41 | # display the training loss 42 | if total_steps % opt.print_freq == 0: 43 | errors = model.get_current_errors() 44 | t = (time.time() - iter_start_time) / opt.batchSize 45 | writer.add_scalar('G_GAN', errors['G_GAN'], total_steps + 1) 46 | writer.add_scalar('G_L1', errors['G_L1'], total_steps + 1) 47 | writer.add_scalar('G_stde', errors['G_stde'], total_steps + 1) 48 | writer.add_scalar('D_loss', errors['D'], total_steps + 1) 49 | writer.add_scalar('F_loss', errors['F'], total_steps + 1) 50 | print('iteration time: %d' % t) 51 | if epoch % opt.save_epoch_freq == 0: 52 | print('saving the model at the end of epoch %d, iters %d' % 53 | (epoch, total_steps)) 54 | model.save_networks(epoch) 55 | print('End of epoch %d / %d \t Time Taken: %d sec' % 56 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 57 | model.update_learning_rate() 58 | writer.close() 59 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | absl-py=1.2.0=pypi_0 7 | addict=2.4.0=pypi_0 8 | blas=1.0=mkl 9 | ca-certificates=2022.07.19=h06a4308_0 10 | cachetools=5.2.0=pypi_0 11 | certifi=2022.9.14=py38h06a4308_0 12 | charset-normalizer=2.1.1=pypi_0 13 | cycler=0.11.0=pypi_0 14 | einops=0.4.1=pypi_0 15 | fftw=3.3.9=h27cfd23_1 16 | flopth=0.1.2=pypi_0 17 | fonttools=4.37.1=pypi_0 18 | fvcore=0.1.5.post20220512=pypi_0 19 | google-auth=2.11.0=pypi_0 20 | google-auth-oauthlib=0.4.6=pypi_0 21 | grpcio=1.48.1=pypi_0 22 | idna=3.3=pypi_0 23 | imageio=2.21.2=pypi_0 24 | importlib-metadata=4.12.0=pypi_0 25 | intel-openmp=2021.4.0=h06a4308_3561 26 | iopath=0.1.10=pypi_0 27 | joblib=1.1.0=pypi_0 28 | kiwisolver=1.4.4=pypi_0 29 | ld_impl_linux-64=2.38=h1181459_1 30 | libffi=3.3=he6710b0_2 31 | libgcc-ng=11.2.0=h1234567_1 32 | libgfortran-ng=11.2.0=h00389a5_1 33 | libgfortran5=11.2.0=h1234567_1 34 | libgomp=11.2.0=h1234567_1 35 | libstdcxx-ng=11.2.0=h1234567_1 36 | lpips=0.1.4=pypi_0 37 | markdown=3.4.1=pypi_0 38 | markupsafe=2.1.1=pypi_0 39 | matplotlib=3.5.3=pypi_0 40 | mkl=2021.4.0=h06a4308_640 41 | mkl-service=2.4.0=py38h7f8727e_0 42 | mkl_fft=1.3.1=py38hd3c417c_0 43 | mkl_random=1.2.2=py38h51133e4_0 44 | mmcv=1.6.1=pypi_0 45 | mmcv-full=1.6.2=pypi_0 46 | monai=0.9.1=pypi_0 47 | munch=2.5.0=pyhd3eb1b0_0 48 | ncurses=6.3=h5eee18b_3 49 | networkx=2.8.6=pypi_0 50 | nibabel=4.0.2=pypi_0 51 | numpy=1.23.2=pypi_0 52 | oauthlib=3.2.0=pypi_0 53 | opencv-python=4.6.0.66=pypi_0 54 | openssl=1.1.1q=h7f8727e_0 55 | packaging=21.3=pypi_0 56 | pandas=1.4.4=pypi_0 57 | pillow=9.2.0=pypi_0 58 | pip=22.1.2=py38h06a4308_0 59 | portalocker=2.5.1=pypi_0 60 | protobuf=3.19.4=pypi_0 61 | pyasn1=0.4.8=pypi_0 62 | pyasn1-modules=0.2.8=pypi_0 63 | pyparsing=3.0.9=pypi_0 64 | python=3.8.13=h12debd9_0 65 | python-dateutil=2.8.2=pypi_0 66 | pytz=2022.2.1=pypi_0 67 | pywavelets=1.3.0=pypi_0 68 | pyyaml=6.0=pypi_0 69 | readline=8.1.2=h7f8727e_1 70 | requests=2.28.1=pypi_0 71 | requests-oauthlib=1.3.1=pypi_0 72 | rsa=4.9=pypi_0 73 | scikit-image=0.19.3=pypi_0 74 | scikit-learn=1.1.2=pypi_0 75 | scipy=1.4.1=pypi_0 76 | setuptools=63.4.1=py38h06a4308_0 77 | simpleitk=2.2.0=pypi_0 78 | six=1.16.0=pyhd3eb1b0_1 79 | sqlite=3.39.2=h5082296_0 80 | tabulate=0.8.10=pypi_0 81 | tensorboard=2.10.0=pypi_0 82 | tensorboard-data-server=0.6.1=pypi_0 83 | tensorboard-plugin-wit=1.8.1=pypi_0 84 | termcolor=1.1.0=pypi_0 85 | thop=0.1.1-2209072238=pypi_0 86 | threadpoolctl=3.1.0=pypi_0 87 | tifffile=2022.8.12=pypi_0 88 | timm=0.6.7=pypi_0 89 | tk=8.6.12=h1ccaba5_0 90 | torch=1.12.1=pypi_0 91 | torchvision=0.13.1=pypi_0 92 | tqdm=4.64.1=pypi_0 93 | typing-extensions=4.3.0=pypi_0 94 | urllib3=1.26.12=pypi_0 95 | werkzeug=2.2.2=pypi_0 96 | wheel=0.37.1=pyhd3eb1b0_0 97 | xz=5.2.5=h7f8727e_1 98 | yacs=0.1.8=pypi_0 99 | yapf=0.32.0=pypi_0 100 | zipp=3.8.1=pypi_0 101 | zlib=1.2.12=h5eee18b_3 102 | -------------------------------------------------------------------------------- /models/FAM/Model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class L1_Charbonnier_loss(nn.Module): 6 | """L1 Charbonnierloss.""" 7 | def __init__(self): 8 | super(L1_Charbonnier_loss, self).__init__() 9 | self.eps = 1e-6 10 | 11 | def forward(self, X, Y): 12 | diff = torch.add(X, -Y) 13 | error = torch.sqrt( diff * diff + self.eps) 14 | loss = torch.sum(error) 15 | return loss 16 | 17 | class residual_block(nn.Module): 18 | def __init__(self, input_channel = 256, output_channel = 256, bias = False): 19 | super(residual_block, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(in_channels=input_channel,out_channels=input_channel, kernel_size=3, padding=1, bias=bias) 22 | self.conv2 = nn.Conv2d(in_channels=input_channel,out_channels=output_channel, kernel_size=3, padding=1, bias = bias) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | def forward(self,x): 26 | out = self.relu(self.conv1(x)) 27 | out = self.conv2(out) 28 | # out *= 0.1 for bigmodel 29 | out = torch.add(out,x) 30 | 31 | return out 32 | 33 | def make_residual_block(blocknum=32, input_channel = 64, output_channel = 64, bias = False): 34 | residual_layers = [] 35 | #residual_layers.append(residual_block(input_channel=input_channel, output_channel = output_channel,bias=bias)) 36 | for i in range(blocknum): 37 | residual_layers.append(residual_block(input_channel=output_channel, output_channel = output_channel, bias = bias)) 38 | blockpart_model = nn.Sequential(*residual_layers) 39 | return blockpart_model 40 | 41 | def make_downsampling_network(layernum = 2, in_channels = 3, out_channels = 64): 42 | layers = [] 43 | layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, bias=False, padding=1)) 44 | for _ in range(layernum-1): 45 | layers.append(nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=2, bias=False,padding=1)) 46 | print(layers) 47 | model = nn.Sequential(*layers) 48 | return model 49 | 50 | def DOE_downsample_block(input_channelsize): 51 | layers = [] 52 | layers.append( 53 | nn.Conv2d(in_channels=input_channelsize, out_channels=64, kernel_size=3, stride=2, padding=1, bias=True)) 54 | layers.append(nn.LeakyReLU(inplace=True)) 55 | 56 | pre_model = nn.Sequential(*layers) 57 | return pre_model 58 | 59 | def DOE_upsample_block(in_odd = True, in_channels = 64, out_channels = 64): 60 | layers = [] 61 | 62 | if in_odd: 63 | layers.append( 64 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, 65 | output_padding=1, bias=True)) 66 | else: 67 | layers.append( 68 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=0, 69 | bias=True)) 70 | layers.append(nn.LeakyReLU(inplace=True)) 71 | 72 | post_model = nn.Sequential(*layers) 73 | return post_model -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseModel(): 6 | def __init__(self, opt): 7 | self.opt = opt 8 | self.gpu_ids = opt.gpu_ids 9 | self.isTrain = opt.isTrain 10 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 11 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 12 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 13 | self.modelname = opt.model 14 | #self.num_channel = 64 #256 15 | #self.mode = "add" 16 | 17 | def name(self): 18 | return 'BaseModel' 19 | 20 | def set_input(self, input): 21 | self.input = input 22 | 23 | def forward(self): 24 | pass 25 | 26 | def test(self): 27 | pass 28 | 29 | def get_image_paths(self): 30 | pass 31 | 32 | def optimize_parameters(self): 33 | pass 34 | 35 | def get_current_visuals(self): 36 | return self.input 37 | 38 | def get_current_errors(self): 39 | return {} 40 | 41 | def save(self, label): 42 | pass 43 | 44 | # helper saving function that can be used by subclasses 45 | 46 | 47 | def save_networks(self, which_epoch): 48 | for name in self.model_names: 49 | if isinstance(name, str): 50 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 51 | save_path = os.path.join(self.save_dir, save_filename).replace('\\', '/') 52 | net = getattr(self, 'net' + name) 53 | optimize = getattr(self, 'optimizer_' + name) 54 | 55 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 56 | torch.save({'net': net.module.cpu().state_dict(), 'optimize': optimize.state_dict()}, save_path) 57 | net.cuda(self.gpu_ids[0]) 58 | else: 59 | torch.save(net.cpu().state_dict(), save_path) 60 | 61 | # helper loading function that can be used by subclasses 62 | def load_networks(self, which_epoch): 63 | for name in self.model_names: 64 | if isinstance(name, str): 65 | load_filename = '%s_net_%s.pth' % (which_epoch, name) 66 | load_path = os.path.join(self.save_dir, load_filename) 67 | 68 | net = getattr(self, 'net' + name) 69 | optimize = getattr(self, 'optimizer_' + name) 70 | if isinstance(net, torch.nn.DataParallel): 71 | net = net.module 72 | # if you are using PyTorch newer than 0.4 (e.g., built from 73 | # GitHub source), you can remove str() on self.device 74 | state_dict = torch.load(load_path.replace('\\', '/'), map_location=str(self.device)) 75 | optimize.load_state_dict(state_dict['optimize']) 76 | net.load_state_dict(state_dict['net']) 77 | 78 | # update learning rate (called once every epoch) 79 | def update_learning_rate(self): 80 | for scheduler in self.schedulers: 81 | scheduler.step() 82 | lr = self.optimizers[0].param_groups[0]['lr'] 83 | print('learning rate = %.7f' % lr) 84 | def set_requires_grad(self, nets, requires_grad=False): 85 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 86 | Parameters: 87 | nets (network list) -- a list of networks 88 | requires_grad (bool) -- whether the networks require gradients or not 89 | """ 90 | if not isinstance(nets, list): 91 | nets = [nets] 92 | for net in nets: 93 | if net is not None: 94 | for param in net.parameters(): 95 | param.requires_grad = requires_grad 96 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pdb 3 | from options.test_options import TestOptions 4 | from data.dataprocess import DataProcess 5 | from models.model import create_model 6 | import torchvision 7 | from torch.utils import data 8 | from torch.utils.tensorboard import SummaryWriter 9 | import os 10 | import torch 11 | from PIL import Image 12 | import numpy as np 13 | from glob import glob 14 | from tqdm import tqdm 15 | import torchvision.transforms as transforms 16 | if __name__ == "__main__": 17 | 18 | img_transform = transforms.Compose([ 19 | transforms.Resize((256, 256)), 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 22 | ]) 23 | mask_transform = transforms.Compose([ 24 | transforms.Resize((256, 256)), 25 | transforms.ToTensor() 26 | ]) 27 | 28 | results_dir = r'./result/' 29 | if not os.path.exists( results_dir): 30 | os.mkdir(results_dir) 31 | 32 | 33 | opt = TestOptions().parse() 34 | writer = SummaryWriter(log_dir=dir, comment=opt.name) 35 | model = create_model(opt) 36 | 37 | net_EN = torch.load("./checkpoints/RGTSI/net_EN.pth") 38 | net_RefEN = torch.load("./checkpoints/RGTSI/net_RefEN.pth") 39 | net_DE = torch.load("./checkpoints/RGTSI/net_DE.pth") 40 | net_RGTSI = torch.load("./checkpoints/RGTSI/net_RGTSI.pth") 41 | 42 | model.netEN.module.load_state_dict(net_EN['net']) 43 | model.netRefEN.module.load_state_dict(net_RefEN['net']) 44 | model.netDE.module.load_state_dict(net_DE['net']) 45 | model.netRGTSI.module.load_state_dict(net_RGTSI['net']) 46 | 47 | input_mask_paths = glob('{:s}/*'.format("/project/liutaorong/RGTSI/data/DPED10K/test/input_mask/")) 48 | input_mask_paths.sort() 49 | de_paths = glob('{:s}/*'.format("/project/liutaorong/RGTSI/data/DPED10K/test/images/")) 50 | de_paths.sort() 51 | st_path = glob('{:s}/*'.format("/project/liutaorong/RGTSI/data/DPED10K/test/structure/")) 52 | st_path.sort() 53 | ref_paths = glob('{:s}/*'.format("/project/liutaorong/RGTSI/data/DPED10K/test/reference/")) 54 | ref_paths.sort() 55 | 56 | image_len = len(de_paths) 57 | 58 | for i in tqdm(range(image_len)): 59 | 60 | path_im = input_mask_paths[i] 61 | path_de = de_paths[i] 62 | (filepath,tempfilename) = os.path.split(path_de) 63 | (filename,extension) = os.path.splitext(tempfilename) 64 | path_st = st_path[i] 65 | path_rf = ref_paths[i] 66 | 67 | input_mask = Image.open(path_im).convert("RGB") 68 | detail = Image.open(path_de).convert("RGB") 69 | structure = Image.open(path_st).convert("RGB") 70 | reference = Image.open(path_rf).convert("RGB") 71 | 72 | input_mask = mask_transform(input_mask) 73 | detail = img_transform(detail) 74 | structure = img_transform(structure) 75 | reference = img_transform(reference) 76 | 77 | input_mask = torch.unsqueeze(input_mask, 0) 78 | detail = torch.unsqueeze(detail, 0) 79 | structure = torch.unsqueeze(structure,0) 80 | reference = torch.unsqueeze(reference,0) 81 | 82 | with torch.no_grad(): 83 | model.set_input(detail,structure,input_mask,reference) 84 | model.forward() 85 | fake_out = model.fake_out 86 | fake_out = fake_out.detach().cpu() * input_mask + detail*(1-input_mask) 87 | fake_image = (fake_out+1)/2.0 88 | output = fake_image.detach().numpy()[0].transpose((1, 2, 0))*255 89 | output = Image.fromarray(output.astype(np.uint8)) 90 | output.save(results_dir+filename+".jpg") 91 | 92 | input, reference, output, GT = model.get_current_visuals() 93 | image_out = torch.cat([input,reference,output,GT], 0) 94 | grid = torchvision.utils.make_grid(image_out) 95 | writer.add_image('picture(%d)' % i,grid,i) 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICIP 2022] Reference-Guided Texture and Structure Inference for Image Inpainting 2 | [![paper](https://img.shields.io/badge/IEEE-Paper-red)](https://ieeexplore.ieee.org/abstract/document/9897592) 3 | 4 | 5 | This is the repository of the paper **Reference-Guided Texture and Structure Inference for Image Inpainting**, accepted by [ICIP 2022](https://2022.ieeeicip.org/). 6 | 7 | > **Abstract:** *Existing learning-based image inpainting methods are still in challenge when facing complex semantic environments and diverse hole patterns. The prior information learned from the large scale training data is still insufficient for these situations. Reference images captured covering the same scenes share similar texture and structure priors with the corrupted images, which offers new prospects for the image inpainting tasks. Inspired by this, we first build a benchmark dataset containing 10K pairs of input and reference images for reference-guided inpainting. Then we adopt an encoder-decoder structure to separately infer the texture and structure features of the input image considering their pattern discrepancy of texture and structure during inpainting. A feature alignment module is further designed to refine these features of the input image with the guidance of a reference image. Both quantitative and qualitative evaluations demonstrate the superiority of our method over the state-of-the-art methods in terms of completing complex holes.* 8 | 9 | ![](./imgs/pipeline.png) 10 | 11 | 12 | ## Usage Instructions 13 | 14 | ### Environment 15 | Please install Anaconda, Pytorch. For other libs, please refer to the file requirements.txt. 16 | ``` 17 | git clone https://github.com/Cameltr/RGTSI.git 18 | conda create -n RGTSI python=3.8 19 | conda activate RGTSI 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ### Dataset Preparation 24 | 25 | Please download DPED10K dataset from [Google Drive](https://drive.google.com/drive/folders/1CdtWeEqQaZM8RWcPX3m1PyC1BGDcmq-N?usp=share_link) or [Baidu Netdisk](https://pan.baidu.com/s/18mwRhUdKsKaL6J-08mdlLQ) (Password: roqs). Create a folder and unzip the dataset into it, then edit the pathes of the folder in `options/base_options.py` 26 | 27 | Our model is trained on the irregular mask dataset provided by [Liu et al](https://arxiv.org/abs/1804.07723). You can download publically available Irregular Mask Dataset from their [website](http://masc.cs.gmu.edu/wiki/partialconv). 28 | 29 | For Structure image of datasets, we follow the [Structure flow](https://github.com/RenYurui/StructureFlow) and utlize the [RTV smooth method](http://www.cse.cuhk.edu.hk/~leojia/projects/texturesep/).Run generation function [data/Matlab/generate_structre_images.m](./data/Matlab/generate_structure_images.m) in your matlab. For example, if you want to generate smooth images, you can run the following code: 30 | 31 | ``` 32 | generate_structure_images("path to dataset root", "path to output folder"); 33 | ``` 34 | 35 | ### Training and Testing 36 | ```bash 37 | # To train on the you dataset, for example. 38 | python train.py --st_root=[the path of structure images] --de_root=[the path of ground truth images] --input_mask_root=[the path of mask images] --ref_root=[the path of reference images] 39 | ``` 40 | There are many options you can specify. Please use `python train.py --help` or see the options 41 | 42 | For the current version, the batchsize needs to be set to 1. 43 | 44 | To log training, use `--./logs` for Tensorboard. The logs are stored at `logs/[name]`. 45 | 46 | ```bash 47 | # To test on the your dataset, for example. 48 | python test.py 49 | ``` 50 | Please edit the the path of test images in `test.py` when testing on your dataset. 51 | 52 | ### Pre-trained weights and test model 53 | Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1nBFG6EAQTW-G55Nh4YwFaiU_gnIo8Qks?usp=share_link) or [Baidu Netdisk](https://pan.baidu.com/s/1Oh4cqFNgJorOjdxDAugkng) (Password: bb0j). 54 | 55 | 56 | ## Citation 57 | If you find our code or datasets helpful for your research, please cite our paper. 58 | ``` 59 | @inproceedings{liu2022reference, 60 | title={Reference-guided texture and structure inference for image inpainting}, 61 | author={Liu, Taorong and Liao, Liang and Wang, Zheng and Satoh, Shin’Ichi}, 62 | booktitle={2022 IEEE International Conference on Image Processing (ICIP)}, 63 | pages={1996--2000}, 64 | year={2022}, 65 | organization={IEEE} 66 | } 67 | ``` 68 | 69 | ## Acknowledgments 70 | RGTSI is bulit upon the [MEDFE](https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE) and inspired by [SSEN](https://github.com/Slime0519/CVPR_2020_SSEN). We appreciate the authors' excellent work! 71 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.initialized = False 9 | 10 | def initialize(self, parser): 11 | 12 | parser.add_argument('--st_root', type=str, default=r'./data/datasets/structure', help='path to structure images') 13 | parser.add_argument('--de_root', type=str, default=r'./data/datasets/images', help='path to detail images (which are the groundtruth)') 14 | parser.add_argument('--input_mask_root', type=str, default=r'./data/datasets/input_mask', help='path to mask, we use the datasetsets of partial conv hear') 15 | parser.add_argument('--ref_root', type=str, default=r'./data/datasets/reference', help='path to mask, we use the datasetsets of partial conv hear') 16 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 17 | parser.add_argument('--num_workers', type=int, default=8, help='numbers of the core of CPU') 18 | parser.add_argument('--name', type=str, default='RBED', 19 | help='name of the experiment. It decides where to store samples and models') 20 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 21 | parser.add_argument('--input_nc', type=int, default=6, help='# of input image channels') 22 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 23 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 24 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 25 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2') 27 | parser.add_argument('--model', type=str, default='training1', help='set the names of current training process') 28 | parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 29 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 30 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 31 | parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 32 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 33 | 34 | parser.add_argument('--lambda_L1', type=int, default=1, help='weight on L1 term in objective') 35 | parser.add_argument('--lambda_S', type=int, default=250, help='weight on Style loss in objective') 36 | parser.add_argument('--lambda_P', type=int, default=0.2, help='weight on Perceptual loss in objective') 37 | parser.add_argument('--lambda_Gan', type=int, default=0.2, help='weight on GAN term in objective') 38 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 39 | self.initialized = True 40 | return parser 41 | 42 | def gather_options(self): 43 | # initialize parser with basic options 44 | if not self.initialized: 45 | parser = argparse.ArgumentParser( 46 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 47 | parser = self.initialize(parser) 48 | 49 | 50 | self.parser = parser 51 | return parser.parse_args() 52 | 53 | def print_options(self, opt): 54 | message = '' 55 | message += '----------------- Options ---------------\n' 56 | for k, v in sorted(vars(opt).items()): 57 | comment = '' 58 | default = self.parser.get_default(k) 59 | if v != default: 60 | comment = '\t[default: %s]' % str(default) 61 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 62 | message += '----------------- End -------------------' 63 | print(message) 64 | 65 | # save to the disk 66 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 67 | util.mkdirs(expr_dir) 68 | file_name = os.path.join(expr_dir, 'opt.txt') 69 | with open(file_name, 'wt') as opt_file: 70 | opt_file.write(message) 71 | opt_file.write('\n') 72 | 73 | def parse(self): 74 | 75 | opt = self.gather_options() 76 | opt.isTrain = self.isTrain # train or test 77 | 78 | # process opt.suffix 79 | 80 | self.print_options(opt) 81 | 82 | # set gpu ids 83 | str_ids = opt.gpu_ids.split(',') 84 | opt.gpu_ids = [] 85 | for str_id in str_ids: 86 | id = int(str_id) 87 | if id >= 0: 88 | opt.gpu_ids.append(id) 89 | if len(opt.gpu_ids) > 0: 90 | torch.cuda.set_device(opt.gpu_ids[0]) 91 | 92 | self.opt = opt 93 | return self.opt 94 | 95 | -------------------------------------------------------------------------------- /data/Matlab/tsmooth.m: -------------------------------------------------------------------------------- 1 | function S = tsmooth(I,lambda,sigma,sharpness,maxIter) 2 | %tsmooth - Structure Extraction from Texture via Relative Total Variation 3 | % S = tsmooth(I, lambda, sigma, maxIter) extracts structure S from 4 | % structure+texture input I, with smoothness weight lambda, scale 5 | % parameter sigma and iteration number maxIter. 6 | % 7 | % Paras: 8 | % @I : Input UINT8 image, both grayscale and color images are acceptable. 9 | % @lambda : Parameter controlling the degree of smooth. 10 | % Range (0, 0.05], 0.01 by default. 11 | % @sigma : Parameter specifying the maximum size of texture elements. 12 | % Range (0, 6], 3 by defalut. 13 | % @sharpness : Parameter controlling the sharpness of the final results, 14 | % which corresponds to \epsilon_s in the paper [1]. The smaller the value, the sharper the result. 15 | % Range (1e-3, 0.03], 0.02 by defalut. 16 | % @maxIter : Number of itearations, 4 by default. 17 | % 18 | % Example 19 | % ========== 20 | % I = imread('Bishapur_zan.jpg'); 21 | % S = tsmooth(I); % Default Parameters (lambda = 0.01, sigma = 3, sharpness = 0.02, maxIter = 4) 22 | % figure, imshow(I), figure, imshow(S); 23 | % 24 | % ========== 25 | % The Code is created based on the method described in the following paper 26 | % [1] "Structure Extraction from Texture via Relative Total Variation", Li Xu, Qiong Yan, Yang Xia, Jiaya Jia, ACM Transactions on Graphics, 27 | % (SIGGRAPH Asia 2012), 2012. 28 | % The code and the algorithm are for non-comercial use only. 29 | % 30 | % Author: Li Xu (xuli@cse.cuhk.edu.hk) 31 | % Date : 08/25/2012 32 | % Version : 1.0 33 | % Copyright 2012, The Chinese University of Hong Kong. 34 | % 35 | 36 | if (~exist('lambda','var')) 37 | lambda=0.01; 38 | end 39 | if (~exist('sigma','var')) 40 | sigma=3.0; 41 | end 42 | if (~exist('sharpness','var')) 43 | sharpness = 0.02; 44 | end 45 | if (~exist('maxIter','var')) 46 | maxIter=4; 47 | end 48 | I = im2double(I); 49 | x = I; 50 | sigma_iter = sigma; 51 | lambda = lambda/2.0; 52 | dec=2.0; 53 | for iter = 1:maxIter 54 | [wx, wy] = computeTextureWeights(x, sigma_iter, sharpness); 55 | x = solveLinearEquation(I, wx, wy, lambda); 56 | sigma_iter = sigma_iter/dec; 57 | if sigma_iter < 0.5 58 | sigma_iter = 0.5; 59 | end 60 | end 61 | S = x; 62 | end 63 | 64 | function [retx, rety] = computeTextureWeights(fin, sigma,sharpness) 65 | 66 | fx = diff(fin,1,2); 67 | fx = padarray(fx, [0 1 0], 'post'); 68 | fy = diff(fin,1,1); 69 | fy = padarray(fy, [1 0 0], 'post'); 70 | 71 | vareps_s = sharpness; 72 | vareps = 0.001; 73 | 74 | wto = max(sum(sqrt(fx.^2+fy.^2),3)/size(fin,3),vareps_s).^(-1); 75 | fbin = lpfilter(fin, sigma); 76 | gfx = diff(fbin,1,2); 77 | gfx = padarray(gfx, [0 1], 'post'); 78 | gfy = diff(fbin,1,1); 79 | gfy = padarray(gfy, [1 0], 'post'); 80 | wtbx = max(sum(abs(gfx),3)/size(fin,3),vareps).^(-1); 81 | wtby = max(sum(abs(gfy),3)/size(fin,3),vareps).^(-1); 82 | retx = wtbx.*wto; 83 | rety = wtby.*wto; 84 | 85 | retx(:,end) = 0; 86 | rety(end,:) = 0; 87 | 88 | end 89 | 90 | function ret = conv2_sep(im, sigma) 91 | ksize = bitor(round(5*sigma),1); 92 | g = fspecial('gaussian', [1,ksize], sigma); 93 | ret = conv2(im,g,'same'); 94 | ret = conv2(ret,g','same'); 95 | end 96 | 97 | function FBImg = lpfilter(FImg, sigma) 98 | FBImg = FImg; 99 | for ic = 1:size(FBImg,3) 100 | FBImg(:,:,ic) = conv2_sep(FImg(:,:,ic), sigma); 101 | end 102 | end 103 | 104 | function OUT = solveLinearEquation(IN, wx, wy, lambda) 105 | % 106 | % The code for constructing inhomogenious Laplacian is adapted from 107 | % the implementaion of the wlsFilter. 108 | % 109 | % For color images, we enforce wx and wy be same for three channels 110 | % and thus the pre-conditionar only need to be computed once. 111 | % 112 | [r,c,ch] = size(IN); 113 | k = r*c; 114 | dx = -lambda*wx(:); 115 | dy = -lambda*wy(:); 116 | B(:,1) = dx; 117 | B(:,2) = dy; 118 | d = [-r,-1]; 119 | A = spdiags(B,d,k,k); 120 | e = dx; 121 | w = padarray(dx, r, 'pre'); w = w(1:end-r); 122 | s = dy; 123 | n = padarray(dy, 1, 'pre'); n = n(1:end-1); 124 | D = 1-(e+w+s+n); 125 | A = A + A' + spdiags(D, 0, k, k); 126 | if exist('ichol','builtin') 127 | L = ichol(A,struct('michol','on')); 128 | OUT = IN; 129 | for ii=1:ch 130 | tin = IN(:,:,ii); 131 | [tout, flag] = pcg(A, tin(:),0.1,100, L, L'); 132 | OUT(:,:,ii) = reshape(tout, r, c); 133 | end 134 | else 135 | OUT = IN; 136 | for ii=1:ch 137 | tin = IN(:,:,ii); 138 | tout = A\tin(:); 139 | OUT(:,:,ii) = reshape(tout, r, c); 140 | end 141 | end 142 | 143 | end -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | # Define networks, init networks 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import functools 6 | from torch.optim import lr_scheduler 7 | from models.PCconv import PCconv 8 | from models.InnerCos import InnerCos 9 | from models.Encoder import Encoder, RefEncoder 10 | from models.Discriminator import NLayerDiscriminator 11 | from models.Decoder import Decoder 12 | 13 | 14 | 15 | def get_norm_layer(norm_type='instance'): 16 | if norm_type == 'batch': 17 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 18 | elif norm_type == 'instance': 19 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=True) 20 | elif norm_type == 'none': 21 | norm_layer = None 22 | else: 23 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 24 | return norm_layer 25 | 26 | def get_scheduler(optimizer, opt): 27 | if opt.lr_policy == 'lambda': 28 | def lambda_rule(epoch): 29 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 30 | return lr_l 31 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 32 | elif opt.lr_policy == 'step': 33 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 34 | elif opt.lr_policy == 'plateau': 35 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 36 | elif opt.lr_policy == 'cosine': 37 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 38 | else: 39 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 40 | return scheduler 41 | 42 | 43 | def init_weights(net, init_type='normal', gain=0.02): 44 | def init_func(m): 45 | classname = m.__class__.__name__ 46 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 47 | if init_type == 'normal': 48 | init.normal(m.weight.data, 0.0, gain) 49 | elif init_type == 'xavier': 50 | init.xavier_normal(m.weight.data, gain=gain) 51 | elif init_type == 'kaiming': 52 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 53 | elif init_type == 'orthogonal': 54 | init.orthogonal(m.weight.data, gain=gain) 55 | else: 56 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 57 | if hasattr(m, 'bias') and m.bias is not None: 58 | init.constant(m.bias.data, 0.0) 59 | elif classname.find('BatchNorm2d') != -1: 60 | init.normal(m.weight.data, 1.0, gain) 61 | init.constant(m.bias.data, 0.0) 62 | 63 | print('initialize network with %s' % init_type) 64 | net.apply(init_func) 65 | 66 | 67 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 68 | if len(gpu_ids) > 0: 69 | assert(torch.cuda.is_available()) 70 | net.to(gpu_ids[0]) 71 | net = torch.nn.DataParallel(net, gpu_ids) 72 | init_weights(net, init_type, gain=init_gain) 73 | return net 74 | 75 | 76 | def define_G(input_nc, output_nc, ngf, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[], init_gain=0.02): 77 | 78 | norm_layer = get_norm_layer(norm_type=norm) 79 | 80 | stde_list = [] 81 | netEN = Encoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 82 | netRefEN = RefEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 83 | netDE = Decoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 84 | 85 | PCBlock = PCblock(stde_list) 86 | 87 | return init_net(netEN, init_type, init_gain, gpu_ids), init_net(netRefEN, init_type, init_gain, gpu_ids), init_net(netDE, init_type, init_gain, gpu_ids), init_net(PCBlock, init_type, init_gain, gpu_ids),stde_list 88 | 89 | 90 | def define_D(input_nc, ndf, n_layers_D=3, norm='batch', init_type='normal', gpu_ids=[], init_gain=0.02): 91 | netD = None 92 | norm_layer = get_norm_layer(norm_type=norm) 93 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers_D, norm_layer=norm_layer, use_sigmoid=False) 94 | 95 | return init_net(netD, init_type, init_gain, gpu_ids) 96 | 97 | 98 | def print_network(net): 99 | num_params = 0 100 | for param in net.parameters(): 101 | num_params += param.numel() 102 | print(net) 103 | print('Total number of parameters: %d' % num_params) 104 | 105 | 106 | class PCblock(nn.Module): 107 | def __init__(self, stde_list): 108 | super(PCblock, self).__init__() 109 | self.pc_block = PCconv() 110 | innerloss = InnerCos() 111 | stde_list.append(innerloss) 112 | loss = [innerloss] 113 | self.loss=nn.Sequential(*loss) 114 | def forward(self,input,reference,mask): 115 | out = self.pc_block(input,reference,mask) 116 | out = self.loss(out) 117 | return out 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /models/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # Define the resnet block 5 | class ResnetBlock(nn.Module): 6 | def __init__(self, dim, dilation=1): 7 | super(ResnetBlock, self).__init__() 8 | self.conv_block = nn.Sequential( 9 | nn.ReflectionPad2d(dilation), 10 | nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=False), 11 | nn.InstanceNorm2d(dim, track_running_stats=False), 12 | nn.ReLU(True), 13 | nn.ReflectionPad2d(1), 14 | nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=False), 15 | nn.InstanceNorm2d(dim, track_running_stats=False), 16 | ) 17 | 18 | def forward(self, x): 19 | out = x + self.conv_block(x) 20 | return out 21 | 22 | 23 | # define the Encoder unit 24 | class UnetSkipConnectionEBlock(nn.Module): 25 | def __init__(self, outer_nc, inner_nc, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, 26 | use_dropout=False): 27 | super(UnetSkipConnectionEBlock, self).__init__() 28 | downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, 29 | stride=2, padding=1) 30 | 31 | downrelu = nn.LeakyReLU(0.2, True) 32 | 33 | downnorm = norm_layer(inner_nc, affine=True) 34 | if outermost: 35 | down = [downconv] 36 | model = down 37 | elif innermost: 38 | down = [downrelu, downconv] 39 | model = down 40 | else: 41 | down = [downrelu, downconv, downnorm] 42 | if use_dropout: 43 | model = down + [nn.Dropout(0.5)] 44 | else: 45 | model = down 46 | self.model = nn.Sequential(*model) 47 | 48 | def forward(self, x): 49 | return self.model(x) 50 | 51 | 52 | class Encoder(nn.Module): 53 | def __init__(self, input_nc, output_nc, ngf=64, res_num=4, norm_layer=nn.BatchNorm2d, use_dropout=False): 54 | super(Encoder, self).__init__() 55 | 56 | # construct unet structure 57 | Encoder_1 = UnetSkipConnectionEBlock(input_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, outermost=True) 58 | Encoder_2 = UnetSkipConnectionEBlock(ngf, ngf * 2, norm_layer=norm_layer, use_dropout=use_dropout) 59 | Encoder_3 = UnetSkipConnectionEBlock(ngf * 2, ngf * 4, norm_layer=norm_layer, use_dropout=use_dropout) 60 | Encoder_4 = UnetSkipConnectionEBlock(ngf * 4, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout) 61 | Encoder_5 = UnetSkipConnectionEBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout) 62 | Encoder_6 = UnetSkipConnectionEBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout, innermost=True) 63 | 64 | blocks = [] 65 | for _ in range(res_num): 66 | block = ResnetBlock(ngf * 8, 2) 67 | blocks.append(block) 68 | 69 | self.middle = nn.Sequential(*blocks) 70 | 71 | self.Encoder_1 = Encoder_1 72 | self.Encoder_2 = Encoder_2 73 | self.Encoder_3 = Encoder_3 74 | self.Encoder_4 = Encoder_4 75 | self.Encoder_5 = Encoder_5 76 | self.Encoder_6 = Encoder_6 77 | 78 | def forward(self, input): 79 | y_1 = self.Encoder_1(input) 80 | y_2 = self.Encoder_2(y_1) 81 | y_3 = self.Encoder_3(y_2) 82 | y_4 = self.Encoder_4(y_3) 83 | y_5 = self.Encoder_5(y_4) 84 | y_6 = self.Encoder_6(y_5) 85 | y_7 = self.middle(y_6) 86 | 87 | return y_1, y_2, y_3, y_4, y_5, y_7 88 | 89 | class RefEncoder(nn.Module): 90 | def __init__(self, input_nc, output_nc, ngf=64, res_num=4, norm_layer=nn.BatchNorm2d, use_dropout=False): 91 | super(RefEncoder, self).__init__() 92 | 93 | # construct unet structure 94 | Encoder_1 = UnetSkipConnectionEBlock(3, ngf, norm_layer=norm_layer, use_dropout=use_dropout, outermost=True) 95 | Encoder_2 = UnetSkipConnectionEBlock(ngf, ngf * 2, norm_layer=norm_layer, use_dropout=use_dropout) 96 | Encoder_3 = UnetSkipConnectionEBlock(ngf * 2, ngf * 4, norm_layer=norm_layer, use_dropout=use_dropout) 97 | Encoder_4 = UnetSkipConnectionEBlock(ngf * 4, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout) 98 | Encoder_5 = UnetSkipConnectionEBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout) 99 | Encoder_6 = UnetSkipConnectionEBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout, innermost=True) 100 | 101 | blocks = [] 102 | for _ in range(res_num): 103 | block = ResnetBlock(ngf * 8, 2) 104 | blocks.append(block) 105 | 106 | self.middle = nn.Sequential(*blocks) 107 | 108 | self.Encoder_1 = Encoder_1 109 | self.Encoder_2 = Encoder_2 110 | self.Encoder_3 = Encoder_3 111 | self.Encoder_4 = Encoder_4 112 | self.Encoder_5 = Encoder_5 113 | self.Encoder_6 = Encoder_6 114 | 115 | def forward(self, input): 116 | y_1 = self.Encoder_1(input) 117 | y_2 = self.Encoder_2(y_1) 118 | y_3 = self.Encoder_3(y_2) 119 | y_4 = self.Encoder_4(y_3) 120 | y_5 = self.Encoder_5(y_4) 121 | y_6 = self.Encoder_6(y_5) 122 | y_7 = self.middle(y_6) 123 | 124 | return y_1, y_2, y_3, y_4, y_5, y_7 -------------------------------------------------------------------------------- /models/FAM/non_local_embedded_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _NonLocalBlockND(nn.Module): 7 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 8 | """ 9 | :param in_channels: 10 | :param inter_channels: 11 | :param dimension: 12 | :param sub_sample: 13 | :param bn_layer: 14 | """ 15 | 16 | super(_NonLocalBlockND, self).__init__() 17 | 18 | assert dimension in [1, 2, 3] 19 | 20 | self.dimension = dimension 21 | self.sub_sample = sub_sample 22 | 23 | self.in_channels = in_channels 24 | self.inter_channels = inter_channels 25 | 26 | if self.inter_channels is None: 27 | self.inter_channels = in_channels // 2 28 | if self.inter_channels == 0: 29 | self.inter_channels = 1 30 | 31 | if dimension == 3: 32 | conv_nd = nn.Conv3d 33 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 34 | bn = nn.BatchNorm3d 35 | elif dimension == 2: 36 | conv_nd = nn.Conv2d 37 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 38 | bn = nn.BatchNorm2d 39 | else: 40 | conv_nd = nn.Conv1d 41 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 42 | bn = nn.BatchNorm1d 43 | 44 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | 47 | if bn_layer: 48 | self.W = nn.Sequential( 49 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 50 | kernel_size=1, stride=1, padding=0), 51 | bn(self.in_channels) 52 | ) 53 | nn.init.constant_(self.W[1].weight, 0) 54 | nn.init.constant_(self.W[1].bias, 0) 55 | else: 56 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 57 | kernel_size=1, stride=1, padding=0) 58 | nn.init.constant_(self.W.weight, 0) 59 | nn.init.constant_(self.W.bias, 0) 60 | 61 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 62 | kernel_size=1, stride=1, padding=0) 63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | 66 | if sub_sample: 67 | self.g = nn.Sequential(self.g, max_pool_layer) 68 | self.phi = nn.Sequential(self.phi, max_pool_layer) 69 | 70 | def forward(self, x, return_nl_map=False): 71 | """ 72 | :param x: (b, c, t, h, w) 73 | :param return_nl_map: if True return z, nl_map, else only return z. 74 | :return: 75 | """ 76 | 77 | batch_size = x.size(0) 78 | 79 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 80 | g_x = g_x.permute(0, 2, 1) 81 | 82 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 83 | theta_x = theta_x.permute(0, 2, 1) 84 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 85 | f = torch.matmul(theta_x, phi_x) 86 | f_div_C = F.softmax(f, dim=-1) 87 | 88 | y = torch.matmul(f_div_C, g_x) 89 | y = y.permute(0, 2, 1).contiguous() 90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 91 | W_y = self.W(y) 92 | z = W_y + x 93 | 94 | if return_nl_map: 95 | return z, f_div_C 96 | return z 97 | 98 | 99 | class NONLocalBlock1D(_NonLocalBlockND): 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 101 | super(NONLocalBlock1D, self).__init__(in_channels, 102 | inter_channels=inter_channels, 103 | dimension=1, sub_sample=sub_sample, 104 | bn_layer=bn_layer) 105 | 106 | 107 | class NONLocalBlock2D(_NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NONLocalBlock2D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=2, sub_sample=sub_sample, 112 | bn_layer=bn_layer,) 113 | 114 | 115 | class NONLocalBlock3D(_NonLocalBlockND): 116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 117 | super(NONLocalBlock3D, self).__init__(in_channels, 118 | inter_channels=inter_channels, 119 | dimension=3, sub_sample=sub_sample, 120 | bn_layer=bn_layer,) 121 | 122 | 123 | if __name__ == '__main__': 124 | import torch 125 | 126 | for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]: 127 | img = torch.zeros(2, 3, 20) 128 | net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 129 | out = net(img) 130 | print(out.size()) 131 | 132 | img = torch.zeros(2, 3, 20, 20) 133 | net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 134 | out = net(img) 135 | print(out.size()) 136 | 137 | img = torch.randn(2, 3, 8, 20, 20) 138 | net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_) 139 | out = net(img) 140 | print(out.size()) 141 | 142 | 143 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | import torch.nn.functional as F 6 | class VGG16(torch.nn.Module): 7 | def __init__(self): 8 | super(VGG16, self).__init__() 9 | features = models.vgg16(pretrained=True).features 10 | self.relu1_1 = torch.nn.Sequential() 11 | self.relu1_2 = torch.nn.Sequential() 12 | 13 | self.relu2_1 = torch.nn.Sequential() 14 | self.relu2_2 = torch.nn.Sequential() 15 | 16 | self.relu3_1 = torch.nn.Sequential() 17 | self.relu3_2 = torch.nn.Sequential() 18 | self.relu3_3 = torch.nn.Sequential() 19 | self.max3 = torch.nn.Sequential() 20 | 21 | 22 | self.relu4_1 = torch.nn.Sequential() 23 | self.relu4_2 = torch.nn.Sequential() 24 | self.relu4_3 = torch.nn.Sequential() 25 | 26 | 27 | self.relu5_1 = torch.nn.Sequential() 28 | self.relu5_2 = torch.nn.Sequential() 29 | self.relu5_3 = torch.nn.Sequential() 30 | 31 | for x in range(2): 32 | self.relu1_1.add_module(str(x), features[x]) 33 | 34 | for x in range(2, 4): 35 | self.relu1_2.add_module(str(x), features[x]) 36 | 37 | for x in range(4, 7): 38 | self.relu2_1.add_module(str(x), features[x]) 39 | 40 | for x in range(7, 9): 41 | self.relu2_2.add_module(str(x), features[x]) 42 | 43 | for x in range(9, 12): 44 | self.relu3_1.add_module(str(x), features[x]) 45 | 46 | for x in range(12, 14): 47 | self.relu3_2.add_module(str(x), features[x]) 48 | 49 | for x in range(14, 16): 50 | self.relu3_3.add_module(str(x), features[x]) 51 | for x in range(16, 17): 52 | self.max3.add_module(str(x), features[x]) 53 | 54 | for x in range(17, 19): 55 | self.relu4_1.add_module(str(x), features[x]) 56 | 57 | for x in range(19, 21): 58 | self.relu4_2.add_module(str(x), features[x]) 59 | 60 | for x in range(21, 23): 61 | self.relu4_3.add_module(str(x), features[x]) 62 | 63 | for x in range(23, 26): 64 | self.relu5_1.add_module(str(x), features[x]) 65 | 66 | for x in range(26, 28): 67 | self.relu5_2.add_module(str(x), features[x]) 68 | 69 | for x in range(28, 30): 70 | self.relu5_3.add_module(str(x), features[x]) 71 | 72 | 73 | # don't need the gradients, just want the features 74 | for param in self.parameters(): 75 | param.requires_grad = False 76 | 77 | def forward(self, x): 78 | relu1_1 = self.relu1_1(x) 79 | relu1_2 = self.relu1_2(relu1_1) 80 | 81 | relu2_1 = self.relu2_1(relu1_2) 82 | relu2_2 = self.relu2_2(relu2_1) 83 | 84 | relu3_1 = self.relu3_1(relu2_2) 85 | relu3_2 = self.relu3_2(relu3_1) 86 | relu3_3 = self.relu3_3(relu3_2) 87 | max_3 = self.max3(relu3_3) 88 | 89 | 90 | relu4_1 = self.relu4_1(max_3) 91 | relu4_2 = self.relu4_2(relu4_1) 92 | relu4_3 = self.relu4_3(relu4_2) 93 | 94 | 95 | relu5_1 = self.relu5_1(relu4_3) 96 | relu5_2 = self.relu5_1(relu5_1) 97 | relu5_3 = self.relu5_1(relu5_2) 98 | out = { 99 | 'relu1_1': relu1_1, 100 | 'relu1_2': relu1_2, 101 | 102 | 'relu2_1': relu2_1, 103 | 'relu2_2': relu2_2, 104 | 105 | 'relu3_1': relu3_1, 106 | 'relu3_2': relu3_2, 107 | 'relu3_3': relu3_3, 108 | 'max_3':max_3, 109 | 110 | 111 | 'relu4_1': relu4_1, 112 | 'relu4_2': relu4_2, 113 | 'relu4_3': relu4_3, 114 | 115 | 116 | 'relu5_1': relu5_1, 117 | 'relu5_2': relu5_2, 118 | 'relu5_3': relu5_3, 119 | } 120 | return out 121 | class StyleLoss(nn.Module): 122 | r""" 123 | Perceptual loss, VGG-based 124 | https://arxiv.org/abs/1603.08155 125 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 126 | """ 127 | 128 | def __init__(self): 129 | super(StyleLoss, self).__init__() 130 | self.add_module('vgg', VGG16().cuda()) 131 | self.criterion = torch.nn.L1Loss() 132 | 133 | def compute_gram(self, x): 134 | b, ch, h, w = x.size() 135 | f = x.view(b, ch, w * h) 136 | f_T = f.transpose(1, 2) 137 | G = f.bmm(f_T) / (h * w * ch) 138 | 139 | return G 140 | 141 | def __call__(self, x, y): 142 | # Compute features 143 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 144 | 145 | # Compute loss 146 | style_loss = 0.0 147 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 148 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_3']), self.compute_gram(y_vgg['relu3_3'])) 149 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_3']), self.compute_gram(y_vgg['relu4_3'])) 150 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 151 | 152 | return style_loss 153 | class PerceptualLoss(nn.Module): 154 | r""" 155 | Perceptual loss, VGG-based 156 | https://arxiv.org/abs/1603.08155 157 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 158 | """ 159 | 160 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 161 | super(PerceptualLoss, self).__init__() 162 | self.add_module('vgg', VGG16().cuda()) 163 | self.criterion = torch.nn.L1Loss() 164 | self.weights = weights 165 | 166 | def __call__(self, x, y): 167 | # Compute features 168 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 169 | 170 | content_loss = 0.0 171 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 172 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 173 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 174 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 175 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 176 | 177 | 178 | return content_loss 179 | 180 | class GANLoss(nn.Module): 181 | def __init__(self, target_real_label=1.0, target_fake_label=0.0, 182 | tensor=torch.FloatTensor): 183 | super(GANLoss, self).__init__() 184 | self.real_label = target_real_label 185 | self.fake_label = target_fake_label 186 | self.real_label_var = None 187 | self.fake_label_var = None 188 | self.Tensor = tensor 189 | 190 | #这个我可以改,改成我自己定义的GAN 191 | def get_target_tensor(self, input, target_is_real): 192 | target_tensor = None 193 | if target_is_real: 194 | create_label = ((self.real_label_var is None) or 195 | (self.real_label_var.numel() != input.numel())) 196 | if create_label: 197 | self.real_label_var = self.Tensor(input.size()).fill_(self.real_label) 198 | target_tensor = self.real_label_var 199 | else: 200 | create_label = ((self.fake_label_var is None) or 201 | (self.fake_label_var.numel() != input.numel())) 202 | if create_label: 203 | self.fake_label_var= self.Tensor(input.size()).fill_(self.real_label) 204 | target_tensor = self.fake_label_var 205 | return target_tensor 206 | 207 | def __call__(self, y_pred_fake, y_pred, target_is_real): 208 | target_tensor = self.get_target_tensor(y_pred_fake, target_is_real) 209 | if target_is_real: 210 | errD = (torch.mean((y_pred - torch.mean(y_pred_fake) - target_tensor) ** 2) + torch.mean( 211 | (y_pred_fake - torch.mean(y_pred) + target_tensor) ** 2)) / 2 212 | return errD 213 | else: 214 | errG = (torch.mean((y_pred - torch.mean(y_pred_fake) + target_tensor) ** 2) + torch.mean( 215 | (y_pred_fake - torch.mean(y_pred) - target_tensor) ** 2)) / 2 216 | return errG 217 | 218 | # class DESTLOSS(nn.Module): 219 | # def __init__(self): 220 | # super(DESTLOSS, self).__init__() 221 | # self.criterion = torch.nn.L1Loss() 222 | # 223 | # def __call__(self, Gt_de, Gt_st, Fake_de, Fake_st): 224 | # Gt_de = F.interpolate (Gt_de, size=(32,32), mode='bilinear') 225 | # Gt_st = F.interpolate (Gt_st, size=(32,32), mode='bilinear') 226 | # 227 | # 228 | # 229 | # return content_loss -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import random 6 | import inspect, re 7 | import numpy as np 8 | import os 9 | import collections 10 | import math 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | import torch.nn as nn 14 | import matplotlib.pyplot as plt 15 | 16 | # Converts a Tensor into a Numpy array 17 | # |imtype|: the desired type of the converted numpy array 18 | def tensor2im(image_tensor, imtype=np.uint8): 19 | image_numpy = image_tensor[0].cpu().float().numpy() 20 | if image_numpy.shape[0] == 1: 21 | image_numpy = np.tile(image_numpy, (3,1,1)) 22 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 23 | return image_numpy.astype(imtype) 24 | 25 | 26 | def diagnose_network(net, name='network'): 27 | mean = 0.0 28 | count = 0 29 | for param in net.parameters(): 30 | if param.grad is not None: 31 | mean += torch.mean(torch.abs(param.grad.data)) 32 | count += 1 33 | if count > 0: 34 | mean = mean / count 35 | print(name) 36 | print(mean) 37 | 38 | def binary_mask(in_mask, threshold): 39 | assert in_mask.dim() == 2, "mask must be 2 dimensions" 40 | 41 | output = torch.ByteTensor(in_mask.size()) 42 | output = (output > threshold).float().mul_(1) 43 | 44 | return output 45 | 46 | def gussin(v): 47 | outk = [] 48 | v = v 49 | for i in range(32): 50 | for k in range(32): 51 | 52 | out = [] 53 | for x in range(32): 54 | row = [] 55 | for y in range(32): 56 | cord_x = i 57 | cord_y = k 58 | dis_x = np.abs(x - cord_x) 59 | dis_y = np.abs(y - cord_y) 60 | dis_add = -(dis_x * dis_x + dis_y * dis_y) 61 | dis_add = dis_add / (2 * v * v) 62 | dis_add = math.exp(dis_add) / (2 * math.pi * v * v) 63 | 64 | row.append(dis_add) 65 | out.append(row) 66 | 67 | outk.append(out) 68 | 69 | out = np.array(outk) 70 | f = out.sum(-1).sum(-1) 71 | q = [] 72 | for i in range(1024): 73 | g = out[i] / f[i] 74 | q.append(g) 75 | out = np.array(q) 76 | return torch.from_numpy(out) 77 | 78 | def cal_feat_mask(inMask, conv_layers, threshold): 79 | assert inMask.dim() == 4, "mask must be 4 dimensions" 80 | assert inMask.size(0) == 1, "the first dimension must be 1 for mask" 81 | inMask = inMask.float() 82 | convs = [] 83 | inMask = Variable(inMask, requires_grad = False) 84 | for id_net in range(conv_layers): 85 | conv = nn.Conv2d(1,1,4,2,1, bias=False) 86 | conv.weight.data.fill_(1/16) 87 | convs.append(conv) 88 | lnet = nn.Sequential(*convs) 89 | if inMask.is_cuda: 90 | 91 | lnet = lnet.cuda() 92 | output = lnet(inMask) 93 | output = (output > threshold).float().mul_(1) 94 | 95 | return output 96 | 97 | def cal_mask_given_mask_thred(img, mask, patch_size, stride, mask_thred): 98 | assert img.dim() == 3, 'img has to be 3 dimenison!' 99 | assert mask.dim() == 2, 'mask has to be 2 dimenison!' 100 | dim = img.dim() 101 | #math.floor 是向下取整 102 | _, H, W = img.size(dim-3), img.size(dim-2), img.size(dim-1) 103 | nH = int(math.floor((H-patch_size)/stride + 1)) 104 | nW = int(math.floor((W-patch_size)/stride + 1)) 105 | N = nH*nW 106 | 107 | flag = torch.zeros(N).long() 108 | offsets_tmp_vec = torch.zeros(N).long() 109 | #返回的是一个list类型的数据 110 | 111 | nonmask_point_idx_all = torch.zeros(N).long() 112 | 113 | tmp_non_mask_idx = 0 114 | 115 | 116 | mask_point_idx_all = torch.zeros(N).long() 117 | 118 | tmp_mask_idx = 0 119 | #所有的像素点都浏览一遍 120 | for i in range(N): 121 | h = int(math.floor(i/nW)) 122 | w = int(math.floor(i%nW)) 123 | # print(h, w) 124 | #截取一个个1×1的小方片 125 | mask_tmp = mask[h*stride:h*stride + patch_size, 126 | w*stride:w*stride + patch_size] 127 | 128 | 129 | if torch.sum(mask_tmp) < mask_thred: 130 | nonmask_point_idx_all[tmp_non_mask_idx] = i 131 | tmp_non_mask_idx += 1 132 | else: 133 | mask_point_idx_all[tmp_mask_idx] = i 134 | tmp_mask_idx += 1 135 | flag[i] = 1 136 | offsets_tmp_vec[i] = -1 137 | # print(flag) #checked 138 | # print(offsets_tmp_vec) # checked 139 | 140 | non_mask_num = tmp_non_mask_idx 141 | mask_num = tmp_mask_idx 142 | 143 | nonmask_point_idx = nonmask_point_idx_all.narrow(0, 0, non_mask_num) 144 | mask_point_idx=mask_point_idx_all.narrow(0, 0, mask_num) 145 | 146 | # get flatten_offsets 147 | flatten_offsets_all = torch.LongTensor(N).zero_() 148 | for i in range(N): 149 | offset_value = torch.sum(offsets_tmp_vec[0:i+1]) 150 | if flag[i] == 1: 151 | offset_value = offset_value + 1 152 | # print(i+offset_value) 153 | flatten_offsets_all[i+offset_value] = -offset_value 154 | 155 | flatten_offsets = flatten_offsets_all.narrow(0, 0, non_mask_num) 156 | 157 | # print('flatten_offsets') 158 | # print(flatten_offsets) # checked 159 | 160 | 161 | # print('nonmask_point_idx') 162 | # print(nonmask_point_idx) #checked 163 | 164 | return flag, nonmask_point_idx, flatten_offsets, mask_point_idx 165 | 166 | 167 | # sp_x: LongTensor 168 | # sp_y: LongTensor 169 | def cal_sps_for_Advanced_Indexing(h, w): 170 | sp_y = torch.arange(0, w).long() 171 | sp_y = torch.cat([sp_y]*h) 172 | 173 | lst = [] 174 | for i in range(h): 175 | lst.extend([i]*w) 176 | sp_x = torch.from_numpy(np.array(lst)) 177 | return sp_x, sp_y 178 | 179 | 180 | def save_image(image_numpy, image_path): 181 | image_pil = Image.fromarray(image_numpy) 182 | image_pil.save(image_path) 183 | 184 | def info(object, spacing=10, collapse=1): 185 | """Print methods and doc strings. 186 | Takes module, class, list, dictionary, or string.""" 187 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] 188 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 189 | print( "\n".join(["%s %s" % 190 | (method.ljust(spacing), 191 | processFunc(str(getattr(object, method).__doc__))) 192 | for method in methodList]) ) 193 | 194 | def varname(p): 195 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 196 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 197 | if m: 198 | return m.group(1) 199 | 200 | def print_numpy(x, val=True, shp=False): 201 | x = x.astype(np.float64) 202 | if shp: 203 | print('shape,', x.shape) 204 | if val: 205 | x = x.flatten() 206 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 207 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 208 | 209 | 210 | def showpatch(imagepatch, modelname ,foldername=None, istensor = True): 211 | batchsize = imagepatch.shape[0] 212 | channelsize = imagepatch.shape[1] 213 | #print(imagepatch.shape) 214 | #print(batchsize) 215 | 216 | if istensor: 217 | imagepatch = np.array(imagepatch.cpu().detach()) 218 | folderpath = os.path.join("Network_patches",modelname, foldername) 219 | 220 | print("start visulization {}, channelsize : {}".format(foldername,channelsize)) 221 | 222 | if not os.path.isdir(os.path.join("Network_patches",modelname)): 223 | os.mkdir(os.path.join("Network_patches",modelname)) 224 | if not os.path.isdir(folderpath): 225 | os.mkdir(folderpath) 226 | 227 | for index in range(batchsize): 228 | patches = imagepatch[index] 229 | for channel in range(0,channelsize,128): 230 | #for channel in range(0,channelsize): 231 | image = regularization_image(patches[channel]) 232 | image = (image*255).astype(np.uint8) 233 | #PIL_Input_Image = Image.fromarray(image) 234 | #PIL_Input_Image.save(os.path.join(folderpath,"image{}.png".format(index))) 235 | plt.imshow(image, 'gray') 236 | plt.savefig(os.path.join(folderpath,"image{}.png".format(channel))) 237 | 238 | 239 | 240 | def saveoffset(offsetbatch, modelname, foldername, istensor = False): 241 | if istensor: 242 | offsetbatch = np.array(offsetbatch.cpu().detach()) 243 | offsetbatch = np.transpose(offsetbatch, (0, 2, 3, 1)) 244 | offsetbatch = np.squeeze(offsetbatch) 245 | sizetemp = offsetbatch.shape[:-1] 246 | offset_coord = np.zeros((*sizetemp, int(offsetbatch.shape[-1] / 2), 2), dtype=np.float32) 247 | 248 | for y in range(offset_coord.shape[0]): 249 | for x in range(offset_coord.shape[1]): 250 | for i in range(offset_coord.shape[2]): 251 | coordtuple = offsetbatch[y,x,i*2:(i+1)*2] 252 | offset_coord[y,x,i] = coordtuple 253 | 254 | folderpath = os.path.join("Network_patches",modelname, foldername) 255 | 256 | if not os.path.isdir(folderpath): 257 | os.mkdir(folderpath) 258 | 259 | for i in range(offsetbatch.shape[0]): 260 | np.save(os.path.join(folderpath, "offset_{}.npy".format(i)), offsetbatch[i]) 261 | 262 | for i in range(offsetbatch.shape[0]): 263 | np.save(os.path.join(folderpath, "offset_{}.npy".format(i)), offsetbatch[i]) 264 | return offset_coord 265 | 266 | def regularization_image(image): 267 | min = np.min(image) 268 | temp_image = image-min 269 | 270 | max = np.max(temp_image) 271 | temp_image = temp_image/max 272 | 273 | return temp_image 274 | 275 | def mkdirs(paths): 276 | if isinstance(paths, list) and not isinstance(paths, str): 277 | for path in paths: 278 | mkdir(path) 279 | else: 280 | mkdir(paths) 281 | 282 | 283 | def mkdir(path): 284 | if not os.path.exists(path): 285 | os.makedirs(path) 286 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /data/Matlab/dirPlus.m: -------------------------------------------------------------------------------- 1 | function output = dirPlus(rootPath, varargin) 2 | %dirPlus Recursively collect files or directories within a folder. 3 | % LIST = dirPlus(ROOTPATH) will search recursively through the folder 4 | % tree beneath ROOTPATH and collect a cell array LIST of all files it 5 | % finds. The list will contain the absolute paths to each file starting 6 | % at ROOTPATH. 7 | % 8 | % LIST = dirPlus(ROOTPATH, 'PropertyName', PropertyValue, ...) will 9 | % modify how files and directories are selected, as well as the format of 10 | % LIST, based on the property/value pairs specified. Valid properties 11 | % that the user can set are: 12 | % 13 | % GENERAL: 14 | % 'Struct' - A logical value determining if the output LIST should 15 | % instead be a structure array of the form returned by 16 | % the DIR function. If TRUE, LIST will be an N-by-1 17 | % structure array instead of a cell array. 18 | % 'Depth' - A non-negative integer value for the maximum folder 19 | % tree depth that dirPlus will search through. A value 20 | % of 0 will only search in ROOTPATH, a value of 1 will 21 | % search in ROOTPATH and its subfolders, etc. Default 22 | % (and maximum allowable) value is the current 23 | % recursion limit set on the root object (i.e. 24 | % get(0, 'RecursionLimit')). 25 | % 'ReturnDirs' - A logical value determining if the output will be a 26 | % list of files or subdirectories. If TRUE, LIST will 27 | % be a cell array of subdirectory names/paths. Default 28 | % is FALSE. 29 | % 'PrependPath' - A logical value determining if the full path from 30 | % ROOTPATH to the file/subdirectory is prepended to 31 | % each item in LIST. The default TRUE will prepend the 32 | % full path, otherwise just the file/subdirectory name 33 | % is returned. This setting is ignored if the 'Struct' 34 | % argument is TRUE. 35 | % 36 | % FILE-SPECIFIC: 37 | % 'FileFilter' - A string defining a regular-expression pattern 38 | % that will be applied to the file name. Only files 39 | % matching the pattern will be included in LIST. 40 | % Default is '' (i.e. all files are included). 41 | % 'ValidateFileFcn' - A handle to a function that takes as input a 42 | % structure of the form returned by the DIR 43 | % function and returns a logical value. This 44 | % function will be applied to all files found and 45 | % only files that have a TRUE return value will be 46 | % included in LIST. Default is [] (i.e. all files 47 | % are included). 48 | % 49 | % DIRECTORY-SPECIFIC: 50 | % 'DirFilter' - A string defining a regular-expression pattern 51 | % that will be applied to the subdirectory name. 52 | % Only subdirectories matching the pattern will be 53 | % considered valid (i.e. included in LIST themselves 54 | % or having their files included in LIST). Default 55 | % is '' (i.e. all subdirectories are valid). The 56 | % setting of the 'RecurseInvalid' argument 57 | % determines if invalid subdirectories are still 58 | % recursed down. 59 | % 'ValidateDirFcn' - A handle to a function that takes as input a 60 | % structure of the form returned by the DIR function 61 | % and returns a logical value. This function will be 62 | % applied to all subdirectories found and only 63 | % subdirectories that have a TRUE return value will 64 | % be considered valid (i.e. included in LIST 65 | % themselves or having their files included in 66 | % LIST). Default is [] (i.e. all subdirectories are 67 | % valid). The setting of the 'RecurseInvalid' 68 | % argument determines if invalid subdirectories are 69 | % still recursed down. 70 | % 'RecurseInvalid' - A logical value determining if invalid 71 | % subdirectories (as identified by the 'DirFilter' 72 | % and 'ValidateDirFcn' arguments) should still be 73 | % recursed down. Default is FALSE (i.e the recursive 74 | % searching stops at invalid subdirectories). 75 | % 76 | % Examples: 77 | % 78 | % 1) Find all '.m' files: 79 | % 80 | % fileList = dirPlus(rootPath, 'FileFilter', '\.m$'); 81 | % 82 | % 2) Find all '.m' files, returning the list as a structure array: 83 | % 84 | % fileList = dirPlus(rootPath, 'Struct', true, ... 85 | % 'FileFilter', '\.m$'); 86 | % 87 | % 3) Find all '.jpg', '.png', and '.tif' files: 88 | % 89 | % fileList = dirPlus(rootPath, 'FileFilter', '\.(jpg|png|tif)$'); 90 | % 91 | % 4) Find all '.m' files in the given folder and its subfolders: 92 | % 93 | % fileList = dirPlus(rootPath, 'Depth', 1, 'FileFilter', '\.m$'); 94 | % 95 | % 5) Find all '.m' files, returning only the file names: 96 | % 97 | % fileList = dirPlus(rootPath, 'FileFilter', '\.m$', ... 98 | % 'PrependPath', false); 99 | % 100 | % 6) Find all '.jpg' files with a size of more than 1MB: 101 | % 102 | % bigFcn = @(s) (s.bytes > 1024^2); 103 | % fileList = dirPlus(rootPath, 'FileFilter', '\.jpg$', ... 104 | % 'ValidateFcn', bigFcn); 105 | % 106 | % 7) Find all '.m' files contained in folders containing the string 107 | % 'addons', recursing without restriction: 108 | % 109 | % fileList = dirPlus(rootPath, 'DirFilter', 'addons', ... 110 | % 'FileFilter', '\.m$', ... 111 | % 'RecurseInvalid', true); 112 | % 113 | % See also dir, regexp. 114 | 115 | % Author: Ken Eaton 116 | % Version: MATLAB R2016b - R2011a 117 | % Last modified: 4/14/17 118 | % Copyright 2017 by Kenneth P. Eaton 119 | % Copyright 2017 by Stephen Larroque - backwards compatibility 120 | %-------------------------------------------------------------------------- 121 | 122 | % Create input parser (only have to do this once, hence the use of a 123 | % persistent variable): 124 | 125 | persistent parser 126 | if isempty(parser) 127 | recursionLimit = get(0, 'RecursionLimit'); 128 | parser = inputParser(); 129 | parser.FunctionName = 'dirPlus'; 130 | if verLessThan('matlab', '8.2') % MATLAB R2013b = 8.2 131 | addPVPair = @addParamValue; 132 | else 133 | parser.PartialMatching = true; 134 | addPVPair = @addParameter; 135 | end 136 | 137 | % Add general parameters: 138 | 139 | addRequired(parser, 'rootPath', ... 140 | @(s) validateattributes(s, {'char'}, {'nonempty'})); 141 | addPVPair(parser, 'Struct', false, ... 142 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 143 | addPVPair(parser, 'Depth', recursionLimit, ... 144 | @(s) validateattributes(s, {'numeric'}, ... 145 | {'scalar', 'nonnegative', ... 146 | 'nonnan', 'integer', ... 147 | '<=', recursionLimit})); 148 | addPVPair(parser, 'ReturnDirs', false, ... 149 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 150 | addPVPair(parser, 'PrependPath', true, ... 151 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 152 | 153 | % Add file-specific parameters: 154 | 155 | addPVPair(parser, 'FileFilter', '', ... 156 | @(s) validateattributes(s, {'char'}, {'row'})); 157 | addPVPair(parser, 'ValidateFileFcn', [], ... 158 | @(f) validateattributes(f, {'function_handle'}, {'scalar'})); 159 | 160 | % Add directory-specific parameters: 161 | 162 | addPVPair(parser, 'DirFilter', '', ... 163 | @(s) validateattributes(s, {'char'}, {'row'})); 164 | addPVPair(parser, 'ValidateDirFcn', [], ... 165 | @(f) validateattributes(f, {'function_handle'}, {'scalar'})); 166 | addPVPair(parser, 'RecurseInvalid', false, ... 167 | @(b) validateattributes(b, {'logical'}, {'scalar'})); 168 | 169 | end 170 | 171 | % Parse input and recursively find contents: 172 | 173 | parse(parser, rootPath, varargin{:}); 174 | output = dirPlus_core(parser.Results.rootPath, ... 175 | rmfield(parser.Results, 'rootPath'), 0, true); 176 | if parser.Results.Struct 177 | output = vertcat(output{:}); 178 | end 179 | 180 | end 181 | 182 | %~~~Begin local functions~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 183 | 184 | %-------------------------------------------------------------------------- 185 | % Core recursive function to find files and directories. 186 | function output = dirPlus_core(rootPath, optionStruct, depth, isValid) 187 | 188 | % Backwards compatibility for fullfile: 189 | 190 | persistent fullfilecell 191 | if isempty(fullfilecell) 192 | if verLessThan('matlab', '8.0') % MATLAB R2012b = 8.0 193 | fullfilecell = @(P, C) cellfun(@(S) fullfile(P, S), C, ... 194 | 'UniformOutput', false); 195 | else 196 | fullfilecell = @fullfile; 197 | end 198 | end 199 | 200 | % Get current directory contents: 201 | 202 | rootData = dir(rootPath); 203 | dirIndex = [rootData.isdir]; 204 | subDirs = {}; 205 | validIndex = []; 206 | 207 | % Find valid subdirectories, only if necessary: 208 | 209 | if (depth < optionStruct.Depth) || optionStruct.ReturnDirs 210 | 211 | % Get subdirectories, not counting current or parent: 212 | 213 | dirData = rootData(dirIndex); 214 | subDirs = {dirData.name}.'; 215 | index = ~ismember(subDirs, {'.', '..'}); 216 | dirData = dirData(index); 217 | subDirs = subDirs(index); 218 | validIndex = true(size(subDirs)); 219 | if any(validIndex) 220 | % Apply directory name filter, if specified: 221 | nameFilter = optionStruct.DirFilter; 222 | if ~isempty(nameFilter) 223 | validIndex = ~cellfun(@isempty, regexp(subDirs, nameFilter)); 224 | end 225 | if any(validIndex) 226 | % Apply validation function to the directory list, if specified: 227 | validateFcn = optionStruct.ValidateDirFcn; 228 | if ~isempty(validateFcn) 229 | validIndex(validIndex) = arrayfun(validateFcn, ... 230 | dirData(validIndex)); 231 | end 232 | end 233 | end 234 | end 235 | % Determine if files or subdirectories are being returned: 236 | if optionStruct.ReturnDirs % Return directories 237 | % Use structure data or prepend full path, if specified: 238 | if optionStruct.Struct 239 | output = {dirData(validIndex)}; 240 | elseif any(validIndex) && optionStruct.PrependPath 241 | output = fullfilecell(rootPath, subDirs(validIndex)); 242 | else 243 | output = subDirs(validIndex); 244 | end 245 | elseif isValid % Return files 246 | % Find all files in the current directory: 247 | fileData = rootData(~dirIndex); 248 | output = {fileData.name}.'; 249 | 250 | if ~isempty(output) 251 | 252 | % Apply file name filter, if specified: 253 | 254 | fileFilter = optionStruct.FileFilter; 255 | if ~isempty(fileFilter) 256 | filterIndex = ~cellfun(@isempty, regexp(output, fileFilter)); 257 | fileData = fileData(filterIndex); 258 | output = output(filterIndex); 259 | end 260 | 261 | if ~isempty(output) 262 | 263 | % Apply validation function to the file list, if specified: 264 | 265 | validateFcn = optionStruct.ValidateFileFcn; 266 | if ~isempty(validateFcn) 267 | validateIndex = arrayfun(validateFcn, fileData); 268 | fileData = fileData(validateIndex); 269 | output = output(validateIndex); 270 | end 271 | 272 | % Use structure data or prepend full path, if specified: 273 | 274 | if optionStruct.Struct 275 | output = {fileData}; 276 | elseif ~isempty(output) && optionStruct.PrependPath 277 | output = fullfilecell(rootPath, output); 278 | end 279 | 280 | end 281 | 282 | end 283 | 284 | else % Return nothing 285 | 286 | output = {}; 287 | 288 | end 289 | 290 | % Check recursion depth: 291 | 292 | if (depth < optionStruct.Depth) 293 | 294 | % Select subdirectories to recurse down: 295 | 296 | if ~optionStruct.RecurseInvalid 297 | subDirs = subDirs(validIndex); 298 | validIndex = validIndex(validIndex); 299 | end 300 | 301 | % Recursively collect output from subdirectories: 302 | 303 | nSubDirs = numel(subDirs); 304 | if (nSubDirs > 0) 305 | subDirs = fullfilecell(rootPath, subDirs); 306 | output = {output; cell(nSubDirs, 1)}; 307 | for iSub = 1:nSubDirs 308 | output{iSub+1} = dirPlus_core(subDirs{iSub}, optionStruct, ... 309 | depth+1, validIndex(iSub)); 310 | end 311 | output = vertcat(output{:}); 312 | end 313 | 314 | end 315 | 316 | end 317 | 318 | %~~~End local functions~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -------------------------------------------------------------------------------- /models/RGTSI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from collections import OrderedDict 4 | from torch.autograd import Variable 5 | from PIL import Image 6 | import torch.nn.functional as F 7 | 8 | from models.base_model import BaseModel 9 | from models import networks 10 | 11 | from .loss import VGG16, PerceptualLoss, StyleLoss, GANLoss 12 | 13 | 14 | class RGTSI(BaseModel): 15 | def __init__(self, opt): 16 | super(RGTSI, self).__init__(opt) 17 | self.isTrain = opt.isTrain 18 | self.opt = opt 19 | self.device = torch.device('cuda') 20 | # define tensors 21 | self.vgg = VGG16() 22 | self.PerceptualLoss = PerceptualLoss() 23 | self.StyleLoss = StyleLoss() 24 | self.input_DE = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 25 | self.input_ST = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 26 | self.ref_DE = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 27 | self.fake_input_p_1 = self.Tensor(opt.batchSize, 6, opt.fineSize, opt.fineSize) 28 | self.Gt_Local = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 29 | self.Gt_DE = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 30 | self.Gt_ST = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 31 | self.Gt_RF = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 32 | self.input_mask_global = self.Tensor(opt.batchSize, 1, opt.fineSize, opt.fineSize) 33 | 34 | 35 | self.model_names = [] 36 | if len(opt.gpu_ids) > 0: 37 | self.use_gpu = True 38 | self.vgg = self.vgg.to(self.gpu_ids[0]) 39 | self.vgg = torch.nn.DataParallel(self.vgg, self.gpu_ids) 40 | # load/define networks EN:Encoder RefEN:RefEncoder DE:Decoder RGTSI: Reference-Guided Texture and Structure Inference 41 | self.netEN, self.netRefEN, self.netDE, self.netRGTSI, self.stde_loss = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.norm, 42 | opt.use_dropout, opt.init_type, 43 | self.gpu_ids, 44 | opt.init_gain) 45 | 46 | self.model_names=[ 'EN','RefEN','DE', 'RGTSI'] 47 | 48 | if self.isTrain: 49 | 50 | self.netD = networks.define_D(3, opt.ndf, 51 | opt.n_layers_D, opt.norm, opt.init_type, self.gpu_ids, opt.init_gain) 52 | self.netF = networks.define_D(3, opt.ndf, 53 | opt.n_layers_D, opt.norm, opt.init_type, self.gpu_ids, opt.init_gain) 54 | self.model_names.append('D') 55 | self.model_names.append('F') 56 | if self.isTrain: 57 | self.old_lr = opt.lr 58 | # define loss functions 59 | self.criterionGAN = GANLoss(tensor=self.Tensor) 60 | self.criterionL1 = torch.nn.L1Loss() 61 | self.criterionL2 = torch.nn.MSELoss() 62 | 63 | # initialize optimizers 64 | self.schedulers = [] 65 | self.optimizers = [] 66 | 67 | self.optimizer_EN = torch.optim.Adam(self.netEN.parameters(), 68 | lr=opt.lr, betas=(opt.beta1, 0.999)) 69 | self.optimizer_RefEN = torch.optim.Adam(self.netRefEN.parameters(), 70 | lr=opt.lr, betas=(opt.beta1, 0.999)) 71 | 72 | self.optimizer_DE = torch.optim.Adam(self.netDE.parameters(), 73 | lr=opt.lr, betas=(opt.beta1, 0.999)) 74 | 75 | self.optimizer_RGTSI = torch.optim.Adam(self.netRGTSI.parameters(), 76 | lr=opt.lr, betas=(opt.beta1, 0.999)) 77 | 78 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 79 | lr=opt.lr, betas=(opt.beta1, 0.999)) 80 | self.optimizer_F = torch.optim.Adam(self.netF.parameters(), 81 | lr=opt.lr, betas=(opt.beta1, 0.999)) 82 | self.optimizers.append(self.optimizer_EN) 83 | self.optimizers.append(self.optimizer_RefEN) 84 | self.optimizers.append(self.optimizer_DE) 85 | 86 | self.optimizers.append(self.optimizer_RGTSI) 87 | self.optimizers.append(self.optimizer_D) 88 | self.optimizers.append(self.optimizer_F) 89 | for optimizer in self.optimizers: 90 | self.schedulers.append(networks.get_scheduler(optimizer, opt)) 91 | 92 | print('---------- Networks initialized -------------') 93 | networks.print_network(self.netEN) 94 | networks.print_network(self.netRefEN) 95 | networks.print_network(self.netDE) 96 | networks.print_network(self.netRGTSI) 97 | if self.isTrain: 98 | networks.print_network(self.netD) 99 | networks.print_network(self.netF) 100 | print('-----------------------------------------------') 101 | #####modified 102 | if self.isTrain: 103 | if opt.continue_train : 104 | print('Loading pre-trained network!') 105 | self.load_networks(self.netEN, 'EN', opt.which_epoch) 106 | self.load_networks(self.netRefEN, 'RefEN', opt.which_epoch) 107 | self.load_networks(self.netDE, 'DE', opt.which_epoch) 108 | self.load_networks(self.netRGTSI, 'RGTSI', opt.which_epoch) 109 | self.load_networks(self.netD, 'D', opt.which_epoch) 110 | self.load_networks(self.netF, 'F', opt.which_epoch) 111 | 112 | def name(self): 113 | return self.modelname 114 | 115 | def mask_process(self, mask): 116 | mask = mask[0][0] 117 | mask = torch.unsqueeze(mask, 0) 118 | mask = torch.unsqueeze(mask, 1) 119 | mask = mask.byte() 120 | return mask 121 | 122 | def set_input(self, input_De,input_St,input_Mask,ref_De): 123 | self.Gt_DE = input_De.to(self.device) 124 | self.Gt_ST = input_St.to(self.device) 125 | self.input_DE = input_De.to(self.device) 126 | self.ref_DE = ref_De.to(self.device) 127 | 128 | self.input_mask_global = self.mask_process(input_Mask.to(self.device)) 129 | 130 | 131 | self.Gt_Local = input_De.to(self.device) 132 | # define local area which send to the local discriminator 133 | self.crop_x = random.randint(0, 191) 134 | self.crop_y = random.randint(0, 191) 135 | self.Gt_Local = self.Gt_Local[:, :, self.crop_x:self.crop_x + 64, self.crop_y:self.crop_y + 64] 136 | self.ex_input_mask = self.input_mask_global.expand(self.input_mask_global.size(0), 3, self.input_mask_global.size(2), 137 | self.input_mask_global.size(3)) 138 | 139 | 140 | #unpositve with original mask 141 | 142 | self.inv_ex_input_mask = torch.add(torch.neg(self.ex_input_mask.float()), 1).float() 143 | 144 | # set loss groundtruth for two branch 145 | self.stde_loss[0].set_target(self.Gt_DE, self.Gt_ST) 146 | 147 | # Do not set the mask regions as 0 148 | self.input_DE.narrow(1, 0, 1).masked_fill_(self.input_mask_global.narrow(1, 0, 1).bool(), 2 * 123.0 / 255.0 - 1.0) 149 | self.input_DE.narrow(1, 1, 1).masked_fill_(self.input_mask_global.narrow(1, 0, 1).bool(), 2 * 104.0 / 255.0 - 1.0) 150 | self.input_DE.narrow(1, 2, 1).masked_fill_(self.input_mask_global.narrow(1, 0, 1).bool(), 2 * 117.0 / 255.0 - 1.0) 151 | 152 | def forward(self): 153 | 154 | fake_input_p_1, fake_input_p_2, fake_input_p_3, fake_input_p_4, fake_input_p_5, fake_input_p_6 = self.netEN( 155 | torch.cat([self.input_DE, self.inv_ex_input_mask], 1)) 156 | De_in = [fake_input_p_1, fake_input_p_2, fake_input_p_3, fake_input_p_4, fake_input_p_5, fake_input_p_6] 157 | 158 | fake_ref_p_1, fake_ref_p_2, fake_ref_p_3, fake_ref_p_4, fake_ref_p_5, fake_ref_p_6 = self.netRefEN(self.ref_DE) 159 | 160 | Ref_in = [fake_ref_p_1,fake_ref_p_2, fake_ref_p_3, fake_ref_p_4, fake_ref_p_5, fake_ref_p_6] 161 | 162 | #De_in=[fake_p_1,fake_p_2,fake_p_3,fake_p_4,fake_p_5,fake_p_6] 163 | x_out = self.netRGTSI(De_in, Ref_in, self.input_mask_global) 164 | 165 | 166 | #x_out返回为,图片+损失[x_1, x_2, x_3, x_4, x_5, x_6, x_ST_fi, x_DE_fi] 167 | self.fake_out = self.netDE(x_out[0], x_out[1], x_out[2], x_out[3], x_out[4], x_out[5]) 168 | 169 | def backward_D(self): 170 | fake_AB = self.fake_out 171 | real_AB = self.Gt_DE # GroundTruth 172 | real_local = self.Gt_Local 173 | fake_local = self.fake_out[:, :, self.crop_x:self.crop_x + 64, self.crop_y:self.crop_y + 64] 174 | # Global Discriminator 175 | self.pred_fake = self.netD(fake_AB.detach()) 176 | self.pred_real = self.netD(real_AB) 177 | self.loss_D_fake = self.criterionGAN(self.pred_fake, self.pred_real, True) 178 | 179 | # Local discriminator 180 | self.pred_fake_F = self.netF(fake_local.detach()) 181 | self.pred_real_F = self.netF(real_local) 182 | self.loss_F_fake = self.criterionGAN(self.pred_fake_F, self.pred_real_F, True) 183 | 184 | self.loss_D = self.loss_D_fake + self.loss_F_fake 185 | self.loss_D.backward() 186 | 187 | def backward_G(self): 188 | # First, The generator should fake the discriminator 189 | real_AB = self.Gt_DE 190 | fake_AB = self.fake_out 191 | real_local = self.Gt_Local 192 | fake_local = self.fake_out[:, :, self.crop_x:self.crop_x + 64, self.crop_y:self.crop_y + 64] 193 | # Global discriminator 194 | pred_real = self.netD(real_AB) 195 | pred_fake = self.netD(fake_AB) 196 | # Local discriminator 197 | pred_real_F = self.netF(real_local) 198 | pred_fake_f = self.netF(fake_local) 199 | self.loss_G_GAN = self.criterionGAN(pred_fake, pred_real, False) + self.criterionGAN(pred_fake_f, pred_real_F, 200 | False) 201 | # Second, Reconstruction loss 202 | self.loss_L1 = self.criterionL1(self.fake_out, self.Gt_DE) 203 | self.Perceptual_loss = self.PerceptualLoss(self.fake_out, self.Gt_DE) 204 | self.Style_Loss = self.StyleLoss(self.fake_out, self.Gt_DE) 205 | 206 | # self.loss_G = self.loss_G_L1 + self.loss_G_GAN *0.2 + self.Perceptual_loss * 0.2 + self.Style_Loss *250 207 | self.loss_G = self.loss_L1 * self.opt.lambda_L1 + self.loss_G_GAN * self.opt.lambda_Gan + \ 208 | self.Perceptual_loss * self.opt.lambda_P + self.Style_Loss * self.opt.lambda_S 209 | 210 | 211 | self.stde_loss_value = 0 212 | 213 | for loss in self.stde_loss: 214 | 215 | self.stde_loss_value += loss.backward() 216 | self.stde_loss_value += loss.loss 217 | self.loss_G += self.stde_loss_value 218 | self.loss_G.backward() 219 | 220 | def optimize_parameters(self): 221 | self.forward() 222 | # Optimize the D and F first 223 | self.set_requires_grad(self.netF, True) 224 | self.set_requires_grad(self.netD, True) 225 | self.set_requires_grad(self.netEN, False) 226 | self.set_requires_grad(self.netRefEN, False) 227 | self.set_requires_grad(self.netDE, False) 228 | self.set_requires_grad(self.netRGTSI, False) 229 | self.optimizer_D.zero_grad() 230 | self.optimizer_F.zero_grad() 231 | self.backward_D() 232 | self.optimizer_D.step() 233 | self.optimizer_F.step() 234 | 235 | # Optimize EN, RefEN, DE, MEDEF 236 | self.set_requires_grad(self.netF, False) 237 | self.set_requires_grad(self.netD, False) 238 | self.set_requires_grad(self.netEN, True) 239 | self.set_requires_grad(self.netRefEN, True) 240 | self.set_requires_grad(self.netDE, True) 241 | self.set_requires_grad(self.netRGTSI, True) 242 | self.optimizer_EN.zero_grad() 243 | self.optimizer_RefEN.zero_grad() 244 | self.optimizer_DE.zero_grad() 245 | self.optimizer_RGTSI.zero_grad() 246 | self.backward_G() 247 | self.optimizer_RGTSI.step() 248 | self.optimizer_EN.step() 249 | self.optimizer_RefEN.step() 250 | self.optimizer_DE.step() 251 | 252 | def get_current_errors(self): 253 | # show the current loss 254 | return OrderedDict([('G_GAN', self.loss_G_GAN.data), 255 | ('G_L1', self.loss_G.data), 256 | ('G_stde', self.stde_loss_value.data), 257 | ('D', self.loss_D_fake.data), 258 | ('F', self.loss_F_fake.data) 259 | ]) 260 | 261 | # You can also see the Tensorborad 262 | def get_current_visuals(self): 263 | input_image = (self.input_DE.data.cpu()+1)/2.0 264 | ref_image = (self.ref_DE.data.cpu()+1)/2.0 265 | fake_image = (self.fake_out.data.cpu()+1)/2.0 266 | real_gt = (self.Gt_DE.data.cpu()+1)/2.0 267 | return input_image, ref_image,fake_image, real_gt 268 | 269 | -------------------------------------------------------------------------------- /models/PCconv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import torch.nn.functional as F 5 | import torch 6 | import torch.nn as nn 7 | 8 | from models.FAM.FeatureAlignment import FAM 9 | 10 | import util.util as util 11 | from util.Selfpatch import Selfpatch 12 | from util.util import saveoffset, showpatch 13 | 14 | 15 | # SE MODEL 16 | class SELayer(nn.Module): 17 | def __init__(self, channel, reduction=16): 18 | super(SELayer, self).__init__() 19 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 20 | self.fc = nn.Sequential( 21 | nn.Conv2d(channel, channel // reduction, kernel_size=1, stride=1, padding=0), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(channel // reduction, channel, kernel_size=1, stride=1, padding=0), 24 | nn.Sigmoid() 25 | ) 26 | 27 | def forward(self, x): 28 | b, c, _, _ = x.size() 29 | y = self.avg_pool(x).view(b, c, 1, 1) 30 | y = self.fc(y) 31 | return x * y.expand_as(x) 32 | 33 | class Convnorm(nn.Module): 34 | def __init__(self, in_ch, out_ch, sample='none-3', activ='leaky'): 35 | super().__init__() 36 | self.bn = nn.InstanceNorm2d(out_ch, affine=True) 37 | 38 | if sample == 'down-3': 39 | self.conv = nn.Conv2d(in_ch, out_ch, 3, 2, 1, bias=False) 40 | else: 41 | self.conv = nn.Conv2d(in_ch, out_ch, 3, 1) 42 | if activ == 'leaky': 43 | self.activation = nn.LeakyReLU(negative_slope=0.2) 44 | 45 | def forward(self, input): 46 | out = input 47 | out = self.conv(out) 48 | out = self.bn(out) 49 | if hasattr(self, 'activation'): 50 | out = self.activation(out[0]) 51 | return out 52 | 53 | class PCBActiv(nn.Module): 54 | def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='leaky', 55 | conv_bias=False, innorm=False, inner=False, outer=False): 56 | super().__init__() 57 | if sample == 'same-5': 58 | self.conv = PartialConv(in_ch, out_ch, 5, 1, 2, bias=conv_bias) 59 | elif sample == 'same-7': 60 | self.conv = PartialConv(in_ch, out_ch, 7, 1, 3, bias=conv_bias) 61 | elif sample == 'down-3': 62 | self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias) 63 | else: 64 | self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias) 65 | 66 | if bn: 67 | self.bn = nn.InstanceNorm2d(out_ch, affine=True) 68 | if activ == 'relu': 69 | self.activation = nn.ReLU() 70 | elif activ == 'leaky': 71 | self.activation = nn.LeakyReLU(negative_slope=0.2) 72 | self.innorm = innorm 73 | self.inner = inner 74 | self.outer = outer 75 | 76 | def forward(self, input): 77 | out = input 78 | if self.inner: 79 | out[0] = self.bn(out[0]) 80 | out[0] = self.activation(out[0]) 81 | out = self.conv(out) 82 | out[0] = self.bn(out[0]) 83 | out[0] = self.activation(out[0]) 84 | 85 | elif self.innorm: 86 | out = self.conv(out) 87 | out[0] = self.bn(out[0]) 88 | out[0] = self.activation(out[0]) 89 | elif self.outer: 90 | out = self.conv(out) 91 | out[0] = self.bn(out[0]) 92 | else: 93 | out = self.conv(out) 94 | out[0] = self.bn(out[0]) 95 | if hasattr(self, 'activation'): 96 | out[0] = self.activation(out[0]) 97 | return out 98 | 99 | class ConvDown(nn.Module): 100 | def __init__(self, in_c, out_c, kernel, stride, padding=0, dilation=1, groups=1, bias=False, layers=1, activ=True): 101 | super().__init__() 102 | nf_mult = 1 103 | nums = out_c / 64 104 | sequence = [] 105 | 106 | for i in range(1, layers + 1): 107 | nf_mult_prev = nf_mult 108 | if nums == 8: 109 | if in_c == 512: 110 | 111 | nfmult = 1 112 | else: 113 | nf_mult = 2 114 | 115 | else: 116 | nf_mult = min(2 ** i, 8) 117 | if kernel != 1: 118 | 119 | if activ == False and layers == 1: 120 | sequence += [ 121 | nn.Conv2d(nf_mult_prev * in_c, nf_mult * in_c, 122 | kernel_size=kernel, stride=stride, padding=padding, bias=bias), 123 | nn.InstanceNorm2d(nf_mult * in_c) 124 | ] 125 | else: 126 | sequence += [ 127 | nn.Conv2d(nf_mult_prev * in_c, nf_mult * in_c, 128 | kernel_size=kernel, stride=stride, padding=padding, bias=bias), 129 | nn.InstanceNorm2d(nf_mult * in_c), 130 | nn.LeakyReLU(0.2, True) 131 | ] 132 | 133 | else: 134 | 135 | sequence += [ 136 | nn.Conv2d(in_c, out_c, 137 | kernel_size=kernel, stride=stride, padding=padding, bias=bias), 138 | nn.InstanceNorm2d(out_c), 139 | nn.LeakyReLU(0.2, True) 140 | ] 141 | 142 | if activ == False: 143 | if i + 1 == layers: 144 | if layers == 2: 145 | sequence += [ 146 | nn.Conv2d(nf_mult * in_c, nf_mult * in_c, 147 | kernel_size=kernel, stride=stride, padding=padding, bias=bias), 148 | nn.InstanceNorm2d(nf_mult * in_c) 149 | ] 150 | else: 151 | sequence += [ 152 | nn.Conv2d(nf_mult_prev * in_c, nf_mult * in_c, 153 | kernel_size=kernel, stride=stride, padding=padding, bias=bias), 154 | nn.InstanceNorm2d(nf_mult * in_c) 155 | ] 156 | break 157 | 158 | self.model = nn.Sequential(*sequence) 159 | 160 | def forward(self, input): 161 | return self.model(input) 162 | 163 | class ConvUp(nn.Module): 164 | def __init__(self, in_c, out_c, kernel, stride, padding=0, dilation=1, groups=1, bias=False): 165 | super().__init__() 166 | 167 | self.conv = nn.Conv2d(in_c, out_c, kernel, 168 | stride, padding, dilation, groups, bias) 169 | self.bn = nn.InstanceNorm2d(out_c) 170 | self.relu = nn.LeakyReLU(negative_slope=0.2) 171 | 172 | def forward(self, input, size): 173 | out = F.interpolate(input=input, size=size, mode='bilinear') 174 | out = self.conv(out) 175 | out = self.bn(out) 176 | out = self.relu(out) 177 | return out 178 | 179 | class BASE(nn.Module): 180 | def __init__(self, inner_nc): 181 | super(BASE, self).__init__() 182 | se = SELayer(inner_nc, 16) 183 | model = [se] 184 | gus = util.gussin(1.5).cuda() 185 | self.gus = torch.unsqueeze(gus, 1).double() 186 | self.model = nn.Sequential(*model) 187 | self.down = nn.Sequential( 188 | nn.Conv2d(1024, 512, 1, 1, 0, bias=False), 189 | nn.InstanceNorm2d(512), 190 | nn.LeakyReLU(negative_slope=0.2) 191 | ) 192 | 193 | def forward(self, x): 194 | Nonparm = Selfpatch() 195 | out_32 = self.model(x) 196 | b, c, h, w = out_32.size() 197 | gus = self.gus.float() 198 | gus_out = out_32[0].expand(h * w, c, h, w) 199 | gus_out = gus * gus_out 200 | gus_out = torch.sum(gus_out, -1) 201 | gus_out = torch.sum(gus_out, -1) 202 | gus_out = gus_out.contiguous().view(b, c, h, w) 203 | csa2_in = F.sigmoid(out_32) 204 | csa2_f = torch.nn.functional.pad(csa2_in, (1, 1, 1, 1)) 205 | csa2_ff = torch.nn.functional.pad(out_32, (1, 1, 1, 1)) 206 | csa2_fff, csa2_f, csa2_conv = Nonparm.buildAutoencoder(csa2_f[0], csa2_in[0], csa2_ff[0], 3, 1) 207 | csa2_conv = csa2_conv.expand_as(csa2_f) 208 | csa_a = csa2_conv * csa2_f 209 | csa_a = torch.mean(csa_a, 1) 210 | a_c, a_h, a_w = csa_a.size() 211 | csa_a = csa_a.contiguous().view(a_c, -1) 212 | csa_a = F.softmax(csa_a, dim=1) 213 | csa_a = csa_a.contiguous().view(a_c, 1, a_h, a_h) 214 | out = csa_a * csa2_fff 215 | out = torch.sum(out, -1) 216 | out = torch.sum(out, -1) 217 | out_csa = out.contiguous().view(b, c, h, w) 218 | out_32 = torch.cat([gus_out, out_csa], 1) 219 | out_32 = self.down(out_32) 220 | return out_32 221 | 222 | class PartialConv(nn.Module): 223 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 224 | padding=0, dilation=1, groups=1, bias=True): 225 | super().__init__() 226 | self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, 227 | stride, padding, dilation, groups, bias) 228 | self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, 229 | stride, padding, dilation, groups, False) 230 | 231 | torch.nn.init.constant_(self.mask_conv.weight, 1.0) 232 | 233 | # mask is not updated 234 | for param in self.mask_conv.parameters(): 235 | param.requires_grad = False 236 | 237 | def forward(self, inputt): 238 | # http://masc.cs.gmu.edu/wiki/partialconv 239 | # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M) 240 | # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0) 241 | 242 | input = inputt[0] 243 | mask = inputt[1].float().cuda() 244 | 245 | output = self.input_conv(input * mask) 246 | if self.input_conv.bias is not None: 247 | output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as( 248 | output) 249 | else: 250 | output_bias = torch.zeros_like(output) 251 | 252 | with torch.no_grad(): 253 | output_mask = self.mask_conv(mask) 254 | 255 | no_update_holes = output_mask == 0 256 | mask_sum = output_mask.masked_fill_(no_update_holes.bool(), 1.0) 257 | output_pre = (output - output_bias) / mask_sum + output_bias 258 | output = output_pre.masked_fill_(no_update_holes.bool(), 0.0) 259 | new_mask = torch.ones_like(output) 260 | new_mask = new_mask.masked_fill_(no_update_holes.bool(), 0.0) 261 | out = [] 262 | out.append(output) 263 | out.append(new_mask) 264 | return out 265 | 266 | class PCconv(nn.Module): 267 | def __init__(self): 268 | super(PCconv, self).__init__() 269 | self.down_128 = ConvDown(64, 128, 4, 2, padding=1, layers=2) 270 | self.down_64 = ConvDown(128, 256, 4, 2, padding=1) 271 | self.down_32 = ConvDown(256, 256, 1, 1) 272 | self.down_16 = ConvDown(512, 512, 4, 2, padding=1, activ=False) 273 | self.down_8 = ConvDown(512, 512, 4, 2, padding=1, layers=2, activ=False) 274 | self.down_4 = ConvDown(512, 512, 4, 2, padding=1, layers=3, activ=False) 275 | self.down = ConvDown(768, 256, 1, 1) 276 | self.fuse = ConvDown(512, 512, 1, 1) 277 | self.up = ConvUp(512, 256, 1, 1) 278 | self.up_128 = ConvUp(512, 64, 1, 1) 279 | self.up_64 = ConvUp(512, 128, 1, 1) 280 | self.up_32 = ConvUp(512, 256, 1, 1) 281 | self.base= BASE(512) 282 | seuqence_3 = [] 283 | seuqence_5 = [] 284 | seuqence_7 = [] 285 | for i in range(5): 286 | seuqence_3 += [PCBActiv(256, 256, innorm=True)] 287 | seuqence_5 += [PCBActiv(256, 256, sample='same-5', innorm=True)] 288 | seuqence_7 += [PCBActiv(256, 256, sample='same-7', innorm=True)] 289 | 290 | self.cov_3 = nn.Sequential(*seuqence_3) 291 | self.cov_5 = nn.Sequential(*seuqence_5) 292 | self.cov_7 = nn.Sequential(*seuqence_7) 293 | self.activation = nn.LeakyReLU(negative_slope=0.2) 294 | 295 | self.TextureAlignment = FAM(in_channels=768) 296 | self.StructureAlignment = FAM(in_channels=768) 297 | 298 | def forward(self,input,reference,mask): 299 | #def forward(self,input,reference,input_Mask): 300 | mask = util.cal_feat_mask(mask, 3, 1) 301 | 302 | # input[2]:256 32 32 303 | b, c, h, w = input[2].size() 304 | mask_1 = torch.add(torch.neg(mask.float()), 1) 305 | mask_1 = mask_1.expand(b, c, h, w) 306 | 307 | x_1 = self.activation(input[0]) 308 | x_2 = self.activation(input[1]) 309 | x_3 = self.activation(input[2]) 310 | x_4 = self.activation(input[3]) 311 | x_5 = self.activation(input[4]) 312 | x_6 = self.activation(input[5]) 313 | 314 | y_1 = self.activation(reference[0]) 315 | y_2 = self.activation(reference[1]) 316 | y_3 = self.activation(reference[2]) 317 | y_4 = self.activation(reference[3]) 318 | y_5 = self.activation(reference[4]) 319 | y_6 = self.activation(reference[5]) 320 | # Change the shape of each layer and intergrate low-level/high-level features 321 | x_1 = self.down_128(x_1) 322 | x_2 = self.down_64(x_2) 323 | x_3 = self.down_32(x_3) 324 | x_4 = self.up(x_4, (32, 32)) 325 | x_5 = self.up(x_5, (32, 32)) 326 | x_6 = self.up(x_6, (32, 32)) 327 | 328 | y_1 = self.down_128(y_1) 329 | y_2 = self.down_64(y_2) 330 | y_3 = self.down_32(y_3) 331 | y_4 = self.up(y_4, (32,32)) 332 | y_5 = self.up(y_5, (32,32)) 333 | y_6 = self.up(y_6, (32,32)) 334 | # The first three layers are Texture 335 | # The last three layers are Structure 336 | x_INDE = torch.cat([x_1, x_2, x_3], 1) 337 | x_INST = torch.cat([x_4, x_5, x_6], 1) 338 | 339 | y_RFDE = torch.cat([y_1, y_2, y_3], 1) 340 | y_RFST = torch.cat([y_4, y_5, y_6], 1) 341 | 342 | #Feature Aligned 合并 343 | 344 | x_DE = self.TextureAlignment(x_INDE,y_RFDE) 345 | x_ST = self.StructureAlignment(x_INST,y_RFST) 346 | 347 | x_ST = self.down(x_ST) 348 | x_DE = self.down(x_DE) 349 | 350 | x_ST = [x_ST, mask_1] 351 | x_DE = [x_DE, mask_1] 352 | 353 | # Multi Scale PConv fill the Details 354 | x_DE_3 = self.cov_3(x_DE) 355 | x_DE_5 = self.cov_5(x_DE) 356 | x_DE_7 = self.cov_7(x_DE) 357 | x_DE_fuse = torch.cat([x_DE_3[0], x_DE_5[0], x_DE_7[0]], 1) 358 | x_DE_fi = self.down(x_DE_fuse) 359 | 360 | # Multi Scale PConv fill the Structure 361 | x_ST_3 = self.cov_3(x_ST) 362 | x_ST_5 = self.cov_5(x_ST) 363 | x_ST_7 = self.cov_7(x_ST) 364 | x_ST_fuse = torch.cat([x_ST_3[0], x_ST_5[0], x_ST_7[0]], 1) 365 | x_ST_fi = self.down(x_ST_fuse) 366 | 367 | x_cat = torch.cat([x_ST_fi, x_DE_fi], 1) 368 | x_cat_fuse = self.fuse(x_cat) 369 | 370 | # Feature equalizations 371 | x_final = self.base(x_cat_fuse) 372 | 373 | # Add back to the input 374 | x_ST = x_final 375 | x_DE = x_final 376 | x_1 = self.up_128(x_DE, (128, 128)) + input[0] 377 | x_2 = self.up_64(x_DE, (64, 64)) + input[1] 378 | x_3 = self.up_32(x_DE, (32, 32)) + input[2] 379 | x_4 = self.down_16(x_ST) + input[3] 380 | x_5 = self.down_8(x_ST) + input[4] 381 | x_6 = self.down_4(x_ST) + input[5] 382 | 383 | out_final = [x_1, x_2, x_3, x_4, x_5, x_6, x_ST_fi, x_DE_fi] 384 | return out_final 385 | --------------------------------------------------------------------------------