├── LICENSE ├── README.md ├── assets ├── Screenshot from 2022-01-25 18-32-45.png ├── paper.png └── prev.jpeg ├── index.html ├── notebooks ├── .ipynb_checkpoints │ ├── comp_(1)-checkpoint.ipynb │ ├── depth-checkpoint.ipynb │ └── object-segmentation-checkpoint.ipynb ├── depth.ipynb ├── object-segmentation.ipynb └── semantic-segmentation.ipynb ├── requirements.txt ├── results ├── depth_perseption │ ├── combine_images (14).jpg │ ├── d1.png │ ├── d2.png │ ├── d3.png │ ├── d4.png │ ├── d5.png │ └── d6.png ├── object-segmentation │ ├── combine_images (15).jpg │ ├── os1.png │ ├── os2.png │ ├── os3.png │ ├── os4.png │ ├── os5.png │ └── os6.png └── semantic-segmentation │ ├── combine_images (16).jpg │ ├── f1.png │ ├── f2.png │ ├── f3.png │ ├── f4.png │ ├── f5.png │ └── f6.png └── src ├── evaluate.py ├── model.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yigit Gunduc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensor-to-image 2 | [Website](https://yigitgunduc.github.io/tensor2image/) | [Arxiv](https://arxiv.org/abs/2110.08037) 3 | 4 | 5 | 6 | ## Abstract 7 | 8 | Transformers gain huge attention since they are first introduced and have 9 | a wide range of applications. Transformers start to take over all areas of 10 | deep learning and the Vision transformers paper also proved that they can 11 | be used for computer vision tasks. In this paper, we utilized a 12 | vision transformer-based custom-designed model, tensor-to-image, 13 | for the image to image translation. With the help of self-attention, 14 | our model was able to generalize and apply to different problems without 15 | a single modification 16 | 17 | ## Setup 18 | 19 | Clone the repo 20 | ```bash 21 | git clone https://github.com/yigitgunduc/tensor-to-image/ 22 | ``` 23 | 24 | Install requirements 25 | ```bash 26 | pip3 install -r requirements.txt 27 | ``` 28 | 29 | > For GPU support setup `TensorFlow >= 2.4.0` with `CUDA v11.0 or above` 30 | > - you can ignore this step if you are going to train on the CPU 31 | 32 | ## Training 33 | 34 | Train the model 35 | ```bash 36 | python3 src/train.py 37 | ``` 38 | Weights are saved after every epoch and can be found in `./weights/` 39 | 40 | ## Evaluating 41 | 42 | After you have trained the model you can test it against 3 different criteria 43 | (FID, Structural similarity, Inceptoin score). 44 | 45 | ```bash 46 | python3 src/evaluate.py path/to/weights 47 | ``` 48 | 49 | ## Datasets 50 | 51 | Implementation support 8 datasets for various tasks. 6 pix2pix datasets and two additional ones. 52 | 6 of the pix2pix dataset can be used by changing the `DATASET` variable on the `src/train.py` 53 | for the additional datasets please see `notebooks/object-segmentation.ipynb` and 54 | `notebooks/depth.ipynb` 55 | 56 | Dataset available thought the `src/train.py` 57 | 58 | - `cityscapes` 99 MB 59 | - `edges2handbags` 8.0 GB 60 | - `edges2shoes` 2.0 GB 61 | - `facades` 29 MB 62 | - `maps` 239 MB 63 | - `night2day` 1.9 GB 64 | 65 | Dataset available though the notebooks 66 | 67 | - `Oxford-IIIT Pets` 68 | - `RGB+D DATABASE` 69 | 70 | ## Cite 71 | If you use this code for your research, please cite our paper [Tensor-to-Image: Image-to-Image Translation with Vision Transformers](https://arxiv.org/abs/2110.08037) 72 | ``` 73 | @article{gunducc2021tensor, 74 | title={Tensor-to-Image: Image-to-Image Translation with Vision Transformers}, 75 | author={G{\"u}nd{\"u}{\c{c}}, Yi{\u{g}}it}, 76 | journal={arXiv preprint arXiv:2110.08037}, 77 | year={2021} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /assets/Screenshot from 2022-01-25 18-32-45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/assets/Screenshot from 2022-01-25 18-32-45.png -------------------------------------------------------------------------------- /assets/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/assets/paper.png -------------------------------------------------------------------------------- /assets/prev.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/assets/prev.jpeg -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 30 | 31 |
32 |
33 |

Tensor-to-Image: Image-to-Image Translation with Vision Transformers

34 |

Yigit Gunduc

35 |

[GitHub] [Paper]

36 | 37 |

Abstract

38 |
39 |

Transformers gain huge attention since they are first introduced and have a wide 40 | range of applications. Transformers start to take over all areas of deep 41 | learning and the Vision transformers paper also proved that they 42 | can be used for computer vision tasks. In this paper, we utilized a 43 | vision transformer-based custom-designed model, tensor-to-image, for the image 44 | to image translation. With the help of self-attention, our model 45 | was able to generalize and apply to different problems without a single 46 | modification

47 |
48 |
49 |

Code & Paper

50 |
51 | 52 | 53 | 54 |
55 |

Tensor-to-Image: Image-to-Image Translation with Vision Transformers

56 |

For the full and please see the GitHub repo

57 |
58 |
59 |
60 |

Cite

61 |
62 |
63 |

If you use this code for your research, please cite our paper Tensor-to-Image: Image-to-Image Translation with Vision Transformers

64 | 65 |
66 |               
67 |   @article{gunducc2021tensor,
68 |     title={Tensor-to-Image: Image-to-Image Translation with Vision Transformers}, 
69 |     author={G{\"u}nd{\"u}{\c{c}}, Yi{\u{g}}it},
70 |     journal={arXiv preprint arXiv:2110.08037},
71 |     year={2021}
72 |   }
73 |             
74 |
75 |
76 |
77 | 78 | 79 | -------------------------------------------------------------------------------- /notebooks/.ipynb_checkpoints/depth-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/" 9 | }, 10 | "id": "cHaQkNvlc1ki", 11 | "outputId": "23afd6fb-521c-4b67-d765-22812a8a5bab" 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "from google.colab import drive\n", 16 | "import os\n", 17 | "import tensorflow as tf\n", 18 | "import glob\n", 19 | "drive.mount('/content/gdrive')" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "colab": { 27 | "base_uri": "https://localhost:8080/" 28 | }, 29 | "id": "I_R2JVOec2PY", 30 | "outputId": "5ffe6ee9-550c-4453-961c-727307f38bf6" 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "!unzip /content/gdrive/MyDrive/indoor_test.zip" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "colab": { 42 | "base_uri": "https://localhost:8080/" 43 | }, 44 | "id": "qMZaGHOlc0SD", 45 | "outputId": "9f0ad718-d88a-4384-a315-fc8be1109a3e" 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "dataset_path = '../../../depth/dataset/test/LR'" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": { 56 | "id": "8Mx-LXgLc0SM" 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "input_paths = glob.glob(dataset_path + '/**/color/*.png')\n", 61 | "target_paths = glob.glob(dataset_path + '/**/depth_vi/*.png')" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "colab": { 69 | "base_uri": "https://localhost:8080/" 70 | }, 71 | "id": "xFbmIbP_c0SN", 72 | "outputId": "56e6ada1-e53b-4791-ec22-97c9b3045379" 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "print(target_paths)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "id": "4OVZgPjtc0SP" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "BUFFER_SIZE = 400\n", 88 | "EPOCHS = 100\n", 89 | "LAMBDA = 100\n", 90 | "BATCH_SIZE = 8\n", 91 | "IMG_WIDTH = 256\n", 92 | "IMG_HEIGHT = 256\n", 93 | "patch_size = 8\n", 94 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n", 95 | "projection_dim = 64\n", 96 | "embed_dim = 64\n", 97 | "num_heads = 2 \n", 98 | "ff_dim = 32" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "id": "g89DtSq_c0SQ" 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "real = []\n", 110 | "targets = []" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": { 117 | "id": "AvaXbs4lc0SR" 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "def load(path):\n", 122 | "\n", 123 | " image_path = path[:-12] + 'c.png'\n", 124 | " image_path = image_path.replace(\"depth_vi\", \"color\")\n", 125 | " depth_path = path[:-12] + 'depth_vi.png'\n", 126 | "\n", 127 | "\n", 128 | " input_image = tf.io.read_file(image_path)\n", 129 | " input_image = tf.image.decode_jpeg(input_image)\n", 130 | " \n", 131 | " target_image = tf.io.read_file(depth_path)\n", 132 | " target_image = tf.image.decode_jpeg(target_image)\n", 133 | " \n", 134 | " input_image = tf.cast(input_image, tf.float32)\n", 135 | " target_image = tf.cast(target_image, tf.float32)\n", 136 | "\n", 137 | "\n", 138 | " return input_image, target_image" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "id": "y8_ilo0Fc0SS" 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "def resize(input_image, real_image, height, width):\n", 150 | " input_image = tf.image.resize(input_image, [height, width],\n", 151 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", 152 | " real_image = tf.image.resize(real_image, [height, width],\n", 153 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", 154 | "\n", 155 | " return input_image, real_image" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": { 162 | "id": "opCaZjvRc0ST" 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "def normalize(input_image, target_image):\n", 167 | " input_image = input_image / 255\n", 168 | " target_image = target_image / 255\n", 169 | "\n", 170 | " return input_image, target_image" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "id": "bvQKHD3-c0SU" 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "def load_image_train(depth_path):\n", 182 | " input_image, target = load(depth_path)\n", 183 | " input_image, target = resize(input_image, target,\n", 184 | " IMG_HEIGHT, IMG_WIDTH)\n", 185 | " input_image, target = normalize(input_image, target)\n", 186 | "\n", 187 | " return input_image, target" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "id": "i48PBiP7c0SU" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "real = []\n", 199 | "targets = []\n", 200 | "import numpy as np\n", 201 | "for i in range(len(target_paths)):\n", 202 | " #inputs, target = load(target_paths[i])\n", 203 | " inputs, target = load_image_train(target_paths[i])\n", 204 | " #inputs, target = normalize(inputs, target)\n", 205 | " real.append(inputs)\n", 206 | " targets.append(target)\n", 207 | "\n", 208 | "real = np.array(real)\n", 209 | "targets = np.array(targets)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": { 216 | "id": "4yfKd4i_c0SV" 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "from matplotlib import pyplot as plt\n", 221 | "import numpy as np" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": { 228 | "colab": { 229 | "base_uri": "https://localhost:8080/", 230 | "height": 286 231 | }, 232 | "id": "1VCfus5zc0SV", 233 | "outputId": "c21309a4-f950-47d8-c27d-3520328f70bf" 234 | }, 235 | "outputs": [], 236 | "source": [ 237 | "plt.imshow(real[23])\n", 238 | "print(real[12].shape)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": { 245 | "colab": { 246 | "base_uri": "https://localhost:8080/", 247 | "height": 286 248 | }, 249 | "id": "vtOUxl5oc0SV", 250 | "outputId": "00f762f6-9e0f-46a7-d984-c3a1ec089b4f" 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "plt.imshow(targets[23].reshape(256, 256))\n", 255 | "print(targets[1].shape)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": { 262 | "id": "DVWWuvVic0SW" 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "import tensorflow as tf\n", 267 | "\n", 268 | "import os\n", 269 | "import time\n", 270 | "\n", 271 | "from matplotlib import pyplot as plt\n", 272 | "from IPython import display" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": { 279 | "id": "hm6uhABxc0SW" 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "def downsample(filters, size, apply_batchnorm=True):\n", 284 | " initializer = tf.random_normal_initializer(0., 0.02)\n", 285 | "\n", 286 | " result = tf.keras.Sequential()\n", 287 | " result.add(\n", 288 | " tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n", 289 | " kernel_initializer=initializer, use_bias=False))\n", 290 | "\n", 291 | " if apply_batchnorm:\n", 292 | " result.add(tf.keras.layers.BatchNormalization())\n", 293 | "\n", 294 | " result.add(tf.keras.layers.LeakyReLU())\n", 295 | "\n", 296 | " return result" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": { 303 | "id": "VqAxP9ayc0SX" 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "class Patches(tf.keras.layers.Layer):\n", 308 | " def __init__(self, patch_size):\n", 309 | " super(Patches, self).__init__()\n", 310 | " self.patch_size = patch_size\n", 311 | "\n", 312 | " def call(self, images):\n", 313 | " batch_size = tf.shape(images)[0]\n", 314 | " patches = tf.image.extract_patches(\n", 315 | " images=images,\n", 316 | " sizes=[1, self.patch_size, self.patch_size, 1],\n", 317 | " strides=[1, self.patch_size, self.patch_size, 1],\n", 318 | " rates=[1, 1, 1, 1],\n", 319 | " padding=\"SAME\",\n", 320 | " )\n", 321 | " patch_dims = patches.shape[-1]\n", 322 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n", 323 | " return patches" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": { 330 | "id": "n_mdkF59c0SX" 331 | }, 332 | "outputs": [], 333 | "source": [ 334 | "class PatchEncoder(tf.keras.layers.Layer):\n", 335 | " def __init__(self, num_patches, projection_dim):\n", 336 | " super(PatchEncoder, self).__init__()\n", 337 | " self.num_patches = num_patches\n", 338 | " self.projection = layers.Dense(units=projection_dim)\n", 339 | " self.position_embedding = layers.Embedding(\n", 340 | " input_dim=num_patches, output_dim=projection_dim\n", 341 | " )\n", 342 | "\n", 343 | " def call(self, patch):\n", 344 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n", 345 | " encoded = self.projection(patch) + self.position_embedding(positions)\n", 346 | " return encoded" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": { 353 | "id": "ZNDm9KXXc0SY" 354 | }, 355 | "outputs": [], 356 | "source": [ 357 | "class TransformerBlock(tf.keras.layers.Layer):\n", 358 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n", 359 | " super(TransformerBlock, self).__init__()\n", 360 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n", 361 | " self.ffn = tf.keras.Sequential(\n", 362 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", 363 | " )\n", 364 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n", 365 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n", 366 | " self.dropout1 = layers.Dropout(rate)\n", 367 | " self.dropout2 = layers.Dropout(rate)\n", 368 | "\n", 369 | " def call(self, inputs, training):\n", 370 | " attn_output = self.att(inputs, inputs)\n", 371 | " attn_output = self.dropout1(attn_output, training=training)\n", 372 | " out1 = self.layernorm1(inputs + attn_output)\n", 373 | " ffn_output = self.ffn(out1)\n", 374 | " ffn_output = self.dropout2(ffn_output, training=training)\n", 375 | " return self.layernorm2(out1 + ffn_output)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": { 382 | "id": "RPZFwD5PP4af" 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "from tensorflow import Tensor\n", 387 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n", 388 | " Add, AveragePooling2D, Flatten, Dense\n", 389 | "from tensorflow.keras.models import Model\n", 390 | "\n", 391 | "def relu_bn(inputs: Tensor) -> Tensor:\n", 392 | " relu = ReLU()(inputs)\n", 393 | " bn = BatchNormalization()(relu)\n", 394 | " return bn\n", 395 | "\n", 396 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n", 397 | " y = Conv2D(kernel_size=kernel_size,\n", 398 | " strides= (1 if not downsample else 2),\n", 399 | " filters=filters,\n", 400 | " padding=\"same\")(x)\n", 401 | " y = relu_bn(y)\n", 402 | " y = Conv2D(kernel_size=kernel_size,\n", 403 | " strides=1,\n", 404 | " filters=filters,\n", 405 | " padding=\"same\")(y)\n", 406 | "\n", 407 | " if downsample:\n", 408 | " x = Conv2D(kernel_size=1,\n", 409 | " strides=2,\n", 410 | " filters=filters,\n", 411 | " padding=\"same\")(x)\n", 412 | " out = Add()([x, y])\n", 413 | " out = relu_bn(out)\n", 414 | " return out" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": { 421 | "id": "lcQVzKBDc0SZ" 422 | }, 423 | "outputs": [], 424 | "source": [ 425 | "from tensorflow.keras import layers\n", 426 | "\n", 427 | "def Generator():\n", 428 | "\n", 429 | " inputs = layers.Input(shape=(256, 256, 3))\n", 430 | "\n", 431 | " patches = Patches(patch_size)(inputs)\n", 432 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n", 433 | "\n", 434 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n", 435 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 436 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 437 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 438 | "\n", 439 | " x = layers.Reshape((8, 8, 1024))(x)\n", 440 | "\n", 441 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 442 | " x = layers.BatchNormalization()(x)\n", 443 | " x = layers.LeakyReLU()(x)\n", 444 | "\n", 445 | " x = residual_block(x, downsample=False, filters=512)\n", 446 | "\n", 447 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 448 | " x = layers.BatchNormalization()(x)\n", 449 | " x = layers.LeakyReLU()(x)\n", 450 | "\n", 451 | " x = residual_block(x, downsample=False, filters=256)\n", 452 | "\n", 453 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 454 | " x = layers.BatchNormalization()(x)\n", 455 | " x = layers.LeakyReLU()(x)\n", 456 | " \n", 457 | " x = residual_block(x, downsample=False, filters=64)\n", 458 | "\n", 459 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)\n", 460 | " x = layers.BatchNormalization()(x)\n", 461 | " x = layers.LeakyReLU()(x)\n", 462 | "\n", 463 | " x = residual_block(x, downsample=False, filters=32)\n", 464 | "\n", 465 | " x = layers.Conv2D(1, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n", 466 | "\n", 467 | " return tf.keras.Model(inputs=inputs, outputs=x)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": { 474 | "colab": { 475 | "base_uri": "https://localhost:8080/", 476 | "height": 1000 477 | }, 478 | "id": "DBHxlKHvc0Sa", 479 | "outputId": "0b70c08f-2c2c-4d01-dd44-e340c0b088c0" 480 | }, 481 | "outputs": [], 482 | "source": [ 483 | "generator = Generator()\n", 484 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "metadata": { 491 | "colab": { 492 | "base_uri": "https://localhost:8080/" 493 | }, 494 | "id": "51J3xxeZRLEO", 495 | "outputId": "e3794664-9dcc-4d21-e38a-b08c27bdff4f" 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "generator.summary()" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": { 506 | "id": "CxG6_fP1c0Sa" 507 | }, 508 | "outputs": [], 509 | "source": [ 510 | "tf.config.run_functions_eagerly(False)" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": { 517 | "id": "TZn1NNgbc0Sb" 518 | }, 519 | "outputs": [], 520 | "source": [ 521 | "loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": { 528 | "id": "tYCaFUoGc0Sb" 529 | }, 530 | "outputs": [], 531 | "source": [ 532 | "def generator_loss(disc_generated_output, gen_output, target):\n", 533 | " gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)\n", 534 | "\n", 535 | " # mean absolute error\n", 536 | " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", 537 | "\n", 538 | " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n", 539 | "\n", 540 | " return total_gen_loss, gan_loss, l1_loss" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": null, 546 | "metadata": { 547 | "id": "lw8T5T3Ac0Sd" 548 | }, 549 | "outputs": [], 550 | "source": [ 551 | "tf.config.run_functions_eagerly(True)" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": null, 557 | "metadata": { 558 | "id": "_Qhap2DDc0Sd" 559 | }, 560 | "outputs": [], 561 | "source": [ 562 | "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "metadata": { 569 | "id": "Gl9RqSOHc0Se" 570 | }, 571 | "outputs": [], 572 | "source": [ 573 | "def generate_images(model, test_input, tar):\n", 574 | " prediction = model(test_input, training=True)\n", 575 | " plt.figure(figsize=(15, 15))\n", 576 | "\n", 577 | " display_list = [test_input[0], np.array(tar[0]).reshape(256, 256), np.array(prediction[0]).reshape(256, 256)]\n", 578 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", 579 | "\n", 580 | " for i in range(3):\n", 581 | " plt.subplot(1, 3, i+1)\n", 582 | " plt.title(title[i])\n", 583 | " # getting the pixel values between [0, 1] to plot it.\n", 584 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n", 585 | " plt.axis('off')\n", 586 | " plt.show()\n", 587 | "\n", 588 | "def generate_batch_images(model, test_input, tar):\n", 589 | " for i in range(len(test_input)):\n", 590 | " prediction = model(test_input, training=True)\n", 591 | " plt.figure(figsize=(15, 15))\n", 592 | "\n", 593 | " display_list = [test_input[i], tar[i], prediction[i]]\n", 594 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", 595 | "\n", 596 | " for i in range(3):\n", 597 | " plt.subplot(1, 3, i+1)\n", 598 | " plt.title(title[i])\n", 599 | " # getting the pixel values between [0, 1] to plot it.\n", 600 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n", 601 | " plt.axis('off')\n", 602 | " plt.show()" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": { 609 | "id": "N2M-Jbjvc0Se" 610 | }, 611 | "outputs": [], 612 | "source": [ 613 | "@tf.function\n", 614 | "def train_step(input_image, target):\n", 615 | " with tf.device('/device:GPU:0'):\n", 616 | " with tf.GradientTape() as gen_tape:\n", 617 | " gen_output = generator(input_image, training=True)\n", 618 | " gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", 619 | " \n", 620 | "\n", 621 | " generator_gradients = gen_tape.gradient(gen_total_loss,\n", 622 | " generator.trainable_variables)\n", 623 | " generator_optimizer.apply_gradients(zip(generator_gradients,\n", 624 | " generator.trainable_variables))" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "metadata": { 631 | "id": "5wOgyEJmc0Se" 632 | }, 633 | "outputs": [], 634 | "source": [ 635 | "def fit(train_ds, epochs, test_ds):\n", 636 | " for epoch in range(epochs):\n", 637 | " start = time.time()\n", 638 | "\n", 639 | " display.clear_output(wait=True)\n", 640 | "\n", 641 | " print(\"Epoch: \", epoch)\n", 642 | "\n", 643 | " # Train\n", 644 | " for n, (input_image, target) in train_ds.enumerate():\n", 645 | " print('.', end='')\n", 646 | " if (n+1) % 100 == 0:\n", 647 | " print()\n", 648 | " train_step(input_image, target)\n", 649 | " print()\n", 650 | "\n", 651 | " generator.save_weights(f'depth-gen-weights.h5')" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "metadata": { 658 | "id": "z4Kq8t1kc0Se" 659 | }, 660 | "outputs": [], 661 | "source": [ 662 | "train_dataset = tf.data.Dataset.from_tensor_slices((real, targets))\n", 663 | "\n", 664 | "train_dataset = train_dataset.batch(BATCH_SIZE)" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": null, 670 | "metadata": { 671 | "colab": { 672 | "base_uri": "https://localhost:8080/" 673 | }, 674 | "id": "B1SXMOPoc0Se", 675 | "outputId": "ee25b332-c08f-4ec4-eb15-1d59a4e896b2" 676 | }, 677 | "outputs": [], 678 | "source": [ 679 | "fit(train_dataset, 10000, train_dataset)" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "metadata": { 686 | "id": "6H20taNNc0Sf" 687 | }, 688 | "outputs": [], 689 | "source": [ 690 | "generator.save_weights('gen-depth-weights.h5')" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "metadata": { 697 | "colab": { 698 | "base_uri": "https://localhost:8080/", 699 | "height": 1000 700 | }, 701 | "id": "9mSLHL9Ac0Sf", 702 | "outputId": "44e8b2a6-eec6-4041-c7f9-87da233911ba" 703 | }, 704 | "outputs": [], 705 | "source": [ 706 | "for example_input, example_target in train_dataset.take(54):\n", 707 | " generate_images(generator, example_input, example_target)" 708 | ] 709 | } 710 | ], 711 | "metadata": { 712 | "accelerator": "GPU", 713 | "colab": { 714 | "name": "image2image_depth-res.ipynb", 715 | "provenance": [] 716 | }, 717 | "kernelspec": { 718 | "display_name": "Python 3", 719 | "language": "python", 720 | "name": "python3" 721 | }, 722 | "language_info": { 723 | "codemirror_mode": { 724 | "name": "ipython", 725 | "version": 3 726 | }, 727 | "file_extension": ".py", 728 | "mimetype": "text/x-python", 729 | "name": "python", 730 | "nbconvert_exporter": "python", 731 | "pygments_lexer": "ipython3", 732 | "version": "3.8.10" 733 | } 734 | }, 735 | "nbformat": 4, 736 | "nbformat_minor": 1 737 | } 738 | -------------------------------------------------------------------------------- /notebooks/depth.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# download the depth dataset from https://dimlrgbd.github.io/" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "cHaQkNvlc1ki", 20 | "outputId": "23afd6fb-521c-4b67-d765-22812a8a5bab" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import os\n", 25 | "import time\n", 26 | "import glob\n", 27 | "import numpy as np\n", 28 | "import tensorflow as tf\n", 29 | "from IPython import display\n", 30 | "from matplotlib import pyplot as plt" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "colab": { 38 | "base_uri": "https://localhost:8080/" 39 | }, 40 | "id": "qMZaGHOlc0SD", 41 | "outputId": "9f0ad718-d88a-4384-a315-fc8be1109a3e" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "dataset_path = 'depth/dataset/train/LR'" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": { 52 | "id": "8Mx-LXgLc0SM" 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "input_paths = glob.glob(dataset_path + '/**/color/*.png')\n", 57 | "target_paths = glob.glob(dataset_path + '/**/depth_vi/*.png')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "id": "4OVZgPjtc0SP" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "BUFFER_SIZE = 400\n", 69 | "EPOCHS = 100\n", 70 | "LAMBDA = 100\n", 71 | "BATCH_SIZE = 8\n", 72 | "IMG_WIDTH = 256\n", 73 | "IMG_HEIGHT = 256\n", 74 | "patch_size = 8\n", 75 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n", 76 | "projection_dim = 64\n", 77 | "embed_dim = 64\n", 78 | "num_heads = 2 \n", 79 | "ff_dim = 32" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "id": "g89DtSq_c0SQ" 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "real = []\n", 91 | "targets = []" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "id": "AvaXbs4lc0SR" 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "def load(path):\n", 103 | "\n", 104 | " image_path = path[:-12] + 'c.png'\n", 105 | " image_path = image_path.replace(\"depth_vi\", \"color\")\n", 106 | " depth_path = path[:-12] + 'depth_vi.png'\n", 107 | "\n", 108 | "\n", 109 | " input_image = tf.io.read_file(image_path)\n", 110 | " input_image = tf.image.decode_jpeg(input_image)\n", 111 | " \n", 112 | " target_image = tf.io.read_file(depth_path)\n", 113 | " target_image = tf.image.decode_jpeg(target_image)\n", 114 | " \n", 115 | " input_image = tf.cast(input_image, tf.float32)\n", 116 | " target_image = tf.cast(target_image, tf.float32)\n", 117 | "\n", 118 | "\n", 119 | " return input_image, target_image" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "id": "y8_ilo0Fc0SS" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "def resize(input_image, real_image, height, width):\n", 131 | " input_image = tf.image.resize(input_image, [height, width],\n", 132 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", 133 | " real_image = tf.image.resize(real_image, [height, width],\n", 134 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", 135 | "\n", 136 | " return input_image, real_image" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "id": "opCaZjvRc0ST" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "def normalize(input_image, target_image):\n", 148 | " input_image = input_image / 255\n", 149 | " target_image = target_image / 255\n", 150 | "\n", 151 | " return input_image, target_image" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": { 158 | "id": "bvQKHD3-c0SU" 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "def load_image_train(depth_path):\n", 163 | " input_image, target = load(depth_path)\n", 164 | " input_image, target = resize(input_image, target,\n", 165 | " IMG_HEIGHT, IMG_WIDTH)\n", 166 | " input_image, target = normalize(input_image, target)\n", 167 | "\n", 168 | " return input_image, target" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": { 175 | "id": "i48PBiP7c0SU" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "real = []\n", 180 | "targets = []\n", 181 | "\n", 182 | "for i in range(len(target_paths)):\n", 183 | " #inputs, target = load(target_paths[i])\n", 184 | " inputs, target = load_image_train(target_paths[i])\n", 185 | " #inputs, target = normalize(inputs, target)\n", 186 | " real.append(inputs)\n", 187 | " targets.append(target)\n", 188 | "\n", 189 | "real = np.array(real)\n", 190 | "targets = np.array(targets)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "id": "4yfKd4i_c0SV" 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "from matplotlib import pyplot as plt\n", 202 | "import numpy as np" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "colab": { 210 | "base_uri": "https://localhost:8080/", 211 | "height": 286 212 | }, 213 | "id": "1VCfus5zc0SV", 214 | "outputId": "c21309a4-f950-47d8-c27d-3520328f70bf" 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "plt.imshow(real[23])\n", 219 | "print(real[12].shape)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": { 226 | "colab": { 227 | "base_uri": "https://localhost:8080/", 228 | "height": 286 229 | }, 230 | "id": "vtOUxl5oc0SV", 231 | "outputId": "00f762f6-9e0f-46a7-d984-c3a1ec089b4f" 232 | }, 233 | "outputs": [], 234 | "source": [ 235 | "plt.imshow(targets[23].reshape(256, 256))\n", 236 | "print(targets[1].shape)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": { 243 | "id": "hm6uhABxc0SW" 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "def downsample(filters, size, apply_batchnorm=True):\n", 248 | " initializer = tf.random_normal_initializer(0., 0.02)\n", 249 | "\n", 250 | " result = tf.keras.Sequential()\n", 251 | " result.add(\n", 252 | " tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n", 253 | " kernel_initializer=initializer, use_bias=False))\n", 254 | "\n", 255 | " if apply_batchnorm:\n", 256 | " result.add(tf.keras.layers.BatchNormalization())\n", 257 | "\n", 258 | " result.add(tf.keras.layers.LeakyReLU())\n", 259 | "\n", 260 | " return result" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": { 267 | "id": "VqAxP9ayc0SX" 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "class Patches(tf.keras.layers.Layer):\n", 272 | " def __init__(self, patch_size):\n", 273 | " super(Patches, self).__init__()\n", 274 | " self.patch_size = patch_size\n", 275 | "\n", 276 | " def call(self, images):\n", 277 | " batch_size = tf.shape(images)[0]\n", 278 | " patches = tf.image.extract_patches(\n", 279 | " images=images,\n", 280 | " sizes=[1, self.patch_size, self.patch_size, 1],\n", 281 | " strides=[1, self.patch_size, self.patch_size, 1],\n", 282 | " rates=[1, 1, 1, 1],\n", 283 | " padding=\"SAME\",\n", 284 | " )\n", 285 | " patch_dims = patches.shape[-1]\n", 286 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n", 287 | " return patches" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": { 294 | "id": "n_mdkF59c0SX" 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "class PatchEncoder(tf.keras.layers.Layer):\n", 299 | " def __init__(self, num_patches, projection_dim):\n", 300 | " super(PatchEncoder, self).__init__()\n", 301 | " self.num_patches = num_patches\n", 302 | " self.projection = layers.Dense(units=projection_dim)\n", 303 | " self.position_embedding = layers.Embedding(\n", 304 | " input_dim=num_patches, output_dim=projection_dim\n", 305 | " )\n", 306 | "\n", 307 | " def call(self, patch):\n", 308 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n", 309 | " encoded = self.projection(patch) + self.position_embedding(positions)\n", 310 | " return encoded" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": { 317 | "id": "ZNDm9KXXc0SY" 318 | }, 319 | "outputs": [], 320 | "source": [ 321 | "class TransformerBlock(tf.keras.layers.Layer):\n", 322 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n", 323 | " super(TransformerBlock, self).__init__()\n", 324 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n", 325 | " self.ffn = tf.keras.Sequential(\n", 326 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", 327 | " )\n", 328 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n", 329 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n", 330 | " self.dropout1 = layers.Dropout(rate)\n", 331 | " self.dropout2 = layers.Dropout(rate)\n", 332 | "\n", 333 | " def call(self, inputs, training):\n", 334 | " attn_output = self.att(inputs, inputs)\n", 335 | " attn_output = self.dropout1(attn_output, training=training)\n", 336 | " out1 = self.layernorm1(inputs + attn_output)\n", 337 | " ffn_output = self.ffn(out1)\n", 338 | " ffn_output = self.dropout2(ffn_output, training=training)\n", 339 | " return self.layernorm2(out1 + ffn_output)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": { 346 | "id": "RPZFwD5PP4af" 347 | }, 348 | "outputs": [], 349 | "source": [ 350 | "from tensorflow import Tensor\n", 351 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n", 352 | " Add, AveragePooling2D, Flatten, Dense\n", 353 | "from tensorflow.keras.models import Model\n", 354 | "\n", 355 | "def relu_bn(inputs: Tensor) -> Tensor:\n", 356 | " relu = ReLU()(inputs)\n", 357 | " bn = BatchNormalization()(relu)\n", 358 | " return bn\n", 359 | "\n", 360 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n", 361 | " y = Conv2D(kernel_size=kernel_size,\n", 362 | " strides= (1 if not downsample else 2),\n", 363 | " filters=filters,\n", 364 | " padding=\"same\")(x)\n", 365 | " y = relu_bn(y)\n", 366 | " y = Conv2D(kernel_size=kernel_size,\n", 367 | " strides=1,\n", 368 | " filters=filters,\n", 369 | " padding=\"same\")(y)\n", 370 | "\n", 371 | " if downsample:\n", 372 | " x = Conv2D(kernel_size=1,\n", 373 | " strides=2,\n", 374 | " filters=filters,\n", 375 | " padding=\"same\")(x)\n", 376 | " out = Add()([x, y])\n", 377 | " out = relu_bn(out)\n", 378 | " return out" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "id": "lcQVzKBDc0SZ" 386 | }, 387 | "outputs": [], 388 | "source": [ 389 | "from tensorflow.keras import layers\n", 390 | "\n", 391 | "def Generator():\n", 392 | "\n", 393 | " inputs = layers.Input(shape=(256, 256, 3))\n", 394 | "\n", 395 | " patches = Patches(patch_size)(inputs)\n", 396 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n", 397 | "\n", 398 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n", 399 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 400 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 401 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 402 | "\n", 403 | " x = layers.Reshape((8, 8, 1024))(x)\n", 404 | "\n", 405 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 406 | " x = layers.BatchNormalization()(x)\n", 407 | " x = layers.LeakyReLU()(x)\n", 408 | "\n", 409 | " x = residual_block(x, downsample=False, filters=512)\n", 410 | "\n", 411 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 412 | " x = layers.BatchNormalization()(x)\n", 413 | " x = layers.LeakyReLU()(x)\n", 414 | "\n", 415 | " x = residual_block(x, downsample=False, filters=256)\n", 416 | "\n", 417 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 418 | " x = layers.BatchNormalization()(x)\n", 419 | " x = layers.LeakyReLU()(x)\n", 420 | " \n", 421 | " x = residual_block(x, downsample=False, filters=64)\n", 422 | "\n", 423 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)\n", 424 | " x = layers.BatchNormalization()(x)\n", 425 | " x = layers.LeakyReLU()(x)\n", 426 | "\n", 427 | " x = residual_block(x, downsample=False, filters=32)\n", 428 | "\n", 429 | " x = layers.Conv2D(1, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n", 430 | "\n", 431 | " return tf.keras.Model(inputs=inputs, outputs=x)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "metadata": { 438 | "colab": { 439 | "base_uri": "https://localhost:8080/", 440 | "height": 1000 441 | }, 442 | "id": "DBHxlKHvc0Sa", 443 | "outputId": "0b70c08f-2c2c-4d01-dd44-e340c0b088c0" 444 | }, 445 | "outputs": [], 446 | "source": [ 447 | "generator = Generator()\n", 448 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n", 449 | "generator.summary()\n", 450 | "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": { 457 | "id": "lw8T5T3Ac0Sd" 458 | }, 459 | "outputs": [], 460 | "source": [ 461 | "tf.config.run_functions_eagerly(True)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": { 468 | "id": "Gl9RqSOHc0Se" 469 | }, 470 | "outputs": [], 471 | "source": [ 472 | "def generate_images(model, test_input, tar):\n", 473 | " prediction = model(test_input, training=True)\n", 474 | " plt.figure(figsize=(15, 15))\n", 475 | "\n", 476 | " display_list = [test_input[0], np.array(tar[0]).reshape(256, 256), np.array(prediction[0]).reshape(256, 256)]\n", 477 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", 478 | "\n", 479 | " for i in range(3):\n", 480 | " plt.subplot(1, 3, i+1)\n", 481 | " plt.title(title[i])\n", 482 | " # getting the pixel values between [0, 1] to plot it.\n", 483 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n", 484 | " plt.axis('off')\n", 485 | " plt.show()\n", 486 | "\n", 487 | "def generate_batch_images(model, test_input, tar):\n", 488 | " for i in range(len(test_input)):\n", 489 | " prediction = model(test_input, training=True)\n", 490 | " plt.figure(figsize=(15, 15))\n", 491 | "\n", 492 | " display_list = [test_input[i], tar[i], prediction[i]]\n", 493 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", 494 | "\n", 495 | " for i in range(3):\n", 496 | " plt.subplot(1, 3, i+1)\n", 497 | " plt.title(title[i])\n", 498 | " # getting the pixel values between [0, 1] to plot it.\n", 499 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n", 500 | " plt.axis('off')\n", 501 | " plt.show()" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": { 508 | "id": "N2M-Jbjvc0Se" 509 | }, 510 | "outputs": [], 511 | "source": [ 512 | "@tf.function\n", 513 | "def train_step(input_image, target):\n", 514 | " with tf.device('/device:GPU:0'):\n", 515 | " with tf.GradientTape() as gen_tape:\n", 516 | " gen_output = generator(input_image, training=True)\n", 517 | " gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", 518 | " \n", 519 | "\n", 520 | " generator_gradients = gen_tape.gradient(gen_total_loss,\n", 521 | " generator.trainable_variables)\n", 522 | " generator_optimizer.apply_gradients(zip(generator_gradients,\n", 523 | " generator.trainable_variables))" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": { 530 | "id": "5wOgyEJmc0Se" 531 | }, 532 | "outputs": [], 533 | "source": [ 534 | "def fit(train_ds, epochs, test_ds):\n", 535 | " for epoch in range(epochs):\n", 536 | " start = time.time()\n", 537 | "\n", 538 | " display.clear_output(wait=True)\n", 539 | "\n", 540 | " print(\"Epoch: \", epoch)\n", 541 | "\n", 542 | " # Train\n", 543 | " for n, (input_image, target) in train_ds.enumerate():\n", 544 | " print('.', end='')\n", 545 | " if (n+1) % 100 == 0:\n", 546 | " print()\n", 547 | " train_step(input_image, target)\n", 548 | " print()\n", 549 | "\n", 550 | " generator.save_weights(f'depth-weights.h5')" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "metadata": { 557 | "id": "z4Kq8t1kc0Se" 558 | }, 559 | "outputs": [], 560 | "source": [ 561 | "train_dataset = tf.data.Dataset.from_tensor_slices((real, targets))\n", 562 | "\n", 563 | "train_dataset = train_dataset.batch(BATCH_SIZE)" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": { 570 | "colab": { 571 | "base_uri": "https://localhost:8080/" 572 | }, 573 | "id": "B1SXMOPoc0Se", 574 | "outputId": "ee25b332-c08f-4ec4-eb15-1d59a4e896b2" 575 | }, 576 | "outputs": [], 577 | "source": [ 578 | "fit(train_dataset, EPOCHS, train_dataset)" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": { 585 | "id": "6H20taNNc0Sf" 586 | }, 587 | "outputs": [], 588 | "source": [ 589 | "generator.save_weights('gen-depth-weights.h5')" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": { 596 | "colab": { 597 | "base_uri": "https://localhost:8080/", 598 | "height": 1000 599 | }, 600 | "id": "9mSLHL9Ac0Sf", 601 | "outputId": "44e8b2a6-eec6-4041-c7f9-87da233911ba" 602 | }, 603 | "outputs": [], 604 | "source": [ 605 | "for example_input, example_target in train_dataset.take(54):\n", 606 | " generate_images(generator, example_input, example_target)" 607 | ] 608 | } 609 | ], 610 | "metadata": { 611 | "accelerator": "GPU", 612 | "colab": { 613 | "name": "image2image_depth-res.ipynb", 614 | "provenance": [] 615 | }, 616 | "kernelspec": { 617 | "display_name": "Python 3", 618 | "language": "python", 619 | "name": "python3" 620 | }, 621 | "language_info": { 622 | "codemirror_mode": { 623 | "name": "ipython", 624 | "version": 3 625 | }, 626 | "file_extension": ".py", 627 | "mimetype": "text/x-python", 628 | "name": "python", 629 | "nbconvert_exporter": "python", 630 | "pygments_lexer": "ipython3", 631 | "version": "3.8.10" 632 | } 633 | }, 634 | "nbformat": 4, 635 | "nbformat_minor": 1 636 | } 637 | -------------------------------------------------------------------------------- /notebooks/object-segmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "YfIk2es3hJEd" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf\n", 12 | "\n", 13 | "import os\n", 14 | "import time\n", 15 | "from matplotlib import pyplot as plt\n", 16 | "from IPython import display\n", 17 | "import tensorflow_datasets as tfds" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": { 24 | "id": "2CbTEt448b4R" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "BUFFER_SIZE = 400\n", 29 | "EPOCHS = 100\n", 30 | "LAMBDA = 100\n", 31 | "DATASET = 'seg'\n", 32 | "BATCH_SIZE = 32\n", 33 | "IMG_WIDTH = 128\n", 34 | "IMG_HEIGHT = 128\n", 35 | "patch_size = 8\n", 36 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n", 37 | "projection_dim = 64\n", 38 | "embed_dim = 64\n", 39 | "num_heads = 2 \n", 40 | "ff_dim = 32\n", 41 | "\n", 42 | "assert IMG_WIDTH == IMG_HEIGHT, \"image width and image height must have same dims\"\n", 43 | "\n", 44 | "tf.config.run_functions_eagerly(False)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "colab": { 52 | "base_uri": "https://localhost:8080/", 53 | "height": 347, 54 | "referenced_widgets": [ 55 | "e6d3b16d24cd4468b68af5be44eeaa46", 56 | "8fd55e55c24e48afb223ff8b7422a546", 57 | "954933d3187645c6bd191289c122a6b5", 58 | "1c7b1be085354daa90f7491366bf3f26", 59 | "84132b5022db442183e409973de11d67", 60 | "f60bacd17c604608a7fd80a06a305bdf", 61 | "112fde1fafb042c7abe870a471f69cb4", 62 | "2d886f8a4991438a97adb7f228c0b247", 63 | "2127f1922264401dbca0ac6365d283af", 64 | "5c6312228c08467186d56dba16c4028c", 65 | "642dffa8e9374234bbce51d29e75c7ea", 66 | "4f1f9e0926cd45e4aa4e813b32b1d7db", 67 | "e292c1f1791e4b349a931400f52e980c", 68 | "80d7be7130b8468c8ab3bcb095dc36bf", 69 | "5cf59228cd5844a8abbe62cb897ce431", 70 | "605b204e45bb401e97270eba6eceb351", 71 | "bbffba56f53b4b68bdcba5875e7c2f07", 72 | "74034286026f457892c06d6025286628", 73 | "0e398791e69f4c3c9f3b8f8928ffac94", 74 | "d8118a57a2814d19b2fde27d2452c84e", 75 | "7e80e3f85f4c4937b4e433f5d0cf8651", 76 | "4e313c0fc9a34434ad8828f9f7d51245", 77 | "a625ba1ffcc340c5a4f042be0b4877c3", 78 | "5695c6a28a5e47008227462f5ade5c9b", 79 | "83c9ad4ed8bb4aa9804a23a330da4d8c", 80 | "2d19d5c44a0348768a9ff242aa199119", 81 | "c94bf43a065b44c182d8af0c717e92ca", 82 | "cd600dd278a04f1fabc50fd0e9639fc4", 83 | "231b5248ecd746d0939739f6a42db75e", 84 | "bffbcc4727fd4a2fa5694aea1680af65", 85 | "a7bdbd085e5c43e3b4572f53175f61b6", 86 | "6abe70e78dce42a091dca068f5af4696", 87 | "2c75ae91013c455e8cedafddcd12052f", 88 | "a6aa784cb2a14bc29ed56a7e29857061", 89 | "c11d62582c59442e9cf388a6fbacaad7", 90 | "fc6c6d3b63634d1f9f58cebbd50f4793", 91 | "5979ebdad11f4f6f941b32ba4a509416", 92 | "2638e1b29d744f6b8a147f8bfa6f89d4", 93 | "7e7891071f7c4f5e87d876da643a3045", 94 | "12aae5a01c5a462e999735d848d3e354", 95 | "1612ca06675349d4b76e5378a129eefb", 96 | "81e3fa193c364c0f90b1dc7bca808eb9", 97 | "8aebd4d81c6a47f78c84602fcef1249c", 98 | "aa67ec0f9e06454b99dea3b324bfaeb2", 99 | "bd9ea3399660446897f60580a39588f2", 100 | "69b5d3f908c748509302621147b3517b", 101 | "3e2b935749cb42aeac6507c0f4697295", 102 | "c985393de22c4f97a7219f53642a38e6", 103 | "600a3e8becc84abeadc682bae8db52d5", 104 | "97c62988c2e1454db66b33ab880f08d1", 105 | "ca89e296533745f7824bd9f91006d162", 106 | "fad326679299483f9cfcc800bd4aa549", 107 | "faee640e94db490f8cd61481aa74a92a", 108 | "6ba3d098f19b4919a06ecca5b3763596", 109 | "f5b8dec6698746268a69122340b1f989", 110 | "1689a79e51304012976295a13bac17cf" 111 | ] 112 | }, 113 | "id": "Kn-k8kTXuAlv", 114 | "outputId": "6cc83593-137e-429a-e2ef-060e8d07d41a" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "id": "aO9ZAGH5K3SY" 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "def normalize(input_image, input_mask):\n", 130 | " input_image = tf.cast(input_image, tf.float32) / 255.0\n", 131 | " input_mask -= 1\n", 132 | " return input_image, input_mask" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "id": "4OLHMpsQ5aOv" 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "@tf.function\n", 144 | "def load_image_train(datapoint):\n", 145 | " input_image = tf.image.resize(datapoint['image'], (128, 128))\n", 146 | " input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n", 147 | "\n", 148 | " if tf.random.uniform(()) > 0.5:\n", 149 | " input_image = tf.image.flip_left_right(input_image)\n", 150 | " input_mask = tf.image.flip_left_right(input_mask)\n", 151 | "\n", 152 | " input_image, input_mask = normalize(input_image, input_mask)\n", 153 | "\n", 154 | " return input_image, input_mask" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "id": "rwwYQpu9FzDu" 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "def load_image_test(datapoint):\n", 166 | " input_image = tf.image.resize(datapoint['image'], (128, 128))\n", 167 | " input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n", 168 | "\n", 169 | " input_image, input_mask = normalize(input_image, input_mask)\n", 170 | "\n", 171 | " return input_image, input_mask" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": { 178 | "colab": { 179 | "base_uri": "https://localhost:8080/" 180 | }, 181 | "id": "Yn3IwqhiIszt", 182 | "outputId": "e52589a1-2d3a-42c8-ec89-7a08e27b9538" 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "TRAIN_LENGTH = info.splits['train'].num_examples\n", 187 | "STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "id": "muhR2cgbLKWW" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)\n", 199 | "test = dataset['test'].map(load_image_test)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "id": "fVQOjcPVLrUc" 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()\n", 211 | "train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)\n", 212 | "test_dataset = test.batch(BATCH_SIZE)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": { 219 | "id": "n0OGdi6D92kM" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "def display(display_list):\n", 224 | " plt.figure(figsize=(15, 15))\n", 225 | "\n", 226 | " title = ['Input Image', 'True Mask', 'Predicted Mask']\n", 227 | "\n", 228 | " for i in range(len(display_list)):\n", 229 | " plt.subplot(1, len(display_list), i+1)\n", 230 | " plt.title(title[i])\n", 231 | " plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))\n", 232 | " plt.axis('off')\n", 233 | " plt.show()" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": { 240 | "colab": { 241 | "base_uri": "https://localhost:8080/", 242 | "height": 427 243 | }, 244 | "id": "tyaP4hLJ8b4W", 245 | "outputId": "ba14a1a1-ecc9-4fe1-f512-10e0793cd921" 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "for image, mask in train.take(1):\n", 250 | " sample_image, sample_mask = image, mask\n", 251 | "display([sample_image, sample_mask])" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": { 258 | "id": "VB3Z6D_zKSru" 259 | }, 260 | "outputs": [], 261 | "source": [ 262 | "def create_mask(pred_mask):\n", 263 | " pred_mask = tf.argmax(pred_mask, axis=-1)\n", 264 | " pred_mask = pred_mask[..., tf.newaxis]\n", 265 | " return pred_mask[0]" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": { 272 | "id": "SQHmYSmk8b4b" 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "def show_predictions(dataset=None, num=1):\n", 277 | " if dataset:\n", 278 | " for image, mask in dataset.take(num):\n", 279 | " pred_mask = generator.predict(image)\n", 280 | " display([image[0], mask[0], create_mask(pred_mask)])\n", 281 | " else:\n", 282 | " display([sample_image, sample_mask,\n", 283 | " create_mask(generator.predict(sample_image[tf.newaxis, ...]))])" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": { 290 | "id": "AWSBM-ckAZZL" 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "class Patches(tf.keras.layers.Layer):\n", 295 | " def __init__(self, patch_size):\n", 296 | " super(Patches, self).__init__()\n", 297 | " self.patch_size = patch_size\n", 298 | "\n", 299 | " def call(self, images):\n", 300 | " batch_size = tf.shape(images)[0]\n", 301 | " patches = tf.image.extract_patches(\n", 302 | " images=images,\n", 303 | " sizes=[1, self.patch_size, self.patch_size, 1],\n", 304 | " strides=[1, self.patch_size, self.patch_size, 1],\n", 305 | " rates=[1, 1, 1, 1],\n", 306 | " padding=\"SAME\",\n", 307 | " )\n", 308 | " patch_dims = patches.shape[-1]\n", 309 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n", 310 | " return patches" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": { 317 | "id": "mXT2GyxTAZWq" 318 | }, 319 | "outputs": [], 320 | "source": [ 321 | "class PatchEncoder(tf.keras.layers.Layer):\n", 322 | " def __init__(self, num_patches, projection_dim):\n", 323 | " super(PatchEncoder, self).__init__()\n", 324 | " self.num_patches = num_patches\n", 325 | " self.projection = layers.Dense(units=projection_dim)\n", 326 | " self.position_embedding = layers.Embedding(\n", 327 | " input_dim=num_patches, output_dim=projection_dim\n", 328 | " )\n", 329 | "\n", 330 | " def call(self, patch):\n", 331 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n", 332 | " encoded = self.projection(patch) + self.position_embedding(positions)\n", 333 | " return encoded" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "id": "EsRN0b3qAdWz" 341 | }, 342 | "outputs": [], 343 | "source": [ 344 | "class TransformerBlock(tf.keras.layers.Layer):\n", 345 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n", 346 | " super(TransformerBlock, self).__init__()\n", 347 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n", 348 | " self.ffn = tf.keras.Sequential(\n", 349 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", 350 | " )\n", 351 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n", 352 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n", 353 | " self.dropout1 = layers.Dropout(rate)\n", 354 | " self.dropout2 = layers.Dropout(rate)\n", 355 | "\n", 356 | " def call(self, inputs, training):\n", 357 | " attn_output = self.att(inputs, inputs)\n", 358 | " attn_output = self.dropout1(attn_output, training=training)\n", 359 | " out1 = self.layernorm1(inputs + attn_output)\n", 360 | " ffn_output = self.ffn(out1)\n", 361 | " ffn_output = self.dropout2(ffn_output, training=training)\n", 362 | " return self.layernorm2(out1 + ffn_output)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": { 369 | "id": "h9GZYWlkAsBn" 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "from tensorflow import Tensor\n", 374 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n", 375 | " Add, AveragePooling2D, Flatten, Dense\n", 376 | "from tensorflow.keras.models import Model\n", 377 | "\n", 378 | "def relu_bn(inputs: Tensor) -> Tensor:\n", 379 | " relu = ReLU()(inputs)\n", 380 | " bn = BatchNormalization()(relu)\n", 381 | " return bn\n", 382 | "\n", 383 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n", 384 | " y = Conv2D(kernel_size=kernel_size,\n", 385 | " strides= (1 if not downsample else 2),\n", 386 | " filters=filters,\n", 387 | " padding=\"same\")(x)\n", 388 | " y = relu_bn(y)\n", 389 | " y = Conv2D(kernel_size=kernel_size,\n", 390 | " strides=1,\n", 391 | " filters=filters,\n", 392 | " padding=\"same\")(y)\n", 393 | "\n", 394 | " if downsample:\n", 395 | " x = Conv2D(kernel_size=1,\n", 396 | " strides=2,\n", 397 | " filters=filters,\n", 398 | " padding=\"same\")(x)\n", 399 | " out = Add()([x, y])\n", 400 | " out = relu_bn(out)\n", 401 | " return out" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": { 408 | "id": "lFPI4Nu-8b4q" 409 | }, 410 | "outputs": [], 411 | "source": [ 412 | "from tensorflow.keras import layers\n", 413 | "\n", 414 | "def Generator():\n", 415 | "\n", 416 | " inputs = layers.Input(shape=(128, 128, 3))\n", 417 | "\n", 418 | " patches = Patches(patch_size)(inputs)\n", 419 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n", 420 | "\n", 421 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n", 422 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 423 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 424 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 425 | "\n", 426 | " x = layers.Reshape((8, 8, 256))(x)\n", 427 | "\n", 428 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 429 | " x = layers.BatchNormalization()(x)\n", 430 | " x = layers.LeakyReLU()(x)\n", 431 | "\n", 432 | " x = residual_block(x, downsample=False, filters=512)\n", 433 | "\n", 434 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 435 | " x = layers.BatchNormalization()(x)\n", 436 | " x = layers.LeakyReLU()(x)\n", 437 | "\n", 438 | " x = residual_block(x, downsample=False, filters=256)\n", 439 | "\n", 440 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 441 | " x = layers.BatchNormalization()(x)\n", 442 | " x = layers.LeakyReLU()(x)\n", 443 | " \n", 444 | " x = residual_block(x, downsample=False, filters=64)\n", 445 | "\n", 446 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 447 | " x = layers.BatchNormalization()(x)\n", 448 | " x = layers.LeakyReLU()(x)\n", 449 | "\n", 450 | " x = residual_block(x, downsample=False, filters=32)\n", 451 | "\n", 452 | " x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n", 453 | "\n", 454 | " return tf.keras.Model(inputs=inputs, outputs=x)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": { 461 | "colab": { 462 | "base_uri": "https://localhost:8080/" 463 | }, 464 | "id": "dIbRPFzjmV85", 465 | "outputId": "5216d85f-f401-4657-d41e-233f9be51233" 466 | }, 467 | "outputs": [], 468 | "source": [ 469 | "generator = Generator()\n", 470 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n", 471 | "generator.summary()" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": null, 477 | "metadata": { 478 | "id": "o58eGY46eiPQ" 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "class DisplayCallback(tf.keras.callbacks.Callback):\n", 483 | " def on_epoch_end(self, epoch, logs=None):\n", 484 | " clear_output(wait=True)\n", 485 | " show_predictions()\n", 486 | " print ('\\nSample Prediction after epoch {}\\n'.format(epoch+1))" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "metadata": { 493 | "id": "5nfPDmCNemKf" 494 | }, 495 | "outputs": [], 496 | "source": [ 497 | "generator.compile(optimizer='adam',\n", 498 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", 499 | " metrics=['accuracy'])" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": { 506 | "colab": { 507 | "base_uri": "https://localhost:8080/" 508 | }, 509 | "id": "LyA03ie2dUAS", 510 | "outputId": "0639a42b-4a8a-4603-9998-7c1409f1c71c" 511 | }, 512 | "outputs": [], 513 | "source": [ 514 | "EPOCHS = 200\n", 515 | "VAL_SUBSPLITS = 5\n", 516 | "VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n", 517 | "\n", 518 | "model_history = generator.fit(train_dataset, epochs=EPOCHS,\n", 519 | " steps_per_epoch=STEPS_PER_EPOCH,\n", 520 | " validation_steps=VALIDATION_STEPS,\n", 521 | " validation_data=test_dataset)" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": { 528 | "colab": { 529 | "base_uri": "https://localhost:8080/", 530 | "height": 293 531 | }, 532 | "id": "U1N1_obwtdQH", 533 | "outputId": "20004ed9-8789-4c37-962c-629d0bfd9946" 534 | }, 535 | "outputs": [], 536 | "source": [ 537 | "show_predictions(train_dataset)" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": null, 543 | "metadata": { 544 | "id": "NiTrkKItvZHE" 545 | }, 546 | "outputs": [], 547 | "source": [ 548 | "generator.save_weights('seg-gen-weights.h5')" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "metadata": {}, 555 | "outputs": [], 556 | "source": [ 557 | "generator.load_weights('weights/seg-gen-weights (5).h5')" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "for inp, tar in train_dataset:\n", 567 | " break" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": null, 573 | "metadata": {}, 574 | "outputs": [], 575 | "source": [ 576 | "plt.imshow(generator(inp)[0])\n", 577 | "import numpy as np\n", 578 | "plt.imsave('pred1.png', np.array(create_mask(generator(inp))).astype(np.float32).reshape(128, 128))" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "plt.imshow(inp[0])\n", 588 | "import numpy as np\n", 589 | "plt.imsave('tar1.png', np.array(inp[0]).astype(np.float32).reshape(128, 128, 3))" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [] 598 | } 599 | ], 600 | "metadata": { 601 | "accelerator": "GPU", 602 | "colab": { 603 | "collapsed_sections": [], 604 | "name": "image2image_seg.ipynb", 605 | "provenance": [], 606 | "toc_visible": true 607 | }, 608 | "kernelspec": { 609 | "display_name": "Python 3", 610 | "language": "python", 611 | "name": "python3" 612 | }, 613 | "language_info": { 614 | "codemirror_mode": { 615 | "name": "ipython", 616 | "version": 3 617 | }, 618 | "file_extension": ".py", 619 | "mimetype": "text/x-python", 620 | "name": "python", 621 | "nbconvert_exporter": "python", 622 | "pygments_lexer": "ipython3", 623 | "version": "3.8.10" 624 | }, 625 | "widgets": { 626 | "application/vnd.jupyter.widget-state+json": { 627 | "0e398791e69f4c3c9f3b8f8928ffac94": { 628 | "model_module": "@jupyter-widgets/controls", 629 | "model_name": "FloatProgressModel", 630 | "state": { 631 | "_dom_classes": [], 632 | "_model_module": "@jupyter-widgets/controls", 633 | "_model_module_version": "1.5.0", 634 | "_model_name": "FloatProgressModel", 635 | "_view_count": null, 636 | "_view_module": "@jupyter-widgets/controls", 637 | "_view_module_version": "1.5.0", 638 | "_view_name": "ProgressView", 639 | "bar_style": "success", 640 | "description": "Extraction completed...: 100%", 641 | "description_tooltip": null, 642 | "layout": "IPY_MODEL_4e313c0fc9a34434ad8828f9f7d51245", 643 | "max": 1, 644 | "min": 0, 645 | "orientation": "horizontal", 646 | "style": "IPY_MODEL_7e80e3f85f4c4937b4e433f5d0cf8651", 647 | "value": 1 648 | } 649 | }, 650 | "112fde1fafb042c7abe870a471f69cb4": { 651 | "model_module": "@jupyter-widgets/controls", 652 | "model_name": "DescriptionStyleModel", 653 | "state": { 654 | "_model_module": "@jupyter-widgets/controls", 655 | "_model_module_version": "1.5.0", 656 | "_model_name": "DescriptionStyleModel", 657 | "_view_count": null, 658 | "_view_module": "@jupyter-widgets/base", 659 | "_view_module_version": "1.2.0", 660 | "_view_name": "StyleView", 661 | "description_width": "" 662 | } 663 | }, 664 | "12aae5a01c5a462e999735d848d3e354": { 665 | "model_module": "@jupyter-widgets/base", 666 | "model_name": "LayoutModel", 667 | "state": { 668 | "_model_module": "@jupyter-widgets/base", 669 | "_model_module_version": "1.2.0", 670 | "_model_name": "LayoutModel", 671 | "_view_count": null, 672 | "_view_module": "@jupyter-widgets/base", 673 | "_view_module_version": "1.2.0", 674 | "_view_name": "LayoutView", 675 | "align_content": null, 676 | "align_items": null, 677 | "align_self": null, 678 | "border": null, 679 | "bottom": null, 680 | "display": null, 681 | "flex": null, 682 | "flex_flow": null, 683 | "grid_area": null, 684 | "grid_auto_columns": null, 685 | "grid_auto_flow": null, 686 | "grid_auto_rows": null, 687 | "grid_column": null, 688 | "grid_gap": null, 689 | "grid_row": null, 690 | "grid_template_areas": null, 691 | "grid_template_columns": null, 692 | "grid_template_rows": null, 693 | "height": null, 694 | "justify_content": null, 695 | "justify_items": null, 696 | "left": null, 697 | "margin": null, 698 | "max_height": null, 699 | "max_width": null, 700 | "min_height": null, 701 | "min_width": null, 702 | "object_fit": null, 703 | "object_position": null, 704 | "order": null, 705 | "overflow": null, 706 | "overflow_x": null, 707 | "overflow_y": null, 708 | "padding": null, 709 | "right": null, 710 | "top": null, 711 | "visibility": null, 712 | "width": null 713 | } 714 | }, 715 | "1612ca06675349d4b76e5378a129eefb": { 716 | "model_module": "@jupyter-widgets/controls", 717 | "model_name": "HBoxModel", 718 | "state": { 719 | "_dom_classes": [], 720 | "_model_module": "@jupyter-widgets/controls", 721 | "_model_module_version": "1.5.0", 722 | "_model_name": "HBoxModel", 723 | "_view_count": null, 724 | "_view_module": "@jupyter-widgets/controls", 725 | "_view_module_version": "1.5.0", 726 | "_view_name": "HBoxView", 727 | "box_style": "", 728 | "children": [ 729 | "IPY_MODEL_8aebd4d81c6a47f78c84602fcef1249c", 730 | "IPY_MODEL_aa67ec0f9e06454b99dea3b324bfaeb2" 731 | ], 732 | "layout": "IPY_MODEL_81e3fa193c364c0f90b1dc7bca808eb9" 733 | } 734 | }, 735 | "1689a79e51304012976295a13bac17cf": { 736 | "model_module": "@jupyter-widgets/base", 737 | "model_name": "LayoutModel", 738 | "state": { 739 | "_model_module": "@jupyter-widgets/base", 740 | "_model_module_version": "1.2.0", 741 | "_model_name": "LayoutModel", 742 | "_view_count": null, 743 | "_view_module": "@jupyter-widgets/base", 744 | "_view_module_version": "1.2.0", 745 | "_view_name": "LayoutView", 746 | "align_content": null, 747 | "align_items": null, 748 | "align_self": null, 749 | "border": null, 750 | "bottom": null, 751 | "display": null, 752 | "flex": null, 753 | "flex_flow": null, 754 | "grid_area": null, 755 | "grid_auto_columns": null, 756 | "grid_auto_flow": null, 757 | "grid_auto_rows": null, 758 | "grid_column": null, 759 | "grid_gap": null, 760 | "grid_row": null, 761 | "grid_template_areas": null, 762 | "grid_template_columns": null, 763 | "grid_template_rows": null, 764 | "height": null, 765 | "justify_content": null, 766 | "justify_items": null, 767 | "left": null, 768 | "margin": null, 769 | "max_height": null, 770 | "max_width": null, 771 | "min_height": null, 772 | "min_width": null, 773 | "object_fit": null, 774 | "object_position": null, 775 | "order": null, 776 | "overflow": null, 777 | "overflow_x": null, 778 | "overflow_y": null, 779 | "padding": null, 780 | "right": null, 781 | "top": null, 782 | "visibility": null, 783 | "width": null 784 | } 785 | }, 786 | "1c7b1be085354daa90f7491366bf3f26": { 787 | "model_module": "@jupyter-widgets/controls", 788 | "model_name": "HTMLModel", 789 | "state": { 790 | "_dom_classes": [], 791 | "_model_module": "@jupyter-widgets/controls", 792 | "_model_module_version": "1.5.0", 793 | "_model_name": "HTMLModel", 794 | "_view_count": null, 795 | "_view_module": "@jupyter-widgets/controls", 796 | "_view_module_version": "1.5.0", 797 | "_view_name": "HTMLView", 798 | "description": "", 799 | "description_tooltip": null, 800 | "layout": "IPY_MODEL_2d886f8a4991438a97adb7f228c0b247", 801 | "placeholder": "​", 802 | "style": "IPY_MODEL_112fde1fafb042c7abe870a471f69cb4", 803 | "value": " 2/2 [00:37<00:00, 18.66s/ url]" 804 | } 805 | }, 806 | "2127f1922264401dbca0ac6365d283af": { 807 | "model_module": "@jupyter-widgets/controls", 808 | "model_name": "HBoxModel", 809 | "state": { 810 | "_dom_classes": [], 811 | "_model_module": "@jupyter-widgets/controls", 812 | "_model_module_version": "1.5.0", 813 | "_model_name": "HBoxModel", 814 | "_view_count": null, 815 | "_view_module": "@jupyter-widgets/controls", 816 | "_view_module_version": "1.5.0", 817 | "_view_name": "HBoxView", 818 | "box_style": "", 819 | "children": [ 820 | "IPY_MODEL_642dffa8e9374234bbce51d29e75c7ea", 821 | "IPY_MODEL_4f1f9e0926cd45e4aa4e813b32b1d7db" 822 | ], 823 | "layout": "IPY_MODEL_5c6312228c08467186d56dba16c4028c" 824 | } 825 | }, 826 | "231b5248ecd746d0939739f6a42db75e": { 827 | "model_module": "@jupyter-widgets/controls", 828 | "model_name": "ProgressStyleModel", 829 | "state": { 830 | "_model_module": "@jupyter-widgets/controls", 831 | "_model_module_version": "1.5.0", 832 | "_model_name": "ProgressStyleModel", 833 | "_view_count": null, 834 | "_view_module": "@jupyter-widgets/base", 835 | "_view_module_version": "1.2.0", 836 | "_view_name": "StyleView", 837 | "bar_color": null, 838 | "description_width": "initial" 839 | } 840 | }, 841 | "2638e1b29d744f6b8a147f8bfa6f89d4": { 842 | "model_module": "@jupyter-widgets/base", 843 | "model_name": "LayoutModel", 844 | "state": { 845 | "_model_module": "@jupyter-widgets/base", 846 | "_model_module_version": "1.2.0", 847 | "_model_name": "LayoutModel", 848 | "_view_count": null, 849 | "_view_module": "@jupyter-widgets/base", 850 | "_view_module_version": "1.2.0", 851 | "_view_name": "LayoutView", 852 | "align_content": null, 853 | "align_items": null, 854 | "align_self": null, 855 | "border": null, 856 | "bottom": null, 857 | "display": null, 858 | "flex": null, 859 | "flex_flow": null, 860 | "grid_area": null, 861 | "grid_auto_columns": null, 862 | "grid_auto_flow": null, 863 | "grid_auto_rows": null, 864 | "grid_column": null, 865 | "grid_gap": null, 866 | "grid_row": null, 867 | "grid_template_areas": null, 868 | "grid_template_columns": null, 869 | "grid_template_rows": null, 870 | "height": null, 871 | "justify_content": null, 872 | "justify_items": null, 873 | "left": null, 874 | "margin": null, 875 | "max_height": null, 876 | "max_width": null, 877 | "min_height": null, 878 | "min_width": null, 879 | "object_fit": null, 880 | "object_position": null, 881 | "order": null, 882 | "overflow": null, 883 | "overflow_x": null, 884 | "overflow_y": null, 885 | "padding": null, 886 | "right": null, 887 | "top": null, 888 | "visibility": null, 889 | "width": null 890 | } 891 | }, 892 | "2c75ae91013c455e8cedafddcd12052f": { 893 | "model_module": "@jupyter-widgets/controls", 894 | "model_name": "HBoxModel", 895 | "state": { 896 | "_dom_classes": [], 897 | "_model_module": "@jupyter-widgets/controls", 898 | "_model_module_version": "1.5.0", 899 | "_model_name": "HBoxModel", 900 | "_view_count": null, 901 | "_view_module": "@jupyter-widgets/controls", 902 | "_view_module_version": "1.5.0", 903 | "_view_name": "HBoxView", 904 | "box_style": "", 905 | "children": [ 906 | "IPY_MODEL_c11d62582c59442e9cf388a6fbacaad7", 907 | "IPY_MODEL_fc6c6d3b63634d1f9f58cebbd50f4793" 908 | ], 909 | "layout": "IPY_MODEL_a6aa784cb2a14bc29ed56a7e29857061" 910 | } 911 | }, 912 | "2d19d5c44a0348768a9ff242aa199119": { 913 | "model_module": "@jupyter-widgets/base", 914 | "model_name": "LayoutModel", 915 | "state": { 916 | "_model_module": "@jupyter-widgets/base", 917 | "_model_module_version": "1.2.0", 918 | "_model_name": "LayoutModel", 919 | "_view_count": null, 920 | "_view_module": "@jupyter-widgets/base", 921 | "_view_module_version": "1.2.0", 922 | "_view_name": "LayoutView", 923 | "align_content": null, 924 | "align_items": null, 925 | "align_self": null, 926 | "border": null, 927 | "bottom": null, 928 | "display": null, 929 | "flex": null, 930 | "flex_flow": null, 931 | "grid_area": null, 932 | "grid_auto_columns": null, 933 | "grid_auto_flow": null, 934 | "grid_auto_rows": null, 935 | "grid_column": null, 936 | "grid_gap": null, 937 | "grid_row": null, 938 | "grid_template_areas": null, 939 | "grid_template_columns": null, 940 | "grid_template_rows": null, 941 | "height": null, 942 | "justify_content": null, 943 | "justify_items": null, 944 | "left": null, 945 | "margin": null, 946 | "max_height": null, 947 | "max_width": null, 948 | "min_height": null, 949 | "min_width": null, 950 | "object_fit": null, 951 | "object_position": null, 952 | "order": null, 953 | "overflow": null, 954 | "overflow_x": null, 955 | "overflow_y": null, 956 | "padding": null, 957 | "right": null, 958 | "top": null, 959 | "visibility": null, 960 | "width": null 961 | } 962 | }, 963 | "2d886f8a4991438a97adb7f228c0b247": { 964 | "model_module": "@jupyter-widgets/base", 965 | "model_name": "LayoutModel", 966 | "state": { 967 | "_model_module": "@jupyter-widgets/base", 968 | "_model_module_version": "1.2.0", 969 | "_model_name": "LayoutModel", 970 | "_view_count": null, 971 | "_view_module": "@jupyter-widgets/base", 972 | "_view_module_version": "1.2.0", 973 | "_view_name": "LayoutView", 974 | "align_content": null, 975 | "align_items": null, 976 | "align_self": null, 977 | "border": null, 978 | "bottom": null, 979 | "display": null, 980 | "flex": null, 981 | "flex_flow": null, 982 | "grid_area": null, 983 | "grid_auto_columns": null, 984 | "grid_auto_flow": null, 985 | "grid_auto_rows": null, 986 | "grid_column": null, 987 | "grid_gap": null, 988 | "grid_row": null, 989 | "grid_template_areas": null, 990 | "grid_template_columns": null, 991 | "grid_template_rows": null, 992 | "height": null, 993 | "justify_content": null, 994 | "justify_items": null, 995 | "left": null, 996 | "margin": null, 997 | "max_height": null, 998 | "max_width": null, 999 | "min_height": null, 1000 | "min_width": null, 1001 | "object_fit": null, 1002 | "object_position": null, 1003 | "order": null, 1004 | "overflow": null, 1005 | "overflow_x": null, 1006 | "overflow_y": null, 1007 | "padding": null, 1008 | "right": null, 1009 | "top": null, 1010 | "visibility": null, 1011 | "width": null 1012 | } 1013 | }, 1014 | "3e2b935749cb42aeac6507c0f4697295": { 1015 | "model_module": "@jupyter-widgets/controls", 1016 | "model_name": "DescriptionStyleModel", 1017 | "state": { 1018 | "_model_module": "@jupyter-widgets/controls", 1019 | "_model_module_version": "1.5.0", 1020 | "_model_name": "DescriptionStyleModel", 1021 | "_view_count": null, 1022 | "_view_module": "@jupyter-widgets/base", 1023 | "_view_module_version": "1.2.0", 1024 | "_view_name": "StyleView", 1025 | "description_width": "" 1026 | } 1027 | }, 1028 | "4e313c0fc9a34434ad8828f9f7d51245": { 1029 | "model_module": "@jupyter-widgets/base", 1030 | "model_name": "LayoutModel", 1031 | "state": { 1032 | "_model_module": "@jupyter-widgets/base", 1033 | "_model_module_version": "1.2.0", 1034 | "_model_name": "LayoutModel", 1035 | "_view_count": null, 1036 | "_view_module": "@jupyter-widgets/base", 1037 | "_view_module_version": "1.2.0", 1038 | "_view_name": "LayoutView", 1039 | "align_content": null, 1040 | "align_items": null, 1041 | "align_self": null, 1042 | "border": null, 1043 | "bottom": null, 1044 | "display": null, 1045 | "flex": null, 1046 | "flex_flow": null, 1047 | "grid_area": null, 1048 | "grid_auto_columns": null, 1049 | "grid_auto_flow": null, 1050 | "grid_auto_rows": null, 1051 | "grid_column": null, 1052 | "grid_gap": null, 1053 | "grid_row": null, 1054 | "grid_template_areas": null, 1055 | "grid_template_columns": null, 1056 | "grid_template_rows": null, 1057 | "height": null, 1058 | "justify_content": null, 1059 | "justify_items": null, 1060 | "left": null, 1061 | "margin": null, 1062 | "max_height": null, 1063 | "max_width": null, 1064 | "min_height": null, 1065 | "min_width": null, 1066 | "object_fit": null, 1067 | "object_position": null, 1068 | "order": null, 1069 | "overflow": null, 1070 | "overflow_x": null, 1071 | "overflow_y": null, 1072 | "padding": null, 1073 | "right": null, 1074 | "top": null, 1075 | "visibility": null, 1076 | "width": null 1077 | } 1078 | }, 1079 | "4f1f9e0926cd45e4aa4e813b32b1d7db": { 1080 | "model_module": "@jupyter-widgets/controls", 1081 | "model_name": "HTMLModel", 1082 | "state": { 1083 | "_dom_classes": [], 1084 | "_model_module": "@jupyter-widgets/controls", 1085 | "_model_module_version": "1.5.0", 1086 | "_model_name": "HTMLModel", 1087 | "_view_count": null, 1088 | "_view_module": "@jupyter-widgets/controls", 1089 | "_view_module_version": "1.5.0", 1090 | "_view_name": "HTMLView", 1091 | "description": "", 1092 | "description_tooltip": null, 1093 | "layout": "IPY_MODEL_605b204e45bb401e97270eba6eceb351", 1094 | "placeholder": "​", 1095 | "style": "IPY_MODEL_5cf59228cd5844a8abbe62cb897ce431", 1096 | "value": " 773/773 [00:37<00:00, 20.74 MiB/s]" 1097 | } 1098 | }, 1099 | "5695c6a28a5e47008227462f5ade5c9b": { 1100 | "model_module": "@jupyter-widgets/base", 1101 | "model_name": "LayoutModel", 1102 | "state": { 1103 | "_model_module": "@jupyter-widgets/base", 1104 | "_model_module_version": "1.2.0", 1105 | "_model_name": "LayoutModel", 1106 | "_view_count": null, 1107 | "_view_module": "@jupyter-widgets/base", 1108 | "_view_module_version": "1.2.0", 1109 | "_view_name": "LayoutView", 1110 | "align_content": null, 1111 | "align_items": null, 1112 | "align_self": null, 1113 | "border": null, 1114 | "bottom": null, 1115 | "display": null, 1116 | "flex": null, 1117 | "flex_flow": null, 1118 | "grid_area": null, 1119 | "grid_auto_columns": null, 1120 | "grid_auto_flow": null, 1121 | "grid_auto_rows": null, 1122 | "grid_column": null, 1123 | "grid_gap": null, 1124 | "grid_row": null, 1125 | "grid_template_areas": null, 1126 | "grid_template_columns": null, 1127 | "grid_template_rows": null, 1128 | "height": null, 1129 | "justify_content": null, 1130 | "justify_items": null, 1131 | "left": null, 1132 | "margin": null, 1133 | "max_height": null, 1134 | "max_width": null, 1135 | "min_height": null, 1136 | "min_width": null, 1137 | "object_fit": null, 1138 | "object_position": null, 1139 | "order": null, 1140 | "overflow": null, 1141 | "overflow_x": null, 1142 | "overflow_y": null, 1143 | "padding": null, 1144 | "right": null, 1145 | "top": null, 1146 | "visibility": null, 1147 | "width": null 1148 | } 1149 | }, 1150 | "5979ebdad11f4f6f941b32ba4a509416": { 1151 | "model_module": "@jupyter-widgets/controls", 1152 | "model_name": "ProgressStyleModel", 1153 | "state": { 1154 | "_model_module": "@jupyter-widgets/controls", 1155 | "_model_module_version": "1.5.0", 1156 | "_model_name": "ProgressStyleModel", 1157 | "_view_count": null, 1158 | "_view_module": "@jupyter-widgets/base", 1159 | "_view_module_version": "1.2.0", 1160 | "_view_name": "StyleView", 1161 | "bar_color": null, 1162 | "description_width": "initial" 1163 | } 1164 | }, 1165 | "5c6312228c08467186d56dba16c4028c": { 1166 | "model_module": "@jupyter-widgets/base", 1167 | "model_name": "LayoutModel", 1168 | "state": { 1169 | "_model_module": "@jupyter-widgets/base", 1170 | "_model_module_version": "1.2.0", 1171 | "_model_name": "LayoutModel", 1172 | "_view_count": null, 1173 | "_view_module": "@jupyter-widgets/base", 1174 | "_view_module_version": "1.2.0", 1175 | "_view_name": "LayoutView", 1176 | "align_content": null, 1177 | "align_items": null, 1178 | "align_self": null, 1179 | "border": null, 1180 | "bottom": null, 1181 | "display": null, 1182 | "flex": null, 1183 | "flex_flow": null, 1184 | "grid_area": null, 1185 | "grid_auto_columns": null, 1186 | "grid_auto_flow": null, 1187 | "grid_auto_rows": null, 1188 | "grid_column": null, 1189 | "grid_gap": null, 1190 | "grid_row": null, 1191 | "grid_template_areas": null, 1192 | "grid_template_columns": null, 1193 | "grid_template_rows": null, 1194 | "height": null, 1195 | "justify_content": null, 1196 | "justify_items": null, 1197 | "left": null, 1198 | "margin": null, 1199 | "max_height": null, 1200 | "max_width": null, 1201 | "min_height": null, 1202 | "min_width": null, 1203 | "object_fit": null, 1204 | "object_position": null, 1205 | "order": null, 1206 | "overflow": null, 1207 | "overflow_x": null, 1208 | "overflow_y": null, 1209 | "padding": null, 1210 | "right": null, 1211 | "top": null, 1212 | "visibility": null, 1213 | "width": null 1214 | } 1215 | }, 1216 | "5cf59228cd5844a8abbe62cb897ce431": { 1217 | "model_module": "@jupyter-widgets/controls", 1218 | "model_name": "DescriptionStyleModel", 1219 | "state": { 1220 | "_model_module": "@jupyter-widgets/controls", 1221 | "_model_module_version": "1.5.0", 1222 | "_model_name": "DescriptionStyleModel", 1223 | "_view_count": null, 1224 | "_view_module": "@jupyter-widgets/base", 1225 | "_view_module_version": "1.2.0", 1226 | "_view_name": "StyleView", 1227 | "description_width": "" 1228 | } 1229 | }, 1230 | "600a3e8becc84abeadc682bae8db52d5": { 1231 | "model_module": "@jupyter-widgets/controls", 1232 | "model_name": "HBoxModel", 1233 | "state": { 1234 | "_dom_classes": [], 1235 | "_model_module": "@jupyter-widgets/controls", 1236 | "_model_module_version": "1.5.0", 1237 | "_model_name": "HBoxModel", 1238 | "_view_count": null, 1239 | "_view_module": "@jupyter-widgets/controls", 1240 | "_view_module_version": "1.5.0", 1241 | "_view_name": "HBoxView", 1242 | "box_style": "", 1243 | "children": [ 1244 | "IPY_MODEL_ca89e296533745f7824bd9f91006d162", 1245 | "IPY_MODEL_fad326679299483f9cfcc800bd4aa549" 1246 | ], 1247 | "layout": "IPY_MODEL_97c62988c2e1454db66b33ab880f08d1" 1248 | } 1249 | }, 1250 | "605b204e45bb401e97270eba6eceb351": { 1251 | "model_module": "@jupyter-widgets/base", 1252 | "model_name": "LayoutModel", 1253 | "state": { 1254 | "_model_module": "@jupyter-widgets/base", 1255 | "_model_module_version": "1.2.0", 1256 | "_model_name": "LayoutModel", 1257 | "_view_count": null, 1258 | "_view_module": "@jupyter-widgets/base", 1259 | "_view_module_version": "1.2.0", 1260 | "_view_name": "LayoutView", 1261 | "align_content": null, 1262 | "align_items": null, 1263 | "align_self": null, 1264 | "border": null, 1265 | "bottom": null, 1266 | "display": null, 1267 | "flex": null, 1268 | "flex_flow": null, 1269 | "grid_area": null, 1270 | "grid_auto_columns": null, 1271 | "grid_auto_flow": null, 1272 | "grid_auto_rows": null, 1273 | "grid_column": null, 1274 | "grid_gap": null, 1275 | "grid_row": null, 1276 | "grid_template_areas": null, 1277 | "grid_template_columns": null, 1278 | "grid_template_rows": null, 1279 | "height": null, 1280 | "justify_content": null, 1281 | "justify_items": null, 1282 | "left": null, 1283 | "margin": null, 1284 | "max_height": null, 1285 | "max_width": null, 1286 | "min_height": null, 1287 | "min_width": null, 1288 | "object_fit": null, 1289 | "object_position": null, 1290 | "order": null, 1291 | "overflow": null, 1292 | "overflow_x": null, 1293 | "overflow_y": null, 1294 | "padding": null, 1295 | "right": null, 1296 | "top": null, 1297 | "visibility": null, 1298 | "width": null 1299 | } 1300 | }, 1301 | "642dffa8e9374234bbce51d29e75c7ea": { 1302 | "model_module": "@jupyter-widgets/controls", 1303 | "model_name": "FloatProgressModel", 1304 | "state": { 1305 | "_dom_classes": [], 1306 | "_model_module": "@jupyter-widgets/controls", 1307 | "_model_module_version": "1.5.0", 1308 | "_model_name": "FloatProgressModel", 1309 | "_view_count": null, 1310 | "_view_module": "@jupyter-widgets/controls", 1311 | "_view_module_version": "1.5.0", 1312 | "_view_name": "ProgressView", 1313 | "bar_style": "success", 1314 | "description": "Dl Size...: 100%", 1315 | "description_tooltip": null, 1316 | "layout": "IPY_MODEL_80d7be7130b8468c8ab3bcb095dc36bf", 1317 | "max": 1, 1318 | "min": 0, 1319 | "orientation": "horizontal", 1320 | "style": "IPY_MODEL_e292c1f1791e4b349a931400f52e980c", 1321 | "value": 1 1322 | } 1323 | }, 1324 | "69b5d3f908c748509302621147b3517b": { 1325 | "model_module": "@jupyter-widgets/base", 1326 | "model_name": "LayoutModel", 1327 | "state": { 1328 | "_model_module": "@jupyter-widgets/base", 1329 | "_model_module_version": "1.2.0", 1330 | "_model_name": "LayoutModel", 1331 | "_view_count": null, 1332 | "_view_module": "@jupyter-widgets/base", 1333 | "_view_module_version": "1.2.0", 1334 | "_view_name": "LayoutView", 1335 | "align_content": null, 1336 | "align_items": null, 1337 | "align_self": null, 1338 | "border": null, 1339 | "bottom": null, 1340 | "display": null, 1341 | "flex": null, 1342 | "flex_flow": null, 1343 | "grid_area": null, 1344 | "grid_auto_columns": null, 1345 | "grid_auto_flow": null, 1346 | "grid_auto_rows": null, 1347 | "grid_column": null, 1348 | "grid_gap": null, 1349 | "grid_row": null, 1350 | "grid_template_areas": null, 1351 | "grid_template_columns": null, 1352 | "grid_template_rows": null, 1353 | "height": null, 1354 | "justify_content": null, 1355 | "justify_items": null, 1356 | "left": null, 1357 | "margin": null, 1358 | "max_height": null, 1359 | "max_width": null, 1360 | "min_height": null, 1361 | "min_width": null, 1362 | "object_fit": null, 1363 | "object_position": null, 1364 | "order": null, 1365 | "overflow": null, 1366 | "overflow_x": null, 1367 | "overflow_y": null, 1368 | "padding": null, 1369 | "right": null, 1370 | "top": null, 1371 | "visibility": null, 1372 | "width": null 1373 | } 1374 | }, 1375 | "6abe70e78dce42a091dca068f5af4696": { 1376 | "model_module": "@jupyter-widgets/base", 1377 | "model_name": "LayoutModel", 1378 | "state": { 1379 | "_model_module": "@jupyter-widgets/base", 1380 | "_model_module_version": "1.2.0", 1381 | "_model_name": "LayoutModel", 1382 | "_view_count": null, 1383 | "_view_module": "@jupyter-widgets/base", 1384 | "_view_module_version": "1.2.0", 1385 | "_view_name": "LayoutView", 1386 | "align_content": null, 1387 | "align_items": null, 1388 | "align_self": null, 1389 | "border": null, 1390 | "bottom": null, 1391 | "display": null, 1392 | "flex": null, 1393 | "flex_flow": null, 1394 | "grid_area": null, 1395 | "grid_auto_columns": null, 1396 | "grid_auto_flow": null, 1397 | "grid_auto_rows": null, 1398 | "grid_column": null, 1399 | "grid_gap": null, 1400 | "grid_row": null, 1401 | "grid_template_areas": null, 1402 | "grid_template_columns": null, 1403 | "grid_template_rows": null, 1404 | "height": null, 1405 | "justify_content": null, 1406 | "justify_items": null, 1407 | "left": null, 1408 | "margin": null, 1409 | "max_height": null, 1410 | "max_width": null, 1411 | "min_height": null, 1412 | "min_width": null, 1413 | "object_fit": null, 1414 | "object_position": null, 1415 | "order": null, 1416 | "overflow": null, 1417 | "overflow_x": null, 1418 | "overflow_y": null, 1419 | "padding": null, 1420 | "right": null, 1421 | "top": null, 1422 | "visibility": null, 1423 | "width": null 1424 | } 1425 | }, 1426 | "6ba3d098f19b4919a06ecca5b3763596": { 1427 | "model_module": "@jupyter-widgets/base", 1428 | "model_name": "LayoutModel", 1429 | "state": { 1430 | "_model_module": "@jupyter-widgets/base", 1431 | "_model_module_version": "1.2.0", 1432 | "_model_name": "LayoutModel", 1433 | "_view_count": null, 1434 | "_view_module": "@jupyter-widgets/base", 1435 | "_view_module_version": "1.2.0", 1436 | "_view_name": "LayoutView", 1437 | "align_content": null, 1438 | "align_items": null, 1439 | "align_self": null, 1440 | "border": null, 1441 | "bottom": null, 1442 | "display": null, 1443 | "flex": null, 1444 | "flex_flow": null, 1445 | "grid_area": null, 1446 | "grid_auto_columns": null, 1447 | "grid_auto_flow": null, 1448 | "grid_auto_rows": null, 1449 | "grid_column": null, 1450 | "grid_gap": null, 1451 | "grid_row": null, 1452 | "grid_template_areas": null, 1453 | "grid_template_columns": null, 1454 | "grid_template_rows": null, 1455 | "height": null, 1456 | "justify_content": null, 1457 | "justify_items": null, 1458 | "left": null, 1459 | "margin": null, 1460 | "max_height": null, 1461 | "max_width": null, 1462 | "min_height": null, 1463 | "min_width": null, 1464 | "object_fit": null, 1465 | "object_position": null, 1466 | "order": null, 1467 | "overflow": null, 1468 | "overflow_x": null, 1469 | "overflow_y": null, 1470 | "padding": null, 1471 | "right": null, 1472 | "top": null, 1473 | "visibility": null, 1474 | "width": null 1475 | } 1476 | }, 1477 | "74034286026f457892c06d6025286628": { 1478 | "model_module": "@jupyter-widgets/base", 1479 | "model_name": "LayoutModel", 1480 | "state": { 1481 | "_model_module": "@jupyter-widgets/base", 1482 | "_model_module_version": "1.2.0", 1483 | "_model_name": "LayoutModel", 1484 | "_view_count": null, 1485 | "_view_module": "@jupyter-widgets/base", 1486 | "_view_module_version": "1.2.0", 1487 | "_view_name": "LayoutView", 1488 | "align_content": null, 1489 | "align_items": null, 1490 | "align_self": null, 1491 | "border": null, 1492 | "bottom": null, 1493 | "display": null, 1494 | "flex": null, 1495 | "flex_flow": null, 1496 | "grid_area": null, 1497 | "grid_auto_columns": null, 1498 | "grid_auto_flow": null, 1499 | "grid_auto_rows": null, 1500 | "grid_column": null, 1501 | "grid_gap": null, 1502 | "grid_row": null, 1503 | "grid_template_areas": null, 1504 | "grid_template_columns": null, 1505 | "grid_template_rows": null, 1506 | "height": null, 1507 | "justify_content": null, 1508 | "justify_items": null, 1509 | "left": null, 1510 | "margin": null, 1511 | "max_height": null, 1512 | "max_width": null, 1513 | "min_height": null, 1514 | "min_width": null, 1515 | "object_fit": null, 1516 | "object_position": null, 1517 | "order": null, 1518 | "overflow": null, 1519 | "overflow_x": null, 1520 | "overflow_y": null, 1521 | "padding": null, 1522 | "right": null, 1523 | "top": null, 1524 | "visibility": null, 1525 | "width": null 1526 | } 1527 | }, 1528 | "7e7891071f7c4f5e87d876da643a3045": { 1529 | "model_module": "@jupyter-widgets/controls", 1530 | "model_name": "DescriptionStyleModel", 1531 | "state": { 1532 | "_model_module": "@jupyter-widgets/controls", 1533 | "_model_module_version": "1.5.0", 1534 | "_model_name": "DescriptionStyleModel", 1535 | "_view_count": null, 1536 | "_view_module": "@jupyter-widgets/base", 1537 | "_view_module_version": "1.2.0", 1538 | "_view_name": "StyleView", 1539 | "description_width": "" 1540 | } 1541 | }, 1542 | "7e80e3f85f4c4937b4e433f5d0cf8651": { 1543 | "model_module": "@jupyter-widgets/controls", 1544 | "model_name": "ProgressStyleModel", 1545 | "state": { 1546 | "_model_module": "@jupyter-widgets/controls", 1547 | "_model_module_version": "1.5.0", 1548 | "_model_name": "ProgressStyleModel", 1549 | "_view_count": null, 1550 | "_view_module": "@jupyter-widgets/base", 1551 | "_view_module_version": "1.2.0", 1552 | "_view_name": "StyleView", 1553 | "bar_color": null, 1554 | "description_width": "initial" 1555 | } 1556 | }, 1557 | "80d7be7130b8468c8ab3bcb095dc36bf": { 1558 | "model_module": "@jupyter-widgets/base", 1559 | "model_name": "LayoutModel", 1560 | "state": { 1561 | "_model_module": "@jupyter-widgets/base", 1562 | "_model_module_version": "1.2.0", 1563 | "_model_name": "LayoutModel", 1564 | "_view_count": null, 1565 | "_view_module": "@jupyter-widgets/base", 1566 | "_view_module_version": "1.2.0", 1567 | "_view_name": "LayoutView", 1568 | "align_content": null, 1569 | "align_items": null, 1570 | "align_self": null, 1571 | "border": null, 1572 | "bottom": null, 1573 | "display": null, 1574 | "flex": null, 1575 | "flex_flow": null, 1576 | "grid_area": null, 1577 | "grid_auto_columns": null, 1578 | "grid_auto_flow": null, 1579 | "grid_auto_rows": null, 1580 | "grid_column": null, 1581 | "grid_gap": null, 1582 | "grid_row": null, 1583 | "grid_template_areas": null, 1584 | "grid_template_columns": null, 1585 | "grid_template_rows": null, 1586 | "height": null, 1587 | "justify_content": null, 1588 | "justify_items": null, 1589 | "left": null, 1590 | "margin": null, 1591 | "max_height": null, 1592 | "max_width": null, 1593 | "min_height": null, 1594 | "min_width": null, 1595 | "object_fit": null, 1596 | "object_position": null, 1597 | "order": null, 1598 | "overflow": null, 1599 | "overflow_x": null, 1600 | "overflow_y": null, 1601 | "padding": null, 1602 | "right": null, 1603 | "top": null, 1604 | "visibility": null, 1605 | "width": null 1606 | } 1607 | }, 1608 | "81e3fa193c364c0f90b1dc7bca808eb9": { 1609 | "model_module": "@jupyter-widgets/base", 1610 | "model_name": "LayoutModel", 1611 | "state": { 1612 | "_model_module": "@jupyter-widgets/base", 1613 | "_model_module_version": "1.2.0", 1614 | "_model_name": "LayoutModel", 1615 | "_view_count": null, 1616 | "_view_module": "@jupyter-widgets/base", 1617 | "_view_module_version": "1.2.0", 1618 | "_view_name": "LayoutView", 1619 | "align_content": null, 1620 | "align_items": null, 1621 | "align_self": null, 1622 | "border": null, 1623 | "bottom": null, 1624 | "display": null, 1625 | "flex": null, 1626 | "flex_flow": null, 1627 | "grid_area": null, 1628 | "grid_auto_columns": null, 1629 | "grid_auto_flow": null, 1630 | "grid_auto_rows": null, 1631 | "grid_column": null, 1632 | "grid_gap": null, 1633 | "grid_row": null, 1634 | "grid_template_areas": null, 1635 | "grid_template_columns": null, 1636 | "grid_template_rows": null, 1637 | "height": null, 1638 | "justify_content": null, 1639 | "justify_items": null, 1640 | "left": null, 1641 | "margin": null, 1642 | "max_height": null, 1643 | "max_width": null, 1644 | "min_height": null, 1645 | "min_width": null, 1646 | "object_fit": null, 1647 | "object_position": null, 1648 | "order": null, 1649 | "overflow": null, 1650 | "overflow_x": null, 1651 | "overflow_y": null, 1652 | "padding": null, 1653 | "right": null, 1654 | "top": null, 1655 | "visibility": null, 1656 | "width": null 1657 | } 1658 | }, 1659 | "83c9ad4ed8bb4aa9804a23a330da4d8c": { 1660 | "model_module": "@jupyter-widgets/controls", 1661 | "model_name": "HBoxModel", 1662 | "state": { 1663 | "_dom_classes": [], 1664 | "_model_module": "@jupyter-widgets/controls", 1665 | "_model_module_version": "1.5.0", 1666 | "_model_name": "HBoxModel", 1667 | "_view_count": null, 1668 | "_view_module": "@jupyter-widgets/controls", 1669 | "_view_module_version": "1.5.0", 1670 | "_view_name": "HBoxView", 1671 | "box_style": "", 1672 | "children": [ 1673 | "IPY_MODEL_c94bf43a065b44c182d8af0c717e92ca", 1674 | "IPY_MODEL_cd600dd278a04f1fabc50fd0e9639fc4" 1675 | ], 1676 | "layout": "IPY_MODEL_2d19d5c44a0348768a9ff242aa199119" 1677 | } 1678 | }, 1679 | "84132b5022db442183e409973de11d67": { 1680 | "model_module": "@jupyter-widgets/controls", 1681 | "model_name": "ProgressStyleModel", 1682 | "state": { 1683 | "_model_module": "@jupyter-widgets/controls", 1684 | "_model_module_version": "1.5.0", 1685 | "_model_name": "ProgressStyleModel", 1686 | "_view_count": null, 1687 | "_view_module": "@jupyter-widgets/base", 1688 | "_view_module_version": "1.2.0", 1689 | "_view_name": "StyleView", 1690 | "bar_color": null, 1691 | "description_width": "initial" 1692 | } 1693 | }, 1694 | "8aebd4d81c6a47f78c84602fcef1249c": { 1695 | "model_module": "@jupyter-widgets/controls", 1696 | "model_name": "FloatProgressModel", 1697 | "state": { 1698 | "_dom_classes": [], 1699 | "_model_module": "@jupyter-widgets/controls", 1700 | "_model_module_version": "1.5.0", 1701 | "_model_name": "FloatProgressModel", 1702 | "_view_count": null, 1703 | "_view_module": "@jupyter-widgets/controls", 1704 | "_view_module_version": "1.5.0", 1705 | "_view_name": "ProgressView", 1706 | "bar_style": "info", 1707 | "description": "", 1708 | "description_tooltip": null, 1709 | "layout": "IPY_MODEL_69b5d3f908c748509302621147b3517b", 1710 | "max": 1, 1711 | "min": 0, 1712 | "orientation": "horizontal", 1713 | "style": "IPY_MODEL_bd9ea3399660446897f60580a39588f2", 1714 | "value": 1 1715 | } 1716 | }, 1717 | "8fd55e55c24e48afb223ff8b7422a546": { 1718 | "model_module": "@jupyter-widgets/base", 1719 | "model_name": "LayoutModel", 1720 | "state": { 1721 | "_model_module": "@jupyter-widgets/base", 1722 | "_model_module_version": "1.2.0", 1723 | "_model_name": "LayoutModel", 1724 | "_view_count": null, 1725 | "_view_module": "@jupyter-widgets/base", 1726 | "_view_module_version": "1.2.0", 1727 | "_view_name": "LayoutView", 1728 | "align_content": null, 1729 | "align_items": null, 1730 | "align_self": null, 1731 | "border": null, 1732 | "bottom": null, 1733 | "display": null, 1734 | "flex": null, 1735 | "flex_flow": null, 1736 | "grid_area": null, 1737 | "grid_auto_columns": null, 1738 | "grid_auto_flow": null, 1739 | "grid_auto_rows": null, 1740 | "grid_column": null, 1741 | "grid_gap": null, 1742 | "grid_row": null, 1743 | "grid_template_areas": null, 1744 | "grid_template_columns": null, 1745 | "grid_template_rows": null, 1746 | "height": null, 1747 | "justify_content": null, 1748 | "justify_items": null, 1749 | "left": null, 1750 | "margin": null, 1751 | "max_height": null, 1752 | "max_width": null, 1753 | "min_height": null, 1754 | "min_width": null, 1755 | "object_fit": null, 1756 | "object_position": null, 1757 | "order": null, 1758 | "overflow": null, 1759 | "overflow_x": null, 1760 | "overflow_y": null, 1761 | "padding": null, 1762 | "right": null, 1763 | "top": null, 1764 | "visibility": null, 1765 | "width": null 1766 | } 1767 | }, 1768 | "954933d3187645c6bd191289c122a6b5": { 1769 | "model_module": "@jupyter-widgets/controls", 1770 | "model_name": "FloatProgressModel", 1771 | "state": { 1772 | "_dom_classes": [], 1773 | "_model_module": "@jupyter-widgets/controls", 1774 | "_model_module_version": "1.5.0", 1775 | "_model_name": "FloatProgressModel", 1776 | "_view_count": null, 1777 | "_view_module": "@jupyter-widgets/controls", 1778 | "_view_module_version": "1.5.0", 1779 | "_view_name": "ProgressView", 1780 | "bar_style": "success", 1781 | "description": "Dl Completed...: 100%", 1782 | "description_tooltip": null, 1783 | "layout": "IPY_MODEL_f60bacd17c604608a7fd80a06a305bdf", 1784 | "max": 1, 1785 | "min": 0, 1786 | "orientation": "horizontal", 1787 | "style": "IPY_MODEL_84132b5022db442183e409973de11d67", 1788 | "value": 1 1789 | } 1790 | }, 1791 | "97c62988c2e1454db66b33ab880f08d1": { 1792 | "model_module": "@jupyter-widgets/base", 1793 | "model_name": "LayoutModel", 1794 | "state": { 1795 | "_model_module": "@jupyter-widgets/base", 1796 | "_model_module_version": "1.2.0", 1797 | "_model_name": "LayoutModel", 1798 | "_view_count": null, 1799 | "_view_module": "@jupyter-widgets/base", 1800 | "_view_module_version": "1.2.0", 1801 | "_view_name": "LayoutView", 1802 | "align_content": null, 1803 | "align_items": null, 1804 | "align_self": null, 1805 | "border": null, 1806 | "bottom": null, 1807 | "display": null, 1808 | "flex": null, 1809 | "flex_flow": null, 1810 | "grid_area": null, 1811 | "grid_auto_columns": null, 1812 | "grid_auto_flow": null, 1813 | "grid_auto_rows": null, 1814 | "grid_column": null, 1815 | "grid_gap": null, 1816 | "grid_row": null, 1817 | "grid_template_areas": null, 1818 | "grid_template_columns": null, 1819 | "grid_template_rows": null, 1820 | "height": null, 1821 | "justify_content": null, 1822 | "justify_items": null, 1823 | "left": null, 1824 | "margin": null, 1825 | "max_height": null, 1826 | "max_width": null, 1827 | "min_height": null, 1828 | "min_width": null, 1829 | "object_fit": null, 1830 | "object_position": null, 1831 | "order": null, 1832 | "overflow": null, 1833 | "overflow_x": null, 1834 | "overflow_y": null, 1835 | "padding": null, 1836 | "right": null, 1837 | "top": null, 1838 | "visibility": null, 1839 | "width": null 1840 | } 1841 | }, 1842 | "a625ba1ffcc340c5a4f042be0b4877c3": { 1843 | "model_module": "@jupyter-widgets/controls", 1844 | "model_name": "DescriptionStyleModel", 1845 | "state": { 1846 | "_model_module": "@jupyter-widgets/controls", 1847 | "_model_module_version": "1.5.0", 1848 | "_model_name": "DescriptionStyleModel", 1849 | "_view_count": null, 1850 | "_view_module": "@jupyter-widgets/base", 1851 | "_view_module_version": "1.2.0", 1852 | "_view_name": "StyleView", 1853 | "description_width": "" 1854 | } 1855 | }, 1856 | "a6aa784cb2a14bc29ed56a7e29857061": { 1857 | "model_module": "@jupyter-widgets/base", 1858 | "model_name": "LayoutModel", 1859 | "state": { 1860 | "_model_module": "@jupyter-widgets/base", 1861 | "_model_module_version": "1.2.0", 1862 | "_model_name": "LayoutModel", 1863 | "_view_count": null, 1864 | "_view_module": "@jupyter-widgets/base", 1865 | "_view_module_version": "1.2.0", 1866 | "_view_name": "LayoutView", 1867 | "align_content": null, 1868 | "align_items": null, 1869 | "align_self": null, 1870 | "border": null, 1871 | "bottom": null, 1872 | "display": null, 1873 | "flex": null, 1874 | "flex_flow": null, 1875 | "grid_area": null, 1876 | "grid_auto_columns": null, 1877 | "grid_auto_flow": null, 1878 | "grid_auto_rows": null, 1879 | "grid_column": null, 1880 | "grid_gap": null, 1881 | "grid_row": null, 1882 | "grid_template_areas": null, 1883 | "grid_template_columns": null, 1884 | "grid_template_rows": null, 1885 | "height": null, 1886 | "justify_content": null, 1887 | "justify_items": null, 1888 | "left": null, 1889 | "margin": null, 1890 | "max_height": null, 1891 | "max_width": null, 1892 | "min_height": null, 1893 | "min_width": null, 1894 | "object_fit": null, 1895 | "object_position": null, 1896 | "order": null, 1897 | "overflow": null, 1898 | "overflow_x": null, 1899 | "overflow_y": null, 1900 | "padding": null, 1901 | "right": null, 1902 | "top": null, 1903 | "visibility": null, 1904 | "width": null 1905 | } 1906 | }, 1907 | "a7bdbd085e5c43e3b4572f53175f61b6": { 1908 | "model_module": "@jupyter-widgets/controls", 1909 | "model_name": "DescriptionStyleModel", 1910 | "state": { 1911 | "_model_module": "@jupyter-widgets/controls", 1912 | "_model_module_version": "1.5.0", 1913 | "_model_name": "DescriptionStyleModel", 1914 | "_view_count": null, 1915 | "_view_module": "@jupyter-widgets/base", 1916 | "_view_module_version": "1.2.0", 1917 | "_view_name": "StyleView", 1918 | "description_width": "" 1919 | } 1920 | }, 1921 | "aa67ec0f9e06454b99dea3b324bfaeb2": { 1922 | "model_module": "@jupyter-widgets/controls", 1923 | "model_name": "HTMLModel", 1924 | "state": { 1925 | "_dom_classes": [], 1926 | "_model_module": "@jupyter-widgets/controls", 1927 | "_model_module_version": "1.5.0", 1928 | "_model_name": "HTMLModel", 1929 | "_view_count": null, 1930 | "_view_module": "@jupyter-widgets/controls", 1931 | "_view_module_version": "1.5.0", 1932 | "_view_name": "HTMLView", 1933 | "description": "", 1934 | "description_tooltip": null, 1935 | "layout": "IPY_MODEL_c985393de22c4f97a7219f53642a38e6", 1936 | "placeholder": "​", 1937 | "style": "IPY_MODEL_3e2b935749cb42aeac6507c0f4697295", 1938 | "value": " 3669/0 [00:02<00:00, 1372.29 examples/s]" 1939 | } 1940 | }, 1941 | "bbffba56f53b4b68bdcba5875e7c2f07": { 1942 | "model_module": "@jupyter-widgets/controls", 1943 | "model_name": "HBoxModel", 1944 | "state": { 1945 | "_dom_classes": [], 1946 | "_model_module": "@jupyter-widgets/controls", 1947 | "_model_module_version": "1.5.0", 1948 | "_model_name": "HBoxModel", 1949 | "_view_count": null, 1950 | "_view_module": "@jupyter-widgets/controls", 1951 | "_view_module_version": "1.5.0", 1952 | "_view_name": "HBoxView", 1953 | "box_style": "", 1954 | "children": [ 1955 | "IPY_MODEL_0e398791e69f4c3c9f3b8f8928ffac94", 1956 | "IPY_MODEL_d8118a57a2814d19b2fde27d2452c84e" 1957 | ], 1958 | "layout": "IPY_MODEL_74034286026f457892c06d6025286628" 1959 | } 1960 | }, 1961 | "bd9ea3399660446897f60580a39588f2": { 1962 | "model_module": "@jupyter-widgets/controls", 1963 | "model_name": "ProgressStyleModel", 1964 | "state": { 1965 | "_model_module": "@jupyter-widgets/controls", 1966 | "_model_module_version": "1.5.0", 1967 | "_model_name": "ProgressStyleModel", 1968 | "_view_count": null, 1969 | "_view_module": "@jupyter-widgets/base", 1970 | "_view_module_version": "1.2.0", 1971 | "_view_name": "StyleView", 1972 | "bar_color": null, 1973 | "description_width": "initial" 1974 | } 1975 | }, 1976 | "bffbcc4727fd4a2fa5694aea1680af65": { 1977 | "model_module": "@jupyter-widgets/base", 1978 | "model_name": "LayoutModel", 1979 | "state": { 1980 | "_model_module": "@jupyter-widgets/base", 1981 | "_model_module_version": "1.2.0", 1982 | "_model_name": "LayoutModel", 1983 | "_view_count": null, 1984 | "_view_module": "@jupyter-widgets/base", 1985 | "_view_module_version": "1.2.0", 1986 | "_view_name": "LayoutView", 1987 | "align_content": null, 1988 | "align_items": null, 1989 | "align_self": null, 1990 | "border": null, 1991 | "bottom": null, 1992 | "display": null, 1993 | "flex": null, 1994 | "flex_flow": null, 1995 | "grid_area": null, 1996 | "grid_auto_columns": null, 1997 | "grid_auto_flow": null, 1998 | "grid_auto_rows": null, 1999 | "grid_column": null, 2000 | "grid_gap": null, 2001 | "grid_row": null, 2002 | "grid_template_areas": null, 2003 | "grid_template_columns": null, 2004 | "grid_template_rows": null, 2005 | "height": null, 2006 | "justify_content": null, 2007 | "justify_items": null, 2008 | "left": null, 2009 | "margin": null, 2010 | "max_height": null, 2011 | "max_width": null, 2012 | "min_height": null, 2013 | "min_width": null, 2014 | "object_fit": null, 2015 | "object_position": null, 2016 | "order": null, 2017 | "overflow": null, 2018 | "overflow_x": null, 2019 | "overflow_y": null, 2020 | "padding": null, 2021 | "right": null, 2022 | "top": null, 2023 | "visibility": null, 2024 | "width": null 2025 | } 2026 | }, 2027 | "c11d62582c59442e9cf388a6fbacaad7": { 2028 | "model_module": "@jupyter-widgets/controls", 2029 | "model_name": "FloatProgressModel", 2030 | "state": { 2031 | "_dom_classes": [], 2032 | "_model_module": "@jupyter-widgets/controls", 2033 | "_model_module_version": "1.5.0", 2034 | "_model_name": "FloatProgressModel", 2035 | "_view_count": null, 2036 | "_view_module": "@jupyter-widgets/controls", 2037 | "_view_module_version": "1.5.0", 2038 | "_view_name": "ProgressView", 2039 | "bar_style": "danger", 2040 | "description": " 94%", 2041 | "description_tooltip": null, 2042 | "layout": "IPY_MODEL_2638e1b29d744f6b8a147f8bfa6f89d4", 2043 | "max": 3680, 2044 | "min": 0, 2045 | "orientation": "horizontal", 2046 | "style": "IPY_MODEL_5979ebdad11f4f6f941b32ba4a509416", 2047 | "value": 3461 2048 | } 2049 | }, 2050 | "c94bf43a065b44c182d8af0c717e92ca": { 2051 | "model_module": "@jupyter-widgets/controls", 2052 | "model_name": "FloatProgressModel", 2053 | "state": { 2054 | "_dom_classes": [], 2055 | "_model_module": "@jupyter-widgets/controls", 2056 | "_model_module_version": "1.5.0", 2057 | "_model_name": "FloatProgressModel", 2058 | "_view_count": null, 2059 | "_view_module": "@jupyter-widgets/controls", 2060 | "_view_module_version": "1.5.0", 2061 | "_view_name": "ProgressView", 2062 | "bar_style": "info", 2063 | "description": "", 2064 | "description_tooltip": null, 2065 | "layout": "IPY_MODEL_bffbcc4727fd4a2fa5694aea1680af65", 2066 | "max": 1, 2067 | "min": 0, 2068 | "orientation": "horizontal", 2069 | "style": "IPY_MODEL_231b5248ecd746d0939739f6a42db75e", 2070 | "value": 1 2071 | } 2072 | }, 2073 | "c985393de22c4f97a7219f53642a38e6": { 2074 | "model_module": "@jupyter-widgets/base", 2075 | "model_name": "LayoutModel", 2076 | "state": { 2077 | "_model_module": "@jupyter-widgets/base", 2078 | "_model_module_version": "1.2.0", 2079 | "_model_name": "LayoutModel", 2080 | "_view_count": null, 2081 | "_view_module": "@jupyter-widgets/base", 2082 | "_view_module_version": "1.2.0", 2083 | "_view_name": "LayoutView", 2084 | "align_content": null, 2085 | "align_items": null, 2086 | "align_self": null, 2087 | "border": null, 2088 | "bottom": null, 2089 | "display": null, 2090 | "flex": null, 2091 | "flex_flow": null, 2092 | "grid_area": null, 2093 | "grid_auto_columns": null, 2094 | "grid_auto_flow": null, 2095 | "grid_auto_rows": null, 2096 | "grid_column": null, 2097 | "grid_gap": null, 2098 | "grid_row": null, 2099 | "grid_template_areas": null, 2100 | "grid_template_columns": null, 2101 | "grid_template_rows": null, 2102 | "height": null, 2103 | "justify_content": null, 2104 | "justify_items": null, 2105 | "left": null, 2106 | "margin": null, 2107 | "max_height": null, 2108 | "max_width": null, 2109 | "min_height": null, 2110 | "min_width": null, 2111 | "object_fit": null, 2112 | "object_position": null, 2113 | "order": null, 2114 | "overflow": null, 2115 | "overflow_x": null, 2116 | "overflow_y": null, 2117 | "padding": null, 2118 | "right": null, 2119 | "top": null, 2120 | "visibility": null, 2121 | "width": null 2122 | } 2123 | }, 2124 | "ca89e296533745f7824bd9f91006d162": { 2125 | "model_module": "@jupyter-widgets/controls", 2126 | "model_name": "FloatProgressModel", 2127 | "state": { 2128 | "_dom_classes": [], 2129 | "_model_module": "@jupyter-widgets/controls", 2130 | "_model_module_version": "1.5.0", 2131 | "_model_name": "FloatProgressModel", 2132 | "_view_count": null, 2133 | "_view_module": "@jupyter-widgets/controls", 2134 | "_view_module_version": "1.5.0", 2135 | "_view_name": "ProgressView", 2136 | "bar_style": "danger", 2137 | "description": " 99%", 2138 | "description_tooltip": null, 2139 | "layout": "IPY_MODEL_6ba3d098f19b4919a06ecca5b3763596", 2140 | "max": 3669, 2141 | "min": 0, 2142 | "orientation": "horizontal", 2143 | "style": "IPY_MODEL_faee640e94db490f8cd61481aa74a92a", 2144 | "value": 3648 2145 | } 2146 | }, 2147 | "cd600dd278a04f1fabc50fd0e9639fc4": { 2148 | "model_module": "@jupyter-widgets/controls", 2149 | "model_name": "HTMLModel", 2150 | "state": { 2151 | "_dom_classes": [], 2152 | "_model_module": "@jupyter-widgets/controls", 2153 | "_model_module_version": "1.5.0", 2154 | "_model_name": "HTMLModel", 2155 | "_view_count": null, 2156 | "_view_module": "@jupyter-widgets/controls", 2157 | "_view_module_version": "1.5.0", 2158 | "_view_name": "HTMLView", 2159 | "description": "", 2160 | "description_tooltip": null, 2161 | "layout": "IPY_MODEL_6abe70e78dce42a091dca068f5af4696", 2162 | "placeholder": "​", 2163 | "style": "IPY_MODEL_a7bdbd085e5c43e3b4572f53175f61b6", 2164 | "value": " 3680/0 [00:02<00:00, 1421.50 examples/s]" 2165 | } 2166 | }, 2167 | "d8118a57a2814d19b2fde27d2452c84e": { 2168 | "model_module": "@jupyter-widgets/controls", 2169 | "model_name": "HTMLModel", 2170 | "state": { 2171 | "_dom_classes": [], 2172 | "_model_module": "@jupyter-widgets/controls", 2173 | "_model_module_version": "1.5.0", 2174 | "_model_name": "HTMLModel", 2175 | "_view_count": null, 2176 | "_view_module": "@jupyter-widgets/controls", 2177 | "_view_module_version": "1.5.0", 2178 | "_view_name": "HTMLView", 2179 | "description": "", 2180 | "description_tooltip": null, 2181 | "layout": "IPY_MODEL_5695c6a28a5e47008227462f5ade5c9b", 2182 | "placeholder": "​", 2183 | "style": "IPY_MODEL_a625ba1ffcc340c5a4f042be0b4877c3", 2184 | "value": " 2/2 [00:37<00:00, 18.60s/ file]" 2185 | } 2186 | }, 2187 | "e292c1f1791e4b349a931400f52e980c": { 2188 | "model_module": "@jupyter-widgets/controls", 2189 | "model_name": "ProgressStyleModel", 2190 | "state": { 2191 | "_model_module": "@jupyter-widgets/controls", 2192 | "_model_module_version": "1.5.0", 2193 | "_model_name": "ProgressStyleModel", 2194 | "_view_count": null, 2195 | "_view_module": "@jupyter-widgets/base", 2196 | "_view_module_version": "1.2.0", 2197 | "_view_name": "StyleView", 2198 | "bar_color": null, 2199 | "description_width": "initial" 2200 | } 2201 | }, 2202 | "e6d3b16d24cd4468b68af5be44eeaa46": { 2203 | "model_module": "@jupyter-widgets/controls", 2204 | "model_name": "HBoxModel", 2205 | "state": { 2206 | "_dom_classes": [], 2207 | "_model_module": "@jupyter-widgets/controls", 2208 | "_model_module_version": "1.5.0", 2209 | "_model_name": "HBoxModel", 2210 | "_view_count": null, 2211 | "_view_module": "@jupyter-widgets/controls", 2212 | "_view_module_version": "1.5.0", 2213 | "_view_name": "HBoxView", 2214 | "box_style": "", 2215 | "children": [ 2216 | "IPY_MODEL_954933d3187645c6bd191289c122a6b5", 2217 | "IPY_MODEL_1c7b1be085354daa90f7491366bf3f26" 2218 | ], 2219 | "layout": "IPY_MODEL_8fd55e55c24e48afb223ff8b7422a546" 2220 | } 2221 | }, 2222 | "f5b8dec6698746268a69122340b1f989": { 2223 | "model_module": "@jupyter-widgets/controls", 2224 | "model_name": "DescriptionStyleModel", 2225 | "state": { 2226 | "_model_module": "@jupyter-widgets/controls", 2227 | "_model_module_version": "1.5.0", 2228 | "_model_name": "DescriptionStyleModel", 2229 | "_view_count": null, 2230 | "_view_module": "@jupyter-widgets/base", 2231 | "_view_module_version": "1.2.0", 2232 | "_view_name": "StyleView", 2233 | "description_width": "" 2234 | } 2235 | }, 2236 | "f60bacd17c604608a7fd80a06a305bdf": { 2237 | "model_module": "@jupyter-widgets/base", 2238 | "model_name": "LayoutModel", 2239 | "state": { 2240 | "_model_module": "@jupyter-widgets/base", 2241 | "_model_module_version": "1.2.0", 2242 | "_model_name": "LayoutModel", 2243 | "_view_count": null, 2244 | "_view_module": "@jupyter-widgets/base", 2245 | "_view_module_version": "1.2.0", 2246 | "_view_name": "LayoutView", 2247 | "align_content": null, 2248 | "align_items": null, 2249 | "align_self": null, 2250 | "border": null, 2251 | "bottom": null, 2252 | "display": null, 2253 | "flex": null, 2254 | "flex_flow": null, 2255 | "grid_area": null, 2256 | "grid_auto_columns": null, 2257 | "grid_auto_flow": null, 2258 | "grid_auto_rows": null, 2259 | "grid_column": null, 2260 | "grid_gap": null, 2261 | "grid_row": null, 2262 | "grid_template_areas": null, 2263 | "grid_template_columns": null, 2264 | "grid_template_rows": null, 2265 | "height": null, 2266 | "justify_content": null, 2267 | "justify_items": null, 2268 | "left": null, 2269 | "margin": null, 2270 | "max_height": null, 2271 | "max_width": null, 2272 | "min_height": null, 2273 | "min_width": null, 2274 | "object_fit": null, 2275 | "object_position": null, 2276 | "order": null, 2277 | "overflow": null, 2278 | "overflow_x": null, 2279 | "overflow_y": null, 2280 | "padding": null, 2281 | "right": null, 2282 | "top": null, 2283 | "visibility": null, 2284 | "width": null 2285 | } 2286 | }, 2287 | "fad326679299483f9cfcc800bd4aa549": { 2288 | "model_module": "@jupyter-widgets/controls", 2289 | "model_name": "HTMLModel", 2290 | "state": { 2291 | "_dom_classes": [], 2292 | "_model_module": "@jupyter-widgets/controls", 2293 | "_model_module_version": "1.5.0", 2294 | "_model_name": "HTMLModel", 2295 | "_view_count": null, 2296 | "_view_module": "@jupyter-widgets/controls", 2297 | "_view_module_version": "1.5.0", 2298 | "_view_name": "HTMLView", 2299 | "description": "", 2300 | "description_tooltip": null, 2301 | "layout": "IPY_MODEL_1689a79e51304012976295a13bac17cf", 2302 | "placeholder": "​", 2303 | "style": "IPY_MODEL_f5b8dec6698746268a69122340b1f989", 2304 | "value": " 3648/3669 [00:01<00:00, 2429.53 examples/s]" 2305 | } 2306 | }, 2307 | "faee640e94db490f8cd61481aa74a92a": { 2308 | "model_module": "@jupyter-widgets/controls", 2309 | "model_name": "ProgressStyleModel", 2310 | "state": { 2311 | "_model_module": "@jupyter-widgets/controls", 2312 | "_model_module_version": "1.5.0", 2313 | "_model_name": "ProgressStyleModel", 2314 | "_view_count": null, 2315 | "_view_module": "@jupyter-widgets/base", 2316 | "_view_module_version": "1.2.0", 2317 | "_view_name": "StyleView", 2318 | "bar_color": null, 2319 | "description_width": "initial" 2320 | } 2321 | }, 2322 | "fc6c6d3b63634d1f9f58cebbd50f4793": { 2323 | "model_module": "@jupyter-widgets/controls", 2324 | "model_name": "HTMLModel", 2325 | "state": { 2326 | "_dom_classes": [], 2327 | "_model_module": "@jupyter-widgets/controls", 2328 | "_model_module_version": "1.5.0", 2329 | "_model_name": "HTMLModel", 2330 | "_view_count": null, 2331 | "_view_module": "@jupyter-widgets/controls", 2332 | "_view_module_version": "1.5.0", 2333 | "_view_name": "HTMLView", 2334 | "description": "", 2335 | "description_tooltip": null, 2336 | "layout": "IPY_MODEL_12aae5a01c5a462e999735d848d3e354", 2337 | "placeholder": "​", 2338 | "style": "IPY_MODEL_7e7891071f7c4f5e87d876da643a3045", 2339 | "value": " 3461/3680 [00:00<00:00, 3002.87 examples/s]" 2340 | } 2341 | } 2342 | } 2343 | } 2344 | }, 2345 | "nbformat": 4, 2346 | "nbformat_minor": 1 2347 | } 2348 | -------------------------------------------------------------------------------- /notebooks/semantic-segmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "YfIk2es3hJEd" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import tensorflow as tf\n", 12 | "\n", 13 | "import os\n", 14 | "import time\n", 15 | "\n", 16 | "from matplotlib import pyplot as plt\n", 17 | "from IPython import display" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": { 24 | "id": "2CbTEt448b4R" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "BUFFER_SIZE = 400\n", 29 | "EPOCHS = 100\n", 30 | "LAMBDA = 100\n", 31 | "DATASET = 'cityscapes'\n", 32 | "BATCH_SIZE = 8\n", 33 | "IMG_WIDTH = 256\n", 34 | "IMG_HEIGHT = 256\n", 35 | "patch_size = 8\n", 36 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n", 37 | "projection_dim = 64\n", 38 | "embed_dim = 64\n", 39 | "num_heads = 2 \n", 40 | "ff_dim = 32\n", 41 | "\n", 42 | "assert IMG_WIDTH == IMG_HEIGHT, \"image width and image height must have same dims\"\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "colab": { 50 | "base_uri": "https://localhost:8080/" 51 | }, 52 | "id": "Kn-k8kTXuAlv", 53 | "outputId": "6322b63c-547d-4ae7-d1aa-5c5098e5fe3d" 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "_URL = f'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/{DATASET}.tar.gz'\n", 58 | "\n", 59 | "path_to_zip = tf.keras.utils.get_file(f'{DATASET}.tar.gz',\n", 60 | " origin=_URL,\n", 61 | " extract=True)\n", 62 | "\n", 63 | "PATH = os.path.join(os.path.dirname(path_to_zip), f'{DATASET}/')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "id": "aO9ZAGH5K3SY" 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "def load(image_file):\n", 75 | " image = tf.io.read_file(image_file)\n", 76 | " image = tf.image.decode_jpeg(image)\n", 77 | "\n", 78 | " w = tf.shape(image)[1]\n", 79 | "\n", 80 | " w = w // 2\n", 81 | " real_image = image[:, :w, :]\n", 82 | " input_image = image[:, w:, :]\n", 83 | "\n", 84 | " input_image = tf.cast(input_image, tf.float32)\n", 85 | " real_image = tf.cast(real_image, tf.float32)\n", 86 | "\n", 87 | " return input_image, real_image" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "colab": { 95 | "base_uri": "https://localhost:8080/", 96 | "height": 538 97 | }, 98 | "id": "4OLHMpsQ5aOv", 99 | "outputId": "1242d6f1-c340-47bc-a716-e97a5e82acfd" 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "inp, re = load(PATH+'train/100.jpg')\n", 104 | "# casting to int for matplotlib to show the image\n", 105 | "plt.figure()\n", 106 | "plt.imshow(inp/255.0)\n", 107 | "plt.figure()\n", 108 | "plt.imshow(re/255.0)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "id": "rwwYQpu9FzDu" 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "def resize(input_image, real_image, height, width):\n", 120 | " input_image = tf.image.resize(input_image, [height, width],\n", 121 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", 122 | " real_image = tf.image.resize(real_image, [height, width],\n", 123 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n", 124 | "\n", 125 | " return input_image, real_image" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "id": "Yn3IwqhiIszt" 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "def random_crop(input_image, real_image):\n", 137 | " stacked_image = tf.stack([input_image, real_image], axis=0)\n", 138 | " cropped_image = tf.image.random_crop(\n", 139 | " stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n", 140 | "\n", 141 | " return cropped_image[0], cropped_image[1]" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": { 148 | "id": "muhR2cgbLKWW" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "# normalizing the images to [-1, 1]\n", 153 | "\n", 154 | "def normalize(input_image, real_image):\n", 155 | " input_image = (input_image / 127.5) - 1\n", 156 | " real_image = (real_image / 127.5) - 1\n", 157 | "\n", 158 | " return real_image, input_image" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "id": "fVQOjcPVLrUc" 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "@tf.function()\n", 170 | "def random_jitter(input_image, real_image):\n", 171 | " # resizing to 286 x 286 x 3\n", 172 | " input_image, real_image = resize(input_image, real_image, 286, 286)\n", 173 | "\n", 174 | " # randomly cropping to 256 x 256 x 3\n", 175 | " input_image, real_image = random_crop(input_image, real_image)\n", 176 | "\n", 177 | " if tf.random.uniform(()) > 0.5:\n", 178 | " # random mirroring\n", 179 | " input_image = tf.image.flip_left_right(input_image)\n", 180 | " real_image = tf.image.flip_left_right(real_image)\n", 181 | "\n", 182 | " return input_image, real_image" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "colab": { 190 | "base_uri": "https://localhost:8080/", 191 | "height": 357 192 | }, 193 | "id": "n0OGdi6D92kM", 194 | "outputId": "aa3371d3-f764-4e11-affd-3b6640646491" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "plt.figure(figsize=(6, 6))\n", 199 | "for i in range(4):\n", 200 | " rj_inp, rj_re = random_jitter(inp, re)\n", 201 | " plt.subplot(2, 2, i+1)\n", 202 | " plt.imshow(rj_inp/255.0)\n", 203 | " plt.axis('off')\n", 204 | "plt.show()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "id": "tyaP4hLJ8b4W" 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "def load_image_train(image_file):\n", 216 | " input_image, real_image = load(image_file)\n", 217 | " input_image, real_image = random_jitter(input_image, real_image)\n", 218 | " input_image, real_image = normalize(input_image, real_image)\n", 219 | "\n", 220 | " return input_image, real_image" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": { 227 | "id": "VB3Z6D_zKSru" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "def load_image_test(image_file):\n", 232 | " input_image, real_image = load(image_file)\n", 233 | " input_image, real_image = resize(input_image, real_image,\n", 234 | " IMG_HEIGHT, IMG_WIDTH)\n", 235 | " input_image, real_image = normalize(input_image, real_image)\n", 236 | "\n", 237 | " return input_image, real_image" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": { 244 | "id": "SQHmYSmk8b4b" 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "tf.config.run_functions_eagerly(False)\n", 249 | "\n", 250 | "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n", 251 | "train_dataset = train_dataset.map(load_image_train,\n", 252 | " num_parallel_calls=tf.data.AUTOTUNE)\n", 253 | "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", 254 | "train_dataset = train_dataset.batch(BATCH_SIZE)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": { 261 | "id": "MS9J0yA58b4g" 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "try:\n", 266 | " test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n", 267 | " test_dataset = test_dataset.map(load_image_test)\n", 268 | " test_dataset = test_dataset.batch(BATCH_SIZE)\n", 269 | "except:\n", 270 | " test_dataset = train_dataset" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": { 277 | "id": "AWSBM-ckAZZL" 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "class Patches(tf.keras.layers.Layer):\n", 282 | " def __init__(self, patch_size):\n", 283 | " super(Patches, self).__init__()\n", 284 | " self.patch_size = patch_size\n", 285 | "\n", 286 | " def call(self, images):\n", 287 | " batch_size = tf.shape(images)[0]\n", 288 | " patches = tf.image.extract_patches(\n", 289 | " images=images,\n", 290 | " sizes=[1, self.patch_size, self.patch_size, 1],\n", 291 | " strides=[1, self.patch_size, self.patch_size, 1],\n", 292 | " rates=[1, 1, 1, 1],\n", 293 | " padding=\"SAME\",\n", 294 | " )\n", 295 | " patch_dims = patches.shape[-1]\n", 296 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n", 297 | " return patches" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": { 304 | "id": "mXT2GyxTAZWq" 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "class PatchEncoder(tf.keras.layers.Layer):\n", 309 | " def __init__(self, num_patches, projection_dim):\n", 310 | " super(PatchEncoder, self).__init__()\n", 311 | " self.num_patches = num_patches\n", 312 | " self.projection = layers.Dense(units=projection_dim)\n", 313 | " self.position_embedding = layers.Embedding(\n", 314 | " input_dim=num_patches, output_dim=projection_dim\n", 315 | " )\n", 316 | "\n", 317 | " def call(self, patch):\n", 318 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n", 319 | " encoded = self.projection(patch) + self.position_embedding(positions)\n", 320 | " return encoded" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": { 327 | "id": "EsRN0b3qAdWz" 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "class TransformerBlock(tf.keras.layers.Layer):\n", 332 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n", 333 | " super(TransformerBlock, self).__init__()\n", 334 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n", 335 | " self.ffn = tf.keras.Sequential(\n", 336 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", 337 | " )\n", 338 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n", 339 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n", 340 | " self.dropout1 = layers.Dropout(rate)\n", 341 | " self.dropout2 = layers.Dropout(rate)\n", 342 | "\n", 343 | " def call(self, inputs, training):\n", 344 | " attn_output = self.att(inputs, inputs)\n", 345 | " attn_output = self.dropout1(attn_output, training=training)\n", 346 | " out1 = self.layernorm1(inputs + attn_output)\n", 347 | " ffn_output = self.ffn(out1)\n", 348 | " ffn_output = self.dropout2(ffn_output, training=training)\n", 349 | " return self.layernorm2(out1 + ffn_output)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "metadata": { 356 | "id": "BzdEEA95TzBE" 357 | }, 358 | "outputs": [], 359 | "source": [ 360 | "from tensorflow import Tensor\n", 361 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n", 362 | " Add, AveragePooling2D, Flatten, Dense\n", 363 | "from tensorflow.keras.models import Model\n", 364 | "\n", 365 | "def relu_bn(inputs: Tensor) -> Tensor:\n", 366 | " relu = ReLU()(inputs)\n", 367 | " bn = BatchNormalization()(relu)\n", 368 | " return bn\n", 369 | "\n", 370 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n", 371 | " y = Conv2D(kernel_size=kernel_size,\n", 372 | " strides= (1 if not downsample else 2),\n", 373 | " filters=filters,\n", 374 | " padding=\"same\")(x)\n", 375 | " y = relu_bn(y)\n", 376 | " y = Conv2D(kernel_size=kernel_size,\n", 377 | " strides=1,\n", 378 | " filters=filters,\n", 379 | " padding=\"same\")(y)\n", 380 | "\n", 381 | " if downsample:\n", 382 | " x = Conv2D(kernel_size=1,\n", 383 | " strides=2,\n", 384 | " filters=filters,\n", 385 | " padding=\"same\")(x)\n", 386 | " out = Add()([x, y])\n", 387 | " out = relu_bn(out)\n", 388 | " return out" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "id": "lFPI4Nu-8b4q" 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "from tensorflow.keras import layers\n", 400 | "\n", 401 | "def Generator():\n", 402 | "\n", 403 | " inputs = layers.Input(shape=(256, 256, 3))\n", 404 | "\n", 405 | " patches = Patches(patch_size)(inputs)\n", 406 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n", 407 | "\n", 408 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n", 409 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 410 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 411 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n", 412 | "\n", 413 | " x = layers.Reshape((8, 8, 1024))(x)\n", 414 | "\n", 415 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 416 | " x = layers.BatchNormalization()(x)\n", 417 | " x = layers.LeakyReLU()(x)\n", 418 | "\n", 419 | " x = residual_block(x, downsample=False, filters=512)\n", 420 | "\n", 421 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 422 | " x = layers.BatchNormalization()(x)\n", 423 | " x = layers.LeakyReLU()(x)\n", 424 | "\n", 425 | " x = residual_block(x, downsample=False, filters=256)\n", 426 | "\n", 427 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n", 428 | " x = layers.BatchNormalization()(x)\n", 429 | " x = layers.LeakyReLU()(x)\n", 430 | " \n", 431 | " x = residual_block(x, downsample=False, filters=64)\n", 432 | "\n", 433 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)\n", 434 | " x = layers.BatchNormalization()(x)\n", 435 | " x = layers.LeakyReLU()(x)\n", 436 | "\n", 437 | " x = residual_block(x, downsample=False, filters=32)\n", 438 | "\n", 439 | " x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n", 440 | "\n", 441 | " return tf.keras.Model(inputs=inputs, outputs=x)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": { 448 | "colab": { 449 | "base_uri": "https://localhost:8080/" 450 | }, 451 | "id": "dIbRPFzjmV85", 452 | "outputId": "33b0d3d6-6588-4e3f-aee3-18a9aa09e150" 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "generator = Generator()\n", 457 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n", 458 | "generator.summary()" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "colab": { 466 | "base_uri": "https://localhost:8080/", 467 | "height": 303 468 | }, 469 | "id": "U1N1_obwtdQH", 470 | "outputId": "abf76049-d489-4635-8f9b-512e3935387c" 471 | }, 472 | "outputs": [], 473 | "source": [ 474 | "gen_output = generator(inp[tf.newaxis, ...], training=False)\n", 475 | "plt.imshow(gen_output[0, ...])" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": { 482 | "id": "lbHFNexF0x6O" 483 | }, 484 | "outputs": [], 485 | "source": [ 486 | "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "metadata": { 493 | "id": "RmdVsmvhPxyy" 494 | }, 495 | "outputs": [], 496 | "source": [ 497 | "def generate_images(model, test_input, tar):\n", 498 | " prediction = model(test_input, training=True)\n", 499 | " plt.figure(figsize=(15, 15))\n", 500 | "\n", 501 | " display_list = [test_input[0], tar[0], prediction[0]]\n", 502 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", 503 | "\n", 504 | " for i in range(3):\n", 505 | " plt.subplot(1, 3, i+1)\n", 506 | " plt.title(title[i])\n", 507 | " # getting the pixel values between [0, 1] to plot it.\n", 508 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n", 509 | " plt.axis('off')\n", 510 | " plt.show()\n", 511 | "\n", 512 | "def generate_batch_images(model, test_input, tar):\n", 513 | " for i in range(len(test_input)):\n", 514 | " prediction = model(test_input, training=True)\n", 515 | " plt.figure(figsize=(15, 15))\n", 516 | "\n", 517 | " display_list = [test_input[i], tar[i], prediction[i]]\n", 518 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", 519 | " \n", 520 | " for i in range(3):\n", 521 | " plt.subplot(1, 3, i+1)\n", 522 | " plt.title(title[i])\n", 523 | " # getting the pixel values between [0, 1] to plot it.\n", 524 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n", 525 | " plt.axis('off')\n", 526 | " plt.show()" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": { 533 | "colab": { 534 | "base_uri": "https://localhost:8080/", 535 | "height": 293 536 | }, 537 | "id": "8Fc4NzT-DgEx", 538 | "outputId": "6b5e738a-5851-4c3c-89e9-2defb4b32b88" 539 | }, 540 | "outputs": [], 541 | "source": [ 542 | "for example_input, example_target in test_dataset.take(1):\n", 543 | " generate_images(generator, example_input, example_target)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": { 550 | "id": "KBKUV2sKXDbY" 551 | }, 552 | "outputs": [], 553 | "source": [ 554 | "@tf.function\n", 555 | "def train_step(input_image, target, epoch):\n", 556 | " with tf.device('/device:GPU:0'):\n", 557 | " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", 558 | " gen_output = generator(input_image, training=True)\n", 559 | "\n", 560 | " gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", 561 | " \n", 562 | " generator_gradients = gen_tape.gradient(gen_total_loss,\n", 563 | " generator.trainable_variables)\n", 564 | "\n", 565 | " generator_optimizer.apply_gradients(zip(generator_gradients,\n", 566 | " generator.trainable_variables))" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": { 573 | "id": "2M7LmLtGEMQJ" 574 | }, 575 | "outputs": [], 576 | "source": [ 577 | "def fit(train_ds, epochs, test_ds):\n", 578 | " for epoch in range(epochs):\n", 579 | " start = time.time()\n", 580 | "\n", 581 | " display.clear_output(wait=True)\n", 582 | "\n", 583 | " for example_input, example_target in test_ds.take(1):\n", 584 | " generate_images(generator, example_input, example_target)\n", 585 | " print(\"Epoch: \", epoch)\n", 586 | "\n", 587 | " # Train\n", 588 | " for n, (input_image, target) in train_ds.enumerate():\n", 589 | " print('.', end='')\n", 590 | " if (n+1) % 100 == 0:\n", 591 | " print()\n", 592 | " train_step(input_image, target, epoch)\n", 593 | " print()\n", 594 | "\n", 595 | " generator.save_weights(f'_{DATASET}-gen-weights.h5')\n", 596 | " discriminator.save_weights(f'_{DATASET}-disc-weights.h5')" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": { 603 | "colab": { 604 | "base_uri": "https://localhost:8080/", 605 | "height": 293 606 | }, 607 | "id": "a1zZmKmvOH85", 608 | "outputId": "e90cbd9a-0860-4260-f928-c4609ace3d07", 609 | "scrolled": true 610 | }, 611 | "outputs": [], 612 | "source": [ 613 | "fit(train_dataset, 100000, test_dataset)" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "metadata": { 620 | "colab": { 621 | "base_uri": "https://localhost:8080/", 622 | "height": 1000 623 | }, 624 | "id": "KUgSnmy2nqSP", 625 | "outputId": "65667797-7b67-4b07-e9ab-94c5ac5173d0" 626 | }, 627 | "outputs": [], 628 | "source": [ 629 | "for inp, tar in test_dataset.take(1):\n", 630 | " outs = generator(inp)\n", 631 | " generate_batch_images(generator, inp, tar)" 632 | ] 633 | } 634 | ], 635 | "metadata": { 636 | "accelerator": "GPU", 637 | "colab": { 638 | "collapsed_sections": [], 639 | "name": "image2image_res.ipynb", 640 | "provenance": [] 641 | }, 642 | "kernelspec": { 643 | "display_name": "Python 3", 644 | "language": "python", 645 | "name": "python3" 646 | }, 647 | "language_info": { 648 | "codemirror_mode": { 649 | "name": "ipython", 650 | "version": 3 651 | }, 652 | "file_extension": ".py", 653 | "mimetype": "text/x-python", 654 | "name": "python", 655 | "nbconvert_exporter": "python", 656 | "pygments_lexer": "ipython3", 657 | "version": "3.8.5" 658 | } 659 | }, 660 | "nbformat": 4, 661 | "nbformat_minor": 1 662 | } 663 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipython==8.0.1 2 | matplotlib==3.1.2 3 | opencv_python==4.5.5.62 4 | scikit_image==0.19.1 5 | scipy==1.4.1 6 | tensorflow==2.4.0 7 | -------------------------------------------------------------------------------- /results/depth_perseption/combine_images (14).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/combine_images (14).jpg -------------------------------------------------------------------------------- /results/depth_perseption/d1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d1.png -------------------------------------------------------------------------------- /results/depth_perseption/d2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d2.png -------------------------------------------------------------------------------- /results/depth_perseption/d3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d3.png -------------------------------------------------------------------------------- /results/depth_perseption/d4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d4.png -------------------------------------------------------------------------------- /results/depth_perseption/d5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d5.png -------------------------------------------------------------------------------- /results/depth_perseption/d6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d6.png -------------------------------------------------------------------------------- /results/object-segmentation/combine_images (15).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/combine_images (15).jpg -------------------------------------------------------------------------------- /results/object-segmentation/os1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os1.png -------------------------------------------------------------------------------- /results/object-segmentation/os2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os2.png -------------------------------------------------------------------------------- /results/object-segmentation/os3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os3.png -------------------------------------------------------------------------------- /results/object-segmentation/os4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os4.png -------------------------------------------------------------------------------- /results/object-segmentation/os5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os5.png -------------------------------------------------------------------------------- /results/object-segmentation/os6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os6.png -------------------------------------------------------------------------------- /results/semantic-segmentation/combine_images (16).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/combine_images (16).jpg -------------------------------------------------------------------------------- /results/semantic-segmentation/f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f1.png -------------------------------------------------------------------------------- /results/semantic-segmentation/f2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f2.png -------------------------------------------------------------------------------- /results/semantic-segmentation/f3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f3.png -------------------------------------------------------------------------------- /results/semantic-segmentation/f4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f4.png -------------------------------------------------------------------------------- /results/semantic-segmentation/f5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f5.png -------------------------------------------------------------------------------- /results/semantic-segmentation/f6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f6.png -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import utils # local import 4 | import skimage 5 | import numpy as np 6 | from numpy import log 7 | from numpy import std 8 | from numpy import exp 9 | from math import floor 10 | from numpy import mean 11 | from numpy import cov 12 | from numpy import trace 13 | import tensorflow as tf 14 | from numpy import asarray 15 | from model import Generator # local import 16 | from numpy import expand_dims 17 | from numpy import iscomplexobj 18 | from scipy.linalg import sqrtm 19 | from skimage.metrics import structural_similarity as ssim 20 | from tensorflow.keras.applications.inception_v3 import InceptionV3 21 | from tensorflow.keras.applications.inception_v3 import preprocess_input 22 | 23 | EPOCHS = 100 24 | LAMBDA = 100 25 | BATCH_SIZE = 8 26 | IMG_WIDTH = 256 27 | IMG_HEIGHT = 256 28 | BUFFER_SIZE = 400 29 | DATASET = 'cityscapes' 30 | 31 | num_of_samples = 100 # number of samples to test the model 32 | 33 | # model params 34 | ff_dim = 32 35 | num_heads = 2 36 | patch_size = 8 37 | embed_dim = 64 38 | projection_dim = 64 39 | input_shape = (IMG_HEIGHT, IMG_WIDTH, 3) 40 | num_patches = (IMG_HEIGHT // patch_size) ** 2 41 | 42 | path_to_weights = sys.argv[1] 43 | device = '/device:GPU:0' if utils.check_cuda else '/cpu:0' 44 | 45 | 46 | _URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{DATASET}.tar.gz' 47 | 48 | path_to_zip = tf.keras.utils.get_file(f'{DATASET}.tar.gz', 49 | origin=_URL, 50 | extract=True) 51 | 52 | PATH = os.path.join(os.path.dirname(path_to_zip), f'{DATASET}/') 53 | 54 | 55 | def load(image_file): 56 | image = tf.io.read_file(image_file) 57 | image = tf.image.decode_jpeg(image) 58 | 59 | w = tf.shape(image)[1] 60 | 61 | w = w // 2 62 | real_image = image[:, :w, :] 63 | input_image = image[:, w:, :] 64 | 65 | input_image = tf.cast(input_image, tf.float32) 66 | real_image = tf.cast(real_image, tf.float32) 67 | 68 | return input_image, real_image 69 | 70 | 71 | inp, re = load(PATH+'train/100.jpg') 72 | 73 | 74 | def resize(input_image, real_image, height, width): 75 | input_image = tf.image.resize(input_image, [height, width], 76 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 77 | real_image = tf.image.resize(real_image, [height, width], 78 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 79 | 80 | return input_image, real_image 81 | 82 | 83 | def random_crop(input_image, real_image): 84 | stacked_image = tf.stack([input_image, real_image], axis=0) 85 | cropped_image = tf.image.random_crop( 86 | stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3]) 87 | 88 | return cropped_image[0], cropped_image[1] 89 | 90 | 91 | # normalizing the images to [-1, 1] 92 | def normalize(input_image, real_image): 93 | input_image = (input_image / 127.5) - 1 94 | real_image = (real_image / 127.5) - 1 95 | 96 | return real_image, input_image 97 | 98 | 99 | @tf.function() 100 | def random_jitter(input_image, real_image): 101 | # resizing to 286 x 286 x 3 102 | input_image, real_image = resize(input_image, real_image, 286, 286) 103 | 104 | # randomly cropping to 256 x 256 x 3 105 | input_image, real_image = random_crop(input_image, real_image) 106 | 107 | if tf.random.uniform(()) > 0.5: 108 | # random mirroring 109 | input_image = tf.image.flip_left_right(input_image) 110 | real_image = tf.image.flip_left_right(real_image) 111 | 112 | return input_image, real_image 113 | 114 | 115 | def load_image_train(image_file): 116 | input_image, real_image = load(image_file) 117 | input_image, real_image = random_jitter(input_image, real_image) 118 | input_image, real_image = normalize(input_image, real_image) 119 | 120 | return input_image, real_image 121 | 122 | 123 | def load_image_test(image_file): 124 | input_image, real_image = load(image_file) 125 | input_image, real_image = resize(input_image, real_image, 126 | IMG_HEIGHT, IMG_WIDTH) 127 | input_image, real_image = normalize(input_image, real_image) 128 | 129 | return input_image, real_image 130 | 131 | 132 | train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg') 133 | train_dataset = train_dataset.map(load_image_train, 134 | num_parallel_calls=tf.data.AUTOTUNE) 135 | train_dataset = train_dataset.shuffle(BUFFER_SIZE) 136 | train_dataset = train_dataset.batch(BATCH_SIZE) 137 | 138 | 139 | def generate_samples(model, dataset, device, num_of_samples): 140 | with tf.device(device): 141 | outs = list() 142 | targets = list() 143 | 144 | for n, (input_image, target) in dataset.enumerate(): 145 | 146 | target = np.array(target) 147 | targets.append(target) 148 | 149 | input_image = np.array(input_image) 150 | model_out = np.squeeze(np.array(model(input_image, training=False)).reshape((-1, 256, 256, 3))) 151 | outs.append(model_out) 152 | 153 | if (n + 1) % num_of_samples == 0: 154 | break 155 | 156 | return outs, targets 157 | 158 | 159 | def pre_process(outs, targets): 160 | outs = np.array(outs) 161 | targets = np.array(targets) 162 | 163 | outs = outs.reshape((-1, 3, 256, 256)) 164 | targets = targets.reshape(-1, 3, 256, 256) 165 | 166 | outs = outs * 0.5 + 0.5 167 | targets = targets * 0.5 + 0.5 168 | 169 | outs = outs * 255 170 | targets = targets * 255 171 | 172 | return outs, targets 173 | 174 | # assumes images have any shape and pixels in [0,255] 175 | def calculate_inception_score(images, n_split=10, eps=1E-16): 176 | # load inception v3 model 177 | model = InceptionV3() 178 | # enumerate splits of images/predictions 179 | scores = list() 180 | n_part = floor(images.shape[0] / n_split) 181 | for i in range(n_split): 182 | # retrieve images 183 | ix_start, ix_end = i * n_part, (i+1) * n_part 184 | subset = images[ix_start:ix_end] 185 | # convert from uint8 to float32 186 | subset = subset.astype('float32') 187 | # scale images to the required size 188 | subset = scale_images(subset, (299,299,3)) 189 | # pre-process images, scale to [-1,1] 190 | subset = preprocess_input(subset) 191 | # predict p(y|x) 192 | p_yx = model.predict(subset) 193 | # calculate p(y) 194 | p_y = expand_dims(p_yx.mean(axis=0), 0) 195 | # calculate KL divergence using log probabilities 196 | kl_d = p_yx * (log(p_yx + eps) - log(p_y + eps)) 197 | # sum over classes 198 | sum_kl_d = kl_d.sum(axis=1) 199 | # average over images 200 | avg_kl_d = mean(sum_kl_d) 201 | # undo the log 202 | is_score = exp(avg_kl_d) 203 | # store 204 | scores.append(is_score) 205 | # average across images 206 | is_avg, is_std = mean(scores), std(scores) 207 | return is_avg, is_std 208 | 209 | 210 | # scale an array of images to a new size 211 | def scale_images(images, new_shape): 212 | images_list = list() 213 | for image in images: 214 | # resize with nearest neighbor interpolation 215 | new_image = skimage.transform.resize(image, new_shape, 0) 216 | # store 217 | images_list.append(new_image) 218 | return asarray(images_list) 219 | 220 | 221 | # calculate frechet inception distance 222 | def calculate_fid(images1, images2): 223 | model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3)) 224 | 225 | images1 = scale_images(images1, (299,299,3)) 226 | images2 = scale_images(images2, (299,299,3)) 227 | 228 | # calculate activations 229 | act1 = model.predict(images1) 230 | act2 = model.predict(images2) 231 | # calculate mean and covariance statistics 232 | mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False) 233 | mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False) 234 | # calculate sum squared difference between means 235 | ssdiff = np.sum((mu1 - mu2)**2.0) 236 | # calculate sqrt of product between cov 237 | covmean = sqrtm(sigma1.dot(sigma2)) 238 | # check and correct imaginary numbers from sqrt 239 | if iscomplexobj(covmean): 240 | covmean = covmean.real 241 | # calculate score 242 | fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean) 243 | return fid 244 | 245 | 246 | model = Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim) 247 | model.load_weights(path_to_weights) 248 | 249 | 250 | # generate and process samples from the model 251 | outs, targets = generate_samples(model, train_dataset, device, num_of_samples) 252 | outs, targets = pre_process(outs, targets) 253 | 254 | 255 | # calculate fid, ssim, inception score 256 | fid_score = calculate_fid(targets, outs) 257 | ssim_score = ssim(targets.reshape(-1, 256, 256, 3), outs.reshape(-1, 256, 256, 3), data_range=targets.max() - targets.min(), multichannel=True) 258 | inception_score = calculate_inception_score(outs) 259 | 260 | print('----------------|-------------') 261 | print(f'ssim score | {ssim_score}') 262 | print(f'FID | {fid_score}') 263 | print(f'Inception score | mean: {inception_score[0]} std: {inception_score[1]}') 264 | print('----------------|-------------') 265 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import Tensor 3 | from tensorflow.keras import layers 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras.layers import (Input, 6 | Conv2D, 7 | ReLU, 8 | BatchNormalization, 9 | Add, 10 | AveragePooling2D, 11 | Flatten, 12 | Dense) 13 | 14 | 15 | class Patches(tf.keras.layers.Layer): 16 | def __init__(self, patch_size): 17 | super(Patches, self).__init__() 18 | self.patch_size = patch_size 19 | 20 | def call(self, images): 21 | batch_size = tf.shape(images)[0] 22 | patches = tf.image.extract_patches( 23 | images=images, 24 | sizes=[1, self.patch_size, self.patch_size, 1], 25 | strides=[1, self.patch_size, self.patch_size, 1], 26 | rates=[1, 1, 1, 1], 27 | padding="SAME", 28 | ) 29 | patch_dims = patches.shape[-1] 30 | patches = tf.reshape(patches, [batch_size, -1, patch_dims]) 31 | return patches 32 | 33 | 34 | class PatchEncoder(tf.keras.layers.Layer): 35 | def __init__(self, num_patches, projection_dim): 36 | super(PatchEncoder, self).__init__() 37 | self.num_patches = num_patches 38 | self.projection = layers.Dense(units=projection_dim) 39 | self.position_embedding = layers.Embedding( 40 | input_dim=num_patches, output_dim=projection_dim 41 | ) 42 | 43 | def call(self, patch): 44 | positions = tf.range(start=0, limit=self.num_patches, delta=1) 45 | encoded = self.projection(patch) + self.position_embedding(positions) 46 | return encoded 47 | 48 | 49 | class TransformerBlock(tf.keras.layers.Layer): 50 | def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): 51 | super(TransformerBlock, self).__init__() 52 | self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) 53 | self.ffn = tf.keras.Sequential( 54 | [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),] 55 | ) 56 | self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) 57 | self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) 58 | self.dropout1 = layers.Dropout(rate) 59 | self.dropout2 = layers.Dropout(rate) 60 | 61 | def call(self, inputs, training): 62 | attn_output = self.att(inputs, inputs) 63 | attn_output = self.dropout1(attn_output, training=training) 64 | out1 = self.layernorm1(inputs + attn_output) 65 | ffn_output = self.ffn(out1) 66 | ffn_output = self.dropout2(ffn_output, training=training) 67 | return self.layernorm2(out1 + ffn_output) 68 | 69 | 70 | def relu_bn(inputs: Tensor) -> Tensor: 71 | relu = ReLU()(inputs) 72 | bn = BatchNormalization()(relu) 73 | return bn 74 | 75 | 76 | def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor: 77 | y = Conv2D(kernel_size=kernel_size, 78 | strides= (1 if not downsample else 2), 79 | filters=filters, 80 | padding="same")(x) 81 | y = relu_bn(y) 82 | y = Conv2D(kernel_size=kernel_size, 83 | strides=1, 84 | filters=filters, 85 | padding="same")(y) 86 | 87 | if downsample: 88 | x = Conv2D(kernel_size=1, 89 | strides=2, 90 | filters=filters, 91 | padding="same")(x) 92 | out = Add()([x, y]) 93 | out = relu_bn(out) 94 | return out 95 | 96 | 97 | def Generator(input_shape, 98 | patch_size, 99 | num_patches, 100 | projection_dim, 101 | num_heads, 102 | ff_dim): 103 | 104 | inputs = layers.Input(shape=(256, 256, 3)) 105 | 106 | patches = Patches(patch_size)(inputs) 107 | encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) 108 | 109 | x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches) 110 | x = TransformerBlock(64, num_heads, ff_dim)(x) 111 | x = TransformerBlock(64, num_heads, ff_dim)(x) 112 | x = TransformerBlock(64, num_heads, ff_dim)(x) 113 | 114 | x = layers.Reshape((8, 8, 1024))(x) 115 | 116 | x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x) 117 | x = layers.BatchNormalization()(x) 118 | x = layers.LeakyReLU()(x) 119 | 120 | x = residual_block(x, downsample=False, filters=512) 121 | 122 | x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x) 123 | x = layers.BatchNormalization()(x) 124 | x = layers.LeakyReLU()(x) 125 | 126 | x = residual_block(x, downsample=False, filters=256) 127 | 128 | x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x) 129 | x = layers.BatchNormalization()(x) 130 | x = layers.LeakyReLU()(x) 131 | 132 | x = residual_block(x, downsample=False, filters=64) 133 | 134 | x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x) 135 | x = layers.BatchNormalization()(x) 136 | x = layers.LeakyReLU()(x) 137 | 138 | x = residual_block(x, downsample=False, filters=32) 139 | 140 | x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x) 141 | 142 | return tf.keras.Model(inputs=inputs, outputs=x) 143 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import tensorflow as tf 4 | from IPython import display 5 | from tensorflow import Tensor 6 | from tensorflow.keras import layers 7 | from matplotlib import pyplot as plt 8 | from tensorflow.keras.models import Model 9 | 10 | import utils # local import 11 | from model import Generator # local import 12 | 13 | EPOCHS = 100 14 | LAMBDA = 100 15 | BATCH_SIZE = 8 16 | IMG_WIDTH = 256 17 | IMG_HEIGHT = 256 18 | BUFFER_SIZE = 400 19 | SAVE_PATH = 'weights' 20 | DATASET = 'cityscapes' 21 | ff_dim = 32 22 | num_heads = 2 23 | patch_size = 8 24 | embed_dim = 64 25 | projection_dim = 64 26 | input_shape = (IMG_HEIGHT, IMG_WIDTH, 3) 27 | num_patches = (IMG_HEIGHT // patch_size) ** 2 28 | 29 | if not os.path.exists(SAVE_PATH): 30 | os.makedirs(SAVE_PATH) 31 | 32 | available_datasets = [ 33 | 'cityscapes', 34 | 'edges2handbags', 35 | 'edges2shoes', 36 | 'facades', 37 | 'maps', 38 | 'night2day' 39 | ] 40 | 41 | if DATASET not in available_datasets: 42 | print(f'[ERROR] dataset: {DATASET}') 43 | print('[INFO] please us on of the following datasets') 44 | for dataset in available_datasets: 45 | print(f' -> {dataset}') 46 | 47 | exit(1) 48 | 49 | assert IMG_WIDTH == IMG_HEIGHT, 'width and height must have same size' 50 | _URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{DATASET}.tar.gz' 51 | device = '/device:GPU:0' if utils.check_cuda else '/cpu:0' 52 | 53 | path_to_zip = tf.keras.utils.get_file(f'{DATASET}.tar.gz', 54 | origin=_URL, 55 | extract=True) 56 | 57 | PATH = os.path.join(os.path.dirname(path_to_zip), f'{DATASET}/') 58 | 59 | 60 | def load(image_file): 61 | image = tf.io.read_file(image_file) 62 | image = tf.image.decode_jpeg(image) 63 | 64 | w = tf.shape(image)[1] 65 | 66 | w = w // 2 67 | real_image = image[:, :w, :] 68 | input_image = image[:, w:, :] 69 | 70 | input_image = tf.cast(input_image, tf.float32) 71 | real_image = tf.cast(real_image, tf.float32) 72 | 73 | return input_image, real_image 74 | 75 | 76 | def resize(input_image, real_image, height, width): 77 | input_image = tf.image.resize(input_image, [height, width], 78 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 79 | real_image = tf.image.resize(real_image, [height, width], 80 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 81 | 82 | return input_image, real_image 83 | 84 | 85 | def random_crop(input_image, real_image): 86 | stacked_image = tf.stack([input_image, real_image], axis=0) 87 | cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3]) 88 | 89 | return cropped_image[0], cropped_image[1] 90 | 91 | 92 | # normalizing the images between [-1, 1] 93 | def normalize(input_image, real_image): 94 | input_image = (input_image / 127.5) - 1 95 | real_image = (real_image / 127.5) - 1 96 | 97 | return real_image, input_image 98 | 99 | 100 | @tf.function() 101 | def random_jitter(input_image, real_image): 102 | # resizing to 286 x 286 x 3 103 | input_image, real_image = resize(input_image, real_image, 286, 286) 104 | 105 | # randomly cropping to 256 x 256 x 3 106 | input_image, real_image = random_crop(input_image, real_image) 107 | 108 | if tf.random.uniform(()) > 0.5: 109 | # random mirroring 110 | input_image = tf.image.flip_left_right(input_image) 111 | real_image = tf.image.flip_left_right(real_image) 112 | 113 | return input_image, real_image 114 | 115 | 116 | def load_image_train(image_file): 117 | input_image, real_image = load(image_file) 118 | input_image, real_image = random_jitter(input_image, real_image) 119 | input_image, real_image = normalize(input_image, real_image) 120 | 121 | return input_image, real_image 122 | 123 | 124 | def load_image_test(image_file): 125 | input_image, real_image = load(image_file) 126 | input_image, real_image = resize(input_image, real_image, 127 | IMG_HEIGHT, IMG_WIDTH) 128 | input_image, real_image = normalize(input_image, real_image) 129 | 130 | return input_image, real_image 131 | 132 | 133 | tf.config.run_functions_eagerly(False) 134 | 135 | train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg') 136 | train_dataset = train_dataset.map(load_image_train, 137 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 138 | train_dataset = train_dataset.shuffle(BUFFER_SIZE) 139 | train_dataset = train_dataset.batch(BATCH_SIZE) 140 | 141 | try: 142 | test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg') 143 | test_dataset = test_dataset.map(load_image_test) 144 | test_dataset = test_dataset.batch(BATCH_SIZE) 145 | except: 146 | test_dataset = train_dataset 147 | 148 | 149 | generator = Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim) 150 | tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64) 151 | generator.summary() 152 | 153 | optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5) 154 | 155 | 156 | def generate_images(model, test_input, tar): 157 | prediction = model(test_input, training=True) 158 | plt.figure(figsize=(15, 15)) 159 | 160 | display_list = [test_input[0], tar[0], prediction[0]] 161 | title = ['Input Image', 'Ground Truth', 'Predicted Image'] 162 | 163 | for i in range(3): 164 | plt.subplot(1, 3, i+1) 165 | plt.title(title[i]) 166 | # getting the pixel values between [0, 1] to plot it. 167 | plt.imshow(display_list[i] * 0.5 + 0.5) 168 | plt.axis('off') 169 | plt.show() 170 | 171 | 172 | def generate_batch_images(model, test_input, tar): 173 | for i in range(len(test_input)): 174 | prediction = model(test_input, training=True) 175 | plt.figure(figsize=(15, 15)) 176 | 177 | display_list = [test_input[i], tar[i], prediction[i]] 178 | title = ['Input Image', 'Ground Truth', 'Predicted Image'] 179 | 180 | for i in range(3): 181 | plt.subplot(1, 3, i+1) 182 | plt.title(title[i]) 183 | # converting the pixel values to [0, 1] to plot it. 184 | plt.imshow(display_list[i] * 0.5 + 0.5) 185 | plt.axis('off') 186 | plt.show() 187 | 188 | 189 | def train_step(input_image, target, epoch): 190 | with tf.device(device): 191 | with tf.GradientTape() as gen_tape: 192 | gen_output = generator(input_image, training=True) 193 | 194 | gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output)) 195 | 196 | generator_gradients = gen_tape.gradient(gen_total_loss, 197 | generator.trainable_variables) 198 | 199 | optimizer.apply_gradients(zip(generator_gradients, 200 | generator.trainable_variables)) 201 | 202 | 203 | def fit(train_ds, epochs, test_ds): 204 | print(f"[INFO] will train on device: {device}") 205 | for epoch in range(epochs): 206 | 207 | if utils.is_notebook(): 208 | display.clear_output(wait=True) 209 | 210 | for example_input, example_target in test_ds.take(1): 211 | generate_images(generator, example_input, example_target) 212 | 213 | print(f'Epoch: [{epoch}/{epochs}]') 214 | 215 | # Train 216 | for n, (input_image, target) in train_ds.enumerate(): 217 | train_step(input_image, target, epoch) 218 | 219 | generator.save_weights(f'{SAVE_PATH}/tensor2image-{DATASET}-{epoch}-epochs-weights.h5') 220 | 221 | 222 | def test(test_dataset, generator): 223 | ''' 224 | a function to visually inspect to outputs 225 | ''' 226 | if utils.is_notebook(): 227 | for inp, tar in test_dataset.take(1): 228 | generate_batch_images(generator, inp, tar) 229 | 230 | 231 | if __name__ == '__main__': 232 | fit(train_dataset, EPOCHS, test_dataset) 233 | 234 | test(test_dataset, generator) 235 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | def display_image(images:list, display=True, save=False, name=None): 2 | import cv2 3 | import numpy as np 4 | import tensorflow as tf 5 | from matplotlib import pyplot as plt 6 | 7 | img1, img2, img3, *_ = images 8 | 9 | img1 = np.array(img1).astype(np.float32) 10 | img2 = np.array(img2).astype(np.float32) 11 | img3 = np.array(img3).astype(np.float32) 12 | 13 | img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR) 14 | img3 = cv2.cvtColor(img3, cv2.COLOR_GRAY2BGR) 15 | print(img1.shape, img2.shape, img3.shape) 16 | 17 | im_h = cv2.hconcat([img1, img2, img3]) 18 | 19 | im_h = tf.nn.relu(im_h).numpy() 20 | im_h = np.clip(im_h, 0, 1) 21 | 22 | print(np.max(im_h)) 23 | print(np.min(im_h)) 24 | 25 | plt.xticks([]) 26 | plt.yticks([]) 27 | 28 | if display: 29 | plt.imshow(im_h) 30 | 31 | if save: 32 | if name is not None: 33 | plt.imsave(name, im_h.astype(np.float32)) 34 | else: 35 | raise AttributeError('plt.imsave expected to have a name to save the image') 36 | 37 | return im_h 38 | 39 | 40 | def is_notebook(): 41 | try: 42 | shell = get_ipython().__class__.__name__ 43 | if shell == 'ZMQInteractiveShell': 44 | return True # Jupyter notebook or qtconsole 45 | elif shell == 'TerminalInteractiveShell': 46 | return False # Terminal running IPython 47 | else: 48 | return False # Other type (?) 49 | except NameError: 50 | return False # Probably standard Python interpreter 51 | 52 | 53 | def check_cuda(): 54 | import tensorflow as tf 55 | device_name = tf.test.gpu_device_name() 56 | if device_name != '/device:GPU:0': 57 | return False 58 | return True 59 | --------------------------------------------------------------------------------