├── Jacobian_inner_product_noisevsimg.ipynb ├── Jacobian_multi_layer_deep_decoder.ipynb ├── README.md ├── denoising_MSE_curves.ipynb ├── denoising_bm3d_example.ipynb ├── denoising_bm3d_imagenet_selected100_paper.ipynb ├── denoising_imagenet_selected100_paper.ipynb ├── denoising_performance_example.ipynb ├── image_fitted_faster_than_noise_on_imgnet.ipynb ├── include ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── compression.cpython-36.pyc │ ├── compression.cpython-37.pyc │ ├── decoder.cpython-36.pyc │ ├── decoder.cpython-37.pyc │ ├── fit.cpython-36.pyc │ ├── fit.cpython-37.pyc │ ├── helpers.cpython-36.pyc │ ├── helpers.cpython-37.pyc │ ├── transforms.cpython-36.pyc │ ├── transforms.cpython-37.pyc │ ├── visualize.cpython-36.pyc │ ├── visualize.cpython-37.pyc │ ├── wavelet.cpython-36.pyc │ └── wavelet.cpython-37.pyc ├── compression.py ├── decoder.py ├── denoise.py ├── fit.py ├── helpers.py ├── onedim.py ├── transforms.py ├── visualize.py └── wavelet.py ├── kernels_and_associated_dual_kernels.ipynb ├── linear_least_squares_selective_fitting_warmup.ipynb ├── noise_vs_img_fitting_different_architectures.ipynb ├── test_data ├── astronaut.png └── phantom256.png └── visualization_linear_approximation.ipynb /Jacobian_inner_product_noisevsimg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Jacobian at initialization\n", 8 | "\n", 9 | "Here, we compute the norm of the product of the Jacobian at initialization with a signal and noise. For both deep decoder and deep image prior, this quantity is significantly smaller for noise than for a natural image. Thus, a natural image is better aligned with the leading singular vectors of the Jacobian than noise. This demonstrates that the Jacobian of the networks is approximately low-rank, with natural images lying in the space spanned by the leading singularvectors.\n", 10 | "\n", 11 | "Running the DIP part requires the models from [https://github.com/DmitryUlyanov/deep-image-prior](https://github.com/DmitryUlyanov/deep-image-prior)." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "num GPUs 1\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "from __future__ import print_function\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "#%matplotlib notebook\n", 31 | "\n", 32 | "import os\n", 33 | "\n", 34 | "import warnings\n", 35 | "warnings.filterwarnings('ignore')\n", 36 | "\n", 37 | "from include import *\n", 38 | "from PIL import Image\n", 39 | "import PIL\n", 40 | "\n", 41 | "import numpy as np\n", 42 | "import torch\n", 43 | "import torch.optim\n", 44 | "from torch.autograd import Variable\n", 45 | "from models import *\n", 46 | "\n", 47 | "GPU = True\n", 48 | "if GPU == True:\n", 49 | " torch.backends.cudnn.enabled = True\n", 50 | " torch.backends.cudnn.benchmark = True\n", 51 | " dtype = torch.cuda.FloatTensor\n", 52 | " os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", 53 | " print(\"num GPUs\",torch.cuda.device_count())\n", 54 | "else:\n", 55 | " dtype = torch.FloatTensor\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "## Load image" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": { 69 | "scrolled": true 70 | }, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "torch.Size([1, 1, 256, 256])\n", 77 | "1.0\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "path = './test_data/'\n", 83 | "#img_name = \"astronaut\"\n", 84 | "#img_name = \"mri\"\n", 85 | "img_name = \"phantom256\"\n", 86 | "\n", 87 | "img_path = path + img_name + \".png\"\n", 88 | "img_pil = Image.open(img_path)\n", 89 | "\n", 90 | "#img_pil = load_and_crop(img_path,target_width=256,target_height=256)\n", 91 | "\n", 92 | "img_np = pil_to_np(img_pil)\n", 93 | "img_np = img_np / np.max(img_np)\n", 94 | "img_clean_var = np_to_var(img_np).type(dtype)\n", 95 | "print(img_clean_var.shape)\n", 96 | "print(np.max(img_np))" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Functions to generate noisy image and noise" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 3, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "Image size: (1, 256, 256)\n" 116 | ] 117 | }, 118 | { 119 | "data": { 120 | "text/plain": [ 121 | "(tensor(85.7806, device='cuda:0'), tensor(85.7806, device='cuda:0'))" 122 | ] 123 | }, 124 | "execution_count": 3, 125 | "metadata": {}, 126 | "output_type": "execute_result" 127 | } 128 | ], 129 | "source": [ 130 | "def get_noisy_img(sig=30,noise_same = False):\n", 131 | " sigma = sig/255.\n", 132 | " if noise_same: # add the same noise in each channel\n", 133 | " noise = np.random.normal(scale=sigma, size=img_np.shape[1:])\n", 134 | " noise = np.array( [noise]*img_np.shape[0] )\n", 135 | " else: # add independent noise in each channel\n", 136 | " noise = np.random.normal(scale=sigma, size=img_np.shape)\n", 137 | "\n", 138 | " img_noisy_np = np.clip( img_np + noise , 0, 1).astype(np.float32)\n", 139 | " img_noisy_var = np_to_var(img_noisy_np).type(dtype)\n", 140 | " return img_noisy_np,img_noisy_var\n", 141 | "\n", 142 | "def get_noise(sig=30,noise_same = False,sh=None):\n", 143 | " sigma = sig/255.\n", 144 | " if noise_same: # add the same noise in each channel\n", 145 | " if sh is None:\n", 146 | " sh = img_np.shape[1:]\n", 147 | " noise = np.random.rand(sh[0],sh[1]) #np.random.normal(scale=sigma, size=img_np.shape[1:])\n", 148 | " noise = np.array( [noise]*img_np.shape[0] )\n", 149 | " else: # add independent noise in each channel\n", 150 | " if sh is None:\n", 151 | " sh = img_np.shape\n", 152 | " noise = np.random.rand(sh[0],sh[1],sh[2]) # np.random.normal(scale=sigma, size=img_np.shape)\n", 153 | "\n", 154 | " img_noisy_np = np.clip( noise , 0, 1).astype(np.float32)\n", 155 | " img_noisy_var = np_to_var(img_noisy_np).type(dtype)\n", 156 | " return img_noisy_np,img_noisy_var\n", 157 | "\n", 158 | "img_noisy_np,img_noisy_var = get_noisy_img() \n", 159 | "output_depth = img_np.shape[0] \n", 160 | "print(\"Image size: \", img_np.shape)\n", 161 | "\n", 162 | "img_np, img_var = get_noise(sig=30,noise_same = False,sh=None)\n", 163 | "# make sure the norm is the same\n", 164 | "img_var *= torch.norm(img_clean_var)/torch.norm(img_var) \n", 165 | "\n", 166 | "torch.norm(img_var), torch.norm(img_clean_var)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 4, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "numit = 5000" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 5, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "def tikz_hist(res,bins=100,filename=\"data.csv\"):\n", 185 | " hist = plt.hist(res,normed=True, bins=bins)\n", 186 | " plt.show()\n", 187 | " x = np.array([ (hist[1][i] + hist[1][i+1])/2 for i in range(bins) ])\n", 188 | " y = np.array(hist[0])\n", 189 | " np.savetxt(filename, np.vstack([ x , y ]).T , delimiter=\"\\t\")" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "## Jacobian" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 6, 202 | "metadata": {}, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "input shape: [1, 256, 16, 16]\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "num_channels = [256]*4\n", 214 | "\n", 215 | "def get_net_input(num_channels=[256]*4,upsample=True):\n", 216 | " if upsample:\n", 217 | " totalupsample = 2**len(num_channels)\n", 218 | " else:\n", 219 | " totalupsample = 1\n", 220 | " width = int(img_clean_var.data.shape[2]/totalupsample)\n", 221 | " height = int(img_clean_var.data.shape[3]/totalupsample)\n", 222 | " shape = [1,num_channels[0], width, height]\n", 223 | " print(\"input shape: \", shape)\n", 224 | " net_input = Variable(torch.zeros(shape)).type(dtype)\n", 225 | " net_input.data.uniform_()\n", 226 | " net_input.data *= 1./10\n", 227 | " return net_input\n", 228 | "\n", 229 | "net_input = get_net_input(num_channels)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 7, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "def get_jacobian(net, x, noutputs):\n", 239 | " x = x.squeeze()\n", 240 | " n = x.size()[0]\n", 241 | " x = x.repeat(noutputs, 1)\n", 242 | " x.requires_grad_(True)\n", 243 | " y = net(x)\n", 244 | " y.backward(torch.eye(noutputs))\n", 245 | " return x.grad.data" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 8, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "def grad_norm(net): \n", 255 | " # returns the norm of the gradient corresponding to the convolutional parameters\n", 256 | " \n", 257 | " # count number of convolutional layers\n", 258 | " nconvnets = 0\n", 259 | " for p in list(filter(lambda p: len(p.data.shape)>2, net.parameters())):\n", 260 | " nconvnets += 1\n", 261 | " \n", 262 | " out_grads = np.zeros(nconvnets)\n", 263 | " p = [x for x in net.parameters() ]\n", 264 | " for ind,p in enumerate(list(filter(lambda p: p.grad is not None and len(p.data.shape)>2, net.parameters()))):\n", 265 | " out_grads[ind] = p.grad.data.norm(2).item()\n", 266 | "\n", 267 | " return np.linalg.norm( out_grads ) \n", 268 | "\n", 269 | "def get_jacobinanprodsDD(img_var,numit=100):\n", 270 | " res = []\n", 271 | " for i in range(numit):\n", 272 | " net = decodernw(1,num_channels_up=num_channels).type(dtype)\n", 273 | " out = net(net_input).type(dtype)\n", 274 | " out.backward( img_var )\n", 275 | " res += [grad_norm(net)]\n", 276 | " return res" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 9, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "image/png": "\n", 287 | "text/plain": [ 288 | "
" 289 | ] 290 | }, 291 | "metadata": { 292 | "needs_background": "light" 293 | }, 294 | "output_type": "display_data" 295 | } 296 | ], 297 | "source": [ 298 | "imgjprods = get_jacobinanprodsDD(img_clean_var,numit=numit)\n", 299 | "\n", 300 | "tikz_hist(imgjprods,bins=50,filename=\"JacobianCleanImg.csv\")" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 10, 306 | "metadata": {}, 307 | "outputs": [ 308 | { 309 | "data": { 310 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD8CAYAAAB3u9PLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEolJREFUeJzt3X+sX3V9x/Hny1bqj22gcDXa4lpDdStm/ljTubksm6iUQej+gHiJ28jWpP9AdHOJa2eCStJEsmXoIroQYSJzFsZ0u9FORNGYJQq9TIcU7LwCkzvcqCuyuQWw9b0/zqfy5cv39p57e/vj3vt8JDc953M+n3PPORy+r3vO55zPN1WFJEnPOtEbIEk6ORgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUrDzRGzAXZ5xxRq1du/ZEb4YkLRp33XXX96tqrE/dRRUIa9euZXJy8kRvhiQtGkn+rW9dbxlJkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSgEX2pvJytnb7Z2dc9uD7zz+OWyJpqfIKQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSp6RUISTYn2ZdkKsn2EctXJbmpLb8jydqBZTta+b4k5w6UP5jkm0m+kWRyIXZGkjR/s452mmQFcA3wZmAa2JNkoqruHai2FXi0qs5KMg5cBbw1yQZgHDgbeCnwhSSvqKpDrd1vVNX3F3B/JEnz1OcKYRMwVVX3V9WTwC5gy1CdLcANbfoW4JwkaeW7quqJqnoAmGrrkySdZPoEwmrgoYH56VY2sk5VHQQeA06fpW0Bn09yV5Jtc990SdJC6vMFORlRVj3rHKntG6rq4SQvAm5L8q2q+sozfnkXFtsAXvayl/XY3OVnpi/P8YtzJM1FnyuEaeDMgfk1wMMz1UmyEjgVOHCktlV1+N9HgE8zw62kqrq2qjZW1caxsbEemytJmo8+gbAHWJ9kXZJT6DqJJ4bqTACXtumLgNurqlr5eHsKaR2wHrgzyfOT/DRAkucDbwHuOfrdkSTN16y3jKrqYJLLgVuBFcD1VbU3yZXAZFVNANcBNyaZorsyGG9t9ya5GbgXOAhcVlWHkrwY+HTX78xK4G+q6nPHYP8WnSN9d/JCrctbSZJG6dOHQFXtBnYPlV0xMP04cPEMbXcCO4fK7gdePdeNlSQdO76pLEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVLT6ys0tbT4XcuSRvEKQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkplcgJNmcZF+SqSTbRyxfleSmtvyOJGsHlu1o5fuSnDvUbkWSryf5zNHuiCTp6MwaCElWANcA5wEbgEuSbBiqthV4tKrOAq4GrmptNwDjwNnAZuDDbX2HvQO472h3QpJ09PoMbrcJmKqq+wGS7AK2APcO1NkCvLdN3wJ8KEla+a6qegJ4IMlUW99Xk6wBzgd2Au9cgH1ZVGYaYE6STpQ+t4xWAw8NzE+3spF1quog8Bhw+ixtPwC8C/jxnLdakrTg+gRCRpRVzzojy5NcADxSVXfN+suTbUkmk0zu379/9q2VJM1Ln0CYBs4cmF8DPDxTnSQrgVOBA0do+wbgwiQPAruANyb561G/vKquraqNVbVxbGysx+ZKkuajTyDsAdYnWZfkFLpO4omhOhPApW36IuD2qqpWPt6eQloHrAfurKodVbWmqta29d1eVb+9APsjSZqnWTuVq+pgksuBW4EVwPVVtTfJlcBkVU0A1wE3tk7jA3Qf8rR6N9N1QB8ELquqQ8doXyRJR6HXV2hW1W5g91DZFQPTjwMXz9B2J92TRDOt+8vAl/tshyTp2PFNZUkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSUDP70PQ8rB2+2dHlj/4/vOP85ZIOhG8QpAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJahzcTrNy0Dtpeeh1hZBkc5J9SaaSbB+xfFWSm9ryO5KsHVi2o5XvS3JuK3tOkjuT/EuSvUnet1A7JEman1kDIckK4BrgPGADcEmSDUPVtgKPVtVZwNXAVa3tBmAcOBvYDHy4re8J4I1V9WrgNcDmJK9fmF2SJM1HnyuETcBUVd1fVU8Cu4AtQ3W2ADe06VuAc5Kkle+qqieq6gFgCthUnR+2+s9uP3WU+yJJOgp9AmE18NDA/HQrG1mnqg4CjwGnH6ltkhVJvgE8AtxWVXfMZwckSQujTyBkRNnwX/Mz1ZmxbVUdqqrXAGuATUleNfKXJ9uSTCaZ3L9/f4/NlSTNR59AmAbOHJhfAzw8U50kK4FTgQN92lbVD4Av0/UxPENVXVtVG6tq49jYWI/NlSTNR59A2AOsT7IuySl0ncQTQ3UmgEvb9EXA7VVVrXy8PYW0DlgP3JlkLMlpAEmeC7wJ+NbR744kab5mfQ+hqg4muRy4FVgBXF9Ve5NcCUxW1QRwHXBjkim6K4Px1nZvkpuBe4GDwGVVdSjJS4Ab2hNHzwJurqrPHIsdlCT10+vFtKraDeweKrtiYPpx4OIZ2u4Edg6V3Q28dq4bK0k6dhy6QpIEGAiSpMZAkCQBBoIkqTEQJEmAw1/rKDgstrS0GAjH2EwfmpJ0svGWkSQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUOPz1AnGYa0mLnVcIkiTAQJAkNQaCJAmwD0HHgN+1LC1OXiFIkgADQZLUGAiSJKBnICTZnGRfkqkk20csX5Xkprb8jiRrB5btaOX7kpzbys5M8qUk9yXZm+QdC7VDkqT5mTUQkqwArgHOAzYAlyTZMFRtK/BoVZ0FXA1c1dpuAMaBs4HNwIfb+g4Cf1RVPw+8HrhsxDolScdRnyuETcBUVd1fVU8Cu4AtQ3W2ADe06VuAc5Kkle+qqieq6gFgCthUVd+rqn8GqKr/Ae4DVh/97kiS5qtPIKwGHhqYn+aZH94/qVNVB4HHgNP7tG23l14L3NF/syVJC61PIGREWfWsc8S2SX4K+DvgD6rqv0f+8mRbkskkk/v37++xuZKk+egTCNPAmQPza4CHZ6qTZCVwKnDgSG2TPJsuDD5RVZ+a6ZdX1bVVtbGqNo6NjfXYXEnSfPQJhD3A+iTrkpxC10k8MVRnAri0TV8E3F5V1crH21NI64D1wJ2tf+E64L6q+vOF2BFJ0tGZdeiKqjqY5HLgVmAFcH1V7U1yJTBZVRN0H+43JpmiuzIYb233JrkZuJfuyaLLqupQkl8Ffgf4ZpJvtF/1J1W1e6F3UJLUT6+xjNoH9e6hsisGph8HLp6h7U5g51DZPzG6f0GSdIL4prIkCTAQJEmNgSBJAgwESVJjIEiSAANBktT4FZo6bvxqTenk5hWCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAEc71UnAUVClk4NXCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1PR6MS3JZuCDwArgo1X1/qHlq4CPA78I/Bfw1qp6sC3bAWwFDgFvr6pbW/n1wAXAI1X1qgXZGy0pvrAmHV+zBkKSFcA1wJuBaWBPkomquneg2lbg0ao6K8k4cBXw1iQbgHHgbOClwBeSvKKqDgEfAz5EFySLxkwfUpK02PW5ZbQJmKqq+6vqSWAXsGWozhbghjZ9C3BOkrTyXVX1RFU9AEy19VFVXwEOLMA+SJIWQJ9AWA08NDA/3cpG1qmqg8BjwOk920qSTgJ9AiEjyqpnnT5tj/zLk21JJpNM7t+/fy5NJUlz0CcQpoEzB+bXAA/PVCfJSuBUuttBfdoeUVVdW1Ubq2rj2NjYXJpKkuagTyDsAdYnWZfkFLpO4omhOhPApW36IuD2qqpWPp5kVZJ1wHrgzoXZdEnSQpo1EFqfwOXArcB9wM1VtTfJlUkubNWuA05PMgW8E9je2u4FbgbuBT4HXNaeMCLJJ4GvAq9MMp1k68LumiRpLnq9h1BVu4HdQ2VXDEw/Dlw8Q9udwM4R5ZfMaUslSceUbypLkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVLT68U06WRypO+k8MtzpPnzCkGSBBgIkqTGQJAkAfYhaImZqX/BvgVpdl4hSJIAA0GS1BgIkiTAQJAkNXYqz+BILz9J0lLkFYIkCfAKQRrJx1e1HBkIWhb8gJdm5y0jSRJgIEiSGm8ZaVnzaTLpKQaCNAf2RWgp85aRJAkwECRJjbeMpAUw174IbzHpZLTsA8FORZ0I9kXoZNTrllGSzUn2JZlKsn3E8lVJbmrL70iydmDZjla+L8m5fdcpSTq+Zr1CSLICuAZ4MzAN7EkyUVX3DlTbCjxaVWclGQeuAt6aZAMwDpwNvBT4QpJXtDazrXNBeSWgpcgrDS2kPreMNgFTVXU/QJJdwBZg8MN7C/DeNn0L8KEkaeW7quoJ4IEkU2199FintOws1B8uR1qPYaGZ9AmE1cBDA/PTwC/NVKeqDiZ5DDi9lX9tqO3qNj3bOiUdAyfb1fKJDKhjfYU11/Wf6Cu+PoGQEWXVs85M5aP6LobX2a042QZsa7M/TLJvhu08kjOA78+j3VLjceh4HDonxXHIVSd6C4ChY3Gst2mu6z/K7fnZvhX7BMI0cObA/Brg4RnqTCdZCZwKHJil7WzrBKCqrgWu7bGdM0oyWVUbj2YdS4HHoeNx6HgcnuKx6PR5ymgPsD7JuiSn0HUSTwzVmQAubdMXAbdXVbXy8fYU0jpgPXBnz3VKko6jWa8QWp/A5cCtwArg+qram+RKYLKqJoDrgBtbp/EBug94Wr2b6TqLDwKXVdUhgFHrXPjdkyT1le4P+aUtybZ262lZ8zh0PA4dj8NTPBadZREIkqTZObidJAlYQoGQZEWSryf5TJtf14bR+HYbVuOUVj7jMBuLXZLTktyS5FtJ7kvyy0lemOS2dhxuS/KCVjdJ/qIdh7uTvO5Eb/9CSvKHSfYmuSfJJ5M8ZzmcE0muT/JIknsGyuZ8DiS5tNX/dpJLR/2uk9kMx+FP2/8bdyf5dJLTBpY5xA5LKBCAdwD3DcxfBVxdVeuBR+mG14CBYTaAq1u9peKDwOeq6ueAV9Mdj+3AF9tx+GKbBziP7qmv9XTveXzk+G/usZFkNfB2YGNVvYruwYXDQ6os9XPiY8DmobI5nQNJXgi8h+5l0U3Aew6HyCLyMZ55HG4DXlVVvwD8K7ADYGiInc3Ah9sfmIeH7TkP2ABc0uouWUsiEJKsAc4HPtrmA7yRbhgNgBuA32rTW9o8bfk5rf6iluRngF+je+KLqnqyqn7A0/d3+Dh8vDpfA05L8pLjvNnH0krgue29mOcB32MZnBNV9RW6J/0GzfUcOBe4raoOVNWjdB+kwx+uJ7VRx6GqPl9VB9vs1+jef4KBIXaq6gHg8BA7Pxm2p6qeBA4PsbNkLYlAAD4AvAv4cZs/HfjBwH/8wSEznjbMBnB4mI3F7uXAfuCv2q2zjyZ5PvDiqvoeQPv3Ra3+qCFJVrMEVNW/A38GfJcuCB4D7mL5nROHzfUcWLLnxoDfB/6xTS/n4/A0iz4QklwAPFJVdw0Wj6haPZYtZiuB1wEfqarXAv/LU7cGRlmqx4F2e2MLsI5ulN3n0132D1vq58Rs5jrkzJKQ5N1070V94nDRiGpL/jiMsugDAXgDcGGSB+ku6d5Id8VwWrtdAE8fGuMnw2nk6cNsLHbTwHRV3dHmb6ELiP88fCuo/fvIQP1ew4csQm8CHqiq/VX1I+BTwK+w/M6Jw+Z6DizZc6N1kF8AvK2eeuZ+2R2HmSz6QKiqHVW1pqrW0nUM3V5VbwO+RDeMBnTDavxDm55pmI1Frar+A3goyStb0Tl0b4gP7u/wcfjd9qTJ64HHDt9WWAK+C7w+yfNaX8DhY7GszokBcz0HbgXekuQF7WrrLa1sUUuyGfhj4MKq+r+BRQ6xc1hVLZkf4NeBz7Tpl9P9R50C/hZY1cqf0+an2vKXn+jtXsD9fw0wCdwN/D3wArp74V8Evt3+fWGrG7onKL4DfJPuiZwTvg8LeCzeB3wLuAe4EVi1HM4J4JN0/SY/ovsLd+t8zgG6e+xT7ef3TvR+LdBxmKLrE/hG+/nLgfrvbsdhH3DeQPlv0j2R9B3g3Sd6v471j28qS5KAJXDLSJK0MAwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSQD8P3iSsOdoQrb1AAAAAElFTkSuQmCC\n", 311 | "text/plain": [ 312 | "
" 313 | ] 314 | }, 315 | "metadata": { 316 | "needs_background": "light" 317 | }, 318 | "output_type": "display_data" 319 | } 320 | ], 321 | "source": [ 322 | "imgjprods = get_jacobinanprodsDD(img_var,numit=numit)\n", 323 | "\n", 324 | "tikz_hist(imgjprods,bins=50,filename=\"JacobianNoiseImg.csv\")" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": {}, 330 | "source": [ 331 | "## DIP" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 11, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "name": "stdout", 341 | "output_type": "stream", 342 | "text": [ 343 | "input shape: [1, 32, 256, 256]\n" 344 | ] 345 | } 346 | ], 347 | "source": [ 348 | "num_channels = [32]*5\n", 349 | "net_input = get_net_input(num_channels,False)\n", 350 | "\n", 351 | "def get_jacobinanprodsDIP(img_var,numit=numit):\n", 352 | " res = []\n", 353 | " for i in range(numit): \n", 354 | " net = get_net(32, 'skip', 'reflection',n_channels=output_depth,skip_n33d=128,\n", 355 | " skip_n33u=128,skip_n11=4,num_scales=5,upsample_mode='bilinear').type(dtype) \n", 356 | " out = net(net_input).type(dtype)\n", 357 | " out.backward( img_var )\n", 358 | " res += [grad_norm(net)]\n", 359 | " return res" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 12, 365 | "metadata": {}, 366 | "outputs": [ 367 | { 368 | "data": { 369 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAD8CAYAAABZ/vJZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFhpJREFUeJzt3X+MZeV93/H3J7tdkjguP5ZNBQvOrsMm6pJatjPFdiO5EXbKEqcsUom8WK1wS4SaQN3UrQrIbWRTInntyriV8Q9qqAiSMxDUlCmJS21DJLeNgcH414K3jIGaNW68mB9RGwWy5Ns/7oO5O753586zMztzl/dLGu25z33O95xn78x85pzn3HNTVUiStFw/stY7IEmaTgaIJKmLASJJ6mKASJK6GCCSpC4GiCSpiwEiSepigEiSuhggkqQuG9d6B1bTqaeeWtu2bVvr3ZCkqfLAAw88VVVblup3XAfItm3bmJ+fX+vdkKSpkuR/T9LPU1iSpC4GiCSpiwEiSepigEiSuhggkqQuBogkqYsBIknqYoBIkroYIJKkLsf1O9GlpWy76g9Gtj/+wXcc4z2Rpo9HIJKkLgaIJKmLASJJ6mKASJK6GCCSpC4GiCSpi5fxaip5+a209jwCkSR1MUAkSV0MEElSFwNEktTFAJEkdTFAJEldDBBJUhcDRJLUxQCRJHUxQCRJXQwQSVKXiQIkya4k+5MsJLlqxPMnJLm1PX9vkm1Dz13d2vcnOW+pmkm2txqPtJqbWvu7kxxM8pX29WtHM3BJ0tFZMkCSbACuB84HdgIXJ9m5qNulwDNVdRZwHbC3rbsT2AOcDewCPp5kwxI19wLXVdUO4JlW+yW3VtXr29enu0YsSVoRkxyBnAMsVNWjVfUCMAvsXtRnN3BzW74deFuStPbZqnq+qh4DFlq9kTXbOue2GrSaF/YPT5K0WiYJkK3AE0OPD7S2kX2q6hDwHLD5COuOa98MPNtqjNrW30vytSS3Jzlzgn2XJK2SSQIkI9pqwj4r1Q7wX4BtVfU64PO8fMRz+I4klyWZTzJ/8ODBUV0kSStgkg+UOgAM/7V/BvDkmD4HkmwETgSeXmLdUe1PAScl2diOQn7Qv6q+P9T/P9DmWRarqhuAGwBmZmYWB53WKT8gSpo+kwTI/cCOJNuB7zCYFH/Xoj5zwCXAHwMXAXdXVSWZAz6T5CPA6cAO4D4GRxo/VLOtc0+rMdtq3gGQ5LSq+m7b3gXAw51j1ivQuICS1G/JAKmqQ0muAO4CNgA3VdW+JNcA81U1B9wI3JJkgcGRx5627r4ktwEPAYeAy6vqRYBRNdsmrwRmk1wLPNhqA7wnyQWtztPAu4969JKkbhN9JnpV/SHwh4vafmto+c+BXx2z7m8Dvz1Jzdb+KIOrtBa3Xw1cPcn+SpJWn+9ElyR1megIRFory527cK5DOnY8ApEkdTFAJEldDBBJUhcDRJLUxQCRJHUxQCRJXQwQSVIXA0SS1MUAkSR1MUAkSV0MEElSF++FJS2DH3wlvcwjEElSFwNEktTFAJEkdTFAJEldDBBJUhcDRJLUxQCRJHUxQCRJXQwQSVIXA0SS1MUAkSR18V5YOqbG3UtqvZmW/ZTWkkcgkqQuBogkqctEAZJkV5L9SRaSXDXi+ROS3NqevzfJtqHnrm7t+5Oct1TNJNtbjUdazU2LtnVRkkoy0zNgSdLKWDJAkmwArgfOB3YCFyfZuajbpcAzVXUWcB2wt627E9gDnA3sAj6eZMMSNfcC11XVDuCZVvulfXk18B7g3r7hSpJWyiRHIOcAC1X1aFW9AMwCuxf12Q3c3JZvB96WJK19tqqer6rHgIVWb2TNts65rQat5oVD2/k3wIeAP1/mOCVJK2ySq7C2Ak8MPT4AvGlcn6o6lOQ5YHNr/9Kidbe25VE1NwPPVtWhxf2TvAE4s6ruTPIvJthvrZFX4hVMflKhXokmOQLJiLaasM+KtCf5EQanxv75EfZzsCPJZUnmk8wfPHhwqe6SpE6TBMgB4Myhx2cAT47rk2QjcCLw9BHWHdf+FHBSqzHc/mrg54A/SvI48GZgbtREelXdUFUzVTWzZcuWCYYnSeoxSYDcD+xoV0dtYjApPreozxxwSVu+CLi7qqq172lXaW0HdgD3javZ1rmn1aDVvKOqnquqU6tqW1VtY3Ba7IKqmu8ctyTpKC05B9LmNK4A7gI2ADdV1b4k1wDzVTUH3AjckmSBwZHHnrbuviS3AQ8Bh4DLq+pFgFE12yavBGaTXAs82GpLktaZDP7oPz7NzMzU/LwHKcfaK3ESfRwn0TWNkjxQVUu+1853okuSuhggkqQuBogkqYsBIknqYoBIkroYIJKkLgaIJKmLASJJ6mKASJK6GCCSpC4GiCSpiwEiSepigEiSuhggkqQuBogkqYsBIknqYoBIkroYIJKkLgaIJKmLASJJ6rJxrXdAOp5tu+oPRrY//sF3HOM9kVaeRyCSpC4GiCSpi6ew1G3c6RlJrwwGiLQGjhS+zo9oWngKS5LUxQCRJHUxQCRJXSYKkCS7kuxPspDkqhHPn5Dk1vb8vUm2DT13dWvfn+S8pWom2d5qPNJqbmrt/zjJ15N8Jcl/T7LzaAYuSTo6SwZIkg3A9cD5wE7g4hG/vC8Fnqmqs4DrgL1t3Z3AHuBsYBfw8SQblqi5F7iuqnYAz7TaAJ+pqr9RVa8HPgR8pHPMkqQVMMkRyDnAQlU9WlUvALPA7kV9dgM3t+XbgbclSWufrarnq+oxYKHVG1mzrXNuq0GreSFAVf3p0PZeBdTyhipJWkmTXMa7FXhi6PEB4E3j+lTVoSTPAZtb+5cWrbu1LY+quRl4tqoOjehPksuB9wKbGATND0lyGXAZwGte85oJhidJ6jHJEUhGtC3+639cn5VqHyxUXV9VPw1cCfyrUTtbVTdU1UxVzWzZsmVUF0nSCpgkQA4AZw49PgN4clyfJBuBE4Gnj7DuuPangJNajXHbgsEprwsn2HdJ0iqZJEDuB3a0q6M2MZgUn1vUZw64pC1fBNxdVdXa97SrtLYDO4D7xtVs69zTatBq3gGQZMfQ9t4BPLK8oUqSVtKScyBtTuMK4C5gA3BTVe1Lcg0wX1VzwI3ALUkWGBx57Gnr7ktyG/AQcAi4vKpeBBhVs23ySmA2ybXAg602wBVJ3g78BYOrs14KLEnSGsjgj/7j08zMTM3Pz6/1bhy3vJni6vBeWFprSR6oqpml+vlOdElSFwNEktTFAJEkdfHzQLQk5zokjeIRiCSpiwEiSepigEiSuhggkqQuBogkqYsBIknqYoBIkroYIJKkLgaIJKmL70SX1plx7/z3Lr1abzwCkSR1MUAkSV0MEElSFwNEktTFAJEkdTFAJEldDBBJUhffB6If8JMHJS2HRyCSpC4GiCSpiwEiSepigEiSuhggkqQuBogkqctEAZJkV5L9SRaSXDXi+ROS3NqevzfJtqHnrm7t+5Oct1TNJNtbjUdazU2t/b1JHkrytSRfSPJTRzNwSdLRWfJ9IEk2ANcDvwQcAO5PMldVDw11uxR4pqrOSrIH2Au8M8lOYA9wNnA68PkkP9PWGVdzL3BdVc0m+WSr/QngQWCmqv4sya8DHwLeebT/AdK08HNCtN5McgRyDrBQVY9W1QvALLB7UZ/dwM1t+XbgbUnS2mer6vmqegxYaPVG1mzrnNtq0GpeCFBV91TVn7X2LwFnLH+4kqSVMkmAbAWeGHp8oLWN7FNVh4DngM1HWHdc+2bg2VZj3LZgcFTy2VE7m+SyJPNJ5g8ePLjk4CRJfSYJkIxoqwn7rFT7yxtK/j4wA3x4RF+q6oaqmqmqmS1btozqIklaAZPcC+sAcObQ4zOAJ8f0OZBkI3Ai8PQS645qfwo4KcnGdhRy2LaSvB14H/C3q+r5CfZdkrRKJjkCuR/Y0a6O2sRgUnxuUZ854JK2fBFwd1VVa9/TrtLaDuwA7htXs61zT6tBq3kHQJI3AJ8CLqiq7/UNV5K0UpY8AqmqQ0muAO4CNgA3VdW+JNcA81U1B9wI3JJkgcGRx5627r4ktwEPAYeAy6vqRYBRNdsmrwRmk1zL4MqrG1v7h4GfAH5vMNfOt6vqgqP+H5Akdcngj/7j08zMTM3Pz6/1bkwNb+c+nbyMVystyQNVNbNUP9+JLknqYoBIkroYIJKkLgaIJKmLASJJ6mKASJK6TPJOdEnrmHfp1VrxCESS1MUAkSR1MUAkSV0MEElSFyfRpeOUk+tabR6BSJK6GCCSpC4GiCSpiwEiSepigEiSuhggkqQuBogkqYsBIknqYoBIkroYIJKkLgaIJKmL98KSXmG8R5ZWikcgkqQuBogkqYsBIknqMtEcSJJdwL8DNgCfrqoPLnr+BOB3gJ8Hvg+8s6oeb89dDVwKvAi8p6ruOlLNJNuBWeAU4MvAP6iqF5K8Ffgo8DpgT1XdfhTjfkUbdw5ckpZjySOQJBuA64HzgZ3AxUl2Lup2KfBMVZ0FXAfsbevuBPYAZwO7gI8n2bBEzb3AdVW1A3im1Qb4NvBu4DN9Q5UkraRJjkDOARaq6lGAJLPAbuChoT67gfe35duBjyVJa5+tqueBx5IstHqMqpnkYeBc4F2tz82t7ieGjmj+cvnDlLQUr87Sck0yB7IVeGLo8YHWNrJPVR0CngM2H2Hdce2bgWdbjXHbkiStA5MESEa01YR9Vqp9YkkuSzKfZP7gwYPLWVWStAyTBMgB4Myhx2cAT47rk2QjcCLw9BHWHdf+FHBSqzFuW0dUVTdU1UxVzWzZsmU5q0qSlmGSOZD7gR3t6qjvMJgUf9eiPnPAJcAfAxcBd1dVJZkDPpPkI8DpwA7gPgZHGj9Us61zT6sx22recZRjlHQUnBvROEsGSFUdSnIFcBeDS25vqqp9Sa4B5qtqDrgRuKVNkj/NIBBo/W5jMOF+CLi8ql4EGFWzbfJKYDbJtcCDrTZJ/ibw+8DJwN9N8oGqOntF/heOU16uK2k1pWpZUwxTZWZmpubn59d6N9aMAaLV5BHI8SvJA1U1s1Q/34kuSepigEiSuhggkqQuBogkqYsBIknqYoBIkroYIJKkLgaIJKmLASJJ6mKASJK6TPSRtlrfvGWJpLXgEYgkqYtHIJLW1JGOoL1h4/pmgEjqstzPCfFU6/HHU1iSpC4GiCSpi6ewJK0oT1W9chggU8QfTEnriaewJEldDBBJUhcDRJLUxQCRJHVxEn0dcrJc0jTwCESS1MUAkSR18RSWpHVruffb0rFlgEiaOgbL+uApLElSl4kCJMmuJPuTLCS5asTzJyS5tT1/b5JtQ89d3dr3JzlvqZpJtrcaj7Sam5bahiTp2FvyFFaSDcD1wC8BB4D7k8xV1UND3S4Fnqmqs5LsAfYC70yyE9gDnA2cDnw+yc+0dcbV3AtcV1WzST7Zan9i3DaO9j/gWPCyXOnYWO7Pmqe8js4kcyDnAAtV9ShAkllgNzAcILuB97fl24GPJUlrn62q54HHkiy0eoyqmeRh4FzgXa3Pza3uJ8Zto6pqOQOWpJcYOEdnkgDZCjwx9PgA8KZxfarqUJLngM2t/UuL1t3alkfV3Aw8W1WHRvQft42nJhjDsnnUIGmx1f69sJKf5ngswm6SAMmItsV/9Y/rM6591NzLkfpPuh8kuQy4rD38v0n2j1hvKaeySsG0BhzL+nO8jAMcy4rK3hUrdWr2HtVYfmqSTpMEyAHgzKHHZwBPjulzIMlG4ETg6SXWHdX+FHBSko3tKGS4/7htHKaqbgBumGBcYyWZr6qZo6mxXjiW9ed4GQc4lvXqWI1lkquw7gd2tKujNjGYFJ9b1GcOuKQtXwTc3eYm5oA97Qqq7cAO4L5xNds697QatJp3LLENSdIaWPIIpM03XAHcBWwAbqqqfUmuAearag64EbilTZI/zSAQaP1uYzDhfgi4vKpeBBhVs23ySmA2ybXAg60247YhSVob8Y/4H5bksnYqbOo5lvXneBkHOJb16liNxQCRJHXxViaSpC7HbYAk+dEk9yX5apJ9ST7Q2pd9q5Tl3o5lFce0IcmDSe6c5rEkeTzJ15N8Jcl8azslyefaWD6X5OTWniT/vu3X15K8cajOJa3/I0kuGWr/+VZ/oa076hLwlRrLSUluT/LNJA8necu0jSXJz7bX4qWvP03ym9M2jqFt/bP2M/+NJL+bwe+Caf1Z+adtHPuS/GZrWz+vS1Udl18M3jfyE235rwD3Am8GbgP2tPZPAr/eln8D+GRb3gPc2pZ3Al8FTgC2A99iMPG/oS2/FtjU+uxc5TG9F/gMcGd7PJVjAR4HTl3U9iHgqrZ8FbC3Lf8y8Nn2er4ZuLe1nwI82v49uS2f3J67D3hLW+ezwPmrOJabgV9ry5uAk6Z1LG17G4D/w+B9AFM3DgZvOH4M+LGhn5F3T+PPCvBzwDeAH2dwwdPnGVzJum5el1X7RlxPX+0F+DKDd7s/BWxs7W8B7mrLdwFvacsbW78AVwNXD9W6q633g3Vb+2H9VmEMZwBfYHCrlzvbvk3rWB7nhwNkP3BaWz4N2N+WPwVcvLgfcDHwqaH2T7W204BvDrUf1m+Fx/FXGfyyyrSPZWgbfwf4H9M6Dl6+Y8Up7Xv/TuC8afxZAX4V+PTQ438N/Mv19Loct6ew4AenfL4CfA/4HIO/HCa6VQowfDuWxbdd2XqE9tXyUQbfPH/ZHk982xfW31gK+G9JHsjgzgEAf62qvtv2+bvAT7b25e7z1ra8uH01vBY4CPzHDE4tfjrJq5jOsbxkD/C7bXnqxlFV3wH+LfBt4LsMvvcfYDp/Vr4BvDXJ5iQ/zuAI40zW0etyXAdIVb1YVa9n8Nf7OcBfH9Wt/bvc27FMdGuVlZDkV4DvVdUDw81H2P66HUvzC1X1RuB84PIkbz1C3/U8lo3AG4FPVNUbgP/H4JTCOOt5LLR5gQuA31uq64i2dTGONh+wm8Fpp9OBVzH4Phu3/XU7lqp6mMFdxz8H/FcGp8sOHWGVYz6W4zpAXlJVzwJ/xOC84EkZ3AoFRt8qhUx2O5ZJbvGyUn4BuCDJ48Asg9NYH2U6x0JVPdn+/R7w+wzC/U+SnNb2+TQGR42HjWXCfT7Qlhe3r4YDwIGqurc9vp1BoEzjWGDwi/bLVfUn7fE0juPtwGNVdbCq/gL4T8DfYnp/Vm6sqjdW1Vvbfj3CenpdVuPc3Xr4ArYAJ7XlHwO+CPwKg7+uhifTfqMtX87hk2m3teWzOXwy7VEGE2kb2/J2Xp5MO/sYjOsXeXkSferGwuAvwlcPLf9PYBfwYQ6fGPxQW34Hh08M3tfaT2Ew/3By+3oMOKU9d3/r+9LE4C+v4uvxReBn2/L72zimdSyzwD8cejx142Awz7mPwbxnGFzk8E+m8Wel7cdPtn9fA3yz/b+um9dlVQa9Hr6A1zG4FcrXGJxL/K3W/loGVx4stG+qE1r7j7bHC+351w7Veh+D+ZP9DF2lwOCc5P9qz73vGI3rF3k5QKZuLG2fv9q+9r20LQbnnb/A4C+sLwx9g4fBh499C/g6MDNU6x+1MS5w+C++mfaafwv4GIsmuVd4PK8H5tv32X9uP6BTNxYGv3C/D5w41DZ142jb+gCDX7bfAG5hEAJT97PStvVFBreC+irwtvX2uvhOdElSl1fEHIgkaeUZIJKkLgaIJKmLASJJ6mKASJK6GCCSpC4GiCSpiwEiSery/wE3HYzTeTpHRQAAAABJRU5ErkJggg==\n", 370 | "text/plain": [ 371 | "
" 372 | ] 373 | }, 374 | "metadata": { 375 | "needs_background": "light" 376 | }, 377 | "output_type": "display_data" 378 | } 379 | ], 380 | "source": [ 381 | "imgjprods = get_jacobinanprodsDIP(img_clean_var,numit=numit)\n", 382 | "\n", 383 | "tikz_hist(imgjprods,bins=50,filename=\"JacobianCleanImgDIP.csv\")" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 13, 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "image/png": "\n", 394 | "text/plain": [ 395 | "
" 396 | ] 397 | }, 398 | "metadata": { 399 | "needs_background": "light" 400 | }, 401 | "output_type": "display_data" 402 | } 403 | ], 404 | "source": [ 405 | "imgjprods = get_jacobinanprodsDIP(img_var,numit=numit)\n", 406 | "\n", 407 | "tikz_hist(imgjprods,bins=50,filename=\"JacobianNoiseImgDIP.csv\")" 408 | ] 409 | } 410 | ], 411 | "metadata": { 412 | "kernelspec": { 413 | "display_name": "Python 3", 414 | "language": "python", 415 | "name": "python3" 416 | }, 417 | "language_info": { 418 | "codemirror_mode": { 419 | "name": "ipython", 420 | "version": 3 421 | }, 422 | "file_extension": ".py", 423 | "mimetype": "text/x-python", 424 | "name": "python", 425 | "nbconvert_exporter": "python", 426 | "pygments_lexer": "ipython3", 427 | "version": "3.6.7" 428 | } 429 | }, 430 | "nbformat": 4, 431 | "nbformat_minor": 2 432 | } 433 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Denoising and Regularization via Exploiting the Structural Bias of Convolutional Generators 2 | 3 | This repository provides code for reproducing the figures in the paper: 4 | 5 | **``Denoising and Regularization via Exploiting the Structural Bias of Convolutional Generators''**, by Reinhard Heckel and Mahdi Soltanolkotabi. Contact: [reinhard.heckel@gmail.com](reinhard.heckel@gmail.com) 6 | 7 | The paper is available online [here](http://www.reinhardheckel.com/papers/overparameterized_convolutional_generators.pdf). 8 | 9 | ## Organization 10 | 11 | - Figure 1: denoising_MSE_curves.ipynb 12 | - Figure 2: denoising_performance_example.ipynb, denoising_bm3d_example.ipynb 13 | - Figure 4,8: noise_vs_img_fitting_different_architectures.ipynb 14 | - Figure 5: linear_least_squares_selective_fitting_warmup.ipynb 15 | - Figure 6: kernels_and_associated_dual_kernels.ipynb 16 | - Figure 7: Jacobian_multi_layer_deep_decoder.ipynb 17 | - Figure 10: image_fitted_faster_than_noise_on_imgnet.ipynb 18 | - Figure 12: Jacobian_inner_product_noisevsimg.ipynb 19 | - Table 1: denoising_imagenet_selected100_paper.ipynb, denoising_bm3d_imagenet_selected100_paper.ipynb 20 | 21 | ## Installation 22 | 23 | The code is written in python and relies on pytorch. The following libraries are required: 24 | - python 3 25 | - pytorch 26 | - numpy 27 | - skimage 28 | - matplotlib 29 | - scikit-image 30 | - jupyter 31 | 32 | The libraries can be installed via: 33 | ``` 34 | conda install jupyter 35 | ``` 36 | 37 | A small part of the code compares performance to the deep image prior. This part requires downloading the models folder from [https://github.com/DmitryUlyanov/deep-image-prior](https://github.com/DmitryUlyanov/deep-image-prior). 38 | 39 | 40 | ## Citation 41 | ``` 42 | @article{heckel_denoising_2019, 43 | author = {Reinhard Heckel and Mahdi Soltanolkotabi}, 44 | title = {Denoising and Regularization via Exploiting the Structural Bias of Convolutional Generators}, 45 | journal = {arXiv:1910.14634 [cs.LG]}, 46 | year = {2019} 47 | } 48 | ``` 49 | 50 | ## Licence 51 | 52 | All files are provided under the terms of the Apache License, Version 2.0. 53 | -------------------------------------------------------------------------------- /denoising_bm3d_imagenet_selected100_paper.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Denoising with BM3D\n", 8 | "\n", 9 | "The code below demonstrates the denoising performance on an example image." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "#from __future__ import print_function\n", 19 | "#import matplotlib.pyplot as plt\n", 20 | "#%matplotlib notebook\n", 21 | "\n", 22 | "import os\n", 23 | "\n", 24 | "from os import *\n", 25 | "from os.path import *\n", 26 | "\n", 27 | "import warnings\n", 28 | "warnings.filterwarnings('ignore')\n", 29 | "\n", 30 | "#from include import *\n", 31 | "from PIL import Image\n", 32 | "import PIL\n", 33 | "\n", 34 | "import numpy as np\n", 35 | "import pybm3d" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Load images" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "got 100 images\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "def pil_to_np(img_PIL):\n", 60 | " '''Converts image in PIL format to np.array.\n", 61 | " \n", 62 | " From W x H x C [0...255] to C x W x H [0..1]\n", 63 | " '''\n", 64 | " ar = np.array(img_PIL)\n", 65 | "\n", 66 | " if len(ar.shape) == 3:\n", 67 | " ar = ar.transpose(2,0,1)\n", 68 | " else:\n", 69 | " ar = ar[None, ...]\n", 70 | "\n", 71 | " return ar.astype(np.float32) / 255.\n", 72 | "\n", 73 | "def load_and_crop(imgname,target_width=512,target_height=512):\n", 74 | " '''\n", 75 | " imgname: string of image location\n", 76 | " load an image, and center-crop if the image is large enough, else return none\n", 77 | " '''\n", 78 | " img = Image.open(imgname)\n", 79 | " width, height = img.size\n", 80 | " if width <= target_width or height <= target_height:\n", 81 | " return None\t\n", 82 | " \n", 83 | " left = (width - target_width)/2\n", 84 | " top = (height - target_height)/2\n", 85 | " right = (width + target_width)/2\n", 86 | " bottom = (height + target_height)/2\n", 87 | " \n", 88 | " return img.crop((left, top, right, bottom))\n", 89 | "\n", 90 | "def get_imgnet_imgs(path = './imgs/'):\n", 91 | " siz = 512\n", 92 | " imgs = []\n", 93 | " imgnames = [f for f in listdir(path) if isfile(join(path, f))] \n", 94 | " for imgname in imgnames:\n", 95 | " # prepare and select image\n", 96 | " imgname = path + imgname\n", 97 | "\n", 98 | " img = load_and_crop(imgname,target_width=512,target_height=512)\n", 99 | " if img is None: # then the image could not be croped to 512x512\n", 100 | " continue\n", 101 | " \n", 102 | " img_np = pil_to_np(img)\n", 103 | "\n", 104 | " if img_np.shape[0] != 3: # we only want to consider color images\n", 105 | " continue\n", 106 | " imgs += [img_np]\n", 107 | " print(\"got \", len(imgs), \" images\")\n", 108 | " return imgs\n", 109 | "\n", 110 | "imgs = get_imgnet_imgs()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 1, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "## to greyscale\n", 120 | "def rgb2gray(rgb):\n", 121 | " return 0.2989*rgb[0] + 0.5870*rgb[1] + 0.1140*rgb[2]\n", 122 | "\n", 123 | "#gimg = np.array([rgb2gray(imgs[0])])\n", 124 | "#gimg = gimg.transpose(1,2,0)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 11, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "def get_noisy_img(img_np,sig=30,noise_same = False):\n", 134 | " sigma = sig/255.\n", 135 | " if noise_same: # add the same noise in each channel\n", 136 | " noise1 = np.random.normal(scale=sigma, size=img_np.shape[:2])\n", 137 | " noise = np.zeros(img_np.shape)\n", 138 | " noise[:,:,0] = noise1\n", 139 | " noise[:,:,1] = noise1\n", 140 | " noise[:,:,2] = noise1\n", 141 | " else: # add independent noise in each channel\n", 142 | " noise = np.random.normal(scale=sigma, size=img_np.shape)\n", 143 | "\n", 144 | " img_noisy_np = np.clip( img_np + noise , 0, 1).astype(np.float32)\n", 145 | " #img_noisy_var = np_to_var(img_noisy_np).type(dtype)\n", 146 | " return img_noisy_np #,img_noisy_var" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 6, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "def psnr(x_hat,x_true,maxv=1.):\n", 156 | " x_hat = x_hat.flatten()\n", 157 | " x_true = x_true.flatten()\n", 158 | " mse=np.mean(np.square(x_hat-x_true))\n", 159 | " psnr_ = 10.*np.log(maxv**2/mse)/np.log(10.)\n", 160 | " return psnr_" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Denoise noisy image" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 12, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "sig = 25.0\n", 177 | "noise_same = True" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 13, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "(512, 512, 3) 1.0 0.0\n", 190 | "Noise PSNR: 20.212181200072777\n", 191 | "Recov PSNR: 23.63727431083024\n", 192 | "(512, 512, 3) 0.972549 0.0\n", 193 | "Noise PSNR: 20.504885441401665\n", 194 | "Recov PSNR: 23.940315558364137\n", 195 | "(512, 512, 3) 1.0 0.0\n", 196 | "Noise PSNR: 20.301638870226878\n", 197 | "Recov PSNR: 23.696178870942163\n", 198 | "(512, 512, 3) 1.0 0.0\n", 199 | "Noise PSNR: 20.768846524713283\n", 200 | "Recov PSNR: 28.750837924533748\n", 201 | "(512, 512, 3) 1.0 0.0\n", 202 | "Noise PSNR: 20.594251079013212\n", 203 | "Recov PSNR: 25.36551000419403\n", 204 | "(512, 512, 3) 1.0 0.0\n", 205 | "Noise PSNR: 20.464093268981927\n", 206 | "Recov PSNR: 23.997677276501193\n", 207 | "(512, 512, 3) 1.0 0.0\n", 208 | "Noise PSNR: 20.57955428958745\n", 209 | "Recov PSNR: 24.469921477217806\n", 210 | "(512, 512, 3) 1.0 0.0\n", 211 | "Noise PSNR: 20.54248374191298\n", 212 | "Recov PSNR: 26.705066584818994\n", 213 | "(512, 512, 3) 1.0 0.0\n", 214 | "Noise PSNR: 20.59146513458877\n", 215 | "Recov PSNR: 24.70622667488254\n", 216 | "(512, 512, 3) 0.99215686 0.0\n", 217 | "Noise PSNR: 20.480707008180083\n", 218 | "Recov PSNR: 25.4927768415609\n", 219 | "(512, 512, 3) 1.0 0.0\n", 220 | "Noise PSNR: 20.56513601853171\n", 221 | "Recov PSNR: 27.63120525150335\n", 222 | "(512, 512, 3) 1.0 0.0\n", 223 | "Noise PSNR: 20.85872056313074\n", 224 | "Recov PSNR: 27.272297774312243\n", 225 | "(512, 512, 3) 1.0 0.0\n", 226 | "Noise PSNR: 20.765993136637388\n", 227 | "Recov PSNR: 27.338659874856795\n", 228 | "(512, 512, 3) 1.0 0.0\n", 229 | "Noise PSNR: 21.306242258803405\n", 230 | "Recov PSNR: 25.446193654517028\n", 231 | "(512, 512, 3) 1.0 0.0\n", 232 | "Noise PSNR: 20.54850148581473\n", 233 | "Recov PSNR: 27.01530978394026\n", 234 | "(512, 512, 3) 1.0 0.0\n", 235 | "Noise PSNR: 20.398360202163094\n", 236 | "Recov PSNR: 26.389982539269145\n", 237 | "(512, 512, 3) 1.0 0.05882353\n", 238 | "Noise PSNR: 20.181425964893922\n", 239 | "Recov PSNR: 25.42297513517164\n", 240 | "(512, 512, 3) 1.0 0.0\n", 241 | "Noise PSNR: 20.236747114864038\n", 242 | "Recov PSNR: 26.119894603915196\n", 243 | "(512, 512, 3) 1.0 0.0\n", 244 | "Noise PSNR: 20.549344599038925\n", 245 | "Recov PSNR: 24.745598223981585\n", 246 | "(512, 512, 3) 1.0 0.0\n", 247 | "Noise PSNR: 20.406044023408143\n", 248 | "Recov PSNR: 23.02123809966171\n", 249 | "(512, 512, 3) 0.88235295 0.0\n", 250 | "Noise PSNR: 21.07100899874531\n", 251 | "Recov PSNR: 27.966267242110685\n", 252 | "(512, 512, 3) 1.0 0.0\n", 253 | "Noise PSNR: 20.266211966933977\n", 254 | "Recov PSNR: 26.026880683704405\n", 255 | "(512, 512, 3) 1.0 0.0\n", 256 | "Noise PSNR: 20.28534942506627\n", 257 | "Recov PSNR: 24.028501285463516\n", 258 | "(512, 512, 3) 1.0 0.0\n", 259 | "Noise PSNR: 20.37162074821613\n", 260 | "Recov PSNR: 23.20894817892436\n", 261 | "(512, 512, 3) 1.0 0.0\n", 262 | "Noise PSNR: 20.819537047306984\n", 263 | "Recov PSNR: 24.955001645791075\n", 264 | "(512, 512, 3) 1.0 0.011764706\n", 265 | "Noise PSNR: 20.36794988194004\n", 266 | "Recov PSNR: 25.01484370885203\n", 267 | "(512, 512, 3) 1.0 0.0\n", 268 | "Noise PSNR: 20.859252924913807\n", 269 | "Recov PSNR: 25.41711012187948\n", 270 | "(512, 512, 3) 1.0 0.07058824\n", 271 | "Noise PSNR: 20.391379042510795\n", 272 | "Recov PSNR: 25.847262037706383\n", 273 | "(512, 512, 3) 1.0 0.0\n", 274 | "Noise PSNR: 20.33835437136609\n", 275 | "Recov PSNR: 23.321868702283826\n", 276 | "(512, 512, 3) 1.0 0.0\n", 277 | "Noise PSNR: 20.43070507108709\n", 278 | "Recov PSNR: 24.873822391942387\n", 279 | "(512, 512, 3) 1.0 0.0\n", 280 | "Noise PSNR: 20.481937911063778\n", 281 | "Recov PSNR: 23.8057189559377\n", 282 | "(512, 512, 3) 1.0 0.0\n", 283 | "Noise PSNR: 21.362199887623458\n", 284 | "Recov PSNR: 28.620436302966986\n", 285 | "(512, 512, 3) 1.0 0.0\n", 286 | "Noise PSNR: 21.195583777328547\n", 287 | "Recov PSNR: 28.785344603958727\n", 288 | "(512, 512, 3) 1.0 0.0\n", 289 | "Noise PSNR: 20.362385421825444\n", 290 | "Recov PSNR: 24.592646298984757\n", 291 | "(512, 512, 3) 1.0 0.0\n", 292 | "Noise PSNR: 20.260951503499502\n", 293 | "Recov PSNR: 26.27547029013389\n", 294 | "(512, 512, 3) 1.0 0.0\n", 295 | "Noise PSNR: 21.437002162190304\n", 296 | "Recov PSNR: 25.784020736052643\n", 297 | "(512, 512, 3) 1.0 0.0\n", 298 | "Noise PSNR: 20.448009174943806\n", 299 | "Recov PSNR: 28.338362719443378\n", 300 | "(512, 512, 3) 1.0 0.0\n", 301 | "Noise PSNR: 20.29782132626301\n", 302 | "Recov PSNR: 23.440980909381945\n", 303 | "(512, 512, 3) 1.0 0.0\n", 304 | "Noise PSNR: 20.394585615905946\n", 305 | "Recov PSNR: 25.44805373031625\n", 306 | "(512, 512, 3) 1.0 0.0\n", 307 | "Noise PSNR: 20.39641311987079\n", 308 | "Recov PSNR: 24.20596256704499\n", 309 | "(512, 512, 3) 1.0 0.078431375\n", 310 | "Noise PSNR: 20.88329087810746\n", 311 | "Recov PSNR: 29.331612323118552\n", 312 | "(512, 512, 3) 1.0 0.0\n", 313 | "Noise PSNR: 20.701236714965223\n", 314 | "Recov PSNR: 24.088279746783332\n", 315 | "(512, 512, 3) 1.0 0.0\n", 316 | "Noise PSNR: 20.351789215955613\n", 317 | "Recov PSNR: 24.565497444521036\n", 318 | "(512, 512, 3) 1.0 0.0\n", 319 | "Noise PSNR: 20.365233633265184\n", 320 | "Recov PSNR: 25.345188304122615\n", 321 | "(512, 512, 3) 1.0 0.0\n", 322 | "Noise PSNR: 20.620843698072637\n", 323 | "Recov PSNR: 25.397858996489905\n", 324 | "(512, 512, 3) 1.0 0.0\n", 325 | "Noise PSNR: 20.772502391644636\n", 326 | "Recov PSNR: 26.214549685021296\n", 327 | "(512, 512, 3) 1.0 0.0\n", 328 | "Noise PSNR: 20.89921246786762\n", 329 | "Recov PSNR: 27.732481393578652\n", 330 | "(512, 512, 3) 1.0 0.0\n", 331 | "Noise PSNR: 20.503894657113488\n", 332 | "Recov PSNR: 25.033410148687754\n", 333 | "(512, 512, 3) 1.0 0.0\n", 334 | "Noise PSNR: 20.746347428212257\n", 335 | "Recov PSNR: 27.42911953759399\n", 336 | "(512, 512, 3) 1.0 0.0\n", 337 | "Noise PSNR: 20.60218659863753\n", 338 | "Recov PSNR: 27.84133661305435\n", 339 | "(512, 512, 3) 1.0 0.050980393\n", 340 | "Noise PSNR: 20.21157092052578\n", 341 | "Recov PSNR: 25.681543134704302\n", 342 | "(512, 512, 3) 1.0 0.007843138\n", 343 | "Noise PSNR: 20.371621188818207\n", 344 | "Recov PSNR: 24.442689973454986\n", 345 | "(512, 512, 3) 1.0 0.0\n", 346 | "Noise PSNR: 20.88919016991864\n", 347 | "Recov PSNR: 26.802588862825644\n", 348 | "(512, 512, 3) 1.0 0.0\n", 349 | "Noise PSNR: 21.505021422178842\n", 350 | "Recov PSNR: 26.751418943139534\n", 351 | "(512, 512, 3) 0.8784314 0.078431375\n", 352 | "Noise PSNR: 20.201139366298737\n", 353 | "Recov PSNR: 25.979236079533912\n", 354 | "(512, 512, 3) 1.0 0.0\n", 355 | "Noise PSNR: 20.229492836552055\n", 356 | "Recov PSNR: 26.556030467778225\n", 357 | "(512, 512, 3) 1.0 0.0\n", 358 | "Noise PSNR: 22.663572833508585\n", 359 | "Recov PSNR: 26.00644769518003\n", 360 | "(512, 512, 3) 1.0 0.0\n", 361 | "Noise PSNR: 20.55470772886831\n", 362 | "Recov PSNR: 26.25484391889147\n", 363 | "(512, 512, 3) 1.0 0.0\n", 364 | "Noise PSNR: 20.791949108181264\n", 365 | "Recov PSNR: 26.210043053906254\n", 366 | "(512, 512, 3) 1.0 0.0\n", 367 | "Noise PSNR: 20.489994383272126\n", 368 | "Recov PSNR: 24.281132877707673\n", 369 | "(512, 512, 3) 1.0 0.0\n", 370 | "Noise PSNR: 20.381062090222596\n", 371 | "Recov PSNR: 24.325705124929762\n", 372 | "(512, 512, 3) 1.0 0.0\n", 373 | "Noise PSNR: 20.438834326532188\n", 374 | "Recov PSNR: 24.384656944078717\n", 375 | "(512, 512, 3) 1.0 0.003921569\n", 376 | "Noise PSNR: 20.360608654940275\n", 377 | "Recov PSNR: 25.71720661048374\n", 378 | "(512, 512, 3) 1.0 0.0\n", 379 | "Noise PSNR: 20.28177360376858\n", 380 | "Recov PSNR: 25.5587857585238\n", 381 | "(512, 512, 3) 1.0 0.0\n", 382 | "Noise PSNR: 20.732083177945896\n", 383 | "Recov PSNR: 24.246224953118094\n", 384 | "(512, 512, 3) 1.0 0.0\n", 385 | "Noise PSNR: 20.547831517988826\n", 386 | "Recov PSNR: 27.15494658403775\n", 387 | "(512, 512, 3) 0.85882354 0.0\n", 388 | "Noise PSNR: 20.253141334835455\n", 389 | "Recov PSNR: 24.61338703868632\n", 390 | "(512, 512, 3) 1.0 0.0\n", 391 | "Noise PSNR: 20.381955459886118\n", 392 | "Recov PSNR: 25.90159621176429\n", 393 | "(512, 512, 3) 1.0 0.0\n", 394 | "Noise PSNR: 20.601218891409445\n", 395 | "Recov PSNR: 26.883687277248278\n", 396 | "(512, 512, 3) 1.0 0.0\n", 397 | "Noise PSNR: 20.496299796191835\n", 398 | "Recov PSNR: 24.530463374261938\n", 399 | "(512, 512, 3) 1.0 0.0\n", 400 | "Noise PSNR: 20.669437860597373\n", 401 | "Recov PSNR: 22.3111797760357\n", 402 | "(512, 512, 3) 1.0 0.0\n", 403 | "Noise PSNR: 20.52946515058389\n", 404 | "Recov PSNR: 24.8521365367373\n", 405 | "(512, 512, 3) 1.0 0.0\n", 406 | "Noise PSNR: 20.418582850846928\n", 407 | "Recov PSNR: 25.872974047164654\n", 408 | "(512, 512, 3) 1.0 0.0\n", 409 | "Noise PSNR: 20.545457635330962\n", 410 | "Recov PSNR: 23.28932483116338\n", 411 | "(512, 512, 3) 0.93333334 0.0\n", 412 | "Noise PSNR: 20.400763235295166\n", 413 | "Recov PSNR: 27.486010211875254\n", 414 | "(512, 512, 3) 1.0 0.0\n", 415 | "Noise PSNR: 20.595904760121005\n", 416 | "Recov PSNR: 25.171128316539246\n", 417 | "(512, 512, 3) 1.0 0.0\n", 418 | "Noise PSNR: 20.649539190001782\n", 419 | "Recov PSNR: 24.285987221176573\n", 420 | "(512, 512, 3) 1.0 0.0\n", 421 | "Noise PSNR: 20.66221849343832\n", 422 | "Recov PSNR: 25.907687691067967\n", 423 | "(512, 512, 3) 1.0 0.0\n", 424 | "Noise PSNR: 20.44171253833379\n", 425 | "Recov PSNR: 24.372830839819954\n", 426 | "(512, 512, 3) 1.0 0.0\n", 427 | "Noise PSNR: 20.665001186594473\n", 428 | "Recov PSNR: 25.678229193479474\n", 429 | "(512, 512, 3) 1.0 0.0\n", 430 | "Noise PSNR: 20.194153633915025\n", 431 | "Recov PSNR: 25.16241731506566\n", 432 | "(512, 512, 3) 1.0 0.0\n", 433 | "Noise PSNR: 20.241951382784187\n", 434 | "Recov PSNR: 24.872739451710075\n", 435 | "(512, 512, 3) 1.0 0.0\n", 436 | "Noise PSNR: 21.15028872426424\n", 437 | "Recov PSNR: 26.21962158905324\n", 438 | "(512, 512, 3) 1.0 0.0\n", 439 | "Noise PSNR: 20.946377291946877\n", 440 | "Recov PSNR: 24.956474403494262\n", 441 | "(512, 512, 3) 1.0 0.0\n", 442 | "Noise PSNR: 20.42800509504013\n", 443 | "Recov PSNR: 26.085611432453454\n", 444 | "(512, 512, 3) 1.0 0.0\n", 445 | "Noise PSNR: 20.579051433603397\n", 446 | "Recov PSNR: 26.471694868140176\n", 447 | "(512, 512, 3) 1.0 0.0\n", 448 | "Noise PSNR: 20.273315961656948\n", 449 | "Recov PSNR: 23.70575463385746\n", 450 | "(512, 512, 3) 1.0 0.007843138\n", 451 | "Noise PSNR: 21.33651026880801\n", 452 | "Recov PSNR: 26.93139751777845\n", 453 | "(512, 512, 3) 1.0 0.0\n", 454 | "Noise PSNR: 20.89228910571425\n", 455 | "Recov PSNR: 24.759780544364784\n", 456 | "(512, 512, 3) 1.0 0.0\n", 457 | "Noise PSNR: 20.824624513111274\n", 458 | "Recov PSNR: 27.26841616666023\n", 459 | "25.522569088039877\n", 460 | "25.522569088039877\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "psnrs = []\n", 466 | "for img in imgs:\n", 467 | " # get noisy img\n", 468 | " \n", 469 | " # make grayscale\n", 470 | " #gimg = np.array([rgb2gray(imgs[0])])\n", 471 | " #img = gimg.transpose(1,2,0)\n", 472 | " img = img.transpose(1,2,0)\n", 473 | " \n", 474 | " img_noisy_np = get_noisy_img(img,sig=sig,noise_same=noise_same)\n", 475 | " output_depth = img.shape[0] \n", 476 | " print(img.shape, np.max(img), np.min(img))\n", 477 | " \n", 478 | " # denoise\n", 479 | " sigma = sig/255.\n", 480 | " out_img_np = pybm3d.bm3d.bm3d(img_noisy_np, sigma)\n", 481 | " #img = img.transpose(2,1,0)\n", 482 | " \n", 483 | " print(\"Noise PSNR: \",psnr(img,img_noisy_np) )\n", 484 | " print(\"Recov PSNR: \",psnr(img,out_img_np) )\n", 485 | " psnrv = psnr(img,out_img_np)\n", 486 | " psnrs.append(psnrv)\n", 487 | "print(np.mean(psnrs))\n", 488 | "print(np.mean(psnrs))" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [] 497 | } 498 | ], 499 | "metadata": { 500 | "kernelspec": { 501 | "display_name": "Python 3", 502 | "language": "python", 503 | "name": "python3" 504 | }, 505 | "language_info": { 506 | "codemirror_mode": { 507 | "name": "ipython", 508 | "version": 3 509 | }, 510 | "file_extension": ".py", 511 | "mimetype": "text/x-python", 512 | "name": "python", 513 | "nbconvert_exporter": "python", 514 | "pygments_lexer": "ipython3", 515 | "version": "3.6.7" 516 | } 517 | }, 518 | "nbformat": 4, 519 | "nbformat_minor": 2 520 | } 521 | -------------------------------------------------------------------------------- /include/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | from .wavelet import * 3 | from .decoder import * 4 | from .visualize import * 5 | from .fit import * 6 | from .helpers import * 7 | from .compression import * -------------------------------------------------------------------------------- /include/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/compression.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/compression.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/compression.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/compression.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/fit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/fit.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/fit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/fit.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/visualize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/visualize.cpython-37.pyc -------------------------------------------------------------------------------- /include/__pycache__/wavelet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/wavelet.cpython-36.pyc -------------------------------------------------------------------------------- /include/__pycache__/wavelet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/include/__pycache__/wavelet.cpython-37.pyc -------------------------------------------------------------------------------- /include/compression.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim 4 | import copy 5 | import numpy as np 6 | 7 | from .helpers import * 8 | from .decoder import * 9 | from .fit import * 10 | from .wavelet import * 11 | 12 | def rep_error_deep_decoder(img_np,k=128,convert2ycbcr=False): 13 | ''' 14 | mse obtained by representing img_np with the deep decoder 15 | ''' 16 | output_depth = img_np.shape[0] 17 | if output_depth == 3 and convert2ycbcr: 18 | img = rgb2ycbcr(img_np) 19 | else: 20 | img = img_np 21 | img_var = np_to_var(img).type(dtype) 22 | 23 | num_channels = [k]*5 24 | net = decodernwv2(output_depth,num_channels_up=num_channels,bn_before_act=True).type(dtype) 25 | rnd = 500 26 | numit = 15000 27 | rn = 0.005 28 | mse_n, mse_t, ni, net = fit( num_channels=num_channels, 29 | reg_noise_std=rn, 30 | reg_noise_decayevery = rnd, 31 | num_iter=numit, 32 | LR=0.004, 33 | img_noisy_var=img_var, 34 | net=net, 35 | img_clean_var=img_var, 36 | find_best=True, 37 | ) 38 | out_img = net(ni.type(dtype)).data.cpu().numpy()[0] 39 | if output_depth == 3 and convert2ycbcr: 40 | out_img = ycbcr2rgb(out_img) 41 | return psnr(out_img,img_np), out_img, num_param(net) 42 | 43 | def rep_error_wavelet(img_np,ncoeff=300): 44 | ''' 45 | mse obtained by representing img_np with wavelet thresholding 46 | ncoff coefficients are retained per color channel 47 | ''' 48 | if img_np.shape[0] == 1: 49 | img_np = img_np[0,:,:] 50 | out_img_np = denoise_wavelet(img_np, ncoeff=ncoeff, multichannel=False, convert2ycbcr=True, mode='hard') 51 | else: 52 | img_np = np.transpose(img_np) 53 | out_img_np = denoise_wavelet(img_np, ncoeff=ncoeff, multichannel=True, convert2ycbcr=True, mode='hard') 54 | # img_np = np.array([img_np[:,:,0],img_np[:,:,1],img_np[:,:,2]]) 55 | return psnr(out_img_np,img_np), out_img_np 56 | 57 | def myimgshow(plt,img): 58 | if(img.shape[0] == 1): 59 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='none') 60 | else: 61 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1),interpolation='none') 62 | 63 | -------------------------------------------------------------------------------- /include/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def add_module(self, module): 6 | self.add_module(str(len(self) + 1), module) 7 | 8 | torch.nn.Module.add = add_module 9 | 10 | 11 | def conv(in_f, out_f, kernel_size, stride=1, pad='zero',bias=False): 12 | padder = None 13 | to_pad = int((kernel_size - 1) / 2) 14 | if pad == 'reflection': 15 | padder = nn.ReflectionPad2d(to_pad) 16 | to_pad = 0 17 | 18 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 19 | 20 | layers = filter(lambda x: x is not None, [padder, convolver]) 21 | return nn.Sequential(*layers) 22 | 23 | def decodernw( 24 | num_output_channels=3, 25 | num_channels_up=[128]*5, 26 | filter_size_up=1, 27 | need_sigmoid=True, 28 | pad ='reflection', 29 | upsample_mode='bilinear', 30 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 31 | bn_before_act = False, 32 | bn_affine = True, 33 | bn = True, 34 | upsample_first = True, 35 | bias=False 36 | ): 37 | 38 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 39 | n_scales = len(num_channels_up) 40 | 41 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 42 | filter_size_up = [filter_size_up]*n_scales 43 | model = nn.Sequential() 44 | 45 | 46 | for i in range(len(num_channels_up)-1): 47 | 48 | if upsample_first: 49 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad, bias=bias)) 50 | if upsample_mode!='none' and i != len(num_channels_up)-2: 51 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 52 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 53 | else: 54 | if upsample_mode!='none' and i!=0: 55 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 56 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 57 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad,bias=bias)) 58 | 59 | if i != len(num_channels_up)-1: 60 | if(bn_before_act and bn): 61 | model.add(nn.BatchNorm2d( num_channels_up[i+1] ,affine=bn_affine)) 62 | if act_fun is not None: 63 | model.add(act_fun) 64 | if( (not bn_before_act) and bn): 65 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 66 | 67 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad,bias=bias)) 68 | if need_sigmoid: 69 | model.add(nn.Sigmoid()) 70 | 71 | return model 72 | 73 | 74 | 75 | # Residual block 76 | class ResidualBlock(nn.Module): 77 | def __init__(self, in_f, out_f): 78 | super(ResidualBlock, self).__init__() 79 | self.conv = nn.Conv2d(in_f, out_f, 1, 1, padding=0, bias=False) 80 | 81 | def forward(self, x): 82 | residual = x 83 | out = self.conv(x) 84 | out += residual 85 | return out 86 | 87 | def resdecoder( 88 | num_output_channels=3, 89 | num_channels_up=[128]*5, 90 | filter_size_up=1, 91 | need_sigmoid=True, 92 | pad='reflection', 93 | upsample_mode='bilinear', 94 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 95 | bn_before_act = False, 96 | bn_affine = True, 97 | ): 98 | 99 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 100 | n_scales = len(num_channels_up) 101 | 102 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 103 | filter_size_up = [filter_size_up]*n_scales 104 | 105 | model = nn.Sequential() 106 | 107 | for i in range(len(num_channels_up)-2): 108 | 109 | model.add( ResidualBlock( num_channels_up[i], num_channels_up[i+1]) ) 110 | 111 | if upsample_mode!='none': 112 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 113 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 114 | 115 | if i != len(num_channels_up)-1: 116 | model.add(act_fun) 117 | #model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 118 | 119 | # new 120 | model.add(ResidualBlock( num_channels_up[-1], num_channels_up[-1])) 121 | #model.add(nn.BatchNorm2d( num_channels_up[-1] ,affine=bn_affine)) 122 | model.add(act_fun) 123 | # end new 124 | 125 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 126 | 127 | if need_sigmoid: 128 | model.add(nn.Sigmoid()) 129 | 130 | return model 131 | 132 | ########################## 133 | 134 | 135 | def np_to_tensor(img_np): 136 | '''Converts image in numpy.array to torch.Tensor. 137 | 138 | From C x W x H [0..1] to C x W x H [0..1] 139 | ''' 140 | return torch.from_numpy(img_np) 141 | 142 | def set_to(tensor,mtx): 143 | if not len(tensor.shape)==4: 144 | raise Exception("assumes a 4D tensor") 145 | num_kernels = tensor.shape[0] 146 | for i in range(tensor.shape[0]): 147 | for j in range(tensor.shape[1]): 148 | if i == j: 149 | tensor[i,j] = np_to_tensor(mtx) 150 | else: 151 | tensor[i,j] = np_to_tensor(np.zeros(mtx.shape)) 152 | return tensor 153 | 154 | def conv2(in_f, out_f, kernel_size, stride=1, pad='zero',bias=False): 155 | padder = None 156 | to_pad = int((kernel_size - 1) / 2) 157 | 158 | if kernel_size != 4: 159 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 160 | else: 161 | padder = nn.ReflectionPad2d( (1,0,1,0) ) 162 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=1, bias=bias) 163 | layers = filter(lambda x: x is not None, [padder, convolver]) 164 | return nn.Sequential(*layers) 165 | 166 | def fixed_decodernw( 167 | num_output_channels=3, 168 | num_channels_up=[128]*5, 169 | need_sigmoid=True, 170 | pad ='reflection', 171 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 172 | bn_affine = True, 173 | bn = True, 174 | mtx = np.array( [[1,3,3,1] , [3,9,9,3], [3,9,9,3], [1,3,3,1] ] )*1/16., 175 | output_padding = 0,padding=1, 176 | ): 177 | 178 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 179 | n_scales = len(num_channels_up) 180 | 181 | model = nn.Sequential() 182 | 183 | for i in range(len(num_channels_up)-2): 184 | 185 | # those will be fixed 186 | model.add(conv2( num_channels_up[i], num_channels_up[i], 4, 1, pad=pad)) 187 | # those will be learned 188 | model.add(conv( num_channels_up[i], num_channels_up[i+1], 1, 1, pad=pad)) 189 | 190 | if i != len(num_channels_up)-1: 191 | if act_fun is not None: 192 | model.add(act_fun) 193 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 194 | 195 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 196 | if need_sigmoid: 197 | model.add(nn.Sigmoid()) 198 | 199 | ### 200 | # this is a Gaussian kernel 201 | 202 | # set filters to fixed and then set the gradients to zero 203 | for m in model.modules(): 204 | if isinstance(m, nn.Conv2d): 205 | if(m.kernel_size == mtx.shape): 206 | m.weight.data = set_to(m.weight.data,mtx) 207 | for param in m.parameters(): 208 | param.requires_grad = False 209 | ### 210 | 211 | return model 212 | 213 | 214 | #### 215 | 216 | def deconv_decoder( 217 | num_output_channels=3, 218 | num_channels_up=[128]*5, 219 | filter_size=1, 220 | pad ='reflection', 221 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 222 | bn_affine = True, 223 | stride=2, 224 | padding=0, 225 | output_padding=0, 226 | final_conv=False, 227 | ): 228 | 229 | n_scales = len(num_channels_up) 230 | 231 | model = nn.Sequential() 232 | 233 | for i in range(len(num_channels_up)-1): 234 | 235 | model.add( 236 | nn.ConvTranspose2d(num_channels_up[i], num_channels_up[i+1], filter_size, stride=stride, padding=padding, output_padding=output_padding, groups=1, bias=False, dilation=1) 237 | ) 238 | #model.add(deconv(num_channels_up[i], num_channels_up[i+1], filter_size, stride,pad)) 239 | 240 | if i != len(num_channels_up)-1: 241 | model.add(act_fun) 242 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 243 | 244 | if final_conv: 245 | model.add(conv( num_channels_up[-1], num_channels_up[-1], 1, 1, pad=pad)) 246 | model.add(act_fun) 247 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 248 | 249 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 250 | model.add(nn.Sigmoid()) 251 | 252 | return model 253 | 254 | 255 | ##### 256 | 257 | 258 | def fixed_deconv_decoder( 259 | num_output_channels=3, 260 | num_channels_up=[128]*5, 261 | filter_size=1, 262 | pad ='reflection', 263 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 264 | bn_affine = True, 265 | mtx = np.array( [[1,4,7,4,1] , [4,16,26,16,4], [7,26,41,26,7], [4,16,26,16,4], [1,4,7,4,1]] ), 266 | output_padding=1, 267 | padding=2, 268 | ): 269 | 270 | num_channels_up = num_channels_up + [num_channels_up[-1]] 271 | n_scales = len(num_channels_up) 272 | 273 | model = nn.Sequential() 274 | 275 | for i in range(len(num_channels_up)-1): 276 | 277 | # those will be learned - conv 278 | model.add(conv( num_channels_up[i], num_channels_up[i+1], 1, 1, pad=pad)) 279 | 280 | # those will be fixed - upsample 281 | model.add( nn.ConvTranspose2d( 282 | num_channels_up[i], 283 | num_channels_up[i+1], 284 | kernel_size=4, 285 | stride=2, 286 | padding=padding, 287 | output_padding=output_padding, groups=1, bias=False, dilation=1) ) 288 | 289 | if i != len(num_channels_up)-1: 290 | model.add(act_fun) 291 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 292 | 293 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 294 | model.add(nn.Sigmoid()) 295 | 296 | ### 297 | # this is a Gaussian kernel 298 | # set filters to fixed and then set the gradients to zero 299 | for m in model.modules(): 300 | if isinstance(m, nn.ConvTranspose2d): 301 | if(m.kernel_size == mtx.shape): 302 | m.weight.data = set_to(m.weight.data,mtx) 303 | for param in m.parameters(): 304 | param.requires_grad = False 305 | ### 306 | 307 | return model 308 | 309 | 310 | -------------------------------------------------------------------------------- /include/denoise.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim 4 | import copy 5 | import numpy as np 6 | from scipy.linalg import hadamard 7 | 8 | from .helpers import * 9 | 10 | dtype = torch.cuda.FloatTensor 11 | #dtype = torch.FloatTensor 12 | 13 | 14 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7): 15 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 16 | lr = init_lr * (0.5**(epoch // lr_decay_epoch)) 17 | 18 | if epoch % lr_decay_epoch == 0: 19 | print('LR is set to {}'.format(lr)) 20 | 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = lr 23 | 24 | return optimizer 25 | 26 | 27 | def DIPdenoise( net, 28 | img_noisy_var, 29 | num_channels, 30 | img_clean_var, 31 | net_type="decoder", 32 | num_iter = 5000, 33 | LR = 0.01, 34 | OPTIMIZER='adam', 35 | opt_input = False, 36 | reg_noise_std = 0, 37 | reg_noise_decayevery = 100000, 38 | mask_var = None, 39 | apply_f = None, 40 | decaylr = False, 41 | net_input = None, 42 | net_input_gen = "random", 43 | ): 44 | 45 | if net_input is not None: 46 | print("input provided") 47 | else: 48 | # feed noise into the network 49 | if net_type == "decoder": 50 | totalupsample = 2**len(num_channels) 51 | width = int(img_clean_var.data.shape[2]/totalupsample) 52 | height = int(img_clean_var.data.shape[3]/totalupsample) 53 | shape = [1,num_channels[0], width, height] 54 | print("shape: ", shape) 55 | if(net_input_gen == "random"): 56 | net_input = Variable(torch.zeros(shape)) 57 | net_input.data.uniform_() 58 | net_input.data *= 1./10 59 | elif(net_input_gen == "hadamard"): 60 | H = hadamard(width*height) 61 | ni = np.zeros(shape) 62 | for i in range(shape[1]): 63 | ni[0,i] = np.reshape(H[i],(width,height)) 64 | net_input = np_to_var(ni[0]) 65 | net_input.data *= 1./20 66 | elif(net_input_gen == "rademacher"): 67 | ni = np.random.randint(2, size = shape) - 0.5 68 | net_input = np_to_var(ni[0]) 69 | net_input.data *= 1./10 70 | print(net_input.data.cpu().numpy()[0]) 71 | 72 | elif net_type == "hourglass": 73 | print("hourglass mode") 74 | input_depth = 32 75 | shape = [1, input_depth, img_clean_var.data.shape[2], img_clean_var.data.shape[3]] 76 | net_input = Variable(torch.zeros(shape)) 77 | net_input.data.uniform_() 78 | net_input.data *= 1./10 79 | elif net_type == "noup": 80 | print("no upsampling mode") 81 | shape = [1, num_channels[0], img_clean_var.data.shape[2], img_clean_var.data.shape[3]] 82 | net_input = Variable(torch.zeros(shape)) 83 | net_input.data.uniform_() 84 | net_input.data *= 1./10 85 | #print(net_input.data.cpu().numpy() ) 86 | net_input_saved = net_input.data.clone() 87 | noise = net_input.data.clone() 88 | p = [x for x in net.parameters() ] 89 | 90 | if(opt_input == True): 91 | net_input.requires_grad = True 92 | p += [net_input] 93 | 94 | mse_wrt_noisy = np.zeros(num_iter) 95 | mse_wrt_truth = np.zeros(num_iter) 96 | if OPTIMIZER == 'SGD': 97 | print("optimize with SGD", LR) 98 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9) 99 | elif OPTIMIZER == 'adam': 100 | print("optimize with adam", LR) 101 | optimizer = torch.optim.Adam(p, lr=LR) 102 | elif OPTIMIZER == 'adadelta': 103 | print("optimize with adadelta", LR) 104 | optimizer = torch.optim.Adadelta(p, lr=LR, rho=0.9, eps=1e-06, weight_decay=0) 105 | 106 | mse = torch.nn.MSELoss().type(dtype) 107 | if apply_f is None: 108 | noise_energy = mse(img_noisy_var, img_clean_var) 109 | else: 110 | noise_energy = mse(img_noisy_var, img_noisy_var) 111 | 112 | for i in range(num_iter): 113 | if decaylr is True: 114 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=100) 115 | if reg_noise_std > 0: 116 | if i % reg_noise_decayevery == 0: 117 | reg_noise_std *= 0.7 118 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 119 | optimizer.zero_grad() 120 | out = net(net_input.type(dtype)) 121 | # training loss 122 | if mask_var is not None: 123 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 124 | elif apply_f: 125 | loss = mse( apply_f(out) , img_noisy_var ) 126 | else: 127 | loss = mse(out, img_noisy_var) 128 | loss.backward() 129 | mse_wrt_noisy[i] = var_to_np(loss) 130 | # the actual loss 131 | true_loss = mse(Variable(out.data, requires_grad=False), img_clean_var) 132 | mse_wrt_truth[i] = var_to_np(true_loss) 133 | if i % 10 == 0: 134 | out2 = net(Variable(net_input_saved).type(dtype)) 135 | loss2 = mse(out2, img_clean_var) 136 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f Noise Energy %f' % (i, loss.data[0],true_loss.data[0],loss2.data[0],noise_energy.data[0]), '\r', end='') 137 | optimizer.step() 138 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved 139 | 140 | 141 | 142 | ''' 143 | 144 | def fit(net, 145 | img_noisy_var, 146 | num_channels, 147 | img_clean_var, 148 | num_iter = 5000, 149 | LR = 0.01, 150 | OPTIMIZER='adam', 151 | opt_input = False, 152 | reg_noise_std = 0, 153 | reg_noise_decayevery = 100000, 154 | mask_var = None, 155 | apply_f = None, 156 | decaylr = False, 157 | net_input = None, 158 | net_input_gen = "random", 159 | ): 160 | 161 | if net_input is not None: 162 | print("input provided") 163 | else: 164 | # feed uniform noise into the network 165 | totalupsample = 2**len(num_channels) 166 | width = int(img_clean_var.data.shape[2]/totalupsample) 167 | height = int(img_clean_var.data.shape[3]/totalupsample) 168 | shape = [1,num_channels[0], width, height] 169 | print("shape: ", shape) 170 | net_input = Variable(torch.zeros(shape)) 171 | net_input.data.uniform_() 172 | net_input.data *= 1./10 173 | 174 | net_input_saved = net_input.data.clone() 175 | noise = net_input.data.clone() 176 | p = [x for x in net.parameters() ] 177 | 178 | if(opt_input == True): 179 | net_input.requires_grad = True 180 | p += [net_input] 181 | 182 | mse_wrt_noisy = np.zeros(num_iter) 183 | mse_wrt_truth = np.zeros(num_iter) 184 | 185 | if OPTIMIZER == 'SGD': 186 | print("optimize with SGD", LR) 187 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9) 188 | elif OPTIMIZER == 'adam': 189 | print("optimize with adam", LR) 190 | optimizer = torch.optim.Adam(p, lr=LR) 191 | 192 | mse = torch.nn.MSELoss() #.type(dtype) 193 | noise_energy = mse(img_noisy_var, img_clean_var) 194 | 195 | best_net = copy.deepcopy(net) 196 | best_mse = 1000000.0 197 | 198 | for i in range(num_iter): 199 | if decaylr is True: 200 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=100) 201 | if reg_noise_std > 0: 202 | if i % reg_noise_decayevery == 0: 203 | reg_noise_std *= 0.7 204 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 205 | optimizer.zero_grad() 206 | out = net(net_input.type(dtype)) 207 | 208 | # training loss 209 | if mask_var is not None: 210 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 211 | elif apply_f: 212 | loss = mse( apply_f(out) , img_noisy_var ) 213 | else: 214 | loss = mse(out, img_noisy_var) 215 | loss.backward() 216 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 217 | 218 | # the actual loss 219 | true_loss = mse(Variable(out.data, requires_grad=False), img_clean_var) 220 | mse_wrt_truth[i] = true_loss.data.cpu().numpy() 221 | if i % 10 == 0: 222 | out2 = net(Variable(net_input_saved).type(dtype)) 223 | loss2 = mse(out2, img_clean_var) 224 | #print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f Noise Energy %f' 225 | # % (i, loss.data.item(),true_loss.data.item(),loss2.data.item(),noise_energy.data.item()), '\r', end='') 226 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f Noise Energy %f' % (i, loss.data[0],true_loss.data[0],loss2.data[0],noise_energy.data[0]), '\r', end='') 227 | 228 | # if training loss improves by at least one percent, we found a new best net 229 | if best_mse > 1.005*loss.data[0]: 230 | best_mse = loss.data[0] 231 | best_net = copy.deepcopy(net) 232 | 233 | optimizer.step() 234 | 235 | net = best_net 236 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved 237 | 238 | ''' 239 | 240 | -------------------------------------------------------------------------------- /include/fit.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim 4 | import copy 5 | import numpy as np 6 | from scipy.linalg import hadamard 7 | 8 | from .helpers import * 9 | 10 | dtype = torch.cuda.FloatTensor 11 | #dtype = torch.FloatTensor 12 | 13 | 14 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=500): 15 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 16 | lr = init_lr * (0.65**(epoch // lr_decay_epoch)) 17 | 18 | if epoch % lr_decay_epoch == 0: 19 | print('LR is set to {}'.format(lr)) 20 | 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = lr 23 | 24 | return optimizer 25 | 26 | def sqnorm(a): 27 | return np.sum( a*a ) 28 | 29 | def get_distances(initial_maps,final_maps): 30 | results = [] 31 | for a,b in zip(initial_maps,final_maps): 32 | res = sqnorm(a-b)/(sqnorm(a) + sqnorm(b)) 33 | results += [res] 34 | return(results) 35 | 36 | def get_weights(net): 37 | weights = [] 38 | for m in net.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | weights += [m.weight.data.cpu().numpy()] 41 | return weights 42 | 43 | def fit(net, 44 | img_noisy_var, 45 | num_channels, 46 | img_clean_var, 47 | num_iter = 5000, 48 | LR = 0.01, 49 | OPTIMIZER='adam', 50 | opt_input = False, 51 | reg_noise_std = 0, 52 | reg_noise_decayevery = 100000, 53 | mask_var = None, 54 | apply_f = None, 55 | lr_decay_epoch = 0, 56 | net_input = None, 57 | net_input_gen = "random", 58 | find_best=False, 59 | weight_decay=0, 60 | upsample_mode = "bilinear", 61 | totalupsample = 1, 62 | loss_type="MSE", 63 | output_gradients=False, 64 | output_weights=False, 65 | show_images=False, 66 | plot_after=None, 67 | ): 68 | 69 | if net_input is not None: 70 | print("input provided") 71 | else: 72 | if upsample_mode=="bilinear": 73 | # feed uniform noise into the network 74 | totalupsample = 2**len(num_channels) 75 | elif upsample_mode=="deconv": 76 | # feed uniform noise into the network 77 | totalupsample = 2**(len(num_channels)-1) 78 | width = int(img_clean_var.data.shape[2]/totalupsample) 79 | height = int(img_clean_var.data.shape[3]/totalupsample) 80 | shape = [1,num_channels[0], width, height] 81 | print("input shape: ", shape) 82 | net_input = Variable(torch.zeros(shape)).type(dtype) 83 | net_input.data.uniform_() 84 | net_input.data *= 1./10 85 | 86 | net_input = net_input.type(dtype) 87 | net_input_saved = net_input.data.clone() 88 | noise = net_input.data.clone() 89 | p = [x for x in net.parameters() ] 90 | 91 | if(opt_input == True): # optimizer over the input as well 92 | net_input.requires_grad = True 93 | p += [net_input] 94 | 95 | mse_wrt_noisy = np.zeros(num_iter) 96 | mse_wrt_truth = np.zeros(num_iter) 97 | 98 | 99 | if OPTIMIZER == 'SGD': 100 | print("optimize with SGD", LR) 101 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9,weight_decay=weight_decay) 102 | elif OPTIMIZER == 'adam': 103 | print("optimize with adam", LR) 104 | optimizer = torch.optim.Adam(p, lr=LR,weight_decay=weight_decay) 105 | elif OPTIMIZER == 'LBFGS': 106 | print("optimize with LBFGS", LR) 107 | optimizer = torch.optim.LBFGS(p, lr=LR) 108 | 109 | if loss_type=="MSE": 110 | mse = torch.nn.MSELoss() #.type(dtype) 111 | if loss_type=="L1": 112 | mse = nn.L1Loss() 113 | 114 | if find_best: 115 | best_net = copy.deepcopy(net) 116 | best_mse = 1000000.0 117 | 118 | nconvnets = 0 119 | for p in list(filter(lambda p: len(p.data.shape)>2, net.parameters())): 120 | nconvnets += 1 121 | 122 | out_grads = np.zeros((nconvnets,num_iter)) 123 | 124 | init_weights = get_weights(net) 125 | out_weights = np.zeros(( len(init_weights) ,num_iter)) 126 | 127 | out_imgs = np.zeros((1,1)) 128 | 129 | if plot_after is not None: 130 | out_img_np = net( net_input_saved.type(dtype) ).data.cpu().numpy()[0] 131 | out_imgs = np.zeros( (len(plot_after),) + out_img_np.shape ) 132 | 133 | for i in range(num_iter): 134 | 135 | if lr_decay_epoch is not 0: 136 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=lr_decay_epoch) 137 | if reg_noise_std > 0: 138 | if i % reg_noise_decayevery == 0: 139 | reg_noise_std *= 0.7 140 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 141 | 142 | def closure(): 143 | optimizer.zero_grad() 144 | out = net(net_input.type(dtype)) 145 | 146 | # training loss 147 | if mask_var is not None: 148 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 149 | elif apply_f: 150 | loss = mse( apply_f(out) , img_noisy_var ) 151 | else: 152 | loss = mse(out, img_noisy_var) 153 | 154 | loss.backward() 155 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 156 | 157 | 158 | # the actual loss 159 | true_loss = mse( Variable(out.data, requires_grad=False).type(dtype), img_clean_var.type(dtype) ) 160 | mse_wrt_truth[i] = true_loss.data.cpu().numpy() 161 | 162 | if output_gradients: 163 | for ind,p in enumerate(list(filter(lambda p: p.grad is not None and len(p.data.shape)>2, net.parameters()))): 164 | out_grads[ind,i] = p.grad.data.norm(2).item() 165 | #print(p.grad.data.norm(2).item()) 166 | #su += p.grad.data.norm(2).item() 167 | #mse_wrt_noisy[i] = su 168 | 169 | if i % 10 == 0: 170 | out2 = net(Variable(net_input_saved).type(dtype)) 171 | loss2 = mse(out2, img_clean_var) 172 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f' % (i, loss.data,true_loss.data,loss2.data), '\r', end='') 173 | 174 | if show_images: 175 | if i % 50 == 0: 176 | print(i) 177 | out_img_np = net( ni.type(dtype) ).data.cpu().numpy()[0] 178 | myimgshow(plt,out_img_np) 179 | plt.show() 180 | 181 | if plot_after is not None: 182 | if i in plot_after: 183 | out_imgs[ plot_after.index(i) ,:] = net( net_input_saved.type(dtype) ).data.cpu().numpy()[0] 184 | 185 | if output_weights: 186 | out_weights[:,i] = np.array( get_distances( init_weights, get_weights(net) ) ) 187 | 188 | return loss 189 | 190 | loss = optimizer.step(closure) 191 | 192 | if find_best: 193 | # if training loss improves by at least one percent, we found a new best net 194 | if best_mse > 1.005*loss.data: 195 | best_mse = loss.data 196 | best_net = copy.deepcopy(net) 197 | 198 | 199 | if find_best: 200 | net = best_net 201 | if output_gradients and output_weights: 202 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_grads 203 | elif output_gradients: 204 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_grads 205 | elif output_weights: 206 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_weights 207 | elif plot_after is not None: 208 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net, out_imgs 209 | else: 210 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net 211 | 212 | 213 | 214 | 215 | 216 | 217 | ### weight regularization 218 | #if orth_reg > 0: 219 | # for name, param in net.named_parameters(): 220 | # consider all the conv weights, but the last one which only combines colors 221 | # if '.1.weight' in name and str( len(net)-1 ) not in name: 222 | # param_flat = param.view(param.shape[0], -1) 223 | # sym = torch.mm(param_flat, torch.t(param_flat)) 224 | # sym -= Variable(torch.eye(param_flat.shape[0])).type(dtype) 225 | # loss = loss + (orth_reg * sym.sum().type(dtype) ) 226 | ### 227 | 228 | def fit_multiple(net, 229 | imgs, # list of images [ [1, color channels, W, H] ] 230 | num_channels, 231 | num_iter = 5000, 232 | LR = 0.01, 233 | find_best=False, 234 | upsample_mode="bilinear", 235 | ): 236 | # generate netinputs 237 | # feed uniform noise into the network 238 | nis = [] 239 | for i in range(len(imgs)): 240 | if upsample_mode=="bilinear": 241 | # feed uniform noise into the network 242 | totalupsample = 2**len(num_channels) 243 | elif upsample_mode=="deconv": 244 | # feed uniform noise into the network 245 | totalupsample = 2**(len(num_channels)-1) 246 | #totalupsample = 2**len(num_channels) 247 | width = int(imgs[0].data.shape[2]/totalupsample) 248 | height = int(imgs[0].data.shape[3]/totalupsample) 249 | shape = [1 ,num_channels[0], width, height] 250 | print("shape: ", shape) 251 | net_input = Variable(torch.zeros(shape)) 252 | net_input.data.uniform_() 253 | net_input.data *= 1./10 254 | nis.append(net_input) 255 | 256 | # learnable parameters are the weights 257 | p = [x for x in net.parameters() ] 258 | 259 | mse_wrt_noisy = np.zeros(num_iter) 260 | 261 | optimizer = torch.optim.Adam(p, lr=LR) 262 | 263 | mse = torch.nn.MSELoss() #.type(dtype) 264 | 265 | if find_best: 266 | best_net = copy.deepcopy(net) 267 | best_mse = 1000000.0 268 | 269 | for i in range(num_iter): 270 | 271 | def closure(): 272 | optimizer.zero_grad() 273 | 274 | #loss = np_to_var(np.array([0.0])) 275 | out = net(nis[0].type(dtype)) 276 | loss = mse(out, imgs[0].type(dtype)) 277 | #for img,ni in zip(imgs,nis): 278 | for j in range(1,len(imgs)): 279 | #out = net(ni.type(dtype)) 280 | #loss += mse(out, img.type(dtype)) 281 | out = net(nis[j].type(dtype)) 282 | loss += mse(out, imgs[j].type(dtype)) 283 | 284 | #out = net(nis[0].type(dtype)) 285 | #out2 = net(nis[1].type(dtype)) 286 | #loss = mse(out, imgs[0].type(dtype)) + mse(out2, imgs[1].type(dtype)) 287 | 288 | loss.backward() 289 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 290 | 291 | if i % 10 == 0: 292 | print ('Iteration %05d Train loss %f' % (i, loss.data), '\r', end='') 293 | return loss 294 | 295 | loss = optimizer.step(closure) 296 | 297 | if find_best: 298 | # if training loss improves by at least one percent, we found a new best net 299 | if best_mse > 1.005*loss.data: 300 | best_mse = loss.data 301 | best_net = copy.deepcopy(net) 302 | 303 | if find_best: 304 | net = best_net 305 | return mse_wrt_noisy, nis, net 306 | 307 | -------------------------------------------------------------------------------- /include/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import PIL 9 | import numpy as np 10 | 11 | from torch.autograd import Variable 12 | 13 | import random 14 | import numpy as np 15 | import torch 16 | import matplotlib.pyplot as plt 17 | 18 | from PIL import Image 19 | import PIL 20 | 21 | from torch.autograd import Variable 22 | 23 | def myimgshow(plt,img): 24 | if(img.shape[0] == 1): 25 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='none') 26 | else: 27 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1),interpolation='none') 28 | 29 | def load_and_crop(imgname,target_width=512,target_height=512): 30 | ''' 31 | imgname: string of image location 32 | load an image, and center-crop if the image is large enough, else return none 33 | ''' 34 | img = Image.open(imgname) 35 | width, height = img.size 36 | if width <= target_width or height <= target_height: 37 | return None 38 | 39 | left = (width - target_width)/2 40 | top = (height - target_height)/2 41 | right = (width + target_width)/2 42 | bottom = (height + target_height)/2 43 | 44 | return img.crop((left, top, right, bottom)) 45 | 46 | def save_np_img(img,filename): 47 | if(img.shape[0] == 1): 48 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='nearest') 49 | else: 50 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1)) 51 | plt.axis('off') 52 | plt.savefig(filename, bbox_inches='tight') 53 | plt.close() 54 | 55 | def np_to_tensor(img_np): 56 | '''Converts image in numpy.array to torch.Tensor. 57 | 58 | From C x W x H [0..1] to C x W x H [0..1] 59 | ''' 60 | return torch.from_numpy(img_np) 61 | 62 | def np_to_var(img_np, dtype = torch.cuda.FloatTensor): 63 | '''Converts image in numpy.array to torch.Variable. 64 | 65 | From C x W x H [0..1] to 1 x C x W x H [0..1] 66 | ''' 67 | return Variable(np_to_tensor(img_np)[None, :]) 68 | 69 | def var_to_np(img_var): 70 | '''Converts an image in torch.Variable format to np.array. 71 | 72 | From 1 x C x W x H [0..1] to C x W x H [0..1] 73 | ''' 74 | return img_var.data.cpu().numpy()[0] 75 | 76 | 77 | def pil_to_np(img_PIL): 78 | '''Converts image in PIL format to np.array. 79 | 80 | From W x H x C [0...255] to C x W x H [0..1] 81 | ''' 82 | ar = np.array(img_PIL) 83 | 84 | if len(ar.shape) == 3: 85 | ar = ar.transpose(2,0,1) 86 | else: 87 | ar = ar[None, ...] 88 | 89 | return ar.astype(np.float32) / 255. 90 | 91 | 92 | def rgb2ycbcr(img): 93 | #out = color.rgb2ycbcr( img.transpose(1, 2, 0) ) 94 | #return out.transpose(2,0,1)/256. 95 | r,g,b = img[0],img[1],img[2] 96 | y = 0.299*r+0.587*g+0.114*b 97 | cb = 0.5 - 0.168736*r - 0.331264*g + 0.5*b 98 | cr = 0.5 + 0.5*r - 0.418588*g - 0.081312*b 99 | return np.array([y,cb,cr]) 100 | 101 | def ycbcr2rgb(img): 102 | #out = color.ycbcr2rgb( 256.*img.transpose(1, 2, 0) ) 103 | #return (out.transpose(2,0,1) - np.min(out))/(np.max(out)-np.min(out)) 104 | y,cb,cr = img[0],img[1],img[2] 105 | r = y + 1.402*(cr-0.5) 106 | g = y - 0.344136*(cb-0.5) - 0.714136*(cr-0.5) 107 | b = y + 1.772*(cb - 0.5) 108 | return np.array([r,g,b]) 109 | 110 | 111 | 112 | def mse(x_hat,x_true,maxv=1.): 113 | x_hat = x_hat.flatten() 114 | x_true = x_true.flatten() 115 | mse = np.mean(np.square(x_hat-x_true)) 116 | energy = np.mean(np.square(x_true)) 117 | return mse/energy 118 | 119 | def psnr(x_hat,x_true,maxv=1.): 120 | x_hat = x_hat.flatten() 121 | x_true = x_true.flatten() 122 | mse=np.mean(np.square(x_hat-x_true)) 123 | psnr_ = 10.*np.log(maxv**2/mse)/np.log(10.) 124 | return psnr_ 125 | 126 | def num_param(net): 127 | s = sum([np.prod(list(p.size())) for p in net.parameters()]); 128 | return s 129 | #print('Number of params: %d' % s) 130 | 131 | def num_trainable_param(net): 132 | s = sum([np.prod(list(p.size())) for p in net.parameters() if p.requires_grad==True]); 133 | return s 134 | 135 | def rgb2gray(rgb): 136 | r, g, b = rgb[0,:,:], rgb[1,:,:], rgb[2,:,:] 137 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 138 | return np.array([gray]) 139 | 140 | def savemtx_for_logplot(A,filename = "exp.dat"): 141 | ind = sorted(list(set([int(i) for i in np.geomspace(1, len(A[0])-1 ,num=700)]))) 142 | A = [ [a[i] for i in ind] for a in A] 143 | X = np.array([ind] + A) 144 | np.savetxt(filename, X.T, delimiter=' ') 145 | 146 | 147 | def get_imgnet_imgs(num_samples = 100, path = '../imagenet/',verbose=False): 148 | perm = [i for i in range(1,50000)] 149 | random.Random(4).shuffle(perm) 150 | siz = 512 151 | file = open("exp_imgnet_imgs.txt","w") 152 | 153 | imgs = [] 154 | sampled = 0 155 | imgslist = [] 156 | for imgnr in perm: 157 | # prepare and select image 158 | # Format is: ILSVRC2012_val_00024995.JPEG 159 | imgnr_str = str(imgnr).zfill(8) 160 | imgname = path + 'ILSVRC2012_val_' + imgnr_str + ".JPEG" 161 | img = load_and_crop(imgname,target_width=512,target_height=512) 162 | if img is None: # then the image could not be croped to 512x512 163 | continue 164 | 165 | img_np = pil_to_np(img) 166 | 167 | if img_np.shape[0] != 3: # we only want to consider color images 168 | continue 169 | if verbose: 170 | imgslist += ['ILSVRC2012_val_' + imgnr_str + ".JPEG"] 171 | print("cp ", imgname, "./imgs") 172 | imgs += [img_np] 173 | sampled += 1 174 | if sampled >= num_samples: 175 | break 176 | if verbose: 177 | print(imgslist) 178 | return imgs 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /include/onedim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim 5 | from torch.autograd import Variable 6 | import matplotlib.pyplot as plt 7 | import copy 8 | 9 | 10 | 11 | dtype = torch.FloatTensor # This code is meant for CPU 12 | 13 | def add_module(self, module): 14 | self.add_module(str(len(self) + 1), module) 15 | 16 | torch.nn.Module.add = add_module 17 | 18 | 19 | def conv1(in_f, out_f, kernel_size, stride=1, pad='zero'): 20 | padder = None 21 | to_pad = int((kernel_size - 1) / 2) 22 | if pad == 'reflection': 23 | padder = nn.ReflectionPad2d(to_pad) 24 | to_pad = 0 25 | 26 | convolver = nn.Conv1d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=False) 27 | 28 | layers = filter(lambda x: x is not None, [padder, convolver]) 29 | return nn.Sequential(*layers) 30 | 31 | 32 | 33 | # Define the upsampling matrices 34 | def get_upsample_matrix(k, identity=False, upsample_mode='linear'): 35 | # Returns a 2*k-1 x k numpy array corresponding to an upsampling matrix 36 | 37 | if identity: 38 | return np.eye(k) 39 | U = np.zeros((2*k-1, k)) 40 | for i in range(k): 41 | U[2*i, i] = 1 42 | 43 | if i < k-1: 44 | if upsample_mode=='linear': 45 | U[2*i+1, [i, (i+1) % k]] = [1./2, 1./2] 46 | elif upsample_mode=='convex0.7-0.3': 47 | U[2*i+1, [i, (i+1) % k]] = [0.7, 0.3] 48 | elif upsample_mode=='convex0.75-0.25': 49 | U[2*i+1, [i, (i+1) % k]] = [0.75, 0.25] 50 | return U 51 | 52 | 53 | 54 | class Upsample_Module(nn.Module): 55 | # Only works for batch size 1. Works for any number of channels 56 | 57 | def __init__(self, upsample_mode='linear'): 58 | super(Upsample_Module,self).__init__() 59 | self.upsample_mode=upsample_mode 60 | 61 | def forward(self, x): 62 | n = x.shape[2] 63 | U = Variable(torch.Tensor(get_upsample_matrix(n, upsample_mode=self.upsample_mode))) 64 | return torch.stack([torch.t(U.matmul(torch.t(x[0,...])))], 0) 65 | 66 | 67 | 68 | 69 | def decoder_1d( 70 | num_output_channels=1, 71 | num_channels_up=[128]*5, 72 | filter_size_up=1, 73 | need_sigmoid=False, 74 | pad='zero', 75 | upsample_mode='linear', 76 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 77 | need_bn=True, 78 | ): 79 | 80 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 81 | n_scales = len(num_channels_up) 82 | #print('n_scales = %d' %n_scales) 83 | 84 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 85 | filter_size_up = [filter_size_up]*n_scales 86 | 87 | model = nn.Sequential() 88 | 89 | for i in range(len(num_channels_up)-1): 90 | 91 | if upsample_mode!='none' and i!=0: 92 | if upsample_mode=='MatrixUpsample': 93 | model.add(Upsample_Module()) 94 | elif upsample_mode=='MatrixUpsampleConvex0.7-0.3': 95 | model.add(Upsample_Module(upsample_mode='convex0.7-0.3')) 96 | elif upsample_mode=='MatrixUpsampleConvex0.75-0.25': 97 | model.add(Upsample_Module(upsample_mode='convex0.75-0.25')) 98 | elif upsample_mode=='nnUpsampleDouble': 99 | model.add(nn.Upsample(scale_factor=2.0, mode='linear', align_corners=False)) 100 | elif upsample_mode=='nearest': 101 | model.add(nn.Upsample(scale_factor=2.0, mode='nearest')) 102 | 103 | model.add(conv1( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 104 | if i != len(num_channels_up)-1: 105 | if need_bn: 106 | model.add(nn.BatchNorm1d( num_channels_up[i+1] )) 107 | model.add(act_fun) 108 | 109 | model.add(conv1( num_channels_up[-1], num_output_channels, 1, pad=pad)) 110 | 111 | if need_sigmoid: 112 | model.add(nn.Sigmoid()) 113 | 114 | return model 115 | 116 | 117 | def fit_1d(net, 118 | img_noisy_var, 119 | num_channels, 120 | img_clean_var, 121 | net_input, # Passing in the net_input is required 122 | num_iter = 5000, 123 | LR = 0.01, 124 | OPTIMIZER='adam', 125 | opt_input = False, 126 | reg_noise_std = 0, 127 | reg_noise_decayevery = 100000, 128 | mask_var = None, 129 | apply_f = None, 130 | decaylr = False, 131 | net_input_gen = "random", 132 | plot_output_every = None, 133 | ): 134 | 135 | net_input_saved = net_input.data.clone() 136 | noise = net_input.data.clone() 137 | p = [x for x in net.parameters() ] 138 | 139 | if(opt_input == True): 140 | net_input.requires_grad = True 141 | p += [net_input] 142 | 143 | mse_wrt_noisy = np.zeros(num_iter) 144 | mse_wrt_truth = np.zeros(num_iter) 145 | 146 | if OPTIMIZER == 'SGD': 147 | print("optimize with SGD", LR) 148 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9) 149 | elif OPTIMIZER == 'adam': 150 | print("optimize with adam", LR) 151 | optimizer = torch.optim.Adam(p, lr=LR) 152 | 153 | mse = torch.nn.MSELoss() #.type(dtype) 154 | noise_energy = mse(img_noisy_var, img_clean_var) 155 | 156 | 157 | 158 | for i in range(num_iter): 159 | if decaylr is True: 160 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=100) 161 | if reg_noise_std > 0: 162 | if i % reg_noise_decayevery == 0: 163 | reg_noise_std *= 0.7 164 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 165 | optimizer.zero_grad() 166 | out = net(net_input.type(dtype)) 167 | 168 | 169 | 170 | # training loss 171 | if mask_var is not None: 172 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 173 | elif apply_f: 174 | loss = mse( apply_f(out) , img_noisy_var ) 175 | else: 176 | loss = mse(out, img_noisy_var) 177 | loss.backward() 178 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 179 | if mse_wrt_noisy[i] == np.min(mse_wrt_noisy[:i+1]): 180 | best_net = copy.deepcopy(net) 181 | best_mse_wrt_noisy = mse_wrt_noisy[i] 182 | 183 | # the actual loss 184 | true_loss = mse(Variable(out.data, requires_grad=False), img_clean_var) 185 | mse_wrt_truth[i] = true_loss.data.cpu().numpy() 186 | if i % 10 == 0: 187 | out2 = net(Variable(net_input_saved).type(dtype)) 188 | loss2 = mse(out2, img_clean_var) 189 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f Noise Energy %f' 190 | % (i, loss.data.item(),true_loss.data.item(),loss2.data.item(),noise_energy.data.item()), '\r', end='') 191 | if plot_output_every and (i % plot_output_every==1): 192 | out3 = net(Variable(net_input_saved).type(dtype)) 193 | ax = plt.figure(figsize=(12,5)) 194 | plt.plot(out3[0,0,:].data.numpy(), '.b') 195 | plt.plot(img_clean_var[0,0,:].data.numpy(), '-r') 196 | plt.show() 197 | optimizer.step() 198 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, best_net, best_mse_wrt_noisy # Didn't implement case wehere there is noise in signal 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /include/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def to_tensor(data): 13 | """ 14 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts 15 | are stacked along the last dimension. 16 | 17 | Args: 18 | data (np.array): Input numpy array 19 | 20 | Returns: 21 | torch.Tensor: PyTorch version of data 22 | """ 23 | if np.iscomplexobj(data): 24 | data = np.stack((data.real, data.imag), axis=-1) 25 | return torch.from_numpy(data) 26 | 27 | 28 | def apply_mask(data, mask_func, seed=None): 29 | """ 30 | Subsample given k-space by multiplying with a mask. 31 | 32 | Args: 33 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 34 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 35 | 2 (for complex values). 36 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 37 | number seed and returns a mask. 38 | seed (int or 1-d array_like, optional): Seed for the random number generator. 39 | 40 | Returns: 41 | (tuple): tuple containing: 42 | masked data (torch.Tensor): Subsampled k-space data 43 | mask (torch.Tensor): The generated mask 44 | """ 45 | shape = np.array(data.shape) 46 | shape[:-3] = 1 47 | mask = mask_func(shape, seed) 48 | return data * mask, mask 49 | 50 | 51 | def fft2(data): 52 | """ 53 | Apply centered 2 dimensional Fast Fourier Transform. 54 | 55 | Args: 56 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 57 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 58 | assumed to be batch dimensions. 59 | 60 | Returns: 61 | torch.Tensor: The FFT of the input. 62 | """ 63 | assert data.size(-1) == 2 64 | data = ifftshift(data, dim=(-3, -2)) 65 | data = torch.fft(data, 2, normalized=True) 66 | data = fftshift(data, dim=(-3, -2)) 67 | return data 68 | 69 | 70 | def ifft2(data): 71 | """ 72 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 73 | 74 | Args: 75 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 76 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 77 | assumed to be batch dimensions. 78 | 79 | Returns: 80 | torch.Tensor: The IFFT of the input. 81 | """ 82 | assert data.size(-1) == 2 83 | data = ifftshift(data, dim=(-3, -2)) 84 | data = torch.ifft(data, 2, normalized=True) 85 | data = fftshift(data, dim=(-3, -2)) 86 | return data 87 | 88 | 89 | def complex_abs(data): 90 | """ 91 | Compute the absolute value of a complex valued input tensor. 92 | 93 | Args: 94 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 95 | should be 2. 96 | 97 | Returns: 98 | torch.Tensor: Absolute value of data 99 | """ 100 | assert data.size(-1) == 2 101 | return (data ** 2).sum(dim=-1).sqrt() 102 | 103 | 104 | def root_sum_of_squares(data, dim=0): 105 | """ 106 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. 107 | 108 | Args: 109 | data (torch.Tensor): The input tensor 110 | dim (int): The dimensions along which to apply the RSS transform 111 | 112 | Returns: 113 | torch.Tensor: The RSS value 114 | """ 115 | return torch.sqrt((data ** 2).sum(dim)) 116 | 117 | 118 | def center_crop(data, shape): 119 | """ 120 | Apply a center crop to the input real image or batch of real images. 121 | 122 | Args: 123 | data (torch.Tensor): The input tensor to be center cropped. It should have at 124 | least 2 dimensions and the cropping is applied along the last two dimensions. 125 | shape (int, int): The output shape. The shape should be smaller than the 126 | corresponding dimensions of data. 127 | 128 | Returns: 129 | torch.Tensor: The center cropped image 130 | """ 131 | assert 0 < shape[0] <= data.shape[-2] 132 | assert 0 < shape[1] <= data.shape[-1] 133 | w_from = (data.shape[-2] - shape[0]) // 2 134 | h_from = (data.shape[-1] - shape[1]) // 2 135 | w_to = w_from + shape[0] 136 | h_to = h_from + shape[1] 137 | return data[..., w_from:w_to, h_from:h_to] 138 | 139 | 140 | def complex_center_crop(data, shape): 141 | """ 142 | Apply a center crop to the input image or batch of complex images. 143 | 144 | Args: 145 | data (torch.Tensor): The complex input tensor to be center cropped. It should 146 | have at least 3 dimensions and the cropping is applied along dimensions 147 | -3 and -2 and the last dimensions should have a size of 2. 148 | shape (int, int): The output shape. The shape should be smaller than the 149 | corresponding dimensions of data. 150 | 151 | Returns: 152 | torch.Tensor: The center cropped image 153 | """ 154 | assert 0 < shape[0] <= data.shape[-3] 155 | assert 0 < shape[1] <= data.shape[-2] 156 | w_from = (data.shape[-3] - shape[0]) // 2 157 | h_from = (data.shape[-2] - shape[1]) // 2 158 | w_to = w_from + shape[0] 159 | h_to = h_from + shape[1] 160 | return data[..., w_from:w_to, h_from:h_to, :] 161 | 162 | 163 | def normalize(data, mean, stddev, eps=0.): 164 | """ 165 | Normalize the given tensor using: 166 | (data - mean) / (stddev + eps) 167 | 168 | Args: 169 | data (torch.Tensor): Input data to be normalized 170 | mean (float): Mean value 171 | stddev (float): Standard deviation 172 | eps (float): Added to stddev to prevent dividing by zero 173 | 174 | Returns: 175 | torch.Tensor: Normalized tensor 176 | """ 177 | return (data - mean) / (stddev + eps) 178 | 179 | 180 | def normalize_instance(data, eps=0.): 181 | """ 182 | Normalize the given tensor using: 183 | (data - mean) / (stddev + eps) 184 | where mean and stddev are computed from the data itself. 185 | 186 | Args: 187 | data (torch.Tensor): Input data to be normalized 188 | eps (float): Added to stddev to prevent dividing by zero 189 | 190 | Returns: 191 | torch.Tensor: Normalized tensor 192 | """ 193 | mean = data.mean() 194 | std = data.std() 195 | return normalize(data, mean, std, eps), mean, std 196 | 197 | 198 | # Helper functions 199 | 200 | def roll(x, shift, dim): 201 | """ 202 | Similar to np.roll but applies to PyTorch Tensors 203 | """ 204 | if isinstance(shift, (tuple, list)): 205 | assert len(shift) == len(dim) 206 | for s, d in zip(shift, dim): 207 | x = roll(x, s, d) 208 | return x 209 | shift = shift % x.size(dim) 210 | if shift == 0: 211 | return x 212 | left = x.narrow(dim, 0, x.size(dim) - shift) 213 | right = x.narrow(dim, x.size(dim) - shift, shift) 214 | return torch.cat((right, left), dim=dim) 215 | 216 | 217 | def fftshift(x, dim=None): 218 | """ 219 | Similar to np.fft.fftshift but applies to PyTorch Tensors 220 | """ 221 | if dim is None: 222 | dim = tuple(range(x.dim())) 223 | shift = [dim // 2 for dim in x.shape] 224 | elif isinstance(dim, int): 225 | shift = x.shape[dim] // 2 226 | else: 227 | shift = [x.shape[i] // 2 for i in dim] 228 | return roll(x, shift, dim) 229 | 230 | 231 | def ifftshift(x, dim=None): 232 | """ 233 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 234 | """ 235 | if dim is None: 236 | dim = tuple(range(x.dim())) 237 | shift = [(dim + 1) // 2 for dim in x.shape] 238 | elif isinstance(dim, int): 239 | shift = (x.shape[dim] + 1) // 2 240 | else: 241 | shift = [(x.shape[i] + 1) // 2 for i in dim] 242 | return roll(x, shift, dim) 243 | -------------------------------------------------------------------------------- /include/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.optim 5 | import numpy as np 6 | from collections import Iterable 7 | 8 | 9 | dtype = torch.cuda.FloatTensor 10 | #dtype = torch.FloatTensor 11 | 12 | def save_np_img(img,filename): 13 | if(img.shape[0] == 1): 14 | plt.imshow(np.clip(img[0],0,1)) 15 | else: 16 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1)) 17 | plt.axis('off') 18 | plt.savefig(filename, bbox_inches='tight') 19 | plt.close() 20 | 21 | def apply_until(net_input,net,n = 100): 22 | # applies function by funtion of a network 23 | for i,fun in enumerate(net): 24 | if i>=n: 25 | break 26 | if i==0: 27 | out = fun(net_input.type(dtype)) 28 | else: 29 | out = fun(out) 30 | print(i, "last func. applied:", net[i-1]) 31 | if n == 0: 32 | return net_input 33 | else: 34 | return out 35 | 36 | 37 | from math import ceil 38 | 39 | 40 | # given a lists of images as np-arrays, plot them as a row# given 41 | def plot_image_grid(imgs,nrows=10): 42 | ncols = ceil( len(imgs)/nrows ) 43 | nrows = min(nrows,len(imgs)) 44 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows),squeeze=False) 45 | for i, row in enumerate(axes): 46 | for j, ax in enumerate(row): 47 | ax.imshow(imgs[j*nrows+i], cmap='Greys_r', interpolation='none') 48 | ax.get_xaxis().set_visible(False) 49 | ax.get_yaxis().set_visible(False) 50 | fig.tight_layout(pad=0.1) 51 | return fig 52 | 53 | def save_tensor(out,filename,nrows=8): 54 | imgs = [img for img in out.data.cpu().numpy()[0]] 55 | fig = plot_image_grid(imgs,nrows=nrows) 56 | plt.savefig(filename) 57 | plt.close() 58 | 59 | def plot_kernels(tensor): 60 | if not len(tensor.shape)==4: 61 | raise Exception("assumes a 4D tensor") 62 | num_kernels = tensor.shape[0] 63 | fig = plt.figure(figsize=(tensor.shape[0],tensor.shape[1])) 64 | for i in range(tensor.shape[0]): 65 | for j in range(tensor.shape[1]): 66 | ax1 = fig.add_subplot(tensor.shape[0],tensor.shape[1],1+i*tensor.shape[0]+j) 67 | ax1.imshow(tensor[i][j]) 68 | ax1.axis('off') 69 | ax1.set_xticklabels([]) 70 | ax1.set_yticklabels([]) 71 | 72 | plt.subplots_adjust(wspace=0.1, hspace=0.1) 73 | plt.show() 74 | 75 | def plot_tensor(out,nrows=8): 76 | imgs = [img for img in out.data.cpu().numpy()[0]] 77 | fig = plot_image_grid(imgs,nrows=nrows) 78 | plt.show() 79 | -------------------------------------------------------------------------------- /include/wavelet.py: -------------------------------------------------------------------------------- 1 | #import matplotlib.pyplot as plt 2 | import numpy as np 3 | import numbers 4 | import pywt 5 | import scipy 6 | import skimage.color as color 7 | from skimage.restoration import (denoise_wavelet, estimate_sigma) 8 | from skimage import data, img_as_float 9 | from skimage.util import random_noise 10 | from skimage.measure import compare_psnr 11 | from include import * 12 | 13 | def _wavelet_threshold(image, wavelet, ncoeff = None, threshold=None, mode='soft', wavelet_levels=None): 14 | 15 | wavelet = pywt.Wavelet(wavelet) 16 | 17 | # original_extent is used to workaround PyWavelets issue #80 18 | # odd-sized input results in an image with 1 extra sample after waverecn 19 | original_extent = [slice(s) for s in image.shape] 20 | 21 | # Determine the number of wavelet decomposition levels 22 | if wavelet_levels is None: 23 | # Determine the maximum number of possible levels for image 24 | dlen = wavelet.dec_len 25 | wavelet_levels = np.min( 26 | [pywt.dwt_max_level(s, dlen) for s in image.shape]) 27 | 28 | # Skip coarsest wavelet scales (see Notes in docstring). 29 | wavelet_levels = max(wavelet_levels - 3, 1) 30 | 31 | coeffs = pywt.wavedecn(image, wavelet=wavelet, level=wavelet_levels) 32 | # Detail coefficients at each decomposition level 33 | dcoeffs = coeffs[1:] 34 | 35 | a = [] 36 | for level in dcoeffs: 37 | for key in level: 38 | a += [np.ndarray.flatten(level[key])] 39 | a = np.concatenate(a) 40 | a = np.sort( np.abs(a) ) 41 | 42 | sh = coeffs[0].shape 43 | basecoeffs = sh[0]*sh[1] 44 | threshold = a[- (ncoeff - basecoeffs)] 45 | 46 | # A single threshold for all coefficient arrays 47 | denoised_detail = [{key: pywt.threshold(level[key],value=threshold, 48 | mode=mode) for key in level} for level in dcoeffs] 49 | 50 | denoised_coeffs = [coeffs[0]] + denoised_detail 51 | return pywt.waverecn(denoised_coeffs, wavelet)[original_extent] 52 | 53 | 54 | def denoise_wavelet(image, ncoeff=None, wavelet='db1', mode='hard', 55 | wavelet_levels=None, multichannel=False, 56 | convert2ycbcr=False): 57 | 58 | image = img_as_float(image) 59 | 60 | 61 | if multichannel: 62 | if convert2ycbcr: 63 | out = color.rgb2ycbcr(image) 64 | for i in range(3): 65 | # renormalizing this color channel to live in [0, 1] 66 | min, max = out[..., i].min(), out[..., i].max() 67 | channel = out[..., i] - min 68 | channel /= max - min 69 | out[..., i] = denoise_wavelet(channel, wavelet=wavelet,ncoeff=ncoeff, 70 | mode=mode, 71 | wavelet_levels=wavelet_levels) 72 | 73 | out[..., i] = out[..., i] * (max - min) 74 | out[..., i] += min 75 | out = color.ycbcr2rgb(out) 76 | else: 77 | out = np.empty_like(image) 78 | for c in range(image.shape[-1]): 79 | out[..., c] = _wavelet_threshold(image[..., c],ncoeff=ncoeff, 80 | wavelet=wavelet, mode=mode, 81 | wavelet_levels=wavelet_levels) 82 | else: 83 | out = _wavelet_threshold(image, wavelet=wavelet, mode=mode,ncoeff=ncoeff, 84 | wavelet_levels=wavelet_levels) 85 | 86 | clip_range = (-1, 1) if image.min() < 0 else (0, 1) 87 | return np.clip(out, *clip_range) 88 | 89 | 90 | -------------------------------------------------------------------------------- /linear_least_squares_selective_fitting_warmup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Linear regression fitting noise versus structure\n", 8 | "\n", 9 | "Let \n", 10 | "$$\n", 11 | "y = y^\\ast + e, \\quad y^\\ast = X \\theta^\\ast\n", 12 | "$$\n", 13 | "where $X$ is a fat matrix. Consider optimizing the least squares objective \n", 14 | "$$\n", 15 | "\\frac{1}{2}\n", 16 | "\\| X \\theta - y \\|^2.\n", 17 | "$$\n", 18 | "The gradient is \n", 19 | "$$\n", 20 | "\\nabla f(\\theta) = X^T X \\theta - X^T y.\n", 21 | "$$\n", 22 | "We have\n", 23 | "$$\n", 24 | "\\begin{align}\n", 25 | "\\theta_{k+1} - \\theta^\\ast\n", 26 | "&=\n", 27 | "\\theta_k - \\alpha(X^T X \\theta - X^T y) - \\theta^\\ast \\\\\n", 28 | "&=\n", 29 | "\\theta_k - \\alpha(X^T X \\theta - X^T X \\theta^\\ast - X^T e) - \\theta^\\ast \\\\\n", 30 | "&=\n", 31 | "(I - \\alpha X^T X) (\\theta_k - \\theta^\\ast) + \\alpha X^T e\n", 32 | "\\end{align}\n", 33 | "$$\n", 34 | "The difference in terms of residual therefore becomes\n", 35 | "$$\n", 36 | "X\\theta_{k+1} - X\\theta^\\ast\n", 37 | "=\n", 38 | "(I - \\alpha XX^T) (X\\theta_k - X\\theta^\\ast) + \\alpha X X^T e \\\\\n", 39 | "$$" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 1, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import matplotlib.pyplot as plt\n", 49 | "#%matplotlib notebook\n", 50 | "#import matplotlib.pyplot as plt\n", 51 | "from numpy import *\n", 52 | "import numpy as np" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "def gradient_descent(A,b,niter = 10000,ytarget=None,stepsize=None):\n", 62 | " \n", 63 | " def f(x):\n", 64 | " return linalg.norm(dot(A,x) - b)**2\n", 65 | "\n", 66 | " def gradf(x):\n", 67 | " return 0.5*(dot(Q,x) - dot(b,A))\n", 68 | "\n", 69 | " Q = dot(A.T,A)\n", 70 | " eigenvalues = linalg.eigvals(Q)\n", 71 | " M = max(eigenvalues)\n", 72 | " m = min(eigenvalues)\n", 73 | "\n", 74 | " xopt = dot( linalg.inv( dot(A.T,A) ), dot( A.T , b ) ) \n", 75 | " \n", 76 | " print(\"optimal errors: \", linalg.norm( y - dot(A,xopt) ), linalg.norm( ytarget - dot(A,xopt) ) )\n", 77 | " \n", 78 | " if stepsize==None:\n", 79 | " stepsize = 2/(M+m)\n", 80 | " \n", 81 | " print(\"minimal and maximal eigenvalues: \", m, M)\n", 82 | " print(\"stepsize: \", stepsize)\n", 83 | " \n", 84 | " residuals = []\n", 85 | " gradients = []\n", 86 | " residual_target = []\n", 87 | " distances = []\n", 88 | " xk = zeros(n) #random.randn(n) # random initializer\n", 89 | " for k in range(niter):\n", 90 | " xk = xk - stepsize*gradf(xk)\n", 91 | " residuals.append( linalg.norm( y - dot(A,xk) ) )\n", 92 | " gradients.append( linalg.norm(gradf(xk)) )\n", 93 | " residual_target.append( linalg.norm( ytarget - dot(A,xk) ) )\n", 94 | " distances.append( linalg.norm(xk) )\n", 95 | " return array(residuals), array(gradients), array(residual_target),array(distances)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "image/png": "\n", 106 | "text/plain": [ 107 | "" 108 | ] 109 | }, 110 | "metadata": {}, 111 | "output_type": "display_data" 112 | }, 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "1.5072614647736406 1.0000000000000009 1.0 0.13591856159578958\n", 118 | "optimal errors: 0.5198674927030573 1.052776280346125\n", 119 | "minimal and maximal eigenvalues: 3.3195518835241045e-18 1.0000000000000047\n", 120 | "stepsize: 0.25\n", 121 | "logarithmic\n" 122 | ] 123 | }, 124 | { 125 | "data": { 126 | "image/png": "\n", 127 | "text/plain": [ 128 | "" 129 | ] 130 | }, 131 | "metadata": {}, 132 | "output_type": "display_data" 133 | }, 134 | { 135 | "data": { 136 | "image/png": "\n", 137 | "text/plain": [ 138 | "" 139 | ] 140 | }, 141 | "metadata": {}, 142 | "output_type": "display_data" 143 | }, 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "non logarithmic\n" 149 | ] 150 | }, 151 | { 152 | "data": { 153 | "image/png": "\n", 154 | "text/plain": [ 155 | "" 156 | ] 157 | }, 158 | "metadata": {}, 159 | "output_type": "display_data" 160 | }, 161 | { 162 | "data": { 163 | "image/png": "\n", 164 | "text/plain": [ 165 | "" 166 | ] 167 | }, 168 | "metadata": {}, 169 | "output_type": "display_data" 170 | } 171 | ], 172 | "source": [ 173 | "# generate a problem instance\n", 174 | "n = 100\n", 175 | "A = random.randn(n,n)\n", 176 | "U,S,VT = linalg.svd(A)\n", 177 | "\n", 178 | "ytarget = U[:, int(0)]\n", 179 | "perturbation = random.randn(n) #U[:,n-10]\n", 180 | "perturbation = perturbation/linalg.norm(perturbation)\n", 181 | "#perturbation = U[:,n-10]\n", 182 | "\n", 183 | "newS = np.array([s**3 for s in S])\n", 184 | "newS = newS/np.max(newS)\n", 185 | "plt.plot(newS)\n", 186 | "plt.title(\"spectrum\")\n", 187 | "plt.show()\n", 188 | "\n", 189 | "S = np.diag(newS)\n", 190 | "A = U @ S @ VT\n", 191 | "\n", 192 | "y = ytarget + perturbation\n", 193 | "\n", 194 | "print(linalg.norm(y), linalg.norm(ytarget), linalg.norm(perturbation), dot(ytarget,perturbation) )\n", 195 | "\n", 196 | "steps = 1000\n", 197 | "residuals,gradients,residual_target,distances = gradient_descent(A,y,niter=steps,ytarget=ytarget,stepsize=0.25)\n", 198 | "\n", 199 | "print(\"logarithmic\")\n", 200 | "plt.plot( log(residuals) )\n", 201 | "plt.show()\n", 202 | "plt.plot( log(residual_target) )\n", 203 | "plt.show()\n", 204 | "\n", 205 | "print(\"non logarithmic\")\n", 206 | "plt.plot( residuals )\n", 207 | "plt.show()\n", 208 | "plt.plot( residual_target )\n", 209 | "plt.show()\n", 210 | "\n", 211 | "\n", 212 | "ks = np.array( [i for i in range(steps)] )\n", 213 | "np.savetxt(\"ls_residuals.dat\", np.vstack([ ks ,np.array(residuals),np.array(residual_target) ] ).T , delimiter=\"\\t\")\n", 214 | "\n", 215 | "ns = np.array( [i for i in range(n)] )\n", 216 | "np.savetxt(\"ls_spectrum.dat\", np.vstack([ ns ,np.array(newS)] ).T , delimiter=\"\\t\")\n" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.6.4" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } 242 | -------------------------------------------------------------------------------- /test_data/astronaut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/test_data/astronaut.png -------------------------------------------------------------------------------- /test_data/phantom256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLI-lab/overparameterized_convolutional_generators/ef2fae85768f1954dbd1ead75b9ba8e214c13230/test_data/phantom256.png -------------------------------------------------------------------------------- /visualization_linear_approximation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Approximation with a linear model\n", 8 | "\n", 9 | "Here, we visualy demonstrate that an overparameterized network can be well approximated around a random inital point with a linearized model." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import matplotlib.pyplot as plt\n", 19 | "#%matplotlib notebook\n", 20 | "#import matplotlib.pyplot as plt\n", 21 | "from numpy import *\n", 22 | "import numpy as np" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 46, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "image/png": "\n", 33 | "text/plain": [ 34 | "" 35 | ] 36 | }, 37 | "metadata": {}, 38 | "output_type": "display_data" 39 | } 40 | ], 41 | "source": [ 42 | "# generator network\n", 43 | "\n", 44 | "n = 10\n", 45 | "k = 100 \n", 46 | "v = np.ones(k)\n", 47 | "v[:int(k/2)] = -np.ones( int(k/2) )\n", 48 | "v = v/np.sqrt(k)\n", 49 | "U = np.eye(n)\n", 50 | "\n", 51 | "def G(C): \n", 52 | " return np.maximum( U @ C , 0 ) @ v\n", 53 | "\n", 54 | "# Jaccobian\n", 55 | "def J(C):\n", 56 | " return np.vstack( [ve * (U.T @ np.diag(c > 0)) for ve,c in zip(v,C.T)] ).T\n", 57 | " \n", 58 | "# original loss\n", 59 | "def loss(y,C):\n", 60 | " return np.linalg.norm( y - G(C) )**2\n", 61 | "\n", 62 | "# associated linearized loss\n", 63 | "def losslin(y,C,C0):\n", 64 | " return np.linalg.norm( G(C0) + J(C0) @ np.hstack((C-C0).T) - y )**2\n", 65 | "\n", 66 | "\n", 67 | "\n", 68 | "y = np.random.randn(n)\n", 69 | "\n", 70 | "# initial vector\n", 71 | "C0 = np.random.randn(n,k)\n", 72 | "\n", 73 | "# random direction\n", 74 | "Crand = np.random.randn(n,k)\n", 75 | "\n", 76 | "R = 3\n", 77 | "epsilons = np.linspace(-R,R,100)\n", 78 | "\n", 79 | "\n", 80 | "losses = [loss(y, C0+ep*Crand) for ep in epsilons]\n", 81 | "linlosses = [losslin(y, C0+ep*Crand,C0) for ep in epsilons]\n", 82 | "\n", 83 | "\n", 84 | "plt.plot(losses)\n", 85 | "plt.plot(linlosses)\n", 86 | "plt.show()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python 3", 100 | "language": "python", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.6.4" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 2 118 | } 119 | --------------------------------------------------------------------------------