├── .gitignore ├── Day01 ├── 0_MNIST_probNN.ipynb ├── 1_Linear_Regression.ipynb ├── 2_Logistic_Regression.ipynb └── 3_MLE_with_Bernoulli.ipynb ├── Day02 ├── 1_Naive_Bayes_Classifier.ipynb ├── 2_Gaussian_Discriminant_Analysis.ipynb ├── 3_GMM_EM.ipynb ├── data │ └── spam.csv └── images │ ├── after.png │ ├── before.png │ └── token.png ├── Day03 ├── 0_Conv_Transposed_Exercise.ipynb ├── 0_variational_cointoss.ipynb ├── 1_Simple_Autoencoder.ipynb ├── 2-2_VAE_2d_example.ipynb ├── 2_Simple_VAE.ipynb ├── 3_Convolutional_Autoencoder.ipynb └── 4_Convolutional_VAE.ipynb ├── Day04 ├── 0_Conv_Transposed_Exercise.ipynb ├── DCGAN │ ├── dcgan.ipynb │ ├── dcgan.png │ ├── discriminator.pkl │ ├── download.sh │ └── generator.pkl ├── GAN │ ├── GAN.png │ ├── discriminator.pkl │ ├── gan.ipynb │ └── generator.pkl ├── GAN_2d_example.ipynb └── toy_data.ipynb ├── Day05 ├── CVAE │ └── CVAE.ipynb ├── WGAN.ipynb ├── infogan.ipynb └── visdom_utils.py ├── Day06 ├── AAE │ └── AAE.ipynb ├── CVAE │ ├── CVAE.ipynb │ └── complements │ │ ├── CVAE.png │ │ ├── KLD_analytic.JPG │ │ ├── KLD_analytic2.JPG │ │ ├── reconstruction_loss.JPG │ │ └── total_loss.JPG └── CycleGAN │ ├── CycleGAN_test.ipynb │ ├── CycleGAN_train.ipynb │ ├── README.md │ ├── complements │ ├── concept.jpg │ ├── cover.jpg │ ├── cycle_consistency.JPG │ ├── full_objectives.JPG │ └── gan_loss.JPG │ ├── dataset.py │ ├── logger.py │ ├── model.py │ └── utils.py ├── Day07 ├── MDN │ ├── .ipynb_checkpoints │ │ └── mixture_density_networks-checkpoint.ipynb │ └── mixture_density_networks.ipynb └── pixelCNN │ ├── .ipynb_checkpoints │ ├── 201708_IconPixelCNN-checkpoint.ipynb │ ├── IconPixelCNN-checkpoint.ipynb │ └── ToyPixelCNN-checkpoint.ipynb │ ├── 201708_IconPixelCNN.ipynb │ ├── IconPixelCNN.ipynb │ └── ToyPixelCNN.ipynb └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | Untitled.* 3 | *.png 4 | *.pkl 5 | *zip* 6 | 7 | data2 8 | output 9 | temp 10 | tmp 11 | dataset 12 | datasets 13 | __pycache__ 14 | 15 | # 1Konny 16 | workroom 17 | *_model 18 | *_results 19 | 20 | horse2zebra 21 | horse2zebra/testA 22 | horse2zebra/testB 23 | horse2zebra/trainA 24 | horse2zebra/trainB 25 | download_horse2zebra.sh 26 | -------------------------------------------------------------------------------- /Day01/3_MLE_with_Bernoulli.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 88, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math\n", 10 | "import torch\n", 11 | "from torch.distributions import normal\n", 12 | "%matplotlib inline" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "# MLE with Bernoulli" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 89, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import torch\n", 29 | "import numpy as np\n", 30 | "\n", 31 | "sample = np.array([ 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", 32 | " 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", 33 | " 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", 34 | " 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1.,\n", 35 | " 1., 0., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1.,\n", 36 | " 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1.,\n", 37 | " 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1.,\n", 38 | " 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 0.,\n", 39 | " 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1.,\n", 40 | " 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1.,\n", 41 | " 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1.,\n", 42 | " 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 1.,\n", 43 | " 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 0., 0., 1.,\n", 44 | " 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,\n", 45 | " 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1.,\n", 46 | " 1., 0., 1., 0., 1.])" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 90, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "0.725" 58 | ] 59 | }, 60 | "execution_count": 90, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "np.mean(sample)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "Let's now define the probability p of generating 1, and put the sample into a PyTorch Variable:" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 91, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "x = torch.from_numpy(sample).float()\n", 83 | "p = torch.rand(1).float().requires_grad_()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "We are ready to learn the model using maximum likelihood:" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 92, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "loglikelihood= 120.99613 p= [0.8012794] dL/dp= [95.80995]\n", 103 | "loglikelihood= 117.67018 p= [0.7334692] dL/dp= [8.664444]\n", 104 | "loglikelihood= 117.634346 p= [0.726097] dL/dp= [1.1031647]\n", 105 | "loglikelihood= 117.63379 p= [0.72514415] dL/dp= [0.14459229]\n", 106 | "loglikelihood= 117.63375 p= [0.7250189] dL/dp= [0.01899719]\n", 107 | "loglikelihood= 117.63374 p= [0.7250027] dL/dp= [0.00271606]\n", 108 | "loglikelihood= 117.63374 p= [0.72500145] dL/dp= [0.0014801]\n", 109 | "loglikelihood= 117.63374 p= [0.72500145] dL/dp= [0.0014801]\n", 110 | "loglikelihood= 117.63374 p= [0.72500145] dL/dp= [0.0014801]\n", 111 | "loglikelihood= 117.63374 p= [0.72500145] dL/dp= [0.0014801]\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "learning_rate = 0.00002\n", 117 | "for t in range(1000):\n", 118 | " NLL = -torch.sum(torch.log(x*p + (1-x)*(1-p)) )\n", 119 | " NLL.backward()\n", 120 | "\n", 121 | " if t % 100 == 0:\n", 122 | " print(\"loglikelihood=\", NLL.data.numpy(), \"p=\", p.data.numpy(), \"dL/dp= \", p.grad.data.numpy())\n", 123 | "\n", 124 | " \n", 125 | " p.data -= learning_rate * p.grad.data\n", 126 | " p.grad.data.zero_()" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.6.8" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 2 158 | } 159 | -------------------------------------------------------------------------------- /Day02/data/spam.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day02/data/spam.csv -------------------------------------------------------------------------------- /Day02/images/after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day02/images/after.png -------------------------------------------------------------------------------- /Day02/images/before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day02/images/before.png -------------------------------------------------------------------------------- /Day02/images/token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day02/images/token.png -------------------------------------------------------------------------------- /Day03/0_Conv_Transposed_Exercise.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Convolution Transposed Exercise\n", 8 | "\n", 9 | "torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)\n", 10 | "\n", 11 | "check out https://github.com/vdumoulin/conv_arithmetic\n", 12 | "\n", 13 | "## 1. Import Required Libraries" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import torch\n", 25 | "import torch.nn as nn\n", 26 | "import torch.nn.init as init" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## 2. Input Data" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "tensor([[[[1., 1., 1.],\n", 46 | " [1., 1., 1.],\n", 47 | " [1., 1., 1.]]]])\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "img = torch.ones(1,1,3,3)\n", 53 | "print(img)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## 3. Set All Weights to One" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Parameter containing:\n", 73 | "tensor([[[[-0.2049, -0.3024, -0.3156],\n", 74 | " [-0.1755, 0.1763, 0.1756],\n", 75 | " [-0.2891, -0.3212, -0.1272]]]], requires_grad=True)\n" 76 | ] 77 | }, 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "Parameter containing:\n", 82 | "tensor([[[[1., 1., 1.],\n", 83 | " [1., 1., 1.],\n", 84 | " [1., 1., 1.]]]], requires_grad=True)" 85 | ] 86 | }, 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0, output_padding=0, bias=False)\n", 94 | "print(transpose.weight)\n", 95 | "\n", 96 | "init.constant_(transpose.weight,1)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Kernel Size=3, stride=1, padding=0, output_padding=0" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/plain": [ 114 | "tensor([[[[1., 2., 3., 2., 1.],\n", 115 | " [2., 4., 6., 4., 2.],\n", 116 | " [3., 6., 9., 6., 3.],\n", 117 | " [2., 4., 6., 4., 2.],\n", 118 | " [1., 2., 3., 2., 1.]]]], grad_fn=)" 119 | ] 120 | }, 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "transpose(img)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Kernel Size=3, stride=2, padding=0, output_padding=0" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 5, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "text/plain": [ 145 | "tensor([[[[1., 1., 2., 1., 2., 1., 1.],\n", 146 | " [1., 1., 2., 1., 2., 1., 1.],\n", 147 | " [2., 2., 4., 2., 4., 2., 2.],\n", 148 | " [1., 1., 2., 1., 2., 1., 1.],\n", 149 | " [2., 2., 4., 2., 4., 2., 2.],\n", 150 | " [1., 1., 2., 1., 2., 1., 1.],\n", 151 | " [1., 1., 2., 1., 2., 1., 1.]]]],\n", 152 | " grad_fn=)" 153 | ] 154 | }, 155 | "execution_count": 5, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=0, output_padding=0, bias=False)\n", 162 | "init.constant_(transpose.weight,1)\n", 163 | "transpose(img)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "## Kernel Size=3, stride=2, padding=1, output_padding=0" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 6, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "tensor([[[[1., 2., 1., 2., 1.],\n", 182 | " [2., 4., 2., 4., 2.],\n", 183 | " [1., 2., 1., 2., 1.],\n", 184 | " [2., 4., 2., 4., 2.],\n", 185 | " [1., 2., 1., 2., 1.]]]], grad_fn=)" 186 | ] 187 | }, 188 | "execution_count": 6, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=0, bias=False)\n", 195 | "init.constant_(transpose.weight.data,1)\n", 196 | "transpose(img)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "## Kernel Size=3, stride=2, padding=0, output_padding=1" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 7, 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "data": { 213 | "text/plain": [ 214 | "tensor([[[[1., 1., 2., 1., 2., 1., 1., 0.],\n", 215 | " [1., 1., 2., 1., 2., 1., 1., 0.],\n", 216 | " [2., 2., 4., 2., 4., 2., 2., 0.],\n", 217 | " [1., 1., 2., 1., 2., 1., 1., 0.],\n", 218 | " [2., 2., 4., 2., 4., 2., 2., 0.],\n", 219 | " [1., 1., 2., 1., 2., 1., 1., 0.],\n", 220 | " [1., 1., 2., 1., 2., 1., 1., 0.],\n", 221 | " [0., 0., 0., 0., 0., 0., 0., 0.]]]],\n", 222 | " grad_fn=)" 223 | ] 224 | }, 225 | "execution_count": 7, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=0, output_padding=1, bias=False)\n", 232 | "init.constant_(transpose.weight.data,1)\n", 233 | "transpose(img)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "## Kernel Size=3, stride=2, padding=1, output_padding=1" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 8, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "data": { 250 | "text/plain": [ 251 | "tensor([[[[1., 2., 1., 2., 1., 1.],\n", 252 | " [2., 4., 2., 4., 2., 2.],\n", 253 | " [1., 2., 1., 2., 1., 1.],\n", 254 | " [2., 4., 2., 4., 2., 2.],\n", 255 | " [1., 2., 1., 2., 1., 1.],\n", 256 | " [1., 2., 1., 2., 1., 1.]]]], grad_fn=)" 257 | ] 258 | }, 259 | "execution_count": 8, 260 | "metadata": {}, 261 | "output_type": "execute_result" 262 | } 263 | ], 264 | "source": [ 265 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)\n", 266 | "init.constant_(transpose.weight.data,1)\n", 267 | "transpose(img)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [] 276 | } 277 | ], 278 | "metadata": { 279 | "kernelspec": { 280 | "display_name": "Python 3", 281 | "language": "python", 282 | "name": "python3" 283 | }, 284 | "language_info": { 285 | "codemirror_mode": { 286 | "name": "ipython", 287 | "version": 3 288 | }, 289 | "file_extension": ".py", 290 | "mimetype": "text/x-python", 291 | "name": "python", 292 | "nbconvert_exporter": "python", 293 | "pygments_lexer": "ipython3", 294 | "version": "3.6.8" 295 | } 296 | }, 297 | "nbformat": 4, 298 | "nbformat_minor": 2 299 | } 300 | -------------------------------------------------------------------------------- /Day04/0_Conv_Transposed_Exercise.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Convolution Transposed Exercise\n", 8 | "\n", 9 | "torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)\n", 10 | "\n", 11 | "check out https://github.com/vdumoulin/conv_arithmetic\n", 12 | "\n", 13 | "## 1. Import Required Libraries" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 6, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import torch\n", 25 | "import torch.nn as nn\n", 26 | "import torch.nn.init as init" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## 2. Input Data" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 7, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "tensor([[[[ 1., 1., 1.],\n", 46 | " [ 1., 1., 1.],\n", 47 | " [ 1., 1., 1.]]]])\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "img = torch.ones(1,1,3,3)\n", 53 | "print(img)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## 3. Set All Weights to One" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 8, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Parameter containing:\n", 73 | "tensor([[[[ 0.0206, -0.1985, -0.3008],\n", 74 | " [ 0.1443, 0.1907, -0.2587],\n", 75 | " [-0.2740, 0.0397, -0.1210]]]])\n" 76 | ] 77 | }, 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "Parameter containing:\n", 82 | "tensor([[[[ 1., 1., 1.],\n", 83 | " [ 1., 1., 1.],\n", 84 | " [ 1., 1., 1.]]]])" 85 | ] 86 | }, 87 | "execution_count": 8, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=0, output_padding=0, bias=False)\n", 94 | "print(transpose.weight)\n", 95 | "\n", 96 | "init.constant_(transpose.weight,1)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "## Kernel Size=3, stride=1, padding=0, output_padding=0" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 9, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/plain": [ 114 | "tensor([[[[ 1., 2., 3., 2., 1.],\n", 115 | " [ 2., 4., 6., 4., 2.],\n", 116 | " [ 3., 6., 9., 6., 3.],\n", 117 | " [ 2., 4., 6., 4., 2.],\n", 118 | " [ 1., 2., 3., 2., 1.]]]])" 119 | ] 120 | }, 121 | "execution_count": 9, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "transpose(img)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Kernel Size=3, stride=2, padding=0, output_padding=0" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 10, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "text/plain": [ 145 | "tensor([[[[ 1., 1., 2., 1., 2., 1., 1.],\n", 146 | " [ 1., 1., 2., 1., 2., 1., 1.],\n", 147 | " [ 2., 2., 4., 2., 4., 2., 2.],\n", 148 | " [ 1., 1., 2., 1., 2., 1., 1.],\n", 149 | " [ 2., 2., 4., 2., 4., 2., 2.],\n", 150 | " [ 1., 1., 2., 1., 2., 1., 1.],\n", 151 | " [ 1., 1., 2., 1., 2., 1., 1.]]]])" 152 | ] 153 | }, 154 | "execution_count": 10, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=0, output_padding=0, bias=False)\n", 161 | "init.constant_(transpose.weight,1)\n", 162 | "transpose(img)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Kernel Size=3, stride=2, padding=1, output_padding=0" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 11, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "text/plain": [ 180 | "tensor([[[[ 1., 2., 1., 2., 1.],\n", 181 | " [ 2., 4., 2., 4., 2.],\n", 182 | " [ 1., 2., 1., 2., 1.],\n", 183 | " [ 2., 4., 2., 4., 2.],\n", 184 | " [ 1., 2., 1., 2., 1.]]]])" 185 | ] 186 | }, 187 | "execution_count": 11, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=0, bias=False)\n", 194 | "init.constant_(transpose.weight.data,1)\n", 195 | "transpose(img)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "## Kernel Size=3, stride=2, padding=0, output_padding=1" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 12, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "data": { 212 | "text/plain": [ 213 | "tensor([[[[ 1., 1., 2., 1., 2., 1., 1., 0.],\n", 214 | " [ 1., 1., 2., 1., 2., 1., 1., 0.],\n", 215 | " [ 2., 2., 4., 2., 4., 2., 2., 0.],\n", 216 | " [ 1., 1., 2., 1., 2., 1., 1., 0.],\n", 217 | " [ 2., 2., 4., 2., 4., 2., 2., 0.],\n", 218 | " [ 1., 1., 2., 1., 2., 1., 1., 0.],\n", 219 | " [ 1., 1., 2., 1., 2., 1., 1., 0.],\n", 220 | " [ 0., 0., 0., 0., 0., 0., 0., 0.]]]])" 221 | ] 222 | }, 223 | "execution_count": 12, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=0, output_padding=1, bias=False)\n", 230 | "init.constant_(transpose.weight.data,1)\n", 231 | "transpose(img)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "## Kernel Size=3, stride=2, padding=1, output_padding=1" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 13, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "text/plain": [ 249 | "tensor([[[[ 1., 2., 1., 2., 1., 1.],\n", 250 | " [ 2., 4., 2., 4., 2., 2.],\n", 251 | " [ 1., 2., 1., 2., 1., 1.],\n", 252 | " [ 2., 4., 2., 4., 2., 2.],\n", 253 | " [ 1., 2., 1., 2., 1., 1.],\n", 254 | " [ 1., 2., 1., 2., 1., 1.]]]])" 255 | ] 256 | }, 257 | "execution_count": 13, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "transpose = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)\n", 264 | "init.constant_(transpose.weight.data,1)\n", 265 | "transpose(img)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "Python 3", 279 | "language": "python", 280 | "name": "python3" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.6.8" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 2 297 | } 298 | -------------------------------------------------------------------------------- /Day04/DCGAN/dcgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day04/DCGAN/dcgan.png -------------------------------------------------------------------------------- /Day04/DCGAN/discriminator.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day04/DCGAN/discriminator.pkl -------------------------------------------------------------------------------- /Day04/DCGAN/download.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/e0ig4nf1v94hyj8/CelebA_128crop_FD.zip?dl=0 -P ./ 2 | unzip CelebA_128crop_FD.zip -d ./ -------------------------------------------------------------------------------- /Day04/DCGAN/generator.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day04/DCGAN/generator.pkl -------------------------------------------------------------------------------- /Day04/GAN/GAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day04/GAN/GAN.png -------------------------------------------------------------------------------- /Day04/GAN/discriminator.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day04/GAN/discriminator.pkl -------------------------------------------------------------------------------- /Day04/GAN/generator.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day04/GAN/generator.pkl -------------------------------------------------------------------------------- /Day04/toy_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [ 10 | { 11 | "ename": "ModuleNotFoundError", 12 | "evalue": "No module named 'data'", 13 | "output_type": "error", 14 | "traceback": [ 15 | "\u001b[0;31m----------------------------------------------------------\u001b[0m", 16 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 17 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mIPython\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'notebooks'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 18 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'data'" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "import torch\n", 24 | "import torch.nn as nn\n", 25 | "import torch.nn.functional as F\n", 26 | "import torchvision\n", 27 | "import numpy as np\n", 28 | "import matplotlib\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import matplotlib.gridspec as gridspec\n", 31 | "%matplotlib inline\n", 32 | "import os\n", 33 | "from IPython import display\n", 34 | "os.chdir('..')\n", 35 | "import data\n", 36 | "os.chdir('notebooks')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 14, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import shutil\n", 46 | "#shutil.rmtree('octa/pred_mean//')" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "z_dim = 128\n", 56 | "batch = 128\n", 57 | "max_step = 50001\n", 58 | "summary_step = 100" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "class Generator(nn.Module):\n", 68 | " def __init__(self):\n", 69 | " super().__init__()\n", 70 | " self._net = nn.Sequential(\n", 71 | " nn.Linear(z_dim, 128),\n", 72 | " nn.ReLU(),\n", 73 | " nn.Linear(128, 128),\n", 74 | " nn.ReLU(),\n", 75 | " )\n", 76 | " self.mean_fc = nn.Linear(128, 2)\n", 77 | " self.var_fc = nn.Linear(128, 2)\n", 78 | " \n", 79 | " def forward(self, z):\n", 80 | " h = self._net(z)\n", 81 | " return self.mean_fc(h), self.var_fc(h)\n", 82 | "\n", 83 | "\n", 84 | "class Discriminator(nn.Module):\n", 85 | " def __init__(self):\n", 86 | " super().__init__()\n", 87 | " self._net = nn.Sequential(\n", 88 | " nn.Linear(2, 128),\n", 89 | " nn.ReLU(),\n", 90 | " nn.Linear(128, 128),\n", 91 | " nn.ReLU(),\n", 92 | " nn.Linear(128, 1),\n", 93 | " nn.Sigmoid()\n", 94 | " )\n", 95 | " \n", 96 | " def forward(self, x):\n", 97 | " return self._net(x)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "device = torch.device(\"cuda:0\")" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "# Read Data" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 15, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "def generate_data():\n", 123 | " num_data = 50000\n", 124 | " noise_std = 0.1\n", 125 | " radius = 1\n", 126 | " rad = np.random.choice([\n", 127 | " 0, np.pi / 4, np.pi / 2, 3 * np.pi / 4,\n", 128 | " np.pi, 5 * np.pi / 4, 3 * np.pi / 2, 7 * np.pi / 4\n", 129 | " ], [num_data])\n", 130 | " data = np.stack([radius * np.cos(rad), radius * np.sin(rad)], axis=1)\n", 131 | " noise = np.random.normal(0, noise_std, data.shape)\n", 132 | " return data + noise\n", 133 | " \n", 134 | "raw_data = generate_data()\n", 135 | "dataset = torch.utils.data.TensorDataset(torch.Tensor(raw_data))\n", 136 | "dataset.vector_preprocess = lambda x: x\n", 137 | "#data_iterator = data.DataIterator(dataset, batch_size=batch, sampler=data.InfiniteRandomSampler(dataset))" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 20, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "array([[ 0.12963152, -1.1354435 ],\n", 149 | " [-0.0690444 , 0.87052675],\n", 150 | " [-0.05483682, 1.04465169],\n", 151 | " ...,\n", 152 | " [ 0.87101206, -0.1739673 ],\n", 153 | " [-0.7568259 , 0.72366955],\n", 154 | " [ 1.09713994, 0.14930306]])" 155 | ] 156 | }, 157 | "execution_count": 20, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "raw_data" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "# Visualize the Data" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 18, 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAALYAAACyCAYAAADvYTUSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFZ1JREFUeJztnW1TE1kWx09IgjGQRCAEAwIGGVzBUXRn3KlRa2Z33aeqrdpPsB9sv8C+2Zdb+2KfnRJ3dtwZQUZEEXBAITw0gSQQEjoh++Jfp27TBoUkYPr2+VVReaDz1P3v0/9777nnesrlMgmCbjR96C8gCCeBCFvQEhG2oCUibEFLRNiCloiwBS0RYQtaIsIWtESELWiJCFvQEhG2oCUibEFLfNW8yOPxeIiom4iy9f06glCREBEtl4+RsVeVsAmiflPlawWhGi4Q0dJRN65W2FkiotevX1M4HK7yLQTh/WQyGert7SU6pjuoVthERBQOh0XYQkMijUdBS0TYgpaIsAUtEWELWiLCFrREhC1oiQhb0BIRtqAlImxBS0TYgpaIsAUtEWELWiLCFrREhC1oiQhb0BIRtqAlImxBS0TYgpaIsAUtEWELWiLCFrREhC1oiQj7hDBNomQSt9b7wulQU10R4XAMg2jJUreI78fjH+b7uA0Rdp0wTYg5GiXy+3FLpG7t94WTRaxIneAIbRh47PcjOvv9lbcXe3KyiLDrRDRKFIspT23FLnp+bnaW6P59olzuVL+qKxBh1wm/H39rawcFTATR9/S8bUvyeaLXr4lmZk73u7oB8dh1gP11JILHkQhsBvtt9tyLi0SpFNHICFEwSHT7NkQ9NPRhv7+OiLDrgLUHJB6HgKemIOBoFOINhYj++1+ijQ0IfXQUt11dh/twoXrEitSBSITI51MRm4jI48Ht1BTRv/9NNDlJ1NdH9OmnKkLbvbc0KOuHROw6kE4TFYu4JSJaXSX66CNE64UFPP/sGVFzM9GPfoRtTBONxnJZnRD2yC9Ujwi7BuzeOhpFhJ6ehrDn52FLZmdhRfb2iFZWiM6fhwX55huis2dhRdi28PsItSHCroFKEbariyiTIfrhB6KtLViQZBIR/fZtos1NogcPiO7eJdreJlpfJzp3Du8Vj0ukrhci7BqoFGEnJiDmaJSou5vo2jWiuTlEa58PYp6aItrdVVG8WHy7J0WoDdcK2z4EXg08umh9z9ZWRO1ymaijg2h5GdutryOKj4xA1Ht78NttbXgunRZ/XU9cK2y2Eaap+plrjZSGgfe4dQvRN5cjKhSIAgH8P50mevECfdmpFNFnn6GXhEh9dj38dT1OWqfjWmGzgEyzdoFbG5H9/USdnfDZLOr+fojaNGFFensxEpnP43Wmic+tV6Q+iZPWabhK2FYBptMHo+PqKlGphPvHFZhhoFsvEICtME30VQcCEHJrK6zIq1ewIfv7ROEwvHihQDQ4WN3nHvYbTRMWZ2kJv6urCyfSu5KydMNVwuZItrqKaPnmDQ46EUQdCFRnBaJR9Z482OLx4HE6rfJEgkH0ipRKeM7nwwkQjdZuQfikzWSIxsbw3i0t6F5Mp/FnbxPojCuEzQc9GMQBv3CB6OuviV6+JBoYIPr4YzT2dnfhi3O5g1H9fVGO+6HZ1zK5nBJ6JELk9eLzzp7Fc+3tSJqKRmv3xHzSJpNEjx4RnTlD9NOfolfG40EPjZv6x10h7MVF5GnE48ime/0akWxxET0Xpkn03XfKA7e1IQIXi3j9UaKc3w/xTk0hCsfjEHY2C0Gl08gTMQzkidhFVm2PiPWkNU1Ym0QCo5wdHbA/hQJ+l1gRjcjlcGl+8QLiWVnBAR4cxAjg7CyiW3Mz7EFvL0TCYjxOlJuZgW8mgngNA5/X3o7PHB7G/zgy9/XhsbWRdxjWng4idT+ZxGeWy0TPn6N78eZNoqYmePtsFjnf16/jO4yOuqPXRGthmybRw4cQ9NYWPGc4DKEND0MIY2PwwjduwKJsbyPyBoM46McRACc3WdNQy2Xcsr+152pb//cukkmVMej3H5xPmU7DW3//PbbzevHb83k0Wre38f9EAs9PTeF/RPp6bq2FbRjws/39EKfPhxyOchnRdX0dB/zcOaI7d96O0lYxcXR9F8EgIiLDl35rlK1lEIYzBvn9IhG85/XrRH/+M3JTTBNRemcHOStnziByj4yo9kM+X31D2SloLexoFAeaRcQR+A9/gEhiMYi+r++gqK3RmcVUDfZIXEuSk/Uk4VuOvNks2gTBoBJyuYyTyOPB1WppCa/L5dCQ1N1va52PzbNXNjdx0Pv6cLmenkbDsbMTBzifRwTnXgXOiY7HlTet1/epVlD21xoGLMbmJhqJn3yC32eaRE+fQvh9feh1WVlB2uz4OB7z63XO+9Y6YhMdvGwvLuLx4CBuL1+G725rw6U5GMS2CwuIgCMjjetBue98awsnaiqF79/eDvt18yYazD4fTs7f/AbtCB7Asf5GHSO39sLmSLe4iIbizg56P5qa4K2TSfSMlEoQfCKBS3yhAJEcxVt/CPx+NFJNE3ZqdhaCHhyEmItFWJTNTfT+7O2h64/bAKmUGlBq1JO3FrQXtpVQCH/ZLP5evkS0e/4cwi6VcBkPBCDsRiedhoculSDORAIReGICjcnlZVyNOjthW1ZWlJArDSjphGuEzf40EkFX3/w80khbW3EJ39mBz97bQ3bejRt4HfcxNxKVZsVbG77t7WpIfXgYNqS/XzU6pR9bI9iSmCaEvLxM9NVXmMny+99jYCaVwuV6eBjbcE9Co12qK3UbWvvc43GiX/wCuTDZLHpBeNDJ74f90j332xXCtkaoZBI9BJ2dGJC5dYvo6lX8f2BATcJlGvFSXanbkEcgOzoQrT0e/L75efjs7W31GtNEV2cj/rZ64Qph2yuflsvwnl98ASEcNnDSqNHssKtIOo2rTlsbrkChEJ6LxVTENgx0+fX06GtDiFwibHuEY69tzwVxcgSLx2GruE3AwrdaECL3zIT3lDmZ4Tgv8njCRJROp9MUDofr/61OCDc0mnQjk8lQBK3kSLlczhz1dVqPPNqpVPXUqUjVqHfjCivC6HQZlqpR78ZVwm7Errtq0ekkPQm0tyK6XrJrSahyA9pHbClF8DZuaERrL2x7/RAifexItbjhZNfOitith3UGi+6jbe/DNJHlmMupuiPffYf9Zd1GB+vm+Iidy6nlLoLBg70FPNLG2+iae3xUDAPpBKaJDMadHfzZt9EhV9txwrb7Q54ZbprIjeDaITwfcGkJQ+i1TPHShWgUCV5c3Ke1Ve2ruTkMxycSSJwqFtV+dqIfd5yw7f23PCM8FEKkyWZxnwdh2tpw29Ul3ppLPsTjsBss8EePVL2Vixch9HPnlKid2DZxnLAr5X1wmbLxcUQcr1fV4gsEEH3sST9u6BmoBP9uIuyffB4zb3gyQqEA68aLPjm1v9xRwq4kRk7XnJtDimY4jKjz6BFSUIeGkOxkL6zu1EhUDfa03clJ7KezZ5G2ywuoZjKwKh0datt6VoE9TRwlbKsYg0Gi//wH/nl8HAcqFiO6cgUt/dVV3Pf7cWCfPMFzPT0HS4w5LRJVgz1tl8utcTskkcA+e/EC+/DGDUyX+9vfcP+zz5x3VXOUsCMRzApJp7HTx8bQAGpqwmSBH/8YPntqCoItl4n++EdM9/J6EZl8Pvet92I9iXmSwcQEIvTCAiYiRCKY9BuN4ur34AHRP/6B+93dRJcufdjfcFwcJWyONDMz8IK9vbAauZw6eMUitvF6if76VzSILlzAVKlYDNtY1zvX3Wdbu0OJcH9sDKXfQiHVI1IqEf3857j6LSygSlZTE56fmUGj00n7yFHC5u4qFmY2iwP26hUEvLGBv1gM0ScSwYTWL79EqWDrfD8+SLr7bGuhzFCI6PFj9BSFQrBzly/jyra+jqAwOYmTYW+P6PPP0U4JhbDfnDRK6Shhc3cVF1bc2UEj8dYt7PyxMTQgAwEIvVjENrxSVyJxcNHQeq770qhYC2U+eYIBmkAA0ZkI9m15GaUZnj5FL0kigYWfOjpQXzsYxH6bnERvCS+X3cg4StjcuudKoisrKNbe2Ykdzn771StEmN1dVEFKp1FqgaOSx4Pn3OCzrYUyYzHYN54a19MDG9LUhP3Q1QWLd/UqAsL+Pl7HBYcyGQQFJxTZcZSwrTOxYzEMImSz8NDT0weXydjags++cgWX0/FxtPC5a0vnKG2H8z+IiH77W0zmXVqCly6Xsf+uXkWkbmnBVc8wsP+IVLffnTt47IR95yhhE6n1ys+ehcXIZhG5Hz9GROnqwoEqlXDQfD612JBpNm7JspOCbdubNxDqyAieX1pCsfuFBURsrqkdi2HftbcrGzMxcfRSyo2Co4TNM7GJEDUePoQvJIKYOzthQ65cQQTa3ETk5j5uJ0SaemMYsG1dXRDr6qpqMJ8/j322swOLViggmp85g8Y219N2Yp6No4RtXd6CCGuTF4to/HR1IYJvbOD+pUuINH//O9Evf4lLcKM3eE4Cax+2YaCd0dKC/fPRR2iHJBK4Eq6u4s/rxb4qFhEQrl93XlBwlLDtBINE9+4p/8jCvXdPjUwmk7jkOu3A1AvrPE/eB9euQci5HPYPV5Xl5Cje9qirpjUijhW2Nf+Bo3gyiTK6PDPkV7+CT/z8c2cenGp4V3KXVeR+P9G33yKCs8Ct+5JI1Qt3Io4VdqUkJmuRd17GYnRUVSV1A0dN7jIMNUKbzarafo3ejXdUHCvsSklMHJGSSXcsIFSJoyZ3RaOwJHz/uEv/NTqOFbb1smq//FoPrlssCHPU2in2hriTbUcltJjMay9dJjU3BMdGbCtuyq0WjoYWwtapdJlQH7SwIoJgR4QtaIkIW9ASEbYGvK8smS5ly46Da4RtP7i5HJKkOD/bybxvpQadVnI4Kq4Rtv3gzswgh/vhQzWHstGj2mHfMRrFbBiehc7bWItQui1tV4vuvqNg7+vmNdPX1lBmYGgI+dtEjdt1+K48EOssmbU1PE6l1AQDJ8xTrCeuidj20UjDwN/LlygOk0qpqNeocGTmqlbWEhKTk5jMbJrYJp0m+te/kG/d2wuB62C7jop2wrZfris9XlyEEJJJJN2PjiInmSsjNaod4ZMznVa2ii1HRwfyPdbWIODHj3FFmp7GjP3vv4ft4hTVRv2N9UI7K2Kv72y9fPv9RH/6E+Y/ZjJq/uTyMuZQZrPICmz02tCRCL4jn4iLi/juLS2wVbEY/nfhAmxIJoM2RXu7KpRD1LiWqx5oF7GjUaSr5vMq4y8WQ4T6y1+I7t/HQR4cRHbbzo4qzdDVhftbW43Vg2C/6qTTmLbFqabZLOZ+fv01ovOLF5jE3N2Nbe7fV6XfdnYg9ka2XPVAu4jt96MhODWlPKVpEv3zn4jcHg+i89wcBFEoQBDJJC7j+/uIdrkcImEjZAnaG432hnBrK74vl6NYWcFv+vWvcVL7fGhDfPwxIj2XCNYZ7YRNpGr8GQYO4LNnRF99hcdtbUSzs4h4pRK23dzEDJLOTnjVQADe1OdrjASrd2UvGgZ+QzAI8Q4Pq9/+6pWamMs1Qbj+daOctCeFlsLmGn9MqUQ0MEDU34/a2aUSxDw/D9tx/TrEvL6OiLexQfSzn2F7e13t43DU4vLv285+clnbEUNDKDfBdQu//Raeem8PkXt/H+UnlpchdJ4Olkg0xkl7UmgpbOvsENMkunlTFaXkguZjY2hozc9DGPv7iOzhME6KQADvwUUsiY4vgkr9zpVEfNwFjaJRtcyGYeBk7OnBe8zNQdAdHXgfnw8WxefD9q9fo+wCD+bkcs6ejX4YWgrbCkclw1C9CNwl1t2NKki8NBwRqo/evasqJtUyicE6uZijfjIJ/2+trGQX6vtOIL8fr+d1ZAoFVL0aGkJkzuXwHsEgrkwXL+I3xuOI5hcv4kRYW4N9KRbxvjpFb+2FzSW+8nncf/AAUTmVQsGdUkn1Z9+9i5LDP/mJmgNYy+XaOrl4YQGDJPk8Hp87pzwuC5VPPrv1qRTl+XX5PP5u3UJvz9QUxN7Zic9YW0NDMptFGQquKc6vj0T0m8hL5AJhc4kvrxfdYCsrEPG1a4haQ0OwIIYBX7q8jPv1rFMXjULUz54hOnJfOUduFhmfBHb7cthQOkd6jwf+eWQEjWCvFydtPI42RDIJYQeDKA/M8HvpNpGXyAXCti45vb2NunT37uFybZqIVp98gm0mJiAU3r7W1Q6s+Rs8sun1Imo+far6y+1itd4e9hyRivR8ReKyE6OjiNa8mkNHhyolzMWEdEd7YXMk5AN6545a2WByEsK+excRenT04KKd1TQarSeEYWBoe34e77+7C7uwvAxvb5pvW49K1udddoj77fl9nz9HVGZrw8sCFouwJTr3hFjRXthElaMvR8+WFrVdpTp3x/Gedj+/u4skq//9D7ka+TxE/umnKL82MABRT0/XtlIAj0T6fBiQ4fdNpdQaPbp66cNwhbArRd90Gj0FgUDlCHacyGZfaYEXUB0fxzJz6+uwQa2tyNNobYVVmJmBNXryBL0xFy5UF01ZrMEgRlV9PqJvvsHn+v3KhujopQ/DFcJ+n2+t1XPyiROLYVCHMwi3ttC9tr2N6F0soqttYwP/b2mBfcjnIfBqoyn3z/PVor0dPTtraziR3OKrrWiXBFWJSpWh6lUtitNGrTNUolEIdX8f0butTa0Y3NKCRt30NLYplQ6OlFabUsq9P3wFunQJn7e52VgJXaeFKyL2SWIYiLqBAB6vreG2vR11uS9fRnRubkZG4cWLEODt2xjWXl/HqGhfX23LYFe6Arm5QpYIu0aso4ZEaobLmzewJV4vLMncHGzB2hqsCUfY3d2D72W9PQ7H7U3RHRF2HWhvxy1bm2QSFiMYhNBTKTV8nU5DzOPjOAF8PpWN54bl+U4LEXaNcO5JTw9EzYlFnEH3ww9Ev/sduuB8Pnje5mb0gBQKsCeRiDvtwkkiwq4Ru31IJtHVFg4jkp8/D//d0QHRFwrI49jbQ49JPo8BFbf1Wpw0IuwaqeRjQyEIOZVSqaIvX8KeNDfDohgGekVWV52x0q3TEGHXGfbZuRy69LxezDNcWkKkvnkTjcetLYi8o8Nda+ScFq7oxz5NrP3jXi+GtCMRNSrY30/0xRfoZ+Zhbp4tL9QPidgnBIs7GlX+e2BAPf/llwfXVBTqiwj7hLB671wONiQcPjhRwElrkzsNEfYp4OYRwA+FCPsUcPMI4IdCGo+CloiwBS0RYQtaIsIWtESELWiJCFvQEhG2oCUibEFLRNiCloiwBS0RYQtaIsIWtESELWiJCFvQEhG2oCUibEFLRNiCloiwBS0RYQtaIsIWtESELWiJCFvQEhG2oCUibEFLRNiCltRUCSqTydTrewhCRarVmKdcLh//RR5PDxG9qeoTBaE6LpTL5aWjblytsD1E1E1E2WO/WBCOT4iIlsvHEGtVwhaERkcaj4KWiLAFLRFhC1oiwha0RIQtaIkIW9ASEbagJSJsQUtE2IKWiLAFLfk/4onwV/acAxoAAAAASUVORK5CYII=\n", 181 | "text/plain": [ 182 | "
" 183 | ] 184 | }, 185 | "metadata": {}, 186 | "output_type": "display_data" 187 | } 188 | ], 189 | "source": [ 190 | "vis_limit = 2.0\n", 191 | "\n", 192 | "def visualize_samples(samples, save_path=None):\n", 193 | " fig = plt.figure(figsize=(2,2), dpi=100)\n", 194 | " plt.plot(samples[:,0], samples[:,1], 'b.', markersize=1, alpha=0.2)\n", 195 | " axes = plt.gca()\n", 196 | " axes.set_xlim([-vis_limit, vis_limit])\n", 197 | " axes.set_ylim([-vis_limit, vis_limit])\n", 198 | " plt.xticks([])\n", 199 | " plt.yticks([])\n", 200 | " plt.show()\n", 201 | " \n", 202 | " if save_path:\n", 203 | " os.makedirs(os.path.dirname(save_path), exist_ok=True)\n", 204 | " fig.savefig(save_path, bbox_inches='tight', pad_inches=0)\n", 205 | "\n", 206 | " return fig\n", 207 | "\n", 208 | "grid_size = 64\n", 209 | "grid = np.tile(np.reshape(np.linspace(-vis_limit, vis_limit, grid_size), (grid_size, 1)), (1, grid_size))\n", 210 | "coord = np.stack([np.transpose(grid), grid], axis=2)\n", 211 | "coord = coord[::-1, :]\n", 212 | "coord = np.reshape(coord, [grid_size ** 2, 2])\n", 213 | "coord = torch.Tensor(coord).to(device)\n", 214 | "\n", 215 | "def visualize_discriminator(d):\n", 216 | " with torch.no_grad():\n", 217 | " verdicts = d(coord).cpu().numpy()\n", 218 | " verdicts = np.reshape(verdicts, [grid_size, grid_size])\n", 219 | " fig = plt.figure(figsize=(2,2), dpi=100)\n", 220 | " plt.imshow(verdicts)\n", 221 | " plt.xticks([])\n", 222 | " plt.yticks([])\n", 223 | " plt.show()\n", 224 | " return fig\n", 225 | "\n", 226 | "fig = visualize_samples(raw_data[:1000], save_path='octa/data.png')" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "# Train GAN" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 19, 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "ename": "NameError", 243 | "evalue": "name 'data_iterator' is not defined", 244 | "output_type": "error", 245 | "traceback": [ 246 | "\u001b[0;31m----------------------------------------------------------\u001b[0m", 247 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 248 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msummary_step\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m200\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;31m# Train D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m z = torch.normal(torch.zeros((x.size(0), z_dim)),\n\u001b[1;32m 16\u001b[0m torch.ones((x.size(0), z_dim))).to(device)\n", 249 | "\u001b[0;31mNameError\u001b[0m: name 'data_iterator' is not defined" 250 | ] 251 | } 252 | ], 253 | "source": [ 254 | "G = Generator().to(device)\n", 255 | "D = Discriminator().to(device)\n", 256 | "G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001)\n", 257 | "D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0001)\n", 258 | "criterion = nn.BCELoss()\n", 259 | "real_label = 1\n", 260 | "fake_label = 0\n", 261 | "summary_step = 10\n", 262 | "\n", 263 | "# Train\n", 264 | "training_loss_sum = 0.0\n", 265 | "for step in range(summary_step * 200):\n", 266 | " # Train D\n", 267 | " x = next(data_iterator).to(device)\n", 268 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 269 | " torch.ones((x.size(0), z_dim))).to(device)\n", 270 | "\n", 271 | " D.zero_grad()\n", 272 | " label = torch.full((x.size(0),), real_label, device=device)\n", 273 | " fake, _ = G(z)\n", 274 | " real_loss = criterion(D(x), label)\n", 275 | " real_loss.backward()\n", 276 | "\n", 277 | " label.fill_(fake_label)\n", 278 | " fake_loss = criterion(D(fake.detach()), label)\n", 279 | " fake_loss.backward()\n", 280 | " D_optimizer.step()\n", 281 | "\n", 282 | " # Train G\n", 283 | " G.zero_grad()\n", 284 | " label.fill_(real_label)\n", 285 | " G_loss = criterion(D(fake), label)\n", 286 | " G_loss.backward()\n", 287 | " G_optimizer.step()\n", 288 | "\n", 289 | " # Summary\n", 290 | " if step % 10 == 0:\n", 291 | " with torch.no_grad():\n", 292 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 293 | " samples, _ = G(z)\n", 294 | " samples = samples.cpu().numpy()\n", 295 | " display.clear_output(wait=True)\n", 296 | " print(\"Step %d\" % step)\n", 297 | " fig = visualize_samples(samples, 'octa/gan/%07d.png' % step)\n", 298 | "\n", 299 | " # Visualize D's view\n", 300 | " visualize_discriminator(D)\n" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": {}, 306 | "source": [ 307 | "# Train Gaussian Predictor" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "P_gaussian = Generator().to(device)\n", 317 | "P_optimizer = torch.optim.Adam(P_gaussian.parameters(), lr=0.0001)\n", 318 | "mse_loss = nn.MSELoss()\n", 319 | "summary_step = 50 # for MM\n", 320 | "summary_step = 5\n", 321 | "\n", 322 | "# Train\n", 323 | "training_loss_sum = 0.0\n", 324 | "for step in range(summary_step * 200):\n", 325 | " # Train G\n", 326 | " x = next(data_iterator).to(device)\n", 327 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 328 | " torch.ones((x.size(0), z_dim))).to(device)\n", 329 | " P_gaussian.zero_grad()\n", 330 | " mean, log_var = P_gaussian(z)\n", 331 | " P_loss = ((mean - x) ** 2) / (2 * torch.exp(log_var)) + log_var / 2\n", 332 | " P_loss.sum(dim=1).mean(dim=0).backward()\n", 333 | " P_optimizer.step()\n", 334 | "\n", 335 | " # Summary\n", 336 | " if step % summary_step == 0:\n", 337 | " with torch.no_grad():\n", 338 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 339 | " mean, log_var = P_gaussian(z)\n", 340 | " samples = torch.normal(mean.repeat(2, 1), log_var.exp().sqrt().repeat(2, 1))\n", 341 | " mean = mean.cpu().numpy()\n", 342 | " samples = samples.cpu().numpy()\n", 343 | " display.clear_output(wait=True)\n", 344 | " print(\"Step %d\" % step)\n", 345 | " fig = visualize_samples(mean, 'octa/gaussian_pred_mean/%07d.png' % step)\n", 346 | " fig = visualize_samples(samples, 'octa/gaussian_pred/%07d.png' % step)\n" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "# Train Laplace Predictor" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "P_laplace = Generator().to(device)\n", 363 | "P_optimizer = torch.optim.Adam(P_laplace.parameters(), lr=0.0001)\n", 364 | "mse_loss = nn.MSELoss()\n", 365 | "summary_step = 50 # for MM\n", 366 | "summary_step = 5\n", 367 | "\n", 368 | "# Train\n", 369 | "training_loss_sum = 0.0\n", 370 | "for step in range(summary_step * 200):\n", 371 | " # Train G\n", 372 | " x = next(data_iterator).to(device)\n", 373 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 374 | " torch.ones((x.size(0), z_dim))).to(device)\n", 375 | " P_laplace.zero_grad()\n", 376 | " mean, log_var = P_laplace(z)\n", 377 | " P_loss = (mean - x).abs() / torch.exp(log_var) + log_var\n", 378 | " P_loss.sum(dim=1).mean(dim=0).backward()\n", 379 | " P_optimizer.step()\n", 380 | "\n", 381 | " # Summary\n", 382 | " if step % summary_step == 0:\n", 383 | " with torch.no_grad():\n", 384 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 385 | " mean, log_var = P_laplace(z)\n", 386 | " laplace = torch.distributions.laplace.Laplace(mean.repeat(2, 1), log_var.exp().repeat(2, 1))\n", 387 | " samples = laplace.sample()\n", 388 | " mean = mean.cpu().numpy()\n", 389 | " samples = samples.cpu().numpy()\n", 390 | " display.clear_output(wait=True)\n", 391 | " print(\"Step %d\" % step)\n", 392 | " fig = visualize_samples(mean, 'octa/laplace_pred_median/%07d.png' % step)\n", 393 | " fig = visualize_samples(samples, 'octa/laplace_pred/%07d.png' % step)\n" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "# Train GAN with L2 Loss" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "G = Generator().to(device)\n", 410 | "D = Discriminator().to(device)\n", 411 | "G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001)\n", 412 | "D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0001)\n", 413 | "criterion = nn.BCELoss()\n", 414 | "real_label = 1\n", 415 | "fake_label = 0\n", 416 | "summary_step = 10\n", 417 | "\n", 418 | "# Train\n", 419 | "training_loss_sum = 0.0\n", 420 | "for step in range(summary_step * 200):\n", 421 | " # Train D\n", 422 | " x = next(data_iterator).to(device)\n", 423 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 424 | " torch.ones((x.size(0), z_dim))).to(device)\n", 425 | "\n", 426 | " D.zero_grad()\n", 427 | " label = torch.full((x.size(0),), real_label, device=device)\n", 428 | " fake, _ = G(z)\n", 429 | " real_loss = criterion(D(x), label)\n", 430 | " real_loss.backward()\n", 431 | "\n", 432 | " label.fill_(fake_label)\n", 433 | " fake_loss = criterion(D(fake.detach()), label)\n", 434 | " fake_loss.backward()\n", 435 | " D_optimizer.step()\n", 436 | "\n", 437 | " # Train G\n", 438 | " G.zero_grad()\n", 439 | " label.fill_(real_label)\n", 440 | " G_loss = criterion(D(fake), label)\n", 441 | " l2_loss = mse_loss(fake, x)\n", 442 | " (G_loss + l2_loss).backward()\n", 443 | " G_optimizer.step()\n", 444 | "\n", 445 | " # Summary\n", 446 | " if step % summary_step == 0:\n", 447 | " with torch.no_grad():\n", 448 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 449 | " samples, _ = G(z)\n", 450 | " samples = samples.cpu().numpy()\n", 451 | " display.clear_output(wait=True)\n", 452 | " print(\"[Step %d] L2 Loss: %.3f | GAN Loss: %.3f\" % (step, l2_loss, G_loss))\n", 453 | " fig = visualize_samples(samples, 'octa/gan+l2/%07d.png' % step)\n", 454 | " fig = visualize_discriminator(D)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "metadata": {}, 460 | "source": [ 461 | "# Train GAN with Gaussian Moment Matching" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 10, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "def train_harmony(method, dist, order):\n", 471 | " G = Generator().to(device)\n", 472 | " D = Discriminator().to(device)\n", 473 | " G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001)\n", 474 | " D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0001)\n", 475 | " criterion = nn.BCELoss()\n", 476 | " real_label = 1\n", 477 | " fake_label = 0\n", 478 | " summary_step = 75\n", 479 | "\n", 480 | " # Train\n", 481 | " training_loss_sum = 0.0\n", 482 | " for step in range(summary_step * 200):\n", 483 | " # Train D\n", 484 | " x = next(data_iterator).to(device)\n", 485 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 486 | " torch.ones((x.size(0), z_dim))).to(device)\n", 487 | "\n", 488 | " D.zero_grad()\n", 489 | " label = torch.full((x.size(0),), real_label, device=device)\n", 490 | " fake, _ = G(z)\n", 491 | " real_loss = criterion(D(x), label)\n", 492 | " real_loss.backward()\n", 493 | "\n", 494 | " label.fill_(fake_label)\n", 495 | " fake_loss = criterion(D(fake.detach()), label)\n", 496 | " fake_loss.backward()\n", 497 | " D_optimizer.step()\n", 498 | "\n", 499 | " # Train G\n", 500 | " with torch.no_grad():\n", 501 | " p_mean, p_log_var = P(z)\n", 502 | " p_var = torch.exp(p_log_var)\n", 503 | " G.zero_grad()\n", 504 | " label.fill_(real_label)\n", 505 | " G_loss = criterion(D(fake), label)\n", 506 | " gen_mean = fake.mean(dim=0)\n", 507 | " gen_var = fake.std(dim=0) ** 2\n", 508 | " mm_loss = (p_mean[0] - gen_mean) ** 2 + (p_var[0] - gen_var) ** 2\n", 509 | " mm_loss = mm_loss.mean(dim=0)\n", 510 | " (G_loss + mm_loss).backward()\n", 511 | " G_optimizer.step()\n", 512 | "\n", 513 | " # Summary\n", 514 | " if step % summary_step == 0:\n", 515 | " with torch.no_grad():\n", 516 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 517 | " samples, _ = G(z)\n", 518 | " samples = samples.cpu().numpy()\n", 519 | " display.clear_output(wait=True)\n", 520 | " print(\"[Step %d] MM Loss: %.3f | GAN Loss: %.3f\" % (step, mm_loss, G_loss))\n", 521 | " fig = visualize_samples(samples, 'octa/mm/%07d.png' % step)\n", 522 | " fig = visualize_discriminator(D)\n" 523 | ] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": {}, 528 | "source": [ 529 | "# Train GAN with L1 Loss" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 11, 535 | "metadata": {}, 536 | "outputs": [ 537 | { 538 | "ename": "NameError", 539 | "evalue": "name 'data_iterator' is not defined", 540 | "output_type": "error", 541 | "traceback": [ 542 | "\u001b[0;31m----------------------------------------------------------\u001b[0m", 543 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 544 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_step\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;31m# Train D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m z = torch.normal(torch.zeros((x.size(0), z_dim)),\n\u001b[1;32m 14\u001b[0m torch.ones((x.size(0), z_dim))).to(device)\n", 545 | "\u001b[0;31mNameError\u001b[0m: name 'data_iterator' is not defined" 546 | ] 547 | } 548 | ], 549 | "source": [ 550 | "G = Generator().to(device)\n", 551 | "D = Discriminator().to(device)\n", 552 | "G_optimizer = torch.optim.SGD(G.parameters(), lr=0.001)\n", 553 | "D_optimizer = torch.optim.SGD(D.parameters(), lr=0.001)\n", 554 | "criterion = nn.BCELoss()\n", 555 | "real_label = 1\n", 556 | "fake_label = 0\n", 557 | "# Train\n", 558 | "training_loss_sum = 0.0\n", 559 | "for step in range(max_step):\n", 560 | " # Train D\n", 561 | " x = next(data_iterator).to(device)\n", 562 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 563 | " torch.ones((x.size(0), z_dim))).to(device)\n", 564 | "\n", 565 | " D.zero_grad()\n", 566 | " label = torch.full((x.size(0),), real_label, device=device)\n", 567 | " fake, _ = G(z)\n", 568 | " real_loss = criterion(D(x), label)\n", 569 | " real_loss.backward()\n", 570 | "\n", 571 | " label.fill_(fake_label)\n", 572 | " fake_loss = criterion(D(fake.detach()), label)\n", 573 | " fake_loss.backward()\n", 574 | " D_optimizer.step()\n", 575 | "\n", 576 | " # Train G\n", 577 | " G.zero_grad()\n", 578 | " label.fill_(real_label)\n", 579 | " G_loss = criterion(D(fake), label)\n", 580 | " l1_loss = (fake - x).abs().mean()\n", 581 | " (G_loss + l1_loss).backward()\n", 582 | " G_optimizer.step()\n", 583 | "\n", 584 | " # Summary\n", 585 | " if step % summary_step == 0:\n", 586 | " with torch.no_grad():\n", 587 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 588 | " samples, _ = G(z)\n", 589 | " samples = samples.cpu().numpy()\n", 590 | " display.clear_output(wait=True)\n", 591 | " print(\"[Step %d] L1 Loss: %.3f | GAN Loss: %.3f\" % (step, l1_loss, G_loss))\n", 592 | " os.makedirs('notebooks/toy_data_summary/gan+l1', exist_ok=True)\n", 593 | " fig = visualize_samples(samples)\n", 594 | " fig.savefig('notebooks/toy_data_summary/gan+l1/%07d.pdf' % step, bbox_inches='tight', pad_inches=0)\n", 595 | "\n", 596 | " # Visualize D's view\n", 597 | " fig = visualize_discriminator(D)\n", 598 | " fig.savefig('notebooks/toy_data_summary/gan+l1/d_%07d.pdf' % step, bbox_inches='tight', pad_inches=0)\n" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "# Train Laplace Predictor" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 12, 611 | "metadata": {}, 612 | "outputs": [ 613 | { 614 | "ename": "NameError", 615 | "evalue": "name 'data_iterator' is not defined", 616 | "output_type": "error", 617 | "traceback": [ 618 | "\u001b[0;31m----------------------------------------------------------\u001b[0m", 619 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 620 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_step\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;31m# Train G\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m z = torch.normal(torch.zeros((x.size(0), z_dim)),\n\u001b[1;32m 10\u001b[0m torch.ones((x.size(0), z_dim))).to(device)\n", 621 | "\u001b[0;31mNameError\u001b[0m: name 'data_iterator' is not defined" 622 | ] 623 | } 624 | ], 625 | "source": [ 626 | "P_laplace = Generator().to(device)\n", 627 | "P_optimizer = torch.optim.SGD(P_laplace.parameters(), lr=0.001)\n", 628 | "mse_loss = nn.MSELoss()\n", 629 | "# Train\n", 630 | "training_loss_sum = 0.0\n", 631 | "for step in range(max_step):\n", 632 | " # Train G\n", 633 | " x = next(data_iterator).to(device)\n", 634 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 635 | " torch.ones((x.size(0), z_dim))).to(device)\n", 636 | " P_laplace.zero_grad()\n", 637 | " mean, log_var = P_laplace(z)\n", 638 | " P_loss = (mean - x).abs() / torch.exp(log_var) + log_var\n", 639 | " P_loss.sum(dim=1).mean(dim=0).backward()\n", 640 | " P_optimizer.step()\n", 641 | "\n", 642 | " # Summary\n", 643 | " if step % summary_step == 0:\n", 644 | " with torch.no_grad():\n", 645 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 646 | " mean, log_var = P_laplace(z)\n", 647 | " dist = torch.distributions.Laplace(mean, log_var.exp())\n", 648 | " samples = dist.sample()\n", 649 | " mean = mean.cpu().numpy()\n", 650 | " samples = samples.cpu().numpy()\n", 651 | " display.clear_output(wait=True)\n", 652 | " print(\"Step %d\" % step)\n", 653 | " os.makedirs('notebooks/toy_data_summary/predictor_laplace_median', exist_ok=True)\n", 654 | " fig = visualize_samples(mean)\n", 655 | " fig.savefig('notebooks/toy_data_summary/predictor_laplace_median/%07d.pdf' % step, bbox_inches='tight', pad_inches=0)\n", 656 | " fig = visualize_samples(samples)\n", 657 | " fig.savefig('notebooks/toy_data_summary/predictor_laplace_sample/%07d.pdf' % step, bbox_inches='tight', pad_inches=0)\n" 658 | ] 659 | }, 660 | { 661 | "cell_type": "markdown", 662 | "metadata": {}, 663 | "source": [ 664 | "# Train GAN with Laplace Moment Matching" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": 13, 670 | "metadata": {}, 671 | "outputs": [ 672 | { 673 | "ename": "NameError", 674 | "evalue": "name 'data_iterator' is not defined", 675 | "output_type": "error", 676 | "traceback": [ 677 | "\u001b[0;31m----------------------------------------------------------\u001b[0m", 678 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 679 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_step\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;31m# Train D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m z = torch.normal(torch.zeros((x.size(0), z_dim)),\n\u001b[1;32m 15\u001b[0m torch.ones((x.size(0), z_dim))).to(device)\n", 680 | "\u001b[0;31mNameError\u001b[0m: name 'data_iterator' is not defined" 681 | ] 682 | } 683 | ], 684 | "source": [ 685 | "G = Generator().to(device)\n", 686 | "D = Discriminator().to(device)\n", 687 | "G_optimizer = torch.optim.SGD(G.parameters(), lr=0.001)\n", 688 | "D_optimizer = torch.optim.SGD(D.parameters(), lr=0.001)\n", 689 | "criterion = nn.BCELoss()\n", 690 | "real_label = 1\n", 691 | "fake_label = 0\n", 692 | "\n", 693 | "# Train\n", 694 | "training_loss_sum = 0.0\n", 695 | "for step in range(max_step * 5):\n", 696 | " # Train D\n", 697 | " x = next(data_iterator).to(device)\n", 698 | " z = torch.normal(torch.zeros((x.size(0), z_dim)),\n", 699 | " torch.ones((x.size(0), z_dim))).to(device)\n", 700 | "\n", 701 | " D.zero_grad()\n", 702 | " label = torch.full((x.size(0),), real_label, device=device)\n", 703 | " fake, _ = G(z)\n", 704 | " real_loss = criterion(D(x), label)\n", 705 | " real_loss.backward()\n", 706 | "\n", 707 | " label.fill_(fake_label)\n", 708 | " fake_loss = criterion(D(fake.detach()), label)\n", 709 | " fake_loss.backward()\n", 710 | " D_optimizer.step()\n", 711 | "\n", 712 | " # Train G\n", 713 | " with torch.no_grad():\n", 714 | " p_mean, p_log_var = P_laplace(z)\n", 715 | " p_var = torch.exp(p_log_var)\n", 716 | " G.zero_grad()\n", 717 | " label.fill_(real_label)\n", 718 | " G_loss = criterion(D(fake), label)\n", 719 | " gen_median = fake.median(dim=0, keepdim=True)[0]\n", 720 | " gen_mad = (fake - gen_median).abs().mean(0)\n", 721 | " gen_median = gen_median.squeeze(0)\n", 722 | " mm_loss = (p_mean[0] - gen_median) ** 2 + (p_var[0] - gen_mad) ** 2\n", 723 | " mm_loss = mm_loss.mean(dim=0)\n", 724 | " (G_loss + mm_loss).backward()\n", 725 | " G_optimizer.step()\n", 726 | "\n", 727 | " # Summary\n", 728 | " if step % summary_step == 0:\n", 729 | " with torch.no_grad():\n", 730 | " z = torch.normal(torch.zeros((1000, z_dim)), 1.0).to(device)\n", 731 | " samples, _ = G(z)\n", 732 | " samples = samples.cpu().numpy()\n", 733 | " display.clear_output(wait=True)\n", 734 | " print(\"[Step %d] MM Loss: %.3f | GAN Loss: %.3f\" % (step, mm_loss, G_loss))\n", 735 | " fig = visualize_samples(samples)\n", 736 | " os.makedirs('notebooks/toy_data_summary/gan+laplace_mm', exist_ok=True)\n", 737 | " fig.savefig('notebooks/toy_data_summary/gan+laplace_mm/%07d.pdf' % step, bbox_inches='tight', pad_inches=0)\n", 738 | "\n", 739 | " # Visualize D's view\n", 740 | " fig = visualize_discriminator(D)\n", 741 | " fig.savefig('notebooks/toy_data_summary/gan+laplace_mm/d_%07d.pdf' % step, bbox_inches='tight', pad_inches=0)\n" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": null, 747 | "metadata": {}, 748 | "outputs": [], 749 | "source": [] 750 | }, 751 | { 752 | "cell_type": "code", 753 | "execution_count": null, 754 | "metadata": {}, 755 | "outputs": [], 756 | "source": [] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": null, 768 | "metadata": {}, 769 | "outputs": [], 770 | "source": [] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": null, 775 | "metadata": {}, 776 | "outputs": [], 777 | "source": [] 778 | } 779 | ], 780 | "metadata": { 781 | "kernelspec": { 782 | "display_name": "Python 3", 783 | "language": "python", 784 | "name": "python3" 785 | }, 786 | "language_info": { 787 | "codemirror_mode": { 788 | "name": "ipython", 789 | "version": 3 790 | }, 791 | "file_extension": ".py", 792 | "mimetype": "text/x-python", 793 | "name": "python", 794 | "nbconvert_exporter": "python", 795 | "pygments_lexer": "ipython3", 796 | "version": "3.6.8" 797 | } 798 | }, 799 | "nbformat": 4, 800 | "nbformat_minor": 2 801 | } 802 | -------------------------------------------------------------------------------- /Day05/visdom_utils.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | # import matplotlib.pyploy as plt 3 | # import scipy.ndimage 4 | from scipy.misc import imresize 5 | import numpy as np 6 | from torchvision.utils import make_grid 7 | 8 | class VisFunc(object): 9 | 10 | def __init__(self, config=None, vis=None, enval='main',port=8097): 11 | self.config = config 12 | self.vis = visdom.Visdom(env=enval, port=port) 13 | self.win = None 14 | self.win2 = None 15 | self.epoch_list = [] 16 | self.train_loss_list = [] 17 | self.val_loss_list = [] 18 | self.epoch_list2 = [] 19 | self.train_acc_list = [] 20 | self.val_acc_list = [] 21 | 22 | 23 | 24 | def imshow(self, img, title=' ', caption=' ', factor=1): 25 | 26 | img = img / 2 + 0.5 # Unnormalize 27 | npimg = img.numpy() 28 | obj = np.transpose(npimg, (1,2,0)) 29 | obj = np.swapaxes(obj,0,2) 30 | obj = np.swapaxes(obj,1,2) 31 | 32 | imgsize = tuple((np.array(obj.shape[1:])*factor).astype(int)) 33 | rgbArray = np.zeros(tuple([3])+imgsize,'float32') 34 | rgbArray[0,...] = imresize(obj[0,:,:],imgsize,'cubic') 35 | rgbArray[1,...] = imresize(obj[1,:,:],imgsize,'cubic') 36 | rgbArray[2,...] = imresize(obj[2,:,:],imgsize,'cubic') 37 | 38 | """ 39 | rgbArray[0,...] = scipy.ndimage.zoom(obj[0,:,:],2,order=3) 40 | rgbArray[1,...] = scipy.ndimage.zoom(obj[1,:,:],2,order=0) 41 | rgbArray[2,...] = scipy.ndimage.zoom(obj[2,:,:],2,order=0) 42 | """ 43 | 44 | self.vis.image( rgbArray, 45 | opts=dict(title=title, caption=caption), 46 | ) 47 | 48 | 49 | def imshow_multi(self, imgs, nrow=10, title=' ', caption=' ', factor=1): 50 | self.imshow( make_grid(imgs,nrow), title, caption, factor) 51 | 52 | 53 | def imshow_one_batch(self, loader, classes=None, factor=1): 54 | dataiter = iter(loader) 55 | images, labels = dataiter.next() 56 | self.imshow(make_grid(images,padding)) 57 | 58 | if classes: 59 | print(' '.join('%5s' % classes[labels[j]] 60 | for j in range(loader.batch_size))) 61 | else: 62 | print(' '.join('%5s' % labels[j] 63 | for j in range(loader.batch_size))) 64 | 65 | 66 | 67 | 68 | def plot(self, epoch, train_loss, val_loss,Des): 69 | ''' plot learning curve interactively with visdom ''' 70 | self.epoch_list.append(epoch) 71 | self.train_loss_list.append(train_loss) 72 | self.val_loss_list.append(val_loss) 73 | 74 | if not self.win: 75 | # send line plot 76 | # embed() 77 | self.win = self.vis.line( 78 | X=np.array(self.epoch_list), 79 | Y=np.array([[self.train_loss_list[-1], self.val_loss_list[-1]]]), 80 | opts=dict( 81 | title='Learning Curve (' + Des +')', 82 | xlabel='Epoch', 83 | ylabel='Loss', 84 | legend=['train_loss', 'val_loss'], 85 | #caption=Des 86 | )) 87 | # send text memo (configuration) 88 | # self.vis.text(str(Des)) 89 | else: 90 | self.vis.updateTrace( 91 | X=np.array(self.epoch_list[-2:]), 92 | Y=np.array(self.train_loss_list[-2:]), 93 | win=self.win, 94 | name='train_loss', 95 | ) 96 | self.vis.updateTrace( 97 | X=np.array(self.epoch_list[-2:]), 98 | Y=np.array(self.val_loss_list[-2:]), 99 | win=self.win, 100 | name='val_loss', 101 | ) 102 | 103 | 104 | def acc_plot(self, epoch, train_acc, val_acc, Des): 105 | ''' plot learning curve interactively with visdom ''' 106 | self.epoch_list2.append(epoch) 107 | self.train_acc_list.append(train_acc) 108 | self.val_acc_list.append(val_acc) 109 | 110 | if not self.win2: 111 | # send line plot 112 | # embed() 113 | self.win2 = self.vis.line( 114 | X=np.array(self.epoch_list2), 115 | Y=np.array([[self.train_acc_list[-1], self.val_acc_list[-1]]]), 116 | opts=dict( 117 | title='Accuracy Curve (' + Des +')', 118 | xlabel='Epoch', 119 | ylabel='Accuracy', 120 | legend=['train_accuracy', 'val_accuracy'] 121 | )) 122 | # send text memo (configuration) 123 | # self.vis.text(str(self.config)) 124 | else: 125 | self.vis.updateTrace( 126 | X=np.array(self.epoch_list2[-2:]), 127 | Y=np.array(self.train_acc_list[-2:]), 128 | win=self.win2, 129 | name='train_accuracy', 130 | ) 131 | self.vis.updateTrace( 132 | X=np.array(self.epoch_list2[-2:]), 133 | Y=np.array(self.val_acc_list[-2:]), 134 | win=self.win2, 135 | name='val_accuracy', 136 | ) 137 | 138 | 139 | def plot2(self, epoch, train_loss, val_loss,Des, win): 140 | ''' plot learning curve interactively with visdom ''' 141 | self.epoch_list.append(epoch) 142 | self.train_loss_list.append(train_loss) 143 | self.val_loss_list.append(val_loss) 144 | 145 | if not self.win: 146 | self.win = win 147 | # send line plot 148 | # embed() 149 | #self.win = self.vis.line( 150 | # X=np.array(self.epoch_list), 151 | # Y=np.array([[self.train_loss_list[-1], self.val_loss_list[-1]]]), 152 | # opts=dict( 153 | # title='Learning Curve (' + Des +')', 154 | # xlabel='Epoch', 155 | # ylabel='Loss', 156 | # legend=['train_loss', 'val_loss'], 157 | # #caption=Des 158 | # )) 159 | ## send text memo (configuration) 160 | # self.vis.text(str(Des)) 161 | else: 162 | self.vis.updateTrace( 163 | X=np.array(self.epoch_list[-2:]), 164 | Y=np.array(self.train_loss_list[-2:]), 165 | win=self.win, 166 | name='train_loss2', 167 | ) 168 | self.vis.updateTrace( 169 | X=np.array(self.epoch_list[-2:]), 170 | Y=np.array(self.val_loss_list[-2:]), 171 | win=self.win, 172 | name='val_lossi2', 173 | ) 174 | 175 | -------------------------------------------------------------------------------- /Day06/CVAE/complements/CVAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CVAE/complements/CVAE.png -------------------------------------------------------------------------------- /Day06/CVAE/complements/KLD_analytic.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CVAE/complements/KLD_analytic.JPG -------------------------------------------------------------------------------- /Day06/CVAE/complements/KLD_analytic2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CVAE/complements/KLD_analytic2.JPG -------------------------------------------------------------------------------- /Day06/CVAE/complements/reconstruction_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CVAE/complements/reconstruction_loss.JPG -------------------------------------------------------------------------------- /Day06/CVAE/complements/total_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CVAE/complements/total_loss.JPG -------------------------------------------------------------------------------- /Day06/CycleGAN/CycleGAN_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2018-01-21T03:25:05.006127Z", 9 | "start_time": "2018-01-21T03:25:04.304912Z" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stdout", 15 | "output_type": "stream", 16 | "text": [ 17 | "Namespace(batch_size=1, dataset='horse2zebra', input_size=256, ngf=32, num_resnet=6)\n" 18 | ] 19 | } 20 | ], 21 | "source": [ 22 | "import torch\n", 23 | "from torchvision import transforms\n", 24 | "from dataset import DatasetFromFolder\n", 25 | "from model import Generator\n", 26 | "import utils\n", 27 | "import argparse\n", 28 | "import os\n", 29 | "\n", 30 | "parser = argparse.ArgumentParser()\n", 31 | "parser.add_argument('--dataset', required=False, default='horse2zebra', help='input dataset')\n", 32 | "parser.add_argument('--batch_size', type=int, default=1, help='test batch size')\n", 33 | "parser.add_argument('--ngf', type=int, default=32)\n", 34 | "parser.add_argument('--num_resnet', type=int, default=6, help='number of resnet blocks in generator')\n", 35 | "parser.add_argument('--input_size', type=int, default=256, help='input size')\n", 36 | "params = parser.parse_args([])\n", 37 | "print(params)\n", 38 | "\n", 39 | "# Directories for loading data and saving results\n", 40 | "data_dir = 'data/' + params.dataset + '/'\n", 41 | "save_dir = params.dataset + '_test_results/'\n", 42 | "model_dir = params.dataset + '_model/'\n", 43 | "\n", 44 | "if not os.path.exists(save_dir):\n", 45 | " os.mkdir(save_dir)\n", 46 | "if not os.path.exists(model_dir):\n", 47 | " os.mkdir(model_dir)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": { 54 | "ExecuteTime": { 55 | "end_time": "2018-01-21T03:25:22.753947Z", 56 | "start_time": "2018-01-21T03:25:22.734976Z" 57 | } 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "# Data pre-processing\n", 62 | "transform = transforms.Compose([transforms.Resize(params.input_size),\n", 63 | " transforms.ToTensor(),\n", 64 | " transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])\n", 65 | "\n", 66 | "# Test data\n", 67 | "test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform)\n", 68 | "test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A,\n", 69 | " batch_size=params.batch_size,\n", 70 | " shuffle=False)\n", 71 | "test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform)\n", 72 | "test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B,\n", 73 | " batch_size=params.batch_size,\n", 74 | " shuffle=False)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": { 81 | "ExecuteTime": { 82 | "end_time": "2018-01-21T03:25:23.007239Z", 83 | "start_time": "2018-01-21T03:25:23.002727Z" 84 | } 85 | }, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "'horse2zebra_model/'" 91 | ] 92 | }, 93 | "execution_count": 3, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "model_dir" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 5, 105 | "metadata": { 106 | "ExecuteTime": { 107 | "end_time": "2018-01-21T03:25:23.886357Z", 108 | "start_time": "2018-01-21T03:25:23.779878Z" 109 | } 110 | }, 111 | "outputs": [ 112 | { 113 | "ename": "FileNotFoundError", 114 | "evalue": "[Errno 2] No such file or directory: 'horse2zebra_model/generator_A_param.pkl'", 115 | "output_type": "error", 116 | "traceback": [ 117 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 118 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 119 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mG_A\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mG_B\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mG_A\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_dir\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'generator_A_param.pkl'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mG_B\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_dir\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'generator_B_param.pkl'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 120 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module)\u001b[0m\n\u001b[1;32m 364\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mversion_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m3\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpathlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0mnew_fd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 366\u001b[0;31m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 367\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 121 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'horse2zebra_model/generator_A_param.pkl'" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "# Load model\n", 127 | "G_A = Generator(3, params.ngf, 3, params.num_resnet)\n", 128 | "G_B = Generator(3, params.ngf, 3, params.num_resnet)\n", 129 | "G_A.cuda()\n", 130 | "G_B.cuda()\n", 131 | "G_A.load_state_dict(torch.load(model_dir + 'generator_A_param.pkl'))\n", 132 | "G_B.load_state_dict(torch.load(model_dir + 'generator_B_param.pkl'))" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "ExecuteTime": { 140 | "end_time": "2018-01-21T03:26:17.150207Z", 141 | "start_time": "2018-01-21T03:25:23.997942Z" 142 | } 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "# Test\n", 147 | "for i, real_A in enumerate(test_data_loader_A):\n", 148 | "\n", 149 | " # input image data\n", 150 | " real_A = Variable(real_A.cuda())\n", 151 | "\n", 152 | " # A -> B -> A\n", 153 | " fake_B = G_A(real_A)\n", 154 | " recon_A = G_B(fake_B)\n", 155 | "\n", 156 | " # Show result for test data\n", 157 | " utils.plot_test_result(real_A, fake_B, recon_A, i, save=True, save_dir=save_dir + 'AtoB/')\n", 158 | "\n", 159 | " print('%d images are generated.' % (i + 1))\n", 160 | "\n", 161 | "for i, real_B in enumerate(test_data_loader_B):\n", 162 | "\n", 163 | " # input image data\n", 164 | " real_B = Variable(real_B.cuda())\n", 165 | "\n", 166 | " # B -> A -> B\n", 167 | " fake_A = G_B(real_B)\n", 168 | " recon_B = G_A(fake_A)\n", 169 | "\n", 170 | " # Show result for test data\n", 171 | " utils.plot_test_result(real_B, fake_A, recon_B, i, save=True, save_dir=save_dir + 'BtoA/')\n", 172 | "\n", 173 | " print('%d images are generated.' % (i + 1))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "Python 3", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.6.8" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 2 205 | } 206 | -------------------------------------------------------------------------------- /Day06/CycleGAN/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### List of Datasets 3 | - ae_photos 4 | - apple2orange 5 | - cezanne2photo 6 | - cityscapes 7 | - facades 8 | - horse2zebra 9 | - iphone2dslr_flower 10 | - maps 11 | - monet2photo 12 | - summer2winter_yosemi 13 | - ukiyoe2photo 14 | - vangogh2photo 15 | -------------------------------------------------------------------------------- /Day06/CycleGAN/complements/concept.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CycleGAN/complements/concept.jpg -------------------------------------------------------------------------------- /Day06/CycleGAN/complements/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CycleGAN/complements/cover.jpg -------------------------------------------------------------------------------- /Day06/CycleGAN/complements/cycle_consistency.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CycleGAN/complements/cycle_consistency.JPG -------------------------------------------------------------------------------- /Day06/CycleGAN/complements/full_objectives.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CycleGAN/complements/full_objectives.JPG -------------------------------------------------------------------------------- /Day06/CycleGAN/complements/gan_loss.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/insujeon/Hello-Generative-Model/77a6740e34a02ab0515fb218c0167e326a4f7cd4/Day06/CycleGAN/complements/gan_loss.JPG -------------------------------------------------------------------------------- /Day06/CycleGAN/dataset.py: -------------------------------------------------------------------------------- 1 | # Custom dataset 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import os 5 | import random 6 | 7 | 8 | class DatasetFromFolder(data.Dataset): 9 | def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False): 10 | super(DatasetFromFolder, self).__init__() 11 | self.input_path = os.path.join(image_dir, subfolder) 12 | self.image_filenames = [x for x in sorted(os.listdir(self.input_path))] 13 | self.transform = transform 14 | 15 | self.resize_scale = resize_scale 16 | self.crop_size = crop_size 17 | self.fliplr = fliplr 18 | 19 | def __getitem__(self, index): 20 | # Load Image 21 | img_fn = os.path.join(self.input_path, self.image_filenames[index]) 22 | img = Image.open(img_fn).convert('RGB') 23 | 24 | # preprocessing 25 | if self.resize_scale: 26 | img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR) 27 | 28 | if self.crop_size: 29 | x = random.randint(0, self.resize_scale - self.crop_size + 1) 30 | y = random.randint(0, self.resize_scale - self.crop_size + 1) 31 | img = img.crop((x, y, x + self.crop_size, y + self.crop_size)) 32 | if self.fliplr: 33 | if random.random() < 0.5: 34 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 35 | 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img 40 | 41 | def __len__(self): 42 | return len(self.image_filenames) 43 | -------------------------------------------------------------------------------- /Day06/CycleGAN/logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | 6 | try: 7 | from StringIO import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.x 10 | 11 | 12 | class Logger(object): 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | self.writer.flush() 22 | 23 | def image_summary(self, tag, images, step): 24 | """Log a list of images.""" 25 | 26 | img_summaries = [] 27 | for i, img in enumerate(images): 28 | # Write the image to a string 29 | try: 30 | s = StringIO() 31 | except: 32 | s = BytesIO() 33 | scipy.misc.toimage(img).save(s, format="png") 34 | 35 | # Create an Image object 36 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 37 | height=img.shape[0], 38 | width=img.shape[1]) 39 | # Create a Summary value 40 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 41 | 42 | # Create and write Summary 43 | summary = tf.Summary(value=img_summaries) 44 | self.writer.add_summary(summary, step) 45 | self.writer.flush() 46 | 47 | def histo_summary(self, tag, values, step, bins=1000): 48 | """Log a histogram of the tensor of values.""" 49 | 50 | # Create a histogram using numpy 51 | counts, bin_edges = np.histogram(values, bins=bins) 52 | 53 | # Fill the fields of the histogram proto 54 | hist = tf.HistogramProto() 55 | hist.min = float(np.min(values)) 56 | hist.max = float(np.max(values)) 57 | hist.num = int(np.prod(values.shape)) 58 | hist.sum = float(np.sum(values)) 59 | hist.sum_squares = float(np.sum(values ** 2)) 60 | 61 | # Drop the start of the first bin 62 | bin_edges = bin_edges[1:] 63 | 64 | # Add bin edges and counts 65 | for edge in bin_edges: 66 | hist.bucket_limit.append(edge) 67 | for c in counts: 68 | hist.bucket.append(c) 69 | 70 | # Create and write Summary 71 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 72 | self.writer.add_summary(summary, step) 73 | self.writer.flush() -------------------------------------------------------------------------------- /Day06/CycleGAN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ConvBlock(torch.nn.Module): 4 | def __init__(self, input_size, output_size, kernel_size=3, stride=2, padding=1, activation='relu', batch_norm=True): 5 | super(ConvBlock, self).__init__() 6 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding) 7 | self.batch_norm = batch_norm 8 | self.bn = torch.nn.InstanceNorm2d(output_size) 9 | self.activation = activation 10 | self.relu = torch.nn.ReLU(True) 11 | self.lrelu = torch.nn.LeakyReLU(0.2, True) 12 | self.tanh = torch.nn.Tanh() 13 | 14 | def forward(self, x): 15 | if self.batch_norm: 16 | out = self.bn(self.conv(x)) 17 | else: 18 | out = self.conv(x) 19 | 20 | if self.activation == 'relu': 21 | return self.relu(out) 22 | elif self.activation == 'lrelu': 23 | return self.lrelu(out) 24 | elif self.activation == 'tanh': 25 | return self.tanh(out) 26 | elif self.activation == 'no_act': 27 | return out 28 | 29 | 30 | class DeconvBlock(torch.nn.Module): 31 | def __init__(self, input_size, output_size, kernel_size=3, stride=2, padding=1, output_padding=1, activation='relu', batch_norm=True): 32 | super(DeconvBlock, self).__init__() 33 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, output_padding) 34 | self.batch_norm = batch_norm 35 | self.bn = torch.nn.InstanceNorm2d(output_size) 36 | self.activation = activation 37 | self.relu = torch.nn.ReLU(True) 38 | 39 | def forward(self, x): 40 | if self.batch_norm: 41 | out = self.bn(self.deconv(x)) 42 | else: 43 | out = self.deconv(x) 44 | 45 | if self.activation == 'relu': 46 | return self.relu(out) 47 | elif self.activation == 'lrelu': 48 | return self.lrelu(out) 49 | elif self.activation == 'tanh': 50 | return self.tanh(out) 51 | elif self.activation == 'no_act': 52 | return out 53 | 54 | 55 | class ResnetBlock(torch.nn.Module): 56 | def __init__(self, num_filter, kernel_size=3, stride=1, padding=0): 57 | super(ResnetBlock, self).__init__() 58 | conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding) 59 | conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding) 60 | bn = torch.nn.InstanceNorm2d(num_filter) 61 | relu = torch.nn.ReLU(True) 62 | pad = torch.nn.ReflectionPad2d(1) 63 | 64 | self.resnet_block = torch.nn.Sequential( 65 | pad, 66 | conv1, 67 | bn, 68 | relu, 69 | pad, 70 | conv2, 71 | bn 72 | ) 73 | 74 | def forward(self, x): 75 | out = self.resnet_block(x) 76 | return out 77 | 78 | 79 | class Generator(torch.nn.Module): 80 | def __init__(self, input_dim, num_filter, output_dim, num_resnet): 81 | super(Generator, self).__init__() 82 | 83 | # Reflection padding 84 | self.pad = torch.nn.ReflectionPad2d(3) 85 | # Encoder 86 | self.conv1 = ConvBlock(input_dim, num_filter, kernel_size=7, stride=1, padding=0) 87 | self.conv2 = ConvBlock(num_filter, num_filter * 2) 88 | self.conv3 = ConvBlock(num_filter * 2, num_filter * 4) 89 | # Resnet blocks 90 | self.resnet_blocks = [] 91 | for i in range(num_resnet): 92 | self.resnet_blocks.append(ResnetBlock(num_filter * 4)) 93 | self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks) 94 | # Decoder 95 | self.deconv1 = DeconvBlock(num_filter * 4, num_filter * 2) 96 | self.deconv2 = DeconvBlock(num_filter * 2, num_filter) 97 | self.deconv3 = ConvBlock(num_filter, output_dim, kernel_size=7, stride=1, padding=0, activation='tanh', batch_norm=False) 98 | 99 | def forward(self, x): 100 | # Encoder 101 | enc1 = self.conv1(self.pad(x)) 102 | enc2 = self.conv2(enc1) 103 | enc3 = self.conv3(enc2) 104 | # Resnet blocks 105 | res = self.resnet_blocks(enc3) 106 | # Decoder 107 | dec1 = self.deconv1(res) 108 | dec2 = self.deconv2(dec1) 109 | out = self.deconv3(self.pad(dec2)) 110 | return out 111 | 112 | def normal_weight_init(self, mean=0.0, std=0.02): 113 | for m in self.children(): 114 | if isinstance(m, ConvBlock): 115 | torch.nn.init.normal_(m.conv.weight, mean, std) 116 | if isinstance(m, DeconvBlock): 117 | torch.nn.init.normal_(m.deconv.weight, mean, std) 118 | if isinstance(m, ResnetBlock): 119 | torch.nn.init.normal_(m.conv.weight, mean, std) 120 | torch.nn.init.constant(m.conv.bias, 0) 121 | 122 | 123 | class Discriminator(torch.nn.Module): 124 | def __init__(self, input_dim, num_filter, output_dim): 125 | super(Discriminator, self).__init__() 126 | 127 | conv1 = ConvBlock(input_dim, num_filter, kernel_size=4, stride=2, padding=1, activation='lrelu', batch_norm=False) 128 | conv2 = ConvBlock(num_filter, num_filter * 2, kernel_size=4, stride=2, padding=1, activation='lrelu') 129 | conv3 = ConvBlock(num_filter * 2, num_filter * 4, kernel_size=4, stride=2, padding=1, activation='lrelu') 130 | conv4 = ConvBlock(num_filter * 4, num_filter * 8, kernel_size=4, stride=1, padding=1, activation='lrelu') 131 | conv5 = ConvBlock(num_filter * 8, output_dim, kernel_size=4, stride=1, padding=1, activation='no_act', batch_norm=False) 132 | 133 | self.conv_blocks = torch.nn.Sequential( 134 | conv1, 135 | conv2, 136 | conv3, 137 | conv4, 138 | conv5 139 | ) 140 | 141 | def forward(self, x): 142 | out = self.conv_blocks(x) 143 | return out 144 | 145 | def normal_weight_init(self, mean=0.0, std=0.02): 146 | for m in self.children(): 147 | if isinstance(m, ConvBlock): 148 | torch.nn.init.normal_(m.conv.weight, mean, std) 149 | -------------------------------------------------------------------------------- /Day06/CycleGAN/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | import imageio 7 | import random 8 | 9 | 10 | # For logger 11 | def to_np(x): 12 | return x.data.cpu().numpy() 13 | 14 | 15 | def to_var(x): 16 | if torch.cuda.is_available(): 17 | x = x.cuda() 18 | return Variable(x) 19 | 20 | 21 | # De-normalization 22 | def denorm(x): 23 | out = (x + 1) / 2 24 | return out.clamp(0, 1) 25 | 26 | 27 | # Plot losses 28 | def plot_loss(avg_losses, num_epochs, save=False, save_dir='results/', show=False): 29 | fig, ax = plt.subplots() 30 | ax.set_xlim(0, num_epochs) 31 | temp = 0.0 32 | for i in range(len(avg_losses)): 33 | temp = max(np.max(avg_losses[i]), temp) 34 | ax.set_ylim(0, temp*1.1) 35 | plt.xlabel('# of Epochs') 36 | plt.ylabel('Loss values') 37 | 38 | plt.plot(avg_losses[0], label='D_A') 39 | plt.plot(avg_losses[1], label='D_B') 40 | plt.plot(avg_losses[2], label='G_A') 41 | plt.plot(avg_losses[3], label='G_B') 42 | plt.plot(avg_losses[4], label='cycle_A') 43 | plt.plot(avg_losses[5], label='cycle_B') 44 | plt.legend() 45 | 46 | # save figure 47 | if save: 48 | if not os.path.exists(save_dir): 49 | os.mkdir(save_dir) 50 | save_fn = save_dir + 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png' 51 | plt.savefig(save_fn) 52 | 53 | if show: 54 | plt.show() 55 | else: 56 | plt.close() 57 | 58 | 59 | def plot_train_result(real_image, gen_image, recon_image, epoch, save=False, save_dir='results/', show=False, fig_size=(5, 5)): 60 | fig, axes = plt.subplots(2, 3, figsize=fig_size) 61 | 62 | imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]), 63 | to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])] 64 | for ax, img in zip(axes.flatten(), imgs): 65 | ax.axis('off') 66 | ax.set_adjustable('box-forced') 67 | # Scale to 0-255 68 | img = img.squeeze() 69 | img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8) 70 | ax.imshow(img, cmap=None, aspect='equal') 71 | plt.subplots_adjust(wspace=0, hspace=0) 72 | 73 | title = 'Epoch {0}'.format(epoch + 1) 74 | fig.text(0.5, 0.04, title, ha='center') 75 | 76 | # save figure 77 | if save: 78 | if not os.path.exists(save_dir): 79 | os.mkdir(save_dir) 80 | 81 | save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch+1) + '.png' 82 | plt.savefig(save_fn) 83 | 84 | if show: 85 | plt.show() 86 | else: 87 | plt.close() 88 | 89 | 90 | def plot_test_result(real_image, gen_image, recon_image, index, save=False, save_dir='results/', show=False): 91 | fig_size = (real_image.size(2) * 3 / 100, real_image.size(3) / 100) 92 | fig, axes = plt.subplots(1, 3, figsize=fig_size) 93 | 94 | imgs = [to_np(real_image), to_np(gen_image), to_np(recon_image)] 95 | for ax, img in zip(axes.flatten(), imgs): 96 | ax.axis('off') 97 | ax.set_adjustable('box-forced') 98 | # Scale to 0-255 99 | img = img.squeeze() 100 | img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8) 101 | ax.imshow(img, cmap=None, aspect='equal') 102 | plt.subplots_adjust(wspace=0, hspace=0) 103 | 104 | # save figure 105 | if save: 106 | if not os.path.exists(save_dir): 107 | os.mkdir(save_dir) 108 | 109 | save_fn = save_dir + 'Test_result_{:d}'.format(index + 1) + '.png' 110 | fig.subplots_adjust(bottom=0) 111 | fig.subplots_adjust(top=1) 112 | fig.subplots_adjust(right=1) 113 | fig.subplots_adjust(left=0) 114 | plt.savefig(save_fn) 115 | 116 | if show: 117 | plt.show() 118 | else: 119 | plt.close() 120 | 121 | 122 | # Make gif 123 | def make_gif(dataset, num_epochs, save_dir='results/'): 124 | gen_image_plots = [] 125 | for epoch in range(num_epochs): 126 | # plot for generating gif 127 | save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png' 128 | gen_image_plots.append(imageio.imread(save_fn)) 129 | 130 | imageio.mimsave(save_dir + dataset + '_CycleGAN_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5) 131 | 132 | 133 | class ImagePool(): 134 | def __init__(self, pool_size): 135 | self.pool_size = pool_size 136 | if self.pool_size > 0: 137 | self.num_imgs = 0 138 | self.images = [] 139 | 140 | def query(self, images): 141 | if self.pool_size == 0: 142 | return images 143 | return_images = [] 144 | for image in images.data: 145 | image = torch.unsqueeze(image, 0) 146 | if self.num_imgs < self.pool_size: 147 | self.num_imgs = self.num_imgs + 1 148 | self.images.append(image) 149 | return_images.append(image) 150 | else: 151 | p = random.uniform(0, 1) 152 | if p > 0.5: 153 | random_id = random.randint(0, self.pool_size-1) 154 | tmp = self.images[random_id].clone() 155 | self.images[random_id] = image 156 | return_images.append(tmp) 157 | else: 158 | return_images.append(image) 159 | return_images = Variable(torch.cat(return_images, 0)) 160 | return return_images 161 | -------------------------------------------------------------------------------- /Day07/pixelCNN/.ipynb_checkpoints/IconPixelCNN-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 16, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "from matplotlib import pyplot as plt\n", 13 | "\n", 14 | "def show_as_image(binary_image, figsize=(10, 5)):\n", 15 | " plt.figure(figsize=figsize)\n", 16 | " plt.imshow(binary_image, cmap='gray')\n", 17 | " plt.xticks([]); plt.yticks([])\n", 18 | "\n", 19 | "%matplotlib inline" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 17, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "mkdir: cannot create directory ‘./data’: File exists\n", 32 | "mkdir: cannot create directory ‘./data/all’: File exists\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "# ! mkdir ./data\n", 38 | "# ! mkdir ./data/all\n", 39 | "# git clone git@github.com:encharm/Font-Awesome-SVG-PNG.git\n", 40 | "# ! cp ./Font-Awesome-SVG-PNG/black/png/16/* ./data/all" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 18, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "data": { 50 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA5IAAAOSCAYAAAAGce8SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3d2xNDeQINamgibIk5URclk+jIzQ+/pw9cDYmNlmcYjETQCZqHPeyPhuNwp/1RmozPrj5+fnAwAAAKP+j9MNAAAAoBeBJAAAACECSQAAAEIEkgAAAIQIJAEAAAgRSAIAABAikAQAACBEIAkAAECIQBIAAICQPyP/+I8//vhZ1RAAAACO+39+fn7+73/7R04kAQAA+F/+z5F/JJAEAAAgRCAJAABAiEASAACAkFCxHd7p5+ffayz98ccfG1oC+zzN+5XzfPf3wW1G7lWfj3XM7xhjfuO239ROJAEAAAgRSAIAABAikAQAACBEjiT/m9ue3V6lQi7OSqPXN6JrHzzJzI3J7OOdurb789k7F/XTerN9/P13XdawObVf5z5f2fau47lb1h71+dTtcyeSAAAAhAgkAQAACBFIAgAAECKQBAAAIESxnYOyEv4r6JQYvFPFMa5YIKJCv3x7atPTtWT1Z8U+6Fxo4snK9XhTX1Xcz1e2afSzd+9lN8+p0/MJyOFEEgAAgBCBJAAAACECSQAAAEIEkgAAAIQotnNQl2TzkXZWvJaKhQp2F7GY7YPdBSO6FPd4MtvO78+v2Adwys7iLBUK68BvVCyil6nrWnvD3uJEEgAAgBCBJAAAACECSQAAAEKO5Uhm5W6tfE676jPKWc9XV8whHNW57d9O58at/q6da3a1zL66qV/YLys3l/X50Kdl5mndNH8yr6VCPQI4wYkkAAAAIQJJAAAAQgSSAAAAhAgkAQAACNlSbKdrcnZmEZTTBVVOyHqhdNf5c8JMX43Ozd0v0d35QvIT3/dtpM9v3zNud3sRtC5ueyE4f3e6oOPoPVSRnjG7f0N33SNOFA5zIgkAAECIQBIAAIAQgSQAAAAhAkkAAABC0ovtVC0UcFORg9lk8A7Jw13GINPpPn/SYa5QR8W50HUvqdiXXShc8my2UEmXNVThfjXy+bNF7Ha7fT18mx2DlWO3el5kFvJzIgkAAECIQBIAAIAQgSQAAAAh6TmSo7Je+soz/Xn+ZbVPOucezLbdXIS/7H6p9u1W7sPGqqYTL1xf6XQ7b8qVfXK6f9/AiSQAAAAhAkkAAABCBJIAAACECCQBAAAI+XWxnYov8qzwfazVJYG6Szuz3LTOdhfb6Nx3ndt+sy7jMrLWdr90fuTl8Z339y5zgzwjc7qz2fU42y8rCxCuLm6YuXc5kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhoWI7/+N//I/Pf/zHf4S/ZCSps2LCb5dE+t2JwtTtu5kCFVXt7uORvssswGON8vnkFXDpvNafWB8wzlrI1bU/T7TbiSQAAAAhAkkAAABCBJIAAACE/BHJq/jjjz/uSsK4WIX8kpncnxMvU71ZZt7U7lznimO8Mg/tbddbQWaf66sxs/3UZX2YU2N291OnF7yvctt86tDnTyr8Ph/0//78/Pxf//aPnEgCAAAQIpAEAAAgRCAJAABAiEASAACAkD9PN4B9iiTv/ree2nhbgvhOmf35tkIzT2YLDnW5Psi2s7iOe8U7jdzn3jg3brrmrvfQzN9bVfvAiSQAAAAhAkkAAABCBJIAAACEyJEsZPSZ6Kzn3m96fp5xs3l+K7+vs5G+6pTv8F/dlrO8ss+/P1s/rf++rD6uuhbNqbXftzN/94Qu7aQ3J5IAAACECCQBAAAIEUgCAAAQIpAEAAAgRLGdQjITo0c+q0sidpd23kSfjxspiHFTf950LSvpp2cr73OjxVO6jk3XdlelP1lttmBdp7npRBIAAIAQgSQAAAAhAkkAAABCBJIAAACEKLbTUKckXHgb6xP+0871YO0B1d22TzmRBAAAIEQgCQAAQIhAEgAAgBCBJAAAACECSQAAAEIEkgAAAIQIJAEAAAgRSAIAABDy5+kG8Hs/Pz9Tf3fbS1GBuuxTfD7z86Aic5NOblp7T6zHM5xIAgAAECKQBAAAIEQgCQAAQIhAEgAAgJDSxXYyE4NvSsL97pfZa3vq39P9NDrmI+0c+azM6939faylOMy8zL2l4j61m7n4TruLo5gvd3sa3y77a8V2+o35FyeSAAAAhAgkAQAACBFIAgAAEFImR3J1LkBWXmEXs8+TV30Gm3Xk4fzl5tzjqmb7/OZ9KnM9dp6L3+0cuZbR6x2Zd7e/vP2J30lrdenPN879ETP9svK3eOb3/YYTSQAAAEIEkgAAAIQIJAEAAAgRSAIAABBSptjOk90v8qRPMjhjKs770TZ1nYudX/rcxU19d3PhoN8Y6Zed/6aqlYWCbtq3ZvtktgDY7Lh07d/bdN4Tvq1ex04kAQAACBFIAgAAECKQBAAAIEQgCQAAQEjpYjvsdzoZPPOzdyetV0iSvzlBvEL/PpldM12ur6LT+xTrfY/fSMGI0aISWYVRKtjdzi771u49YuSzRv7N7gJHXeZ+hd9zM/0y2+7McVndd04kAQAACBFIAgAAECKQBAAAIORYjuTpZ7BvetFuptN9cPvL6jOdXkO7dVmzo7kNs/ld6Jc3GNnfdv4bnlXYt2b3V4g4fc+uWmvBiSQAAAAhAkkAAABCBJIAAACECCQBAAAIOVZsZ4Tk6P286Ju3WplIn/VSdGvvL133qdl7WtUXUa80sx5H1+zM2mNchQI81b6/k9OFiqqO1UzhrsxrmS0ctro/nUgCAAAQIpAEAAAgRCAJAABAiEASAACAkC3FdrokrSts8c5rhiejBThmC36sLO5zu679MlvUpev1/sZsYYlV/+aE00VPMvl9RSdZa232vp651lf/tnAiCQAAQIhAEgAAgBCBJAAAACFbciRhlLwJKst8ufn3v5v9bLjRTA7xaC7Q7Jo9rUs76W133vb391WoF7Byre1es6v7zokkAAAAIQJJAAAAQgSSAAAAhAgkAQAACFFsp7iVScezRUGo4W2FF6rOw5UFP2Y/e7eVRYEUIZq3u2hGppG27/w3VXW5D6ycZ533iA7tXN2mkftcBR3aeWL+OJEEAAAgRCAJAABAiEASAACAEIEkAAAAIYrtXGA04bdDUje/0yEZfFTX+amI1bPZoic39dPseuy8jmft3Mu6FKwZdfo+UGHNdilUdlqXeW786nIiCQAAQIhAEgAAgBCBJAAAACFbciS75B943hrWe+M66/zCbDhh5DfC7Jr5/uyKv0cyrfwN1mXfkrs+7nQf3BQzrG736bH6fJxIAgAAECSQBAAAIEQgCQAAQIhAEgAAgJAtxXaeVEgQ7WBlkQ6J5vwT82C90y8Nz7Ty5d837VNd2819Zgua3D6HO+/D1POGdeZEEgAAgBCBJAAAACECSQAAAEIEkgAAAIQcK7bDvKwiHZ2SeRlzWxL3m9w0LrMFBkY/Cz4fhVGyvW2tve16R9dLl3WVVXhyty7tHOVEEgAAgBCBJAAAACECSQAAAELkSF7gtuetmWcuUJW5yedjHsAp1h4rOJEEAAAgRCAJAABAiEASAACAEIEkAAAAIQJJAAAAQgSSAAAAhAgkAQAACBFIAgAAEPLn6QYA9fz8/KR9lpcg809G5pn5AzHf66rzGpq9F52+5q7thignkgAAAIQIJAEAAAgRSAIAABAikAQAACBEsZ0FMguVPNmdjL3yer6vZXXfdbFyjHf3cZeCKl3nXoW+GzHbv09/t/Oau86LE7rMxZuMzM/TayjTU7ufrq/iuh1tOzVljVXXtfdPnEgCAAAQIpAEAAAgRCAJAABAiEASAACAEMV2gCmdCwd0aeeImwpp0NtN62qlCuuz674xOsdOX8vuAkDW3riRudFlrE7P88/HiSQAAABBAkkAAABCBJIAAACEvDZHsstL0p94Fp5Ru+fKbN5k13wdxs3ODfOgN+O3Vmau+vffVRi77zZ0uVeMjsts2+2dfWXmaFZYs04kAQAACBFIAgAAECKQBAAAIEQgCQAAQMhriu3MJJ93SermWWZC8+kX2K6UWaxh5PO79tM/mb2+2/uFvVav4ywV2nT7fTxrb/EbaL2d66HC2ttt93zN+r7Rz/ke0xNr1okkAAAAIQJJAAAAQgSSAAAAhAgkAQAACLmy2E5WQrGk8vsZ4/XJ2W9M8H+7zDEf+awK63ikwMls0a7M68sqHNaluA+5Y/X9dxXWXmez/Wet9XXb2DmRBAAAIEQgCQAAQIhAEgAAgJD0HMnZZ38rPCd++ln/zn1Q8ZnvlblGFa83U2Z+1+1mc4be2Fe3WJlDPDt/nv5utp3W+v2y8iZPvACdsZxs5q3MAb/pd/bn40QSAACAIIEkAAAAIQJJAAAAQgSSAAAAhKQX25k1mrB9Otm1otl+uqkPgDoyCw5V3KdG7lcV251p5f2Z9fyWgnwr72lV91wnkgAAAIQIJAEAAAgRSAIAABAikAQAACAkvdhOZjJoZhKphPD3WTnmVZOes9yWDL7T2673jUbWx2gBuYr3ptG200Pn31Ijba84X0f7/PvfnW43v/M9fk/zYHeRudXrw4kkAAAAIQJJAAAAQgSSAAAAhKTnSD7ZnTf1tmfMM3NQ9d2zm/tldn2O9slIzgA1zI5NxfWx8loq5P2PfHbmtVi3fWSNVcV1PXq/qjhf1RB4n4prKJsTSQAAAEIEkgAAAIQIJAEAAAgRSAIAABCypdjOk6wCHDclslZIuq7QhiyZ13JTv3xbfW039x1UVnHt3XTPriBzjDuMzeoCed+fn/U5//T/duowvpWNxC1ZhQpPz5UIJ5IAAACECCQBAAAIEUgCAAAQIpAEAAAg5FixnW9PyadPyaaShWHc6LrKomjWO2UVqMj6fuowNn9XYb+r0IYRK9s5Mjdnf4eu7t/ZtpM3NrO/r24bFyeSAAAAhAgkAQAACBFIAgAAEFImR/JJl2f44Y06P+e/O3d0pcx9MutFyZl9+bYXrnemf/cb2cuMy7isl87vZoxryBqHTvc9J5IAAACECCQBAAAIEUgCAAAQIpAEAAAgpHSxna4kPc/Td+tlFVTJ/L4KurTztNl+6jqnzAu6MWfnjexTFfq3YsEf8lSYY6OcSAIAABAikAQAACBEIAkAAECIQBIAAIAQxXaAv+mU6E0P5hRQSec9qXPbuYsTSQAAAEIEkgAAAIQIJAEAAAgRSAIAABAikAQAACBEIAkAAECIQBIAAIAQgSQAAAAhf55uAHf4+fn52//zwlzgv/O0b3yzjwArjOw/mexlz1aOgz5/9t3nv+knJ5IAAACECCQBAAAIEUgCAAAQIpAEAAAgpHSxncwEXAm3eUbHZSaZd/Szvz9rd9L8arPz9bZ++K+sYejn5j1p1E17V9Z47u6TCvPwdFHCCn3wdL23/57b6UTfOZEEAAAgRCAJAABAiEASAACAkGM5kruf4735xderc0lnP79rf9Kb/Ioxb8xRAn4ns9ZBB6P7ZMVrznzpPOdlxjGZ89WJJAAAACECSQAAAEIEkgAAAIQIJAEAAAjZUmynQrGWkTZUTExeWSTo6bNHXpi78qW6s59TYazInYtZzI39Msfz5kJp/KXr+FUssJIp87fb7X3FPHNj/2/9TE4kAQAACBFIAgAAECKQBAAAIGRLjuSI1TkSI3l+XWRdy2gew87c0dFruWk8n3TNGeqi63wxL56tzNumrtl7U8V6CLMy97LvfuiS6ziy/kevZXedhrflgK+smVJxbj6pMOaZn+9EEgAAgBCBJAAAACECSQAAAEIEkgAAAISkF9uZLZaS+fkjnz2SeL27gMPo92UVCpj9Pvh85gtbcL/MfeSmQhPMG5k/s3OsS7Gmir+vVrdpVoWCJt/edi/cvedXWLMr513VfcqJJAAAACECSQAAAEIEkgAAAIQIJAEAAAhJL7aT6fbEef6uc5I1z2aLMwD8LysLwVW9f+wuGDOzVytYs16X37S3F1TbuR66jPnn40QSAACAIIEkAAAAIQJJAAAAQkrnSL7N6PPPWc9JV3zeevblxrflTVQcm5Xz7rbxA8i2+74wsldXvFc9qdBO976xcdidHzj7u3P193XhRBIAAIAQgSQAAAAhAkkAAABCBJIAAACElC62M5uYXCGpeqWVfdDpJajslVVk4bZEc2C9lfuG+14fbyxY02UudmnnrJnfQLNzs1NfOpEEAAAgRCAJAABAiEASAACAEIEkAAAAIenFdkYToWcLd2QloHYp2rO7wMBps31eYax4VnGewW+Y0/v33JHfFqNtyiocRh+ri57YE/q6qXDgiTY5kQQAACBEIAkAAECIQBIAAICQ9BzJWatfCFzxWeZvFV5cejpXZLQPvtvZYXwjTo/D7fRvH14Wzz95W0797rXQZe2N1ubooGL/8qzLvFs9p5xIAgAAECKQBAAAIEQgCQAAQIhAEgAAgJAtxXZmE1J3J63uTnKuUFwHblcx+b2Cii+UHzXyd/bJPqzRZyNF5TIL4tw0DrcX5BvRtQ/euHfPjlWFvnIiCQAAQIhAEgAAgBCBJAAAACECSQAAAEK2FNt5MpIgmpkYfDohtUJhnQ6J1rPXe3p8qcNc6KNrMQg4YbRgVdY6umkvXX0tHfqqQxvfYGQcRseqwj3TiSQAAAAhAkkAAABCBJIAAACEHMuRHHHT89wjuQ278yFv6l/+UnFMK7aJmrrMlS7trEBfraV/oa435OY6kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhpYvt3G5lkmyFBFwAAHiDN/72diIJAABAiEASAACAEIEkAAAAIQJJAAAAQgSSAAAAhAgkAQAACBFIAgAAECKQBAAAIOTP0w0A2O3n5+d0E5Z640uRIdPTHmFd8fmsvX+YY32MzIM3jKcTSQAAAEIEkgAAAIQIJAEAAAgRSAIAABBSuthOZkLzGxJe+XdZc+qm+TTbJ5374Knt3/1Q8foUAIFzOuwR/M7u++H3941+/8z33V5kbpR1m8uJJAAAACECSQAAAEIEkgAAAIQIJAEAAAhJL7ZTNZl3pF1ZCbhV+6Cirn2+MiE+08p+6dIH1LVzX6a3lXNltIhVhf30ZhXX+uo2jXz+09xQOHDMynW1e4+oyokkAAAAIQJJAAAAQgSSAAAAhKTnSGbKfHZ75LnllS8bvv059N26PPfuBdb7X/CcZTRvancbuspaQ5/P+bkxYrbdFca8Qv9m9UPF+dN1T/yNmd9gn8/aax75vsz1OHsts79JZq+ly++rkc9eqeKcPsGJJAAAACECSQAAAEIEkgAAAIQIJAEAAAgpXWwn08oE6lmn23Bbwu+M0T6oOH92u70POlxflzW7uu8qFrG6qfBChaIn5KlQMKaLmSI2J+y8X81e72ibKvQn85xIAgAAECKQBAAAIEQgCQAAQIhAEgAAgJD2xXZmiwI8/Zvvz+pSGOF2FQoFZLVh9ZxStGKtlePXeexOt71zcZjdhYNOj9WsN957V17zyG+gld+3ex4q/FJX1lzosnfeFls4kQQAACBEIAkAAECIQBIAAICQ9BzJzs/5cpfT+ZCrdWnnrNMvnR/NIZptZ5cXX3/rMu9uy0N5m6yxMuZQW9c1OvIbYfb3QKc+cSIJAABAiEASAACAEIEkAAAAIQJJAAAAQtKL7QC9zCZ1ZxaVmf2cTgnpHXUprDMqq6BB5svczeHzMsezi7dd3xvHOMvqfsr6/N176cp+6fR7x4kkAAAAIQJJAAAAQgSSAAAAhAgkAQAACCldbEciNL8xm6z8tqT8zKTum/pp9FqyCrY8fV9WcZjdVq+hkb4jt1+6zsVMN13zyLXcvq5G9qnd98IKc2x3Gypc82md+8CJJAAAACECSQAAAEIEkgAAAISUzpFc+czwyPPrnZ9ZvklmvtVs/kNWTlaXOTWSN3l7/syTleNXMTe3y3zt0k76qDCnVual7t5bTu9ls7UAbr/XP1k5Vivvc6v7fOY3UGYNiqqcSAIAABAikAQAACBEIAkAAECIQBIAAICQY8V2dhe7OZ3o/eS2hNuORhOhK86f3W7vgw7X94bE/bfpUthi9zxbWWjmbd7Yd7sLo91k9/Xd3p+3cyIJAABAiEASAACAEIEkAAAAIcdyJEfszlla+Zx2h/yrE2b7fObFsKMyP+ttL7Bf6Y15GyvnObm6jNVIOyvM/W+76yqsNLt3V51TWSqMX5d1PMKc2m9k/nTZc0c5kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhx4rtjCSWdil6cvK7+Mton2fNqapjXLVdMFtgpOuc7tLuCu18W+GwJxXGgb/rsm+ZP+u9bU8a5UQSAACAEIEkAAAAIQJJAAAAQgSSAAAAhBwrtjNC8jDZzCk+n7Gk+S6J9U/t7DLPu7ST/b7nRpf1yP1GCvDMzld74l3eUEjMiSQAAAAhAkkAAABCBJIAAACElM6RBFhBHgr0Ys1Smfl5v6wxvm2uOJEEAAAgRCAJAABAiEASAACAEIEkAAAAIQJJAAAAQgSSAAAAhAgkAQAACBFIAgAAECKQBAAAIOTP0w2AnX5+fqb+7o8//khuyRqz19dBlzEAOGnnfcC+fL/dvytun1Mj/dmpD5xIAgAAECKQBAAAIEQgCQAAQMiWHMnb89JW2v0s9U3Pbj9dy2zbvz+rSx90aeeTm/M93yhzPLvO69n9tcu+bM2O2X3PXinzPksN5lSe2b7s1AdOJAEAAAgRSAIAABAikAQAACBEIAkAAEDIlmI7s7oUGFhppPDCbFLuaBLw7X08Q9/1pvBLrt0F1TrcG2b3iNOFLqhr9b1+xOz3nV6PPOuy35hTdTmRBAAAIEQgCQAAQIhAEgAAgBCBJAAAACFbiu2sTAaXgMsJ5th9Vo5ph+IwozL33O/PGv2ckQI1u+8Ns2OcVezitgJgu4vIVJR1fRXGfHatn3bTb8zb1kuHOTW751e8ln/iRBIAAIAQgSQAAAAhAkkAAABCtuRIjhh9Hnjk2eKbnml/sjI/6KZ+GrH6Bc8d+jPzWm579n+n2/etTCO5MSvzgU7nQ/5GxTU6+30j98LM7xuRWe/hZqvvvTOfPfr5K/PER53OIe6yhnbvZW/MeXciCQAAQIhAEgAAgBCBJAAAACECSQAAAEK2FNvJTO6fTTDu8OLS1d6WzL+S+ZP7ORUT91d/1sz3VSgcUGGsZguHne7PLioUrVjppnvhyqI1pwu63Ob2vlpdCImanEgCAAAQIpAEAAAgRCAJAABAiEASAACAkC3FdkZkFlmY/b4uCb8rE7Y798tOu+drRSPFTEb/rqKKhREqrM/Mggrf/69inz8ZmfujY9Xlmm8qVFShYFSWlQVxRubw7jk9+30V9s5ZXdr5NhX27gr7shNJAAAAQgSSAAAAhAgkAQAACNmSI5n5DP/I88BZz9BXfS5958uGeXZT32Vey039wrPMMc7M3TptJN9zJE9r9Uu9T/fdyj1i9N5/ug9Gdd1PV+ZtZspsQ9ZndR3zz6d320fM/K7OnGOZufmZnEgCAAAQIpAEAAAgRCAJAABAiEASAACAkC3Fdkasfsls1xdfP8lq++jn3J5AfZOb5vnpeVfx5fGn++Q3ZguAZRYP2Dl+s4VfVs+7LoVQVlq5jm7vu4oq7tVv9LY+r3A/rvCbz4kkAAAAIQJJAAAAQgSSAAAAhAgkAQAACNlSbGe2yIIE6mdZxRIqJApnySzI8WQkoblif3ZZQyv7LrMPdhd5OT2nZovfVFgfFdqQ9f2n232binvgk5FCTCMqzJ8ufX5axX2Lv5wuVDYyN07MFSeSAAAAhAgkAQAACBFIAgAAELIlRxKqGHl+XY4CJ3SeY1m5XE9G8jQ7993byJV7NjLPK+RaZ43f6jW7+0XtK/PnZn+TVHhZ/U7uA2c4kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhiu0AfNYWsripCMBsUZCVfaBA1nr6c7+sYimz6yOzOMvp+TP6/SsL0ozsnStlfn+Fwj2n59STlQWOKl7v5+NEEgAAgCCBJAAAACECSQAAAEIEkgAAAIRsKbaTmSCa9VlVk1az3H5930aTyGf75fuz3ta/tzF+ud5WuIP1VhatmPn+E234VrFNTzqvz91tP91X5hS/5UQSAACAEIEkAAAAIQJJAAAAQrbkSMIJmc/+3/R8/mzuUVYe3E19yTjj3sfsWK0c4wrzp0Ibvs3e57Ly4Cr2Cb9TMT+5q9n12akPnEgCAAAQIpAEAAAgRCAJAABAiEASAACAEMV2eJVOCcy7ZPaJ/gU4yz5MJvMp12396UQSAACAEIEkAAAAIQJJAAAAQgSSAAAAhAgkAQAACBFIAgAAECKQBAAAIEQgCQAAQMifpxsQ9fPz87/999OLPb//zT/9u9Oe2tnVyv7VTwAwZ/Ye6n7FW3X53VlhjTqRBAAAIEQgCQAAQIhAEgAAgBCBJAAAACHtiu18G02IHSnSQx9dxu90wvbp7/+NLmPMvJH52XUevHHtdb7mnTLn9Mo+77I+V/ZBheubZT3mjt/OuTBaNLRCbONEEgAAgBCBJAAAACECSQAAAELK5Eje/iz3G69v5llt/ZT/+V1VePafMafHavW6mm1DVxX68226zp/ZdptP+1Xt86w51HUN/cZIH6wedyeSAAAAhAgkAQAACBFIAgAAECKQBAAAIKRMsZ2RF21mGklI7VJwwMujxykMQFVd5uZIOzP3ltnPqrh/n/7+J2+8D5ymz+fX5+6+qzBWWftGhWvJdNv1ZBiNpTLvRU4kAQAACBFIAgAAECKQBAAAIEQgCQAAQEiZYjuSZoFbnS4clml3YbRZFQvbrFSxuFBnK+f56LhUXFezzMX99Pldqu4HTiQBAAAIEUgCAAAQIpAEAAAgpEyOZAVVnz8GauiyR3y3s2quTMUXkFd0ex+M5CNWncPfdrdzNpcz6+9un5sVzOY/G5s++8a3zHavngdOJAEAAAgRSAIAABAikAQAACBEIAkAAEBImWI7XV5yDZ3MFrEYSe4fLQCwsvADa2WOwWzxgNm5MTvvTjPvx2WsAyMLAAAa5UlEQVTtZSOfDd1V3O8+n/m1pvBTDU4kAQAACBFIAgAAECKQBAAAIEQgCQAAQMiWYjtZibRPn/XGIh1ZCdO3913mtYzMu9vdNDee3Hx9swVH4JSV8zOrAJg1RCe33eNuu56unEgCAAAQIpAEAAAgRCAJAABAyJYcyVlZL53ubHcORte8yd1t7JJzVnENrZ5jHeZrF6NzenaerXwxfMX1OKLrHrxaZh90nRsQYZ6zgxNJAAAAQgSSAAAAhAgkAQAACBFIAgAAEJJebKdzUQCJybDebcV9mGccxqwsSlRBxeJlKwuVnb423qHznpDFWlvPiSQAAAAhAkkAAABCBJIAAACECCQBAAAISS+2o7AFvFeFtT6SXD+bgD9bFKRCv2Tpci0VC7isdNu1zcyz0b/pMof5uzf+xsxa21X3iKyiVbfPg6qcSAIAABAikAQAACBEIAkAAEBIeo6kZ5Rz6c8xVZ/95y6z8+ym+TmSozR6vSP722wOasU+t5/zT2bnRte/q7g+b3fb/nPb9XTlRBIAAIAQgSQAAAAhAkkAAABCBJIAAACEpBfbySyycJPbX6K78oW5+mn+8zv3nWIMfZweq9Pf/9QGa48IfT5P3z3TL+zgRBIAAIAQgSQAAAAhAkkAAABCBJIAAACEpBfbqUCC8f06F7LYyVqgspvn583X9k/eeM0z3L+emT9UZt0+cyIJAABAiEASAACAEIEkAAAAIcdyJD0Lrw9G6ScAbuGeBrVZo+OcSAIAABAikAQAACBEIAkAAECIQBIAAIAQgSQAAAAhAkkAAABCBJIAAACECCQBAAAIEUgCAAAQ8ufpBgB09vPzk/I5f/zxR8rncJ+nOdZ1vsyul93X26WdACc5kQQAACBEIAkAAECIQBIAAICQ1+RIdshjymrjCV3zQjL7XB+M6dJPu/tl5Pu69B3rmS/36/yb4N90mZujY9Dlethr9RquMO+cSAIAABAikAQAACBEIAkAAECIQBIAAICQ1xTbGUlIrVhco4vTL8yu0JdeYN1bhTn0b06vM6ikQxE9epmdU99/12VOdbjvjdrd57sLnj19X4V550QSAACAEIEkAAAAIQJJAAAAQgSSAAAAhLym2M63kaIVFZKQKyZsr+yXCn2+2+6Ebe6aZwrw3M94uu+wx+zvwA5r9PZ5vvpemFXYZvY332zR0NVz04kkAAAAIQJJAAAAQgSSAAAAhBzLkdz9IuEKL+0E9rs9L+SJ/e680VyVkbGazemf/b5MWTngu9fxbJsq5iyPzqmsv7vJ6bHjL13m4u79rkLepBNJAAAAQgSSAAAAhAgkAQAACBFIAgAAELKl2M7KF66vfFls5yTrri+5ny2e1FnFcWC+6Emmm+Y5Y+M5OuaZn3VaxXZWbBPrzY67+/hdsu71u+fFiaJETiQBAAAIEUgCAAAQIpAEAAAgJD1HcmXO4ujnvC23YfZ6K744+cnKdmbOlYp9B5V1zeXO9H19mTlab7sX8mz2Hmr+jPvuq9v3rd0qzsWVbeq0nzuRBAAAIEQgCQAAQIhAEgAAgBCBJAAAACHpxXaeVHwhp0Tou1R4iXCFNow4nbDducjTb/7d291eFIxcO4vBjM4nax3qqnBfqFgIcjUnkgAAAIQIJAEAAAgRSAIAABAikAQAACBkS7Edcq1Mwn1bYYtOCc3MGSm+9ZvPyvrsTKfX7Oo+OF08LXOfzOqrzD7PvL7ZNTNzPbNFc7oUSstUYZ9ir8x74Uo3ravMa6naL04kAQAACBFIAgAAECKQBAAAIESOZENVn5OGilbnjlW0O4fwdL9UyO0+3Qezurb7yU3XApwzkk868m8q/P5YfS90IgkAAECIQBIAAIAQgSQAAAAhAkkAAABCthTb2V0IoULhBfaafRH1ahXn3coXp1e8XtYz7nlWrk+A/8reneuN/elEEgAAgBCBJAAAACECSQAAAEIEkgAAAISkF9t5SjR9KgLw/f8UGHh2+/Xt9sZE6FX0JQDU9Lbfj6t/k3x//u4ChFULHjqRBAAAIEQgCQAAQIhAEgAAgJD0HMknI3mTmc9yd8ndmnneuqoufc68rmM8mre98vueVGgDAP/J3smokbnyhtjGiSQAAAAhAkkAAABCBJIAAACECCQBAAAI2VJs50nVpNGT9Mkz/UK2zAI8s/Mzq9iW9QF3srb5DfPnvDeMgRNJAAAAQgSSAAAAhAgkAQAACBFIAgAAEHKs2A5AJaeT4k9/P+MqjlXFNo3q3HaAN3MiCQAAQIhAEgAAgBCBJAAAACECSQAAAEIEkgAAAIQIJAEAAAgRSAIAABAikAQAACDkz9MNYI2fn59//TdeAg35Rtbe52P9wSmja/SbNdvH7Bifdvscm/1t6jdtXU4kAQAACBFIAgAAECKQBAAAIEQgCQAAQEjpYjuZydI3JeFm9UuFoiBdE+J3mx2Dzv3bZc3O9vH333W5XvZTaGJe5h54es3az591Xh/fbX+6lqpt/zc3XcusN8QxTiQBAAAIEUgCAAAQIpAEAAAgpEyO5Opn/0/nNsza/ez/0/d17bs36pxD8+2meTf7guUuKl5L1/mSlXP7+dTsg9nrW3kts+tzdZ9XXFezbtrP2a/ifNm9PqvmAjuRBAAAIEQgCQAAQIhAEgAAgBCBJAAAACFliu0w7juZdvULT29K+GfeSBK3uVKzKECmLmPcpfjMt9k9uMO1fT597jFd2rmSPZ/f6LInjZqZ67uLb5247zmRBAAAIEQgCQAAQIhAEgAAgBCBJAAAACGK7RQymsi7Mrl9dzLvbcnYMKtLcZiKbbqdPt9PEZkxihK90+oij13NXst3fz59TtW15kQSAACAEIEkAAAAIQJJAAAAQuRIHlTh2eYsXfK74DdW5gd32Q+6tPNJ1z1ppM+7XlsFb7x/jVzfG/vlZlX37qx2Zc7NLnvuyO+I1evYiSQAAAAhAkkAAABCBJIAAACECCQBAAAIKVNsZ/WLNiskxX7rWmzjScX+5dnKscp6GS9/GXlJMXeZXQu3F0aZvb7Z3xY39WdWQbDPx17dydvG6qY124kTSQAAAEIEkgAAAIQIJAEAAAgRSAIAABCypdjOSMLvaELsyL+bTaSfbVOWLontmf1S8foqkjC+n8R9Tpi9D3Sem7vvfbOF7rqMQ4U2cI/dxTBv+l1407X8EyeSAAAAhAgkAQAACBFIAgAAELIlR3LE7hyFLvmIkG12nmflJ/NsZE+St/lOxriGiuOwsk32895mcw9Pz/Ob8sQrtimbE0kAAABCBJIAAACECCQBAAAIEUgCAAAQcqzYzu7CHd+f9YYEWIBs9s79uhaaWG32vn57f67sF+63ex5UnHddit9V2MucSAIAABAikAQAACBEIAkAAEDIsRzJLJnP/p5+trlLTsboc+FZ1zebTzv7fRWf12e/0Zci71y3FfaIzuujQv/9m9n+PT03s828TL3z3My0uwYF8JeVv+urrlknkgAAAIQIJAEAAAgRSAIAABAikAQAACCkfbGdLkYKIWQWsan4fbMyi08wT3/uL15SsVjK6Po/rWLfjZjt367XO2rlvOvcdxXX3k1G+7fDHFq9d+8u8nT6flyxANiJeehEEgAAgBCBJAAAACECSQAAAEIEkgAAAIQcK7aTlYC6shBLZtLqTUVzMr+PXDPJ4FV1nS+zfd71ej+ftW1fWciri7dd76jZ/a5Lf9rP98v6TdnlejML5Lyt+NXtxYVGOZEEAAAgRCAJAABAiEASAACAkC05klWf693p9lzH2Twmc2Mt/bvfbJ/LBXymDxh1+1y5/fpmjfwm2Z1felOufMU2dfGGvnMiCQAAQIhAEgAAgBCBJAAAACECSQAAAEK2FNsht/BM1mft/j7gn1lDADnsp7CHE0kAAABCBJIAAACECCQBAAAIEUgCAAAQIpAEAAAgRCAJAABAiEASAACAEIEkAAAAIX+ebgAA8A4/Pz+nmzDEC+0B/p0TSQAAAEIEkgAAAIQIJAEAAAgRSAIAABDy62I7XRLnV5pNyv/uO8n9veeT8eujwjx723zJ7POb+2713KzYd7vbNHLvfRqHin13ei+r2Cefz3y/VL0e3qXT/dKJJAAAACECSQAAAEIEkgAAAIQIJAEAAAj5dbEd9luZXJ9VOCjzs6mh6xifLkbxT24qtrW7j2+ei5ntfvq+m+bdSiMFeE4XCfonWe0a+b4uRYlGdd1bKsoseHT7uKy8h65eo04kAQAACBFIAgAAECKQBAAAIOTXOZKdn0nuoGLu0UrmUw2z86BCvkxWXsbu/I4KfTfi9B4x6nR/7s5nm/3s0/3UycgecVPfze5lFczO/RFdxz3znrby77rMsZVW58pnciIJAABAiEASAACAEIEkAAAAIQJJAAAAQn5dbGfEbGJyl4TbnS//vd0b++B0kv7qPj/9Eu9RWf3wxjk866b7wAjFb2q6aY6tNlJwqMscztx/OtznRtv0fS0V1kfF/uQvTiQBAAAIEUgCAAAQIpAEAAAgRCAJAABASHqxnQpJucA/O71GuxRiqGB2rBQAe7ayIMbsZ5/u46d23r5GrY/7jYzNyDzILFBzeg11ma+je1Lm5+9UYRwy74VOJAEAAAgRSAIAABAikAQAACDk1zmSK5/1Pf0cM/sZ8/X08f1Ovxw78/syXxpOH6d/W1TIcbtpXp/ek55k5VH+5vtW9kPXvO1Zq8dqp9VzJfOznEgCAAAQIpAEAAAgRCAJAABAiEASAACAkF8X2/lO2MxMUD2d7DqqQtL4LbqM+ShzY7+Ve1KmmXbOvqi5QuEQno3MA+O3VtU9YtZt1zMjq0jX6Nqbve9kFSGa3SMqzJWRvpu9942o0AcjZsd49b3CiSQAAAAhAkkAAABCBJIAAACECCQBAAAI+XWxnW8rE2KB36u4Hlcmg1fdk2baUKHdI0bb2aUYxIzMwgijBXhGKNJTU9d5PqtCAamVxcuq3nduMlukh1xOJAEAAAgRSAIAABAikAQAACAkPUfyyewzym97tnn0mfqV/XI67+ZtY37C7IuTV30/dd2W57MzFzezn2ZzSVf+HVQ2sv6e5vnp++Nus2u94r5x+vvfyokkAAAAIQJJAAAAQgSSAAAAhAgkAQAACPl1sZ3bE5FH7H6Z+kqnk5XNp7/sHIfVBVVOz6knFQtWrVRxDD6fmn3V1egY6/O/q1BUKrMNO9d7hza+wcj86VJYp0LhScY5kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhvy62A9xntvCD5PcahTtW6nJ9Fefi7sIkndfs6Xad/v6nNlRcZ6Mq9OfbZM2fCmO3shiVe9rvOJEEAAAgRCAJAABAiEASAACAkF/nSFZ9ZpeezKe6jM283TkYu8dq5Psyr7fiXFyZj9g5t4k8o/vI6fyuCvMuqw2Z13K6X05/f7aV13NbX63kRBIAAIAQgSQAAAAhAkkAAABCBJIAAACE/LrYDvMk88J7vW393369Cj/MO10cprPb5wZQmxNJAAAAQgSSAAAAhAgkAQAACBFIAgAAEKLYDgCwheIwAPdwIgkAAECIQBIAAIAQgSQAAAAhAkkAAABCBJIAAACECCQBAAAIEUgCAAAQIpAEAAAg5M/TDfjv/Pz8pH2WlyDP96e+4zaZe0tFFdds1z6v2Jec8T2Hq86NnWutah8AeziRBAAAIEQgCQAAQIhAEgAAgJAyOZKrn+nvktuQpUs+0tvGpYs35id3aee3Lmv9bW4fl67rJdPTGO/ul9179ff3VegD4BwnkgAAAIQIJAEAAAgRSAIAABAikAQAACCkTLGdUTPJ4DyTEL/eyFzsXJwh6/sqzEX7xnm3r4U3yurj2blRYb/Lmtez11JhXVXY47/tnhsV+2DWGwvyjcj6vTPavxX6zokkAAAAIQJJAAAAQgSSAAAAhAgkAQAACCldbGc2ifTp7xRV4PPpU/gly+6E+M79W7Vd/8beVlOX+dSlqMNIAZfb1kLF4jqzff7971bPp4r3ooptGrF6Xe2eG1ky+2X2syr0nRNJAAAAQgSSAAAAhAgkAQAACCmdIzn7Utvb8iRgpcxn6rvkNlBzn+ySr8e42bEamQu7cwhX5rhVXI8rzf6+W/1ZK83+fq2Yy7n7+yuM584c19HP+W7Tib5zIgkAAECIQBIAAIAQgSQAAAAhAkkAAABCShfbeXI6CbiCrD5Q2GJclxcJWx/nWVfvZO3lyiqst3KdWcPPfbByLbxxna0soDLy2bv7fHROrVx/VQv+fKtQwMmJJAAAACECSQAAAEIEkgAAAIQIJAEAAAgpU2xndcJ2xSRZOGFlkv5KVdfwbD98/13V64MKKhRZqbhGK7Zpt6zCTE+ftfvvVv4WHv2c02tttp2za+H09X4+Y/Mnc61n/v5wIgkAAECIQBIAAIAQgSQAAAAhx3Ikd79I+PSLizN1eFHzardfH32NzrsKeRnk6bLf3DzvVo+BvGYgYjbXsdPe4kQSAACAEIEkAAAAIQJJAAAAQgSSAAAAhBwrtjPi5qIAcEpWwYjOL/9d6abr253w3+WF2berUOhhZoxnXwI/a/VLw7u0oYPd9zn3xz6exuppHGZ+O41+9so96UnmHuFEEgAAgBCBJAAAACECSQAAAEIEkgAAAIQcK7YzkuiZmXQs+Zw3+J7nnRP3u6zZ2T7vcn383ey6euOYr9yDZtfeyja9rfjN7nvM7uIlFWRdy2jfdbFyjEf6qsK8Gxm/1W1yIgkAAECIQBIAAIAQgSQAAAAhx3IkR8w+z33Ts/GwWmZOX+d8iyz2n/W69vHK+9ftay/z5d8VdW77jMw13OW3YsU2dZl3Ffb8kZzsrL7rlIfvRBIAAIAQgSQAAAAhAkkAAABCBJIAAACElC62M5tsWuEloaw1O55d/i5LZiL9zheLr5bZL1ltP/391FDhBe+dPn/n9+8uljJS3OPJ97/b3QdPusyDm+5zs3YX1uvSL99Gfkdk/nasGts4kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhZYrtjCbuziYBr0w+J8/uQhNdrCzg8OSNifSz7czqq4r9ZD0+Oz1XRlWcUzfpUiQs8/fVzs/JtLsP3mj3b4uKVs6fqnPTiSQAAAAhAkkAAABCBJIAAACElMmRfLL6RZ43q/os9bcu7Xwb4zJOXzHKXOmt4vhl/t6Z+bsKfTLS7grthBs5kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhZYrtrE6ElmgN8J/siXCnt63tt11vBZnFMOnNiSQAAAAhAkkAAABCBJIAAACECCQBAAAIEUgCAAAQIpAEAAAgRCAJAABAiEASAACAkD9PN4C4n5+frd/nBbLAb83uW/YfAPi7CvdVJ5IAAACECCQBAAAIEUgCAAAQIpAEAAAgRLGd4kYSaTOTZp++7/v/dSl+oSjRvMy+0y/PbuqXb0/9NHu9mZ912u5r2X3/OG10fVa85pX3q6fr7To3do/x7d/HWm/4HepEEgAAgBCBJAAAACECSQAAAEIEkgAAAIS8tthOxQIOo236/nerk3m/21Cx73brfL27k7+7FHXQL33c1Hddi5lVMLtmT/d5hb1m5Jor3utnCwd1drogV1Wn5+KT0/15Ys06kQQAACBEIAkAAECIQBIAAICQ0jmSK/NgRp6zX/1c8cj3zT5vPZv/MPLvRttZ8fn1LF2u9/Tz+qMqvqj9SWabRtpwOperqpv7ocvewnqZ9/GRv/v+vgq/k2Z1befuz959T+v8fTPfX8Hqe4oTSQAAAEIEkgAAAIQIJAEAAAgRSAIAABBSptjO6qTVmaTczATVmcIaVWUW6aGmN770+dvq+Tqy3/DsdJGF3RTgeafde8LIPXvl76QRmcWFMu0umjPTx7fd03Z/X1aBzN0Fh1ZzIgkAAECIQBIAAIAQgSQAAAAh6TmSFZ7XHXmevEI7Z9uU9Xz16OfMPvN9uo9Xvsy5an7Syj5f2Z8jn9W5z6u2vSv92Wd9jOiSA5rV5xXvl09GxmX0Wlb+bsncg7PGYfb7MutwnF5DVedG1rrtfO/PvH84kQQAACBEIAkAAECIQBIAAIAQgSQAAAAhvy62UzFBvIuVL9Y9XRilgsw2VU2Yrmb3vNstc82aU/MJ/52LHKzSZV2NtnPn9VjX40bW7EjRk8yiMiv/btbqdt4+zyqaGdMu+/JvOJEEAAAgRCAJAABAiEASAACAEIEkAAAAIb8utgOdSVgnYqSIxOjfMU9/9jW7hirKnIdd+2W24NDs360uyJVVAMy9orcOa68KJ5IAAACECCQBAAAIEUgCAAAQ8uscye/nuT1XPO6mvntbfgfPOox5NvOFCrruwZ9Pn7bvXus3/UaoaGV/Gis+n3f8PnAiCQAAQIhAEgAAgBCBJAAAACECSQAAAEJ+XWzn22xiaWZicpfE/YptmlXxWmbnQecCPCuLM+we49N9vnoeVFwzs2Zf4j0i86XhN7n9mhWa+buKv21m27R6/s7On9MFlTI/5/Q+ufu31Oz3nV5DvzHb9sxxcCIJAABAiEASAACAEIEkAAAAIQJJAAAAQtKL7cxanUQ+k3BboWDEbNJ6hb+rWAxCAYcxt/dLZiEmxtxe6GolfbC/0EyXPj/dL1WL2OzW4bfF7gJyq8d89/qf+b7M9VFxTn0+TiQBAAAIEkgCAAAQIpAEAAAgpEyO5JOKLy7NNPJM/chz2rPPTc8+u12h77KMPgd/+sW+syq+wPpJhf7cnaMw8303rT2eGc9xWXlpt/X5yn6ZraPAfqfz7lb/1qj4m31W53xvJ5IAAACECCQBAAAIEUgCAAAQIpAEAAAgpHSxnZUqJoNXSGy/qdDMLNe3v6hMFxWT+ysU4FlZkKvC9a1007Wcpi+fjf62mOm/Ln1+ek+sqks7K6pYzPDEeDqRBAAAIEQgCQAAQIhAEgAAgBCBJAAAACGvLbbTxUgy7+rkXsnY72PMa6qY3P8ks503zUUFP6hidm6YU9wmsxhVZlG5rM9ezYkkAAAAIQJJAAAAQgSSAAAAhMiRbKjqc9LAfl32gy7tBODd5LOPcyIJAABAiEASAACAEIEkAAAAIQJJAAAAQqLFdv7n5/P5/1Y0BAAAgKP+5+g//OPn52dlQwAAALiMR1sBAAAIEUgCAAAQIpAEAAAgRCAJAABAiEASAACAEIEkAAAAIQJJAAAAQgSSAAAAhAgkAQAACPn/AauprZvJUlvOAAAAAElFTkSuQmCC\n", 51 | "text/plain": [ 52 | "
" 53 | ] 54 | }, 55 | "metadata": {}, 56 | "output_type": "display_data" 57 | } 58 | ], 59 | "source": [ 60 | "from torchvision import datasets\n", 61 | "from torch.utils.data import Dataset, DataLoader\n", 62 | "from torch.utils.data.sampler import RandomSampler, SubsetRandomSampler\n", 63 | "from torchvision import transforms\n", 64 | "from PIL import Image\n", 65 | "\n", 66 | "la_loader = lambda _: Image.open(_).convert('LA')\n", 67 | "\n", 68 | "image_dataset = datasets.folder.ImageFolder(\n", 69 | " './data', \n", 70 | " transform=transforms.Lambda(lambda image: transforms.ToTensor()(image.crop(box=(0, 0, 16, 16)))[1:, :, :] > .5),\n", 71 | " loader=la_loader)\n", 72 | "\n", 73 | "batch_images, labels = next(b for b in DataLoader(image_dataset, sampler=RandomSampler(image_dataset), batch_size=144))\n", 74 | "\n", 75 | "def batch_images_to_one(batches_images):\n", 76 | " n_square_elements = int(np.sqrt(batches_images.shape[0]))\n", 77 | " rows_images = np.split(np.squeeze(batches_images), n_square_elements)\n", 78 | " return np.vstack([np.hstack(row_images) for row_images in rows_images])\n", 79 | "\n", 80 | "show_as_image(batch_images_to_one(batch_images.numpy()), figsize=(16, 17))\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 20, 86 | "metadata": { 87 | "collapsed": true 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "import torch\n", 92 | "import torch.nn as nn\n", 93 | "\n", 94 | "class MaskedConv2d(nn.Conv2d):\n", 95 | " def __init__(self, mask_type, *args, **kwargs):\n", 96 | " super(MaskedConv2d, self).__init__(*args, **kwargs)\n", 97 | " assert mask_type in {'A', 'B'}\n", 98 | " self.register_buffer('mask', self.weight.data.clone())\n", 99 | " _, _, kH, kW = self.weight.size()\n", 100 | " self.mask.fill_(1)\n", 101 | " self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0\n", 102 | " self.mask[:, :, kH // 2 + 1:] = 0\n", 103 | "\n", 104 | " def forward(self, x):\n", 105 | " self.weight.data *= self.mask\n", 106 | " return super(MaskedConv2d, self).forward(x)\n", 107 | " \n", 108 | "class PixelCNN(nn.Module):\n", 109 | " n_channels = 16\n", 110 | " kernel_size = 7\n", 111 | " padding = 3\n", 112 | " n_pixels_out = 2 # binary 0/1 pixels\n", 113 | "\n", 114 | " def __init__(self):\n", 115 | " super(PixelCNN, self).__init__()\n", 116 | " self.layers = nn.Sequential(\n", 117 | " MaskedConv2d('A', in_channels=1, out_channels=self.n_channels, kernel_size=self.kernel_size, padding=self.padding, bias=False), nn.BatchNorm2d(self.n_channels), nn.ReLU(True),\n", 118 | " MaskedConv2d('B', self.n_channels, self.n_channels, kernel_size=self.kernel_size, padding=self.padding, bias=False), nn.BatchNorm2d(self.n_channels), nn.ReLU(True),\n", 119 | " MaskedConv2d('B', self.n_channels, self.n_channels, kernel_size=self.kernel_size, padding=self.padding, bias=False), nn.BatchNorm2d(self.n_channels), nn.ReLU(True),\n", 120 | " MaskedConv2d('B', self.n_channels, self.n_channels, kernel_size=self.kernel_size, padding=self.padding, bias=False), nn.BatchNorm2d(self.n_channels), nn.ReLU(True),\n", 121 | " MaskedConv2d('B', self.n_channels, self.n_channels, kernel_size=self.kernel_size, padding=self.padding, bias=False), nn.BatchNorm2d(self.n_channels), nn.ReLU(True),\n", 122 | " nn.Conv2d(in_channels=self.n_channels, out_channels=self.n_pixels_out, kernel_size=1)\n", 123 | " )\n", 124 | "\n", 125 | " def forward(self, x):\n", 126 | " pixel_logits = self.layers(x)\n", 127 | " return pixel_logits" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "Epoch [1/80], Train loss: 0.5941, Test loss: 0.5875\n", 140 | "Epoch [2/80], Train loss: 0.5597, Test loss: 0.5559\n", 141 | "Epoch [3/80], Train loss: 0.5411, Test loss: 0.5302\n", 142 | "Epoch [4/80], Train loss: 0.5127, Test loss: 0.5079\n", 143 | "Epoch [5/80], Train loss: 0.4920, Test loss: 0.4839\n", 144 | "Epoch [6/80], Train loss: 0.4706, Test loss: 0.4740\n", 145 | "Epoch [7/80], Train loss: 0.4574, Test loss: 0.4425\n", 146 | "Epoch [8/80], Train loss: 0.4457, Test loss: 0.4219\n", 147 | "Epoch [9/80], Train loss: 0.4023, Test loss: 0.3869\n", 148 | "Epoch [10/80], Train loss: 0.3768, Test loss: 0.3583\n", 149 | "Epoch [11/80], Train loss: 0.3607, Test loss: 0.3310\n", 150 | "Epoch [12/80], Train loss: 0.3307, Test loss: 0.3168\n", 151 | "Epoch [13/80], Train loss: 0.3174, Test loss: 0.2945\n", 152 | "Epoch [14/80], Train loss: 0.2911, Test loss: 0.2847\n", 153 | "Epoch [15/80], Train loss: 0.2865, Test loss: 0.2738\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "import torch.nn.functional as F\n", 159 | "from torch.utils.data.sampler import SubsetRandomSampler\n", 160 | "\n", 161 | "\n", 162 | "N_EPOCHS = 80\n", 163 | "BATCH_SIZE = 128\n", 164 | "LR = 0.005\n", 165 | "TEST_RATIO = .2\n", 166 | "\n", 167 | "cnn = PixelCNN()\n", 168 | "optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)\n", 169 | "\n", 170 | "test_indices = np.random.choice(len(image_dataset), size=int(len(image_dataset) * TEST_RATIO), replace=False)\n", 171 | "train_indices = np.setdiff1d(np.arange(len(image_dataset)), test_indices)\n", 172 | "train_loader = DataLoader(image_dataset, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(indices=train_indices))\n", 173 | "test_loader = DataLoader(image_dataset, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(indices=test_indices))\n", 174 | "\n", 175 | "for epoch in range(N_EPOCHS):\n", 176 | " for i, (images, _) in enumerate(train_loader):\n", 177 | " images = images.float()\n", 178 | " optimizer.zero_grad()\n", 179 | " loss = F.cross_entropy(input=cnn(images), target=torch.squeeze(images).long())\n", 180 | " loss.backward()\n", 181 | " optimizer.step()\n", 182 | " \n", 183 | " test_images = next(i for i, _ in test_loader).float()\n", 184 | " test_loss =F .cross_entropy(input=cnn(test_images), target=torch.squeeze(test_images).long())\n", 185 | "\n", 186 | " print ('Epoch [%d/%d], Train loss: %.4f, Test loss: %.4f'\n", 187 | " %(epoch+1, N_EPOCHS, loss.item(), test_loss.item()))" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "## Completing images" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "# pixelCNN로 새로운 샘플 만들기\n", 204 | "# 어떤 특정한 이미지를 pixelCNN의 input으로 줄 수도 있고 아니면 그냥 zeros를 넣어 줄 수도 있다.\n", 205 | "# pixelCNN은 input으로 주어진 이미지를 Masked Convolution으로 흝고 지나가면서 자신의 기억속에 존재하는, 가장 그럴듯한 픽셀 값을 예측해 낸다.\n", 206 | "def generate_samples(n_samples, starting_point=(0, 0), starting_image=None):\n", 207 | " samples = torch.from_numpy(\n", 208 | " starting_image if starting_image is not None else np.zeros((n_samples * n_samples, 1, IMAGE_WIDTH, IMAGE_HEIGHT))).float()\n", 209 | "\n", 210 | " cnn.train(False)\n", 211 | "\n", 212 | " # pixelCNN이 예측한 pixel-level distribution으로 부터 픽셀 값(0 또는 1)을 샘플링해서 이미지를 만들어 낸다.\n", 213 | " for i in range(IMAGE_WIDTH):\n", 214 | " for j in range(IMAGE_HEIGHT):\n", 215 | " if i < starting_point[0] or (i == starting_point[0] and j < starting_point[1]):\n", 216 | " continue\n", 217 | " out = cnn(samples)\n", 218 | " probs = F.softmax(out[:, :, i, j],1).data\n", 219 | " samples[:, :, i, j] = torch.multinomial(probs, 1).float()\n", 220 | " return samples.numpy()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "n_images = 12\n", 230 | "batch_images, labels = next(b for b in train_loader)\n", 231 | "batch_images = batch_images[:n_images, :, :, :]\n", 232 | "IMAGE_WIDTH=16\n", 233 | "IMAGE_HEIGHT=16\n", 234 | "starting_point = (8, 8)\n", 235 | "\n", 236 | "row_grid, col_grid = np.meshgrid(np.arange(16), np.arange(16), indexing='ij')\n", 237 | "mask = np.logical_or(row_grid < starting_point[0], np.logical_and(row_grid == starting_point[0], col_grid <= starting_point[1]))\n", 238 | "\n", 239 | "starting_images = batch_images.numpy().squeeze()\n", 240 | "batch_starting_images = np.expand_dims(np.stack([i * mask for i in starting_images] * n_images), axis=1)\n", 241 | "\n", 242 | "samples = generate_samples(10, starting_image=batch_starting_images, starting_point=starting_point)\n", 243 | "\n", 244 | "show_as_image(np.hstack([(1 + mask) * i for i in starting_images]), figsize=(16, 17))\n", 245 | "\n", 246 | "show_as_image(\n", 247 | " batch_images_to_one((samples * (1 + mask))),\n", 248 | " figsize=(16, 17))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "n_images = 12\n", 258 | "batch_images, labels = next(b for b in test_loader)\n", 259 | "batch_images = batch_images[:n_images, :, :, :]\n", 260 | "\n", 261 | "starting_point = (8, 8)\n", 262 | "\n", 263 | "row_grid, col_grid = np.meshgrid(np.arange(IMAGE_WIDTH), np.arange(IMAGE_HEIGHT), indexing='ij')\n", 264 | "mask = np.logical_or(row_grid < starting_point[0], np.logical_and(row_grid == starting_point[0], col_grid <= starting_point[1]))\n", 265 | "\n", 266 | "starting_images = batch_images.numpy().squeeze()\n", 267 | "batch_starting_images = np.expand_dims(np.stack([i * mask for i in starting_images] * n_images), axis=1)\n", 268 | "\n", 269 | "samples = generate_samples(10,starting_image=batch_starting_images, starting_point=starting_point)\n", 270 | "\n", 271 | "show_as_image(np.hstack([(1 + mask) * i for i in starting_images]), figsize=(16, 17))\n", 272 | "\n", 273 | "show_as_image(\n", 274 | " batch_images_to_one((samples * (1 + mask))),\n", 275 | " figsize=(16, 17))" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "## Generating new samples" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "IMAGE_WIDTH, IMAGE_HEIGHT = 16, 16\n", 292 | "\n", 293 | "def generate_samples(n_samples=12, starting_point=(0, 0), starting_image=None):\n", 294 | "\n", 295 | " samples = torch.from_numpy(\n", 296 | " starting_image if starting_image is not None else np.zeros((n_samples * n_samples, 1, IMAGE_WIDTH, IMAGE_HEIGHT))).float()\n", 297 | "\n", 298 | " cnn.train(False)\n", 299 | "\n", 300 | " for i in range(IMAGE_WIDTH):\n", 301 | " for j in range(IMAGE_HEIGHT):\n", 302 | " if i < starting_point[0] or (i == starting_point[0] and j < starting_point[1]):\n", 303 | " continue\n", 304 | " out = cnn(samples)\n", 305 | " probs = F.softmax(out[:, :, i, j],1).data\n", 306 | " samples[:, :, i, j] = torch.multinomial(probs, 1).float()\n", 307 | " return samples.numpy()\n", 308 | "\n", 309 | "samples = generate_samples(12)\n", 310 | "\n", 311 | "show_as_image(\n", 312 | " batch_images_to_one(samples),\n", 313 | " figsize=(16, 17))\n" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [] 322 | } 323 | ], 324 | "metadata": { 325 | "kernelspec": { 326 | "display_name": "Python 3", 327 | "language": "python", 328 | "name": "python3" 329 | }, 330 | "language_info": { 331 | "codemirror_mode": { 332 | "name": "ipython", 333 | "version": 3 334 | }, 335 | "file_extension": ".py", 336 | "mimetype": "text/x-python", 337 | "name": "python", 338 | "nbconvert_exporter": "python", 339 | "pygments_lexer": "ipython3", 340 | "version": "3.6.8" 341 | } 342 | }, 343 | "nbformat": 4, 344 | "nbformat_minor": 2 345 | } 346 | -------------------------------------------------------------------------------- /Day07/pixelCNN/.ipynb_checkpoints/ToyPixelCNN-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "from matplotlib import pyplot as plt\n", 13 | "\n", 14 | "def show_as_image(binary_image, figsize=(10, 5)):\n", 15 | " plt.figure(figsize=figsize)\n", 16 | " plt.imshow(binary_image, cmap='gray')\n", 17 | " plt.xticks([]); plt.yticks([])\n", 18 | "\n", 19 | "%matplotlib inline" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "collapsed": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "from torchvision import datasets, utils" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Pixel CNN\n", 40 | "\n", 41 | "Alternative to Pixel RNN from [Pixel Recurrent Neural Networks](https://arxiv.org/pdf/1601.06759.pdf). \n", 42 | "\n", 43 | "On-line resources:\n", 44 | " * See for an existing PyTorch implementation https://github.com/jzbontar/pixelcnn-pytorch/blob/master/main.py\n", 45 | " * http://sergeiturukin.com/2017/02/22/pixelcnn.html for a nice walk-through\n", 46 | " * http://tinyclouds.org/residency/\n", 47 | " * https://tensorflow.blog/2016/11/29/pixelcnn-1601-06759-summary/ (in korean ;) ) \n", 48 | "\n", 49 | "The core ideas are the following:\n", 50 | "\n", 51 | "### Joint distribution of an image $\\mathbf{x}$ modelled as an autoregressive process\n", 52 | "\n", 53 | "Same model for PixelRNN and PixelCNN:\n", 54 | "\n", 55 | "$$p(\\mathbf{x}) = \\prod_{i=1}^{n^2} p(x_i|x_{1}, \\dots, x_{i-1})$$\n", 56 | " \n", 57 | "![](http://sergeiturukin.com/assets/2017-02-22-183010_479x494_scrot.png)\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## 위 사진처럼 생긴 Mask를 생성.\n", 65 | "이 Mask를 convolution filter적용하면 convolution filter를 causal하게 만든다." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# True/False로 구성된 Mask\n", 75 | "def causal_mask(width, height, starting_point):\n", 76 | " row_grid, col_grid = np.meshgrid(np.arange(width), np.arange(height), indexing='ij')\n", 77 | "# print(row_grid)\n", 78 | "# print()\n", 79 | "# print(col_grid)\n", 80 | "# print('Mask making')\n", 81 | "# print(row_grid\n", 206 | "Application on a simple generative model of LCD digits\n", 207 | "
\n", 208 | "From https://gist.github.com/benjaminwilson/b25a321f292f98d74269b83d4ed2b9a8#file-lcd-digits-dataset-nmf-ipynb" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "CELL_LENGTH = 4 # 숫자의 한 획의 길이\n", 218 | "IMAGE_WIDTH, IMAGE_HEIGHT = 2 * CELL_LENGTH + 5, CELL_LENGTH + 4 # 획의 길이에 따른 LCD 이미지 규격\n", 219 | "\n", 220 | "# 세로 획 그리기\n", 221 | "def vertical_stroke(rightness, downness):\n", 222 | " \"\"\"\n", 223 | " Return a 2d numpy array representing an image with a single vertical stroke in it.\n", 224 | " `rightness` and `downness` are values from [0, 1] and define the position of the vertical stroke.\n", 225 | " \"\"\"\n", 226 | " i = (downness * (CELL_LENGTH + 1)) + 2\n", 227 | " j = rightness * (CELL_LENGTH + 1) + 1\n", 228 | " x = np.zeros(shape=(IMAGE_WIDTH, IMAGE_HEIGHT), dtype=np.float64)\n", 229 | " x[i + np.arange(CELL_LENGTH), j] = 1.\n", 230 | " return x\n", 231 | "\n", 232 | "# 가로 획 그리기\n", 233 | "def horizontal_stroke(downness):\n", 234 | " \"\"\"\n", 235 | " Analogue to vertical_stroke, but it returns horizontal strokes.\n", 236 | " `downness` is here a value in [0, 1, 2].\n", 237 | " \"\"\"\n", 238 | " i = (downness * (CELL_LENGTH + 1)) + 1\n", 239 | " x = np.zeros(shape=(IMAGE_WIDTH, IMAGE_HEIGHT), dtype=np.float64)\n", 240 | " x[i, 2 + np.arange(CELL_LENGTH)] = 1.\n", 241 | " return x\n", 242 | "\n", 243 | "show_as_image(horizontal_stroke(0))\n", 244 | "# show_as_image(horizontal_stroke(1))\n", 245 | "# show_as_image(horizontal_stroke(2))\n", 246 | "# show_as_image(vertical_stroke(0,0))\n", 247 | "# show_as_image(vertical_stroke(0,1))\n", 248 | "# show_as_image(vertical_stroke(1,0))\n", 249 | "# show_as_image(vertical_stroke(1,1))" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "# 0 ~ 9 사이의 숫자를 만들 수 있는 기본 획(가로, 세로)들 총 집합.\n", 259 | "# 가로 획 3개 + 세로 획 4개 = 총 7개 획\n", 260 | "BASE_STROKES = np.asarray(\n", 261 | " [horizontal_stroke(k) for k in range(3)] + [vertical_stroke(k, l) for k in range(2) for l in range(2)])\n", 262 | "\n", 263 | "# 기본 획들 총 집합 중에서 각 숫자를 만들기 위해 실제로 필요한 획 구성 \n", 264 | "DIGITS_STROKES = np.array([[0, 2, 3, 4, 5, 6], [5, 6], [0, 1, 2, 4, 5], [0, 1, 2, 5, 6], [1, 3, 5, 6], [0, 1, 2, 3, 6], [0, 1, 2, 3, 4, 6], [0, 5, 6], np.arange(7), [0, 1, 2, 3, 5, 6]])\n", 265 | "\n", 266 | "# 저 기본 획들과 각 숫자를 만들기 위한 조합을 이용해서 랜덤 숫자 이미지를 만들어 낸다.\n", 267 | "def random_digits(strokes=BASE_STROKES, digit_as_strokes=DIGITS_STROKES, fixed_label=None):\n", 268 | " label = fixed_label if fixed_label is not None else np.random.choice(len(digit_as_strokes))\n", 269 | " combined_strokes = strokes[digit_as_strokes[label], :, :].sum(axis=0)\n", 270 | " return combined_strokes, label\n", 271 | "\n", 272 | "def batch_images_to_one(batches_images):\n", 273 | " n_square_elements = int(np.sqrt(batches_images.shape[0]))\n", 274 | " rows_images = np.split(np.squeeze(batches_images), n_square_elements)\n", 275 | " return np.vstack([np.hstack(row_images) for row_images in rows_images])\n", 276 | "\n", 277 | "print(random_digits()[0])\n", 278 | "show_as_image(random_digits()[0])\n", 279 | "# show_as_image(random_digits()[0])\n", 280 | "# show_as_image(random_digits()[0])\n", 281 | "# show_as_image(random_digits(fixed_label=3)[0])\n", 282 | "# show_as_image(random_digits(fixed_label=4)[0])\n", 283 | "# show_as_image(random_digits(fixed_label=5)[0])\n", 284 | "# show_as_image(batch_images_to_one(np.stack([random_digits()[0] for _ in range(25)])), figsize=(9, 9))" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "from torch.utils.data import Dataset, DataLoader\n", 294 | "\n", 295 | "# 위에서 만든, LCD Digits 만들기 함수를 이용해서 Custom Dataset을 만든다.\n", 296 | "class LcdDigits(Dataset):\n", 297 | "\n", 298 | " def __init__(self, n_examples):\n", 299 | " digits, labels = zip(*[random_digits() for _ in range(n_examples)])\n", 300 | " self.digits = np.asarray(digits, dtype=np.float64)\n", 301 | " self.labels = np.asarray(labels)\n", 302 | " \n", 303 | " def __len__(self):\n", 304 | " return len(self.labels)\n", 305 | " \n", 306 | " def __getitem__(self, idx):\n", 307 | " digit_with_channel = self.digits[idx][np.newaxis, :, :]\n", 308 | " \n", 309 | " return torch.from_numpy(digit_with_channel).float(), torch.from_numpy(np.array([self.labels[idx]]))\n", 310 | "\n", 311 | "# next(b for b in DataLoader(LcdDigits(128), batch_size=3))" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": {}, 317 | "source": [ 318 | "## Training" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "import torch.nn.functional as F\n", 328 | "\n", 329 | "N_EPOCHS = 25\n", 330 | "BATCH_SIZE = 128\n", 331 | "LR = 0.005\n", 332 | "\n", 333 | "cnn = PixelCNN()\n", 334 | "optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)\n", 335 | "\n", 336 | "train_dataset = LcdDigits(BATCH_SIZE * 50)\n", 337 | "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)\n", 338 | "\n", 339 | "for epoch in range(N_EPOCHS):\n", 340 | " for i, (images, _) in enumerate(train_loader):\n", 341 | " images = images # BATCH_SIZE x 1 x 13 x 8\n", 342 | " pixelCNN_out = cnn(images) # BATCH_SIZE x 2 x 13 x 8\n", 343 | " pixelCNN_target = torch.squeeze(images).long() # BATCH_SIZE x 13 x 8\n", 344 | " optimizer.zero_grad()\n", 345 | " loss = F.cross_entropy(input=pixelCNN_out, target=pixelCNN_target)\n", 346 | " loss.backward()\n", 347 | " optimizer.step()\n", 348 | " \n", 349 | " if i % 100 == 0:\n", 350 | " print ('Epoch [%d/%d], Loss: %.4f' \n", 351 | " %(epoch+1, N_EPOCHS, loss.item()))" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "## input 이미지를 pixel-by-pixel 흝고 지나가면서 이미지 생성해보기\n", 359 | "
\n", 360 | "Sequentially generating new samples " 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "# pixelCNN로 새로운 샘플 만들기\n", 370 | "# 어떤 특정한 이미지를 pixelCNN의 input으로 줄 수도 있고 아니면 그냥 zeros를 넣어 줄 수도 있다.\n", 371 | "# pixelCNN은 input으로 주어진 이미지를 Masked Convolution으로 흝고 지나가면서 자신의 기억속에 존재하는, 가장 그럴듯한 픽셀 값을 예측해 낸다.\n", 372 | "def generate_samples(n_samples, starting_point=(0, 0), starting_image=None):\n", 373 | " samples = torch.from_numpy(\n", 374 | " starting_image if starting_image is not None else np.zeros((n_samples * n_samples, 1, IMAGE_WIDTH, IMAGE_HEIGHT))).float()\n", 375 | "\n", 376 | " cnn.train(False)\n", 377 | "\n", 378 | " # pixelCNN이 예측한 pixel-level distribution으로 부터 픽셀 값(0 또는 1)을 샘플링해서 이미지를 만들어 낸다.\n", 379 | " for i in range(IMAGE_WIDTH):\n", 380 | " for j in range(IMAGE_HEIGHT):\n", 381 | " if i < starting_point[0] or (i == starting_point[0] and j < starting_point[1]):\n", 382 | " continue\n", 383 | " out = cnn(samples)\n", 384 | " probs = F.softmax(out[:, :, i, j],1).data\n", 385 | " samples[:, :, i, j] = torch.multinomial(probs, 1).float()\n", 386 | " return samples.numpy()\n", 387 | "\n", 388 | "show_as_image(batch_images_to_one(generate_samples(n_samples=10)), figsize=(10, 20))" 389 | ] 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "metadata": {}, 394 | "source": [ 395 | "## 이미지의 일부분을 input으로 주고, 나머지 부분을 완성하기\n", 396 | "
\n", 397 | "Or completing existing cropped image\n", 398 | "\n", 399 | " * $0, 8, 9$ and $2, 3, 7$ undistinguishable early one\n", 400 | " * Very small amount of noise (jitter) in samples\n", 401 | " * The last horizontal bar is hard to predict as it depends on the 1st horizontal bar\n", 402 | " * ($4, 9$) sometimes lead to incomplete or erroneous images because of the long term dependency between the upper and lower horizontal bars (could be improved by increasing the receptive field with more layers or \"à trous\" convolutions)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "n_images = 10\n", 412 | "starting_point = (4, 3)\n", 413 | "\n", 414 | "mask = causal_mask(IMAGE_WIDTH, IMAGE_HEIGHT, starting_point)\n", 415 | "\n", 416 | "starting_images = digits_list = [random_digits(fixed_label=d)[0] for d in range(10)]\n", 417 | "batch_starting_images = np.expand_dims(np.stack([i * mask for i in starting_images] * n_images), axis=1)\n", 418 | "\n", 419 | "samples = generate_samples(n_images, starting_image=batch_starting_images, starting_point=starting_point)\n", 420 | "\n", 421 | "show_as_image(np.hstack([(1 + mask) * i for i in starting_images]), figsize=(10, 10))\n", 422 | "\n", 423 | "show_as_image(\n", 424 | " batch_images_to_one((samples * (1 + mask))),\n", 425 | " figsize=(10, 20))" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [] 441 | } 442 | ], 443 | "metadata": { 444 | "kernelspec": { 445 | "display_name": "Python 3", 446 | "language": "python", 447 | "name": "python3" 448 | }, 449 | "language_info": { 450 | "codemirror_mode": { 451 | "name": "ipython", 452 | "version": 3 453 | }, 454 | "file_extension": ".py", 455 | "mimetype": "text/x-python", 456 | "name": "python", 457 | "nbconvert_exporter": "python", 458 | "pygments_lexer": "ipython3", 459 | "version": "3.6.8" 460 | } 461 | }, 462 | "nbformat": 4, 463 | "nbformat_minor": 2 464 | } 465 | -------------------------------------------------------------------------------- /Day07/pixelCNN/ToyPixelCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "from matplotlib import pyplot as plt\n", 13 | "\n", 14 | "def show_as_image(binary_image, figsize=(10, 5)):\n", 15 | " plt.figure(figsize=figsize)\n", 16 | " plt.imshow(binary_image, cmap='gray')\n", 17 | " plt.xticks([]); plt.yticks([])\n", 18 | "\n", 19 | "%matplotlib inline" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": { 26 | "collapsed": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "from torchvision import datasets, utils" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Pixel CNN\n", 40 | "\n", 41 | "Alternative to Pixel RNN from [Pixel Recurrent Neural Networks](https://arxiv.org/pdf/1601.06759.pdf). \n", 42 | "\n", 43 | "On-line resources:\n", 44 | " * See for an existing PyTorch implementation https://github.com/jzbontar/pixelcnn-pytorch/blob/master/main.py\n", 45 | " * http://sergeiturukin.com/2017/02/22/pixelcnn.html for a nice walk-through\n", 46 | " * http://tinyclouds.org/residency/\n", 47 | " * https://tensorflow.blog/2016/11/29/pixelcnn-1601-06759-summary/ (in korean ;) ) \n", 48 | "\n", 49 | "The core ideas are the following:\n", 50 | "\n", 51 | "### Joint distribution of an image $\\mathbf{x}$ modelled as an autoregressive process\n", 52 | "\n", 53 | "Same model for PixelRNN and PixelCNN:\n", 54 | "\n", 55 | "$$p(\\mathbf{x}) = \\prod_{i=1}^{n^2} p(x_i|x_{1}, \\dots, x_{i-1})$$\n", 56 | " \n", 57 | "![](http://sergeiturukin.com/assets/2017-02-22-183010_479x494_scrot.png)\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## 위 사진처럼 생긴 Mask를 생성.\n", 65 | "이 Mask를 convolution filter적용하면 convolution filter를 causal하게 만든다." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAAElCAYAAACiZ/R3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABDFJREFUeJzt17FtAkEURdEdixJw7C3C/VfAFkFOD+PUjliQVlzhc+IX/OhqZsw5F4CKj1cfAPCbKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmnR8bn83mu63rQKcA727btNuf8vLd7KErrui6Xy+X5q4B/a4xx3bPzfQNSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUg5vfoAeMYY49UncBAvJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgZcw594/H2D8G+Gubc37fG3kpASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQMrpwf1tWZbrEYcAb+9rz2jMOY8+BGA33zcgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSDlB+3gHd7ORa3KAAAAAElFTkSuQmCC\n", 76 | "text/plain": [ 77 | "
" 78 | ] 79 | }, 80 | "metadata": {}, 81 | "output_type": "display_data" 82 | }, 83 | { 84 | "data": { 85 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAAElCAYAAACiZ/R3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABDFJREFUeJzt17Ftw0AUBUGeoRLk2CzC/VcgFqHcPZxTOxJpgNBCnolf8KPF3ZhzLgAVb88+AOAnUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlIuR8bX63Wu63rSKcAr27bta875/mh3KErrui632+3vVwH/1hjjvmfn+wakiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpByefYBnGuM8ewT4BAvJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSBFlIAUUQJSRAlIESUgZcw594/H2D8G+G2bc34+GnkpASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQIooASmiBKSIEpAiSkCKKAEpogSkiBKQIkpAiigBKaIEpIgSkCJKQMrl4P5rWZb7GYcAL+9jz2jMOc8+BGA33zcgRZSAFFECUkQJSBElIEWUgBRRAlJECUgRJSDlG/BoHd7iaQCCAAAAAElFTkSuQmCC\n", 86 | "text/plain": [ 87 | "
" 88 | ] 89 | }, 90 | "metadata": {}, 91 | "output_type": "display_data" 92 | } 93 | ], 94 | "source": [ 95 | "# True/False로 구성된 Mask\n", 96 | "def causal_mask(width, height, starting_point):\n", 97 | " row_grid, col_grid = np.meshgrid(np.arange(width), np.arange(height), indexing='ij')\n", 98 | "# print(row_grid)\n", 99 | "# print()\n", 100 | "# print(col_grid)\n", 101 | "# print('Mask making')\n", 102 | "# print(row_grid\n", 253 | "Application on a simple generative model of LCD digits\n", 254 | "
\n", 255 | "From https://gist.github.com/benjaminwilson/b25a321f292f98d74269b83d4ed2b9a8#file-lcd-digits-dataset-nmf-ipynb" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 7, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAALwAAAElCAYAAABeV4iUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAA2pJREFUeJzt3LGRAjEQAMFbijw+/ww+HXxy0IcAV3DAM922jDWmVDIkzVprg4rTuweAVxI8KYInRfCkCJ4UwZMieFIET4rgSdkV/Mz8HjUIPOLeNmfP1YKZcQ+Bj7XWmltrHGlIETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgifl/O4BjuAL8OeauflU9N+ww5MieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSfnKR9zf9OiY57LDkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwp553rr9u2XY4YBB70c8+iWWsdPQh8DEcaUgRPiuBJETwpgidF8KQInhTBkyJ4Uv4AmrcVqoHJrzkAAAAASUVORK5CYII=\n", 266 | "text/plain": [ 267 | "
" 268 | ] 269 | }, 270 | "metadata": {}, 271 | "output_type": "display_data" 272 | } 273 | ], 274 | "source": [ 275 | "CELL_LENGTH = 4 # 숫자의 한 획의 길이\n", 276 | "IMAGE_WIDTH, IMAGE_HEIGHT = 2 * CELL_LENGTH + 5, CELL_LENGTH + 4 # 획의 길이에 따른 LCD 이미지 규격\n", 277 | "\n", 278 | "# 세로 획 그리기\n", 279 | "def vertical_stroke(rightness, downness):\n", 280 | " \"\"\"\n", 281 | " Return a 2d numpy array representing an image with a single vertical stroke in it.\n", 282 | " `rightness` and `downness` are values from [0, 1] and define the position of the vertical stroke.\n", 283 | " \"\"\"\n", 284 | " i = (downness * (CELL_LENGTH + 1)) + 2\n", 285 | " j = rightness * (CELL_LENGTH + 1) + 1\n", 286 | " x = np.zeros(shape=(IMAGE_WIDTH, IMAGE_HEIGHT), dtype=np.float64)\n", 287 | " x[i + np.arange(CELL_LENGTH), j] = 1.\n", 288 | " return x\n", 289 | "\n", 290 | "# 가로 획 그리기\n", 291 | "def horizontal_stroke(downness):\n", 292 | " \"\"\"\n", 293 | " Analogue to vertical_stroke, but it returns horizontal strokes.\n", 294 | " `downness` is here a value in [0, 1, 2].\n", 295 | " \"\"\"\n", 296 | " i = (downness * (CELL_LENGTH + 1)) + 1\n", 297 | " x = np.zeros(shape=(IMAGE_WIDTH, IMAGE_HEIGHT), dtype=np.float64)\n", 298 | " x[i, 2 + np.arange(CELL_LENGTH)] = 1.\n", 299 | " return x\n", 300 | "\n", 301 | "show_as_image(horizontal_stroke(0))\n", 302 | "# show_as_image(horizontal_stroke(1))\n", 303 | "# show_as_image(horizontal_stroke(2))\n", 304 | "# show_as_image(vertical_stroke(0,0))\n", 305 | "# show_as_image(vertical_stroke(0,1))\n", 306 | "# show_as_image(vertical_stroke(1,0))\n", 307 | "# show_as_image(vertical_stroke(1,1))" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 8, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stdout", 317 | "output_type": "stream", 318 | "text": [ 319 | "[[0. 0. 0. 0. 0. 0. 0. 0.]\n", 320 | " [0. 0. 1. 1. 1. 1. 0. 0.]\n", 321 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 322 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 323 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 324 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 325 | " [0. 0. 0. 0. 0. 0. 0. 0.]\n", 326 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 327 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 328 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 329 | " [0. 0. 0. 0. 0. 0. 1. 0.]\n", 330 | " [0. 0. 0. 0. 0. 0. 0. 0.]\n", 331 | " [0. 0. 0. 0. 0. 0. 0. 0.]]\n" 332 | ] 333 | }, 334 | { 335 | "data": { 336 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAALwAAAElCAYAAABeV4iUAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAA51JREFUeJzt3bFtw0AQAEGdoT7cfwdux7l7eJcgCRL9MncmZnAAFw8GPHLWWheo+Ng9APwlwZMieFIET4rgSRE8KYInRfCkCJ6Uh4Kfma+jBoFn3NvmPPJqwcx4D4G3tdaaW9d4pCFF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInpTr7gGO4BPgrzVzc1X033DCkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJOeUS95mWjt/BEUvxu+6RE54UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZNy3T3AEdZau0c4lZnZPcLLOOFJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KSccon7TEvH7+CIpfhd98gJT4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KYInRfCkCJ4UwZMieFIET4rgSRE8KdfdAxxhrbV7hFOZmd0jvIwTnhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPyimXuM+0dMxrOeFJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQInhTBkyJ4UgRPiuBJETwpgidF8KQ8usT9c7lcvo8YBJ70ec9F4zftlHikIUXwpAieFMGTInhSBE+K4EkRPCmCJ+UX2gMhuHpdw20AAAAASUVORK5CYII=\n", 337 | "text/plain": [ 338 | "
" 339 | ] 340 | }, 341 | "metadata": {}, 342 | "output_type": "display_data" 343 | } 344 | ], 345 | "source": [ 346 | "# 0 ~ 9 사이의 숫자를 만들 수 있는 기본 획(가로, 세로)들 총 집합.\n", 347 | "# 가로 획 3개 + 세로 획 4개 = 총 7개 획\n", 348 | "BASE_STROKES = np.asarray(\n", 349 | " [horizontal_stroke(k) for k in range(3)] + [vertical_stroke(k, l) for k in range(2) for l in range(2)])\n", 350 | "\n", 351 | "# 기본 획들 총 집합 중에서 각 숫자를 만들기 위해 실제로 필요한 획 구성 \n", 352 | "DIGITS_STROKES = np.array([[0, 2, 3, 4, 5, 6], [5, 6], [0, 1, 2, 4, 5], [0, 1, 2, 5, 6], [1, 3, 5, 6], [0, 1, 2, 3, 6], [0, 1, 2, 3, 4, 6], [0, 5, 6], np.arange(7), [0, 1, 2, 3, 5, 6]])\n", 353 | "\n", 354 | "# 저 기본 획들과 각 숫자를 만들기 위한 조합을 이용해서 랜덤 숫자 이미지를 만들어 낸다.\n", 355 | "def random_digits(strokes=BASE_STROKES, digit_as_strokes=DIGITS_STROKES, fixed_label=None):\n", 356 | " label = fixed_label if fixed_label is not None else np.random.choice(len(digit_as_strokes))\n", 357 | " combined_strokes = strokes[digit_as_strokes[label], :, :].sum(axis=0)\n", 358 | " return combined_strokes, label\n", 359 | "\n", 360 | "def batch_images_to_one(batches_images):\n", 361 | " n_square_elements = int(np.sqrt(batches_images.shape[0]))\n", 362 | " rows_images = np.split(np.squeeze(batches_images), n_square_elements)\n", 363 | " return np.vstack([np.hstack(row_images) for row_images in rows_images])\n", 364 | "\n", 365 | "print(random_digits()[0])\n", 366 | "show_as_image(random_digits()[0])\n", 367 | "# show_as_image(random_digits()[0])\n", 368 | "# show_as_image(random_digits()[0])\n", 369 | "# show_as_image(random_digits(fixed_label=3)[0])\n", 370 | "# show_as_image(random_digits(fixed_label=4)[0])\n", 371 | "# show_as_image(random_digits(fixed_label=5)[0])\n", 372 | "# show_as_image(batch_images_to_one(np.stack([random_digits()[0] for _ in range(25)])), figsize=(9, 9))" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 9, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "from torch.utils.data import Dataset, DataLoader\n", 382 | "\n", 383 | "# 위에서 만든, LCD Digits 만들기 함수를 이용해서 Custom Dataset을 만든다.\n", 384 | "class LcdDigits(Dataset):\n", 385 | "\n", 386 | " def __init__(self, n_examples):\n", 387 | " digits, labels = zip(*[random_digits() for _ in range(n_examples)])\n", 388 | " self.digits = np.asarray(digits, dtype=np.float64)\n", 389 | " self.labels = np.asarray(labels)\n", 390 | " \n", 391 | " def __len__(self):\n", 392 | " return len(self.labels)\n", 393 | " \n", 394 | " def __getitem__(self, idx):\n", 395 | " digit_with_channel = self.digits[idx][np.newaxis, :, :]\n", 396 | " \n", 397 | " return torch.from_numpy(digit_with_channel).float(), torch.from_numpy(np.array([self.labels[idx]]))\n", 398 | "\n", 399 | "# next(b for b in DataLoader(LcdDigits(128), batch_size=3))" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "## Training" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 10, 412 | "metadata": {}, 413 | "outputs": [ 414 | { 415 | "name": "stdout", 416 | "output_type": "stream", 417 | "text": [ 418 | "Epoch [1/25], Loss: 0.6708\n", 419 | "Epoch [2/25], Loss: 0.1917\n", 420 | "Epoch [3/25], Loss: 0.0633\n", 421 | "Epoch [4/25], Loss: 0.0360\n", 422 | "Epoch [5/25], Loss: 0.0307\n", 423 | "Epoch [6/25], Loss: 0.0290\n", 424 | "Epoch [7/25], Loss: 0.0280\n", 425 | "Epoch [8/25], Loss: 0.0271\n", 426 | "Epoch [9/25], Loss: 0.0260\n", 427 | "Epoch [10/25], Loss: 0.0253\n", 428 | "Epoch [11/25], Loss: 0.0254\n", 429 | "Epoch [12/25], Loss: 0.0252\n", 430 | "Epoch [13/25], Loss: 0.0252\n", 431 | "Epoch [14/25], Loss: 0.0250\n", 432 | "Epoch [15/25], Loss: 0.0250\n", 433 | "Epoch [16/25], Loss: 0.0252\n", 434 | "Epoch [17/25], Loss: 0.0253\n", 435 | "Epoch [18/25], Loss: 0.0244\n", 436 | "Epoch [19/25], Loss: 0.0244\n", 437 | "Epoch [20/25], Loss: 0.0247\n", 438 | "Epoch [21/25], Loss: 0.0242\n", 439 | "Epoch [22/25], Loss: 0.0249\n", 440 | "Epoch [23/25], Loss: 0.0242\n", 441 | "Epoch [24/25], Loss: 0.0242\n", 442 | "Epoch [25/25], Loss: 0.0243\n" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "import torch.nn.functional as F\n", 448 | "\n", 449 | "N_EPOCHS = 25\n", 450 | "BATCH_SIZE = 128\n", 451 | "LR = 0.005\n", 452 | "\n", 453 | "cnn = PixelCNN()\n", 454 | "optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)\n", 455 | "\n", 456 | "train_dataset = LcdDigits(BATCH_SIZE * 50)\n", 457 | "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)\n", 458 | "\n", 459 | "for epoch in range(N_EPOCHS):\n", 460 | " for i, (images, _) in enumerate(train_loader):\n", 461 | " images = images # BATCH_SIZE x 1 x 13 x 8\n", 462 | " pixelCNN_out = cnn(images) # BATCH_SIZE x 2 x 13 x 8\n", 463 | " pixelCNN_target = torch.squeeze(images).long() # BATCH_SIZE x 13 x 8\n", 464 | " optimizer.zero_grad()\n", 465 | " loss = F.cross_entropy(input=pixelCNN_out, target=pixelCNN_target)\n", 466 | " loss.backward()\n", 467 | " optimizer.step()\n", 468 | " \n", 469 | " if i % 100 == 0:\n", 470 | " print ('Epoch [%d/%d], Loss: %.4f' \n", 471 | " %(epoch+1, N_EPOCHS, loss.item()))" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "## input 이미지를 pixel-by-pixel 흝고 지나가면서 이미지 생성해보기\n", 479 | "
\n", 480 | "Sequentially generating new samples " 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 11, 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkMAAAOgCAYAAAA+nnIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGE1JREFUeJzt3UGO7Eq1QNEwqinQZv7Dos8c/BtIpSv4L7m3yo5ynr1W9z1snwpnshUoyOM8zwUAUPW3n34AAICfJIYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCAtI8/+ZeP4/DbHQDAu/jXeZ5//1//kp0hAGCqf/7OvySGAIA0MQQApIkhACBNDAEAaX90moz/33nuPWR3HMfW++2cb/ds03k3r2W+a03+brF217p7PjtDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAg7bKj9a+O2T3pyN8dz2K+6zzpPbrDHfP5m+0z+bO31uz5fPb+3PT5fmVnCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSPq660HEcf/nPzvO86ja/5dWz3MF819k92253zDf9szd9vlfMdx3v5rXXnDDfr+wMAQBpYggASBNDAECaGAIA0sQQAJAmhgCAtMuO1r/ypOOau48f3mHyfLtn223y2q1lvqdc80mmz/dXvJvXuvtZ7AwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNrHTz/AHY7j2Hq/8zy33m/yfLtn223y2q1l/a6+5oT1e8p8PnvXXnPCfL+yMwQApIkhACBNDAEAaWIIAEgTQwBAmhgCANJGHq3fzXFinsravbfp6zd5vsmzrTVvPjtDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApG351frzPHfc5tPuX9M133Umz7aW+a5mvmtNnm/ybGuZ77vsDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLQtR+ufdOTvq89yxzW/avez7Jxv8mxrzX83d5uwfq+Y7zrTP3uT126t++ezMwQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaR87bnKe547bfDqOY+s1J8z3ys75Js+2lnfz3Xk/35fP3nu7ez47QwBAmhgCANLEEACQJoYAgDQxBACkiSEAIG3L0fonHUf96rPccc2vmj7fThPW7pXp8+1m/d6XtXtvd89nZwgASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAEDax46bnOe54zafjuPYek3z7XmOO0xYu1emz7eb9Xtf1u693T2fnSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJC25Wi9I3/vbfJ8k2dba/5801m/92Xt3oudIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANK2/Gr9eZ47bvPJrwVfa+f6Wbtr+ey9t+nrN/m7ZfraTWNnCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApG05Wv+kI41ffZY7rvlVk+d7ynPcZfLarfWsZ7mD9XvGNd/hOazdte5+FjtDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJD28dMPcIfjOLZe8zzPy+/3yuT57pjtSXav3W5PepY7TP7srTV7Pu/mtdec8G7+ys4QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIG3m0/tWRvwlHmyfPt3u23e6Y70l/syc9yx0mf/bWmj2fd/Na095NO0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkPbx0w9wh+M4fvoRbjV5vsmzrXXPfK+ueZ7n5fd7xfq9t8nzTZ5trf3zTftusTMEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSRh6tB/5t+nFi4GdM+26xMwQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAEDall+tP89zx20+7f41XfNdZ/Jsa837peefNn39zHedybOtZb7vsjMEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSthytf9KRv68+yx3XfBc755u+dk96ljvsnm/C+nknZj6Hd/O92BkCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLSPn36AOxzHsfWa53lefr9X7pjvlZ3z7V673Z70LHeY/G6u5bvl6mu++3fLK9buWnevn50hACBNDAEAaWIIAEgTQwBAmhgCANLEEACQNvJo/asjfxOObt8x3+6/2U5Pmu1Jz3KH6e+m75Zr7Zxv+ru527R3084QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKR9/PQD3OE4jq33O89z6/3umO/VNXfON3m2tfa/m7tZv2uZ7zrezfd293x2hgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNrIo/W7OdL4vibPVjB9/cz3vibPNpGdIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkOZo/c1e/XKxo5cA8PPsDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDRH6y/w6vj8Hf+53Ufyv/qcXzF5trXMdzXzXWvyfJNnW8t832VnCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApI08Wn/HL8U/6dfndz/Lzvms3TOu+VXme8Y1v8p3y3UmrF3p3bQzBACkiSEAIE0MAQBpYggASBNDAECaGAIA0kYerS970lFIAHgHdoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECao/XDOD4PAH/GzhAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpI3+1/o5fbn91zfM8t15z8i/Tv8vafZX5rr3m9Pl22/0sO9dv8mxr+ex9l50hACBNDAEAaWIIAEgTQwBAmhgCANLEEACQNvJo/R3H0r96zTuOH04+dr97ticdtzXfn5s+326T12/ybGuZ77vsDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLSRR+ufZMJxWwCYzM4QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIc7T+Zl/9pd1X/+zVNaf/qjYAXM3OEACQJoYAgDQxBACkiSEAIE0MAQBpYggASHO0/gJPOs7u+DwA/Bk7QwBAmhgCANLEEACQJoYAgDQxBACkiSEAIG3k0fo7jpff8SvyX73fZLvn/ur6fJX5rjV9vt0mr9/k2dYy33fZGQIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGkjj9bv5pfp39f0v6X5eLLJ6zd5trXmzWdnCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQtuWHWs/z3HGbT7t/QM5815k821rzftzwp01fP/NdZ/Jsa5nvu+wMAQBpYggASBNDAECaGAIA0sQQAJAmhgCAtMuO1r86ZvekI39ffRbz7bH7OSas3ZNMfjfX8n5e7d2/W6a/m0+a7252hgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAg7eOqCx3HcdWlvu2OZ3l1zfM8L7/fK5Pn2/0eTVi7J5n8bq7l/bzau3+3POnz/KT57ljXu//WdoYAgDQxBACkiSEAIE0MAQBpYggASBNDAEDaZUfrXx2le9JxzTuOCk6Y75Wd802eba1nvUd38Nl7xjWfZPp8T/Gkd/OOdb37PbIzBACkiSEAIE0MAQBpYggASBNDAECaGAIA0i47Wg8AzFH6vz6wMwQApIkhACBNDAEAaWIIAEgTQwBAmhgCANIcrQcA/svdvxT/JHaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASLvsV+tf/YLtq1++vcMdv6Y7fb5Xds43eba15v3S83/y2bv2mubb8xwTvMvaffWad6+fnSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJB22dH6VxxpfG+T55s8W8H09TMfT7X7KP/d7AwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQtuVX68/z3HGbT34J+Vo712/32k1/N813LfNdy3fLdcz3PXaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2paj9U868vfVZ7njml81fb6rTZ5tLfO9uyfN57vlOtbuvdgZAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIC0jx03Oc9zx20+Hcex9Zrm2/Mcd1zT2l3LfM+45ldNXr/df+fp7+aE+X5lZwgASBNDAECaGAIA0sQQAJAmhgCANDEEAKRtOVr/pCONE47G3jHf7r/ZU55jwtq9Yr5rTf7srTV7/SbPtpZ387vsDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2seOm5znueM2n47j2Hq/CfO9uubO+azdtcx3rcmfvbVmr9/k2dbybn6XnSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJC25Wj97iONu5nvfU2ebS3zvTvzva/Js601bz47QwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKRt+dX68zx33ObT7l/TnT7fZNPXbvp8001fv53zeTevNe3dtDMEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSthytf9JxzTueZcJ8u/9mO5/jKbP9BO/mtcx3Ld8t15m8dmvdP5+dIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABI+9hxk/M8d9zm03EcW++32x3zvbrmzvWbPNta+9/NCfNNX7/p873iu+U6/nvve+wMAQBpYggASBNDAECaGAIA0sQQAJAmhgCAtMuO1r86Rjj9yN9uu//WO9dv8mxr3TNf+bM3Yf1eMd91Js+21vzvgbvnszMEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGkfV13oOI6//GfneV51m9/y6lkm2D3fzvWbPNta98zns7eP9/Navluu47P3PXaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2mVH61+ZfuRvusnrN3m2tebPN9309Zs83+TZJrIzBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNqWX60/z3PHbT7t/rVg811n8mxrme9q5rvW5Pkmz7aW+b7LzhAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEi77Gj9q2N2Tzryd8ezmO86k2db6575nvTZ223C+r0yYb7J7+fk2f6XCe/mr+wMAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASPu46kLHcfzlPzvP86rb/JZXz3IH811n8mxr3TNf+bO3m/fz2mu++3fLU2Zby7v5XXaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2mVH61950pG/3ccr7zB5vsmzrWW+d2f9ruW75Tp3zDf98/wrO0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkPax4ybnee64zafjOLbez3zXmTzbWuZ7d9bvWr5brnPHfK+uOWG+X9kZAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaVuO1jtu+94mzzd5trXmzzfd9PWbPN/k2daaN5+dIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANK2/Gr9eZ47bvNp96/pmu86k2dby3xXM9+1Js83eba15v2K/G52hgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNqWo/VPOtJ4x7OY7zp3zLb77/XK5LVby3xXmzDfkz5/V5s821rPmu/uZ7EzBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpHz/9AHc4juOnH+FWk+e7Y7Yn/b12P8t5nlvv96S/9R2s37XX3Dnf5NnWMt932RkCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpI4/WvzryN+Ho7+T57pjtSX+v6fNNZ/2utXO+3X/nCd8tr0ybz84QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKR9/PQD3OE4jp9+hFtNnu+O2V5d8zzPy+/3ivnem/W71s75Js+2lvm+y84QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIG3m0Hn6Xo+A82fT1mzzf5NnWmjefnSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJD2p0fr/7XW+ucdDwIAcLF//M6/dJznefeDAAA8lv+ZDABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEj7PwxJzmr6HvY2AAAAAElFTkSuQmCC\n", 491 | "text/plain": [ 492 | "
" 493 | ] 494 | }, 495 | "metadata": {}, 496 | "output_type": "display_data" 497 | } 498 | ], 499 | "source": [ 500 | "# pixelCNN로 새로운 샘플 만들기\n", 501 | "# 어떤 특정한 이미지를 pixelCNN의 input으로 줄 수도 있고 아니면 그냥 zeros를 넣어 줄 수도 있다.\n", 502 | "# pixelCNN은 input으로 주어진 이미지를 Masked Convolution으로 흝고 지나가면서 자신의 기억속에 존재하는, 가장 그럴듯한 픽셀 값을 예측해 낸다.\n", 503 | "def generate_samples(n_samples, starting_point=(0, 0), starting_image=None):\n", 504 | " samples = torch.from_numpy(\n", 505 | " starting_image if starting_image is not None else np.zeros((n_samples * n_samples, 1, IMAGE_WIDTH, IMAGE_HEIGHT))).float()\n", 506 | "\n", 507 | " cnn.train(False)\n", 508 | "\n", 509 | " # pixelCNN이 예측한 pixel-level distribution으로 부터 픽셀 값(0 또는 1)을 샘플링해서 이미지를 만들어 낸다.\n", 510 | " for i in range(IMAGE_WIDTH):\n", 511 | " for j in range(IMAGE_HEIGHT):\n", 512 | " if i < starting_point[0] or (i == starting_point[0] and j < starting_point[1]):\n", 513 | " continue\n", 514 | " out = cnn(samples)\n", 515 | " probs = F.softmax(out[:, :, i, j],1).data\n", 516 | " samples[:, :, i, j] = torch.multinomial(probs, 1).float()\n", 517 | " return samples.numpy()\n", 518 | "\n", 519 | "show_as_image(batch_images_to_one(generate_samples(n_samples=10)), figsize=(10, 20))" 520 | ] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "metadata": {}, 525 | "source": [ 526 | "## 이미지의 일부분을 input으로 주고, 나머지 부분을 완성하기\n", 527 | "
\n", 528 | "Or completing existing cropped image\n", 529 | "\n", 530 | " * $0, 8, 9$ and $2, 3, 7$ undistinguishable early one\n", 531 | " * Very small amount of noise (jitter) in samples\n", 532 | " * The last horizontal bar is hard to predict as it depends on the 1st horizontal bar\n", 533 | " * ($4, 9$) sometimes lead to incomplete or erroneous images because of the long term dependency between the upper and lower horizontal bars (could be improved by increasing the receptive field with more layers or \"à trous\" convolutions)" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": 12, 539 | "metadata": {}, 540 | "outputs": [ 541 | { 542 | "data": { 543 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkMAAABwCAYAAAAdSHSxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAAy9JREFUeJzt3cFx2zAURVEg4xayThMukCrQRXjvHpCtVyIlUST43zlrJNAXmMwdZhD3MUYDAEj15+wPAABwJjEEAEQTQwBANDEEAEQTQwBANDEEAEQTQwBANDEEAEQTQwBANDEEAET7eGRx793P7gAAruJnjPF3bZE3QwBAVd9bFokhACCaGAIAookhACCaGAIAoj10m+xZYxx7Ca33fuh+1R15fkefnWfz2qqfn/n2U3m21sz3Km+GAIBoYggAiCaGAIBoYggAiCaGAIBoYggAiLbb1fp71+xmuvJX4Wrz0fMd+Z1Vnq2198w30/Nefb53qD7fPVf/u2WmszPfa7wZAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCi7fZT62dyu93O/ghvVXm+yrO1Zj44S/Vn03yv8WYIAIgmhgCAaGIIAIgmhgCAaGIIAIgmhgCAaLtdra9+rQ9mVf3PnvmYVfWzqz7fb94MAQDRxBAAEE0MAQDRxBAAEE0MAQDRxBAAEK2PMbYv7n37YspYlqXkXvazn/3O3a+y6md3of2+xhifa4u8GQIAookhACCaGAIAookhACCaGAIAookhACDablfr7117m+kKXoWro++Yr/p3Novq37P55vg9nzXTZ9mbs7u2F+ZztR4AYI0YAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIJoYAgCiiSEAIFofY2xf3Pv2xb8sy/LML3va0ftVd+T3Wf1Z8Wzuq/r5Vd+vsupnd6H9vsYYn2uLvBkCAKKJIQAgmhgCAKKJIQAgmhgCAKKJIQAg2iFX6wEATuBqPQDAGjEEAEQTQwBANDEEAEQTQwBANDEEAET7eHD9T2vt+x0fBABgZ/+2LHro/xkCAKjGP5MBANHEEAAQTQwBANHEEAAQTQwBANHEEAAQTQwBANHEEAAQTQwBANH+AyEU1tXo+nQXAAAAAElFTkSuQmCC\n", 544 | "text/plain": [ 545 | "
" 546 | ] 547 | }, 548 | "metadata": {}, 549 | "output_type": "display_data" 550 | }, 551 | { 552 | "data": { 553 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkMAAAOgCAYAAAA+nnIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGQlJREFUeJzt3c1x6zqbqFHy1k6hx52EA6QCdBA97xxwp560pE8/EIVnrTG8oZekXU/xFI72McYGAFD1/z79AQAAPkkMAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGn//pPF+7777g4A4Fv87xjjv24t8mYIAFjV/9yzSAwBAGliCABIE0MAQJoYAgDS/qPTZI8aY+4htH3fp+63upn3b/a982x+t9Xvn/leZ+XZts18z/JmCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApL3saP21Y3ZnOvK3wtHm2fPNvGYrz7Zt75nvTM/76vO9w+rzXfPtf1vOdO/M9xxvhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEh72bfWn8nlcvn0R3irledbebZtMx98yurPpvme480QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIe9nR+tWP9cFZrf67Zz7OavV7t/p8f3kzBACkiSEAIE0MAQBpYggASBNDAECaGAIA0vYxxv2L9/3+xfAFjuOwn/3sF9hv9mycxu8Y4+fWIm+GAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2suO1l87tnim45orHK9ceb53zLby9Tqb1a/16s/nmT7Lq7l35/g3H/XEZ3G0HgDgFjEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgbR9j3L943+9f/MdxHI/82MNm78f3Wv3ZtJ/97Dd/L/udar/fMcbPrUXeDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLQpR+sBAD7A0XoAgFvEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAg7d+MTcaY+2X3+75P3W91M+/f7Hvn2fxuq98/873OyrNtm/me5c0QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIe9nR+mvH7M505G+Fo82z55t5zVaebdveM9+ZnvfV53uH1ee75tv/tpzp3pnvOd4MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkPayb60/k8vl8umP8FYrz7fybNtmPviU1Z9N8z3HmyEAIE0MAQBpYggASBNDAECaGAIA0sQQAJD2sqP1qx/rg7Na/XfPfJzV6vdu9fn+8mYIAEgTQwBAmhgCANLEEACQJoYAgDQxBACk7WOM+xfv+/2LWcZxHEvuZT/72a+z38qz2e+q3zHGz61F3gwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIC0lx2tv3bs7UxH8GZ/lnd4x3wrX7OVZzub1a+1+b7XyrNtm/mucLQeAOAWMQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApO1jjPsX7/v9i/84juORH3vY7P1WN/N6rv6s2M9+9vvMfivPZr+rfscYP7cWeTMEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSphytBwD4AEfrAQBuEUMAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLR/MzYZY+6X3e/7PnW/1c28f7PvnWfzu61+/8z3OivPtm3me5Y3QwBAmhgCANLEEACQJoYAgDQxBACkiSEAIO1lR+uvHbM705G/FY42z55v5jVbebZte898Z3reV5/vHVaf75pv/9typntnvud4MwQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAEDay761/kwul8unP8JbrTzfyrNtm/ngU1Z/Ns33HG+GAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2suO1q9+rA/OavXfPfNxVqvfu9Xn+8ubIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkLaPMe5fvO/3L/7jOI5Hfuxhs/eDs1r9d89+9jvjXgVf9Kz8jjF+bi3yZggASBNDAECaGAIA0sQQAJAmhgCANDEEAKS97Gj9tWNvZzqCt8LxypXnW3m2bTPftzPf93rHbGe6Xmf6LO/wxHyO1gMA3CKGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApO1jjPsX7/v9i/84juORH3vY7P2Az1j9b4v9vnMv+51qv98xxs+tRd4MAQBpYggASBNDAECaGAIA0sQQAJAmhgCAtClH6wEAPsDRegCAW8QQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACDt34xNxpj7Zff7vk/db3Uz79/se+fZ/G6r3z/zvc7Ks22b+Z7lzRAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEh72dH6a8fsznTkb4WjzbPnm3nNVp5t294z35me99Xne4fV57vm2/+2nOneme853gwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQ9rJvrT+Ty+Xy6Y/wVivPt/Js22Y++JTVn03zPcebIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkPayo/WrH+uDs1r9d898nNXq9271+f7yZggASBNDAECaGAIA0sQQAJAmhgCANDEEAKTtY4z7F+/7/Yv/OI7jkR972Oz9+F6rP5v2s5/95u9V2O+L/I4xfm4t8mYIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkTTlaP9u1I4YrHD9ceb6VZ9s28307832vlWfbtvfMd6Zr9sRncbQeAOAWMQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBtH2Pcv3jf71/8x3Ecj/zYw2bvx/fybH631e+f/b5zr4IvelZ+xxg/txZ5MwQApIkhACBNDAEAaWIIAEgTQwBAmhgCANKmHK0HAPgAR+sBAG4RQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCAtH8zNhlj7pfd7/s+db/Vzbx/s++dZ/O7rX7/zPc6K8+2beZ7ljdDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAg7WVH668dszvTkb8VjjbPnm/mNVt5tm17z3xnet5Xn+8dVp/vmm//23Kme2e+53gzBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNrLvrX+TC6Xy6c/wlutPN/Ks22b+eBTVn82zfccb4YAgDQxBACkiSEAIE0MAQBpYggASBNDAEDay47Wr36sD85q9d8983FWq9+71ef7y5shACBNDAEAaWIIAEgTQwBAmhgCANLEEACQto8x7l+87/cv/uM4jkd+7GGz91vdzOu5+rNiP/vZ7zP7rTyb/a76HWP83FrkzRAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEibcrR+tmtH8FY4dv+O+Va/Zmex+nU23zn+zUed6bO82sqzbZv5rnC0HgDgFjEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgbR9j3L943+9fzDKO41hyL/vZz36d/WbPxmn8jjF+bi3yZggASBNDAECaGAIA0sQQAJAmhgCANDEEAKQ5Wg8ArMrRegCAW8QQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApP2bscl/8mWwr7Dv+9T9Vjfz/s2+d57N77b6/TPf66w827aZ71neDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLSXHa2/dszuTEf+VjjaPHu+mdds5dm27T3znel5X32+d1h9vmu+/W/Lme6d+Z7jzRAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpL/vW+jO5XC6f/ghvtfJ8K8+2beaDT1n92TTfc7wZAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaS87Wr/6sT44q9V/98zHWa1+71af7y9vhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNo+xrh/8b7fv5hlHMex5F72s5/9Prvfyla/d1+03+8Y4+fWIm+GAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2suO1l879namI3grHB1deb6VZ9s28307832vd8x2put1ps/yDk/M52g9AMAtYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAEDaPsa4f/G+37/4j+M4Hvmxh83eb3Uzr+fqz4r97Ge/htXv3Rft9zvG+Lm1yJshACBNDAEAaWIIAEgTQwBAmhgCANLEEACQNuVoPQDABzhaDwBwixgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKT9m7HJGHO/7H7f96n7rW7m/Zt97zyb3231+2e+11l5tm0z37O8GQIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGkvO1p/7ZjdmY78rXC0efZ8M6/ZyrNt23vmO9Pzvvp877D6fNd8+9+WM9078z3HmyEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSXvat9WdyuVw+/RHeauX5Vp5t28wHn7L6s2m+53gzBACkiSEAIE0MAQBpYggASBNDAECaGAIA0l52tH71Y31wVqv/7pmPs1r93q0+31/eDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLR9jHH/4n2/fzHLOI5jyb3sZz/7dfZbeTb7XfU7xvi5tcibIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkLbk0fprR/BmHwd8h3fMt/I1W3m2s1n9WpvvHP/mWZxpttXv3ROfxdF6AIBbxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIC0fYxx/+J9v38xyziOY8m97Gc/+3X2W3k2+131O8b4ubXImyEAIE0MAQBpYggASBNDAECaGAIA0sQQAJDmaD0AsCpH6wEAbhFDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIC0fzM2GWPul93v+z51v9XNvH+z751n87utfv/M9zorz7Zt5nuWN0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACDtZUfrrx2zO9ORvxWONs+eb+Y1W3m2bXvPfGd63lef7x1Wn++ab//bcqZ7Z77neDMEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBA2su+tf5MLpfLpz/CW60838qzbZv54FNWfzbN9xxvhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNrLjtavfqwPzmr13z3zcVar37vV5/vLmyEAIE0MAQBpYggASBNDAECaGAIA0sQQAJC2jzHuX7zv9y/+4ziOR37sYbP3W93M67n6s2I/+9nvM/utPJv9rvodY/zcWuTNEACQJoYAgDQxBACkiSEAIE0MAQBpYggASJtytH62a0fwVjh2v/J8K8+2beY7y7/5qDN9lndYeb6VZ9s2813haD0AwC1iCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNo+xrh/8b7fv/iP4zge+bGHzd5vdTOv5+rPiv3sZ7/P7LfybPa76neM8XNrkTdDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgbcrRegCAD3C0HgDgFjEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEj7N2OTMeZ+2f2+71P3W93M+zf73nk2v9vq9898r7PybNtmvmd5MwQApIkhACBNDAEAaWIIAEgTQwBAmhgCANJedrT+2jG7Mx35W+Fo8+z5Zl6zlWfbtvfMd6bnffX53mH1+a759r8tZ7p35nuON0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLSXfVHrmVwul09/hLdaeb6VZ9s288GnrP5smu853gwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIC0lx2tX/1YH5zV6r975uOsVr93q8/3lzdDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgbR9j3L943+9f/MdxHI/82MNm7wdntfrvnv2+ez+Y4HeM8XNrkTdDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgbcrR+tmuHQ9d4ejoyvOtPNu2me/bme97vWO2M10v8/2fHK0HALhFDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEjbxxj3L973+xf/cRzHIz/2sNn7AZ+x+t+W1fdb2er37ov2+x1j/Nxa5M0QAJAmhgCANDEEAKSJIQAgTQwBAGliCABIm3K0HgDgAxytBwC4RQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANL+zdhkjLlfdr/v+9T9Vjfz/s2+d57N77b6/TPf66w827aZ71neDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgLSXHa2/dszuTEf+VjjaPHu+mdds5dm27T3znel5X32+d1h9vmu+/W/Lme6d+Z7jzRAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIE0MAQBpL/vW+jO5XC6f/ghvtfJ8K8+2beaDT1n92TTfc7wZAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaS87Wr/6sT44q9V/98zHWa1+71af7y9vhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQNo+xrh/8b7fv/iP4zge+bGHzd5vdTOv5+rPiv3sZ7/P7LfybJ/Y74v8jjF+bi3yZggASBNDAECaGAIA0sQQAJAmhgCANDEEAKRNOVo/27UjhiscP1x5vpVn2zbzfTvzfa+VZ9u298x3pmv2xGdxtB4A4BYxBACkiSEAIE0MAQBpYggASBNDAECaGAIA0sQQAJAmhgCANDEEAKSJIQAgTQwBAGliCABIE0MAQJoYAgDSxBAAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkiSEAIG0fY9y/eN/vX/zHcRyP/NjDZu+3upnXc/VnxbP5WqvfP/t9514FX/Ss/I4xfm4t8mYIAEgTQwBAmhgCANLEEACQJoYAgDQxBACkTTlaDwDwAY7WAwDcIoYAgDQxBACkiSEAIE0MAQBpYggASPv3H67/323b/ucdHwQA4MX++55F/9H/ZwgAYDX+MxkAkCaGAIA0MQQApIkhACBNDAEAaWIIAEgTQwBAmhgCANLEEACQ9v8BqKrJIRuq+QMAAAAASUVORK5CYII=\n", 554 | "text/plain": [ 555 | "
" 556 | ] 557 | }, 558 | "metadata": {}, 559 | "output_type": "display_data" 560 | } 561 | ], 562 | "source": [ 563 | "n_images = 10\n", 564 | "starting_point = (4, 3)\n", 565 | "\n", 566 | "mask = causal_mask(IMAGE_WIDTH, IMAGE_HEIGHT, starting_point)\n", 567 | "\n", 568 | "starting_images = digits_list = [random_digits(fixed_label=d)[0] for d in range(10)]\n", 569 | "batch_starting_images = np.expand_dims(np.stack([i * mask for i in starting_images] * n_images), axis=1)\n", 570 | "\n", 571 | "samples = generate_samples(n_images, starting_image=batch_starting_images, starting_point=starting_point)\n", 572 | "\n", 573 | "show_as_image(np.hstack([(1 + mask) * i for i in starting_images]), figsize=(10, 10))\n", 574 | "\n", 575 | "show_as_image(\n", 576 | " batch_images_to_one((samples * (1 + mask))),\n", 577 | " figsize=(10, 20))" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [] 593 | } 594 | ], 595 | "metadata": { 596 | "kernelspec": { 597 | "display_name": "Python 3", 598 | "language": "python", 599 | "name": "python3" 600 | }, 601 | "language_info": { 602 | "codemirror_mode": { 603 | "name": "ipython", 604 | "version": 3 605 | }, 606 | "file_extension": ".py", 607 | "mimetype": "text/x-python", 608 | "name": "python", 609 | "nbconvert_exporter": "python", 610 | "pygments_lexer": "ipython3", 611 | "version": "3.6.6" 612 | } 613 | }, 614 | "nbformat": 4, 615 | "nbformat_minor": 2 616 | } 617 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Hello-Generative-Model 3 | 4 | ## Day1. Introduction to (Classic) Generative Model 5 | 6 | #### 0) Pytorch Tutorial - [HelloPytorch](https://github.com/InsuJeon/HelloPyTorch) 7 | #### 1) Linear Regression 8 | #### 2) Logistic Regression 9 | #### 3) Gaussian Discriminant Analysis 10 | #### 4) GMM(Gaussian Mixture Model) with EM algorithm 11 | 12 | 13 | ## Day2. Introduction to Varitional Inference / Probabilistic Neural Network 14 | 15 | #### 1) Variational Coin Toss - [related blog](http://www.openias.org/variational-coin-toss) 16 | #### 2) MNIST classification with Probabilistic (Layer) Neural Network 17 | 18 | 19 | ## Day3. Introduction to Varitional Auto-Encoder(VAE) 20 | #### 1) AutoEncoder 21 | #### 2) Varitional AutoEncoder - [code](https://github.com/GunhoChoi/PyTorch-FastCampus/tree/master/08_Autoencoder) 22 | #### 3) CVAE 23 | 24 | 25 | ## Day4. Introduction to Generative Adversarial Networks(GAN) 26 | #### 1) GAN 27 | #### 2) DCGAN 28 | 29 | 30 | ## Day5. Improved GAN 31 | #### 1) infoGAN 32 | #### 2) WGAN 33 | #### Checkout Other Generative Model Collections [Here](https://github.com/znxlwm/pytorch-generative-model-collections) 34 | 35 | 36 | ## Day6. Application of Deep Generative Model 37 | #### 1) CVAE 38 | #### 2) AAE 39 | #### 3) CycleGAN [Original code](https://github.com/togheppi/CycleGAN) 40 | 41 | ## Day7. Other Important Deep Generative Model 42 | #### 1) PixelCNN [original code](https://github.com/pilipolio/learn-pytorch) 43 | #### 2) Mixture Density Network [original code](https://github.com/hardmaru/pytorch_notebooks) 44 | 45 | --------------------------------------------------------------------------------