├── Dockerfile ├── LICENSE ├── README.md ├── activation_maximization.ipynb ├── data ├── denoising │ ├── F16_GT.png │ └── snail.jpg ├── feature_inversion │ ├── building.jpg │ └── monkey.jpg ├── flash_no_flash │ ├── cave01_00_flash.jpg │ └── cave01_01_noflash.jpg ├── imagenet1000_clsid_to_human.txt ├── inpainting │ ├── kate.png │ ├── kate_mask.png │ ├── library.png │ ├── library_mask.png │ ├── vase.png │ └── vase_mask.png ├── restoration │ ├── barbara.png │ └── kate.png ├── sr │ ├── zebra_GT.png │ └── zebra_crop.png └── teaser_compiled.jpg ├── denoising.ipynb ├── environment.yml ├── feature_inversion.ipynb ├── flash-no-flash.ipynb ├── inpainting.ipynb ├── models ├── __init__.py ├── common.py ├── dcgan.py ├── downsampler.py ├── resnet.py ├── skip.py ├── texture_nets.py └── unet.py ├── restoration.ipynb ├── sr_prior_effect.ipynb ├── super-resolution.ipynb ├── super-resolution_eval_script.py └── utils ├── __init__.py ├── common_utils.py ├── denoising_utils.py ├── feature_inversion_utils.py ├── inpainting_utils.py ├── matcher.py ├── perceptual_loss ├── __init__.py ├── matcher.py ├── perceptual_loss.py └── vgg_modified.py └── sr_utils.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-devel 2 | 3 | # Install system dependencies 4 | RUN apt-get update \ 5 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 6 | build-essential \ 7 | curl \ 8 | git \ 9 | && apt-get clean 10 | 11 | # Install python miniconda3 + requirements 12 | ENV MINICONDA_HOME="/opt/miniconda" 13 | ENV PATH="${MINICONDA_HOME}/bin:${PATH}" 14 | RUN curl -o Miniconda3-latest-Linux-x86_64.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 15 | && chmod +x Miniconda3-latest-Linux-x86_64.sh \ 16 | && ./Miniconda3-latest-Linux-x86_64.sh -b -p "${MINICONDA_HOME}" \ 17 | && rm Miniconda3-latest-Linux-x86_64.sh 18 | COPY environment.yml environment.yml 19 | RUN conda env update -n=root --file=environment.yml 20 | RUN conda clean -y -i -l -p -t && \ 21 | rm environment.yml 22 | 23 | # Clone deep image prior repository 24 | RUN git clone https://github.com/DmitryUlyanov/deep-image-prior.git 25 | WORKDIR /deep-image-prior 26 | 27 | # Start container in notebook mode 28 | CMD jupyter notebook --ip="*" --no-browser --allow-root 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2.0 2 | 3 | But please contact me if you want to use this software in a commercial application. 4 | 5 | Note that Apache License 2.0 asks to include a copyright notice if you use this software. 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Warning!** The optimization may not converge on some GPUs. We've personally experienced issues on Tesla V100 and P40 GPUs. When running the code, make sure you get similar results to the paper first. Easiest to check using text inpainting notebook. Try to set double precision mode or turn off cudnn. 2 | 3 | # Deep image prior 4 | 5 | In this repository we provide *Jupyter Notebooks* to reproduce each figure from the paper: 6 | 7 | > **Deep Image Prior** 8 | 9 | > CVPR 2018 10 | 11 | > Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky 12 | 13 | 14 | [[paper]](https://sites.skoltech.ru/app/data/uploads/sites/25/2018/04/deep_image_prior.pdf) [[supmat]](https://box.skoltech.ru/index.php/s/ib52BOoV58ztuPM) [[project page]](https://dmitryulyanov.github.io/deep_image_prior) 15 | 16 | ![](data/teaser_compiled.jpg) 17 | 18 | Here we provide hyperparameters and architectures, that were used to generate the figures. Most of them are far from optimal. Do not hesitate to change them and see the effect. 19 | 20 | We will expand this README with a list of hyperparameters and options shortly. 21 | 22 | # Install 23 | 24 | Here is the list of libraries you need to install to execute the code: 25 | - python = 3.6 26 | - [pytorch](http://pytorch.org/) = 0.4 27 | - numpy 28 | - scipy 29 | - matplotlib 30 | - scikit-image 31 | - jupyter 32 | 33 | All of them can be installed via `conda` (`anaconda`), e.g. 34 | ``` 35 | conda install jupyter 36 | ``` 37 | 38 | 39 | or create an conda env with all dependencies via environment file 40 | 41 | ``` 42 | conda env create -f environment.yml 43 | ``` 44 | 45 | ## Docker image 46 | 47 | Alternatively, you can use a Docker image that exposes a Jupyter Notebook with all required dependencies. To build this image ensure you have both [docker](https://www.docker.com/) and [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) installed, then run 48 | 49 | ``` 50 | nvidia-docker build -t deep-image-prior . 51 | ``` 52 | 53 | After the build you can start the container as 54 | 55 | ``` 56 | nvidia-docker run --rm -it --ipc=host -p 8888:8888 deep-image-prior 57 | ``` 58 | 59 | you will be provided an URL through which you can connect to the Jupyter notebook. 60 | 61 | ## Google Colab 62 | 63 | To run it using Google Colab, click [here](https://colab.research.google.com/github/DmitryUlyanov/deep-image-prior) and select the notebook to run. Remember to uncomment the first cell to clone the repository into colab's environment. 64 | 65 | 66 | # Citation 67 | ``` 68 | @article{UlyanovVL17, 69 | author = {Ulyanov, Dmitry and Vedaldi, Andrea and Lempitsky, Victor}, 70 | title = {Deep Image Prior}, 71 | journal = {arXiv:1711.10925}, 72 | year = {2017} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /activation_maximization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for **\"Activation maximization\"** figure." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\"\"\"\n", 17 | "*Uncomment if running on colab* \n", 18 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 19 | "\"\"\"\n", 20 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 21 | "# !mv deep-image-prior/* ./" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "You can select net type (`vgg_16_caffe`, `vgg19_caffe`, `alexnet`) and a layer. For your reference the layer names for each network type are shown below." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "vgg_19_names=['conv1_1','relu1_1','conv1_2','relu1_2','pool1',\n", 38 | " 'conv2_1','relu2_1','conv2_2','relu2_2','pool2',\n", 39 | " 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3',\n", 40 | " 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4',\n", 41 | " 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4','pool5',\n", 42 | " 'torch_view','fc6','relu6','drop6','fc7','relu7','drop7','fc8']\n", 43 | "\n", 44 | "vgg_16_names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1',\n", 45 | " 'conv2_1','relu2_1','conv2_2','relu2_2','pool2',\n", 46 | " 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','pool3',\n", 47 | " 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4',\n", 48 | " 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5',\n", 49 | " 'torch_view','fc6','relu6','drop6','fc7','relu7','fc8']\n", 50 | "\n", 51 | "alexnet_names = ['conv1', 'relu1', 'norm1', 'pool1',\n", 52 | " 'conv2', 'relu2', 'norm2', 'pool2',\n", 53 | " 'conv3', 'relu3', 'conv4', 'relu4',\n", 54 | " 'conv5', 'relu5', 'pool5', 'torch_view',\n", 55 | " 'fc6', 'relu6', 'drop6',\n", 56 | " 'fc7', 'relu7', 'drop7',\n", 57 | " 'fc8', 'softmax']" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "The actual code starts here." 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "# Import libs" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "from __future__ import print_function\n", 81 | "import matplotlib.pyplot as plt\n", 82 | "%matplotlib inline\n", 83 | "\n", 84 | "import argparse\n", 85 | "import os\n", 86 | "# os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n", 87 | "\n", 88 | "import numpy as np\n", 89 | "from models import *\n", 90 | "\n", 91 | "import torch\n", 92 | "import torch.optim\n", 93 | "\n", 94 | "from utils.perceptual_loss.perceptual_loss import *\n", 95 | "from utils.common_utils import *\n", 96 | "\n", 97 | "torch.backends.cudnn.enabled = True\n", 98 | "torch.backends.cudnn.benchmark =True\n", 99 | "dtype = torch.cuda.FloatTensor\n", 100 | "\n", 101 | "PLOT = True\n", 102 | "fname = './data/feature_inversion/building.jpg'\n", 103 | "\n", 104 | "# Choose net type\n", 105 | "pretrained_net = 'alexnet_caffe' \n", 106 | "assert pretrained_net in ['alexnet_caffe', 'vgg19_caffe', 'vgg16_caffe']\n", 107 | "\n", 108 | "# Choose layers\n", 109 | "layer_to_use = 'conv4'" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "import json\n", 119 | "with open('data/imagenet1000_clsid_to_human.txt', 'r') as f:\n", 120 | " corresp = json.load(f)\n", 121 | " \n", 122 | "\n", 123 | "if layer_to_use == 'fc8':\n", 124 | " # Choose class\n", 125 | " name = 'black swan'\n", 126 | " # name = 'cheeseburger'\n", 127 | "\n", 128 | " map_idx = None\n", 129 | " for k,v in corresp.items():\n", 130 | " if name in v:\n", 131 | " map_idx = int(k)\n", 132 | " break\n", 133 | "else:\n", 134 | " map_idx = 2 # Choose here" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "# Setup pretrained net" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "# Target imsize \n", 151 | "imsize = 227 if pretrained_net == 'alexnet_caffe' else 224\n", 152 | "\n", 153 | "# Something divisible by a power of two\n", 154 | "imsize_net = 256\n", 155 | "\n", 156 | "# VGG and Alexnet need input to be correctly normalized\n", 157 | "preprocess, deprocess = get_preprocessor(imsize), get_deprocessor()\n", 158 | "\n", 159 | "\n", 160 | "img_content_pil, img_content_np = get_image(fname, -1)\n", 161 | "img_content_prerocessed = preprocess(img_content_pil)[None,:].type(dtype)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": { 168 | "scrolled": true 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "opt_content = {'layers': [layer_to_use], 'what':'features', 'map_idx': map_idx}\n", 173 | "\n", 174 | "cnn = get_pretrained_net(pretrained_net).type(dtype)\n", 175 | "cnn.add_module('softmax', nn.Softmax())\n", 176 | "\n", 177 | "# Remove the layers we don't need \n", 178 | "keys = [x for x in cnn._modules.keys()]\n", 179 | "max_idx = max(keys.index(x) for x in opt_content['layers'])\n", 180 | "for k in keys[max_idx+1:]:\n", 181 | " cnn._modules.pop(k)\n", 182 | " \n", 183 | "print(cnn)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "scrolled": true 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "matcher_content = get_matcher(cnn, opt_content)\n", 195 | "matcher_content.mode = 'match'\n", 196 | "\n", 197 | "if layer_to_use == 'fc8':\n", 198 | " matcher_content.mode = 'match'\n", 199 | " LR = 0.01\n", 200 | "else:\n", 201 | " \n", 202 | " # Choose here\n", 203 | " # Window size controls the width of the region where the activations are maximized\n", 204 | " matcher_content.window_size = 20 # if = 1 then it is neuron maximization\n", 205 | " matcher_content.method = 'maximize'\n", 206 | " LR = 0.001" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "# Setup matcher and net" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "INPUT = 'noise'\n", 223 | "input_depth = 32\n", 224 | "OPTIMIZER = 'adam'\n", 225 | "net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()\n", 226 | "OPT_OVER = 'net' #'net,input'\n", 227 | "pad='reflection'\n", 228 | "\n", 229 | "tv_weight=0.0\n", 230 | "reg_noise_std = 0.03\n", 231 | "param_noise = True\n", 232 | "num_iter = 3100" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],\n", 242 | " num_channels_up = [16, 32, 64, 128, 128, 128],\n", 243 | " num_channels_skip = [0, 4, 4, 4, 4, 4], \n", 244 | " filter_size_down = [5, 3, 5, 5, 3, 5], filter_size_up = [5, 3, 5, 3, 5, 3], \n", 245 | " upsample_mode='bilinear', downsample_mode='avg',\n", 246 | " need_sigmoid=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n", 247 | "\n", 248 | "\n", 249 | "\n", 250 | "\n", 251 | "\n", 252 | "# Compute number of parameters\n", 253 | "s = sum(np.prod(list(p.size())) for p in net.parameters())\n", 254 | "print ('Number of params: %d' % s)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "net(net_input).shape" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "### TV" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "Uncomment this section if you do not wan to optimize over pixels with TV prior only." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "# INPUT = 'noise'\n", 287 | "# input_depth = 3\n", 288 | "# net_input = (get_noise(input_depth, INPUT, imsize_net).type(dtype)+0.5).detach()\n", 289 | "\n", 290 | "# OPT_OVER = 'input' #'net,input'\n", 291 | "# net = nn.Sequential()\n", 292 | "# reg_noise_std =0.0\n", 293 | "# OPTIMIZER = 'adam'# 'LBFGS'\n", 294 | "# LR = 0.01\n", 295 | "# tv_weight=1e-6" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "# Optimize" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "mask = net_input.clone()[:,:3,:imsize,:imsize] * 0\n", 312 | "for i in range(imsize):\n", 313 | " for j in range(imsize):\n", 314 | " d = np.sqrt((i - imsize//2)**2 + (j - imsize//2)**2)\n", 315 | "# if d > 75:\n", 316 | " mask[:,:, i, j] = 1 - min(100./d, 1)\n", 317 | " \n", 318 | "plot_image_grid([torch_to_np(mask)]);\n", 319 | "use_mask = False" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "from utils.sr_utils import tv_loss\n", 329 | "\n", 330 | "net_input_saved = net_input.detach().clone()\n", 331 | "noise = net_input.detach().clone()\n", 332 | "\n", 333 | "\n", 334 | "outs = [] \n", 335 | "\n", 336 | "def closure():\n", 337 | " \n", 338 | " global i, net_input\n", 339 | " \n", 340 | " if param_noise:\n", 341 | " for n in [x for x in net.parameters() if len(x.size()) == 4]:\n", 342 | " n = n + n.detach().clone().normal_() * n.std()/50\n", 343 | " \n", 344 | " net_input = net_input_saved\n", 345 | " if reg_noise_std > 0:\n", 346 | " net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n", 347 | "\n", 348 | " out = net(net_input)[:, :, :imsize, :imsize]\n", 349 | " \n", 350 | "# out = out* (1-mask)\n", 351 | " \n", 352 | " \n", 353 | " cnn(vgg_preprocess_caffe(out))\n", 354 | " total_loss = sum(matcher_content.losses.values()) * 5\n", 355 | " \n", 356 | " if tv_weight > 0:\n", 357 | " total_loss += tv_weight * tv_loss(vgg_preprocess_caffe(out), beta=2)\n", 358 | " \n", 359 | " \n", 360 | " if use_mask:\n", 361 | " total_loss += nn.functional.mse_loss(out * mask, mask * 0, size_average=False) * 1e1\n", 362 | " \n", 363 | " total_loss.backward()\n", 364 | "\n", 365 | " print ('Iteration %05d Loss %.3f' % (i, total_loss.item()), '\\r', end='')\n", 366 | " if PLOT and i % 100==0:\n", 367 | " out_np = np.clip(torch_to_np(out), 0, 1)\n", 368 | " plot_image_grid([out_np], 3, 3, interpolation='lanczos');\n", 369 | " \n", 370 | " outs.append(out_np)\n", 371 | " i += 1\n", 372 | " \n", 373 | " return total_loss" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "i=0\n", 383 | "\n", 384 | "p = get_params(OPT_OVER, net, net_input)\n", 385 | "\n", 386 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "# Result" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "out = net(net_input)[:, :, :imsize, :imsize]\n", 403 | "plot_image_grid([torch_to_np(out)], 3, 3);" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [] 412 | } 413 | ], 414 | "metadata": { 415 | "kernelspec": { 416 | "display_name": "Python 3", 417 | "language": "python", 418 | "name": "python3" 419 | }, 420 | "language_info": { 421 | "codemirror_mode": { 422 | "name": "ipython", 423 | "version": 3 424 | }, 425 | "file_extension": ".py", 426 | "mimetype": "text/x-python", 427 | "name": "python", 428 | "nbconvert_exporter": "python", 429 | "pygments_lexer": "ipython3", 430 | "version": "3.6.9" 431 | } 432 | }, 433 | "nbformat": 4, 434 | "nbformat_minor": 2 435 | } 436 | -------------------------------------------------------------------------------- /data/denoising/F16_GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/denoising/F16_GT.png -------------------------------------------------------------------------------- /data/denoising/snail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/denoising/snail.jpg -------------------------------------------------------------------------------- /data/feature_inversion/building.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/feature_inversion/building.jpg -------------------------------------------------------------------------------- /data/feature_inversion/monkey.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/feature_inversion/monkey.jpg -------------------------------------------------------------------------------- /data/flash_no_flash/cave01_00_flash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/flash_no_flash/cave01_00_flash.jpg -------------------------------------------------------------------------------- /data/flash_no_flash/cave01_01_noflash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/flash_no_flash/cave01_01_noflash.jpg -------------------------------------------------------------------------------- /data/imagenet1000_clsid_to_human.txt: -------------------------------------------------------------------------------- 1 | {"0": "tench, Tinca tinca", 2 | "1": "goldfish, Carassius auratus", 3 | "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", 4 | "3": "tiger shark, Galeocerdo cuvieri", 5 | "4": "hammerhead, hammerhead shark", 6 | "5": "electric ray, crampfish, numbfish, torpedo", 7 | "6": "stingray", 8 | "7": "cock", 9 | "8": "hen", 10 | "9": "ostrich, Struthio camelus", 11 | "10": "brambling, Fringilla montifringilla", 12 | "11": "goldfinch, Carduelis carduelis", 13 | "12": "house finch, linnet, Carpodacus mexicanus", 14 | "13": "junco, snowbird", 15 | "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", 16 | "15": "robin, American robin, Turdus migratorius", 17 | "16": "bulbul", 18 | "17": "jay", 19 | "18": "magpie", 20 | "19": "chickadee", 21 | "20": "water ouzel, dipper", 22 | "21": "kite", 23 | "22": "bald eagle, American eagle, Haliaeetus leucocephalus", 24 | "23": "vulture", 25 | "24": "great grey owl, great gray owl, Strix nebulosa", 26 | "25": "European fire salamander, Salamandra salamandra", 27 | "26": "common newt, Triturus vulgaris", 28 | "27": "eft", 29 | "28": "spotted salamander, Ambystoma maculatum", 30 | "29": "axolotl, mud puppy, Ambystoma mexicanum", 31 | "30": "bullfrog, Rana catesbeiana", 32 | "31": "tree frog, tree-frog", 33 | "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", 34 | "33": "loggerhead, loggerhead turtle, Caretta caretta", 35 | "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", 36 | "35": "mud turtle", 37 | "36": "terrapin", 38 | "37": "box turtle, box tortoise", 39 | "38": "banded gecko", 40 | "39": "common iguana, iguana, Iguana iguana", 41 | "40": "American chameleon, anole, Anolis carolinensis", 42 | "41": "whiptail, whiptail lizard", 43 | "42": "agama", 44 | "43": "frilled lizard, Chlamydosaurus kingi", 45 | "44": "alligator lizard", 46 | "45": "Gila monster, Heloderma suspectum", 47 | "46": "green lizard, Lacerta viridis", 48 | "47": "African chameleon, Chamaeleo chamaeleon", 49 | "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", 50 | "49": "African crocodile, Nile crocodile, Crocodylus niloticus", 51 | "50": "American alligator, Alligator mississipiensis", 52 | "51": "triceratops", 53 | "52": "thunder snake, worm snake, Carphophis amoenus", 54 | "53": "ringneck snake, ring-necked snake, ring snake", 55 | "54": "hognose snake, puff adder, sand viper", 56 | "55": "green snake, grass snake", 57 | "56": "king snake, kingsnake", 58 | "57": "garter snake, grass snake", 59 | "58": "water snake", 60 | "59": "vine snake", 61 | "60": "night snake, Hypsiglena torquata", 62 | "61": "boa constrictor, Constrictor constrictor", 63 | "62": "rock python, rock snake, Python sebae", 64 | "63": "Indian cobra, Naja naja", 65 | "64": "green mamba", 66 | "65": "sea snake", 67 | "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", 68 | "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", 69 | "68": "sidewinder, horned rattlesnake, Crotalus cerastes", 70 | "69": "trilobite", 71 | "70": "harvestman, daddy longlegs, Phalangium opilio", 72 | "71": "scorpion", 73 | "72": "black and gold garden spider, Argiope aurantia", 74 | "73": "barn spider, Araneus cavaticus", 75 | "74": "garden spider, Aranea diademata", 76 | "75": "black widow, Latrodectus mactans", 77 | "76": "tarantula", 78 | "77": "wolf spider, hunting spider", 79 | "78": "tick", 80 | "79": "centipede", 81 | "80": "black grouse", 82 | "81": "ptarmigan", 83 | "82": "ruffed grouse, partridge, Bonasa umbellus", 84 | "83": "prairie chicken, prairie grouse, prairie fowl", 85 | "84": "peacock", 86 | "85": "quail", 87 | "86": "partridge", 88 | "87": "African grey, African gray, Psittacus erithacus", 89 | "88": "macaw", 90 | "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 91 | "90": "lorikeet", 92 | "91": "coucal", 93 | "92": "bee eater", 94 | "93": "hornbill", 95 | "94": "hummingbird", 96 | "95": "jacamar", 97 | "96": "toucan", 98 | "97": "drake", 99 | "98": "red-breasted merganser, Mergus serrator", 100 | "99": "goose", 101 | "100": "black swan, Cygnus atratus", 102 | "101": "tusker", 103 | "102": "echidna, spiny anteater, anteater", 104 | "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", 105 | "104": "wallaby, brush kangaroo", 106 | "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", 107 | "106": "wombat", 108 | "107": "jellyfish", 109 | "108": "sea anemone, anemone", 110 | "109": "brain coral", 111 | "110": "flatworm, platyhelminth", 112 | "111": "nematode, nematode worm, roundworm", 113 | "112": "conch", 114 | "113": "snail", 115 | "114": "slug", 116 | "115": "sea slug, nudibranch", 117 | "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", 118 | "117": "chambered nautilus, pearly nautilus, nautilus", 119 | "118": "Dungeness crab, Cancer magister", 120 | "119": "rock crab, Cancer irroratus", 121 | "120": "fiddler crab", 122 | "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", 123 | "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", 124 | "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", 125 | "124": "crayfish, crawfish, crawdad, crawdaddy", 126 | "125": "hermit crab", 127 | "126": "isopod", 128 | "127": "white stork, Ciconia ciconia", 129 | "128": "black stork, Ciconia nigra", 130 | "129": "spoonbill", 131 | "130": "flamingo", 132 | "131": "little blue heron, Egretta caerulea", 133 | "132": "American egret, great white heron, Egretta albus", 134 | "133": "bittern", 135 | "134": "crane", 136 | "135": "limpkin, Aramus pictus", 137 | "136": "European gallinule, Porphyrio porphyrio", 138 | "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", 139 | "138": "bustard", 140 | "139": "ruddy turnstone, Arenaria interpres", 141 | "140": "red-backed sandpiper, dunlin, Erolia alpina", 142 | "141": "redshank, Tringa totanus", 143 | "142": "dowitcher", 144 | "143": "oystercatcher, oyster catcher", 145 | "144": "pelican", 146 | "145": "king penguin, Aptenodytes patagonica", 147 | "146": "albatross, mollymawk", 148 | "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", 149 | "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", 150 | "149": "dugong, Dugong dugon", 151 | "150": "sea lion", 152 | "151": "Chihuahua", 153 | "152": "Japanese spaniel", 154 | "153": "Maltese dog, Maltese terrier, Maltese", 155 | "154": "Pekinese, Pekingese, Peke", 156 | "155": "Shih-Tzu", 157 | "156": "Blenheim spaniel", 158 | "157": "papillon", 159 | "158": "toy terrier", 160 | "159": "Rhodesian ridgeback", 161 | "160": "Afghan hound, Afghan", 162 | "161": "basset, basset hound", 163 | "162": "beagle", 164 | "163": "bloodhound, sleuthhound", 165 | "164": "bluetick", 166 | "165": "black-and-tan coonhound", 167 | "166": "Walker hound, Walker foxhound", 168 | "167": "English foxhound", 169 | "168": "redbone", 170 | "169": "borzoi, Russian wolfhound", 171 | "170": "Irish wolfhound", 172 | "171": "Italian greyhound", 173 | "172": "whippet", 174 | "173": "Ibizan hound, Ibizan Podenco", 175 | "174": "Norwegian elkhound, elkhound", 176 | "175": "otterhound, otter hound", 177 | "176": "Saluki, gazelle hound", 178 | "177": "Scottish deerhound, deerhound", 179 | "178": "Weimaraner", 180 | "179": "Staffordshire bullterrier, Staffordshire bull terrier", 181 | "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", 182 | "181": "Bedlington terrier", 183 | "182": "Border terrier", 184 | "183": "Kerry blue terrier", 185 | "184": "Irish terrier", 186 | "185": "Norfolk terrier", 187 | "186": "Norwich terrier", 188 | "187": "Yorkshire terrier", 189 | "188": "wire-haired fox terrier", 190 | "189": "Lakeland terrier", 191 | "190": "Sealyham terrier, Sealyham", 192 | "191": "Airedale, Airedale terrier", 193 | "192": "cairn, cairn terrier", 194 | "193": "Australian terrier", 195 | "194": "Dandie Dinmont, Dandie Dinmont terrier", 196 | "195": "Boston bull, Boston terrier", 197 | "196": "miniature schnauzer", 198 | "197": "giant schnauzer", 199 | "198": "standard schnauzer", 200 | "199": "Scotch terrier, Scottish terrier, Scottie", 201 | "200": "Tibetan terrier, chrysanthemum dog", 202 | "201": "silky terrier, Sydney silky", 203 | "202": "soft-coated wheaten terrier", 204 | "203": "West Highland white terrier", 205 | "204": "Lhasa, Lhasa apso", 206 | "205": "flat-coated retriever", 207 | "206": "curly-coated retriever", 208 | "207": "golden retriever", 209 | "208": "Labrador retriever", 210 | "209": "Chesapeake Bay retriever", 211 | "210": "German short-haired pointer", 212 | "211": "vizsla, Hungarian pointer", 213 | "212": "English setter", 214 | "213": "Irish setter, red setter", 215 | "214": "Gordon setter", 216 | "215": "Brittany spaniel", 217 | "216": "clumber, clumber spaniel", 218 | "217": "English springer, English springer spaniel", 219 | "218": "Welsh springer spaniel", 220 | "219": "cocker spaniel, English cocker spaniel, cocker", 221 | "220": "Sussex spaniel", 222 | "221": "Irish water spaniel", 223 | "222": "kuvasz", 224 | "223": "schipperke", 225 | "224": "groenendael", 226 | "225": "malinois", 227 | "226": "briard", 228 | "227": "kelpie", 229 | "228": "komondor", 230 | "229": "Old English sheepdog, bobtail", 231 | "230": "Shetland sheepdog, Shetland sheep dog, Shetland", 232 | "231": "collie", 233 | "232": "Border collie", 234 | "233": "Bouvier des Flandres, Bouviers des Flandres", 235 | "234": "Rottweiler", 236 | "235": "German shepherd, German shepherd dog, German police dog, alsatian", 237 | "236": "Doberman, Doberman pinscher", 238 | "237": "miniature pinscher", 239 | "238": "Greater Swiss Mountain dog", 240 | "239": "Bernese mountain dog", 241 | "240": "Appenzeller", 242 | "241": "EntleBucher", 243 | "242": "boxer", 244 | "243": "bull mastiff", 245 | "244": "Tibetan mastiff", 246 | "245": "French bulldog", 247 | "246": "Great Dane", 248 | "247": "Saint Bernard, St Bernard", 249 | "248": "Eskimo dog, husky", 250 | "249": "malamute, malemute, Alaskan malamute", 251 | "250": "Siberian husky", 252 | "251": "dalmatian, coach dog, carriage dog", 253 | "252": "affenpinscher, monkey pinscher, monkey dog", 254 | "253": "basenji", 255 | "254": "pug, pug-dog", 256 | "255": "Leonberg", 257 | "256": "Newfoundland, Newfoundland dog", 258 | "257": "Great Pyrenees", 259 | "258": "Samoyed, Samoyede", 260 | "259": "Pomeranian", 261 | "260": "chow, chow chow", 262 | "261": "keeshond", 263 | "262": "Brabancon griffon", 264 | "263": "Pembroke, Pembroke Welsh corgi", 265 | "264": "Cardigan, Cardigan Welsh corgi", 266 | "265": "toy poodle", 267 | "266": "miniature poodle", 268 | "267": "standard poodle", 269 | "268": "Mexican hairless", 270 | "269": "timber wolf, grey wolf, gray wolf, Canis lupus", 271 | "270": "white wolf, Arctic wolf, Canis lupus tundrarum", 272 | "271": "red wolf, maned wolf, Canis rufus, Canis niger", 273 | "272": "coyote, prairie wolf, brush wolf, Canis latrans", 274 | "273": "dingo, warrigal, warragal, Canis dingo", 275 | "274": "dhole, Cuon alpinus", 276 | "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", 277 | "276": "hyena, hyaena", 278 | "277": "red fox, Vulpes vulpes", 279 | "278": "kit fox, Vulpes macrotis", 280 | "279": "Arctic fox, white fox, Alopex lagopus", 281 | "280": "grey fox, gray fox, Urocyon cinereoargenteus", 282 | "281": "tabby, tabby cat", 283 | "282": "tiger cat", 284 | "283": "Persian cat", 285 | "284": "Siamese cat, Siamese", 286 | "285": "Egyptian cat", 287 | "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", 288 | "287": "lynx, catamount", 289 | "288": "leopard, Panthera pardus", 290 | "289": "snow leopard, ounce, Panthera uncia", 291 | "290": "jaguar, panther, Panthera onca, Felis onca", 292 | "291": "lion, king of beasts, Panthera leo", 293 | "292": "tiger, Panthera tigris", 294 | "293": "cheetah, chetah, Acinonyx jubatus", 295 | "294": "brown bear, bruin, Ursus arctos", 296 | "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", 297 | "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", 298 | "297": "sloth bear, Melursus ursinus, Ursus ursinus", 299 | "298": "mongoose", 300 | "299": "meerkat, mierkat", 301 | "300": "tiger beetle", 302 | "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", 303 | "302": "ground beetle, carabid beetle", 304 | "303": "long-horned beetle, longicorn, longicorn beetle", 305 | "304": "leaf beetle, chrysomelid", 306 | "305": "dung beetle", 307 | "306": "rhinoceros beetle", 308 | "307": "weevil", 309 | "308": "fly", 310 | "309": "bee", 311 | "310": "ant, emmet, pismire", 312 | "311": "grasshopper, hopper", 313 | "312": "cricket", 314 | "313": "walking stick, walkingstick, stick insect", 315 | "314": "cockroach, roach", 316 | "315": "mantis, mantid", 317 | "316": "cicada, cicala", 318 | "317": "leafhopper", 319 | "318": "lacewing, lacewing fly", 320 | "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", 321 | "320": "damselfly", 322 | "321": "admiral", 323 | "322": "ringlet, ringlet butterfly", 324 | "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", 325 | "324": "cabbage butterfly", 326 | "325": "sulphur butterfly, sulfur butterfly", 327 | "326": "lycaenid, lycaenid butterfly", 328 | "327": "starfish, sea star", 329 | "328": "sea urchin", 330 | "329": "sea cucumber, holothurian", 331 | "330": "wood rabbit, cottontail, cottontail rabbit", 332 | "331": "hare", 333 | "332": "Angora, Angora rabbit", 334 | "333": "hamster", 335 | "334": "porcupine, hedgehog", 336 | "335": "fox squirrel, eastern fox squirrel, Sciurus niger", 337 | "336": "marmot", 338 | "337": "beaver", 339 | "338": "guinea pig, Cavia cobaya", 340 | "339": "sorrel", 341 | "340": "zebra", 342 | "341": "hog, pig, grunter, squealer, Sus scrofa", 343 | "342": "wild boar, boar, Sus scrofa", 344 | "343": "warthog", 345 | "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", 346 | "345": "ox", 347 | "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", 348 | "347": "bison", 349 | "348": "ram, tup", 350 | "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", 351 | "350": "ibex, Capra ibex", 352 | "351": "hartebeest", 353 | "352": "impala, Aepyceros melampus", 354 | "353": "gazelle", 355 | "354": "Arabian camel, dromedary, Camelus dromedarius", 356 | "355": "llama", 357 | "356": "weasel", 358 | "357": "mink", 359 | "358": "polecat, fitch, foulmart, foumart, Mustela putorius", 360 | "359": "black-footed ferret, ferret, Mustela nigripes", 361 | "360": "otter", 362 | "361": "skunk, polecat, wood pussy", 363 | "362": "badger", 364 | "363": "armadillo", 365 | "364": "three-toed sloth, ai, Bradypus tridactylus", 366 | "365": "orangutan, orang, orangutang, Pongo pygmaeus", 367 | "366": "gorilla, Gorilla gorilla", 368 | "367": "chimpanzee, chimp, Pan troglodytes", 369 | "368": "gibbon, Hylobates lar", 370 | "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", 371 | "370": "guenon, guenon monkey", 372 | "371": "patas, hussar monkey, Erythrocebus patas", 373 | "372": "baboon", 374 | "373": "macaque", 375 | "374": "langur", 376 | "375": "colobus, colobus monkey", 377 | "376": "proboscis monkey, Nasalis larvatus", 378 | "377": "marmoset", 379 | "378": "capuchin, ringtail, Cebus capucinus", 380 | "379": "howler monkey, howler", 381 | "380": "titi, titi monkey", 382 | "381": "spider monkey, Ateles geoffroyi", 383 | "382": "squirrel monkey, Saimiri sciureus", 384 | "383": "Madagascar cat, ring-tailed lemur, Lemur catta", 385 | "384": "indri, indris, Indri indri, Indri brevicaudatus", 386 | "385": "Indian elephant, Elephas maximus", 387 | "386": "African elephant, Loxodonta africana", 388 | "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", 389 | "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", 390 | "389": "barracouta, snoek", 391 | "390": "eel", 392 | "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", 393 | "392": "rock beauty, Holocanthus tricolor", 394 | "393": "anemone fish", 395 | "394": "sturgeon", 396 | "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", 397 | "396": "lionfish", 398 | "397": "puffer, pufferfish, blowfish, globefish", 399 | "398": "abacus", 400 | "399": "abaya", 401 | "400": "academic gown, academic robe, judge's robe", 402 | "401": "accordion, piano accordion, squeeze box", 403 | "402": "acoustic guitar", 404 | "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", 405 | "404": "airliner", 406 | "405": "airship, dirigible", 407 | "406": "altar", 408 | "407": "ambulance", 409 | "408": "amphibian, amphibious vehicle", 410 | "409": "analog clock", 411 | "410": "apiary, bee house", 412 | "411": "apron", 413 | "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", 414 | "413": "assault rifle, assault gun", 415 | "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", 416 | "415": "bakery, bakeshop, bakehouse", 417 | "416": "balance beam, beam", 418 | "417": "balloon", 419 | "418": "ballpoint, ballpoint pen, ballpen, Biro", 420 | "419": "Band Aid", 421 | "420": "banjo", 422 | "421": "bannister, banister, balustrade, balusters, handrail", 423 | "422": "barbell", 424 | "423": "barber chair", 425 | "424": "barbershop", 426 | "425": "barn", 427 | "426": "barometer", 428 | "427": "barrel, cask", 429 | "428": "barrow, garden cart, lawn cart, wheelbarrow", 430 | "429": "baseball", 431 | "430": "basketball", 432 | "431": "bassinet", 433 | "432": "bassoon", 434 | "433": "bathing cap, swimming cap", 435 | "434": "bath towel", 436 | "435": "bathtub, bathing tub, bath, tub", 437 | "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", 438 | "437": "beacon, lighthouse, beacon light, pharos", 439 | "438": "beaker", 440 | "439": "bearskin, busby, shako", 441 | "440": "beer bottle", 442 | "441": "beer glass", 443 | "442": "bell cote, bell cot", 444 | "443": "bib", 445 | "444": "bicycle-built-for-two, tandem bicycle, tandem", 446 | "445": "bikini, two-piece", 447 | "446": "binder, ring-binder", 448 | "447": "binoculars, field glasses, opera glasses", 449 | "448": "birdhouse", 450 | "449": "boathouse", 451 | "450": "bobsled, bobsleigh, bob", 452 | "451": "bolo tie, bolo, bola tie, bola", 453 | "452": "bonnet, poke bonnet", 454 | "453": "bookcase", 455 | "454": "bookshop, bookstore, bookstall", 456 | "455": "bottlecap", 457 | "456": "bow", 458 | "457": "bow tie, bow-tie, bowtie", 459 | "458": "brass, memorial tablet, plaque", 460 | "459": "brassiere, bra, bandeau", 461 | "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", 462 | "461": "breastplate, aegis, egis", 463 | "462": "broom", 464 | "463": "bucket, pail", 465 | "464": "buckle", 466 | "465": "bulletproof vest", 467 | "466": "bullet train, bullet", 468 | "467": "butcher shop, meat market", 469 | "468": "cab, hack, taxi, taxicab", 470 | "469": "caldron, cauldron", 471 | "470": "candle, taper, wax light", 472 | "471": "cannon", 473 | "472": "canoe", 474 | "473": "can opener, tin opener", 475 | "474": "cardigan", 476 | "475": "car mirror", 477 | "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", 478 | "477": "carpenter's kit, tool kit", 479 | "478": "carton", 480 | "479": "car wheel", 481 | "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", 482 | "481": "cassette", 483 | "482": "cassette player", 484 | "483": "castle", 485 | "484": "catamaran", 486 | "485": "CD player", 487 | "486": "cello, violoncello", 488 | "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", 489 | "488": "chain", 490 | "489": "chainlink fence", 491 | "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", 492 | "491": "chain saw, chainsaw", 493 | "492": "chest", 494 | "493": "chiffonier, commode", 495 | "494": "chime, bell, gong", 496 | "495": "china cabinet, china closet", 497 | "496": "Christmas stocking", 498 | "497": "church, church building", 499 | "498": "cinema, movie theater, movie theatre, movie house, picture palace", 500 | "499": "cleaver, meat cleaver, chopper", 501 | "500": "cliff dwelling", 502 | "501": "cloak", 503 | "502": "clog, geta, patten, sabot", 504 | "503": "cocktail shaker", 505 | "504": "coffee mug", 506 | "505": "coffeepot", 507 | "506": "coil, spiral, volute, whorl, helix", 508 | "507": "combination lock", 509 | "508": "computer keyboard, keypad", 510 | "509": "confectionery, confectionary, candy store", 511 | "510": "container ship, containership, container vessel", 512 | "511": "convertible", 513 | "512": "corkscrew, bottle screw", 514 | "513": "cornet, horn, trumpet, trump", 515 | "514": "cowboy boot", 516 | "515": "cowboy hat, ten-gallon hat", 517 | "516": "cradle", 518 | "517": "crane", 519 | "518": "crash helmet", 520 | "519": "crate", 521 | "520": "crib, cot", 522 | "521": "Crock Pot", 523 | "522": "croquet ball", 524 | "523": "crutch", 525 | "524": "cuirass", 526 | "525": "dam, dike, dyke", 527 | "526": "desk", 528 | "527": "desktop computer", 529 | "528": "dial telephone, dial phone", 530 | "529": "diaper, nappy, napkin", 531 | "530": "digital clock", 532 | "531": "digital watch", 533 | "532": "dining table, board", 534 | "533": "dishrag, dishcloth", 535 | "534": "dishwasher, dish washer, dishwashing machine", 536 | "535": "disk brake, disc brake", 537 | "536": "dock, dockage, docking facility", 538 | "537": "dogsled, dog sled, dog sleigh", 539 | "538": "dome", 540 | "539": "doormat, welcome mat", 541 | "540": "drilling platform, offshore rig", 542 | "541": "drum, membranophone, tympan", 543 | "542": "drumstick", 544 | "543": "dumbbell", 545 | "544": "Dutch oven", 546 | "545": "electric fan, blower", 547 | "546": "electric guitar", 548 | "547": "electric locomotive", 549 | "548": "entertainment center", 550 | "549": "envelope", 551 | "550": "espresso maker", 552 | "551": "face powder", 553 | "552": "feather boa, boa", 554 | "553": "file, file cabinet, filing cabinet", 555 | "554": "fireboat", 556 | "555": "fire engine, fire truck", 557 | "556": "fire screen, fireguard", 558 | "557": "flagpole, flagstaff", 559 | "558": "flute, transverse flute", 560 | "559": "folding chair", 561 | "560": "football helmet", 562 | "561": "forklift", 563 | "562": "fountain", 564 | "563": "fountain pen", 565 | "564": "four-poster", 566 | "565": "freight car", 567 | "566": "French horn, horn", 568 | "567": "frying pan, frypan, skillet", 569 | "568": "fur coat", 570 | "569": "garbage truck, dustcart", 571 | "570": "gasmask, respirator, gas helmet", 572 | "571": "gas pump, gasoline pump, petrol pump, island dispenser", 573 | "572": "goblet", 574 | "573": "go-kart", 575 | "574": "golf ball", 576 | "575": "golfcart, golf cart", 577 | "576": "gondola", 578 | "577": "gong, tam-tam", 579 | "578": "gown", 580 | "579": "grand piano, grand", 581 | "580": "greenhouse, nursery, glasshouse", 582 | "581": "grille, radiator grille", 583 | "582": "grocery store, grocery, food market, market", 584 | "583": "guillotine", 585 | "584": "hair slide", 586 | "585": "hair spray", 587 | "586": "half track", 588 | "587": "hammer", 589 | "588": "hamper", 590 | "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", 591 | "590": "hand-held computer, hand-held microcomputer", 592 | "591": "handkerchief, hankie, hanky, hankey", 593 | "592": "hard disc, hard disk, fixed disk", 594 | "593": "harmonica, mouth organ, harp, mouth harp", 595 | "594": "harp", 596 | "595": "harvester, reaper", 597 | "596": "hatchet", 598 | "597": "holster", 599 | "598": "home theater, home theatre", 600 | "599": "honeycomb", 601 | "600": "hook, claw", 602 | "601": "hoopskirt, crinoline", 603 | "602": "horizontal bar, high bar", 604 | "603": "horse cart, horse-cart", 605 | "604": "hourglass", 606 | "605": "iPod", 607 | "606": "iron, smoothing iron", 608 | "607": "jack-o'-lantern", 609 | "608": "jean, blue jean, denim", 610 | "609": "jeep, landrover", 611 | "610": "jersey, T-shirt, tee shirt", 612 | "611": "jigsaw puzzle", 613 | "612": "jinrikisha, ricksha, rickshaw", 614 | "613": "joystick", 615 | "614": "kimono", 616 | "615": "knee pad", 617 | "616": "knot", 618 | "617": "lab coat, laboratory coat", 619 | "618": "ladle", 620 | "619": "lampshade, lamp shade", 621 | "620": "laptop, laptop computer", 622 | "621": "lawn mower, mower", 623 | "622": "lens cap, lens cover", 624 | "623": "letter opener, paper knife, paperknife", 625 | "624": "library", 626 | "625": "lifeboat", 627 | "626": "lighter, light, igniter, ignitor", 628 | "627": "limousine, limo", 629 | "628": "liner, ocean liner", 630 | "629": "lipstick, lip rouge", 631 | "630": "Loafer", 632 | "631": "lotion", 633 | "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", 634 | "633": "loupe, jeweler's loupe", 635 | "634": "lumbermill, sawmill", 636 | "635": "magnetic compass", 637 | "636": "mailbag, postbag", 638 | "637": "mailbox, letter box", 639 | "638": "maillot", 640 | "639": "maillot, tank suit", 641 | "640": "manhole cover", 642 | "641": "maraca", 643 | "642": "marimba, xylophone", 644 | "643": "mask", 645 | "644": "matchstick", 646 | "645": "maypole", 647 | "646": "maze, labyrinth", 648 | "647": "measuring cup", 649 | "648": "medicine chest, medicine cabinet", 650 | "649": "megalith, megalithic structure", 651 | "650": "microphone, mike", 652 | "651": "microwave, microwave oven", 653 | "652": "military uniform", 654 | "653": "milk can", 655 | "654": "minibus", 656 | "655": "miniskirt, mini", 657 | "656": "minivan", 658 | "657": "missile", 659 | "658": "mitten", 660 | "659": "mixing bowl", 661 | "660": "mobile home, manufactured home", 662 | "661": "Model T", 663 | "662": "modem", 664 | "663": "monastery", 665 | "664": "monitor", 666 | "665": "moped", 667 | "666": "mortar", 668 | "667": "mortarboard", 669 | "668": "mosque", 670 | "669": "mosquito net", 671 | "670": "motor scooter, scooter", 672 | "671": "mountain bike, all-terrain bike, off-roader", 673 | "672": "mountain tent", 674 | "673": "mouse, computer mouse", 675 | "674": "mousetrap", 676 | "675": "moving van", 677 | "676": "muzzle", 678 | "677": "nail", 679 | "678": "neck brace", 680 | "679": "necklace", 681 | "680": "nipple", 682 | "681": "notebook, notebook computer", 683 | "682": "obelisk", 684 | "683": "oboe, hautboy, hautbois", 685 | "684": "ocarina, sweet potato", 686 | "685": "odometer, hodometer, mileometer, milometer", 687 | "686": "oil filter", 688 | "687": "organ, pipe organ", 689 | "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", 690 | "689": "overskirt", 691 | "690": "oxcart", 692 | "691": "oxygen mask", 693 | "692": "packet", 694 | "693": "paddle, boat paddle", 695 | "694": "paddlewheel, paddle wheel", 696 | "695": "padlock", 697 | "696": "paintbrush", 698 | "697": "pajama, pyjama, pj's, jammies", 699 | "698": "palace", 700 | "699": "panpipe, pandean pipe, syrinx", 701 | "700": "paper towel", 702 | "701": "parachute, chute", 703 | "702": "parallel bars, bars", 704 | "703": "park bench", 705 | "704": "parking meter", 706 | "705": "passenger car, coach, carriage", 707 | "706": "patio, terrace", 708 | "707": "pay-phone, pay-station", 709 | "708": "pedestal, plinth, footstall", 710 | "709": "pencil box, pencil case", 711 | "710": "pencil sharpener", 712 | "711": "perfume, essence", 713 | "712": "Petri dish", 714 | "713": "photocopier", 715 | "714": "pick, plectrum, plectron", 716 | "715": "pickelhaube", 717 | "716": "picket fence, paling", 718 | "717": "pickup, pickup truck", 719 | "718": "pier", 720 | "719": "piggy bank, penny bank", 721 | "720": "pill bottle", 722 | "721": "pillow", 723 | "722": "ping-pong ball", 724 | "723": "pinwheel", 725 | "724": "pirate, pirate ship", 726 | "725": "pitcher, ewer", 727 | "726": "plane, carpenter's plane, woodworking plane", 728 | "727": "planetarium", 729 | "728": "plastic bag", 730 | "729": "plate rack", 731 | "730": "plow, plough", 732 | "731": "plunger, plumber's helper", 733 | "732": "Polaroid camera, Polaroid Land camera", 734 | "733": "pole", 735 | "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", 736 | "735": "poncho", 737 | "736": "pool table, billiard table, snooker table", 738 | "737": "pop bottle, soda bottle", 739 | "738": "pot, flowerpot", 740 | "739": "potter's wheel", 741 | "740": "power drill", 742 | "741": "prayer rug, prayer mat", 743 | "742": "printer", 744 | "743": "prison, prison house", 745 | "744": "projectile, missile", 746 | "745": "projector", 747 | "746": "puck, hockey puck", 748 | "747": "punching bag, punch bag, punching ball, punchball", 749 | "748": "purse", 750 | "749": "quill, quill pen", 751 | "750": "quilt, comforter, comfort, puff", 752 | "751": "racer, race car, racing car", 753 | "752": "racket, racquet", 754 | "753": "radiator", 755 | "754": "radio, wireless", 756 | "755": "radio telescope, radio reflector", 757 | "756": "rain barrel", 758 | "757": "recreational vehicle, RV, R.V.", 759 | "758": "reel", 760 | "759": "reflex camera", 761 | "760": "refrigerator, icebox", 762 | "761": "remote control, remote", 763 | "762": "restaurant, eating house, eating place, eatery", 764 | "763": "revolver, six-gun, six-shooter", 765 | "764": "rifle", 766 | "765": "rocking chair, rocker", 767 | "766": "rotisserie", 768 | "767": "rubber eraser, rubber, pencil eraser", 769 | "768": "rugby ball", 770 | "769": "rule, ruler", 771 | "770": "running shoe", 772 | "771": "safe", 773 | "772": "safety pin", 774 | "773": "saltshaker, salt shaker", 775 | "774": "sandal", 776 | "775": "sarong", 777 | "776": "sax, saxophone", 778 | "777": "scabbard", 779 | "778": "scale, weighing machine", 780 | "779": "school bus", 781 | "780": "schooner", 782 | "781": "scoreboard", 783 | "782": "screen, CRT screen", 784 | "783": "screw", 785 | "784": "screwdriver", 786 | "785": "seat belt, seatbelt", 787 | "786": "sewing machine", 788 | "787": "shield, buckler", 789 | "788": "shoe shop, shoe-shop, shoe store", 790 | "789": "shoji", 791 | "790": "shopping basket", 792 | "791": "shopping cart", 793 | "792": "shovel", 794 | "793": "shower cap", 795 | "794": "shower curtain", 796 | "795": "ski", 797 | "796": "ski mask", 798 | "797": "sleeping bag", 799 | "798": "slide rule, slipstick", 800 | "799": "sliding door", 801 | "800": "slot, one-armed bandit", 802 | "801": "snorkel", 803 | "802": "snowmobile", 804 | "803": "snowplow, snowplough", 805 | "804": "soap dispenser", 806 | "805": "soccer ball", 807 | "806": "sock", 808 | "807": "solar dish, solar collector, solar furnace", 809 | "808": "sombrero", 810 | "809": "soup bowl", 811 | "810": "space bar", 812 | "811": "space heater", 813 | "812": "space shuttle", 814 | "813": "spatula", 815 | "814": "speedboat", 816 | "815": "spider web, spider's web", 817 | "816": "spindle", 818 | "817": "sports car, sport car", 819 | "818": "spotlight, spot", 820 | "819": "stage", 821 | "820": "steam locomotive", 822 | "821": "steel arch bridge", 823 | "822": "steel drum", 824 | "823": "stethoscope", 825 | "824": "stole", 826 | "825": "stone wall", 827 | "826": "stopwatch, stop watch", 828 | "827": "stove", 829 | "828": "strainer", 830 | "829": "streetcar, tram, tramcar, trolley, trolley car", 831 | "830": "stretcher", 832 | "831": "studio couch, day bed", 833 | "832": "stupa, tope", 834 | "833": "submarine, pigboat, sub, U-boat", 835 | "834": "suit, suit of clothes", 836 | "835": "sundial", 837 | "836": "sunglass", 838 | "837": "sunglasses, dark glasses, shades", 839 | "838": "sunscreen, sunblock, sun blocker", 840 | "839": "suspension bridge", 841 | "840": "swab, swob, mop", 842 | "841": "sweatshirt", 843 | "842": "swimming trunks, bathing trunks", 844 | "843": "swing", 845 | "844": "switch, electric switch, electrical switch", 846 | "845": "syringe", 847 | "846": "table lamp", 848 | "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", 849 | "848": "tape player", 850 | "849": "teapot", 851 | "850": "teddy, teddy bear", 852 | "851": "television, television system", 853 | "852": "tennis ball", 854 | "853": "thatch, thatched roof", 855 | "854": "theater curtain, theatre curtain", 856 | "855": "thimble", 857 | "856": "thresher, thrasher, threshing machine", 858 | "857": "throne", 859 | "858": "tile roof", 860 | "859": "toaster", 861 | "860": "tobacco shop, tobacconist shop, tobacconist", 862 | "861": "toilet seat", 863 | "862": "torch", 864 | "863": "totem pole", 865 | "864": "tow truck, tow car, wrecker", 866 | "865": "toyshop", 867 | "866": "tractor", 868 | "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", 869 | "868": "tray", 870 | "869": "trench coat", 871 | "870": "tricycle, trike, velocipede", 872 | "871": "trimaran", 873 | "872": "tripod", 874 | "873": "triumphal arch", 875 | "874": "trolleybus, trolley coach, trackless trolley", 876 | "875": "trombone", 877 | "876": "tub, vat", 878 | "877": "turnstile", 879 | "878": "typewriter keyboard", 880 | "879": "umbrella", 881 | "880": "unicycle, monocycle", 882 | "881": "upright, upright piano", 883 | "882": "vacuum, vacuum cleaner", 884 | "883": "vase", 885 | "884": "vault", 886 | "885": "velvet", 887 | "886": "vending machine", 888 | "887": "vestment", 889 | "888": "viaduct", 890 | "889": "violin, fiddle", 891 | "890": "volleyball", 892 | "891": "waffle iron", 893 | "892": "wall clock", 894 | "893": "wallet, billfold, notecase, pocketbook", 895 | "894": "wardrobe, closet, press", 896 | "895": "warplane, military plane", 897 | "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", 898 | "897": "washer, automatic washer, washing machine", 899 | "898": "water bottle", 900 | "899": "water jug", 901 | "900": "water tower", 902 | "901": "whiskey jug", 903 | "902": "whistle", 904 | "903": "wig", 905 | "904": "window screen", 906 | "905": "window shade", 907 | "906": "Windsor tie", 908 | "907": "wine bottle", 909 | "908": "wing", 910 | "909": "wok", 911 | "910": "wooden spoon", 912 | "911": "wool, woolen, woollen", 913 | "912": "worm fence, snake fence, snake-rail fence, Virginia fence", 914 | "913": "wreck", 915 | "914": "yawl", 916 | "915": "yurt", 917 | "916": "web site, website, internet site, site", 918 | "917": "comic book", 919 | "918": "crossword puzzle, crossword", 920 | "919": "street sign", 921 | "920": "traffic light, traffic signal, stoplight", 922 | "921": "book jacket, dust cover, dust jacket, dust wrapper", 923 | "922": "menu", 924 | "923": "plate", 925 | "924": "guacamole", 926 | "925": "consomme", 927 | "926": "hot pot, hotpot", 928 | "927": "trifle", 929 | "928": "ice cream, icecream", 930 | "929": "ice lolly, lolly, lollipop, popsicle", 931 | "930": "French loaf", 932 | "931": "bagel, beigel", 933 | "932": "pretzel", 934 | "933": "cheeseburger", 935 | "934": "hotdog, hot dog, red hot", 936 | "935": "mashed potato", 937 | "936": "head cabbage", 938 | "937": "broccoli", 939 | "938": "cauliflower", 940 | "939": "zucchini, courgette", 941 | "940": "spaghetti squash", 942 | "941": "acorn squash", 943 | "942": "butternut squash", 944 | "943": "cucumber, cuke", 945 | "944": "artichoke, globe artichoke", 946 | "945": "bell pepper", 947 | "946": "cardoon", 948 | "947": "mushroom", 949 | "948": "Granny Smith", 950 | "949": "strawberry", 951 | "950": "orange", 952 | "951": "lemon", 953 | "952": "fig", 954 | "953": "pineapple, ananas", 955 | "954": "banana", 956 | "955": "jackfruit, jak, jack", 957 | "956": "custard apple", 958 | "957": "pomegranate", 959 | "958": "hay", 960 | "959": "carbonara", 961 | "960": "chocolate sauce, chocolate syrup", 962 | "961": "dough", 963 | "962": "meat loaf, meatloaf", 964 | "963": "pizza, pizza pie", 965 | "964": "potpie", 966 | "965": "burrito", 967 | "966": "red wine", 968 | "967": "espresso", 969 | "968": "cup", 970 | "969": "eggnog", 971 | "970": "alp", 972 | "971": "bubble", 973 | "972": "cliff, drop, drop-off", 974 | "973": "coral reef", 975 | "974": "geyser", 976 | "975": "lakeside, lakeshore", 977 | "976": "promontory, headland, head, foreland", 978 | "977": "sandbar, sand bar", 979 | "978": "seashore, coast, seacoast, sea-coast", 980 | "979": "valley, vale", 981 | "980": "volcano", 982 | "981": "ballplayer, baseball player", 983 | "982": "groom, bridegroom", 984 | "983": "scuba diver", 985 | "984": "rapeseed", 986 | "985": "daisy", 987 | "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 988 | "987": "corn", 989 | "988": "acorn", 990 | "989": "hip, rose hip, rosehip", 991 | "990": "buckeye, horse chestnut, conker", 992 | "991": "coral fungus", 993 | "992": "agaric", 994 | "993": "gyromitra", 995 | "994": "stinkhorn, carrion fungus", 996 | "995": "earthstar", 997 | "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", 998 | "997": "bolete", 999 | "998": "ear, spike, capitulum", 1000 | "999": "toilet tissue, toilet paper, bathroom tissue"} -------------------------------------------------------------------------------- /data/inpainting/kate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/inpainting/kate.png -------------------------------------------------------------------------------- /data/inpainting/kate_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/inpainting/kate_mask.png -------------------------------------------------------------------------------- /data/inpainting/library.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/inpainting/library.png -------------------------------------------------------------------------------- /data/inpainting/library_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/inpainting/library_mask.png -------------------------------------------------------------------------------- /data/inpainting/vase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/inpainting/vase.png -------------------------------------------------------------------------------- /data/inpainting/vase_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/inpainting/vase_mask.png -------------------------------------------------------------------------------- /data/restoration/barbara.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/restoration/barbara.png -------------------------------------------------------------------------------- /data/restoration/kate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/restoration/kate.png -------------------------------------------------------------------------------- /data/sr/zebra_GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/sr/zebra_GT.png -------------------------------------------------------------------------------- /data/sr/zebra_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/sr/zebra_crop.png -------------------------------------------------------------------------------- /data/teaser_compiled.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/data/teaser_compiled.jpg -------------------------------------------------------------------------------- /denoising.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for **\"Blind restoration of a JPEG-compressed image\"** and **\"Blind image denoising\"** figures. Select `fname` below to switch between the two.\n", 8 | "\n", 9 | "- To see overfitting set `num_iter` to a large value." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "\"\"\"\n", 19 | "*Uncomment if running on colab* \n", 20 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 21 | "\"\"\"\n", 22 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 23 | "# !mv deep-image-prior/* ./" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Import libs" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "from __future__ import print_function\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "%matplotlib inline\n", 42 | "\n", 43 | "import os\n", 44 | "#os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n", 45 | "\n", 46 | "import numpy as np\n", 47 | "from models import *\n", 48 | "\n", 49 | "import torch\n", 50 | "import torch.optim\n", 51 | "\n", 52 | "from skimage.measure import compare_psnr\n", 53 | "from utils.denoising_utils import *\n", 54 | "\n", 55 | "torch.backends.cudnn.enabled = True\n", 56 | "torch.backends.cudnn.benchmark =True\n", 57 | "dtype = torch.cuda.FloatTensor\n", 58 | "\n", 59 | "imsize =-1\n", 60 | "PLOT = True\n", 61 | "sigma = 25\n", 62 | "sigma_ = sigma/255." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "# deJPEG \n", 72 | "# fname = 'data/denoising/snail.jpg'\n", 73 | "\n", 74 | "## denoising\n", 75 | "fname = 'data/denoising/F16_GT.png'" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "# Load image" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "if fname == 'data/denoising/snail.jpg':\n", 92 | " img_noisy_pil = crop_image(get_image(fname, imsize)[0], d=32)\n", 93 | " img_noisy_np = pil_to_np(img_noisy_pil)\n", 94 | " \n", 95 | " # As we don't have ground truth\n", 96 | " img_pil = img_noisy_pil\n", 97 | " img_np = img_noisy_np\n", 98 | " \n", 99 | " if PLOT:\n", 100 | " plot_image_grid([img_np], 4, 5);\n", 101 | " \n", 102 | "elif fname == 'data/denoising/F16_GT.png':\n", 103 | " # Add synthetic noise\n", 104 | " img_pil = crop_image(get_image(fname, imsize)[0], d=32)\n", 105 | " img_np = pil_to_np(img_pil)\n", 106 | " \n", 107 | " img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)\n", 108 | " \n", 109 | " if PLOT:\n", 110 | " plot_image_grid([img_np, img_noisy_np], 4, 6);\n", 111 | "else:\n", 112 | " assert False" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "# Setup" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "INPUT = 'noise' # 'meshgrid'\n", 129 | "pad = 'reflection'\n", 130 | "OPT_OVER = 'net' # 'net,input'\n", 131 | "\n", 132 | "reg_noise_std = 1./30. # set to 1./20. for sigma=50\n", 133 | "LR = 0.01\n", 134 | "\n", 135 | "OPTIMIZER='adam' # 'LBFGS'\n", 136 | "show_every = 100\n", 137 | "exp_weight=0.99\n", 138 | "\n", 139 | "if fname == 'data/denoising/snail.jpg':\n", 140 | " num_iter = 2400\n", 141 | " input_depth = 3\n", 142 | " figsize = 5 \n", 143 | " \n", 144 | " net = skip(\n", 145 | " input_depth, 3, \n", 146 | " num_channels_down = [8, 16, 32, 64, 128], \n", 147 | " num_channels_up = [8, 16, 32, 64, 128],\n", 148 | " num_channels_skip = [0, 0, 0, 4, 4], \n", 149 | " upsample_mode='bilinear',\n", 150 | " need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')\n", 151 | "\n", 152 | " net = net.type(dtype)\n", 153 | "\n", 154 | "elif fname == 'data/denoising/F16_GT.png':\n", 155 | " num_iter = 3000\n", 156 | " input_depth = 32 \n", 157 | " figsize = 4 \n", 158 | " \n", 159 | " \n", 160 | " net = get_net(input_depth, 'skip', pad,\n", 161 | " skip_n33d=128, \n", 162 | " skip_n33u=128, \n", 163 | " skip_n11=4, \n", 164 | " num_scales=5,\n", 165 | " upsample_mode='bilinear').type(dtype)\n", 166 | "\n", 167 | "else:\n", 168 | " assert False\n", 169 | " \n", 170 | "net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()\n", 171 | "\n", 172 | "# Compute number of parameters\n", 173 | "s = sum([np.prod(list(p.size())) for p in net.parameters()]); \n", 174 | "print ('Number of params: %d' % s)\n", 175 | "\n", 176 | "# Loss\n", 177 | "mse = torch.nn.MSELoss().type(dtype)\n", 178 | "\n", 179 | "img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "# Optimize" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "scrolled": true 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "net_input_saved = net_input.detach().clone()\n", 198 | "noise = net_input.detach().clone()\n", 199 | "out_avg = None\n", 200 | "last_net = None\n", 201 | "psrn_noisy_last = 0\n", 202 | "\n", 203 | "i = 0\n", 204 | "def closure():\n", 205 | " \n", 206 | " global i, out_avg, psrn_noisy_last, last_net, net_input\n", 207 | " \n", 208 | " if reg_noise_std > 0:\n", 209 | " net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n", 210 | " \n", 211 | " out = net(net_input)\n", 212 | " \n", 213 | " # Smoothing\n", 214 | " if out_avg is None:\n", 215 | " out_avg = out.detach()\n", 216 | " else:\n", 217 | " out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)\n", 218 | " \n", 219 | " total_loss = mse(out, img_noisy_torch)\n", 220 | " total_loss.backward()\n", 221 | " \n", 222 | " \n", 223 | " psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0]) \n", 224 | " psrn_gt = compare_psnr(img_np, out.detach().cpu().numpy()[0]) \n", 225 | " psrn_gt_sm = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0]) \n", 226 | " \n", 227 | " # Note that we do not have GT for the \"snail\" example\n", 228 | " # So 'PSRN_gt', 'PSNR_gt_sm' make no sense\n", 229 | " print ('Iteration %05d Loss %f PSNR_noisy: %f PSRN_gt: %f PSNR_gt_sm: %f' % (i, total_loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm), '\\r', end='')\n", 230 | " if PLOT and i % show_every == 0:\n", 231 | " out_np = torch_to_np(out)\n", 232 | " plot_image_grid([np.clip(out_np, 0, 1), \n", 233 | " np.clip(torch_to_np(out_avg), 0, 1)], factor=figsize, nrow=1)\n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " # Backtracking\n", 238 | " if i % show_every:\n", 239 | " if psrn_noisy - psrn_noisy_last < -5: \n", 240 | " print('Falling back to previous checkpoint.')\n", 241 | "\n", 242 | " for new_param, net_param in zip(last_net, net.parameters()):\n", 243 | " net_param.data.copy_(new_param.cuda())\n", 244 | "\n", 245 | " return total_loss*0\n", 246 | " else:\n", 247 | " last_net = [x.detach().cpu() for x in net.parameters()]\n", 248 | " psrn_noisy_last = psrn_noisy\n", 249 | " \n", 250 | " i += 1\n", 251 | "\n", 252 | " return total_loss\n", 253 | "\n", 254 | "p = get_params(OPT_OVER, net, net_input)\n", 255 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "out_np = torch_to_np(net(net_input))\n", 265 | "q = plot_image_grid([np.clip(out_np, 0, 1), img_np], factor=13);" 266 | ] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "Python 3", 272 | "language": "python", 273 | "name": "python3" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.6.9" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 2 290 | } 291 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deep-image-prior 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - jupyter 7 | - nb_conda 8 | - numpy 9 | - pyyaml 10 | - mkl 11 | - setuptools 12 | - cmake 13 | - cffi 14 | - pytorch=0.4 15 | - matplotlib 16 | - scikit-image 17 | - torchvision 18 | -------------------------------------------------------------------------------- /feature_inversion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for **\"AlexNet inversion\"** figure from the main paper and **\"VGG inversion\"** from supmat." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\"\"\"\n", 17 | "*Uncomment if running on colab* \n", 18 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 19 | "\"\"\"\n", 20 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 21 | "# !mv deep-image-prior/* ./" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Import libs" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from __future__ import print_function\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline\n", 40 | "\n", 41 | "import argparse\n", 42 | "import os\n", 43 | "#os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", 44 | "\n", 45 | "import numpy as np\n", 46 | "from models import *\n", 47 | "\n", 48 | "import torch\n", 49 | "import torch.optim\n", 50 | "\n", 51 | "from utils.feature_inversion_utils import *\n", 52 | "from utils.perceptual_loss.perceptual_loss import get_pretrained_net\n", 53 | "from utils.common_utils import *\n", 54 | "\n", 55 | "torch.backends.cudnn.enabled = True\n", 56 | "torch.backends.cudnn.benchmark =True\n", 57 | "dtype = torch.cuda.FloatTensor\n", 58 | "\n", 59 | "PLOT = True\n", 60 | "fname = './data/feature_inversion/building.jpg'\n", 61 | "\n", 62 | "pretrained_net = 'alexnet_caffe' # 'vgg19_caffe'\n", 63 | "layers_to_use = 'fc6' # comma-separated string of layer names e.g. 'fc6,fc7'" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "# Setup pretrained net" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "cnn = get_pretrained_net(pretrained_net).type(dtype)\n", 80 | "\n", 81 | "opt_content = {'layers': layers_to_use, 'what':'features'}\n", 82 | "\n", 83 | "# Remove the layers we don't need \n", 84 | "keys = [x for x in cnn._modules.keys()]\n", 85 | "max_idx = max(keys.index(x) for x in opt_content['layers'].split(','))\n", 86 | "for k in keys[max_idx+1:]:\n", 87 | " cnn._modules.pop(k)\n", 88 | " \n", 89 | "print(cnn)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "# Load image" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# Target imsize \n", 106 | "imsize = 227 if pretrained_net == 'alexnet' else 224\n", 107 | "\n", 108 | "# Something divisible by a power of two\n", 109 | "imsize_net = 256\n", 110 | "\n", 111 | "# VGG and Alexnet need input to be correctly normalized\n", 112 | "preprocess, deprocess = get_preprocessor(imsize), get_deprocessor()\n", 113 | "\n", 114 | "\n", 115 | "img_content_pil, img_content_np = get_image(fname, imsize)\n", 116 | "img_content_prerocessed = preprocess(img_content_pil)[None,:].type(dtype)\n", 117 | "\n", 118 | "img_content_pil" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "# Setup matcher and net" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "scrolled": false 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "matcher_content = get_matcher(cnn, opt_content)\n", 137 | "\n", 138 | "matcher_content.mode = 'store'\n", 139 | "cnn(img_content_prerocessed);" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "INPUT = 'noise'\n", 149 | "pad = 'zero' # 'refection'\n", 150 | "OPT_OVER = 'net' #'net,input'\n", 151 | "OPTIMIZER = 'adam' # 'LBFGS'\n", 152 | "LR = 0.001\n", 153 | "\n", 154 | "num_iter = 3100\n", 155 | "\n", 156 | "input_depth = 32\n", 157 | "net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],\n", 167 | " num_channels_up = [16, 32, 64, 128, 128, 128],\n", 168 | " num_channels_skip = [4, 4, 4, 4, 4, 4], \n", 169 | " filter_size_down = [7, 7, 5, 5, 3, 3], filter_size_up = [7, 7, 5, 5, 3, 3], \n", 170 | " upsample_mode='nearest', downsample_mode='avg',\n", 171 | " need_sigmoid=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n", 172 | "\n", 173 | "# Compute number of parameters\n", 174 | "s = sum(np.prod(list(p.size())) for p in net.parameters())\n", 175 | "print ('Number of params: %d' % s)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "# Optimize" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "def closure():\n", 192 | " \n", 193 | " global i\n", 194 | " \n", 195 | " out = net(net_input)[:, :, :imsize, :imsize]\n", 196 | " \n", 197 | " cnn(vgg_preprocess_var(out))\n", 198 | " total_loss = sum(matcher_content.losses.values())\n", 199 | " total_loss.backward()\n", 200 | " \n", 201 | " print ('Iteration %05d Loss %.3f' % (i, total_loss.item()), '\\r', end='')\n", 202 | " if PLOT and i % 200 == 0:\n", 203 | " out_np = np.clip(torch_to_np(out), 0, 1)\n", 204 | " plot_image_grid([out_np], 3, 3);\n", 205 | "\n", 206 | " i += 1\n", 207 | " \n", 208 | " return total_loss" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "i=0\n", 218 | "matcher_content.mode = 'match'\n", 219 | "p = get_params(OPT_OVER, net, net_input)\n", 220 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "# Result" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "out = net(net_input)[:, :, :imsize, :imsize]\n", 237 | "plot_image_grid([torch_to_np(out)], 3, 3);" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "The code above was used to produce the images from the paper." 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "# Appedndix: more noise" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "We also found adding heavy noise sometimes improves the results (see below). Interestingly, network manages to adapt to a very heavy noise." 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "input_depth = 2\n", 268 | "net_input = get_noise(input_depth, INPUT, imsize_net).type(dtype).detach()\n", 269 | "\n", 270 | "net = skip(input_depth, 3, num_channels_down = [16, 32, 64, 128, 128, 128],\n", 271 | " num_channels_up = [16, 32, 64, 128, 128, 128],\n", 272 | " num_channels_skip = [4, 4, 4, 4, 4, 4], \n", 273 | " filter_size_up = [7, 7, 5, 5, 3, 3], filter_size_down = [7, 7, 5, 5, 3, 3],\n", 274 | " upsample_mode='nearest', downsample_mode='avg',\n", 275 | " need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "def closure():\n", 285 | " \n", 286 | " global i \n", 287 | " if i < 10000:\n", 288 | " # Weight noise\n", 289 | " for n in [x for x in net.parameters() if len(x) == 4]:\n", 290 | " n = n + n.detach().clone().normal_()*n.std()/50\n", 291 | " \n", 292 | " # Input noise\n", 293 | " net_input = net_input_saved + (noise.normal_() * 10)\n", 294 | "\n", 295 | " elif i < 15000:\n", 296 | " # Weight noise\n", 297 | " for n in [x for x in net.parameters() if len(x) == 4]:\n", 298 | " n = n + n.detach().clone().normal_()*n.std()/100\n", 299 | " \n", 300 | " # Input noise\n", 301 | " net_input = net_input_saved + (noise.normal_() * 2)\n", 302 | " \n", 303 | " elif i < 20000:\n", 304 | " # Input noise\n", 305 | " net_input = net_input_saved + (noise.normal_() / 2)\n", 306 | " \n", 307 | " \n", 308 | " out = net(net_input)[:, :, :imsize, :imsize]\n", 309 | " \n", 310 | " cnn(vgg_preprocess_var(out))\n", 311 | " total_loss = sum(matcher_content.losses.values())\n", 312 | " total_loss.backward()\n", 313 | " \n", 314 | " print ('Iteration %05d Loss %.3f' % (i, total_loss.item()), '\\r', end='')\n", 315 | " if PLOT and i % 1000==0:\n", 316 | " out_np = np.clip(torch_to_np(out), 0, 1)\n", 317 | " plot_image_grid([out_np], 3, 3);\n", 318 | "\n", 319 | " i += 1\n", 320 | " \n", 321 | " return total_loss" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "num_iter = 20000\n", 331 | "LR = 0.01\n", 332 | "\n", 333 | "net_input_saved = net_input.detach().clone()\n", 334 | "noise = net_input.detach().clone()\n", 335 | "i=0\n", 336 | "\n", 337 | "matcher_content.mode = 'match'\n", 338 | "p = get_params(OPT_OVER, net, net_input)\n", 339 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [] 348 | } 349 | ], 350 | "metadata": { 351 | "kernelspec": { 352 | "display_name": "Python 3", 353 | "language": "python", 354 | "name": "python3" 355 | }, 356 | "language_info": { 357 | "codemirror_mode": { 358 | "name": "ipython", 359 | "version": 3 360 | }, 361 | "file_extension": ".py", 362 | "mimetype": "text/x-python", 363 | "name": "python", 364 | "nbconvert_exporter": "python", 365 | "pygments_lexer": "ipython3", 366 | "version": "3.6.9" 367 | } 368 | }, 369 | "nbformat": 4, 370 | "nbformat_minor": 2 371 | } 372 | -------------------------------------------------------------------------------- /flash-no-flash.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for **\"Flash/No Flash\"** figure. " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\"\"\"\n", 17 | "*Uncomment if running on colab* \n", 18 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 19 | "\"\"\"\n", 20 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 21 | "# !mv deep-image-prior/* ./" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Import libs" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from __future__ import print_function\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline\n", 40 | "\n", 41 | "import os\n", 42 | "#os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n", 43 | "\n", 44 | "import numpy as np\n", 45 | "from models import *\n", 46 | "\n", 47 | "import torch\n", 48 | "import torch.optim\n", 49 | "\n", 50 | "from utils.denoising_utils import *\n", 51 | "from utils.sr_utils import load_LR_HR_imgs_sr\n", 52 | "torch.backends.cudnn.enabled = True\n", 53 | "torch.backends.cudnn.benchmark =True\n", 54 | "dtype = torch.cuda.FloatTensor\n", 55 | "\n", 56 | "imsize =-1\n", 57 | "PLOT = True" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "# Load image" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "imgs = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_00_flash.jpg', -1, 1, enforse_div32='CROP')\n", 74 | "img_flash = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_00_flash.jpg', -1, 1, enforse_div32='CROP')['HR_pil']\n", 75 | "img_flash_np = pil_to_np(img_flash)\n", 76 | "\n", 77 | "img_noflash = load_LR_HR_imgs_sr('data/flash_no_flash/cave01_01_noflash.jpg', -1, 1, enforse_div32='CROP')['HR_pil']\n", 78 | "img_noflash_np = pil_to_np(img_noflash)\n", 79 | "\n", 80 | "g = plot_image_grid([img_flash_np, img_noflash_np],3,12)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "# Setup" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "pad = 'reflection'\n", 97 | "OPT_OVER = 'net'\n", 98 | "\n", 99 | "num_iter = 601\n", 100 | "LR = 0.1 \n", 101 | "OPTIMIZER = 'adam'\n", 102 | "reg_noise_std = 0.0\n", 103 | "show_every = 50\n", 104 | "figsize = 6\n", 105 | "\n", 106 | "# We will use flash image as input\n", 107 | "input_depth = 3\n", 108 | "net_input =np_to_torch(img_flash_np).type(dtype)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "net = skip(input_depth, 3, num_channels_down = [128, 128, 128, 128, 128], \n", 118 | " num_channels_up = [128, 128, 128, 128, 128],\n", 119 | " num_channels_skip = [4, 4, 4, 4, 4], \n", 120 | " upsample_mode=['nearest', 'nearest', 'bilinear', 'bilinear', 'bilinear'], \n", 121 | " need_sigmoid=True, need_bias=True, pad=pad).type(dtype)\n", 122 | "\n", 123 | "mse = torch.nn.MSELoss().type(dtype)\n", 124 | "\n", 125 | "img_flash_var = np_to_torch(img_flash_np).type(dtype)\n", 126 | "img_noflash_var = np_to_torch(img_noflash_np).type(dtype)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "# Optimize" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "scrolled": false 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "net_input_saved = net_input.detach().clone()\n", 145 | "noise = net_input.detach().clone()\n", 146 | "\n", 147 | "\n", 148 | "i = 0\n", 149 | "def closure():\n", 150 | " \n", 151 | " global i, net_input\n", 152 | " \n", 153 | " if reg_noise_std > 0:\n", 154 | " net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n", 155 | " \n", 156 | " out = net(net_input)\n", 157 | " \n", 158 | " total_loss = mse(out, img_noflash_var)\n", 159 | " total_loss.backward()\n", 160 | " \n", 161 | " print ('Iteration %05d Loss %f' % (i, total_loss.item()), '\\r', end='')\n", 162 | " if PLOT and i % show_every == 0:\n", 163 | " out_np = torch_to_np(out)\n", 164 | " plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)\n", 165 | " \n", 166 | " i += 1\n", 167 | "\n", 168 | " return total_loss\n", 169 | "\n", 170 | "p = get_params(OPT_OVER, net, net_input)\n", 171 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "Sometimes the process stucks at reddish image, just run the code from the top one more time. " 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "out_np = torch_to_np(net(net_input))\n", 188 | "q = plot_image_grid([np.clip(out_np, 0, 1), img_noflash_np], factor=13);" 189 | ] 190 | } 191 | ], 192 | "metadata": { 193 | "kernelspec": { 194 | "display_name": "Python 3", 195 | "language": "python", 196 | "name": "python3" 197 | }, 198 | "language_info": { 199 | "codemirror_mode": { 200 | "name": "ipython", 201 | "version": 3 202 | }, 203 | "file_extension": ".py", 204 | "mimetype": "text/x-python", 205 | "name": "python", 206 | "nbconvert_exporter": "python", 207 | "pygments_lexer": "ipython3", 208 | "version": "3.6.9" 209 | } 210 | }, 211 | "nbformat": 4, 212 | "nbformat_minor": 2 213 | } 214 | -------------------------------------------------------------------------------- /inpainting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for **\"Inpainting\"** figures $6$, $8$ and 7 (top) from the main paper. " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\"\"\"\n", 17 | "*Uncomment if running on colab* \n", 18 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 19 | "\"\"\"\n", 20 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 21 | "# !mv deep-image-prior/* ./" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Import libs" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from __future__ import print_function\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline\n", 40 | "\n", 41 | "import os\n", 42 | "# os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", 43 | "\n", 44 | "import numpy as np\n", 45 | "from models.resnet import ResNet\n", 46 | "from models.unet import UNet\n", 47 | "from models.skip import skip\n", 48 | "import torch\n", 49 | "import torch.optim\n", 50 | "\n", 51 | "from utils.inpainting_utils import *\n", 52 | "\n", 53 | "torch.backends.cudnn.enabled = True\n", 54 | "torch.backends.cudnn.benchmark =True\n", 55 | "dtype = torch.cuda.FloatTensor\n", 56 | "\n", 57 | "PLOT = True\n", 58 | "imsize = -1\n", 59 | "dim_div_by = 64" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "# Choose figure" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "## Fig 6\n", 76 | "# img_path = 'data/inpainting/vase.png'\n", 77 | "# mask_path = 'data/inpainting/vase_mask.png'\n", 78 | "\n", 79 | "## Fig 8\n", 80 | "# img_path = 'data/inpainting/library.png'\n", 81 | "# mask_path = 'data/inpainting/library_mask.png'\n", 82 | "\n", 83 | "## Fig 7 (top)\n", 84 | "img_path = 'data/inpainting/kate.png'\n", 85 | "mask_path = 'data/inpainting/kate_mask.png'\n", 86 | "\n", 87 | "# Another text inpainting example\n", 88 | "# img_path = 'data/inpainting/peppers.png'\n", 89 | "# mask_path = 'data/inpainting/peppers_mask.png'\n", 90 | "\n", 91 | "NET_TYPE = 'skip_depth6' # one of skip_depth4|skip_depth2|UNET|ResNet" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "# Load mask" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "img_pil, img_np = get_image(img_path, imsize)\n", 108 | "img_mask_pil, img_mask_np = get_image(mask_path, imsize)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "### Center crop" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "img_mask_pil = crop_image(img_mask_pil, dim_div_by)\n", 125 | "img_pil = crop_image(img_pil, dim_div_by)\n", 126 | "\n", 127 | "img_np = pil_to_np(img_pil)\n", 128 | "img_mask_np = pil_to_np(img_mask_pil)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### Visualize" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "scrolled": true 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "img_mask_var = np_to_torch(img_mask_np).type(dtype)\n", 147 | "\n", 148 | "plot_image_grid([img_np, img_mask_np, img_mask_np*img_np], 3,11);" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "# Setup" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "pad = 'reflection' # 'zero'\n", 165 | "OPT_OVER = 'net'\n", 166 | "OPTIMIZER = 'adam'" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "if 'vase.png' in img_path:\n", 176 | " INPUT = 'meshgrid'\n", 177 | " input_depth = 2\n", 178 | " LR = 0.01 \n", 179 | " num_iter = 5001\n", 180 | " param_noise = False\n", 181 | " show_every = 50\n", 182 | " figsize = 5\n", 183 | " reg_noise_std = 0.03\n", 184 | " \n", 185 | " net = skip(input_depth, img_np.shape[0], \n", 186 | " num_channels_down = [128] * 5,\n", 187 | " num_channels_up = [128] * 5,\n", 188 | " num_channels_skip = [0] * 5, \n", 189 | " upsample_mode='nearest', filter_skip_size=1, filter_size_up=3, filter_size_down=3,\n", 190 | " need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n", 191 | " \n", 192 | "elif ('kate.png' in img_path) or ('peppers.png' in img_path):\n", 193 | " # Same params and net as in super-resolution and denoising\n", 194 | " INPUT = 'noise'\n", 195 | " input_depth = 32\n", 196 | " LR = 0.01 \n", 197 | " num_iter = 6001\n", 198 | " param_noise = False\n", 199 | " show_every = 50\n", 200 | " figsize = 5\n", 201 | " reg_noise_std = 0.03\n", 202 | " \n", 203 | " net = skip(input_depth, img_np.shape[0], \n", 204 | " num_channels_down = [128] * 5,\n", 205 | " num_channels_up = [128] * 5,\n", 206 | " num_channels_skip = [128] * 5, \n", 207 | " filter_size_up = 3, filter_size_down = 3, \n", 208 | " upsample_mode='nearest', filter_skip_size=1,\n", 209 | " need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n", 210 | " \n", 211 | "elif 'library.png' in img_path:\n", 212 | " \n", 213 | " INPUT = 'noise'\n", 214 | " input_depth = 1\n", 215 | " \n", 216 | " num_iter = 3001\n", 217 | " show_every = 50\n", 218 | " figsize = 8\n", 219 | " reg_noise_std = 0.00\n", 220 | " param_noise = True\n", 221 | " \n", 222 | " if 'skip' in NET_TYPE:\n", 223 | " \n", 224 | " depth = int(NET_TYPE[-1])\n", 225 | " net = skip(input_depth, img_np.shape[0], \n", 226 | " num_channels_down = [16, 32, 64, 128, 128, 128][:depth],\n", 227 | " num_channels_up = [16, 32, 64, 128, 128, 128][:depth],\n", 228 | " num_channels_skip = [0, 0, 0, 0, 0, 0][:depth], \n", 229 | " filter_size_up = 3,filter_size_down = 5, filter_skip_size=1,\n", 230 | " upsample_mode='nearest', # downsample_mode='avg',\n", 231 | " need1x1_up=False,\n", 232 | " need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n", 233 | " \n", 234 | " LR = 0.01 \n", 235 | " \n", 236 | " elif NET_TYPE == 'UNET':\n", 237 | " \n", 238 | " net = UNet(num_input_channels=input_depth, num_output_channels=3, \n", 239 | " feature_scale=8, more_layers=1, \n", 240 | " concat_x=False, upsample_mode='deconv', \n", 241 | " pad='zero', norm_layer=torch.nn.InstanceNorm2d, need_sigmoid=True, need_bias=True)\n", 242 | " \n", 243 | " LR = 0.001\n", 244 | " param_noise = False\n", 245 | " \n", 246 | " elif NET_TYPE == 'ResNet':\n", 247 | " \n", 248 | " net = ResNet(input_depth, img_np.shape[0], 8, 32, need_sigmoid=True, act_fun='LeakyReLU')\n", 249 | " \n", 250 | " LR = 0.001\n", 251 | " param_noise = False\n", 252 | " \n", 253 | " else:\n", 254 | " assert False\n", 255 | "else:\n", 256 | " assert False\n", 257 | "\n", 258 | "net = net.type(dtype)\n", 259 | "net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# Compute number of parameters\n", 269 | "s = sum(np.prod(list(p.size())) for p in net.parameters())\n", 270 | "print ('Number of params: %d' % s)\n", 271 | "\n", 272 | "# Loss\n", 273 | "mse = torch.nn.MSELoss().type(dtype)\n", 274 | "\n", 275 | "img_var = np_to_torch(img_np).type(dtype)\n", 276 | "mask_var = np_to_torch(img_mask_np).type(dtype)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "# Main loop" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": { 290 | "scrolled": true 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "i = 0\n", 295 | "def closure():\n", 296 | " \n", 297 | " global i\n", 298 | " \n", 299 | " if param_noise:\n", 300 | " for n in [x for x in net.parameters() if len(x.size()) == 4]:\n", 301 | " n = n + n.detach().clone().normal_() * n.std() / 50\n", 302 | " \n", 303 | " net_input = net_input_saved\n", 304 | " if reg_noise_std > 0:\n", 305 | " net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n", 306 | " \n", 307 | " \n", 308 | " out = net(net_input)\n", 309 | " \n", 310 | " total_loss = mse(out * mask_var, img_var * mask_var)\n", 311 | " total_loss.backward()\n", 312 | " \n", 313 | " print ('Iteration %05d Loss %f' % (i, total_loss.item()), '\\r', end='')\n", 314 | " if PLOT and i % show_every == 0:\n", 315 | " out_np = torch_to_np(out)\n", 316 | " plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)\n", 317 | " \n", 318 | " i += 1\n", 319 | "\n", 320 | " return total_loss\n", 321 | "\n", 322 | "net_input_saved = net_input.detach().clone()\n", 323 | "noise = net_input.detach().clone()\n", 324 | "\n", 325 | "p = get_params(OPT_OVER, net, net_input)\n", 326 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "out_np = torch_to_np(net(net_input))\n", 336 | "plot_image_grid([out_np], factor=5);" 337 | ] 338 | } 339 | ], 340 | "metadata": { 341 | "kernelspec": { 342 | "display_name": "Python 3", 343 | "language": "python", 344 | "name": "python3" 345 | }, 346 | "language_info": { 347 | "codemirror_mode": { 348 | "name": "ipython", 349 | "version": 3 350 | }, 351 | "file_extension": ".py", 352 | "mimetype": "text/x-python", 353 | "name": "python", 354 | "nbconvert_exporter": "python", 355 | "pygments_lexer": "ipython3", 356 | "version": "3.6.9" 357 | } 358 | }, 359 | "nbformat": 4, 360 | "nbformat_minor": 2 361 | } 362 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .skip import skip 2 | from .texture_nets import get_texture_nets 3 | from .resnet import ResNet 4 | from .unet import UNet 5 | 6 | import torch.nn as nn 7 | 8 | def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, downsample_mode='stride'): 9 | if NET_TYPE == 'ResNet': 10 | # TODO 11 | net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False) 12 | elif NET_TYPE == 'skip': 13 | net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, 14 | num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, 15 | num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 16 | upsample_mode=upsample_mode, downsample_mode=downsample_mode, 17 | need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun) 18 | 19 | elif NET_TYPE == 'texture_nets': 20 | net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad) 21 | 22 | elif NET_TYPE =='UNet': 23 | net = UNet(num_input_channels=input_depth, num_output_channels=3, 24 | feature_scale=4, more_layers=0, concat_x=False, 25 | upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True) 26 | elif NET_TYPE == 'identity': 27 | assert input_depth == 3 28 | net = nn.Sequential() 29 | else: 30 | assert False 31 | 32 | return net -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .downsampler import Downsampler 5 | 6 | def add_module(self, module): 7 | self.add_module(str(len(self) + 1), module) 8 | 9 | torch.nn.Module.add = add_module 10 | 11 | class Concat(nn.Module): 12 | def __init__(self, dim, *args): 13 | super(Concat, self).__init__() 14 | self.dim = dim 15 | 16 | for idx, module in enumerate(args): 17 | self.add_module(str(idx), module) 18 | 19 | def forward(self, input): 20 | inputs = [] 21 | for module in self._modules.values(): 22 | inputs.append(module(input)) 23 | 24 | inputs_shapes2 = [x.shape[2] for x in inputs] 25 | inputs_shapes3 = [x.shape[3] for x in inputs] 26 | 27 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 28 | inputs_ = inputs 29 | else: 30 | target_shape2 = min(inputs_shapes2) 31 | target_shape3 = min(inputs_shapes3) 32 | 33 | inputs_ = [] 34 | for inp in inputs: 35 | diff2 = (inp.size(2) - target_shape2) // 2 36 | diff3 = (inp.size(3) - target_shape3) // 2 37 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 38 | 39 | return torch.cat(inputs_, dim=self.dim) 40 | 41 | def __len__(self): 42 | return len(self._modules) 43 | 44 | 45 | class GenNoise(nn.Module): 46 | def __init__(self, dim2): 47 | super(GenNoise, self).__init__() 48 | self.dim2 = dim2 49 | 50 | def forward(self, input): 51 | a = list(input.size()) 52 | a[1] = self.dim2 53 | # print (input.data.type()) 54 | 55 | b = torch.zeros(a).type_as(input.data) 56 | b.normal_() 57 | 58 | x = torch.autograd.Variable(b) 59 | 60 | return x 61 | 62 | 63 | class Swish(nn.Module): 64 | """ 65 | https://arxiv.org/abs/1710.05941 66 | The hype was so huge that I could not help but try it 67 | """ 68 | def __init__(self): 69 | super(Swish, self).__init__() 70 | self.s = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | return x * self.s(x) 74 | 75 | 76 | def act(act_fun = 'LeakyReLU'): 77 | ''' 78 | Either string defining an activation function or module (e.g. nn.ReLU) 79 | ''' 80 | if isinstance(act_fun, str): 81 | if act_fun == 'LeakyReLU': 82 | return nn.LeakyReLU(0.2, inplace=True) 83 | elif act_fun == 'Swish': 84 | return Swish() 85 | elif act_fun == 'ELU': 86 | return nn.ELU() 87 | elif act_fun == 'none': 88 | return nn.Sequential() 89 | else: 90 | assert False 91 | else: 92 | return act_fun() 93 | 94 | 95 | def bn(num_features): 96 | return nn.BatchNorm2d(num_features) 97 | 98 | 99 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 100 | downsampler = None 101 | if stride != 1 and downsample_mode != 'stride': 102 | 103 | if downsample_mode == 'avg': 104 | downsampler = nn.AvgPool2d(stride, stride) 105 | elif downsample_mode == 'max': 106 | downsampler = nn.MaxPool2d(stride, stride) 107 | elif downsample_mode in ['lanczos2', 'lanczos3']: 108 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 109 | else: 110 | assert False 111 | 112 | stride = 1 113 | 114 | padder = None 115 | to_pad = int((kernel_size - 1) / 2) 116 | if pad == 'reflection': 117 | padder = nn.ReflectionPad2d(to_pad) 118 | to_pad = 0 119 | 120 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 121 | 122 | 123 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 124 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /models/dcgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def dcgan(inp=2, 5 | ndf=32, 6 | num_ups=4, need_sigmoid=True, need_bias=True, pad='zero', upsample_mode='nearest', need_convT = True): 7 | 8 | layers= [nn.ConvTranspose2d(inp, ndf, kernel_size=3, stride=1, padding=0, bias=False), 9 | nn.BatchNorm2d(ndf), 10 | nn.LeakyReLU(True)] 11 | 12 | for i in range(num_ups-3): 13 | if need_convT: 14 | layers += [ nn.ConvTranspose2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=False), 15 | nn.BatchNorm2d(ndf), 16 | nn.LeakyReLU(True)] 17 | else: 18 | layers += [ nn.Upsample(scale_factor=2, mode=upsample_mode), 19 | nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=False), 20 | nn.BatchNorm2d(ndf), 21 | nn.LeakyReLU(True)] 22 | 23 | if need_convT: 24 | layers += [nn.ConvTranspose2d(ndf, 3, 4, 2, 1, bias=False),] 25 | else: 26 | layers += [nn.Upsample(scale_factor=2, mode='bilinear'), 27 | nn.Conv2d(ndf, 3, kernel_size=3, stride=1, padding=1, bias=False)] 28 | 29 | 30 | if need_sigmoid: 31 | layers += [nn.Sigmoid()] 32 | 33 | model =nn.Sequential(*layers) 34 | return model -------------------------------------------------------------------------------- /models/downsampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Downsampler(nn.Module): 6 | ''' 7 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 8 | ''' 9 | def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False): 10 | super(Downsampler, self).__init__() 11 | 12 | assert phase in [0, 0.5], 'phase should be 0 or 0.5' 13 | 14 | if kernel_type == 'lanczos2': 15 | support = 2 16 | kernel_width = 4 * factor + 1 17 | kernel_type_ = 'lanczos' 18 | 19 | elif kernel_type == 'lanczos3': 20 | support = 3 21 | kernel_width = 6 * factor + 1 22 | kernel_type_ = 'lanczos' 23 | 24 | elif kernel_type == 'gauss12': 25 | kernel_width = 7 26 | sigma = 1/2 27 | kernel_type_ = 'gauss' 28 | 29 | elif kernel_type == 'gauss1sq2': 30 | kernel_width = 9 31 | sigma = 1./np.sqrt(2) 32 | kernel_type_ = 'gauss' 33 | 34 | elif kernel_type in ['lanczos', 'gauss', 'box']: 35 | kernel_type_ = kernel_type 36 | 37 | else: 38 | assert False, 'wrong name kernel' 39 | 40 | 41 | # note that `kernel width` will be different to actual size for phase = 1/2 42 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) 43 | 44 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) 45 | downsampler.weight.data[:] = 0 46 | downsampler.bias.data[:] = 0 47 | 48 | kernel_torch = torch.from_numpy(self.kernel) 49 | for i in range(n_planes): 50 | downsampler.weight.data[i, i] = kernel_torch 51 | 52 | self.downsampler_ = downsampler 53 | 54 | if preserve_size: 55 | 56 | if self.kernel.shape[0] % 2 == 1: 57 | pad = int((self.kernel.shape[0] - 1) / 2.) 58 | else: 59 | pad = int((self.kernel.shape[0] - factor) / 2.) 60 | 61 | self.padding = nn.ReplicationPad2d(pad) 62 | 63 | self.preserve_size = preserve_size 64 | 65 | def forward(self, input): 66 | if self.preserve_size: 67 | x = self.padding(input) 68 | else: 69 | x= input 70 | self.x = x 71 | return self.downsampler_(x) 72 | 73 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): 74 | assert kernel_type in ['lanczos', 'gauss', 'box'] 75 | 76 | # factor = float(factor) 77 | if phase == 0.5 and kernel_type != 'box': 78 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 79 | else: 80 | kernel = np.zeros([kernel_width, kernel_width]) 81 | 82 | 83 | if kernel_type == 'box': 84 | assert phase == 0.5, 'Box filter is always half-phased' 85 | kernel[:] = 1./(kernel_width * kernel_width) 86 | 87 | elif kernel_type == 'gauss': 88 | assert sigma, 'sigma is not specified' 89 | assert phase != 0.5, 'phase 1/2 for gauss not implemented' 90 | 91 | center = (kernel_width + 1.)/2. 92 | print(center, kernel_width) 93 | sigma_sq = sigma * sigma 94 | 95 | for i in range(1, kernel.shape[0] + 1): 96 | for j in range(1, kernel.shape[1] + 1): 97 | di = (i - center)/2. 98 | dj = (j - center)/2. 99 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq)) 100 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq) 101 | elif kernel_type == 'lanczos': 102 | assert support, 'support is not specified' 103 | center = (kernel_width + 1) / 2. 104 | 105 | for i in range(1, kernel.shape[0] + 1): 106 | for j in range(1, kernel.shape[1] + 1): 107 | 108 | if phase == 0.5: 109 | di = abs(i + 0.5 - center) / factor 110 | dj = abs(j + 0.5 - center) / factor 111 | else: 112 | di = abs(i - center) / factor 113 | dj = abs(j - center) / factor 114 | 115 | 116 | pi_sq = np.pi * np.pi 117 | 118 | val = 1 119 | if di != 0: 120 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 121 | val = val / (np.pi * np.pi * di * di) 122 | 123 | if dj != 0: 124 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 125 | val = val / (np.pi * np.pi * dj * dj) 126 | 127 | kernel[i - 1][j - 1] = val 128 | 129 | 130 | else: 131 | assert False, 'wrong method name' 132 | 133 | kernel /= kernel.sum() 134 | 135 | return kernel 136 | 137 | #a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True) 138 | 139 | 140 | 141 | 142 | 143 | 144 | ################# 145 | # Learnable downsampler 146 | 147 | # KS = 32 148 | # dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor)) 149 | 150 | # class Apply(nn.Module): 151 | # def __init__(self, what, dim, *args): 152 | # super(Apply, self).__init__() 153 | # self.dim = dim 154 | 155 | # self.what = what 156 | 157 | # def forward(self, input): 158 | # inputs = [] 159 | # for i in range(input.size(self.dim)): 160 | # inputs.append(self.what(input.narrow(self.dim, i, 1))) 161 | 162 | # return torch.cat(inputs, dim=self.dim) 163 | 164 | # def __len__(self): 165 | # return len(self._modules) 166 | 167 | # downs = Apply(dow, 1) 168 | # downs.type(dtype)(net_input.type(dtype)).size() 169 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy.random import normal 4 | from numpy.linalg import svd 5 | from math import sqrt 6 | import torch.nn.init 7 | from .common import * 8 | 9 | class ResidualSequential(nn.Sequential): 10 | def __init__(self, *args): 11 | super(ResidualSequential, self).__init__(*args) 12 | 13 | def forward(self, x): 14 | out = super(ResidualSequential, self).forward(x) 15 | # print(x.size(), out.size()) 16 | x_ = None 17 | if out.size(2) != x.size(2) or out.size(3) != x.size(3): 18 | diff2 = x.size(2) - out.size(2) 19 | diff3 = x.size(3) - out.size(3) 20 | # print(1) 21 | x_ = x[:, :, diff2 /2:out.size(2) + diff2 / 2, diff3 / 2:out.size(3) + diff3 / 2] 22 | else: 23 | x_ = x 24 | return out + x_ 25 | 26 | def eval(self): 27 | print(2) 28 | for m in self.modules(): 29 | m.eval() 30 | exit() 31 | 32 | 33 | def get_block(num_channels, norm_layer, act_fun): 34 | layers = [ 35 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 36 | norm_layer(num_channels, affine=True), 37 | act(act_fun), 38 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 39 | norm_layer(num_channels, affine=True), 40 | ] 41 | return layers 42 | 43 | 44 | class ResNet(nn.Module): 45 | def __init__(self, num_input_channels, num_output_channels, num_blocks, num_channels, need_residual=True, act_fun='LeakyReLU', need_sigmoid=True, norm_layer=nn.BatchNorm2d, pad='reflection'): 46 | ''' 47 | pad = 'start|zero|replication' 48 | ''' 49 | super(ResNet, self).__init__() 50 | 51 | if need_residual: 52 | s = ResidualSequential 53 | else: 54 | s = nn.Sequential 55 | 56 | stride = 1 57 | # First layers 58 | layers = [ 59 | # nn.ReplicationPad2d(num_blocks * 2 * stride + 3), 60 | conv(num_input_channels, num_channels, 3, stride=1, bias=True, pad=pad), 61 | act(act_fun) 62 | ] 63 | # Residual blocks 64 | # layers_residual = [] 65 | for i in range(num_blocks): 66 | layers += [s(*get_block(num_channels, norm_layer, act_fun))] 67 | 68 | layers += [ 69 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 70 | norm_layer(num_channels, affine=True) 71 | ] 72 | 73 | # if need_residual: 74 | # layers += [ResidualSequential(*layers_residual)] 75 | # else: 76 | # layers += [Sequential(*layers_residual)] 77 | 78 | # if factor >= 2: 79 | # # Do upsampling if needed 80 | # layers += [ 81 | # nn.Conv2d(num_channels, num_channels * 82 | # factor ** 2, 3, 1), 83 | # nn.PixelShuffle(factor), 84 | # act(act_fun) 85 | # ] 86 | layers += [ 87 | conv(num_channels, num_output_channels, 3, 1, bias=True, pad=pad), 88 | nn.Sigmoid() 89 | ] 90 | self.model = nn.Sequential(*layers) 91 | 92 | def forward(self, input): 93 | return self.model(input) 94 | 95 | def eval(self): 96 | self.model.eval() 97 | -------------------------------------------------------------------------------- /models/skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | def skip( 6 | num_input_channels=2, num_output_channels=3, 7 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 8 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 9 | need_sigmoid=True, need_bias=True, 10 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 11 | need1x1_up=True): 12 | """Assembles encoder-decoder with skip connections. 13 | 14 | Arguments: 15 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 16 | pad (string): zero|reflection (default: 'zero') 17 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 18 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 19 | 20 | """ 21 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 22 | 23 | n_scales = len(num_channels_down) 24 | 25 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 26 | upsample_mode = [upsample_mode]*n_scales 27 | 28 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 29 | downsample_mode = [downsample_mode]*n_scales 30 | 31 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 32 | filter_size_down = [filter_size_down]*n_scales 33 | 34 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 35 | filter_size_up = [filter_size_up]*n_scales 36 | 37 | last_scale = n_scales - 1 38 | 39 | cur_depth = None 40 | 41 | model = nn.Sequential() 42 | model_tmp = model 43 | 44 | input_depth = num_input_channels 45 | for i in range(len(num_channels_down)): 46 | 47 | deeper = nn.Sequential() 48 | skip = nn.Sequential() 49 | 50 | if num_channels_skip[i] != 0: 51 | model_tmp.add(Concat(1, skip, deeper)) 52 | else: 53 | model_tmp.add(deeper) 54 | 55 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 56 | 57 | if num_channels_skip[i] != 0: 58 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 59 | skip.add(bn(num_channels_skip[i])) 60 | skip.add(act(act_fun)) 61 | 62 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 63 | 64 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 65 | deeper.add(bn(num_channels_down[i])) 66 | deeper.add(act(act_fun)) 67 | 68 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 69 | deeper.add(bn(num_channels_down[i])) 70 | deeper.add(act(act_fun)) 71 | 72 | deeper_main = nn.Sequential() 73 | 74 | if i == len(num_channels_down) - 1: 75 | # The deepest 76 | k = num_channels_down[i] 77 | else: 78 | deeper.add(deeper_main) 79 | k = num_channels_up[i + 1] 80 | 81 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 82 | 83 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 84 | model_tmp.add(bn(num_channels_up[i])) 85 | model_tmp.add(act(act_fun)) 86 | 87 | 88 | if need1x1_up: 89 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 90 | model_tmp.add(bn(num_channels_up[i])) 91 | model_tmp.add(act(act_fun)) 92 | 93 | input_depth = num_channels_down[i] 94 | model_tmp = deeper_main 95 | 96 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 97 | if need_sigmoid: 98 | model.add(nn.Sigmoid()) 99 | 100 | return model 101 | -------------------------------------------------------------------------------- /models/texture_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | 6 | normalization = nn.BatchNorm2d 7 | 8 | 9 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero'): 10 | if pad == 'zero': 11 | return nn.Conv2d(in_f, out_f, kernel_size, stride, padding=(kernel_size - 1) / 2, bias=bias) 12 | elif pad == 'reflection': 13 | layers = [nn.ReflectionPad2d((kernel_size - 1) / 2), 14 | nn.Conv2d(in_f, out_f, kernel_size, stride, padding=0, bias=bias)] 15 | return nn.Sequential(*layers) 16 | 17 | def get_texture_nets(inp=3, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False, pad='zero', need_sigmoid=False, conv_num=8, upsample_mode='nearest'): 18 | 19 | 20 | for i in range(len(ratios)): 21 | j = i + 1 22 | 23 | seq = nn.Sequential() 24 | 25 | tmp = nn.AvgPool2d(ratios[i], ratios[i]) 26 | 27 | seq.add(tmp) 28 | if fill_noise: 29 | seq.add(GenNoise(inp)) 30 | 31 | seq.add(conv(inp, conv_num, 3, pad=pad)) 32 | seq.add(normalization(conv_num)) 33 | seq.add(act()) 34 | 35 | seq.add(conv(conv_num, conv_num, 3, pad=pad)) 36 | seq.add(normalization(conv_num)) 37 | seq.add(act()) 38 | 39 | seq.add(conv(conv_num, conv_num, 1, pad=pad)) 40 | seq.add(normalization(conv_num)) 41 | seq.add(act()) 42 | 43 | if i == 0: 44 | seq.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 45 | cur = seq 46 | else: 47 | 48 | cur_temp = cur 49 | 50 | cur = nn.Sequential() 51 | 52 | # Batch norm before merging 53 | seq.add(normalization(conv_num)) 54 | cur_temp.add(normalization(conv_num * (j - 1))) 55 | 56 | cur.add(Concat(1, cur_temp, seq)) 57 | 58 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 59 | cur.add(normalization(conv_num * j)) 60 | cur.add(act()) 61 | 62 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 63 | cur.add(normalization(conv_num * j)) 64 | cur.add(act()) 65 | 66 | cur.add(conv(conv_num * j, conv_num * j, 1, pad=pad)) 67 | cur.add(normalization(conv_num * j)) 68 | cur.add(act()) 69 | 70 | if i == len(ratios) - 1: 71 | cur.add(conv(conv_num * j, 3, 1, pad=pad)) 72 | else: 73 | cur.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 74 | 75 | model = cur 76 | if need_sigmoid: 77 | model.add(nn.Sigmoid()) 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .common import * 6 | 7 | class ListModule(nn.Module): 8 | def __init__(self, *args): 9 | super(ListModule, self).__init__() 10 | idx = 0 11 | for module in args: 12 | self.add_module(str(idx), module) 13 | idx += 1 14 | 15 | def __getitem__(self, idx): 16 | if idx >= len(self._modules): 17 | raise IndexError('index {} is out of range'.format(idx)) 18 | if idx < 0: 19 | idx = len(self) + idx 20 | 21 | it = iter(self._modules.values()) 22 | for i in range(idx): 23 | next(it) 24 | return next(it) 25 | 26 | def __iter__(self): 27 | return iter(self._modules.values()) 28 | 29 | def __len__(self): 30 | return len(self._modules) 31 | 32 | class UNet(nn.Module): 33 | ''' 34 | upsample_mode in ['deconv', 'nearest', 'bilinear'] 35 | pad in ['zero', 'replication', 'none'] 36 | ''' 37 | def __init__(self, num_input_channels=3, num_output_channels=3, 38 | feature_scale=4, more_layers=0, concat_x=False, 39 | upsample_mode='deconv', pad='zero', norm_layer=nn.InstanceNorm2d, need_sigmoid=True, need_bias=True): 40 | super(UNet, self).__init__() 41 | 42 | self.feature_scale = feature_scale 43 | self.more_layers = more_layers 44 | self.concat_x = concat_x 45 | 46 | 47 | filters = [64, 128, 256, 512, 1024] 48 | filters = [x // self.feature_scale for x in filters] 49 | 50 | self.start = unetConv2(num_input_channels, filters[0] if not concat_x else filters[0] - num_input_channels, norm_layer, need_bias, pad) 51 | 52 | self.down1 = unetDown(filters[0], filters[1] if not concat_x else filters[1] - num_input_channels, norm_layer, need_bias, pad) 53 | self.down2 = unetDown(filters[1], filters[2] if not concat_x else filters[2] - num_input_channels, norm_layer, need_bias, pad) 54 | self.down3 = unetDown(filters[2], filters[3] if not concat_x else filters[3] - num_input_channels, norm_layer, need_bias, pad) 55 | self.down4 = unetDown(filters[3], filters[4] if not concat_x else filters[4] - num_input_channels, norm_layer, need_bias, pad) 56 | 57 | # more downsampling layers 58 | if self.more_layers > 0: 59 | self.more_downs = [ 60 | unetDown(filters[4], filters[4] if not concat_x else filters[4] - num_input_channels , norm_layer, need_bias, pad) for i in range(self.more_layers)] 61 | self.more_ups = [unetUp(filters[4], upsample_mode, need_bias, pad, same_num_filt =True) for i in range(self.more_layers)] 62 | 63 | self.more_downs = ListModule(*self.more_downs) 64 | self.more_ups = ListModule(*self.more_ups) 65 | 66 | self.up4 = unetUp(filters[3], upsample_mode, need_bias, pad) 67 | self.up3 = unetUp(filters[2], upsample_mode, need_bias, pad) 68 | self.up2 = unetUp(filters[1], upsample_mode, need_bias, pad) 69 | self.up1 = unetUp(filters[0], upsample_mode, need_bias, pad) 70 | 71 | self.final = conv(filters[0], num_output_channels, 1, bias=need_bias, pad=pad) 72 | 73 | if need_sigmoid: 74 | self.final = nn.Sequential(self.final, nn.Sigmoid()) 75 | 76 | def forward(self, inputs): 77 | 78 | # Downsample 79 | downs = [inputs] 80 | down = nn.AvgPool2d(2, 2) 81 | for i in range(4 + self.more_layers): 82 | downs.append(down(downs[-1])) 83 | 84 | in64 = self.start(inputs) 85 | if self.concat_x: 86 | in64 = torch.cat([in64, downs[0]], 1) 87 | 88 | down1 = self.down1(in64) 89 | if self.concat_x: 90 | down1 = torch.cat([down1, downs[1]], 1) 91 | 92 | down2 = self.down2(down1) 93 | if self.concat_x: 94 | down2 = torch.cat([down2, downs[2]], 1) 95 | 96 | down3 = self.down3(down2) 97 | if self.concat_x: 98 | down3 = torch.cat([down3, downs[3]], 1) 99 | 100 | down4 = self.down4(down3) 101 | if self.concat_x: 102 | down4 = torch.cat([down4, downs[4]], 1) 103 | 104 | if self.more_layers > 0: 105 | prevs = [down4] 106 | for kk, d in enumerate(self.more_downs): 107 | # print(prevs[-1].size()) 108 | out = d(prevs[-1]) 109 | if self.concat_x: 110 | out = torch.cat([out, downs[kk + 5]], 1) 111 | 112 | prevs.append(out) 113 | 114 | up_ = self.more_ups[-1](prevs[-1], prevs[-2]) 115 | for idx in range(self.more_layers - 1): 116 | l = self.more_ups[self.more - idx - 2] 117 | up_= l(up_, prevs[self.more - idx - 2]) 118 | else: 119 | up_= down4 120 | 121 | up4= self.up4(up_, down3) 122 | up3= self.up3(up4, down2) 123 | up2= self.up2(up3, down1) 124 | up1= self.up1(up2, in64) 125 | 126 | return self.final(up1) 127 | 128 | 129 | 130 | class unetConv2(nn.Module): 131 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 132 | super(unetConv2, self).__init__() 133 | 134 | print(pad) 135 | if norm_layer is not None: 136 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 137 | norm_layer(out_size), 138 | nn.ReLU(),) 139 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 140 | norm_layer(out_size), 141 | nn.ReLU(),) 142 | else: 143 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 144 | nn.ReLU(),) 145 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 146 | nn.ReLU(),) 147 | def forward(self, inputs): 148 | outputs= self.conv1(inputs) 149 | outputs= self.conv2(outputs) 150 | return outputs 151 | 152 | 153 | class unetDown(nn.Module): 154 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 155 | super(unetDown, self).__init__() 156 | self.conv= unetConv2(in_size, out_size, norm_layer, need_bias, pad) 157 | self.down= nn.MaxPool2d(2, 2) 158 | 159 | def forward(self, inputs): 160 | outputs= self.down(inputs) 161 | outputs= self.conv(outputs) 162 | return outputs 163 | 164 | 165 | class unetUp(nn.Module): 166 | def __init__(self, out_size, upsample_mode, need_bias, pad, same_num_filt=False): 167 | super(unetUp, self).__init__() 168 | 169 | num_filt = out_size if same_num_filt else out_size * 2 170 | if upsample_mode == 'deconv': 171 | self.up= nn.ConvTranspose2d(num_filt, out_size, 4, stride=2, padding=1) 172 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 173 | elif upsample_mode=='bilinear' or upsample_mode=='nearest': 174 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode=upsample_mode), 175 | conv(num_filt, out_size, 3, bias=need_bias, pad=pad)) 176 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 177 | else: 178 | assert False 179 | 180 | def forward(self, inputs1, inputs2): 181 | in1_up= self.up(inputs1) 182 | 183 | if (inputs2.size(2) != in1_up.size(2)) or (inputs2.size(3) != in1_up.size(3)): 184 | diff2 = (inputs2.size(2) - in1_up.size(2)) // 2 185 | diff3 = (inputs2.size(3) - in1_up.size(3)) // 2 186 | inputs2_ = inputs2[:, :, diff2 : diff2 + in1_up.size(2), diff3 : diff3 + in1_up.size(3)] 187 | else: 188 | inputs2_ = inputs2 189 | 190 | output= self.conv(torch.cat([in1_up, inputs2_], 1)) 191 | 192 | return output 193 | -------------------------------------------------------------------------------- /restoration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for the figures, where an image is restored from a fraction of pixels (fig. 7 bottom, fig. 14 of supmat)." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\"\"\"\n", 17 | "*Uncomment if running on colab* \n", 18 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 19 | "\"\"\"\n", 20 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 21 | "# !mv deep-image-prior/* ./" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Import libs" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from __future__ import print_function\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline\n", 40 | "\n", 41 | "import os\n", 42 | "#os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", 43 | "\n", 44 | "import numpy as np\n", 45 | "from models.resnet import ResNet\n", 46 | "from models.unet import UNet\n", 47 | "from models.skip import skip\n", 48 | "from models import get_net\n", 49 | "import torch\n", 50 | "import torch.optim\n", 51 | "from skimage.measure import compare_psnr\n", 52 | "\n", 53 | "from utils.inpainting_utils import *\n", 54 | "\n", 55 | "torch.backends.cudnn.enabled = True\n", 56 | "torch.backends.cudnn.benchmark =True\n", 57 | "dtype = torch.cuda.FloatTensor\n", 58 | "\n", 59 | "PLOT = True\n", 60 | "imsize=-1\n", 61 | "dim_div_by = 64\n", 62 | "dtype = torch.cuda.FloatTensor" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "# Choose figure" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# fig. 7 (bottom)\n", 79 | "f = './data/restoration/barbara.png'\n", 80 | "\n", 81 | "# fig. 14 of supmat\n", 82 | "# f = './data/restoration/kate.png'\n", 83 | "\n", 84 | "\n", 85 | "img_pil, img_np = get_image(f, imsize)\n", 86 | "\n", 87 | "if 'barbara' in f:\n", 88 | " img_np = nn.ReflectionPad2d(1)(np_to_torch(img_np))[0].numpy()\n", 89 | " img_pil = np_to_pil(img_np)\n", 90 | " \n", 91 | " img_mask = get_bernoulli_mask(img_pil, 0.50)\n", 92 | " img_mask_np = pil_to_np(img_mask)\n", 93 | "elif 'kate' in f:\n", 94 | " img_mask = get_bernoulli_mask(img_pil, 0.98)\n", 95 | "\n", 96 | " img_mask_np = pil_to_np(img_mask)\n", 97 | " img_mask_np[1] = img_mask_np[0]\n", 98 | " img_mask_np[2] = img_mask_np[0]\n", 99 | "else:\n", 100 | " assert False\n", 101 | " \n", 102 | "\n", 103 | "img_masked = img_np * img_mask_np\n", 104 | "\n", 105 | "mask_var = np_to_torch(img_mask_np).type(dtype)\n", 106 | "\n", 107 | "plot_image_grid([img_np, img_mask_np, img_mask_np * img_np], 3,11);" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "# Set up everything" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "show_every=50\n", 124 | "figsize=5\n", 125 | "pad = 'reflection' # 'zero'\n", 126 | "INPUT = 'noise'\n", 127 | "input_depth = 32\n", 128 | "OPTIMIZER = 'adam'\n", 129 | "OPT_OVER = 'net'\n", 130 | "if 'barbara' in f:\n", 131 | " OPTIMIZER = 'adam'\n", 132 | " \n", 133 | " LR = 0.001\n", 134 | " num_iter = 11000\n", 135 | " reg_noise_std = 0.03\n", 136 | " \n", 137 | " NET_TYPE = 'skip'\n", 138 | " net = get_net(input_depth, 'skip', pad, n_channels=1,\n", 139 | " skip_n33d=128, \n", 140 | " skip_n33u=128, \n", 141 | " skip_n11=4, \n", 142 | " num_scales=5,\n", 143 | " upsample_mode='bilinear').type(dtype)\n", 144 | "elif 'kate' in f:\n", 145 | " OPT_OVER = 'net'\n", 146 | " num_iter = 1000\n", 147 | " LR = 0.01\n", 148 | " reg_noise_std = 0.00\n", 149 | " \n", 150 | " net = skip(input_depth, \n", 151 | " img_np.shape[0], \n", 152 | " num_channels_down = [16, 32, 64, 128, 128],\n", 153 | " num_channels_up = [16, 32, 64, 128, 128],\n", 154 | " num_channels_skip = [0, 0, 0, 0, 0], \n", 155 | " filter_size_down = 3, filter_size_up = 3, filter_skip_size=1,\n", 156 | " upsample_mode='bilinear', \n", 157 | " downsample_mode='avg',\n", 158 | " need_sigmoid=True, need_bias=True, pad=pad).type(dtype)\n", 159 | " \n", 160 | "# Loss\n", 161 | "mse = torch.nn.MSELoss().type(dtype)\n", 162 | "img_var = np_to_torch(img_np).type(dtype)\n", 163 | "\n", 164 | "net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype).detach()" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "# Main loop" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "def closure():\n", 181 | "\n", 182 | " global i, psrn_masked_last, last_net, net_input\n", 183 | " \n", 184 | " if reg_noise_std > 0:\n", 185 | " net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n", 186 | " \n", 187 | " out = net(net_input)\n", 188 | "\n", 189 | " total_loss = mse(out * mask_var, img_var * mask_var)\n", 190 | " total_loss.backward()\n", 191 | " \n", 192 | " psrn_masked = compare_psnr(img_masked, out.detach().cpu().numpy()[0] * img_mask_np) \n", 193 | " psrn = compare_psnr(img_np, out.detach().cpu().numpy()[0]) \n", 194 | "\n", 195 | " print ('Iteration %05d Loss %f PSNR_masked %f PSNR %f' % (i, total_loss.item(), psrn_masked, psrn),'\\r', end='')\n", 196 | " \n", 197 | " \n", 198 | " if PLOT and i % show_every == 0:\n", 199 | " out_np = torch_to_np(out)\n", 200 | " \n", 201 | " # Backtracking\n", 202 | " if psrn_masked - psrn_masked_last < -5: \n", 203 | " print('Falling back to previous checkpoint.')\n", 204 | "\n", 205 | " for new_param, net_param in zip(last_net, net.parameters()):\n", 206 | " net_param.data.copy_(new_param.cuda())\n", 207 | "\n", 208 | " return total_loss*0\n", 209 | " else:\n", 210 | " last_net = [x.cpu() for x in net.parameters()]\n", 211 | " psrn_masked_last = psrn_masked\n", 212 | "\n", 213 | "\n", 214 | "\n", 215 | " plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)\n", 216 | "\n", 217 | " i += 1\n", 218 | "\n", 219 | " return total_loss\n", 220 | "\n", 221 | "# Init globals \n", 222 | "last_net = None\n", 223 | "psrn_masked_last = 0\n", 224 | "i = 0\n", 225 | "\n", 226 | "net_input_saved = net_input.detach().clone()\n", 227 | "noise = net_input.detach().clone()\n", 228 | "\n", 229 | "# Run\n", 230 | "p = get_params(OPT_OVER, net, net_input)\n", 231 | "optimize(OPTIMIZER, p, closure, LR=LR, num_iter=num_iter)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "out_np = torch_to_np(net(net_input))\n", 241 | "q = plot_image_grid([np.clip(out_np, 0, 1), img_np], factor=13);" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "Python 3", 255 | "language": "python", 256 | "name": "python3" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.6.9" 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 2 273 | } 274 | -------------------------------------------------------------------------------- /sr_prior_effect.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for **\"Prior effect\"** figure from supmat." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\"\"\"\n", 17 | "*Uncomment if running on colab* \n", 18 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 19 | "\"\"\"\n", 20 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 21 | "# !mv deep-image-prior/* ./" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Import libs" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from __future__ import print_function\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline\n", 40 | "\n", 41 | "import argparse\n", 42 | "import os\n", 43 | "# os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", 44 | "\n", 45 | "import numpy as np\n", 46 | "from models import *\n", 47 | "\n", 48 | "import torch\n", 49 | "import torch.optim\n", 50 | "\n", 51 | "from skimage.measure import compare_psnr\n", 52 | "from models.downsampler import Downsampler\n", 53 | "\n", 54 | "from utils.sr_utils import *\n", 55 | "\n", 56 | "torch.backends.cudnn.enabled = True\n", 57 | "torch.backends.cudnn.benchmark =True\n", 58 | "dtype = torch.cuda.FloatTensor\n", 59 | "\n", 60 | "imsize =-1 \n", 61 | "factor = 4\n", 62 | "enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by a power of two\n", 63 | "\n", 64 | "PLOT = True" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "# Load image" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "fname = 'data/sr/zebra_crop.png'\n", 81 | "\n", 82 | "imgs = load_LR_HR_imgs_sr(fname, imsize, factor, enforse_div32)\n", 83 | "\n", 84 | "if PLOT:\n", 85 | " imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] = get_baselines(imgs['LR_pil'], imgs['HR_pil'])\n", 86 | " plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np']], 4,12);\n", 87 | " print ('PSNR bicubic: %.4f PSNR nearest: %.4f' % (\n", 88 | " compare_psnr(imgs['HR_np'], imgs['bicubic_np']), \n", 89 | " compare_psnr(imgs['HR_np'], imgs['nearest_np'])))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def closure():\n", 99 | " \n", 100 | " global i, net_input\n", 101 | " \n", 102 | " \n", 103 | " if reg_noise_std > 0:\n", 104 | " net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n", 105 | " \n", 106 | " out_HR = net(net_input)\n", 107 | " out_LR = downsampler(out_HR)\n", 108 | "\n", 109 | " total_loss = mse(out_LR, img_LR_var) + tv_weight * tv_loss(out_HR)\n", 110 | " total_loss.backward()\n", 111 | "\n", 112 | " # Log\n", 113 | " psnr_LR = compare_psnr(imgs['LR_np'], torch_to_np(out_LR))\n", 114 | " psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR))\n", 115 | " print ('Iteration %05d PSNR_LR %.3f PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\\r', end='')\n", 116 | " \n", 117 | " # History\n", 118 | " psnr_history.append([psnr_LR, psnr_HR])\n", 119 | " \n", 120 | " if PLOT and i % 500 == 0:\n", 121 | " out_HR_np = torch_to_np(out_HR)\n", 122 | " plot_image_grid([imgs['HR_np'], np.clip(out_HR_np, 0, 1)], factor=8, nrow=2, interpolation='lanczos')\n", 123 | "\n", 124 | " i += 1\n", 125 | " \n", 126 | " return total_loss" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "# Experiment 1: no prior, optimize over pixels" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "input_depth = 3\n", 143 | " \n", 144 | "INPUT = 'noise'\n", 145 | "pad = 'reflection'\n", 146 | "OPT_OVER = 'input'\n", 147 | "KERNEL_TYPE='lanczos2'\n", 148 | "\n", 149 | "LR = 0.01\n", 150 | "tv_weight = 0.0\n", 151 | "\n", 152 | "OPTIMIZER = 'adam'\n", 153 | "\n", 154 | "num_iter = 2000\n", 155 | "reg_noise_std = 0.0" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "# Identity mapping network, optimize over `net_input`\n", 165 | "net = nn.Sequential()\n", 166 | "net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()\n", 167 | "\n", 168 | "downsampler = Downsampler(n_planes=3, factor=factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype)\n", 169 | "\n", 170 | "# Loss\n", 171 | "mse = torch.nn.MSELoss().type(dtype)\n", 172 | "\n", 173 | "img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": { 180 | "scrolled": false 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "psnr_history = [] \n", 185 | "i = 0\n", 186 | "net_input_saved = net_input.detach().clone()\n", 187 | "noise = net_input.detach().clone()\n", 188 | " \n", 189 | "p = get_params(OPT_OVER, net, net_input)\n", 190 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)\n", 200 | "\n", 201 | "result_no_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])\n", 202 | "psnr_history_direct = psnr_history" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "# Experiment 2: using TV loss" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "tv_weight = 1e-7\n", 219 | "net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()\n", 220 | "\n", 221 | "psnr_history = [] \n", 222 | "i = 0\n", 223 | " \n", 224 | "p = get_params(OPT_OVER, net, net_input)\n", 225 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)\n", 235 | "\n", 236 | "result_tv_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])\n", 237 | "psnr_history_tv = psnr_history" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "# Experiment 3: using deep prior" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "Same setting, but use parametrization." 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "OPT_OVER = 'net'\n", 261 | "reg_noise_std = 1./30. # This parameter probably should be set to a lower value for this example\n", 262 | "tv_weight = 0.0\n", 263 | "\n", 264 | "net = skip(input_depth, 3, num_channels_down = [128, 128, 128, 128, 128], \n", 265 | " num_channels_up = [128, 128, 128, 128, 128],\n", 266 | " num_channels_skip = [4, 4, 4, 4, 4], \n", 267 | " upsample_mode='bilinear',\n", 268 | " need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n", 269 | "\n", 270 | "net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()\n", 271 | "\n", 272 | "# Compute number of parameters\n", 273 | "s = sum([np.prod(list(p.size())) for p in net.parameters()]); \n", 274 | "print ('Number of params: %d' % s)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "scrolled": false 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "psnr_history = [] \n", 286 | "net_input_saved = net_input.detach().clone()\n", 287 | "noise = net_input.detach().clone()\n", 288 | "\n", 289 | "i = 0\n", 290 | "p = get_params(OPT_OVER, net, net_input)\n", 291 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)\n", 301 | "\n", 302 | "result_deep_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])\n", 303 | "psnr_history_deep_prior = psnr_history" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "# Comparison" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "plot_image_grid([imgs['HR_np'], \n", 320 | " result_no_prior, \n", 321 | " result_tv_prior, \n", 322 | " result_deep_prior], factor=8, nrow=2);" 323 | ] 324 | } 325 | ], 326 | "metadata": { 327 | "kernelspec": { 328 | "display_name": "Python 3", 329 | "language": "python", 330 | "name": "python3" 331 | }, 332 | "language_info": { 333 | "codemirror_mode": { 334 | "name": "ipython", 335 | "version": 3 336 | }, 337 | "file_extension": ".py", 338 | "mimetype": "text/x-python", 339 | "name": "python", 340 | "nbconvert_exporter": "python", 341 | "pygments_lexer": "ipython3", 342 | "version": "3.6.9" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 2 347 | } 348 | -------------------------------------------------------------------------------- /super-resolution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Code for **super-resolution** (figures $1$ and $5$ from main paper).. Change `factor` to $8$ to reproduce images from fig. $9$ from supmat.\n", 8 | "\n", 9 | "You can play with parameters and see how they affect the result. " 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "\"\"\"\n", 19 | "*Uncomment if running on colab* \n", 20 | "Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab \n", 21 | "\"\"\"\n", 22 | "# !git clone https://github.com/DmitryUlyanov/deep-image-prior\n", 23 | "# !mv deep-image-prior/* ./" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Import libs" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "from __future__ import print_function\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "%matplotlib inline\n", 42 | "\n", 43 | "import argparse\n", 44 | "import os\n", 45 | "# os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", 46 | "\n", 47 | "import numpy as np\n", 48 | "from models import *\n", 49 | "\n", 50 | "import torch\n", 51 | "import torch.optim\n", 52 | "\n", 53 | "from skimage.measure import compare_psnr\n", 54 | "from models.downsampler import Downsampler\n", 55 | "\n", 56 | "from utils.sr_utils import *\n", 57 | "\n", 58 | "torch.backends.cudnn.enabled = True\n", 59 | "torch.backends.cudnn.benchmark =True\n", 60 | "dtype = torch.cuda.FloatTensor\n", 61 | "\n", 62 | "imsize = -1 \n", 63 | "factor = 4 # 8\n", 64 | "enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by a power of two (32 in this case)\n", 65 | "PLOT = True\n", 66 | "\n", 67 | "# To produce images from the paper we took *_GT.png images from LapSRN viewer for corresponding factor,\n", 68 | "# e.g. x4/zebra_GT.png for factor=4, and x8/zebra_GT.png for factor=8 \n", 69 | "path_to_image = 'data/sr/zebra_GT.png'" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# Load image and baselines" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# Starts here\n", 86 | "imgs = load_LR_HR_imgs_sr(path_to_image , imsize, factor, enforse_div32)\n", 87 | "\n", 88 | "imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] = get_baselines(imgs['LR_pil'], imgs['HR_pil'])\n", 89 | "\n", 90 | "if PLOT:\n", 91 | " plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np']], 4,12);\n", 92 | " print ('PSNR bicubic: %.4f PSNR nearest: %.4f' % (\n", 93 | " compare_psnr(imgs['HR_np'], imgs['bicubic_np']), \n", 94 | " compare_psnr(imgs['HR_np'], imgs['nearest_np'])))" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "# Set up parameters and net" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "input_depth = 32\n", 111 | " \n", 112 | "INPUT = 'noise'\n", 113 | "pad = 'reflection'\n", 114 | "OPT_OVER = 'net'\n", 115 | "KERNEL_TYPE='lanczos2'\n", 116 | "\n", 117 | "LR = 0.01\n", 118 | "tv_weight = 0.0\n", 119 | "\n", 120 | "OPTIMIZER = 'adam'\n", 121 | "\n", 122 | "if factor == 4: \n", 123 | " num_iter = 2000\n", 124 | " reg_noise_std = 0.03\n", 125 | "elif factor == 8:\n", 126 | " num_iter = 4000\n", 127 | " reg_noise_std = 0.05\n", 128 | "else:\n", 129 | " assert False, 'We did not experiment with other factors'" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach()\n", 139 | "\n", 140 | "NET_TYPE = 'skip' # UNet, ResNet\n", 141 | "net = get_net(input_depth, 'skip', pad,\n", 142 | " skip_n33d=128, \n", 143 | " skip_n33u=128, \n", 144 | " skip_n11=4, \n", 145 | " num_scales=5,\n", 146 | " upsample_mode='bilinear').type(dtype)\n", 147 | "\n", 148 | "# Losses\n", 149 | "mse = torch.nn.MSELoss().type(dtype)\n", 150 | "\n", 151 | "img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)\n", 152 | "\n", 153 | "downsampler = Downsampler(n_planes=3, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "# Define closure and optimize" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "def closure():\n", 170 | " global i, net_input\n", 171 | " \n", 172 | " if reg_noise_std > 0:\n", 173 | " net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n", 174 | "\n", 175 | " out_HR = net(net_input)\n", 176 | " out_LR = downsampler(out_HR)\n", 177 | "\n", 178 | " total_loss = mse(out_LR, img_LR_var) \n", 179 | " \n", 180 | " if tv_weight > 0:\n", 181 | " total_loss += tv_weight * tv_loss(out_HR)\n", 182 | " \n", 183 | " total_loss.backward()\n", 184 | "\n", 185 | " # Log\n", 186 | " psnr_LR = compare_psnr(imgs['LR_np'], torch_to_np(out_LR))\n", 187 | " psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR))\n", 188 | " print ('Iteration %05d PSNR_LR %.3f PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\\r', end='')\n", 189 | " \n", 190 | " # History\n", 191 | " psnr_history.append([psnr_LR, psnr_HR])\n", 192 | " \n", 193 | " if PLOT and i % 100 == 0:\n", 194 | " out_HR_np = torch_to_np(out_HR)\n", 195 | " plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)\n", 196 | "\n", 197 | " i += 1\n", 198 | " \n", 199 | " return total_loss" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "psnr_history = [] \n", 209 | "net_input_saved = net_input.detach().clone()\n", 210 | "noise = net_input.detach().clone()\n", 211 | "\n", 212 | "i = 0\n", 213 | "p = get_params(OPT_OVER, net, net_input)\n", 214 | "optimize(OPTIMIZER, p, closure, LR, num_iter)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)\n", 224 | "result_deep_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])\n", 225 | "\n", 226 | "# For the paper we acually took `_bicubic.png` files from LapSRN viewer and used `result_deep_prior` as our result\n", 227 | "plot_image_grid([imgs['HR_np'],\n", 228 | " imgs['bicubic_np'],\n", 229 | " out_HR_np], factor=4, nrow=1);" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [] 238 | } 239 | ], 240 | "metadata": { 241 | "kernelspec": { 242 | "display_name": "Python 3", 243 | "language": "python", 244 | "name": "python3" 245 | }, 246 | "language_info": { 247 | "codemirror_mode": { 248 | "name": "ipython", 249 | "version": 3 250 | }, 251 | "file_extension": ".py", 252 | "mimetype": "text/x-python", 253 | "name": "python", 254 | "nbconvert_exporter": "python", 255 | "pygments_lexer": "ipython3", 256 | "version": "3.6.9" 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 2 261 | } 262 | -------------------------------------------------------------------------------- /super-resolution_eval_script.py: -------------------------------------------------------------------------------- 1 | # This script had been used to get the numbers in the paper 2 | from utils.common_utils import get_image, plot_image_grid 3 | import cv2 4 | def rgb2ycbcr(im_rgb): 5 | im_rgb = im_rgb.astype(np.float32) 6 | im_ycrcb = cv2.cvtColor(im_rgb, cv2.COLOR_RGB2YCR_CB) 7 | im_ycbcr = im_ycrcb[:,:,(0,2,1)].astype(np.float32) 8 | im_ycbcr[:,:,0] = (im_ycbcr[:,:,0]*(235-16)+16)/255.0 #to [16/255, 235/255] 9 | im_ycbcr[:,:,1:] = (im_ycbcr[:,:,1:]*(240-16)+16)/255.0 #to [16/255, 240/255] 10 | return im_ycbcr 11 | 12 | def compare_psnr_y(x, y): 13 | return compare_psnr(rgb2ycbcr(x.transpose(1,2,0))[:,:,0], rgb2ycbcr(y.transpose(1,2,0))[:,:,0]) 14 | 15 | from collections import defaultdict 16 | datasets = { 17 | 'Set14': ["baboon", "barbara", "bridge", "coastguard", "comic", "face", "flowers", "foreman", "lenna", "man", "monarch", "pepper", "ppt3", "zebra"], 18 | # 'Set5': ['baby', 'bird', 'butterfly', 'head', 'woman'] 19 | } 20 | from glob import glob 21 | # g = sorted(glob('../image_compare/data/sr/Set5/x4/*')) 22 | 23 | from skimage.measure import compare_psnr 24 | # our 25 | stats = {} 26 | imsize = -1 27 | dct = defaultdict(lambda : 0) 28 | for cur_dataset in datasets.keys(): 29 | 30 | for method_name in postfixes: 31 | psnrs = [] 32 | for name in datasets[cur_dataset]: 33 | img_HR = f'/home/dulyanov/dmitryulyanov.github.io/assets/deep-image-prior/SR/{cur_dataset}/x4/{name}_GT.png' 34 | ours = f'/home/dulyanov/dmitryulyanov.github.io/assets/deep-image-prior/SR/{cur_dataset}/x4/{name}_deep_prior.png' 35 | method = f'/home/dulyanov/dmitryulyanov.github.io/assets/deep-image-prior/SR/{cur_dataset}/x4/{name}_{method_name}.png' 36 | 37 | gt_pil, gt = get_image(img_HR, imsize) 38 | ours_pil, ours = get_image(ours, imsize) 39 | method_pil, methods = get_image(method, imsize) 40 | 41 | if methods.shape[0] == 1: 42 | methods = np.concatenate([methods, methods, methods], 0) 43 | 44 | q1 = ours[:3].sum(0) 45 | t1 = np.where(q1.sum(0) > 0)[0] 46 | t2 = np.where(q1.sum(1) > 0)[0] 47 | 48 | 49 | 50 | psnr = compare_psnr_y(gt [:3,t2[0] + 4:t2[-1]-4,t1[0] + 4:t1[-1] - 4], 51 | methods[:3,t2[0] + 4:t2[-1]-4,t1[0] + 4:t1[-1] - 4]) 52 | 53 | # psnr = compare_psnr(gt [:3], 54 | # ours[:3]) 55 | 56 | psnrs.append(psnr) 57 | 58 | print(name, psnr) 59 | 60 | 61 | header = f'\small{{{method_name}}} & ' + ' & '.join([f'${x:.4}$' for x in psnrs]) 62 | 63 | stats[method_name] = [header, np.mean(psnrs)] 64 | 65 | print (header) 66 | 67 | names = datasets[cur_dataset] 68 | header = ' & ' + ' & '.join([f'\small{{{x.title()}}}' for x in names]) 69 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/utils/__init__.py -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import PIL 9 | import numpy as np 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | def crop_image(img, d=32): 14 | '''Make dimensions divisible by `d`''' 15 | 16 | new_size = (img.size[0] - img.size[0] % d, 17 | img.size[1] - img.size[1] % d) 18 | 19 | bbox = [ 20 | int((img.size[0] - new_size[0])/2), 21 | int((img.size[1] - new_size[1])/2), 22 | int((img.size[0] + new_size[0])/2), 23 | int((img.size[1] + new_size[1])/2), 24 | ] 25 | 26 | img_cropped = img.crop(bbox) 27 | return img_cropped 28 | 29 | def get_params(opt_over, net, net_input, downsampler=None): 30 | '''Returns parameters that we want to optimize over. 31 | 32 | Args: 33 | opt_over: comma separated list, e.g. "net,input" or "net" 34 | net: network 35 | net_input: torch.Tensor that stores input `z` 36 | ''' 37 | opt_over_list = opt_over.split(',') 38 | params = [] 39 | 40 | for opt in opt_over_list: 41 | 42 | if opt == 'net': 43 | params += [x for x in net.parameters() ] 44 | elif opt=='down': 45 | assert downsampler is not None 46 | params = [x for x in downsampler.parameters()] 47 | elif opt == 'input': 48 | net_input.requires_grad = True 49 | params += [net_input] 50 | else: 51 | assert False, 'what is it?' 52 | 53 | return params 54 | 55 | def get_image_grid(images_np, nrow=8): 56 | '''Creates a grid from a list of images by concatenating them.''' 57 | images_torch = [torch.from_numpy(x) for x in images_np] 58 | torch_grid = torchvision.utils.make_grid(images_torch, nrow) 59 | 60 | return torch_grid.numpy() 61 | 62 | def plot_image_grid(images_np, nrow =8, factor=1, interpolation='lanczos'): 63 | """Draws images in a grid 64 | 65 | Args: 66 | images_np: list of images, each image is np.array of size 3xHxW of 1xHxW 67 | nrow: how many images will be in one row 68 | factor: size if the plt.figure 69 | interpolation: interpolation used in plt.imshow 70 | """ 71 | n_channels = max(x.shape[0] for x in images_np) 72 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels" 73 | 74 | images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np] 75 | 76 | grid = get_image_grid(images_np, nrow) 77 | 78 | plt.figure(figsize=(len(images_np) + factor, 12 + factor)) 79 | 80 | if images_np[0].shape[0] == 1: 81 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation) 82 | else: 83 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation) 84 | 85 | plt.show() 86 | 87 | return grid 88 | 89 | def load(path): 90 | """Load PIL image.""" 91 | img = Image.open(path) 92 | return img 93 | 94 | def get_image(path, imsize=-1): 95 | """Load an image and resize to a cpecific size. 96 | 97 | Args: 98 | path: path to image 99 | imsize: tuple or scalar with dimensions; -1 for `no resize` 100 | """ 101 | img = load(path) 102 | 103 | if isinstance(imsize, int): 104 | imsize = (imsize, imsize) 105 | 106 | if imsize[0]!= -1 and img.size != imsize: 107 | if imsize[0] > img.size[0]: 108 | img = img.resize(imsize, Image.BICUBIC) 109 | else: 110 | img = img.resize(imsize, Image.ANTIALIAS) 111 | 112 | img_np = pil_to_np(img) 113 | 114 | return img, img_np 115 | 116 | 117 | 118 | def fill_noise(x, noise_type): 119 | """Fills tensor `x` with noise of type `noise_type`.""" 120 | if noise_type == 'u': 121 | x.uniform_() 122 | elif noise_type == 'n': 123 | x.normal_() 124 | else: 125 | assert False 126 | 127 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10): 128 | """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 129 | initialized in a specific way. 130 | Args: 131 | input_depth: number of channels in the tensor 132 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid 133 | spatial_size: spatial size of the tensor to initialize 134 | noise_type: 'u' for uniform; 'n' for normal 135 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 136 | """ 137 | if isinstance(spatial_size, int): 138 | spatial_size = (spatial_size, spatial_size) 139 | if method == 'noise': 140 | shape = [1, input_depth, spatial_size[0], spatial_size[1]] 141 | net_input = torch.zeros(shape) 142 | 143 | fill_noise(net_input, noise_type) 144 | net_input *= var 145 | elif method == 'meshgrid': 146 | assert input_depth == 2 147 | X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1)) 148 | meshgrid = np.concatenate([X[None,:], Y[None,:]]) 149 | net_input= np_to_torch(meshgrid) 150 | else: 151 | assert False 152 | 153 | return net_input 154 | 155 | def pil_to_np(img_PIL): 156 | '''Converts image in PIL format to np.array. 157 | 158 | From W x H x C [0...255] to C x W x H [0..1] 159 | ''' 160 | ar = np.array(img_PIL) 161 | 162 | if len(ar.shape) == 3: 163 | ar = ar.transpose(2,0,1) 164 | else: 165 | ar = ar[None, ...] 166 | 167 | return ar.astype(np.float32) / 255. 168 | 169 | def np_to_pil(img_np): 170 | '''Converts image in np.array format to PIL image. 171 | 172 | From C x W x H [0..1] to W x H x C [0...255] 173 | ''' 174 | ar = np.clip(img_np*255,0,255).astype(np.uint8) 175 | 176 | if img_np.shape[0] == 1: 177 | ar = ar[0] 178 | else: 179 | ar = ar.transpose(1, 2, 0) 180 | 181 | return Image.fromarray(ar) 182 | 183 | def np_to_torch(img_np): 184 | '''Converts image in numpy.array to torch.Tensor. 185 | 186 | From C x W x H [0..1] to C x W x H [0..1] 187 | ''' 188 | return torch.from_numpy(img_np)[None, :] 189 | 190 | def torch_to_np(img_var): 191 | '''Converts an image in torch.Tensor format to np.array. 192 | 193 | From 1 x C x W x H [0..1] to C x W x H [0..1] 194 | ''' 195 | return img_var.detach().cpu().numpy()[0] 196 | 197 | 198 | def optimize(optimizer_type, parameters, closure, LR, num_iter): 199 | """Runs optimization loop. 200 | 201 | Args: 202 | optimizer_type: 'LBFGS' of 'adam' 203 | parameters: list of Tensors to optimize over 204 | closure: function, that returns loss variable 205 | LR: learning rate 206 | num_iter: number of iterations 207 | """ 208 | if optimizer_type == 'LBFGS': 209 | # Do several steps with adam first 210 | optimizer = torch.optim.Adam(parameters, lr=0.001) 211 | for j in range(100): 212 | optimizer.zero_grad() 213 | closure() 214 | optimizer.step() 215 | 216 | print('Starting optimization with LBFGS') 217 | def closure2(): 218 | optimizer.zero_grad() 219 | return closure() 220 | optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1) 221 | optimizer.step(closure2) 222 | 223 | elif optimizer_type == 'adam': 224 | print('Starting optimization with ADAM') 225 | optimizer = torch.optim.Adam(parameters, lr=LR) 226 | 227 | for j in range(num_iter): 228 | optimizer.zero_grad() 229 | closure() 230 | optimizer.step() 231 | else: 232 | assert False -------------------------------------------------------------------------------- /utils/denoising_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .common_utils import * 3 | 4 | 5 | 6 | def get_noisy_image(img_np, sigma): 7 | """Adds Gaussian noise to an image. 8 | 9 | Args: 10 | img_np: image, np.array with values from 0 to 1 11 | sigma: std of the noise 12 | """ 13 | img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma, size=img_np.shape), 0, 1).astype(np.float32) 14 | img_noisy_pil = np_to_pil(img_noisy_np) 15 | 16 | return img_noisy_pil, img_noisy_np -------------------------------------------------------------------------------- /utils/feature_inversion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.models as models 5 | from .matcher import Matcher 6 | import os 7 | from collections import OrderedDict 8 | 9 | class View(nn.Module): 10 | def __init__(self): 11 | super(View, self).__init__() 12 | 13 | def forward(self, x): 14 | return x.view(-1) 15 | 16 | def get_vanilla_vgg_features(cut_idx=-1): 17 | if not os.path.exists('vgg_features.pth'): 18 | os.system( 19 | 'wget --no-check-certificate -N https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth') 20 | vgg_weights = torch.load('vgg19-d01eb7cb.pth') 21 | # fix compatibility issues 22 | map = {'classifier.6.weight':u'classifier.7.weight', 'classifier.6.bias':u'classifier.7.bias'} 23 | vgg_weights = OrderedDict([(map[k] if k in map else k,v) for k,v in vgg_weights.iteritems()]) 24 | 25 | 26 | 27 | model = models.vgg19() 28 | model.classifier = nn.Sequential(View(), *model.classifier._modules.values()) 29 | 30 | 31 | model.load_state_dict(vgg_weights) 32 | 33 | torch.save(model.features, 'vgg_features.pth') 34 | torch.save(model.classifier, 'vgg_classifier.pth') 35 | 36 | vgg = torch.load('vgg_features.pth') 37 | if cut_idx > 36: 38 | vgg_classifier = torch.load('vgg_classifier.pth') 39 | vgg = nn.Sequential(*(vgg._modules.values() + vgg_classifier._modules.values())) 40 | 41 | vgg.eval() 42 | 43 | return vgg 44 | 45 | 46 | def get_matcher(net, opt): 47 | idxs = [x for x in opt['layers'].split(',')] 48 | matcher = Matcher(opt['what']) 49 | 50 | def hook(module, input, output): 51 | matcher(module, output) 52 | 53 | for i in idxs: 54 | net._modules[i].register_forward_hook(hook) 55 | 56 | return matcher 57 | 58 | 59 | 60 | def get_vgg(cut_idx=-1): 61 | f = get_vanilla_vgg_features(cut_idx) 62 | 63 | if cut_idx > 0: 64 | num_modules = len(f._modules) 65 | keys_to_delete = [f._modules.keys()[x] for x in range(cut_idx, num_modules)] 66 | for k in keys_to_delete: 67 | del f._modules[k] 68 | 69 | return f 70 | 71 | def vgg_preprocess_var(var): 72 | (r, g, b) = torch.chunk(var, 3, dim=1) 73 | bgr = torch.cat((b, g, r), 1) 74 | out = bgr * 255 - torch.autograd.Variable(vgg_mean[None, ...]).type(var.type()).expand_as(bgr) 75 | return out 76 | 77 | vgg_mean = torch.FloatTensor([103.939, 116.779, 123.680]).view(3, 1, 1) 78 | 79 | 80 | 81 | def get_preprocessor(imsize): 82 | def vgg_preprocess(tensor): 83 | (r, g, b) = torch.chunk(tensor, 3, dim=0) 84 | bgr = torch.cat((b, g, r), 0) 85 | out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr) 86 | return out 87 | preprocess = transforms.Compose([ 88 | transforms.Resize(imsize), 89 | transforms.ToTensor(), 90 | transforms.Lambda(vgg_preprocess) 91 | ]) 92 | 93 | return preprocess 94 | 95 | 96 | def get_deprocessor(): 97 | def vgg_deprocess(tensor): 98 | bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0 99 | (b, g, r) = torch.chunk(bgr, 3, dim=0) 100 | rgb = torch.cat((r, g, b), 0) 101 | return rgb 102 | deprocess = transforms.Compose([ 103 | transforms.Lambda(vgg_deprocess), 104 | transforms.Lambda(lambda x: torch.clamp(x, 0, 1)), 105 | transforms.ToPILImage() 106 | ]) 107 | return deprocess 108 | -------------------------------------------------------------------------------- /utils/inpainting_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import PIL.ImageDraw as ImageDraw 4 | import PIL.ImageFont as ImageFont 5 | from .common_utils import * 6 | 7 | def get_text_mask(for_image, sz=20): 8 | font_fname = '/usr/share/fonts/truetype/freefont/FreeSansBold.ttf' 9 | font_size = sz 10 | font = ImageFont.truetype(font_fname, font_size) 11 | 12 | img_mask = Image.fromarray(np.array(for_image)*0+255) 13 | draw = ImageDraw.Draw(img_mask) 14 | draw.text((128, 128), "hello world", font=font, fill='rgb(0, 0, 0)') 15 | 16 | return img_mask 17 | 18 | def get_bernoulli_mask(for_image, zero_fraction=0.95): 19 | img_mask_np=(np.random.random_sample(size=pil_to_np(for_image).shape) > zero_fraction).astype(int) 20 | img_mask = np_to_pil(img_mask_np) 21 | 22 | return img_mask 23 | -------------------------------------------------------------------------------- /utils/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Matcher: 5 | def __init__(self, how='gram_matrix', loss='mse'): 6 | self.mode = 'store' 7 | self.stored = {} 8 | self.losses = {} 9 | 10 | if how in all_features.keys(): 11 | self.get_statistics = all_features[how] 12 | else: 13 | assert False 14 | pass 15 | 16 | if loss in all_losses.keys(): 17 | self.loss = all_losses[loss] 18 | else: 19 | assert False 20 | 21 | def __call__(self, module, features): 22 | statistics = self.get_statistics(features) 23 | 24 | self.statistics = statistics 25 | if self.mode == 'store': 26 | self.stored[module] = statistics.detach().clone() 27 | elif self.mode == 'match': 28 | self.losses[module] = self.loss(statistics, self.stored[module]) 29 | 30 | def clean(self): 31 | self.losses = {} 32 | 33 | def gram_matrix(x): 34 | (b, ch, h, w) = x.size() 35 | features = x.view(b, ch, w * h) 36 | features_t = features.transpose(1, 2) 37 | gram = features.bmm(features_t) / (ch * h * w) 38 | return gram 39 | 40 | 41 | def features(x): 42 | return x 43 | 44 | 45 | all_features = { 46 | 'gram_matrix': gram_matrix, 47 | 'features': features, 48 | } 49 | 50 | all_losses = { 51 | 'mse': nn.MSELoss(), 52 | 'smoothL1': nn.SmoothL1Loss(), 53 | 'L1': nn.L1Loss(), 54 | } 55 | -------------------------------------------------------------------------------- /utils/perceptual_loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DmitryUlyanov/deep-image-prior/042e0d4c1e93f4b1eb0932781de55b8cff5e0f40/utils/perceptual_loss/__init__.py -------------------------------------------------------------------------------- /utils/perceptual_loss/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Matcher: 6 | def __init__(self, how='gram_matrix', loss='mse', map_index=933): 7 | self.mode = 'store' 8 | self.stored = {} 9 | self.losses = {} 10 | 11 | if how in all_features.keys(): 12 | self.get_statistics = all_features[how] 13 | else: 14 | assert False 15 | pass 16 | 17 | if loss in all_losses.keys(): 18 | self.loss = all_losses[loss] 19 | else: 20 | assert False 21 | 22 | self.map_index = map_index 23 | self.method = 'match' 24 | 25 | 26 | def __call__(self, module, features): 27 | statistics = self.get_statistics(features) 28 | 29 | self.statistics = statistics 30 | if self.mode == 'store': 31 | self.stored[module] = statistics.detach() 32 | 33 | elif self.mode == 'match': 34 | 35 | if statistics.ndimension() == 2: 36 | 37 | if self.method == 'maximize': 38 | self.losses[module] = - statistics[0, self.map_index] 39 | else: 40 | self.losses[module] = torch.abs(300 - statistics[0, self.map_index]) 41 | 42 | else: 43 | ws = self.window_size 44 | 45 | t = statistics.detach() * 0 46 | 47 | s_cc = statistics[:1, :, t.shape[2] // 2 - ws:t.shape[2] // 2 + ws, t.shape[3] // 2 - ws:t.shape[3] // 2 + ws] #* 1.0 48 | t_cc = t[:1, :, t.shape[2] // 2 - ws:t.shape[2] // 2 + ws, t.shape[3] // 2 - ws:t.shape[3] // 2 + ws] #* 1.0 49 | t_cc[:, self.map_index,...] = 1 50 | 51 | if self.method == 'maximize': 52 | self.losses[module] = -(s_cc * t_cc.contiguous()).sum() 53 | else: 54 | self.losses[module] = torch.abs(200 -(s_cc * t_cc.contiguous())).sum() 55 | 56 | 57 | def clean(self): 58 | self.losses = {} 59 | 60 | def gram_matrix(x): 61 | (b, ch, h, w) = x.size() 62 | features = x.view(b, ch, w * h) 63 | features_t = features.transpose(1, 2) 64 | gram = features.bmm(features_t) / (ch * h * w) 65 | return gram 66 | 67 | 68 | def features(x): 69 | return x 70 | 71 | 72 | all_features = { 73 | 'gram_matrix': gram_matrix, 74 | 'features': features, 75 | } 76 | 77 | all_losses = { 78 | 'mse': nn.MSELoss(), 79 | 'smoothL1': nn.SmoothL1Loss(), 80 | 'L1': nn.L1Loss(), 81 | } 82 | -------------------------------------------------------------------------------- /utils/perceptual_loss/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | import torchvision.models as models 6 | from .matcher import Matcher 7 | from collections import OrderedDict 8 | 9 | from torchvision.models.vgg import model_urls 10 | from torchvision.models import vgg19 11 | from torch.autograd import Variable 12 | 13 | from .vgg_modified import VGGModified 14 | 15 | def get_pretrained_net(name): 16 | """Loads pretrained network""" 17 | if name == 'alexnet_caffe': 18 | if not os.path.exists('alexnet-torch_py3.pth'): 19 | print('Downloading AlexNet') 20 | os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download') 21 | return torch.load('alexnet-torch_py3.pth') 22 | elif name == 'vgg19_caffe': 23 | if not os.path.exists('vgg19-caffe-py3.pth'): 24 | print('Downloading VGG-19') 25 | os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download') 26 | 27 | vgg = get_vgg19_caffe() 28 | 29 | return vgg 30 | elif name == 'vgg16_caffe': 31 | if not os.path.exists('vgg16-caffe-py3.pth'): 32 | print('Downloading VGG-16') 33 | os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download') 34 | 35 | vgg = get_vgg16_caffe() 36 | 37 | return vgg 38 | elif name == 'vgg19_pytorch_modified': 39 | # os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1') 40 | 41 | model = VGGModified(vgg19(pretrained=False), 0.2) 42 | model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict']) 43 | 44 | return model 45 | else: 46 | assert False 47 | 48 | 49 | class PerceputalLoss(nn.modules.loss._Loss): 50 | """ 51 | Assumes input image is in range [0,1] if `input_range` is 'sigmoid', [-1, 1] if 'tanh' 52 | """ 53 | def __init__(self, input_range='sigmoid', 54 | net_type = 'vgg_torch', 55 | input_preprocessing='corresponding', 56 | match=[{'layers':[11,20,29],'what':'features'}]): 57 | 58 | if input_range not in ['sigmoid', 'tanh']: 59 | assert False 60 | 61 | self.net = get_pretrained_net(net_type).cuda() 62 | 63 | self.matchers = [get_matcher(self.net, match_opts) for match_opts in match] 64 | 65 | preprocessing_correspondence = { 66 | 'vgg19_torch': vgg_preprocess_caffe, 67 | 'vgg16_torch': vgg_preprocess_caffe, 68 | 'vgg19_pytorch': vgg_preprocess_pytorch, 69 | 'vgg19_pytorch_modified': vgg_preprocess_pytorch, 70 | } 71 | 72 | if input_preprocessing == 'corresponding': 73 | self.preprocess_input = preprocessing_correspondence[net_type] 74 | else: 75 | self.preprocessing = preprocessing_correspondence[input_preprocessing] 76 | 77 | def preprocess_input(self, x): 78 | if self.input_range == 'tanh': 79 | x = (x + 1.) / 2. 80 | 81 | return self.preprocess(x) 82 | 83 | def __call__(self, x, y): 84 | 85 | # for 86 | self.matcher_content.mode = 'store' 87 | self.net(self.preprocess_input(y)); 88 | 89 | self.matcher_content.mode = 'match' 90 | self.net(self.preprocess_input(x)); 91 | 92 | return sum([sum(matcher.losses.values()) for matcher in self.matchers]) 93 | 94 | 95 | def get_vgg19_caffe(): 96 | model = vgg19() 97 | model.classifier = nn.Sequential(View(), *model.classifier._modules.values()) 98 | vgg = model.features 99 | vgg_classifier = model.classifier 100 | 101 | names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1', 102 | 'conv2_1','relu2_1','conv2_2','relu2_2','pool2', 103 | 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3', 104 | 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4', 105 | 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4','pool5', 106 | 'torch_view','fc6','relu6','drop6','fc7','relu7','drop7','fc8'] 107 | 108 | model = nn.Sequential() 109 | for n, m in zip(names, list(vgg) + list(vgg_classifier)): 110 | model.add_module(n, m) 111 | 112 | model.load_state_dict(torch.load('vgg19-caffe-py3.pth')) 113 | 114 | return model 115 | 116 | def get_vgg16_caffe(): 117 | vgg = torch.load('vgg16-caffe-py3.pth') 118 | 119 | names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1', 120 | 'conv2_1','relu2_1','conv2_2','relu2_2','pool2', 121 | 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','pool3', 122 | 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4', 123 | 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5', 124 | 'torch_view','fc6','relu6','drop6','fc7','relu7','fc8'] 125 | 126 | model = nn.Sequential() 127 | for n, m in zip(names, list(vgg)): 128 | model.add_module(n, m) 129 | 130 | # model.load_state_dict(torch.load('vgg19-caffe-py3.pth')) 131 | 132 | return model 133 | 134 | 135 | class View(nn.Module): 136 | def __init__(self): 137 | super(View, self).__init__() 138 | 139 | def forward(self, x): 140 | return x.view(x.size(0), -1) 141 | 142 | 143 | def get_matcher(vgg, opt): 144 | # idxs = [int(x) for x in opt['layers'].split(',')] 145 | matcher = Matcher(opt['what'], 'mse', opt['map_idx']) 146 | 147 | def hook(module, input, output): 148 | matcher(module, output) 149 | 150 | for layer_name in opt['layers']: 151 | vgg._modules[layer_name].register_forward_hook(hook) 152 | 153 | return matcher 154 | 155 | 156 | def get_vgg(cut_idx=-1, vgg_type='pytorch'): 157 | f = get_vanilla_vgg_features(cut_idx, vgg_type) 158 | 159 | keys = [x for x in cnn._modules.keys()] 160 | max_idx = max(keys.index(x) for x in opt_content['layers'].split(',')) 161 | for k in keys[max_idx+1:]: 162 | cnn._modules.pop(k) 163 | 164 | return f 165 | 166 | vgg_mean = torch.FloatTensor([103.939, 116.779, 123.680]).view(3, 1, 1) 167 | def vgg_preprocess_caffe(var): 168 | (r, g, b) = torch.chunk(var, 3, dim=1) 169 | bgr = torch.cat((b, g, r), 1) 170 | out = bgr * 255 - torch.autograd.Variable(vgg_mean).type(var.type()) 171 | return out 172 | 173 | 174 | 175 | mean_pytorch = Variable(torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) 176 | std_pytorch = Variable(torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)) 177 | 178 | def vgg_preprocess_pytorch(var): 179 | return (var - mean_pytorch.type_as(var))/std_pytorch.type_as(var) 180 | 181 | 182 | 183 | def get_preprocessor(imsize): 184 | def vgg_preprocess(tensor): 185 | (r, g, b) = torch.chunk(tensor, 3, dim=0) 186 | bgr = torch.cat((b, g, r), 0) 187 | out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr) 188 | return out 189 | preprocess = transforms.Compose([ 190 | transforms.Resize(imsize), 191 | transforms.ToTensor(), 192 | transforms.Lambda(vgg_preprocess) 193 | ]) 194 | 195 | return preprocess 196 | 197 | 198 | def get_deprocessor(): 199 | def vgg_deprocess(tensor): 200 | bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0 201 | (b, g, r) = torch.chunk(bgr, 3, dim=0) 202 | rgb = torch.cat((r, g, b), 0) 203 | return rgb 204 | deprocess = transforms.Compose([ 205 | transforms.Lambda(vgg_deprocess), 206 | transforms.Lambda(lambda x: torch.clamp(x, 0, 1)), 207 | transforms.ToPILImage() 208 | ]) 209 | return deprocess 210 | 211 | -------------------------------------------------------------------------------- /utils/perceptual_loss/vgg_modified.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class VGGModified(nn.Module): 4 | def __init__(self, vgg19_orig, slope=0.01): 5 | super(VGGModified, self).__init__() 6 | self.features = nn.Sequential() 7 | 8 | self.features.add_module(str(0), vgg19_orig.features[0]) 9 | self.features.add_module(str(1), nn.LeakyReLU(slope, True)) 10 | self.features.add_module(str(2), vgg19_orig.features[2]) 11 | self.features.add_module(str(3), nn.LeakyReLU(slope, True)) 12 | self.features.add_module(str(4), nn.AvgPool2d((2,2), (2,2))) 13 | 14 | self.features.add_module(str(5), vgg19_orig.features[5]) 15 | self.features.add_module(str(6), nn.LeakyReLU(slope, True)) 16 | self.features.add_module(str(7), vgg19_orig.features[7]) 17 | self.features.add_module(str(8), nn.LeakyReLU(slope, True)) 18 | self.features.add_module(str(9), nn.AvgPool2d((2,2), (2,2))) 19 | 20 | self.features.add_module(str(10), vgg19_orig.features[10]) 21 | self.features.add_module(str(11), nn.LeakyReLU(slope, True)) 22 | self.features.add_module(str(12), vgg19_orig.features[12]) 23 | self.features.add_module(str(13), nn.LeakyReLU(slope, True)) 24 | self.features.add_module(str(14), vgg19_orig.features[14]) 25 | self.features.add_module(str(15), nn.LeakyReLU(slope, True)) 26 | self.features.add_module(str(16), vgg19_orig.features[16]) 27 | self.features.add_module(str(17), nn.LeakyReLU(slope, True)) 28 | self.features.add_module(str(18), nn.AvgPool2d((2,2), (2,2))) 29 | 30 | self.features.add_module(str(19), vgg19_orig.features[19]) 31 | self.features.add_module(str(20), nn.LeakyReLU(slope, True)) 32 | self.features.add_module(str(21), vgg19_orig.features[21]) 33 | self.features.add_module(str(22), nn.LeakyReLU(slope, True)) 34 | self.features.add_module(str(23), vgg19_orig.features[23]) 35 | self.features.add_module(str(24), nn.LeakyReLU(slope, True)) 36 | self.features.add_module(str(25), vgg19_orig.features[25]) 37 | self.features.add_module(str(26), nn.LeakyReLU(slope, True)) 38 | self.features.add_module(str(27), nn.AvgPool2d((2,2), (2,2))) 39 | 40 | self.features.add_module(str(28), vgg19_orig.features[28]) 41 | self.features.add_module(str(29), nn.LeakyReLU(slope, True)) 42 | self.features.add_module(str(30), vgg19_orig.features[30]) 43 | self.features.add_module(str(31), nn.LeakyReLU(slope, True)) 44 | self.features.add_module(str(32), vgg19_orig.features[32]) 45 | self.features.add_module(str(33), nn.LeakyReLU(slope, True)) 46 | self.features.add_module(str(34), vgg19_orig.features[34]) 47 | self.features.add_module(str(35), nn.LeakyReLU(slope, True)) 48 | self.features.add_module(str(36), nn.AvgPool2d((2,2), (2,2))) 49 | 50 | self.classifier = nn.Sequential() 51 | 52 | self.classifier.add_module(str(0), vgg19_orig.classifier[0]) 53 | self.classifier.add_module(str(1), nn.LeakyReLU(slope, True)) 54 | self.classifier.add_module(str(2), nn.Dropout2d(p = 0.5)) 55 | self.classifier.add_module(str(3), vgg19_orig.classifier[3]) 56 | self.classifier.add_module(str(4), nn.LeakyReLU(slope, True)) 57 | self.classifier.add_module(str(5), nn.Dropout2d(p = 0.5)) 58 | self.classifier.add_module(str(6), vgg19_orig.classifier[6]) 59 | 60 | def forward(self, x): 61 | return self.classifier(self.features.forward(x)) -------------------------------------------------------------------------------- /utils/sr_utils.py: -------------------------------------------------------------------------------- 1 | from .common_utils import * 2 | 3 | def put_in_center(img_np, target_size): 4 | img_out = np.zeros([3, target_size[0], target_size[1]]) 5 | 6 | bbox = [ 7 | int((target_size[0] - img_np.shape[1]) / 2), 8 | int((target_size[1] - img_np.shape[2]) / 2), 9 | int((target_size[0] + img_np.shape[1]) / 2), 10 | int((target_size[1] + img_np.shape[2]) / 2), 11 | ] 12 | 13 | img_out[:, bbox[0]:bbox[2], bbox[1]:bbox[3]] = img_np 14 | 15 | return img_out 16 | 17 | 18 | def load_LR_HR_imgs_sr(fname, imsize, factor, enforse_div32=None): 19 | '''Loads an image, resizes it, center crops and downscales. 20 | 21 | Args: 22 | fname: path to the image 23 | imsize: new size for the image, -1 for no resizing 24 | factor: downscaling factor 25 | enforse_div32: if 'CROP' center crops an image, so that its dimensions are divisible by 32. 26 | ''' 27 | img_orig_pil, img_orig_np = get_image(fname, -1) 28 | 29 | if imsize != -1: 30 | img_orig_pil, img_orig_np = get_image(fname, imsize) 31 | 32 | # For comparison with GT 33 | if enforse_div32 == 'CROP': 34 | new_size = (img_orig_pil.size[0] - img_orig_pil.size[0] % 32, 35 | img_orig_pil.size[1] - img_orig_pil.size[1] % 32) 36 | 37 | bbox = [ 38 | (img_orig_pil.size[0] - new_size[0])/2, 39 | (img_orig_pil.size[1] - new_size[1])/2, 40 | (img_orig_pil.size[0] + new_size[0])/2, 41 | (img_orig_pil.size[1] + new_size[1])/2, 42 | ] 43 | 44 | img_HR_pil = img_orig_pil.crop(bbox) 45 | img_HR_np = pil_to_np(img_HR_pil) 46 | else: 47 | img_HR_pil, img_HR_np = img_orig_pil, img_orig_np 48 | 49 | LR_size = [ 50 | img_HR_pil.size[0] // factor, 51 | img_HR_pil.size[1] // factor 52 | ] 53 | 54 | img_LR_pil = img_HR_pil.resize(LR_size, Image.ANTIALIAS) 55 | img_LR_np = pil_to_np(img_LR_pil) 56 | 57 | print('HR and LR resolutions: %s, %s' % (str(img_HR_pil.size), str (img_LR_pil.size))) 58 | 59 | return { 60 | 'orig_pil': img_orig_pil, 61 | 'orig_np': img_orig_np, 62 | 'LR_pil': img_LR_pil, 63 | 'LR_np': img_LR_np, 64 | 'HR_pil': img_HR_pil, 65 | 'HR_np': img_HR_np 66 | } 67 | 68 | 69 | def get_baselines(img_LR_pil, img_HR_pil): 70 | '''Gets `bicubic`, sharpened bicubic and `nearest` baselines.''' 71 | img_bicubic_pil = img_LR_pil.resize(img_HR_pil.size, Image.BICUBIC) 72 | img_bicubic_np = pil_to_np(img_bicubic_pil) 73 | 74 | img_nearest_pil = img_LR_pil.resize(img_HR_pil.size, Image.NEAREST) 75 | img_nearest_np = pil_to_np(img_nearest_pil) 76 | 77 | img_bic_sharp_pil = img_bicubic_pil.filter(PIL.ImageFilter.UnsharpMask()) 78 | img_bic_sharp_np = pil_to_np(img_bic_sharp_pil) 79 | 80 | return img_bicubic_np, img_bic_sharp_np, img_nearest_np 81 | 82 | 83 | 84 | def tv_loss(x, beta = 0.5): 85 | '''Calculates TV loss for an image `x`. 86 | 87 | Args: 88 | x: image, torch.Variable of torch.Tensor 89 | beta: See https://arxiv.org/abs/1412.0035 (fig. 2) to see effect of `beta` 90 | ''' 91 | dh = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2) 92 | dw = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2) 93 | 94 | return torch.sum(torch.pow(dh[:, :, :-1] + dw[:, :, :, :-1], beta)) 95 | --------------------------------------------------------------------------------