├── .gitignore ├── README.md ├── augment_image.py ├── checkpoint └── .gitignore ├── data ├── prepared_data │ └── .gitignore └── raw_data │ └── .gitignore ├── dataloader.py ├── outpaint.ipynb ├── prepare_data.py ├── prepare_data.sh ├── requirements.txt └── saved_images └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras implementation of Image OutPainting 2 | 3 | This is an implementation of [Painting Outside the Box: Image Outpainting](https://cs230.stanford.edu/projects_spring_2018/posters/8265861.pdf) paper from Standford University. 4 | Some changes have been made to work with 256*256 image: 5 | - Added Identity loss i.e from generated image to the original image 6 | - Removed patches from training data. (training pipeline) 7 | - Replaced masking with cropping. (training pipeline) 8 | - Added convolution layers. 9 | 10 | ## Results 11 | The model was train with [3500 scrapped beach data](https://drive.google.com/open?id=1hKIn-Z8Uf3voESbJZVsapLHESPabjjrb) with agumentation totalling upto 10500 images for 25 epochs. 12 | ![Demo](https://i.imgur.com/ZHtoeDF.jpg) 13 | 14 | #### Recursive painting 15 | ![Demo](http://i.imgur.com/pDUpzcY.jpg) 16 | 17 | ### Install Requirements 18 | ``` 19 | sudo apt-get install curl 20 | sudo pip3 install -r requirements.txt 21 | ``` 22 | 23 | ## Get Started 24 | 25 | 1. Prepare Data: 26 | ```sh 27 | # Downloads the beach data and converts to numpy batch data 28 | # saves the Numpy batch data to 'data/prepared_data/' 29 | sh prepare_data.sh 30 | ``` 31 | 2. Build Model 32 | * To build Model from scratch you can directly run 'outpaint.ipynb' 33 |
OR
34 | * You can [Download](https://drive.google.com/open?id=1MfXsRwjx5CTRGBoLx154S0h-Q3rIUNH0) my trained model and move it to 'checkpoint/' and run it. 35 | 36 | ## References 37 | * [Painting Outside the Box: Image Outpainting](https://cs230.stanford.edu/projects_spring_2018/posters/8265861.pdf) 38 | -------------------------------------------------------------------------------- /augment_image.py: -------------------------------------------------------------------------------- 1 | import imgaug as ia 2 | from imgaug import augmenters as iaa 3 | import numpy as np 4 | import random 5 | 6 | 7 | brightness = iaa.Add((-7, 7), per_channel=0.5) 8 | contrast = iaa.ContrastNormalization((0.8, 1.6), per_channel=0.5) 9 | perspective = iaa.PerspectiveTransform(scale=(0.025, 0.090)) 10 | gaussian_noise = iaa.AdditiveGaussianNoise(loc=0, scale=(0.03*255, 0.04*255), per_channel=0.5) 11 | crop = iaa.Crop(px=(0, 25)) 12 | 13 | 14 | def aug_image(my_image): 15 | image = my_image.copy() 16 | if random.choice([0,0,1]): 17 | image = perspective.augment_image(image) 18 | if random.choice([0,0,1]): 19 | image = brightness.augment_image(image) 20 | if random.choice([0,0,1]): 21 | image = contrast.augment_image(image) 22 | if random.choice([0,0,1]): 23 | image = gaussian_noise.augment_image(image) 24 | if random.choice([0,0,1]): 25 | image = crop.augment_image(image) 26 | return image 27 | 28 | 29 | if __name__ == "__main__": 30 | import cv2 31 | image = cv2.imread('/home/ben/work/compare_myntra/test_image/test_images/taken_15324282418.jpg') 32 | aug_images = aug_image(image) 33 | aug_images = [aug_images] 34 | print(len(aug_images)) 35 | image = cv2.resize(image, (600,600)) 36 | image_1 = cv2.resize(aug_images[0], (600,600)) 37 | cv2.imshow('1', image) 38 | cv2.waitKey(0) 39 | cv2.imshow('2', image_1) 40 | cv2.waitKey(0) 41 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /data/prepared_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /data/raw_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from random import shuffle 4 | 5 | 6 | DATA_PATH = "data/prepared_data/train" 7 | TEST_PATH = "data/prepared_data/test" 8 | 9 | 10 | class Data(): 11 | 12 | def __init__(self): 13 | self.X_counter = 0 14 | self.file_counter = 0 15 | self.files = os.listdir(DATA_PATH) 16 | self.files = [file for file in self.files if '.npy' in file] 17 | shuffle(self.files) 18 | self._load_data() 19 | 20 | def _load_data(self): 21 | datas = np.load(os.path.join(DATA_PATH, self.files[self.file_counter])) 22 | self.X = [] 23 | for data in datas: 24 | self.X.append(data) 25 | shuffle(self.X) 26 | self.X = np.asarray(self.X) 27 | self.file_counter += 1 28 | 29 | def get_data(self, batch_size): 30 | if self.X_counter >= len(self.X): 31 | if self.file_counter > len(self.files) - 1: 32 | print("Data exhausted, Re Initialize") 33 | self.__init__() 34 | return None 35 | else: 36 | self._load_data() 37 | self.X_counter = 0 38 | 39 | if self.X_counter + batch_size <= len(self.X): 40 | remaining = len(self.X) - (self.X_counter) 41 | X = self.X[self.X_counter: self.X_counter + batch_size] 42 | else: 43 | X = self.X[self.X_counter: ] 44 | 45 | self.X_counter += batch_size 46 | return X 47 | 48 | 49 | class TestData(): 50 | 51 | def __init__(self): 52 | self.X_counter = 0 53 | self.file_counter = 0 54 | self.files = os.listdir(TEST_PATH) 55 | self.files = [file for file in self.files if '.npy' in file] 56 | shuffle(self.files) 57 | self._load_data() 58 | 59 | def _load_data(self): 60 | datas = np.load(os.path.join(TEST_PATH, self.files[self.file_counter])) 61 | self.X = [] 62 | for data in datas: 63 | self.X.append(data) 64 | shuffle(self.X) 65 | self.X = np.asarray(self.X) 66 | self.file_counter += 1 67 | 68 | def get_data(self, batch_size): 69 | if self.X_counter >= len(self.X): 70 | if self.file_counter > len(self.files) - 1: 71 | print("Data exhausted, Re Initialize") 72 | self.__init__() 73 | return None 74 | else: 75 | self._load_data() 76 | self.X_counter = 0 77 | 78 | if self.X_counter + batch_size <= len(self.X): 79 | remaining = len(self.X) - (self.X_counter) 80 | X = self.X[self.X_counter: self.X_counter + batch_size] 81 | else: 82 | X = self.X[self.X_counter: ] 83 | 84 | self.X_counter += batch_size 85 | return X 86 | -------------------------------------------------------------------------------- /outpaint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Out Paint" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from keras.layers.convolutional import Conv2D, AtrousConvolution2D\n", 17 | "from keras.layers import Activation, Dense, Input, Conv2DTranspose, Dense, Flatten\n", 18 | "from keras.layers import ReLU, Dropout, Concatenate, BatchNormalization, Reshape\n", 19 | "from keras.layers.advanced_activations import LeakyReLU\n", 20 | "from keras.models import Model, model_from_json\n", 21 | "from keras.optimizers import Adam\n", 22 | "from keras.layers.convolutional import UpSampling2D\n", 23 | "import keras.backend as K\n", 24 | "import tensorflow as tf\n", 25 | "\n", 26 | "import os\n", 27 | "import numpy as np\n", 28 | "import PIL\n", 29 | "import cv2\n", 30 | "import IPython.display\n", 31 | "from IPython.display import clear_output\n", 32 | "from datetime import datetime\n", 33 | "from dataloader import Data, TestData" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "try:\n", 43 | " from keras_contrib.layers.normalization import InstanceNormalization\n", 44 | "except Exception:\n", 45 | " from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Initialize dataloader\n", 55 | "data = Data()\n", 56 | "test_data = Data()" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# Saves Model in every N minutes\n", 66 | "TIME_INTERVALS = 2\n", 67 | "SHOW_SUMMARY = True\n", 68 | "\n", 69 | "INPUT_SHAPE = (256, 256, 3)\n", 70 | "EPOCHS = 500\n", 71 | "BATCH = 1\n", 72 | "\n", 73 | "# 25% i.e 64 width size will be mask from both side\n", 74 | "MASK_PERCENTAGE = .25\n", 75 | "\n", 76 | "EPSILON = 1e-9\n", 77 | "ALPHA = 0.0004\n", 78 | "\n", 79 | "CHECKPOINT = \"checkpoint/\"\n", 80 | "SAVED_IMAGES = \"saved_images/\"" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Models" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### Discriminator" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def dcrm_loss(y_true, y_pred):\n", 104 | " return -tf.reduce_mean(tf.log(tf.maximum(y_true, EPSILON)) + tf.log(tf.maximum(1. - y_pred, EPSILON)))\n", 105 | "\n", 106 | "d_input_shape = (INPUT_SHAPE[0], int(INPUT_SHAPE[1] * (MASK_PERCENTAGE *2)), INPUT_SHAPE[2])\n", 107 | "d_dropout = 0.25\n", 108 | "DCRM_OPTIMIZER = Adam(0.0001, 0.5)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "def d_build_conv(layer_input, filter_size, kernel_size=4, strides=2, activation='leakyrelu', dropout_rate=d_dropout, norm=True):\n", 118 | " c = Conv2D(filter_size, kernel_size=kernel_size, strides=strides, padding='same')(layer_input)\n", 119 | " if activation == 'leakyrelu':\n", 120 | " c = LeakyReLU(alpha=0.2)(c)\n", 121 | " if dropout_rate:\n", 122 | " c = Dropout(dropout_rate)(c)\n", 123 | " if norm == 'inst':\n", 124 | " c = InstanceNormalization()(c)\n", 125 | " return c\n", 126 | "\n", 127 | "\n", 128 | "def build_discriminator():\n", 129 | " d_input = Input(shape=d_input_shape)\n", 130 | " d = d_build_conv(d_input, 32, 5,strides=2, norm=False)\n", 131 | "\n", 132 | " d = d_build_conv(d, 64, 5, strides=2)\n", 133 | " d = d_build_conv(d, 64, 5, strides=2)\n", 134 | " d = d_build_conv(d, 128, 5, strides=2)\n", 135 | " d = d_build_conv(d, 128, 5, strides=2)\n", 136 | " \n", 137 | " flat = Flatten()(d)\n", 138 | " fc1 = Dense(1024, activation='relu')(flat)\n", 139 | " d_output = Dense(1, activation='sigmoid')(fc1)\n", 140 | " \n", 141 | " return Model(d_input, d_output)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "# Discriminator initialization\n", 151 | "DCRM = build_discriminator()\n", 152 | "DCRM.compile(loss=dcrm_loss, optimizer=DCRM_OPTIMIZER)\n", 153 | "if SHOW_SUMMARY:\n", 154 | " DCRM.summary()" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "### Generator Model" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "def gen_loss(y_true, y_pred):\n", 171 | " G_MSE_loss = K.mean(K.square(y_pred - y_true))\n", 172 | " return G_MSE_loss - ALPHA * tf.reduce_mean(tf.log(tf.maximum(y_pred, EPSILON)))\n", 173 | "\n", 174 | "g_input_shape = (INPUT_SHAPE[0], int(INPUT_SHAPE[1] * (MASK_PERCENTAGE *2)), INPUT_SHAPE[2])\n", 175 | "g_dropout = 0.25\n", 176 | "GEN_OPTIMIZER = Adam(0.001, 0.5)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "def g_build_conv(layer_input, filter_size, kernel_size=4, strides=2, activation='leakyrelu', dropout_rate=g_dropout, norm='inst', dilation=1):\n", 186 | " c = AtrousConvolution2D(filter_size, kernel_size=kernel_size, strides=strides,atrous_rate=(dilation,dilation), padding='same')(layer_input)\n", 187 | " if activation == 'leakyrelu':\n", 188 | " c = ReLU()(c)\n", 189 | " if dropout_rate:\n", 190 | " c = Dropout(dropout_rate)(c)\n", 191 | " if norm == 'inst':\n", 192 | " c = InstanceNormalization()(c)\n", 193 | " return c\n", 194 | "\n", 195 | "\n", 196 | "def g_build_deconv(layer_input, filter_size, kernel_size=3, strides=2, activation='relu', dropout=0):\n", 197 | " d = Conv2DTranspose(filter_size, kernel_size=kernel_size, strides=strides, padding='same')(layer_input)\n", 198 | " if activation == 'relu':\n", 199 | " d = ReLU()(d)\n", 200 | " return d\n", 201 | "\n", 202 | "\n", 203 | "def build_generator():\n", 204 | " g_input = Input(shape=g_input_shape)\n", 205 | " \n", 206 | " g1 = g_build_conv(g_input, 64, 5, strides=1)\n", 207 | " g2 = g_build_conv(g1, 128, 4, strides=2)\n", 208 | " g3 = g_build_conv(g2, 256, 4, strides=2)\n", 209 | "\n", 210 | " g4 = g_build_conv(g3, 512, 4, strides=1)\n", 211 | " g5 = g_build_conv(g4, 512, 4, strides=1)\n", 212 | " \n", 213 | " g6 = g_build_conv(g5, 512, 4, strides=1, dilation=2)\n", 214 | " g7 = g_build_conv(g6, 512, 4, strides=1, dilation=4)\n", 215 | " g8 = g_build_conv(g7, 512, 4, strides=1, dilation=8)\n", 216 | " g9 = g_build_conv(g8, 512, 4, strides=1, dilation=16)\n", 217 | " \n", 218 | " g10 = g_build_conv(g9, 512, 4, strides=1)\n", 219 | " g11 = g_build_conv(g10, 512, 4, strides=1)\n", 220 | " \n", 221 | " g12 = g_build_deconv(g11, 256, 4, strides=2)\n", 222 | " g13 = g_build_deconv(g12, 128, 4, strides=2)\n", 223 | " \n", 224 | " g14 = g_build_conv(g13, 128, 4, strides=1)\n", 225 | " g15 = g_build_conv(g14, 64, 4, strides=1)\n", 226 | " \n", 227 | " g_output = AtrousConvolution2D(3, kernel_size=4, strides=(1,1), activation='tanh',padding='same', atrous_rate=(1,1))(g15)\n", 228 | " \n", 229 | " return Model(g_input, g_output)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "# Generator Initialization\n", 239 | "GEN = build_generator()\n", 240 | "GEN.compile(loss=gen_loss, optimizer=GEN_OPTIMIZER)\n", 241 | "if SHOW_SUMMARY:\n", 242 | " GEN.summary()" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "### Combined Model" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "IMAGE = Input(shape=g_input_shape)\n", 259 | "DCRM.trainable = False\n", 260 | "GENERATED_IMAGE = GEN(IMAGE)\n", 261 | "CONF_GENERATED_IMAGE = DCRM(GENERATED_IMAGE)\n", 262 | "\n", 263 | "COMBINED = Model(IMAGE, [CONF_GENERATED_IMAGE, GENERATED_IMAGE])\n", 264 | "COMBINED.compile(loss=['mse', 'mse'], optimizer=GEN_OPTIMIZER)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": {}, 270 | "source": [ 271 | "### Masking and De-Masking" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "def mask_width(img):\n", 281 | " image = img.copy()\n", 282 | " height = image.shape[0]\n", 283 | " width = image.shape[1]\n", 284 | " new_width = int(width * MASK_PERCENTAGE)\n", 285 | " mask = np.ones([height, new_width, 3])\n", 286 | " missing_x = img[:, :new_width]\n", 287 | " missing_y = img[:, width - new_width:]\n", 288 | " missing_part = np.concatenate((missing_x, missing_y), axis=1)\n", 289 | " image = image[:, :width - new_width]\n", 290 | " image = image[:, new_width:]\n", 291 | " return image, missing_part\n", 292 | "\n", 293 | "\n", 294 | "def get_masked_images(images):\n", 295 | " mask_images = []\n", 296 | " missing_images = []\n", 297 | " for image in images:\n", 298 | " mask_image, missing_image = mask_width(image)\n", 299 | " mask_images.append(mask_image)\n", 300 | " missing_images.append(missing_image)\n", 301 | " return np.array(mask_images), np.array(missing_images)\n", 302 | "\n", 303 | "\n", 304 | "def get_demask_images(original_images, generated_images):\n", 305 | " demask_images = []\n", 306 | " for o_image, g_image in zip(original_images, generated_images):\n", 307 | " width = g_image.shape[1] // 2\n", 308 | " x_image = g_image[:, :width]\n", 309 | " y_image = g_image[:, width:]\n", 310 | " o_image = np.concatenate((x_image,o_image, y_image), axis=1)\n", 311 | " demask_images.append(o_image)\n", 312 | " return np.asarray(demask_images)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "# Masking, Demasking example\n", 322 | "# Note: IPython display gives false colors.\n", 323 | "x = data.get_data(1)\n", 324 | "\n", 325 | "# a will be the input and b will be the output for the model\n", 326 | "a, b = get_masked_images(x)\n", 327 | "border = np.ones([x[0].shape[0], 10, 3]).astype(np.uint8)\n", 328 | "print('After masking')\n", 329 | "print('\\tOriginal Image\\t\\t\\t a \\t\\t b')\n", 330 | "image = np.concatenate((border, x[0],border,a[0],border, b[0], border), axis=1)\n", 331 | "IPython.display.display(PIL.Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))\n", 332 | "\n", 333 | "print(\"After desmasking: 'b/2' + a + 'b/2' \")\n", 334 | "c = get_demask_images(a,b)\n", 335 | "IPython.display.display(PIL.Image.fromarray(cv2.cvtColor(c[0], cv2.COLOR_BGR2RGB)))" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "### Utilities\n", 343 | "1. Save Model\n", 344 | "2. Load Model\n", 345 | "3. Save Image\n", 346 | "4. Save Log" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "def save_model():\n", 356 | " global DCRM, GEN\n", 357 | " models = [DCRM, GEN]\n", 358 | " model_names = ['DCRM','GEN']\n", 359 | "\n", 360 | " for model, model_name in zip(models, model_names):\n", 361 | " model_path = CHECKPOINT + \"%s.json\" % model_name\n", 362 | " weights_path = CHECKPOINT + \"/%s.hdf5\" % model_name\n", 363 | " options = {\"file_arch\": model_path, \n", 364 | " \"file_weight\": weights_path}\n", 365 | " json_string = model.to_json()\n", 366 | " open(options['file_arch'], 'w').write(json_string)\n", 367 | " model.save_weights(options['file_weight'])\n", 368 | " print(\"Saved Model\")\n", 369 | " \n", 370 | " \n", 371 | "def load_model():\n", 372 | " # Checking if all the model exists\n", 373 | " model_names = ['DCRM', 'GEN']\n", 374 | " files = os.listdir(CHECKPOINT)\n", 375 | " for model_name in model_names:\n", 376 | " if model_name+\".json\" not in files or\\\n", 377 | " model_name+\".hdf5\" not in files:\n", 378 | " print(\"Models not Found\")\n", 379 | " return\n", 380 | " global DCRM, GEN, COMBINED, IMAGE, GENERATED_IMAGE, CONF_GENERATED_IMAGE\n", 381 | " \n", 382 | " # load DCRM Model\n", 383 | " model_path = CHECKPOINT + \"%s.json\" % 'DCRM'\n", 384 | " weight_path = CHECKPOINT + \"%s.hdf5\" % 'DCRM'\n", 385 | " with open(model_path, 'r') as f:\n", 386 | " DCRM = model_from_json(f.read())\n", 387 | " DCRM.load_weights(weight_path)\n", 388 | " DCRM.compile(loss=dcrm_loss, optimizer=DCRM_OPTIMIZER)\n", 389 | " \n", 390 | " #load GEN Model\n", 391 | " model_path = CHECKPOINT + \"%s.json\" % 'GEN'\n", 392 | " weight_path = CHECKPOINT + \"%s.hdf5\" % 'GEN'\n", 393 | " with open(model_path, 'r') as f:\n", 394 | " GEN = model_from_json(f.read(), custom_objects={'InstanceNormalization': InstanceNormalization()})\n", 395 | " GEN.load_weights(weight_path)\n", 396 | " \n", 397 | " # Combined Model\n", 398 | " DCRM.trainable = False\n", 399 | " IMAGE = Input(shape=g_input_shape)\n", 400 | " GENERATED_IMAGE = GEN(IMAGE)\n", 401 | " CONF_GENERATED_IMAGE = DCRM(GENERATED_IMAGE)\n", 402 | "\n", 403 | " COMBINED = Model(IMAGE, [CONF_GENERATED_IMAGE, GENERATED_IMAGE])\n", 404 | " COMBINED.compile(loss=['mse', 'mse'], optimizer=GEN_OPTIMIZER)\n", 405 | " \n", 406 | " print(\"loaded model\")\n", 407 | " \n", 408 | " \n", 409 | "def save_image(epoch, steps):\n", 410 | " train_image = test_data.get_data(1)\n", 411 | " if train_image is None:\n", 412 | " train_image = test_data.get_data(1)\n", 413 | " \n", 414 | " test_image = data.get_data(1)\n", 415 | " if test_image is None:\n", 416 | " test_image = test_data.get_data(1)\n", 417 | " \n", 418 | " for nc, original in enumerate([train_image, test_image]):\n", 419 | " if nc:\n", 420 | " print(\"Predicting with train image\")\n", 421 | " else:\n", 422 | " print(\"Predicting with test image\")\n", 423 | " \n", 424 | " mask_image_original , missing_image = get_masked_images(original)\n", 425 | " mask_image = mask_image_original.copy()\n", 426 | " mask_image = mask_image / 127.5 - 1\n", 427 | " missing_image = missing_image / 127.5 - 1\n", 428 | " gen_missing = GEN.predict(mask_image)\n", 429 | " gen_missing = (gen_missing + 1) * 127.5\n", 430 | " gen_missing = gen_missing.astype(np.uint8)\n", 431 | " demask_image = get_demask_images(mask_image_original, gen_missing)\n", 432 | "\n", 433 | " mask_image = (mask_image + 1) * 127.5\n", 434 | " mask_image = mask_image.astype(np.uint8)\n", 435 | "\n", 436 | " border = np.ones([original[0].shape[0], 10, 3]).astype(np.uint8)\n", 437 | "\n", 438 | " file_name = str(epoch) + \"_\" + str(steps) + \".jpg\"\n", 439 | " final_image = np.concatenate((border, original[0],border,mask_image_original[0],border, demask_image[0], border), axis=1)\n", 440 | " if not nc:\n", 441 | " cv2.imwrite(os.path.join(SAVED_IMAGES, file_name), final_image)\n", 442 | " final_image = cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)\n", 443 | " print(\"\\t1.Original image \\t 2.Input \\t\\t 3. Output\")\n", 444 | " IPython.display.display(PIL.Image.fromarray(final_image))\n", 445 | " print(\"image saved\")\n", 446 | "\n", 447 | "\n", 448 | "def save_log(log):\n", 449 | " with open('log.txt', 'a') as f:\n", 450 | " f.write(\"%s\\n\"%log)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "metadata": {}, 456 | "source": [ 457 | "## Train" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [ 466 | "def train():\n", 467 | " start_time = datetime.now()\n", 468 | " saved_time = start_time\n", 469 | " \n", 470 | " global MIN_D_LOSS, MIN_G_LOSS, CURRENT_D_LOSS, CURRENT_G_LOSS\n", 471 | " for epoch in range(1, EPOCHS):\n", 472 | " steps = 1\n", 473 | " test = None\n", 474 | " while True:\n", 475 | " original = data.get_data(BATCH)\n", 476 | " if original is None:\n", 477 | " break\n", 478 | " batch_size = original.shape[0]\n", 479 | "\n", 480 | " mask_image, missing_image = get_masked_images(original)\n", 481 | " mask_image = mask_image / 127.5 - 1\n", 482 | " missing_image = missing_image / 127.5 - 1\n", 483 | "\n", 484 | " # Train Discriminator\n", 485 | " gen_missing = GEN.predict(mask_image)\n", 486 | "\n", 487 | " real = np.ones([batch_size, 1])\n", 488 | " fake = np.zeros([batch_size, 1])\n", 489 | " \n", 490 | " d_loss_original = DCRM.train_on_batch(missing_image, real)\n", 491 | " d_loss_mask = DCRM.train_on_batch(gen_missing, fake)\n", 492 | " d_loss = 0.5 * np.add(d_loss_original, d_loss_mask)\n", 493 | "\n", 494 | " # Train Generator\n", 495 | " for i in range(2):\n", 496 | " g_loss = COMBINED.train_on_batch(mask_image, [real, missing_image])\n", 497 | " \n", 498 | " log = \"epoch: %d, steps: %d, DIS loss: %s, GEN loss: %s, Identity loss: %s\" \\\n", 499 | " %(epoch, steps, str(d_loss), str(g_loss[0]), str(g_loss[2]))\n", 500 | " print(log)\n", 501 | " save_log(log)\n", 502 | " steps += 1\n", 503 | " \n", 504 | " # Save model if time taken > TIME_INTERVALS\n", 505 | " current_time = datetime.now()\n", 506 | " difference_time = current_time - saved_time\n", 507 | " if difference_time.seconds >= (TIME_INTERVALS * 60):\n", 508 | " save_model()\n", 509 | " save_image(epoch, steps)\n", 510 | " saved_time = current_time\n", 511 | " clear_output()\n", 512 | " " 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "metadata": {}, 519 | "outputs": [], 520 | "source": [] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "load_model()" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "train()" 538 | ] 539 | }, 540 | { 541 | "cell_type": "markdown", 542 | "metadata": {}, 543 | "source": [ 544 | "## Recursive paint" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "load_model()" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "def recursive_paint(image, factor=3):\n", 563 | " final_image = None\n", 564 | " gen_missing = None\n", 565 | " for i in range(factor):\n", 566 | " demask_image = None\n", 567 | " if i == 0:\n", 568 | " x, y = get_masked_images([image])\n", 569 | " gen_missing = GEN.predict(x)\n", 570 | " final_image = get_demask_images(x, gen_missing)[0]\n", 571 | " else:\n", 572 | " gen_missing = GEN.predict(gen_missing)\n", 573 | " final_image = get_demask_images([final_image], gen_missing)[0]\n", 574 | " return final_image\n", 575 | " " 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [ 584 | "images = data.get_data(1)\n", 585 | "for i, image in enumerate(images):\n", 586 | " image = image / 127.5 - 1\n", 587 | " image = recursive_paint(image)\n", 588 | " image = (image + 1) * 127.5\n", 589 | " image = image.astype(np.uint8)\n", 590 | " path = 'recursive/'+str(i)+'.jpg'\n", 591 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 592 | " IPython.display.display(PIL.Image.fromarray(image))" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": null, 598 | "metadata": {}, 599 | "outputs": [], 600 | "source": [] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": {}, 605 | "source": [ 606 | "## Test from URL" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": null, 612 | "metadata": {}, 613 | "outputs": [], 614 | "source": [ 615 | "url = 'https://upload.wikimedia.org/wikipedia/commons/3/33/A_beach_in_Maldives.jpg'\n", 616 | "\n", 617 | "file_name = os.path.basename(url)\n", 618 | "import urllib.request\n", 619 | "_ = urllib.request.urlretrieve(url, file_name)\n", 620 | "print(\"Downloaded image\")\n", 621 | "\n", 622 | "image = cv2.imread(file_name)\n", 623 | "image = cv2.resize(image, (256,256))\n", 624 | "cropped_image = image[:, 65:193]\n", 625 | "input_image = cropped_image / 127.5 - 1\n", 626 | "input_image = np.expand_dims(input_image, axis=0)\n", 627 | "print(input_image.shape)\n", 628 | "predicted_image = GEN.predict(input_image)\n", 629 | "predicted_image = get_demask_images(input_image, predicted_image)[0]\n", 630 | "predicted_image = (predicted_image + 1) * 127.5\n", 631 | "predicted_image = predicted_image.astype(np.uint8)\n", 632 | "\n", 633 | "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 634 | "predicted_image = cv2.cvtColor(predicted_image, cv2.COLOR_BGR2RGB)\n", 635 | "\n", 636 | "print('original image')\n", 637 | "IPython.display.display(PIL.Image.fromarray(image))\n", 638 | "print('predicted image')\n", 639 | "IPython.display.display(PIL.Image.fromarray(predicted_image))\n", 640 | "\n", 641 | "os.remove(file_name)" 642 | ] 643 | } 644 | ], 645 | "metadata": { 646 | "kernelspec": { 647 | "display_name": "Python 3", 648 | "language": "python", 649 | "name": "python3" 650 | }, 651 | "language_info": { 652 | "codemirror_mode": { 653 | "name": "ipython", 654 | "version": 3 655 | }, 656 | "file_extension": ".py", 657 | "mimetype": "text/x-python", 658 | "name": "python", 659 | "nbconvert_exporter": "python", 660 | "pygments_lexer": "ipython3", 661 | "version": "3.6.7" 662 | } 663 | }, 664 | "nbformat": 4, 665 | "nbformat_minor": 2 666 | } 667 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import random 5 | from augment_image import aug_image 6 | 7 | # raw_data_path: directory where the downloaded images are 8 | # save_path: directory where the numpy images will be 9 | raw_data_path = "data/raw_data/beach_image" 10 | train_save_path = "data/prepared_data/train" 11 | test_save_path = "data/prepared_data/test" 12 | 13 | # Train/Test Data split 14 | train_percen = 0.9 15 | 16 | files = os.listdir(raw_data_path) 17 | random.shuffle(files) 18 | train_files = files[: int(len(files) * train_percen)] 19 | test_files = files[int(len(files) * train_percen) + 1:] 20 | 21 | 22 | total_train_images = 0 23 | total_test_images = 0 24 | 25 | # Augment both train and test dataset by N times 26 | augment_times = 2 27 | 28 | input_shape = (256, 256) 29 | 30 | # batch: each file will have N images 31 | batch = 2000 32 | 33 | # Dumping numpy batch images to save_path 34 | train_dump_counter = 0 35 | test_dump_counter = 0 36 | def dump_numpy(data, is_train_data=True): 37 | global train_dump_counter, test_dump_counter 38 | random.shuffle(data) 39 | if is_train_data: 40 | train_dump_counter += 1 41 | path = os.path.join(train_save_path, 'train_data_' + str(train_dump_counter)) 42 | else: 43 | test_dump_counter += 1 44 | path = os.path.join(test_save_path, 'test_data_' + str(test_dump_counter)) 45 | np.save(path, data) 46 | 47 | 48 | def create_data(files_path, is_train_data=True, augment_times=augment_times): 49 | global total_test_images, total_train_images 50 | bulk = [] 51 | image_counter = 0 52 | for i, file in enumerate(files_path, 1): 53 | image_path = os.path.join(raw_data_path, file) 54 | try: 55 | image = cv2.imread(image_path) 56 | image = cv2.resize(image, input_shape) 57 | bulk.append(image) 58 | image_counter += 1 59 | for _ in range(augment_times): 60 | new_image = aug_image(image) 61 | image_counter += 1 62 | bulk.append(new_image) 63 | except Exception as e: 64 | print("error: ", e) 65 | print("file name: ", image_path) 66 | 67 | print("Proccessed: ", image_counter) 68 | 69 | if len(bulk) >= batch or i == len(files_path): 70 | print("Dumping batch: ", len(bulk)) 71 | dump_numpy(bulk, is_train_data=is_train_data) 72 | bulk = [] 73 | 74 | if is_train_data: 75 | total_train_images += image_counter 76 | else: 77 | total_test_images += image_counter 78 | 79 | # Create Train Dataset 80 | print("CREATING TRAIN DATASET") 81 | create_data(train_files, is_train_data=True) 82 | 83 | # CREATE TEST DATASET 84 | print("CREATING TEST DATASET") 85 | create_data(test_files, is_train_data=False) 86 | 87 | print("*"*50) 88 | print("Data preparation completed") 89 | print("*"*50) 90 | print("Total train images: ", total_train_images) 91 | print("Total test images: ", total_test_images) -------------------------------------------------------------------------------- /prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | mkdir data/prepared_data 3 | mkdir data/prepared_data/train 4 | mkdir data/prepared_data/test 5 | cd data/raw_data 6 | 7 | echo "Downloading Dataset:" 8 | fileid="1hKIn-Z8Uf3voESbJZVsapLHESPabjjrb" 9 | filename="scrap_beach_image.zip" 10 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 11 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 12 | 13 | sudo apt-get install unzip 14 | unzip scrap_beach_image.zip -d ./ 15 | sudo rm scrap_beach_image.zip 16 | cd ../../ 17 | echo "Preparing Data:" 18 | python3 prepare_data.py 19 | echo "completed" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | six==1.12.0 2 | numpy==1.15.4 3 | scipy==1.1.0 4 | matplotlib==3.0.2 5 | scikit-image==0.14.1 6 | imageio==2.4.1 7 | Shapely 8 | opencv-python==3.4.3.18 9 | Pillow==6.2.0 10 | imgaug==0.2.6 11 | tensorflow-gpu==1.10.0 12 | keras==2.2.4 13 | git+https://www.github.com/keras-team/keras-contrib.git -------------------------------------------------------------------------------- /saved_images/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | --------------------------------------------------------------------------------