├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── Pix2Pix-DepthEstimation.ipynb ├── README.md └── assets ├── input_depth.png ├── mono_depth_estimator_icon_web.png ├── result1.gif ├── result2.gif ├── result3.gif ├── result4.gif ├── training.gif └── training_example.png /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | patreon: gsurma 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Greg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pix2Pix-DepthEstimation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", 8 | "_kg_hide-output": true, 9 | "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import os\n", 14 | "import imageio\n", 15 | "import numpy as np\n", 16 | "import warnings\n", 17 | "warnings.filterwarnings('ignore',category=FutureWarning)\n", 18 | "import tensorflow as tf\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from glob import glob\n", 21 | "import cv2\n", 22 | "import shutil\n", 23 | "tf.logging.set_verbosity(tf.logging.ERROR)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "class Helpers():\n", 33 | " \n", 34 | " @staticmethod\n", 35 | " def normalize(images):\n", 36 | " return np.array(images)/127.5-1.0\n", 37 | " \n", 38 | " @staticmethod\n", 39 | " def unnormalize(images):\n", 40 | " return (0.5*np.array(images)+0.5)*255\n", 41 | " \n", 42 | " @staticmethod\n", 43 | " def resize(image, size):\n", 44 | " return np.array(cv2.resize(image, size))\n", 45 | " \n", 46 | " @staticmethod\n", 47 | " def split_images(image, is_testing):\n", 48 | " image = imageio.imread(image).astype(np.float)\n", 49 | " _, width, _ = image.shape\n", 50 | " half_width = int(width/2)\n", 51 | " source_image = image[:, half_width:, :]\n", 52 | " destination_image = image[:, :half_width, :]\n", 53 | " source_image = Helpers.resize(source_image, (IMAGE_SIZE, IMAGE_SIZE))\n", 54 | " destination_image = Helpers.resize(destination_image, (IMAGE_SIZE, IMAGE_SIZE))\n", 55 | " if not is_testing and np.random.random() > 0.5:\n", 56 | " source_image = np.fliplr(source_image)\n", 57 | " destination_image = np.fliplr(destination_image)\n", 58 | " return source_image, destination_image\n", 59 | " \n", 60 | " @staticmethod\n", 61 | " def new_dir(path):\n", 62 | " shutil.rmtree(path, ignore_errors=True)\n", 63 | " os.makedirs(path, exist_ok=True)\n", 64 | " \n", 65 | " @staticmethod\n", 66 | " def archive_output():\n", 67 | " shutil.make_archive(\"output\", \"zip\", \"./output\")\n", 68 | " \n", 69 | " @staticmethod\n", 70 | " def image_pairs(batch, is_testing):\n", 71 | " source_images, destination_images = [], []\n", 72 | " for image_path in batch:\n", 73 | " source_image, destination_image = Helpers.split_images(image_path, is_testing)\n", 74 | " source_images.append(source_image)\n", 75 | " destination_images.append(destination_image)\n", 76 | " return source_images, destination_images" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0", 84 | "_kg_hide-output": false, 85 | "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "# Requires following dataset structure:\n", 90 | "# dataset_name\n", 91 | "# └── dataset_name\n", 92 | "# ├── testing\n", 93 | "# │ └── ... (image files)\n", 94 | "# ├── testing_raw\n", 95 | "# │ ├── ... (image files)\n", 96 | "# ├── training\n", 97 | "# │ └── ... (image files)\n", 98 | "# └── validation (optional)\n", 99 | "# └── ... (image files)\n", 100 | "class DataLoader():\n", 101 | " \n", 102 | " def __init__(self, dataset_name=\"pix2pix-depth\"):\n", 103 | " self.dataset_name = dataset_name\n", 104 | " base_path = BASE_INPUT_PATH + self.dataset_name + \"/\" + self.dataset_name + \"/\"\n", 105 | " self.training_path = base_path + \"training/\"\n", 106 | " self.validation_path = base_path + \"validation/\"\n", 107 | " self.testing_path = base_path + \"testing/\"\n", 108 | " self.testing_raw_path = base_path + \"testing_raw/\"\n", 109 | "\n", 110 | " def load_random_data(self, data_size, is_testing=False):\n", 111 | " paths = glob(self.training_path+\"*\") if is_testing else glob(self.testing_path+\"*\")\n", 112 | " source_images, destination_images = Helpers.image_pairs(np.random.choice(paths, size=data_size), is_testing)\n", 113 | " return Helpers.normalize(source_images), Helpers.normalize(destination_images)\n", 114 | "\n", 115 | " def yield_batch(self, batch_size, is_testing=False):\n", 116 | " paths = glob(self.training_path+\"*\") if is_testing else glob(self.validation_path+\"*\")\n", 117 | " for i in range(int(len(paths)/batch_size)-1):\n", 118 | " batch = paths[i*batch_size:(i+1)*batch_size]\n", 119 | " source_images, destination_images = Helpers.image_pairs(batch, is_testing)\n", 120 | " yield Helpers.normalize(source_images), Helpers.normalize(destination_images)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "_uuid": "8ff0cb940babed054508cf32b3ef5d383302fa12" 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "# Model architecture from: https://phillipi.github.io/pix2pix/\n", 132 | "class Pix2Pix(): \n", 133 | " \n", 134 | " def __init__(self):\n", 135 | " Helpers.new_dir(BASE_OUTPUT_PATH + \"training/\")\n", 136 | " Helpers.new_dir(BASE_OUTPUT_PATH + \"training/losses/\")\n", 137 | "\n", 138 | " self.image_shape = (IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS)\n", 139 | " self.data_loader = DataLoader()\n", 140 | "\n", 141 | " patch = int(IMAGE_SIZE / 2**4)\n", 142 | " self.disc_patch = (patch, patch, 1)\n", 143 | "\n", 144 | " self.generator_filters = 64\n", 145 | " self.discriminator_filters = 64\n", 146 | " \n", 147 | " optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, BETA_1)\n", 148 | "\n", 149 | " self.discriminator = self.discriminator()\n", 150 | " self.discriminator.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"accuracy\"])\n", 151 | " self.generator = self.generator()\n", 152 | "\n", 153 | " source_image = tf.keras.layers.Input(shape=self.image_shape)\n", 154 | " destination_image = tf.keras.layers.Input(shape=self.image_shape)\n", 155 | " generated_image = self.generator(destination_image)\n", 156 | "\n", 157 | " self.discriminator.trainable = False\n", 158 | " valid = self.discriminator([generated_image, destination_image])\n", 159 | " self.combined = tf.keras.models.Model(inputs=[source_image, destination_image], outputs=[valid, generated_image])\n", 160 | " self.combined.compile(loss=[\"mse\", \"mae\"], loss_weights=[1, 100], optimizer=optimizer)\n", 161 | "\n", 162 | " def generator(self):\n", 163 | " def conv2d(layer_input, filters, bn=True):\n", 164 | " downsample = tf.keras.layers.Conv2D(filters, kernel_size=4, strides=2, padding=\"same\")(layer_input)\n", 165 | " downsample = tf.keras.layers.LeakyReLU(alpha=LEAKY_RELU_ALPHA)(downsample)\n", 166 | " if bn:\n", 167 | " downsample = tf.keras.layers.BatchNormalization(momentum=BN_MOMENTUM)(downsample)\n", 168 | " return downsample\n", 169 | "\n", 170 | " def deconv2d(layer_input, skip_input, filters, dropout_rate=0):\n", 171 | " upsample = tf.keras.layers.UpSampling2D(size=2)(layer_input)\n", 172 | " upsample = tf.keras.layers.Conv2D(filters, kernel_size=4, strides=1, padding=\"same\", activation=\"relu\")(upsample)\n", 173 | " if dropout_rate:\n", 174 | " upsample = tf.keras.layers.Dropout(dropout_rate)(upsample)\n", 175 | " upsample = tf.keras.layers.BatchNormalization(momentum=BN_MOMENTUM)(upsample)\n", 176 | " upsample = tf.keras.layers.Concatenate()([upsample, skip_input])\n", 177 | " return upsample\n", 178 | "\n", 179 | " downsample_0 = tf.keras.layers.Input(shape=self.image_shape)\n", 180 | " downsample_1 = conv2d(downsample_0, self.generator_filters, bn=False)\n", 181 | " downsample_2 = conv2d(downsample_1, self.generator_filters*2)\n", 182 | " downsample_3 = conv2d(downsample_2, self.generator_filters*4)\n", 183 | " downsample_4 = conv2d(downsample_3, self.generator_filters*8)\n", 184 | " downsample_5 = conv2d(downsample_4, self.generator_filters*8)\n", 185 | " downsample_6 = conv2d(downsample_5, self.generator_filters*8)\n", 186 | " downsample_7 = conv2d(downsample_6, self.generator_filters*8)\n", 187 | "\n", 188 | " upsample_1 = deconv2d(downsample_7, downsample_6, self.generator_filters*8)\n", 189 | " upsample_2 = deconv2d(upsample_1, downsample_5, self.generator_filters*8)\n", 190 | " upsample_3 = deconv2d(upsample_2, downsample_4, self.generator_filters*8)\n", 191 | " upsample_4 = deconv2d(upsample_3, downsample_3, self.generator_filters*4)\n", 192 | " upsample_5 = deconv2d(upsample_4, downsample_2, self.generator_filters*2)\n", 193 | " upsample_6 = deconv2d(upsample_5, downsample_1, self.generator_filters)\n", 194 | " upsample_7 = tf.keras.layers.UpSampling2D(size=2)(upsample_6)\n", 195 | " \n", 196 | " output_image = tf.keras.layers.Conv2D(IMAGE_CHANNELS, kernel_size=4, strides=1, padding=\"same\", activation=\"tanh\")(upsample_7)\n", 197 | " return tf.keras.models.Model(downsample_0, output_image)\n", 198 | "\n", 199 | " def discriminator(self):\n", 200 | " def discriminator_layer(layer_input, filters, bn=True):\n", 201 | " discriminator_layer = tf.keras.layers.Conv2D(filters, kernel_size=4, strides=2, padding=\"same\")(layer_input)\n", 202 | " discriminator_layer = tf.keras.layers.LeakyReLU(alpha=LEAKY_RELU_ALPHA)(discriminator_layer)\n", 203 | " if bn:\n", 204 | " discriminator_layer = tf.keras.layers.BatchNormalization(momentum=BN_MOMENTUM)(discriminator_layer)\n", 205 | " return discriminator_layer\n", 206 | "\n", 207 | " source_image = tf.keras.layers.Input(shape=self.image_shape)\n", 208 | " destination_image = tf.keras.layers.Input(shape=self.image_shape)\n", 209 | " combined_images = tf.keras.layers.Concatenate(axis=-1)([source_image, destination_image])\n", 210 | " discriminator_layer_1 = discriminator_layer(combined_images, self.discriminator_filters, bn=False)\n", 211 | " discriminator_layer_2 = discriminator_layer(discriminator_layer_1, self.discriminator_filters*2)\n", 212 | " discriminator_layer_3 = discriminator_layer(discriminator_layer_2, self.discriminator_filters*4)\n", 213 | " discriminator_layer_4 = discriminator_layer(discriminator_layer_3, self.discriminator_filters*8)\n", 214 | " validity = tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding=\"same\")(discriminator_layer_4)\n", 215 | " return tf.keras.models.Model([source_image, destination_image], validity)\n", 216 | " \n", 217 | " def preview_training_progress(self, epoch, size=3):\n", 218 | " def preview_outputs(epoch, size):\n", 219 | " source_images, destination_images = self.data_loader.load_random_data(size, is_testing=True)\n", 220 | " generated_images = self.generator.predict(destination_images)\n", 221 | " grid_image = None\n", 222 | " for i in range(size):\n", 223 | " row = Helpers.unnormalize(np.concatenate([destination_images[i], generated_images[i], source_images[i]], axis=1))\n", 224 | " if grid_image is None:\n", 225 | " grid_image = row\n", 226 | " else:\n", 227 | " grid_image = np.concatenate([grid_image, row], axis=0)\n", 228 | " plt.imshow(grid_image/255.0)\n", 229 | " plt.show()\n", 230 | " plt.close()\n", 231 | " grid_image = cv2.cvtColor(np.float32(grid_image), cv2.COLOR_RGB2BGR)\n", 232 | " cv2.imwrite(BASE_OUTPUT_PATH + \"training/ \" + str(epoch) + \".png\", grid_image)\n", 233 | " \n", 234 | " def preview_losses():\n", 235 | " def plot(title, data):\n", 236 | " plt.plot(data, alpha=0.6)\n", 237 | " plt.title(title + \"_\" + str(i))\n", 238 | " plt.savefig(BASE_OUTPUT_PATH + \"training/losses/\" + title + \"_\" + str(i) + \".png\")\n", 239 | " plt.close()\n", 240 | " for i, d in enumerate(self.d_losses):\n", 241 | " plot(\"discriminator\", d)\n", 242 | " for i, g in enumerate(self.g_losses):\n", 243 | " plot(\"generator\", g)\n", 244 | " \n", 245 | " preview_outputs(epoch, size)\n", 246 | " #preview_losses()\n", 247 | "\n", 248 | " def train(self):\n", 249 | " valid = np.ones((BATCH_SIZE,) + self.disc_patch)\n", 250 | " fake = np.zeros((BATCH_SIZE,) + self.disc_patch)\n", 251 | " self.d_losses = []\n", 252 | " self.g_losses = []\n", 253 | " self.preview_training_progress(0)\n", 254 | " for epoch in range(EPOCHS):\n", 255 | " epoch_d_losses = []\n", 256 | " epoch_g_losses = []\n", 257 | " for iteration, (source_images, destination_images) in enumerate(self.data_loader.yield_batch(BATCH_SIZE)):\n", 258 | " generated_images = self.generator.predict(destination_images)\n", 259 | " d_loss_real = self.discriminator.train_on_batch([source_images, destination_images], valid)\n", 260 | " d_loss_fake = self.discriminator.train_on_batch([generated_images, destination_images], fake)\n", 261 | " d_losses = 0.5 * np.add(d_loss_real, d_loss_fake)\n", 262 | " g_losses = self.combined.train_on_batch([source_images, destination_images], [valid, source_images])\n", 263 | " epoch_d_losses.append(d_losses)\n", 264 | " epoch_g_losses.append(g_losses)\n", 265 | " print(\"\\repoch: \" + str(epoch) \n", 266 | " +\", iteration: \"+ str(iteration) \n", 267 | " + \", d_losses: \" + str(d_losses) \n", 268 | " + \", g_losses: \" + str(g_losses)\n", 269 | " , sep=\" \", end=\" \", flush=True)\n", 270 | " self.d_losses.append(np.average(epoch_d_losses, axis=0))\n", 271 | " self.g_losses.append(np.average(epoch_g_losses, axis=0))\n", 272 | " self.preview_training_progress(epoch)\n", 273 | " \n", 274 | " def test(self):\n", 275 | " image_paths = glob(self.data_loader.testing_raw_path+\"*\")\n", 276 | " for image_path in image_paths:\n", 277 | " image = np.array(imageio.imread(image_path))\n", 278 | " image_normalized = Helpers.normalize(image)\n", 279 | " generated_batch = self.generator.predict(np.array([image_normalized]))\n", 280 | " concat = Helpers.unnormalize(np.concatenate([image_normalized, generated_batch[0]], axis=1))\n", 281 | " cv2.imwrite(BASE_OUTPUT_PATH+os.path.basename(image_path), cv2.cvtColor(np.float32(concat), cv2.COLOR_RGB2BGR))\n", 282 | " " 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": { 289 | "_uuid": "abbde5b9129e01502039e3a4638b678385bfbf77" 290 | }, 291 | "outputs": [], 292 | "source": [ 293 | "BASE_INPUT_PATH = \"\" # Kaggle: \"../input/pix2pix-depth/\" \n", 294 | "BASE_OUTPUT_PATH = \"./output/\"\n", 295 | "\n", 296 | "IMAGE_SIZE = 256\n", 297 | "IMAGE_CHANNELS = 3\n", 298 | "LEARNING_RATE = 0.00015\n", 299 | "BETA_1 = 0.5\n", 300 | "LEAKY_RELU_ALPHA = 0.2\n", 301 | "BN_MOMENTUM = 0.8\n", 302 | "EPOCHS = 50\n", 303 | "BATCH_SIZE = 32\n", 304 | "\n", 305 | "gan = Pix2Pix()\n", 306 | "gan.train()\n", 307 | "gan.test()" 308 | ] 309 | } 310 | ], 311 | "metadata": { 312 | "kernelspec": { 313 | "display_name": "Python 3", 314 | "language": "python", 315 | "name": "python3" 316 | }, 317 | "language_info": { 318 | "codemirror_mode": { 319 | "name": "ipython", 320 | "version": 3 321 | }, 322 | "file_extension": ".py", 323 | "mimetype": "text/x-python", 324 | "name": "python", 325 | "nbconvert_exporter": "python", 326 | "pygments_lexer": "ipython3", 327 | "version": "3.7.4" 328 | } 329 | }, 330 | "nbformat": 4, 331 | "nbformat_minor": 1 332 | } 333 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |