├── Agreement_form.doc ├── Evaluate.py ├── README.md ├── UTILITY.py ├── config.py ├── config.yml ├── data_load.py ├── gauss.py ├── images └── arch_new.png ├── loss.py ├── metrics.py ├── model.py ├── network.py ├── test.py ├── train.py └── untils.py /Agreement_form.doc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wcq19941215/SceneTextRemoval/0cd36d54476b869e2e6de170a7e115acb830f56f/Agreement_form.doc -------------------------------------------------------------------------------- /Evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import signal, ndimage 3 | from math import floor 4 | import gauss 5 | 6 | 7 | def ssim(img1, img2, cs_map=False): 8 | """Return the Structural Similarity Map corresponding to input images img1 9 | and img2 (images are assumed to be uint8) 10 | 11 | This function attempts to mimic precisely the functionality of ssim.m a 12 | MATLAB provided by the author's of SSIM 13 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m 14 | """ 15 | img1 = img1.astype(float) 16 | img2 = img2.astype(float) 17 | 18 | size = min(img1.shape[0], 11) 19 | sigma = 1.5 20 | window = gauss.fspecial_gauss(size, sigma) 21 | K1 = 0.01 22 | K2 = 0.03 23 | L = 255 #bitdepth of image 24 | C1 = (K1 * L) ** 2 25 | C2 = (K2 * L) ** 2 26 | 27 | mu1 = signal.fftconvolve(img1, window, mode = 'valid') 28 | mu2 = signal.fftconvolve(img2, window, mode = 'valid') 29 | mu1_sq = mu1 * mu1 30 | mu2_sq = mu2 * mu2 31 | mu1_mu2 = mu1 * mu2 32 | sigma1_sq = signal.fftconvolve(img1 * img1, window, mode = 'valid') - mu1_sq 33 | sigma2_sq = signal.fftconvolve(img2 * img2, window, mode = 'valid') - mu2_sq 34 | sigma12 = signal.fftconvolve(img1 * img2, window, mode = 'valid') - mu1_mu2 35 | if cs_map: 36 | return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)), 37 | (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) 38 | else: 39 | return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 40 | (sigma1_sq + sigma2_sq + C2)) 41 | 42 | 43 | def msssim(img1, img2): 44 | """This function implements Multi-Scale Structural Similarity (MSSSIM) Image 45 | Quality Assessment according to Z. Wang's "Multi-scale structural similarity 46 | for image quality assessment" Invited Paper, IEEE Asilomar Conference on 47 | Signals, Systems and Computers, Nov. 2003 48 | 49 | Author's MATLAB implementation:- 50 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 51 | """ 52 | level = 5 53 | weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 54 | downsample_filter = np.ones((2, 2)) / 4.0 55 | im1 = img1.astype(np.float64) 56 | im2 = img2.astype(np.float64) 57 | mssim = np.array([]) 58 | mcs = np.array([]) 59 | for l in range(level): 60 | ssim_map, cs_map = ssim(im1, im2, cs_map = True) 61 | mssim = np.append(mssim, ssim_map.mean()) 62 | mcs = np.append(mcs, cs_map.mean()) 63 | filtered_im1 = ndimage.filters.convolve(im1, downsample_filter, 64 | mode = 'reflect') 65 | filtered_im2 = ndimage.filters.convolve(im2, downsample_filter, 66 | mode = 'reflect') 67 | im1 = filtered_im1[: : 2, : : 2] 68 | im2 = filtered_im2[: : 2, : : 2] 69 | 70 | # Note: Remove the negative and add it later to avoid NaN in exponential. 71 | sign_mcs = np.sign(mcs[0 : level - 1]) 72 | sign_mssim = np.sign(mssim[level - 1]) 73 | mcs_power = np.power(np.abs(mcs[0 : level - 1]), weight[0 : level - 1]) 74 | mssim_power = np.power(np.abs(mssim[level - 1]), weight[level - 1]) 75 | return np.prod(sign_mcs * mcs_power) * sign_mssim * mssim_power 76 | #return (np.prod(mcs[0 : level - 1] ** weight[0 : level - 1]) * (mssim[level - 1] ** weight[level - 1])) 77 | 78 | def mae(img1, img2): 79 | r = np.asarray(img1, dtype=np.float64).ravel() 80 | print(r.shape) 81 | c = np.asarray(img2, dtype=np.float64).ravel() 82 | return np.mean(np.abs(r - c))/255 83 | 84 | 85 | def PeakSignaltoNoiseRatio(origImg, distImg, max_value=255): 86 | origImg = origImg.astype(float) 87 | distImg = distImg.astype(float) 88 | 89 | M, N = np.shape(origImg) 90 | error = origImg - distImg 91 | MSE = sum(sum(error * error)) / (M * N) 92 | 93 | if MSE > 0: 94 | PSNR = 10 * np.log10(max_value * max_value / MSE) 95 | else: 96 | PSNR = 99 97 | # print(PSNR) 98 | # print(MSE) 99 | 100 | return PSNR , MSE 101 | 102 | 103 | def cqm(orig_img, dist_img): 104 | M, N, C = np.shape(orig_img) 105 | 106 | if C != 3: 107 | CQM = float("inf") 108 | return CQM 109 | 110 | Ro = orig_img[:, :, 0] 111 | Go = orig_img[:, :, 1] 112 | Bo = orig_img[:, :, 2] 113 | 114 | Rd = dist_img[:, :, 0] 115 | Gd = dist_img[:, :, 1] 116 | Bd = dist_img[:, :, 2] 117 | 118 | ################################################ 119 | ### Reversible YUV Transformation ### 120 | ################################################ 121 | YUV_img1 = np.zeros((M, N, 3)) 122 | YUV_img2 = np.zeros((M, N, 3)) 123 | 124 | for i in range(M): 125 | for j in range(N): 126 | ### Original Image Trasnformation ### 127 | # Y=(R+2*G+B)/4 128 | YUV_img1[i, j, 0] = floor((Ro[i, j] + Go[i, j] * 2 + Bo[i, j]) / 4) 129 | YUV_img2[i, j, 0] = floor((Rd[i, j] + Gd[i, j] * 2 + Bd[i, j]) / 4) 130 | # U=R-G 131 | YUV_img1[i, j, 1] = max(0, Ro[i, j] - Go[i, j]) 132 | YUV_img2[i, j, 1] = max(0, Rd[i, j] - Gd[i, j]) 133 | # V=B-G 134 | YUV_img1[i, j, 2] = max(0, Bo[i, j] - Go[i, j]) 135 | YUV_img2[i, j, 2] = max(0, Bd[i, j] - Gd[i, j]) 136 | 137 | ################################################ 138 | ### CQM Calculation ### 139 | ################################################ 140 | Y_psnr = PeakSignaltoNoiseRatio(YUV_img1[:, :, 0], YUV_img2[:, :, 0]); # PSNR for Y channel 141 | U_psnr = PeakSignaltoNoiseRatio(YUV_img1[:, :, 1], YUV_img2[:, :, 1]); # PSNR for U channel 142 | V_psnr = PeakSignaltoNoiseRatio(YUV_img1[:, :, 2], YUV_img2[:, :, 2]); # PSNR for V channel 143 | 144 | CQM = (Y_psnr * 0.9449) + (U_psnr + V_psnr) / 2 * 0.0551 145 | 146 | return CQM 147 | 148 | 149 | def Evaluate(GT, BC): 150 | print(np.shape(GT)) 151 | [M, N, C] = np.shape(GT) 152 | dimension = M * N 153 | 154 | GT = np.ndarray((M, N, 3), 'u1', GT.tostring()).astype(float) 155 | BC = np.ndarray((M, N, 3), 'u1', BC.tostring()).astype(float) 156 | 157 | if C == 3: # In case of color images, use luminance in YCbCr 158 | R = GT[:, :, 0] 159 | G = GT[:, :, 1] 160 | B = GT[:, :, 2] 161 | 162 | YGT = .299 * R + .587 * G + .114 * B 163 | 164 | R = BC[:, :, 0] 165 | G = BC[:, :, 1] 166 | B = BC[:, :, 2] 167 | 168 | YBC = .299 * R + .587 * G + .114 * B 169 | 170 | else: 171 | YGT = GT 172 | YBC = BC 173 | 174 | ############################# AGE ######################################## 175 | Diff = abs(YGT - YBC).round().astype(np.uint8) 176 | AGE = np.mean(Diff) 177 | 178 | ########################### EPs and pEPs ################################# 179 | threshold = 20 180 | 181 | Errors = Diff > threshold 182 | EPs = sum(sum(Errors)).astype(float) 183 | pEPs = EPs / float(dimension) 184 | 185 | ########################## CEPs and pCEPs ################################ 186 | structure = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) 187 | erodedErrors = ndimage.binary_erosion( 188 | Errors, structure).astype(Errors.dtype) 189 | CEPs = sum(sum(erodedErrors)) 190 | pCEPs = CEPs / float(dimension) 191 | 192 | ############################# MSSSIM ##################################### 193 | MSSSIM = msssim(YGT, YBC) 194 | # print("MSSSIM",MSSSIM) 195 | SSIM = np.mean(ssim(YGT, YBC)) 196 | # print("SSIM",SSIM) 197 | MAE = mae(GT, BC) 198 | ############################# PSNR ####################################### 199 | PSNR,MSE = PeakSignaltoNoiseRatio(YGT, YBC) 200 | 201 | ############################# CQM ######################################## 202 | # if C == 3: 203 | # CQM = cqm(GT, BC) 204 | 205 | return (AGE, pEPs, pCEPs, MSSSIM, PSNR, SSIM, MSE,MAE) 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scene text removal via cascaded text stroke detection and erasing 2 | 3 | The training set of synthetic database consists of a total of 617401 images and the test set contains 2000 images; all the training and test samples are resized to 256 × 256. The code for generating synthetic dataset and more synthetic text images as described in “Ankush Gupta, Andrea Vedaldi, Andrew Zisserman, Synthetic Data for Text localisation in Natural Images, CVPR 2016", and can be found in (https://github.com/ankush-me/SynthText). 4 | Besides, all the real scene text images are also resized to 256 × 256. 5 | 6 | For more details, please refer to our [CVM 2021 paper] arXiv: https://arxiv.org/abs/2011.09768 7 | ![](images/arch_new.png) 8 | 9 | ## Requirements 10 | 1. Tensorflow==1.13.1 11 | 2. Python==3.6.13 12 | 3. CUDA==10.0. 13 | 4. Opencv==4.5.1. 14 | 5. Numpy. 15 | 16 | ## Installation 17 | 1. Clone this respository. 18 | ``` 19 | git clone https://github.com/wcq19941215/SceneTextRemoval.git 20 | ``` 21 | ## Running 22 | ### 1. Image Prepare 23 | You can modify the path of the trainset, valset dir, and other hyperparameters in `config.yml`. 24 | It should be noted that during training, gt, mask, and image are concat into a single image, which will be automatically separated during training. 25 | ### 2. Training 26 | Once `config.yml` is configured, you only need to run train.py. Then run the following code: 27 | ``` 28 | python train.py 29 | ``` 30 | ### 3. Testing 31 | During the test, `test.py` can only get the output result, not including the evaluation result such as `PSNR`. Please run the following code: 32 | ``` 33 | python test.py \ 34 | --image=[the path of test images] \ 35 | --mask=[the path of test mask] \ 36 | --output=[Where to save output image.] \ 37 | --checkpoint_dir=[The directory of tensorflow checkpoint] 38 | ``` 39 | For fair comparison, we use the same evaluation method as [Ensnet](https://github.com/HCIILAB/Scene-Text-Removal), you can find the evaluation metrics in this website [PythonCode.zip](http://pione.dinf.usherbrooke.ca/static/code), You can also use `UTILITY.py` to test PSNR and SSIM 40 | ### 4. Pretrained models 41 | Please download the our pretrained models [TextRomoval](https://pan.baidu.com/s/1Bj1YM5RqNqZ_PRkvetmy9Q) PASSWORD:1234. 42 | 43 | ### 5. Dataset 44 | The dataset can be obtained by sending a request email to us. Specifically, the researchers should download and fill up this [Agreement Form](https://github.com/wcq19941215/SceneTextRemoval/blob/main/Agreement_form.doc) and send it back to Weize Quan (weize.quan AT nlpr.ia.ac.cn; Email title: Scene Text Removal Dataset Request). We will then send you the download instructions at our discretion. 45 | 46 | ## Paper 47 | 48 | Please consider to cite our paper when you use our database: 49 | ``` 50 | @article{Bian2021Scence, 51 | title = {Scene text removal via cascaded text stroke detection and erasing}, 52 | author = {Xuewei Bian, Chaoqun Wang, Weize Quan, Juntao Ye, Xiaopeng Zhang, Dong-Ming Yan} 53 | publisher = {Computational Visual Media}, 54 | year = {2022}, 55 | journal = {Computational Visual Media}, 56 | volume = {8}, 57 | number = {2}, 58 | numpages = {15}, 59 | keywords = {;scene text removal;text stroke detection;generative adversarial networks;cascaded network design;real-world dataset}, 60 | doi = {10.1007/s41095-021-0242-8} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /UTILITY.py: -------------------------------------------------------------------------------- 1 | import os 2 | import Evaluate 3 | import scipy 4 | # import tkMessageBox 5 | from tkinter import messagebox 6 | import sys 7 | # import ipdb 8 | def Utility(GT_path, evaluated_path,num_path): 9 | ''' 10 | Function to evaulate the your resutls for SBMnet dataset, this code will generate a 'cm.txt' file in your result path to save all the metrics. 11 | input: GT_path: the path of the groundtruth folder. 12 | evaluated_path: the path of your evaluated results folder. 13 | ''' 14 | result_file = os.path.join(evaluated_path, 'cm.txt') 15 | 16 | with open(result_file, 'w') as fid: 17 | fid.write('\t\timage_name\tPSNR\tSSIM\tMSE\tMAE\r\n') 18 | 19 | m_AGE = 0 20 | m_pEPs = 0 21 | m_pCEPs = 0 22 | m_MSSSIM = 0 23 | m_PSNR = 0 24 | m_SSIM = 0 25 | m_MSE = 0 26 | m_MAE = 0 27 | # ipdb.set_trace() 28 | 29 | c_AGE = 0 30 | c_pEPs = 0 31 | c_pCEPs = 0 32 | c_MSSSIM = 0 33 | c_PSNR = 0 34 | c_SSIM = 0 35 | c_MSE = 0 36 | c_MAE = 0 37 | 38 | image_num = 0 39 | 40 | for root, dirs, files in os.walk(evaluated_path): 41 | MSSSIM_max = 0 42 | for i in files: 43 | # 判断是否以.jpg结尾 44 | if i.endswith('.JPG') or i.endswith('.jpg') or i.endswith('.PNG') or i.endswith('.png'): 45 | picname = i.split('.')[0] 46 | print("picname:",picname) 47 | num = picname.split('_')[0] 48 | 49 | #if more than one GT exists for the video, we keep the 50 | #metrics with the highest MSSSIM value. 51 | if(num_path==2000): 52 | GT_img = scipy.misc.imread(GT_path+num+".jpg") #background ground truth 53 | result_img = scipy.misc.imread(evaluated_path+num+".jpg") 54 | if(num_path==1080): 55 | GT_img = scipy.misc.imread(GT_path+num+"_gt.png") #background ground truth 56 | result_img = scipy.misc.imread(evaluated_path+picname+".png") 57 | 58 | AGE, pEPs, pCEPs, MSSSIM, PSNR, SSIM, MSE,MAE = Evaluate.Evaluate(GT_img, result_img); 59 | if MSSSIM > MSSSIM_max: 60 | MSSSIM_max = MSSSIM 61 | v_AGE = AGE 62 | v_pEPs = pEPs 63 | v_pCEPs = pCEPs 64 | v_MSSSIM = MSSSIM 65 | v_PSNR = PSNR 66 | v_SSIM = SSIM 67 | v_MSE = MSE 68 | v_MAE = MAE 69 | 70 | 71 | #save the video evaluation results 72 | with open(result_file, 'a+') as fid: 73 | fid.write('\t\t' + picname + ':\t' + str(round(v_PSNR, 4)) + '\t' + str(round(v_SSIM, 4)) + '\t' + str(round(v_MSE, 4)) + '\t' + str(round(v_MAE, 4)) + '\r\n') 74 | 75 | c_AGE = c_AGE + v_AGE 76 | c_pEPs = c_pEPs + v_pEPs 77 | c_pCEPs = c_pCEPs + v_pCEPs 78 | c_MSSSIM = c_MSSSIM + v_MSSSIM 79 | c_PSNR = c_PSNR + v_PSNR 80 | c_SSIM = c_SSIM + v_SSIM 81 | c_MSE = c_MSE + v_MSE 82 | c_MAE = c_MAE + v_MAE 83 | image_num = image_num + 1 84 | 85 | c_AGE = c_AGE / float(image_num) 86 | c_pEPs = c_pEPs / float(image_num) 87 | c_pCEPs = c_pCEPs / float(image_num) 88 | c_MSSSIM = c_MSSSIM / float(image_num) 89 | c_PSNR = c_PSNR / float(image_num) 90 | c_SSIM = c_SSIM / float(image_num) 91 | c_MSE = c_MSE / float(image_num) 92 | c_MAE = c_MAE / float(image_num) 93 | 94 | #save the category evaluation results 95 | with open(result_file, 'a+') as fid: 96 | fid.write('\t\timage_name\tPSNR\tSSIM\tMSE\tMAE\r\n') 97 | fid.write('\r\n' + 'gt' + '_AVG::\t\t' + str(round(c_PSNR, 4)) + '\t' + str(round(c_SSIM, 4)) + '\t' + str(round(c_MSE, 4))+ '\t' + str(round(c_MAE, 4)) + '\r\n\r\n') 98 | 99 | m_AGE = m_AGE + c_AGE 100 | m_pEPs = m_pEPs + c_pEPs 101 | m_pCEPs = m_pCEPs + c_pCEPs 102 | m_MSSSIM = m_MSSSIM + c_MSSSIM 103 | m_PSNR = m_PSNR + c_PSNR 104 | m_SSIM = m_SSIM + c_SSIM 105 | m_MSE = m_MSE + c_MSE 106 | m_MAE = m_MAE + c_MAE 107 | 108 | 109 | 110 | with open(result_file, 'a+') as fid: 111 | fid.write('Total:\t\t\t' + str(round(m_PSNR, 4)) + '\t' + str(round(m_SSIM*100, 4)) + '\t' + str(round(m_MSE, 4)) + '\t' + str(round(m_MAE*100, 4)) + '\r\n') 112 | 113 | if __name__ == '__main__': 114 | if len(sys.argv) < 3: 115 | print("Usage: python {0} ".format(sys.argv[0])) 116 | num_path=2000 117 | GT_path = '/data/cqwang/64_backup/cqwang/dataset/IJCIA_dataset/test/test_{}_256/gt/'.format(str(num_path)) 118 | evaluated_path = './result_{}/'.format(str(num_path)) 119 | Utility(GT_path, evaluated_path,num_path) 120 | 121 | num_path=1080 122 | GT_path = '/data/cqwang/64_backup/cqwang/dataset/IJCIA_dataset/test/test_{}_256/gt/'.format(str(num_path)) 123 | evaluated_path = './result_{}/'.format(str(num_path)) 124 | Utility(GT_path, evaluated_path,num_path) 125 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | class Config(object): 5 | def __init__(self, config_path): 6 | assert os.path.exists(config_path), "ERROR: Config File doesn't exist" 7 | with open(config_path, 'r') as f: 8 | self._yaml = f.read() 9 | self._dict = yaml.load(self._yaml) 10 | 11 | def __getattr__(self, name): 12 | if self._dict.get(name) is not None: 13 | return self._dict[name] 14 | 15 | if DEFAULT_CONFIG.get(name) is not None: 16 | return DEFAULT_CONFIG[name] 17 | 18 | return None 19 | 20 | def print(self): 21 | print('Model configurations:') 22 | print('---------------------------------') 23 | print(self._yaml) 24 | print('') 25 | print('---------------------------------') 26 | print('') 27 | 28 | DEFAULT_CONFIG = { 29 | 'SEED': 10, # random seed 30 | 'GPU': [5], # list of gpu ids 31 | 'LR': 0.0001, # learning rate 32 | 'BETA1': 0.0, # adam optimizer beta1 33 | 'BETA2': 0.9, # adam optimizer beta2 34 | 'BATCH_SIZE': 8, # input batch size for training 35 | 'INPUT_SIZE': 256, # input image size for training 0 for original size 36 | 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never) 37 | 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never) 38 | 'SAMPLE_SIZE': 12, # number of images to sample 39 | 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never) 40 | 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never) 41 | } 42 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | SEED: 10 # random seed 2 | GPU: [3] # list of gpu ids 3 | 4 | CHECKPOINTS: ./model_logs/text_output_0513/ # Save the dir of CHECKPOINTS files 5 | LOAD_MODEL: ./model_logs/text_output_0513/ # If there is a checkpoint meta file, it will be read and training will continue 6 | 7 | TRAIN_CONCAT_FLIST: /data/dataset/IJCIA_dataset/train/concated.list # train pic list 8 | VAL_CONCAT_FLIST: /data/dataset/IJCIA_dataset/test/test_1080/concated.txt # val pic list 9 | 10 | LR: 0.0001 # learning rate 11 | BETA1: 0.5 # adam optimizer beta1 12 | BETA2: 0.9 # adam optimizer beta2 13 | BATCH_SIZE: 4 # input batch size for training 14 | VAL_BATCH_SIZE: 8 # input batch size for validation 15 | INPUT_SIZE: 256 # input image size for training 0 for original size 16 | EPOCH: 200 # epoch number to train the model 17 | SAVE_INTERVAL: 500 # how many iterations to wait before saving model (0: never)ckp 18 | SUMMARY_INTERVAL: 500 # how many iterations to wait before summary model (0: never)tensorboard 19 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling (0: never) 20 | SAMPLE_SIZE: 12 # number of images to sample 21 | EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never) 22 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never) 23 | -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | class Dataset(object): 7 | def __init__(self, config): 8 | super(Dataset, self).__init__() 9 | self.train_concat_list = self.load_flist(config.TRAIN_CONCAT_FLIST) 10 | self.val_concat_list = self.load_flist(config.VAL_CONCAT_FLIST) 11 | 12 | self.len_train = len(self.train_concat_list) 13 | self.len_val = len(self.val_concat_list) 14 | 15 | self.input_size = config.INPUT_SIZE 16 | self.epoch = config.EPOCH 17 | self.batch_size = config.BATCH_SIZE 18 | self.val_batch_size = config.VAL_BATCH_SIZE 19 | 20 | self.data_batch() 21 | 22 | def load_flist(self, flist): 23 | 24 | if isinstance(flist, list): 25 | return flist 26 | # flist: image file path, image directory path, text file flist path 27 | if isinstance(flist, str): 28 | if os.path.isdir(flist): 29 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 30 | flist.sort() 31 | return flist 32 | 33 | if os.path.isfile(flist): 34 | try: 35 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 36 | except: 37 | return [flist] 38 | return [] 39 | 40 | def data_batch(self): 41 | train_concat = tf.data.Dataset.from_tensor_slices(self.train_concat_list) 42 | val_concat = tf.data.Dataset.from_tensor_slices(self.val_concat_list) 43 | 44 | def image_fn(img_path): 45 | x = tf.read_file(img_path) 46 | x_decode = tf.image.decode_jpeg(x, channels=3) 47 | img = tf.image.resize_images(x_decode, [256,1280]) 48 | return img 49 | 50 | train_concat = train_concat.map(image_fn, num_parallel_calls=self.batch_size) 51 | train_dataset = tf.data.Dataset.zip((train_concat)) 52 | train_dataset = train_dataset.apply(tf.data.experimental.shuffle_and_repeat(1000, 10*self.epoch)).batch(self.batch_size, drop_remainder=True) 53 | 54 | val_concat = val_concat.map(image_fn, num_parallel_calls=self.val_batch_size) 55 | val_dataset = tf.data.Dataset.zip((val_concat)) 56 | val_dataset = val_dataset.apply(tf.data.experimental.shuffle_and_repeat(1000, 10*self.epoch)).batch(self.val_batch_size, drop_remainder=True) 57 | 58 | self.batch_concat= train_dataset.make_one_shot_iterator().get_next() 59 | self.val_concat = val_dataset.make_one_shot_iterator().get_next() 60 | 61 | # get the epoch of dataset train_image 62 | self.dataset = train_dataset 63 | 64 | 65 | -------------------------------------------------------------------------------- /gauss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module providing functionality surrounding gaussian function. 3 | """ 4 | SVN_REVISION = '$LastChangedRevision: 16541 $' 5 | 6 | import sys 7 | import numpy 8 | 9 | def gaussian2(size, sigma): 10 | """Returns a normalized circularly symmetric 2D gauss kernel array 11 | 12 | f(x,y) = A.e^{-(x^2/2*sigma^2 + y^2/2*sigma^2)} where 13 | 14 | A = 1/(2*pi*sigma^2) 15 | 16 | as define by Wolfram Mathworld 17 | http://mathworld.wolfram.com/GaussianFunction.html 18 | """ 19 | A = 1/(2.0*numpy.pi*sigma**2) 20 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 21 | g = A*numpy.exp(-((x**2/(2.0*sigma**2))+(y**2/(2.0*sigma**2)))) 22 | return g 23 | 24 | def fspecial_gauss(size, sigma): 25 | """Function to mimic the 'fspecial' gaussian MATLAB function 26 | """ 27 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 28 | g = numpy.exp(-((x**2 + y**2)/(2.0*sigma**2))) 29 | return g/g.sum() 30 | 31 | def main(): 32 | """Show simple use cases for functionality provided by this module.""" 33 | from mpl_toolkits.mplot3d.axes3d import Axes3D 34 | import pylab 35 | argv = sys.argv 36 | if len(argv) != 3: 37 | print >>sys.stderr, 'usage: python -m pim.sp.gauss size sigma' 38 | sys.exit(2) 39 | size = int(argv[1]) 40 | sigma = float(argv[2]) 41 | x, y = numpy.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 42 | 43 | fig = pylab.figure() 44 | fig.suptitle('Some 2-D Gauss Functions') 45 | ax = fig.add_subplot(2, 1, 1, projection='3d') 46 | ax.plot_surface(x, y, fspecial_gauss(size, sigma), rstride=1, cstride=1, 47 | linewidth=0, antialiased=False, cmap=pylab.jet()) 48 | ax = fig.add_subplot(2, 1, 2, projection='3d') 49 | ax.plot_surface(x, y, gaussian2(size, sigma), rstride=1, cstride=1, 50 | linewidth=0, antialiased=False, cmap=pylab.jet()) 51 | pylab.show() 52 | return 0 53 | 54 | if __name__ == '__main__': 55 | sys.exit(main()) -------------------------------------------------------------------------------- /images/arch_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wcq19941215/SceneTextRemoval/0cd36d54476b869e2e6de170a7e115acb830f56f/images/arch_new.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import numpy as np 4 | import inspect 5 | 6 | def hinge_gan_loss(discriminator_data, discriminator_z): 7 | loss_discriminator_data = tf.reduce_mean(tf.nn.relu(1 - discriminator_data)) 8 | loss_discriminator_z = tf.reduce_mean(tf.nn.relu(1 + discriminator_z)) 9 | loss_discriminator = (loss_discriminator_data + loss_discriminator_z) 10 | 11 | loss_generator_adversarial = -tf.reduce_mean(discriminator_z) 12 | return loss_discriminator, loss_generator_adversarial 13 | def adversarial_loss(outputs, is_real, is_disc=None, type='nsgan'): 14 | r""" 15 | Adversarial loss 16 | https://arxiv.org/abs/1711.10337 17 | """ 18 | outputs = tf.reshape(outputs, [-1]) 19 | if type == 'hinge': 20 | if is_disc: 21 | if is_real: 22 | outputs = -outputs 23 | return tf.reduce_mean(tf.nn.relu(1 + outputs)) 24 | else: 25 | return tf.reduce_mean(-outputs) 26 | 27 | elif type == 'nsgan': 28 | labels = tf.ones_like(outputs) if is_real else tf.zeros_like(outputs) 29 | loss = tf.keras.metrics.binary_crossentropy(labels, outputs) 30 | return loss 31 | elif type == 'lsgan': 32 | labels = tf.ones_like(outputs) if is_real else tf.zeros_like(outputs) 33 | loss = tf.keras.metrics.mean_squared_error(labels, outputs) 34 | return loss 35 | 36 | def l1_loss(inputs, targets): 37 | inputs = tf.reshape(inputs, [-1]) 38 | targets = tf.reshape(targets, [-1]) 39 | loss = tf.reduce_mean(tf.abs(inputs - targets)) 40 | return loss 41 | 42 | def tv_loss_mask(y_comp, mask, margin=3): 43 | """Total variation loss, used for smoothing the hole region, see. eq. 6""" 44 | 45 | # Create dilated hole region using a 3x3 kernel of all 1s. 46 | kernel = tf.ones([margin, margin, tf.shape(mask)[3], tf.shape(mask)[3]]) 47 | dilated_mask = tf.nn.conv2d(mask, kernel, strides=[1,1,1,1], padding='SAME') 48 | 49 | # Cast values to be [0., 1.], and compute dilated hole region of y_comp 50 | dilated_mask = tf.cast(dilated_mask > 0, tf.float32) 51 | # tf.debugging.assert_less(tf.reduce_sum(mask, axis=[1,2,3]), tf.reduce_sum(dilated_mask, axis=[1,2,3])) 52 | P = dilated_mask * y_comp 53 | 54 | # Calculate total variation loss 55 | a = l1_loss_mask(P[:,1:,:,:], P[:,:-1,:,:], dilated_mask) 56 | b = l1_loss_mask(P[:,:,1:,:], P[:,:,:-1,:], dilated_mask) 57 | return a + b 58 | 59 | def tv_loss(inputs): 60 | r""" A smooth loss in fact. 61 | Like the smooth prior in MRF. V(y) = || y_{n+1} - y_n ||_1 62 | """ 63 | dy = inputs[:, :-1, ...] - inputs[:, 1:, ...] 64 | dx = inputs[:, :, :-1, ...] - inputs[:, :, 1:, ...] 65 | dy_loss = l1_loss(dy, tf.zeros_like(dy)) 66 | dx_loss = l1_loss(dx, tf.zeros_like(dx)) 67 | return dy_loss + dx_loss 68 | 69 | def l1_loss_mask(inputs, targets, mask): 70 | loss = tf.reduce_sum(tf.abs(inputs - targets), axis=[1,2,3]) 71 | xs = targets.get_shape().as_list() 72 | # xs = tf.cast(xs, tf.float32) 73 | ratio = tf.reduce_sum(mask, axis=[1,2,3]) * xs[3] # mask: BWH1. if mask = BWH3, then remove '*3' 74 | loss_mean = tf.reduce_mean(loss / (ratio + 1e-12)) # avoid mask = 0 75 | return loss_mean 76 | 77 | def style_loss(x, y): 78 | r""" 79 | Perceptual loss, VGG-based 80 | https://arxiv.org/abs/1603.08155 81 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 82 | """ 83 | def compute_gram(x): 84 | 85 | b, h, w, ch = x.get_shape().as_list() 86 | f = tf.reshape(x, [-1, ch, w * h]) 87 | f_T = tf.transpose(f, perm=[0, 2, 1]) 88 | # G = tf.matmul(f, f_T) / (h * w * ch) 89 | for i in range(b): 90 | G = tf.matmul(f[i], f_T[i]) 91 | G = tf.expand_dims(G, axis=0) 92 | if i == 0: 93 | g = G 94 | else: 95 | g = tf.concat([g, G], axis=0) 96 | g = g / (h * w * ch) 97 | return g / (h * w * ch) 98 | 99 | x_vgg = Vgg19(x) 100 | y_vgg = Vgg19(y) 101 | 102 | # Compute loss 103 | style_loss = 0.0 104 | style_loss += l1_loss(compute_gram(x_vgg.conv2_2), compute_gram(y_vgg.conv2_2)) 105 | style_loss += l1_loss(compute_gram(x_vgg.conv3_4), compute_gram(y_vgg.conv3_4)) 106 | style_loss += l1_loss(compute_gram(x_vgg.conv4_4), compute_gram(y_vgg.conv4_4)) 107 | style_loss += l1_loss(compute_gram(x_vgg.conv5_2), compute_gram(y_vgg.conv5_2)) 108 | 109 | return style_loss 110 | 111 | def perceptual_loss(x, y, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 112 | r""" 113 | Perceptual loss, VGG-based 114 | https://arxiv.org/abs/1603.08155 115 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 116 | """ 117 | x_vgg = Vgg19(x) 118 | y_vgg = Vgg19(y) 119 | 120 | content_loss = 0.0 121 | content_loss += weights[0] * l1_loss(x_vgg.conv1_1, y_vgg.conv1_1) 122 | content_loss += weights[1] * l1_loss(x_vgg.conv2_1, y_vgg.conv2_1) 123 | content_loss += weights[2] * l1_loss(x_vgg.conv3_1, y_vgg.conv3_1) 124 | content_loss += weights[3] * l1_loss(x_vgg.conv4_1, y_vgg.conv4_1) 125 | content_loss += weights[4] * l1_loss(x_vgg.conv5_1, y_vgg.conv5_1) 126 | 127 | 128 | return content_loss 129 | 130 | 131 | class Vgg19: 132 | def __init__(self, img, vgg19_npy_path=None): 133 | with tf.variable_scope('VGG19'): 134 | if vgg19_npy_path is None: 135 | path = inspect.getfile(Vgg19) 136 | path = os.path.abspath(os.path.join(path, os.pardir)) 137 | path = os.path.join(path, "vgg19.npy") 138 | vgg19_npy_path = path 139 | # print(vgg19_npy_path) 140 | 141 | self.data_dict = np.load(vgg19_npy_path, encoding='latin1',allow_pickle=True).item() 142 | # print("npy file loaded") 143 | 144 | self.build(img) 145 | 146 | # def __setitem__(self, key, value): 147 | 148 | def build(self, x): 149 | """ 150 | load variable from npy to build the VGG 151 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 152 | """ 153 | self.conv1_1 = self.conv_layer(x, "conv1_1") 154 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 155 | 156 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 157 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 158 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 159 | 160 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 161 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 162 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 163 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 164 | self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4") 165 | 166 | self.pool3 = self.max_pool(self.conv3_4, 'pool3') 167 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 168 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 169 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 170 | self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4") 171 | 172 | self.pool4 = self.max_pool(self.conv4_4, 'pool4') 173 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 174 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 175 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 176 | self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4") 177 | 178 | 179 | def avg_pool(self, bottom, name): 180 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 181 | 182 | def max_pool(self, bottom, name): 183 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 184 | 185 | def conv_layer(self, bottom, name): 186 | with tf.variable_scope(name): 187 | filt = self.get_conv_filter(name) 188 | 189 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 190 | 191 | conv_biases = self.get_bias(name) 192 | bias = tf.nn.bias_add(conv, conv_biases) 193 | 194 | relu = tf.nn.relu(bias) 195 | return relu 196 | 197 | def get_conv_filter(self, name): 198 | return tf.constant(self.data_dict[name][0], name="filter") 199 | 200 | def get_bias(self, name): 201 | return tf.constant(self.data_dict[name][1], name="biases") 202 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | class PSNR(object): 7 | def __init__(self, max_val=255): 8 | super(PSNR, self).__init__() 9 | 10 | self.base10 = tf.log(tf.constant(10.0)) 11 | max_val = tf.constant(max_val) 12 | 13 | self.max_val = 20 * tf.log(max_val) / self.base10 14 | 15 | def __call__(self, a, b): 16 | a = tf.cast(a, tf.float32) 17 | b = tf.cast(b, tf.float32) 18 | mse = tf.reduce_mean((a - b) ** 2) 19 | 20 | if mse == 0: 21 | return 0 22 | 23 | return self.max_val - 10 * tf.log(mse) / self.base10 24 | 25 | def mean_psnr(a, b, max_val=255.0): 26 | psnr_value = tf.image.psnr(a, b, max_val); 27 | return psnr_value, tf.reduce_mean(psnr_value) 28 | 29 | def mean_ssim(a, b, max_val=255.0): 30 | ssim_value = tf.image.ssim(a, b, max_val); 31 | return ssim_value, tf.reduce_mean(ssim_value) 32 | 33 | class Progbar(object): 34 | """Displays a progress bar. 35 | 36 | Arguments: 37 | target: Total number of steps expected, None if unknown. 38 | width: Progress bar width on screen. 39 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 40 | stateful_metrics: Iterable of string names of metrics that 41 | should *not* be averaged over time. Metrics in this list 42 | will be displayed as-is. All others will be averaged 43 | by the progbar before display. 44 | interval: Minimum visual progress update interval (in seconds). 45 | """ 46 | 47 | def __init__(self, target, width=25, verbose=1, interval=0.05, 48 | stateful_metrics=None): 49 | self.target = target 50 | self.width = width 51 | self.verbose = verbose 52 | self.interval = interval 53 | if stateful_metrics: 54 | self.stateful_metrics = set(stateful_metrics) 55 | else: 56 | self.stateful_metrics = set() 57 | 58 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 59 | sys.stdout.isatty()) or 60 | 'ipykernel' in sys.modules or 61 | 'posix' in sys.modules) 62 | self._total_width = 0 63 | self._seen_so_far = 0 64 | # We use a dict + list to avoid garbage collection 65 | # issues found in OrderedDict 66 | self._values = {} 67 | self._values_order = [] 68 | self._start = time.time() 69 | self._last_update = 0 70 | 71 | def update(self, current, values=None): 72 | """Updates the progress bar. 73 | 74 | Arguments: 75 | current: Index of current step. 76 | values: List of tuples: 77 | `(name, value_for_last_step)`. 78 | If `name` is in `stateful_metrics`, 79 | `value_for_last_step` will be displayed as-is. 80 | Else, an average of the metric over time will be displayed. 81 | """ 82 | values = values or [] 83 | for k, v in values: 84 | if k not in self._values_order: 85 | self._values_order.append(k) 86 | if k not in self.stateful_metrics: 87 | if k not in self._values: 88 | self._values[k] = [v * (current - self._seen_so_far), 89 | current - self._seen_so_far] 90 | else: 91 | self._values[k][0] += v * (current - self._seen_so_far) 92 | self._values[k][1] += (current - self._seen_so_far) 93 | else: 94 | self._values[k] = v 95 | self._seen_so_far = current 96 | 97 | now = time.time() 98 | info = ' - %.0fs' % (now - self._start) 99 | if self.verbose == 1: 100 | if (now - self._last_update < self.interval and 101 | self.target is not None and current < self.target): 102 | return 103 | 104 | prev_total_width = self._total_width 105 | if self._dynamic_display: 106 | sys.stdout.write('\b' * prev_total_width) 107 | sys.stdout.write('\r') 108 | else: 109 | sys.stdout.write('\n') 110 | 111 | if self.target is not None: 112 | numdigits = int(np.floor(np.log10(self.target))) + 1 113 | barstr = '%%%dd/%d [' % (numdigits, self.target) 114 | bar = barstr % current 115 | prog = float(current) / self.target 116 | prog_width = int(self.width * prog) 117 | if prog_width > 0: 118 | bar += ('=' * (prog_width - 1)) 119 | if current < self.target: 120 | bar += '>' 121 | else: 122 | bar += '=' 123 | bar += ('.' * (self.width - prog_width)) 124 | bar += ']' 125 | else: 126 | bar = '%7d/Unknown' % current 127 | 128 | self._total_width = len(bar) 129 | sys.stdout.write(bar) 130 | 131 | if current: 132 | time_per_unit = (now - self._start) / current 133 | else: 134 | time_per_unit = 0 135 | if self.target is not None and current < self.target: 136 | eta = time_per_unit * (self.target - current) 137 | if eta > 3600: 138 | eta_format = '%d:%02d:%02d' % (eta // 3600, 139 | (eta % 3600) // 60, 140 | eta % 60) 141 | elif eta > 60: 142 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 143 | else: 144 | eta_format = '%ds' % eta 145 | 146 | info = ' - ETA: %s' % eta_format 147 | else: 148 | if time_per_unit >= 1: 149 | info += ' %.0fs/step' % time_per_unit 150 | elif time_per_unit >= 1e-3: 151 | info += ' %.0fms/step' % (time_per_unit * 1e3) 152 | else: 153 | info += ' %.0fus/step' % (time_per_unit * 1e6) 154 | 155 | for k in self._values_order: 156 | info += ' - %s:' % k 157 | if isinstance(self._values[k], list): 158 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 159 | if abs(avg) > 1e-3: 160 | info += ' %.4f' % avg 161 | else: 162 | info += ' %.4e' % avg 163 | else: 164 | info += ' %s' % self._values[k] 165 | 166 | self._total_width += len(info) 167 | if prev_total_width > self._total_width: 168 | info += (' ' * (prev_total_width - self._total_width)) 169 | 170 | if self.target is not None and current >= self.target: 171 | info += '\n' 172 | 173 | sys.stdout.write(info) 174 | sys.stdout.flush() 175 | 176 | elif self.verbose == 2: 177 | if self.target is None or current >= self.target: 178 | for k in self._values_order: 179 | info += ' - %s:' % k 180 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 181 | if avg > 1e-3: 182 | info += ' %.4f' % avg 183 | else: 184 | info += ' %.4e' % avg 185 | info += '\n' 186 | 187 | sys.stdout.write(info) 188 | sys.stdout.flush() 189 | 190 | self._last_update = now 191 | 192 | def add(self, n, values=None): 193 | self.update(self._seen_so_far + n, values) 194 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from network import * 2 | from metrics import * 3 | from loss import * 4 | from tensorflow.contrib.framework.python.ops import arg_scope 5 | class TextRemoval(object): 6 | def __init__(self, config): 7 | self.config = config 8 | self.res_num = config.RES_NUM 9 | self.base_channel = config.BASE_CHANNEL 10 | self.sample_num = config.SAMPLE_NUM 11 | self.model_name = 'textremoval' 12 | self.rate=config.RATE 13 | 14 | self.gen_optimizer = tf.train.AdamOptimizer( 15 | learning_rate=float(config.LR), 16 | beta1=float(config.BETA1), 17 | beta2=float(config.BETA2) 18 | ) 19 | self.dis_optimizer = tf.train.AdamOptimizer( 20 | learning_rate=float(config.LR), 21 | beta1=float(config.BETA1), 22 | beta2=float(config.BETA2) 23 | ) 24 | 25 | def build_whole_model(self, batch_data,is_training=True): 26 | batch_predicted, batch_complete, batch_gt, gen_loss, dis_loss=self.textremoval_model(batch_data, training=is_training) 27 | outputs_merged = (batch_complete + 1) / 2 * 255 28 | gt = (batch_gt + 1) / 2 * 255 29 | _, psnr = mean_psnr(gt, outputs_merged) 30 | _, ssim = mean_ssim(gt, outputs_merged) 31 | tf.summary.scalar('train/psnr', psnr) 32 | tf.summary.scalar('train/ssim', ssim) 33 | tf.summary.scalar('train_loss/gen_loss', gen_loss) 34 | tf.summary.scalar('train_loss/dis_loss', dis_loss) 35 | return gen_loss,dis_loss,psnr,ssim 36 | 37 | # def build_validation_model(self, batch_data): 38 | def build_validation_model(self, batch_data, reuse=True, is_training=False): 39 | batch_batch = batch_data 40 | batch_width = int(batch_batch.get_shape().as_list()[2]/5) 41 | batch_gt = batch_batch[:, :, :batch_width,:] / 127.5 - 1. 42 | batch_img = batch_batch[:, :, batch_width:batch_width * 2,:] / 127.5 - 1. 43 | batch_mask = tf.cast(batch_batch[:, :, batch_width*2:batch_width * 3,0:1] > 127.5, tf.float32) 44 | batch_text = batch_mask 45 | # process outputs 46 | stroke_mask1, output1, stroke_mask2, output2 = self.generator( 47 | batch_img, batch_mask, reuse=reuse, training=is_training,name=self.model_name + '_generator', 48 | padding='SAME') 49 | batch_predicted = output2 50 | 51 | batch_complete = batch_predicted * batch_mask + batch_gt * (1.-batch_mask) 52 | 53 | _, psnr = mean_psnr((batch_gt+1.)*127.5, (batch_complete+1.)*127.5) 54 | _, ssim = mean_ssim((batch_gt+1.)*127.5, (batch_complete+1.)*127.5) 55 | return psnr,ssim 56 | 57 | # def build_optim(self, gen_loss, dis_loss): 58 | def build_optim(self, gen_loss, dis_loss): 59 | g_vars = tf.get_collection( 60 | tf.GraphKeys.TRAINABLE_VARIABLES, self.model_name + '_generator') 61 | d_vars = tf.get_collection( 62 | tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') 63 | g_gradient = self.gen_optimizer.compute_gradients(gen_loss, var_list=g_vars) 64 | d_gradient = self.dis_optimizer.compute_gradients(dis_loss, var_list=d_vars) 65 | return self.gen_optimizer.apply_gradients(g_gradient), self.dis_optimizer.apply_gradients(d_gradient) 66 | 67 | 68 | #def inpaint_model(self,batch_data,is_training=False,reuse=False): 69 | def textremoval_model(self, batch_data, training=True,reuse=False): 70 | batch_batch = batch_data 71 | batch_width = int(batch_batch.get_shape().as_list()[2]/5) 72 | batch_gt = batch_batch[:, :, :batch_width,:] / 127.5 - 1. 73 | batch_img = batch_batch[:, :, batch_width:batch_width*2,:] / 127.5 - 1. # raw image with text 74 | batch_mask = tf.cast(batch_batch[:, :, batch_width*2:batch_width*3,0:1] > 127.5, tf.float32) # text region mask 75 | batch_text = tf.cast(batch_batch[:, :, batch_width*3:batch_width*4,0:1] > 127.5, tf.float32) # text stroke mask 76 | # process outputs x1, x2, s1, s2 77 | 78 | stroke_mask1, output1, stroke_mask2, output2 = self.generator( 79 | batch_img, batch_mask, reuse=reuse, training=training,name=self.model_name + '_generator', 80 | padding='SAME') 81 | batch_predicted = output2 82 | 83 | batch_complete = batch_predicted * batch_mask + batch_img * (1. - batch_mask) 84 | if training: 85 | losses = {} 86 | l1_alpha = 1.2 87 | losses['output1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_gt - output1))# 88 | losses['output1_loss'] += 10. * tf.reduce_mean(tf.abs(batch_gt - output1) * batch_mask) 89 | losses['output2_loss'] = tf.reduce_mean(tf.abs(batch_gt - output2))# 90 | losses['output2_loss'] += 10. * tf.reduce_mean(tf.abs(batch_gt - output2) * batch_mask) 91 | 92 | losses['stroke_mask1_loss'] = tf.reduce_mean(tf.abs(batch_text - stroke_mask1))# * (1.-bbox_mask)) 93 | losses['stroke_mask1_loss'] += 10. * tf.reduce_mean(tf.abs(batch_text - stroke_mask1) * batch_mask) 94 | losses['stroke_mask2_loss'] = tf.reduce_mean(tf.abs(batch_text - stroke_mask2))# * (1.-bbox_mask)) 95 | losses['stroke_mask2_loss'] += 10. * tf.reduce_mean(tf.abs(batch_text - stroke_mask2) * batch_mask) 96 | 97 | # seperate gan 98 | batch_pos_feature = self.sngan_discriminator(batch_gt, training=training, reuse=reuse) 99 | batch_neg_feature = self.sngan_discriminator(batch_complete, training=training, reuse=tf.AUTO_REUSE) 100 | 101 | # wgan loss 102 | loss_discriminator, loss_generator = hinge_gan_loss(batch_pos_feature, batch_neg_feature) 103 | 104 | losses['g_loss'] = 0.001 * loss_generator 105 | losses['d_loss'] = loss_discriminator 106 | 107 | losses['g_loss'] = 0.001 * losses['g_loss'] 108 | losses['g_loss'] += 1. * losses['output1_loss'] 109 | losses['g_loss'] += 5. * losses['output2_loss'] 110 | 111 | losses['g_loss'] += losses['stroke_mask1_loss'] 112 | losses['g_loss'] += losses['stroke_mask2_loss'] 113 | viz_img = [batch_gt, batch_img, tf.tile(stroke_mask1,[1,1,1,3]), tf.tile(stroke_mask2,[1,1,1,3]), output1, output2] 114 | 115 | images_summary( 116 | tf.concat(viz_img, axis=2), 117 | 'batchgt_batchimg_strokemask1_strokemask2_output1_output2', 10) 118 | 119 | return batch_predicted,batch_complete,batch_gt,losses['g_loss'],losses['d_loss'] 120 | else: 121 | return batch_predicted,batch_complete,batch_gt,batch_img,batch_mask,batch_text 122 | 123 | def generator(self, image, mask, reuse=False, 124 | training=True, padding='SAME', name='generator'): 125 | """Inpaint network. 126 | 127 | Args: 128 | image: incomplete image, [-1, 1] 129 | mask: mask region {0, 1} 130 | Returns: 131 | [-1, 1] as predicted image 132 | """ 133 | 134 | image2 = image 135 | ones_image = tf.ones_like(image)[:, :, :, 0:1] 136 | image = tf.concat([image, ones_image, ones_image * mask], axis=3) 137 | # two stage network 138 | cnum = 32 139 | with tf.variable_scope(name, reuse=reuse), \ 140 | arg_scope([gen_conv, gen_deconv], 141 | training=training, padding=padding): 142 | # stage 1 stroke mask 143 | t1_conv1 = gen_conv(image, cnum//2, 3, 1, name='t1conv1') 144 | t1_conv2 = gen_conv(t1_conv1, cnum//2, 3, 1, name='t1conv2') 145 | t1_conv3 = gen_conv(t1_conv2, cnum, 3, 2, name='t1conv3_128') 146 | t1_conv4 = gen_conv(t1_conv3, cnum, 3, 1, name='t1conv4') 147 | t1_conv5 = gen_conv(t1_conv4, cnum, 3, 1, name='t1conv5') 148 | t1_conv6 = gen_conv(t1_conv5, 2*cnum, 3, 2, name='t1conv6_64') 149 | t1_conv7 = gen_conv(t1_conv6, 2*cnum, 3, 1, name='t1conv7') 150 | t1_conv8 = gen_conv(t1_conv7, 2*cnum, 3, 1, name='t1conv8') 151 | t1_conv9 = gen_conv(t1_conv8, 4*cnum, 3, 2, name='t1conv9_32') 152 | t1_conv10 = gen_conv(t1_conv9, 4*cnum, 3, 1, name='t1conv10') 153 | t1_conv11 = gen_deconv(t1_conv10, 2*cnum, name='t1conv11_64') 154 | t1_conv11 = tf.concat([t1_conv8, t1_conv11], axis=3) 155 | t1_conv12 = gen_conv(t1_conv11, 2*cnum, 3, 1, name='t1conv12') 156 | t1_conv13 = gen_conv(t1_conv12, 2*cnum, 3, 1, name='t1conv13') 157 | t1_conv14 = gen_conv(t1_conv13, 2*cnum, 3, 1, name='t1conv14') 158 | t1_conv15 = gen_deconv(t1_conv14, cnum, name='t1conv15_128') 159 | t1_conv15 = tf.concat([t1_conv5, t1_conv15], axis=3) 160 | t1_conv16 = gen_conv(t1_conv15, cnum, 3, 1, name='t1conv16') 161 | t1_conv17 = gen_conv(t1_conv16, cnum, 3, 1, name='t1conv17') 162 | t1_conv18 = gen_conv(t1_conv17, cnum, 3, 1, name='t1conv18') 163 | t1_conv19 = gen_deconv(t1_conv18, cnum//2, name='t1conv19_256') 164 | t1_conv19 = tf.concat([t1_conv2, t1_conv19], axis=3) 165 | t1_conv20 = gen_conv(t1_conv19, cnum//2, 3, 1, name='t1conv20') 166 | 167 | stroke_mask1 = gen_conv(t1_conv20, 1, 3, 1, name='stroke_mask1') 168 | 169 | # stage 1 output 170 | xnow = tf.concat([image2, ones_image, ones_image * mask, stroke_mask1 * mask], axis=3) 171 | s1_conv1 = gen_conv(xnow, cnum, 5, 1, name='conv1') 172 | s1_conv2 = gen_conv(s1_conv1, 2*cnum, 3, 2, name='conv2_downsample') 173 | s1_conv3 = gen_conv(s1_conv2, 2*cnum, 3, 1, name='conv3') 174 | s1_conv4 = gen_conv(s1_conv3, 4*cnum, 3, 2, name='conv4_downsample') 175 | s1_conv5 = gen_conv(s1_conv4, 4*cnum, 3, 1, name='conv5') 176 | s1_conv6 = gen_conv(s1_conv5, 4*cnum, 3, 1, name='conv6') 177 | 178 | s1_conv7 = res_block(s1_conv6, name='s1res_block1') 179 | s1_conv8 = res_block(s1_conv7, name='s1res_block2') 180 | s1_conv9 = res_block(s1_conv8, name='s1res_block3') 181 | s1_conv10 = res_block(s1_conv9, name='s1res_block4') 182 | 183 | s1_conv11 = gen_conv(s1_conv10, 4*cnum, 3, 1, name='conv11') 184 | s1_conv11 = tf.concat([s1_conv6, s1_conv11], axis=3) 185 | s1_conv12 = gen_conv(s1_conv11, 4*cnum, 3, 1, name='conv12') 186 | s1_conv12 = tf.concat([s1_conv5, s1_conv12], axis=3) 187 | s1_conv13 = gen_deconv(s1_conv12, 2*cnum, name='conv13_upsample') 188 | s1_conv13 = tf.concat([s1_conv3, s1_conv13], axis=3) 189 | s1_conv14 = gen_conv(s1_conv13, 2*cnum, 3, 1, name='conv14') 190 | s1_conv14 = tf.concat([s1_conv2, s1_conv14], axis=3) 191 | s1_conv15 = gen_deconv(s1_conv14, cnum, name='conv15_upsample') 192 | s1_conv15 = tf.concat([s1_conv1, s1_conv15], axis=3) 193 | s1_conv16 = gen_conv(s1_conv15, cnum//2, 3, 1, name='conv16') 194 | s1_conv17 = gen_conv(s1_conv16, 3, 3, 1, activation=None, name='conv17') 195 | s1_conv = tf.clip_by_value(s1_conv17, -1., 1., name='stage1') 196 | output1 = s1_conv 197 | 198 | # stage 2 stroke mask 199 | sin = tf.concat([output1, ones_image, ones_image * mask, stroke_mask1 * mask], axis=3) 200 | t2_conv1 = gen_conv(sin, cnum//2, 3, 1, name='t2conv1') 201 | t2_conv2 = gen_conv(t2_conv1, cnum//2, 3, 1, name='t2conv2') 202 | t2_conv3 = gen_conv(t2_conv2, cnum, 3, 2, name='t2conv3_128') 203 | t2_conv4 = gen_conv(t2_conv3, cnum, 3, 1, name='t2conv4') 204 | t2_conv5 = gen_conv(t2_conv4, cnum, 3, 1, name='t2conv5') 205 | t2_conv6 = gen_conv(t2_conv5, 2*cnum, 3, 2, name='t2conv6_64') 206 | t2_conv7 = gen_conv(t2_conv6, 2*cnum, 3, 1, name='t2conv7') 207 | t2_conv8 = gen_conv(t2_conv7, 2*cnum, 3, 1, name='t2conv8') 208 | t2_conv9 = gen_conv(t2_conv8, 4*cnum, 3, 2, name='t2conv9_32') 209 | t2_conv10 = gen_conv(t2_conv9, 4*cnum, 3, 1, name='t2conv10') 210 | t2_conv11 = gen_deconv(t2_conv10, 2*cnum, name='t2conv11_64') 211 | t2_conv11 = tf.concat([t2_conv8, t2_conv11], axis=3) 212 | t2_conv12 = gen_conv(t2_conv11, 2*cnum, 3, 1, name='t2conv12') 213 | t2_conv13 = gen_conv(t2_conv12, 2*cnum, 3, 1, name='t2conv13') 214 | t2_conv14 = gen_conv(t2_conv13, 2*cnum, 3, 1, name='t2conv14') 215 | t2_conv15 = gen_deconv(t2_conv14, cnum, name='t2conv15_128') 216 | t2_conv15 = tf.concat([t2_conv5, t2_conv15], axis=3) 217 | t2_conv16 = gen_conv(t2_conv15, cnum, 3, 1, name='t2conv16') 218 | t2_conv17 = gen_conv(t2_conv16, cnum, 3, 1, name='t2conv17') 219 | t2_conv18 = gen_conv(t2_conv17, cnum, 3, 1, name='t2conv18') 220 | t2_conv19 = gen_deconv(t2_conv18, cnum//2, name='t2conv19_256') 221 | t2_conv19 = tf.concat([t2_conv2, t2_conv19], axis=3) 222 | t2_conv20 = gen_conv(t2_conv19, cnum//2, 3, 1, name='t2conv20') 223 | 224 | stroke_mask2 = gen_conv(t2_conv20, 1, 3, 1, name='stroke_mask2') 225 | 226 | # stage 2 output 227 | xnow = tf.concat([output1, ones_image, ones_image * mask, stroke_mask2 * mask], axis=3) 228 | s2c_conv1 = gen_conv(xnow, cnum, 5, 1, name='s2conv1') 229 | s2c_conv2 = gen_conv(s2c_conv1, cnum, 3, 2, name='s2conv2_downsample') 230 | s2c_conv3 = gen_conv(s2c_conv2, 2*cnum, 3, 1, name='s2conv3') 231 | s2c_conv4 = gen_conv(s2c_conv3, 2*cnum, 3, 2, name='s2conv4_downsample') 232 | s2c_conv5 = gen_conv(s2c_conv4, 4*cnum, 3, 1, name='s2conv5') 233 | s2c_conv6 = gen_conv(s2c_conv5, 4*cnum, 3, 1, name='s2conv6') 234 | 235 | s2c_conv7 = res_block(s2c_conv6, name='s2res_block1') 236 | s2c_conv8 = res_block(s2c_conv7, name='s2res_block2') 237 | s2c_conv9 = res_block(s2c_conv8, name='s2res_block3') 238 | s2c_conv10 = res_block(s2c_conv9, name='s2res_block4') 239 | 240 | s2_conv11 = gen_conv(s2c_conv10, 4*cnum, 3, 1, name='s2conv11') 241 | s2_conv11 = tf.concat([s2c_conv6, s2_conv11], axis=3) 242 | s2_conv12 = gen_conv(s2_conv11, 4*cnum, 3, 1, name='s2conv12') 243 | s2_conv12 = tf.concat([s2c_conv5, s2_conv12], axis=3) 244 | s2_conv13 = gen_deconv(s2_conv12, 2*cnum, name='s2conv13_upsample') 245 | s2_conv13 = tf.concat([s2c_conv3, s2_conv13], axis=3) 246 | s2_conv14 = gen_conv(s2_conv13, 2*cnum, 3, 1, name='s2conv14') 247 | s2_conv14 = tf.concat([s2c_conv2, s2_conv14], axis=3) 248 | s2_conv15 = gen_deconv(s2_conv14, cnum, name='s2conv15_upsample') 249 | s2_conv15 = tf.concat([s2c_conv1, s2_conv15], axis=3) 250 | s2_conv16 = gen_conv(s2_conv15, cnum//2, 3, 1, name='s2conv16') 251 | s2_conv17 = gen_conv(s2_conv16, 3, 3, 1, activation=None, name='s2conv17') 252 | output2 = tf.clip_by_value(s2_conv17, -1., 1., name='output') 253 | return stroke_mask1, output1, stroke_mask2, output2 254 | 255 | def sngan_discriminator(self, x, reuse=False, training=True): 256 | with tf.variable_scope('discriminator', reuse=reuse): 257 | cnum = 64 258 | x = sndis_conv(x, cnum, 5, 1, name='conv1', training=training) 259 | x = sndis_conv(x, cnum*2, 5, 2, name='conv2', training=training) 260 | x = sndis_conv(x, cnum*4, 5, 2, name='conv3', training=training) 261 | x = sndis_conv(x, cnum*4, 5, 2, name='conv4', training=training) 262 | x = sndis_conv(x, cnum*4, 5, 2, name='conv5', training=training) 263 | x = sndis_conv(x, cnum*4, 5, 2, name='conv6', training=training) 264 | return x 265 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework.python.ops import add_arg_scope 3 | # import tensorflow.contrib as tf_contrib 4 | # from tensorflow.contrib.framework.python.ops import add_arg_scope 5 | # Xavier : tf_contrib.layers.xavier_initializer() 6 | # He : tf_contrib.layers.variance_scaling_initializer() 7 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02) 8 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001) 9 | 10 | # weight_init = tf.contrib.layers.xavier_initializer(uniform=False) 11 | # weight_regularizer = None 12 | 13 | ################################################################################## 14 | # Layer 15 | ################################################################################## 16 | def images_summary(images, name, max_outs, color_format='BGR'): 17 | """Summary images. 18 | 19 | **Note** that images should be scaled to [-1, 1] for 'RGB' or 'BGR', 20 | [0, 1] for 'GREY'. 21 | 22 | :param images: images tensor (in NHWC format) 23 | :param name: name of images summary 24 | :param max_outs: max_outputs for images summary 25 | :param color_format: 'BGR', 'RGB' or 'GREY' 26 | :return: None 27 | """ 28 | with tf.variable_scope(name), tf.device('/cpu:0'): 29 | if color_format == 'BGR': 30 | img = tf.clip_by_value( 31 | (tf.reverse(images, [-1])+1.)*127.5, 0., 255.) 32 | elif color_format == 'RGB': 33 | img = tf.clip_by_value((images+1.)*127.5, 0, 255) 34 | elif color_format == 'GREY': 35 | img = tf.clip_by_value(images*255., 0, 255) 36 | else: 37 | raise NotImplementedError("color format is not supported.") 38 | tf.summary.image(name, img, max_outputs=max_outs) 39 | @add_arg_scope 40 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv', 41 | padding='SAME', activation=tf.nn.leaky_relu ,training = True): 42 | """Define conv for generator. 43 | 44 | Args: 45 | x: Input. 46 | cnum: Channel number. 47 | ksize: Kernel size. 48 | Stride: Convolution stride. 49 | Rate: Rate for or dilated conv. 50 | name: Name of layers. 51 | padding: Default to SYMMETRIC. 52 | activation: Activation function after convolution. 53 | training: If current graph is for training or inference, used for bn. 54 | 55 | Returns: 56 | tf.Tensor: output 57 | 58 | """ 59 | assert padding in ['SYMMETRIC', 'SAME', 'REFELECT'] 60 | if padding == 'SYMMETRIC' or padding == 'REFELECT': 61 | p = int(rate*(ksize-1)/2) 62 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 63 | padding = 'VALID' 64 | x = tf.layers.conv2d( 65 | x, cnum, ksize, stride, dilation_rate=rate, 66 | activation=activation, padding=padding, name=name) 67 | return x 68 | @add_arg_scope 69 | def gen_deconv(x, cnum, name='upsample', padding='SAME',training = True): 70 | """Define deconv for generator. 71 | The deconv is defined to be a x2 resize_nearest_neighbor operation with 72 | additional gen_conv operation. 73 | 74 | Args: 75 | x: Input. 76 | cnum: Channel number. 77 | name: Name of layers. 78 | training: If current graph is for training or inference, used for bn. 79 | 80 | Returns: 81 | tf.Tensor: output 82 | """ 83 | with tf.variable_scope(name): 84 | x = resize(x, func=tf.image.resize_nearest_neighbor) 85 | x = gen_conv(x, cnum, 3, 1, name=name+'_conv', padding=padding) 86 | return x 87 | 88 | def spectral_norm(w, iteration=1): 89 | w_shape = w.shape.as_list() 90 | w = tf.reshape(w, [-1, w_shape[-1]]) 91 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 92 | u_hat = u 93 | v_hat = None 94 | for i in range(iteration): 95 | """ 96 | power iteration 97 | Usually iteration = 1 will be enough 98 | """ 99 | v_ = tf.matmul(u_hat, tf.transpose(w)) 100 | v_hat = tf.nn.l2_normalize(v_) 101 | u_ = tf.matmul(v_hat, w) 102 | u_hat = tf.nn.l2_normalize(u_) 103 | u_hat = tf.stop_gradient(u_hat) 104 | v_hat = tf.stop_gradient(v_hat) 105 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 106 | with tf.control_dependencies([u.assign(u_hat)]): 107 | w_norm = w / sigma 108 | w_norm = tf.reshape(w_norm, w_shape) 109 | return w_norm 110 | 111 | 112 | def sndis_conv(x, cnum, ksize=5, stride=2, padding='SAME', name='conv', training=True): 113 | """Define conv for sn-patch discriminator. 114 | Activation is set to leaky_relu. 115 | Args: 116 | x: Input. 117 | cnum: Channel number. 118 | ksize: Kernel size. 119 | Stride: Convolution stride. 120 | name: Name of layers. 121 | training: If current graph is for training or inference, used for bn. 122 | Returns: 123 | tf.Tensor: output 124 | """ 125 | with tf.variable_scope(name): 126 | in_channel = x.get_shape().as_list()[-1] 127 | kernel = tf.get_variable('kernel', [ksize, ksize, in_channel, cnum], 128 | initializer=tf.variance_scaling_initializer(), trainable=training) 129 | x = tf.nn.conv2d(x, spectral_norm(kernel), strides=[1, stride, stride, 1], 130 | padding=padding, name=name) 131 | x = tf.nn.leaky_relu(x) 132 | return x 133 | 134 | 135 | def res_block(x, activation = tf.nn.leaky_relu, padding = 'SAME', name = 'res_block'): 136 | cnum = x.get_shape().as_list()[-1] 137 | xin = x 138 | x = tf.layers.conv2d(x, cnum // 4, kernel_size = 1, strides = 1, activation = activation, padding = padding, name = name + '_conv1') 139 | x = tf.layers.conv2d(x, cnum // 4, kernel_size = 3, strides = 1, activation = activation, padding = padding, name = name + '_conv2') 140 | x = tf.layers.conv2d(x, cnum, kernel_size = 1, strides = 1, activation = None, padding = padding, name = name + '_conv3') 141 | x = tf.add(xin, x, name = name + '_add') 142 | x = tf.layers.batch_normalization(x, name = name + '_bn') 143 | x = activation(x, name = name + '_out') 144 | return x 145 | 146 | def hinge_gan_loss(discriminator_data, discriminator_z): 147 | loss_discriminator_data = tf.reduce_mean(tf.nn.relu(1 - discriminator_data)) 148 | loss_discriminator_z = tf.reduce_mean(tf.nn.relu(1 + discriminator_z)) 149 | loss_discriminator = (loss_discriminator_data + loss_discriminator_z) 150 | loss_generator_adversarial = -tf.reduce_mean(discriminator_z) 151 | return loss_discriminator, loss_generator_adversarial 152 | 153 | def resize(x, scale=2, to_shape=None, align_corners=True, dynamic=False, 154 | func=tf.image.resize_bilinear, name='resize'): 155 | if dynamic: 156 | xs = tf.cast(tf.shape(x), tf.float32) 157 | new_xs = [tf.cast(xs[1]*scale, tf.int32), 158 | tf.cast(xs[2]*scale, tf.int32)] 159 | else: 160 | xs = x.get_shape().as_list() 161 | new_xs = [int(xs[1]*scale), int(xs[2]*scale)] 162 | with tf.variable_scope(name): 163 | if to_shape is None: 164 | x = func(x, new_xs, align_corners=align_corners) 165 | else: 166 | x = func(x, [to_shape[0], to_shape[1]], 167 | align_corners=align_corners) 168 | return x 169 | 170 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | import os 5 | import cv2 6 | import glob 7 | from model import TextRemoval 8 | from config import Config 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--image', default='./examples/images/test.jpg', type=str, 12 | help='The filename of image to be completed.') 13 | parser.add_argument('--mask', default='./examples/masks/00001.png', type=str, 14 | help='The filename of mask, value 255 indicates mask.') 15 | parser.add_argument('--output', default='./examples/results/output.png', type=str, 16 | help='Where to save output image.') 17 | parser.add_argument('--checkpoint_dir', default='./model_logs/text_output_0308', type=str, 18 | help='The directory of tensorflow checkpoint.') 19 | 20 | def data_batch(list1, list2,list3): 21 | test_dataset = tf.data.Dataset.from_tensor_slices((list1, list2,list3)) 22 | input_size=256 23 | def image_fn(gt_path,img_path, mask_path): 24 | x = tf.read_file(gt_path) 25 | x_decode = tf.image.decode_jpeg(x, channels=3) 26 | gt = tf.image.resize_images(x_decode, [input_size, input_size]) 27 | gt = tf.cast(gt, tf.float32) 28 | 29 | x = tf.read_file(img_path) 30 | x_decode = tf.image.decode_jpeg(x, channels=3) 31 | img = tf.image.resize_images(x_decode, [input_size, input_size]) 32 | img = tf.cast(img, tf.float32) 33 | 34 | x = tf.read_file(mask_path) 35 | x_decode = tf.image.decode_jpeg(x, channels=1) 36 | mask = tf.image.resize_images(x_decode, [input_size, input_size]) 37 | mask = tf.cast(mask, tf.float32) 38 | return gt,img, mask 39 | 40 | test_dataset = test_dataset. \ 41 | repeat(1). \ 42 | map(image_fn). \ 43 | apply(tf.contrib.data.batch_and_drop_remainder(1)). \ 44 | prefetch(1) 45 | 46 | test_gt,test_image, test_mask = test_dataset.make_one_shot_iterator().get_next() 47 | return test_gt,test_image, test_mask 48 | 49 | 50 | if __name__ == "__main__": 51 | # ng.get_gpus(1) 52 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 53 | args = parser.parse_args() 54 | 55 | config_path = os.path.join('config.yml') 56 | config = Config(config_path) 57 | model = TextRemoval(config) 58 | 59 | # dataset_name = 2000 60 | # path_img = '/data/cqwang/64_backup/cqwang/dataset/IJCIA_dataset/test/test_{}_256/text/'.format(str(dataset_name)) 61 | # path_mask = '/data/cqwang/64_backup/cqwang/dataset/IJCIA_dataset/test/test_{}_256/mask/'.format(str(dataset_name)) 62 | path_img = args.image 63 | path_mask = args.mask 64 | 65 | list_img = list(glob.glob(path_img + '/*.jpg')) + list(glob.glob(path_img + '/*.png')) 66 | list_img.sort() 67 | list_mask = list(glob.glob(path_mask + '/*.jpg')) + list(glob.glob(path_mask + '/*.png')) 68 | list_mask.sort() 69 | 70 | gt,images, masks = data_batch(list_img,list_img, list_mask) 71 | 72 | images = (images / 255 - 0.5) / 0.5 73 | 74 | masks = masks / 255 75 | 76 | images_masked = (images * (1 - masks)) + masks 77 | # input of the model 78 | inputs = tf.concat([images_masked, masks], axis=3) 79 | 80 | # process outputs 81 | stroke_mask1, output1, stroke_mask2, output2 = model.generator( 82 | images, masks, reuse=False, training=False,name='textremoval_generator', 83 | padding='SAME') 84 | output = output2 85 | 86 | outputs_merged = (output * masks) + (images * (1 - masks)) 87 | images = (images + 1) / 2 * 255 88 | 89 | images_masked = (images_masked + 1) / 2 * 255 90 | outputs = (output + 1) / 2 * 255 91 | masks=masks*255 92 | outputs_merged = (outputs_merged + 1) / 2 * 255 93 | 94 | sess_config = tf.ConfigProto() 95 | sess_config.gpu_options.allow_growth = True 96 | with tf.Session(config=sess_config) as sess: 97 | # load pretrained model 98 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 99 | assign_ops = [] 100 | for var in vars_list: 101 | vname = var.name 102 | from_name = vname 103 | var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name) 104 | assign_ops.append(tf.assign(var, var_value)) 105 | 106 | sess.run(assign_ops) 107 | 108 | # res_path="./result_{}/".format(str(dataset_name)) 109 | res_path = args.output 110 | if os.path.exists(res_path): 111 | print("res_path已经存在") 112 | else: 113 | os.makedirs(res_path) 114 | 115 | 116 | for num in range(0, len(list_img)): 117 | outputs_merge, mas, out,img = sess.run([outputs_merged, masks, outputs,images]) 118 | outputs_merge = outputs_merge[0][:, :, ::-1].astype(np.uint8) 119 | 120 | picname = list_img[num].split('/')[-1] 121 | cv2.imwrite(res_path+picname, outputs_merge) 122 | print(res_path+picname) 123 | 124 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import tensorflow as tf 5 | from untils import Progbar 6 | from config import Config 7 | from data_load import Dataset 8 | from model import TextRemoval 9 | 10 | def main(): 11 | config_path = os.path.join('config.yml') 12 | config = Config(config_path) 13 | config.print() 14 | # Init cuda environment 15 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU) 16 | 17 | # Init random seed to less result random 18 | tf.set_random_seed(config.SEED) 19 | np.random.seed(config.SEED) 20 | random.seed(config.SEED) 21 | 22 | # Init training data 23 | dataset = Dataset(config) 24 | batch_concat = dataset.batch_concat 25 | val_concat = dataset.val_concat 26 | 27 | # Init the model 28 | model = TextRemoval(config) 29 | 30 | gen_loss,dis_loss, t_psnr, t_ssim= model.build_whole_model(batch_concat) 31 | gen_optim, dis_optim = model.build_optim(gen_loss, dis_loss) 32 | 33 | val_psnr,val_ssim = model.build_validation_model(val_concat) 34 | 35 | # Create the graph 36 | config_graph = tf.ConfigProto() 37 | config_graph.gpu_options.allow_growth = True 38 | 39 | with tf.Session(config=config_graph) as sess: 40 | # Merge all the summaries 41 | merged = tf.summary.merge_all() 42 | 43 | train_writer = tf.summary.FileWriter(config.CHECKPOINTS + 'train', sess.graph) 44 | eval_writer = tf.summary.FileWriter(config.CHECKPOINTS + 'eval') 45 | saver = tf.train.Saver() 46 | 47 | sess.run(tf.global_variables_initializer()) 48 | sess.run(tf.local_variables_initializer()) 49 | 50 | #For restore the train 51 | checkpoint = tf.train.get_checkpoint_state(config.LOAD_MODEL) 52 | if (checkpoint and checkpoint.model_checkpoint_path): 53 | print(checkpoint.model_checkpoint_path) 54 | meta_graph_path = checkpoint.model_checkpoint_path + ".meta" 55 | restore = tf.train.import_meta_graph(meta_graph_path) 56 | restore.restore(sess, tf.train.latest_checkpoint(config.LOAD_MODEL)) 57 | epoch = int(meta_graph_path.split("-")[-1].split(".")[0]) 58 | step = int(epoch * dataset.len_train / dataset.batch_size) 59 | # flag=1 60 | else: 61 | step = 0 62 | epoch = 0 63 | 64 | # Start input enqueue threads 65 | progbar = Progbar(dataset.len_train // dataset.batch_size, width=20, stateful_metrics=['epoch', 'iter', 'gen_loss', 'dis_loss', 'psnr', 'ssim']) 66 | tmp_epoch = epoch 67 | while epoch < config.EPOCH: 68 | step += 1 69 | epoch = int(step * dataset.batch_size / dataset.len_train) 70 | if (tmp_epoch < epoch): 71 | tmp_epoch = epoch 72 | # print("\n") 73 | progbar = Progbar(dataset.len_train // dataset.batch_size, width=20, stateful_metrics=['epoch', 'iter', 'gen_loss', 'dis_loss', 'psnr', 'ssim']) 74 | 75 | g_loss, _ = sess.run([gen_loss, gen_optim]) 76 | d_loss, _ = sess.run([dis_loss, dis_optim]) 77 | tr_psnr, tr_ssim = sess.run([t_psnr, t_ssim]) 78 | logs = [ 79 | ("epoch", epoch), 80 | ("iter", step), 81 | ("g_loss", g_loss), 82 | ("d_loss", d_loss), 83 | ("psnr", tr_psnr), 84 | ("ssim", tr_ssim) 85 | ] 86 | progbar.add(1, values=logs) 87 | 88 | if step % config.SUMMARY_INTERVAL == 0: 89 | # Run validation 90 | v_psnr = [] 91 | v_ssim = [] 92 | for i in range(dataset.len_val // dataset.val_batch_size): 93 | val_psnr_tmp,val_ssim_tmp=sess.run([val_psnr,val_ssim]) 94 | v_psnr.append(val_psnr_tmp) 95 | v_ssim.append(val_ssim_tmp) 96 | eval_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='val_psnr', simple_value=np.mean(v_psnr))]), epoch) 97 | eval_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='val_ssim', simple_value=np.mean(v_ssim))]), epoch) 98 | 99 | # Train summary 100 | summary = sess.run(merged) 101 | train_writer.add_summary(summary, epoch) 102 | if step % config.SAVE_INTERVAL == 0: 103 | if (checkpoint and checkpoint.model_checkpoint_path): 104 | saver.save(sess, config.CHECKPOINTS + 'textremoval', global_step=epoch, write_meta_graph=False) 105 | else: 106 | saver.save(sess, config.CHECKPOINTS + 'textremoval', global_step=epoch, write_meta_graph=True) 107 | sess.close() 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /untils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | def mean_psnr(a, b, max_val=255.0): 7 | psnr_value = tf.image.psnr(a, b, max_val); 8 | return psnr_value, tf.reduce_mean(psnr_value) 9 | 10 | def mean_ssim(a, b, max_val=255.0): 11 | ssim_value = tf.image.ssim(a, b, max_val); 12 | return ssim_value, tf.reduce_mean(ssim_value) 13 | 14 | class Progbar(object): 15 | """Displays a progress bar. 16 | 17 | Arguments: 18 | target: Total number of steps expected, None if unknown. 19 | width: Progress bar width on screen. 20 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 21 | stateful_metrics: Iterable of string names of metrics that 22 | should *not* be averaged over time. Metrics in this list 23 | will be displayed as-is. All others will be averaged 24 | by the progbar before display. 25 | interval: Minimum visual progress update interval (in seconds). 26 | """ 27 | 28 | def __init__(self, target, width=25, verbose=1, interval=0.05, 29 | stateful_metrics=None): 30 | self.target = target 31 | self.width = width 32 | self.verbose = verbose 33 | self.interval = interval 34 | if stateful_metrics: 35 | self.stateful_metrics = set(stateful_metrics) 36 | else: 37 | self.stateful_metrics = set() 38 | 39 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 40 | sys.stdout.isatty()) or 41 | 'ipykernel' in sys.modules or 42 | 'posix' in sys.modules) 43 | self._total_width = 0 44 | self._seen_so_far = 0 45 | # We use a dict + list to avoid garbage collection 46 | # issues found in OrderedDict 47 | self._values = {} 48 | self._values_order = [] 49 | self._start = time.time() 50 | self._last_update = 0 51 | 52 | def update(self, current, values=None): 53 | """Updates the progress bar. 54 | 55 | Arguments: 56 | current: Index of current step. 57 | values: List of tuples: 58 | `(name, value_for_last_step)`. 59 | If `name` is in `stateful_metrics`, 60 | `value_for_last_step` will be displayed as-is. 61 | Else, an average of the metric over time will be displayed. 62 | """ 63 | values = values or [] 64 | for k, v in values: 65 | if k not in self._values_order: 66 | self._values_order.append(k) 67 | if k not in self.stateful_metrics: 68 | if k not in self._values: 69 | self._values[k] = [v * (current - self._seen_so_far), 70 | current - self._seen_so_far] 71 | else: 72 | self._values[k][0] += v * (current - self._seen_so_far) 73 | self._values[k][1] += (current - self._seen_so_far) 74 | else: 75 | self._values[k] = v 76 | self._seen_so_far = current 77 | 78 | now = time.time() 79 | info = ' - %.0fs' % (now - self._start) 80 | if self.verbose == 1: 81 | if (now - self._last_update < self.interval and 82 | self.target is not None and current < self.target): 83 | return 84 | 85 | prev_total_width = self._total_width 86 | if self._dynamic_display: 87 | sys.stdout.write('\b' * prev_total_width) 88 | sys.stdout.write('\r') 89 | else: 90 | sys.stdout.write('\n') 91 | 92 | if self.target is not None: 93 | numdigits = int(np.floor(np.log10(self.target))) + 1 94 | barstr = '%%%dd/%d [' % (numdigits, self.target) 95 | bar = barstr % current 96 | prog = float(current) / self.target 97 | prog_width = int(self.width * prog) 98 | if prog_width > 0: 99 | bar += ('=' * (prog_width - 1)) 100 | if current < self.target: 101 | bar += '>' 102 | else: 103 | bar += '=' 104 | bar += ('.' * (self.width - prog_width)) 105 | bar += ']' 106 | else: 107 | bar = '%7d/Unknown' % current 108 | 109 | self._total_width = len(bar) 110 | sys.stdout.write(bar) 111 | 112 | if current: 113 | time_per_unit = (now - self._start) / current 114 | else: 115 | time_per_unit = 0 116 | if self.target is not None and current < self.target: 117 | eta = time_per_unit * (self.target - current) 118 | if eta > 3600: 119 | eta_format = '%d:%02d:%02d' % (eta // 3600, 120 | (eta % 3600) // 60, 121 | eta % 60) 122 | elif eta > 60: 123 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 124 | else: 125 | eta_format = '%ds' % eta 126 | 127 | info = ' - ETA: %s' % eta_format 128 | else: 129 | if time_per_unit >= 1: 130 | info += ' %.0fs/step' % time_per_unit 131 | elif time_per_unit >= 1e-3: 132 | info += ' %.0fms/step' % (time_per_unit * 1e3) 133 | else: 134 | info += ' %.0fus/step' % (time_per_unit * 1e6) 135 | 136 | for k in self._values_order: 137 | info += ' - %s:' % k 138 | if isinstance(self._values[k], list): 139 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 140 | if abs(avg) > 1e-3: 141 | info += ' %.4f' % avg 142 | else: 143 | info += ' %.4e' % avg 144 | else: 145 | info += ' %s' % self._values[k] 146 | 147 | self._total_width += len(info) 148 | if prev_total_width > self._total_width: 149 | info += (' ' * (prev_total_width - self._total_width)) 150 | 151 | if self.target is not None and current >= self.target: 152 | info += '\n' 153 | 154 | sys.stdout.write(info) 155 | sys.stdout.flush() 156 | 157 | elif self.verbose == 2: 158 | if self.target is None or current >= self.target: 159 | for k in self._values_order: 160 | info += ' - %s:' % k 161 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 162 | if avg > 1e-3: 163 | info += ' %.4f' % avg 164 | else: 165 | info += ' %.4e' % avg 166 | info += '\n' 167 | 168 | sys.stdout.write(info) 169 | sys.stdout.flush() 170 | 171 | self._last_update = now 172 | 173 | def add(self, n, values=None): 174 | self.update(self._seen_so_far + n, values) 175 | --------------------------------------------------------------------------------