├── LICENSE ├── README.md ├── data └── dataloader.py ├── evaluatuion.py ├── example ├── all_images │ └── 2_black.png ├── all_labels │ └── 2_black.png └── mask │ └── 2_black.png ├── gauss.py ├── image ├── 1.png ├── 11.png ├── 2.png ├── 22.png ├── 3.png ├── 33.png ├── 4.png └── 44.png ├── loss └── Loss.py ├── models ├── Model.py ├── discriminator.py ├── networks.py └── sa_gan.py ├── test.py ├── train.py └── yz_gen.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chongyu-Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # remove-stamp 2 | 3 | This repository is the implementation of GAN, a neural network for end-to-end remove stamp.用gan实现的印章擦除,同时你也可以用这个方法擦除去其他噪声,比如水印,手写字等 4 | 5 | 6 | ## Environment 7 | 8 | ``` 9 | python = 3.7 10 | pytorch = 1.3.1 11 | torchvision = 0.4.2 12 | ``` 13 | 14 | ## Training 15 | 16 | ``` 17 | python train.py --batchSize 2 \ 18 | --dataRoot 'your path' \ 19 | --modelsSavePath 'your path' \ 20 | --logPath 'your path' \ 21 | ``` 22 | 23 | ## Testing 24 | 25 | 26 | ``` 27 | python test_image_STE.py --dataRoot 'your path' \ 28 | --batchSize 1 \ 29 | --pretrain 'your path' \ 30 | --savePath 'your path' 31 | ``` 32 | ## Model 33 | 34 | - 链接: https://pan.baidu.com/s/1WnfQfvsxRyhFx0ODdWztxg 提取码: gsfe 复制这段内容后打开百度网盘手机App,操作更方便哦 35 | --来自百度网盘超级会员v4的分享 36 | 37 | ## dataset 38 | - 数据生成与增强 可以看看。自己继续添加 39 | - 链接: https://pan.baidu.com/s/1CTzcDNHuT0OI4o4rrC1aWA 提取码: bay3 复制这段内容后打开百度网盘手机App,操作更方便哦 来自百度网盘超级会员v4的分享 40 | 41 | ## case 42 | ![0](https://github.com/tommyMessi/remove-stamp/blob/main/image/1.png) 43 | ![00](https://github.com/tommyMessi/remove-stamp/blob/main/image/11.png) 44 | ![0](https://github.com/tommyMessi/remove-stamp/blob/main/image/2.png) 45 | ![00](https://github.com/tommyMessi/remove-stamp/blob/main/image/22.png) 46 | ![0](https://github.com/tommyMessi/remove-stamp/blob/main/image/3.png) 47 | ![00](https://github.com/tommyMessi/remove-stamp/blob/main/image/33.png) 48 | ![0](https://github.com/tommyMessi/remove-stamp/blob/main/image/4.png) 49 | ![00](https://github.com/tommyMessi/remove-stamp/blob/main/image/44.png) 50 | 51 | ## other 52 | 53 | - 更多ocr 文档解析相关内容 关注微信公众号 hulugeAI 54 | 55 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | import numpy as np 5 | import cv2 6 | from os import listdir, walk 7 | from os.path import join 8 | from random import randint 9 | import random 10 | from PIL import Image 11 | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, Resize, RandomHorizontalFlip 12 | 13 | 14 | def random_horizontal_flip(imgs): 15 | if random.random() < 0.3: 16 | for i in range(len(imgs)): 17 | imgs[i] = imgs[i].transpose(Image.FLIP_LEFT_RIGHT) 18 | return imgs 19 | 20 | def random_rotate(imgs): 21 | if random.random() < 0.3: 22 | max_angle = 10 23 | angle = random.random() * 2 * max_angle - max_angle 24 | # print(angle) 25 | for i in range(len(imgs)): 26 | img = np.array(imgs[i]) 27 | w, h = img.shape[:2] 28 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 29 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w)) 30 | imgs[i] =Image.fromarray(img_rotation) 31 | return imgs 32 | 33 | def CheckImageFile(filename): 34 | return any(filename.endswith(extention) for extention in ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP']) 35 | 36 | def ImageTransform(loadSize): 37 | return Compose([ 38 | Resize(size=loadSize, interpolation=Image.BICUBIC), 39 | ToTensor(), 40 | ]) 41 | 42 | class ErasingData(Dataset): 43 | def __init__(self, dataRoot, loadSize, training=True): 44 | super(ErasingData, self).__init__() 45 | self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ 46 | for files in filenames if CheckImageFile(files)] 47 | self.loadSize = loadSize 48 | self.ImgTrans = ImageTransform(loadSize) 49 | self.training = training 50 | 51 | def __getitem__(self, index): 52 | img = Image.open(self.imageFiles[index]) 53 | mask = Image.open(self.imageFiles[index].replace('all_images','mask')) 54 | gt = Image.open(self.imageFiles[index].replace('all_images','all_labels')) 55 | # import pdb;pdb.set_trace() 56 | if self.training: 57 | # ### for data augmentation 58 | all_input = [img, mask, gt] 59 | all_input = random_horizontal_flip(all_input) 60 | all_input = random_rotate(all_input) 61 | img = all_input[0] 62 | mask = all_input[1] 63 | gt = all_input[2] 64 | ### for data augmentation 65 | inputImage = self.ImgTrans(img.convert('RGB')) 66 | mask = self.ImgTrans(mask.convert('RGB')) 67 | groundTruth = self.ImgTrans(gt.convert('RGB')) 68 | path = self.imageFiles[index].split('/')[-1] 69 | # import pdb;pdb.set_trace() 70 | 71 | return inputImage, groundTruth, mask, path 72 | 73 | def __len__(self): 74 | return len(self.imageFiles) 75 | 76 | class devdata(Dataset): 77 | def __init__(self, dataRoot, gtRoot, loadSize=512): 78 | super(devdata, self).__init__() 79 | self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ 80 | for files in filenames if CheckImageFile(files)] 81 | self.gtFiles = [join (gtRootK, files) for gtRootK, dn, filenames in walk(gtRoot) \ 82 | for files in filenames if CheckImageFile(files)] 83 | self.loadSize = loadSize 84 | self.ImgTrans = ImageTransform(loadSize) 85 | 86 | def __getitem__(self, index): 87 | img = Image.open(self.imageFiles[index]) 88 | gt = Image.open(self.gtFiles[index]) 89 | #import pdb;pdb.set_trace() 90 | inputImage = self.ImgTrans(img.convert('RGB')) 91 | 92 | groundTruth = self.ImgTrans(gt.convert('RGB')) 93 | path = self.imageFiles[index].split('/')[-1] 94 | 95 | return inputImage, groundTruth,path 96 | 97 | def __len__(self): 98 | return len(self.imageFiles) 99 | -------------------------------------------------------------------------------- /evaluatuion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | from PIL import Image 8 | import numpy as np 9 | from torch.autograd import Variable 10 | from torchvision.utils import save_image 11 | from torch.utils.data import DataLoader 12 | from data.dataloader import devdata 13 | from scipy import signal, ndimage 14 | import gauss 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--target_path', type=str, default='', 19 | help='results') 20 | parser.add_argument('--gt_path', type=str, default='', 21 | help='labels') 22 | args = parser.parse_args() 23 | 24 | sum_psnr = 0 25 | sum_ssim = 0 26 | sum_AGE = 0 27 | sum_pCEPS = 0 28 | sum_pEPS = 0 29 | sum_mse = 0 30 | 31 | count = 0 32 | sum_time = 0.0 33 | l1_loss = 0 34 | 35 | img_path = args.target_path 36 | gt_path = args.gt_path 37 | 38 | 39 | def ssim(img1, img2, cs_map=False): 40 | """Return the Structural Similarity Map corresponding to input images img1 41 | and img2 (images are assumed to be uint8) 42 | 43 | This function attempts to mimic precisely the functionality of ssim.m a 44 | MATLAB provided by the author's of SSIM 45 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m 46 | """ 47 | img1 = img1.astype(float) 48 | img2 = img2.astype(float) 49 | 50 | size = min(img1.shape[0], 11) 51 | sigma = 1.5 52 | window = gauss.fspecial_gauss(size, sigma) 53 | K1 = 0.01 54 | K2 = 0.03 55 | L = 255 #bitdepth of image 56 | C1 = (K1 * L) ** 2 57 | C2 = (K2 * L) ** 2 58 | # import pdb;pdb.set_trace() 59 | mu1 = signal.fftconvolve(img1, window, mode = 'valid') 60 | mu2 = signal.fftconvolve(img2, window, mode = 'valid') 61 | mu1_sq = mu1 * mu1 62 | mu2_sq = mu2 * mu2 63 | mu1_mu2 = mu1 * mu2 64 | sigma1_sq = signal.fftconvolve(img1 * img1, window, mode = 'valid') - mu1_sq 65 | sigma2_sq = signal.fftconvolve(img2 * img2, window, mode = 'valid') - mu2_sq 66 | sigma12 = signal.fftconvolve(img1 * img2, window, mode = 'valid') - mu1_mu2 67 | if cs_map: 68 | return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)), 69 | (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) 70 | else: 71 | return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 72 | (sigma1_sq + sigma2_sq + C2)) 73 | 74 | 75 | def msssim(img1, img2): 76 | """This function implements Multi-Scale Structural Similarity (MSSSIM) Image 77 | Quality Assessment according to Z. Wang's "Multi-scale structural similarity 78 | for image quality assessment" Invited Paper, IEEE Asilomar Conference on 79 | Signals, Systems and Computers, Nov. 2003 80 | 81 | Author's MATLAB implementation:- 82 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 83 | """ 84 | level = 5 85 | weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 86 | downsample_filter = np.ones((2, 2)) / 4.0 87 | # im1 = img1.astype(np.float64) 88 | # im2 = img2.astype(np.float64) 89 | mssim = np.array([]) 90 | mcs = np.array([]) 91 | for l in range(level): 92 | ssim_map, cs_map = ssim(img1, img2, cs_map = True) 93 | mssim = np.append(mssim, ssim_map.mean()) 94 | mcs = np.append(mcs, cs_map.mean()) 95 | filtered_im1 = ndimage.filters.convolve(img1, downsample_filter, 96 | mode = 'reflect') 97 | filtered_im2 = ndimage.filters.convolve(img2, downsample_filter, 98 | mode = 'reflect') 99 | im1 = filtered_im1[: : 2, : : 2] 100 | im2 = filtered_im2[: : 2, : : 2] 101 | 102 | # Note: Remove the negative and add it later to avoid NaN in exponential. 103 | sign_mcs = np.sign(mcs[0 : level - 1]) 104 | sign_mssim = np.sign(mssim[level - 1]) 105 | mcs_power = np.power(np.abs(mcs[0 : level - 1]), weight[0 : level - 1]) 106 | mssim_power = np.power(np.abs(mssim[level - 1]), weight[level - 1]) 107 | return np.prod(sign_mcs * mcs_power) * sign_mssim * mssim_power 108 | 109 | def ImageTransform(loadSize, cropSize): 110 | return Compose([ 111 | Resize(size=loadSize, interpolation=Image.BICUBIC), 112 | # RandomCrop(size=cropSize), 113 | #RandomHorizontalFlip(p=0.5), 114 | ToTensor(), 115 | ]) 116 | 117 | def visual(image): 118 | im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy() 119 | Image.fromarray(im[0].astype(np.uint8)).show() 120 | 121 | imgData = devdata(dataRoot=img_path, gtRoot=gt_path) 122 | data_loader = DataLoader(imgData, batch_size=1, shuffle=True, num_workers=0, drop_last=False) 123 | 124 | for k, (img,lbl,path) in enumerate(data_loader): 125 | ##import pdb;pdb.set_trace() 126 | mse = ((lbl - img)**2).mean() 127 | sum_mse += mse 128 | print(path,count, 'mse: ', mse) 129 | if mse == 0: 130 | continue 131 | count += 1 132 | psnr = 10 * math.log10(1/mse) 133 | sum_psnr += psnr 134 | print(path,count, ' psnr: ', psnr) 135 | #l1_loss += nn.L1Loss()(img, lbl) 136 | 137 | 138 | R = lbl[0,0,:, :] 139 | G = lbl[0,1,:, :] 140 | B = lbl[0,2,:, :] 141 | 142 | YGT = .299 * R + .587 * G + .114 * B 143 | 144 | R = img[0,0,:, :] 145 | G = img[0,1,:, :] 146 | B = img[0,2,:, :] 147 | 148 | YBC = .299 * R + .587 * G + .114 * B 149 | Diff = abs(np.array(YBC*255) - np.array(YGT*255)).round().astype(np.uint8) 150 | AGE = np.mean(Diff) 151 | print(' AGE: ', AGE) 152 | mssim = msssim(np.array(YGT*255), np.array(YBC*255)) 153 | sum_ssim += mssim 154 | print(count, ' ssim:', mssim) 155 | threshold = 20 156 | 157 | Errors = Diff > threshold 158 | EPs = sum(sum(Errors)).astype(float) 159 | pEPs = EPs / float(512*512) 160 | print(' pEPS: ' , pEPs) 161 | sum_pEPS += pEPs 162 | ########################## CEPs and pCEPs ################################ 163 | structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) 164 | sum_AGE+=AGE 165 | erodedErrors = ndimage.binary_erosion(Errors, structure).astype(Errors.dtype) 166 | CEPs = sum(sum(erodedErrors)) 167 | pCEPs = CEPs / float(512*512) 168 | print(' pCEPS: ' , pCEPs) 169 | sum_pCEPS += pCEPs 170 | 171 | print(sum_psnr) 172 | print('avg mse:', sum_mse / count) 173 | print('average psnr:', sum_psnr / count) 174 | print('average ssim:', sum_ssim / count) 175 | print('average AGE:', sum_AGE / count) 176 | print('average pEPS:', sum_pEPS / count) 177 | print('average pCEPS:', sum_pCEPS / count) 178 | -------------------------------------------------------------------------------- /example/all_images/2_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/example/all_images/2_black.png -------------------------------------------------------------------------------- /example/all_labels/2_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/example/all_labels/2_black.png -------------------------------------------------------------------------------- /example/mask/2_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/example/mask/2_black.png -------------------------------------------------------------------------------- /gauss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module providing functionality surrounding gaussian function. 3 | """ 4 | SVN_REVISION = '$LastChangedRevision: 16541 $' 5 | 6 | import sys 7 | import numpy 8 | 9 | def gaussian2(size, sigma): 10 | """Returns a normalized circularly symmetric 2D gauss kernel array 11 | 12 | f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where 13 | 14 | A = 1/(2*pi*sigma^2) 15 | 16 | as define by Wolfram Mathworld 17 | http://mathworld.wolfram.com/GaussianFunction.html 18 | """ 19 | A = 1/(2.0*numpy.pi*sigma**2) 20 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 21 | g = A*numpy.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) 22 | return g 23 | 24 | def fspecial_gauss(size, sigma): 25 | """Function to mimic the 'fspecial' gaussian MATLAB function 26 | """ 27 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 28 | g = numpy.exp(-((x**2 + y**2)/(2.0*sigma**2))) 29 | return g/g.sum() 30 | 31 | def main(): 32 | """Show simple use cases for functionality provided by this module.""" 33 | from mpl_toolkits.mplot3d.axes3d import Axes3D 34 | import pylab 35 | argv = sys.argv 36 | if len(argv) != 3: 37 | print >>sys.stderr, 'usage: python -m pim.sp.gauss size sigma' 38 | sys.exit(2) 39 | size = int(argv[1]) 40 | sigma = float(argv[2]) 41 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 42 | 43 | fig = pylab.figure() 44 | fig.suptitle('Some 2-D Gauss Functions') 45 | ax = fig.add_subplot(2, 1, 1, projection='3d') 46 | ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1, 47 | linewidth=0, antialiased=False, cmap=pylab.jet()) 48 | ax = fig.add_subplot(2, 1, 2, projection='3d') 49 | ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1, 50 | linewidth=0, antialiased=False, cmap=pylab.jet()) 51 | pylab.show() 52 | return 0 53 | 54 | if __name__ == '__main__': 55 | sys.exit(main()) -------------------------------------------------------------------------------- /image/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/1.png -------------------------------------------------------------------------------- /image/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/11.png -------------------------------------------------------------------------------- /image/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/2.png -------------------------------------------------------------------------------- /image/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/22.png -------------------------------------------------------------------------------- /image/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/3.png -------------------------------------------------------------------------------- /image/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/33.png -------------------------------------------------------------------------------- /image/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/4.png -------------------------------------------------------------------------------- /image/44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tommyMessi/remove-stamp/2880f96359d9c90e89457e482d6aff397814535c/image/44.png -------------------------------------------------------------------------------- /loss/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import autograd 4 | import torch.nn.functional as F 5 | from tensorboardX import SummaryWriter 6 | from models.discriminator import Discriminator_STE 7 | from PIL import Image 8 | import numpy as np 9 | 10 | def gram_matrix(feat): 11 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py 12 | (b, ch, h, w) = feat.size() 13 | feat = feat.view(b, ch, h * w) 14 | feat_t = feat.transpose(1, 2) 15 | gram = torch.bmm(feat, feat_t) / (ch * h * w) 16 | return gram 17 | 18 | def visual(image): 19 | im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy() 20 | Image.fromarray(im[0].astype(np.uint8)).show() 21 | 22 | def dice_loss(input, target): 23 | input = torch.sigmoid(input) 24 | 25 | input = input.contiguous().view(input.size()[0], -1) 26 | target = target.contiguous().view(target.size()[0], -1) 27 | 28 | input = input 29 | target = target 30 | 31 | a = torch.sum(input * target, 1) 32 | b = torch.sum(input * input, 1) + 0.001 33 | c = torch.sum(target * target, 1) + 0.001 34 | d = (2 * a) / (b + c) 35 | dice_loss = torch.mean(d) 36 | return 1 - dice_loss 37 | 38 | class LossWithGAN_STE(nn.Module): 39 | def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)): 40 | super(LossWithGAN_STE, self).__init__() 41 | self.l1 = nn.L1Loss() 42 | self.extractor = extractor 43 | self.discriminator = Discriminator_STE(3) ## local_global sn patch gan 44 | self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit) 45 | self.cudaAvailable = torch.cuda.is_available() 46 | self.numOfGPUs = torch.cuda.device_count() 47 | self.lamda = Lamda 48 | self.writer = SummaryWriter(logPath) 49 | 50 | def forward(self, input, mask, x_o1,x_o2,x_o3,output,mm, gt, count, epoch): 51 | self.discriminator.zero_grad() 52 | D_real = self.discriminator(gt, mask) 53 | D_real = D_real.mean().sum() * -1 54 | D_fake = self.discriminator(output, mask) 55 | D_fake = D_fake.mean().sum() * 1 56 | D_loss = torch.mean(F.relu(1.+D_real)) + torch.mean(F.relu(1.+D_fake)) #SN-patch-GAN loss 57 | D_fake = -torch.mean(D_fake) # SN-Patch-GAN loss 58 | 59 | self.D_optimizer.zero_grad() 60 | D_loss.backward(retain_graph=True) 61 | self.D_optimizer.step() 62 | 63 | self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count) 64 | 65 | output_comp = mask * input + (1 - mask) * output 66 | # import pdb;pdb.set_trace() 67 | holeLoss = 10 * self.l1((1 - mask) * output, (1 - mask) * gt) 68 | validAreaLoss = 2*self.l1(mask * output, mask * gt) 69 | 70 | mask_loss = dice_loss(mm, 1-mask) 71 | ### MSR loss ### 72 | masks_a = F.interpolate(mask, scale_factor=0.25) 73 | masks_b = F.interpolate(mask, scale_factor=0.5) 74 | imgs1 = F.interpolate(gt, scale_factor=0.25) 75 | imgs2 = F.interpolate(gt, scale_factor=0.5) 76 | msrloss = 8 * self.l1((1-mask)*x_o3,(1-mask)*gt) + 0.8*self.l1(mask*x_o3, mask*gt)+\ 77 | 6 * self.l1((1-masks_b)*x_o2,(1-masks_b)*imgs2)+1*self.l1(masks_b*x_o2,masks_b*imgs2)+\ 78 | 5 * self.l1((1-masks_a)*x_o1,(1-masks_a)*imgs1)+0.8*self.l1(masks_a*x_o1,masks_a*imgs1) 79 | 80 | feat_output_comp = self.extractor(output_comp) 81 | feat_output = self.extractor(output) 82 | feat_gt = self.extractor(gt) 83 | 84 | prcLoss = 0.0 85 | for i in range(3): 86 | prcLoss += 0.01 * self.l1(feat_output[i], feat_gt[i]) 87 | prcLoss += 0.01 * self.l1(feat_output_comp[i], feat_gt[i]) 88 | 89 | styleLoss = 0.0 90 | for i in range(3): 91 | styleLoss += 120 * self.l1(gram_matrix(feat_output[i]), 92 | gram_matrix(feat_gt[i])) 93 | styleLoss += 120 * self.l1(gram_matrix(feat_output_comp[i]), 94 | gram_matrix(feat_gt[i])) 95 | """ if self.numOfGPUs > 1: 96 | holeLoss = holeLoss.sum() / self.numOfGPUs 97 | validAreaLoss = validAreaLoss.sum() / self.numOfGPUs 98 | prcLoss = prcLoss.sum() / self.numOfGPUs 99 | styleLoss = styleLoss.sum() / self.numOfGPUs """ 100 | self.writer.add_scalar('LossG/Hole loss', holeLoss.item(), count) 101 | self.writer.add_scalar('LossG/Valid loss', validAreaLoss.item(), count) 102 | self.writer.add_scalar('LossG/msr loss', msrloss.item(), count) 103 | self.writer.add_scalar('LossPrc/Perceptual loss', prcLoss.item(), count) 104 | self.writer.add_scalar('LossStyle/style loss', styleLoss.item(), count) 105 | 106 | GLoss = msrloss+ holeLoss + validAreaLoss+ prcLoss + styleLoss + 0.1 * D_fake + 1*mask_loss 107 | self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count) 108 | return GLoss.sum() 109 | 110 | -------------------------------------------------------------------------------- /models/Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | #VGG16 feature extract 6 | class VGG16FeatureExtractor(nn.Module): 7 | def __init__(self): 8 | super(VGG16FeatureExtractor, self).__init__() 9 | vgg16 = models.vgg16(pretrained=True) 10 | # vgg16.load_state_dict(torch.load('./vgg16-397923af.pth')) 11 | self.enc_1 = nn.Sequential(*vgg16.features[:5]) 12 | self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 13 | self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 14 | 15 | # fix the encoder 16 | for i in range(3): 17 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 18 | param.requires_grad = False 19 | 20 | def forward(self, image): 21 | results = [image] 22 | for i in range(3): 23 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 24 | results.append(func(results[-1])) 25 | return results[1:] 26 | 27 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .networks import ConvWithActivation, get_pad 4 | 5 | ##discriminator 6 | class Discriminator_STE(nn.Module): 7 | def __init__(self, inputChannels): 8 | super(Discriminator_STE, self).__init__() 9 | cnum =32 10 | self.globalDis = nn.Sequential( 11 | ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), 12 | ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), 13 | ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), 14 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), 15 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), 16 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), 17 | ) 18 | 19 | self.localDis = nn.Sequential( 20 | ConvWithActivation(3, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), 21 | ConvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), 22 | ConvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), 23 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), 24 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), 25 | ConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), 26 | ) 27 | 28 | self.fusion = nn.Sequential( 29 | nn.Conv2d(512, 1, kernel_size=4), 30 | nn.Sigmoid() 31 | ) 32 | 33 | def forward(self, input, masks): 34 | global_feat = self.globalDis(input) 35 | local_feat = self.localDis(input * (1 - masks)) 36 | 37 | concat_feat = torch.cat((global_feat, local_feat), 1) 38 | 39 | return self.fusion(concat_feat).view(input.size()[0], -1) 40 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | def get_pad(in_, ksize, stride, atrous=1): 7 | out_ = np.ceil(float(in_)/stride) 8 | return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2) 9 | 10 | class ConvWithActivation(torch.nn.Module): 11 | """ 12 | SN convolution for spetral normalization conv 13 | """ 14 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): 15 | super(ConvWithActivation, self).__init__() 16 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 17 | self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) 18 | self.activation = activation 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight) 22 | def forward(self, input): 23 | x = self.conv2d(input) 24 | if self.activation is not None: 25 | return self.activation(x) 26 | else: 27 | return x 28 | 29 | class DeConvWithActivation(torch.nn.Module): 30 | """ 31 | SN convolution for spetral normalization conv 32 | """ 33 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): 34 | super(DeConvWithActivation, self).__init__() 35 | self.conv2d = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 36 | self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) 37 | self.activation = activation 38 | for m in self.modules(): 39 | if isinstance(m, nn.ConvTranspose2d): 40 | nn.init.kaiming_normal_(m.weight) 41 | def forward(self, input): 42 | x = self.conv2d(input) 43 | if self.activation is not None: 44 | return self.activation(x) 45 | else: 46 | return x -------------------------------------------------------------------------------- /models/sa_gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | from torch.autograd import Variable 7 | from .networks import get_pad, ConvWithActivation, DeConvWithActivation 8 | 9 | def img2photo(imgs): 10 | return ((imgs+1)*127.5).transpose(1,2).transpose(2,3).detach().cpu().numpy() 11 | 12 | def visual(imgs): 13 | im = img2photo(imgs) 14 | Image.fromarray(im[0].astype(np.uint8)).show() 15 | 16 | class Residual(nn.Module): 17 | def __init__(self, in_channels, out_channels, same_shape=True, **kwargs): 18 | super(Residual,self).__init__() 19 | self.same_shape = same_shape 20 | strides = 1 if same_shape else 2 21 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,stride=strides) 22 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 23 | # self.conv2 = torch.nn.utils.spectral_norm(self.conv2) 24 | if not same_shape: 25 | self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, 26 | # self.conv3 = nn.Conv2D(channels, kernel_size=3, padding=1, 27 | stride=strides) 28 | # self.conv3 = torch.nn.utils.spectral_norm(self.conv3) 29 | self.batch_norm2d = nn.BatchNorm2d(out_channels) 30 | 31 | def forward(self,x): 32 | out = F.relu(self.conv1(x)) 33 | out = self.conv2(out) 34 | if not self.same_shape: 35 | x = self.conv3(x) 36 | out = self.batch_norm2d(out + x) 37 | # out = out + x 38 | return F.relu(out) 39 | 40 | class ASPP(nn.Module): 41 | def __init__(self, in_channel=512, depth=256): 42 | super(ASPP,self).__init__() 43 | self.mean = nn.AdaptiveAvgPool2d((1, 1)) 44 | self.conv = nn.Conv2d(in_channel, depth, 1, 1) 45 | # k=1 s=1 no pad 46 | self.atrous_block1 = nn.Conv2d(in_channel, depth, 1, 1) 47 | self.atrous_block6 = nn.Conv2d(in_channel, depth, 3, 1, padding=6, dilation=6) 48 | self.atrous_block12 = nn.Conv2d(in_channel, depth, 3, 1, padding=12, dilation=12) 49 | self.atrous_block18 = nn.Conv2d(in_channel, depth, 3, 1, padding=18, dilation=18) 50 | 51 | self.conv_1x1_output = nn.Conv2d(depth * 5, depth, 1, 1) 52 | 53 | def forward(self, x): 54 | size = x.shape[2:] 55 | 56 | image_features = self.mean(x) 57 | image_features = self.conv(image_features) 58 | image_features = F.upsample(image_features, size=size, mode='bilinear') 59 | 60 | atrous_block1 = self.atrous_block1(x) 61 | 62 | atrous_block6 = self.atrous_block6(x) 63 | 64 | atrous_block12 = self.atrous_block12(x) 65 | 66 | atrous_block18 = self.atrous_block18(x) 67 | 68 | net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6, 69 | atrous_block12, atrous_block18], dim=1)) 70 | return net 71 | 72 | class STRnet2(nn.Module): 73 | def __init__(self, n_in_channel=3): 74 | super(STRnet2, self).__init__() 75 | #### U-Net #### 76 | #downsample 77 | self.conv1 = ConvWithActivation(3,32,kernel_size=4,stride=2,padding=1) 78 | self.conva = ConvWithActivation(32,32,kernel_size=3, stride=1, padding=1) 79 | self.convb = ConvWithActivation(32,64, kernel_size=4, stride=2, padding=1) 80 | self.res1 = Residual(64,64) 81 | self.res2 = Residual(64,64) 82 | self.res3 = Residual(64,128,same_shape=False) 83 | self.res4 = Residual(128,128) 84 | self.res5 = Residual(128,256,same_shape=False) 85 | # self.nn = ConvWithActivation(256, 512, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)) 86 | self.res6 = Residual(256,256) 87 | self.res7 = Residual(256,512,same_shape=False) 88 | self.res8 = Residual(512,512) 89 | self.conv2 = ConvWithActivation(512,512,kernel_size=1) 90 | 91 | #upsample 92 | self.deconv1 = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2) 93 | self.deconv2 = DeConvWithActivation(256*2,128,kernel_size=3,padding=1,stride=2) 94 | self.deconv3 = DeConvWithActivation(128*2,64,kernel_size=3,padding=1,stride=2) 95 | self.deconv4 = DeConvWithActivation(64*2,32,kernel_size=3,padding=1,stride=2) 96 | self.deconv5 = DeConvWithActivation(64,3,kernel_size=3,padding=1,stride=2) 97 | 98 | #lateral connection 99 | self.lateral_connection1 = nn.Sequential( 100 | nn.Conv2d(256, 256, kernel_size=1, padding=0,stride=1), 101 | nn.Conv2d(256, 512, kernel_size=3, padding=1,stride=1), 102 | nn.Conv2d(512, 512, kernel_size=3, padding=1,stride=1), 103 | nn.Conv2d(512, 256, kernel_size=1, padding=0,stride=1),) 104 | self.lateral_connection2 = nn.Sequential( 105 | nn.Conv2d(128, 128, kernel_size=1, padding=0,stride=1), 106 | nn.Conv2d(128, 256, kernel_size=3, padding=1,stride=1), 107 | nn.Conv2d(256, 256, kernel_size=3, padding=1,stride=1), 108 | nn.Conv2d(256, 128, kernel_size=1, padding=0,stride=1),) 109 | self.lateral_connection3 = nn.Sequential( 110 | nn.Conv2d(64, 64, kernel_size=1, padding=0,stride=1), 111 | nn.Conv2d(64, 128, kernel_size=3, padding=1,stride=1), 112 | nn.Conv2d(128, 128, kernel_size=3, padding=1,stride=1), 113 | nn.Conv2d(128, 64, kernel_size=1, padding=0,stride=1),) 114 | self.lateral_connection4 = nn.Sequential( 115 | nn.Conv2d(32, 32, kernel_size=1, padding=0,stride=1), 116 | nn.Conv2d(32, 64, kernel_size=3, padding=1,stride=1), 117 | nn.Conv2d(64, 64, kernel_size=3, padding=1,stride=1), 118 | nn.Conv2d(64, 32, kernel_size=1, padding=0,stride=1),) 119 | 120 | #self.relu = nn.elu(alpha=1.0) 121 | self.conv_o1 = nn.Conv2d(64,3,kernel_size=1) 122 | self.conv_o2 = nn.Conv2d(32,3,kernel_size=1) 123 | ##### U-Net ##### 124 | 125 | ### ASPP ### 126 | # self.aspp = ASPP(512, 256) 127 | ### ASPP ### 128 | 129 | ### mask branch decoder ### 130 | self.mask_deconv_a = DeConvWithActivation(512,256,kernel_size=3,padding=1,stride=2) 131 | self.mask_conv_a = ConvWithActivation(256,128,kernel_size=3,padding=1,stride=1) 132 | self.mask_deconv_b = DeConvWithActivation(256,128,kernel_size=3,padding=1,stride=2) 133 | self.mask_conv_b = ConvWithActivation(128,64,kernel_size=3,padding=1,stride=1) 134 | self.mask_deconv_c = DeConvWithActivation(128,64,kernel_size=3,padding=1,stride=2) 135 | self.mask_conv_c = ConvWithActivation(64,32,kernel_size=3,padding=1,stride=1) 136 | self.mask_deconv_d = DeConvWithActivation(64,32,kernel_size=3,padding=1,stride=2) 137 | self.mask_conv_d = nn.Conv2d(32,3,kernel_size=1) 138 | ### mask branch ### 139 | 140 | ##### Refine sub-network ###### 141 | n_in_channel = 3 142 | cnum = 32 143 | ####downsapmle 144 | self.coarse_conva = ConvWithActivation(n_in_channel, cnum, kernel_size=5, stride=1, padding=2) 145 | self.coarse_convb = ConvWithActivation(cnum, 2*cnum, kernel_size=4, stride=2, padding=1) 146 | self.coarse_convc = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1) 147 | self.coarse_convd = ConvWithActivation(2*cnum, 4*cnum, kernel_size=4, stride=2, padding=1) 148 | self.coarse_conve = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) 149 | self.coarse_convf = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) 150 | ### astrous 151 | self.astrous_net = nn.Sequential( 152 | ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)), 153 | ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=4, padding=get_pad(64, 3, 1, 4)), 154 | ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=8, padding=get_pad(64, 3, 1, 8)), 155 | ConvWithActivation(4*cnum, 4*cnum, 3, 1, dilation=16, padding=get_pad(64, 3, 1, 16)), 156 | ) 157 | ###astrous 158 | ### upsample 159 | self.coarse_convk = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) 160 | self.coarse_convl = ConvWithActivation(4*cnum, 4*cnum, kernel_size=3, stride=1, padding=1) 161 | self.coarse_deconva = DeConvWithActivation(4*cnum*3, 2*cnum, kernel_size=3,padding=1,stride=2) 162 | self.coarse_convm = ConvWithActivation(2*cnum, 2*cnum, kernel_size=3, stride=1, padding=1) 163 | self.coarse_deconvb = DeConvWithActivation(2*cnum*3, cnum, kernel_size=3,padding=1,stride=2) 164 | self.coarse_convn = nn.Sequential( 165 | ConvWithActivation(cnum, cnum//2, kernel_size=3, stride=1, padding=1), 166 | #Self_Attn(cnum//2, 'relu'), 167 | ConvWithActivation(cnum//2, 3, kernel_size=3, stride=1, padding=1, activation=None), 168 | ) 169 | self.c1 = nn.Conv2d(32,64,kernel_size=1) 170 | self.c2 = nn.Conv2d(64,128,kernel_size=1) 171 | ##### Refine network ###### 172 | 173 | def forward(self, x): 174 | #downsample 175 | x = self.conv1(x) 176 | x = self.conva(x) 177 | con_x1 = x 178 | # import pdb;pdb.set_trace() 179 | x = self.convb(x) 180 | x = self.res1(x) 181 | con_x2 = x 182 | x = self.res2(x) 183 | x = self.res3(x) 184 | con_x3 = x 185 | x = self.res4(x) 186 | x = self.res5(x) 187 | con_x4 = x 188 | x = self.res6(x) 189 | # x_mask = self.nn(con_x4) ### for mask branch aspp 190 | # x_mask = self.aspp(x_mask) ### for mask branch aspp 191 | x_mask=x ### no aspp 192 | # import pdb;pdb.set_trace() 193 | x = self.res7(x) 194 | x = self.res8(x) 195 | x = self.conv2(x) 196 | #upsample 197 | x = self.deconv1(x) 198 | x = torch.cat([self.lateral_connection1(con_x4), x], dim=1) 199 | x = self.deconv2(x) 200 | x = torch.cat([self.lateral_connection2(con_x3), x], dim=1) 201 | x = self.deconv3(x) 202 | xo1 = x 203 | x = torch.cat([self.lateral_connection3(con_x2), x], dim=1) 204 | x = self.deconv4(x) 205 | xo2 = x 206 | x = torch.cat([self.lateral_connection4(con_x1), x], dim=1) 207 | #import pdb;pdb.set_trace() 208 | x = self.deconv5(x) 209 | x_o1 = self.conv_o1(xo1) 210 | x_o2 = self.conv_o2(xo2) 211 | x_o_unet = x 212 | 213 | ### mask branch ### 214 | mm = self.mask_deconv_a(torch.cat([x_mask,con_x4],dim=1)) 215 | mm = self.mask_conv_a(mm) 216 | mm = self.mask_deconv_b(torch.cat([mm,con_x3],dim=1)) 217 | mm = self.mask_conv_b(mm) 218 | mm = self.mask_deconv_c(torch.cat([mm,con_x2],dim=1)) 219 | mm = self.mask_conv_c(mm) 220 | mm = self.mask_deconv_d(torch.cat([mm,con_x1],dim=1)) 221 | mm = self.mask_conv_d(mm) 222 | ### mask branch ### 223 | 224 | ###refine sub-network 225 | x = self.coarse_conva(x_o_unet) 226 | x = self.coarse_convb(x) 227 | x = self.coarse_convc(x) 228 | x_c1 = x ###concate feature1 229 | x = self.coarse_convd(x) 230 | x = self.coarse_conve(x) 231 | x = self.coarse_convf(x) 232 | x_c2 = x ###concate feature2 233 | x = self.astrous_net(x) 234 | x = self.coarse_convk(x) 235 | x = self.coarse_convl(x) 236 | x = self.coarse_deconva(torch.cat([x, x_c2,self.c2(con_x2)],dim=1)) 237 | x = self.coarse_convm(x) 238 | x = self.coarse_deconvb(torch.cat([x,x_c1,self.c1(con_x1)],dim=1)) 239 | x = self.coarse_convn(x) 240 | return x_o1, x_o2, x_o_unet, x, mm -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | from PIL import Image 8 | import numpy as np 9 | from torch.autograd import Variable 10 | from torchvision.utils import save_image 11 | from torch.utils.data import DataLoader 12 | from data.dataloader import ErasingData 13 | from models.sa_gan import STRnet2 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--numOfWorkers', type=int, default=0, 18 | help='workers for dataloader') 19 | parser.add_argument('--modelsSavePath', type=str, default='', 20 | help='path for saving models') 21 | parser.add_argument('--logPath', type=str, 22 | default='') 23 | parser.add_argument('--batchSize', type=int, default=16) 24 | parser.add_argument('--loadSize', type=int, default=512, 25 | help='image loading size') 26 | parser.add_argument('--dataRoot', type=str, 27 | default='') 28 | parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning') 29 | parser.add_argument('--savePath', type=str, default='./results/sn_tv/') 30 | args = parser.parse_args() 31 | 32 | cuda = torch.cuda.is_available() 33 | if cuda: 34 | print('Cuda is available!') 35 | cudnn.benchmark = True 36 | 37 | 38 | def visual(image): 39 | im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy() 40 | Image.fromarray(im[0].astype(np.uint8)).show() 41 | 42 | batchSize = args.batchSize 43 | loadSize = (args.loadSize, args.loadSize) 44 | dataRoot = args.dataRoot 45 | savePath = args.savePath 46 | result_with_mask = savePath + 'WithMaskOutput/' 47 | result_straight = savePath + 'StrOuput/' 48 | #import pdb;pdb.set_trace() 49 | 50 | if not os.path.exists(savePath): 51 | os.makedirs(savePath) 52 | os.makedirs(result_with_mask) 53 | os.makedirs(result_straight) 54 | 55 | 56 | Erase_data = ErasingData(dataRoot, loadSize, training=False) 57 | Erase_data = DataLoader(Erase_data, batch_size=batchSize, shuffle=True, num_workers=args.numOfWorkers, drop_last=False) 58 | 59 | 60 | netG = STRnet2(3) 61 | 62 | netG.load_state_dict(torch.load(args.pretrained)) 63 | 64 | # 65 | if cuda: 66 | netG = netG.cuda() 67 | 68 | for param in netG.parameters(): 69 | param.requires_grad = False 70 | 71 | print('OK!') 72 | 73 | import time 74 | start = time.time() 75 | netG.eval() 76 | for imgs, gt, masks, path in (Erase_data): 77 | if cuda: 78 | imgs = imgs.cuda() 79 | gt = gt.cuda() 80 | masks = masks.cuda() 81 | out1, out2, out3, g_images,mm = netG(imgs) 82 | g_image = g_images.data.cpu() 83 | gt = gt.data.cpu() 84 | mask = masks.data.cpu() 85 | g_image_with_mask = gt * (mask) + g_image * (1- mask) 86 | 87 | save_image(g_image_with_mask, result_with_mask+path[0]) 88 | save_image(g_image, result_straight+path[0]) 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | from PIL import Image 9 | import numpy as np 10 | from torch.autograd import Variable 11 | from torchvision.utils import save_image 12 | from torchvision import datasets 13 | from torch.utils.data import DataLoader 14 | from torchvision import utils 15 | from data.dataloader import ErasingData 16 | from loss.Loss import LossWithGAN_STE 17 | from models.Model import VGG16FeatureExtractor 18 | from models.sa_gan import STRnet2 19 | 20 | torch.set_num_threads(5) 21 | 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" ### set the gpu as No.... 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--numOfWorkers', type=int, default=0, 26 | help='workers for dataloader') 27 | parser.add_argument('--modelsSavePath', type=str, default='', 28 | help='path for saving models') 29 | parser.add_argument('--logPath', type=str, 30 | default='') 31 | parser.add_argument('--batchSize', type=int, default=16) 32 | parser.add_argument('--loadSize', type=int, default=512, 33 | help='image loading size') 34 | parser.add_argument('--dataRoot', type=str, 35 | default='') 36 | parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning') 37 | parser.add_argument('--num_epochs', type=int, default=500, help='epochs') 38 | args = parser.parse_args() 39 | 40 | 41 | def visual(image): 42 | im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy() 43 | Image.fromarray(im[0].astype(np.uint8)).show() 44 | 45 | 46 | cuda = torch.cuda.is_available() 47 | if cuda: 48 | print('Cuda is available!') 49 | cudnn.enable = True 50 | cudnn.benchmark = True 51 | 52 | batchSize = args.batchSize 53 | loadSize = (args.loadSize, args.loadSize) 54 | 55 | if not os.path.exists(args.modelsSavePath): 56 | os.makedirs(args.modelsSavePath) 57 | 58 | dataRoot = args.dataRoot 59 | 60 | # import pdb;pdb.set_trace() 61 | Erase_data = ErasingData(dataRoot, loadSize, training=True) 62 | Erase_data = DataLoader(Erase_data, batch_size=batchSize, 63 | shuffle=True, num_workers=args.numOfWorkers, drop_last=False, pin_memory=True) 64 | 65 | netG = STRnet2(3) 66 | 67 | if args.pretrained != '': 68 | print('loaded ') 69 | netG.load_state_dict(torch.load(args.pretrained)) 70 | 71 | numOfGPUs = torch.cuda.device_count() 72 | 73 | if cuda: 74 | netG = netG.cuda() 75 | if numOfGPUs > 1: 76 | netG = nn.DataParallel(netG, device_ids=range(numOfGPUs)) 77 | 78 | count = 1 79 | 80 | 81 | G_optimizer = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.9)) 82 | 83 | 84 | criterion = LossWithGAN_STE(args.logPath, VGG16FeatureExtractor(), lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0) 85 | 86 | if cuda: 87 | criterion = criterion.cuda() 88 | 89 | if numOfGPUs > 1: 90 | criterion = nn.DataParallel(criterion, device_ids=range(numOfGPUs)) 91 | 92 | print('OK!') 93 | num_epochs = args.num_epochs 94 | 95 | for i in range(1, num_epochs + 1): 96 | netG.train() 97 | 98 | for k,(imgs, gt, masks, path) in enumerate(Erase_data): 99 | if cuda: 100 | imgs = imgs.cuda() 101 | gt = gt.cuda() 102 | masks = masks.cuda() 103 | netG.zero_grad() 104 | 105 | x_o1,x_o2,x_o3,fake_images,mm = netG(imgs) 106 | G_loss = criterion(imgs, masks, x_o1, x_o2, x_o3, fake_images, mm, gt, count, i) 107 | G_loss = G_loss.sum() 108 | G_optimizer.zero_grad() 109 | G_loss.backward() 110 | G_optimizer.step() 111 | 112 | print('[{}/{}] Generator Loss of epoch{} is {}'.format(k,len(Erase_data),i, G_loss.item())) 113 | 114 | count += 1 115 | 116 | if ( i % 10 == 0): 117 | if numOfGPUs > 1 : 118 | torch.save(netG.module.state_dict(), args.modelsSavePath + 119 | '/STE_{}.pth'.format(i)) 120 | else: 121 | torch.save(netG.state_dict(), args.modelsSavePath + 122 | '/STE_{}.pth'.format(i)) -------------------------------------------------------------------------------- /yz_gen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import cv2 4 | import os 5 | import numpy as np 6 | yz_img_root = r'\all_images_org' 7 | mask_root = r'\img_mask_aug' 8 | gt_root = r'\img_gt_aug' 9 | 10 | img_aug = r'\img_aug' 11 | def vertical_grad(src,color_start,color_end): 12 | h = src.shape[0] 13 | 14 | # print type(src) 15 | 16 | # 创建一幅与原图片一样大小的透明图片 17 | grad_img = np.ndarray(src.shape,dtype=np.uint8) 18 | 19 | # opencv 默认采用 BGR 格式而非 RGB 格式 20 | g_b = float(color_end[0] - color_start[0]) / h 21 | g_g = float(color_end[1] - color_start[1]) / h 22 | g_r = float(color_end[2] - color_start[2]) / h 23 | 24 | for i in range(h): 25 | for j in range(src.shape[1]): 26 | 27 | grad_img[i,j,0] = color_start[0] + i * g_b 28 | grad_img[i,j,1] = color_start[1] + i * g_g 29 | grad_img[i,j,2] = color_start[2] + i * g_r 30 | 31 | return grad_img 32 | 33 | def remove_red_seal(input_img): 34 | blue_c, green_c, red_c = cv2.split(input_img) 35 | 36 | thresh, ret = cv2.threshold(red_c, 100, 255, cv2.THRESH_OTSU) 37 | thresh, ret_f = cv2.threshold(ret, 100, 255, cv2.THRESH_BINARY_INV) 38 | # cv2.imshow('1', ret_f) 39 | # cv2.imshow('r', ret) 40 | # cv2.waitKey(0) 41 | thresh1, ret1 = cv2.threshold(blue_c, 100, 255, cv2.THRESH_OTSU) 42 | # cv2.imshow('b', ret1) 43 | # cv2.waitKey(0) 44 | result = ret_f+ret1 45 | # cv2.imshow('c', result) 46 | # cv2.waitKey(0) 47 | 48 | result_img_gray = np.expand_dims(green_c, axis=2) 49 | result_img_gray = np.concatenate((result_img_gray, result_img_gray, result_img_gray), axis=-1) 50 | 51 | aug_img = vertical_grad(input_img,(30,255,0 ),(30,255,4)) 52 | blend = cv2.addWeighted(input_img,1.0,aug_img,0.6,0.0) 53 | # 加一个蒙版 54 | # result_img_inv = np.concatenate((red_c, blue_c, green_c), axis=2) 55 | 56 | return ret, result, result_img_gray, blend 57 | 58 | 59 | yz_name_list = os.listdir(yz_img_root) 60 | for yz_name in yz_name_list: 61 | print(yz_name) 62 | yz_path = os.path.join(yz_img_root, yz_name) 63 | yz_name_b = yz_name.replace('.png', '_black.png') 64 | yz_name_v = yz_name.replace('.png', '_inv.png') 65 | yz_path_b = os.path.join(img_aug, yz_name_b) 66 | gt_path_b = os.path.join(gt_root, yz_name_b) 67 | mask_path_b = os.path.join(mask_root, yz_name_b) 68 | 69 | yz_path_v = os.path.join(yz_img_root, yz_name_v) 70 | gt_path_v = os.path.join(gt_root, yz_name_v) 71 | 72 | mask_path = os.path.join(mask_root, yz_name) 73 | image = cv2.imread(yz_path) 74 | image_aug = cv2.imread(yz_path, 0) 75 | image = cv2.resize(image, (512,512)) 76 | gt, mask, result_img_gray, result_img_inv = remove_red_seal(image) 77 | # cv2.imshow('gray', result_img_gray) 78 | # cv2.imshow('inv', result_img_inv) 79 | # cv2.waitKey(0) 80 | 81 | cv2.imwrite(yz_path_b, result_img_gray) 82 | cv2.imwrite(gt_path_b, gt) 83 | cv2.imwrite(mask_path_b, mask) 84 | 85 | 86 | 87 | --------------------------------------------------------------------------------