├── .gitignore
├── README.md
├── augment_image.py
├── checkpoint
└── .gitignore
├── data
├── prepared_data
│ └── .gitignore
└── raw_data
│ └── .gitignore
├── dataloader.py
├── outpaint.ipynb
├── prepare_data.py
├── prepare_data.sh
├── requirements.txt
└── saved_images
└── .gitignore
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .ipynb_checkpoints/
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Keras implementation of Image OutPainting
2 |
3 | This is an implementation of [Painting Outside the Box: Image Outpainting](https://cs230.stanford.edu/projects_spring_2018/posters/8265861.pdf) paper from Standford University.
4 | Some changes have been made to work with 256*256 image:
5 | - Added Identity loss i.e from generated image to the original image
6 | - Removed patches from training data. (training pipeline)
7 | - Replaced masking with cropping. (training pipeline)
8 | - Added convolution layers.
9 |
10 | ## Results
11 | The model was train with [3500 scrapped beach data](https://drive.google.com/open?id=1hKIn-Z8Uf3voESbJZVsapLHESPabjjrb) with agumentation totalling upto 10500 images for 25 epochs.
12 | 
13 |
14 | #### Recursive painting
15 | 
16 |
17 | ### Install Requirements
18 | ```
19 | sudo apt-get install curl
20 | sudo pip3 install -r requirements.txt
21 | ```
22 |
23 | ## Get Started
24 |
25 | 1. Prepare Data:
26 | ```sh
27 | # Downloads the beach data and converts to numpy batch data
28 | # saves the Numpy batch data to 'data/prepared_data/'
29 | sh prepare_data.sh
30 | ```
31 | 2. Build Model
32 | * To build Model from scratch you can directly run 'outpaint.ipynb'
33 |
OR
34 | * You can [Download](https://drive.google.com/open?id=1MfXsRwjx5CTRGBoLx154S0h-Q3rIUNH0) my trained model and move it to 'checkpoint/' and run it.
35 |
36 | ## References
37 | * [Painting Outside the Box: Image Outpainting](https://cs230.stanford.edu/projects_spring_2018/posters/8265861.pdf)
38 |
--------------------------------------------------------------------------------
/augment_image.py:
--------------------------------------------------------------------------------
1 | import imgaug as ia
2 | from imgaug import augmenters as iaa
3 | import numpy as np
4 | import random
5 |
6 |
7 | brightness = iaa.Add((-7, 7), per_channel=0.5)
8 | contrast = iaa.ContrastNormalization((0.8, 1.6), per_channel=0.5)
9 | perspective = iaa.PerspectiveTransform(scale=(0.025, 0.090))
10 | gaussian_noise = iaa.AdditiveGaussianNoise(loc=0, scale=(0.03*255, 0.04*255), per_channel=0.5)
11 | crop = iaa.Crop(px=(0, 25))
12 |
13 |
14 | def aug_image(my_image):
15 | image = my_image.copy()
16 | if random.choice([0,0,1]):
17 | image = perspective.augment_image(image)
18 | if random.choice([0,0,1]):
19 | image = brightness.augment_image(image)
20 | if random.choice([0,0,1]):
21 | image = contrast.augment_image(image)
22 | if random.choice([0,0,1]):
23 | image = gaussian_noise.augment_image(image)
24 | if random.choice([0,0,1]):
25 | image = crop.augment_image(image)
26 | return image
27 |
28 |
29 | if __name__ == "__main__":
30 | import cv2
31 | image = cv2.imread('/home/ben/work/compare_myntra/test_image/test_images/taken_15324282418.jpg')
32 | aug_images = aug_image(image)
33 | aug_images = [aug_images]
34 | print(len(aug_images))
35 | image = cv2.resize(image, (600,600))
36 | image_1 = cv2.resize(aug_images[0], (600,600))
37 | cv2.imshow('1', image)
38 | cv2.waitKey(0)
39 | cv2.imshow('2', image_1)
40 | cv2.waitKey(0)
41 |
--------------------------------------------------------------------------------
/checkpoint/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/data/prepared_data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/data/raw_data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from random import shuffle
4 |
5 |
6 | DATA_PATH = "data/prepared_data/train"
7 | TEST_PATH = "data/prepared_data/test"
8 |
9 |
10 | class Data():
11 |
12 | def __init__(self):
13 | self.X_counter = 0
14 | self.file_counter = 0
15 | self.files = os.listdir(DATA_PATH)
16 | self.files = [file for file in self.files if '.npy' in file]
17 | shuffle(self.files)
18 | self._load_data()
19 |
20 | def _load_data(self):
21 | datas = np.load(os.path.join(DATA_PATH, self.files[self.file_counter]))
22 | self.X = []
23 | for data in datas:
24 | self.X.append(data)
25 | shuffle(self.X)
26 | self.X = np.asarray(self.X)
27 | self.file_counter += 1
28 |
29 | def get_data(self, batch_size):
30 | if self.X_counter >= len(self.X):
31 | if self.file_counter > len(self.files) - 1:
32 | print("Data exhausted, Re Initialize")
33 | self.__init__()
34 | return None
35 | else:
36 | self._load_data()
37 | self.X_counter = 0
38 |
39 | if self.X_counter + batch_size <= len(self.X):
40 | remaining = len(self.X) - (self.X_counter)
41 | X = self.X[self.X_counter: self.X_counter + batch_size]
42 | else:
43 | X = self.X[self.X_counter: ]
44 |
45 | self.X_counter += batch_size
46 | return X
47 |
48 |
49 | class TestData():
50 |
51 | def __init__(self):
52 | self.X_counter = 0
53 | self.file_counter = 0
54 | self.files = os.listdir(TEST_PATH)
55 | self.files = [file for file in self.files if '.npy' in file]
56 | shuffle(self.files)
57 | self._load_data()
58 |
59 | def _load_data(self):
60 | datas = np.load(os.path.join(TEST_PATH, self.files[self.file_counter]))
61 | self.X = []
62 | for data in datas:
63 | self.X.append(data)
64 | shuffle(self.X)
65 | self.X = np.asarray(self.X)
66 | self.file_counter += 1
67 |
68 | def get_data(self, batch_size):
69 | if self.X_counter >= len(self.X):
70 | if self.file_counter > len(self.files) - 1:
71 | print("Data exhausted, Re Initialize")
72 | self.__init__()
73 | return None
74 | else:
75 | self._load_data()
76 | self.X_counter = 0
77 |
78 | if self.X_counter + batch_size <= len(self.X):
79 | remaining = len(self.X) - (self.X_counter)
80 | X = self.X[self.X_counter: self.X_counter + batch_size]
81 | else:
82 | X = self.X[self.X_counter: ]
83 |
84 | self.X_counter += batch_size
85 | return X
86 |
--------------------------------------------------------------------------------
/outpaint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Out Paint"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "from keras.layers.convolutional import Conv2D, AtrousConvolution2D\n",
17 | "from keras.layers import Activation, Dense, Input, Conv2DTranspose, Dense, Flatten\n",
18 | "from keras.layers import ReLU, Dropout, Concatenate, BatchNormalization, Reshape\n",
19 | "from keras.layers.advanced_activations import LeakyReLU\n",
20 | "from keras.models import Model, model_from_json\n",
21 | "from keras.optimizers import Adam\n",
22 | "from keras.layers.convolutional import UpSampling2D\n",
23 | "import keras.backend as K\n",
24 | "import tensorflow as tf\n",
25 | "\n",
26 | "import os\n",
27 | "import numpy as np\n",
28 | "import PIL\n",
29 | "import cv2\n",
30 | "import IPython.display\n",
31 | "from IPython.display import clear_output\n",
32 | "from datetime import datetime\n",
33 | "from dataloader import Data, TestData"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "try:\n",
43 | " from keras_contrib.layers.normalization import InstanceNormalization\n",
44 | "except Exception:\n",
45 | " from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "# Initialize dataloader\n",
55 | "data = Data()\n",
56 | "test_data = Data()"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "# Saves Model in every N minutes\n",
66 | "TIME_INTERVALS = 2\n",
67 | "SHOW_SUMMARY = True\n",
68 | "\n",
69 | "INPUT_SHAPE = (256, 256, 3)\n",
70 | "EPOCHS = 500\n",
71 | "BATCH = 1\n",
72 | "\n",
73 | "# 25% i.e 64 width size will be mask from both side\n",
74 | "MASK_PERCENTAGE = .25\n",
75 | "\n",
76 | "EPSILON = 1e-9\n",
77 | "ALPHA = 0.0004\n",
78 | "\n",
79 | "CHECKPOINT = \"checkpoint/\"\n",
80 | "SAVED_IMAGES = \"saved_images/\""
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "## Models"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "metadata": {},
93 | "source": [
94 | "### Discriminator"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "def dcrm_loss(y_true, y_pred):\n",
104 | " return -tf.reduce_mean(tf.log(tf.maximum(y_true, EPSILON)) + tf.log(tf.maximum(1. - y_pred, EPSILON)))\n",
105 | "\n",
106 | "d_input_shape = (INPUT_SHAPE[0], int(INPUT_SHAPE[1] * (MASK_PERCENTAGE *2)), INPUT_SHAPE[2])\n",
107 | "d_dropout = 0.25\n",
108 | "DCRM_OPTIMIZER = Adam(0.0001, 0.5)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "def d_build_conv(layer_input, filter_size, kernel_size=4, strides=2, activation='leakyrelu', dropout_rate=d_dropout, norm=True):\n",
118 | " c = Conv2D(filter_size, kernel_size=kernel_size, strides=strides, padding='same')(layer_input)\n",
119 | " if activation == 'leakyrelu':\n",
120 | " c = LeakyReLU(alpha=0.2)(c)\n",
121 | " if dropout_rate:\n",
122 | " c = Dropout(dropout_rate)(c)\n",
123 | " if norm == 'inst':\n",
124 | " c = InstanceNormalization()(c)\n",
125 | " return c\n",
126 | "\n",
127 | "\n",
128 | "def build_discriminator():\n",
129 | " d_input = Input(shape=d_input_shape)\n",
130 | " d = d_build_conv(d_input, 32, 5,strides=2, norm=False)\n",
131 | "\n",
132 | " d = d_build_conv(d, 64, 5, strides=2)\n",
133 | " d = d_build_conv(d, 64, 5, strides=2)\n",
134 | " d = d_build_conv(d, 128, 5, strides=2)\n",
135 | " d = d_build_conv(d, 128, 5, strides=2)\n",
136 | " \n",
137 | " flat = Flatten()(d)\n",
138 | " fc1 = Dense(1024, activation='relu')(flat)\n",
139 | " d_output = Dense(1, activation='sigmoid')(fc1)\n",
140 | " \n",
141 | " return Model(d_input, d_output)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "# Discriminator initialization\n",
151 | "DCRM = build_discriminator()\n",
152 | "DCRM.compile(loss=dcrm_loss, optimizer=DCRM_OPTIMIZER)\n",
153 | "if SHOW_SUMMARY:\n",
154 | " DCRM.summary()"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {},
160 | "source": [
161 | "### Generator Model"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "def gen_loss(y_true, y_pred):\n",
171 | " G_MSE_loss = K.mean(K.square(y_pred - y_true))\n",
172 | " return G_MSE_loss - ALPHA * tf.reduce_mean(tf.log(tf.maximum(y_pred, EPSILON)))\n",
173 | "\n",
174 | "g_input_shape = (INPUT_SHAPE[0], int(INPUT_SHAPE[1] * (MASK_PERCENTAGE *2)), INPUT_SHAPE[2])\n",
175 | "g_dropout = 0.25\n",
176 | "GEN_OPTIMIZER = Adam(0.001, 0.5)"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "def g_build_conv(layer_input, filter_size, kernel_size=4, strides=2, activation='leakyrelu', dropout_rate=g_dropout, norm='inst', dilation=1):\n",
186 | " c = AtrousConvolution2D(filter_size, kernel_size=kernel_size, strides=strides,atrous_rate=(dilation,dilation), padding='same')(layer_input)\n",
187 | " if activation == 'leakyrelu':\n",
188 | " c = ReLU()(c)\n",
189 | " if dropout_rate:\n",
190 | " c = Dropout(dropout_rate)(c)\n",
191 | " if norm == 'inst':\n",
192 | " c = InstanceNormalization()(c)\n",
193 | " return c\n",
194 | "\n",
195 | "\n",
196 | "def g_build_deconv(layer_input, filter_size, kernel_size=3, strides=2, activation='relu', dropout=0):\n",
197 | " d = Conv2DTranspose(filter_size, kernel_size=kernel_size, strides=strides, padding='same')(layer_input)\n",
198 | " if activation == 'relu':\n",
199 | " d = ReLU()(d)\n",
200 | " return d\n",
201 | "\n",
202 | "\n",
203 | "def build_generator():\n",
204 | " g_input = Input(shape=g_input_shape)\n",
205 | " \n",
206 | " g1 = g_build_conv(g_input, 64, 5, strides=1)\n",
207 | " g2 = g_build_conv(g1, 128, 4, strides=2)\n",
208 | " g3 = g_build_conv(g2, 256, 4, strides=2)\n",
209 | "\n",
210 | " g4 = g_build_conv(g3, 512, 4, strides=1)\n",
211 | " g5 = g_build_conv(g4, 512, 4, strides=1)\n",
212 | " \n",
213 | " g6 = g_build_conv(g5, 512, 4, strides=1, dilation=2)\n",
214 | " g7 = g_build_conv(g6, 512, 4, strides=1, dilation=4)\n",
215 | " g8 = g_build_conv(g7, 512, 4, strides=1, dilation=8)\n",
216 | " g9 = g_build_conv(g8, 512, 4, strides=1, dilation=16)\n",
217 | " \n",
218 | " g10 = g_build_conv(g9, 512, 4, strides=1)\n",
219 | " g11 = g_build_conv(g10, 512, 4, strides=1)\n",
220 | " \n",
221 | " g12 = g_build_deconv(g11, 256, 4, strides=2)\n",
222 | " g13 = g_build_deconv(g12, 128, 4, strides=2)\n",
223 | " \n",
224 | " g14 = g_build_conv(g13, 128, 4, strides=1)\n",
225 | " g15 = g_build_conv(g14, 64, 4, strides=1)\n",
226 | " \n",
227 | " g_output = AtrousConvolution2D(3, kernel_size=4, strides=(1,1), activation='tanh',padding='same', atrous_rate=(1,1))(g15)\n",
228 | " \n",
229 | " return Model(g_input, g_output)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "# Generator Initialization\n",
239 | "GEN = build_generator()\n",
240 | "GEN.compile(loss=gen_loss, optimizer=GEN_OPTIMIZER)\n",
241 | "if SHOW_SUMMARY:\n",
242 | " GEN.summary()"
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "metadata": {},
248 | "source": [
249 | "### Combined Model"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": null,
255 | "metadata": {},
256 | "outputs": [],
257 | "source": [
258 | "IMAGE = Input(shape=g_input_shape)\n",
259 | "DCRM.trainable = False\n",
260 | "GENERATED_IMAGE = GEN(IMAGE)\n",
261 | "CONF_GENERATED_IMAGE = DCRM(GENERATED_IMAGE)\n",
262 | "\n",
263 | "COMBINED = Model(IMAGE, [CONF_GENERATED_IMAGE, GENERATED_IMAGE])\n",
264 | "COMBINED.compile(loss=['mse', 'mse'], optimizer=GEN_OPTIMIZER)"
265 | ]
266 | },
267 | {
268 | "cell_type": "markdown",
269 | "metadata": {},
270 | "source": [
271 | "### Masking and De-Masking"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "metadata": {},
278 | "outputs": [],
279 | "source": [
280 | "def mask_width(img):\n",
281 | " image = img.copy()\n",
282 | " height = image.shape[0]\n",
283 | " width = image.shape[1]\n",
284 | " new_width = int(width * MASK_PERCENTAGE)\n",
285 | " mask = np.ones([height, new_width, 3])\n",
286 | " missing_x = img[:, :new_width]\n",
287 | " missing_y = img[:, width - new_width:]\n",
288 | " missing_part = np.concatenate((missing_x, missing_y), axis=1)\n",
289 | " image = image[:, :width - new_width]\n",
290 | " image = image[:, new_width:]\n",
291 | " return image, missing_part\n",
292 | "\n",
293 | "\n",
294 | "def get_masked_images(images):\n",
295 | " mask_images = []\n",
296 | " missing_images = []\n",
297 | " for image in images:\n",
298 | " mask_image, missing_image = mask_width(image)\n",
299 | " mask_images.append(mask_image)\n",
300 | " missing_images.append(missing_image)\n",
301 | " return np.array(mask_images), np.array(missing_images)\n",
302 | "\n",
303 | "\n",
304 | "def get_demask_images(original_images, generated_images):\n",
305 | " demask_images = []\n",
306 | " for o_image, g_image in zip(original_images, generated_images):\n",
307 | " width = g_image.shape[1] // 2\n",
308 | " x_image = g_image[:, :width]\n",
309 | " y_image = g_image[:, width:]\n",
310 | " o_image = np.concatenate((x_image,o_image, y_image), axis=1)\n",
311 | " demask_images.append(o_image)\n",
312 | " return np.asarray(demask_images)"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {},
319 | "outputs": [],
320 | "source": [
321 | "# Masking, Demasking example\n",
322 | "# Note: IPython display gives false colors.\n",
323 | "x = data.get_data(1)\n",
324 | "\n",
325 | "# a will be the input and b will be the output for the model\n",
326 | "a, b = get_masked_images(x)\n",
327 | "border = np.ones([x[0].shape[0], 10, 3]).astype(np.uint8)\n",
328 | "print('After masking')\n",
329 | "print('\\tOriginal Image\\t\\t\\t a \\t\\t b')\n",
330 | "image = np.concatenate((border, x[0],border,a[0],border, b[0], border), axis=1)\n",
331 | "IPython.display.display(PIL.Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))\n",
332 | "\n",
333 | "print(\"After desmasking: 'b/2' + a + 'b/2' \")\n",
334 | "c = get_demask_images(a,b)\n",
335 | "IPython.display.display(PIL.Image.fromarray(cv2.cvtColor(c[0], cv2.COLOR_BGR2RGB)))"
336 | ]
337 | },
338 | {
339 | "cell_type": "markdown",
340 | "metadata": {},
341 | "source": [
342 | "### Utilities\n",
343 | "1. Save Model\n",
344 | "2. Load Model\n",
345 | "3. Save Image\n",
346 | "4. Save Log"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {},
353 | "outputs": [],
354 | "source": [
355 | "def save_model():\n",
356 | " global DCRM, GEN\n",
357 | " models = [DCRM, GEN]\n",
358 | " model_names = ['DCRM','GEN']\n",
359 | "\n",
360 | " for model, model_name in zip(models, model_names):\n",
361 | " model_path = CHECKPOINT + \"%s.json\" % model_name\n",
362 | " weights_path = CHECKPOINT + \"/%s.hdf5\" % model_name\n",
363 | " options = {\"file_arch\": model_path, \n",
364 | " \"file_weight\": weights_path}\n",
365 | " json_string = model.to_json()\n",
366 | " open(options['file_arch'], 'w').write(json_string)\n",
367 | " model.save_weights(options['file_weight'])\n",
368 | " print(\"Saved Model\")\n",
369 | " \n",
370 | " \n",
371 | "def load_model():\n",
372 | " # Checking if all the model exists\n",
373 | " model_names = ['DCRM', 'GEN']\n",
374 | " files = os.listdir(CHECKPOINT)\n",
375 | " for model_name in model_names:\n",
376 | " if model_name+\".json\" not in files or\\\n",
377 | " model_name+\".hdf5\" not in files:\n",
378 | " print(\"Models not Found\")\n",
379 | " return\n",
380 | " global DCRM, GEN, COMBINED, IMAGE, GENERATED_IMAGE, CONF_GENERATED_IMAGE\n",
381 | " \n",
382 | " # load DCRM Model\n",
383 | " model_path = CHECKPOINT + \"%s.json\" % 'DCRM'\n",
384 | " weight_path = CHECKPOINT + \"%s.hdf5\" % 'DCRM'\n",
385 | " with open(model_path, 'r') as f:\n",
386 | " DCRM = model_from_json(f.read())\n",
387 | " DCRM.load_weights(weight_path)\n",
388 | " DCRM.compile(loss=dcrm_loss, optimizer=DCRM_OPTIMIZER)\n",
389 | " \n",
390 | " #load GEN Model\n",
391 | " model_path = CHECKPOINT + \"%s.json\" % 'GEN'\n",
392 | " weight_path = CHECKPOINT + \"%s.hdf5\" % 'GEN'\n",
393 | " with open(model_path, 'r') as f:\n",
394 | " GEN = model_from_json(f.read(), custom_objects={'InstanceNormalization': InstanceNormalization()})\n",
395 | " GEN.load_weights(weight_path)\n",
396 | " \n",
397 | " # Combined Model\n",
398 | " DCRM.trainable = False\n",
399 | " IMAGE = Input(shape=g_input_shape)\n",
400 | " GENERATED_IMAGE = GEN(IMAGE)\n",
401 | " CONF_GENERATED_IMAGE = DCRM(GENERATED_IMAGE)\n",
402 | "\n",
403 | " COMBINED = Model(IMAGE, [CONF_GENERATED_IMAGE, GENERATED_IMAGE])\n",
404 | " COMBINED.compile(loss=['mse', 'mse'], optimizer=GEN_OPTIMIZER)\n",
405 | " \n",
406 | " print(\"loaded model\")\n",
407 | " \n",
408 | " \n",
409 | "def save_image(epoch, steps):\n",
410 | " train_image = test_data.get_data(1)\n",
411 | " if train_image is None:\n",
412 | " train_image = test_data.get_data(1)\n",
413 | " \n",
414 | " test_image = data.get_data(1)\n",
415 | " if test_image is None:\n",
416 | " test_image = test_data.get_data(1)\n",
417 | " \n",
418 | " for nc, original in enumerate([train_image, test_image]):\n",
419 | " if nc:\n",
420 | " print(\"Predicting with train image\")\n",
421 | " else:\n",
422 | " print(\"Predicting with test image\")\n",
423 | " \n",
424 | " mask_image_original , missing_image = get_masked_images(original)\n",
425 | " mask_image = mask_image_original.copy()\n",
426 | " mask_image = mask_image / 127.5 - 1\n",
427 | " missing_image = missing_image / 127.5 - 1\n",
428 | " gen_missing = GEN.predict(mask_image)\n",
429 | " gen_missing = (gen_missing + 1) * 127.5\n",
430 | " gen_missing = gen_missing.astype(np.uint8)\n",
431 | " demask_image = get_demask_images(mask_image_original, gen_missing)\n",
432 | "\n",
433 | " mask_image = (mask_image + 1) * 127.5\n",
434 | " mask_image = mask_image.astype(np.uint8)\n",
435 | "\n",
436 | " border = np.ones([original[0].shape[0], 10, 3]).astype(np.uint8)\n",
437 | "\n",
438 | " file_name = str(epoch) + \"_\" + str(steps) + \".jpg\"\n",
439 | " final_image = np.concatenate((border, original[0],border,mask_image_original[0],border, demask_image[0], border), axis=1)\n",
440 | " if not nc:\n",
441 | " cv2.imwrite(os.path.join(SAVED_IMAGES, file_name), final_image)\n",
442 | " final_image = cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)\n",
443 | " print(\"\\t1.Original image \\t 2.Input \\t\\t 3. Output\")\n",
444 | " IPython.display.display(PIL.Image.fromarray(final_image))\n",
445 | " print(\"image saved\")\n",
446 | "\n",
447 | "\n",
448 | "def save_log(log):\n",
449 | " with open('log.txt', 'a') as f:\n",
450 | " f.write(\"%s\\n\"%log)"
451 | ]
452 | },
453 | {
454 | "cell_type": "markdown",
455 | "metadata": {},
456 | "source": [
457 | "## Train"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": null,
463 | "metadata": {},
464 | "outputs": [],
465 | "source": [
466 | "def train():\n",
467 | " start_time = datetime.now()\n",
468 | " saved_time = start_time\n",
469 | " \n",
470 | " global MIN_D_LOSS, MIN_G_LOSS, CURRENT_D_LOSS, CURRENT_G_LOSS\n",
471 | " for epoch in range(1, EPOCHS):\n",
472 | " steps = 1\n",
473 | " test = None\n",
474 | " while True:\n",
475 | " original = data.get_data(BATCH)\n",
476 | " if original is None:\n",
477 | " break\n",
478 | " batch_size = original.shape[0]\n",
479 | "\n",
480 | " mask_image, missing_image = get_masked_images(original)\n",
481 | " mask_image = mask_image / 127.5 - 1\n",
482 | " missing_image = missing_image / 127.5 - 1\n",
483 | "\n",
484 | " # Train Discriminator\n",
485 | " gen_missing = GEN.predict(mask_image)\n",
486 | "\n",
487 | " real = np.ones([batch_size, 1])\n",
488 | " fake = np.zeros([batch_size, 1])\n",
489 | " \n",
490 | " d_loss_original = DCRM.train_on_batch(missing_image, real)\n",
491 | " d_loss_mask = DCRM.train_on_batch(gen_missing, fake)\n",
492 | " d_loss = 0.5 * np.add(d_loss_original, d_loss_mask)\n",
493 | "\n",
494 | " # Train Generator\n",
495 | " for i in range(2):\n",
496 | " g_loss = COMBINED.train_on_batch(mask_image, [real, missing_image])\n",
497 | " \n",
498 | " log = \"epoch: %d, steps: %d, DIS loss: %s, GEN loss: %s, Identity loss: %s\" \\\n",
499 | " %(epoch, steps, str(d_loss), str(g_loss[0]), str(g_loss[2]))\n",
500 | " print(log)\n",
501 | " save_log(log)\n",
502 | " steps += 1\n",
503 | " \n",
504 | " # Save model if time taken > TIME_INTERVALS\n",
505 | " current_time = datetime.now()\n",
506 | " difference_time = current_time - saved_time\n",
507 | " if difference_time.seconds >= (TIME_INTERVALS * 60):\n",
508 | " save_model()\n",
509 | " save_image(epoch, steps)\n",
510 | " saved_time = current_time\n",
511 | " clear_output()\n",
512 | " "
513 | ]
514 | },
515 | {
516 | "cell_type": "code",
517 | "execution_count": null,
518 | "metadata": {},
519 | "outputs": [],
520 | "source": []
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": null,
525 | "metadata": {},
526 | "outputs": [],
527 | "source": [
528 | "load_model()"
529 | ]
530 | },
531 | {
532 | "cell_type": "code",
533 | "execution_count": null,
534 | "metadata": {},
535 | "outputs": [],
536 | "source": [
537 | "train()"
538 | ]
539 | },
540 | {
541 | "cell_type": "markdown",
542 | "metadata": {},
543 | "source": [
544 | "## Recursive paint"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [],
552 | "source": [
553 | "load_model()"
554 | ]
555 | },
556 | {
557 | "cell_type": "code",
558 | "execution_count": null,
559 | "metadata": {},
560 | "outputs": [],
561 | "source": [
562 | "def recursive_paint(image, factor=3):\n",
563 | " final_image = None\n",
564 | " gen_missing = None\n",
565 | " for i in range(factor):\n",
566 | " demask_image = None\n",
567 | " if i == 0:\n",
568 | " x, y = get_masked_images([image])\n",
569 | " gen_missing = GEN.predict(x)\n",
570 | " final_image = get_demask_images(x, gen_missing)[0]\n",
571 | " else:\n",
572 | " gen_missing = GEN.predict(gen_missing)\n",
573 | " final_image = get_demask_images([final_image], gen_missing)[0]\n",
574 | " return final_image\n",
575 | " "
576 | ]
577 | },
578 | {
579 | "cell_type": "code",
580 | "execution_count": null,
581 | "metadata": {},
582 | "outputs": [],
583 | "source": [
584 | "images = data.get_data(1)\n",
585 | "for i, image in enumerate(images):\n",
586 | " image = image / 127.5 - 1\n",
587 | " image = recursive_paint(image)\n",
588 | " image = (image + 1) * 127.5\n",
589 | " image = image.astype(np.uint8)\n",
590 | " path = 'recursive/'+str(i)+'.jpg'\n",
591 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
592 | " IPython.display.display(PIL.Image.fromarray(image))"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": null,
598 | "metadata": {},
599 | "outputs": [],
600 | "source": []
601 | },
602 | {
603 | "cell_type": "markdown",
604 | "metadata": {},
605 | "source": [
606 | "## Test from URL"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": null,
612 | "metadata": {},
613 | "outputs": [],
614 | "source": [
615 | "url = 'https://upload.wikimedia.org/wikipedia/commons/3/33/A_beach_in_Maldives.jpg'\n",
616 | "\n",
617 | "file_name = os.path.basename(url)\n",
618 | "import urllib.request\n",
619 | "_ = urllib.request.urlretrieve(url, file_name)\n",
620 | "print(\"Downloaded image\")\n",
621 | "\n",
622 | "image = cv2.imread(file_name)\n",
623 | "image = cv2.resize(image, (256,256))\n",
624 | "cropped_image = image[:, 65:193]\n",
625 | "input_image = cropped_image / 127.5 - 1\n",
626 | "input_image = np.expand_dims(input_image, axis=0)\n",
627 | "print(input_image.shape)\n",
628 | "predicted_image = GEN.predict(input_image)\n",
629 | "predicted_image = get_demask_images(input_image, predicted_image)[0]\n",
630 | "predicted_image = (predicted_image + 1) * 127.5\n",
631 | "predicted_image = predicted_image.astype(np.uint8)\n",
632 | "\n",
633 | "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
634 | "predicted_image = cv2.cvtColor(predicted_image, cv2.COLOR_BGR2RGB)\n",
635 | "\n",
636 | "print('original image')\n",
637 | "IPython.display.display(PIL.Image.fromarray(image))\n",
638 | "print('predicted image')\n",
639 | "IPython.display.display(PIL.Image.fromarray(predicted_image))\n",
640 | "\n",
641 | "os.remove(file_name)"
642 | ]
643 | }
644 | ],
645 | "metadata": {
646 | "kernelspec": {
647 | "display_name": "Python 3",
648 | "language": "python",
649 | "name": "python3"
650 | },
651 | "language_info": {
652 | "codemirror_mode": {
653 | "name": "ipython",
654 | "version": 3
655 | },
656 | "file_extension": ".py",
657 | "mimetype": "text/x-python",
658 | "name": "python",
659 | "nbconvert_exporter": "python",
660 | "pygments_lexer": "ipython3",
661 | "version": "3.6.7"
662 | }
663 | },
664 | "nbformat": 4,
665 | "nbformat_minor": 2
666 | }
667 |
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import random
5 | from augment_image import aug_image
6 |
7 | # raw_data_path: directory where the downloaded images are
8 | # save_path: directory where the numpy images will be
9 | raw_data_path = "data/raw_data/beach_image"
10 | train_save_path = "data/prepared_data/train"
11 | test_save_path = "data/prepared_data/test"
12 |
13 | # Train/Test Data split
14 | train_percen = 0.9
15 |
16 | files = os.listdir(raw_data_path)
17 | random.shuffle(files)
18 | train_files = files[: int(len(files) * train_percen)]
19 | test_files = files[int(len(files) * train_percen) + 1:]
20 |
21 |
22 | total_train_images = 0
23 | total_test_images = 0
24 |
25 | # Augment both train and test dataset by N times
26 | augment_times = 2
27 |
28 | input_shape = (256, 256)
29 |
30 | # batch: each file will have N images
31 | batch = 2000
32 |
33 | # Dumping numpy batch images to save_path
34 | train_dump_counter = 0
35 | test_dump_counter = 0
36 | def dump_numpy(data, is_train_data=True):
37 | global train_dump_counter, test_dump_counter
38 | random.shuffle(data)
39 | if is_train_data:
40 | train_dump_counter += 1
41 | path = os.path.join(train_save_path, 'train_data_' + str(train_dump_counter))
42 | else:
43 | test_dump_counter += 1
44 | path = os.path.join(test_save_path, 'test_data_' + str(test_dump_counter))
45 | np.save(path, data)
46 |
47 |
48 | def create_data(files_path, is_train_data=True, augment_times=augment_times):
49 | global total_test_images, total_train_images
50 | bulk = []
51 | image_counter = 0
52 | for i, file in enumerate(files_path, 1):
53 | image_path = os.path.join(raw_data_path, file)
54 | try:
55 | image = cv2.imread(image_path)
56 | image = cv2.resize(image, input_shape)
57 | bulk.append(image)
58 | image_counter += 1
59 | for _ in range(augment_times):
60 | new_image = aug_image(image)
61 | image_counter += 1
62 | bulk.append(new_image)
63 | except Exception as e:
64 | print("error: ", e)
65 | print("file name: ", image_path)
66 |
67 | print("Proccessed: ", image_counter)
68 |
69 | if len(bulk) >= batch or i == len(files_path):
70 | print("Dumping batch: ", len(bulk))
71 | dump_numpy(bulk, is_train_data=is_train_data)
72 | bulk = []
73 |
74 | if is_train_data:
75 | total_train_images += image_counter
76 | else:
77 | total_test_images += image_counter
78 |
79 | # Create Train Dataset
80 | print("CREATING TRAIN DATASET")
81 | create_data(train_files, is_train_data=True)
82 |
83 | # CREATE TEST DATASET
84 | print("CREATING TEST DATASET")
85 | create_data(test_files, is_train_data=False)
86 |
87 | print("*"*50)
88 | print("Data preparation completed")
89 | print("*"*50)
90 | print("Total train images: ", total_train_images)
91 | print("Total test images: ", total_test_images)
--------------------------------------------------------------------------------
/prepare_data.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | mkdir data/prepared_data
3 | mkdir data/prepared_data/train
4 | mkdir data/prepared_data/test
5 | cd data/raw_data
6 |
7 | echo "Downloading Dataset:"
8 | fileid="1hKIn-Z8Uf3voESbJZVsapLHESPabjjrb"
9 | filename="scrap_beach_image.zip"
10 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null
11 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename}
12 |
13 | sudo apt-get install unzip
14 | unzip scrap_beach_image.zip -d ./
15 | sudo rm scrap_beach_image.zip
16 | cd ../../
17 | echo "Preparing Data:"
18 | python3 prepare_data.py
19 | echo "completed"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | six==1.12.0
2 | numpy==1.15.4
3 | scipy==1.1.0
4 | matplotlib==3.0.2
5 | scikit-image==0.14.1
6 | imageio==2.4.1
7 | Shapely
8 | opencv-python==3.4.3.18
9 | Pillow==6.2.0
10 | imgaug==0.2.6
11 | tensorflow-gpu==1.10.0
12 | keras==2.2.4
13 | git+https://www.github.com/keras-team/keras-contrib.git
--------------------------------------------------------------------------------
/saved_images/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------