├── test.py ├── js_api.txt ├── images ├── 1.PNG ├── 1.jpg ├── 2.PNG ├── 3.PNG ├── 9.jpg ├── combine_1.jpg ├── combine_9.jpg ├── test_step1.PNG └── test_step2.PNG ├── weights ├── readme.md └── wdsr-b-32-x4.h5 ├── crul_api.txt ├── README.md ├── prepare.py ├── predict.py ├── WDSR.py ├── utils.py └── pix2pix.py /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/test.py -------------------------------------------------------------------------------- /js_api.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/js_api.txt -------------------------------------------------------------------------------- /images/1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/1.PNG -------------------------------------------------------------------------------- /images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/1.jpg -------------------------------------------------------------------------------- /images/2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/2.PNG -------------------------------------------------------------------------------- /images/3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/3.PNG -------------------------------------------------------------------------------- /images/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/9.jpg -------------------------------------------------------------------------------- /weights/readme.md: -------------------------------------------------------------------------------- 1 | 训练好的模型地址 2 | p2p:链接link:https://pan.baidu.com/s/1tmfMTCpAFVpC6Z7GWvkrfA 提取码code:cac2 3 | -------------------------------------------------------------------------------- /images/combine_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/combine_1.jpg -------------------------------------------------------------------------------- /images/combine_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/combine_9.jpg -------------------------------------------------------------------------------- /images/test_step1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/test_step1.PNG -------------------------------------------------------------------------------- /images/test_step2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/images/test_step2.PNG -------------------------------------------------------------------------------- /weights/wdsr-b-32-x4.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjrelc/pix2pix-color-paint-for-gray-comic/HEAD/weights/wdsr-b-32-x4.h5 -------------------------------------------------------------------------------- /crul_api.txt: -------------------------------------------------------------------------------- 1 | img_file_name=`curl -F 'file=@' https://momodel.cn/pyapi/file/temp_api_file | jq -r '.temp_file_name'` 2 | 3 | curl --header "Content-Type: application/json" --request POST --data '{"app": {"input": {"img": {"val": "'${img_file_name}'", "type": "img"}}, "output": {"str": {"type": "str"}}}, "version": "dev"}' https://momodel.cn/pyapi/apps/run/5e267a93d13fba905e3323be -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | auto color paint for comic using pix2pix and WDSR(keras),based on user:wmylxmj job 2 | https://github.com/wmylxmj/Pix2Pix-Keras 3 | https://github.com/wmylxmj/Anime-Super-Resolution 4 | 感谢大佬的开源项目 5 | 漫画、动漫壁纸、本子自动上色,目前先训练了本子的自动上色,时长12个钟,模型效果还行,起码比黑白的好看2333 6 | 训练example show: 7 | 左边灰度(gray) 中间预测(predict) 右边原图(original) 8 | ![example1](/images/1.PNG) 9 | ![example2](/images/2.PNG) 10 | ![example3](/images/3.PNG) 11 | 12 | 用训练好后的模型对黑白本子上色效果如下(gray img non-resize predict): 13 | 原图original:![黑白本子](/images/9.jpg) 14 | 预测predict:![自动上色后](/images/combine_9.jpg) 15 | 16 | how to use: 17 |
1.train model 训练自己需要的上色模型:just see wmylxmj job:https://github.com/wmylxmj/Pix2Pix-Keras 18 |
2.get colored_img 得到上色后图片:by running code predict.py 19 | 20 | 项目目录(Directory): 21 |
./weights:存放训练后的上色模型以及SR模型,包含wdsr-b-32-x4.h5、discriminator_weights.h5(可省略)、generator_weights.h5 22 |
./datasets/OriginalImages:存放用于训练的彩色图片 23 |
pix2pix.py:pix2pix model file 24 |
WDSR.py:wdsr model file 25 |
utils.py: settings of loading data 设置加载数据的方法等 26 |
prepare.py: pre-step before you start train 训练前的预处理数据 27 |
predict.py: using trained model get colored img 得到上色后图片 28 |
test.py: using api of my trained model get colored img 调用我训练的模型API来上色图片测试 29 |
js_api.txt javascript调用本项目api方法 30 |
crul_api.txt crul调用本项目api方法 31 | 32 | 项目API使用: 33 |
使用Mo平台部署在线测试(Deploy test online):支持javascript、curl、python 34 |
流程:调用API初始化图片,返回上色后图片的base64字符串,使用函数转成img另存本地 35 | ![step1](/images/test_step1.PNG) 36 | ![step2](/images/test_step2.PNG) 37 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Feb 24 16:31:16 2019 4 | 5 | @author: wmy 6 | """ 7 | 8 | import os 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | import tensorflow as tf 13 | 14 | original_images_path = r'.\datasets\OriginalImages' 15 | color_images_path = r'.\datasets\ColorImages' 16 | grayscale_images_path = r'.\datasets\GrayscaleImages' 17 | combined_images_path = r'.\datasets\CombinedImages' 18 | 19 | resize_height = 256 20 | resize_weidth = 256 21 | 22 | def find_images(path): 23 | result = [] 24 | for filename in os.listdir(path): 25 | _, ext = os.path.splitext(filename.lower()) 26 | if ext == ".jpg" or ext == ".png": 27 | result.append(os.path.join(path, filename)) 28 | pass 29 | pass 30 | result.sort() 31 | return result 32 | 33 | if __name__ == '__main__': 34 | search_result = find_images(original_images_path) 35 | for image_path in search_result: 36 | img_name = image_path[len(original_images_path):] 37 | img = Image.open(image_path) 38 | img_color = img.resize((resize_weidth, resize_height), Image.ANTIALIAS) 39 | img_color.save(color_images_path + img_name, quality=95) 40 | print("Info: image '" + color_images_path + img_name + "' saved.") 41 | img_gray = img_color.convert('L') 42 | img_gray = img_gray.convert('RGB') 43 | img_gray.save(grayscale_images_path + img_name, quality=95) 44 | print("Info: image '" + grayscale_images_path + img_name + "' saved.") 45 | combined_image = Image.new('RGB', (resize_weidth*2, resize_height)) 46 | combined_image.paste(img_color, (0, 0, resize_weidth, resize_height)) 47 | combined_image.paste(img_gray, (resize_weidth, 0, resize_weidth*2, resize_height)) 48 | combined_image.save(combined_images_path + img_name, quality=95) 49 | print("Info: image '" + combined_images_path + img_name + "' saved.") 50 | pass 51 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import skimage 2 | import imageio 3 | from keras.datasets import mnist 4 | from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate 5 | from keras.layers import BatchNormalization, Activation, ZeroPadding2D 6 | from keras.layers.advanced_activations import LeakyReLU 7 | from keras.layers.convolutional import UpSampling2D, Conv2D 8 | from keras.models import Sequential, Model,load_model 9 | from keras.optimizers import Adam 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import os 13 | import sys 14 | from pix2pix import Pix2Pix 15 | from PIL import Image 16 | import tensorflow as tf 17 | from PIL import Image 18 | from keras import backend as K 19 | from keras.losses import mean_absolute_error, mean_squared_error 20 | import random 21 | from WDSR import wdsr_b 22 | from utils import DataLoader 23 | import cv2 24 | 25 | data_loader = DataLoader() 26 | 27 | def combine(bottom_pic,top_pic,alpha,beta,gamma,save_pth): 28 | bottom = cv2.imread(bottom_pic) 29 | top = cv2.imread(top_pic) 30 | h, w, _ = bottom.shape 31 | img2 = cv2.resize(top, (w,h), interpolation=cv2.INTER_AREA) 32 | overlapping = cv2.addWeighted(bottom, alpha, img2, beta, gamma) 33 | cv2.imwrite(save_pth, overlapping) 34 | 35 | def predict_single_image(pix2pix,wdsr, image_path, save_path): 36 | pix2pix.generator.load_weights('./weights/generator_weights.h5') 37 | wdsr.load_weights('./weights/wdsr-b-32-x4.h5') 38 | image_B = imageio.imread(image_path, pilmode='RGB').astype(np.float) 39 | image_B = skimage.transform.resize(image_B, (pix2pix.nW, pix2pix.nH)) 40 | images_B = [] 41 | images_B.append(image_B) 42 | images_B = np.array(images_B)/127.5 - 1. 43 | generates_A = pix2pix.generator.predict(images_B) 44 | generate_A = generates_A[0] 45 | generate_A = np.uint8((np.array(generate_A) * 0.5 + 0.5) * 255) 46 | generate_A = Image.fromarray(generate_A) 47 | generated_image = Image.new('RGB', (pix2pix.nW, pix2pix.nH)) 48 | generated_image.paste(generate_A, (0, 0, pix2pix.nW, pix2pix.nH)) 49 | lr = np.asarray(generated_image) 50 | x = np.array([lr]) 51 | y = wdsr.predict(x) 52 | y = np.clip(y, 0, 255) 53 | y = y.astype('uint8') 54 | sr = Image.fromarray(y[0]) 55 | sr.save(save_path) 56 | combine(image_path,save_path,0.5,0.5,0,save_path) 57 | pass 58 | 59 | gan = Pix2Pix() 60 | wdsr = wdsr_b(scale=4, num_res_blocks=32) 61 | 62 | predict_single_image(gan,wdsr, '1.jpg', 'test_1.jpg') -------------------------------------------------------------------------------- /WDSR.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from keras import backend as K 4 | from keras.layers import Add, Conv2D, Input, Lambda, Activation 5 | from keras.models import Model 6 | 7 | def SubpixelConv2D(scale, **kwargs): 8 | return Lambda(lambda x: tf.depth_to_space(x, scale), **kwargs) 9 | 10 | def Normalization(**kwargs): 11 | # you can change this if you know mean in dataset 12 | rgb_mean = np.array([0.5, 0.5, 0.5]) * 255 13 | return Lambda(lambda x: (x - rgb_mean) / 127.5, **kwargs) 14 | 15 | def Denormalization(**kwargs): 16 | # you can change this if you know mean in dataset 17 | rgb_mean = np.array([0.5, 0.5, 0.5]) * 255 18 | return Lambda(lambda x: x * 127.5 + rgb_mean, **kwargs) 19 | 20 | def PadSymmetricInTestPhase(): 21 | pad = Lambda(lambda x: K.in_train_phase(x, tf.pad(x, tf.constant([[0, 0], [2, 2], [2, 2], [0, 0]]), 'SYMMETRIC'))) 22 | pad.uses_learning_phase = True 23 | return pad 24 | 25 | def res_block_a(x_in, num_filters, expansion, kernel_size, scaling): 26 | x = Conv2D(num_filters * expansion, kernel_size, padding='same')(x_in) 27 | x = Activation('relu')(x) 28 | x = Conv2D(num_filters, kernel_size, padding='same')(x) 29 | x = Add()([x_in, x]) 30 | if scaling: 31 | x = Lambda(lambda t: t * scaling)(x) 32 | pass 33 | return x 34 | 35 | def res_block_b(x_in, num_filters, expansion, kernel_size, scaling): 36 | linear = 0.8 37 | x = Conv2D(num_filters * expansion, 1, padding='same')(x_in) 38 | x = Activation('relu')(x) 39 | x = Conv2D(int(num_filters * linear), 1, padding='same')(x) 40 | x = Conv2D(num_filters, kernel_size, padding='same')(x) 41 | x = Add()([x_in, x]) 42 | if scaling: 43 | x = Lambda(lambda t: t * scaling)(x) 44 | pass 45 | return x 46 | 47 | def wdsr(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling, res_block, name=None): 48 | if scale not in (2, 3, 4): 49 | raise ValueError("scale must in (2, 3, 4)") 50 | x_in = Input(shape=(None, None, 3)) 51 | x = Normalization()(x_in) 52 | # pad input if in test phase 53 | x = PadSymmetricInTestPhase()(x) 54 | # main branch (revise padding) 55 | m = Conv2D(num_filters, 3, padding='valid')(x) 56 | for i in range(num_res_blocks): 57 | m = res_block(m, num_filters, res_block_expansion, kernel_size=3, scaling=res_block_scaling) 58 | m = Conv2D(3 * scale ** 2, 3, padding='valid', name='conv2d_main_scale_{}'.format(scale))(m) 59 | m = SubpixelConv2D(scale)(m) 60 | # skip branch 61 | s = Conv2D(3 * scale ** 2, 5, padding='valid', name='conv2d_skip_scale_{}'.format(scale))(x) 62 | s = SubpixelConv2D(scale)(s) 63 | x = Add()([m, s]) 64 | x = Denormalization()(x) 65 | if name == None: 66 | return Model(x_in, x) 67 | return Model(x_in, x, name=name) 68 | 69 | def wdsr_a(scale=2, num_filters=32, num_res_blocks=8, res_block_expansion=4, res_block_scaling=None): 70 | return wdsr(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling, res_block_a, name='wdsr-a') 71 | 72 | def wdsr_b(scale=2, num_filters=32, num_res_blocks=8, res_block_expansion=6, res_block_scaling=None): 73 | return wdsr(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling, res_block_b, name='wdsr-b') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import skimage 4 | import imageio 5 | from glob import glob 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import os 9 | import random 10 | from PIL import Image 11 | from PIL import ImageFilter 12 | 13 | class DataLoader(object): 14 | 15 | def __init__(self, dataset_path=r'.\datasets\CombinedImages',scale=4, crop_size=96, name=None): 16 | self.image_height = 256 17 | self.image_width = 256 18 | self.dataset_path = dataset_path 19 | self.__scale = 4 20 | self.__crop_size = 96 21 | self.scale = scale 22 | self.crop_size = crop_size 23 | self.name = name 24 | pass 25 | 26 | def imread(self, path): 27 | return imageio.imread(path, pilmode='RGB').astype(np.float) 28 | 29 | def find_images(self, path): 30 | result = [] 31 | for filename in os.listdir(path): 32 | _, ext = os.path.splitext(filename.lower()) 33 | if ext == ".jpg" or ext == ".png": 34 | result.append(os.path.join(path, filename)) 35 | pass 36 | pass 37 | result.sort() 38 | return result 39 | 40 | def load_data(self, batch_size=1, for_testing=False): 41 | search_result = self.find_images(self.dataset_path) 42 | batch_images = np.random.choice(search_result, size=batch_size) 43 | images_A = [] 44 | images_B = [] 45 | for image_path in batch_images: 46 | combined_image = self.imread(image_path) 47 | h, w, c = combined_image.shape 48 | nW = int(w/2) 49 | image_A, image_B = combined_image[:, :nW, :], combined_image[:, nW:, :] 50 | image_A = skimage.transform.resize(image_A, (self.image_height, self.image_width)) 51 | image_B = skimage.transform.resize(image_B, (self.image_height, self.image_width)) 52 | if not for_testing and np.random.random() < 0.5: 53 | # 数据增强,左右翻转 54 | image_A = np.fliplr(image_A) 55 | image_B = np.fliplr(image_B) 56 | pass 57 | images_A.append(image_A) 58 | images_B.append(image_B) 59 | pass 60 | images_A = np.array(images_A)/127.5 - 1. 61 | images_B = np.array(images_B)/127.5 - 1. 62 | return images_A, images_B 63 | 64 | def load_batch(self, batch_size=1, for_testing=False): 65 | search_result = self.find_images(self.dataset_path) 66 | self.n_complete_batches = int(len(search_result) / batch_size) 67 | for i in range(self.n_complete_batches): 68 | batch = search_result[i*batch_size:(i+1)*batch_size] 69 | images_A, images_B = [], [] 70 | for image_path in batch: 71 | combined_image = self.imread(image_path) 72 | h, w, c = combined_image.shape 73 | nW = int(w/2) 74 | image_A = combined_image[:, :nW, :] 75 | image_B = combined_image[:, nW:, :] 76 | image_A = skimage.transform.resize(image_A, (self.image_height, self.image_width)) 77 | image_B = skimage.transform.resize(image_B, (self.image_height, self.image_width)) 78 | if not for_testing and np.random.random() > 0.5: 79 | # 数据增强,左右翻转 80 | image_A = np.fliplr(image_A) 81 | image_B = np.fliplr(image_B) 82 | pass 83 | images_A.append(image_A) 84 | images_B.append(image_B) 85 | pass 86 | images_A = np.array(images_A)/127.5 - 1. 87 | images_B = np.array(images_B)/127.5 - 1. 88 | yield images_A, images_B 89 | 90 | 91 | @property 92 | def scale(self): 93 | return self.__scale 94 | 95 | @scale.setter 96 | def scale(self, value): 97 | if not isinstance(value, int): 98 | raise ValueError("scale must be int") 99 | elif value <= 0: 100 | raise ValueError("scale must > 0") 101 | else: 102 | self.__scale = value 103 | pass 104 | pass 105 | 106 | @property 107 | def crop_size(self): 108 | return self.__crop_size 109 | 110 | @crop_size.setter 111 | def crop_size(self, value): 112 | if not isinstance(value, int): 113 | raise ValueError("crop size must be int") 114 | elif value <= 0: 115 | raise ValueError("crop size must > 0") 116 | else: 117 | self.__crop_size = value 118 | pass 119 | pass 120 | 121 | def imread(self, path): 122 | return Image.open(path) 123 | 124 | def resize(self, image, size): 125 | resamples = [Image.NEAREST, Image.BILINEAR, Image.HAMMING, \ 126 | Image.BICUBIC, Image.LANCZOS] 127 | resample = random.choice(resamples) 128 | return image.resize(size, resample=resample) 129 | 130 | def gaussianblur(self, image, radius=2): 131 | return image.filter(ImageFilter.GaussianBlur(radius=radius)) 132 | 133 | def medianfilter(self, image, size=3): 134 | return image.filter(ImageFilter.MedianFilter(size=size)) 135 | 136 | def downsampling(self, image): 137 | resize = (image.size[0]//self.scale, image.size[1]//self.scale) 138 | hidden_scale = random.uniform(1, 3) 139 | hidden_resize = (int(resize[0]/hidden_scale), int(resize[1]/hidden_scale)) 140 | radius = random.uniform(1, 3) 141 | image = self.gaussianblur(image, radius) 142 | image = self.resize(image, hidden_resize) 143 | image = self.resize(image, resize) 144 | return image 145 | 146 | def search(self, setpath): 147 | results = [] 148 | files = os.listdir(setpath) 149 | for file in files: 150 | path = os.path.join(setpath, file) 151 | results.append(path) 152 | pass 153 | return results 154 | 155 | def rotate(self, lr, hr): 156 | angle = random.choice([0, 90, 180, 270]) 157 | lr = lr.rotate(angle, expand=True) 158 | hr = hr.rotate(angle, expand=True) 159 | return lr, hr 160 | 161 | def flip(self, lr, hr): 162 | mode = random.choice([0, 1, 2, 3]) 163 | if mode == 0: 164 | pass 165 | elif mode == 1: 166 | lr = lr.transpose(Image.FLIP_LEFT_RIGHT) 167 | hr = hr.transpose(Image.FLIP_LEFT_RIGHT) 168 | pass 169 | elif mode == 2: 170 | lr = lr.transpose(Image.FLIP_TOP_BOTTOM) 171 | hr = hr.transpose(Image.FLIP_TOP_BOTTOM) 172 | pass 173 | elif mode == 3: 174 | lr = lr.transpose(Image.FLIP_LEFT_RIGHT) 175 | hr = hr.transpose(Image.FLIP_LEFT_RIGHT) 176 | lr = lr.transpose(Image.FLIP_TOP_BOTTOM) 177 | hr = hr.transpose(Image.FLIP_TOP_BOTTOM) 178 | pass 179 | return lr, hr 180 | 181 | def crop(self, lr, hr): 182 | hr_crop_size = self.crop_size 183 | lr_crop_size = hr_crop_size//self.scale 184 | lr_w = np.random.randint(lr.size[0] - lr_crop_size + 1) 185 | lr_h = np.random.randint(lr.size[1] - lr_crop_size + 1) 186 | hr_w = lr_w * self.scale 187 | hr_h = lr_h * self.scale 188 | lr = lr.crop([lr_w, lr_h, lr_w+lr_crop_size, lr_h+lr_crop_size]) 189 | hr = hr.crop([hr_w, hr_h, hr_w+hr_crop_size, hr_h+hr_crop_size]) 190 | return lr, hr 191 | 192 | def pair(self, fp): 193 | hr = self.imread(fp) 194 | lr = self.downsampling(hr) 195 | lr, hr = self.rotate(lr, hr) 196 | lr, hr = self.flip(lr, hr) 197 | lr, hr = self.crop(lr, hr) 198 | lr = np.asarray(lr) 199 | hr = np.asarray(hr) 200 | return lr, hr 201 | 202 | def batches(self, setpath="datasets/train", batch_size=16, complete_batch_only=False): 203 | images = self.search(setpath) 204 | sizes = [] 205 | for image in images: 206 | array = plt.imread(image) 207 | sizes.append(array.shape[0]) 208 | sizes.append(array.shape[1]) 209 | pass 210 | crop_size_max = min(sizes) 211 | crop_size = min(crop_size_max, self.crop_size) 212 | if self.crop_size != crop_size: 213 | self.crop_size = crop_size 214 | print("Info: crop size adjusted to " + str(self.crop_size) + ".") 215 | pass 216 | np.random.shuffle(images) 217 | n_complete_batches = int(len(images)/batch_size) 218 | self.n_batches = int(len(images) / batch_size) 219 | have_res_batch = (len(images)/batch_size) > n_complete_batches 220 | if have_res_batch and complete_batch_only==False: 221 | self.n_batches += 1 222 | pass 223 | for i in range(n_complete_batches): 224 | batch = images[i*batch_size:(i+1)*batch_size] 225 | lrs, hrs = [], [] 226 | for image in batch: 227 | lr, hr = self.pair(image) 228 | lrs.append(lr) 229 | hrs.append(hr) 230 | pass 231 | lrs = np.array(lrs) 232 | hrs = np.array(hrs) 233 | yield lrs, hrs 234 | if self.n_batches > n_complete_batches: 235 | batch = images[n_complete_batches*batch_size:] 236 | lrs, hrs = [], [] 237 | for image in batch: 238 | lr, hr = self.pair(image) 239 | lrs.append(lr) 240 | hrs.append(hr) 241 | pass 242 | lrs = np.array(lrs) 243 | hrs = np.array(hrs) 244 | yield lrs, hrs 245 | pass 246 | 247 | pass -------------------------------------------------------------------------------- /pix2pix.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Feb 24 19:27:09 2019 4 | 5 | @author: wmy 6 | """ 7 | 8 | import scipy 9 | from keras.datasets import mnist 10 | from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate 11 | from keras.layers import BatchNormalization, Activation, ZeroPadding2D 12 | from keras.layers.advanced_activations import LeakyReLU 13 | from keras.layers.convolutional import UpSampling2D, Conv2D 14 | from keras.models import Sequential, Model 15 | from keras.optimizers import Adam 16 | import tensorflow as tf 17 | from keras import backend as K 18 | from keras.layers import Add, Conv2D, Input, Lambda, Activation 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import os 22 | import sys 23 | from utils import DataLoader 24 | from PIL import Image 25 | 26 | class Pix2Pix(object): 27 | 28 | def __init__(self): 29 | self.nH = 256 30 | self.nW = 256 31 | self.nC = 3 32 | self.data_loader = DataLoader() 33 | self.image_shape = (self.nH, self.nW, self.nC) 34 | self.image_A = Input(shape=self.image_shape) 35 | self.image_B = Input(shape=self.image_shape) 36 | self.discriminator = self.creat_discriminator() 37 | self.discriminator.compile(loss='mse', optimizer=Adam(0.0002, 0.5), metrics=['accuracy']) 38 | self.generator = self.creat_generator() 39 | self.fake_A = self.generator(self.image_B) 40 | self.discriminator.trainable = False 41 | self.valid = self.discriminator([self.fake_A, self.image_B]) 42 | self.combined = Model(inputs=[self.image_A, self.image_B], outputs=[self.valid, self.fake_A]) 43 | self.combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=Adam(0.0002, 0.5)) 44 | # Calculate output shape of D (PatchGAN) 45 | self.disc_patch = (int(self.nH/2**4), int(self.nW/2**4), 1) 46 | pass 47 | 48 | def creat_generator(self): 49 | # layer 0 50 | d0 = Input(shape=self.image_shape) 51 | # layer 1 52 | d1 = Conv2D(filters=64, kernel_size=4, strides=2, padding='same')(d0) 53 | d1 = LeakyReLU(alpha=0.2)(d1) 54 | # layer 2 55 | d2 = Conv2D(filters=128, kernel_size=4, strides=2, padding='same')(d1) 56 | d2 = LeakyReLU(alpha=0.2)(d2) 57 | d2 = BatchNormalization(momentum=0.8)(d2) 58 | # layer 3 59 | d3 = Conv2D(filters=256, kernel_size=4, strides=2, padding='same')(d2) 60 | d3 = LeakyReLU(alpha=0.2)(d3) 61 | d3 = BatchNormalization(momentum=0.8)(d3) 62 | # layer 4 63 | d4 = Conv2D(filters=512, kernel_size=4, strides=2, padding='same')(d3) 64 | d4 = LeakyReLU(alpha=0.2)(d4) 65 | d4 = BatchNormalization(momentum=0.8)(d4) 66 | # layer 5 67 | d5 = Conv2D(filters=512, kernel_size=4, strides=2, padding='same')(d4) 68 | d5 = LeakyReLU(alpha=0.2)(d5) 69 | d5 = BatchNormalization(momentum=0.8)(d5) 70 | # layer 6 71 | d6 = Conv2D(filters=512, kernel_size=4, strides=2, padding='same')(d5) 72 | d6 = LeakyReLU(alpha=0.2)(d6) 73 | d6 = BatchNormalization(momentum=0.8)(d6) 74 | # layer 7 75 | d7 = Conv2D(filters=512, kernel_size=4, strides=2, padding='same')(d6) 76 | d7 = LeakyReLU(alpha=0.2)(d7) 77 | d7 = BatchNormalization(momentum=0.8)(d7) 78 | # layer 6 79 | u6 = UpSampling2D(size=2)(d7) 80 | u6 = Conv2D(filters=512, kernel_size=4, strides=1, padding='same', activation='relu')(u6) 81 | u6 = BatchNormalization(momentum=0.8)(u6) 82 | u6 = Concatenate()([u6, d6]) 83 | # layer 5 84 | u5 = UpSampling2D(size=2)(u6) 85 | u5 = Conv2D(filters=512, kernel_size=4, strides=1, padding='same', activation='relu')(u5) 86 | u5 = BatchNormalization(momentum=0.8)(u5) 87 | u5 = Concatenate()([u5, d5]) 88 | # layer 4 89 | u4 = UpSampling2D(size=2)(u5) 90 | u4 = Conv2D(filters=512, kernel_size=4, strides=1, padding='same', activation='relu')(u4) 91 | u4 = BatchNormalization(momentum=0.8)(u4) 92 | u4 = Concatenate()([u4, d4]) 93 | # layer 3 94 | u3 = UpSampling2D(size=2)(u4) 95 | u3 = Conv2D(filters=256, kernel_size=4, strides=1, padding='same', activation='relu')(u3) 96 | u3 = BatchNormalization(momentum=0.8)(u3) 97 | u3 = Concatenate()([u3, d3]) 98 | # layer 2 99 | u2 = UpSampling2D(size=2)(u3) 100 | u2 = Conv2D(filters=128, kernel_size=4, strides=1, padding='same', activation='relu')(u2) 101 | u2 = BatchNormalization(momentum=0.8)(u2) 102 | u2 = Concatenate()([u2, d2]) 103 | # layer 1 104 | u1 = UpSampling2D(size=2)(u2) 105 | u1 = Conv2D(filters=64, kernel_size=4, strides=1, padding='same', activation='relu')(u1) 106 | u1 = BatchNormalization(momentum=0.8)(u1) 107 | u1 = Concatenate()([u1, d1]) 108 | # layer 0 109 | u0 = UpSampling2D(size=2)(u1) 110 | u0 = Conv2D(self.nC, kernel_size=4, strides=1, padding='same', activation='tanh')(u0) 111 | return Model(d0, u0) 112 | 113 | def creat_discriminator(self): 114 | # layer 0 115 | image_A = Input(shape=self.image_shape) 116 | image_B = Input(shape=self.image_shape) 117 | combined_images = Concatenate(axis=-1)([image_A, image_B]) 118 | # layer 1 119 | d1 = Conv2D(filters=64, kernel_size=4, strides=2, padding='same')(combined_images) 120 | d1 = LeakyReLU(alpha=0.2)(d1) 121 | # layer 2 122 | d2 = Conv2D(filters=128, kernel_size=4, strides=2, padding='same')(d1) 123 | d2 = LeakyReLU(alpha=0.2)(d2) 124 | d2 = BatchNormalization(momentum=0.8)(d2) 125 | # layer 3 126 | d3 = Conv2D(filters=128, kernel_size=4, strides=2, padding='same')(d2) 127 | d3 = LeakyReLU(alpha=0.2)(d3) 128 | d3 = BatchNormalization(momentum=0.8)(d3) 129 | # layer 4 130 | d4 = Conv2D(filters=128, kernel_size=4, strides=2, padding='same')(d3) 131 | d4 = LeakyReLU(alpha=0.2)(d4) 132 | d4 = BatchNormalization(momentum=0.8)(d4) 133 | validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4) 134 | return Model([image_A, image_B], validity) 135 | 136 | def train(self, epochs, batch_size=1, sample_interval=50, load_pretrained=False): 137 | if load_pretrained: 138 | print('Info: weights loaded.') 139 | self.generator.load_weights('./weights/generator_weights.h5') 140 | self.discriminator.load_weights('./weights/discriminator_weights.h5') 141 | pass 142 | # Adversarial loss ground truths 143 | valid = np.ones((batch_size,) + self.disc_patch) 144 | fake = np.zeros((batch_size,) + self.disc_patch) 145 | for epoch in range(epochs): 146 | for batch_i, (images_A, images_B) in enumerate(self.data_loader.load_batch(batch_size)): 147 | # Condition on B and generate a translated version 148 | fake_A = self.generator.predict(images_B) 149 | # Train the discriminators (original images = real / generated = Fake) 150 | d_loss_real = self.discriminator.train_on_batch([images_A, images_B], valid) 151 | d_loss_fake = self.discriminator.train_on_batch([fake_A, images_B], fake) 152 | d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 153 | # Train the generators 154 | g_loss = self.combined.train_on_batch([images_A, images_B], [valid, images_A]) 155 | # Plot the progress 156 | print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f]" % 157 | (epoch+1, epochs, batch_i+1, self.data_loader.n_complete_batches, 158 | d_loss[0], 100*d_loss[1], g_loss[0])) 159 | # If at save interval => save generated image samples 160 | if (batch_i+1) % sample_interval == 0: 161 | self.save_sample_images(epoch+1, batch_i+1) 162 | pass 163 | if (batch_i+1) % 500 == 0: 164 | self.generator.save_weights('./weights/generator_weights.h5') 165 | self.discriminator.save_weights('./weights/discriminator_weights.h5') 166 | print('Info: weights saved.') 167 | pass 168 | pass 169 | if (epoch+1) % 10 == 0 : 170 | self.generator.save_weights('./weights/generator_weights.h5') 171 | self.discriminator.save_weights('./weights/discriminator_weights.h5') 172 | print('Info: weights saved.') 173 | pass 174 | pass 175 | self.generator.save_weights('./weights/generator_weights.h5') 176 | self.discriminator.save_weights('./weights/discriminator_weights.h5') 177 | print('Info: weights saved.') 178 | pass 179 | 180 | def save_sample_images(self, epoch, batch_i, save_dir=r'.\outputs'): 181 | batch_size = 3 182 | images_A, images_B = self.data_loader.load_data(batch_size=batch_size, for_testing=True) 183 | fake_A = self.generator.predict(images_B) 184 | generated_image = Image.new('RGB', (self.nW*3, self.nH*batch_size)) 185 | for b in range(batch_size): 186 | image_A = np.uint8((np.array(images_A[b]) * 0.5 + 0.5) * 255) 187 | image_B = np.uint8((np.array(images_B[b]) * 0.5 + 0.5) * 255) 188 | image_fake_A = np.uint8((np.array(fake_A[b]) * 0.5 + 0.5) * 255) 189 | image_A = Image.fromarray(image_A) 190 | image_B = Image.fromarray(image_B) 191 | image_fake_A = Image.fromarray(image_fake_A) 192 | generated_image.paste(image_B, (0, b*self.nH, self.nW, (b+1)*self.nH)) 193 | generated_image.paste(image_fake_A, (self.nW, b*self.nH, self.nW*2, (b+1)*self.nH)) 194 | generated_image.paste(image_A, (self.nW*2, b*self.nH, self.nW*3, (b+1)*self.nH)) 195 | pass 196 | generated_image.save(save_dir + "/G_%d_%d.jpg" % (epoch, batch_i), quality=95) 197 | pass 198 | 199 | pass 200 | --------------------------------------------------------------------------------