├── DDPM_notebook.ipynb ├── Neural_ODE.ipynb ├── README.md ├── base ├── mnist.py ├── model.py └── train.py ├── flowmatching.ipynb ├── pi_digit_estimation_GRU.ipynb ├── vae.ipynb └── vqvae.ipynb /Neural_ODE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "gpuClass": "standard", 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": { 24 | "colab": { 25 | "base_uri": "https://localhost:8080/" 26 | }, 27 | "id": "ysYtGUmudhFw", 28 | "outputId": "ea41ca86-4ca0-4225-a0be-561e8d5ceda2" 29 | }, 30 | "outputs": [ 31 | { 32 | "output_type": "stream", 33 | "name": "stdout", 34 | "text": [ 35 | "Collecting torchdiffeq\n", 36 | " Downloading torchdiffeq-0.2.3-py3-none-any.whl (31 kB)\n", 37 | "Requirement already satisfied: torch>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from torchdiffeq) (2.0.1+cu118)\n", 38 | "Requirement already satisfied: scipy>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from torchdiffeq) (1.10.1)\n", 39 | "Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /usr/local/lib/python3.10/dist-packages (from scipy>=1.4.0->torchdiffeq) (1.22.4)\n", 40 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (3.12.2)\n", 41 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (4.6.3)\n", 42 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (1.11.1)\n", 43 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (3.1)\n", 44 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (3.1.2)\n", 45 | "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (2.0.0)\n", 46 | "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.3.0->torchdiffeq) (3.25.2)\n", 47 | "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.3.0->torchdiffeq) (16.0.6)\n", 48 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.3.0->torchdiffeq) (2.1.3)\n", 49 | "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.3.0->torchdiffeq) (1.3.0)\n", 50 | "Installing collected packages: torchdiffeq\n", 51 | "Successfully installed torchdiffeq-0.2.3\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "!pip install torchdiffeq" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "from torchdiffeq import odeint_adjoint as odeint\n", 63 | "import torch\n", 64 | "import torch.nn as nn\n", 65 | "import torch.nn.functional as F\n", 66 | "import torch.optim as optim\n", 67 | "import numpy as np\n", 68 | "import os\n", 69 | "\n", 70 | "device = torch.device(\"cuda\")" 71 | ], 72 | "metadata": { 73 | "id": "1mRK6nlmeHz2" 74 | }, 75 | "execution_count": null, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "source": [ 81 | "class TargetFunction(nn.Module):\n", 82 | " def __init__(self):\n", 83 | " super().__init__()\n", 84 | " self.true_function = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)\n", 85 | "\n", 86 | " def forward(self, t, x):\n", 87 | " return torch.mm(x**3, self.true_function)\n", 88 | "\n", 89 | "starting_point = torch.tensor([[2.0, 0.]]).to(device)\n", 90 | "t = torch.linspace(0., 25., 1000).to(device)\n", 91 | "target_func = TargetFunction().to(device)\n", 92 | "\n", 93 | "with torch.no_grad():\n", 94 | " true_dxdt = odeint(target_func, starting_point, t)\n", 95 | "\n", 96 | "print(\"> Starting point = {}\".format(starting_point.squeeze()))\n", 97 | "print(\"> t.shape = {}\".format(t.shape))\n", 98 | "print(\"> true_dxdt.shape = {}\".format(true_dxdt.shape))" 99 | ], 100 | "metadata": { 101 | "colab": { 102 | "base_uri": "https://localhost:8080/" 103 | }, 104 | "id": "_8mN11OUe-Gl", 105 | "outputId": "36ac52d5-c8fc-4089-f9ae-b4b40281ccdc" 106 | }, 107 | "execution_count": null, 108 | "outputs": [ 109 | { 110 | "output_type": "stream", 111 | "name": "stdout", 112 | "text": [ 113 | "> Starting point = tensor([2., 0.], device='cuda:0')\n", 114 | "> t.shape = torch.Size([1000])\n", 115 | "> true_dxdt.shape = torch.Size([1000, 1, 2])\n" 116 | ] 117 | } 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "source": [ 123 | "def get_batch():\n", 124 | " random_points = np.random.choice(np.arange(990, dtype=np.int64), 768, replace=False)\n", 125 | " batch_starting_point = true_dxdt[random_points]\n", 126 | " batch_t = t[:10]\n", 127 | " batch_dxdt = torch.stack([true_dxdt[random_points + i] for i in range(10)], dim=0)\n", 128 | "\n", 129 | " return batch_starting_point, batch_t, batch_dxdt" 130 | ], 131 | "metadata": { 132 | "id": "gP88LQRLlFY8" 133 | }, 134 | "execution_count": null, 135 | "outputs": [] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "source": [], 140 | "metadata": { 141 | "id": "-YpdMmXX0Q09" 142 | } 143 | }, 144 | { 145 | "cell_type": "code", 146 | "source": [ 147 | "class PredictFunction(nn.Module):\n", 148 | " def __init__(self):\n", 149 | " super().__init__()\n", 150 | "\n", 151 | " self.net = nn.Sequential(\n", 152 | " nn.Linear(2, 16),\n", 153 | " nn.ReLU(),\n", 154 | " nn.Linear(16, 50),\n", 155 | " nn.ReLU(),\n", 156 | " nn.Linear(50, 50),\n", 157 | " nn.ReLU(),\n", 158 | " nn.Linear(50, 16),\n", 159 | " nn.ReLU(),\n", 160 | " nn.Linear(16, 2)\n", 161 | " )\n", 162 | "\n", 163 | " def forward(self, t, x):\n", 164 | " return self.net(x)" 165 | ], 166 | "metadata": { 167 | "id": "i8eSU_94izdL" 168 | }, 169 | "execution_count": null, 170 | "outputs": [] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "source": [ 175 | "import matplotlib.pyplot as plt\n", 176 | "def visualize(true_y, pred_y, odefunc, itr):\n", 177 | " fig = plt.figure(figsize=(12, 4), facecolor='white')\n", 178 | " ax_traj = fig.add_subplot(131, frameon=False)\n", 179 | " ax_phase = fig.add_subplot(132, frameon=False)\n", 180 | " ax_vecfield = fig.add_subplot(133, frameon=False)\n", 181 | " ax_traj.cla()\n", 182 | " ax_traj.set_title('Trajectories')\n", 183 | " ax_traj.set_xlabel('t')\n", 184 | " ax_traj.set_ylabel('x,y')\n", 185 | " ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')\n", 186 | " ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')\n", 187 | " ax_traj.set_xlim(t.cpu().min(), t.cpu().max())\n", 188 | " ax_traj.set_ylim(-2, 2)\n", 189 | "\n", 190 | " ax_phase.cla()\n", 191 | " ax_phase.set_title('Phase Portrait')\n", 192 | " ax_phase.set_xlabel('x')\n", 193 | " ax_phase.set_ylabel('y')\n", 194 | " ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')\n", 195 | " ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')\n", 196 | " ax_phase.set_xlim(-2, 2)\n", 197 | " ax_phase.set_ylim(-2, 2)\n", 198 | "\n", 199 | " ax_vecfield.cla()\n", 200 | " ax_vecfield.set_title('Learned Vector Field')\n", 201 | " ax_vecfield.set_xlabel('x')\n", 202 | " ax_vecfield.set_ylabel('y')\n", 203 | "\n", 204 | " y, x = np.mgrid[-2:2:21j, -2:2:21j]\n", 205 | " dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()\n", 206 | " mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)\n", 207 | " dydt = (dydt / mag)\n", 208 | " dydt = dydt.reshape(21, 21, 2)\n", 209 | "\n", 210 | " ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color=\"black\")\n", 211 | " ax_vecfield.set_xlim(-2, 2)\n", 212 | " ax_vecfield.set_ylim(-2, 2)\n", 213 | "\n", 214 | " fig.tight_layout()\n", 215 | " plt.savefig('png/{:03d}.png'.format(itr))\n", 216 | " #plt.draw()\n", 217 | " #plt.pause(0.001)" 218 | ], 219 | "metadata": { 220 | "id": "e3-lMYe5uP5F" 221 | }, 222 | "execution_count": null, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "source": [ 228 | "net = PredictFunction().to(device)\n", 229 | "opt = optim.Adam(net.parameters(), lr=3e-4)\n", 230 | "os.makedirs(\"png\", exist_ok=True)\n", 231 | "\n", 232 | "for i in range(3000+1):\n", 233 | " opt.zero_grad()https:\\\\\\\\colab.research.google.com\\\\a93e4ab1-c004-40da-8cea-f5bcbedcbb07\n", 234 | " batch_starting, batch_t, batch_dxdt = get_batch()\n", 235 | " pred_dxdt = odeint(net, batch_starting, batch_t).to(device)\n", 236 | " loss = F.l1_loss(pred_dxdt, batch_dxdt)\n", 237 | " loss.backward()\n", 238 | " opt.step()\n", 239 | " print(\"\\r> [{}/3000] loss = {:.3f}\".format(i, loss.item()), end='')\n", 240 | "\n", 241 | " if i % 300 == 0:\n", 242 | " with torch.no_grad():\n", 243 | " pred_dxdt = odeint(net, starting_point, t)\n", 244 | " visualize(true_dxdt, pred_dxdt, net, i)\n", 245 | " print(\"\")" 246 | ], 247 | "metadata": { 248 | "colab": { 249 | "base_uri": "https://localhost:8080/" 250 | }, 251 | "id": "CiA9tmiHpkx6", 252 | "outputId": "9dcc072c-117c-463f-9882-586f7832c099" 253 | }, 254 | "execution_count": null, 255 | "outputs": [ 256 | { 257 | "output_type": "stream", 258 | "name": "stdout", 259 | "text": [ 260 | "> [0/3000] loss = 0.087\n", 261 | "> [300/3000] loss = 0.051\n", 262 | "> [600/3000] loss = 0.027\n", 263 | "> [900/3000] loss = 0.013\n", 264 | "> [1200/3000] loss = 0.007\n", 265 | "> [1500/3000] loss = 0.005\n", 266 | "> [1800/3000] loss = 0.003\n", 267 | "> [2100/3000] loss = 0.002\n", 268 | "> [2400/3000] loss = 0.001\n", 269 | "> [2687/3000] loss = 0.001" 270 | ] 271 | } 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "source": [], 277 | "metadata": { 278 | "id": "Jm60tmjN5X7i" 279 | }, 280 | "execution_count": null, 281 | "outputs": [] 282 | } 283 | ] 284 | } 285 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple Deep Learning Projects 2 | jupyter notebook 단위로 project 구현 / Imcommit 채널에서 각각에 대해 다룰지도 3 | 4 | 5 | ### 1. DDPM-notebook.ipynb : DDPM 논문 알고리즘 구현 6 | Paper: [https://arxiv.org/pdf/2006.11239.pdf](https://arxiv.org/pdf/2006.11239.pdf)
7 | Youtube: [https://www.youtube.com/watch?v=svSQhYGKk0Q](https://www.youtube.com/watch?v=svSQhYGKk0Q) 8 | 9 | 실험적인 세부사항(Exponential Moving Average 등)은 제외 10 | - coefficient 구현이 핵심 11 | Dataset: CIFAR-10 12 | 13 | ### 2. pi_digit_estimation_GRU.ipynb : π 패턴 예측 14 | [https://youtu.be/kdmrlMAaCiA](https://youtu.be/kdmrlMAaCiA) 15 | 16 | 될 리가 있나 17 | 18 | ### 3. Neural_ODE.ipynb : official code 약간 수정 19 | Paper: [https://arxiv.org/pdf/1806.07366.pdf](https://arxiv.org/pdf/1806.07366.pdf)
20 | Official github(Reference): [https://github.com/rtqichen/torchdiffeq](https://github.com/rtqichen/torchdiffeq)
21 | [https://www.youtube.com/watch?v=NS-C_QjjcT4](https://www.youtube.com/watch?v=NS-C_QjjcT4) 22 | 23 | 출력 코드 정리 24 | model 구성 변경(official code보다 학습이 약간 어렵도록) 25 | Dataset: toy example 26 | 27 | ### 4. vae.ipynb : VAE, AE, AE with z-regularization 28 | paper: [https://arxiv.org/pdf/1312.6114.pdf](https://arxiv.org/pdf/1312.6114.pdf)
29 | 30 | Autoencoder: reconstruction loss only 31 | AE with z-regularization: loss = reconstruction loss + z.abs().mean() 32 | VAE: loss = reconstruction loss + kl divergence 33 | Dataset: MNIST dataset 34 | 35 | ### 5. flowmatching.ipynb : flow matching for generative modeling 36 | Paper: [https://arxiv.org/pdf/2210.02747](https://arxiv.org/pdf/2210.02747)
37 | 38 | 실험적인 세부사항(Exponential Moving Average 등)은 제외 39 | - train과 sample 함수가 핵심 40 | - Dataset: MNIST 41 | -------------------------------------------------------------------------------- /base/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import MNIST 3 | from torchvision.transforms import ToTensor 4 | 5 | class MNISTData: 6 | def __init__(self): 7 | dataset = MNIST(root='./data', train=True, download=True) 8 | self.dataset = [(ToTensor()(img), target) for img, target in dataset] 9 | val_dataset = MNIST(root='./data', train=False, download=True) 10 | self.val_dataset = [(ToTensor()(img), target) for img, target in val_dataset] 11 | 12 | def __len__(self): 13 | return len(self.dataset) 14 | 15 | def __getitem__(self, idx): 16 | return self.dataset[idx] -------------------------------------------------------------------------------- /base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Net(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1, stride=2) 8 | self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=2) 9 | self.fc1 = nn.Linear(7*7*16, 10) 10 | 11 | def forward(self, x): 12 | x = torch.relu(self.conv1(x)) 13 | x = torch.relu(self.conv2(x)) 14 | x = x.view(-1, 7*7*16) 15 | x = self.fc1(x) 16 | return x -------------------------------------------------------------------------------- /base/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from mnist import MNISTData 5 | from model import Net 6 | 7 | class Trainer: 8 | def __init__(self): 9 | self.model = Net() 10 | dataset = MNISTData() 11 | self.dataloader = DataLoader(dataset, batch_size=32, shuffle=True) 12 | self.val_dataloader = DataLoader(dataset.val_dataset, batch_size=32, shuffle=False) 13 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) 14 | self.criterion = torch.nn.CrossEntropyLoss() 15 | 16 | def train(self, epochs): 17 | for epoch in range(epochs): 18 | for i, (data, target) in enumerate(self.dataloader, 1): 19 | self.optimizer.zero_grad() 20 | output = self.model(data) 21 | loss = self.criterion(output, target) 22 | loss.backward() 23 | self.optimizer.step() 24 | if i % 100 == 0: 25 | print(f"\rEpoch {epoch}, Loss {loss.item()}", end='') 26 | print() 27 | 28 | def test(self): 29 | correct = 0 30 | total = 0 31 | with torch.no_grad(): 32 | for data, target in self.val_dataloader: 33 | output = self.model(data) 34 | _, predicted = torch.max(output.data, 1) 35 | total += target.size(0) 36 | correct += (predicted == target).sum().item() 37 | print(f"Accuracy: {correct/total}, {correct=}, {total=}") 38 | 39 | 40 | def main(): 41 | trainer = Trainer() 42 | trainer.train(5) 43 | trainer.test() 44 | 45 | if __name__ == "__main__": 46 | main() -------------------------------------------------------------------------------- /flowmatching.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from torch.nn import init\n", 12 | "import torch.nn.functional as F\n", 13 | "from torchvision.datasets import MNIST\n", 14 | "from torchvision import transforms\n", 15 | "\n", 16 | "from random import random\n", 17 | "from cv2 import imwrite\n", 18 | "import math\n", 19 | "import os" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "dataset = MNIST(\n", 29 | " root=\"./data\", train=True, download=True,\n", 30 | " transform=transforms.Compose([\n", 31 | " transforms.ToTensor()\n", 32 | " ])\n", 33 | ")\n", 34 | "dataloader = torch.utils.data.DataLoader(\n", 35 | " dataset, batch_size=128, shuffle=True, num_workers=4\n", 36 | ")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def train(model, x, label):\n", 46 | " t = random()\n", 47 | " eps = torch.randn_like(x) # eps ~ N(0, I)\n", 48 | " x_t = t * x + (1-t) * eps\n", 49 | " drop_label = random() < 0.3\n", 50 | " hat = model(x_t, t, label, drop_label)\n", 51 | " loss = F.mse_loss(hat, x-eps)\n", 52 | " return loss\n", 53 | "\n", 54 | "def write(x_0):\n", 55 | " x_0 = x_0.permute(0, 2, 3, 1).clamp(0, 1).detach().cpu().numpy() * 255\n", 56 | " imwrite(\"temp.png\", x_0[0])\n", 57 | "\n", 58 | "def sample(model, z, label, num_step=10, w=2.0, cfg=True):\n", 59 | " x_t = z # ~ N(0, I)\n", 60 | " ts = torch.linspace(0, 1, num_step+1) # num_step = 10 -> ts : [.0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 10]\n", 61 | " dts = ts[1:] - ts[:-1] # [.1,] * 10\n", 62 | " for t, dt in zip(ts[:-1], dts): # [.0, .1], [.1, .1], [.2, .1], ..., [.9, .1]\n", 63 | " hat = model(x_t, t, label, drop_label=False)\n", 64 | " if cfg:\n", 65 | " hat_uncond = model(x_t, t, label, drop_label=True)\n", 66 | " hat = (1+w) * hat - w * hat_uncond\n", 67 | " x_t = x_t + dt * hat\n", 68 | " return x_t" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "class Swish(nn.Module):\n", 78 | " def forward(self, x):\n", 79 | " return x * torch.sigmoid(x)\n", 80 | "\n", 81 | "class TimeEmbedding(nn.Module):\n", 82 | " def __init__(self, T, d_model, dim):\n", 83 | " assert d_model % 2 == 0\n", 84 | " super().__init__()\n", 85 | " emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)\n", 86 | " emb = torch.exp(-emb)\n", 87 | " pos = torch.arange(T).float()\n", 88 | " emb = pos[:, None] * emb[None, :]\n", 89 | " emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)\n", 90 | " emb = emb.view(T, d_model)\n", 91 | "\n", 92 | " self.timembedding = nn.Sequential(\n", 93 | " nn.Linear(1, d_model),\n", 94 | " nn.Linear(d_model, dim),\n", 95 | " Swish(),\n", 96 | " nn.Linear(dim, dim),\n", 97 | " )\n", 98 | " self.initialize()\n", 99 | "\n", 100 | " def initialize(self):\n", 101 | " for module in self.modules():\n", 102 | " if isinstance(module, nn.Linear):\n", 103 | " init.xavier_uniform_(module.weight)\n", 104 | " init.zeros_(module.bias)\n", 105 | "\n", 106 | " def forward(self, t):\n", 107 | " emb = self.timembedding(t)\n", 108 | " return emb\n", 109 | "\n", 110 | "class DownSample(nn.Module):\n", 111 | " def __init__(self, in_ch):\n", 112 | " super().__init__()\n", 113 | " self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)\n", 114 | " self.initialize()\n", 115 | "\n", 116 | " def initialize(self):\n", 117 | " init.xavier_uniform_(self.main.weight)\n", 118 | " init.zeros_(self.main.bias)\n", 119 | "\n", 120 | " def forward(self, x, temb):\n", 121 | " x = self.main(x)\n", 122 | " return x\n", 123 | "\n", 124 | "class UpSample(nn.Module):\n", 125 | " def __init__(self, in_ch):\n", 126 | " super().__init__()\n", 127 | " self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)\n", 128 | " self.initialize()\n", 129 | "\n", 130 | " def initialize(self):\n", 131 | " init.xavier_uniform_(self.main.weight)\n", 132 | " init.zeros_(self.main.bias)\n", 133 | "\n", 134 | " def forward(self, x, temb):\n", 135 | " _, _, H, W = x.shape\n", 136 | " x = F.interpolate(\n", 137 | " x, scale_factor=2, mode='nearest')\n", 138 | " x = self.main(x)\n", 139 | " return x\n", 140 | "\n", 141 | "class AttnBlock(nn.Module):\n", 142 | " def __init__(self, in_ch):\n", 143 | " super().__init__()\n", 144 | " self.group_norm = nn.GroupNorm(32, in_ch)\n", 145 | " self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", 146 | " self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", 147 | " self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", 148 | " self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n", 149 | " self.initialize()\n", 150 | "\n", 151 | " def initialize(self):\n", 152 | " for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:\n", 153 | " init.xavier_uniform_(module.weight)\n", 154 | " init.zeros_(module.bias)\n", 155 | " init.xavier_uniform_(self.proj.weight, gain=1e-5)\n", 156 | "\n", 157 | " def forward(self, x):\n", 158 | " B, C, H, W = x.shape\n", 159 | " h = self.group_norm(x)\n", 160 | " q = self.proj_q(h)\n", 161 | " k = self.proj_k(h)\n", 162 | " v = self.proj_v(h)\n", 163 | "\n", 164 | " q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n", 165 | " k = k.view(B, C, H * W)\n", 166 | " w = torch.bmm(q, k) * (int(C) ** (-0.5))\n", 167 | " assert list(w.shape) == [B, H * W, H * W]\n", 168 | " w = F.softmax(w, dim=-1)\n", 169 | "\n", 170 | " v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n", 171 | " h = torch.bmm(w, v)\n", 172 | " assert list(h.shape) == [B, H * W, C]\n", 173 | " h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n", 174 | " h = self.proj(h)\n", 175 | "\n", 176 | " return x + h\n", 177 | "\n", 178 | "class ResBlock(nn.Module):\n", 179 | " def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):\n", 180 | " super().__init__()\n", 181 | " self.block1 = nn.Sequential(\n", 182 | " nn.GroupNorm(32, in_ch),\n", 183 | " Swish(),\n", 184 | " nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),\n", 185 | " )\n", 186 | " self.temb_proj = nn.Sequential(\n", 187 | " Swish(),\n", 188 | " nn.Linear(tdim, out_ch),\n", 189 | " )\n", 190 | " self.block2 = nn.Sequential(\n", 191 | " nn.GroupNorm(32, out_ch),\n", 192 | " Swish(),\n", 193 | " nn.Dropout(dropout),\n", 194 | " nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),\n", 195 | " )\n", 196 | " if in_ch != out_ch:\n", 197 | " self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)\n", 198 | " else:\n", 199 | " self.shortcut = nn.Identity()\n", 200 | " if attn:\n", 201 | " self.attn = AttnBlock(out_ch)\n", 202 | " else:\n", 203 | " self.attn = nn.Identity()\n", 204 | " self.initialize()\n", 205 | "\n", 206 | " def initialize(self):\n", 207 | " for module in self.modules():\n", 208 | " if isinstance(module, (nn.Conv2d, nn.Linear)):\n", 209 | " init.xavier_uniform_(module.weight)\n", 210 | " init.zeros_(module.bias)\n", 211 | " init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)\n", 212 | "\n", 213 | " def forward(self, x, temb):\n", 214 | " h = self.block1(x)\n", 215 | " h += self.temb_proj(temb)[:, :, None, None]\n", 216 | " h = self.block2(h)\n", 217 | "\n", 218 | " h = h + self.shortcut(x)\n", 219 | " h = self.attn(h)\n", 220 | " return h\n", 221 | "\n", 222 | "class UNet(nn.Module):\n", 223 | " def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n", 224 | " super().__init__()\n", 225 | " assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n", 226 | " tdim = ch * 4\n", 227 | " self.time_embedding = TimeEmbedding(T, ch, tdim)\n", 228 | " self.label_embedding = nn.Embedding(10, tdim)\n", 229 | "\n", 230 | " self.head = nn.Conv2d(1, ch, kernel_size=3, stride=1, padding=1)\n", 231 | " self.downblocks = nn.ModuleList()\n", 232 | " chs = [ch] # record output channel when dowmsample for upsample\n", 233 | " now_ch = ch\n", 234 | " for i, mult in enumerate(ch_mult):\n", 235 | " out_ch = ch * mult\n", 236 | " for _ in range(num_res_blocks):\n", 237 | " self.downblocks.append(ResBlock(\n", 238 | " in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n", 239 | " dropout=dropout, attn=(i in attn)))\n", 240 | " now_ch = out_ch\n", 241 | " chs.append(now_ch)\n", 242 | " if i != len(ch_mult) - 1:\n", 243 | " self.downblocks.append(DownSample(now_ch))\n", 244 | " chs.append(now_ch)\n", 245 | "\n", 246 | " self.middleblocks = nn.ModuleList([\n", 247 | " ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n", 248 | " ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n", 249 | " ])\n", 250 | "\n", 251 | " self.upblocks = nn.ModuleList()\n", 252 | " for i, mult in reversed(list(enumerate(ch_mult))):\n", 253 | " out_ch = ch * mult\n", 254 | " for _ in range(num_res_blocks + 1):\n", 255 | " self.upblocks.append(ResBlock(\n", 256 | " in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n", 257 | " dropout=dropout, attn=(i in attn)))\n", 258 | " now_ch = out_ch\n", 259 | " if i != 0:\n", 260 | " self.upblocks.append(UpSample(now_ch))\n", 261 | " assert len(chs) == 0\n", 262 | "\n", 263 | " self.tail = nn.Sequential(\n", 264 | " nn.GroupNorm(32, now_ch),\n", 265 | " Swish(),\n", 266 | " nn.Conv2d(now_ch, 1, 3, stride=1, padding=1)\n", 267 | " )\n", 268 | " self.initialize()\n", 269 | "\n", 270 | " def initialize(self):\n", 271 | " init.xavier_uniform_(self.head.weight)\n", 272 | " init.zeros_(self.head.bias)\n", 273 | " init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n", 274 | " init.zeros_(self.tail[-1].bias)\n", 275 | "\n", 276 | " def forward(self, x, t, label, drop_label):\n", 277 | " # Timestep embedding\n", 278 | " t = torch.tensor([t], device=x.device)\n", 279 | " temb = self.time_embedding(t)\n", 280 | " lemb = self.label_embedding(label)\n", 281 | " if drop_label: lemb = torch.zeros_like(lemb)\n", 282 | " temb = temb + lemb\n", 283 | " # Downsampling\n", 284 | " h = self.head(x)\n", 285 | " hs = [h]\n", 286 | " for layer in self.downblocks:\n", 287 | " h = layer(h, temb)\n", 288 | " hs.append(h)\n", 289 | " # Middle\n", 290 | " for layer in self.middleblocks:\n", 291 | " h = layer(h, temb)\n", 292 | " # Upsampling\n", 293 | " for layer in self.upblocks:\n", 294 | " if isinstance(layer, ResBlock):\n", 295 | " h = torch.cat([h, hs.pop()], dim=1)\n", 296 | " h = layer(h, temb)\n", 297 | " h = self.tail(h)\n", 298 | "\n", 299 | " assert len(hs) == 0\n", 300 | " return h" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "device = torch.device(\"cuda\")\n", 310 | "model = UNet(T=1, ch=128, ch_mult=[2, 2], attn=[1],\n", 311 | " num_res_blocks=2, dropout=0.1).to(device)\n", 312 | "optim = torch.optim.Adam(model.parameters(), lr=2e-4)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "for e in range(1, 10+1):\n", 322 | " model.train()\n", 323 | " for i, (x, l) in enumerate(dataloader, 1):\n", 324 | " optim.zero_grad()\n", 325 | " x, l = x.to(device), l.to(device)\n", 326 | " loss = train(model, x, l)\n", 327 | " loss.backward()\n", 328 | " optim.step()\n", 329 | " print(\"\\r[Epoch: {} , Iter: {}/{}] Loss: {:.3f}\".format(e, i, len(dataloader), loss.item()), end='')\n", 330 | " print(\"\\n> Eval at epoch {}\".format(e))\n", 331 | " model.eval()\n", 332 | " with torch.no_grad():\n", 333 | " os.makedirs(\"cfm\", exist_ok=True)\n", 334 | " x_T = torch.randn(1, 1, 28, 28).to(device)\n", 335 | " labels = [torch.zeros([1], device=device, dtype=torch.long) + i for i in range(10)]\n", 336 | " for i in range(10):\n", 337 | " x_0 = sample(model, x_T, labels[i])\n", 338 | " x_0 = x_0.permute(0, 2, 3, 1).clamp(0, 1).detach().cpu().numpy() * 255\n", 339 | " imwrite(f\"cfm/gen_cfg_{e}_{i}.png\", x_0[0])" 340 | ] 341 | } 342 | ], 343 | "metadata": { 344 | "kernelspec": { 345 | "display_name": "foo", 346 | "language": "python", 347 | "name": "python3" 348 | }, 349 | "language_info": { 350 | "name": "python", 351 | "version": "3.10.14" 352 | } 353 | }, 354 | "nbformat": 4, 355 | "nbformat_minor": 2 356 | } 357 | -------------------------------------------------------------------------------- /pi_digit_estimation_GRU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "bDlRyKYTnjE-" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "import torch.nn.functional as F\n", 14 | "from torch.utils.data import DataLoader\n", 15 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "colab": { 23 | "base_uri": "https://localhost:8080/" 24 | }, 25 | "id": "eAwmGxvaooYc", 26 | "outputId": "b203b108-bb6f-4893-a872-c09d46be2e2a" 27 | }, 28 | "outputs": [ 29 | { 30 | "output_type": "stream", 31 | "name": "stdout", 32 | "text": [ 33 | "1000001\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "with open(\"pi_million.txt\") as f:\n", 39 | " pi = f.read()\n", 40 | " pi = pi[0] + pi[2:]\n", 41 | "print(len(pi)) # \"3.\" + 1,000,000 digits" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "id": "QyxkcuxnW6Yb" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "class Dataset:\n", 53 | " def __init__(self, pi_str, _start, _end, _len=512):\n", 54 | " self.data = self.fetch_data(pi_str, \n", 55 | " start=_start, \n", 56 | " end=_end, \n", 57 | " seq_len=_len)\n", 58 | "\n", 59 | " def fetch_data(self, pi_str, start, end, seq_len):\n", 60 | " sequences = [pi_str[i:i+seq_len] for i in range(start, end)]\n", 61 | " digit_data = [[int(c) for c in seq] for seq in sequences]\n", 62 | "\n", 63 | " return digit_data\n", 64 | " \n", 65 | " def __len__(self):\n", 66 | " return len(self.data)\n", 67 | " \n", 68 | " def __getitem__(self, idx):\n", 69 | " return self.data[idx]\n", 70 | "\n", 71 | "def col_fn(batch):\n", 72 | " return torch.stack([torch.LongTensor(b) for b in batch])\n", 73 | "\n", 74 | "train_set = Dataset(pi, _start=0, _end=100000)\n", 75 | "test_set = Dataset(pi, _start=100000, _end=103000)\n", 76 | "train_dl = DataLoader(train_set, batch_size=256, drop_last=False, collate_fn = col_fn)\n", 77 | "test_dl = DataLoader(test_set, batch_size=256, drop_last=False, collate_fn = col_fn)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": { 84 | "id": "JnHQ0BB2oqIO" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "class Model(nn.Module):\n", 89 | " def __init__(self, num_digit, in_dim, hidden_dim, out_dim):\n", 90 | " super().__init__()\n", 91 | " self.embedding = nn.Embedding(num_digit, in_dim)\n", 92 | " self.in_gru = nn.GRU(in_dim, hidden_dim, batch_first=True)\n", 93 | " self.latent_fc = nn.Linear(hidden_dim, hidden_dim)\n", 94 | " self.out_gru = nn.GRU(hidden_dim, out_dim, batch_first=True)\n", 95 | " self.out_fc = nn.Linear(out_dim, num_digit)\n", 96 | " \n", 97 | " def forward(self, x, return_loss=True):\n", 98 | " emb_out = self.embedding(x)\n", 99 | " hidden, _ = self.in_gru(emb_out)\n", 100 | " latent = self.latent_fc(hidden)\n", 101 | " out, _ = self.out_gru(latent)\n", 102 | " out_digit = self.out_fc(out[:, -2])\n", 103 | "\n", 104 | " if return_loss:\n", 105 | " loss = self.get_loss(out_digit, x)\n", 106 | " return loss\n", 107 | " else:\n", 108 | " return out_digit\n", 109 | "\n", 110 | "\n", 111 | " def get_loss(self, digit_logit, x):\n", 112 | " x_last = x[:, -1]\n", 113 | "\n", 114 | " loss = F.cross_entropy(digit_logit, x_last)\n", 115 | " return loss\n", 116 | "\n", 117 | " @torch.no_grad()\n", 118 | " def generate(self, x):\n", 119 | " digit_logit = self.forward(x, return_loss=False)\n", 120 | "\n", 121 | " prob = torch.softmax(digit_logit, -1)\n", 122 | " sample = torch.argmax(prob, -1)\n", 123 | "\n", 124 | " return sample" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "id": "9515oV3gRJgH" 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "model_config = {\n", 136 | " \"num_digit\": 10,\n", 137 | " \"in_dim\": 512,\n", 138 | " \"hidden_dim\": 1024,\n", 139 | " \"out_dim\": 512\n", 140 | "}\n", 141 | "\n", 142 | "model = Model(**model_config)\n", 143 | "model.to(device)\n", 144 | "optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.999])" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": { 151 | "colab": { 152 | "base_uri": "https://localhost:8080/" 153 | }, 154 | "id": "pQtA2imoulEA", 155 | "outputId": "cb6eec81-4932-4802-8a8b-c3a213c39e35" 156 | }, 157 | "outputs": [ 158 | { 159 | "output_type": "stream", 160 | "name": "stdout", 161 | "text": [ 162 | "idx: 0 [0 / 391], loss= 2.306\n", 163 | "idx: 500 [109 / 391], loss= 2.308\n", 164 | "idx: 1,000 [218 / 391], loss= 2.310\n", 165 | "idx: 1,500 [327 / 391], loss= 2.318\n", 166 | "idx: 2,000 [45 / 391], loss= 2.309\n", 167 | "idx: 2,500 [154 / 391], loss= 2.302\n", 168 | "idx: 3,000 [263 / 391], loss= 2.305\n", 169 | "idx: 3,500 [372 / 391], loss= 2.310\n", 170 | "idx: 4,000 [90 / 391], loss= 2.296\n", 171 | "idx: 4,500 [199 / 391], loss= 2.312\n", 172 | "idx: 5,000 [308 / 391], loss= 2.313\n", 173 | "idx: 5,500 [26 / 391], loss= 2.298\n", 174 | "idx: 6,000 [135 / 391], loss= 2.303\n", 175 | "idx: 6,500 [244 / 391], loss= 2.315\n", 176 | "idx: 7,000 [353 / 391], loss= 2.303\n", 177 | "idx: 7,500 [71 / 391], loss= 2.302\n", 178 | "idx: 8,000 [180 / 391], loss= 2.304\n", 179 | "idx: 8,500 [289 / 391], loss= 2.309\n", 180 | "idx: 9,000 [7 / 391], loss= 2.304\n", 181 | "idx: 9,500 [116 / 391], loss= 2.300\n", 182 | "idx: 10,000 [225 / 391], loss= 2.308\n", 183 | "idx: 10,500 [334 / 391], loss= 2.306\n", 184 | "idx: 11,000 [52 / 391], loss= 2.309\n", 185 | "idx: 11,500 [161 / 391], loss= 2.304\n", 186 | "idx: 12,000 [270 / 391], loss= 2.317\n", 187 | "idx: 12,500 [379 / 391], loss= 2.308\n", 188 | "idx: 13,000 [97 / 391], loss= 2.318\n", 189 | "idx: 13,500 [206 / 391], loss= 2.314\n", 190 | "idx: 14,000 [315 / 391], loss= 2.300\n", 191 | "idx: 14,500 [33 / 391], loss= 2.289\n", 192 | "idx: 15,000 [142 / 391], loss= 2.308\n", 193 | "idx: 15,500 [251 / 391], loss= 2.318\n", 194 | "idx: 16,000 [360 / 391], loss= 2.310\n", 195 | "idx: 16,500 [78 / 391], loss= 2.315\n", 196 | "idx: 17,000 [187 / 391], loss= 2.304\n", 197 | "idx: 17,500 [296 / 391], loss= 2.303\n", 198 | "idx: 18,000 [14 / 391], loss= 2.315\n", 199 | "idx: 18,500 [123 / 391], loss= 2.302\n", 200 | "idx: 19,000 [232 / 391], loss= 2.302\n", 201 | "idx: 19,500 [341 / 391], loss= 2.307\n", 202 | "idx: 19,549 [390 / 391], loss= 2.316" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "idx = 0\n", 208 | "epochs = 50\n", 209 | "for e in range(epochs):\n", 210 | " for i, x in enumerate(train_dl):\n", 211 | " x = x.to(device)\n", 212 | "\n", 213 | " loss = model(x)\n", 214 | " optim.zero_grad()\n", 215 | " loss.backward()\n", 216 | " optim.step()\n", 217 | "\n", 218 | " print(f\"\\ridx: {idx:,} [{i} / {len(train_dl)}], loss= {loss.item():.3f}\", end='')\n", 219 | " if idx % 500 == 0:\n", 220 | " print(\"\")\n", 221 | "\n", 222 | " idx += 1\n" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "source": [ 228 | "torch.save(model.state_dict(), \"512_1024_10k.ckpt\")" 229 | ], 230 | "metadata": { 231 | "id": "mkqNRrG3LKyX" 232 | }, 233 | "execution_count": null, 234 | "outputs": [] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "source": [ 239 | "# **Evaluation**\n", 240 | "Check whether pi estimator works well in validation dataset" 241 | ], 242 | "metadata": { 243 | "id": "3D5PFjHhC8J7" 244 | } 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "id": "ggDal1N6yUdB", 251 | "colab": { 252 | "base_uri": "https://localhost:8080/" 253 | }, 254 | "outputId": "d130904a-f279-4c77-aebf-af80d193b066" 255 | }, 256 | "outputs": [ 257 | { 258 | "output_type": "stream", 259 | "name": "stdout", 260 | "text": [ 261 | "total_correct = 9999, total_wrong = 1\n" 262 | ] 263 | } 264 | ], 265 | "source": [ 266 | "def check_correct_wrong(sample, true):\n", 267 | " batch_len = len(sample)\n", 268 | " correct = sample == true\n", 269 | " correct_num = correct.sum()\n", 270 | " wrong_num = batch_len - correct_num\n", 271 | "\n", 272 | " return correct_num, wrong_num\n", 273 | "\n", 274 | "model_config = {\n", 275 | " \"num_digit\": 10,\n", 276 | " \"in_dim\": 512,\n", 277 | " \"hidden_dim\": 1024,\n", 278 | " \"out_dim\": 512\n", 279 | "}\n", 280 | "\n", 281 | "model = Model(**model_config)\n", 282 | "state_dict = torch.load(\"512_1024_10k.ckpt\")\n", 283 | "model.load_state_dict(state_dict)\n", 284 | "model.to(device)\n", 285 | "\n", 286 | "total_correct, total_wrong = 0, 0\n", 287 | "for x in train_dl:\n", 288 | " x = x.to(device)\n", 289 | " sample = model.generate(x)\n", 290 | " correct_num, wrong_num = check_correct_wrong(sample, x[:, -1])\n", 291 | " total_correct += correct_num\n", 292 | " total_wrong += wrong_num\n", 293 | "print(f\"total_correct = {total_correct}, total_wrong = {total_wrong}\")" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "source": [], 299 | "metadata": { 300 | "id": "3UbnOa4hbv4q" 301 | }, 302 | "execution_count": null, 303 | "outputs": [] 304 | } 305 | ], 306 | "metadata": { 307 | "accelerator": "GPU", 308 | "colab": { 309 | "provenance": [] 310 | }, 311 | "gpuClass": "premium", 312 | "kernelspec": { 313 | "display_name": "Python 3", 314 | "name": "python3" 315 | }, 316 | "language_info": { 317 | "name": "python" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 0 322 | } -------------------------------------------------------------------------------- /vae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "device = cuda\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import torch\n", 18 | "import torch.nn as nn\n", 19 | "from torch.utils.data import DataLoader\n", 20 | "\n", 21 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 22 | "print(f\"device = {device}\")" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", 35 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST\\raw\\train-images-idx3-ubyte.gz\n" 36 | ] 37 | }, 38 | { 39 | "name": "stderr", 40 | "output_type": "stream", 41 | "text": [ 42 | "100%|██████████| 9912422/9912422 [00:00<00:00, 30112030.95it/s]\n" 43 | ] 44 | }, 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "Extracting ./MNIST\\raw\\train-images-idx3-ubyte.gz to ./MNIST\\raw\n", 50 | "\n", 51 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", 52 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST\\raw\\train-labels-idx1-ubyte.gz\n" 53 | ] 54 | }, 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "100%|██████████| 28881/28881 [00:00<00:00, 1653458.73it/s]\n" 60 | ] 61 | }, 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Extracting ./MNIST\\raw\\train-labels-idx1-ubyte.gz to ./MNIST\\raw\n", 67 | "\n", 68 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", 69 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST\\raw\\t10k-images-idx3-ubyte.gz\n" 70 | ] 71 | }, 72 | { 73 | "name": "stderr", 74 | "output_type": "stream", 75 | "text": [ 76 | "100%|██████████| 1648877/1648877 [00:00<00:00, 10304400.89it/s]\n" 77 | ] 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "Extracting ./MNIST\\raw\\t10k-images-idx3-ubyte.gz to ./MNIST\\raw\n", 84 | "\n", 85 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", 86 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n" 87 | ] 88 | }, 89 | { 90 | "name": "stderr", 91 | "output_type": "stream", 92 | "text": [ 93 | "100%|██████████| 4542/4542 [00:00) tensor(-3.1758, device='cuda:0', grad_fn=) tensor(1.0295, device='cuda:0', grad_fn=)\n" 238 | ] 239 | }, 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "" 244 | ] 245 | }, 246 | "execution_count": 13, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | }, 250 | { 251 | "data": { 252 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgQUlEQVR4nO3de3DV9f3n8dfJ7XBLTgwhNwkYUMEKxEohpSrFkiXE1gVlO946A44/XGlwitTq0lVR25m0+Bvr6o/q/GZaqLvibVZg5Ke4CiasbaAFoZSq+RF+UYKQIKnJyYVcz2f/YE17FNTP8YR3Ep6Pme8MOef74vvhyze88s05vBNwzjkBAHCWJVgvAABwbqKAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYCLJegGfFolEdPToUaWmpioQCFgvBwDgyTmnlpYW5eXlKSHhzPc5A66Ajh49qvz8fOtlAAC+orq6Oo0dO/aMzw+4AkpNTZUkXalrlKRk49UAAHz1qFtv6ZW+f8/PpN8KaO3atXrkkUdUX1+vwsJCPfHEE5o5c+YX5j75tluSkpUUoIAAYND5/xNGv+hllH55E8Lzzz+vlStXavXq1Xr77bdVWFiokpISHT9+vD8OBwAYhPqlgB599FEtXbpUt956q772ta/pqaee0ogRI/Tb3/62Pw4HABiE4l5AXV1d2rNnj4qLi/9+kIQEFRcXq6qq6jP7d3Z2KhwOR20AgKEv7gV04sQJ9fb2Kjs7O+rx7Oxs1dfXf2b/8vJyhUKhvo13wAHAucH8P6KuWrVKzc3NfVtdXZ31kgAAZ0Hc3wWXmZmpxMRENTQ0RD3e0NCgnJycz+wfDAYVDAbjvQwAwAAX9zuglJQUTZ8+Xdu2bet7LBKJaNu2bZo1a1a8DwcAGKT65f8BrVy5UosXL9Y3vvENzZw5U4899pja2tp066239sfhAACDUL8U0A033KCPPvpIDzzwgOrr63XZZZdp69atn3ljAgDg3BVwzjnrRfyjcDisUCikOVrAJAQAGIR6XLcqtFnNzc1KS0s7437m74IDAJybKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgIkk6wUAA0kgKYZPicDA/TrO9fbGGIzEkHGxHQvnrIH7mQMAGNIoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBgpBr5AwDuSODojpkN1XF7gnfnoshTvTNtY/yGhSe3+Xy8GG/3PnSSlfeA/jHTU4XbvTFLdCe9M74lG74zr7PTOoP9xBwQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEw0hxdiUkekcSM9K9Mye+e7F3RpI+LjnpnZk69gP/TOiodyYY6PHOHO9O9c5I0n+0Znpn3vkwxzuT/uZ470xW5XDvTKS2zjsjSa67K6YcvhzugAAAJiggAICJuBfQgw8+qEAgELVNnjw53ocBAAxy/fIa0KWXXqo33njj7wdJ4qUmAEC0fmmGpKQk5eT4vyAJADh39MtrQAcPHlReXp4mTJigW265RYcPHz7jvp2dnQqHw1EbAGDoi3sBFRUVaf369dq6dauefPJJ1dbW6qqrrlJLS8tp9y8vL1coFOrb8vPz470kAMAAFPcCKi0t1fe//31NmzZNJSUleuWVV9TU1KQXXnjhtPuvWrVKzc3NfVtdXWzv1wcADC79/u6A9PR0XXzxxaqpqTnt88FgUMFgsL+XAQAYYPr9/wG1trbq0KFDys3N7e9DAQAGkbgX0N13363Kykq9//77+sMf/qDrrrtOiYmJuummm+J9KADAIBb3b8EdOXJEN910kxobGzVmzBhdeeWV2rlzp8aMGRPvQwEABrG4F9Bzzz0X798SA1Usg0VDad6Z4wv9B4sO+y8N3hlJ+qfz/+KdmTLc/40zLb3+AzVrOrO9M+cHP/bOSNKMUbXemaYxI7wzvxn1Le/M39r9v5gdHW71zkhSz/ET/qFIb0zHOhcxCw4AYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAICJfv+BdBgEAoGYYgnD/H+QYOflE7wzmTcf9s78MP9N74wkpSe2e2d2t/v/mf5lzxzvTKAxxT+T3emdkaSCHP8hnNPSP/TOXH/Bn70zv/vuN70zo46e752RpJT2k96Z3pYW/wM5558ZArgDAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBo2FEjxn7IsSYELxnpn/uMW/8nbzxds8s7kJcY2BXpL6yTvzNrX53lncnZ5RxSI+Gfasof7hyQdycj3zhwrTPPOfH/iXu/Mf5603zuz+bszvTOSdGHY/zwE9h/0zrjuLu/MUMAdEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMMIx1qAv7DPhPS/IdISlL97NHemX++8hnvzPSURO/MnzpjG8L5yGvXemcueLnbO5PU6j980iX7n4dhJ/wzktSS7z+gNpwQ8s5sG+k//HXxuCrvTMO33vPOSNI7hy7xzuQc8r/2epv9ryE5558ZYLgDAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIJhpENMIMV/iKTGnBfTsUZ//4h35nsjG70zH0f8B3f+1/3/5J2RpAu2+A+FDB4Le2cC7R3eGSX6DxZNGhbD9SApuSXonQmGR3hnjozO8s4czvEfgvu90X/2zkhS1bcKvDM5r2f4Hyjc6p9xvf6ZAYY7IACACQoIAGDCu4B27Niha6+9Vnl5eQoEAtq0aVPU8845PfDAA8rNzdXw4cNVXFysgwcPxmu9AIAhwruA2traVFhYqLVr1572+TVr1ujxxx/XU089pV27dmnkyJEqKSlRR0cM3/MGAAxZ3m9CKC0tVWlp6Wmfc87pscce03333acFCxZIkp5++mllZ2dr06ZNuvHGG7/aagEAQ0ZcXwOqra1VfX29iouL+x4LhUIqKipSVdXpf4xuZ2enwuFw1AYAGPriWkD19fWSpOzs7KjHs7Oz+577tPLycoVCob4tPz8/nksCAAxQ5u+CW7VqlZqbm/u2uro66yUBAM6CuBZQTk6OJKmhoSHq8YaGhr7nPi0YDCotLS1qAwAMfXEtoIKCAuXk5Gjbtm19j4XDYe3atUuzZs2K56EAAIOc97vgWltbVVNT0/dxbW2t9u3bp4yMDI0bN04rVqzQz3/+c1100UUqKCjQ/fffr7y8PC1cuDCe6wYADHLeBbR7925dffXVfR+vXLlSkrR48WKtX79e99xzj9ra2nT77berqalJV155pbZu3aphw4bFb9UAgEHPu4DmzJkj59wZnw8EAnr44Yf18MMPf6WFQVIg4B1JGO5f9A1X+g93lKR/nfA/vDPBgP9wzLXNk7wzKZvTvTOSFDzc8MU7fUpMg0UjEf+M/Aelqq09huNIiR/7f3d+VJf/UNu0sf6Zv0zL885cPepd74wkzZpY6505njXeO5Pwgf+gWRdhGCkAADGhgAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJjwnoaNsyjg//VBID3knem5psk7I0mXpfhfPod7Wr0zT/5biXdm4r6wd0aSAq3+06NddwxTqrv8M66nx/84MXKxTGKP4ThpdanemdqP/ae3p+ef9M5I0tTUD70zr4260DuTkuB/vocC7oAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBjpABZITPTOdEwc4525YcJb3hlJSoxhWOoTJ67yzuRW9XpnEusbvTOSFGlt8w/FMIzU9Ub8j+NiyJxF7qT/PyfBRv9zd7R1uHcmlBDDwFhJ2cnN3pmuVP/P22AMw1+dd2Lg4Q4IAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACYaRDmCBYUHvzMcXp3hnrkvb652RpPaI/+Xzv//yde/MpPdbvTPu5EnvjCS5jk7/TK//sNSYDPRhpDEMWA3EMpQ1hq+bUxP8h31K0rCA/xDThJ6hMCb07OAOCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAmGkQ5ggWHDvDNNX/Mf7jgmMbbhif/e7Z8b8a7/nynh44+8M70nO7wzkmIb+DnAh4TGJOD/tWkg0T/TM8L/n6AxGU3emVCC/3UnSfU9Ie9MQgyfF66nxzszFHAHBAAwQQEBAEx4F9COHTt07bXXKi8vT4FAQJs2bYp6fsmSJQoEAlHb/Pnz47VeAMAQ4V1AbW1tKiws1Nq1a8+4z/z583Xs2LG+7dlnn/1KiwQADD3erwCWlpaqtLT0c/cJBoPKycmJeVEAgKGvX14DqqioUFZWliZNmqRly5apsbHxjPt2dnYqHA5HbQCAoS/uBTR//nw9/fTT2rZtm375y1+qsrJSpaWl6u3tPe3+5eXlCoVCfVt+fn68lwQAGIDi/v+Abrzxxr5fT506VdOmTdPEiRNVUVGhuXPnfmb/VatWaeXKlX0fh8NhSggAzgH9/jbsCRMmKDMzUzU1Nad9PhgMKi0tLWoDAAx9/V5AR44cUWNjo3Jzc/v7UACAQcT7W3Ctra1RdzO1tbXat2+fMjIylJGRoYceekiLFi1STk6ODh06pHvuuUcXXnihSkpK4rpwAMDg5l1Au3fv1tVXX9338Sev3yxevFhPPvmk9u/fr9/97ndqampSXl6e5s2bp5/97GcKBoPxWzUAYNDzLqA5c+bIuTMP23vttde+0oKGrEDAP5KU6J0ZltvmnYnVX7vyvDOjjvgP7nRt7d4ZRWIbsOpiycUwuFOR078rNO5iuO4kKZDs//6kQGqqdyZ8QYp3ZvG43d6Zbhfb+d7w/gzvzOgjLd6ZyBneJTzUMQsOAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGAi7j+SG3GU5P/XE0zu8c50fM5088/zQWemdyYQy9DfXv8J2meVi2F9MU6p9j5Miv+0aUlKSA95ZzonZnln/jbV/9orHvmud+bfu/0ny0tSy5/GeGcyP3zP/0Axfg4OdtwBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMMEw0oEshgGFEec/5LI7xjmImckt3pnOdP/1BYYF/TOdnd6ZmMVwzpXoPxwzIRjDechI985IUnfeed6ZD2cP8878t3mbvDOpCf7DXx+q+553RpLOr+jwzkRa22I61rmIOyAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmGEZ6tsQwWNR1dnlnwidGemfanf9gTEkal9zonWm+2P88jL4w1zuTUpfinZEkd9J/+GQgOdn/QMn+n3pdYzO8My3j/AeYStLxb/r/Pf334pe8M/9pZI13ZnPrJd6Zd16a7J2RpLF/Peid6e3y/7w9V3EHBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwATDSAcw19bmnQnt9x/C+d7sbO+MJH0tpd47M32G/3DHP7df7J0Z9cH53hlJSmn1H8LZEwx4ZzrP88+0jo94Zy77+iHvjCStzf8370x+Urd35vnwpd6Zx7dc4525aMtx74wkRT7+OKYcvhzugAAAJiggAIAJrwIqLy/XjBkzlJqaqqysLC1cuFDV1dVR+3R0dKisrEyjR4/WqFGjtGjRIjU0NMR10QCAwc+rgCorK1VWVqadO3fq9ddfV3d3t+bNm6e2f3it4q677tLLL7+sF198UZWVlTp69Kiuv/76uC8cADC4eb0JYevWrVEfr1+/XllZWdqzZ49mz56t5uZm/eY3v9GGDRv0ne98R5K0bt06XXLJJdq5c6e++c1vxm/lAIBB7Su9BtTc3CxJysg49aOC9+zZo+7ubhUXF/ftM3nyZI0bN05VVVWn/T06OzsVDoejNgDA0BdzAUUiEa1YsUJXXHGFpkyZIkmqr69XSkqK0tPTo/bNzs5Wff3p37JbXl6uUCjUt+Xn58e6JADAIBJzAZWVlenAgQN67rnnvtICVq1apebm5r6trq7uK/1+AIDBIab/iLp8+XJt2bJFO3bs0NixY/sez8nJUVdXl5qamqLughoaGpSTk3Pa3ysYDCoYDMayDADAIOZ1B+Sc0/Lly7Vx40Zt375dBQUFUc9Pnz5dycnJ2rZtW99j1dXVOnz4sGbNmhWfFQMAhgSvO6CysjJt2LBBmzdvVmpqat/rOqFQSMOHD1coFNJtt92mlStXKiMjQ2lpabrzzjs1a9Ys3gEHAIjiVUBPPvmkJGnOnDlRj69bt05LliyRJP3qV79SQkKCFi1apM7OTpWUlOjXv/51XBYLABg6As45/+mL/SgcDisUCmmOFigpkGy9HFsJif6Rqf6DO4//vNc7I0n/OuV/eWeGBfyP9dvGK7wzexrHeWckqfnkMO9MdmqLd6bwvA+9M4vS/+R/HP/ZtJKkj3o7vTNPNF7pnXnp//h/a/7C/+k/INQdrPXOSJLr9D8PkHpctyq0Wc3NzUpLSzvjfsyCAwCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYiOknouIsicQwpbrmsHdk+Pop/seR9C8/meudeSjvVf9M9h+8M01jdnhnJKk54j+BPCPB/+8pN2mUd6Y50uOdefNkundGku5/7xbvTM+rmd6Zi1/xnwree+SYd8Z1d3ln0P+4AwIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCYaRDTORkh3cmreJgTMeq7Zjsnfn2Qv/M9Zfv8c5ck/5n74wkpSec9M683XWed+a141O9M1ve8c+MPDDMOyNJuf+3zTuT9O5fvTM94VbvTExDejEgcQcEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABMNIh5oYBjX2/u3jmA41/E3/4ZOXVGd7Z/Ze8HXvzJ9GfsM7I0mBiH8mkhTwzgSbur0zl7zf6J1xjbH93cYy1La3x//PJOf8MxgyuAMCAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABggmGkiHkgZKS93T9UU+sdSY4l450Y+HqsFwDEGXdAAAATFBAAwIRXAZWXl2vGjBlKTU1VVlaWFi5cqOrq6qh95syZo0AgELXdcccdcV00AGDw8yqgyspKlZWVaefOnXr99dfV3d2tefPmqa2tLWq/pUuX6tixY33bmjVr4rpoAMDg5/UmhK1bt0Z9vH79emVlZWnPnj2aPXt23+MjRoxQTk5OfFYIABiSvtJrQM3NzZKkjIyMqMefeeYZZWZmasqUKVq1apXaP+fdUp2dnQqHw1EbAGDoi/lt2JFIRCtWrNAVV1yhKVOm9D1+8803a/z48crLy9P+/ft17733qrq6Wi+99NJpf5/y8nI99NBDsS4DADBIBZyL7T+BLFu2TK+++qreeustjR079oz7bd++XXPnzlVNTY0mTpz4mec7OzvV2dnZ93E4HFZ+fr7maIGSAkPxf3MAwNDW47pVoc1qbm5WWlraGfeL6Q5o+fLl2rJli3bs2PG55SNJRUVFknTGAgoGgwoGg7EsAwAwiHkVkHNOd955pzZu3KiKigoVFBR8YWbfvn2SpNzc3JgWCAAYmrwKqKysTBs2bNDmzZuVmpqq+vp6SVIoFNLw4cN16NAhbdiwQddcc41Gjx6t/fv366677tLs2bM1bdq0fvkDAAAGJ6/XgAKBwGkfX7dunZYsWaK6ujr94Ac/0IEDB9TW1qb8/Hxdd911uu+++z73+4D/KBwOKxQK8RoQAAxS/fIa0Bd1VX5+viorK31+SwDAOYpZcAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAE0nWC/g055wkqUfdkjNeDADAW4+6Jf393/MzGXAF1NLSIkl6S68YrwQA8FW0tLQoFAqd8fmA+6KKOssikYiOHj2q1NRUBQKBqOfC4bDy8/NVV1entLQ0oxXa4zycwnk4hfNwCufhlIFwHpxzamlpUV5enhISzvxKz4C7A0pISNDYsWM/d5+0tLRz+gL7BOfhFM7DKZyHUzgPp1ifh8+78/kEb0IAAJiggAAAJgZVAQWDQa1evVrBYNB6KaY4D6dwHk7hPJzCeThlMJ2HAfcmBADAuWFQ3QEBAIYOCggAYIICAgCYoIAAACYGTQGtXbtWF1xwgYYNG6aioiL98Y9/tF7SWffggw8qEAhEbZMnT7ZeVr/bsWOHrr32WuXl5SkQCGjTpk1Rzzvn9MADDyg3N1fDhw9XcXGxDh48aLPYfvRF52HJkiWfuT7mz59vs9h+Ul5erhkzZig1NVVZWVlauHChqquro/bp6OhQWVmZRo8erVGjRmnRokVqaGgwWnH/+DLnYc6cOZ+5Hu644w6jFZ/eoCig559/XitXrtTq1av19ttvq7CwUCUlJTp+/Lj10s66Sy+9VMeOHevb3nrrLesl9bu2tjYVFhZq7dq1p31+zZo1evzxx/XUU09p165dGjlypEpKStTR0XGWV9q/vug8SNL8+fOjro9nn332LK6w/1VWVqqsrEw7d+7U66+/ru7ubs2bN09tbW19+9x11116+eWX9eKLL6qyslJHjx7V9ddfb7jq+Psy50GSli5dGnU9rFmzxmjFZ+AGgZkzZ7qysrK+j3t7e11eXp4rLy83XNXZt3r1aldYWGi9DFOS3MaNG/s+jkQiLicnxz3yyCN9jzU1NblgMOieffZZgxWeHZ8+D845t3jxYrdgwQKT9Vg5fvy4k+QqKyudc6f+7pOTk92LL77Yt8+7777rJLmqqiqrZfa7T58H55z79re/7X70ox/ZLepLGPB3QF1dXdqzZ4+Ki4v7HktISFBxcbGqqqoMV2bj4MGDysvL04QJE3TLLbfo8OHD1ksyVVtbq/r6+qjrIxQKqaio6Jy8PioqKpSVlaVJkyZp2bJlamxstF5Sv2pubpYkZWRkSJL27Nmj7u7uqOth8uTJGjdu3JC+Hj59Hj7xzDPPKDMzU1OmTNGqVavU3t5usbwzGnDDSD/txIkT6u3tVXZ2dtTj2dnZeu+994xWZaOoqEjr16/XpEmTdOzYMT300EO66qqrdODAAaWmplovz0R9fb0knfb6+OS5c8X8+fN1/fXXq6CgQIcOHdJPf/pTlZaWqqqqSomJidbLi7tIJKIVK1boiiuu0JQpUySduh5SUlKUnp4ete9Qvh5Odx4k6eabb9b48eOVl5en/fv3695771V1dbVeeuklw9VGG/AFhL8rLS3t+/W0adNUVFSk8ePH64UXXtBtt91muDIMBDfeeGPfr6dOnapp06Zp4sSJqqio0Ny5cw1X1j/Kysp04MCBc+J10M9zpvNw++239/166tSpys3N1dy5c3Xo0CFNnDjxbC/ztAb8t+AyMzOVmJj4mXexNDQ0KCcnx2hVA0N6erouvvhi1dTUWC/FzCfXANfHZ02YMEGZmZlD8vpYvny5tmzZojfffDPqx7fk5OSoq6tLTU1NUfsP1evhTOfhdIqKiiRpQF0PA76AUlJSNH36dG3btq3vsUgkom3btmnWrFmGK7PX2tqqQ4cOKTc313opZgoKCpSTkxN1fYTDYe3ateucvz6OHDmixsbGIXV9OOe0fPlybdy4Udu3b1dBQUHU89OnT1dycnLU9VBdXa3Dhw8Pqevhi87D6ezbt0+SBtb1YP0uiC/jueeec8Fg0K1fv96988477vbbb3fp6emuvr7eemln1Y9//GNXUVHhamtr3e9//3tXXFzsMjMz3fHjx62X1q9aWlrc3r173d69e50k9+ijj7q9e/e6Dz74wDnn3C9+8QuXnp7uNm/e7Pbv3+8WLFjgCgoK3MmTJ41XHl+fdx5aWlrc3Xff7aqqqlxtba1744033OWXX+4uuugi19HRYb30uFm2bJkLhUKuoqLCHTt2rG9rb2/v2+eOO+5w48aNc9u3b3e7d+92s2bNcrNmzTJcdfx90XmoqalxDz/8sNu9e7erra11mzdvdhMmTHCzZ882Xnm0QVFAzjn3xBNPuHHjxrmUlBQ3c+ZMt3PnTuslnXU33HCDy83NdSkpKe788893N9xwg6upqbFeVr978803naTPbIsXL3bOnXor9v333++ys7NdMBh0c+fOddXV1baL7gefdx7a29vdvHnz3JgxY1xycrIbP368W7p06ZD7Iu10f35Jbt26dX37nDx50v3whz905513nhsxYoS77rrr3LFjx+wW3Q++6DwcPnzYzZ4922VkZLhgMOguvPBC95Of/MQ1NzfbLvxT+HEMAAATA/41IADA0EQBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMDE/wP2Gvqm9MbL2gAAAABJRU5ErkJggg==", 253 | "text/plain": [ 254 | "
" 255 | ] 256 | }, 257 | "metadata": {}, 258 | "output_type": "display_data" 259 | } 260 | ], 261 | "source": [ 262 | "import matplotlib.pyplot as plt\n", 263 | "print(z1.max(), z1.min(), z1.std())\n", 264 | "plt.imshow(decoder(torch.randn(1, z_size).cuda()).view(28, 28).detach().cpu().numpy())" 265 | ] 266 | } 267 | ], 268 | "metadata": { 269 | "kernelspec": { 270 | "display_name": "Python 3", 271 | "language": "python", 272 | "name": "python3" 273 | }, 274 | "language_info": { 275 | "codemirror_mode": { 276 | "name": "ipython", 277 | "version": 3 278 | }, 279 | "file_extension": ".py", 280 | "mimetype": "text/x-python", 281 | "name": "python", 282 | "nbconvert_exporter": "python", 283 | "pygments_lexer": "ipython3", 284 | "version": "3.11.4" 285 | } 286 | }, 287 | "nbformat": 4, 288 | "nbformat_minor": 2 289 | } 290 | -------------------------------------------------------------------------------- /vqvae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.optim as optim\n", 12 | "from torchvision import datasets, transforms\n", 13 | "from torch.utils.data import DataLoader\n", 14 | "import cv2\n", 15 | "import numpy as np\n", 16 | "torch.manual_seed(41)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "### 1. Load MNIST dataset" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# 1. Load MNIST dataset\n", 33 | "transform = transforms.Compose([transforms.ToTensor()])\n", 34 | "train_dataset = datasets.MNIST(\n", 35 | " root=\"./data\", train=True, download=True, transform=transform\n", 36 | ")\n", 37 | "train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=False)\n" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "### 2. Define VQVAE model" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "\n", 54 | "class VQVAE(nn.Module):\n", 55 | " def __init__(self):\n", 56 | " super(VQVAE, self).__init__()\n", 57 | " self.encoder = nn.Sequential(\n", 58 | " nn.Conv2d(1, 4, 4, stride=4, padding=0),\n", 59 | " nn.BatchNorm2d(4),\n", 60 | " nn.ReLU()\n", 61 | " )\n", 62 | "\n", 63 | " self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1)\n", 64 | " self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=2)\n", 65 | " self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1)\n", 66 | "\n", 67 | " # Commitment Loss Beta\n", 68 | " self.beta = 0.2\n", 69 | "\n", 70 | " self.decoder = nn.Sequential(\n", 71 | " nn.ConvTranspose2d(4, 16, 4, stride=2, padding=1),\n", 72 | " nn.BatchNorm2d(16),\n", 73 | " nn.ReLU(),\n", 74 | " nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1),\n", 75 | " nn.Sigmoid(),\n", 76 | " )\n", 77 | "\n", 78 | " def forward(self, x):\n", 79 | " # B, C, H, W\n", 80 | " encoded_output = self.encoder(x)\n", 81 | " quant_input = self.pre_quant_conv(encoded_output)\n", 82 | "\n", 83 | " ## Quantization\n", 84 | " B, C, H, W = quant_input.shape\n", 85 | " quant_input = quant_input.permute(0, 2, 3, 1)\n", 86 | " quant_input = quant_input.reshape(\n", 87 | " (quant_input.size(0), -1, quant_input.size(-1))\n", 88 | " )\n", 89 | "\n", 90 | " # Compute pairwise distances\n", 91 | " dist = torch.cdist(\n", 92 | " quant_input,\n", 93 | " self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1)),\n", 94 | " )\n", 95 | "\n", 96 | " # Find index of nearest embedding\n", 97 | " min_encoding_indices = torch.argmin(dist, dim=-1)\n", 98 | "\n", 99 | " # Select the embedding weights\n", 100 | " quant_out = torch.index_select(\n", 101 | " self.embedding.weight, 0, min_encoding_indices.view(-1)\n", 102 | " )\n", 103 | " quant_input = quant_input.reshape((-1, quant_input.size(-1)))\n", 104 | "\n", 105 | " # Compute losses\n", 106 | " commitment_loss = torch.mean((quant_out.detach() - quant_input) ** 2)\n", 107 | " codebook_loss = torch.mean((quant_out - quant_input.detach()) ** 2)\n", 108 | " quantize_losses = codebook_loss + commitment_loss * 0.1\n", 109 | " quant_out = quant_input + (quant_out - quant_input).detach()\n", 110 | "\n", 111 | " # Reshaping back to original input shape\n", 112 | " quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)\n", 113 | " min_encoding_indices = min_encoding_indices.reshape(\n", 114 | " (-1, quant_out.size(-2), quant_out.size(-1))\n", 115 | " )\n", 116 | "\n", 117 | " ## Decoder part\n", 118 | " decoder_input = self.post_quant_conv(quant_out)\n", 119 | " output = self.decoder(decoder_input)\n", 120 | " return output, quantize_losses" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "### 3. Initialize model, optimizer, scheduler" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "model = VQVAE().cuda()\n", 137 | "optimizer = optim.AdamW(model.parameters(), lr=2e-3)\n", 138 | "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "### 4. Training loop" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "model.train()\n", 155 | "for epoch in range(20):\n", 156 | " for batch_idx, (data, _) in enumerate(train_loader):\n", 157 | " data = data.cuda()\n", 158 | " optimizer.zero_grad()\n", 159 | " out, quantize_loss = model(data)\n", 160 | " recon_loss = torch.nn.functional.mse_loss(out, data)\n", 161 | " loss = recon_loss + quantize_loss\n", 162 | " loss.backward()\n", 163 | " nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", 164 | " optimizer.step()\n", 165 | " print(\n", 166 | " f\"\\rEpoch {epoch}, Batch {batch_idx:03d}, Loss: {loss.item():.4f} = {recon_loss.item():.4f} + {quantize_loss.item():.4f}\",\n", 167 | " end=\"\",\n", 168 | " )\n", 169 | " scheduler.step()\n", 170 | " print(\"\")\n", 171 | "\n", 172 | "# torch.save(model.state_dict(), \"vqvae.ckpt\")" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "### 5. Inference" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "def save(data, idx):\n", 189 | " data = (data * 255).astype(np.uint8)\n", 190 | " cv2.imwrite(f\"img_{idx}.png\", data)\n", 191 | "\n", 192 | "\n", 193 | "model.eval()\n", 194 | "with torch.no_grad():\n", 195 | " for i in range(10):\n", 196 | " data, _ = train_dataset[i]\n", 197 | " data = data.unsqueeze(0).cuda()\n", 198 | " x_recon, _ = model(data)\n", 199 | " recon_img = x_recon.cpu().squeeze().numpy()\n", 200 | " # Save or display recon_img\n", 201 | " save(x_recon, i)\n" 202 | ] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.11.4" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 2 226 | } 227 | --------------------------------------------------------------------------------