├── .gitignore
├── LICENSE
├── README.md
├── code
├── alexnet.ipynb
├── attention.ipynb
├── cnn.ipynb
├── dataset.py
├── ddpm_1d_example.ipynb
├── ddpm_2d_example.ipynb
├── ddpm_cfg_1d_example.ipynb
├── diffusion.py
├── diffusion_constants.ipynb
├── diffusion_resblock.ipynb
├── diffusion_unet.ipynb
├── diffusion_unet_legacy.ipynb
├── googlenet.ipynb
├── gp.ipynb
├── hdm_1d_example.ipynb
├── hdm_constant.ipynb
├── mdn.py
├── mdn_reg.ipynb
├── mlp.ipynb
├── module.py
├── repaint_1d_example.ipynb
├── repaint_constant.ipynb
├── resnet.ipynb
├── updownsample.ipynb
├── util.py
└── vgg.ipynb
└── img
└── unet.jpg
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # Mac
163 | .DS_Store
164 |
165 | # Data foloer
166 | data/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Sungjoon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Yet Another Pytorch Tutorial v2
2 |
3 | Minimal implementations of
4 | - [MLP](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/mlp.ipynb): MNIST classification using multi-layer perception (MLP)
5 | - [CNN](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/cnn.ipynb): MNIST classification using convolutional neural networks (CNN)
6 | - [AlexNet](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/alexnet.ipynb): : MNIST classification using AlexNet
7 | - [VGG](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/vgg.ipynb): MNIST classification using VGG
8 | - [GoogLeNet](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/googlenet.ipynb): MNIST classification GoogLeNet
9 | - [MDN](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/mdn_reg.ipynb): Regression using Mixture Density Networks
10 | - [ResNet](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/resnet.ipynb): MNIST classification ResNet
11 | - [Attention](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/attention.ipynb): Attention block with legacy QKV attention mechanism
12 | - [ResBlock](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/diffusion_resblock.ipynb): Residual block for diffusion models
13 | - [Diffusion](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/diffusion_constants.ipynb): Diffusion constants
14 | - [DDPM 1D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/ddpm_1d_example.ipynb): Denoising Diffusion Probabilistic Model (DDPM) example on generating trajectories
15 | - [DDPM 2D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/ddpm_2d_example.ipynb): Denoising Diffusion Probabilistic Model (DDPM) example on generating images
16 | - [DDPM-CFG 1D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/ddpm_cfh_1d_example.ipynb): Conditional Generation using Classifier-Free Guidance (CFG)
17 | - [Repaint 1D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/repaint_1d_example.ipynb): Diffusion-based inpainting example on 1D data
18 |
19 | Contact
20 | - sungjoon dash choi at korea dot ac dot kr
21 |
--------------------------------------------------------------------------------
/code/attention.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "71051e5d",
6 | "metadata": {},
7 | "source": [
8 | "### Attention block"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "7ae3d6f4",
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "PyTorch version:[2.0.1].\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "import numpy as np\n",
27 | "import matplotlib.pyplot as plt\n",
28 | "import torch as th\n",
29 | "from module import AttentionBlock\n",
30 | "from util import get_torch_size_string\n",
31 | "np.set_printoptions(precision=3)\n",
32 | "th.set_printoptions(precision=3)\n",
33 | "%matplotlib inline\n",
34 | "%config InlineBackend.figure_format='retina'\n",
35 | "print (\"PyTorch version:[%s].\"%(th.__version__))"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "id": "bfb7e801",
41 | "metadata": {},
42 | "source": [
43 | "### Let's see how `AttentionBlock` works\n",
44 | "- First, we assume that an input tensor has a shape of [B x C x W x H].\n",
45 | "- This can be thought of having a total of WH tokens with each token having C dimensions. \n",
46 | "- The MHA operates by initally partiting the channels, executing qkv attention process, and then merging the results. \n",
47 | "- Note the the number of channels should be divisible by the number of heads."
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "id": "dca0e4fa",
53 | "metadata": {},
54 | "source": [
55 | "### `dims=2`\n",
56 | "#### `x` has a shape of `[B x C x W x H]`"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 2,
62 | "id": "8c6e06cc",
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "name": "stdout",
67 | "output_type": "stream",
68 | "text": [
69 | "input shape:[16x128x28x28] output shape:[16x128x28x28]\n",
70 | "[ x]:[ 16x128x28x28]\n",
71 | "[ x_rsh]:[ 16x128x784]\n",
72 | "[ x_nzd]:[ 16x128x784]\n",
73 | "[ qkv]:[ 16x384x784]\n",
74 | "[ h_att]:[ 16x128x784]\n",
75 | "[ h_proj]:[ 16x128x784]\n",
76 | "[ out]:[ 16x128x28x28]\n"
77 | ]
78 | }
79 | ],
80 | "source": [
81 | "layer = AttentionBlock(n_channels=128,n_heads=4,n_groups=32)\n",
82 | "x = th.randn(16,128,28,28)\n",
83 | "out,intermediate_output_dict = layer(x)\n",
84 | "print (\"input shape:[%s] output shape:[%s]\"%\n",
85 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
86 | "# Print intermediate values\n",
87 | "for key,value in intermediate_output_dict.items():\n",
88 | " print (\"[%10s]:[%15s]\"%(key,get_torch_size_string(value)))"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "id": "a91d9cec",
94 | "metadata": {},
95 | "source": [
96 | "### `dims=1`\n",
97 | "#### `x` has a shape of `[B x C x L]`"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 3,
103 | "id": "97d1bbdb",
104 | "metadata": {
105 | "scrolled": true
106 | },
107 | "outputs": [
108 | {
109 | "name": "stdout",
110 | "output_type": "stream",
111 | "text": [
112 | "input shape:[16x4x100] output shape:[16x4x100]\n",
113 | "[ x]:[ 16x4x100]\n",
114 | "[ x_rsh]:[ 16x4x100]\n",
115 | "[ x_nzd]:[ 16x4x100]\n",
116 | "[ qkv]:[ 16x12x100]\n",
117 | "[ h_att]:[ 16x4x100]\n",
118 | "[ h_proj]:[ 16x4x100]\n",
119 | "[ out]:[ 16x4x100]\n"
120 | ]
121 | }
122 | ],
123 | "source": [
124 | "layer = AttentionBlock(n_channels=4,n_heads=2,n_groups=1)\n",
125 | "x = th.randn(16,4,100)\n",
126 | "out,intermediate_output_dict = layer(x)\n",
127 | "print (\"input shape:[%s] output shape:[%s]\"%\n",
128 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
129 | "# Print intermediate values\n",
130 | "for key,value in intermediate_output_dict.items():\n",
131 | " print (\"[%10s]:[%15s]\"%(key,get_torch_size_string(value)))"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "id": "8e60d073",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": []
141 | }
142 | ],
143 | "metadata": {
144 | "kernelspec": {
145 | "display_name": "Python 3 (ipykernel)",
146 | "language": "python",
147 | "name": "python3"
148 | },
149 | "language_info": {
150 | "codemirror_mode": {
151 | "name": "ipython",
152 | "version": 3
153 | },
154 | "file_extension": ".py",
155 | "mimetype": "text/x-python",
156 | "name": "python",
157 | "nbconvert_exporter": "python",
158 | "pygments_lexer": "ipython3",
159 | "version": "3.9.16"
160 | }
161 | },
162 | "nbformat": 4,
163 | "nbformat_minor": 5
164 | }
165 |
--------------------------------------------------------------------------------
/code/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch as th
4 | from torchvision import datasets,transforms
5 | from util import gp_sampler,periodic_step,get_torch_size_string
6 |
7 | def mnist(root_path='./data/',batch_size=128):
8 | """
9 | MNIST
10 | """
11 | mnist_train = datasets.MNIST(root=root_path,train=True,transform=transforms.ToTensor(),download=True)
12 | mnist_test = datasets.MNIST(root=root_path,train=False,transform=transforms.ToTensor(),download=True)
13 | train_iter = th.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=1)
14 | test_iter = th.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=True,num_workers=1)
15 | # Data
16 | train_data,train_label = mnist_train.data,mnist_train.targets
17 | test_data,test_label = mnist_test.data,mnist_test.targets
18 | return train_iter,test_iter,train_data,train_label,test_data,test_label
19 |
20 | def get_1d_training_data(
21 | traj_type = 'step', # {'step','gp'}
22 | n_traj = 10,
23 | L = 100,
24 | device = 'cpu',
25 | seed = 1,
26 | plot_data = True,
27 | figsize = (6,2),
28 | ls = '-',
29 | lc = 'k',
30 | lw = 1,
31 | verbose = True,
32 | ):
33 | """
34 | 1-D training data
35 | """
36 | if seed is not None:
37 | np.random.seed(seed=seed)
38 | times = np.linspace(start=0.0,stop=1.0,num=L).reshape((-1,1)) # [L x 1]
39 | if traj_type == 'gp':
40 | traj = th.from_numpy(
41 | gp_sampler(
42 | times = times,
43 | hyp_gain = 2.0,
44 | hyp_len = 0.2,
45 | meas_std = 1e-8,
46 | n_traj = n_traj
47 | )
48 | ).to(th.float32).to(device) # [n_traj x L]
49 | elif traj_type == 'gp2':
50 | traj_np = np.zeros((n_traj,L))
51 | for i_idx in range(n_traj):
52 | traj_np[i_idx,:] = gp_sampler(
53 | times = times,
54 | hyp_gain = 2.0,
55 | hyp_len = np.random.uniform(1e-8,1.0),
56 | meas_std = 1e-8,
57 | n_traj = 1
58 | ).reshape(-1)
59 | traj = th.from_numpy(
60 | traj_np
61 | ).to(th.float32).to(device) # [n_traj x L]
62 | elif traj_type == 'step':
63 | traj_np = np.zeros((n_traj,L))
64 | for i_idx in range(n_traj):
65 | period = np.random.uniform(low=0.38,high=0.42)
66 | time_offset = np.random.uniform(low=-0.02,high=0.02)
67 | y_min = np.random.uniform(low=-3.2,high=-2.8)
68 | y_max = np.random.uniform(low=2.8,high=3.2)
69 | traj_np[i_idx,:] = periodic_step(
70 | times = times,
71 | period = period,
72 | time_offset = time_offset,
73 | y_min = y_min,
74 | y_max = y_max
75 | ).reshape(-1)
76 | traj = th.from_numpy(
77 | traj_np
78 | ).to(th.float32).to(device) # [n_traj x L]
79 | elif traj_type == 'step2':
80 | traj_np = np.zeros((n_traj,L))
81 | for i_idx in range(n_traj): # for each trajectory
82 | # First, sample value and duration
83 | rate = 5
84 | val = np.random.uniform(low=-3.0,high=3.0)
85 | dur_tick = int(L*np.random.exponential(scale=1/rate))
86 | dim_dur = 0.1 # minimum duration in sec
87 | dur_tick = max(dur_tick,int(dim_dur*L))
88 |
89 | tick_fr = 0
90 | tick_to = tick_fr+dur_tick
91 | while True:
92 | # Append
93 | traj_np[i_idx,tick_fr:min(L,tick_to)] = val
94 |
95 | # Termination condition
96 | if tick_to >= L: break
97 |
98 | # sample value and duration
99 | val = np.random.uniform(low=-3.0,high=3.0)
100 | dur_tick = int(L*np.random.exponential(scale=1/rate))
101 | dur_tick = max(dur_tick,int(dim_dur*L))
102 |
103 | # Update tick
104 | tick_fr = tick_to
105 | tick_to = tick_fr+dur_tick
106 | traj = th.from_numpy(
107 | traj_np
108 | ).to(th.float32).to(device) # [n_traj x L]
109 | elif traj_type == 'triangle':
110 | traj_np = np.zeros((n_traj,L))
111 | for i_idx in range(n_traj):
112 | period = 0.2
113 | time_offset = np.random.uniform(low=-0.02,high=0.02)
114 | y_min = np.random.uniform(low=-3.2,high=-2.8)
115 | y_max = np.random.uniform(low=2.8,high=3.2)
116 | times_mod = np.mod(times+time_offset,period)/period
117 | y = (y_max - y_min) * times_mod + y_min
118 | traj_np[i_idx,:] = y.reshape(-1)
119 | traj = th.from_numpy(
120 | traj_np
121 | ).to(th.float32).to(device) # [n_traj x L]
122 | else:
123 | print ("Unknown traj_type:[%s]"%(traj_type))
124 | # Plot
125 | if plot_data:
126 | plt.figure(figsize=figsize)
127 | for i_idx in range(n_traj):
128 | plt.plot(times,traj[i_idx,:].cpu().numpy(),ls=ls,color=lc,lw=lw)
129 | plt.xlim([0.0,1.0])
130 | plt.ylim([-4,+4])
131 | plt.xlabel('Time',fontsize=10)
132 | plt.title('Trajectory type:[%s]'%(traj_type),fontsize=10)
133 | plt.show()
134 | # Print
135 | x_0 = traj[:,None,:] # [N x C x L]
136 | if verbose:
137 | print ("x_0:[%s]"%(get_torch_size_string(x_0)))
138 | # Out
139 | return times,x_0
140 |
141 | def get_mdn_data(
142 | n_train = 1000,
143 | x_min = 0.0,
144 | x_max = 1.0,
145 | y_min = -1.0,
146 | y_max = 1.0,
147 | freq = 1.0,
148 | noise_rate = 1.0,
149 | seed = 0,
150 | FLIP_AUGMENT = True,
151 | ):
152 | np.random.seed(seed=seed)
153 |
154 | if FLIP_AUGMENT:
155 | n_half = n_train // 2
156 | x_train_a = x_min + (x_max-x_min)*np.random.rand(n_half,1) # [n_half x 1]
157 | x_rate = (x_train_a-x_min)/(x_max-x_min) # [n_half x 1]
158 | sin_temp = y_min + (y_max-y_min)*np.sin(2*np.pi*freq*x_rate)
159 | cos_temp = y_min + (y_max-y_min)*np.cos(2*np.pi*freq*x_rate)
160 | y_train_a = np.concatenate(
161 | (sin_temp+1*(y_max-y_min)*x_rate,
162 | cos_temp+1*(y_max-y_min)*x_rate),axis=1) # [n_half x 2]
163 | x_train_b = x_min + (x_max-x_min)*np.random.rand(n_half,1) # [n_half x 1]
164 | x_rate = (x_train_b-x_min)/(x_max-x_min) # [n_half x 1]
165 | sin_temp = y_min + (y_max-y_min)*np.sin(2*np.pi*freq*x_rate)
166 | cos_temp = y_min + (y_max-y_min)*np.cos(2*np.pi*freq*x_rate)
167 | y_train_b = -np.concatenate(
168 | (sin_temp+1*(y_max-y_min)*x_rate,
169 | cos_temp+1*(y_max-y_min)*x_rate),axis=1) # [n_half x 2]
170 | # Concatenate
171 | x_train = np.concatenate((x_train_a,x_train_b),axis=0) # [n_train x 1]
172 | y_train = np.concatenate((y_train_a,y_train_b),axis=0) # [n_train x 2]
173 | else:
174 | x_train = x_min + (x_max-x_min)*np.random.rand(n_train,1) # [n_train x 1]
175 | x_rate = (x_train-x_min)/(x_max-x_min) # [n_train x 1]
176 | sin_temp = y_min + (y_max-y_min)*np.sin(2*np.pi*freq*x_rate)
177 | cos_temp = y_min + (y_max-y_min)*np.cos(2*np.pi*freq*x_rate)
178 | y_train = np.concatenate(
179 | (sin_temp+1*(y_max-y_min)*x_rate,
180 | cos_temp+1*(y_max-y_min)*x_rate),axis=1) # [n_train x 2]
181 |
182 | # Add noise
183 | x_rate = (x_train-x_min)/(x_max-x_min) # [n_train x 1]
184 | noise = noise_rate * (y_max-y_min) * (2*np.random.rand(n_train,2)-1) * ((x_rate)**2) # [n_train x 2]
185 | y_train = y_train + noise # [n_train x 2]
186 | return x_train,y_train
--------------------------------------------------------------------------------
/code/diffusion.py:
--------------------------------------------------------------------------------
1 | import math,random
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import torch as th
5 | import torch.nn as nn
6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7 | from module import (
8 | conv_nd,
9 | ResBlock,
10 | AttentionBlock,
11 | TimestepEmbedSequential,
12 | )
13 |
14 | def get_named_beta_schedule(
15 | schedule_name,
16 | num_diffusion_timesteps,
17 | scale_betas=1.0,
18 | np_type=np.float64
19 | ):
20 | """
21 | Get a pre-defined beta schedule for the given name.
22 |
23 | The beta schedule library consists of beta schedules which remain similar
24 | in the limit of num_diffusion_timesteps.
25 | Beta schedules may be added, but should not be removed or changed once
26 | they are committed to maintain backwards compatibility.
27 | """
28 | if schedule_name == "linear":
29 | # Linear schedule from Ho et al, extended to work for any number of
30 | # diffusion steps.
31 | scale = scale_betas * 1000 / num_diffusion_timesteps
32 | beta_start = scale * 0.0001
33 | beta_end = scale * 0.02
34 | return np.linspace(
35 | beta_start, beta_end, num_diffusion_timesteps, dtype=np_type
36 | )
37 | elif schedule_name == "cosine":
38 | return betas_for_alpha_bar(
39 | num_diffusion_timesteps,
40 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
41 | )
42 | else:
43 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
44 |
45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
46 | """
47 | Create a beta schedule that discretizes the given alpha_t_bar function,
48 | which defines the cumulative product of (1-beta) over time from t = [0,1].
49 |
50 | :param num_diffusion_timesteps: the number of betas to produce.
51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
52 | produces the cumulative product of (1-beta) up to that
53 | part of the diffusion process.
54 | :param max_beta: the maximum beta to use; use values lower than 1 to
55 | prevent singularities.
56 | """
57 | betas = []
58 | for i in range(num_diffusion_timesteps):
59 | t1 = i / num_diffusion_timesteps
60 | t2 = (i + 1) / num_diffusion_timesteps
61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
62 | return np.array(betas)
63 |
64 | def get_ddpm_constants(
65 | schedule_name = 'cosine',
66 | T = 1000,
67 | np_type = np.float64
68 | ):
69 | timesteps = np.linspace(start=1,stop=T,num=T)
70 | betas = get_named_beta_schedule(
71 | schedule_name = schedule_name,
72 | num_diffusion_timesteps = T,
73 | scale_betas = 1.0,
74 | ).astype(np_type) # [1,000]
75 | alphas = 1.0 - betas
76 | alphas_bar = np.cumprod(alphas, axis=0) # cummulative product
77 | alphas_bar_prev = np.append(1.0,alphas_bar[:-1])
78 | sqrt_recip_alphas = np.sqrt(1.0/alphas)
79 | sqrt_alphas_bar = np.sqrt(alphas_bar)
80 | sqrt_one_minus_alphas_bar = np.sqrt(1.0-alphas_bar)
81 | posterior_variance = betas*(1.0-alphas_bar_prev)/(1.0-alphas_bar)
82 | posterior_variance = posterior_variance.astype(np_type)
83 |
84 | # Append
85 | dc = {}
86 | dc['schedule_name'] = schedule_name
87 | dc['T'] = T
88 | dc['timesteps'] = timesteps
89 | dc['betas'] = betas
90 | dc['alphas'] = alphas
91 | dc['alphas_bar'] = alphas_bar
92 | dc['alphas_bar_prev'] = alphas_bar_prev
93 | dc['sqrt_recip_alphas'] = sqrt_recip_alphas
94 | dc['sqrt_alphas_bar'] = sqrt_alphas_bar
95 | dc['sqrt_one_minus_alphas_bar'] = sqrt_one_minus_alphas_bar
96 | dc['posterior_variance'] = posterior_variance
97 |
98 | return dc
99 |
100 | def plot_ddpm_constants(dc):
101 | """
102 | Plot DDPM constants
103 | """
104 | plt.figure(figsize=(10,3))
105 | cs = [plt.cm.gist_rainbow(x) for x in np.linspace(0,1,8)]
106 | lw = 2
107 | plt.subplot(1,2,1)
108 | plt.plot(dc['timesteps'],dc['betas'],
109 | color=cs[0],label=r'$\beta_t$',lw=lw)
110 | plt.plot(dc['timesteps'],dc['alphas'],
111 | color=cs[1],label=r'$\alpha_t$',lw=lw)
112 | plt.plot(dc['timesteps'],dc['alphas_bar'],
113 | color=cs[2],label=r'$\bar{\alpha}_t$',lw=lw)
114 | plt.plot(dc['timesteps'],dc['sqrt_alphas_bar'],
115 | color=cs[5],label=r'$\sqrt{\bar{\alpha}_t}$',lw=lw)
116 |
117 | plt.plot(dc['timesteps'],dc['sqrt_one_minus_alphas_bar'],
118 | color=cs[6],label=r'$\sqrt{1-\bar{\alpha}_t}$',lw=lw)
119 |
120 |
121 | plt.plot(dc['timesteps'],dc['posterior_variance'],'--',
122 | color='k',label=r'$ Var[x_{t-1}|x_t,x_0] $',lw=lw)
123 |
124 | plt.xlabel('Diffusion steps',fontsize=8)
125 | plt.legend(fontsize=10,loc='center left',bbox_to_anchor=(1,0.5))
126 | plt.grid(lw=0.5)
127 | plt.title('DDPM Constants',fontsize=10)
128 | plt.subplot(1,2,2)
129 | plt.plot(dc['timesteps'],dc['betas'],color=cs[0],label=r'$\beta_t$',lw=lw)
130 | plt.plot(dc['timesteps'],dc['posterior_variance'],'--',
131 | color='k',label=r'$ Var[x_{t-1}|x_t,x_0] $',lw=lw)
132 | plt.xlabel('Diffusion steps',fontsize=8)
133 | plt.legend(fontsize=10,loc='center left',bbox_to_anchor=(1,0.5))
134 | plt.grid(lw=0.5)
135 | plt.title('DDPM Constants',fontsize=10)
136 | plt.subplots_adjust(wspace=0.7)
137 | plt.show()
138 |
139 | def timestep_embedding(timesteps, dim, max_period=10000):
140 | """
141 | Create sinusoidal timestep embeddings.
142 |
143 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
144 | These may be fractional.
145 | :param dim: the dimension of the output.
146 | :param max_period: controls the minimum frequency of the embeddings.
147 | :return: an [N x dim] Tensor of positional embeddings.
148 | """
149 | half = dim // 2
150 | freqs = th.exp(
151 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
152 | ).to(device=timesteps.device)
153 | args = timesteps[:, None].float() * freqs[None]
154 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
155 | if dim % 2:
156 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
157 | return embedding
158 |
159 | def forward_sample(x0_batch,t_batch,dc,M=None,noise_scale=1.0):
160 | """
161 | Forward diffusion sampling
162 | :param x0_batch: [B x C x ...]
163 | :param t_batch: [B]
164 | :param dc: dictionary of diffusion constants
165 | :param M: a matrix of [L x L] for [B x C x L] data
166 | :return: xt_batch of [B x C x ...] and noise of [B x C x ...]
167 | """
168 | # Gather diffusion constants with matching dimension
169 | out_shape = (t_batch.shape[0],) + ((1,)*(len(x0_batch.shape)-1))
170 | device = t_batch.device
171 | sqrt_alphas_bar_t = th.gather(
172 | input = th.from_numpy(dc['sqrt_alphas_bar']).to(device), # [T]
173 | dim = -1,
174 | index = t_batch
175 | ).reshape(out_shape) # [B x 1 x 1 x 1] if (rank==4) and [B x 1 x 1] if (rank==3)
176 | sqrt_one_minus_alphas_bar = th.gather(
177 | input = th.from_numpy(dc['sqrt_one_minus_alphas_bar']).to(device), # [T]
178 | dim = -1,
179 | index = t_batch
180 | ).reshape(out_shape) # [B x 1 x 1 x 1] if (rank==4) and [B x 1 x 1] if (rank==3)
181 |
182 | # Forward sample
183 | noise = th.randn_like(input=x0_batch) # [B x C x ...]
184 |
185 | # (optional) correlated noise
186 | if M is not None:
187 | B = x0_batch.shape[0]
188 | C = x0_batch.shape[1]
189 | L = x0_batch.shape[2]
190 | if isinstance(M, list): # if M is a list,
191 | M_use = random.choice(M)
192 | else:
193 | M_use = M # [L x L]
194 | M_exp = M_use[None,None,:,:].expand(B,C,L,L) # [L x L] => [B x C x L x L]
195 | noise_exp = noise[:,:,:,None] # [B x C x L x 1]
196 | noise_exp = M_exp @ noise_exp # [B x C x L x 1]
197 | noise = noise_exp.squeeze(dim=3) # [B x C x L]
198 |
199 | # Jump diffusion
200 | xt_batch = sqrt_alphas_bar_t*x0_batch + \
201 | sqrt_one_minus_alphas_bar*noise_scale*noise # [B x C x ...]
202 | return xt_batch,noise
203 |
204 | class DiffusionUNet(nn.Module):
205 | """
206 | U-Net for diffusion models
207 | """
208 | def __init__(
209 | self,
210 | name = 'unet',
211 | dims = 1, # spatial dimension, if dims==1, [B x C x L], if dims==2, [B x C x W x H]
212 | n_in_channels = 128, # input channels
213 | n_model_channels = 64, # base channel size
214 | n_emb_dim = 128, # time embedding size
215 | n_cond_dim = 0, # conditioning vector size (default is 0 indicating an unconditional model)
216 | n_enc_blocks = 2, # number of encoder blocks
217 | n_dec_blocks = 2, # number of decoder blocks
218 | n_groups = 16, # group norm paramter
219 | n_heads = 4, # number of heads
220 | actv = nn.SiLU(),
221 | kernel_size = 3, # kernel size
222 | padding = 1,
223 | use_resblock = True,
224 | use_attention = True,
225 | skip_connection = False,
226 | use_scale_shift_norm = True, # positional embedding handling
227 | device = 'cpu',
228 | ):
229 | super().__init__()
230 | self.name = name
231 | self.dims = dims
232 | self.n_in_channels = n_in_channels
233 | self.n_model_channels = n_model_channels
234 | self.n_emb_dim = n_emb_dim
235 | self.n_cond_dim = n_cond_dim
236 | self.n_enc_blocks = n_enc_blocks
237 | self.n_dec_blocks = n_dec_blocks
238 | self.n_groups = n_groups
239 | self.n_heads = n_heads
240 | self.actv = actv
241 | self.kernel_size = kernel_size
242 | self.padding = padding
243 | self.use_resblock = use_resblock
244 | self.use_attention = use_attention
245 | self.skip_connection = skip_connection
246 | self.use_scale_shift_norm = use_scale_shift_norm
247 | self.device = device
248 |
249 | # Time embedding
250 | self.time_embed = nn.Sequential(
251 | nn.Linear(in_features=self.n_model_channels,out_features=self.n_emb_dim),
252 | nn.SiLU(),
253 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim),
254 | ).to(self.device)
255 |
256 | # Conditional embedding
257 | if self.n_cond_dim > 0:
258 | self.cond_embed = nn.Sequential(
259 | nn.Linear(in_features=self.n_cond_dim,out_features=self.n_emb_dim),
260 | nn.SiLU(),
261 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim),
262 | ).to(self.device)
263 |
264 | # Lifting
265 | self.lift = conv_nd(
266 | dims = self.dims,
267 | in_channels = self.n_in_channels,
268 | out_channels = self.n_model_channels,
269 | kernel_size = 1,
270 | ).to(device)
271 |
272 | # Projection
273 | self.proj = conv_nd(
274 | dims = self.dims,
275 | in_channels = self.n_model_channels,
276 | out_channels = self.n_in_channels,
277 | kernel_size = 1,
278 | ).to(device)
279 |
280 | # Declare U-net
281 | # Encoder
282 | self.enc_layers = []
283 | for e_idx in range(self.n_enc_blocks):
284 | # Residual block in encoder
285 | if self.use_resblock:
286 | self.enc_layers.append(
287 | ResBlock(
288 | name = 'res',
289 | n_channels = self.n_model_channels,
290 | n_emb_channels = self.n_emb_dim,
291 | n_out_channels = self.n_model_channels,
292 | n_groups = self.n_groups,
293 | dims = self.dims,
294 | actv = self.actv,
295 | kernel_size = self.kernel_size,
296 | padding = self.padding,
297 | upsample = False,
298 | downsample = False,
299 | use_scale_shift_norm = self.use_scale_shift_norm,
300 | ).to(device)
301 | )
302 | # Attention block in encoder
303 | if self.use_attention:
304 | self.enc_layers.append(
305 | AttentionBlock(
306 | name = 'att',
307 | n_channels = self.n_model_channels,
308 | n_heads = self.n_heads,
309 | n_groups = self.n_groups,
310 | ).to(device)
311 | )
312 |
313 | # Decoder
314 | self.dec_layers = []
315 | for d_idx in range(self.n_dec_blocks):
316 | # Residual block in decoder
317 | if self.use_resblock:
318 | if d_idx == 0: n_channels = self.n_model_channels*self.n_enc_blocks
319 | else: n_channels = self.n_model_channels
320 | self.dec_layers.append(
321 | ResBlock(
322 | name = 'res',
323 | n_channels = n_channels,
324 | n_emb_channels = self.n_emb_dim,
325 | n_out_channels = self.n_model_channels,
326 | n_groups = self.n_groups,
327 | dims = self.dims,
328 | actv = self.actv,
329 | kernel_size = self.kernel_size,
330 | padding = self.padding,
331 | upsample = False,
332 | downsample = False,
333 | use_scale_shift_norm = self.use_scale_shift_norm,
334 | ).to(device)
335 | )
336 | # Attention block in decoder
337 | if self.use_attention:
338 | self.dec_layers.append(
339 | AttentionBlock(
340 | name = 'att',
341 | n_channels = self.n_model_channels,
342 | n_heads = self.n_heads,
343 | n_groups = self.n_groups,
344 | ).to(device)
345 | )
346 |
347 | # Define U-net
348 | self.enc_net = nn.Sequential()
349 | for l_idx,layer in enumerate(self.enc_layers):
350 | self.enc_net.add_module(
351 | name = 'enc_%02d'%(l_idx),
352 | module = TimestepEmbedSequential(layer)
353 | )
354 | self.dec_net = nn.Sequential()
355 | for l_idx,layer in enumerate(self.dec_layers):
356 | self.dec_net.add_module(
357 | name = 'dec_%02d'%(l_idx),
358 | module = TimestepEmbedSequential(layer)
359 | )
360 |
361 | def forward(self,x,timesteps,c=None):
362 | """
363 | :param x: [B x n_in_channels x ...]
364 | :param timesteps: [B]
365 | :param c: [B]
366 | :return: [B x n_in_channels x ...], same shape as x
367 | """
368 | intermediate_output_dict = {}
369 | intermediate_output_dict['x'] = x
370 |
371 | # time embedding
372 | emb = self.time_embed(
373 | timestep_embedding(timesteps,self.n_model_channels)
374 | ) # [B x n_emb_dim]
375 |
376 | # conditional embedding
377 | if self.n_cond_dim > 0:
378 | cond = self.cond_embed(c)
379 | emb = emb + cond
380 |
381 | # Lift input
382 | h = self.lift(x) # [B x n_model_channels x ...]
383 | intermediate_output_dict['x_lifted'] = h
384 |
385 | # Encoder
386 | self.h_enc_list = []
387 | for m_idx,module in enumerate(self.enc_net):
388 | h = module(h,emb)
389 | if isinstance(h,tuple): h = h[0] # in case of having tuple
390 | # Append
391 | module_name = module[0].name
392 | intermediate_output_dict['h_enc_%s_%02d'%(module_name,m_idx)] = h
393 | # Append encoder output
394 | if self.use_resblock and self.use_attention:
395 | if (m_idx%2) == 1:
396 | self.h_enc_list.append(h)
397 | elif self.use_resblock and not self.use_attention:
398 | self.h_enc_list.append(h)
399 | elif not self.use_resblock and self.use_attention:
400 | self.h_enc_list.append(h)
401 | else:
402 | self.h_enc_list.append(h)
403 |
404 | # Stack encoder outputs
405 | if not self.use_resblock and self.use_attention:
406 | h_enc_stack = h
407 | else:
408 | for h_idx,h_enc in enumerate(self.h_enc_list):
409 | if h_idx == 0: h_enc_stack = h_enc
410 | else: h_enc_stack = th.cat([h_enc_stack,h_enc],dim=1)
411 | intermediate_output_dict['h_enc_stack'] = h_enc_stack
412 |
413 | # Decoder
414 | h = h_enc_stack # [B x n_enc_blocks*n_model_channels x ...]
415 | for m_idx,module in enumerate(self.dec_net):
416 | h = module(h,emb) # [B x n_model_channels x ...]
417 | if isinstance(h,tuple): h = h[0] # in case of having tuple
418 | # Append
419 | module_name = module[0].name
420 | intermediate_output_dict['h_dec_%s_%02d'%(module_name,m_idx)] = h
421 |
422 | # Projection
423 | if self.skip_connection:
424 | out = self.proj(h) + x # [B x n_in_channels x ...]
425 | else:
426 | out = self.proj(h) # [B x n_in_channels x ...]
427 |
428 | # Append
429 | intermediate_output_dict['out'] = out # [B x n_in_channels x ...]
430 |
431 | return out,intermediate_output_dict
432 |
433 | class DiffusionUNetLegacy(nn.Module):
434 | """
435 | U-Net for diffusion models (legacy)
436 | """
437 | def __init__(
438 | self,
439 | name = 'unet',
440 | dims = 1, # spatial dimension, if dims==1, [B x C x L], if dims==2, [B x C x W x H]
441 | n_in_channels = 128, # input channels
442 | n_base_channels = 64, # base channel size
443 | n_emb_dim = 128, # time embedding size
444 | n_cond_dim = 0, # conditioning vector size (default is 0 indicating an unconditional model)
445 | n_enc_blocks = 3, # number of encoder blocks
446 | n_dec_blocks = 3, # number of decoder blocks
447 | n_groups = 16, # group norm paramter
448 | n_heads = 4, # number of heads
449 | actv = nn.SiLU(),
450 | kernel_size = 3, # kernel size
451 | padding = 1,
452 | use_attention = True,
453 | skip_connection = True, # (optional) additional final skip connection
454 | use_scale_shift_norm = True, # positional embedding handling
455 | chnnel_multiples = (1,2,4),
456 | updown_rates = (2,2,2),
457 | device = 'cpu',
458 | ):
459 | super().__init__()
460 | self.name = name
461 | self.dims = dims
462 | self.n_in_channels = n_in_channels
463 | self.n_base_channels = n_base_channels
464 | self.n_emb_dim = n_emb_dim
465 | self.n_cond_dim = n_cond_dim
466 | self.n_enc_blocks = n_enc_blocks
467 | self.n_dec_blocks = n_dec_blocks
468 | self.n_groups = n_groups
469 | self.n_heads = n_heads
470 | self.actv = actv
471 | self.kernel_size = kernel_size
472 | self.padding = padding
473 | self.use_attention = use_attention
474 | self.skip_connection = skip_connection
475 | self.use_scale_shift_norm = use_scale_shift_norm
476 | self.chnnel_multiples = chnnel_multiples
477 | self.updown_rates = updown_rates
478 | self.device = device
479 |
480 | # Time embedding
481 | self.time_embed = nn.Sequential(
482 | nn.Linear(in_features=self.n_base_channels,out_features=self.n_emb_dim),
483 | nn.SiLU(),
484 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim),
485 | ).to(self.device)
486 |
487 | # Conditional embedding
488 | if self.n_cond_dim > 0:
489 | self.cond_embed = nn.Sequential(
490 | nn.Linear(in_features=self.n_cond_dim,out_features=self.n_emb_dim),
491 | nn.SiLU(),
492 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim),
493 | ).to(self.device)
494 |
495 | # Lifting (1x1 conv)
496 | self.lift = conv_nd(
497 | dims = self.dims,
498 | in_channels = self.n_in_channels,
499 | out_channels = self.n_base_channels,
500 | kernel_size = 1,
501 | ).to(device)
502 |
503 | # Encoder
504 | self.enc_layers = []
505 | n_channels2cat = [] # channel size to concat for decoder (note that we should use .pop() )
506 | for e_idx in range(self.n_enc_blocks): # for each encoder block
507 | if e_idx == 0:
508 | in_channel = self.n_base_channels
509 | out_channel = self.n_base_channels*self.chnnel_multiples[e_idx]
510 | else:
511 | in_channel = self.n_base_channels*self.chnnel_multiples[e_idx-1]
512 | out_channel = self.n_base_channels*self.chnnel_multiples[e_idx]
513 | n_channels2cat.append(out_channel) # append out_channel
514 | updown_rate = updown_rates[e_idx]
515 |
516 | # Residual block in encoder
517 | self.enc_layers.append(
518 | ResBlock(
519 | name = 'res',
520 | n_channels = in_channel,
521 | n_emb_channels = self.n_emb_dim,
522 | n_out_channels = out_channel,
523 | n_groups = self.n_groups,
524 | dims = self.dims,
525 | actv = self.actv,
526 | kernel_size = self.kernel_size,
527 | padding = self.padding,
528 | downsample = updown_rate != 1,
529 | down_rate = updown_rate,
530 | use_scale_shift_norm = self.use_scale_shift_norm,
531 | ).to(device)
532 | )
533 | # Attention block in encoder
534 | if self.use_attention:
535 | self.enc_layers.append(
536 | AttentionBlock(
537 | name = 'att',
538 | n_channels = out_channel,
539 | n_heads = self.n_heads,
540 | n_groups = self.n_groups,
541 | ).to(device)
542 | )
543 |
544 | # Mid
545 | self.mid = conv_nd(
546 | dims = self.dims,
547 | in_channels = self.n_base_channels*self.chnnel_multiples[-1],
548 | out_channels = self.n_base_channels*self.chnnel_multiples[-1],
549 | kernel_size = 1,
550 | ).to(device)
551 |
552 | # Decoder
553 | self.dec_layers = []
554 | for d_idx in range(self.n_dec_blocks):
555 | # Residual block in decoder
556 | if d_idx == 0: # first decoder
557 | in_channel = self.chnnel_multiples[::-1][d_idx]*self.n_base_channels + n_channels2cat.pop()
558 | out_channel = self.chnnel_multiples[::-1][d_idx]*self.n_base_channels
559 | else:
560 | in_channel = self.chnnel_multiples[::-1][d_idx-1]*self.n_base_channels + n_channels2cat.pop()
561 | out_channel = self.chnnel_multiples[::-1][d_idx]*self.n_base_channels
562 |
563 | updown_rate = updown_rates[::-1][d_idx]
564 |
565 | self.dec_layers.append(
566 | ResBlock(
567 | name = 'res',
568 | n_channels = in_channel,
569 | n_emb_channels = self.n_emb_dim,
570 | n_out_channels = out_channel,
571 | n_groups = self.n_groups,
572 | dims = self.dims,
573 | actv = self.actv,
574 | kernel_size = self.kernel_size,
575 | padding = self.padding,
576 | upsample = updown_rate != 1,
577 | up_rate = updown_rate,
578 | use_scale_shift_norm = self.use_scale_shift_norm,
579 | ).to(device)
580 | )
581 | # Attention block in decoder
582 | if self.use_attention:
583 | self.dec_layers.append(
584 | AttentionBlock(
585 | name = 'att',
586 | n_channels = out_channel,
587 | n_heads = self.n_heads,
588 | n_groups = self.n_groups,
589 | ).to(device)
590 | )
591 |
592 | # Projection
593 | self.proj = conv_nd(
594 | dims = self.dims,
595 | in_channels = (1+self.chnnel_multiples[0])*self.n_base_channels,
596 | out_channels = self.n_in_channels,
597 | kernel_size = 1,
598 | ).to(device)
599 |
600 | # Define U-net
601 | self.enc_net = nn.Sequential()
602 | for l_idx,layer in enumerate(self.enc_layers):
603 | self.enc_net.add_module(
604 | name = 'enc_%02d'%(l_idx),
605 | module = TimestepEmbedSequential(layer)
606 | )
607 | self.dec_net = nn.Sequential()
608 | for l_idx,layer in enumerate(self.dec_layers):
609 | self.dec_net.add_module(
610 | name = 'dec_%02d'%(l_idx),
611 | module = TimestepEmbedSequential(layer)
612 | )
613 |
614 | def forward(self,x,timesteps,c=None):
615 | """
616 | :param x: [B x n_in_channels x ...]
617 | :param timesteps: [B]
618 | :param c:
619 | :return: [B x n_in_channels x ...], same shape as x
620 | """
621 | intermediate_output_dict = {}
622 | intermediate_output_dict['x'] = x
623 |
624 | # time embedding
625 | emb = self.time_embed(
626 | timestep_embedding(timesteps,self.n_base_channels)
627 | ) # [B x n_emb_dim]
628 |
629 | # conditional embedding
630 | if self.n_cond_dim > 0:
631 | cond = self.cond_embed(c)
632 | emb = emb + cond # [B x n_base_channels]
633 |
634 | # Lift input
635 | h = self.lift(x) # [B x n_base_channels x ...]
636 | if isinstance(h,tuple): h = h[0] # in case of having tuple
637 | intermediate_output_dict['x_lifted'] = h
638 |
639 | # Encoder
640 | self.h_enc_list = [h] # start with lifted input
641 | for m_idx,module in enumerate(self.enc_net):
642 | h = module(h,emb)
643 | if isinstance(h,tuple): h = h[0] # in case of having tuple
644 | # Append
645 | module_name = module[0].name
646 | intermediate_output_dict['h_enc_%s_%02d'%(module_name,m_idx)] = h
647 | # Append encoder output
648 | if self.use_attention:
649 | if (m_idx%2) == 1:
650 | self.h_enc_list.append(h)
651 | else:
652 | self.h_enc_list.append(h)
653 |
654 | # Mid
655 | h = self.mid(h)
656 | if isinstance(h,tuple): h = h[0] # in case of having tuple
657 |
658 | # Decoder
659 | for m_idx,module in enumerate(self.dec_net):
660 | if self.use_attention:
661 | if (m_idx%2) == 0:
662 | h = th.cat((h,self.h_enc_list.pop()),dim=1)
663 | else:
664 | h = th.cat((h,self.h_enc_list.pop()),dim=1)
665 | h = module(h,emb) # [B x n_base_channels x ...]
666 | if isinstance(h,tuple): h = h[0] # in cfase of having tuple
667 | # Append
668 | module_name = module[0].name
669 | intermediate_output_dict['h_dec_%s_%02d'%(module_name,m_idx)] = h
670 |
671 | # Projection
672 | h = th.cat((h,self.h_enc_list.pop()),dim=1)
673 |
674 | if self.skip_connection:
675 | out = self.proj(h) + x # [B x n_in_channels x ...]
676 | else:
677 | out = self.proj(h) # [B x n_in_channels x ...]
678 |
679 | # Append
680 | intermediate_output_dict['out'] = out # [B x n_in_channels x ...]
681 |
682 | return out,intermediate_output_dict
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 | def get_param_groups_and_shapes(named_model_params):
714 | named_model_params = list(named_model_params)
715 | scalar_vector_named_params = (
716 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
717 | (-1),
718 | )
719 | matrix_named_params = (
720 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
721 | (1, -1),
722 | )
723 | return [scalar_vector_named_params, matrix_named_params]
724 |
725 | def make_master_params(param_groups_and_shapes):
726 | """
727 | Copy model parameters into a (differently-shaped) list of full-precision
728 | parameters.
729 | """
730 | master_params = []
731 | for param_group, shape in param_groups_and_shapes:
732 | master_param = nn.Parameter(
733 | _flatten_dense_tensors(
734 | [param.detach().float() for (_, param) in param_group]
735 | ).view(shape)
736 | )
737 | master_param.requires_grad = True
738 | master_params.append(master_param)
739 | return master_params
740 |
741 | def unflatten_master_params(param_group, master_param):
742 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
743 |
744 | def master_params_to_state_dict(
745 | model, param_groups_and_shapes, master_params, use_fp16
746 | ):
747 | if use_fp16:
748 | state_dict = model.state_dict()
749 | for master_param, (param_group, _) in zip(
750 | master_params, param_groups_and_shapes
751 | ):
752 | for (name, _), unflat_master_param in zip(
753 | param_group, unflatten_master_params(param_group, master_param.view(-1))
754 | ):
755 | assert name in state_dict
756 | state_dict[name] = unflat_master_param
757 | else:
758 | state_dict = model.state_dict()
759 | for i, (name, _value) in enumerate(model.named_parameters()):
760 | assert name in state_dict
761 | state_dict[name] = master_params[i]
762 | return state_dict
763 |
764 | def state_dict_to_master_params(model, state_dict, use_fp16):
765 | if use_fp16:
766 | named_model_params = [
767 | (name, state_dict[name]) for name, _ in model.named_parameters()
768 | ]
769 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
770 | master_params = make_master_params(param_groups_and_shapes)
771 | else:
772 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
773 | return master_params
774 |
775 | def zero_grad(model_params):
776 | for param in model_params:
777 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
778 | if param.grad is not None:
779 | param.grad.detach_()
780 | param.grad.zero_()
781 |
782 | def param_grad_or_zeros(param):
783 | if param.grad is not None:
784 | return param.grad.data.detach()
785 | else:
786 | return th.zeros_like(param)
787 |
788 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
789 | """
790 | Copy the gradients from the model parameters into the master parameters
791 | from make_master_params().
792 | """
793 | for master_param, (param_group, shape) in zip(
794 | master_params, param_groups_and_shapes
795 | ):
796 | master_param.grad = _flatten_dense_tensors(
797 | [param_grad_or_zeros(param) for (_, param) in param_group]
798 | ).view(shape)
799 |
800 | def check_overflow(value):
801 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
802 |
803 | def zero_master_grads(master_params):
804 | for param in master_params:
805 | param.grad = None
806 |
807 | def master_params_to_model_params(param_groups_and_shapes, master_params):
808 | """
809 | Copy the master parameter data back into the model parameters.
810 | """
811 | # Without copying to a list, if a generator is passed, this will
812 | # silently not copy any parameters.
813 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
814 | for (_, param), unflat_master_param in zip(
815 | param_group, unflatten_master_params(param_group, master_param.view(-1))
816 | ):
817 | param.detach().copy_(unflat_master_param)
818 |
819 | INITIAL_LOG_LOSS_SCALE = 20.0
820 | class MixedPrecisionTrainer:
821 | def __init__(
822 | self,
823 | *,
824 | model,
825 | use_fp16=False,
826 | fp16_scale_growth=1e-3,
827 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
828 | ):
829 | self.model = model
830 | self.use_fp16 = use_fp16
831 | self.fp16_scale_growth = fp16_scale_growth
832 |
833 | self.model_params = list(self.model.parameters())
834 | self.master_params = self.model_params
835 | self.param_groups_and_shapes = None
836 | self.lg_loss_scale = initial_lg_loss_scale
837 |
838 | if self.use_fp16:
839 | self.param_groups_and_shapes = get_param_groups_and_shapes(
840 | self.model.named_parameters()
841 | )
842 | self.master_params = make_master_params(self.param_groups_and_shapes)
843 | self.model.convert_to_fp16()
844 |
845 | def zero_grad(self):
846 | zero_grad(self.model_params)
847 |
848 | def backward(self, loss: th.Tensor):
849 | if self.use_fp16:
850 | loss_scale = 2 ** self.lg_loss_scale
851 | (loss * loss_scale).backward()
852 | else:
853 | loss.backward()
854 |
855 | def optimize(self, opt: th.optim.Optimizer):
856 | if self.use_fp16:
857 | return self._optimize_fp16(opt)
858 | else:
859 | return self._optimize_normal(opt)
860 |
861 | def _optimize_fp16(self, opt: th.optim.Optimizer):
862 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
863 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
864 | if check_overflow(grad_norm):
865 | self.lg_loss_scale -= 1
866 | zero_master_grads(self.master_params)
867 | return False
868 |
869 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
870 | opt.step()
871 | zero_master_grads(self.master_params)
872 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
873 | self.lg_loss_scale += self.fp16_scale_growth
874 | return True
875 |
876 | def _optimize_normal(self, opt: th.optim.Optimizer):
877 | grad_norm, param_norm = self._compute_norms()
878 | opt.step()
879 | return True
880 |
881 | def _compute_norms(self, grad_scale=1.0):
882 | grad_norm = 0.0
883 | param_norm = 0.0
884 | for p in self.master_params:
885 | with th.no_grad():
886 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
887 | if p.grad is not None:
888 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
889 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
890 |
891 | def master_params_to_state_dict(self, master_params):
892 | return master_params_to_state_dict(
893 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
894 | )
895 |
896 | def state_dict_to_master_params(self, state_dict):
897 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
898 |
899 | def eval_ddpm_1d(
900 | model,
901 | dc,
902 | n_sample,
903 | x_0,
904 | step_list_to_append,
905 | device,
906 | cond = None,
907 | M = None,
908 | noise_scale = 1.0
909 | ):
910 | """
911 | Evaluate DDPM in 1D case
912 | :param model: score function
913 | :param dc: dictionary of diffusion coefficients
914 | :param n_sample: integer of how many trajectories to sample
915 | :param x_0: [N x C x L] tensor
916 | :param step_list_to_append: an ndarry of diffusion steps to append x_t
917 | """
918 | model.eval()
919 | n_data,C,L = x_0.shape
920 | x_dummy = th.zeros(n_sample,C,L,device=device)
921 | step_dummy = th.zeros(n_sample).type(th.long).to(device)
922 | _,x_T = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x L]
923 | x_t = x_T.clone() # [n_sample x C x L]
924 | x_t_list = ['']*dc['T'] # empty list
925 | for t in range(0,dc['T'])[::-1]: # 999 to 0
926 | # Score function
927 | step = th.full(
928 | size = (n_sample,),
929 | fill_value = t,
930 | device = device,
931 | dtype = th.long) # [n_sample]
932 | with th.no_grad():
933 | if cond is None: # unconditioned model
934 | eps_t,_ = model(x_t,step) # [n_sample x C x L]
935 | else:
936 | cond_weight = 0.5
937 | eps_cont_d,_ = model(x_t,step,cond.repeat(n_sample,1))
938 | eps_uncond_d,_ = model(x_t,step,0.0*cond.repeat(n_sample,1))
939 | # Addup
940 | eps_t = (1+cond_weight)*eps_cont_d - cond_weight*eps_uncond_d # [n_sample x C x L]
941 | betas_t = th.gather(
942 | input = th.from_numpy(dc['betas']).to(device), # [T]
943 | dim = -1,
944 | index = step
945 | ).reshape((-1,1,1)) # [n_sample x 1 x 1]
946 | sqrt_one_minus_alphas_bar_t = th.gather(
947 | input = th.from_numpy(dc['sqrt_one_minus_alphas_bar']).to(device), # [T]
948 | dim = -1,
949 | index = step
950 | ).reshape((-1,1,1)) # [n_sample x 1 x 1]
951 | sqrt_recip_alphas_t = th.gather(
952 | input = th.from_numpy(dc['sqrt_recip_alphas']).to(device), # [T]
953 | dim = -1,
954 | index = step
955 | ).reshape((-1,1,1)) # [n_sample x 1 x 1]
956 | # Compute posterior mean
957 | mean_t = sqrt_recip_alphas_t * (
958 | x_t - betas_t*eps_t/sqrt_one_minus_alphas_bar_t
959 | ) # [n_sample x C x L]
960 | # Compute posterior variance
961 | posterior_variance_t = th.gather(
962 | input = th.from_numpy(dc['posterior_variance']).to(device), # [T]
963 | dim = -1,
964 | index = step
965 | ).reshape((-1,1,1)) # [n_sample x 1 x 1]
966 | # Sample
967 | if t == 0: # last sampling, use mean
968 | x_t = mean_t
969 | else:
970 | _,noise_t = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x 1]
971 | x_t = mean_t + noise_scale*th.sqrt(posterior_variance_t)*noise_t
972 | # Append
973 | if t in step_list_to_append:
974 | x_t_list[t] = x_t
975 | model.train()
976 | return x_t_list # list of [n_sample x C x L]
977 |
978 | def eval_ddpm_2d(
979 | model,
980 | dc,
981 | n_sample,
982 | x_0,
983 | step_list_to_append,
984 | device,
985 | cond=None,
986 | M=None,
987 | noise_scale=1.0
988 | ):
989 | """
990 | Evaluate DDPM in 2D case
991 | :param model: score function
992 | :param dc: dictionary of diffusion coefficients
993 | :param n_sample: integer of how many trajectories to sample
994 | :param x_0: [N x C x W x H] tensor
995 | :param step_list_to_append: an ndarry of diffusion steps to append x_t
996 | """
997 | model.eval()
998 | n_data,C,W,H = x_0.shape
999 | x_dummy = th.zeros(n_sample,C,W,H,device=device)
1000 | step_dummy = th.zeros(n_sample).type(th.long).to(device)
1001 | _,x_T = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x W x H]
1002 | x_t = x_T.clone() # [n_sample x C x W x H]
1003 | x_t_list = ['']*dc['T'] # empty list
1004 | for t in range(0,dc['T'])[::-1]: # 999 to 0
1005 | # Score function
1006 | step = th.full(
1007 | size = (n_sample,),
1008 | fill_value = t,
1009 | device = device,
1010 | dtype = th.long) # [n_sample]
1011 | with th.no_grad():
1012 | if cond is None: # unconditioned model
1013 | eps_t,_ = model(x_t,step) # [n_sample x C x W x H]
1014 | else:
1015 | cond_weight = 0.5
1016 | eps_cont_d,_ = model(x_t,step,cond.repeat(n_sample,1))
1017 | eps_uncond_d,_ = model(x_t,step,0.0*cond.repeat(n_sample,1))
1018 | # Addup
1019 | eps_t = (1+cond_weight)*eps_cont_d - cond_weight*eps_uncond_d # [n_sample x C x W x H]
1020 | betas_t = th.gather(
1021 | input = th.from_numpy(dc['betas']).to(device), # [T]
1022 | dim = -1,
1023 | index = step
1024 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1]
1025 | sqrt_one_minus_alphas_bar_t = th.gather(
1026 | input = th.from_numpy(dc['sqrt_one_minus_alphas_bar']).to(device), # [T]
1027 | dim = -1,
1028 | index = step
1029 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1]
1030 | sqrt_recip_alphas_t = th.gather(
1031 | input = th.from_numpy(dc['sqrt_recip_alphas']).to(device), # [T]
1032 | dim = -1,
1033 | index = step
1034 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1]
1035 | # Compute posterior mean
1036 | mean_t = sqrt_recip_alphas_t * (
1037 | x_t - betas_t*eps_t/sqrt_one_minus_alphas_bar_t
1038 | ) # [n_sample x C x W x H]
1039 | # Compute posterior variance
1040 | posterior_variance_t = th.gather(
1041 | input = th.from_numpy(dc['posterior_variance']).to(device), # [T]
1042 | dim = -1,
1043 | index = step
1044 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1]
1045 | # Sample
1046 | if t == 0: # last sampling, use mean
1047 | x_t = mean_t
1048 | else:
1049 | _,noise_t = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x W x H]
1050 | x_t = mean_t + noise_scale*th.sqrt(posterior_variance_t)*noise_t
1051 | # Append
1052 | if t in step_list_to_append:
1053 | x_t_list[t] = x_t
1054 | model.train()
1055 | return x_t_list # list of [n_sample x C x W x H]
--------------------------------------------------------------------------------
/code/diffusion_resblock.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "00d7ebaa",
6 | "metadata": {},
7 | "source": [
8 | "### Resdual block for diffusion models"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "c37dccb5",
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "PyTorch version:[2.0.1].\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "import numpy as np\n",
27 | "import matplotlib.pyplot as plt\n",
28 | "import torch as th\n",
29 | "import torch.nn as nn\n",
30 | "import torch.nn.functional as F\n",
31 | "from module import (\n",
32 | " ResBlock\n",
33 | ")\n",
34 | "from dataset import mnist\n",
35 | "from util import get_torch_size_string,plot_4x4_torch_tensor\n",
36 | "np.set_printoptions(precision=3)\n",
37 | "th.set_printoptions(precision=3)\n",
38 | "%matplotlib inline\n",
39 | "%config InlineBackend.figure_format='retina'\n",
40 | "print (\"PyTorch version:[%s].\"%(th.__version__))"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "id": "257ff26d",
46 | "metadata": {},
47 | "source": [
48 | "### 1-D case `[B x C x L]`"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 5,
54 | "id": "abd8843c",
55 | "metadata": {
56 | "scrolled": true
57 | },
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "1. No upsample nor downsample\n",
64 | " Shape x:[16x32x200] => out:[16x32x400]\n",
65 | "2. Upsample\n",
66 | " Shape x:[16x32x200] => out:[16x32x400]\n",
67 | "3. Downsample\n",
68 | " Shape x:[16x32x200] => out:[16x32x100]\n"
69 | ]
70 | }
71 | ],
72 | "source": [
73 | "# Input\n",
74 | "x = th.randn(16,32,200) # [B x C x L]\n",
75 | "emb = th.randn(16,128) # [B x n_emb_channels]\n",
76 | "\n",
77 | "print (\"1. No upsample nor downsample\")\n",
78 | "resblock = ResBlock(\n",
79 | " n_channels = 32,\n",
80 | " n_emb_channels = 128,\n",
81 | " n_out_channels = 32,\n",
82 | " n_groups = 16,\n",
83 | " dims = 1,\n",
84 | " upsample = True,\n",
85 | " downsample = False,\n",
86 | " down_rate = 1\n",
87 | ")\n",
88 | "out = resblock(x,emb)\n",
89 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
90 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
91 | "\n",
92 | "print (\"2. Upsample\")\n",
93 | "resblock = ResBlock(\n",
94 | " n_channels = 32,\n",
95 | " n_emb_channels = 128,\n",
96 | " n_out_channels = 32,\n",
97 | " n_groups = 16,\n",
98 | " dims = 1,\n",
99 | " upsample = True,\n",
100 | " downsample = False,\n",
101 | " down_rate = 2\n",
102 | ")\n",
103 | "out = resblock(x,emb)\n",
104 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
105 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
106 | "\n",
107 | "print (\"3. Downsample\")\n",
108 | "resblock = ResBlock(\n",
109 | " n_channels = 32,\n",
110 | " n_emb_channels = 128,\n",
111 | " n_out_channels = 32,\n",
112 | " n_groups = 16,\n",
113 | " dims = 1,\n",
114 | " upsample = False,\n",
115 | " downsample = True,\n",
116 | " down_rate = 2\n",
117 | ")\n",
118 | "out = resblock(x,emb)\n",
119 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
120 | " (get_torch_size_string(x),get_torch_size_string(out)))"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "id": "01d77ea8",
126 | "metadata": {},
127 | "source": [
128 | "### 2-D case `[B x C x W x H]`"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 3,
134 | "id": "890df0a2",
135 | "metadata": {},
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | "1. No upsample nor downsample\n",
142 | " Shape x:[16x32x28x28] => out:[16x32x28x28]\n",
143 | "2. Upsample\n",
144 | " Shape x:[16x32x28x28] => out:[16x32x56x56]\n",
145 | "3. Downsample\n",
146 | " Shape x:[16x32x28x28] => out:[16x32x14x14]\n",
147 | "4. (uneven) Upsample\n",
148 | " Shape x:[16x32x28x28] => out:[16x32x56x28]\n",
149 | "5. (uneven) Downsample\n",
150 | " Shape x:[16x32x28x28] => out:[16x32x14x28]\n",
151 | "6. (fake) Downsample\n",
152 | " Shape x:[16x32x28x28] => out:[16x32x28x28]\n"
153 | ]
154 | }
155 | ],
156 | "source": [
157 | "# Input\n",
158 | "x = th.randn(16,32,28,28) # [B x C x W x H]\n",
159 | "emb = th.randn(16,128) # [B x n_emb_channels]\n",
160 | "\n",
161 | "print (\"1. No upsample nor downsample\")\n",
162 | "resblock = ResBlock(\n",
163 | " n_channels = 32,\n",
164 | " n_emb_channels = 128,\n",
165 | " n_out_channels = 32,\n",
166 | " n_groups = 16,\n",
167 | " dims = 2,\n",
168 | " upsample = False,\n",
169 | " downsample = False\n",
170 | ")\n",
171 | "out = resblock(x,emb)\n",
172 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
173 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
174 | "\n",
175 | "print (\"2. Upsample\")\n",
176 | "resblock = ResBlock(\n",
177 | " n_channels = 32,\n",
178 | " n_emb_channels = 128,\n",
179 | " n_out_channels = 32,\n",
180 | " n_groups = 16,\n",
181 | " dims = 2,\n",
182 | " upsample = True,\n",
183 | " downsample = False\n",
184 | ")\n",
185 | "out = resblock(x,emb)\n",
186 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
187 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
188 | "\n",
189 | "print (\"3. Downsample\")\n",
190 | "resblock = ResBlock(\n",
191 | " n_channels = 32,\n",
192 | " n_emb_channels = 128,\n",
193 | " n_out_channels = 32,\n",
194 | " n_groups = 16,\n",
195 | " dims = 2,\n",
196 | " upsample = False,\n",
197 | " downsample = True\n",
198 | ")\n",
199 | "out = resblock(x,emb)\n",
200 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
201 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
202 | "\n",
203 | "print (\"4. (uneven) Upsample\")\n",
204 | "resblock = ResBlock(\n",
205 | " n_channels = 32,\n",
206 | " n_emb_channels = 128,\n",
207 | " n_out_channels = 32,\n",
208 | " n_groups = 16,\n",
209 | " dims = 2,\n",
210 | " upsample = True,\n",
211 | " downsample = False,\n",
212 | " up_rate = (2,1)\n",
213 | ")\n",
214 | "out = resblock(x,emb)\n",
215 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
216 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
217 | "\n",
218 | "print (\"5. (uneven) Downsample\")\n",
219 | "resblock = ResBlock(\n",
220 | " n_channels = 32,\n",
221 | " n_emb_channels = 128,\n",
222 | " n_out_channels = 32,\n",
223 | " n_groups = 16,\n",
224 | " dims = 2,\n",
225 | " upsample = False,\n",
226 | " downsample = True,\n",
227 | " down_rate = (2,1)\n",
228 | ")\n",
229 | "out = resblock(x,emb)\n",
230 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
231 | " (get_torch_size_string(x),get_torch_size_string(out)))\n",
232 | "\n",
233 | "print (\"6. (fake) Downsample\")\n",
234 | "resblock = ResBlock(\n",
235 | " n_channels = 32,\n",
236 | " n_emb_channels = 128,\n",
237 | " n_out_channels = 32,\n",
238 | " n_groups = 16,\n",
239 | " dims = 2,\n",
240 | " upsample = False,\n",
241 | " downsample = True,\n",
242 | " down_rate = (1,1)\n",
243 | ")\n",
244 | "out = resblock(x,emb)\n",
245 | "print (\" Shape x:[%s] => out:[%s]\"%\n",
246 | " (get_torch_size_string(x),get_torch_size_string(out)))"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "id": "184ebe32",
253 | "metadata": {},
254 | "outputs": [],
255 | "source": []
256 | }
257 | ],
258 | "metadata": {
259 | "kernelspec": {
260 | "display_name": "Python 3 (ipykernel)",
261 | "language": "python",
262 | "name": "python3"
263 | },
264 | "language_info": {
265 | "codemirror_mode": {
266 | "name": "ipython",
267 | "version": 3
268 | },
269 | "file_extension": ".py",
270 | "mimetype": "text/x-python",
271 | "name": "python",
272 | "nbconvert_exporter": "python",
273 | "pygments_lexer": "ipython3",
274 | "version": "3.9.16"
275 | }
276 | },
277 | "nbformat": 4,
278 | "nbformat_minor": 5
279 | }
280 |
--------------------------------------------------------------------------------
/code/diffusion_unet_legacy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e331df79-2400-4f60-aca2-15debac3e5de",
6 | "metadata": {},
7 | "source": [
8 | "### U-net Lagacy"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "a1021c53-7ada-41a7-86fc-ce2ec7faada5",
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "PyTorch version:[2.0.1].\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "import numpy as np\n",
27 | "import matplotlib.pyplot as plt\n",
28 | "import torch as th\n",
29 | "import torch.nn as nn\n",
30 | "from util import (\n",
31 | " get_torch_size_string\n",
32 | ")\n",
33 | "from diffusion import (\n",
34 | " get_ddpm_constants,\n",
35 | " plot_ddpm_constants,\n",
36 | " DiffusionUNet,\n",
37 | " DiffusionUNetLegacy\n",
38 | ")\n",
39 | "from dataset import mnist\n",
40 | "np.set_printoptions(precision=3)\n",
41 | "th.set_printoptions(precision=3)\n",
42 | "%matplotlib inline\n",
43 | "%config InlineBackend.figure_format='retina'\n",
44 | "print (\"PyTorch version:[%s].\"%(th.__version__))"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "id": "bc376285-0fa0-4a0b-aacb-f5b49b0e2f5e",
51 | "metadata": {},
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "device:[mps]\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "device = 'mps'\n",
63 | "print (\"device:[%s]\"%(device))"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "id": "7c84b758-2535-439c-936b-1a20397d7598",
70 | "metadata": {},
71 | "outputs": [
72 | {
73 | "name": "stdout",
74 | "output_type": "stream",
75 | "text": [
76 | "Ready.\n"
77 | ]
78 | }
79 | ],
80 | "source": [
81 | "dc = get_ddpm_constants(\n",
82 | " schedule_name = 'cosine', # 'linear', 'cosine'\n",
83 | " T = 1000,\n",
84 | " np_type = np.float32)\n",
85 | "print(\"Ready.\") "
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "id": "d8641679-0ee9-4f4b-9248-23e0c2e0de47",
91 | "metadata": {},
92 | "source": [
93 | "### Guided U-net\n",
94 | "
"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "id": "08f6f417-81a6-49fe-817c-663acd47414a",
100 | "metadata": {},
101 | "source": [
102 | "### 1-D case: `[B x C x L]` with attention"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 4,
108 | "id": "57655fad-1bcb-4a5c-9626-ded7373e2b9e",
109 | "metadata": {},
110 | "outputs": [
111 | {
112 | "name": "stdout",
113 | "output_type": "stream",
114 | "text": [
115 | "Input: x:[2x3x256] timesteps:[2]\n",
116 | "Output: out:[2x3x256]\n",
117 | "[ 0] key:[ x] shape:[ 2x3x256]\n",
118 | "[ 1] key:[ x_lifted] shape:[ 2x32x256]\n",
119 | "[ 2] key:[h_enc_res_00] shape:[ 2x32x128]\n",
120 | "[ 3] key:[h_enc_att_01] shape:[ 2x32x128]\n",
121 | "[ 4] key:[h_enc_res_02] shape:[ 2x64x64]\n",
122 | "[ 5] key:[h_enc_att_03] shape:[ 2x64x64]\n",
123 | "[ 6] key:[h_enc_res_04] shape:[ 2x128x32]\n",
124 | "[ 7] key:[h_enc_att_05] shape:[ 2x128x32]\n",
125 | "[ 8] key:[h_enc_res_06] shape:[ 2x256x16]\n",
126 | "[ 9] key:[h_enc_att_07] shape:[ 2x256x16]\n",
127 | "[10] key:[h_dec_res_00] shape:[ 2x256x32]\n",
128 | "[11] key:[h_dec_att_01] shape:[ 2x256x32]\n",
129 | "[12] key:[h_dec_res_02] shape:[ 2x128x64]\n",
130 | "[13] key:[h_dec_att_03] shape:[ 2x128x64]\n",
131 | "[14] key:[h_dec_res_04] shape:[ 2x64x128]\n",
132 | "[15] key:[h_dec_att_05] shape:[ 2x64x128]\n",
133 | "[16] key:[h_dec_res_06] shape:[ 2x32x256]\n",
134 | "[17] key:[h_dec_att_07] shape:[ 2x32x256]\n",
135 | "[18] key:[ out] shape:[ 2x3x256]\n"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "unet = DiffusionUNetLegacy(\n",
141 | " name = 'unet',\n",
142 | " dims = 1,\n",
143 | " n_in_channels = 3,\n",
144 | " n_base_channels = 32,\n",
145 | " n_emb_dim = 128,\n",
146 | " n_enc_blocks = 4, # number of encoder blocks\n",
147 | " n_dec_blocks = 4, # number of decoder blocks\n",
148 | " n_groups = 16, # group norm paramter\n",
149 | " use_attention = True,\n",
150 | " skip_connection = True, # additional skip connection\n",
151 | " chnnel_multiples = (1,2,4,8),\n",
152 | " updown_rates = (2,2,2,2),\n",
153 | " device = device,\n",
154 | ")\n",
155 | "# Inputs, timesteps:[B] and x:[B x C x L]\n",
156 | "batch_size = 2\n",
157 | "x = th.randn(batch_size,3,256).to(device) # [B x C x L]\n",
158 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n",
159 | "out,intermediate_output_dict = unet(x,timesteps)\n",
160 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n",
161 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n",
162 | "))\n",
163 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n",
164 | "# Print intermediate layers\n",
165 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n",
166 | " z = intermediate_output_dict[key]\n",
167 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))"
168 | ]
169 | },
170 | {
171 | "cell_type": "markdown",
172 | "id": "05572674-6d35-46a6-a880-d41e4b4e5111",
173 | "metadata": {},
174 | "source": [
175 | "### 1-D case: `[B x C x L]` without attention"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 5,
181 | "id": "6962a40a-0c14-4c4a-a024-06b64b076df5",
182 | "metadata": {},
183 | "outputs": [
184 | {
185 | "name": "stdout",
186 | "output_type": "stream",
187 | "text": [
188 | "Input: x:[2x3x256] timesteps:[2]\n",
189 | "Output: out:[2x3x256]\n",
190 | "[ 0] key:[ x] shape:[ 2x3x256]\n",
191 | "[ 1] key:[ x_lifted] shape:[ 2x32x256]\n",
192 | "[ 2] key:[h_enc_res_00] shape:[ 2x32x128]\n",
193 | "[ 3] key:[h_enc_res_01] shape:[ 2x64x64]\n",
194 | "[ 4] key:[h_enc_res_02] shape:[ 2x128x32]\n",
195 | "[ 5] key:[h_enc_res_03] shape:[ 2x256x16]\n",
196 | "[ 6] key:[h_dec_res_00] shape:[ 2x256x32]\n",
197 | "[ 7] key:[h_dec_res_01] shape:[ 2x128x64]\n",
198 | "[ 8] key:[h_dec_res_02] shape:[ 2x64x128]\n",
199 | "[ 9] key:[h_dec_res_03] shape:[ 2x32x256]\n",
200 | "[10] key:[ out] shape:[ 2x3x256]\n"
201 | ]
202 | }
203 | ],
204 | "source": [
205 | "unet = DiffusionUNetLegacy(\n",
206 | " name = 'unet',\n",
207 | " dims = 1,\n",
208 | " n_in_channels = 3,\n",
209 | " n_base_channels = 32,\n",
210 | " n_emb_dim = 128,\n",
211 | " n_enc_blocks = 4, # number of encoder blocks\n",
212 | " n_dec_blocks = 4, # number of decoder blocks\n",
213 | " n_groups = 16, # group norm paramter\n",
214 | " use_attention = False,\n",
215 | " skip_connection = True, # additional skip connection\n",
216 | " chnnel_multiples = (1,2,4,8),\n",
217 | " updown_rates = (2,2,2,2),\n",
218 | " device = device,\n",
219 | ")\n",
220 | "# Inputs, timesteps:[B] and x:[B x C x L]\n",
221 | "batch_size = 2\n",
222 | "x = th.randn(batch_size,3,256).to(device) # [B x C x L]\n",
223 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n",
224 | "out,intermediate_output_dict = unet(x,timesteps)\n",
225 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n",
226 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n",
227 | "))\n",
228 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n",
229 | "# Print intermediate layers\n",
230 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n",
231 | " z = intermediate_output_dict[key]\n",
232 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "id": "1881fc4e-d713-4c24-955c-85e3c0ece8ff",
239 | "metadata": {},
240 | "outputs": [],
241 | "source": []
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "id": "c7143a40-7d2f-46e6-a6b6-a089a1227bd7",
246 | "metadata": {},
247 | "source": [
248 | "### 2-D case: `[B x C x W x H]` without attention"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 6,
254 | "id": "0c3da514-b189-4c01-ba6f-1f2e8daa24b0",
255 | "metadata": {},
256 | "outputs": [
257 | {
258 | "name": "stdout",
259 | "output_type": "stream",
260 | "text": [
261 | "Input: x:[2x3x256x256] timesteps:[2]\n",
262 | "Output: out:[2x3x256x256]\n",
263 | "[ 0] key:[ x] shape:[ 2x3x256x256]\n",
264 | "[ 1] key:[ x_lifted] shape:[2x32x256x256]\n",
265 | "[ 2] key:[h_enc_res_00] shape:[2x32x256x256]\n",
266 | "[ 3] key:[h_enc_res_01] shape:[2x64x256x256]\n",
267 | "[ 4] key:[h_enc_res_02] shape:[2x128x256x256]\n",
268 | "[ 5] key:[h_dec_res_00] shape:[2x128x256x256]\n",
269 | "[ 6] key:[h_dec_res_01] shape:[2x64x256x256]\n",
270 | "[ 7] key:[h_dec_res_02] shape:[2x32x256x256]\n",
271 | "[ 8] key:[ out] shape:[ 2x3x256x256]\n"
272 | ]
273 | }
274 | ],
275 | "source": [
276 | "unet = DiffusionUNetLegacy(\n",
277 | " name = 'unet',\n",
278 | " dims = 2,\n",
279 | " n_in_channels = 3,\n",
280 | " n_base_channels = 32,\n",
281 | " n_emb_dim = 128,\n",
282 | " n_enc_blocks = 3, # number of encoder blocks\n",
283 | " n_dec_blocks = 3, # number of decoder blocks\n",
284 | " n_groups = 16, # group norm paramter\n",
285 | " use_attention = False,\n",
286 | " skip_connection = True, # additional skip connection\n",
287 | " chnnel_multiples = (1,2,4),\n",
288 | " updown_rates = (1,1,1),\n",
289 | " device = device,\n",
290 | ")\n",
291 | "# Inputs, timesteps:[B] and x:[B x C x W x H]\n",
292 | "batch_size = 2\n",
293 | "x = th.randn(batch_size,3,256,256).to(device) # [B x C x W x H]\n",
294 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n",
295 | "out,intermediate_output_dict = unet(x,timesteps)\n",
296 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n",
297 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n",
298 | "))\n",
299 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n",
300 | "# Print intermediate layers\n",
301 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n",
302 | " z = intermediate_output_dict[key]\n",
303 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))"
304 | ]
305 | },
306 | {
307 | "cell_type": "markdown",
308 | "id": "66e7d820-fc72-41c3-88da-09e70147c3e0",
309 | "metadata": {},
310 | "source": [
311 | "### 2-D case: `[B x C x W x H]` without attention + updown pooling"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 7,
317 | "id": "12680faa-9bf0-4ef9-b65f-08952641dfa3",
318 | "metadata": {},
319 | "outputs": [
320 | {
321 | "name": "stdout",
322 | "output_type": "stream",
323 | "text": [
324 | "Input: x:[2x3x256x256] timesteps:[2]\n",
325 | "Output: out:[2x3x256x256]\n",
326 | "[ 0] key:[ x] shape:[ 2x3x256x256]\n",
327 | "[ 1] key:[ x_lifted] shape:[2x32x256x256]\n",
328 | "[ 2] key:[h_enc_res_00] shape:[2x32x256x256]\n",
329 | "[ 3] key:[h_enc_res_01] shape:[2x64x128x128]\n",
330 | "[ 4] key:[h_enc_res_02] shape:[ 2x128x64x64]\n",
331 | "[ 5] key:[h_dec_res_00] shape:[2x128x128x128]\n",
332 | "[ 6] key:[h_dec_res_01] shape:[2x64x256x256]\n",
333 | "[ 7] key:[h_dec_res_02] shape:[2x32x256x256]\n",
334 | "[ 8] key:[ out] shape:[ 2x3x256x256]\n"
335 | ]
336 | }
337 | ],
338 | "source": [
339 | "unet = DiffusionUNetLegacy(\n",
340 | " name = 'unet',\n",
341 | " dims = 2,\n",
342 | " n_in_channels = 3,\n",
343 | " n_base_channels = 32,\n",
344 | " n_emb_dim = 128,\n",
345 | " n_enc_blocks = 3, # number of encoder blocks\n",
346 | " n_dec_blocks = 3, # number of decoder blocks\n",
347 | " n_groups = 16, # group norm paramter\n",
348 | " use_attention = False,\n",
349 | " skip_connection = True, # additional skip connection\n",
350 | " chnnel_multiples = (1,2,4),\n",
351 | " updown_rates = (1,2,2),\n",
352 | " device = device,\n",
353 | ")\n",
354 | "# Inputs, timesteps:[B] and x:[B x C x W x H]\n",
355 | "batch_size = 2\n",
356 | "x = th.randn(batch_size,3,256,256).to(device) # [B x C x W x H]\n",
357 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n",
358 | "out,intermediate_output_dict = unet(x,timesteps)\n",
359 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n",
360 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n",
361 | "))\n",
362 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n",
363 | "# Print intermediate layers\n",
364 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n",
365 | " z = intermediate_output_dict[key]\n",
366 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "execution_count": null,
372 | "id": "b9bf9fc1-90f1-4700-a4b1-360f730d0e51",
373 | "metadata": {},
374 | "outputs": [],
375 | "source": []
376 | }
377 | ],
378 | "metadata": {
379 | "kernelspec": {
380 | "display_name": "Python 3 (ipykernel)",
381 | "language": "python",
382 | "name": "python3"
383 | },
384 | "language_info": {
385 | "codemirror_mode": {
386 | "name": "ipython",
387 | "version": 3
388 | },
389 | "file_extension": ".py",
390 | "mimetype": "text/x-python",
391 | "name": "python",
392 | "nbconvert_exporter": "python",
393 | "pygments_lexer": "ipython3",
394 | "version": "3.9.16"
395 | }
396 | },
397 | "nbformat": 4,
398 | "nbformat_minor": 5
399 | }
400 |
--------------------------------------------------------------------------------
/code/mdn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | import torch.nn as nn
4 | import matplotlib.pyplot as plt
5 | from util import th2np
6 |
7 | def get_argmax_mu(pi,mu):
8 | """
9 | :param pi: [N x K x D]
10 | :param mu: [N x K x D]
11 | """
12 | max_idx = th.argmax(pi,dim=1) # [N x D]
13 | argmax_mu = th.gather(input=mu,dim=1,index=max_idx.unsqueeze(dim=1)).squeeze(dim=1) # [N x D]
14 | return argmax_mu
15 |
16 | def gmm_forward(pi,mu,sigma,data):
17 | """
18 | Compute Gaussian mixture model probability
19 |
20 | :param pi: GMM mixture weights [N x K x D]
21 | :param mu: GMM means [N x K x D]
22 | :param sigma: GMM stds [N x K x D]
23 | :param data: data [N x D]
24 | """
25 | data_usq = th.unsqueeze(data,1) # [N x 1 x D]
26 | data_exp = data_usq.expand_as(sigma) # [N x K x D]
27 | ONEOVERSQRT2PI = 1/np.sqrt(2*np.pi)
28 | probs = ONEOVERSQRT2PI * th.exp(-0.5 * ((data_exp-mu)/sigma)**2) / sigma # [N x K x D]
29 | probs = probs*pi # [N x K x D]
30 | probs = th.sum(probs,dim=1) # [N x D]
31 | log_probs = th.log(probs) # [N x D]
32 | log_probs = th.sum(log_probs,dim=1) # [N]
33 | nlls = -log_probs # [N]
34 |
35 | # Most probable mean [N x D]
36 | argmax_mu = get_argmax_mu(pi,mu) # [N x D]
37 |
38 | out = {
39 | 'data_usq':data_usq,'data_exp':data_exp,
40 | 'probs':probs,'log_probs':log_probs,'nlls':nlls,'argmax_mu':argmax_mu
41 | }
42 | return out
43 |
44 | def gmm_uncertainties(pi, mu, sigma):
45 | """
46 | :param pi: [N x K x D]
47 | :param mu: [N x K x D]
48 | :param sigma: [N x K x D]
49 | """
50 | # Compute Epistemic Uncertainty
51 | mu_avg = th.sum(th.mul(pi,mu),dim=1).unsqueeze(1) # [N x 1 x D]
52 | mu_exp = mu_avg.expand_as(mu) # [N x K x D]
53 | mu_diff_sq = th.square(mu-mu_exp) # [N x K x D]
54 | epis_unct = th.sum(th.mul(pi,mu_diff_sq), dim=1) # [N x D]
55 |
56 | # Compute Aleatoric Uncertainty
57 | alea_unct = th.sum(th.mul(pi,sigma), dim=1) # [N x D]
58 |
59 | # (Optional) sqrt operation helps scaling
60 | epis_unct,alea_unct = th.sqrt(epis_unct),th.sqrt(alea_unct)
61 | return epis_unct,alea_unct
62 |
63 | class MixturesOfGaussianLayer(nn.Module):
64 | def __init__(
65 | self,
66 | in_dim,
67 | out_dim,
68 | k,
69 | sig_max=None
70 | ):
71 | super(MixturesOfGaussianLayer,self).__init__()
72 | self.in_dim = in_dim
73 | self.out_dim = out_dim
74 | self.k = k
75 | self.sig_max = sig_max
76 |
77 | # Netowrks
78 | self.fc_pi = nn.Linear(self.in_dim,self.k*self.out_dim)
79 | self.fc_mu = nn.Linear(self.in_dim,self.k*self.out_dim)
80 | self.fc_sigma = nn.Linear(self.in_dim,self.k*self.out_dim)
81 |
82 | def forward(self,x):
83 | pi_logit = self.fc_pi(x) # [N x KD]
84 | pi_logit = th.reshape(pi_logit,(-1,self.k,self.out_dim)) # [N x K x D]
85 | pi = th.softmax(pi_logit,dim=1) # [N x K x D]
86 | mu = self.fc_mu(x) # [N x KD]
87 | mu = th.reshape(mu,(-1,self.k,self.out_dim)) # [N x K x D]
88 | sigma = self.fc_sigma(x) # [N x KD]
89 | sigma = th.reshape(sigma,(-1,self.k,self.out_dim)) # [N x K x D]
90 | if self.sig_max is None:
91 | sigma = th.exp(sigma) # [N x K x D]
92 | else:
93 | sigma = self.sig_max * th.sigmoid(sigma) # [N x K x D]
94 | return pi,mu,sigma
95 |
96 | class MixtureDensityNetwork(nn.Module):
97 | def __init__(
98 | self,
99 | name = 'mdn',
100 | x_dim = 1,
101 | y_dim = 1,
102 | k = 5,
103 | h_dim_list = [32,32],
104 | actv = nn.ReLU(),
105 | sig_max = 1.0,
106 | mu_min = -3.0,
107 | mu_max = +3.0,
108 | p_drop = 0.1,
109 | use_bn = False,
110 | ):
111 | super(MixtureDensityNetwork,self).__init__()
112 | self.name = name
113 | self.x_dim = x_dim
114 | self.y_dim = y_dim
115 | self.k = k
116 | self.h_dim_list = h_dim_list
117 | self.actv = actv
118 | self.sig_max = sig_max
119 | self.mu_min = mu_min
120 | self.mu_max = mu_max
121 | self.p_drop = p_drop
122 | self.use_bn = use_bn
123 |
124 | # Declare layers
125 | self.layer_list = []
126 | h_dim_prev = self.x_dim
127 | for h_dim in self.h_dim_list:
128 | # dense -> batchnorm -> actv -> dropout
129 | self.layer_list.append(nn.Linear(h_dim_prev,h_dim))
130 | if self.use_bn: self.layer_list.append(nn.BatchNorm1d(num_features=h_dim)) # (optional) batchnorm
131 | self.layer_list.append(self.actv)
132 | self.layer_list.append(nn.Dropout1d(p=self.p_drop))
133 | h_dim_prev = h_dim
134 | self.layer_list.append(
135 | MixturesOfGaussianLayer(
136 | in_dim = h_dim_prev,
137 | out_dim = self.y_dim,
138 | k = self.k,
139 | sig_max = self.sig_max
140 | )
141 | )
142 |
143 | # Definie network
144 | self.net = nn.Sequential()
145 | self.layer_names = []
146 | for l_idx,layer in enumerate(self.layer_list):
147 | layer_name = "%s_%02d"%(type(layer).__name__.lower(),l_idx)
148 | self.layer_names.append(layer_name)
149 | self.net.add_module(layer_name,layer)
150 |
151 | # Initialize parameters
152 | self.init_param(VERBOSE=False)
153 |
154 | def init_param(self,VERBOSE=False):
155 | """
156 | Initialize parameters
157 | """
158 | for m_idx,m in enumerate(self.modules()):
159 | if VERBOSE:
160 | print ("[%02d]"%(m_idx))
161 | if isinstance(m,nn.Conv2d): # init conv
162 | nn.init.kaiming_normal_(m.weight)
163 | nn.init.zeros_(m.bias)
164 | elif isinstance(m,nn.BatchNorm1d): # init BN
165 | nn.init.constant_(m.weight,1.0)
166 | nn.init.constant_(m.bias,0.0)
167 | elif isinstance(m,nn.Linear): # lnit dense
168 | nn.init.kaiming_normal_(m.weight,nonlinearity='relu')
169 | nn.init.zeros_(m.bias)
170 | # (Hueristics) mu bias between mu_min ~ mu_max
171 | self.layer_list[-1].fc_mu.bias.data.uniform_(self.mu_min,self.mu_max)
172 |
173 | def forward(self, x):
174 | """
175 | Forward propagate
176 | """
177 | intermediate_output_list = []
178 | for idx,layer in enumerate(self.net):
179 | x = layer(x)
180 | intermediate_output_list.append(x)
181 | # Final output
182 | final_output = x
183 | return final_output
184 |
185 | def eval_mdn_1d(
186 | mdn,
187 | x_train_np,
188 | y_train_np,
189 | figsize=(12,3),
190 | device='cpu',
191 | ):
192 | # Eval
193 | mdn.eval()
194 | x_margin = 0.2
195 | x_test = th.linspace(
196 | start = x_train_np.min()-x_margin,
197 | end = x_train_np.max()+x_margin,
198 | steps = 300
199 | ).reshape((-1,1)).to(device)
200 | pi_test,mu_test,sigma_test = mdn.forward(x_test)
201 |
202 | # Get the most probable mu
203 | argmax_mu_test = get_argmax_mu(pi_test,mu_test) # [N x D]
204 |
205 | # To numpy array
206 | x_test_np,pi_np,mu_np,sigma_np = th2np(x_test),th2np(pi_test),th2np(mu_test),th2np(sigma_test)
207 | argmax_mu_test_np = th2np(argmax_mu_test) # [N x D]
208 |
209 | # Uncertainties
210 | epis_unct,alea_unct = gmm_uncertainties(pi_test,mu_test,sigma_test) # [N x D]
211 | epis_unct_np,alea_unct_np = th2np(epis_unct),th2np(alea_unct)
212 |
213 | # Plot fitted results
214 | y_dim = y_train_np.shape[1]
215 | plt.figure(figsize=figsize)
216 | cmap = plt.get_cmap('gist_rainbow')
217 | colors = [cmap(ii) for ii in np.linspace(0, 1, mdn.k)] # colors
218 | pi_th = 0.1
219 | for d_idx in range(y_dim): # for each output dimension
220 | plt.subplot(1,y_dim,d_idx+1)
221 | # Plot training data
222 | plt.plot(x_train_np,y_train_np[:,d_idx],'.',color=(0,0,0,0.2),markersize=3,
223 | label="Training Data")
224 | # Plot mixture standard deviations
225 | for k_idx in range(mdn.k): # for each mixture
226 | pi_high_idx = np.where(pi_np[:,k_idx,d_idx] > pi_th)[0]
227 | mu_k = mu_np[:,k_idx,d_idx]
228 | sigma_k = sigma_np[:,k_idx,d_idx]
229 | upper_bound = mu_k + 2*sigma_k
230 | lower_bound = mu_k - 2*sigma_k
231 | plt.fill_between(x_test_np[pi_high_idx,0].squeeze(),
232 | lower_bound[pi_high_idx],
233 | upper_bound[pi_high_idx],
234 | facecolor=colors[k_idx], interpolate=False, alpha=0.3)
235 | # Plot mixture means
236 | for k_idx in range(mdn.k): # for each mixture
237 | pi_high_idx = np.where(pi_np[:,k_idx,d_idx] > pi_th)[0] # [?,]
238 | pi_low_idx = np.where(pi_np[:,k_idx,d_idx] <= pi_th)[0] # [?,]
239 | plt.plot(x_test_np[pi_high_idx,0],mu_np[pi_high_idx,k_idx,d_idx],'-',
240 | color=colors[k_idx],linewidth=1/2)
241 | plt.plot(x_test_np[pi_low_idx,0],mu_np[pi_low_idx,k_idx,d_idx],'-',
242 | color=(0,0,0,0.3),linewidth=1/2)
243 |
244 | # Plot most probable mu
245 | plt.plot(x_test_np[:,0],argmax_mu_test_np[:,d_idx],'-',color='b',linewidth=2,
246 | label="Argmax Mu")
247 | plt.xlim(x_test_np.min(),x_test_np.max())
248 | plt.legend(loc='lower right',fontsize=8)
249 | plt.title("y_dim:[%d]"%(d_idx),fontsize=10)
250 | plt.show()
251 |
252 | # Plot uncertainties
253 | plt.figure(figsize=figsize)
254 | for d_idx in range(y_dim): # for each output dimension
255 | plt.subplot(1,y_dim,d_idx+1)
256 | plt.plot(x_test_np[:,0],epis_unct_np[:,d_idx],'-',color='r',linewidth=2,
257 | label="Epistemic Uncertainty")
258 | plt.plot(x_test_np[:,0],alea_unct_np[:,d_idx],'-',color='b',linewidth=2,
259 | label="Aleatoric Uncertainty")
260 | plt.xlim(x_test_np.min(),x_test_np.max())
261 | plt.legend(loc='lower right',fontsize=8)
262 | plt.title("y_dim:[%d]"%(d_idx),fontsize=10)
263 | plt.show()
264 |
265 | # Back to train
266 | mdn.train()
267 |
--------------------------------------------------------------------------------
/code/mlp.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "95fb33ad",
6 | "metadata": {},
7 | "source": [
8 | "### Multi-Layer Perceptron"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "f1e30a24",
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "PyTorch version:[2.0.1].\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "import numpy as np\n",
27 | "import matplotlib.pyplot as plt\n",
28 | "import torch as th\n",
29 | "import torch.nn as nn\n",
30 | "from dataset import mnist\n",
31 | "from util import (\n",
32 | " get_torch_size_string,\n",
33 | " print_model_parameters,\n",
34 | " print_model_layers,\n",
35 | " model_train,\n",
36 | " model_eval,\n",
37 | " model_test\n",
38 | ")\n",
39 | "np.set_printoptions(precision=3)\n",
40 | "th.set_printoptions(precision=3)\n",
41 | "%matplotlib inline\n",
42 | "%config InlineBackend.figure_format='retina'\n",
43 | "print (\"PyTorch version:[%s].\"%(th.__version__))"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "id": "bd42312f",
49 | "metadata": {},
50 | "source": [
51 | "### Hyperparameters"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 2,
57 | "id": "339cebfd",
58 | "metadata": {},
59 | "outputs": [
60 | {
61 | "name": "stdout",
62 | "output_type": "stream",
63 | "text": [
64 | "Ready.\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "device = 'cpu'\n",
70 | "n_epoch = 20\n",
71 | "batch_size = 128\n",
72 | "print_every = 1\n",
73 | "print (\"Ready.\")"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "id": "2aac096e",
79 | "metadata": {},
80 | "source": [
81 | "### Dataset"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 3,
87 | "id": "95ebd370",
88 | "metadata": {
89 | "scrolled": true
90 | },
91 | "outputs": [
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "MNIST ready.\n"
97 | ]
98 | }
99 | ],
100 | "source": [
101 | "train_iter,test_iter,train_data,train_label,test_data,test_label = \\\n",
102 | " mnist(root_path='../data',batch_size=128)\n",
103 | "print (\"MNIST ready.\")"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "id": "72dc751b",
109 | "metadata": {},
110 | "source": [
111 | "### Model"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 4,
117 | "id": "dfacec51",
118 | "metadata": {},
119 | "outputs": [
120 | {
121 | "name": "stdout",
122 | "output_type": "stream",
123 | "text": [
124 | "Ready.\n"
125 | ]
126 | }
127 | ],
128 | "source": [
129 | "class MultiLayerPerceptronClass(th.nn.Module):\n",
130 | " def __init__(\n",
131 | " self,\n",
132 | " name = 'mlp',\n",
133 | " x_dim = 784,\n",
134 | " h_dim_list = [256,256],\n",
135 | " y_dim = 10,\n",
136 | " actv = nn.ReLU(),\n",
137 | " p_drop = 0.2\n",
138 | " ):\n",
139 | " \"\"\"\n",
140 | " Initialize MLP\n",
141 | " \"\"\"\n",
142 | " super(MultiLayerPerceptronClass,self).__init__()\n",
143 | " self.name = name\n",
144 | " self.x_dim = x_dim\n",
145 | " self.h_dim_list = h_dim_list\n",
146 | " self.y_dim = y_dim\n",
147 | " self.actv = actv\n",
148 | " self.p_drop = p_drop\n",
149 | " \n",
150 | " # Declare layers\n",
151 | " self.layer_list = []\n",
152 | " h_dim_prev = self.x_dim\n",
153 | " for h_dim in self.h_dim_list:\n",
154 | " # dense -> batchnorm -> actv -> dropout\n",
155 | " self.layer_list.append(nn.Linear(h_dim_prev,h_dim))\n",
156 | " self.layer_list.append(nn.BatchNorm1d(num_features=h_dim))\n",
157 | " self.layer_list.append(self.actv)\n",
158 | " self.layer_list.append(nn.Dropout1d(p=self.p_drop))\n",
159 | " h_dim_prev = h_dim\n",
160 | " self.layer_list.append(nn.Linear(h_dim_prev,self.y_dim))\n",
161 | " \n",
162 | " # Define net\n",
163 | " self.net = nn.Sequential()\n",
164 | " self.layer_names = []\n",
165 | " for l_idx,layer in enumerate(self.layer_list):\n",
166 | " layer_name = \"%s_%02d\"%(type(layer).__name__.lower(),l_idx)\n",
167 | " self.layer_names.append(layer_name)\n",
168 | " self.net.add_module(layer_name,layer)\n",
169 | " \n",
170 | " # Initialize parameters\n",
171 | " self.init_param(VERBOSE=False)\n",
172 | " \n",
173 | " def init_param(self,VERBOSE=False):\n",
174 | " \"\"\"\n",
175 | " Initialize parameters\n",
176 | " \"\"\"\n",
177 | " for m_idx,m in enumerate(self.modules()):\n",
178 | " if VERBOSE:\n",
179 | " print (\"[%02d]\"%(m_idx))\n",
180 | " if isinstance(m,nn.Conv2d): # init conv\n",
181 | " nn.init.kaiming_normal_(m.weight)\n",
182 | " nn.init.zeros_(m.bias)\n",
183 | " elif isinstance(m,nn.BatchNorm1d): # init BN\n",
184 | " nn.init.constant_(m.weight,1.0)\n",
185 | " nn.init.constant_(m.bias,0.0)\n",
186 | " elif isinstance(m,nn.Linear): # lnit dense\n",
187 | " nn.init.kaiming_normal_(m.weight,nonlinearity='relu')\n",
188 | " nn.init.zeros_(m.bias)\n",
189 | " \n",
190 | " def forward(self,x):\n",
191 | " \"\"\"\n",
192 | " Forward propagate\n",
193 | " \"\"\"\n",
194 | " intermediate_output_list = []\n",
195 | " for idx,layer in enumerate(self.net):\n",
196 | " x = layer(x)\n",
197 | " intermediate_output_list.append(x)\n",
198 | " # Final output\n",
199 | " final_output = x\n",
200 | " return final_output,intermediate_output_list\n",
201 | " \n",
202 | "print (\"Ready.\") "
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 5,
208 | "id": "fd93b90d",
209 | "metadata": {},
210 | "outputs": [
211 | {
212 | "name": "stdout",
213 | "output_type": "stream",
214 | "text": [
215 | "Ready.\n"
216 | ]
217 | }
218 | ],
219 | "source": [
220 | "mlp = MultiLayerPerceptronClass(\n",
221 | " name = 'mlp',\n",
222 | " x_dim = 784,\n",
223 | " h_dim_list = [512,256],\n",
224 | " y_dim = 10,\n",
225 | " actv = nn.ReLU(),\n",
226 | " p_drop = 0.25\n",
227 | ").to(device)\n",
228 | "loss = nn.CrossEntropyLoss()\n",
229 | "optm = th.optim.Adam(mlp.parameters(),lr=1e-3)\n",
230 | "print (\"Ready.\")"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "id": "f6c04e91",
236 | "metadata": {},
237 | "source": [
238 | "### Print model parameters"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 6,
244 | "id": "2eafd659",
245 | "metadata": {},
246 | "outputs": [
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "[ 0] parameter:[ net.linear_00.weight] shape:[ 512x784] numel:[ 401408]\n",
252 | "[ 1] parameter:[ net.linear_00.bias] shape:[ 512] numel:[ 512]\n",
253 | "[ 2] parameter:[ net.batchnorm1d_01.weight] shape:[ 512] numel:[ 512]\n",
254 | "[ 3] parameter:[ net.batchnorm1d_01.bias] shape:[ 512] numel:[ 512]\n",
255 | "[ 4] parameter:[ net.linear_04.weight] shape:[ 256x512] numel:[ 131072]\n",
256 | "[ 5] parameter:[ net.linear_04.bias] shape:[ 256] numel:[ 256]\n",
257 | "[ 6] parameter:[ net.batchnorm1d_05.weight] shape:[ 256] numel:[ 256]\n",
258 | "[ 7] parameter:[ net.batchnorm1d_05.bias] shape:[ 256] numel:[ 256]\n",
259 | "[ 8] parameter:[ net.linear_08.weight] shape:[ 10x256] numel:[ 2560]\n",
260 | "[ 9] parameter:[ net.linear_08.bias] shape:[ 10] numel:[ 10]\n"
261 | ]
262 | }
263 | ],
264 | "source": [
265 | "print_model_parameters(mlp)"
266 | ]
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "id": "96310c96",
271 | "metadata": {},
272 | "source": [
273 | "### Print model layers"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 7,
279 | "id": "9161988d",
280 | "metadata": {},
281 | "outputs": [
282 | {
283 | "name": "stdout",
284 | "output_type": "stream",
285 | "text": [
286 | "batch_size:[16]\n",
287 | "[ ] layer:[ input] size:[ 16x784]\n",
288 | "[ 0] layer:[ linear_00] size:[ 16x512] numel:[ 8192]\n",
289 | "[ 1] layer:[ batchnorm1d_01] size:[ 16x512] numel:[ 8192]\n",
290 | "[ 2] layer:[ relu_02] size:[ 16x512] numel:[ 8192]\n",
291 | "[ 3] layer:[ dropout1d_03] size:[ 16x512] numel:[ 8192]\n",
292 | "[ 4] layer:[ linear_04] size:[ 16x256] numel:[ 4096]\n",
293 | "[ 5] layer:[ batchnorm1d_05] size:[ 16x256] numel:[ 4096]\n",
294 | "[ 6] layer:[ relu_06] size:[ 16x256] numel:[ 4096]\n",
295 | "[ 7] layer:[ dropout1d_07] size:[ 16x256] numel:[ 4096]\n",
296 | "[ 8] layer:[ linear_08] size:[ 16x10] numel:[ 160]\n"
297 | ]
298 | }
299 | ],
300 | "source": [
301 | "x_torch = th.randn(16,mlp.x_dim).to(device)\n",
302 | "print_model_layers(mlp,x_torch)"
303 | ]
304 | },
305 | {
306 | "cell_type": "markdown",
307 | "id": "5da9aaf1",
308 | "metadata": {},
309 | "source": [
310 | "### Train MLP"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 8,
316 | "id": "c73c3242",
317 | "metadata": {},
318 | "outputs": [
319 | {
320 | "name": "stdout",
321 | "output_type": "stream",
322 | "text": [
323 | "epoch:[ 0/20] loss:[1.137] train_accr:[0.9704] test_accr:[0.9631].\n",
324 | "epoch:[ 1/20] loss:[1.073] train_accr:[0.9839] test_accr:[0.9727].\n",
325 | "epoch:[ 2/20] loss:[1.050] train_accr:[0.9869] test_accr:[0.9746].\n",
326 | "epoch:[ 3/20] loss:[1.041] train_accr:[0.9889] test_accr:[0.9783].\n",
327 | "epoch:[ 4/20] loss:[1.037] train_accr:[0.9880] test_accr:[0.9772].\n",
328 | "epoch:[ 5/20] loss:[1.032] train_accr:[0.9929] test_accr:[0.9798].\n",
329 | "epoch:[ 6/20] loss:[1.024] train_accr:[0.9918] test_accr:[0.9754].\n",
330 | "epoch:[ 7/20] loss:[1.031] train_accr:[0.9926] test_accr:[0.9779].\n",
331 | "epoch:[ 8/20] loss:[1.028] train_accr:[0.9950] test_accr:[0.9779].\n",
332 | "epoch:[ 9/20] loss:[1.030] train_accr:[0.9955] test_accr:[0.9807].\n",
333 | "epoch:[10/20] loss:[1.017] train_accr:[0.9961] test_accr:[0.9812].\n",
334 | "epoch:[11/20] loss:[1.014] train_accr:[0.9968] test_accr:[0.9817].\n",
335 | "epoch:[12/20] loss:[1.016] train_accr:[0.9935] test_accr:[0.9752].\n",
336 | "epoch:[13/20] loss:[1.023] train_accr:[0.9964] test_accr:[0.9790].\n",
337 | "epoch:[14/20] loss:[1.022] train_accr:[0.9974] test_accr:[0.9805].\n",
338 | "epoch:[15/20] loss:[1.018] train_accr:[0.9972] test_accr:[0.9800].\n",
339 | "epoch:[16/20] loss:[1.005] train_accr:[0.9965] test_accr:[0.9794].\n",
340 | "epoch:[17/20] loss:[1.014] train_accr:[0.9973] test_accr:[0.9792].\n",
341 | "epoch:[18/20] loss:[1.012] train_accr:[0.9970] test_accr:[0.9794].\n",
342 | "epoch:[19/20] loss:[1.015] train_accr:[0.9984] test_accr:[0.9821].\n"
343 | ]
344 | }
345 | ],
346 | "source": [
347 | "model_train(mlp,optm,loss,train_iter,test_iter,n_epoch,print_every,device)"
348 | ]
349 | },
350 | {
351 | "cell_type": "markdown",
352 | "id": "0b8ec7e0",
353 | "metadata": {},
354 | "source": [
355 | "### Test MLP"
356 | ]
357 | },
358 | {
359 | "cell_type": "code",
360 | "execution_count": 9,
361 | "id": "619bd29c",
362 | "metadata": {},
363 | "outputs": [
364 | {
365 | "data": {
366 | "image/png": "",
367 | "text/plain": [
368 | ""
369 | ]
370 | },
371 | "metadata": {
372 | "image/png": {
373 | "height": 558,
374 | "width": 488
375 | }
376 | },
377 | "output_type": "display_data"
378 | }
379 | ],
380 | "source": [
381 | "model_test(mlp,test_data,test_label,device)"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "id": "b557e4ed",
388 | "metadata": {},
389 | "outputs": [],
390 | "source": []
391 | }
392 | ],
393 | "metadata": {
394 | "kernelspec": {
395 | "display_name": "Python 3 (ipykernel)",
396 | "language": "python",
397 | "name": "python3"
398 | },
399 | "language_info": {
400 | "codemirror_mode": {
401 | "name": "ipython",
402 | "version": 3
403 | },
404 | "file_extension": ".py",
405 | "mimetype": "text/x-python",
406 | "name": "python",
407 | "nbconvert_exporter": "python",
408 | "pygments_lexer": "ipython3",
409 | "version": "3.9.16"
410 | }
411 | },
412 | "nbformat": 4,
413 | "nbformat_minor": 5
414 | }
415 |
--------------------------------------------------------------------------------
/code/module.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch as th
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from abc import abstractmethod
6 |
7 | class TimestepBlock(nn.Module):
8 | """
9 | Any module where forward() takes timestep embeddings as a second argument.
10 | """
11 |
12 | @abstractmethod
13 | def forward(self, x, emb):
14 | """
15 | Apply the module to `x` given `emb` timestep embeddings.
16 | """
17 |
18 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
19 | """
20 | A sequential module that passes timestep embeddings to the children that
21 | support it as an extra input.
22 | """
23 |
24 | def forward(self, x, emb):
25 | for layer in self:
26 | if isinstance(layer, TimestepBlock):
27 | x = layer(x, emb)
28 | else:
29 | x = layer(x)
30 | return x
31 |
32 | def zero_module(module):
33 | """
34 | Zero out the parameters of a module and return it.
35 | """
36 | for p in module.parameters():
37 | p.detach().zero_()
38 | return module
39 |
40 | def conv_nd(dims, *args, **kwargs):
41 | """
42 | Create a 1D, 2D, or 3D convolution module.
43 | """
44 | if dims == 1:
45 | return nn.Conv1d(*args, **kwargs)
46 | elif dims == 2:
47 | return nn.Conv2d(*args, **kwargs)
48 | elif dims == 3:
49 | return nn.Conv3d(*args, **kwargs)
50 | raise ValueError(f"unsupported dimensions: {dims}")
51 |
52 | def avg_pool_nd(dims, *args, **kwargs):
53 | """
54 | Create a 1D, 2D, or 3D average pooling module.
55 | """
56 | if dims == 1:
57 | return nn.AvgPool1d(*args, **kwargs)
58 | elif dims == 2:
59 | return nn.AvgPool2d(*args, **kwargs)
60 | elif dims == 3:
61 | return nn.AvgPool3d(*args, **kwargs)
62 | raise ValueError(f"unsupported dimensions: {dims}")
63 |
64 | class GroupNorm32(nn.GroupNorm):
65 | def forward(self, x):
66 | return super().forward(x.float()).type(x.dtype)
67 |
68 | def normalization(n_channels,n_groups=1):
69 | """
70 | Make a standard normalization layer.
71 |
72 | :param n_channels: number of input channels.
73 | :param n_groups: number of groups. if this is 1, then it is identical to layernorm.
74 | :return: an nn.Module for normalization.
75 | """
76 | return GroupNorm32(num_groups=n_groups,num_channels=n_channels)
77 |
78 | class Upsample(nn.Module):
79 | """
80 | An upsampling layer with an optional convolution.
81 |
82 | :param n_channels: number of channels in the inputs and outputs.
83 | :param use_conv: a bool determining if a convolution is applied.
84 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
85 | upsampling occurs in the inner-two dimensions.
86 | """
87 |
88 | def __init__(
89 | self,
90 | n_channels,
91 | up_rate = 2, # upsample rate
92 | up_mode = 'nearest', # upsample mode ('nearest' or 'bilinear')
93 | use_conv = False, # (optional) use output conv
94 | dims = 2, # (optional) spatial dimension
95 | n_out_channels = None, # (optional) in case output channels are different from the input
96 | padding_mode = 'zeros',
97 | padding = 1
98 | ):
99 | super().__init__()
100 | self.n_channels = n_channels
101 | self.up_rate = up_rate
102 | self.up_mode = up_mode
103 | self.use_conv = use_conv
104 | self.dims = dims
105 | self.n_out_channels = n_out_channels or n_channels
106 | self.padding_mode = padding_mode;
107 | self.padding = padding
108 |
109 | if use_conv:
110 | self.conv = conv_nd(
111 | dims = dims,
112 | in_channels = self.n_channels,
113 | out_channels = self.n_out_channels,
114 | kernel_size = 3,
115 | padding = padding,
116 | padding_mode = padding_mode)
117 |
118 | def forward(self, x):
119 | """
120 | :param x: [B x C x W x H]
121 | :return: [B x C x 2W x 2H]
122 | """
123 | assert x.shape[1] == self.n_channels
124 | if self.dims == 3: # 3D convolution
125 | x = F.interpolate(
126 | input = x,
127 | size = (x.shape[2],x.shape[3]*2,x.shape[4]*2),
128 | mode = self.up_mode
129 | )
130 | else:
131 | x = F.interpolate(
132 | input = x,
133 | scale_factor = self.up_rate,
134 | mode = self.up_mode
135 | ) # 'nearest' or 'bilinear'
136 |
137 | # (optional) final convolution
138 | if self.use_conv:
139 | x = self.conv(x)
140 | return x
141 |
142 | class Downsample(nn.Module):
143 | """
144 | A downsampling layer with an optional convolution.
145 |
146 | :param channels: channels in the inputs and outputs.
147 | :param use_conv: a bool determining if a convolution is applied.
148 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
149 | downsampling occurs in the inner-two dimensions.
150 | """
151 |
152 | def __init__(
153 | self,
154 | n_channels,
155 | down_rate = 2, # down rate
156 | use_conv = False, # (optional) use output conv
157 | dims = 2, # (optional) spatial dimension
158 | n_out_channels = None, # (optional) in case output channels are different from the input
159 | padding_mode = 'zeros',
160 | padding = 1
161 | ):
162 | super().__init__()
163 | self.n_channels = n_channels
164 | self.down_rate = down_rate
165 | self.n_out_channels = n_out_channels or n_channels
166 | self.use_conv = use_conv
167 | self.dims = dims
168 | stride = self.down_rate if dims != 3 else (1, self.down_rate, self.down_rate)
169 | if use_conv:
170 | self.op = conv_nd(
171 | dims = dims,
172 | in_channels = self.n_channels,
173 | out_channels = self.n_out_channels,
174 | kernel_size = 3,
175 | stride = stride,
176 | padding = padding,
177 | padding_mode = padding_mode
178 | )
179 | else:
180 | assert self.n_channels == self.n_out_channels
181 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
182 |
183 | def forward(self, x):
184 | assert x.shape[1] == self.n_channels
185 | return self.op(x)
186 |
187 | class QKVAttentionLegacy(nn.Module):
188 | """
189 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
190 | """
191 | def __init__(self, n_heads):
192 | super().__init__()
193 | self.n_heads = n_heads
194 |
195 | def forward(self, qkv):
196 | """
197 | Apply QKV attention.
198 | (B:#batches, C:channel size, T:#tokens, H:#heads)
199 |
200 | :param qkv: an [B x (3*C) x T] tensor of Qs, Ks, and Vs.
201 | :return: an [B x C x T] tensor after attention.
202 | """
203 | n_batches, width, n_tokens = qkv.shape
204 | assert width % (3 * self.n_heads) == 0
205 | ch = width // (3 * self.n_heads)
206 | q, k, v = qkv.reshape(n_batches * self.n_heads, ch * 3, n_tokens).split(ch, dim=1)
207 | scale = 1 / math.sqrt(math.sqrt(ch))
208 | weight = th.einsum(
209 | "bct,bcs->bts", q * scale, k * scale
210 | ) # More stable with f16 than dividing afterwards
211 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
212 | a = th.einsum("bts,bcs->bct", weight, v) # [(H*B) x (C//H) x T]
213 | out = a.reshape(n_batches, -1, n_tokens) # [B x C x T]
214 | return out
215 |
216 | class AttentionBlock(nn.Module):
217 | """
218 | An attention block that allows spatial positions to attend to each other.
219 | Input: [B x C x W x H] tensor
220 | Output: [B x C x W x H] tensor
221 | """
222 | def __init__(
223 | self,
224 | name = 'attentionblock',
225 | n_channels = 1,
226 | n_heads = 1,
227 | n_groups = 32,
228 | ):
229 | super().__init__()
230 | self.name = name
231 | self.n_channels = n_channels
232 | self.n_heads = n_heads
233 | assert (
234 | n_channels % n_heads == 0
235 | ), f"n_channels:[%d] should be divisible by n_heads:[%d]."%(n_channels,n_heads)
236 |
237 | # Normalize
238 | self.norm = normalization(n_channels=n_channels,n_groups=n_groups)
239 |
240 | # Tripple the channel
241 | self.qkv = nn.Conv1d(
242 | in_channels = self.n_channels,
243 | out_channels = self.n_channels*3,
244 | kernel_size = 1
245 | )
246 |
247 | # QKV Attention
248 | self.attention = QKVAttentionLegacy(
249 | n_heads = self.n_heads
250 | )
251 |
252 | # Projection
253 | self.proj_out = zero_module(
254 | nn.Conv1d(
255 | in_channels = self.n_channels,
256 | out_channels = self.n_channels,
257 | kernel_size = 1
258 | )
259 | )
260 |
261 | def forward(self, x):
262 | """
263 | :param x: [B x C x W x H] tensor
264 | :return out: [B x C x W x H] tensor
265 | """
266 | intermediate_output_dict = {}
267 | b, c, *spatial = x.shape
268 | # Triple the channel
269 | x_rsh = x.reshape(b, c, -1) # [B x C x WH]
270 | x_nzd = self.norm(x_rsh) # [B x C x WH]
271 | qkv = self.qkv(x_nzd) # [B x 3C x WH]
272 | # QKV attention
273 | h_att = self.attention(qkv) # [B x C x WH]
274 | h_proj = self.proj_out(h_att) # [B x C x WH]
275 | # Residual connection
276 | out = (x_rsh + h_proj).reshape(b, c, *spatial) # [B x C x W x H]
277 | # Append
278 | intermediate_output_dict['x'] = x
279 | intermediate_output_dict['x_rsh'] = x_rsh
280 | intermediate_output_dict['x_nzd'] = x_nzd
281 | intermediate_output_dict['qkv'] = qkv
282 | intermediate_output_dict['h_att'] = h_att
283 | intermediate_output_dict['h_proj'] = h_proj
284 | intermediate_output_dict['out'] = out
285 | return out,intermediate_output_dict
286 |
287 | class ResBlock(TimestepBlock):
288 | """
289 | A residual block that can optionally change the number of channels and resolution
290 |
291 | :param n_channels: the number of input channels
292 | :param n_emb_channels: the number of timestep embedding channels
293 | :param n_out_channels: (if specified) the number of output channels
294 | :param n_groups: the number of groups in group normalization layer
295 | :param dims: spatial dimension
296 | :param p_dropout: the rate of dropout
297 | :param actv: activation
298 | :param use_conv: if True, and n_out_channels is specified,
299 | use 3x3 conv instead of 1x1 conv
300 | :param use_scale_shift_norm: if True, use scale_shift_norm for handling emb
301 | :param upsample: if True, upsample
302 | :param downsample: if True, downsample
303 | :param sample_mode: upsample, downsample mode ('nearest' or 'bilinear')
304 | :param padding_mode: str
305 | :param padding: int
306 | """
307 | def __init__(
308 | self,
309 | name = 'resblock',
310 | n_channels = 128,
311 | n_emb_channels = 128,
312 | n_out_channels = None,
313 | n_groups = 16,
314 | dims = 2,
315 | p_dropout = 0.5,
316 | kernel_size = 3,
317 | actv = nn.SiLU(),
318 | use_conv = False,
319 | use_scale_shift_norm = True,
320 | upsample = False,
321 | downsample = False,
322 | up_rate = 2,
323 | down_rate = 2,
324 | sample_mode = 'nearest',
325 | padding_mode = 'zeros',
326 | padding = 1,
327 | ):
328 | super().__init__()
329 | self.name = name
330 | self.n_channels = n_channels
331 | self.n_emb_channels = n_emb_channels
332 | self.n_groups = n_groups
333 | self.dims = dims
334 | self.n_out_channels = n_out_channels or self.n_channels
335 | self.kernel_size = kernel_size
336 | self.p_dropout = p_dropout
337 | self.actv = actv
338 | self.use_conv = use_conv
339 | self.use_scale_shift_norm = use_scale_shift_norm
340 | self.upsample = upsample
341 | self.downsample = downsample
342 | self.up_rate = up_rate
343 | self.down_rate = down_rate
344 | self.sample_mode = sample_mode
345 | self.padding_mode = padding_mode
346 | self.padding = padding
347 |
348 | # Input layers
349 | self.in_layers = nn.Sequential(
350 | normalization(n_channels=self.n_channels,n_groups=self.n_groups),
351 | self.actv,
352 | conv_nd(
353 | dims = self.dims,
354 | in_channels = self.n_channels,
355 | out_channels = self.n_out_channels,
356 | kernel_size = self.kernel_size,
357 | padding = self.padding,
358 | padding_mode = self.padding_mode
359 | )
360 | )
361 |
362 | # Upsample or downsample
363 | self.updown = self.upsample or self.downsample
364 | if self.upsample:
365 | self.h_upd = Upsample(
366 | n_channels = self.n_channels,
367 | up_rate = self.up_rate,
368 | up_mode = self.sample_mode,
369 | dims = self.dims)
370 | self.x_upd = Upsample(
371 | n_channels = self.n_channels,
372 | up_rate = self.up_rate,
373 | up_mode = self.sample_mode,
374 | dims = self.dims)
375 | elif self.downsample:
376 | self.h_upd = Downsample(
377 | n_channels = self.n_channels,
378 | down_rate = self.down_rate,
379 | dims = self.dims)
380 | self.x_upd = Downsample(
381 | n_channels = self.n_channels,
382 | down_rate = self.down_rate,
383 | dims = self.dims)
384 | else:
385 | self.h_upd = nn.Identity()
386 | self.x_upd = nn.Identity()
387 |
388 | # Embedding layers
389 | self.emb_layers = nn.Sequential(
390 | self.actv,
391 | nn.Linear(
392 | in_features = self.n_emb_channels,
393 | out_features = 2*self.n_out_channels if self.use_scale_shift_norm
394 | else self.n_out_channels,
395 | ),
396 | )
397 |
398 | # Output layers
399 | self.out_layers = nn.Sequential(
400 | normalization(n_channels=self.n_out_channels,n_groups=self.n_groups),
401 | self.actv,
402 | nn.Dropout(p=self.p_dropout),
403 | zero_module(
404 | conv_nd(
405 | dims = self.dims,
406 | in_channels = self.n_out_channels,
407 | out_channels = self.n_out_channels,
408 | kernel_size = self.kernel_size,
409 | padding = self.padding,
410 | padding_mode = self.padding_mode
411 | )
412 | ),
413 | )
414 | if self.n_channels == self.n_out_channels:
415 | self.skip_connection = nn.Identity()
416 | elif use_conv:
417 | self.skip_connection = conv_nd(
418 | dims = self.dims,
419 | in_channels = self.n_channels,
420 | out_channels = self.n_out_channels,
421 | kernel_size = self.kernel_size,
422 | padding = self.padding,
423 | padding_mode = self.padding_mode
424 | )
425 | else:
426 | self.skip_connection = conv_nd(
427 | dims = self.dims,
428 | in_channels = self.n_channels,
429 | out_channels = self.n_out_channels,
430 | kernel_size = 1
431 | )
432 |
433 | def forward(self,x,emb):
434 | """
435 | :param x: [B x C x ...]
436 | :param emb: [B x n_emb_channels]
437 | :return: [B x C x ...]
438 | """
439 | # Input layer (groupnorm -> actv -> conv)
440 | if self.updown: # upsample or downsample
441 | in_norm_actv = self.in_layers[:-1]
442 | in_conv = self.in_layers[-1]
443 | h = in_norm_actv(x)
444 | h = self.h_upd(h)
445 | h = in_conv(h)
446 | x = self.x_upd(x)
447 | else:
448 | h = self.in_layers(x) # [B x C x ...]
449 |
450 | # Embedding layer
451 | emb_out = self.emb_layers(emb).type(h.dtype)
452 | while len(emb_out.shape) < len(h.shape):
453 | emb_out = emb_out[..., None] # match 'emb_out' with 'h': [B x C x ...]
454 |
455 | # Combine input with embedding
456 | if self.use_scale_shift_norm:
457 | out_norm = self.out_layers[0] # layernorm
458 | out_actv_dr_conv = self.out_layers[1:] # activation -> dropout -> conv
459 | # emb_out: [B x 2C x ...]
460 | scale,shift = th.chunk(emb_out, 2, dim=1) # [B x C x ...]
461 | h = out_norm(h) * (1.0 + scale) + shift # [B x C x ...]
462 | h = out_actv_dr_conv(h) # [B x C x ...]
463 | else:
464 | # emb_out: [B x C x ...]
465 | h = h + emb_out
466 | h = self.out_layers(h) # layernorm -> activation -> dropout -> conv
467 |
468 | # Skip connection
469 | out = h + self.skip_connection(x) # [B x C x ...]
470 | return out # [B x C x ...]
471 |
472 |
--------------------------------------------------------------------------------
/code/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch as th
4 | from scipy.spatial import distance
5 |
6 | def th2np(x):
7 | return x.detach().cpu().numpy()
8 |
9 | def get_torch_size_string(x):
10 | return "x".join(map(str,x.shape))
11 |
12 | def plot_4x4_torch_tensor(
13 | x_torch,
14 | figsize = (4,4),
15 | cmap = 'gray',
16 | info_str = '',
17 | top = 0.92,
18 | hspace = 0.1
19 | ):
20 | """
21 | :param x_torch: [B x C x W x H]
22 | """
23 | batch_size = x_torch.shape[0]
24 | fig = plt.figure(figsize=figsize)
25 | for i in range(batch_size):
26 | ax = plt.subplot(4,4,i+1)
27 | plt.imshow(x_torch.permute(0,2,3,1).detach().numpy()[i,:,:,:], cmap=cmap)
28 | plt.axis('off')
29 | plt.subplots_adjust(
30 | left=0.0,right=1.0,bottom=0.0,top=top,wspace=0.0,hspace=hspace)
31 | plt.suptitle('%s[%d] Images of [%dx%d] sizes'%
32 | (info_str,batch_size,x_torch.shape[2],x_torch.shape[3]),fontsize=10)
33 | plt.show()
34 |
35 | def plot_1xN_torch_img_tensor(
36 | x_torch,
37 | title_str_list = None,
38 | title_fontsize = 20):
39 | """
40 | : param x_torch: [B x C x W x H]
41 | """
42 | xt_np = x_torch.cpu().numpy() # [B x C x W x H]
43 | n_imgs = xt_np.shape[0]
44 | plt.figure(figsize=(n_imgs*2,3))
45 | for img_idx in range(n_imgs):
46 | plt.subplot(1,n_imgs,img_idx+1)
47 | if xt_np.shape[1]==1:
48 | plt.imshow(xt_np[img_idx,0,:,:], cmap='gray')
49 | else:
50 | plt.imshow(xt_np[img_idx,:,:,:].transpose(1,2,0))
51 | if title_str_list:
52 | plt.title(title_str_list[img_idx],fontsize=title_fontsize)
53 | plt.axis('off')
54 | plt.tight_layout()
55 | plt.show()
56 |
57 | def plot_1xN_torch_traj_tensor(
58 | times,
59 | x_torch,
60 | title_str_list = None,
61 | title_fontsize = 20,
62 | ylim = None,
63 | figsize = None,
64 | ):
65 | """
66 | : param x_torch: [B x C x ...]
67 | """
68 | xt_np = x_torch.cpu().numpy() # [B x C x W x H]
69 | n_trajs = xt_np.shape[0]
70 | L = times.shape[0]
71 | if figsize is None: figsize = (n_trajs*2,3)
72 | plt.figure(figsize=figsize)
73 | for traj_idx in range(n_trajs):
74 | plt.subplot(1,n_trajs,traj_idx+1)
75 | plt.plot(times,x_torch[traj_idx,0,:].cpu().numpy(),'-',color='k')
76 | if title_str_list:
77 | plt.title(title_str_list[traj_idx],fontsize=title_fontsize)
78 | if ylim:
79 | plt.ylim(ylim)
80 | plt.tight_layout()
81 | plt.show()
82 |
83 | def print_model_parameters(model):
84 | """
85 | Print model parameters
86 | """
87 | for p_idx,(param_name,param) in enumerate(model.named_parameters()):
88 | print ("[%2d] parameter:[%27s] shape:[%12s] numel:[%10d]"%
89 | (p_idx,
90 | param_name,
91 | get_torch_size_string(param),
92 | param.numel()
93 | )
94 | )
95 |
96 | def print_model_layers(model,x_torch):
97 | """
98 | Print model layers
99 | """
100 | y_torch,intermediate_output_list = model(x_torch)
101 | batch_size = x_torch.shape[0]
102 | print ("batch_size:[%d]"%(batch_size))
103 | print ("[ ] layer:[%15s] size:[%14s]"
104 | %('input',"x".join(map(str,x_torch.shape)))
105 | )
106 | for idx,layer_name in enumerate(model.layer_names):
107 | intermediate_output = intermediate_output_list[idx]
108 | print ("[%2d] layer:[%15s] size:[%14s] numel:[%10d]"%
109 | (idx,
110 | layer_name,
111 | get_torch_size_string(intermediate_output),
112 | intermediate_output.numel()
113 | ))
114 |
115 | def model_train(model,optm,loss,train_iter,test_iter,n_epoch,print_every,device):
116 | """
117 | Train model
118 | """
119 | model.init_param(VERBOSE=False)
120 | model.train()
121 | for epoch in range(n_epoch):
122 | loss_val_sum = 0
123 | for batch_in,batch_out in train_iter:
124 | # Forward path
125 | if isinstance(model.x_dim,int):
126 | y_pred,_ = model(batch_in.view(-1,model.x_dim).to(device))
127 | else:
128 | y_pred,_ = model(batch_in.view((-1,)+model.x_dim).to(device))
129 | loss_out = loss(y_pred,batch_out.to(device))
130 | # Update
131 | optm.zero_grad() # reset gradient
132 | loss_out.backward() # back-propagate loss
133 | optm.step() # optimizer update
134 | loss_val_sum += loss_out
135 | loss_val_avg = loss_val_sum/len(train_iter)
136 | # Print
137 | if ((epoch%print_every)==0) or (epoch==(n_epoch-1)):
138 | train_accr = model_eval(model,train_iter,device)
139 | test_accr = model_eval(model,test_iter,device)
140 | print ("epoch:[%2d/%d] loss:[%.3f] train_accr:[%.4f] test_accr:[%.4f]."%
141 | (epoch,n_epoch,loss_val_avg,train_accr,test_accr))
142 |
143 | def model_eval(model,data_iter,device):
144 | """
145 | Evaluate model
146 | """
147 | with th.no_grad():
148 | n_total,n_correct = 0,0
149 | model.eval() # evaluate (affects DropOut and BN)
150 | for batch_in,batch_out in data_iter:
151 | y_trgt = batch_out.to(device)
152 | if isinstance(model.x_dim,int):
153 | model_pred,_ = model(batch_in.view(-1,model.x_dim).to(device))
154 | else:
155 | model_pred,_ = model(batch_in.view((-1,)+model.x_dim).to(device))
156 | _,y_pred = th.max(model_pred.data,1)
157 | n_correct += (y_pred==y_trgt).sum().item()
158 | n_total += batch_in.size(0)
159 | val_accr = (n_correct/n_total)
160 | model.train() # back to train mode
161 | return val_accr
162 |
163 | def model_test(model,test_data,test_label,device):
164 | """
165 | Test model
166 | """
167 | n_sample = 25
168 | sample_indices = np.random.choice(len(test_data),n_sample,replace=False)
169 | test_data_samples = test_data[sample_indices]
170 | test_label_samples = test_label[sample_indices]
171 | with th.no_grad():
172 | model.eval()
173 | if isinstance(model.x_dim,int):
174 | x_in = test_data_samples.view(-1,model.x_dim).type(th.float).to(device)/255.
175 | else:
176 | x_in = test_data_samples.view((-1,)+model.x_dim).type(th.float).to(device)/255.
177 | y_pred,_ = model(x_in)
178 | y_pred = y_pred.argmax(axis=1)
179 | # Plot
180 | plt.figure(figsize=(6,6))
181 | plt.subplots_adjust(top=1.0)
182 | for idx in range(n_sample):
183 | plt.subplot(5,5, idx+1)
184 | plt.imshow(test_data_samples[idx],cmap='gray')
185 | plt.axis('off')
186 | fontcolor = 'k' if (y_pred[idx] == test_label_samples[idx]) else 'r'
187 | plt.title("Pred:%d, Label:%d"%(y_pred[idx],test_label_samples[idx]),
188 | fontsize=8,color=fontcolor)
189 | plt.show()
190 |
191 | def kernel_se(x1,x2,hyp={'gain':1.0,'len':1.0}):
192 | """ Squared-exponential kernel function """
193 | D = distance.cdist(x1/hyp['len'],x2/hyp['len'],'sqeuclidean')
194 | K = hyp['gain']*np.exp(-D)
195 | return K
196 |
197 | def gp_sampler(
198 | times = np.linspace(start=0.0,stop=1.0,num=100).reshape((-1,1)), # [L x 1]
199 | hyp_gain = 1.0,
200 | hyp_len = 1.0,
201 | meas_std = 0e-8,
202 | n_traj = 1
203 | ):
204 | """
205 | Gaussian process sampling
206 | """
207 | if len(times.shape) == 1: times = times.reshape((-1,1))
208 | L = times.shape[0]
209 | K = kernel_se(times,times,hyp={'gain':hyp_gain,'len':hyp_len}) # [L x L]
210 | K_chol = np.linalg.cholesky(K+1e-8*np.eye(L,L)) # [L x L]
211 | traj = K_chol @ np.random.randn(L,n_traj) # [L x n_traj]
212 | traj = traj + meas_std*np.random.randn(*traj.shape)
213 | return traj.T
214 |
215 | def hbm_sampler(
216 | times = np.linspace(start=0.0,stop=1.0,num=100).reshape((-1,1)), # [L x 1]
217 | hyp_gain = 1.0,
218 | hyp_len = 1.0,
219 | meas_std = 0e-8,
220 | n_traj = 1
221 | ):
222 | """
223 | Hilbert Brownian motion sampling
224 | """
225 | if len(times.shape) == 1: times = times.reshape((-1,1))
226 | L = times.shape[0]
227 | K = kernel_se(times,times,hyp={'gain':hyp_gain,'len':hyp_len}) # [L x L]
228 | K = K + 1e-8*np.eye(L,L)
229 | U,V = np.linalg.eigh(K,UPLO='L')
230 | traj = V @ np.diag(np.sqrt(U)) @ np.random.randn(L,n_traj) # [L x n_traj]
231 | traj = traj + meas_std*np.random.randn(*traj.shape)
232 | return traj.T
233 |
234 | def get_colors(n):
235 | return [plt.cm.Set1(x) for x in np.linspace(0,1,n)]
236 |
237 | def periodic_step(times,period,time_offset=0.0,y_min=0.0,y_max=1.0):
238 | times_mod = np.mod(times+time_offset,period)
239 | y = np.zeros_like(times)
240 | y[times_mod < (period/2)] = 1
241 | y*=(y_max-y_min)
242 | y+=y_min
243 | return y
244 |
245 | def plot_ddpm_1d_result(
246 | times,x_data,step_list,x_t_list,
247 | plot_ancestral_sampling=True,
248 | plot_one_sample=False,
249 | lw_gt=1,lw_sample=1/2,
250 | ls_gt='-',ls_sample='-',
251 | lc_gt='b',lc_sample='k',
252 | ylim=(-4,+4),figsize=(6,3),title_str=None
253 | ):
254 | """
255 | :param times: [L x 1] ndarray
256 | :param x_0: [N x C x L] torch tensor, training data
257 | :param step_list: [M] ndarray, diffusion steps to append x_t
258 | :param x_t_list: list of [n_sample x C x L] torch tensors
259 | """
260 |
261 | x_data_np = x_data.detach().cpu().numpy() # [n_data x C x L]
262 | n_data = x_data_np.shape[0] # number of GT data
263 | C = x_data_np.shape[1] # number of GT data
264 |
265 | # Plot a seqeunce of ancestral sampling procedure
266 | if plot_ancestral_sampling:
267 | for c_idx in range(C):
268 | plt.figure(figsize=(15,2)); plt.rc('xtick',labelsize=6); plt.rc('ytick',labelsize=6)
269 | for i_idx,t in enumerate(step_list):
270 | plt.subplot(1,len(step_list),i_idx+1)
271 | x_t = x_t_list[t] # [n_sample x C x L]
272 | x_t_np = x_t.detach().cpu().numpy() # [n_sample x C x L]
273 | n_sample = x_t_np.shape[0]
274 | for i_idx in range(n_data): # GT
275 | plt.plot(times.flatten(),x_data_np[i_idx,c_idx,:],ls='-',color=lc_gt,lw=lw_gt)
276 | for i_idx in range(n_sample): # sampled trajectories
277 | plt.plot(times.flatten(),x_t_np[i_idx,c_idx,:],ls='-',color=lc_sample,lw=lw_sample)
278 | plt.xlim([0.0,1.0]); plt.ylim(ylim)
279 | plt.xlabel('Time',fontsize=8); plt.title('Step:[%d]'%(t),fontsize=8)
280 | plt.tight_layout(); plt.show()
281 |
282 | # Plot generated data
283 | for c_idx in range(C):
284 | plt.figure(figsize=figsize)
285 | x_0_np = x_t_list[0].detach().cpu().numpy() # [n_sample x C x L]
286 | for i_idx in range(n_data): # GT
287 | plt.plot(times.flatten(),x_data_np[i_idx,c_idx,:],ls=ls_gt,color=lc_gt,lw=lw_gt)
288 | n_sample = x_0_np.shape[0]
289 | if plot_one_sample:
290 | i_idx = np.random.randint(low=0,high=n_sample)
291 | plt.plot(times.flatten(),x_0_np[i_idx,c_idx,:],ls=ls_sample,color=lc_sample,lw=lw_sample)
292 | else:
293 | for i_idx in range(n_sample): # sampled trajectories
294 | plt.plot(times.flatten(),x_0_np[i_idx,c_idx,:],ls=ls_sample,color=lc_sample,lw=lw_sample)
295 | plt.xlim([0.0,1.0]); plt.ylim(ylim)
296 | plt.xlabel('Time',fontsize=8)
297 | if title_str is None:
298 | plt.title('[%d] Groundtruth and Generated trajectories'%(c_idx),fontsize=10)
299 | else:
300 | plt.title('[%d] %s'%(c_idx,title_str),fontsize=10)
301 | plt.tight_layout(); plt.show()
302 |
303 |
304 | def plot_ddpm_2d_result(
305 | x_data,step_list,x_t_list,n_plot=1,
306 | tfs=10
307 | ):
308 | """
309 | :param x_data: [N x C x W x H] torch tensor, training data
310 | :param step_list: [M] ndarray, diffusion steps to append x_t
311 | :param x_t_list: list of [n_sample x C x L] torch tensors
312 | """
313 | for sample_idx in range(n_plot):
314 | plt.figure(figsize=(15,2))
315 | for i_idx,t in enumerate(step_list):
316 | x_t = x_t_list[t] # [n_sample x C x W x H]
317 | x_t_np = x_t.detach().cpu().numpy() # [n_sample x C x W x H]
318 | plt.subplot(1,len(step_list),i_idx+1)
319 | if x_data.shape[1]==1: # gray image
320 | plt.imshow(x_t_np[sample_idx,0,:,:], cmap='gray')
321 | else:
322 | plt.imshow(x_t_np[sample_idx,:,:,:].transpose(1,2,0))
323 | plt.axis('off')
324 | plt.title('Step:[%d]'%(t),fontsize=tfs)
325 | plt.tight_layout()
326 | plt.show()
327 |
328 |
329 | def get_hbm_M(times,hyp_gain=1.0,hyp_len=0.1,device='cpu'):
330 | """
331 | Get a matrix M for Hilbert Brownian motion
332 | :param times: [L x 1] ndarray
333 | :return: [L x L] torch tensor
334 | """
335 | L = times.shape[0]
336 | K = kernel_se(times,times,hyp={'gain':hyp_gain,'len':hyp_len}) # [L x L]
337 | K = K + 1e-8*np.eye(L,L)
338 | U,V = np.linalg.eigh(K,UPLO='L')
339 | M = V @ np.diag(np.sqrt(U))
340 | M = th.from_numpy(M).to(th.float32).to(device) # [L x L]
341 | return M
342 |
343 | def get_resampling_steps(t_T, j, r,plot_steps=False,figsize=(15,4)):
344 | """
345 | Get resampling steps for repaint, inpainting method using diffusion models
346 | :param t_T: maximum time steps for inpainting
347 | :param j: jump length
348 | :param r: the number of resampling
349 | """
350 | jumps = np.zeros(t_T+1)
351 | for i in range(1, t_T-j, j):
352 | jumps[i] = r-1
353 | t = t_T+1
354 | resampling_steps = []
355 | while t > 1:
356 | t -= 1
357 | resampling_steps.append(t)
358 | if jumps[t] > 0:
359 | jumps[t] -= 1
360 | for _ in range(j):
361 | t += 1
362 | resampling_steps.append(t)
363 | resampling_steps.append(0)
364 |
365 | # (optional) plot
366 | if plot_steps:
367 | plt.figure(figsize=figsize)
368 | plt.plot(resampling_steps,'-',color='k',lw=1)
369 | plt.xlabel('Number of Transitions')
370 | plt.ylabel('Diffusion time step')
371 | plt.show()
372 |
373 | # Return
374 | return resampling_steps
--------------------------------------------------------------------------------
/img/unet.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sjchoi86/yet-another-pytorch-tutorial-v2/1b2fdabfc11586fcabc9d3aa1123c90a026f1c1f/img/unet.jpg
--------------------------------------------------------------------------------