├── LICENSE
├── README.md
├── assets
├── Screenshot from 2022-01-25 18-32-45.png
├── paper.png
└── prev.jpeg
├── index.html
├── notebooks
├── .ipynb_checkpoints
│ ├── comp_(1)-checkpoint.ipynb
│ ├── depth-checkpoint.ipynb
│ └── object-segmentation-checkpoint.ipynb
├── depth.ipynb
├── object-segmentation.ipynb
└── semantic-segmentation.ipynb
├── requirements.txt
├── results
├── depth_perseption
│ ├── combine_images (14).jpg
│ ├── d1.png
│ ├── d2.png
│ ├── d3.png
│ ├── d4.png
│ ├── d5.png
│ └── d6.png
├── object-segmentation
│ ├── combine_images (15).jpg
│ ├── os1.png
│ ├── os2.png
│ ├── os3.png
│ ├── os4.png
│ ├── os5.png
│ └── os6.png
└── semantic-segmentation
│ ├── combine_images (16).jpg
│ ├── f1.png
│ ├── f2.png
│ ├── f3.png
│ ├── f4.png
│ ├── f5.png
│ └── f6.png
└── src
├── evaluate.py
├── model.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Yigit Gunduc
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tensor-to-image
2 | [Website](https://yigitgunduc.github.io/tensor2image/) | [Arxiv](https://arxiv.org/abs/2110.08037)
3 |
4 |
5 |
6 | ## Abstract
7 |
8 | Transformers gain huge attention since they are first introduced and have
9 | a wide range of applications. Transformers start to take over all areas of
10 | deep learning and the Vision transformers paper also proved that they can
11 | be used for computer vision tasks. In this paper, we utilized a
12 | vision transformer-based custom-designed model, tensor-to-image,
13 | for the image to image translation. With the help of self-attention,
14 | our model was able to generalize and apply to different problems without
15 | a single modification
16 |
17 | ## Setup
18 |
19 | Clone the repo
20 | ```bash
21 | git clone https://github.com/yigitgunduc/tensor-to-image/
22 | ```
23 |
24 | Install requirements
25 | ```bash
26 | pip3 install -r requirements.txt
27 | ```
28 |
29 | > For GPU support setup `TensorFlow >= 2.4.0` with `CUDA v11.0 or above`
30 | > - you can ignore this step if you are going to train on the CPU
31 |
32 | ## Training
33 |
34 | Train the model
35 | ```bash
36 | python3 src/train.py
37 | ```
38 | Weights are saved after every epoch and can be found in `./weights/`
39 |
40 | ## Evaluating
41 |
42 | After you have trained the model you can test it against 3 different criteria
43 | (FID, Structural similarity, Inceptoin score).
44 |
45 | ```bash
46 | python3 src/evaluate.py path/to/weights
47 | ```
48 |
49 | ## Datasets
50 |
51 | Implementation support 8 datasets for various tasks. 6 pix2pix datasets and two additional ones.
52 | 6 of the pix2pix dataset can be used by changing the `DATASET` variable on the `src/train.py`
53 | for the additional datasets please see `notebooks/object-segmentation.ipynb` and
54 | `notebooks/depth.ipynb`
55 |
56 | Dataset available thought the `src/train.py`
57 |
58 | - `cityscapes` 99 MB
59 | - `edges2handbags` 8.0 GB
60 | - `edges2shoes` 2.0 GB
61 | - `facades` 29 MB
62 | - `maps` 239 MB
63 | - `night2day` 1.9 GB
64 |
65 | Dataset available though the notebooks
66 |
67 | - `Oxford-IIIT Pets`
68 | - `RGB+D DATABASE`
69 |
70 | ## Cite
71 | If you use this code for your research, please cite our paper [Tensor-to-Image: Image-to-Image Translation with Vision Transformers](https://arxiv.org/abs/2110.08037)
72 | ```
73 | @article{gunducc2021tensor,
74 | title={Tensor-to-Image: Image-to-Image Translation with Vision Transformers},
75 | author={G{\"u}nd{\"u}{\c{c}}, Yi{\u{g}}it},
76 | journal={arXiv preprint arXiv:2110.08037},
77 | year={2021}
78 | }
79 | ```
80 |
--------------------------------------------------------------------------------
/assets/Screenshot from 2022-01-25 18-32-45.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/assets/Screenshot from 2022-01-25 18-32-45.png
--------------------------------------------------------------------------------
/assets/paper.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/assets/paper.png
--------------------------------------------------------------------------------
/assets/prev.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/assets/prev.jpeg
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
30 |
31 |
32 |
33 | Tensor-to-Image: Image-to-Image Translation with Vision Transformers
34 |
35 |
36 |
37 | Abstract
38 |
39 |
Transformers gain huge attention since they are first introduced and have a wide
40 | range of applications. Transformers start to take over all areas of deep
41 | learning and the Vision transformers paper also proved that they
42 | can be used for computer vision tasks. In this paper, we utilized a
43 | vision transformer-based custom-designed model, tensor-to-image, for the image
44 | to image translation. With the help of self-attention, our model
45 | was able to generalize and apply to different problems without a single
46 | modification
47 |
48 |
49 | Code & Paper
50 |
51 |
52 |
53 |
54 |
58 |
59 |
60 | Cite
61 |
62 |
63 |
If you use this code for your research, please cite our paper Tensor-to-Image: Image-to-Image Translation with Vision Transformers
64 |
65 |
66 |
67 | @article{gunducc2021tensor,
68 | title={Tensor-to-Image: Image-to-Image Translation with Vision Transformers},
69 | author={G{\"u}nd{\"u}{\c{c}}, Yi{\u{g}}it},
70 | journal={arXiv preprint arXiv:2110.08037},
71 | year={2021}
72 | }
73 |
74 |
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/notebooks/.ipynb_checkpoints/depth-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "colab": {
8 | "base_uri": "https://localhost:8080/"
9 | },
10 | "id": "cHaQkNvlc1ki",
11 | "outputId": "23afd6fb-521c-4b67-d765-22812a8a5bab"
12 | },
13 | "outputs": [],
14 | "source": [
15 | "from google.colab import drive\n",
16 | "import os\n",
17 | "import tensorflow as tf\n",
18 | "import glob\n",
19 | "drive.mount('/content/gdrive')"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {
26 | "colab": {
27 | "base_uri": "https://localhost:8080/"
28 | },
29 | "id": "I_R2JVOec2PY",
30 | "outputId": "5ffe6ee9-550c-4453-961c-727307f38bf6"
31 | },
32 | "outputs": [],
33 | "source": [
34 | "!unzip /content/gdrive/MyDrive/indoor_test.zip"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "colab": {
42 | "base_uri": "https://localhost:8080/"
43 | },
44 | "id": "qMZaGHOlc0SD",
45 | "outputId": "9f0ad718-d88a-4384-a315-fc8be1109a3e"
46 | },
47 | "outputs": [],
48 | "source": [
49 | "dataset_path = '../../../depth/dataset/test/LR'"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {
56 | "id": "8Mx-LXgLc0SM"
57 | },
58 | "outputs": [],
59 | "source": [
60 | "input_paths = glob.glob(dataset_path + '/**/color/*.png')\n",
61 | "target_paths = glob.glob(dataset_path + '/**/depth_vi/*.png')"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {
68 | "colab": {
69 | "base_uri": "https://localhost:8080/"
70 | },
71 | "id": "xFbmIbP_c0SN",
72 | "outputId": "56e6ada1-e53b-4791-ec22-97c9b3045379"
73 | },
74 | "outputs": [],
75 | "source": [
76 | "print(target_paths)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {
83 | "id": "4OVZgPjtc0SP"
84 | },
85 | "outputs": [],
86 | "source": [
87 | "BUFFER_SIZE = 400\n",
88 | "EPOCHS = 100\n",
89 | "LAMBDA = 100\n",
90 | "BATCH_SIZE = 8\n",
91 | "IMG_WIDTH = 256\n",
92 | "IMG_HEIGHT = 256\n",
93 | "patch_size = 8\n",
94 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n",
95 | "projection_dim = 64\n",
96 | "embed_dim = 64\n",
97 | "num_heads = 2 \n",
98 | "ff_dim = 32"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {
105 | "id": "g89DtSq_c0SQ"
106 | },
107 | "outputs": [],
108 | "source": [
109 | "real = []\n",
110 | "targets = []"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {
117 | "id": "AvaXbs4lc0SR"
118 | },
119 | "outputs": [],
120 | "source": [
121 | "def load(path):\n",
122 | "\n",
123 | " image_path = path[:-12] + 'c.png'\n",
124 | " image_path = image_path.replace(\"depth_vi\", \"color\")\n",
125 | " depth_path = path[:-12] + 'depth_vi.png'\n",
126 | "\n",
127 | "\n",
128 | " input_image = tf.io.read_file(image_path)\n",
129 | " input_image = tf.image.decode_jpeg(input_image)\n",
130 | " \n",
131 | " target_image = tf.io.read_file(depth_path)\n",
132 | " target_image = tf.image.decode_jpeg(target_image)\n",
133 | " \n",
134 | " input_image = tf.cast(input_image, tf.float32)\n",
135 | " target_image = tf.cast(target_image, tf.float32)\n",
136 | "\n",
137 | "\n",
138 | " return input_image, target_image"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {
145 | "id": "y8_ilo0Fc0SS"
146 | },
147 | "outputs": [],
148 | "source": [
149 | "def resize(input_image, real_image, height, width):\n",
150 | " input_image = tf.image.resize(input_image, [height, width],\n",
151 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
152 | " real_image = tf.image.resize(real_image, [height, width],\n",
153 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
154 | "\n",
155 | " return input_image, real_image"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {
162 | "id": "opCaZjvRc0ST"
163 | },
164 | "outputs": [],
165 | "source": [
166 | "def normalize(input_image, target_image):\n",
167 | " input_image = input_image / 255\n",
168 | " target_image = target_image / 255\n",
169 | "\n",
170 | " return input_image, target_image"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "metadata": {
177 | "id": "bvQKHD3-c0SU"
178 | },
179 | "outputs": [],
180 | "source": [
181 | "def load_image_train(depth_path):\n",
182 | " input_image, target = load(depth_path)\n",
183 | " input_image, target = resize(input_image, target,\n",
184 | " IMG_HEIGHT, IMG_WIDTH)\n",
185 | " input_image, target = normalize(input_image, target)\n",
186 | "\n",
187 | " return input_image, target"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {
194 | "id": "i48PBiP7c0SU"
195 | },
196 | "outputs": [],
197 | "source": [
198 | "real = []\n",
199 | "targets = []\n",
200 | "import numpy as np\n",
201 | "for i in range(len(target_paths)):\n",
202 | " #inputs, target = load(target_paths[i])\n",
203 | " inputs, target = load_image_train(target_paths[i])\n",
204 | " #inputs, target = normalize(inputs, target)\n",
205 | " real.append(inputs)\n",
206 | " targets.append(target)\n",
207 | "\n",
208 | "real = np.array(real)\n",
209 | "targets = np.array(targets)"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {
216 | "id": "4yfKd4i_c0SV"
217 | },
218 | "outputs": [],
219 | "source": [
220 | "from matplotlib import pyplot as plt\n",
221 | "import numpy as np"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {
228 | "colab": {
229 | "base_uri": "https://localhost:8080/",
230 | "height": 286
231 | },
232 | "id": "1VCfus5zc0SV",
233 | "outputId": "c21309a4-f950-47d8-c27d-3520328f70bf"
234 | },
235 | "outputs": [],
236 | "source": [
237 | "plt.imshow(real[23])\n",
238 | "print(real[12].shape)"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": null,
244 | "metadata": {
245 | "colab": {
246 | "base_uri": "https://localhost:8080/",
247 | "height": 286
248 | },
249 | "id": "vtOUxl5oc0SV",
250 | "outputId": "00f762f6-9e0f-46a7-d984-c3a1ec089b4f"
251 | },
252 | "outputs": [],
253 | "source": [
254 | "plt.imshow(targets[23].reshape(256, 256))\n",
255 | "print(targets[1].shape)"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "metadata": {
262 | "id": "DVWWuvVic0SW"
263 | },
264 | "outputs": [],
265 | "source": [
266 | "import tensorflow as tf\n",
267 | "\n",
268 | "import os\n",
269 | "import time\n",
270 | "\n",
271 | "from matplotlib import pyplot as plt\n",
272 | "from IPython import display"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "metadata": {
279 | "id": "hm6uhABxc0SW"
280 | },
281 | "outputs": [],
282 | "source": [
283 | "def downsample(filters, size, apply_batchnorm=True):\n",
284 | " initializer = tf.random_normal_initializer(0., 0.02)\n",
285 | "\n",
286 | " result = tf.keras.Sequential()\n",
287 | " result.add(\n",
288 | " tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n",
289 | " kernel_initializer=initializer, use_bias=False))\n",
290 | "\n",
291 | " if apply_batchnorm:\n",
292 | " result.add(tf.keras.layers.BatchNormalization())\n",
293 | "\n",
294 | " result.add(tf.keras.layers.LeakyReLU())\n",
295 | "\n",
296 | " return result"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "metadata": {
303 | "id": "VqAxP9ayc0SX"
304 | },
305 | "outputs": [],
306 | "source": [
307 | "class Patches(tf.keras.layers.Layer):\n",
308 | " def __init__(self, patch_size):\n",
309 | " super(Patches, self).__init__()\n",
310 | " self.patch_size = patch_size\n",
311 | "\n",
312 | " def call(self, images):\n",
313 | " batch_size = tf.shape(images)[0]\n",
314 | " patches = tf.image.extract_patches(\n",
315 | " images=images,\n",
316 | " sizes=[1, self.patch_size, self.patch_size, 1],\n",
317 | " strides=[1, self.patch_size, self.patch_size, 1],\n",
318 | " rates=[1, 1, 1, 1],\n",
319 | " padding=\"SAME\",\n",
320 | " )\n",
321 | " patch_dims = patches.shape[-1]\n",
322 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n",
323 | " return patches"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": null,
329 | "metadata": {
330 | "id": "n_mdkF59c0SX"
331 | },
332 | "outputs": [],
333 | "source": [
334 | "class PatchEncoder(tf.keras.layers.Layer):\n",
335 | " def __init__(self, num_patches, projection_dim):\n",
336 | " super(PatchEncoder, self).__init__()\n",
337 | " self.num_patches = num_patches\n",
338 | " self.projection = layers.Dense(units=projection_dim)\n",
339 | " self.position_embedding = layers.Embedding(\n",
340 | " input_dim=num_patches, output_dim=projection_dim\n",
341 | " )\n",
342 | "\n",
343 | " def call(self, patch):\n",
344 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
345 | " encoded = self.projection(patch) + self.position_embedding(positions)\n",
346 | " return encoded"
347 | ]
348 | },
349 | {
350 | "cell_type": "code",
351 | "execution_count": null,
352 | "metadata": {
353 | "id": "ZNDm9KXXc0SY"
354 | },
355 | "outputs": [],
356 | "source": [
357 | "class TransformerBlock(tf.keras.layers.Layer):\n",
358 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
359 | " super(TransformerBlock, self).__init__()\n",
360 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n",
361 | " self.ffn = tf.keras.Sequential(\n",
362 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
363 | " )\n",
364 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
365 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
366 | " self.dropout1 = layers.Dropout(rate)\n",
367 | " self.dropout2 = layers.Dropout(rate)\n",
368 | "\n",
369 | " def call(self, inputs, training):\n",
370 | " attn_output = self.att(inputs, inputs)\n",
371 | " attn_output = self.dropout1(attn_output, training=training)\n",
372 | " out1 = self.layernorm1(inputs + attn_output)\n",
373 | " ffn_output = self.ffn(out1)\n",
374 | " ffn_output = self.dropout2(ffn_output, training=training)\n",
375 | " return self.layernorm2(out1 + ffn_output)"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": null,
381 | "metadata": {
382 | "id": "RPZFwD5PP4af"
383 | },
384 | "outputs": [],
385 | "source": [
386 | "from tensorflow import Tensor\n",
387 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n",
388 | " Add, AveragePooling2D, Flatten, Dense\n",
389 | "from tensorflow.keras.models import Model\n",
390 | "\n",
391 | "def relu_bn(inputs: Tensor) -> Tensor:\n",
392 | " relu = ReLU()(inputs)\n",
393 | " bn = BatchNormalization()(relu)\n",
394 | " return bn\n",
395 | "\n",
396 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n",
397 | " y = Conv2D(kernel_size=kernel_size,\n",
398 | " strides= (1 if not downsample else 2),\n",
399 | " filters=filters,\n",
400 | " padding=\"same\")(x)\n",
401 | " y = relu_bn(y)\n",
402 | " y = Conv2D(kernel_size=kernel_size,\n",
403 | " strides=1,\n",
404 | " filters=filters,\n",
405 | " padding=\"same\")(y)\n",
406 | "\n",
407 | " if downsample:\n",
408 | " x = Conv2D(kernel_size=1,\n",
409 | " strides=2,\n",
410 | " filters=filters,\n",
411 | " padding=\"same\")(x)\n",
412 | " out = Add()([x, y])\n",
413 | " out = relu_bn(out)\n",
414 | " return out"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "execution_count": null,
420 | "metadata": {
421 | "id": "lcQVzKBDc0SZ"
422 | },
423 | "outputs": [],
424 | "source": [
425 | "from tensorflow.keras import layers\n",
426 | "\n",
427 | "def Generator():\n",
428 | "\n",
429 | " inputs = layers.Input(shape=(256, 256, 3))\n",
430 | "\n",
431 | " patches = Patches(patch_size)(inputs)\n",
432 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n",
433 | "\n",
434 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n",
435 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
436 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
437 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
438 | "\n",
439 | " x = layers.Reshape((8, 8, 1024))(x)\n",
440 | "\n",
441 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
442 | " x = layers.BatchNormalization()(x)\n",
443 | " x = layers.LeakyReLU()(x)\n",
444 | "\n",
445 | " x = residual_block(x, downsample=False, filters=512)\n",
446 | "\n",
447 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
448 | " x = layers.BatchNormalization()(x)\n",
449 | " x = layers.LeakyReLU()(x)\n",
450 | "\n",
451 | " x = residual_block(x, downsample=False, filters=256)\n",
452 | "\n",
453 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
454 | " x = layers.BatchNormalization()(x)\n",
455 | " x = layers.LeakyReLU()(x)\n",
456 | " \n",
457 | " x = residual_block(x, downsample=False, filters=64)\n",
458 | "\n",
459 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)\n",
460 | " x = layers.BatchNormalization()(x)\n",
461 | " x = layers.LeakyReLU()(x)\n",
462 | "\n",
463 | " x = residual_block(x, downsample=False, filters=32)\n",
464 | "\n",
465 | " x = layers.Conv2D(1, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n",
466 | "\n",
467 | " return tf.keras.Model(inputs=inputs, outputs=x)"
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "execution_count": null,
473 | "metadata": {
474 | "colab": {
475 | "base_uri": "https://localhost:8080/",
476 | "height": 1000
477 | },
478 | "id": "DBHxlKHvc0Sa",
479 | "outputId": "0b70c08f-2c2c-4d01-dd44-e340c0b088c0"
480 | },
481 | "outputs": [],
482 | "source": [
483 | "generator = Generator()\n",
484 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": null,
490 | "metadata": {
491 | "colab": {
492 | "base_uri": "https://localhost:8080/"
493 | },
494 | "id": "51J3xxeZRLEO",
495 | "outputId": "e3794664-9dcc-4d21-e38a-b08c27bdff4f"
496 | },
497 | "outputs": [],
498 | "source": [
499 | "generator.summary()"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": null,
505 | "metadata": {
506 | "id": "CxG6_fP1c0Sa"
507 | },
508 | "outputs": [],
509 | "source": [
510 | "tf.config.run_functions_eagerly(False)"
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "execution_count": null,
516 | "metadata": {
517 | "id": "TZn1NNgbc0Sb"
518 | },
519 | "outputs": [],
520 | "source": [
521 | "loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "metadata": {
528 | "id": "tYCaFUoGc0Sb"
529 | },
530 | "outputs": [],
531 | "source": [
532 | "def generator_loss(disc_generated_output, gen_output, target):\n",
533 | " gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)\n",
534 | "\n",
535 | " # mean absolute error\n",
536 | " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
537 | "\n",
538 | " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n",
539 | "\n",
540 | " return total_gen_loss, gan_loss, l1_loss"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": null,
546 | "metadata": {
547 | "id": "lw8T5T3Ac0Sd"
548 | },
549 | "outputs": [],
550 | "source": [
551 | "tf.config.run_functions_eagerly(True)"
552 | ]
553 | },
554 | {
555 | "cell_type": "code",
556 | "execution_count": null,
557 | "metadata": {
558 | "id": "_Qhap2DDc0Sd"
559 | },
560 | "outputs": [],
561 | "source": [
562 | "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
563 | ]
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": null,
568 | "metadata": {
569 | "id": "Gl9RqSOHc0Se"
570 | },
571 | "outputs": [],
572 | "source": [
573 | "def generate_images(model, test_input, tar):\n",
574 | " prediction = model(test_input, training=True)\n",
575 | " plt.figure(figsize=(15, 15))\n",
576 | "\n",
577 | " display_list = [test_input[0], np.array(tar[0]).reshape(256, 256), np.array(prediction[0]).reshape(256, 256)]\n",
578 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
579 | "\n",
580 | " for i in range(3):\n",
581 | " plt.subplot(1, 3, i+1)\n",
582 | " plt.title(title[i])\n",
583 | " # getting the pixel values between [0, 1] to plot it.\n",
584 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n",
585 | " plt.axis('off')\n",
586 | " plt.show()\n",
587 | "\n",
588 | "def generate_batch_images(model, test_input, tar):\n",
589 | " for i in range(len(test_input)):\n",
590 | " prediction = model(test_input, training=True)\n",
591 | " plt.figure(figsize=(15, 15))\n",
592 | "\n",
593 | " display_list = [test_input[i], tar[i], prediction[i]]\n",
594 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
595 | "\n",
596 | " for i in range(3):\n",
597 | " plt.subplot(1, 3, i+1)\n",
598 | " plt.title(title[i])\n",
599 | " # getting the pixel values between [0, 1] to plot it.\n",
600 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n",
601 | " plt.axis('off')\n",
602 | " plt.show()"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": null,
608 | "metadata": {
609 | "id": "N2M-Jbjvc0Se"
610 | },
611 | "outputs": [],
612 | "source": [
613 | "@tf.function\n",
614 | "def train_step(input_image, target):\n",
615 | " with tf.device('/device:GPU:0'):\n",
616 | " with tf.GradientTape() as gen_tape:\n",
617 | " gen_output = generator(input_image, training=True)\n",
618 | " gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
619 | " \n",
620 | "\n",
621 | " generator_gradients = gen_tape.gradient(gen_total_loss,\n",
622 | " generator.trainable_variables)\n",
623 | " generator_optimizer.apply_gradients(zip(generator_gradients,\n",
624 | " generator.trainable_variables))"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "execution_count": null,
630 | "metadata": {
631 | "id": "5wOgyEJmc0Se"
632 | },
633 | "outputs": [],
634 | "source": [
635 | "def fit(train_ds, epochs, test_ds):\n",
636 | " for epoch in range(epochs):\n",
637 | " start = time.time()\n",
638 | "\n",
639 | " display.clear_output(wait=True)\n",
640 | "\n",
641 | " print(\"Epoch: \", epoch)\n",
642 | "\n",
643 | " # Train\n",
644 | " for n, (input_image, target) in train_ds.enumerate():\n",
645 | " print('.', end='')\n",
646 | " if (n+1) % 100 == 0:\n",
647 | " print()\n",
648 | " train_step(input_image, target)\n",
649 | " print()\n",
650 | "\n",
651 | " generator.save_weights(f'depth-gen-weights.h5')"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": null,
657 | "metadata": {
658 | "id": "z4Kq8t1kc0Se"
659 | },
660 | "outputs": [],
661 | "source": [
662 | "train_dataset = tf.data.Dataset.from_tensor_slices((real, targets))\n",
663 | "\n",
664 | "train_dataset = train_dataset.batch(BATCH_SIZE)"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": null,
670 | "metadata": {
671 | "colab": {
672 | "base_uri": "https://localhost:8080/"
673 | },
674 | "id": "B1SXMOPoc0Se",
675 | "outputId": "ee25b332-c08f-4ec4-eb15-1d59a4e896b2"
676 | },
677 | "outputs": [],
678 | "source": [
679 | "fit(train_dataset, 10000, train_dataset)"
680 | ]
681 | },
682 | {
683 | "cell_type": "code",
684 | "execution_count": null,
685 | "metadata": {
686 | "id": "6H20taNNc0Sf"
687 | },
688 | "outputs": [],
689 | "source": [
690 | "generator.save_weights('gen-depth-weights.h5')"
691 | ]
692 | },
693 | {
694 | "cell_type": "code",
695 | "execution_count": null,
696 | "metadata": {
697 | "colab": {
698 | "base_uri": "https://localhost:8080/",
699 | "height": 1000
700 | },
701 | "id": "9mSLHL9Ac0Sf",
702 | "outputId": "44e8b2a6-eec6-4041-c7f9-87da233911ba"
703 | },
704 | "outputs": [],
705 | "source": [
706 | "for example_input, example_target in train_dataset.take(54):\n",
707 | " generate_images(generator, example_input, example_target)"
708 | ]
709 | }
710 | ],
711 | "metadata": {
712 | "accelerator": "GPU",
713 | "colab": {
714 | "name": "image2image_depth-res.ipynb",
715 | "provenance": []
716 | },
717 | "kernelspec": {
718 | "display_name": "Python 3",
719 | "language": "python",
720 | "name": "python3"
721 | },
722 | "language_info": {
723 | "codemirror_mode": {
724 | "name": "ipython",
725 | "version": 3
726 | },
727 | "file_extension": ".py",
728 | "mimetype": "text/x-python",
729 | "name": "python",
730 | "nbconvert_exporter": "python",
731 | "pygments_lexer": "ipython3",
732 | "version": "3.8.10"
733 | }
734 | },
735 | "nbformat": 4,
736 | "nbformat_minor": 1
737 | }
738 |
--------------------------------------------------------------------------------
/notebooks/depth.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# download the depth dataset from https://dimlrgbd.github.io/"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "colab": {
17 | "base_uri": "https://localhost:8080/"
18 | },
19 | "id": "cHaQkNvlc1ki",
20 | "outputId": "23afd6fb-521c-4b67-d765-22812a8a5bab"
21 | },
22 | "outputs": [],
23 | "source": [
24 | "import os\n",
25 | "import time\n",
26 | "import glob\n",
27 | "import numpy as np\n",
28 | "import tensorflow as tf\n",
29 | "from IPython import display\n",
30 | "from matplotlib import pyplot as plt"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {
37 | "colab": {
38 | "base_uri": "https://localhost:8080/"
39 | },
40 | "id": "qMZaGHOlc0SD",
41 | "outputId": "9f0ad718-d88a-4384-a315-fc8be1109a3e"
42 | },
43 | "outputs": [],
44 | "source": [
45 | "dataset_path = 'depth/dataset/train/LR'"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {
52 | "id": "8Mx-LXgLc0SM"
53 | },
54 | "outputs": [],
55 | "source": [
56 | "input_paths = glob.glob(dataset_path + '/**/color/*.png')\n",
57 | "target_paths = glob.glob(dataset_path + '/**/depth_vi/*.png')"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {
64 | "id": "4OVZgPjtc0SP"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "BUFFER_SIZE = 400\n",
69 | "EPOCHS = 100\n",
70 | "LAMBDA = 100\n",
71 | "BATCH_SIZE = 8\n",
72 | "IMG_WIDTH = 256\n",
73 | "IMG_HEIGHT = 256\n",
74 | "patch_size = 8\n",
75 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n",
76 | "projection_dim = 64\n",
77 | "embed_dim = 64\n",
78 | "num_heads = 2 \n",
79 | "ff_dim = 32"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {
86 | "id": "g89DtSq_c0SQ"
87 | },
88 | "outputs": [],
89 | "source": [
90 | "real = []\n",
91 | "targets = []"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {
98 | "id": "AvaXbs4lc0SR"
99 | },
100 | "outputs": [],
101 | "source": [
102 | "def load(path):\n",
103 | "\n",
104 | " image_path = path[:-12] + 'c.png'\n",
105 | " image_path = image_path.replace(\"depth_vi\", \"color\")\n",
106 | " depth_path = path[:-12] + 'depth_vi.png'\n",
107 | "\n",
108 | "\n",
109 | " input_image = tf.io.read_file(image_path)\n",
110 | " input_image = tf.image.decode_jpeg(input_image)\n",
111 | " \n",
112 | " target_image = tf.io.read_file(depth_path)\n",
113 | " target_image = tf.image.decode_jpeg(target_image)\n",
114 | " \n",
115 | " input_image = tf.cast(input_image, tf.float32)\n",
116 | " target_image = tf.cast(target_image, tf.float32)\n",
117 | "\n",
118 | "\n",
119 | " return input_image, target_image"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {
126 | "id": "y8_ilo0Fc0SS"
127 | },
128 | "outputs": [],
129 | "source": [
130 | "def resize(input_image, real_image, height, width):\n",
131 | " input_image = tf.image.resize(input_image, [height, width],\n",
132 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
133 | " real_image = tf.image.resize(real_image, [height, width],\n",
134 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
135 | "\n",
136 | " return input_image, real_image"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {
143 | "id": "opCaZjvRc0ST"
144 | },
145 | "outputs": [],
146 | "source": [
147 | "def normalize(input_image, target_image):\n",
148 | " input_image = input_image / 255\n",
149 | " target_image = target_image / 255\n",
150 | "\n",
151 | " return input_image, target_image"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {
158 | "id": "bvQKHD3-c0SU"
159 | },
160 | "outputs": [],
161 | "source": [
162 | "def load_image_train(depth_path):\n",
163 | " input_image, target = load(depth_path)\n",
164 | " input_image, target = resize(input_image, target,\n",
165 | " IMG_HEIGHT, IMG_WIDTH)\n",
166 | " input_image, target = normalize(input_image, target)\n",
167 | "\n",
168 | " return input_image, target"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": null,
174 | "metadata": {
175 | "id": "i48PBiP7c0SU"
176 | },
177 | "outputs": [],
178 | "source": [
179 | "real = []\n",
180 | "targets = []\n",
181 | "\n",
182 | "for i in range(len(target_paths)):\n",
183 | " #inputs, target = load(target_paths[i])\n",
184 | " inputs, target = load_image_train(target_paths[i])\n",
185 | " #inputs, target = normalize(inputs, target)\n",
186 | " real.append(inputs)\n",
187 | " targets.append(target)\n",
188 | "\n",
189 | "real = np.array(real)\n",
190 | "targets = np.array(targets)"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": null,
196 | "metadata": {
197 | "id": "4yfKd4i_c0SV"
198 | },
199 | "outputs": [],
200 | "source": [
201 | "from matplotlib import pyplot as plt\n",
202 | "import numpy as np"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "metadata": {
209 | "colab": {
210 | "base_uri": "https://localhost:8080/",
211 | "height": 286
212 | },
213 | "id": "1VCfus5zc0SV",
214 | "outputId": "c21309a4-f950-47d8-c27d-3520328f70bf"
215 | },
216 | "outputs": [],
217 | "source": [
218 | "plt.imshow(real[23])\n",
219 | "print(real[12].shape)"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {
226 | "colab": {
227 | "base_uri": "https://localhost:8080/",
228 | "height": 286
229 | },
230 | "id": "vtOUxl5oc0SV",
231 | "outputId": "00f762f6-9e0f-46a7-d984-c3a1ec089b4f"
232 | },
233 | "outputs": [],
234 | "source": [
235 | "plt.imshow(targets[23].reshape(256, 256))\n",
236 | "print(targets[1].shape)"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {
243 | "id": "hm6uhABxc0SW"
244 | },
245 | "outputs": [],
246 | "source": [
247 | "def downsample(filters, size, apply_batchnorm=True):\n",
248 | " initializer = tf.random_normal_initializer(0., 0.02)\n",
249 | "\n",
250 | " result = tf.keras.Sequential()\n",
251 | " result.add(\n",
252 | " tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n",
253 | " kernel_initializer=initializer, use_bias=False))\n",
254 | "\n",
255 | " if apply_batchnorm:\n",
256 | " result.add(tf.keras.layers.BatchNormalization())\n",
257 | "\n",
258 | " result.add(tf.keras.layers.LeakyReLU())\n",
259 | "\n",
260 | " return result"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {
267 | "id": "VqAxP9ayc0SX"
268 | },
269 | "outputs": [],
270 | "source": [
271 | "class Patches(tf.keras.layers.Layer):\n",
272 | " def __init__(self, patch_size):\n",
273 | " super(Patches, self).__init__()\n",
274 | " self.patch_size = patch_size\n",
275 | "\n",
276 | " def call(self, images):\n",
277 | " batch_size = tf.shape(images)[0]\n",
278 | " patches = tf.image.extract_patches(\n",
279 | " images=images,\n",
280 | " sizes=[1, self.patch_size, self.patch_size, 1],\n",
281 | " strides=[1, self.patch_size, self.patch_size, 1],\n",
282 | " rates=[1, 1, 1, 1],\n",
283 | " padding=\"SAME\",\n",
284 | " )\n",
285 | " patch_dims = patches.shape[-1]\n",
286 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n",
287 | " return patches"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {
294 | "id": "n_mdkF59c0SX"
295 | },
296 | "outputs": [],
297 | "source": [
298 | "class PatchEncoder(tf.keras.layers.Layer):\n",
299 | " def __init__(self, num_patches, projection_dim):\n",
300 | " super(PatchEncoder, self).__init__()\n",
301 | " self.num_patches = num_patches\n",
302 | " self.projection = layers.Dense(units=projection_dim)\n",
303 | " self.position_embedding = layers.Embedding(\n",
304 | " input_dim=num_patches, output_dim=projection_dim\n",
305 | " )\n",
306 | "\n",
307 | " def call(self, patch):\n",
308 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
309 | " encoded = self.projection(patch) + self.position_embedding(positions)\n",
310 | " return encoded"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "metadata": {
317 | "id": "ZNDm9KXXc0SY"
318 | },
319 | "outputs": [],
320 | "source": [
321 | "class TransformerBlock(tf.keras.layers.Layer):\n",
322 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
323 | " super(TransformerBlock, self).__init__()\n",
324 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n",
325 | " self.ffn = tf.keras.Sequential(\n",
326 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
327 | " )\n",
328 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
329 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
330 | " self.dropout1 = layers.Dropout(rate)\n",
331 | " self.dropout2 = layers.Dropout(rate)\n",
332 | "\n",
333 | " def call(self, inputs, training):\n",
334 | " attn_output = self.att(inputs, inputs)\n",
335 | " attn_output = self.dropout1(attn_output, training=training)\n",
336 | " out1 = self.layernorm1(inputs + attn_output)\n",
337 | " ffn_output = self.ffn(out1)\n",
338 | " ffn_output = self.dropout2(ffn_output, training=training)\n",
339 | " return self.layernorm2(out1 + ffn_output)"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "metadata": {
346 | "id": "RPZFwD5PP4af"
347 | },
348 | "outputs": [],
349 | "source": [
350 | "from tensorflow import Tensor\n",
351 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n",
352 | " Add, AveragePooling2D, Flatten, Dense\n",
353 | "from tensorflow.keras.models import Model\n",
354 | "\n",
355 | "def relu_bn(inputs: Tensor) -> Tensor:\n",
356 | " relu = ReLU()(inputs)\n",
357 | " bn = BatchNormalization()(relu)\n",
358 | " return bn\n",
359 | "\n",
360 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n",
361 | " y = Conv2D(kernel_size=kernel_size,\n",
362 | " strides= (1 if not downsample else 2),\n",
363 | " filters=filters,\n",
364 | " padding=\"same\")(x)\n",
365 | " y = relu_bn(y)\n",
366 | " y = Conv2D(kernel_size=kernel_size,\n",
367 | " strides=1,\n",
368 | " filters=filters,\n",
369 | " padding=\"same\")(y)\n",
370 | "\n",
371 | " if downsample:\n",
372 | " x = Conv2D(kernel_size=1,\n",
373 | " strides=2,\n",
374 | " filters=filters,\n",
375 | " padding=\"same\")(x)\n",
376 | " out = Add()([x, y])\n",
377 | " out = relu_bn(out)\n",
378 | " return out"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": null,
384 | "metadata": {
385 | "id": "lcQVzKBDc0SZ"
386 | },
387 | "outputs": [],
388 | "source": [
389 | "from tensorflow.keras import layers\n",
390 | "\n",
391 | "def Generator():\n",
392 | "\n",
393 | " inputs = layers.Input(shape=(256, 256, 3))\n",
394 | "\n",
395 | " patches = Patches(patch_size)(inputs)\n",
396 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n",
397 | "\n",
398 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n",
399 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
400 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
401 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
402 | "\n",
403 | " x = layers.Reshape((8, 8, 1024))(x)\n",
404 | "\n",
405 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
406 | " x = layers.BatchNormalization()(x)\n",
407 | " x = layers.LeakyReLU()(x)\n",
408 | "\n",
409 | " x = residual_block(x, downsample=False, filters=512)\n",
410 | "\n",
411 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
412 | " x = layers.BatchNormalization()(x)\n",
413 | " x = layers.LeakyReLU()(x)\n",
414 | "\n",
415 | " x = residual_block(x, downsample=False, filters=256)\n",
416 | "\n",
417 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
418 | " x = layers.BatchNormalization()(x)\n",
419 | " x = layers.LeakyReLU()(x)\n",
420 | " \n",
421 | " x = residual_block(x, downsample=False, filters=64)\n",
422 | "\n",
423 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)\n",
424 | " x = layers.BatchNormalization()(x)\n",
425 | " x = layers.LeakyReLU()(x)\n",
426 | "\n",
427 | " x = residual_block(x, downsample=False, filters=32)\n",
428 | "\n",
429 | " x = layers.Conv2D(1, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n",
430 | "\n",
431 | " return tf.keras.Model(inputs=inputs, outputs=x)"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": null,
437 | "metadata": {
438 | "colab": {
439 | "base_uri": "https://localhost:8080/",
440 | "height": 1000
441 | },
442 | "id": "DBHxlKHvc0Sa",
443 | "outputId": "0b70c08f-2c2c-4d01-dd44-e340c0b088c0"
444 | },
445 | "outputs": [],
446 | "source": [
447 | "generator = Generator()\n",
448 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n",
449 | "generator.summary()\n",
450 | "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "metadata": {
457 | "id": "lw8T5T3Ac0Sd"
458 | },
459 | "outputs": [],
460 | "source": [
461 | "tf.config.run_functions_eagerly(True)"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": null,
467 | "metadata": {
468 | "id": "Gl9RqSOHc0Se"
469 | },
470 | "outputs": [],
471 | "source": [
472 | "def generate_images(model, test_input, tar):\n",
473 | " prediction = model(test_input, training=True)\n",
474 | " plt.figure(figsize=(15, 15))\n",
475 | "\n",
476 | " display_list = [test_input[0], np.array(tar[0]).reshape(256, 256), np.array(prediction[0]).reshape(256, 256)]\n",
477 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
478 | "\n",
479 | " for i in range(3):\n",
480 | " plt.subplot(1, 3, i+1)\n",
481 | " plt.title(title[i])\n",
482 | " # getting the pixel values between [0, 1] to plot it.\n",
483 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n",
484 | " plt.axis('off')\n",
485 | " plt.show()\n",
486 | "\n",
487 | "def generate_batch_images(model, test_input, tar):\n",
488 | " for i in range(len(test_input)):\n",
489 | " prediction = model(test_input, training=True)\n",
490 | " plt.figure(figsize=(15, 15))\n",
491 | "\n",
492 | " display_list = [test_input[i], tar[i], prediction[i]]\n",
493 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
494 | "\n",
495 | " for i in range(3):\n",
496 | " plt.subplot(1, 3, i+1)\n",
497 | " plt.title(title[i])\n",
498 | " # getting the pixel values between [0, 1] to plot it.\n",
499 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n",
500 | " plt.axis('off')\n",
501 | " plt.show()"
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": null,
507 | "metadata": {
508 | "id": "N2M-Jbjvc0Se"
509 | },
510 | "outputs": [],
511 | "source": [
512 | "@tf.function\n",
513 | "def train_step(input_image, target):\n",
514 | " with tf.device('/device:GPU:0'):\n",
515 | " with tf.GradientTape() as gen_tape:\n",
516 | " gen_output = generator(input_image, training=True)\n",
517 | " gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
518 | " \n",
519 | "\n",
520 | " generator_gradients = gen_tape.gradient(gen_total_loss,\n",
521 | " generator.trainable_variables)\n",
522 | " generator_optimizer.apply_gradients(zip(generator_gradients,\n",
523 | " generator.trainable_variables))"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": null,
529 | "metadata": {
530 | "id": "5wOgyEJmc0Se"
531 | },
532 | "outputs": [],
533 | "source": [
534 | "def fit(train_ds, epochs, test_ds):\n",
535 | " for epoch in range(epochs):\n",
536 | " start = time.time()\n",
537 | "\n",
538 | " display.clear_output(wait=True)\n",
539 | "\n",
540 | " print(\"Epoch: \", epoch)\n",
541 | "\n",
542 | " # Train\n",
543 | " for n, (input_image, target) in train_ds.enumerate():\n",
544 | " print('.', end='')\n",
545 | " if (n+1) % 100 == 0:\n",
546 | " print()\n",
547 | " train_step(input_image, target)\n",
548 | " print()\n",
549 | "\n",
550 | " generator.save_weights(f'depth-weights.h5')"
551 | ]
552 | },
553 | {
554 | "cell_type": "code",
555 | "execution_count": null,
556 | "metadata": {
557 | "id": "z4Kq8t1kc0Se"
558 | },
559 | "outputs": [],
560 | "source": [
561 | "train_dataset = tf.data.Dataset.from_tensor_slices((real, targets))\n",
562 | "\n",
563 | "train_dataset = train_dataset.batch(BATCH_SIZE)"
564 | ]
565 | },
566 | {
567 | "cell_type": "code",
568 | "execution_count": null,
569 | "metadata": {
570 | "colab": {
571 | "base_uri": "https://localhost:8080/"
572 | },
573 | "id": "B1SXMOPoc0Se",
574 | "outputId": "ee25b332-c08f-4ec4-eb15-1d59a4e896b2"
575 | },
576 | "outputs": [],
577 | "source": [
578 | "fit(train_dataset, EPOCHS, train_dataset)"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": null,
584 | "metadata": {
585 | "id": "6H20taNNc0Sf"
586 | },
587 | "outputs": [],
588 | "source": [
589 | "generator.save_weights('gen-depth-weights.h5')"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": null,
595 | "metadata": {
596 | "colab": {
597 | "base_uri": "https://localhost:8080/",
598 | "height": 1000
599 | },
600 | "id": "9mSLHL9Ac0Sf",
601 | "outputId": "44e8b2a6-eec6-4041-c7f9-87da233911ba"
602 | },
603 | "outputs": [],
604 | "source": [
605 | "for example_input, example_target in train_dataset.take(54):\n",
606 | " generate_images(generator, example_input, example_target)"
607 | ]
608 | }
609 | ],
610 | "metadata": {
611 | "accelerator": "GPU",
612 | "colab": {
613 | "name": "image2image_depth-res.ipynb",
614 | "provenance": []
615 | },
616 | "kernelspec": {
617 | "display_name": "Python 3",
618 | "language": "python",
619 | "name": "python3"
620 | },
621 | "language_info": {
622 | "codemirror_mode": {
623 | "name": "ipython",
624 | "version": 3
625 | },
626 | "file_extension": ".py",
627 | "mimetype": "text/x-python",
628 | "name": "python",
629 | "nbconvert_exporter": "python",
630 | "pygments_lexer": "ipython3",
631 | "version": "3.8.10"
632 | }
633 | },
634 | "nbformat": 4,
635 | "nbformat_minor": 1
636 | }
637 |
--------------------------------------------------------------------------------
/notebooks/object-segmentation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "YfIk2es3hJEd"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import tensorflow as tf\n",
12 | "\n",
13 | "import os\n",
14 | "import time\n",
15 | "from matplotlib import pyplot as plt\n",
16 | "from IPython import display\n",
17 | "import tensorflow_datasets as tfds"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {
24 | "id": "2CbTEt448b4R"
25 | },
26 | "outputs": [],
27 | "source": [
28 | "BUFFER_SIZE = 400\n",
29 | "EPOCHS = 100\n",
30 | "LAMBDA = 100\n",
31 | "DATASET = 'seg'\n",
32 | "BATCH_SIZE = 32\n",
33 | "IMG_WIDTH = 128\n",
34 | "IMG_HEIGHT = 128\n",
35 | "patch_size = 8\n",
36 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n",
37 | "projection_dim = 64\n",
38 | "embed_dim = 64\n",
39 | "num_heads = 2 \n",
40 | "ff_dim = 32\n",
41 | "\n",
42 | "assert IMG_WIDTH == IMG_HEIGHT, \"image width and image height must have same dims\"\n",
43 | "\n",
44 | "tf.config.run_functions_eagerly(False)"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {
51 | "colab": {
52 | "base_uri": "https://localhost:8080/",
53 | "height": 347,
54 | "referenced_widgets": [
55 | "e6d3b16d24cd4468b68af5be44eeaa46",
56 | "8fd55e55c24e48afb223ff8b7422a546",
57 | "954933d3187645c6bd191289c122a6b5",
58 | "1c7b1be085354daa90f7491366bf3f26",
59 | "84132b5022db442183e409973de11d67",
60 | "f60bacd17c604608a7fd80a06a305bdf",
61 | "112fde1fafb042c7abe870a471f69cb4",
62 | "2d886f8a4991438a97adb7f228c0b247",
63 | "2127f1922264401dbca0ac6365d283af",
64 | "5c6312228c08467186d56dba16c4028c",
65 | "642dffa8e9374234bbce51d29e75c7ea",
66 | "4f1f9e0926cd45e4aa4e813b32b1d7db",
67 | "e292c1f1791e4b349a931400f52e980c",
68 | "80d7be7130b8468c8ab3bcb095dc36bf",
69 | "5cf59228cd5844a8abbe62cb897ce431",
70 | "605b204e45bb401e97270eba6eceb351",
71 | "bbffba56f53b4b68bdcba5875e7c2f07",
72 | "74034286026f457892c06d6025286628",
73 | "0e398791e69f4c3c9f3b8f8928ffac94",
74 | "d8118a57a2814d19b2fde27d2452c84e",
75 | "7e80e3f85f4c4937b4e433f5d0cf8651",
76 | "4e313c0fc9a34434ad8828f9f7d51245",
77 | "a625ba1ffcc340c5a4f042be0b4877c3",
78 | "5695c6a28a5e47008227462f5ade5c9b",
79 | "83c9ad4ed8bb4aa9804a23a330da4d8c",
80 | "2d19d5c44a0348768a9ff242aa199119",
81 | "c94bf43a065b44c182d8af0c717e92ca",
82 | "cd600dd278a04f1fabc50fd0e9639fc4",
83 | "231b5248ecd746d0939739f6a42db75e",
84 | "bffbcc4727fd4a2fa5694aea1680af65",
85 | "a7bdbd085e5c43e3b4572f53175f61b6",
86 | "6abe70e78dce42a091dca068f5af4696",
87 | "2c75ae91013c455e8cedafddcd12052f",
88 | "a6aa784cb2a14bc29ed56a7e29857061",
89 | "c11d62582c59442e9cf388a6fbacaad7",
90 | "fc6c6d3b63634d1f9f58cebbd50f4793",
91 | "5979ebdad11f4f6f941b32ba4a509416",
92 | "2638e1b29d744f6b8a147f8bfa6f89d4",
93 | "7e7891071f7c4f5e87d876da643a3045",
94 | "12aae5a01c5a462e999735d848d3e354",
95 | "1612ca06675349d4b76e5378a129eefb",
96 | "81e3fa193c364c0f90b1dc7bca808eb9",
97 | "8aebd4d81c6a47f78c84602fcef1249c",
98 | "aa67ec0f9e06454b99dea3b324bfaeb2",
99 | "bd9ea3399660446897f60580a39588f2",
100 | "69b5d3f908c748509302621147b3517b",
101 | "3e2b935749cb42aeac6507c0f4697295",
102 | "c985393de22c4f97a7219f53642a38e6",
103 | "600a3e8becc84abeadc682bae8db52d5",
104 | "97c62988c2e1454db66b33ab880f08d1",
105 | "ca89e296533745f7824bd9f91006d162",
106 | "fad326679299483f9cfcc800bd4aa549",
107 | "faee640e94db490f8cd61481aa74a92a",
108 | "6ba3d098f19b4919a06ecca5b3763596",
109 | "f5b8dec6698746268a69122340b1f989",
110 | "1689a79e51304012976295a13bac17cf"
111 | ]
112 | },
113 | "id": "Kn-k8kTXuAlv",
114 | "outputId": "6cc83593-137e-429a-e2ef-060e8d07d41a"
115 | },
116 | "outputs": [],
117 | "source": [
118 | "dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "metadata": {
125 | "id": "aO9ZAGH5K3SY"
126 | },
127 | "outputs": [],
128 | "source": [
129 | "def normalize(input_image, input_mask):\n",
130 | " input_image = tf.cast(input_image, tf.float32) / 255.0\n",
131 | " input_mask -= 1\n",
132 | " return input_image, input_mask"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {
139 | "id": "4OLHMpsQ5aOv"
140 | },
141 | "outputs": [],
142 | "source": [
143 | "@tf.function\n",
144 | "def load_image_train(datapoint):\n",
145 | " input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
146 | " input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
147 | "\n",
148 | " if tf.random.uniform(()) > 0.5:\n",
149 | " input_image = tf.image.flip_left_right(input_image)\n",
150 | " input_mask = tf.image.flip_left_right(input_mask)\n",
151 | "\n",
152 | " input_image, input_mask = normalize(input_image, input_mask)\n",
153 | "\n",
154 | " return input_image, input_mask"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "id": "rwwYQpu9FzDu"
162 | },
163 | "outputs": [],
164 | "source": [
165 | "def load_image_test(datapoint):\n",
166 | " input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
167 | " input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
168 | "\n",
169 | " input_image, input_mask = normalize(input_image, input_mask)\n",
170 | "\n",
171 | " return input_image, input_mask"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {
178 | "colab": {
179 | "base_uri": "https://localhost:8080/"
180 | },
181 | "id": "Yn3IwqhiIszt",
182 | "outputId": "e52589a1-2d3a-42c8-ec89-7a08e27b9538"
183 | },
184 | "outputs": [],
185 | "source": [
186 | "TRAIN_LENGTH = info.splits['train'].num_examples\n",
187 | "STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {
194 | "id": "muhR2cgbLKWW"
195 | },
196 | "outputs": [],
197 | "source": [
198 | "train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)\n",
199 | "test = dataset['test'].map(load_image_test)"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {
206 | "id": "fVQOjcPVLrUc"
207 | },
208 | "outputs": [],
209 | "source": [
210 | "train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()\n",
211 | "train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)\n",
212 | "test_dataset = test.batch(BATCH_SIZE)"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": null,
218 | "metadata": {
219 | "id": "n0OGdi6D92kM"
220 | },
221 | "outputs": [],
222 | "source": [
223 | "def display(display_list):\n",
224 | " plt.figure(figsize=(15, 15))\n",
225 | "\n",
226 | " title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
227 | "\n",
228 | " for i in range(len(display_list)):\n",
229 | " plt.subplot(1, len(display_list), i+1)\n",
230 | " plt.title(title[i])\n",
231 | " plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))\n",
232 | " plt.axis('off')\n",
233 | " plt.show()"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": null,
239 | "metadata": {
240 | "colab": {
241 | "base_uri": "https://localhost:8080/",
242 | "height": 427
243 | },
244 | "id": "tyaP4hLJ8b4W",
245 | "outputId": "ba14a1a1-ecc9-4fe1-f512-10e0793cd921"
246 | },
247 | "outputs": [],
248 | "source": [
249 | "for image, mask in train.take(1):\n",
250 | " sample_image, sample_mask = image, mask\n",
251 | "display([sample_image, sample_mask])"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "metadata": {
258 | "id": "VB3Z6D_zKSru"
259 | },
260 | "outputs": [],
261 | "source": [
262 | "def create_mask(pred_mask):\n",
263 | " pred_mask = tf.argmax(pred_mask, axis=-1)\n",
264 | " pred_mask = pred_mask[..., tf.newaxis]\n",
265 | " return pred_mask[0]"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": null,
271 | "metadata": {
272 | "id": "SQHmYSmk8b4b"
273 | },
274 | "outputs": [],
275 | "source": [
276 | "def show_predictions(dataset=None, num=1):\n",
277 | " if dataset:\n",
278 | " for image, mask in dataset.take(num):\n",
279 | " pred_mask = generator.predict(image)\n",
280 | " display([image[0], mask[0], create_mask(pred_mask)])\n",
281 | " else:\n",
282 | " display([sample_image, sample_mask,\n",
283 | " create_mask(generator.predict(sample_image[tf.newaxis, ...]))])"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": null,
289 | "metadata": {
290 | "id": "AWSBM-ckAZZL"
291 | },
292 | "outputs": [],
293 | "source": [
294 | "class Patches(tf.keras.layers.Layer):\n",
295 | " def __init__(self, patch_size):\n",
296 | " super(Patches, self).__init__()\n",
297 | " self.patch_size = patch_size\n",
298 | "\n",
299 | " def call(self, images):\n",
300 | " batch_size = tf.shape(images)[0]\n",
301 | " patches = tf.image.extract_patches(\n",
302 | " images=images,\n",
303 | " sizes=[1, self.patch_size, self.patch_size, 1],\n",
304 | " strides=[1, self.patch_size, self.patch_size, 1],\n",
305 | " rates=[1, 1, 1, 1],\n",
306 | " padding=\"SAME\",\n",
307 | " )\n",
308 | " patch_dims = patches.shape[-1]\n",
309 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n",
310 | " return patches"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "metadata": {
317 | "id": "mXT2GyxTAZWq"
318 | },
319 | "outputs": [],
320 | "source": [
321 | "class PatchEncoder(tf.keras.layers.Layer):\n",
322 | " def __init__(self, num_patches, projection_dim):\n",
323 | " super(PatchEncoder, self).__init__()\n",
324 | " self.num_patches = num_patches\n",
325 | " self.projection = layers.Dense(units=projection_dim)\n",
326 | " self.position_embedding = layers.Embedding(\n",
327 | " input_dim=num_patches, output_dim=projection_dim\n",
328 | " )\n",
329 | "\n",
330 | " def call(self, patch):\n",
331 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
332 | " encoded = self.projection(patch) + self.position_embedding(positions)\n",
333 | " return encoded"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "metadata": {
340 | "id": "EsRN0b3qAdWz"
341 | },
342 | "outputs": [],
343 | "source": [
344 | "class TransformerBlock(tf.keras.layers.Layer):\n",
345 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
346 | " super(TransformerBlock, self).__init__()\n",
347 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n",
348 | " self.ffn = tf.keras.Sequential(\n",
349 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
350 | " )\n",
351 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
352 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
353 | " self.dropout1 = layers.Dropout(rate)\n",
354 | " self.dropout2 = layers.Dropout(rate)\n",
355 | "\n",
356 | " def call(self, inputs, training):\n",
357 | " attn_output = self.att(inputs, inputs)\n",
358 | " attn_output = self.dropout1(attn_output, training=training)\n",
359 | " out1 = self.layernorm1(inputs + attn_output)\n",
360 | " ffn_output = self.ffn(out1)\n",
361 | " ffn_output = self.dropout2(ffn_output, training=training)\n",
362 | " return self.layernorm2(out1 + ffn_output)"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "execution_count": null,
368 | "metadata": {
369 | "id": "h9GZYWlkAsBn"
370 | },
371 | "outputs": [],
372 | "source": [
373 | "from tensorflow import Tensor\n",
374 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n",
375 | " Add, AveragePooling2D, Flatten, Dense\n",
376 | "from tensorflow.keras.models import Model\n",
377 | "\n",
378 | "def relu_bn(inputs: Tensor) -> Tensor:\n",
379 | " relu = ReLU()(inputs)\n",
380 | " bn = BatchNormalization()(relu)\n",
381 | " return bn\n",
382 | "\n",
383 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n",
384 | " y = Conv2D(kernel_size=kernel_size,\n",
385 | " strides= (1 if not downsample else 2),\n",
386 | " filters=filters,\n",
387 | " padding=\"same\")(x)\n",
388 | " y = relu_bn(y)\n",
389 | " y = Conv2D(kernel_size=kernel_size,\n",
390 | " strides=1,\n",
391 | " filters=filters,\n",
392 | " padding=\"same\")(y)\n",
393 | "\n",
394 | " if downsample:\n",
395 | " x = Conv2D(kernel_size=1,\n",
396 | " strides=2,\n",
397 | " filters=filters,\n",
398 | " padding=\"same\")(x)\n",
399 | " out = Add()([x, y])\n",
400 | " out = relu_bn(out)\n",
401 | " return out"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {
408 | "id": "lFPI4Nu-8b4q"
409 | },
410 | "outputs": [],
411 | "source": [
412 | "from tensorflow.keras import layers\n",
413 | "\n",
414 | "def Generator():\n",
415 | "\n",
416 | " inputs = layers.Input(shape=(128, 128, 3))\n",
417 | "\n",
418 | " patches = Patches(patch_size)(inputs)\n",
419 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n",
420 | "\n",
421 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n",
422 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
423 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
424 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
425 | "\n",
426 | " x = layers.Reshape((8, 8, 256))(x)\n",
427 | "\n",
428 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
429 | " x = layers.BatchNormalization()(x)\n",
430 | " x = layers.LeakyReLU()(x)\n",
431 | "\n",
432 | " x = residual_block(x, downsample=False, filters=512)\n",
433 | "\n",
434 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
435 | " x = layers.BatchNormalization()(x)\n",
436 | " x = layers.LeakyReLU()(x)\n",
437 | "\n",
438 | " x = residual_block(x, downsample=False, filters=256)\n",
439 | "\n",
440 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
441 | " x = layers.BatchNormalization()(x)\n",
442 | " x = layers.LeakyReLU()(x)\n",
443 | " \n",
444 | " x = residual_block(x, downsample=False, filters=64)\n",
445 | "\n",
446 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
447 | " x = layers.BatchNormalization()(x)\n",
448 | " x = layers.LeakyReLU()(x)\n",
449 | "\n",
450 | " x = residual_block(x, downsample=False, filters=32)\n",
451 | "\n",
452 | " x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n",
453 | "\n",
454 | " return tf.keras.Model(inputs=inputs, outputs=x)"
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": null,
460 | "metadata": {
461 | "colab": {
462 | "base_uri": "https://localhost:8080/"
463 | },
464 | "id": "dIbRPFzjmV85",
465 | "outputId": "5216d85f-f401-4657-d41e-233f9be51233"
466 | },
467 | "outputs": [],
468 | "source": [
469 | "generator = Generator()\n",
470 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n",
471 | "generator.summary()"
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "execution_count": null,
477 | "metadata": {
478 | "id": "o58eGY46eiPQ"
479 | },
480 | "outputs": [],
481 | "source": [
482 | "class DisplayCallback(tf.keras.callbacks.Callback):\n",
483 | " def on_epoch_end(self, epoch, logs=None):\n",
484 | " clear_output(wait=True)\n",
485 | " show_predictions()\n",
486 | " print ('\\nSample Prediction after epoch {}\\n'.format(epoch+1))"
487 | ]
488 | },
489 | {
490 | "cell_type": "code",
491 | "execution_count": null,
492 | "metadata": {
493 | "id": "5nfPDmCNemKf"
494 | },
495 | "outputs": [],
496 | "source": [
497 | "generator.compile(optimizer='adam',\n",
498 | " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
499 | " metrics=['accuracy'])"
500 | ]
501 | },
502 | {
503 | "cell_type": "code",
504 | "execution_count": null,
505 | "metadata": {
506 | "colab": {
507 | "base_uri": "https://localhost:8080/"
508 | },
509 | "id": "LyA03ie2dUAS",
510 | "outputId": "0639a42b-4a8a-4603-9998-7c1409f1c71c"
511 | },
512 | "outputs": [],
513 | "source": [
514 | "EPOCHS = 200\n",
515 | "VAL_SUBSPLITS = 5\n",
516 | "VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n",
517 | "\n",
518 | "model_history = generator.fit(train_dataset, epochs=EPOCHS,\n",
519 | " steps_per_epoch=STEPS_PER_EPOCH,\n",
520 | " validation_steps=VALIDATION_STEPS,\n",
521 | " validation_data=test_dataset)"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "metadata": {
528 | "colab": {
529 | "base_uri": "https://localhost:8080/",
530 | "height": 293
531 | },
532 | "id": "U1N1_obwtdQH",
533 | "outputId": "20004ed9-8789-4c37-962c-629d0bfd9946"
534 | },
535 | "outputs": [],
536 | "source": [
537 | "show_predictions(train_dataset)"
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": null,
543 | "metadata": {
544 | "id": "NiTrkKItvZHE"
545 | },
546 | "outputs": [],
547 | "source": [
548 | "generator.save_weights('seg-gen-weights.h5')"
549 | ]
550 | },
551 | {
552 | "cell_type": "code",
553 | "execution_count": null,
554 | "metadata": {},
555 | "outputs": [],
556 | "source": [
557 | "generator.load_weights('weights/seg-gen-weights (5).h5')"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": null,
563 | "metadata": {},
564 | "outputs": [],
565 | "source": [
566 | "for inp, tar in train_dataset:\n",
567 | " break"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": null,
573 | "metadata": {},
574 | "outputs": [],
575 | "source": [
576 | "plt.imshow(generator(inp)[0])\n",
577 | "import numpy as np\n",
578 | "plt.imsave('pred1.png', np.array(create_mask(generator(inp))).astype(np.float32).reshape(128, 128))"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": null,
584 | "metadata": {},
585 | "outputs": [],
586 | "source": [
587 | "plt.imshow(inp[0])\n",
588 | "import numpy as np\n",
589 | "plt.imsave('tar1.png', np.array(inp[0]).astype(np.float32).reshape(128, 128, 3))"
590 | ]
591 | },
592 | {
593 | "cell_type": "code",
594 | "execution_count": null,
595 | "metadata": {},
596 | "outputs": [],
597 | "source": []
598 | }
599 | ],
600 | "metadata": {
601 | "accelerator": "GPU",
602 | "colab": {
603 | "collapsed_sections": [],
604 | "name": "image2image_seg.ipynb",
605 | "provenance": [],
606 | "toc_visible": true
607 | },
608 | "kernelspec": {
609 | "display_name": "Python 3",
610 | "language": "python",
611 | "name": "python3"
612 | },
613 | "language_info": {
614 | "codemirror_mode": {
615 | "name": "ipython",
616 | "version": 3
617 | },
618 | "file_extension": ".py",
619 | "mimetype": "text/x-python",
620 | "name": "python",
621 | "nbconvert_exporter": "python",
622 | "pygments_lexer": "ipython3",
623 | "version": "3.8.10"
624 | },
625 | "widgets": {
626 | "application/vnd.jupyter.widget-state+json": {
627 | "0e398791e69f4c3c9f3b8f8928ffac94": {
628 | "model_module": "@jupyter-widgets/controls",
629 | "model_name": "FloatProgressModel",
630 | "state": {
631 | "_dom_classes": [],
632 | "_model_module": "@jupyter-widgets/controls",
633 | "_model_module_version": "1.5.0",
634 | "_model_name": "FloatProgressModel",
635 | "_view_count": null,
636 | "_view_module": "@jupyter-widgets/controls",
637 | "_view_module_version": "1.5.0",
638 | "_view_name": "ProgressView",
639 | "bar_style": "success",
640 | "description": "Extraction completed...: 100%",
641 | "description_tooltip": null,
642 | "layout": "IPY_MODEL_4e313c0fc9a34434ad8828f9f7d51245",
643 | "max": 1,
644 | "min": 0,
645 | "orientation": "horizontal",
646 | "style": "IPY_MODEL_7e80e3f85f4c4937b4e433f5d0cf8651",
647 | "value": 1
648 | }
649 | },
650 | "112fde1fafb042c7abe870a471f69cb4": {
651 | "model_module": "@jupyter-widgets/controls",
652 | "model_name": "DescriptionStyleModel",
653 | "state": {
654 | "_model_module": "@jupyter-widgets/controls",
655 | "_model_module_version": "1.5.0",
656 | "_model_name": "DescriptionStyleModel",
657 | "_view_count": null,
658 | "_view_module": "@jupyter-widgets/base",
659 | "_view_module_version": "1.2.0",
660 | "_view_name": "StyleView",
661 | "description_width": ""
662 | }
663 | },
664 | "12aae5a01c5a462e999735d848d3e354": {
665 | "model_module": "@jupyter-widgets/base",
666 | "model_name": "LayoutModel",
667 | "state": {
668 | "_model_module": "@jupyter-widgets/base",
669 | "_model_module_version": "1.2.0",
670 | "_model_name": "LayoutModel",
671 | "_view_count": null,
672 | "_view_module": "@jupyter-widgets/base",
673 | "_view_module_version": "1.2.0",
674 | "_view_name": "LayoutView",
675 | "align_content": null,
676 | "align_items": null,
677 | "align_self": null,
678 | "border": null,
679 | "bottom": null,
680 | "display": null,
681 | "flex": null,
682 | "flex_flow": null,
683 | "grid_area": null,
684 | "grid_auto_columns": null,
685 | "grid_auto_flow": null,
686 | "grid_auto_rows": null,
687 | "grid_column": null,
688 | "grid_gap": null,
689 | "grid_row": null,
690 | "grid_template_areas": null,
691 | "grid_template_columns": null,
692 | "grid_template_rows": null,
693 | "height": null,
694 | "justify_content": null,
695 | "justify_items": null,
696 | "left": null,
697 | "margin": null,
698 | "max_height": null,
699 | "max_width": null,
700 | "min_height": null,
701 | "min_width": null,
702 | "object_fit": null,
703 | "object_position": null,
704 | "order": null,
705 | "overflow": null,
706 | "overflow_x": null,
707 | "overflow_y": null,
708 | "padding": null,
709 | "right": null,
710 | "top": null,
711 | "visibility": null,
712 | "width": null
713 | }
714 | },
715 | "1612ca06675349d4b76e5378a129eefb": {
716 | "model_module": "@jupyter-widgets/controls",
717 | "model_name": "HBoxModel",
718 | "state": {
719 | "_dom_classes": [],
720 | "_model_module": "@jupyter-widgets/controls",
721 | "_model_module_version": "1.5.0",
722 | "_model_name": "HBoxModel",
723 | "_view_count": null,
724 | "_view_module": "@jupyter-widgets/controls",
725 | "_view_module_version": "1.5.0",
726 | "_view_name": "HBoxView",
727 | "box_style": "",
728 | "children": [
729 | "IPY_MODEL_8aebd4d81c6a47f78c84602fcef1249c",
730 | "IPY_MODEL_aa67ec0f9e06454b99dea3b324bfaeb2"
731 | ],
732 | "layout": "IPY_MODEL_81e3fa193c364c0f90b1dc7bca808eb9"
733 | }
734 | },
735 | "1689a79e51304012976295a13bac17cf": {
736 | "model_module": "@jupyter-widgets/base",
737 | "model_name": "LayoutModel",
738 | "state": {
739 | "_model_module": "@jupyter-widgets/base",
740 | "_model_module_version": "1.2.0",
741 | "_model_name": "LayoutModel",
742 | "_view_count": null,
743 | "_view_module": "@jupyter-widgets/base",
744 | "_view_module_version": "1.2.0",
745 | "_view_name": "LayoutView",
746 | "align_content": null,
747 | "align_items": null,
748 | "align_self": null,
749 | "border": null,
750 | "bottom": null,
751 | "display": null,
752 | "flex": null,
753 | "flex_flow": null,
754 | "grid_area": null,
755 | "grid_auto_columns": null,
756 | "grid_auto_flow": null,
757 | "grid_auto_rows": null,
758 | "grid_column": null,
759 | "grid_gap": null,
760 | "grid_row": null,
761 | "grid_template_areas": null,
762 | "grid_template_columns": null,
763 | "grid_template_rows": null,
764 | "height": null,
765 | "justify_content": null,
766 | "justify_items": null,
767 | "left": null,
768 | "margin": null,
769 | "max_height": null,
770 | "max_width": null,
771 | "min_height": null,
772 | "min_width": null,
773 | "object_fit": null,
774 | "object_position": null,
775 | "order": null,
776 | "overflow": null,
777 | "overflow_x": null,
778 | "overflow_y": null,
779 | "padding": null,
780 | "right": null,
781 | "top": null,
782 | "visibility": null,
783 | "width": null
784 | }
785 | },
786 | "1c7b1be085354daa90f7491366bf3f26": {
787 | "model_module": "@jupyter-widgets/controls",
788 | "model_name": "HTMLModel",
789 | "state": {
790 | "_dom_classes": [],
791 | "_model_module": "@jupyter-widgets/controls",
792 | "_model_module_version": "1.5.0",
793 | "_model_name": "HTMLModel",
794 | "_view_count": null,
795 | "_view_module": "@jupyter-widgets/controls",
796 | "_view_module_version": "1.5.0",
797 | "_view_name": "HTMLView",
798 | "description": "",
799 | "description_tooltip": null,
800 | "layout": "IPY_MODEL_2d886f8a4991438a97adb7f228c0b247",
801 | "placeholder": "",
802 | "style": "IPY_MODEL_112fde1fafb042c7abe870a471f69cb4",
803 | "value": " 2/2 [00:37<00:00, 18.66s/ url]"
804 | }
805 | },
806 | "2127f1922264401dbca0ac6365d283af": {
807 | "model_module": "@jupyter-widgets/controls",
808 | "model_name": "HBoxModel",
809 | "state": {
810 | "_dom_classes": [],
811 | "_model_module": "@jupyter-widgets/controls",
812 | "_model_module_version": "1.5.0",
813 | "_model_name": "HBoxModel",
814 | "_view_count": null,
815 | "_view_module": "@jupyter-widgets/controls",
816 | "_view_module_version": "1.5.0",
817 | "_view_name": "HBoxView",
818 | "box_style": "",
819 | "children": [
820 | "IPY_MODEL_642dffa8e9374234bbce51d29e75c7ea",
821 | "IPY_MODEL_4f1f9e0926cd45e4aa4e813b32b1d7db"
822 | ],
823 | "layout": "IPY_MODEL_5c6312228c08467186d56dba16c4028c"
824 | }
825 | },
826 | "231b5248ecd746d0939739f6a42db75e": {
827 | "model_module": "@jupyter-widgets/controls",
828 | "model_name": "ProgressStyleModel",
829 | "state": {
830 | "_model_module": "@jupyter-widgets/controls",
831 | "_model_module_version": "1.5.0",
832 | "_model_name": "ProgressStyleModel",
833 | "_view_count": null,
834 | "_view_module": "@jupyter-widgets/base",
835 | "_view_module_version": "1.2.0",
836 | "_view_name": "StyleView",
837 | "bar_color": null,
838 | "description_width": "initial"
839 | }
840 | },
841 | "2638e1b29d744f6b8a147f8bfa6f89d4": {
842 | "model_module": "@jupyter-widgets/base",
843 | "model_name": "LayoutModel",
844 | "state": {
845 | "_model_module": "@jupyter-widgets/base",
846 | "_model_module_version": "1.2.0",
847 | "_model_name": "LayoutModel",
848 | "_view_count": null,
849 | "_view_module": "@jupyter-widgets/base",
850 | "_view_module_version": "1.2.0",
851 | "_view_name": "LayoutView",
852 | "align_content": null,
853 | "align_items": null,
854 | "align_self": null,
855 | "border": null,
856 | "bottom": null,
857 | "display": null,
858 | "flex": null,
859 | "flex_flow": null,
860 | "grid_area": null,
861 | "grid_auto_columns": null,
862 | "grid_auto_flow": null,
863 | "grid_auto_rows": null,
864 | "grid_column": null,
865 | "grid_gap": null,
866 | "grid_row": null,
867 | "grid_template_areas": null,
868 | "grid_template_columns": null,
869 | "grid_template_rows": null,
870 | "height": null,
871 | "justify_content": null,
872 | "justify_items": null,
873 | "left": null,
874 | "margin": null,
875 | "max_height": null,
876 | "max_width": null,
877 | "min_height": null,
878 | "min_width": null,
879 | "object_fit": null,
880 | "object_position": null,
881 | "order": null,
882 | "overflow": null,
883 | "overflow_x": null,
884 | "overflow_y": null,
885 | "padding": null,
886 | "right": null,
887 | "top": null,
888 | "visibility": null,
889 | "width": null
890 | }
891 | },
892 | "2c75ae91013c455e8cedafddcd12052f": {
893 | "model_module": "@jupyter-widgets/controls",
894 | "model_name": "HBoxModel",
895 | "state": {
896 | "_dom_classes": [],
897 | "_model_module": "@jupyter-widgets/controls",
898 | "_model_module_version": "1.5.0",
899 | "_model_name": "HBoxModel",
900 | "_view_count": null,
901 | "_view_module": "@jupyter-widgets/controls",
902 | "_view_module_version": "1.5.0",
903 | "_view_name": "HBoxView",
904 | "box_style": "",
905 | "children": [
906 | "IPY_MODEL_c11d62582c59442e9cf388a6fbacaad7",
907 | "IPY_MODEL_fc6c6d3b63634d1f9f58cebbd50f4793"
908 | ],
909 | "layout": "IPY_MODEL_a6aa784cb2a14bc29ed56a7e29857061"
910 | }
911 | },
912 | "2d19d5c44a0348768a9ff242aa199119": {
913 | "model_module": "@jupyter-widgets/base",
914 | "model_name": "LayoutModel",
915 | "state": {
916 | "_model_module": "@jupyter-widgets/base",
917 | "_model_module_version": "1.2.0",
918 | "_model_name": "LayoutModel",
919 | "_view_count": null,
920 | "_view_module": "@jupyter-widgets/base",
921 | "_view_module_version": "1.2.0",
922 | "_view_name": "LayoutView",
923 | "align_content": null,
924 | "align_items": null,
925 | "align_self": null,
926 | "border": null,
927 | "bottom": null,
928 | "display": null,
929 | "flex": null,
930 | "flex_flow": null,
931 | "grid_area": null,
932 | "grid_auto_columns": null,
933 | "grid_auto_flow": null,
934 | "grid_auto_rows": null,
935 | "grid_column": null,
936 | "grid_gap": null,
937 | "grid_row": null,
938 | "grid_template_areas": null,
939 | "grid_template_columns": null,
940 | "grid_template_rows": null,
941 | "height": null,
942 | "justify_content": null,
943 | "justify_items": null,
944 | "left": null,
945 | "margin": null,
946 | "max_height": null,
947 | "max_width": null,
948 | "min_height": null,
949 | "min_width": null,
950 | "object_fit": null,
951 | "object_position": null,
952 | "order": null,
953 | "overflow": null,
954 | "overflow_x": null,
955 | "overflow_y": null,
956 | "padding": null,
957 | "right": null,
958 | "top": null,
959 | "visibility": null,
960 | "width": null
961 | }
962 | },
963 | "2d886f8a4991438a97adb7f228c0b247": {
964 | "model_module": "@jupyter-widgets/base",
965 | "model_name": "LayoutModel",
966 | "state": {
967 | "_model_module": "@jupyter-widgets/base",
968 | "_model_module_version": "1.2.0",
969 | "_model_name": "LayoutModel",
970 | "_view_count": null,
971 | "_view_module": "@jupyter-widgets/base",
972 | "_view_module_version": "1.2.0",
973 | "_view_name": "LayoutView",
974 | "align_content": null,
975 | "align_items": null,
976 | "align_self": null,
977 | "border": null,
978 | "bottom": null,
979 | "display": null,
980 | "flex": null,
981 | "flex_flow": null,
982 | "grid_area": null,
983 | "grid_auto_columns": null,
984 | "grid_auto_flow": null,
985 | "grid_auto_rows": null,
986 | "grid_column": null,
987 | "grid_gap": null,
988 | "grid_row": null,
989 | "grid_template_areas": null,
990 | "grid_template_columns": null,
991 | "grid_template_rows": null,
992 | "height": null,
993 | "justify_content": null,
994 | "justify_items": null,
995 | "left": null,
996 | "margin": null,
997 | "max_height": null,
998 | "max_width": null,
999 | "min_height": null,
1000 | "min_width": null,
1001 | "object_fit": null,
1002 | "object_position": null,
1003 | "order": null,
1004 | "overflow": null,
1005 | "overflow_x": null,
1006 | "overflow_y": null,
1007 | "padding": null,
1008 | "right": null,
1009 | "top": null,
1010 | "visibility": null,
1011 | "width": null
1012 | }
1013 | },
1014 | "3e2b935749cb42aeac6507c0f4697295": {
1015 | "model_module": "@jupyter-widgets/controls",
1016 | "model_name": "DescriptionStyleModel",
1017 | "state": {
1018 | "_model_module": "@jupyter-widgets/controls",
1019 | "_model_module_version": "1.5.0",
1020 | "_model_name": "DescriptionStyleModel",
1021 | "_view_count": null,
1022 | "_view_module": "@jupyter-widgets/base",
1023 | "_view_module_version": "1.2.0",
1024 | "_view_name": "StyleView",
1025 | "description_width": ""
1026 | }
1027 | },
1028 | "4e313c0fc9a34434ad8828f9f7d51245": {
1029 | "model_module": "@jupyter-widgets/base",
1030 | "model_name": "LayoutModel",
1031 | "state": {
1032 | "_model_module": "@jupyter-widgets/base",
1033 | "_model_module_version": "1.2.0",
1034 | "_model_name": "LayoutModel",
1035 | "_view_count": null,
1036 | "_view_module": "@jupyter-widgets/base",
1037 | "_view_module_version": "1.2.0",
1038 | "_view_name": "LayoutView",
1039 | "align_content": null,
1040 | "align_items": null,
1041 | "align_self": null,
1042 | "border": null,
1043 | "bottom": null,
1044 | "display": null,
1045 | "flex": null,
1046 | "flex_flow": null,
1047 | "grid_area": null,
1048 | "grid_auto_columns": null,
1049 | "grid_auto_flow": null,
1050 | "grid_auto_rows": null,
1051 | "grid_column": null,
1052 | "grid_gap": null,
1053 | "grid_row": null,
1054 | "grid_template_areas": null,
1055 | "grid_template_columns": null,
1056 | "grid_template_rows": null,
1057 | "height": null,
1058 | "justify_content": null,
1059 | "justify_items": null,
1060 | "left": null,
1061 | "margin": null,
1062 | "max_height": null,
1063 | "max_width": null,
1064 | "min_height": null,
1065 | "min_width": null,
1066 | "object_fit": null,
1067 | "object_position": null,
1068 | "order": null,
1069 | "overflow": null,
1070 | "overflow_x": null,
1071 | "overflow_y": null,
1072 | "padding": null,
1073 | "right": null,
1074 | "top": null,
1075 | "visibility": null,
1076 | "width": null
1077 | }
1078 | },
1079 | "4f1f9e0926cd45e4aa4e813b32b1d7db": {
1080 | "model_module": "@jupyter-widgets/controls",
1081 | "model_name": "HTMLModel",
1082 | "state": {
1083 | "_dom_classes": [],
1084 | "_model_module": "@jupyter-widgets/controls",
1085 | "_model_module_version": "1.5.0",
1086 | "_model_name": "HTMLModel",
1087 | "_view_count": null,
1088 | "_view_module": "@jupyter-widgets/controls",
1089 | "_view_module_version": "1.5.0",
1090 | "_view_name": "HTMLView",
1091 | "description": "",
1092 | "description_tooltip": null,
1093 | "layout": "IPY_MODEL_605b204e45bb401e97270eba6eceb351",
1094 | "placeholder": "",
1095 | "style": "IPY_MODEL_5cf59228cd5844a8abbe62cb897ce431",
1096 | "value": " 773/773 [00:37<00:00, 20.74 MiB/s]"
1097 | }
1098 | },
1099 | "5695c6a28a5e47008227462f5ade5c9b": {
1100 | "model_module": "@jupyter-widgets/base",
1101 | "model_name": "LayoutModel",
1102 | "state": {
1103 | "_model_module": "@jupyter-widgets/base",
1104 | "_model_module_version": "1.2.0",
1105 | "_model_name": "LayoutModel",
1106 | "_view_count": null,
1107 | "_view_module": "@jupyter-widgets/base",
1108 | "_view_module_version": "1.2.0",
1109 | "_view_name": "LayoutView",
1110 | "align_content": null,
1111 | "align_items": null,
1112 | "align_self": null,
1113 | "border": null,
1114 | "bottom": null,
1115 | "display": null,
1116 | "flex": null,
1117 | "flex_flow": null,
1118 | "grid_area": null,
1119 | "grid_auto_columns": null,
1120 | "grid_auto_flow": null,
1121 | "grid_auto_rows": null,
1122 | "grid_column": null,
1123 | "grid_gap": null,
1124 | "grid_row": null,
1125 | "grid_template_areas": null,
1126 | "grid_template_columns": null,
1127 | "grid_template_rows": null,
1128 | "height": null,
1129 | "justify_content": null,
1130 | "justify_items": null,
1131 | "left": null,
1132 | "margin": null,
1133 | "max_height": null,
1134 | "max_width": null,
1135 | "min_height": null,
1136 | "min_width": null,
1137 | "object_fit": null,
1138 | "object_position": null,
1139 | "order": null,
1140 | "overflow": null,
1141 | "overflow_x": null,
1142 | "overflow_y": null,
1143 | "padding": null,
1144 | "right": null,
1145 | "top": null,
1146 | "visibility": null,
1147 | "width": null
1148 | }
1149 | },
1150 | "5979ebdad11f4f6f941b32ba4a509416": {
1151 | "model_module": "@jupyter-widgets/controls",
1152 | "model_name": "ProgressStyleModel",
1153 | "state": {
1154 | "_model_module": "@jupyter-widgets/controls",
1155 | "_model_module_version": "1.5.0",
1156 | "_model_name": "ProgressStyleModel",
1157 | "_view_count": null,
1158 | "_view_module": "@jupyter-widgets/base",
1159 | "_view_module_version": "1.2.0",
1160 | "_view_name": "StyleView",
1161 | "bar_color": null,
1162 | "description_width": "initial"
1163 | }
1164 | },
1165 | "5c6312228c08467186d56dba16c4028c": {
1166 | "model_module": "@jupyter-widgets/base",
1167 | "model_name": "LayoutModel",
1168 | "state": {
1169 | "_model_module": "@jupyter-widgets/base",
1170 | "_model_module_version": "1.2.0",
1171 | "_model_name": "LayoutModel",
1172 | "_view_count": null,
1173 | "_view_module": "@jupyter-widgets/base",
1174 | "_view_module_version": "1.2.0",
1175 | "_view_name": "LayoutView",
1176 | "align_content": null,
1177 | "align_items": null,
1178 | "align_self": null,
1179 | "border": null,
1180 | "bottom": null,
1181 | "display": null,
1182 | "flex": null,
1183 | "flex_flow": null,
1184 | "grid_area": null,
1185 | "grid_auto_columns": null,
1186 | "grid_auto_flow": null,
1187 | "grid_auto_rows": null,
1188 | "grid_column": null,
1189 | "grid_gap": null,
1190 | "grid_row": null,
1191 | "grid_template_areas": null,
1192 | "grid_template_columns": null,
1193 | "grid_template_rows": null,
1194 | "height": null,
1195 | "justify_content": null,
1196 | "justify_items": null,
1197 | "left": null,
1198 | "margin": null,
1199 | "max_height": null,
1200 | "max_width": null,
1201 | "min_height": null,
1202 | "min_width": null,
1203 | "object_fit": null,
1204 | "object_position": null,
1205 | "order": null,
1206 | "overflow": null,
1207 | "overflow_x": null,
1208 | "overflow_y": null,
1209 | "padding": null,
1210 | "right": null,
1211 | "top": null,
1212 | "visibility": null,
1213 | "width": null
1214 | }
1215 | },
1216 | "5cf59228cd5844a8abbe62cb897ce431": {
1217 | "model_module": "@jupyter-widgets/controls",
1218 | "model_name": "DescriptionStyleModel",
1219 | "state": {
1220 | "_model_module": "@jupyter-widgets/controls",
1221 | "_model_module_version": "1.5.0",
1222 | "_model_name": "DescriptionStyleModel",
1223 | "_view_count": null,
1224 | "_view_module": "@jupyter-widgets/base",
1225 | "_view_module_version": "1.2.0",
1226 | "_view_name": "StyleView",
1227 | "description_width": ""
1228 | }
1229 | },
1230 | "600a3e8becc84abeadc682bae8db52d5": {
1231 | "model_module": "@jupyter-widgets/controls",
1232 | "model_name": "HBoxModel",
1233 | "state": {
1234 | "_dom_classes": [],
1235 | "_model_module": "@jupyter-widgets/controls",
1236 | "_model_module_version": "1.5.0",
1237 | "_model_name": "HBoxModel",
1238 | "_view_count": null,
1239 | "_view_module": "@jupyter-widgets/controls",
1240 | "_view_module_version": "1.5.0",
1241 | "_view_name": "HBoxView",
1242 | "box_style": "",
1243 | "children": [
1244 | "IPY_MODEL_ca89e296533745f7824bd9f91006d162",
1245 | "IPY_MODEL_fad326679299483f9cfcc800bd4aa549"
1246 | ],
1247 | "layout": "IPY_MODEL_97c62988c2e1454db66b33ab880f08d1"
1248 | }
1249 | },
1250 | "605b204e45bb401e97270eba6eceb351": {
1251 | "model_module": "@jupyter-widgets/base",
1252 | "model_name": "LayoutModel",
1253 | "state": {
1254 | "_model_module": "@jupyter-widgets/base",
1255 | "_model_module_version": "1.2.0",
1256 | "_model_name": "LayoutModel",
1257 | "_view_count": null,
1258 | "_view_module": "@jupyter-widgets/base",
1259 | "_view_module_version": "1.2.0",
1260 | "_view_name": "LayoutView",
1261 | "align_content": null,
1262 | "align_items": null,
1263 | "align_self": null,
1264 | "border": null,
1265 | "bottom": null,
1266 | "display": null,
1267 | "flex": null,
1268 | "flex_flow": null,
1269 | "grid_area": null,
1270 | "grid_auto_columns": null,
1271 | "grid_auto_flow": null,
1272 | "grid_auto_rows": null,
1273 | "grid_column": null,
1274 | "grid_gap": null,
1275 | "grid_row": null,
1276 | "grid_template_areas": null,
1277 | "grid_template_columns": null,
1278 | "grid_template_rows": null,
1279 | "height": null,
1280 | "justify_content": null,
1281 | "justify_items": null,
1282 | "left": null,
1283 | "margin": null,
1284 | "max_height": null,
1285 | "max_width": null,
1286 | "min_height": null,
1287 | "min_width": null,
1288 | "object_fit": null,
1289 | "object_position": null,
1290 | "order": null,
1291 | "overflow": null,
1292 | "overflow_x": null,
1293 | "overflow_y": null,
1294 | "padding": null,
1295 | "right": null,
1296 | "top": null,
1297 | "visibility": null,
1298 | "width": null
1299 | }
1300 | },
1301 | "642dffa8e9374234bbce51d29e75c7ea": {
1302 | "model_module": "@jupyter-widgets/controls",
1303 | "model_name": "FloatProgressModel",
1304 | "state": {
1305 | "_dom_classes": [],
1306 | "_model_module": "@jupyter-widgets/controls",
1307 | "_model_module_version": "1.5.0",
1308 | "_model_name": "FloatProgressModel",
1309 | "_view_count": null,
1310 | "_view_module": "@jupyter-widgets/controls",
1311 | "_view_module_version": "1.5.0",
1312 | "_view_name": "ProgressView",
1313 | "bar_style": "success",
1314 | "description": "Dl Size...: 100%",
1315 | "description_tooltip": null,
1316 | "layout": "IPY_MODEL_80d7be7130b8468c8ab3bcb095dc36bf",
1317 | "max": 1,
1318 | "min": 0,
1319 | "orientation": "horizontal",
1320 | "style": "IPY_MODEL_e292c1f1791e4b349a931400f52e980c",
1321 | "value": 1
1322 | }
1323 | },
1324 | "69b5d3f908c748509302621147b3517b": {
1325 | "model_module": "@jupyter-widgets/base",
1326 | "model_name": "LayoutModel",
1327 | "state": {
1328 | "_model_module": "@jupyter-widgets/base",
1329 | "_model_module_version": "1.2.0",
1330 | "_model_name": "LayoutModel",
1331 | "_view_count": null,
1332 | "_view_module": "@jupyter-widgets/base",
1333 | "_view_module_version": "1.2.0",
1334 | "_view_name": "LayoutView",
1335 | "align_content": null,
1336 | "align_items": null,
1337 | "align_self": null,
1338 | "border": null,
1339 | "bottom": null,
1340 | "display": null,
1341 | "flex": null,
1342 | "flex_flow": null,
1343 | "grid_area": null,
1344 | "grid_auto_columns": null,
1345 | "grid_auto_flow": null,
1346 | "grid_auto_rows": null,
1347 | "grid_column": null,
1348 | "grid_gap": null,
1349 | "grid_row": null,
1350 | "grid_template_areas": null,
1351 | "grid_template_columns": null,
1352 | "grid_template_rows": null,
1353 | "height": null,
1354 | "justify_content": null,
1355 | "justify_items": null,
1356 | "left": null,
1357 | "margin": null,
1358 | "max_height": null,
1359 | "max_width": null,
1360 | "min_height": null,
1361 | "min_width": null,
1362 | "object_fit": null,
1363 | "object_position": null,
1364 | "order": null,
1365 | "overflow": null,
1366 | "overflow_x": null,
1367 | "overflow_y": null,
1368 | "padding": null,
1369 | "right": null,
1370 | "top": null,
1371 | "visibility": null,
1372 | "width": null
1373 | }
1374 | },
1375 | "6abe70e78dce42a091dca068f5af4696": {
1376 | "model_module": "@jupyter-widgets/base",
1377 | "model_name": "LayoutModel",
1378 | "state": {
1379 | "_model_module": "@jupyter-widgets/base",
1380 | "_model_module_version": "1.2.0",
1381 | "_model_name": "LayoutModel",
1382 | "_view_count": null,
1383 | "_view_module": "@jupyter-widgets/base",
1384 | "_view_module_version": "1.2.0",
1385 | "_view_name": "LayoutView",
1386 | "align_content": null,
1387 | "align_items": null,
1388 | "align_self": null,
1389 | "border": null,
1390 | "bottom": null,
1391 | "display": null,
1392 | "flex": null,
1393 | "flex_flow": null,
1394 | "grid_area": null,
1395 | "grid_auto_columns": null,
1396 | "grid_auto_flow": null,
1397 | "grid_auto_rows": null,
1398 | "grid_column": null,
1399 | "grid_gap": null,
1400 | "grid_row": null,
1401 | "grid_template_areas": null,
1402 | "grid_template_columns": null,
1403 | "grid_template_rows": null,
1404 | "height": null,
1405 | "justify_content": null,
1406 | "justify_items": null,
1407 | "left": null,
1408 | "margin": null,
1409 | "max_height": null,
1410 | "max_width": null,
1411 | "min_height": null,
1412 | "min_width": null,
1413 | "object_fit": null,
1414 | "object_position": null,
1415 | "order": null,
1416 | "overflow": null,
1417 | "overflow_x": null,
1418 | "overflow_y": null,
1419 | "padding": null,
1420 | "right": null,
1421 | "top": null,
1422 | "visibility": null,
1423 | "width": null
1424 | }
1425 | },
1426 | "6ba3d098f19b4919a06ecca5b3763596": {
1427 | "model_module": "@jupyter-widgets/base",
1428 | "model_name": "LayoutModel",
1429 | "state": {
1430 | "_model_module": "@jupyter-widgets/base",
1431 | "_model_module_version": "1.2.0",
1432 | "_model_name": "LayoutModel",
1433 | "_view_count": null,
1434 | "_view_module": "@jupyter-widgets/base",
1435 | "_view_module_version": "1.2.0",
1436 | "_view_name": "LayoutView",
1437 | "align_content": null,
1438 | "align_items": null,
1439 | "align_self": null,
1440 | "border": null,
1441 | "bottom": null,
1442 | "display": null,
1443 | "flex": null,
1444 | "flex_flow": null,
1445 | "grid_area": null,
1446 | "grid_auto_columns": null,
1447 | "grid_auto_flow": null,
1448 | "grid_auto_rows": null,
1449 | "grid_column": null,
1450 | "grid_gap": null,
1451 | "grid_row": null,
1452 | "grid_template_areas": null,
1453 | "grid_template_columns": null,
1454 | "grid_template_rows": null,
1455 | "height": null,
1456 | "justify_content": null,
1457 | "justify_items": null,
1458 | "left": null,
1459 | "margin": null,
1460 | "max_height": null,
1461 | "max_width": null,
1462 | "min_height": null,
1463 | "min_width": null,
1464 | "object_fit": null,
1465 | "object_position": null,
1466 | "order": null,
1467 | "overflow": null,
1468 | "overflow_x": null,
1469 | "overflow_y": null,
1470 | "padding": null,
1471 | "right": null,
1472 | "top": null,
1473 | "visibility": null,
1474 | "width": null
1475 | }
1476 | },
1477 | "74034286026f457892c06d6025286628": {
1478 | "model_module": "@jupyter-widgets/base",
1479 | "model_name": "LayoutModel",
1480 | "state": {
1481 | "_model_module": "@jupyter-widgets/base",
1482 | "_model_module_version": "1.2.0",
1483 | "_model_name": "LayoutModel",
1484 | "_view_count": null,
1485 | "_view_module": "@jupyter-widgets/base",
1486 | "_view_module_version": "1.2.0",
1487 | "_view_name": "LayoutView",
1488 | "align_content": null,
1489 | "align_items": null,
1490 | "align_self": null,
1491 | "border": null,
1492 | "bottom": null,
1493 | "display": null,
1494 | "flex": null,
1495 | "flex_flow": null,
1496 | "grid_area": null,
1497 | "grid_auto_columns": null,
1498 | "grid_auto_flow": null,
1499 | "grid_auto_rows": null,
1500 | "grid_column": null,
1501 | "grid_gap": null,
1502 | "grid_row": null,
1503 | "grid_template_areas": null,
1504 | "grid_template_columns": null,
1505 | "grid_template_rows": null,
1506 | "height": null,
1507 | "justify_content": null,
1508 | "justify_items": null,
1509 | "left": null,
1510 | "margin": null,
1511 | "max_height": null,
1512 | "max_width": null,
1513 | "min_height": null,
1514 | "min_width": null,
1515 | "object_fit": null,
1516 | "object_position": null,
1517 | "order": null,
1518 | "overflow": null,
1519 | "overflow_x": null,
1520 | "overflow_y": null,
1521 | "padding": null,
1522 | "right": null,
1523 | "top": null,
1524 | "visibility": null,
1525 | "width": null
1526 | }
1527 | },
1528 | "7e7891071f7c4f5e87d876da643a3045": {
1529 | "model_module": "@jupyter-widgets/controls",
1530 | "model_name": "DescriptionStyleModel",
1531 | "state": {
1532 | "_model_module": "@jupyter-widgets/controls",
1533 | "_model_module_version": "1.5.0",
1534 | "_model_name": "DescriptionStyleModel",
1535 | "_view_count": null,
1536 | "_view_module": "@jupyter-widgets/base",
1537 | "_view_module_version": "1.2.0",
1538 | "_view_name": "StyleView",
1539 | "description_width": ""
1540 | }
1541 | },
1542 | "7e80e3f85f4c4937b4e433f5d0cf8651": {
1543 | "model_module": "@jupyter-widgets/controls",
1544 | "model_name": "ProgressStyleModel",
1545 | "state": {
1546 | "_model_module": "@jupyter-widgets/controls",
1547 | "_model_module_version": "1.5.0",
1548 | "_model_name": "ProgressStyleModel",
1549 | "_view_count": null,
1550 | "_view_module": "@jupyter-widgets/base",
1551 | "_view_module_version": "1.2.0",
1552 | "_view_name": "StyleView",
1553 | "bar_color": null,
1554 | "description_width": "initial"
1555 | }
1556 | },
1557 | "80d7be7130b8468c8ab3bcb095dc36bf": {
1558 | "model_module": "@jupyter-widgets/base",
1559 | "model_name": "LayoutModel",
1560 | "state": {
1561 | "_model_module": "@jupyter-widgets/base",
1562 | "_model_module_version": "1.2.0",
1563 | "_model_name": "LayoutModel",
1564 | "_view_count": null,
1565 | "_view_module": "@jupyter-widgets/base",
1566 | "_view_module_version": "1.2.0",
1567 | "_view_name": "LayoutView",
1568 | "align_content": null,
1569 | "align_items": null,
1570 | "align_self": null,
1571 | "border": null,
1572 | "bottom": null,
1573 | "display": null,
1574 | "flex": null,
1575 | "flex_flow": null,
1576 | "grid_area": null,
1577 | "grid_auto_columns": null,
1578 | "grid_auto_flow": null,
1579 | "grid_auto_rows": null,
1580 | "grid_column": null,
1581 | "grid_gap": null,
1582 | "grid_row": null,
1583 | "grid_template_areas": null,
1584 | "grid_template_columns": null,
1585 | "grid_template_rows": null,
1586 | "height": null,
1587 | "justify_content": null,
1588 | "justify_items": null,
1589 | "left": null,
1590 | "margin": null,
1591 | "max_height": null,
1592 | "max_width": null,
1593 | "min_height": null,
1594 | "min_width": null,
1595 | "object_fit": null,
1596 | "object_position": null,
1597 | "order": null,
1598 | "overflow": null,
1599 | "overflow_x": null,
1600 | "overflow_y": null,
1601 | "padding": null,
1602 | "right": null,
1603 | "top": null,
1604 | "visibility": null,
1605 | "width": null
1606 | }
1607 | },
1608 | "81e3fa193c364c0f90b1dc7bca808eb9": {
1609 | "model_module": "@jupyter-widgets/base",
1610 | "model_name": "LayoutModel",
1611 | "state": {
1612 | "_model_module": "@jupyter-widgets/base",
1613 | "_model_module_version": "1.2.0",
1614 | "_model_name": "LayoutModel",
1615 | "_view_count": null,
1616 | "_view_module": "@jupyter-widgets/base",
1617 | "_view_module_version": "1.2.0",
1618 | "_view_name": "LayoutView",
1619 | "align_content": null,
1620 | "align_items": null,
1621 | "align_self": null,
1622 | "border": null,
1623 | "bottom": null,
1624 | "display": null,
1625 | "flex": null,
1626 | "flex_flow": null,
1627 | "grid_area": null,
1628 | "grid_auto_columns": null,
1629 | "grid_auto_flow": null,
1630 | "grid_auto_rows": null,
1631 | "grid_column": null,
1632 | "grid_gap": null,
1633 | "grid_row": null,
1634 | "grid_template_areas": null,
1635 | "grid_template_columns": null,
1636 | "grid_template_rows": null,
1637 | "height": null,
1638 | "justify_content": null,
1639 | "justify_items": null,
1640 | "left": null,
1641 | "margin": null,
1642 | "max_height": null,
1643 | "max_width": null,
1644 | "min_height": null,
1645 | "min_width": null,
1646 | "object_fit": null,
1647 | "object_position": null,
1648 | "order": null,
1649 | "overflow": null,
1650 | "overflow_x": null,
1651 | "overflow_y": null,
1652 | "padding": null,
1653 | "right": null,
1654 | "top": null,
1655 | "visibility": null,
1656 | "width": null
1657 | }
1658 | },
1659 | "83c9ad4ed8bb4aa9804a23a330da4d8c": {
1660 | "model_module": "@jupyter-widgets/controls",
1661 | "model_name": "HBoxModel",
1662 | "state": {
1663 | "_dom_classes": [],
1664 | "_model_module": "@jupyter-widgets/controls",
1665 | "_model_module_version": "1.5.0",
1666 | "_model_name": "HBoxModel",
1667 | "_view_count": null,
1668 | "_view_module": "@jupyter-widgets/controls",
1669 | "_view_module_version": "1.5.0",
1670 | "_view_name": "HBoxView",
1671 | "box_style": "",
1672 | "children": [
1673 | "IPY_MODEL_c94bf43a065b44c182d8af0c717e92ca",
1674 | "IPY_MODEL_cd600dd278a04f1fabc50fd0e9639fc4"
1675 | ],
1676 | "layout": "IPY_MODEL_2d19d5c44a0348768a9ff242aa199119"
1677 | }
1678 | },
1679 | "84132b5022db442183e409973de11d67": {
1680 | "model_module": "@jupyter-widgets/controls",
1681 | "model_name": "ProgressStyleModel",
1682 | "state": {
1683 | "_model_module": "@jupyter-widgets/controls",
1684 | "_model_module_version": "1.5.0",
1685 | "_model_name": "ProgressStyleModel",
1686 | "_view_count": null,
1687 | "_view_module": "@jupyter-widgets/base",
1688 | "_view_module_version": "1.2.0",
1689 | "_view_name": "StyleView",
1690 | "bar_color": null,
1691 | "description_width": "initial"
1692 | }
1693 | },
1694 | "8aebd4d81c6a47f78c84602fcef1249c": {
1695 | "model_module": "@jupyter-widgets/controls",
1696 | "model_name": "FloatProgressModel",
1697 | "state": {
1698 | "_dom_classes": [],
1699 | "_model_module": "@jupyter-widgets/controls",
1700 | "_model_module_version": "1.5.0",
1701 | "_model_name": "FloatProgressModel",
1702 | "_view_count": null,
1703 | "_view_module": "@jupyter-widgets/controls",
1704 | "_view_module_version": "1.5.0",
1705 | "_view_name": "ProgressView",
1706 | "bar_style": "info",
1707 | "description": "",
1708 | "description_tooltip": null,
1709 | "layout": "IPY_MODEL_69b5d3f908c748509302621147b3517b",
1710 | "max": 1,
1711 | "min": 0,
1712 | "orientation": "horizontal",
1713 | "style": "IPY_MODEL_bd9ea3399660446897f60580a39588f2",
1714 | "value": 1
1715 | }
1716 | },
1717 | "8fd55e55c24e48afb223ff8b7422a546": {
1718 | "model_module": "@jupyter-widgets/base",
1719 | "model_name": "LayoutModel",
1720 | "state": {
1721 | "_model_module": "@jupyter-widgets/base",
1722 | "_model_module_version": "1.2.0",
1723 | "_model_name": "LayoutModel",
1724 | "_view_count": null,
1725 | "_view_module": "@jupyter-widgets/base",
1726 | "_view_module_version": "1.2.0",
1727 | "_view_name": "LayoutView",
1728 | "align_content": null,
1729 | "align_items": null,
1730 | "align_self": null,
1731 | "border": null,
1732 | "bottom": null,
1733 | "display": null,
1734 | "flex": null,
1735 | "flex_flow": null,
1736 | "grid_area": null,
1737 | "grid_auto_columns": null,
1738 | "grid_auto_flow": null,
1739 | "grid_auto_rows": null,
1740 | "grid_column": null,
1741 | "grid_gap": null,
1742 | "grid_row": null,
1743 | "grid_template_areas": null,
1744 | "grid_template_columns": null,
1745 | "grid_template_rows": null,
1746 | "height": null,
1747 | "justify_content": null,
1748 | "justify_items": null,
1749 | "left": null,
1750 | "margin": null,
1751 | "max_height": null,
1752 | "max_width": null,
1753 | "min_height": null,
1754 | "min_width": null,
1755 | "object_fit": null,
1756 | "object_position": null,
1757 | "order": null,
1758 | "overflow": null,
1759 | "overflow_x": null,
1760 | "overflow_y": null,
1761 | "padding": null,
1762 | "right": null,
1763 | "top": null,
1764 | "visibility": null,
1765 | "width": null
1766 | }
1767 | },
1768 | "954933d3187645c6bd191289c122a6b5": {
1769 | "model_module": "@jupyter-widgets/controls",
1770 | "model_name": "FloatProgressModel",
1771 | "state": {
1772 | "_dom_classes": [],
1773 | "_model_module": "@jupyter-widgets/controls",
1774 | "_model_module_version": "1.5.0",
1775 | "_model_name": "FloatProgressModel",
1776 | "_view_count": null,
1777 | "_view_module": "@jupyter-widgets/controls",
1778 | "_view_module_version": "1.5.0",
1779 | "_view_name": "ProgressView",
1780 | "bar_style": "success",
1781 | "description": "Dl Completed...: 100%",
1782 | "description_tooltip": null,
1783 | "layout": "IPY_MODEL_f60bacd17c604608a7fd80a06a305bdf",
1784 | "max": 1,
1785 | "min": 0,
1786 | "orientation": "horizontal",
1787 | "style": "IPY_MODEL_84132b5022db442183e409973de11d67",
1788 | "value": 1
1789 | }
1790 | },
1791 | "97c62988c2e1454db66b33ab880f08d1": {
1792 | "model_module": "@jupyter-widgets/base",
1793 | "model_name": "LayoutModel",
1794 | "state": {
1795 | "_model_module": "@jupyter-widgets/base",
1796 | "_model_module_version": "1.2.0",
1797 | "_model_name": "LayoutModel",
1798 | "_view_count": null,
1799 | "_view_module": "@jupyter-widgets/base",
1800 | "_view_module_version": "1.2.0",
1801 | "_view_name": "LayoutView",
1802 | "align_content": null,
1803 | "align_items": null,
1804 | "align_self": null,
1805 | "border": null,
1806 | "bottom": null,
1807 | "display": null,
1808 | "flex": null,
1809 | "flex_flow": null,
1810 | "grid_area": null,
1811 | "grid_auto_columns": null,
1812 | "grid_auto_flow": null,
1813 | "grid_auto_rows": null,
1814 | "grid_column": null,
1815 | "grid_gap": null,
1816 | "grid_row": null,
1817 | "grid_template_areas": null,
1818 | "grid_template_columns": null,
1819 | "grid_template_rows": null,
1820 | "height": null,
1821 | "justify_content": null,
1822 | "justify_items": null,
1823 | "left": null,
1824 | "margin": null,
1825 | "max_height": null,
1826 | "max_width": null,
1827 | "min_height": null,
1828 | "min_width": null,
1829 | "object_fit": null,
1830 | "object_position": null,
1831 | "order": null,
1832 | "overflow": null,
1833 | "overflow_x": null,
1834 | "overflow_y": null,
1835 | "padding": null,
1836 | "right": null,
1837 | "top": null,
1838 | "visibility": null,
1839 | "width": null
1840 | }
1841 | },
1842 | "a625ba1ffcc340c5a4f042be0b4877c3": {
1843 | "model_module": "@jupyter-widgets/controls",
1844 | "model_name": "DescriptionStyleModel",
1845 | "state": {
1846 | "_model_module": "@jupyter-widgets/controls",
1847 | "_model_module_version": "1.5.0",
1848 | "_model_name": "DescriptionStyleModel",
1849 | "_view_count": null,
1850 | "_view_module": "@jupyter-widgets/base",
1851 | "_view_module_version": "1.2.0",
1852 | "_view_name": "StyleView",
1853 | "description_width": ""
1854 | }
1855 | },
1856 | "a6aa784cb2a14bc29ed56a7e29857061": {
1857 | "model_module": "@jupyter-widgets/base",
1858 | "model_name": "LayoutModel",
1859 | "state": {
1860 | "_model_module": "@jupyter-widgets/base",
1861 | "_model_module_version": "1.2.0",
1862 | "_model_name": "LayoutModel",
1863 | "_view_count": null,
1864 | "_view_module": "@jupyter-widgets/base",
1865 | "_view_module_version": "1.2.0",
1866 | "_view_name": "LayoutView",
1867 | "align_content": null,
1868 | "align_items": null,
1869 | "align_self": null,
1870 | "border": null,
1871 | "bottom": null,
1872 | "display": null,
1873 | "flex": null,
1874 | "flex_flow": null,
1875 | "grid_area": null,
1876 | "grid_auto_columns": null,
1877 | "grid_auto_flow": null,
1878 | "grid_auto_rows": null,
1879 | "grid_column": null,
1880 | "grid_gap": null,
1881 | "grid_row": null,
1882 | "grid_template_areas": null,
1883 | "grid_template_columns": null,
1884 | "grid_template_rows": null,
1885 | "height": null,
1886 | "justify_content": null,
1887 | "justify_items": null,
1888 | "left": null,
1889 | "margin": null,
1890 | "max_height": null,
1891 | "max_width": null,
1892 | "min_height": null,
1893 | "min_width": null,
1894 | "object_fit": null,
1895 | "object_position": null,
1896 | "order": null,
1897 | "overflow": null,
1898 | "overflow_x": null,
1899 | "overflow_y": null,
1900 | "padding": null,
1901 | "right": null,
1902 | "top": null,
1903 | "visibility": null,
1904 | "width": null
1905 | }
1906 | },
1907 | "a7bdbd085e5c43e3b4572f53175f61b6": {
1908 | "model_module": "@jupyter-widgets/controls",
1909 | "model_name": "DescriptionStyleModel",
1910 | "state": {
1911 | "_model_module": "@jupyter-widgets/controls",
1912 | "_model_module_version": "1.5.0",
1913 | "_model_name": "DescriptionStyleModel",
1914 | "_view_count": null,
1915 | "_view_module": "@jupyter-widgets/base",
1916 | "_view_module_version": "1.2.0",
1917 | "_view_name": "StyleView",
1918 | "description_width": ""
1919 | }
1920 | },
1921 | "aa67ec0f9e06454b99dea3b324bfaeb2": {
1922 | "model_module": "@jupyter-widgets/controls",
1923 | "model_name": "HTMLModel",
1924 | "state": {
1925 | "_dom_classes": [],
1926 | "_model_module": "@jupyter-widgets/controls",
1927 | "_model_module_version": "1.5.0",
1928 | "_model_name": "HTMLModel",
1929 | "_view_count": null,
1930 | "_view_module": "@jupyter-widgets/controls",
1931 | "_view_module_version": "1.5.0",
1932 | "_view_name": "HTMLView",
1933 | "description": "",
1934 | "description_tooltip": null,
1935 | "layout": "IPY_MODEL_c985393de22c4f97a7219f53642a38e6",
1936 | "placeholder": "",
1937 | "style": "IPY_MODEL_3e2b935749cb42aeac6507c0f4697295",
1938 | "value": " 3669/0 [00:02<00:00, 1372.29 examples/s]"
1939 | }
1940 | },
1941 | "bbffba56f53b4b68bdcba5875e7c2f07": {
1942 | "model_module": "@jupyter-widgets/controls",
1943 | "model_name": "HBoxModel",
1944 | "state": {
1945 | "_dom_classes": [],
1946 | "_model_module": "@jupyter-widgets/controls",
1947 | "_model_module_version": "1.5.0",
1948 | "_model_name": "HBoxModel",
1949 | "_view_count": null,
1950 | "_view_module": "@jupyter-widgets/controls",
1951 | "_view_module_version": "1.5.0",
1952 | "_view_name": "HBoxView",
1953 | "box_style": "",
1954 | "children": [
1955 | "IPY_MODEL_0e398791e69f4c3c9f3b8f8928ffac94",
1956 | "IPY_MODEL_d8118a57a2814d19b2fde27d2452c84e"
1957 | ],
1958 | "layout": "IPY_MODEL_74034286026f457892c06d6025286628"
1959 | }
1960 | },
1961 | "bd9ea3399660446897f60580a39588f2": {
1962 | "model_module": "@jupyter-widgets/controls",
1963 | "model_name": "ProgressStyleModel",
1964 | "state": {
1965 | "_model_module": "@jupyter-widgets/controls",
1966 | "_model_module_version": "1.5.0",
1967 | "_model_name": "ProgressStyleModel",
1968 | "_view_count": null,
1969 | "_view_module": "@jupyter-widgets/base",
1970 | "_view_module_version": "1.2.0",
1971 | "_view_name": "StyleView",
1972 | "bar_color": null,
1973 | "description_width": "initial"
1974 | }
1975 | },
1976 | "bffbcc4727fd4a2fa5694aea1680af65": {
1977 | "model_module": "@jupyter-widgets/base",
1978 | "model_name": "LayoutModel",
1979 | "state": {
1980 | "_model_module": "@jupyter-widgets/base",
1981 | "_model_module_version": "1.2.0",
1982 | "_model_name": "LayoutModel",
1983 | "_view_count": null,
1984 | "_view_module": "@jupyter-widgets/base",
1985 | "_view_module_version": "1.2.0",
1986 | "_view_name": "LayoutView",
1987 | "align_content": null,
1988 | "align_items": null,
1989 | "align_self": null,
1990 | "border": null,
1991 | "bottom": null,
1992 | "display": null,
1993 | "flex": null,
1994 | "flex_flow": null,
1995 | "grid_area": null,
1996 | "grid_auto_columns": null,
1997 | "grid_auto_flow": null,
1998 | "grid_auto_rows": null,
1999 | "grid_column": null,
2000 | "grid_gap": null,
2001 | "grid_row": null,
2002 | "grid_template_areas": null,
2003 | "grid_template_columns": null,
2004 | "grid_template_rows": null,
2005 | "height": null,
2006 | "justify_content": null,
2007 | "justify_items": null,
2008 | "left": null,
2009 | "margin": null,
2010 | "max_height": null,
2011 | "max_width": null,
2012 | "min_height": null,
2013 | "min_width": null,
2014 | "object_fit": null,
2015 | "object_position": null,
2016 | "order": null,
2017 | "overflow": null,
2018 | "overflow_x": null,
2019 | "overflow_y": null,
2020 | "padding": null,
2021 | "right": null,
2022 | "top": null,
2023 | "visibility": null,
2024 | "width": null
2025 | }
2026 | },
2027 | "c11d62582c59442e9cf388a6fbacaad7": {
2028 | "model_module": "@jupyter-widgets/controls",
2029 | "model_name": "FloatProgressModel",
2030 | "state": {
2031 | "_dom_classes": [],
2032 | "_model_module": "@jupyter-widgets/controls",
2033 | "_model_module_version": "1.5.0",
2034 | "_model_name": "FloatProgressModel",
2035 | "_view_count": null,
2036 | "_view_module": "@jupyter-widgets/controls",
2037 | "_view_module_version": "1.5.0",
2038 | "_view_name": "ProgressView",
2039 | "bar_style": "danger",
2040 | "description": " 94%",
2041 | "description_tooltip": null,
2042 | "layout": "IPY_MODEL_2638e1b29d744f6b8a147f8bfa6f89d4",
2043 | "max": 3680,
2044 | "min": 0,
2045 | "orientation": "horizontal",
2046 | "style": "IPY_MODEL_5979ebdad11f4f6f941b32ba4a509416",
2047 | "value": 3461
2048 | }
2049 | },
2050 | "c94bf43a065b44c182d8af0c717e92ca": {
2051 | "model_module": "@jupyter-widgets/controls",
2052 | "model_name": "FloatProgressModel",
2053 | "state": {
2054 | "_dom_classes": [],
2055 | "_model_module": "@jupyter-widgets/controls",
2056 | "_model_module_version": "1.5.0",
2057 | "_model_name": "FloatProgressModel",
2058 | "_view_count": null,
2059 | "_view_module": "@jupyter-widgets/controls",
2060 | "_view_module_version": "1.5.0",
2061 | "_view_name": "ProgressView",
2062 | "bar_style": "info",
2063 | "description": "",
2064 | "description_tooltip": null,
2065 | "layout": "IPY_MODEL_bffbcc4727fd4a2fa5694aea1680af65",
2066 | "max": 1,
2067 | "min": 0,
2068 | "orientation": "horizontal",
2069 | "style": "IPY_MODEL_231b5248ecd746d0939739f6a42db75e",
2070 | "value": 1
2071 | }
2072 | },
2073 | "c985393de22c4f97a7219f53642a38e6": {
2074 | "model_module": "@jupyter-widgets/base",
2075 | "model_name": "LayoutModel",
2076 | "state": {
2077 | "_model_module": "@jupyter-widgets/base",
2078 | "_model_module_version": "1.2.0",
2079 | "_model_name": "LayoutModel",
2080 | "_view_count": null,
2081 | "_view_module": "@jupyter-widgets/base",
2082 | "_view_module_version": "1.2.0",
2083 | "_view_name": "LayoutView",
2084 | "align_content": null,
2085 | "align_items": null,
2086 | "align_self": null,
2087 | "border": null,
2088 | "bottom": null,
2089 | "display": null,
2090 | "flex": null,
2091 | "flex_flow": null,
2092 | "grid_area": null,
2093 | "grid_auto_columns": null,
2094 | "grid_auto_flow": null,
2095 | "grid_auto_rows": null,
2096 | "grid_column": null,
2097 | "grid_gap": null,
2098 | "grid_row": null,
2099 | "grid_template_areas": null,
2100 | "grid_template_columns": null,
2101 | "grid_template_rows": null,
2102 | "height": null,
2103 | "justify_content": null,
2104 | "justify_items": null,
2105 | "left": null,
2106 | "margin": null,
2107 | "max_height": null,
2108 | "max_width": null,
2109 | "min_height": null,
2110 | "min_width": null,
2111 | "object_fit": null,
2112 | "object_position": null,
2113 | "order": null,
2114 | "overflow": null,
2115 | "overflow_x": null,
2116 | "overflow_y": null,
2117 | "padding": null,
2118 | "right": null,
2119 | "top": null,
2120 | "visibility": null,
2121 | "width": null
2122 | }
2123 | },
2124 | "ca89e296533745f7824bd9f91006d162": {
2125 | "model_module": "@jupyter-widgets/controls",
2126 | "model_name": "FloatProgressModel",
2127 | "state": {
2128 | "_dom_classes": [],
2129 | "_model_module": "@jupyter-widgets/controls",
2130 | "_model_module_version": "1.5.0",
2131 | "_model_name": "FloatProgressModel",
2132 | "_view_count": null,
2133 | "_view_module": "@jupyter-widgets/controls",
2134 | "_view_module_version": "1.5.0",
2135 | "_view_name": "ProgressView",
2136 | "bar_style": "danger",
2137 | "description": " 99%",
2138 | "description_tooltip": null,
2139 | "layout": "IPY_MODEL_6ba3d098f19b4919a06ecca5b3763596",
2140 | "max": 3669,
2141 | "min": 0,
2142 | "orientation": "horizontal",
2143 | "style": "IPY_MODEL_faee640e94db490f8cd61481aa74a92a",
2144 | "value": 3648
2145 | }
2146 | },
2147 | "cd600dd278a04f1fabc50fd0e9639fc4": {
2148 | "model_module": "@jupyter-widgets/controls",
2149 | "model_name": "HTMLModel",
2150 | "state": {
2151 | "_dom_classes": [],
2152 | "_model_module": "@jupyter-widgets/controls",
2153 | "_model_module_version": "1.5.0",
2154 | "_model_name": "HTMLModel",
2155 | "_view_count": null,
2156 | "_view_module": "@jupyter-widgets/controls",
2157 | "_view_module_version": "1.5.0",
2158 | "_view_name": "HTMLView",
2159 | "description": "",
2160 | "description_tooltip": null,
2161 | "layout": "IPY_MODEL_6abe70e78dce42a091dca068f5af4696",
2162 | "placeholder": "",
2163 | "style": "IPY_MODEL_a7bdbd085e5c43e3b4572f53175f61b6",
2164 | "value": " 3680/0 [00:02<00:00, 1421.50 examples/s]"
2165 | }
2166 | },
2167 | "d8118a57a2814d19b2fde27d2452c84e": {
2168 | "model_module": "@jupyter-widgets/controls",
2169 | "model_name": "HTMLModel",
2170 | "state": {
2171 | "_dom_classes": [],
2172 | "_model_module": "@jupyter-widgets/controls",
2173 | "_model_module_version": "1.5.0",
2174 | "_model_name": "HTMLModel",
2175 | "_view_count": null,
2176 | "_view_module": "@jupyter-widgets/controls",
2177 | "_view_module_version": "1.5.0",
2178 | "_view_name": "HTMLView",
2179 | "description": "",
2180 | "description_tooltip": null,
2181 | "layout": "IPY_MODEL_5695c6a28a5e47008227462f5ade5c9b",
2182 | "placeholder": "",
2183 | "style": "IPY_MODEL_a625ba1ffcc340c5a4f042be0b4877c3",
2184 | "value": " 2/2 [00:37<00:00, 18.60s/ file]"
2185 | }
2186 | },
2187 | "e292c1f1791e4b349a931400f52e980c": {
2188 | "model_module": "@jupyter-widgets/controls",
2189 | "model_name": "ProgressStyleModel",
2190 | "state": {
2191 | "_model_module": "@jupyter-widgets/controls",
2192 | "_model_module_version": "1.5.0",
2193 | "_model_name": "ProgressStyleModel",
2194 | "_view_count": null,
2195 | "_view_module": "@jupyter-widgets/base",
2196 | "_view_module_version": "1.2.0",
2197 | "_view_name": "StyleView",
2198 | "bar_color": null,
2199 | "description_width": "initial"
2200 | }
2201 | },
2202 | "e6d3b16d24cd4468b68af5be44eeaa46": {
2203 | "model_module": "@jupyter-widgets/controls",
2204 | "model_name": "HBoxModel",
2205 | "state": {
2206 | "_dom_classes": [],
2207 | "_model_module": "@jupyter-widgets/controls",
2208 | "_model_module_version": "1.5.0",
2209 | "_model_name": "HBoxModel",
2210 | "_view_count": null,
2211 | "_view_module": "@jupyter-widgets/controls",
2212 | "_view_module_version": "1.5.0",
2213 | "_view_name": "HBoxView",
2214 | "box_style": "",
2215 | "children": [
2216 | "IPY_MODEL_954933d3187645c6bd191289c122a6b5",
2217 | "IPY_MODEL_1c7b1be085354daa90f7491366bf3f26"
2218 | ],
2219 | "layout": "IPY_MODEL_8fd55e55c24e48afb223ff8b7422a546"
2220 | }
2221 | },
2222 | "f5b8dec6698746268a69122340b1f989": {
2223 | "model_module": "@jupyter-widgets/controls",
2224 | "model_name": "DescriptionStyleModel",
2225 | "state": {
2226 | "_model_module": "@jupyter-widgets/controls",
2227 | "_model_module_version": "1.5.0",
2228 | "_model_name": "DescriptionStyleModel",
2229 | "_view_count": null,
2230 | "_view_module": "@jupyter-widgets/base",
2231 | "_view_module_version": "1.2.0",
2232 | "_view_name": "StyleView",
2233 | "description_width": ""
2234 | }
2235 | },
2236 | "f60bacd17c604608a7fd80a06a305bdf": {
2237 | "model_module": "@jupyter-widgets/base",
2238 | "model_name": "LayoutModel",
2239 | "state": {
2240 | "_model_module": "@jupyter-widgets/base",
2241 | "_model_module_version": "1.2.0",
2242 | "_model_name": "LayoutModel",
2243 | "_view_count": null,
2244 | "_view_module": "@jupyter-widgets/base",
2245 | "_view_module_version": "1.2.0",
2246 | "_view_name": "LayoutView",
2247 | "align_content": null,
2248 | "align_items": null,
2249 | "align_self": null,
2250 | "border": null,
2251 | "bottom": null,
2252 | "display": null,
2253 | "flex": null,
2254 | "flex_flow": null,
2255 | "grid_area": null,
2256 | "grid_auto_columns": null,
2257 | "grid_auto_flow": null,
2258 | "grid_auto_rows": null,
2259 | "grid_column": null,
2260 | "grid_gap": null,
2261 | "grid_row": null,
2262 | "grid_template_areas": null,
2263 | "grid_template_columns": null,
2264 | "grid_template_rows": null,
2265 | "height": null,
2266 | "justify_content": null,
2267 | "justify_items": null,
2268 | "left": null,
2269 | "margin": null,
2270 | "max_height": null,
2271 | "max_width": null,
2272 | "min_height": null,
2273 | "min_width": null,
2274 | "object_fit": null,
2275 | "object_position": null,
2276 | "order": null,
2277 | "overflow": null,
2278 | "overflow_x": null,
2279 | "overflow_y": null,
2280 | "padding": null,
2281 | "right": null,
2282 | "top": null,
2283 | "visibility": null,
2284 | "width": null
2285 | }
2286 | },
2287 | "fad326679299483f9cfcc800bd4aa549": {
2288 | "model_module": "@jupyter-widgets/controls",
2289 | "model_name": "HTMLModel",
2290 | "state": {
2291 | "_dom_classes": [],
2292 | "_model_module": "@jupyter-widgets/controls",
2293 | "_model_module_version": "1.5.0",
2294 | "_model_name": "HTMLModel",
2295 | "_view_count": null,
2296 | "_view_module": "@jupyter-widgets/controls",
2297 | "_view_module_version": "1.5.0",
2298 | "_view_name": "HTMLView",
2299 | "description": "",
2300 | "description_tooltip": null,
2301 | "layout": "IPY_MODEL_1689a79e51304012976295a13bac17cf",
2302 | "placeholder": "",
2303 | "style": "IPY_MODEL_f5b8dec6698746268a69122340b1f989",
2304 | "value": " 3648/3669 [00:01<00:00, 2429.53 examples/s]"
2305 | }
2306 | },
2307 | "faee640e94db490f8cd61481aa74a92a": {
2308 | "model_module": "@jupyter-widgets/controls",
2309 | "model_name": "ProgressStyleModel",
2310 | "state": {
2311 | "_model_module": "@jupyter-widgets/controls",
2312 | "_model_module_version": "1.5.0",
2313 | "_model_name": "ProgressStyleModel",
2314 | "_view_count": null,
2315 | "_view_module": "@jupyter-widgets/base",
2316 | "_view_module_version": "1.2.0",
2317 | "_view_name": "StyleView",
2318 | "bar_color": null,
2319 | "description_width": "initial"
2320 | }
2321 | },
2322 | "fc6c6d3b63634d1f9f58cebbd50f4793": {
2323 | "model_module": "@jupyter-widgets/controls",
2324 | "model_name": "HTMLModel",
2325 | "state": {
2326 | "_dom_classes": [],
2327 | "_model_module": "@jupyter-widgets/controls",
2328 | "_model_module_version": "1.5.0",
2329 | "_model_name": "HTMLModel",
2330 | "_view_count": null,
2331 | "_view_module": "@jupyter-widgets/controls",
2332 | "_view_module_version": "1.5.0",
2333 | "_view_name": "HTMLView",
2334 | "description": "",
2335 | "description_tooltip": null,
2336 | "layout": "IPY_MODEL_12aae5a01c5a462e999735d848d3e354",
2337 | "placeholder": "",
2338 | "style": "IPY_MODEL_7e7891071f7c4f5e87d876da643a3045",
2339 | "value": " 3461/3680 [00:00<00:00, 3002.87 examples/s]"
2340 | }
2341 | }
2342 | }
2343 | }
2344 | },
2345 | "nbformat": 4,
2346 | "nbformat_minor": 1
2347 | }
2348 |
--------------------------------------------------------------------------------
/notebooks/semantic-segmentation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "YfIk2es3hJEd"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import tensorflow as tf\n",
12 | "\n",
13 | "import os\n",
14 | "import time\n",
15 | "\n",
16 | "from matplotlib import pyplot as plt\n",
17 | "from IPython import display"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {
24 | "id": "2CbTEt448b4R"
25 | },
26 | "outputs": [],
27 | "source": [
28 | "BUFFER_SIZE = 400\n",
29 | "EPOCHS = 100\n",
30 | "LAMBDA = 100\n",
31 | "DATASET = 'cityscapes'\n",
32 | "BATCH_SIZE = 8\n",
33 | "IMG_WIDTH = 256\n",
34 | "IMG_HEIGHT = 256\n",
35 | "patch_size = 8\n",
36 | "num_patches = (IMG_HEIGHT // patch_size) ** 2\n",
37 | "projection_dim = 64\n",
38 | "embed_dim = 64\n",
39 | "num_heads = 2 \n",
40 | "ff_dim = 32\n",
41 | "\n",
42 | "assert IMG_WIDTH == IMG_HEIGHT, \"image width and image height must have same dims\"\n"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {
49 | "colab": {
50 | "base_uri": "https://localhost:8080/"
51 | },
52 | "id": "Kn-k8kTXuAlv",
53 | "outputId": "6322b63c-547d-4ae7-d1aa-5c5098e5fe3d"
54 | },
55 | "outputs": [],
56 | "source": [
57 | "_URL = f'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/{DATASET}.tar.gz'\n",
58 | "\n",
59 | "path_to_zip = tf.keras.utils.get_file(f'{DATASET}.tar.gz',\n",
60 | " origin=_URL,\n",
61 | " extract=True)\n",
62 | "\n",
63 | "PATH = os.path.join(os.path.dirname(path_to_zip), f'{DATASET}/')"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "id": "aO9ZAGH5K3SY"
71 | },
72 | "outputs": [],
73 | "source": [
74 | "def load(image_file):\n",
75 | " image = tf.io.read_file(image_file)\n",
76 | " image = tf.image.decode_jpeg(image)\n",
77 | "\n",
78 | " w = tf.shape(image)[1]\n",
79 | "\n",
80 | " w = w // 2\n",
81 | " real_image = image[:, :w, :]\n",
82 | " input_image = image[:, w:, :]\n",
83 | "\n",
84 | " input_image = tf.cast(input_image, tf.float32)\n",
85 | " real_image = tf.cast(real_image, tf.float32)\n",
86 | "\n",
87 | " return input_image, real_image"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {
94 | "colab": {
95 | "base_uri": "https://localhost:8080/",
96 | "height": 538
97 | },
98 | "id": "4OLHMpsQ5aOv",
99 | "outputId": "1242d6f1-c340-47bc-a716-e97a5e82acfd"
100 | },
101 | "outputs": [],
102 | "source": [
103 | "inp, re = load(PATH+'train/100.jpg')\n",
104 | "# casting to int for matplotlib to show the image\n",
105 | "plt.figure()\n",
106 | "plt.imshow(inp/255.0)\n",
107 | "plt.figure()\n",
108 | "plt.imshow(re/255.0)"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {
115 | "id": "rwwYQpu9FzDu"
116 | },
117 | "outputs": [],
118 | "source": [
119 | "def resize(input_image, real_image, height, width):\n",
120 | " input_image = tf.image.resize(input_image, [height, width],\n",
121 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
122 | " real_image = tf.image.resize(real_image, [height, width],\n",
123 | " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
124 | "\n",
125 | " return input_image, real_image"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {
132 | "id": "Yn3IwqhiIszt"
133 | },
134 | "outputs": [],
135 | "source": [
136 | "def random_crop(input_image, real_image):\n",
137 | " stacked_image = tf.stack([input_image, real_image], axis=0)\n",
138 | " cropped_image = tf.image.random_crop(\n",
139 | " stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n",
140 | "\n",
141 | " return cropped_image[0], cropped_image[1]"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {
148 | "id": "muhR2cgbLKWW"
149 | },
150 | "outputs": [],
151 | "source": [
152 | "# normalizing the images to [-1, 1]\n",
153 | "\n",
154 | "def normalize(input_image, real_image):\n",
155 | " input_image = (input_image / 127.5) - 1\n",
156 | " real_image = (real_image / 127.5) - 1\n",
157 | "\n",
158 | " return real_image, input_image"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {
165 | "id": "fVQOjcPVLrUc"
166 | },
167 | "outputs": [],
168 | "source": [
169 | "@tf.function()\n",
170 | "def random_jitter(input_image, real_image):\n",
171 | " # resizing to 286 x 286 x 3\n",
172 | " input_image, real_image = resize(input_image, real_image, 286, 286)\n",
173 | "\n",
174 | " # randomly cropping to 256 x 256 x 3\n",
175 | " input_image, real_image = random_crop(input_image, real_image)\n",
176 | "\n",
177 | " if tf.random.uniform(()) > 0.5:\n",
178 | " # random mirroring\n",
179 | " input_image = tf.image.flip_left_right(input_image)\n",
180 | " real_image = tf.image.flip_left_right(real_image)\n",
181 | "\n",
182 | " return input_image, real_image"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {
189 | "colab": {
190 | "base_uri": "https://localhost:8080/",
191 | "height": 357
192 | },
193 | "id": "n0OGdi6D92kM",
194 | "outputId": "aa3371d3-f764-4e11-affd-3b6640646491"
195 | },
196 | "outputs": [],
197 | "source": [
198 | "plt.figure(figsize=(6, 6))\n",
199 | "for i in range(4):\n",
200 | " rj_inp, rj_re = random_jitter(inp, re)\n",
201 | " plt.subplot(2, 2, i+1)\n",
202 | " plt.imshow(rj_inp/255.0)\n",
203 | " plt.axis('off')\n",
204 | "plt.show()"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "id": "tyaP4hLJ8b4W"
212 | },
213 | "outputs": [],
214 | "source": [
215 | "def load_image_train(image_file):\n",
216 | " input_image, real_image = load(image_file)\n",
217 | " input_image, real_image = random_jitter(input_image, real_image)\n",
218 | " input_image, real_image = normalize(input_image, real_image)\n",
219 | "\n",
220 | " return input_image, real_image"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {
227 | "id": "VB3Z6D_zKSru"
228 | },
229 | "outputs": [],
230 | "source": [
231 | "def load_image_test(image_file):\n",
232 | " input_image, real_image = load(image_file)\n",
233 | " input_image, real_image = resize(input_image, real_image,\n",
234 | " IMG_HEIGHT, IMG_WIDTH)\n",
235 | " input_image, real_image = normalize(input_image, real_image)\n",
236 | "\n",
237 | " return input_image, real_image"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {
244 | "id": "SQHmYSmk8b4b"
245 | },
246 | "outputs": [],
247 | "source": [
248 | "tf.config.run_functions_eagerly(False)\n",
249 | "\n",
250 | "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n",
251 | "train_dataset = train_dataset.map(load_image_train,\n",
252 | " num_parallel_calls=tf.data.AUTOTUNE)\n",
253 | "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
254 | "train_dataset = train_dataset.batch(BATCH_SIZE)"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {
261 | "id": "MS9J0yA58b4g"
262 | },
263 | "outputs": [],
264 | "source": [
265 | "try:\n",
266 | " test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n",
267 | " test_dataset = test_dataset.map(load_image_test)\n",
268 | " test_dataset = test_dataset.batch(BATCH_SIZE)\n",
269 | "except:\n",
270 | " test_dataset = train_dataset"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {
277 | "id": "AWSBM-ckAZZL"
278 | },
279 | "outputs": [],
280 | "source": [
281 | "class Patches(tf.keras.layers.Layer):\n",
282 | " def __init__(self, patch_size):\n",
283 | " super(Patches, self).__init__()\n",
284 | " self.patch_size = patch_size\n",
285 | "\n",
286 | " def call(self, images):\n",
287 | " batch_size = tf.shape(images)[0]\n",
288 | " patches = tf.image.extract_patches(\n",
289 | " images=images,\n",
290 | " sizes=[1, self.patch_size, self.patch_size, 1],\n",
291 | " strides=[1, self.patch_size, self.patch_size, 1],\n",
292 | " rates=[1, 1, 1, 1],\n",
293 | " padding=\"SAME\",\n",
294 | " )\n",
295 | " patch_dims = patches.shape[-1]\n",
296 | " patches = tf.reshape(patches, [batch_size, -1, patch_dims])\n",
297 | " return patches"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {
304 | "id": "mXT2GyxTAZWq"
305 | },
306 | "outputs": [],
307 | "source": [
308 | "class PatchEncoder(tf.keras.layers.Layer):\n",
309 | " def __init__(self, num_patches, projection_dim):\n",
310 | " super(PatchEncoder, self).__init__()\n",
311 | " self.num_patches = num_patches\n",
312 | " self.projection = layers.Dense(units=projection_dim)\n",
313 | " self.position_embedding = layers.Embedding(\n",
314 | " input_dim=num_patches, output_dim=projection_dim\n",
315 | " )\n",
316 | "\n",
317 | " def call(self, patch):\n",
318 | " positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
319 | " encoded = self.projection(patch) + self.position_embedding(positions)\n",
320 | " return encoded"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": null,
326 | "metadata": {
327 | "id": "EsRN0b3qAdWz"
328 | },
329 | "outputs": [],
330 | "source": [
331 | "class TransformerBlock(tf.keras.layers.Layer):\n",
332 | " def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
333 | " super(TransformerBlock, self).__init__()\n",
334 | " self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)\n",
335 | " self.ffn = tf.keras.Sequential(\n",
336 | " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
337 | " )\n",
338 | " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
339 | " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
340 | " self.dropout1 = layers.Dropout(rate)\n",
341 | " self.dropout2 = layers.Dropout(rate)\n",
342 | "\n",
343 | " def call(self, inputs, training):\n",
344 | " attn_output = self.att(inputs, inputs)\n",
345 | " attn_output = self.dropout1(attn_output, training=training)\n",
346 | " out1 = self.layernorm1(inputs + attn_output)\n",
347 | " ffn_output = self.ffn(out1)\n",
348 | " ffn_output = self.dropout2(ffn_output, training=training)\n",
349 | " return self.layernorm2(out1 + ffn_output)"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": null,
355 | "metadata": {
356 | "id": "BzdEEA95TzBE"
357 | },
358 | "outputs": [],
359 | "source": [
360 | "from tensorflow import Tensor\n",
361 | "from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,\\\n",
362 | " Add, AveragePooling2D, Flatten, Dense\n",
363 | "from tensorflow.keras.models import Model\n",
364 | "\n",
365 | "def relu_bn(inputs: Tensor) -> Tensor:\n",
366 | " relu = ReLU()(inputs)\n",
367 | " bn = BatchNormalization()(relu)\n",
368 | " return bn\n",
369 | "\n",
370 | "def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:\n",
371 | " y = Conv2D(kernel_size=kernel_size,\n",
372 | " strides= (1 if not downsample else 2),\n",
373 | " filters=filters,\n",
374 | " padding=\"same\")(x)\n",
375 | " y = relu_bn(y)\n",
376 | " y = Conv2D(kernel_size=kernel_size,\n",
377 | " strides=1,\n",
378 | " filters=filters,\n",
379 | " padding=\"same\")(y)\n",
380 | "\n",
381 | " if downsample:\n",
382 | " x = Conv2D(kernel_size=1,\n",
383 | " strides=2,\n",
384 | " filters=filters,\n",
385 | " padding=\"same\")(x)\n",
386 | " out = Add()([x, y])\n",
387 | " out = relu_bn(out)\n",
388 | " return out"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": null,
394 | "metadata": {
395 | "id": "lFPI4Nu-8b4q"
396 | },
397 | "outputs": [],
398 | "source": [
399 | "from tensorflow.keras import layers\n",
400 | "\n",
401 | "def Generator():\n",
402 | "\n",
403 | " inputs = layers.Input(shape=(256, 256, 3))\n",
404 | "\n",
405 | " patches = Patches(patch_size)(inputs)\n",
406 | " encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)\n",
407 | "\n",
408 | " x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)\n",
409 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
410 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
411 | " x = TransformerBlock(64, num_heads, ff_dim)(x)\n",
412 | "\n",
413 | " x = layers.Reshape((8, 8, 1024))(x)\n",
414 | "\n",
415 | " x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
416 | " x = layers.BatchNormalization()(x)\n",
417 | " x = layers.LeakyReLU()(x)\n",
418 | "\n",
419 | " x = residual_block(x, downsample=False, filters=512)\n",
420 | "\n",
421 | " x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
422 | " x = layers.BatchNormalization()(x)\n",
423 | " x = layers.LeakyReLU()(x)\n",
424 | "\n",
425 | " x = residual_block(x, downsample=False, filters=256)\n",
426 | "\n",
427 | " x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)\n",
428 | " x = layers.BatchNormalization()(x)\n",
429 | " x = layers.LeakyReLU()(x)\n",
430 | " \n",
431 | " x = residual_block(x, downsample=False, filters=64)\n",
432 | "\n",
433 | " x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)\n",
434 | " x = layers.BatchNormalization()(x)\n",
435 | " x = layers.LeakyReLU()(x)\n",
436 | "\n",
437 | " x = residual_block(x, downsample=False, filters=32)\n",
438 | "\n",
439 | " x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)\n",
440 | "\n",
441 | " return tf.keras.Model(inputs=inputs, outputs=x)"
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": null,
447 | "metadata": {
448 | "colab": {
449 | "base_uri": "https://localhost:8080/"
450 | },
451 | "id": "dIbRPFzjmV85",
452 | "outputId": "33b0d3d6-6588-4e3f-aee3-18a9aa09e150"
453 | },
454 | "outputs": [],
455 | "source": [
456 | "generator = Generator()\n",
457 | "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)\n",
458 | "generator.summary()"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": null,
464 | "metadata": {
465 | "colab": {
466 | "base_uri": "https://localhost:8080/",
467 | "height": 303
468 | },
469 | "id": "U1N1_obwtdQH",
470 | "outputId": "abf76049-d489-4635-8f9b-512e3935387c"
471 | },
472 | "outputs": [],
473 | "source": [
474 | "gen_output = generator(inp[tf.newaxis, ...], training=False)\n",
475 | "plt.imshow(gen_output[0, ...])"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": null,
481 | "metadata": {
482 | "id": "lbHFNexF0x6O"
483 | },
484 | "outputs": [],
485 | "source": [
486 | "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
487 | ]
488 | },
489 | {
490 | "cell_type": "code",
491 | "execution_count": null,
492 | "metadata": {
493 | "id": "RmdVsmvhPxyy"
494 | },
495 | "outputs": [],
496 | "source": [
497 | "def generate_images(model, test_input, tar):\n",
498 | " prediction = model(test_input, training=True)\n",
499 | " plt.figure(figsize=(15, 15))\n",
500 | "\n",
501 | " display_list = [test_input[0], tar[0], prediction[0]]\n",
502 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
503 | "\n",
504 | " for i in range(3):\n",
505 | " plt.subplot(1, 3, i+1)\n",
506 | " plt.title(title[i])\n",
507 | " # getting the pixel values between [0, 1] to plot it.\n",
508 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n",
509 | " plt.axis('off')\n",
510 | " plt.show()\n",
511 | "\n",
512 | "def generate_batch_images(model, test_input, tar):\n",
513 | " for i in range(len(test_input)):\n",
514 | " prediction = model(test_input, training=True)\n",
515 | " plt.figure(figsize=(15, 15))\n",
516 | "\n",
517 | " display_list = [test_input[i], tar[i], prediction[i]]\n",
518 | " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
519 | " \n",
520 | " for i in range(3):\n",
521 | " plt.subplot(1, 3, i+1)\n",
522 | " plt.title(title[i])\n",
523 | " # getting the pixel values between [0, 1] to plot it.\n",
524 | " plt.imshow(display_list[i] * 0.5 + 0.5)\n",
525 | " plt.axis('off')\n",
526 | " plt.show()"
527 | ]
528 | },
529 | {
530 | "cell_type": "code",
531 | "execution_count": null,
532 | "metadata": {
533 | "colab": {
534 | "base_uri": "https://localhost:8080/",
535 | "height": 293
536 | },
537 | "id": "8Fc4NzT-DgEx",
538 | "outputId": "6b5e738a-5851-4c3c-89e9-2defb4b32b88"
539 | },
540 | "outputs": [],
541 | "source": [
542 | "for example_input, example_target in test_dataset.take(1):\n",
543 | " generate_images(generator, example_input, example_target)"
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "execution_count": null,
549 | "metadata": {
550 | "id": "KBKUV2sKXDbY"
551 | },
552 | "outputs": [],
553 | "source": [
554 | "@tf.function\n",
555 | "def train_step(input_image, target, epoch):\n",
556 | " with tf.device('/device:GPU:0'):\n",
557 | " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
558 | " gen_output = generator(input_image, training=True)\n",
559 | "\n",
560 | " gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
561 | " \n",
562 | " generator_gradients = gen_tape.gradient(gen_total_loss,\n",
563 | " generator.trainable_variables)\n",
564 | "\n",
565 | " generator_optimizer.apply_gradients(zip(generator_gradients,\n",
566 | " generator.trainable_variables))"
567 | ]
568 | },
569 | {
570 | "cell_type": "code",
571 | "execution_count": null,
572 | "metadata": {
573 | "id": "2M7LmLtGEMQJ"
574 | },
575 | "outputs": [],
576 | "source": [
577 | "def fit(train_ds, epochs, test_ds):\n",
578 | " for epoch in range(epochs):\n",
579 | " start = time.time()\n",
580 | "\n",
581 | " display.clear_output(wait=True)\n",
582 | "\n",
583 | " for example_input, example_target in test_ds.take(1):\n",
584 | " generate_images(generator, example_input, example_target)\n",
585 | " print(\"Epoch: \", epoch)\n",
586 | "\n",
587 | " # Train\n",
588 | " for n, (input_image, target) in train_ds.enumerate():\n",
589 | " print('.', end='')\n",
590 | " if (n+1) % 100 == 0:\n",
591 | " print()\n",
592 | " train_step(input_image, target, epoch)\n",
593 | " print()\n",
594 | "\n",
595 | " generator.save_weights(f'_{DATASET}-gen-weights.h5')\n",
596 | " discriminator.save_weights(f'_{DATASET}-disc-weights.h5')"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": null,
602 | "metadata": {
603 | "colab": {
604 | "base_uri": "https://localhost:8080/",
605 | "height": 293
606 | },
607 | "id": "a1zZmKmvOH85",
608 | "outputId": "e90cbd9a-0860-4260-f928-c4609ace3d07",
609 | "scrolled": true
610 | },
611 | "outputs": [],
612 | "source": [
613 | "fit(train_dataset, 100000, test_dataset)"
614 | ]
615 | },
616 | {
617 | "cell_type": "code",
618 | "execution_count": null,
619 | "metadata": {
620 | "colab": {
621 | "base_uri": "https://localhost:8080/",
622 | "height": 1000
623 | },
624 | "id": "KUgSnmy2nqSP",
625 | "outputId": "65667797-7b67-4b07-e9ab-94c5ac5173d0"
626 | },
627 | "outputs": [],
628 | "source": [
629 | "for inp, tar in test_dataset.take(1):\n",
630 | " outs = generator(inp)\n",
631 | " generate_batch_images(generator, inp, tar)"
632 | ]
633 | }
634 | ],
635 | "metadata": {
636 | "accelerator": "GPU",
637 | "colab": {
638 | "collapsed_sections": [],
639 | "name": "image2image_res.ipynb",
640 | "provenance": []
641 | },
642 | "kernelspec": {
643 | "display_name": "Python 3",
644 | "language": "python",
645 | "name": "python3"
646 | },
647 | "language_info": {
648 | "codemirror_mode": {
649 | "name": "ipython",
650 | "version": 3
651 | },
652 | "file_extension": ".py",
653 | "mimetype": "text/x-python",
654 | "name": "python",
655 | "nbconvert_exporter": "python",
656 | "pygments_lexer": "ipython3",
657 | "version": "3.8.5"
658 | }
659 | },
660 | "nbformat": 4,
661 | "nbformat_minor": 1
662 | }
663 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipython==8.0.1
2 | matplotlib==3.1.2
3 | opencv_python==4.5.5.62
4 | scikit_image==0.19.1
5 | scipy==1.4.1
6 | tensorflow==2.4.0
7 |
--------------------------------------------------------------------------------
/results/depth_perseption/combine_images (14).jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/combine_images (14).jpg
--------------------------------------------------------------------------------
/results/depth_perseption/d1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d1.png
--------------------------------------------------------------------------------
/results/depth_perseption/d2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d2.png
--------------------------------------------------------------------------------
/results/depth_perseption/d3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d3.png
--------------------------------------------------------------------------------
/results/depth_perseption/d4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d4.png
--------------------------------------------------------------------------------
/results/depth_perseption/d5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d5.png
--------------------------------------------------------------------------------
/results/depth_perseption/d6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/depth_perseption/d6.png
--------------------------------------------------------------------------------
/results/object-segmentation/combine_images (15).jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/combine_images (15).jpg
--------------------------------------------------------------------------------
/results/object-segmentation/os1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os1.png
--------------------------------------------------------------------------------
/results/object-segmentation/os2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os2.png
--------------------------------------------------------------------------------
/results/object-segmentation/os3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os3.png
--------------------------------------------------------------------------------
/results/object-segmentation/os4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os4.png
--------------------------------------------------------------------------------
/results/object-segmentation/os5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os5.png
--------------------------------------------------------------------------------
/results/object-segmentation/os6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/object-segmentation/os6.png
--------------------------------------------------------------------------------
/results/semantic-segmentation/combine_images (16).jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/combine_images (16).jpg
--------------------------------------------------------------------------------
/results/semantic-segmentation/f1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f1.png
--------------------------------------------------------------------------------
/results/semantic-segmentation/f2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f2.png
--------------------------------------------------------------------------------
/results/semantic-segmentation/f3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f3.png
--------------------------------------------------------------------------------
/results/semantic-segmentation/f4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f4.png
--------------------------------------------------------------------------------
/results/semantic-segmentation/f5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f5.png
--------------------------------------------------------------------------------
/results/semantic-segmentation/f6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YigitGunduc/tensor-to-image/eceb94026db239403c57144c4806410de372d1a3/results/semantic-segmentation/f6.png
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import utils # local import
4 | import skimage
5 | import numpy as np
6 | from numpy import log
7 | from numpy import std
8 | from numpy import exp
9 | from math import floor
10 | from numpy import mean
11 | from numpy import cov
12 | from numpy import trace
13 | import tensorflow as tf
14 | from numpy import asarray
15 | from model import Generator # local import
16 | from numpy import expand_dims
17 | from numpy import iscomplexobj
18 | from scipy.linalg import sqrtm
19 | from skimage.metrics import structural_similarity as ssim
20 | from tensorflow.keras.applications.inception_v3 import InceptionV3
21 | from tensorflow.keras.applications.inception_v3 import preprocess_input
22 |
23 | EPOCHS = 100
24 | LAMBDA = 100
25 | BATCH_SIZE = 8
26 | IMG_WIDTH = 256
27 | IMG_HEIGHT = 256
28 | BUFFER_SIZE = 400
29 | DATASET = 'cityscapes'
30 |
31 | num_of_samples = 100 # number of samples to test the model
32 |
33 | # model params
34 | ff_dim = 32
35 | num_heads = 2
36 | patch_size = 8
37 | embed_dim = 64
38 | projection_dim = 64
39 | input_shape = (IMG_HEIGHT, IMG_WIDTH, 3)
40 | num_patches = (IMG_HEIGHT // patch_size) ** 2
41 |
42 | path_to_weights = sys.argv[1]
43 | device = '/device:GPU:0' if utils.check_cuda else '/cpu:0'
44 |
45 |
46 | _URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{DATASET}.tar.gz'
47 |
48 | path_to_zip = tf.keras.utils.get_file(f'{DATASET}.tar.gz',
49 | origin=_URL,
50 | extract=True)
51 |
52 | PATH = os.path.join(os.path.dirname(path_to_zip), f'{DATASET}/')
53 |
54 |
55 | def load(image_file):
56 | image = tf.io.read_file(image_file)
57 | image = tf.image.decode_jpeg(image)
58 |
59 | w = tf.shape(image)[1]
60 |
61 | w = w // 2
62 | real_image = image[:, :w, :]
63 | input_image = image[:, w:, :]
64 |
65 | input_image = tf.cast(input_image, tf.float32)
66 | real_image = tf.cast(real_image, tf.float32)
67 |
68 | return input_image, real_image
69 |
70 |
71 | inp, re = load(PATH+'train/100.jpg')
72 |
73 |
74 | def resize(input_image, real_image, height, width):
75 | input_image = tf.image.resize(input_image, [height, width],
76 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
77 | real_image = tf.image.resize(real_image, [height, width],
78 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
79 |
80 | return input_image, real_image
81 |
82 |
83 | def random_crop(input_image, real_image):
84 | stacked_image = tf.stack([input_image, real_image], axis=0)
85 | cropped_image = tf.image.random_crop(
86 | stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
87 |
88 | return cropped_image[0], cropped_image[1]
89 |
90 |
91 | # normalizing the images to [-1, 1]
92 | def normalize(input_image, real_image):
93 | input_image = (input_image / 127.5) - 1
94 | real_image = (real_image / 127.5) - 1
95 |
96 | return real_image, input_image
97 |
98 |
99 | @tf.function()
100 | def random_jitter(input_image, real_image):
101 | # resizing to 286 x 286 x 3
102 | input_image, real_image = resize(input_image, real_image, 286, 286)
103 |
104 | # randomly cropping to 256 x 256 x 3
105 | input_image, real_image = random_crop(input_image, real_image)
106 |
107 | if tf.random.uniform(()) > 0.5:
108 | # random mirroring
109 | input_image = tf.image.flip_left_right(input_image)
110 | real_image = tf.image.flip_left_right(real_image)
111 |
112 | return input_image, real_image
113 |
114 |
115 | def load_image_train(image_file):
116 | input_image, real_image = load(image_file)
117 | input_image, real_image = random_jitter(input_image, real_image)
118 | input_image, real_image = normalize(input_image, real_image)
119 |
120 | return input_image, real_image
121 |
122 |
123 | def load_image_test(image_file):
124 | input_image, real_image = load(image_file)
125 | input_image, real_image = resize(input_image, real_image,
126 | IMG_HEIGHT, IMG_WIDTH)
127 | input_image, real_image = normalize(input_image, real_image)
128 |
129 | return input_image, real_image
130 |
131 |
132 | train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
133 | train_dataset = train_dataset.map(load_image_train,
134 | num_parallel_calls=tf.data.AUTOTUNE)
135 | train_dataset = train_dataset.shuffle(BUFFER_SIZE)
136 | train_dataset = train_dataset.batch(BATCH_SIZE)
137 |
138 |
139 | def generate_samples(model, dataset, device, num_of_samples):
140 | with tf.device(device):
141 | outs = list()
142 | targets = list()
143 |
144 | for n, (input_image, target) in dataset.enumerate():
145 |
146 | target = np.array(target)
147 | targets.append(target)
148 |
149 | input_image = np.array(input_image)
150 | model_out = np.squeeze(np.array(model(input_image, training=False)).reshape((-1, 256, 256, 3)))
151 | outs.append(model_out)
152 |
153 | if (n + 1) % num_of_samples == 0:
154 | break
155 |
156 | return outs, targets
157 |
158 |
159 | def pre_process(outs, targets):
160 | outs = np.array(outs)
161 | targets = np.array(targets)
162 |
163 | outs = outs.reshape((-1, 3, 256, 256))
164 | targets = targets.reshape(-1, 3, 256, 256)
165 |
166 | outs = outs * 0.5 + 0.5
167 | targets = targets * 0.5 + 0.5
168 |
169 | outs = outs * 255
170 | targets = targets * 255
171 |
172 | return outs, targets
173 |
174 | # assumes images have any shape and pixels in [0,255]
175 | def calculate_inception_score(images, n_split=10, eps=1E-16):
176 | # load inception v3 model
177 | model = InceptionV3()
178 | # enumerate splits of images/predictions
179 | scores = list()
180 | n_part = floor(images.shape[0] / n_split)
181 | for i in range(n_split):
182 | # retrieve images
183 | ix_start, ix_end = i * n_part, (i+1) * n_part
184 | subset = images[ix_start:ix_end]
185 | # convert from uint8 to float32
186 | subset = subset.astype('float32')
187 | # scale images to the required size
188 | subset = scale_images(subset, (299,299,3))
189 | # pre-process images, scale to [-1,1]
190 | subset = preprocess_input(subset)
191 | # predict p(y|x)
192 | p_yx = model.predict(subset)
193 | # calculate p(y)
194 | p_y = expand_dims(p_yx.mean(axis=0), 0)
195 | # calculate KL divergence using log probabilities
196 | kl_d = p_yx * (log(p_yx + eps) - log(p_y + eps))
197 | # sum over classes
198 | sum_kl_d = kl_d.sum(axis=1)
199 | # average over images
200 | avg_kl_d = mean(sum_kl_d)
201 | # undo the log
202 | is_score = exp(avg_kl_d)
203 | # store
204 | scores.append(is_score)
205 | # average across images
206 | is_avg, is_std = mean(scores), std(scores)
207 | return is_avg, is_std
208 |
209 |
210 | # scale an array of images to a new size
211 | def scale_images(images, new_shape):
212 | images_list = list()
213 | for image in images:
214 | # resize with nearest neighbor interpolation
215 | new_image = skimage.transform.resize(image, new_shape, 0)
216 | # store
217 | images_list.append(new_image)
218 | return asarray(images_list)
219 |
220 |
221 | # calculate frechet inception distance
222 | def calculate_fid(images1, images2):
223 | model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))
224 |
225 | images1 = scale_images(images1, (299,299,3))
226 | images2 = scale_images(images2, (299,299,3))
227 |
228 | # calculate activations
229 | act1 = model.predict(images1)
230 | act2 = model.predict(images2)
231 | # calculate mean and covariance statistics
232 | mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
233 | mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
234 | # calculate sum squared difference between means
235 | ssdiff = np.sum((mu1 - mu2)**2.0)
236 | # calculate sqrt of product between cov
237 | covmean = sqrtm(sigma1.dot(sigma2))
238 | # check and correct imaginary numbers from sqrt
239 | if iscomplexobj(covmean):
240 | covmean = covmean.real
241 | # calculate score
242 | fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
243 | return fid
244 |
245 |
246 | model = Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim)
247 | model.load_weights(path_to_weights)
248 |
249 |
250 | # generate and process samples from the model
251 | outs, targets = generate_samples(model, train_dataset, device, num_of_samples)
252 | outs, targets = pre_process(outs, targets)
253 |
254 |
255 | # calculate fid, ssim, inception score
256 | fid_score = calculate_fid(targets, outs)
257 | ssim_score = ssim(targets.reshape(-1, 256, 256, 3), outs.reshape(-1, 256, 256, 3), data_range=targets.max() - targets.min(), multichannel=True)
258 | inception_score = calculate_inception_score(outs)
259 |
260 | print('----------------|-------------')
261 | print(f'ssim score | {ssim_score}')
262 | print(f'FID | {fid_score}')
263 | print(f'Inception score | mean: {inception_score[0]} std: {inception_score[1]}')
264 | print('----------------|-------------')
265 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow import Tensor
3 | from tensorflow.keras import layers
4 | from tensorflow.keras.models import Model
5 | from tensorflow.keras.layers import (Input,
6 | Conv2D,
7 | ReLU,
8 | BatchNormalization,
9 | Add,
10 | AveragePooling2D,
11 | Flatten,
12 | Dense)
13 |
14 |
15 | class Patches(tf.keras.layers.Layer):
16 | def __init__(self, patch_size):
17 | super(Patches, self).__init__()
18 | self.patch_size = patch_size
19 |
20 | def call(self, images):
21 | batch_size = tf.shape(images)[0]
22 | patches = tf.image.extract_patches(
23 | images=images,
24 | sizes=[1, self.patch_size, self.patch_size, 1],
25 | strides=[1, self.patch_size, self.patch_size, 1],
26 | rates=[1, 1, 1, 1],
27 | padding="SAME",
28 | )
29 | patch_dims = patches.shape[-1]
30 | patches = tf.reshape(patches, [batch_size, -1, patch_dims])
31 | return patches
32 |
33 |
34 | class PatchEncoder(tf.keras.layers.Layer):
35 | def __init__(self, num_patches, projection_dim):
36 | super(PatchEncoder, self).__init__()
37 | self.num_patches = num_patches
38 | self.projection = layers.Dense(units=projection_dim)
39 | self.position_embedding = layers.Embedding(
40 | input_dim=num_patches, output_dim=projection_dim
41 | )
42 |
43 | def call(self, patch):
44 | positions = tf.range(start=0, limit=self.num_patches, delta=1)
45 | encoded = self.projection(patch) + self.position_embedding(positions)
46 | return encoded
47 |
48 |
49 | class TransformerBlock(tf.keras.layers.Layer):
50 | def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
51 | super(TransformerBlock, self).__init__()
52 | self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
53 | self.ffn = tf.keras.Sequential(
54 | [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
55 | )
56 | self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
57 | self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
58 | self.dropout1 = layers.Dropout(rate)
59 | self.dropout2 = layers.Dropout(rate)
60 |
61 | def call(self, inputs, training):
62 | attn_output = self.att(inputs, inputs)
63 | attn_output = self.dropout1(attn_output, training=training)
64 | out1 = self.layernorm1(inputs + attn_output)
65 | ffn_output = self.ffn(out1)
66 | ffn_output = self.dropout2(ffn_output, training=training)
67 | return self.layernorm2(out1 + ffn_output)
68 |
69 |
70 | def relu_bn(inputs: Tensor) -> Tensor:
71 | relu = ReLU()(inputs)
72 | bn = BatchNormalization()(relu)
73 | return bn
74 |
75 |
76 | def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
77 | y = Conv2D(kernel_size=kernel_size,
78 | strides= (1 if not downsample else 2),
79 | filters=filters,
80 | padding="same")(x)
81 | y = relu_bn(y)
82 | y = Conv2D(kernel_size=kernel_size,
83 | strides=1,
84 | filters=filters,
85 | padding="same")(y)
86 |
87 | if downsample:
88 | x = Conv2D(kernel_size=1,
89 | strides=2,
90 | filters=filters,
91 | padding="same")(x)
92 | out = Add()([x, y])
93 | out = relu_bn(out)
94 | return out
95 |
96 |
97 | def Generator(input_shape,
98 | patch_size,
99 | num_patches,
100 | projection_dim,
101 | num_heads,
102 | ff_dim):
103 |
104 | inputs = layers.Input(shape=(256, 256, 3))
105 |
106 | patches = Patches(patch_size)(inputs)
107 | encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
108 |
109 | x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)
110 | x = TransformerBlock(64, num_heads, ff_dim)(x)
111 | x = TransformerBlock(64, num_heads, ff_dim)(x)
112 | x = TransformerBlock(64, num_heads, ff_dim)(x)
113 |
114 | x = layers.Reshape((8, 8, 1024))(x)
115 |
116 | x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
117 | x = layers.BatchNormalization()(x)
118 | x = layers.LeakyReLU()(x)
119 |
120 | x = residual_block(x, downsample=False, filters=512)
121 |
122 | x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
123 | x = layers.BatchNormalization()(x)
124 | x = layers.LeakyReLU()(x)
125 |
126 | x = residual_block(x, downsample=False, filters=256)
127 |
128 | x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
129 | x = layers.BatchNormalization()(x)
130 | x = layers.LeakyReLU()(x)
131 |
132 | x = residual_block(x, downsample=False, filters=64)
133 |
134 | x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)
135 | x = layers.BatchNormalization()(x)
136 | x = layers.LeakyReLU()(x)
137 |
138 | x = residual_block(x, downsample=False, filters=32)
139 |
140 | x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)
141 |
142 | return tf.keras.Model(inputs=inputs, outputs=x)
143 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import tensorflow as tf
4 | from IPython import display
5 | from tensorflow import Tensor
6 | from tensorflow.keras import layers
7 | from matplotlib import pyplot as plt
8 | from tensorflow.keras.models import Model
9 |
10 | import utils # local import
11 | from model import Generator # local import
12 |
13 | EPOCHS = 100
14 | LAMBDA = 100
15 | BATCH_SIZE = 8
16 | IMG_WIDTH = 256
17 | IMG_HEIGHT = 256
18 | BUFFER_SIZE = 400
19 | SAVE_PATH = 'weights'
20 | DATASET = 'cityscapes'
21 | ff_dim = 32
22 | num_heads = 2
23 | patch_size = 8
24 | embed_dim = 64
25 | projection_dim = 64
26 | input_shape = (IMG_HEIGHT, IMG_WIDTH, 3)
27 | num_patches = (IMG_HEIGHT // patch_size) ** 2
28 |
29 | if not os.path.exists(SAVE_PATH):
30 | os.makedirs(SAVE_PATH)
31 |
32 | available_datasets = [
33 | 'cityscapes',
34 | 'edges2handbags',
35 | 'edges2shoes',
36 | 'facades',
37 | 'maps',
38 | 'night2day'
39 | ]
40 |
41 | if DATASET not in available_datasets:
42 | print(f'[ERROR] dataset: {DATASET}')
43 | print('[INFO] please us on of the following datasets')
44 | for dataset in available_datasets:
45 | print(f' -> {dataset}')
46 |
47 | exit(1)
48 |
49 | assert IMG_WIDTH == IMG_HEIGHT, 'width and height must have same size'
50 | _URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{DATASET}.tar.gz'
51 | device = '/device:GPU:0' if utils.check_cuda else '/cpu:0'
52 |
53 | path_to_zip = tf.keras.utils.get_file(f'{DATASET}.tar.gz',
54 | origin=_URL,
55 | extract=True)
56 |
57 | PATH = os.path.join(os.path.dirname(path_to_zip), f'{DATASET}/')
58 |
59 |
60 | def load(image_file):
61 | image = tf.io.read_file(image_file)
62 | image = tf.image.decode_jpeg(image)
63 |
64 | w = tf.shape(image)[1]
65 |
66 | w = w // 2
67 | real_image = image[:, :w, :]
68 | input_image = image[:, w:, :]
69 |
70 | input_image = tf.cast(input_image, tf.float32)
71 | real_image = tf.cast(real_image, tf.float32)
72 |
73 | return input_image, real_image
74 |
75 |
76 | def resize(input_image, real_image, height, width):
77 | input_image = tf.image.resize(input_image, [height, width],
78 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
79 | real_image = tf.image.resize(real_image, [height, width],
80 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
81 |
82 | return input_image, real_image
83 |
84 |
85 | def random_crop(input_image, real_image):
86 | stacked_image = tf.stack([input_image, real_image], axis=0)
87 | cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
88 |
89 | return cropped_image[0], cropped_image[1]
90 |
91 |
92 | # normalizing the images between [-1, 1]
93 | def normalize(input_image, real_image):
94 | input_image = (input_image / 127.5) - 1
95 | real_image = (real_image / 127.5) - 1
96 |
97 | return real_image, input_image
98 |
99 |
100 | @tf.function()
101 | def random_jitter(input_image, real_image):
102 | # resizing to 286 x 286 x 3
103 | input_image, real_image = resize(input_image, real_image, 286, 286)
104 |
105 | # randomly cropping to 256 x 256 x 3
106 | input_image, real_image = random_crop(input_image, real_image)
107 |
108 | if tf.random.uniform(()) > 0.5:
109 | # random mirroring
110 | input_image = tf.image.flip_left_right(input_image)
111 | real_image = tf.image.flip_left_right(real_image)
112 |
113 | return input_image, real_image
114 |
115 |
116 | def load_image_train(image_file):
117 | input_image, real_image = load(image_file)
118 | input_image, real_image = random_jitter(input_image, real_image)
119 | input_image, real_image = normalize(input_image, real_image)
120 |
121 | return input_image, real_image
122 |
123 |
124 | def load_image_test(image_file):
125 | input_image, real_image = load(image_file)
126 | input_image, real_image = resize(input_image, real_image,
127 | IMG_HEIGHT, IMG_WIDTH)
128 | input_image, real_image = normalize(input_image, real_image)
129 |
130 | return input_image, real_image
131 |
132 |
133 | tf.config.run_functions_eagerly(False)
134 |
135 | train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
136 | train_dataset = train_dataset.map(load_image_train,
137 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
138 | train_dataset = train_dataset.shuffle(BUFFER_SIZE)
139 | train_dataset = train_dataset.batch(BATCH_SIZE)
140 |
141 | try:
142 | test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')
143 | test_dataset = test_dataset.map(load_image_test)
144 | test_dataset = test_dataset.batch(BATCH_SIZE)
145 | except:
146 | test_dataset = train_dataset
147 |
148 |
149 | generator = Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim)
150 | tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
151 | generator.summary()
152 |
153 | optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
154 |
155 |
156 | def generate_images(model, test_input, tar):
157 | prediction = model(test_input, training=True)
158 | plt.figure(figsize=(15, 15))
159 |
160 | display_list = [test_input[0], tar[0], prediction[0]]
161 | title = ['Input Image', 'Ground Truth', 'Predicted Image']
162 |
163 | for i in range(3):
164 | plt.subplot(1, 3, i+1)
165 | plt.title(title[i])
166 | # getting the pixel values between [0, 1] to plot it.
167 | plt.imshow(display_list[i] * 0.5 + 0.5)
168 | plt.axis('off')
169 | plt.show()
170 |
171 |
172 | def generate_batch_images(model, test_input, tar):
173 | for i in range(len(test_input)):
174 | prediction = model(test_input, training=True)
175 | plt.figure(figsize=(15, 15))
176 |
177 | display_list = [test_input[i], tar[i], prediction[i]]
178 | title = ['Input Image', 'Ground Truth', 'Predicted Image']
179 |
180 | for i in range(3):
181 | plt.subplot(1, 3, i+1)
182 | plt.title(title[i])
183 | # converting the pixel values to [0, 1] to plot it.
184 | plt.imshow(display_list[i] * 0.5 + 0.5)
185 | plt.axis('off')
186 | plt.show()
187 |
188 |
189 | def train_step(input_image, target, epoch):
190 | with tf.device(device):
191 | with tf.GradientTape() as gen_tape:
192 | gen_output = generator(input_image, training=True)
193 |
194 | gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))
195 |
196 | generator_gradients = gen_tape.gradient(gen_total_loss,
197 | generator.trainable_variables)
198 |
199 | optimizer.apply_gradients(zip(generator_gradients,
200 | generator.trainable_variables))
201 |
202 |
203 | def fit(train_ds, epochs, test_ds):
204 | print(f"[INFO] will train on device: {device}")
205 | for epoch in range(epochs):
206 |
207 | if utils.is_notebook():
208 | display.clear_output(wait=True)
209 |
210 | for example_input, example_target in test_ds.take(1):
211 | generate_images(generator, example_input, example_target)
212 |
213 | print(f'Epoch: [{epoch}/{epochs}]')
214 |
215 | # Train
216 | for n, (input_image, target) in train_ds.enumerate():
217 | train_step(input_image, target, epoch)
218 |
219 | generator.save_weights(f'{SAVE_PATH}/tensor2image-{DATASET}-{epoch}-epochs-weights.h5')
220 |
221 |
222 | def test(test_dataset, generator):
223 | '''
224 | a function to visually inspect to outputs
225 | '''
226 | if utils.is_notebook():
227 | for inp, tar in test_dataset.take(1):
228 | generate_batch_images(generator, inp, tar)
229 |
230 |
231 | if __name__ == '__main__':
232 | fit(train_dataset, EPOCHS, test_dataset)
233 |
234 | test(test_dataset, generator)
235 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | def display_image(images:list, display=True, save=False, name=None):
2 | import cv2
3 | import numpy as np
4 | import tensorflow as tf
5 | from matplotlib import pyplot as plt
6 |
7 | img1, img2, img3, *_ = images
8 |
9 | img1 = np.array(img1).astype(np.float32)
10 | img2 = np.array(img2).astype(np.float32)
11 | img3 = np.array(img3).astype(np.float32)
12 |
13 | img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)
14 | img3 = cv2.cvtColor(img3, cv2.COLOR_GRAY2BGR)
15 | print(img1.shape, img2.shape, img3.shape)
16 |
17 | im_h = cv2.hconcat([img1, img2, img3])
18 |
19 | im_h = tf.nn.relu(im_h).numpy()
20 | im_h = np.clip(im_h, 0, 1)
21 |
22 | print(np.max(im_h))
23 | print(np.min(im_h))
24 |
25 | plt.xticks([])
26 | plt.yticks([])
27 |
28 | if display:
29 | plt.imshow(im_h)
30 |
31 | if save:
32 | if name is not None:
33 | plt.imsave(name, im_h.astype(np.float32))
34 | else:
35 | raise AttributeError('plt.imsave expected to have a name to save the image')
36 |
37 | return im_h
38 |
39 |
40 | def is_notebook():
41 | try:
42 | shell = get_ipython().__class__.__name__
43 | if shell == 'ZMQInteractiveShell':
44 | return True # Jupyter notebook or qtconsole
45 | elif shell == 'TerminalInteractiveShell':
46 | return False # Terminal running IPython
47 | else:
48 | return False # Other type (?)
49 | except NameError:
50 | return False # Probably standard Python interpreter
51 |
52 |
53 | def check_cuda():
54 | import tensorflow as tf
55 | device_name = tf.test.gpu_device_name()
56 | if device_name != '/device:GPU:0':
57 | return False
58 | return True
59 |
--------------------------------------------------------------------------------