├── Application_Form_for_Using_Scene_Text_Removal.docx ├── README.md ├── example ├── mask_512 │ └── img_1.png ├── train_128 │ └── img_1.png ├── train_256 │ └── img_1.png └── train_512 │ └── img_1.png ├── imagepool.py ├── model.py ├── network.py ├── test.py ├── text2.py ├── train.py └── vis_dataset.py /Application_Form_for_Using_Scene_Text_Removal.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HCIILAB/Scene-Text-Removal/9c65c939a5d9ad3544f2db8e80a2138fee1dafc3/Application_Form_for_Using_Scene_Text_Removal.docx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EnsNet: Ensconce Text in the Wild 2 | 3 | A synthetic benchmark database for scene text removal is now released by Deep Learning and Vision Computing Lab of South China University of Technology. The database can be downloaded through the following links: 4 | * Yunpan : (link: https://pan.baidu.com/s/1wwBwgm-n2A7iykoD0i37iQ PASSWORD: vk8f) (Size = 6.3G). 5 | * Google Driver: (link: https://drive.google.com/open?id=1l_yJm1vWV7TF7vDcaVa7FqZLfW7ASYeo) (Size = 6.3G). 6 | 7 | On the other hand, we collected 1000 images from the ICDAR 2017 MLT subdataset which only contains English text to enlarge real data, and the background image(label) is generated by manually erasing the text. The database can be downloaded through the following links: 8 | * Yunpan : (link: https://pan.baidu.com/s/1WBvB1kS1BcmgrDi9c1Me9Q PASSWORD: knr7). 9 | * Google Driver : (link: https://drive.google.com/file/d/1G0d6yQwYEDhJdZH-S8mYWTXltJRG3Mg1/view?usp=sharing). 10 | Note: The real scene text removal dataset can only be used for non-commercial research purpose. For scholars or organization who wants to use the database, please first fill in this [Application Form](https://github.com/HCIILAB/Scene-Text-Removal/blob/master/Application_Form_for_Using_Scene_Text_Removal.docx) and send it via email to us (lianwen.jin@gmail.com). We will give you the decompression password after your letter has been received and approved. 11 | ## Description 12 | 13 | The training set of synthetic database consists of a total of 8000 images and the test set contains 800 images; all the training and test samples are resized to 512 × 512. The code for generating synthetic dataset and more synthetic text images as described in “Ankush Gupta, Andrea Vedaldi, Andrew Zisserman, Synthetic Data for Text localisation in Natural Images, CVPR 2016", and can be found in (https://github.com/ankush-me/SynthText). 14 | Besides, all the real scene text images are also resized to 512 × 512. 15 | 16 | For more details, please refer to our [AAAI 2019 paper](https://www.aaai.org/ojs/index.php/AAAI/article/view/3859/3737). arXiv: http://arxiv.org/abs/1812.00723 17 | 18 | ## Requirements 19 | 1. Mxnet==1.3.1 20 | 2. Python2. 21 | 3. NVIDA GPU+ CUDA 8.0. 22 | 4. Matplotlib. 23 | 5. Numpy. 24 | 25 | ## Installation 26 | 1. Clone this respository. 27 | ``` 28 | git clone https://github.com/HCIILAB/Scene-Text-Removal 29 | ``` 30 | ## Running 31 | ### 1. Image Prepare 32 | You can refer to our given example to put data. 33 | ### 2. Training 34 | To train our model, you may need to change the path of dataset or the parameters of the network etc. Then run the following code: 35 | ``` 36 | python train.py \ 37 | --trainset_path=[the path of dataset] \ 38 | --checkpoint=[path save the model] \ 39 | --gpu=[use gpu] \ 40 | --lr=[Learning Rate] \ 41 | --n_epoch=[Number of iterations] 42 | ``` 43 | ### 3. Testing 44 | To output the generated results of the inputs, you can use the [test.py](https://github.com/HCIILAB/Scene-Text-Removal/blob/master/test.py). Please run the following code: 45 | ``` 46 | python test.py \ 47 | --test_image=[the path of test images] \ 48 | --model=[which model to be test] \ 49 | --vis=[ vis images] \ 50 | --result=[path to save the output images] 51 | ``` 52 | To evalution the model performace over a dataset, you can find the evaluation metrics in this website [PythonCode.zip](http://pione.dinf.usherbrooke.ca/static/code) 53 | ### 4. Pretrained models 54 | Please download the ImageNet pretrained models [vgg16](https://pan.baidu.com/s/1Ep83Wc0DHY8rQHaNZM8oPQ) PASSWORD:8tof, and put it under 55 | ``` 56 | root/.mxmet/models/ 57 | ``` 58 | ## Paper 59 | 60 | Please consider to cite our paper when you use our database: 61 | ``` 62 | @article{zhang2019EnsNet, 63 | title = {EnsNet: Ensconce Text in the Wild}, 64 | author = {Shuaitao Zhang∗, Yuliang Liu∗, Lianwen Jin†, Yaoxiong Huang, Songxuan Lai 65 | joural = {AAAI} 66 | year = {2019} 67 | } 68 | ``` 69 | ## Feedback 70 | 71 | Suggestions and opinions of dataset of this dataset (both positive and negative) are greatly welcome. Please contact the authors by sending email to eestzhang@mail.scut.edu.cn. 72 | 73 | ## Copyright 74 | 75 | The synthetic database can be only used for non-commercial research purpose. 76 | 77 | For commercial purpose usage, please contact Dr. Lianwen Jin: lianwen.jin@gmail.com. 78 | 79 | Copyright 2018, Deep Learning and Vision Computing Lab, South China University of Teacnology.http://www.dlvc-lab.net 80 | -------------------------------------------------------------------------------- /example/mask_512/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HCIILAB/Scene-Text-Removal/9c65c939a5d9ad3544f2db8e80a2138fee1dafc3/example/mask_512/img_1.png -------------------------------------------------------------------------------- /example/train_128/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HCIILAB/Scene-Text-Removal/9c65c939a5d9ad3544f2db8e80a2138fee1dafc3/example/train_128/img_1.png -------------------------------------------------------------------------------- /example/train_256/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HCIILAB/Scene-Text-Removal/9c65c939a5d9ad3544f2db8e80a2138fee1dafc3/example/train_256/img_1.png -------------------------------------------------------------------------------- /example/train_512/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HCIILAB/Scene-Text-Removal/9c65c939a5d9ad3544f2db8e80a2138fee1dafc3/example/train_512/img_1.png -------------------------------------------------------------------------------- /imagepool.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import nd 3 | import numpy as np 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | ret_imgs = [] 15 | for i in range(images.shape[0]): 16 | image = nd.expand_dims(images[i], axis=0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | ret_imgs.append(image) 21 | else: 22 | p = nd.random_uniform(0, 1, shape=(1,)).asscalar() 23 | if p > 0.5: 24 | random_id = nd.random_uniform(0, self.pool_size - 1, shape=(1,)).astype(np.uint8).asscalar() 25 | tmp = self.images[random_id].copy() 26 | self.images[random_id] = image 27 | ret_imgs.append(tmp) 28 | else: 29 | ret_imgs.append(image) 30 | ret_imgs = nd.concat(*ret_imgs, dim=0) 31 | return ret_imgs -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import matplotlib as mpl 4 | import tarfile 5 | import matplotlib.image as mpimg 6 | from matplotlib import pyplot as plt 7 | import mxnet as mx 8 | from mxnet import gluon 9 | from mxnet import ndarray as nd 10 | from mxnet.gluon import nn, utils 11 | from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout,BatchNorm 12 | from mxnet import autograd 13 | import numpy as np 14 | from mxnet.gluon.model_zoo import vision as models 15 | import mxnet 16 | # Define Unet generator skip block 17 | class UnetSkipUnit(HybridBlock): 18 | def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False, 19 | use_dropout=False, use_bias=False): 20 | super(UnetSkipUnit, self).__init__() 21 | 22 | with self.name_scope(): 23 | self.outermost = outermost 24 | en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1, 25 | in_channels=outer_channels, use_bias=use_bias) 26 | en_relu = LeakyReLU(alpha=0.2) 27 | en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels) 28 | de_relu = Activation(activation='relu') 29 | de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels) 30 | 31 | if innermost: 32 | de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, 33 | in_channels=inner_channels, use_bias=use_bias) 34 | encoder = [en_relu, en_conv] 35 | decoder = [de_relu, de_conv, de_norm] 36 | model = encoder + decoder 37 | elif outermost: 38 | de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, 39 | in_channels=inner_channels * 2) 40 | encoder = [en_conv] 41 | decoder = [de_relu, de_conv, Activation(activation='tanh')] 42 | model = encoder + [inner_block] + decoder 43 | else: 44 | de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, 45 | in_channels=inner_channels * 2, use_bias=use_bias) 46 | encoder = [en_relu, en_conv, en_norm] 47 | decoder = [de_relu, de_conv, de_norm] 48 | model = encoder + [inner_block] + decoder 49 | if use_dropout: 50 | model += [Dropout(rate=0.5)] 51 | 52 | self.model = HybridSequential() 53 | with self.model.name_scope(): 54 | for block in model: 55 | self.model.add(block) 56 | 57 | def hybrid_forward(self, F, x): 58 | if self.outermost: 59 | return self.model(x) 60 | else: 61 | return F.concat(self.model(x), x, dim=1) 62 | 63 | # Define Unet generator 64 | class UnetGenerator(HybridBlock): 65 | def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True): 66 | super(UnetGenerator, self).__init__() 67 | 68 | #Build unet generator structure 69 | unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True) 70 | for _ in range(num_downs - 5): 71 | unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout) 72 | unet = UnetSkipUnit(ngf * 8, ngf * 4, unet) 73 | unet = UnetSkipUnit(ngf * 4, ngf * 2, unet) 74 | unet = UnetSkipUnit(ngf * 2, ngf * 1, unet) 75 | unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True) 76 | 77 | with self.name_scope(): 78 | self.model = unet 79 | 80 | def hybrid_forward(self, F, x): 81 | return self.model(x) 82 | 83 | # Define the PatchGAN discriminator 84 | class Discriminator(HybridBlock): 85 | def __init__(self, in_channels, ndf=64, n_layers=3, use_sigmoid=False, use_bias=False): 86 | super(Discriminator, self).__init__() 87 | 88 | with self.name_scope(): 89 | self.model = HybridSequential() 90 | kernel_size = 4 91 | padding = int(np.ceil((kernel_size - 1)/2)) 92 | self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=2, 93 | padding=padding, in_channels=in_channels)) 94 | self.model.add(LeakyReLU(alpha=0.2)) 95 | 96 | nf_mult = 1 97 | for n in range(1, n_layers): 98 | nf_mult_prev = nf_mult 99 | nf_mult = min(2 ** n, 8) 100 | self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=2, 101 | padding=padding, in_channels=ndf * nf_mult_prev, 102 | use_bias=use_bias)) 103 | self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult)) 104 | self.model.add(LeakyReLU(alpha=0.2)) 105 | 106 | nf_mult_prev = nf_mult 107 | nf_mult = min(2 ** n_layers, 8) 108 | self.model.add(Conv2D(channels=ndf * nf_mult, kernel_size=kernel_size, strides=1, 109 | padding=padding, in_channels=ndf * nf_mult_prev, 110 | use_bias=use_bias)) 111 | self.model.add(BatchNorm(momentum=0.1, in_channels=ndf * nf_mult)) 112 | self.model.add(LeakyReLU(alpha=0.2)) 113 | self.model.add(Conv2D(channels=1, kernel_size=kernel_size, strides=1, 114 | padding=padding, in_channels=ndf * nf_mult)) 115 | if use_sigmoid: 116 | self.model.add(Activation(activation='sigmoid')) 117 | 118 | def hybrid_forward(self, F, x): 119 | out = self.model(x) 120 | # print (self.model) 121 | #print(out) 122 | return out 123 | 124 | class label_Discriminator(HybridBlock): 125 | def __init__(self, in_channels, ndf=1, n_layers=3, use_sigmoid=False, use_bias=False): 126 | super(label_Discriminator, self).__init__() 127 | 128 | with self.name_scope(): 129 | self.model = HybridSequential() 130 | kernel_size = 70 131 | padding = 24 132 | self.model.add(Conv2D(channels=ndf, kernel_size=kernel_size, strides=8, 133 | padding=padding, in_channels=in_channels, use_bias=use_bias)) 134 | if use_sigmoid: 135 | self.model.add(Activation(activation='sigmoid')) 136 | 137 | def hybrid_forward(self, F, x): 138 | out = self.model(x) 139 | return out 140 | class LC(HybridBlock): 141 | def __init__(self, outer_channels): 142 | super(LC, self).__init__() 143 | with self.name_scope(): 144 | channels = int(np.ceil(outer_channels/2)) 145 | self.model = HybridSequential() 146 | self.model.add(nn.Conv2D(channels,kernel_size=1)) 147 | self.model.add(nn.Conv2D(channels, kernel_size=3, padding=1, 148 | strides=1)) 149 | self.model.add(nn.Conv2D(channels, kernel_size=3, padding=1, 150 | strides=1)) 151 | def hybrid_forward(self, F, x): 152 | out = self.model(x) 153 | return out 154 | class Residual(nn.HybridBlock): 155 | def __init__(self, channels, same_shape=True, **kwargs): 156 | super(Residual, self).__init__(**kwargs) 157 | self.same_shape = same_shape 158 | with self.name_scope(): 159 | strides = 1 if same_shape else 2 160 | self.conv1 = nn.Conv2D(channels, kernel_size=3, padding=1, 161 | strides=strides) 162 | self.conv2 = nn.Conv2D(channels, kernel_size=3, padding=1) 163 | if not same_shape: 164 | self.conv3 = nn.Conv2D(channels, kernel_size=1, 165 | # self.conv3 = nn.Conv2D(channels, kernel_size=3, padding=1, 166 | strides=strides) 167 | def hybrid_forward(self, F, x): 168 | out = F.relu(self.conv1(x)) 169 | out = self.conv2(out) 170 | if not self.same_shape: 171 | x = self.conv3(x) 172 | return F.relu(out + x) 173 | def get_net(style_layers): 174 | ctx = mx.gpu(0) 175 | vgg16 = models.vgg19(pretrained=True) 176 | net = nn.Sequential() 177 | for i in range(max(style_layers)+1): 178 | net.add(vgg16.features[i]) 179 | # net.collect_params().reset_ctx(ctx) 180 | return net 181 | def extract_features(x,in_size = 224): 182 | B,C,H,W = x.shape 183 | Img = (x + 1.0)*127.5 184 | Img_chanells = [nd.expand_dims(Img[:,i,:,:],axis=1) for i in range(3)] 185 | Img_chanells[0] = (Img_chanells[0]/255 - 0.485)/ 0.229 #subtracted by [103.939, 116.779, 123.68] 186 | Img_chanells[1] = (Img_chanells[1]/255 - 0.456)/ 0.224 #subtracted by [103.939, 116.779, 123.68] 187 | Img_chanells[2] = (Img_chanells[2]/255 - 0.406)/ 0.225 #subtracted by [103.939, 116.779, 123.68] 188 | Img = nd.concat(*Img_chanells,dim=1) 189 | limx = H - in_size 190 | limy = W - in_size 191 | xs = np.random.randint(0,limx,B) 192 | ys = np.random.randint(0,limy,B) 193 | lis = [nd.expand_dims(Img[i,:,x:x+in_size,y:y+in_size],axis=0) for i,(x,y) in enumerate(zip(xs,ys))] 194 | Img_cropped = nd.concat(*lis,dim=0) 195 | return Img_cropped 196 | class STE(nn.Block): 197 | """docstring for STE nn.HybridBlock """ 198 | def __init__(self,**kwargs): 199 | super(STE,self).__init__(**kwargs) 200 | # self.verbose = verbose 201 | # self.in_planes = 64 202 | with self.name_scope(): 203 | self.layer1 = nn.Conv2D(64, kernel_size=4, strides=2,padding=1) 204 | self.lc1 = LC(64) 205 | self.conv1 = nn.Conv2D(64,kernel_size=1) 206 | # self.bn1 = nn.BatchNorm() 207 | # self.relu_conv1 = nn.Activation(activation='relu') 208 | self.a1 = nn.MaxPool2D(pool_size=2, strides=2) 209 | self.a2 = Residual(64) 210 | self.layer2 = Residual(64) 211 | self.lc2 = LC(64) 212 | self.conv2 = nn.Conv2D(64,kernel_size=1) 213 | # self.bn2 = nn.BatchNorm() 214 | # self.relu_conv2 = nn.Activation(activation='relu') 215 | self.b1 = Residual(128, same_shape=False) 216 | self.layer3 = Residual(128) 217 | self.lc3 = LC(128) 218 | self.conv3 = nn.Conv2D(128,kernel_size=1) 219 | # self.bn3 = nn.BatchNorm() 220 | # self.relu_conv3 = nn.Activation(activation='relu') 221 | self.c1 = Residual(256, same_shape=False) 222 | self.layer4 = Residual(256) 223 | self.lc4 = LC(256) 224 | self.conv4 = nn.Conv2D(256,kernel_size=1) 225 | # self.bn4 = nn.BatchNorm() 226 | # self.relu_conv4 = nn.Activation(activation='relu') 227 | self.d1 = Residual(512, same_shape=False) 228 | self.layer5 = Residual(512) 229 | 230 | # block 6 231 | # b6 = nn.Sequential() 232 | # b6.add( 233 | # nn.AvgPool2D(pool_size=3), 234 | # nn.Dense(num_classes) 235 | # ) 236 | self.layer6 = nn.Conv2D(2,kernel_size=1) 237 | self.delayer1 = nn.Conv2DTranspose(256, kernel_size=4, padding=1,strides=2) 238 | # self.debn1 = nn.BatchNorm() 239 | self.relu1 = nn.ELU(alpha=1.0) 240 | # self.relu1 = nn.ELU(alpha=0.2) 241 | # self.relu1 = nn.ELU(alpha=0.2) 242 | # self.relu11 = nn.(activation='relu') 243 | self.relu11 = nn.ELU(alpha=1.0) 244 | # self.relu11 = nn.ELU(alpha=1.0) 245 | # mxnet.ndarray.add(lhs, rhs) 246 | self.delayer2 = nn.Conv2DTranspose(128, kernel_size=4, padding=1,strides=2) 247 | # self.debn2 = nn.BatchNorm() 248 | self.relu2 = nn.ELU(alpha=1.0) 249 | self.relu22 = nn.ELU(alpha=1.0) 250 | self.delayer3 = nn.Conv2DTranspose(64, kernel_size=4, padding=1,strides=2) 251 | self.convs_1 = Conv2D(channels=3, kernel_size=1, strides=1, padding=0,use_bias=False) 252 | # self.debn3 = nn.BatchNorm() 253 | self.relu3 = nn.ELU(alpha=1.0) 254 | self.relu33 = nn.ELU(alpha=1.0) 255 | self.delayer4 = nn.Conv2DTranspose(64, kernel_size=4, padding=1,strides=2) 256 | self.convs_2 =Conv2D(channels=3, kernel_size=1, strides=1, padding=0,use_bias=False) 257 | # self.debn4 = nn.BatchNorm() 258 | self.relu4 = nn.ELU(alpha=1.0) 259 | self.relu44 = nn.ELU(alpha=1.0) 260 | self.delayer5 = nn.Conv2DTranspose(3, kernel_size=4, padding=1,strides=2) 261 | # self.debn5 = nn.BatchNorm() 262 | self.relu5 = nn.ELU(alpha=1.0) 263 | 264 | 265 | def forward(self, x): 266 | c1 = self.layer1(x) 267 | lc1 = self.lc1(c1) 268 | a1 = self.a1(c1) 269 | a2 = self.a2(a1) 270 | c2 = self.layer2(a2) 271 | lc2 = self.lc2(c2) 272 | b1 = self.b1(c2) 273 | c3 = self.layer3(b1) 274 | lc3 = self.lc3(c3) 275 | C1 = self.c1(c3) 276 | c4 = self.layer4(C1) 277 | lc4 = self.lc4(c4) 278 | d1 = self.d1(c4) 279 | c5 = self.layer5(d1) 280 | p51 = self.layer6(c5) 281 | p5 = self.relu11(self.conv4(lc4) + self.relu1(self.delayer1(p51))) 282 | p6 = self.relu22(self.conv3(lc3) + self.relu2(self.delayer2(p5))) 283 | p7 = self.relu33(self.conv2(lc2) + self.relu3(self.delayer3(p6))) 284 | p7_o = self.convs_1(p7) 285 | p8 = self.relu44(self.conv1(lc1) + self.relu4(self.delayer4(p7))) 286 | p8_o = self.convs_2(p8) 287 | p9 = self.relu5(self.delayer5(p8)) 288 | return p5,p6,p7_o,p8_o,p9 289 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import matplotlib as mpl 4 | import tarfile 5 | import matplotlib.image as mpimg 6 | from matplotlib import pyplot as plt 7 | import mxnet as mx 8 | from mxnet import gluon 9 | from mxnet import ndarray as nd 10 | from mxnet.gluon import nn, utils 11 | from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \ 12 | BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout 13 | from mxnet import autograd 14 | import numpy as np 15 | from mxnet import init 16 | from model import UnetGenerator,Discriminator,STE,label_Discriminator 17 | from mxnet.gluon.model_zoo import vision as models 18 | ctx = mx.gpu(0) 19 | def get_net(pretrained_net,style_layers): 20 | net = nn.Sequential() 21 | for i in range(max(style_layers)+1): 22 | net.add(pretrained_net.features[i]) 23 | return net 24 | def param_init(param): 25 | # ctx = mx.cpu() 26 | if param.name.find('ste0_conv0') != -1: 27 | param.initialize(init=mx.init.Zero(), ctx=ctx) 28 | elif param.name.find('conv') != -1: 29 | if param.name.find('weight') != -1: 30 | param.initialize(init=mx.init.Normal(0.02), ctx=ctx) 31 | else: 32 | param.initialize(init=mx.init.Zero(), ctx=ctx) 33 | elif param.name.find('ste0_instancenorm') != -1: 34 | param.initialize(init=mx.init.Zero(), ctx=ctx) 35 | # elif param.name.find('batchnorm1') != -1: 36 | # param.initialize(init=mx.init.Zero(), ctx=ctx) 37 | # if param.name.find('gamma') != -1: 38 | # param.set_data(nd.random_normal(1, 0.02, param.data().shape)) 39 | elif param.name.find('batchnorm') != -1: 40 | param.initialize(init=mx.init.Zero(), ctx=ctx) 41 | # Initialize gamma from normal distribution with mean 1 and std 0.02 42 | if param.name.find('gamma') != -1: 43 | param.set_data(nd.random_normal(1, 0.02, param.data().shape)) 44 | 45 | def network_init(net): 46 | for param in net.collect_params().values(): 47 | param_init(param) 48 | 49 | def set_network(args): 50 | style_layers = [4,9,16] 51 | # Pixel2pixel networks 52 | # netG = UnetGenerator(in_channels=3, num_downs=8) 53 | net_label = label_Discriminator(in_channels=1,use_sigmoid=False) 54 | netG = STE() 55 | netD = Discriminator(in_channels=6,use_sigmoid=False) 56 | netvgg = models.vgg16(pretrained=True) 57 | net = get_net(netvgg,style_layers) 58 | net.collect_params().reset_ctx(ctx) 59 | # net.collect_params().setattr('grad_req', 'null') 60 | # Initialize parameters 61 | netG.initialize(ctx=ctx,init=init.Xavier()) 62 | if args.model: 63 | netG.collect_params().load(args.model,ctx = ctx) 64 | netG.collect_params().reset_ctx(ctx) 65 | network_init(netD) 66 | net_label.initialize(ctx=ctx,init=mx.initializer.One()) 67 | net_label.collect_params().setattr('grad_req', 'null') 68 | 69 | net_label.collect_params().reset_ctx(ctx) 70 | # net.collect_params().setattr('grad_req', 'null') 71 | # net.collect_params().reset_ctx(ctx) 72 | # trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta}) 73 | # trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta}) 74 | # trainerV = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta}) 75 | # trainerL = gluon.Trainer(net_label.collect_params(), 'adam', {'learning_rate': args.lr, 'beta': args.beta}) 76 | trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': args.lr}) 77 | trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': args.lr}) 78 | trainerV = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': args.lr}) 79 | trainerL = gluon.Trainer(net_label.collect_params(), 'adam', {'learning_rate': args.lr}) 80 | return netG, netD,net,net_label,trainerG, trainerD,trainerV, trainerL 81 | 82 | # Loss 83 | # GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() 84 | # L1_loss = gluon.loss.L1Loss() 85 | # 86 | # netG, netD, net, trainerG, trainerD trainerV = set_network() 87 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import matplotlib as mpl 4 | import tarfile 5 | import matplotlib.image as mpimg 6 | from matplotlib import pyplot as plt 7 | import mxnet as mx 8 | from mxnet import gluon 9 | from mxnet import ndarray as nd 10 | from mxnet.gluon import nn, utils 11 | from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \ 12 | BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout 13 | from mxnet import autograd 14 | import numpy as np 15 | from datetime import datetime 16 | import time 17 | import logging 18 | from network import set_network 19 | from imagepool import ImagePool 20 | from dataset import load_data,MyDataSet 21 | from vis_dataset import visualize 22 | from mxnet.gluon.data import Dataset, DataLoader 23 | import glob 24 | from mxnet import image,nd 25 | import argparse 26 | def test(args): 27 | use_gpu = args.gpu 28 | ctx = mx.gpu(0) if use_gpu else mx.cpu(0) 29 | img_lists = glob.glob(args.test_image + '/*') 30 | netG, netD,net, net_label ,trainerG, trainerD, trainerV, trainerL = set_network(args) 31 | netG.collect_params().reset_ctx(ctx) 32 | netG.collect_params().load(args.model,ctx = ctx) 33 | # FPS = 0 34 | # all_time = 0 35 | # btic = time.time() 36 | for i, x in enumerate(img_lists): 37 | time1 = time.time() 38 | prefix = x.split('/')[-1].split('.')[0] 39 | data1 = image.imread(x) 40 | data = data1.astype(np.float32)/127.5 - 1 41 | data = image.imresize(data, args.input_size, args.input_size) 42 | data = nd.transpose(data, (2,0,1)) 43 | data = data.reshape((1,) + data.shape) 44 | img_name = x.split('/')[-1].split('.')[0] 45 | real_in = data.as_in_context(ctx) 46 | # all_time = all_time + time.time()-time1 47 | # btic = time.time() 48 | p5,p6,p7,p8,fake_out = netG(real_in) 49 | # sppeed = time.time() - btic 50 | # FPS = FPS + sppeed 51 | # print (FPS,all_time) 52 | fake_img = fake_out[0] 53 | predict = ((fake_img.asnumpy().transpose(1, 2, 0) + 1.0).clip(0,2) * 127.5).astype(np.uint8) 54 | 55 | plt.imshow(predict) 56 | if args.vis: 57 | plt.show() 58 | # plt.show() 59 | prefix = x.split('/')[-1].split('.')[0] 60 | save_path = args.result + prefix + '.png' 61 | plt.savefig(save_path) 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser(description='Hyperparams') 64 | parser.add_argument('--test_image', nargs='?', type=str, default='', 65 | help='Test image path') 66 | parser.add_argument('--model', nargs='?', type=str, default='', 67 | help='Path to saved model to restart from') 68 | parser.add_argument('--input_size', nargs='?', type=int, default=512, 69 | help='Path to previous saved model to restart from') 70 | parser.add_argument('--lr', nargs='?', type=float, default=0.0002, 71 | help='Learning Rate') 72 | parser.add_argument('--beta', nargs='?', type=float, default=0.0002, 73 | help='beta') 74 | parser.add_argument('--gpu', nargs='?', type=bool, default=True, 75 | help='use_gpu') 76 | parser.add_argument('--vis', nargs='?', type=bool, default=True, 77 | help='vis result') 78 | parser.add_argument('--result', nargs='?', type=str, default='', 79 | help='Path to save resulted images') 80 | args = parser.parse_args() 81 | test(args) -------------------------------------------------------------------------------- /text2.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import mxnet as mx 4 | from mxnet import nd 5 | import numpy as np 6 | from matplotlib import pyplot as plt 7 | from mxnet.gluon.data import Dataset, DataLoader 8 | from vis_dataset import visualize 9 | import random 10 | import cv2 11 | pool_size = 50 12 | img_wd = 512 13 | img_ht = 512 14 | def random_horizontal_flip(imgs): 15 | if random.random() < 0.5: 16 | for i in range(len(imgs)): 17 | imgs[i] = nd.image.flip_left_right(imgs[i]).copy() 18 | return imgs 19 | def random_rotate(imgs): 20 | max_angle = 10 21 | angle = random.random() * 2 * max_angle - max_angle 22 | # print(angle) 23 | for i in range(len(imgs)): 24 | img = imgs[i].asnumpy() 25 | w, h = img.shape[:2] 26 | rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) 27 | img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w)) 28 | imgs[i] = nd.array(img_rotation) 29 | return imgs 30 | class MyDataSet(Dataset): 31 | def __init__(self, root, split, is_transform=False,is_train=True): 32 | self.root = os.path.join(root, split) 33 | self.is_transform = is_transform 34 | self.img_paths = [] 35 | self._img_512 = os.path.join(root, split, 'train_512', '{}.png') 36 | self._mask_512 = os.path.join(root, split, 'mask_512', '{}.png') 37 | self._lbl_512 = os.path.join(root, split, 'train_512', '{}.png') 38 | self._img_256 = os.path.join(root, split, 'train_256', '{}.png') 39 | self._lbl_256 = os.path.join(root, split, 'train_256', '{}.png') 40 | self._img_128 = os.path.join(root, split, 'train_128', '{}.png') 41 | for fn in os.listdir(os.path.join(root, split, 'train_512')): 42 | if len(fn) > 3 and fn[-4:] == '.png': 43 | self.img_paths.append(fn[:-4]) 44 | 45 | def __len__(self): 46 | return len(self.img_paths) 47 | 48 | def __getitem__(self, idx): 49 | img_path_512 = self._img_512.format(self.img_paths[idx]) 50 | img_path_256 = self._img_256.format(self.img_paths[idx]) 51 | img_path_128 = self._img_128.format(self.img_paths[idx]) 52 | lbl_path_256 = self._lbl_256.format(self.img_paths[idx]) 53 | mask_path_512 = self._mask_512.format(self.img_paths[idx]) 54 | lbl_path_512 = self._lbl_512.format(self.img_paths[idx]) 55 | img_arr_256 = mx.image.imread(img_path_256).astype(np.float32)/127.5 - 1 56 | img_arr_512 = mx.image.imread(img_path_512).astype(np.float32)/127.5 - 1 57 | img_arr_128 = mx.image.imread(img_path_128).astype(np.float32)/127.5 - 1 58 | img_arr_512 = mx.image.imresize(img_arr_512, img_wd * 2, img_ht) 59 | img_arr_in_512, img_arr_out_512 = [mx.image.fixed_crop(img_arr_512, 0, 0, img_wd, img_ht), 60 | mx.image.fixed_crop(img_arr_512, img_wd, 0, img_wd, img_ht)] 61 | if os.path.exists(mask_path_512): 62 | mask_512 = mx.image.imread(mask_path_512) 63 | else: 64 | mask_512 = mx.image.imread(mask_path_512.replace(".png",'.jpg',1)) 65 | tep_mask_512 = nd.slice_axis(mask_512, axis=2, begin=0, end=1)/255 66 | if self.is_transform: 67 | imgs = [img_arr_out_512, img_arr_in_512, tep_mask_512,img_arr_256,img_arr_128] 68 | imgs = random_horizontal_flip(imgs) 69 | imgs = random_rotate(imgs) 70 | img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_256,img_arr_128 = imgs[0], imgs[1], imgs[2], imgs[3],imgs[4] 71 | img_arr_in_512, img_arr_out_512 = [nd.transpose(img_arr_in_512, (2,0,1)), 72 | nd.transpose(img_arr_out_512, (2,0,1))] 73 | img_arr_out_256 = nd.transpose(img_arr_256, (2,0,1)) 74 | img_arr_out_128 = nd.transpose(img_arr_128, (2,0,1)) 75 | tep_mask_512 = tep_mask_512.reshape(tep_mask_512.shape[0],tep_mask_512.shape[1],1) 76 | tep_mask_512 = nd.transpose(tep_mask_512,(2,0,1)) 77 | return img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_out_256,img_arr_out_128 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import matplotlib as mpl 4 | import tarfile 5 | import matplotlib.image as mpimg 6 | from matplotlib import pyplot as plt 7 | import mxnet as mx 8 | from mxnet import gluon 9 | from mxnet import ndarray as nd 10 | from mxnet.gluon import nn, utils 11 | from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout 12 | from mxnet import autograd 13 | import numpy as np 14 | from datetime import datetime 15 | import time 16 | import logging 17 | from network import set_network 18 | from imagepool import ImagePool 19 | #from dataset import load_data 20 | from vis_dataset import visualize 21 | from mxnet.gluon.data import Dataset, DataLoader 22 | from text2 import MyDataSet 23 | import argparse 24 | def facc(label, pred): 25 | pred = pred.ravel() 26 | label = label.ravel() 27 | return ((pred > 0.5) == label).mean() 28 | def extract_features(x, style_layers,net,ctx,in_size = 224): 29 | B,C,H,W = x.shape 30 | Img = (x + 1.0)*127.5 31 | Img_chanells = [nd.expand_dims(Img[:,i,:,:],axis=1) for i in range(3)] 32 | Img_chanells[0] = (Img_chanells[0]/255 - 0.485)/ 0.229 #subtracted by [103.939, 116.779, 123.68] 33 | Img_chanells[1] = (Img_chanells[1]/255 - 0.456)/ 0.224 #subtracted by [103.939, 116.779, 123.68] 34 | Img_chanells[2] = (Img_chanells[2]/255 - 0.406)/ 0.225 #subtracted by [103.939, 116.779, 123.68] 35 | Img = nd.concat(*Img_chanells,dim=1) 36 | limx = H - in_size 37 | limy = W - in_size 38 | xs = np.random.randint(0,limx,B) 39 | ys = np.random.randint(0,limy,B) 40 | lis = [nd.expand_dims(Img[i,:,x:x+in_size,y:y+in_size],axis=0) for i,(x,y) in enumerate(zip(xs,ys))] 41 | Img_cropped = nd.concat(*lis,dim=0) 42 | styles = [] 43 | for k in range(len(net)): 44 | Img_cropped = net[k](Img_cropped.as_in_context(ctx)) 45 | if k in style_layers: 46 | styles.append(Img_cropped) 47 | return styles 48 | def gram(x): 49 | c = x.shape[1] 50 | n = x.size / x.shape[1] 51 | y = x.reshape((c, int(n))) 52 | return nd.dot(y, y.T) / n 53 | def style_loss(yhat, y): 54 | return nd.abs(gram(yhat) - gram(y)).mean() 55 | def cal_loss_style(hout,hcomp,hgt): 56 | for i in range(3): 57 | if i==0: 58 | L_style_out = style_loss(hout[0],hgt[0]) 59 | L_style_comp = style_loss(hcomp[0],hgt[0]) 60 | else: 61 | L_style_out = L_style_out + style_loss(hout[i],hgt[i]) 62 | L_style_comp = L_style_comp + style_loss(hcomp[i],hgt[i]) 63 | return L_style_comp + L_style_out 64 | def calc_loss_perceptual(hout,hcomp,hgt): 65 | for j in range(3): 66 | if j == 0: 67 | loss = nd.abs(hout[0]-hgt[0]).mean() 68 | loss = loss + nd.abs(hcomp[0]-hgt[0]).mean() 69 | else: 70 | loss = loss + nd.abs(hout[j]-hgt[j]).mean() 71 | loss = loss + nd.abs(hcomp[j]-hgt[j]).mean() 72 | return loss 73 | def tv_loss(yhat): 74 | return 0.5*((yhat[:,:,1:,:] - yhat[:,:,:-1,:]).abs().mean() + 75 | (yhat[:,:,:,1:] - yhat[:,:,:,:-1]).abs().mean()) 76 | def train(args): 77 | use_gpu = args.gpu 78 | ctx = mx.gpu(0) if use_gpu else mx.cpu() 79 | pool_size = 50 80 | lambda1 = 100 81 | img_wd = args.img_size 82 | img_ht = args.img_size 83 | style_layers = [4,9,16] 84 | my_train = MyDataSet(args.trainset_path, '') 85 | train_loader = DataLoader(my_train, batch_size=args.batch_size, shuffle=True, last_batch='rollover') 86 | GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() 87 | L1_loss = gluon.loss.L1Loss() 88 | netG, netD, net, net_label,trainerG, trainerD,trainerV, trainerL = set_network(args) 89 | image_pool = ImagePool(pool_size) 90 | metric = mx.metric.CustomMetric(facc) 91 | stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') 92 | logging.basicConfig(filename = 'pixel2pixel.log',level=logging.DEBUG) 93 | for epoch in range(args.n_epoch): 94 | tic = time.time() 95 | btic = time.time() 96 | iter = 0 97 | # print(trainerG.learning_rate) 98 | if epoch > 0 and epoch % 200 == 0: 99 | trainerG.set_learning_rate(trainerG.learning_rate * 0.2) 100 | trainerD.set_learning_rate(trainerD.learning_rate * 0.2) 101 | trainerV.set_learning_rate(trainerD.learning_rate * 0.2) 102 | # print(trainerG.learning_rate) 103 | for data, label, mask, data_256, data_128 in train_loader: 104 | batch_size =data.shape[0] 105 | ############################ 106 | # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) 107 | ########################### 108 | real_in = data.as_in_context(ctx) 109 | real_out = label.as_in_context(ctx) 110 | real_out_256 = data_256.as_in_context(ctx) 111 | real_out_128 = data_128.as_in_context(ctx) 112 | mask = mask.as_in_context(ctx).astype('float32') 113 | mask_b = mask.asnumpy().astype(bool) 114 | mask = mask.astype('float32') 115 | _,_,_,_, fake_out = netG(real_in) 116 | fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1)) 117 | with autograd.record(): 118 | # Use image pooling to utilize history images 119 | mask_patch = 1 - net_label(nd.array(mask).as_in_context(ctx)).asnumpy().astype(bool).astype(np.int8) 120 | fake_label = nd.array(mask_patch).as_in_context(ctx) 121 | output = netD(fake_concat) 122 | errD_fake = GAN_loss(output, fake_label) 123 | metric.update([fake_label,], [output,]) 124 | # Train with real image 125 | real_concat = nd.concat(real_in, real_out, dim=1) 126 | output = netD(real_concat) 127 | real_label = nd.ones(output.shape, ctx=ctx) 128 | errD_real = GAN_loss(output, real_label) 129 | errD = (errD_real + errD_fake) * 0.5 130 | errD.backward() 131 | metric.update([real_label,], [output,]) 132 | trainerD.step(data.shape[0]) 133 | ############################ 134 | # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) 135 | ########################### 136 | with autograd.record(): 137 | p5,p6,p7,p8,fake_out = netG(real_in) 138 | I_comp_1 = nd.array(np.where(mask_b,real_out.asnumpy(),fake_out.asnumpy())).as_in_context(ctx) 139 | fake_concat = nd.concat(real_in, fake_out, dim=1) 140 | output = netD(fake_concat) 141 | real_label = nd.ones(output.shape, ctx=ctx) 142 | errG = GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1 + L1_loss(real_out*(1-mask), fake_out*(1-mask))*lambda1*6 +L1_loss(real_out_256, p8) * lambda1*0.8 +L1_loss(real_out_128, p7) * lambda1*0.6 143 | errG.backward() 144 | trainerG.step(data.shape[0]) 145 | name, acc = metric.get() 146 | print('speed: {} samples/s'.format(batch_size / (time.time() - btic))) 147 | print('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'%(nd.mean(errG).asscalar(), 148 | nd.mean(errG).asscalar(), acc, iter, epoch)) 149 | # print ('L_perceptual = %f, L_style = %f, L_tv = %f, L_total = %f'%(nd.mean(L_perceptual).asscalar(),nd.mean(L_style).asscalar(), nd.mean(L_tv).asscalar(),nd.mean(L_total).asscalar())) 150 | if (epoch+1)% 50 ==0: 151 | netG.collect_params().save (args.checkpoint +'/net_%d.params'%(epoch)) 152 | ############################ 153 | # (3) cal vgg16: style_loss+perprocess_loss+tv_loss 154 | ########################### 155 | with autograd.record(): 156 | _,_,_,_,fake_out = netG(real_in) 157 | I_comp = nd.array(np.where(mask_b,real_out.asnumpy(),fake_out.asnumpy())).as_in_context(ctx) 158 | hout = extract_features(fake_out,style_layers,net,ctx) 159 | hgt = extract_features(real_out,style_layers,net,ctx) 160 | hcomp = extract_features(I_comp,style_layers,net,ctx) 161 | L_perceptual = calc_loss_perceptual(hout,hcomp,hgt) 162 | L_style = cal_loss_style(hout,hcomp,hgt) #Loss style out and comp 163 | L_tv = tv_loss(fake_out) 164 | # L_total = 0.5 * L_perceptual + 50.0 * L_style + 25.0 * L_tv + GAN_loss(output, real_label) + L1_loss(real_out, fake_out) * lambda1 + L1_loss(real_out*(1-mask), fake_out*(1-mask))*lambda1*6 +L1_loss(real_out_256, p8) * lambda1*0.8 +L1_loss(real_out_128, p7) * lambda1*0.6 165 | L_total = 0.5 * L_perceptual + 50.0 * L_style + 25.0 * L_tv 166 | L_total.backward() 167 | trainerV.step(data.shape[0]) 168 | print ('L_perceptual = %f, L_style = %f, L_tv = %f, L_total = %f'%(nd.mean(L_perceptual).asscalar(),nd.mean(L_style).asscalar(), nd.mean(L_tv).asscalar(),nd.mean(L_total).asscalar())) 169 | # Print log infomation every ten batches 170 | if iter % 10 == 0: 171 | name, acc = metric.get() 172 | logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic))) 173 | logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' 174 | %(nd.mean(errD).asscalar(), 175 | nd.mean(errG).asscalar(), acc, iter, epoch)) 176 | iter = iter + 1 177 | btic = time.time() 178 | name, acc = metric.get() 179 | metric.reset() 180 | logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) 181 | logging.info('time: %f' % (time.time() - tic)) 182 | if __name__ == '__main__': 183 | parser = argparse.ArgumentParser(description='Hyperparams') 184 | parser.add_argument('--img_size', nargs='?', type=int, default=512, 185 | help='Height of the input image') 186 | parser.add_argument('--n_epoch', nargs='?', type=int, default=1000, 187 | help='# of the epochs') 188 | parser.add_argument('--batch_size', nargs='?', type=int, default=1, 189 | help='Batch Size') 190 | parser.add_argument('--lr', nargs='?', type=float, default=0.0005, 191 | help='Learning Rate') 192 | parser.add_argument('--beta', nargs='?', type=float, default=0.0002, 193 | help='beta') 194 | parser.add_argument('--trainset_path', nargs='?', type=str, default=None, 195 | help='Path to train images') 196 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 197 | help='path to save checkpoint (default: checkpoint)') 198 | parser.add_argument('--model', nargs='?', type=str, default='', 199 | help='Path to saved model to restart from') 200 | parser.add_argument('--gpu', nargs='?', type=bool, default=True, 201 | help='use_gpu') 202 | args = parser.parse_args() 203 | 204 | train(args) 205 | -------------------------------------------------------------------------------- /vis_dataset.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | # from dataset import load_data 3 | from matplotlib import pyplot as plt 4 | import numpy as np 5 | epochs = 100 6 | batch_size = 10 7 | 8 | use_gpu = True 9 | ctx = mx.gpu(0) if use_gpu else mx.cpu() 10 | 11 | lr = 0.0002 12 | beta1 = 0.5 13 | lambda1 = 100 14 | 15 | pool_size = 50 16 | img_wd = 512 17 | img_ht = 512 18 | dataset = '' 19 | train_img_path = '%s/train' % (dataset) 20 | val_img_path = '%s/val' % (dataset) 21 | # dataset = 'facades' 22 | # train_data = load_data(train_img_path, batch_size, is_reversed=True) 23 | # val_data = load_data(val_img_path, batch_size, is_reversed=True) 24 | def visualize(img_arr): 25 | plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)) 26 | # plt.show() 27 | # plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 0.0) * 255.0).astype(np.uint8)) 28 | plt.axis('off') 29 | 30 | def preview_train_data(): 31 | img_in_list, img_out_list = train_data.next().data 32 | for i in range(4): 33 | plt.subplot(2,4,i+1) 34 | visualize(img_in_list[i]) 35 | plt.subplot(2,4,i+5) 36 | visualize(img_out_list[i]) 37 | plt.show() 38 | 39 | # preview_train_data() 40 | --------------------------------------------------------------------------------