├── .gitignore ├── .ipynb_checkpoints ├── deHaze-checkpoint.ipynb ├── purePytorch-checkpoint.ipynb ├── srcnn-checkpoint.ipynb └── test4-checkpoint.jpg ├── README.md ├── __pycache__ └── srcnn_model.cpython-37.pyc ├── deHaze.ipynb ├── intermediate ├── .ipynb_checkpoints │ └── 350-checkpoint.jpg ├── 0.jpg ├── 100.jpg ├── 1000.jpg ├── 1100.jpg ├── 1200.jpg ├── 1300.jpg ├── 1400.jpg ├── 1500.jpg ├── 1600.jpg ├── 1700.jpg ├── 1800.jpg ├── 1900.jpg ├── 200.jpg ├── 2000.jpg ├── 2100.jpg ├── 2200.jpg ├── 2300.jpg ├── 2400.jpg ├── 2500.jpg ├── 2600.jpg ├── 2700.jpg ├── 2800.jpg ├── 2900.jpg ├── 300.jpg ├── 3000.jpg ├── 400.jpg ├── 500.jpg ├── 600.jpg ├── 700.jpg ├── 800.jpg └── 900.jpg ├── output.png ├── output_m.png ├── purePytorch.ipynb ├── purePytorch.py ├── test1.jpeg ├── test2.png ├── test3.jpg └── test4.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/test4-checkpoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/.ipynb_checkpoints/test4-checkpoint.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | - Removes Haze from images 4 | - deHaze.ipynb 5 | 6 | ## Example 7 | - ![output](output.png) 8 | -------------------------------------------------------------------------------- /__pycache__/srcnn_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/__pycache__/srcnn_model.cpython-37.pyc -------------------------------------------------------------------------------- /intermediate/.ipynb_checkpoints/350-checkpoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/.ipynb_checkpoints/350-checkpoint.jpg -------------------------------------------------------------------------------- /intermediate/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/0.jpg -------------------------------------------------------------------------------- /intermediate/100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/100.jpg -------------------------------------------------------------------------------- /intermediate/1000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1000.jpg -------------------------------------------------------------------------------- /intermediate/1100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1100.jpg -------------------------------------------------------------------------------- /intermediate/1200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1200.jpg -------------------------------------------------------------------------------- /intermediate/1300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1300.jpg -------------------------------------------------------------------------------- /intermediate/1400.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1400.jpg -------------------------------------------------------------------------------- /intermediate/1500.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1500.jpg -------------------------------------------------------------------------------- /intermediate/1600.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1600.jpg -------------------------------------------------------------------------------- /intermediate/1700.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1700.jpg -------------------------------------------------------------------------------- /intermediate/1800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1800.jpg -------------------------------------------------------------------------------- /intermediate/1900.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/1900.jpg -------------------------------------------------------------------------------- /intermediate/200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/200.jpg -------------------------------------------------------------------------------- /intermediate/2000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2000.jpg -------------------------------------------------------------------------------- /intermediate/2100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2100.jpg -------------------------------------------------------------------------------- /intermediate/2200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2200.jpg -------------------------------------------------------------------------------- /intermediate/2300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2300.jpg -------------------------------------------------------------------------------- /intermediate/2400.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2400.jpg -------------------------------------------------------------------------------- /intermediate/2500.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2500.jpg -------------------------------------------------------------------------------- /intermediate/2600.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2600.jpg -------------------------------------------------------------------------------- /intermediate/2700.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2700.jpg -------------------------------------------------------------------------------- /intermediate/2800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2800.jpg -------------------------------------------------------------------------------- /intermediate/2900.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/2900.jpg -------------------------------------------------------------------------------- /intermediate/300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/300.jpg -------------------------------------------------------------------------------- /intermediate/3000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/3000.jpg -------------------------------------------------------------------------------- /intermediate/400.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/400.jpg -------------------------------------------------------------------------------- /intermediate/500.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/500.jpg -------------------------------------------------------------------------------- /intermediate/600.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/600.jpg -------------------------------------------------------------------------------- /intermediate/700.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/700.jpg -------------------------------------------------------------------------------- /intermediate/800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/800.jpg -------------------------------------------------------------------------------- /intermediate/900.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/intermediate/900.jpg -------------------------------------------------------------------------------- /output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/output.png -------------------------------------------------------------------------------- /output_m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/output_m.png -------------------------------------------------------------------------------- /purePytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# %%\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "import torch.nn.functional as F\n", 13 | "import torch.optim as optim\n", 14 | "from PIL import Image, ImageDraw, ImageFilter, ImageFont\n", 15 | "import os\n", 16 | "from pathlib import Path\n", 17 | "import mimetypes\n", 18 | "from glob import glob\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import numpy as np\n", 21 | "import itertools\n", 22 | "import logging\n", 23 | "from os.path import splitext\n", 24 | "from os import listdir\n", 25 | "from torch.utils.data import Dataset\n", 26 | "from torchvision.models import resnet34\n", 27 | "from torchvision.transforms import Compose\n", 28 | "\n", 29 | "os.environ[\"TORCH_HOME\"] = \"/media/subhaditya/DATA/COSMO/Datasets-Useful\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# %%\n", 39 | "path = \"/media/subhaditya/DATA/COSMO/Datasets/deHazer\"\n", 40 | "path_hr = path + \"/normal\"\n", 41 | "path_lr = path + \"/hazy\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# %%\n", 51 | "\n", 52 | "\n", 53 | "def plot_img_and_mask(img, mask):\n", 54 | " classes = mask.shape[2] if len(mask.shape) > 2 else 1\n", 55 | " fig, ax = plt.subplots(1, classes + 1)\n", 56 | " ax[0].set_title(\"Input image\")\n", 57 | " ax[0].imshow(img)\n", 58 | " if classes > 1:\n", 59 | " for i in range(classes):\n", 60 | " ax[i + 1].set_title(f\"Output mask (class {i+1})\")\n", 61 | " ax[i + 1].imshow(mask[:, :, i])\n", 62 | " else:\n", 63 | " ax[1].set_title(f\"Output mask\")\n", 64 | " ax[1].imshow(mask)\n", 65 | " plt.xticks([]), plt.yticks([])\n", 66 | " plt.show()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# %%\n", 76 | "\n", 77 | "\n", 78 | "class BasicDataset(Dataset):\n", 79 | " def __init__(self, imgs_dir, masks_dir, scale=1):\n", 80 | " self.imgs_dir = imgs_dir\n", 81 | " self.masks_dir = masks_dir\n", 82 | " self.scale = scale\n", 83 | " assert 0 < scale <= 1, \"Scale must be between 0 and 1\"\n", 84 | "\n", 85 | " self.ids = [\n", 86 | " splitext(file)[0] for file in listdir(imgs_dir) if not file.startswith(\".\")\n", 87 | " ]\n", 88 | " # self.c = len(self.ids)\n", 89 | "\n", 90 | " # logging.info(f'Creating dataset with {len(self.ids)} examples')\n", 91 | "\n", 92 | " def __len__(self):\n", 93 | " return len(self.ids)\n", 94 | "\n", 95 | " @classmethod\n", 96 | " def preprocess(cls, pil_img, scale):\n", 97 | " w, h = pil_img.size\n", 98 | " newW, newH = int(scale * w), int(scale * h)\n", 99 | " assert newW > 0 and newH > 0, \"Scale is too small\"\n", 100 | " pil_img = pil_img.resize((128, 128))\n", 101 | "\n", 102 | " img_nd = np.array(pil_img)\n", 103 | "\n", 104 | " if len(img_nd.shape) == 2:\n", 105 | " img_nd = np.expand_dims(img_nd, axis=2)\n", 106 | "\n", 107 | " # HWC to CHW\n", 108 | " img_trans = img_nd.transpose((2, 0, 1))\n", 109 | " if img_trans.max() > 1:\n", 110 | " img_trans = img_trans / 255\n", 111 | "\n", 112 | " return img_trans\n", 113 | "\n", 114 | " def __getitem__(self, i):\n", 115 | " idx = self.ids[i]\n", 116 | " mask_file = self.masks_dir + \"/\" + idx + \".png\"\n", 117 | " img_file = self.imgs_dir + \"/\" + idx + \".png\"\n", 118 | "\n", 119 | " mask = Image.open(mask_file)\n", 120 | " img = Image.open(img_file)\n", 121 | "\n", 122 | " assert (\n", 123 | " img.size == mask.size\n", 124 | " ), f\"Image and mask {idx} should be the same size, but are {img.size} and {mask.size}\"\n", 125 | "\n", 126 | " img = self.preprocess(img, self.scale)\n", 127 | " mask = self.preprocess(mask, self.scale)\n", 128 | "\n", 129 | " return (\n", 130 | " torch.from_numpy(img).type(torch.FloatTensor),\n", 131 | " torch.from_numpy(mask).type(torch.FloatTensor),\n", 132 | " )" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# %%\n", 142 | "dataset = BasicDataset(path_hr, path_lr, scale=0.5)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 6, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "torch.Size([128, 128, 3])" 154 | ] 155 | }, 156 | "execution_count": 6, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "dataset.__getitem__(1)[0].permute(1, 2, 0).shape" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 7, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "" 174 | ] 175 | }, 176 | "execution_count": 7, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | }, 180 | { 181 | "data": { 182 | "image/png": "\n", 183 | "text/plain": [ 184 | "
" 185 | ] 186 | }, 187 | "metadata": { 188 | "needs_background": "light" 189 | }, 190 | "output_type": "display_data" 191 | } 192 | ], 193 | "source": [ 194 | "# %%\n", 195 | "plt.imshow(dataset.__getitem__(1)[0].permute(1, 2, 0))" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "data": { 205 | "text/plain": [ 206 | "" 207 | ] 208 | }, 209 | "execution_count": 8, 210 | "metadata": {}, 211 | "output_type": "execute_result" 212 | }, 213 | { 214 | "data": { 215 | "image/png": "\n", 216 | "text/plain": [ 217 | "
" 218 | ] 219 | }, 220 | "metadata": { 221 | "needs_background": "light" 222 | }, 223 | "output_type": "display_data" 224 | } 225 | ], 226 | "source": [ 227 | "# %%\n", 228 | "plt.imshow(dataset.__getitem__(1)[1].permute(1, 2, 0))" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "# inits" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 9, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# %%\n", 245 | "def truncated_normal_(tensor, mean=0, std=1):\n", 246 | " size = tensor.shape\n", 247 | " tmp = tensor.new_empty(size + (4,)).normal_()\n", 248 | " valid = (tmp < 2) & (tmp > -2)\n", 249 | " ind = valid.max(-1, keepdim=True)[1]\n", 250 | " tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))\n", 251 | " tensor.data.mul_(std).add_(mean)\n", 252 | "\n", 253 | "\n", 254 | "def init_weights(m):\n", 255 | " if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:\n", 256 | " nn.init.kaiming_normal_(m.weight, mode=\"fan_in\", nonlinearity=\"relu\")\n", 257 | " # nn.init.normal_(m.weight, std=0.001)\n", 258 | " # nn.init.normal_(m.bias, std=0.001)\n", 259 | " truncated_normal_(m.bias, mean=0, std=0.001)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "# Blocks for Unet" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 10, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "class DownConvBlock(nn.Module):\n", 276 | " def __init__(self, input_dim, output_dim, initializers, padding, pool=True):\n", 277 | " super(DownConvBlock, self).__init__()\n", 278 | " layers = []\n", 279 | "\n", 280 | " if pool:\n", 281 | " layers.append(\n", 282 | " nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)\n", 283 | " )\n", 284 | "\n", 285 | " layers.append(\n", 286 | " nn.Conv2d(\n", 287 | " input_dim, output_dim, kernel_size=3, stride=1, padding=int(padding)\n", 288 | " )\n", 289 | " )\n", 290 | " layers.append(nn.ReLU(inplace=True))\n", 291 | " layers.append(\n", 292 | " nn.Conv2d(\n", 293 | " output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding)\n", 294 | " )\n", 295 | " )\n", 296 | " layers.append(nn.ReLU(inplace=True))\n", 297 | " layers.append(\n", 298 | " nn.Conv2d(\n", 299 | " output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding)\n", 300 | " )\n", 301 | " )\n", 302 | " layers.append(nn.ReLU(inplace=True))\n", 303 | "\n", 304 | " self.layers = nn.Sequential(*layers)\n", 305 | "\n", 306 | " self.layers.apply(init_weights)\n", 307 | "\n", 308 | " def forward(self, patch):\n", 309 | " return self.layers(patch)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 11, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "class UpConvBlock(nn.Module):\n", 319 | " def __init__(self, input_dim, output_dim, initializers, padding, bilinear=True):\n", 320 | " super(UpConvBlock, self).__init__()\n", 321 | " self.bilinear = bilinear\n", 322 | "\n", 323 | " if not self.bilinear:\n", 324 | " self.upconv_layer = nn.ConvTranspose2d(\n", 325 | " input_dim, output_dim, kernel_size=2, stride=2\n", 326 | " )\n", 327 | " self.upconv_layer.apply(init_weights)\n", 328 | "\n", 329 | " self.conv_block = DownConvBlock(\n", 330 | " input_dim, output_dim, initializers, padding, pool=False\n", 331 | " )\n", 332 | "\n", 333 | " def forward(self, x, bridge):\n", 334 | " if self.bilinear:\n", 335 | " up = nn.functional.interpolate(\n", 336 | " x, mode=\"bilinear\", scale_factor=2, align_corners=True\n", 337 | " )\n", 338 | " else:\n", 339 | " up = self.upconv_layer(x)\n", 340 | "\n", 341 | " assert up.shape[3] == bridge.shape[3]\n", 342 | " out = torch.cat([up, bridge], 1)\n", 343 | " out = self.conv_block(out)\n", 344 | "\n", 345 | " return out" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": {}, 351 | "source": [ 352 | "\n", 353 | "# Unet " 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": {}, 359 | "source": [ 360 | "https://arxiv.org/abs/1505.04597\n", 361 | "\n", 362 | "@article{DBLP:journals/corr/RonnebergerFB15,\n", 363 | " author = {Olaf Ronneberger and\n", 364 | " Philipp Fischer and\n", 365 | " Thomas Brox},\n", 366 | " title = {U-Net: Convolutional Networks for Biomedical Image Segmentation},\n", 367 | " journal = {CoRR},\n", 368 | " volume = {abs/1505.04597},\n", 369 | " year = {2015},\n", 370 | " url = {http://arxiv.org/abs/1505.04597},\n", 371 | " archivePrefix = {arXiv},\n", 372 | " eprint = {1505.04597},\n", 373 | " timestamp = {Mon, 13 Aug 2018 16:46:52 +0200},\n", 374 | " biburl = {https://dblp.org/rec/journals/corr/RonnebergerFB15.bib},\n", 375 | " bibsource = {dblp computer science bibliography, https://dblp.org}\n", 376 | "}\n", 377 | "\n" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 115, 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "# %%\n", 387 | "class Unet(nn.Module):\n", 388 | " def __init__(\n", 389 | " self,\n", 390 | " input_channels,\n", 391 | " num_classes,\n", 392 | " num_filters,\n", 393 | " initializers,\n", 394 | " apply_last_layer=True,\n", 395 | " padding=True,\n", 396 | " ):\n", 397 | " super(Unet, self).__init__()\n", 398 | " self.input_channels = input_channels\n", 399 | " self.num_classes = num_classes\n", 400 | " self.num_filters = num_filters\n", 401 | " self.padding = padding\n", 402 | " self.activation_maps = []\n", 403 | " self.apply_last_layer = apply_last_layer\n", 404 | " self.contracting_path = nn.ModuleList()\n", 405 | " # self.pl = nn.modules.pixelshuffle\n", 406 | "\n", 407 | " for i in range(len(self.num_filters)):\n", 408 | " input = self.input_channels if i == 0 else output\n", 409 | " output = self.num_filters[i]\n", 410 | "\n", 411 | " if i == 0:\n", 412 | " pool = False\n", 413 | " else:\n", 414 | " pool = True\n", 415 | "\n", 416 | " self.contracting_path.append(\n", 417 | " DownConvBlock(input, output, initializers, padding, pool=pool)\n", 418 | " )\n", 419 | "\n", 420 | " self.upsampling_path = nn.ModuleList()\n", 421 | "\n", 422 | " n = len(self.num_filters) - 2\n", 423 | " for i in range(n, -1, -1):\n", 424 | " input = output + self.num_filters[i]\n", 425 | " output = self.num_filters[i]\n", 426 | " self.upsampling_path.append(\n", 427 | " UpConvBlock(input, output, initializers, padding)\n", 428 | " )\n", 429 | "\n", 430 | " if self.apply_last_layer:\n", 431 | " self.last_layer = nn.Conv2d(output, num_classes, kernel_size=1)\n", 432 | " # nn.init.kaiming_normal_(self.last_layer.weight, mode='fan_in',nonlinearity='relu')\n", 433 | " # nn.init.normal_(self.last_layer.bias)\n", 434 | "\n", 435 | " def forward(self, x, val):\n", 436 | " blocks = []\n", 437 | " for i, down in enumerate(self.contracting_path):\n", 438 | " x = down(x)\n", 439 | " if i != len(self.contracting_path) - 1:\n", 440 | " blocks.append(x)\n", 441 | "\n", 442 | " for i, up in enumerate(self.upsampling_path):\n", 443 | " x = up(x, blocks[-i - 1])\n", 444 | "\n", 445 | " del blocks\n", 446 | "\n", 447 | " # Used for saving the activations and plotting\n", 448 | " # if val:\n", 449 | " # self.activation_maps.append(x)\n", 450 | "\n", 451 | " if self.apply_last_layer:\n", 452 | " x = self.last_layer(x)\n", 453 | " # x = self.pl(x)\n", 454 | "\n", 455 | " return x" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 116, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "# %%\n", 465 | "\n", 466 | "from torch.utils.data.sampler import SubsetRandomSampler\n", 467 | "from torch.utils.data import DataLoader\n", 468 | "from tqdm import tqdm_notebook, tqdm" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 117, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "def l2_regularisation(m):\n", 478 | " l2_reg = None\n", 479 | "\n", 480 | " for W in m.parameters():\n", 481 | " if l2_reg is None:\n", 482 | " l2_reg = W.norm(2)\n", 483 | " else:\n", 484 | " l2_reg = l2_reg + W.norm(2)\n", 485 | " return l2_reg" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 118, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "# %%\n", 495 | "dataset_size = len(dataset)\n", 496 | "indices = list(range(dataset_size))\n", 497 | "split = int(np.floor(0.1 * dataset_size))\n", 498 | "np.random.shuffle(indices)\n", 499 | "train_indices, test_indices = indices[split:], indices[:split]" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 119, 505 | "metadata": {}, 506 | "outputs": [ 507 | { 508 | "name": "stdout", 509 | "output_type": "stream", 510 | "text": [ 511 | "Number of training/test patches: (50, 5)\n" 512 | ] 513 | } 514 | ], 515 | "source": [ 516 | "#%%\n", 517 | "train_sampler = SubsetRandomSampler(train_indices)\n", 518 | "test_sampler = SubsetRandomSampler(test_indices)\n", 519 | "train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler)\n", 520 | "test_loader = DataLoader(dataset, batch_size=16, sampler=test_sampler)\n", 521 | "print(\"Number of training/test patches:\", (len(train_indices), len(test_indices)))" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": 120, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "import gc" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": 190, 536 | "metadata": {}, 537 | "outputs": [ 538 | { 539 | "data": { 540 | "text/plain": [ 541 | "3991" 542 | ] 543 | }, 544 | "execution_count": 190, 545 | "metadata": {}, 546 | "output_type": "execute_result" 547 | } 548 | ], 549 | "source": [ 550 | "net = None\n", 551 | "gc.collect()" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 191, 557 | "metadata": { 558 | "scrolled": true 559 | }, 560 | "outputs": [], 561 | "source": [ 562 | "# %%\n", 563 | "net = Unet(\n", 564 | " input_channels=3,\n", 565 | " num_classes=3,\n", 566 | " num_filters=[32, 64, 128, 192],\n", 567 | " initializers={\"w\": \"he_uniform\", \"b\": \"normal\"},\n", 568 | ")" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 192, 574 | "metadata": {}, 575 | "outputs": [], 576 | "source": [ 577 | "net = net.to(\"cuda\")" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": 193, 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [ 586 | "from tqdm import trange" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 195, 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "# %%\n", 596 | "optimizer = torch.optim.AdamW(net.parameters(), lr=1e-5, weight_decay=10e-3)\n", 597 | "criterion = nn.MSELoss(reduce=\"mean\").to(\"cuda\")" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": null, 603 | "metadata": { 604 | "scrolled": true 605 | }, 606 | "outputs": [ 607 | { 608 | "name": "stderr", 609 | "output_type": "stream", 610 | "text": [ 611 | "ML (loss=0.00745468): 0%| | 0/4000 [00:04 2 else 1 33 | fig, ax = plt.subplots(1, classes + 1) 34 | ax[0].set_title('Input image') 35 | ax[0].imshow(img) 36 | if classes > 1: 37 | for i in range(classes): 38 | ax[i+1].set_title(f'Output mask (class {i+1})') 39 | ax[i+1].imshow(mask[:, :, i]) 40 | else: 41 | ax[1].set_title(f'Output mask') 42 | ax[1].imshow(mask) 43 | plt.xticks([]), plt.yticks([]) 44 | plt.show() 45 | 46 | 47 | # %% 48 | 49 | class BasicDataset(Dataset): 50 | def __init__(self, imgs_dir, masks_dir, scale=1): 51 | self.imgs_dir = imgs_dir 52 | self.masks_dir = masks_dir 53 | self.scale = scale 54 | assert 0 < scale <= 1, 'Scale must be between 0 and 1' 55 | 56 | self.ids = [splitext(file)[0] for file in listdir(imgs_dir) 57 | if not file.startswith('.')] 58 | # self.c = len(self.ids) 59 | # logging.info(f'Creating dataset with {len(self.ids)} examples') 60 | 61 | def __len__(self): 62 | return len(self.ids) 63 | 64 | @classmethod 65 | def preprocess(cls, pil_img, scale): 66 | w, h = pil_img.size 67 | newW, newH = int(scale * w), int(scale * h) 68 | assert newW > 0 and newH > 0, 'Scale is too small' 69 | pil_img = pil_img.resize((newW, newH)) 70 | 71 | img_nd = np.array(pil_img) 72 | 73 | if len(img_nd.shape) == 2: 74 | img_nd = np.expand_dims(img_nd, axis=2) 75 | 76 | # HWC to CHW 77 | img_trans = img_nd.transpose((2, 0, 1)) 78 | if img_trans.max() > 1: 79 | img_trans = img_trans / 255 80 | 81 | return img_trans 82 | 83 | def __getitem__(self, i): 84 | idx = self.ids[i] 85 | mask_file = self.masks_dir + "/" + idx + '.png' 86 | img_file = self.imgs_dir + "/"+idx + '.png' 87 | 88 | mask = Image.open(mask_file) 89 | img = Image.open(img_file) 90 | 91 | assert img.size == mask.size, \ 92 | f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' 93 | 94 | img = self.preprocess(img, self.scale) 95 | mask = self.preprocess(mask, self.scale) 96 | 97 | return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)} 98 | 99 | 100 | # %% 101 | dataset = BasicDataset(path_hr, path_lr, scale=.5) 102 | # %% 103 | plt.imshow(dataset.__getitem__(1)['image'].permute(1, 2, 0)) 104 | #%% 105 | plt.imshow(dataset.__getitem__(1)['mask'].permute(1, 2, 0)) 106 | 107 | #%% [markdown] 108 | # inits 109 | # %% 110 | def truncated_normal_(tensor, mean=0, std=1): 111 | size = tensor.shape 112 | tmp = tensor.new_empty(size + (4,)).normal_() 113 | valid = (tmp < 2) & (tmp > -2) 114 | ind = valid.max(-1, keepdim=True)[1] 115 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 116 | tensor.data.mul_(std).add_(mean) 117 | def init_weights(m): 118 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 119 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 120 | #nn.init.normal_(m.weight, std=0.001) 121 | #nn.init.normal_(m.bias, std=0.001) 122 | truncated_normal_(m.bias, mean=0, std=0.001) 123 | 124 | #%% [markdown] 125 | # Blocks for Unet 126 | 127 | class DownConvBlock(nn.Module): 128 | """ 129 | A block of three convolutional layers where each layer is followed by a non-linear activation function 130 | Between each block we add a pooling operation. 131 | """ 132 | def __init__(self, input_dim, output_dim, initializers, padding, pool=True): 133 | super(DownConvBlock, self).__init__() 134 | layers = [] 135 | 136 | if pool: 137 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 138 | 139 | layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 140 | layers.append(nn.ReLU(inplace=True)) 141 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 142 | layers.append(nn.ReLU(inplace=True)) 143 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 144 | layers.append(nn.ReLU(inplace=True)) 145 | 146 | self.layers = nn.Sequential(*layers) 147 | 148 | self.layers.apply(init_weights) 149 | 150 | def forward(self, patch): 151 | return self.layers(patch) 152 | 153 | 154 | class UpConvBlock(nn.Module): 155 | """ 156 | A block consists of an upsampling layer followed by a convolutional layer to reduce the amount of channels and then a DownConvBlock 157 | If bilinear is set to false, we do a transposed convolution instead of upsampling 158 | """ 159 | def __init__(self, input_dim, output_dim, initializers, padding, bilinear=True): 160 | super(UpConvBlock, self).__init__() 161 | self.bilinear = bilinear 162 | 163 | if not self.bilinear: 164 | self.upconv_layer = nn.ConvTranspose2d(input_dim, output_dim, kernel_size=2, stride=2) 165 | self.upconv_layer.apply(init_weights) 166 | 167 | self.conv_block = DownConvBlock(input_dim, output_dim, initializers, padding, pool=False) 168 | 169 | def forward(self, x, bridge): 170 | if self.bilinear: 171 | up = nn.functional.interpolate(x, mode='bilinear', scale_factor=2, align_corners=True) 172 | else: 173 | up = self.upconv_layer(x) 174 | 175 | assert up.shape[3] == bridge.shape[3] 176 | out = torch.cat([up, bridge], 1) 177 | out = self.conv_block(out) 178 | 179 | return out 180 | 181 | # %% [markdown] 182 | # Unet 183 | 184 | # %% 185 | class Unet(nn.Module): 186 | """ 187 | A UNet (https://arxiv.org/abs/1505.04597) implementation. 188 | input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) 189 | num_classes: the number of classes to predict 190 | num_filters: list with the amount of filters per layer 191 | apply_last_layer: boolean to apply last layer or not (not used in Probabilistic UNet) 192 | padidng: Boolean, if true we pad the images with 1 so that we keep the same dimensions 193 | """ 194 | 195 | def __init__(self, input_channels, num_classes, num_filters, initializers, apply_last_layer=True, padding=True): 196 | super(Unet, self).__init__() 197 | self.input_channels = input_channels 198 | self.num_classes = num_classes 199 | self.num_filters = num_filters 200 | self.padding = padding 201 | self.activation_maps = [] 202 | self.apply_last_layer = apply_last_layer 203 | self.contracting_path = nn.ModuleList() 204 | 205 | for i in range(len(self.num_filters)): 206 | input = self.input_channels if i == 0 else output 207 | output = self.num_filters[i] 208 | 209 | if i == 0: 210 | pool = False 211 | else: 212 | pool = True 213 | 214 | self.contracting_path.append(DownConvBlock(input, output, initializers, padding, pool=pool)) 215 | 216 | self.upsampling_path = nn.ModuleList() 217 | 218 | n = len(self.num_filters) - 2 219 | for i in range(n, -1, -1): 220 | input = output + self.num_filters[i] 221 | output = self.num_filters[i] 222 | self.upsampling_path.append(UpConvBlock(input, output, initializers, padding)) 223 | 224 | if self.apply_last_layer: 225 | self.last_layer = nn.Conv2d(output, num_classes, kernel_size=1) 226 | #nn.init.kaiming_normal_(self.last_layer.weight, mode='fan_in',nonlinearity='relu') 227 | #nn.init.normal_(self.last_layer.bias) 228 | 229 | 230 | def forward(self, x, val): 231 | blocks = [] 232 | for i, down in enumerate(self.contracting_path): 233 | x = down(x) 234 | if i != len(self.contracting_path)-1: 235 | blocks.append(x) 236 | 237 | for i, up in enumerate(self.upsampling_path): 238 | x = up(x, blocks[-i-1]) 239 | 240 | del blocks 241 | 242 | #Used for saving the activations and plotting 243 | if val: 244 | self.activation_maps.append(x) 245 | 246 | if self.apply_last_layer: 247 | x = self.last_layer(x) 248 | 249 | return x 250 | 251 | 252 | # %% 253 | 254 | from torch.utils.data.sampler import SubsetRandomSampler 255 | from torch.utils.data import DataLoader 256 | from tqdm import tqdm_notebook,tqdm 257 | 258 | def l2_regularisation(m): 259 | l2_reg = None 260 | 261 | for W in m.parameters(): 262 | if l2_reg is None: 263 | l2_reg = W.norm(2) 264 | else: 265 | l2_reg = l2_reg + W.norm(2) 266 | return l2_reg 267 | # %% 268 | dataset_size = len(dataset) 269 | indices = list(range(dataset_size)) 270 | split = int(np.floor(0.1 * dataset_size)) 271 | np.random.shuffle(indices) 272 | train_indices, test_indices = indices[split:], indices[:split] 273 | #%% 274 | train_sampler = SubsetRandomSampler(train_indices) 275 | test_sampler = SubsetRandomSampler(test_indices) 276 | train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler) 277 | test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler) 278 | print("Number of training/test patches:", (len(train_indices),len(test_indices))) 279 | 280 | # %% 281 | net = Unet(input_channels=3, num_classes=2, num_filters=[32,64,128,192],initializers = {'w':'he_normal', 'b':'normal'}).to('cuda') 282 | 283 | # %% 284 | optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=0) 285 | epochs = 10 286 | #%% 287 | for epoch in tqdm(range(epochs)): 288 | for step, (patch, mask, _) in enumerate(train_loader): 289 | patch = patch.to(device) 290 | mask = mask.to(device) 291 | mask = torch.unsqueeze(mask,1) 292 | net.forward(patch, mask, training=True) 293 | elbo = net.elbo(mask) 294 | reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers) 295 | loss = -elbo + 1e-5 * reg_loss 296 | optimizer.zero_grad() 297 | loss.backward() 298 | optimizer.step() -------------------------------------------------------------------------------- /test1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/test1.jpeg -------------------------------------------------------------------------------- /test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/test2.png -------------------------------------------------------------------------------- /test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/test3.jpg -------------------------------------------------------------------------------- /test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fast-AI-Code/deHazer/d53fc952f0eb3804dd2d74954ca724f2c63df53f/test4.jpg --------------------------------------------------------------------------------