├── DDPM_notebook.ipynb
├── Neural_ODE.ipynb
├── README.md
├── base
├── mnist.py
├── model.py
└── train.py
├── flowmatching.ipynb
├── pi_digit_estimation_GRU.ipynb
├── vae.ipynb
└── vqvae.ipynb
/Neural_ODE.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4"
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "gpuClass": "standard",
17 | "accelerator": "GPU"
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {
24 | "colab": {
25 | "base_uri": "https://localhost:8080/"
26 | },
27 | "id": "ysYtGUmudhFw",
28 | "outputId": "ea41ca86-4ca0-4225-a0be-561e8d5ceda2"
29 | },
30 | "outputs": [
31 | {
32 | "output_type": "stream",
33 | "name": "stdout",
34 | "text": [
35 | "Collecting torchdiffeq\n",
36 | " Downloading torchdiffeq-0.2.3-py3-none-any.whl (31 kB)\n",
37 | "Requirement already satisfied: torch>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from torchdiffeq) (2.0.1+cu118)\n",
38 | "Requirement already satisfied: scipy>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from torchdiffeq) (1.10.1)\n",
39 | "Requirement already satisfied: numpy<1.27.0,>=1.19.5 in /usr/local/lib/python3.10/dist-packages (from scipy>=1.4.0->torchdiffeq) (1.22.4)\n",
40 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (3.12.2)\n",
41 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (4.6.3)\n",
42 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (1.11.1)\n",
43 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (3.1)\n",
44 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (3.1.2)\n",
45 | "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.3.0->torchdiffeq) (2.0.0)\n",
46 | "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.3.0->torchdiffeq) (3.25.2)\n",
47 | "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.3.0->torchdiffeq) (16.0.6)\n",
48 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.3.0->torchdiffeq) (2.1.3)\n",
49 | "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.3.0->torchdiffeq) (1.3.0)\n",
50 | "Installing collected packages: torchdiffeq\n",
51 | "Successfully installed torchdiffeq-0.2.3\n"
52 | ]
53 | }
54 | ],
55 | "source": [
56 | "!pip install torchdiffeq"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "source": [
62 | "from torchdiffeq import odeint_adjoint as odeint\n",
63 | "import torch\n",
64 | "import torch.nn as nn\n",
65 | "import torch.nn.functional as F\n",
66 | "import torch.optim as optim\n",
67 | "import numpy as np\n",
68 | "import os\n",
69 | "\n",
70 | "device = torch.device(\"cuda\")"
71 | ],
72 | "metadata": {
73 | "id": "1mRK6nlmeHz2"
74 | },
75 | "execution_count": null,
76 | "outputs": []
77 | },
78 | {
79 | "cell_type": "code",
80 | "source": [
81 | "class TargetFunction(nn.Module):\n",
82 | " def __init__(self):\n",
83 | " super().__init__()\n",
84 | " self.true_function = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)\n",
85 | "\n",
86 | " def forward(self, t, x):\n",
87 | " return torch.mm(x**3, self.true_function)\n",
88 | "\n",
89 | "starting_point = torch.tensor([[2.0, 0.]]).to(device)\n",
90 | "t = torch.linspace(0., 25., 1000).to(device)\n",
91 | "target_func = TargetFunction().to(device)\n",
92 | "\n",
93 | "with torch.no_grad():\n",
94 | " true_dxdt = odeint(target_func, starting_point, t)\n",
95 | "\n",
96 | "print(\"> Starting point = {}\".format(starting_point.squeeze()))\n",
97 | "print(\"> t.shape = {}\".format(t.shape))\n",
98 | "print(\"> true_dxdt.shape = {}\".format(true_dxdt.shape))"
99 | ],
100 | "metadata": {
101 | "colab": {
102 | "base_uri": "https://localhost:8080/"
103 | },
104 | "id": "_8mN11OUe-Gl",
105 | "outputId": "36ac52d5-c8fc-4089-f9ae-b4b40281ccdc"
106 | },
107 | "execution_count": null,
108 | "outputs": [
109 | {
110 | "output_type": "stream",
111 | "name": "stdout",
112 | "text": [
113 | "> Starting point = tensor([2., 0.], device='cuda:0')\n",
114 | "> t.shape = torch.Size([1000])\n",
115 | "> true_dxdt.shape = torch.Size([1000, 1, 2])\n"
116 | ]
117 | }
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "source": [
123 | "def get_batch():\n",
124 | " random_points = np.random.choice(np.arange(990, dtype=np.int64), 768, replace=False)\n",
125 | " batch_starting_point = true_dxdt[random_points]\n",
126 | " batch_t = t[:10]\n",
127 | " batch_dxdt = torch.stack([true_dxdt[random_points + i] for i in range(10)], dim=0)\n",
128 | "\n",
129 | " return batch_starting_point, batch_t, batch_dxdt"
130 | ],
131 | "metadata": {
132 | "id": "gP88LQRLlFY8"
133 | },
134 | "execution_count": null,
135 | "outputs": []
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "source": [],
140 | "metadata": {
141 | "id": "-YpdMmXX0Q09"
142 | }
143 | },
144 | {
145 | "cell_type": "code",
146 | "source": [
147 | "class PredictFunction(nn.Module):\n",
148 | " def __init__(self):\n",
149 | " super().__init__()\n",
150 | "\n",
151 | " self.net = nn.Sequential(\n",
152 | " nn.Linear(2, 16),\n",
153 | " nn.ReLU(),\n",
154 | " nn.Linear(16, 50),\n",
155 | " nn.ReLU(),\n",
156 | " nn.Linear(50, 50),\n",
157 | " nn.ReLU(),\n",
158 | " nn.Linear(50, 16),\n",
159 | " nn.ReLU(),\n",
160 | " nn.Linear(16, 2)\n",
161 | " )\n",
162 | "\n",
163 | " def forward(self, t, x):\n",
164 | " return self.net(x)"
165 | ],
166 | "metadata": {
167 | "id": "i8eSU_94izdL"
168 | },
169 | "execution_count": null,
170 | "outputs": []
171 | },
172 | {
173 | "cell_type": "code",
174 | "source": [
175 | "import matplotlib.pyplot as plt\n",
176 | "def visualize(true_y, pred_y, odefunc, itr):\n",
177 | " fig = plt.figure(figsize=(12, 4), facecolor='white')\n",
178 | " ax_traj = fig.add_subplot(131, frameon=False)\n",
179 | " ax_phase = fig.add_subplot(132, frameon=False)\n",
180 | " ax_vecfield = fig.add_subplot(133, frameon=False)\n",
181 | " ax_traj.cla()\n",
182 | " ax_traj.set_title('Trajectories')\n",
183 | " ax_traj.set_xlabel('t')\n",
184 | " ax_traj.set_ylabel('x,y')\n",
185 | " ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')\n",
186 | " ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')\n",
187 | " ax_traj.set_xlim(t.cpu().min(), t.cpu().max())\n",
188 | " ax_traj.set_ylim(-2, 2)\n",
189 | "\n",
190 | " ax_phase.cla()\n",
191 | " ax_phase.set_title('Phase Portrait')\n",
192 | " ax_phase.set_xlabel('x')\n",
193 | " ax_phase.set_ylabel('y')\n",
194 | " ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')\n",
195 | " ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')\n",
196 | " ax_phase.set_xlim(-2, 2)\n",
197 | " ax_phase.set_ylim(-2, 2)\n",
198 | "\n",
199 | " ax_vecfield.cla()\n",
200 | " ax_vecfield.set_title('Learned Vector Field')\n",
201 | " ax_vecfield.set_xlabel('x')\n",
202 | " ax_vecfield.set_ylabel('y')\n",
203 | "\n",
204 | " y, x = np.mgrid[-2:2:21j, -2:2:21j]\n",
205 | " dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()\n",
206 | " mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)\n",
207 | " dydt = (dydt / mag)\n",
208 | " dydt = dydt.reshape(21, 21, 2)\n",
209 | "\n",
210 | " ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color=\"black\")\n",
211 | " ax_vecfield.set_xlim(-2, 2)\n",
212 | " ax_vecfield.set_ylim(-2, 2)\n",
213 | "\n",
214 | " fig.tight_layout()\n",
215 | " plt.savefig('png/{:03d}.png'.format(itr))\n",
216 | " #plt.draw()\n",
217 | " #plt.pause(0.001)"
218 | ],
219 | "metadata": {
220 | "id": "e3-lMYe5uP5F"
221 | },
222 | "execution_count": null,
223 | "outputs": []
224 | },
225 | {
226 | "cell_type": "code",
227 | "source": [
228 | "net = PredictFunction().to(device)\n",
229 | "opt = optim.Adam(net.parameters(), lr=3e-4)\n",
230 | "os.makedirs(\"png\", exist_ok=True)\n",
231 | "\n",
232 | "for i in range(3000+1):\n",
233 | " opt.zero_grad()https:\\\\\\\\colab.research.google.com\\\\a93e4ab1-c004-40da-8cea-f5bcbedcbb07\n",
234 | " batch_starting, batch_t, batch_dxdt = get_batch()\n",
235 | " pred_dxdt = odeint(net, batch_starting, batch_t).to(device)\n",
236 | " loss = F.l1_loss(pred_dxdt, batch_dxdt)\n",
237 | " loss.backward()\n",
238 | " opt.step()\n",
239 | " print(\"\\r> [{}/3000] loss = {:.3f}\".format(i, loss.item()), end='')\n",
240 | "\n",
241 | " if i % 300 == 0:\n",
242 | " with torch.no_grad():\n",
243 | " pred_dxdt = odeint(net, starting_point, t)\n",
244 | " visualize(true_dxdt, pred_dxdt, net, i)\n",
245 | " print(\"\")"
246 | ],
247 | "metadata": {
248 | "colab": {
249 | "base_uri": "https://localhost:8080/"
250 | },
251 | "id": "CiA9tmiHpkx6",
252 | "outputId": "9dcc072c-117c-463f-9882-586f7832c099"
253 | },
254 | "execution_count": null,
255 | "outputs": [
256 | {
257 | "output_type": "stream",
258 | "name": "stdout",
259 | "text": [
260 | "> [0/3000] loss = 0.087\n",
261 | "> [300/3000] loss = 0.051\n",
262 | "> [600/3000] loss = 0.027\n",
263 | "> [900/3000] loss = 0.013\n",
264 | "> [1200/3000] loss = 0.007\n",
265 | "> [1500/3000] loss = 0.005\n",
266 | "> [1800/3000] loss = 0.003\n",
267 | "> [2100/3000] loss = 0.002\n",
268 | "> [2400/3000] loss = 0.001\n",
269 | "> [2687/3000] loss = 0.001"
270 | ]
271 | }
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "source": [],
277 | "metadata": {
278 | "id": "Jm60tmjN5X7i"
279 | },
280 | "execution_count": null,
281 | "outputs": []
282 | }
283 | ]
284 | }
285 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Simple Deep Learning Projects
2 | jupyter notebook 단위로 project 구현 / Imcommit 채널에서 각각에 대해 다룰지도
3 |
4 |
5 | ### 1. DDPM-notebook.ipynb : DDPM 논문 알고리즘 구현
6 | Paper: [https://arxiv.org/pdf/2006.11239.pdf](https://arxiv.org/pdf/2006.11239.pdf)
7 | Youtube: [https://www.youtube.com/watch?v=svSQhYGKk0Q](https://www.youtube.com/watch?v=svSQhYGKk0Q)
8 |
9 | 실험적인 세부사항(Exponential Moving Average 등)은 제외
10 | - coefficient 구현이 핵심
11 | Dataset: CIFAR-10
12 |
13 | ### 2. pi_digit_estimation_GRU.ipynb : π 패턴 예측
14 | [https://youtu.be/kdmrlMAaCiA](https://youtu.be/kdmrlMAaCiA)
15 |
16 | 될 리가 있나
17 |
18 | ### 3. Neural_ODE.ipynb : official code 약간 수정
19 | Paper: [https://arxiv.org/pdf/1806.07366.pdf](https://arxiv.org/pdf/1806.07366.pdf)
20 | Official github(Reference): [https://github.com/rtqichen/torchdiffeq](https://github.com/rtqichen/torchdiffeq)
21 | [https://www.youtube.com/watch?v=NS-C_QjjcT4](https://www.youtube.com/watch?v=NS-C_QjjcT4)
22 |
23 | 출력 코드 정리
24 | model 구성 변경(official code보다 학습이 약간 어렵도록)
25 | Dataset: toy example
26 |
27 | ### 4. vae.ipynb : VAE, AE, AE with z-regularization
28 | paper: [https://arxiv.org/pdf/1312.6114.pdf](https://arxiv.org/pdf/1312.6114.pdf)
29 |
30 | Autoencoder: reconstruction loss only
31 | AE with z-regularization: loss = reconstruction loss + z.abs().mean()
32 | VAE: loss = reconstruction loss + kl divergence
33 | Dataset: MNIST dataset
34 |
35 | ### 5. flowmatching.ipynb : flow matching for generative modeling
36 | Paper: [https://arxiv.org/pdf/2210.02747](https://arxiv.org/pdf/2210.02747)
37 |
38 | 실험적인 세부사항(Exponential Moving Average 등)은 제외
39 | - train과 sample 함수가 핵심
40 | - Dataset: MNIST
41 |
--------------------------------------------------------------------------------
/base/mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.datasets import MNIST
3 | from torchvision.transforms import ToTensor
4 |
5 | class MNISTData:
6 | def __init__(self):
7 | dataset = MNIST(root='./data', train=True, download=True)
8 | self.dataset = [(ToTensor()(img), target) for img, target in dataset]
9 | val_dataset = MNIST(root='./data', train=False, download=True)
10 | self.val_dataset = [(ToTensor()(img), target) for img, target in val_dataset]
11 |
12 | def __len__(self):
13 | return len(self.dataset)
14 |
15 | def __getitem__(self, idx):
16 | return self.dataset[idx]
--------------------------------------------------------------------------------
/base/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Net(nn.Module):
5 | def __init__(self):
6 | super().__init__()
7 | self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1, stride=2)
8 | self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1, stride=2)
9 | self.fc1 = nn.Linear(7*7*16, 10)
10 |
11 | def forward(self, x):
12 | x = torch.relu(self.conv1(x))
13 | x = torch.relu(self.conv2(x))
14 | x = x.view(-1, 7*7*16)
15 | x = self.fc1(x)
16 | return x
--------------------------------------------------------------------------------
/base/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | from mnist import MNISTData
5 | from model import Net
6 |
7 | class Trainer:
8 | def __init__(self):
9 | self.model = Net()
10 | dataset = MNISTData()
11 | self.dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
12 | self.val_dataloader = DataLoader(dataset.val_dataset, batch_size=32, shuffle=False)
13 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
14 | self.criterion = torch.nn.CrossEntropyLoss()
15 |
16 | def train(self, epochs):
17 | for epoch in range(epochs):
18 | for i, (data, target) in enumerate(self.dataloader, 1):
19 | self.optimizer.zero_grad()
20 | output = self.model(data)
21 | loss = self.criterion(output, target)
22 | loss.backward()
23 | self.optimizer.step()
24 | if i % 100 == 0:
25 | print(f"\rEpoch {epoch}, Loss {loss.item()}", end='')
26 | print()
27 |
28 | def test(self):
29 | correct = 0
30 | total = 0
31 | with torch.no_grad():
32 | for data, target in self.val_dataloader:
33 | output = self.model(data)
34 | _, predicted = torch.max(output.data, 1)
35 | total += target.size(0)
36 | correct += (predicted == target).sum().item()
37 | print(f"Accuracy: {correct/total}, {correct=}, {total=}")
38 |
39 |
40 | def main():
41 | trainer = Trainer()
42 | trainer.train(5)
43 | trainer.test()
44 |
45 | if __name__ == "__main__":
46 | main()
--------------------------------------------------------------------------------
/flowmatching.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn as nn\n",
11 | "from torch.nn import init\n",
12 | "import torch.nn.functional as F\n",
13 | "from torchvision.datasets import MNIST\n",
14 | "from torchvision import transforms\n",
15 | "\n",
16 | "from random import random\n",
17 | "from cv2 import imwrite\n",
18 | "import math\n",
19 | "import os"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "dataset = MNIST(\n",
29 | " root=\"./data\", train=True, download=True,\n",
30 | " transform=transforms.Compose([\n",
31 | " transforms.ToTensor()\n",
32 | " ])\n",
33 | ")\n",
34 | "dataloader = torch.utils.data.DataLoader(\n",
35 | " dataset, batch_size=128, shuffle=True, num_workers=4\n",
36 | ")"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "def train(model, x, label):\n",
46 | " t = random()\n",
47 | " eps = torch.randn_like(x) # eps ~ N(0, I)\n",
48 | " x_t = t * x + (1-t) * eps\n",
49 | " drop_label = random() < 0.3\n",
50 | " hat = model(x_t, t, label, drop_label)\n",
51 | " loss = F.mse_loss(hat, x-eps)\n",
52 | " return loss\n",
53 | "\n",
54 | "def write(x_0):\n",
55 | " x_0 = x_0.permute(0, 2, 3, 1).clamp(0, 1).detach().cpu().numpy() * 255\n",
56 | " imwrite(\"temp.png\", x_0[0])\n",
57 | "\n",
58 | "def sample(model, z, label, num_step=10, w=2.0, cfg=True):\n",
59 | " x_t = z # ~ N(0, I)\n",
60 | " ts = torch.linspace(0, 1, num_step+1) # num_step = 10 -> ts : [.0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 10]\n",
61 | " dts = ts[1:] - ts[:-1] # [.1,] * 10\n",
62 | " for t, dt in zip(ts[:-1], dts): # [.0, .1], [.1, .1], [.2, .1], ..., [.9, .1]\n",
63 | " hat = model(x_t, t, label, drop_label=False)\n",
64 | " if cfg:\n",
65 | " hat_uncond = model(x_t, t, label, drop_label=True)\n",
66 | " hat = (1+w) * hat - w * hat_uncond\n",
67 | " x_t = x_t + dt * hat\n",
68 | " return x_t"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "class Swish(nn.Module):\n",
78 | " def forward(self, x):\n",
79 | " return x * torch.sigmoid(x)\n",
80 | "\n",
81 | "class TimeEmbedding(nn.Module):\n",
82 | " def __init__(self, T, d_model, dim):\n",
83 | " assert d_model % 2 == 0\n",
84 | " super().__init__()\n",
85 | " emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)\n",
86 | " emb = torch.exp(-emb)\n",
87 | " pos = torch.arange(T).float()\n",
88 | " emb = pos[:, None] * emb[None, :]\n",
89 | " emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)\n",
90 | " emb = emb.view(T, d_model)\n",
91 | "\n",
92 | " self.timembedding = nn.Sequential(\n",
93 | " nn.Linear(1, d_model),\n",
94 | " nn.Linear(d_model, dim),\n",
95 | " Swish(),\n",
96 | " nn.Linear(dim, dim),\n",
97 | " )\n",
98 | " self.initialize()\n",
99 | "\n",
100 | " def initialize(self):\n",
101 | " for module in self.modules():\n",
102 | " if isinstance(module, nn.Linear):\n",
103 | " init.xavier_uniform_(module.weight)\n",
104 | " init.zeros_(module.bias)\n",
105 | "\n",
106 | " def forward(self, t):\n",
107 | " emb = self.timembedding(t)\n",
108 | " return emb\n",
109 | "\n",
110 | "class DownSample(nn.Module):\n",
111 | " def __init__(self, in_ch):\n",
112 | " super().__init__()\n",
113 | " self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)\n",
114 | " self.initialize()\n",
115 | "\n",
116 | " def initialize(self):\n",
117 | " init.xavier_uniform_(self.main.weight)\n",
118 | " init.zeros_(self.main.bias)\n",
119 | "\n",
120 | " def forward(self, x, temb):\n",
121 | " x = self.main(x)\n",
122 | " return x\n",
123 | "\n",
124 | "class UpSample(nn.Module):\n",
125 | " def __init__(self, in_ch):\n",
126 | " super().__init__()\n",
127 | " self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)\n",
128 | " self.initialize()\n",
129 | "\n",
130 | " def initialize(self):\n",
131 | " init.xavier_uniform_(self.main.weight)\n",
132 | " init.zeros_(self.main.bias)\n",
133 | "\n",
134 | " def forward(self, x, temb):\n",
135 | " _, _, H, W = x.shape\n",
136 | " x = F.interpolate(\n",
137 | " x, scale_factor=2, mode='nearest')\n",
138 | " x = self.main(x)\n",
139 | " return x\n",
140 | "\n",
141 | "class AttnBlock(nn.Module):\n",
142 | " def __init__(self, in_ch):\n",
143 | " super().__init__()\n",
144 | " self.group_norm = nn.GroupNorm(32, in_ch)\n",
145 | " self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n",
146 | " self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n",
147 | " self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n",
148 | " self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n",
149 | " self.initialize()\n",
150 | "\n",
151 | " def initialize(self):\n",
152 | " for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:\n",
153 | " init.xavier_uniform_(module.weight)\n",
154 | " init.zeros_(module.bias)\n",
155 | " init.xavier_uniform_(self.proj.weight, gain=1e-5)\n",
156 | "\n",
157 | " def forward(self, x):\n",
158 | " B, C, H, W = x.shape\n",
159 | " h = self.group_norm(x)\n",
160 | " q = self.proj_q(h)\n",
161 | " k = self.proj_k(h)\n",
162 | " v = self.proj_v(h)\n",
163 | "\n",
164 | " q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n",
165 | " k = k.view(B, C, H * W)\n",
166 | " w = torch.bmm(q, k) * (int(C) ** (-0.5))\n",
167 | " assert list(w.shape) == [B, H * W, H * W]\n",
168 | " w = F.softmax(w, dim=-1)\n",
169 | "\n",
170 | " v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n",
171 | " h = torch.bmm(w, v)\n",
172 | " assert list(h.shape) == [B, H * W, C]\n",
173 | " h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n",
174 | " h = self.proj(h)\n",
175 | "\n",
176 | " return x + h\n",
177 | "\n",
178 | "class ResBlock(nn.Module):\n",
179 | " def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):\n",
180 | " super().__init__()\n",
181 | " self.block1 = nn.Sequential(\n",
182 | " nn.GroupNorm(32, in_ch),\n",
183 | " Swish(),\n",
184 | " nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),\n",
185 | " )\n",
186 | " self.temb_proj = nn.Sequential(\n",
187 | " Swish(),\n",
188 | " nn.Linear(tdim, out_ch),\n",
189 | " )\n",
190 | " self.block2 = nn.Sequential(\n",
191 | " nn.GroupNorm(32, out_ch),\n",
192 | " Swish(),\n",
193 | " nn.Dropout(dropout),\n",
194 | " nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),\n",
195 | " )\n",
196 | " if in_ch != out_ch:\n",
197 | " self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)\n",
198 | " else:\n",
199 | " self.shortcut = nn.Identity()\n",
200 | " if attn:\n",
201 | " self.attn = AttnBlock(out_ch)\n",
202 | " else:\n",
203 | " self.attn = nn.Identity()\n",
204 | " self.initialize()\n",
205 | "\n",
206 | " def initialize(self):\n",
207 | " for module in self.modules():\n",
208 | " if isinstance(module, (nn.Conv2d, nn.Linear)):\n",
209 | " init.xavier_uniform_(module.weight)\n",
210 | " init.zeros_(module.bias)\n",
211 | " init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)\n",
212 | "\n",
213 | " def forward(self, x, temb):\n",
214 | " h = self.block1(x)\n",
215 | " h += self.temb_proj(temb)[:, :, None, None]\n",
216 | " h = self.block2(h)\n",
217 | "\n",
218 | " h = h + self.shortcut(x)\n",
219 | " h = self.attn(h)\n",
220 | " return h\n",
221 | "\n",
222 | "class UNet(nn.Module):\n",
223 | " def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n",
224 | " super().__init__()\n",
225 | " assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n",
226 | " tdim = ch * 4\n",
227 | " self.time_embedding = TimeEmbedding(T, ch, tdim)\n",
228 | " self.label_embedding = nn.Embedding(10, tdim)\n",
229 | "\n",
230 | " self.head = nn.Conv2d(1, ch, kernel_size=3, stride=1, padding=1)\n",
231 | " self.downblocks = nn.ModuleList()\n",
232 | " chs = [ch] # record output channel when dowmsample for upsample\n",
233 | " now_ch = ch\n",
234 | " for i, mult in enumerate(ch_mult):\n",
235 | " out_ch = ch * mult\n",
236 | " for _ in range(num_res_blocks):\n",
237 | " self.downblocks.append(ResBlock(\n",
238 | " in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n",
239 | " dropout=dropout, attn=(i in attn)))\n",
240 | " now_ch = out_ch\n",
241 | " chs.append(now_ch)\n",
242 | " if i != len(ch_mult) - 1:\n",
243 | " self.downblocks.append(DownSample(now_ch))\n",
244 | " chs.append(now_ch)\n",
245 | "\n",
246 | " self.middleblocks = nn.ModuleList([\n",
247 | " ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n",
248 | " ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n",
249 | " ])\n",
250 | "\n",
251 | " self.upblocks = nn.ModuleList()\n",
252 | " for i, mult in reversed(list(enumerate(ch_mult))):\n",
253 | " out_ch = ch * mult\n",
254 | " for _ in range(num_res_blocks + 1):\n",
255 | " self.upblocks.append(ResBlock(\n",
256 | " in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n",
257 | " dropout=dropout, attn=(i in attn)))\n",
258 | " now_ch = out_ch\n",
259 | " if i != 0:\n",
260 | " self.upblocks.append(UpSample(now_ch))\n",
261 | " assert len(chs) == 0\n",
262 | "\n",
263 | " self.tail = nn.Sequential(\n",
264 | " nn.GroupNorm(32, now_ch),\n",
265 | " Swish(),\n",
266 | " nn.Conv2d(now_ch, 1, 3, stride=1, padding=1)\n",
267 | " )\n",
268 | " self.initialize()\n",
269 | "\n",
270 | " def initialize(self):\n",
271 | " init.xavier_uniform_(self.head.weight)\n",
272 | " init.zeros_(self.head.bias)\n",
273 | " init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n",
274 | " init.zeros_(self.tail[-1].bias)\n",
275 | "\n",
276 | " def forward(self, x, t, label, drop_label):\n",
277 | " # Timestep embedding\n",
278 | " t = torch.tensor([t], device=x.device)\n",
279 | " temb = self.time_embedding(t)\n",
280 | " lemb = self.label_embedding(label)\n",
281 | " if drop_label: lemb = torch.zeros_like(lemb)\n",
282 | " temb = temb + lemb\n",
283 | " # Downsampling\n",
284 | " h = self.head(x)\n",
285 | " hs = [h]\n",
286 | " for layer in self.downblocks:\n",
287 | " h = layer(h, temb)\n",
288 | " hs.append(h)\n",
289 | " # Middle\n",
290 | " for layer in self.middleblocks:\n",
291 | " h = layer(h, temb)\n",
292 | " # Upsampling\n",
293 | " for layer in self.upblocks:\n",
294 | " if isinstance(layer, ResBlock):\n",
295 | " h = torch.cat([h, hs.pop()], dim=1)\n",
296 | " h = layer(h, temb)\n",
297 | " h = self.tail(h)\n",
298 | "\n",
299 | " assert len(hs) == 0\n",
300 | " return h"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": null,
306 | "metadata": {},
307 | "outputs": [],
308 | "source": [
309 | "device = torch.device(\"cuda\")\n",
310 | "model = UNet(T=1, ch=128, ch_mult=[2, 2], attn=[1],\n",
311 | " num_res_blocks=2, dropout=0.1).to(device)\n",
312 | "optim = torch.optim.Adam(model.parameters(), lr=2e-4)"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {},
319 | "outputs": [],
320 | "source": [
321 | "for e in range(1, 10+1):\n",
322 | " model.train()\n",
323 | " for i, (x, l) in enumerate(dataloader, 1):\n",
324 | " optim.zero_grad()\n",
325 | " x, l = x.to(device), l.to(device)\n",
326 | " loss = train(model, x, l)\n",
327 | " loss.backward()\n",
328 | " optim.step()\n",
329 | " print(\"\\r[Epoch: {} , Iter: {}/{}] Loss: {:.3f}\".format(e, i, len(dataloader), loss.item()), end='')\n",
330 | " print(\"\\n> Eval at epoch {}\".format(e))\n",
331 | " model.eval()\n",
332 | " with torch.no_grad():\n",
333 | " os.makedirs(\"cfm\", exist_ok=True)\n",
334 | " x_T = torch.randn(1, 1, 28, 28).to(device)\n",
335 | " labels = [torch.zeros([1], device=device, dtype=torch.long) + i for i in range(10)]\n",
336 | " for i in range(10):\n",
337 | " x_0 = sample(model, x_T, labels[i])\n",
338 | " x_0 = x_0.permute(0, 2, 3, 1).clamp(0, 1).detach().cpu().numpy() * 255\n",
339 | " imwrite(f\"cfm/gen_cfg_{e}_{i}.png\", x_0[0])"
340 | ]
341 | }
342 | ],
343 | "metadata": {
344 | "kernelspec": {
345 | "display_name": "foo",
346 | "language": "python",
347 | "name": "python3"
348 | },
349 | "language_info": {
350 | "name": "python",
351 | "version": "3.10.14"
352 | }
353 | },
354 | "nbformat": 4,
355 | "nbformat_minor": 2
356 | }
357 |
--------------------------------------------------------------------------------
/pi_digit_estimation_GRU.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "bDlRyKYTnjE-"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import torch.nn as nn\n",
13 | "import torch.nn.functional as F\n",
14 | "from torch.utils.data import DataLoader\n",
15 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {
22 | "colab": {
23 | "base_uri": "https://localhost:8080/"
24 | },
25 | "id": "eAwmGxvaooYc",
26 | "outputId": "b203b108-bb6f-4893-a872-c09d46be2e2a"
27 | },
28 | "outputs": [
29 | {
30 | "output_type": "stream",
31 | "name": "stdout",
32 | "text": [
33 | "1000001\n"
34 | ]
35 | }
36 | ],
37 | "source": [
38 | "with open(\"pi_million.txt\") as f:\n",
39 | " pi = f.read()\n",
40 | " pi = pi[0] + pi[2:]\n",
41 | "print(len(pi)) # \"3.\" + 1,000,000 digits"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {
48 | "id": "QyxkcuxnW6Yb"
49 | },
50 | "outputs": [],
51 | "source": [
52 | "class Dataset:\n",
53 | " def __init__(self, pi_str, _start, _end, _len=512):\n",
54 | " self.data = self.fetch_data(pi_str, \n",
55 | " start=_start, \n",
56 | " end=_end, \n",
57 | " seq_len=_len)\n",
58 | "\n",
59 | " def fetch_data(self, pi_str, start, end, seq_len):\n",
60 | " sequences = [pi_str[i:i+seq_len] for i in range(start, end)]\n",
61 | " digit_data = [[int(c) for c in seq] for seq in sequences]\n",
62 | "\n",
63 | " return digit_data\n",
64 | " \n",
65 | " def __len__(self):\n",
66 | " return len(self.data)\n",
67 | " \n",
68 | " def __getitem__(self, idx):\n",
69 | " return self.data[idx]\n",
70 | "\n",
71 | "def col_fn(batch):\n",
72 | " return torch.stack([torch.LongTensor(b) for b in batch])\n",
73 | "\n",
74 | "train_set = Dataset(pi, _start=0, _end=100000)\n",
75 | "test_set = Dataset(pi, _start=100000, _end=103000)\n",
76 | "train_dl = DataLoader(train_set, batch_size=256, drop_last=False, collate_fn = col_fn)\n",
77 | "test_dl = DataLoader(test_set, batch_size=256, drop_last=False, collate_fn = col_fn)"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {
84 | "id": "JnHQ0BB2oqIO"
85 | },
86 | "outputs": [],
87 | "source": [
88 | "class Model(nn.Module):\n",
89 | " def __init__(self, num_digit, in_dim, hidden_dim, out_dim):\n",
90 | " super().__init__()\n",
91 | " self.embedding = nn.Embedding(num_digit, in_dim)\n",
92 | " self.in_gru = nn.GRU(in_dim, hidden_dim, batch_first=True)\n",
93 | " self.latent_fc = nn.Linear(hidden_dim, hidden_dim)\n",
94 | " self.out_gru = nn.GRU(hidden_dim, out_dim, batch_first=True)\n",
95 | " self.out_fc = nn.Linear(out_dim, num_digit)\n",
96 | " \n",
97 | " def forward(self, x, return_loss=True):\n",
98 | " emb_out = self.embedding(x)\n",
99 | " hidden, _ = self.in_gru(emb_out)\n",
100 | " latent = self.latent_fc(hidden)\n",
101 | " out, _ = self.out_gru(latent)\n",
102 | " out_digit = self.out_fc(out[:, -2])\n",
103 | "\n",
104 | " if return_loss:\n",
105 | " loss = self.get_loss(out_digit, x)\n",
106 | " return loss\n",
107 | " else:\n",
108 | " return out_digit\n",
109 | "\n",
110 | "\n",
111 | " def get_loss(self, digit_logit, x):\n",
112 | " x_last = x[:, -1]\n",
113 | "\n",
114 | " loss = F.cross_entropy(digit_logit, x_last)\n",
115 | " return loss\n",
116 | "\n",
117 | " @torch.no_grad()\n",
118 | " def generate(self, x):\n",
119 | " digit_logit = self.forward(x, return_loss=False)\n",
120 | "\n",
121 | " prob = torch.softmax(digit_logit, -1)\n",
122 | " sample = torch.argmax(prob, -1)\n",
123 | "\n",
124 | " return sample"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {
131 | "id": "9515oV3gRJgH"
132 | },
133 | "outputs": [],
134 | "source": [
135 | "model_config = {\n",
136 | " \"num_digit\": 10,\n",
137 | " \"in_dim\": 512,\n",
138 | " \"hidden_dim\": 1024,\n",
139 | " \"out_dim\": 512\n",
140 | "}\n",
141 | "\n",
142 | "model = Model(**model_config)\n",
143 | "model.to(device)\n",
144 | "optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.999])"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/"
153 | },
154 | "id": "pQtA2imoulEA",
155 | "outputId": "cb6eec81-4932-4802-8a8b-c3a213c39e35"
156 | },
157 | "outputs": [
158 | {
159 | "output_type": "stream",
160 | "name": "stdout",
161 | "text": [
162 | "idx: 0 [0 / 391], loss= 2.306\n",
163 | "idx: 500 [109 / 391], loss= 2.308\n",
164 | "idx: 1,000 [218 / 391], loss= 2.310\n",
165 | "idx: 1,500 [327 / 391], loss= 2.318\n",
166 | "idx: 2,000 [45 / 391], loss= 2.309\n",
167 | "idx: 2,500 [154 / 391], loss= 2.302\n",
168 | "idx: 3,000 [263 / 391], loss= 2.305\n",
169 | "idx: 3,500 [372 / 391], loss= 2.310\n",
170 | "idx: 4,000 [90 / 391], loss= 2.296\n",
171 | "idx: 4,500 [199 / 391], loss= 2.312\n",
172 | "idx: 5,000 [308 / 391], loss= 2.313\n",
173 | "idx: 5,500 [26 / 391], loss= 2.298\n",
174 | "idx: 6,000 [135 / 391], loss= 2.303\n",
175 | "idx: 6,500 [244 / 391], loss= 2.315\n",
176 | "idx: 7,000 [353 / 391], loss= 2.303\n",
177 | "idx: 7,500 [71 / 391], loss= 2.302\n",
178 | "idx: 8,000 [180 / 391], loss= 2.304\n",
179 | "idx: 8,500 [289 / 391], loss= 2.309\n",
180 | "idx: 9,000 [7 / 391], loss= 2.304\n",
181 | "idx: 9,500 [116 / 391], loss= 2.300\n",
182 | "idx: 10,000 [225 / 391], loss= 2.308\n",
183 | "idx: 10,500 [334 / 391], loss= 2.306\n",
184 | "idx: 11,000 [52 / 391], loss= 2.309\n",
185 | "idx: 11,500 [161 / 391], loss= 2.304\n",
186 | "idx: 12,000 [270 / 391], loss= 2.317\n",
187 | "idx: 12,500 [379 / 391], loss= 2.308\n",
188 | "idx: 13,000 [97 / 391], loss= 2.318\n",
189 | "idx: 13,500 [206 / 391], loss= 2.314\n",
190 | "idx: 14,000 [315 / 391], loss= 2.300\n",
191 | "idx: 14,500 [33 / 391], loss= 2.289\n",
192 | "idx: 15,000 [142 / 391], loss= 2.308\n",
193 | "idx: 15,500 [251 / 391], loss= 2.318\n",
194 | "idx: 16,000 [360 / 391], loss= 2.310\n",
195 | "idx: 16,500 [78 / 391], loss= 2.315\n",
196 | "idx: 17,000 [187 / 391], loss= 2.304\n",
197 | "idx: 17,500 [296 / 391], loss= 2.303\n",
198 | "idx: 18,000 [14 / 391], loss= 2.315\n",
199 | "idx: 18,500 [123 / 391], loss= 2.302\n",
200 | "idx: 19,000 [232 / 391], loss= 2.302\n",
201 | "idx: 19,500 [341 / 391], loss= 2.307\n",
202 | "idx: 19,549 [390 / 391], loss= 2.316"
203 | ]
204 | }
205 | ],
206 | "source": [
207 | "idx = 0\n",
208 | "epochs = 50\n",
209 | "for e in range(epochs):\n",
210 | " for i, x in enumerate(train_dl):\n",
211 | " x = x.to(device)\n",
212 | "\n",
213 | " loss = model(x)\n",
214 | " optim.zero_grad()\n",
215 | " loss.backward()\n",
216 | " optim.step()\n",
217 | "\n",
218 | " print(f\"\\ridx: {idx:,} [{i} / {len(train_dl)}], loss= {loss.item():.3f}\", end='')\n",
219 | " if idx % 500 == 0:\n",
220 | " print(\"\")\n",
221 | "\n",
222 | " idx += 1\n"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "source": [
228 | "torch.save(model.state_dict(), \"512_1024_10k.ckpt\")"
229 | ],
230 | "metadata": {
231 | "id": "mkqNRrG3LKyX"
232 | },
233 | "execution_count": null,
234 | "outputs": []
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "source": [
239 | "# **Evaluation**\n",
240 | "Check whether pi estimator works well in validation dataset"
241 | ],
242 | "metadata": {
243 | "id": "3D5PFjHhC8J7"
244 | }
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {
250 | "id": "ggDal1N6yUdB",
251 | "colab": {
252 | "base_uri": "https://localhost:8080/"
253 | },
254 | "outputId": "d130904a-f279-4c77-aebf-af80d193b066"
255 | },
256 | "outputs": [
257 | {
258 | "output_type": "stream",
259 | "name": "stdout",
260 | "text": [
261 | "total_correct = 9999, total_wrong = 1\n"
262 | ]
263 | }
264 | ],
265 | "source": [
266 | "def check_correct_wrong(sample, true):\n",
267 | " batch_len = len(sample)\n",
268 | " correct = sample == true\n",
269 | " correct_num = correct.sum()\n",
270 | " wrong_num = batch_len - correct_num\n",
271 | "\n",
272 | " return correct_num, wrong_num\n",
273 | "\n",
274 | "model_config = {\n",
275 | " \"num_digit\": 10,\n",
276 | " \"in_dim\": 512,\n",
277 | " \"hidden_dim\": 1024,\n",
278 | " \"out_dim\": 512\n",
279 | "}\n",
280 | "\n",
281 | "model = Model(**model_config)\n",
282 | "state_dict = torch.load(\"512_1024_10k.ckpt\")\n",
283 | "model.load_state_dict(state_dict)\n",
284 | "model.to(device)\n",
285 | "\n",
286 | "total_correct, total_wrong = 0, 0\n",
287 | "for x in train_dl:\n",
288 | " x = x.to(device)\n",
289 | " sample = model.generate(x)\n",
290 | " correct_num, wrong_num = check_correct_wrong(sample, x[:, -1])\n",
291 | " total_correct += correct_num\n",
292 | " total_wrong += wrong_num\n",
293 | "print(f\"total_correct = {total_correct}, total_wrong = {total_wrong}\")"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "source": [],
299 | "metadata": {
300 | "id": "3UbnOa4hbv4q"
301 | },
302 | "execution_count": null,
303 | "outputs": []
304 | }
305 | ],
306 | "metadata": {
307 | "accelerator": "GPU",
308 | "colab": {
309 | "provenance": []
310 | },
311 | "gpuClass": "premium",
312 | "kernelspec": {
313 | "display_name": "Python 3",
314 | "name": "python3"
315 | },
316 | "language_info": {
317 | "name": "python"
318 | }
319 | },
320 | "nbformat": 4,
321 | "nbformat_minor": 0
322 | }
--------------------------------------------------------------------------------
/vae.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "device = cuda\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "import torch\n",
18 | "import torch.nn as nn\n",
19 | "from torch.utils.data import DataLoader\n",
20 | "\n",
21 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
22 | "print(f\"device = {device}\")"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "name": "stdout",
32 | "output_type": "stream",
33 | "text": [
34 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
35 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST\\raw\\train-images-idx3-ubyte.gz\n"
36 | ]
37 | },
38 | {
39 | "name": "stderr",
40 | "output_type": "stream",
41 | "text": [
42 | "100%|██████████| 9912422/9912422 [00:00<00:00, 30112030.95it/s]\n"
43 | ]
44 | },
45 | {
46 | "name": "stdout",
47 | "output_type": "stream",
48 | "text": [
49 | "Extracting ./MNIST\\raw\\train-images-idx3-ubyte.gz to ./MNIST\\raw\n",
50 | "\n",
51 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
52 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
53 | ]
54 | },
55 | {
56 | "name": "stderr",
57 | "output_type": "stream",
58 | "text": [
59 | "100%|██████████| 28881/28881 [00:00<00:00, 1653458.73it/s]\n"
60 | ]
61 | },
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "Extracting ./MNIST\\raw\\train-labels-idx1-ubyte.gz to ./MNIST\\raw\n",
67 | "\n",
68 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
69 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
70 | ]
71 | },
72 | {
73 | "name": "stderr",
74 | "output_type": "stream",
75 | "text": [
76 | "100%|██████████| 1648877/1648877 [00:00<00:00, 10304400.89it/s]\n"
77 | ]
78 | },
79 | {
80 | "name": "stdout",
81 | "output_type": "stream",
82 | "text": [
83 | "Extracting ./MNIST\\raw\\t10k-images-idx3-ubyte.gz to ./MNIST\\raw\n",
84 | "\n",
85 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
86 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
87 | ]
88 | },
89 | {
90 | "name": "stderr",
91 | "output_type": "stream",
92 | "text": [
93 | "100%|██████████| 4542/4542 [00:00, ?it/s]"
94 | ]
95 | },
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "Extracting ./MNIST\\raw\\t10k-labels-idx1-ubyte.gz to ./MNIST\\raw\n",
101 | "\n",
102 | "data.shape = [60000, 28, 28]\n",
103 | "batch size = 60000\n",
104 | "data shape = [28, 28] \n"
105 | ]
106 | },
107 | {
108 | "name": "stderr",
109 | "output_type": "stream",
110 | "text": [
111 | "\n"
112 | ]
113 | }
114 | ],
115 | "source": [
116 | "import torchvision\n",
117 | "data = torchvision.datasets.MNIST(root=\"./\", download=True)\n",
118 | "train_data = data.data\n",
119 | "dl = DataLoader(train_data, batch_size=192)\n",
120 | "print(f\"data.shape = {list(train_data.shape)}\")\n",
121 | "print(f\"batch size = {train_data.shape[0]}\")\n",
122 | "print(f\"data shape = {list(train_data.shape[1:])} \")\n"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 3,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "h, z_size = 1024, 2\n",
132 | "encoder = nn.Sequential(\n",
133 | " nn.Linear(28*28, h),\n",
134 | " nn.ReLU(),\n",
135 | " nn.Linear(h, h),\n",
136 | " nn.ReLU(),\n",
137 | " nn.Linear(h, z_size*2)\n",
138 | ")\n",
139 | "decoder = nn.Sequential(\n",
140 | " nn.Linear(z_size, h),\n",
141 | " nn.ReLU(),\n",
142 | " nn.Linear(h, h),\n",
143 | " nn.ReLU(),\n",
144 | " nn.Linear(h, 28*28),\n",
145 | " nn.Sigmoid()\n",
146 | ")\n",
147 | "model = nn.Sequential(encoder, decoder).to(device)\n"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 4,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "opt = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))\n",
157 | "def vae_loss(x, x_hat, mean, logvar):\n",
158 | " recon_loss = torch.nn.functional.binary_cross_entropy(x_hat, x)\n",
159 | " \n",
160 | " var = torch.exp(logvar)\n",
161 | " kl_loss = 0.5 * torch.mean(mean**2 + var - logvar - 1)\n",
162 | " return recon_loss, kl_loss"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 5,
168 | "metadata": {},
169 | "outputs": [
170 | {
171 | "name": "stdout",
172 | "output_type": "stream",
173 | "text": [
174 | "Epoch 1: loss = 0.213 = 0.212 + 0.001\n",
175 | "Epoch 2: loss = 0.192 = 0.191 + 0.001\n",
176 | "Epoch 3: loss = 0.180 = 0.178 + 0.001\n",
177 | "Epoch 4: loss = 0.172 = 0.171 + 0.001\n",
178 | "Epoch 5: loss = 0.168 = 0.167 + 0.001\n",
179 | "Epoch 6: loss = 0.166 = 0.165 + 0.001\n",
180 | "Epoch 7: loss = 0.166 = 0.165 + 0.001\n",
181 | "Epoch 8: loss = 0.162 = 0.162 + 0.001\n",
182 | "Epoch 9: loss = 0.161 = 0.160 + 0.001\n",
183 | "Epoch 10: loss = 0.161 = 0.160 + 0.001\n",
184 | "Epoch 11: loss = 0.159 = 0.159 + 0.001\n",
185 | "Epoch 12: loss = 0.160 = 0.159 + 0.001\n",
186 | "Epoch 13: loss = 0.161 = 0.161 + 0.001\n",
187 | "Epoch 14: loss = 0.158 = 0.158 + 0.001\n",
188 | "Epoch 15: loss = 0.157 = 0.156 + 0.001\n",
189 | "Epoch 16: loss = 0.157 = 0.156 + 0.001\n",
190 | "Epoch 17: loss = 0.157 = 0.156 + 0.001\n",
191 | "Epoch 18: loss = 0.156 = 0.155 + 0.001\n",
192 | "Epoch 19: loss = 0.155 = 0.155 + 0.001\n",
193 | "Epoch 20: loss = 0.155 = 0.155 + 0.001\n"
194 | ]
195 | }
196 | ],
197 | "source": [
198 | "mode = \"AE\" # AE / z0AE / VAE\n",
199 | "\n",
200 | "for e in range(1, 20+1):\n",
201 | " for x in dl:\n",
202 | " opt.zero_grad()\n",
203 | " x = x.to(device=device, dtype=torch.float) / 255\n",
204 | " x = x.view(x.shape[0], -1)\n",
205 | " z = encoder(x)\n",
206 | " z1, z2 = z[:, :z.shape[1]//2], z[:, z.shape[1]//2:]\n",
207 | " if mode == \"VAE\":\n",
208 | " sig = torch.exp(0.5 * z2)\n",
209 | " z = z1 + sig*torch.randn_like(z1, device=device)\n",
210 | " if mode == \"z0AE\" or \"AE\":\n",
211 | " z = z1\n",
212 | " x_hat = decoder(z)\n",
213 | "\n",
214 | " if mode == \"VAE\":\n",
215 | " rl, kl = vae_loss(x, x_hat, z1, z2)\n",
216 | " loss = rl + kl\n",
217 | " if mode == \"z0AE\":\n",
218 | " loss = rl + z1.abs().mean()\n",
219 | " if mode == \"AE\":\n",
220 | " loss = rl\n",
221 | " loss.backward()\n",
222 | " opt.step()\n",
223 | " print(f\"\\rEpoch {e}: loss = {loss.item():.3f} = {rl.item():.3f} + {kl.item():.3f}\", end=\"\")\n",
224 | " print()\n",
225 | " "
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 13,
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | "tensor(3.9756, device='cuda:0', grad_fn=) tensor(-3.1758, device='cuda:0', grad_fn=) tensor(1.0295, device='cuda:0', grad_fn=)\n"
238 | ]
239 | },
240 | {
241 | "data": {
242 | "text/plain": [
243 | ""
244 | ]
245 | },
246 | "execution_count": 13,
247 | "metadata": {},
248 | "output_type": "execute_result"
249 | },
250 | {
251 | "data": {
252 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgQUlEQVR4nO3de3DV9f3n8dfJ7XBLTgwhNwkYUMEKxEohpSrFkiXE1gVlO946A44/XGlwitTq0lVR25m0+Bvr6o/q/GZaqLvibVZg5Ke4CiasbaAFoZSq+RF+UYKQIKnJyYVcz2f/YE17FNTP8YR3Ep6Pme8MOef74vvhyze88s05vBNwzjkBAHCWJVgvAABwbqKAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYCLJegGfFolEdPToUaWmpioQCFgvBwDgyTmnlpYW5eXlKSHhzPc5A66Ajh49qvz8fOtlAAC+orq6Oo0dO/aMzw+4AkpNTZUkXalrlKRk49UAAHz1qFtv6ZW+f8/PpN8KaO3atXrkkUdUX1+vwsJCPfHEE5o5c+YX5j75tluSkpUUoIAAYND5/xNGv+hllH55E8Lzzz+vlStXavXq1Xr77bdVWFiokpISHT9+vD8OBwAYhPqlgB599FEtXbpUt956q772ta/pqaee0ogRI/Tb3/62Pw4HABiE4l5AXV1d2rNnj4qLi/9+kIQEFRcXq6qq6jP7d3Z2KhwOR20AgKEv7gV04sQJ9fb2Kjs7O+rx7Oxs1dfXf2b/8vJyhUKhvo13wAHAucH8P6KuWrVKzc3NfVtdXZ31kgAAZ0Hc3wWXmZmpxMRENTQ0RD3e0NCgnJycz+wfDAYVDAbjvQwAwAAX9zuglJQUTZ8+Xdu2bet7LBKJaNu2bZo1a1a8DwcAGKT65f8BrVy5UosXL9Y3vvENzZw5U4899pja2tp066239sfhAACDUL8U0A033KCPPvpIDzzwgOrr63XZZZdp69atn3ljAgDg3BVwzjnrRfyjcDisUCikOVrAJAQAGIR6XLcqtFnNzc1KS0s7437m74IDAJybKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgIkk6wUAA0kgKYZPicDA/TrO9fbGGIzEkHGxHQvnrIH7mQMAGNIoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBgpBr5AwDuSODojpkN1XF7gnfnoshTvTNtY/yGhSe3+Xy8GG/3PnSSlfeA/jHTU4XbvTFLdCe9M74lG74zr7PTOoP9xBwQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEw0hxdiUkekcSM9K9Mye+e7F3RpI+LjnpnZk69gP/TOiodyYY6PHOHO9O9c5I0n+0Znpn3vkwxzuT/uZ470xW5XDvTKS2zjsjSa67K6YcvhzugAAAJiggAICJuBfQgw8+qEAgELVNnjw53ocBAAxy/fIa0KWXXqo33njj7wdJ4qUmAEC0fmmGpKQk5eT4vyAJADh39MtrQAcPHlReXp4mTJigW265RYcPHz7jvp2dnQqHw1EbAGDoi3sBFRUVaf369dq6dauefPJJ1dbW6qqrrlJLS8tp9y8vL1coFOrb8vPz470kAMAAFPcCKi0t1fe//31NmzZNJSUleuWVV9TU1KQXXnjhtPuvWrVKzc3NfVtdXWzv1wcADC79/u6A9PR0XXzxxaqpqTnt88FgUMFgsL+XAQAYYPr9/wG1trbq0KFDys3N7e9DAQAGkbgX0N13363Kykq9//77+sMf/qDrrrtOiYmJuummm+J9KADAIBb3b8EdOXJEN910kxobGzVmzBhdeeWV2rlzp8aMGRPvQwEABrG4F9Bzzz0X798SA1Usg0VDad6Z4wv9B4sO+y8N3hlJ+qfz/+KdmTLc/40zLb3+AzVrOrO9M+cHP/bOSNKMUbXemaYxI7wzvxn1Le/M39r9v5gdHW71zkhSz/ET/qFIb0zHOhcxCw4AYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAICJfv+BdBgEAoGYYgnD/H+QYOflE7wzmTcf9s78MP9N74wkpSe2e2d2t/v/mf5lzxzvTKAxxT+T3emdkaSCHP8hnNPSP/TOXH/Bn70zv/vuN70zo46e752RpJT2k96Z3pYW/wM5558ZArgDAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBo2FEjxn7IsSYELxnpn/uMW/8nbzxds8s7kJcY2BXpL6yTvzNrX53lncnZ5RxSI+Gfasof7hyQdycj3zhwrTPPOfH/iXu/Mf5603zuz+bszvTOSdGHY/zwE9h/0zrjuLu/MUMAdEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMMIx1qAv7DPhPS/IdISlL97NHemX++8hnvzPSURO/MnzpjG8L5yGvXemcueLnbO5PU6j980iX7n4dhJ/wzktSS7z+gNpwQ8s5sG+k//HXxuCrvTMO33vPOSNI7hy7xzuQc8r/2epv9ryE5558ZYLgDAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIJhpENMIMV/iKTGnBfTsUZ//4h35nsjG70zH0f8B3f+1/3/5J2RpAu2+A+FDB4Le2cC7R3eGSX6DxZNGhbD9SApuSXonQmGR3hnjozO8s4czvEfgvu90X/2zkhS1bcKvDM5r2f4Hyjc6p9xvf6ZAYY7IACACQoIAGDCu4B27Niha6+9Vnl5eQoEAtq0aVPU8845PfDAA8rNzdXw4cNVXFysgwcPxmu9AIAhwruA2traVFhYqLVr1572+TVr1ujxxx/XU089pV27dmnkyJEqKSlRR0cM3/MGAAxZ3m9CKC0tVWlp6Wmfc87pscce03333acFCxZIkp5++mllZ2dr06ZNuvHGG7/aagEAQ0ZcXwOqra1VfX29iouL+x4LhUIqKipSVdXpf4xuZ2enwuFw1AYAGPriWkD19fWSpOzs7KjHs7Oz+577tPLycoVCob4tPz8/nksCAAxQ5u+CW7VqlZqbm/u2uro66yUBAM6CuBZQTk6OJKmhoSHq8YaGhr7nPi0YDCotLS1qAwAMfXEtoIKCAuXk5Gjbtm19j4XDYe3atUuzZs2K56EAAIOc97vgWltbVVNT0/dxbW2t9u3bp4yMDI0bN04rVqzQz3/+c1100UUqKCjQ/fffr7y8PC1cuDCe6wYADHLeBbR7925dffXVfR+vXLlSkrR48WKtX79e99xzj9ra2nT77berqalJV155pbZu3aphw4bFb9UAgEHPu4DmzJkj59wZnw8EAnr44Yf18MMPf6WFQVIg4B1JGO5f9A1X+g93lKR/nfA/vDPBgP9wzLXNk7wzKZvTvTOSFDzc8MU7fUpMg0UjEf+M/Aelqq09huNIiR/7f3d+VJf/UNu0sf6Zv0zL885cPepd74wkzZpY6505njXeO5Pwgf+gWRdhGCkAADGhgAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJjwnoaNsyjg//VBID3knem5psk7I0mXpfhfPod7Wr0zT/5biXdm4r6wd0aSAq3+06NddwxTqrv8M66nx/84MXKxTGKP4ThpdanemdqP/ae3p+ef9M5I0tTUD70zr4260DuTkuB/vocC7oAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBjpABZITPTOdEwc4525YcJb3hlJSoxhWOoTJ67yzuRW9XpnEusbvTOSFGlt8w/FMIzU9Ub8j+NiyJxF7qT/PyfBRv9zd7R1uHcmlBDDwFhJ2cnN3pmuVP/P22AMw1+dd2Lg4Q4IAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACYaRDmCBYUHvzMcXp3hnrkvb652RpPaI/+Xzv//yde/MpPdbvTPu5EnvjCS5jk7/TK//sNSYDPRhpDEMWA3EMpQ1hq+bUxP8h31K0rCA/xDThJ6hMCb07OAOCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAmGkQ5ggWHDvDNNX/Mf7jgmMbbhif/e7Z8b8a7/nynh44+8M70nO7wzkmIb+DnAh4TGJOD/tWkg0T/TM8L/n6AxGU3emVCC/3UnSfU9Ie9MQgyfF66nxzszFHAHBAAwQQEBAEx4F9COHTt07bXXKi8vT4FAQJs2bYp6fsmSJQoEAlHb/Pnz47VeAMAQ4V1AbW1tKiws1Nq1a8+4z/z583Xs2LG+7dlnn/1KiwQADD3erwCWlpaqtLT0c/cJBoPKycmJeVEAgKGvX14DqqioUFZWliZNmqRly5apsbHxjPt2dnYqHA5HbQCAoS/uBTR//nw9/fTT2rZtm375y1+qsrJSpaWl6u3tPe3+5eXlCoVCfVt+fn68lwQAGIDi/v+Abrzxxr5fT506VdOmTdPEiRNVUVGhuXPnfmb/VatWaeXKlX0fh8NhSggAzgH9/jbsCRMmKDMzUzU1Nad9PhgMKi0tLWoDAAx9/V5AR44cUWNjo3Jzc/v7UACAQcT7W3Ctra1RdzO1tbXat2+fMjIylJGRoYceekiLFi1STk6ODh06pHvuuUcXXnihSkpK4rpwAMDg5l1Au3fv1tVXX9338Sev3yxevFhPPvmk9u/fr9/97ndqampSXl6e5s2bp5/97GcKBoPxWzUAYNDzLqA5c+bIuTMP23vttde+0oKGrEDAP5KU6J0ZltvmnYnVX7vyvDOjjvgP7nRt7d4ZRWIbsOpiycUwuFOR078rNO5iuO4kKZDs//6kQGqqdyZ8QYp3ZvG43d6Zbhfb+d7w/gzvzOgjLd6ZyBneJTzUMQsOAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGAi7j+SG3GU5P/XE0zu8c50fM5088/zQWemdyYQy9DfXv8J2meVi2F9MU6p9j5Miv+0aUlKSA95ZzonZnln/jbV/9orHvmud+bfu/0ny0tSy5/GeGcyP3zP/0Axfg4OdtwBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMMEw0oEshgGFEec/5LI7xjmImckt3pnOdP/1BYYF/TOdnd6ZmMVwzpXoPxwzIRjDechI985IUnfeed6ZD2cP8878t3mbvDOpCf7DXx+q+553RpLOr+jwzkRa22I61rmIOyAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmGEZ6tsQwWNR1dnlnwidGemfanf9gTEkal9zonWm+2P88jL4w1zuTUpfinZEkd9J/+GQgOdn/QMn+n3pdYzO8My3j/AeYStLxb/r/Pf334pe8M/9pZI13ZnPrJd6Zd16a7J2RpLF/Peid6e3y/7w9V3EHBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwATDSAcw19bmnQnt9x/C+d7sbO+MJH0tpd47M32G/3DHP7df7J0Z9cH53hlJSmn1H8LZEwx4ZzrP88+0jo94Zy77+iHvjCStzf8370x+Urd35vnwpd6Zx7dc4525aMtx74wkRT7+OKYcvhzugAAAJiggAIAJrwIqLy/XjBkzlJqaqqysLC1cuFDV1dVR+3R0dKisrEyjR4/WqFGjtGjRIjU0NMR10QCAwc+rgCorK1VWVqadO3fq9ddfV3d3t+bNm6e2f3it4q677tLLL7+sF198UZWVlTp69Kiuv/76uC8cADC4eb0JYevWrVEfr1+/XllZWdqzZ49mz56t5uZm/eY3v9GGDRv0ne98R5K0bt06XXLJJdq5c6e++c1vxm/lAIBB7Su9BtTc3CxJysg49aOC9+zZo+7ubhUXF/ftM3nyZI0bN05VVVWn/T06OzsVDoejNgDA0BdzAUUiEa1YsUJXXHGFpkyZIkmqr69XSkqK0tPTo/bNzs5Wff3p37JbXl6uUCjUt+Xn58e6JADAIBJzAZWVlenAgQN67rnnvtICVq1apebm5r6trq7uK/1+AIDBIab/iLp8+XJt2bJFO3bs0NixY/sez8nJUVdXl5qamqLughoaGpSTk3Pa3ysYDCoYDMayDADAIOZ1B+Sc0/Lly7Vx40Zt375dBQUFUc9Pnz5dycnJ2rZtW99j1dXVOnz4sGbNmhWfFQMAhgSvO6CysjJt2LBBmzdvVmpqat/rOqFQSMOHD1coFNJtt92mlStXKiMjQ2lpabrzzjs1a9Ys3gEHAIjiVUBPPvmkJGnOnDlRj69bt05LliyRJP3qV79SQkKCFi1apM7OTpWUlOjXv/51XBYLABg6As45/+mL/SgcDisUCmmOFigpkGy9HFsJif6Rqf6DO4//vNc7I0n/OuV/eWeGBfyP9dvGK7wzexrHeWckqfnkMO9MdmqLd6bwvA+9M4vS/+R/HP/ZtJKkj3o7vTNPNF7pnXnp//h/a/7C/+k/INQdrPXOSJLr9D8PkHpctyq0Wc3NzUpLSzvjfsyCAwCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYiOknouIsicQwpbrmsHdk+Pop/seR9C8/meudeSjvVf9M9h+8M01jdnhnJKk54j+BPCPB/+8pN2mUd6Y50uOdefNkundGku5/7xbvTM+rmd6Zi1/xnwree+SYd8Z1d3ln0P+4AwIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCYaRDTORkh3cmreJgTMeq7Zjsnfn2Qv/M9Zfv8c5ck/5n74wkpSec9M683XWed+a141O9M1ve8c+MPDDMOyNJuf+3zTuT9O5fvTM94VbvTExDejEgcQcEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABMNIh5oYBjX2/u3jmA41/E3/4ZOXVGd7Z/Ze8HXvzJ9GfsM7I0mBiH8mkhTwzgSbur0zl7zf6J1xjbH93cYy1La3x//PJOf8MxgyuAMCAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABggmGkiHkgZKS93T9UU+sdSY4l450Y+HqsFwDEGXdAAAATFBAAwIRXAZWXl2vGjBlKTU1VVlaWFi5cqOrq6qh95syZo0AgELXdcccdcV00AGDw8yqgyspKlZWVaefOnXr99dfV3d2tefPmqa2tLWq/pUuX6tixY33bmjVr4rpoAMDg5/UmhK1bt0Z9vH79emVlZWnPnj2aPXt23+MjRoxQTk5OfFYIABiSvtJrQM3NzZKkjIyMqMefeeYZZWZmasqUKVq1apXaP+fdUp2dnQqHw1EbAGDoi/lt2JFIRCtWrNAVV1yhKVOm9D1+8803a/z48crLy9P+/ft17733qrq6Wi+99NJpf5/y8nI99NBDsS4DADBIBZyL7T+BLFu2TK+++qreeustjR079oz7bd++XXPnzlVNTY0mTpz4mec7OzvV2dnZ93E4HFZ+fr7maIGSAkPxf3MAwNDW47pVoc1qbm5WWlraGfeL6Q5o+fLl2rJli3bs2PG55SNJRUVFknTGAgoGgwoGg7EsAwAwiHkVkHNOd955pzZu3KiKigoVFBR8YWbfvn2SpNzc3JgWCAAYmrwKqKysTBs2bNDmzZuVmpqq+vp6SVIoFNLw4cN16NAhbdiwQddcc41Gjx6t/fv366677tLs2bM1bdq0fvkDAAAGJ6/XgAKBwGkfX7dunZYsWaK6ujr94Ac/0IEDB9TW1qb8/Hxdd911uu+++z73+4D/KBwOKxQK8RoQAAxS/fIa0Bd1VX5+viorK31+SwDAOYpZcAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAE0nWC/g055wkqUfdkjNeDADAW4+6Jf393/MzGXAF1NLSIkl6S68YrwQA8FW0tLQoFAqd8fmA+6KKOssikYiOHj2q1NRUBQKBqOfC4bDy8/NVV1entLQ0oxXa4zycwnk4hfNwCufhlIFwHpxzamlpUV5enhISzvxKz4C7A0pISNDYsWM/d5+0tLRz+gL7BOfhFM7DKZyHUzgPp1ifh8+78/kEb0IAAJiggAAAJgZVAQWDQa1evVrBYNB6KaY4D6dwHk7hPJzCeThlMJ2HAfcmBADAuWFQ3QEBAIYOCggAYIICAgCYoIAAACYGTQGtXbtWF1xwgYYNG6aioiL98Y9/tF7SWffggw8qEAhEbZMnT7ZeVr/bsWOHrr32WuXl5SkQCGjTpk1Rzzvn9MADDyg3N1fDhw9XcXGxDh48aLPYfvRF52HJkiWfuT7mz59vs9h+Ul5erhkzZig1NVVZWVlauHChqquro/bp6OhQWVmZRo8erVGjRmnRokVqaGgwWnH/+DLnYc6cOZ+5Hu644w6jFZ/eoCig559/XitXrtTq1av19ttvq7CwUCUlJTp+/Lj10s66Sy+9VMeOHevb3nrrLesl9bu2tjYVFhZq7dq1p31+zZo1evzxx/XUU09p165dGjlypEpKStTR0XGWV9q/vug8SNL8+fOjro9nn332LK6w/1VWVqqsrEw7d+7U66+/ru7ubs2bN09tbW19+9x11116+eWX9eKLL6qyslJHjx7V9ddfb7jq+Psy50GSli5dGnU9rFmzxmjFZ+AGgZkzZ7qysrK+j3t7e11eXp4rLy83XNXZt3r1aldYWGi9DFOS3MaNG/s+jkQiLicnxz3yyCN9jzU1NblgMOieffZZgxWeHZ8+D845t3jxYrdgwQKT9Vg5fvy4k+QqKyudc6f+7pOTk92LL77Yt8+7777rJLmqqiqrZfa7T58H55z79re/7X70ox/ZLepLGPB3QF1dXdqzZ4+Ki4v7HktISFBxcbGqqqoMV2bj4MGDysvL04QJE3TLLbfo8OHD1ksyVVtbq/r6+qjrIxQKqaio6Jy8PioqKpSVlaVJkyZp2bJlamxstF5Sv2pubpYkZWRkSJL27Nmj7u7uqOth8uTJGjdu3JC+Hj59Hj7xzDPPKDMzU1OmTNGqVavU3t5usbwzGnDDSD/txIkT6u3tVXZ2dtTj2dnZeu+994xWZaOoqEjr16/XpEmTdOzYMT300EO66qqrdODAAaWmplovz0R9fb0knfb6+OS5c8X8+fN1/fXXq6CgQIcOHdJPf/pTlZaWqqqqSomJidbLi7tIJKIVK1boiiuu0JQpUySduh5SUlKUnp4ete9Qvh5Odx4k6eabb9b48eOVl5en/fv3695771V1dbVeeuklw9VGG/AFhL8rLS3t+/W0adNUVFSk8ePH64UXXtBtt91muDIMBDfeeGPfr6dOnapp06Zp4sSJqqio0Ny5cw1X1j/Kysp04MCBc+J10M9zpvNw++239/166tSpys3N1dy5c3Xo0CFNnDjxbC/ztAb8t+AyMzOVmJj4mXexNDQ0KCcnx2hVA0N6erouvvhi1dTUWC/FzCfXANfHZ02YMEGZmZlD8vpYvny5tmzZojfffDPqx7fk5OSoq6tLTU1NUfsP1evhTOfhdIqKiiRpQF0PA76AUlJSNH36dG3btq3vsUgkom3btmnWrFmGK7PX2tqqQ4cOKTc313opZgoKCpSTkxN1fYTDYe3ateucvz6OHDmixsbGIXV9OOe0fPlybdy4Udu3b1dBQUHU89OnT1dycnLU9VBdXa3Dhw8Pqevhi87D6ezbt0+SBtb1YP0uiC/jueeec8Fg0K1fv96988477vbbb3fp6emuvr7eemln1Y9//GNXUVHhamtr3e9//3tXXFzsMjMz3fHjx62X1q9aWlrc3r173d69e50k9+ijj7q9e/e6Dz74wDnn3C9+8QuXnp7uNm/e7Pbv3+8WLFjgCgoK3MmTJ41XHl+fdx5aWlrc3Xff7aqqqlxtba1744033OWXX+4uuugi19HRYb30uFm2bJkLhUKuoqLCHTt2rG9rb2/v2+eOO+5w48aNc9u3b3e7d+92s2bNcrNmzTJcdfx90XmoqalxDz/8sNu9e7erra11mzdvdhMmTHCzZ882Xnm0QVFAzjn3xBNPuHHjxrmUlBQ3c+ZMt3PnTuslnXU33HCDy83NdSkpKe788893N9xwg6upqbFeVr978803naTPbIsXL3bOnXor9v333++ys7NdMBh0c+fOddXV1baL7gefdx7a29vdvHnz3JgxY1xycrIbP368W7p06ZD7Iu10f35Jbt26dX37nDx50v3whz905513nhsxYoS77rrr3LFjx+wW3Q++6DwcPnzYzZ4922VkZLhgMOguvPBC95Of/MQ1NzfbLvxT+HEMAAATA/41IADA0EQBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMDE/wP2Gvqm9MbL2gAAAABJRU5ErkJggg==",
253 | "text/plain": [
254 | ""
255 | ]
256 | },
257 | "metadata": {},
258 | "output_type": "display_data"
259 | }
260 | ],
261 | "source": [
262 | "import matplotlib.pyplot as plt\n",
263 | "print(z1.max(), z1.min(), z1.std())\n",
264 | "plt.imshow(decoder(torch.randn(1, z_size).cuda()).view(28, 28).detach().cpu().numpy())"
265 | ]
266 | }
267 | ],
268 | "metadata": {
269 | "kernelspec": {
270 | "display_name": "Python 3",
271 | "language": "python",
272 | "name": "python3"
273 | },
274 | "language_info": {
275 | "codemirror_mode": {
276 | "name": "ipython",
277 | "version": 3
278 | },
279 | "file_extension": ".py",
280 | "mimetype": "text/x-python",
281 | "name": "python",
282 | "nbconvert_exporter": "python",
283 | "pygments_lexer": "ipython3",
284 | "version": "3.11.4"
285 | }
286 | },
287 | "nbformat": 4,
288 | "nbformat_minor": 2
289 | }
290 |
--------------------------------------------------------------------------------
/vqvae.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torch.nn as nn\n",
11 | "import torch.optim as optim\n",
12 | "from torchvision import datasets, transforms\n",
13 | "from torch.utils.data import DataLoader\n",
14 | "import cv2\n",
15 | "import numpy as np\n",
16 | "torch.manual_seed(41)"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "### 1. Load MNIST dataset"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# 1. Load MNIST dataset\n",
33 | "transform = transforms.Compose([transforms.ToTensor()])\n",
34 | "train_dataset = datasets.MNIST(\n",
35 | " root=\"./data\", train=True, download=True, transform=transform\n",
36 | ")\n",
37 | "train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=False)\n"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "### 2. Define VQVAE model"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "\n",
54 | "class VQVAE(nn.Module):\n",
55 | " def __init__(self):\n",
56 | " super(VQVAE, self).__init__()\n",
57 | " self.encoder = nn.Sequential(\n",
58 | " nn.Conv2d(1, 4, 4, stride=4, padding=0),\n",
59 | " nn.BatchNorm2d(4),\n",
60 | " nn.ReLU()\n",
61 | " )\n",
62 | "\n",
63 | " self.pre_quant_conv = nn.Conv2d(4, 2, kernel_size=1)\n",
64 | " self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=2)\n",
65 | " self.post_quant_conv = nn.Conv2d(2, 4, kernel_size=1)\n",
66 | "\n",
67 | " # Commitment Loss Beta\n",
68 | " self.beta = 0.2\n",
69 | "\n",
70 | " self.decoder = nn.Sequential(\n",
71 | " nn.ConvTranspose2d(4, 16, 4, stride=2, padding=1),\n",
72 | " nn.BatchNorm2d(16),\n",
73 | " nn.ReLU(),\n",
74 | " nn.ConvTranspose2d(16, 1, 4, stride=2, padding=1),\n",
75 | " nn.Sigmoid(),\n",
76 | " )\n",
77 | "\n",
78 | " def forward(self, x):\n",
79 | " # B, C, H, W\n",
80 | " encoded_output = self.encoder(x)\n",
81 | " quant_input = self.pre_quant_conv(encoded_output)\n",
82 | "\n",
83 | " ## Quantization\n",
84 | " B, C, H, W = quant_input.shape\n",
85 | " quant_input = quant_input.permute(0, 2, 3, 1)\n",
86 | " quant_input = quant_input.reshape(\n",
87 | " (quant_input.size(0), -1, quant_input.size(-1))\n",
88 | " )\n",
89 | "\n",
90 | " # Compute pairwise distances\n",
91 | " dist = torch.cdist(\n",
92 | " quant_input,\n",
93 | " self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1)),\n",
94 | " )\n",
95 | "\n",
96 | " # Find index of nearest embedding\n",
97 | " min_encoding_indices = torch.argmin(dist, dim=-1)\n",
98 | "\n",
99 | " # Select the embedding weights\n",
100 | " quant_out = torch.index_select(\n",
101 | " self.embedding.weight, 0, min_encoding_indices.view(-1)\n",
102 | " )\n",
103 | " quant_input = quant_input.reshape((-1, quant_input.size(-1)))\n",
104 | "\n",
105 | " # Compute losses\n",
106 | " commitment_loss = torch.mean((quant_out.detach() - quant_input) ** 2)\n",
107 | " codebook_loss = torch.mean((quant_out - quant_input.detach()) ** 2)\n",
108 | " quantize_losses = codebook_loss + commitment_loss * 0.1\n",
109 | " quant_out = quant_input + (quant_out - quant_input).detach()\n",
110 | "\n",
111 | " # Reshaping back to original input shape\n",
112 | " quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)\n",
113 | " min_encoding_indices = min_encoding_indices.reshape(\n",
114 | " (-1, quant_out.size(-2), quant_out.size(-1))\n",
115 | " )\n",
116 | "\n",
117 | " ## Decoder part\n",
118 | " decoder_input = self.post_quant_conv(quant_out)\n",
119 | " output = self.decoder(decoder_input)\n",
120 | " return output, quantize_losses"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "### 3. Initialize model, optimizer, scheduler"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "model = VQVAE().cuda()\n",
137 | "optimizer = optim.AdamW(model.parameters(), lr=2e-3)\n",
138 | "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {},
144 | "source": [
145 | "### 4. Training loop"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "model.train()\n",
155 | "for epoch in range(20):\n",
156 | " for batch_idx, (data, _) in enumerate(train_loader):\n",
157 | " data = data.cuda()\n",
158 | " optimizer.zero_grad()\n",
159 | " out, quantize_loss = model(data)\n",
160 | " recon_loss = torch.nn.functional.mse_loss(out, data)\n",
161 | " loss = recon_loss + quantize_loss\n",
162 | " loss.backward()\n",
163 | " nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
164 | " optimizer.step()\n",
165 | " print(\n",
166 | " f\"\\rEpoch {epoch}, Batch {batch_idx:03d}, Loss: {loss.item():.4f} = {recon_loss.item():.4f} + {quantize_loss.item():.4f}\",\n",
167 | " end=\"\",\n",
168 | " )\n",
169 | " scheduler.step()\n",
170 | " print(\"\")\n",
171 | "\n",
172 | "# torch.save(model.state_dict(), \"vqvae.ckpt\")"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {},
178 | "source": [
179 | "### 5. Inference"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "def save(data, idx):\n",
189 | " data = (data * 255).astype(np.uint8)\n",
190 | " cv2.imwrite(f\"img_{idx}.png\", data)\n",
191 | "\n",
192 | "\n",
193 | "model.eval()\n",
194 | "with torch.no_grad():\n",
195 | " for i in range(10):\n",
196 | " data, _ = train_dataset[i]\n",
197 | " data = data.unsqueeze(0).cuda()\n",
198 | " x_recon, _ = model(data)\n",
199 | " recon_img = x_recon.cpu().squeeze().numpy()\n",
200 | " # Save or display recon_img\n",
201 | " save(x_recon, i)\n"
202 | ]
203 | }
204 | ],
205 | "metadata": {
206 | "kernelspec": {
207 | "display_name": "Python 3",
208 | "language": "python",
209 | "name": "python3"
210 | },
211 | "language_info": {
212 | "codemirror_mode": {
213 | "name": "ipython",
214 | "version": 3
215 | },
216 | "file_extension": ".py",
217 | "mimetype": "text/x-python",
218 | "name": "python",
219 | "nbconvert_exporter": "python",
220 | "pygments_lexer": "ipython3",
221 | "version": "3.11.4"
222 | }
223 | },
224 | "nbformat": 4,
225 | "nbformat_minor": 2
226 | }
227 |
--------------------------------------------------------------------------------