├── .gitignore ├── LICENSE ├── README.md ├── code ├── alexnet.ipynb ├── attention.ipynb ├── cnn.ipynb ├── dataset.py ├── ddpm_1d_example.ipynb ├── ddpm_2d_example.ipynb ├── ddpm_cfg_1d_example.ipynb ├── diffusion.py ├── diffusion_constants.ipynb ├── diffusion_resblock.ipynb ├── diffusion_unet.ipynb ├── diffusion_unet_legacy.ipynb ├── googlenet.ipynb ├── gp.ipynb ├── hdm_1d_example.ipynb ├── hdm_constant.ipynb ├── mdn.py ├── mdn_reg.ipynb ├── mlp.ipynb ├── module.py ├── repaint_1d_example.ipynb ├── repaint_constant.ipynb ├── resnet.ipynb ├── updownsample.ipynb ├── util.py └── vgg.ipynb └── img └── unet.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Mac 163 | .DS_Store 164 | 165 | # Data foloer 166 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sungjoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Yet Another Pytorch Tutorial v2 2 | 3 | Minimal implementations of 4 | - [MLP](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/mlp.ipynb): MNIST classification using multi-layer perception (MLP) 5 | - [CNN](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/cnn.ipynb): MNIST classification using convolutional neural networks (CNN) 6 | - [AlexNet](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/alexnet.ipynb): : MNIST classification using AlexNet 7 | - [VGG](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/vgg.ipynb): MNIST classification using VGG 8 | - [GoogLeNet](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/googlenet.ipynb): MNIST classification GoogLeNet 9 | - [MDN](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/mdn_reg.ipynb): Regression using Mixture Density Networks 10 | - [ResNet](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/resnet.ipynb): MNIST classification ResNet 11 | - [Attention](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/attention.ipynb): Attention block with legacy QKV attention mechanism 12 | - [ResBlock](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/diffusion_resblock.ipynb): Residual block for diffusion models 13 | - [Diffusion](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/diffusion_constants.ipynb): Diffusion constants 14 | - [DDPM 1D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/ddpm_1d_example.ipynb): Denoising Diffusion Probabilistic Model (DDPM) example on generating trajectories 15 | - [DDPM 2D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/ddpm_2d_example.ipynb): Denoising Diffusion Probabilistic Model (DDPM) example on generating images 16 | - [DDPM-CFG 1D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/ddpm_cfh_1d_example.ipynb): Conditional Generation using Classifier-Free Guidance (CFG) 17 | - [Repaint 1D Example](https://github.com/sjchoi86/yet-another-pytorch-tutorial-v2/blob/main/code/repaint_1d_example.ipynb): Diffusion-based inpainting example on 1D data 18 | 19 | Contact 20 | - sungjoon dash choi at korea dot ac dot kr 21 | -------------------------------------------------------------------------------- /code/attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "71051e5d", 6 | "metadata": {}, 7 | "source": [ 8 | "### Attention block" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "7ae3d6f4", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "PyTorch version:[2.0.1].\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import numpy as np\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import torch as th\n", 29 | "from module import AttentionBlock\n", 30 | "from util import get_torch_size_string\n", 31 | "np.set_printoptions(precision=3)\n", 32 | "th.set_printoptions(precision=3)\n", 33 | "%matplotlib inline\n", 34 | "%config InlineBackend.figure_format='retina'\n", 35 | "print (\"PyTorch version:[%s].\"%(th.__version__))" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "bfb7e801", 41 | "metadata": {}, 42 | "source": [ 43 | "### Let's see how `AttentionBlock` works\n", 44 | "- First, we assume that an input tensor has a shape of [B x C x W x H].\n", 45 | "- This can be thought of having a total of WH tokens with each token having C dimensions. \n", 46 | "- The MHA operates by initally partiting the channels, executing qkv attention process, and then merging the results. \n", 47 | "- Note the the number of channels should be divisible by the number of heads." 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "dca0e4fa", 53 | "metadata": {}, 54 | "source": [ 55 | "### `dims=2`\n", 56 | "#### `x` has a shape of `[B x C x W x H]`" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "id": "8c6e06cc", 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "input shape:[16x128x28x28] output shape:[16x128x28x28]\n", 70 | "[ x]:[ 16x128x28x28]\n", 71 | "[ x_rsh]:[ 16x128x784]\n", 72 | "[ x_nzd]:[ 16x128x784]\n", 73 | "[ qkv]:[ 16x384x784]\n", 74 | "[ h_att]:[ 16x128x784]\n", 75 | "[ h_proj]:[ 16x128x784]\n", 76 | "[ out]:[ 16x128x28x28]\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "layer = AttentionBlock(n_channels=128,n_heads=4,n_groups=32)\n", 82 | "x = th.randn(16,128,28,28)\n", 83 | "out,intermediate_output_dict = layer(x)\n", 84 | "print (\"input shape:[%s] output shape:[%s]\"%\n", 85 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 86 | "# Print intermediate values\n", 87 | "for key,value in intermediate_output_dict.items():\n", 88 | " print (\"[%10s]:[%15s]\"%(key,get_torch_size_string(value)))" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "a91d9cec", 94 | "metadata": {}, 95 | "source": [ 96 | "### `dims=1`\n", 97 | "#### `x` has a shape of `[B x C x L]`" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 3, 103 | "id": "97d1bbdb", 104 | "metadata": { 105 | "scrolled": true 106 | }, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "input shape:[16x4x100] output shape:[16x4x100]\n", 113 | "[ x]:[ 16x4x100]\n", 114 | "[ x_rsh]:[ 16x4x100]\n", 115 | "[ x_nzd]:[ 16x4x100]\n", 116 | "[ qkv]:[ 16x12x100]\n", 117 | "[ h_att]:[ 16x4x100]\n", 118 | "[ h_proj]:[ 16x4x100]\n", 119 | "[ out]:[ 16x4x100]\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "layer = AttentionBlock(n_channels=4,n_heads=2,n_groups=1)\n", 125 | "x = th.randn(16,4,100)\n", 126 | "out,intermediate_output_dict = layer(x)\n", 127 | "print (\"input shape:[%s] output shape:[%s]\"%\n", 128 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 129 | "# Print intermediate values\n", 130 | "for key,value in intermediate_output_dict.items():\n", 131 | " print (\"[%10s]:[%15s]\"%(key,get_torch_size_string(value)))" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "8e60d073", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "Python 3 (ipykernel)", 146 | "language": "python", 147 | "name": "python3" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.9.16" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 5 164 | } 165 | -------------------------------------------------------------------------------- /code/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch as th 4 | from torchvision import datasets,transforms 5 | from util import gp_sampler,periodic_step,get_torch_size_string 6 | 7 | def mnist(root_path='./data/',batch_size=128): 8 | """ 9 | MNIST 10 | """ 11 | mnist_train = datasets.MNIST(root=root_path,train=True,transform=transforms.ToTensor(),download=True) 12 | mnist_test = datasets.MNIST(root=root_path,train=False,transform=transforms.ToTensor(),download=True) 13 | train_iter = th.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=1) 14 | test_iter = th.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=True,num_workers=1) 15 | # Data 16 | train_data,train_label = mnist_train.data,mnist_train.targets 17 | test_data,test_label = mnist_test.data,mnist_test.targets 18 | return train_iter,test_iter,train_data,train_label,test_data,test_label 19 | 20 | def get_1d_training_data( 21 | traj_type = 'step', # {'step','gp'} 22 | n_traj = 10, 23 | L = 100, 24 | device = 'cpu', 25 | seed = 1, 26 | plot_data = True, 27 | figsize = (6,2), 28 | ls = '-', 29 | lc = 'k', 30 | lw = 1, 31 | verbose = True, 32 | ): 33 | """ 34 | 1-D training data 35 | """ 36 | if seed is not None: 37 | np.random.seed(seed=seed) 38 | times = np.linspace(start=0.0,stop=1.0,num=L).reshape((-1,1)) # [L x 1] 39 | if traj_type == 'gp': 40 | traj = th.from_numpy( 41 | gp_sampler( 42 | times = times, 43 | hyp_gain = 2.0, 44 | hyp_len = 0.2, 45 | meas_std = 1e-8, 46 | n_traj = n_traj 47 | ) 48 | ).to(th.float32).to(device) # [n_traj x L] 49 | elif traj_type == 'gp2': 50 | traj_np = np.zeros((n_traj,L)) 51 | for i_idx in range(n_traj): 52 | traj_np[i_idx,:] = gp_sampler( 53 | times = times, 54 | hyp_gain = 2.0, 55 | hyp_len = np.random.uniform(1e-8,1.0), 56 | meas_std = 1e-8, 57 | n_traj = 1 58 | ).reshape(-1) 59 | traj = th.from_numpy( 60 | traj_np 61 | ).to(th.float32).to(device) # [n_traj x L] 62 | elif traj_type == 'step': 63 | traj_np = np.zeros((n_traj,L)) 64 | for i_idx in range(n_traj): 65 | period = np.random.uniform(low=0.38,high=0.42) 66 | time_offset = np.random.uniform(low=-0.02,high=0.02) 67 | y_min = np.random.uniform(low=-3.2,high=-2.8) 68 | y_max = np.random.uniform(low=2.8,high=3.2) 69 | traj_np[i_idx,:] = periodic_step( 70 | times = times, 71 | period = period, 72 | time_offset = time_offset, 73 | y_min = y_min, 74 | y_max = y_max 75 | ).reshape(-1) 76 | traj = th.from_numpy( 77 | traj_np 78 | ).to(th.float32).to(device) # [n_traj x L] 79 | elif traj_type == 'step2': 80 | traj_np = np.zeros((n_traj,L)) 81 | for i_idx in range(n_traj): # for each trajectory 82 | # First, sample value and duration 83 | rate = 5 84 | val = np.random.uniform(low=-3.0,high=3.0) 85 | dur_tick = int(L*np.random.exponential(scale=1/rate)) 86 | dim_dur = 0.1 # minimum duration in sec 87 | dur_tick = max(dur_tick,int(dim_dur*L)) 88 | 89 | tick_fr = 0 90 | tick_to = tick_fr+dur_tick 91 | while True: 92 | # Append 93 | traj_np[i_idx,tick_fr:min(L,tick_to)] = val 94 | 95 | # Termination condition 96 | if tick_to >= L: break 97 | 98 | # sample value and duration 99 | val = np.random.uniform(low=-3.0,high=3.0) 100 | dur_tick = int(L*np.random.exponential(scale=1/rate)) 101 | dur_tick = max(dur_tick,int(dim_dur*L)) 102 | 103 | # Update tick 104 | tick_fr = tick_to 105 | tick_to = tick_fr+dur_tick 106 | traj = th.from_numpy( 107 | traj_np 108 | ).to(th.float32).to(device) # [n_traj x L] 109 | elif traj_type == 'triangle': 110 | traj_np = np.zeros((n_traj,L)) 111 | for i_idx in range(n_traj): 112 | period = 0.2 113 | time_offset = np.random.uniform(low=-0.02,high=0.02) 114 | y_min = np.random.uniform(low=-3.2,high=-2.8) 115 | y_max = np.random.uniform(low=2.8,high=3.2) 116 | times_mod = np.mod(times+time_offset,period)/period 117 | y = (y_max - y_min) * times_mod + y_min 118 | traj_np[i_idx,:] = y.reshape(-1) 119 | traj = th.from_numpy( 120 | traj_np 121 | ).to(th.float32).to(device) # [n_traj x L] 122 | else: 123 | print ("Unknown traj_type:[%s]"%(traj_type)) 124 | # Plot 125 | if plot_data: 126 | plt.figure(figsize=figsize) 127 | for i_idx in range(n_traj): 128 | plt.plot(times,traj[i_idx,:].cpu().numpy(),ls=ls,color=lc,lw=lw) 129 | plt.xlim([0.0,1.0]) 130 | plt.ylim([-4,+4]) 131 | plt.xlabel('Time',fontsize=10) 132 | plt.title('Trajectory type:[%s]'%(traj_type),fontsize=10) 133 | plt.show() 134 | # Print 135 | x_0 = traj[:,None,:] # [N x C x L] 136 | if verbose: 137 | print ("x_0:[%s]"%(get_torch_size_string(x_0))) 138 | # Out 139 | return times,x_0 140 | 141 | def get_mdn_data( 142 | n_train = 1000, 143 | x_min = 0.0, 144 | x_max = 1.0, 145 | y_min = -1.0, 146 | y_max = 1.0, 147 | freq = 1.0, 148 | noise_rate = 1.0, 149 | seed = 0, 150 | FLIP_AUGMENT = True, 151 | ): 152 | np.random.seed(seed=seed) 153 | 154 | if FLIP_AUGMENT: 155 | n_half = n_train // 2 156 | x_train_a = x_min + (x_max-x_min)*np.random.rand(n_half,1) # [n_half x 1] 157 | x_rate = (x_train_a-x_min)/(x_max-x_min) # [n_half x 1] 158 | sin_temp = y_min + (y_max-y_min)*np.sin(2*np.pi*freq*x_rate) 159 | cos_temp = y_min + (y_max-y_min)*np.cos(2*np.pi*freq*x_rate) 160 | y_train_a = np.concatenate( 161 | (sin_temp+1*(y_max-y_min)*x_rate, 162 | cos_temp+1*(y_max-y_min)*x_rate),axis=1) # [n_half x 2] 163 | x_train_b = x_min + (x_max-x_min)*np.random.rand(n_half,1) # [n_half x 1] 164 | x_rate = (x_train_b-x_min)/(x_max-x_min) # [n_half x 1] 165 | sin_temp = y_min + (y_max-y_min)*np.sin(2*np.pi*freq*x_rate) 166 | cos_temp = y_min + (y_max-y_min)*np.cos(2*np.pi*freq*x_rate) 167 | y_train_b = -np.concatenate( 168 | (sin_temp+1*(y_max-y_min)*x_rate, 169 | cos_temp+1*(y_max-y_min)*x_rate),axis=1) # [n_half x 2] 170 | # Concatenate 171 | x_train = np.concatenate((x_train_a,x_train_b),axis=0) # [n_train x 1] 172 | y_train = np.concatenate((y_train_a,y_train_b),axis=0) # [n_train x 2] 173 | else: 174 | x_train = x_min + (x_max-x_min)*np.random.rand(n_train,1) # [n_train x 1] 175 | x_rate = (x_train-x_min)/(x_max-x_min) # [n_train x 1] 176 | sin_temp = y_min + (y_max-y_min)*np.sin(2*np.pi*freq*x_rate) 177 | cos_temp = y_min + (y_max-y_min)*np.cos(2*np.pi*freq*x_rate) 178 | y_train = np.concatenate( 179 | (sin_temp+1*(y_max-y_min)*x_rate, 180 | cos_temp+1*(y_max-y_min)*x_rate),axis=1) # [n_train x 2] 181 | 182 | # Add noise 183 | x_rate = (x_train-x_min)/(x_max-x_min) # [n_train x 1] 184 | noise = noise_rate * (y_max-y_min) * (2*np.random.rand(n_train,2)-1) * ((x_rate)**2) # [n_train x 2] 185 | y_train = y_train + noise # [n_train x 2] 186 | return x_train,y_train -------------------------------------------------------------------------------- /code/diffusion.py: -------------------------------------------------------------------------------- 1 | import math,random 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch as th 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | from module import ( 8 | conv_nd, 9 | ResBlock, 10 | AttentionBlock, 11 | TimestepEmbedSequential, 12 | ) 13 | 14 | def get_named_beta_schedule( 15 | schedule_name, 16 | num_diffusion_timesteps, 17 | scale_betas=1.0, 18 | np_type=np.float64 19 | ): 20 | """ 21 | Get a pre-defined beta schedule for the given name. 22 | 23 | The beta schedule library consists of beta schedules which remain similar 24 | in the limit of num_diffusion_timesteps. 25 | Beta schedules may be added, but should not be removed or changed once 26 | they are committed to maintain backwards compatibility. 27 | """ 28 | if schedule_name == "linear": 29 | # Linear schedule from Ho et al, extended to work for any number of 30 | # diffusion steps. 31 | scale = scale_betas * 1000 / num_diffusion_timesteps 32 | beta_start = scale * 0.0001 33 | beta_end = scale * 0.02 34 | return np.linspace( 35 | beta_start, beta_end, num_diffusion_timesteps, dtype=np_type 36 | ) 37 | elif schedule_name == "cosine": 38 | return betas_for_alpha_bar( 39 | num_diffusion_timesteps, 40 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 41 | ) 42 | else: 43 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 44 | 45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 46 | """ 47 | Create a beta schedule that discretizes the given alpha_t_bar function, 48 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 49 | 50 | :param num_diffusion_timesteps: the number of betas to produce. 51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 52 | produces the cumulative product of (1-beta) up to that 53 | part of the diffusion process. 54 | :param max_beta: the maximum beta to use; use values lower than 1 to 55 | prevent singularities. 56 | """ 57 | betas = [] 58 | for i in range(num_diffusion_timesteps): 59 | t1 = i / num_diffusion_timesteps 60 | t2 = (i + 1) / num_diffusion_timesteps 61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 62 | return np.array(betas) 63 | 64 | def get_ddpm_constants( 65 | schedule_name = 'cosine', 66 | T = 1000, 67 | np_type = np.float64 68 | ): 69 | timesteps = np.linspace(start=1,stop=T,num=T) 70 | betas = get_named_beta_schedule( 71 | schedule_name = schedule_name, 72 | num_diffusion_timesteps = T, 73 | scale_betas = 1.0, 74 | ).astype(np_type) # [1,000] 75 | alphas = 1.0 - betas 76 | alphas_bar = np.cumprod(alphas, axis=0) # cummulative product 77 | alphas_bar_prev = np.append(1.0,alphas_bar[:-1]) 78 | sqrt_recip_alphas = np.sqrt(1.0/alphas) 79 | sqrt_alphas_bar = np.sqrt(alphas_bar) 80 | sqrt_one_minus_alphas_bar = np.sqrt(1.0-alphas_bar) 81 | posterior_variance = betas*(1.0-alphas_bar_prev)/(1.0-alphas_bar) 82 | posterior_variance = posterior_variance.astype(np_type) 83 | 84 | # Append 85 | dc = {} 86 | dc['schedule_name'] = schedule_name 87 | dc['T'] = T 88 | dc['timesteps'] = timesteps 89 | dc['betas'] = betas 90 | dc['alphas'] = alphas 91 | dc['alphas_bar'] = alphas_bar 92 | dc['alphas_bar_prev'] = alphas_bar_prev 93 | dc['sqrt_recip_alphas'] = sqrt_recip_alphas 94 | dc['sqrt_alphas_bar'] = sqrt_alphas_bar 95 | dc['sqrt_one_minus_alphas_bar'] = sqrt_one_minus_alphas_bar 96 | dc['posterior_variance'] = posterior_variance 97 | 98 | return dc 99 | 100 | def plot_ddpm_constants(dc): 101 | """ 102 | Plot DDPM constants 103 | """ 104 | plt.figure(figsize=(10,3)) 105 | cs = [plt.cm.gist_rainbow(x) for x in np.linspace(0,1,8)] 106 | lw = 2 107 | plt.subplot(1,2,1) 108 | plt.plot(dc['timesteps'],dc['betas'], 109 | color=cs[0],label=r'$\beta_t$',lw=lw) 110 | plt.plot(dc['timesteps'],dc['alphas'], 111 | color=cs[1],label=r'$\alpha_t$',lw=lw) 112 | plt.plot(dc['timesteps'],dc['alphas_bar'], 113 | color=cs[2],label=r'$\bar{\alpha}_t$',lw=lw) 114 | plt.plot(dc['timesteps'],dc['sqrt_alphas_bar'], 115 | color=cs[5],label=r'$\sqrt{\bar{\alpha}_t}$',lw=lw) 116 | 117 | plt.plot(dc['timesteps'],dc['sqrt_one_minus_alphas_bar'], 118 | color=cs[6],label=r'$\sqrt{1-\bar{\alpha}_t}$',lw=lw) 119 | 120 | 121 | plt.plot(dc['timesteps'],dc['posterior_variance'],'--', 122 | color='k',label=r'$ Var[x_{t-1}|x_t,x_0] $',lw=lw) 123 | 124 | plt.xlabel('Diffusion steps',fontsize=8) 125 | plt.legend(fontsize=10,loc='center left',bbox_to_anchor=(1,0.5)) 126 | plt.grid(lw=0.5) 127 | plt.title('DDPM Constants',fontsize=10) 128 | plt.subplot(1,2,2) 129 | plt.plot(dc['timesteps'],dc['betas'],color=cs[0],label=r'$\beta_t$',lw=lw) 130 | plt.plot(dc['timesteps'],dc['posterior_variance'],'--', 131 | color='k',label=r'$ Var[x_{t-1}|x_t,x_0] $',lw=lw) 132 | plt.xlabel('Diffusion steps',fontsize=8) 133 | plt.legend(fontsize=10,loc='center left',bbox_to_anchor=(1,0.5)) 134 | plt.grid(lw=0.5) 135 | plt.title('DDPM Constants',fontsize=10) 136 | plt.subplots_adjust(wspace=0.7) 137 | plt.show() 138 | 139 | def timestep_embedding(timesteps, dim, max_period=10000): 140 | """ 141 | Create sinusoidal timestep embeddings. 142 | 143 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 144 | These may be fractional. 145 | :param dim: the dimension of the output. 146 | :param max_period: controls the minimum frequency of the embeddings. 147 | :return: an [N x dim] Tensor of positional embeddings. 148 | """ 149 | half = dim // 2 150 | freqs = th.exp( 151 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 152 | ).to(device=timesteps.device) 153 | args = timesteps[:, None].float() * freqs[None] 154 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 155 | if dim % 2: 156 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 157 | return embedding 158 | 159 | def forward_sample(x0_batch,t_batch,dc,M=None,noise_scale=1.0): 160 | """ 161 | Forward diffusion sampling 162 | :param x0_batch: [B x C x ...] 163 | :param t_batch: [B] 164 | :param dc: dictionary of diffusion constants 165 | :param M: a matrix of [L x L] for [B x C x L] data 166 | :return: xt_batch of [B x C x ...] and noise of [B x C x ...] 167 | """ 168 | # Gather diffusion constants with matching dimension 169 | out_shape = (t_batch.shape[0],) + ((1,)*(len(x0_batch.shape)-1)) 170 | device = t_batch.device 171 | sqrt_alphas_bar_t = th.gather( 172 | input = th.from_numpy(dc['sqrt_alphas_bar']).to(device), # [T] 173 | dim = -1, 174 | index = t_batch 175 | ).reshape(out_shape) # [B x 1 x 1 x 1] if (rank==4) and [B x 1 x 1] if (rank==3) 176 | sqrt_one_minus_alphas_bar = th.gather( 177 | input = th.from_numpy(dc['sqrt_one_minus_alphas_bar']).to(device), # [T] 178 | dim = -1, 179 | index = t_batch 180 | ).reshape(out_shape) # [B x 1 x 1 x 1] if (rank==4) and [B x 1 x 1] if (rank==3) 181 | 182 | # Forward sample 183 | noise = th.randn_like(input=x0_batch) # [B x C x ...] 184 | 185 | # (optional) correlated noise 186 | if M is not None: 187 | B = x0_batch.shape[0] 188 | C = x0_batch.shape[1] 189 | L = x0_batch.shape[2] 190 | if isinstance(M, list): # if M is a list, 191 | M_use = random.choice(M) 192 | else: 193 | M_use = M # [L x L] 194 | M_exp = M_use[None,None,:,:].expand(B,C,L,L) # [L x L] => [B x C x L x L] 195 | noise_exp = noise[:,:,:,None] # [B x C x L x 1] 196 | noise_exp = M_exp @ noise_exp # [B x C x L x 1] 197 | noise = noise_exp.squeeze(dim=3) # [B x C x L] 198 | 199 | # Jump diffusion 200 | xt_batch = sqrt_alphas_bar_t*x0_batch + \ 201 | sqrt_one_minus_alphas_bar*noise_scale*noise # [B x C x ...] 202 | return xt_batch,noise 203 | 204 | class DiffusionUNet(nn.Module): 205 | """ 206 | U-Net for diffusion models 207 | """ 208 | def __init__( 209 | self, 210 | name = 'unet', 211 | dims = 1, # spatial dimension, if dims==1, [B x C x L], if dims==2, [B x C x W x H] 212 | n_in_channels = 128, # input channels 213 | n_model_channels = 64, # base channel size 214 | n_emb_dim = 128, # time embedding size 215 | n_cond_dim = 0, # conditioning vector size (default is 0 indicating an unconditional model) 216 | n_enc_blocks = 2, # number of encoder blocks 217 | n_dec_blocks = 2, # number of decoder blocks 218 | n_groups = 16, # group norm paramter 219 | n_heads = 4, # number of heads 220 | actv = nn.SiLU(), 221 | kernel_size = 3, # kernel size 222 | padding = 1, 223 | use_resblock = True, 224 | use_attention = True, 225 | skip_connection = False, 226 | use_scale_shift_norm = True, # positional embedding handling 227 | device = 'cpu', 228 | ): 229 | super().__init__() 230 | self.name = name 231 | self.dims = dims 232 | self.n_in_channels = n_in_channels 233 | self.n_model_channels = n_model_channels 234 | self.n_emb_dim = n_emb_dim 235 | self.n_cond_dim = n_cond_dim 236 | self.n_enc_blocks = n_enc_blocks 237 | self.n_dec_blocks = n_dec_blocks 238 | self.n_groups = n_groups 239 | self.n_heads = n_heads 240 | self.actv = actv 241 | self.kernel_size = kernel_size 242 | self.padding = padding 243 | self.use_resblock = use_resblock 244 | self.use_attention = use_attention 245 | self.skip_connection = skip_connection 246 | self.use_scale_shift_norm = use_scale_shift_norm 247 | self.device = device 248 | 249 | # Time embedding 250 | self.time_embed = nn.Sequential( 251 | nn.Linear(in_features=self.n_model_channels,out_features=self.n_emb_dim), 252 | nn.SiLU(), 253 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim), 254 | ).to(self.device) 255 | 256 | # Conditional embedding 257 | if self.n_cond_dim > 0: 258 | self.cond_embed = nn.Sequential( 259 | nn.Linear(in_features=self.n_cond_dim,out_features=self.n_emb_dim), 260 | nn.SiLU(), 261 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim), 262 | ).to(self.device) 263 | 264 | # Lifting 265 | self.lift = conv_nd( 266 | dims = self.dims, 267 | in_channels = self.n_in_channels, 268 | out_channels = self.n_model_channels, 269 | kernel_size = 1, 270 | ).to(device) 271 | 272 | # Projection 273 | self.proj = conv_nd( 274 | dims = self.dims, 275 | in_channels = self.n_model_channels, 276 | out_channels = self.n_in_channels, 277 | kernel_size = 1, 278 | ).to(device) 279 | 280 | # Declare U-net 281 | # Encoder 282 | self.enc_layers = [] 283 | for e_idx in range(self.n_enc_blocks): 284 | # Residual block in encoder 285 | if self.use_resblock: 286 | self.enc_layers.append( 287 | ResBlock( 288 | name = 'res', 289 | n_channels = self.n_model_channels, 290 | n_emb_channels = self.n_emb_dim, 291 | n_out_channels = self.n_model_channels, 292 | n_groups = self.n_groups, 293 | dims = self.dims, 294 | actv = self.actv, 295 | kernel_size = self.kernel_size, 296 | padding = self.padding, 297 | upsample = False, 298 | downsample = False, 299 | use_scale_shift_norm = self.use_scale_shift_norm, 300 | ).to(device) 301 | ) 302 | # Attention block in encoder 303 | if self.use_attention: 304 | self.enc_layers.append( 305 | AttentionBlock( 306 | name = 'att', 307 | n_channels = self.n_model_channels, 308 | n_heads = self.n_heads, 309 | n_groups = self.n_groups, 310 | ).to(device) 311 | ) 312 | 313 | # Decoder 314 | self.dec_layers = [] 315 | for d_idx in range(self.n_dec_blocks): 316 | # Residual block in decoder 317 | if self.use_resblock: 318 | if d_idx == 0: n_channels = self.n_model_channels*self.n_enc_blocks 319 | else: n_channels = self.n_model_channels 320 | self.dec_layers.append( 321 | ResBlock( 322 | name = 'res', 323 | n_channels = n_channels, 324 | n_emb_channels = self.n_emb_dim, 325 | n_out_channels = self.n_model_channels, 326 | n_groups = self.n_groups, 327 | dims = self.dims, 328 | actv = self.actv, 329 | kernel_size = self.kernel_size, 330 | padding = self.padding, 331 | upsample = False, 332 | downsample = False, 333 | use_scale_shift_norm = self.use_scale_shift_norm, 334 | ).to(device) 335 | ) 336 | # Attention block in decoder 337 | if self.use_attention: 338 | self.dec_layers.append( 339 | AttentionBlock( 340 | name = 'att', 341 | n_channels = self.n_model_channels, 342 | n_heads = self.n_heads, 343 | n_groups = self.n_groups, 344 | ).to(device) 345 | ) 346 | 347 | # Define U-net 348 | self.enc_net = nn.Sequential() 349 | for l_idx,layer in enumerate(self.enc_layers): 350 | self.enc_net.add_module( 351 | name = 'enc_%02d'%(l_idx), 352 | module = TimestepEmbedSequential(layer) 353 | ) 354 | self.dec_net = nn.Sequential() 355 | for l_idx,layer in enumerate(self.dec_layers): 356 | self.dec_net.add_module( 357 | name = 'dec_%02d'%(l_idx), 358 | module = TimestepEmbedSequential(layer) 359 | ) 360 | 361 | def forward(self,x,timesteps,c=None): 362 | """ 363 | :param x: [B x n_in_channels x ...] 364 | :param timesteps: [B] 365 | :param c: [B] 366 | :return: [B x n_in_channels x ...], same shape as x 367 | """ 368 | intermediate_output_dict = {} 369 | intermediate_output_dict['x'] = x 370 | 371 | # time embedding 372 | emb = self.time_embed( 373 | timestep_embedding(timesteps,self.n_model_channels) 374 | ) # [B x n_emb_dim] 375 | 376 | # conditional embedding 377 | if self.n_cond_dim > 0: 378 | cond = self.cond_embed(c) 379 | emb = emb + cond 380 | 381 | # Lift input 382 | h = self.lift(x) # [B x n_model_channels x ...] 383 | intermediate_output_dict['x_lifted'] = h 384 | 385 | # Encoder 386 | self.h_enc_list = [] 387 | for m_idx,module in enumerate(self.enc_net): 388 | h = module(h,emb) 389 | if isinstance(h,tuple): h = h[0] # in case of having tuple 390 | # Append 391 | module_name = module[0].name 392 | intermediate_output_dict['h_enc_%s_%02d'%(module_name,m_idx)] = h 393 | # Append encoder output 394 | if self.use_resblock and self.use_attention: 395 | if (m_idx%2) == 1: 396 | self.h_enc_list.append(h) 397 | elif self.use_resblock and not self.use_attention: 398 | self.h_enc_list.append(h) 399 | elif not self.use_resblock and self.use_attention: 400 | self.h_enc_list.append(h) 401 | else: 402 | self.h_enc_list.append(h) 403 | 404 | # Stack encoder outputs 405 | if not self.use_resblock and self.use_attention: 406 | h_enc_stack = h 407 | else: 408 | for h_idx,h_enc in enumerate(self.h_enc_list): 409 | if h_idx == 0: h_enc_stack = h_enc 410 | else: h_enc_stack = th.cat([h_enc_stack,h_enc],dim=1) 411 | intermediate_output_dict['h_enc_stack'] = h_enc_stack 412 | 413 | # Decoder 414 | h = h_enc_stack # [B x n_enc_blocks*n_model_channels x ...] 415 | for m_idx,module in enumerate(self.dec_net): 416 | h = module(h,emb) # [B x n_model_channels x ...] 417 | if isinstance(h,tuple): h = h[0] # in case of having tuple 418 | # Append 419 | module_name = module[0].name 420 | intermediate_output_dict['h_dec_%s_%02d'%(module_name,m_idx)] = h 421 | 422 | # Projection 423 | if self.skip_connection: 424 | out = self.proj(h) + x # [B x n_in_channels x ...] 425 | else: 426 | out = self.proj(h) # [B x n_in_channels x ...] 427 | 428 | # Append 429 | intermediate_output_dict['out'] = out # [B x n_in_channels x ...] 430 | 431 | return out,intermediate_output_dict 432 | 433 | class DiffusionUNetLegacy(nn.Module): 434 | """ 435 | U-Net for diffusion models (legacy) 436 | """ 437 | def __init__( 438 | self, 439 | name = 'unet', 440 | dims = 1, # spatial dimension, if dims==1, [B x C x L], if dims==2, [B x C x W x H] 441 | n_in_channels = 128, # input channels 442 | n_base_channels = 64, # base channel size 443 | n_emb_dim = 128, # time embedding size 444 | n_cond_dim = 0, # conditioning vector size (default is 0 indicating an unconditional model) 445 | n_enc_blocks = 3, # number of encoder blocks 446 | n_dec_blocks = 3, # number of decoder blocks 447 | n_groups = 16, # group norm paramter 448 | n_heads = 4, # number of heads 449 | actv = nn.SiLU(), 450 | kernel_size = 3, # kernel size 451 | padding = 1, 452 | use_attention = True, 453 | skip_connection = True, # (optional) additional final skip connection 454 | use_scale_shift_norm = True, # positional embedding handling 455 | chnnel_multiples = (1,2,4), 456 | updown_rates = (2,2,2), 457 | device = 'cpu', 458 | ): 459 | super().__init__() 460 | self.name = name 461 | self.dims = dims 462 | self.n_in_channels = n_in_channels 463 | self.n_base_channels = n_base_channels 464 | self.n_emb_dim = n_emb_dim 465 | self.n_cond_dim = n_cond_dim 466 | self.n_enc_blocks = n_enc_blocks 467 | self.n_dec_blocks = n_dec_blocks 468 | self.n_groups = n_groups 469 | self.n_heads = n_heads 470 | self.actv = actv 471 | self.kernel_size = kernel_size 472 | self.padding = padding 473 | self.use_attention = use_attention 474 | self.skip_connection = skip_connection 475 | self.use_scale_shift_norm = use_scale_shift_norm 476 | self.chnnel_multiples = chnnel_multiples 477 | self.updown_rates = updown_rates 478 | self.device = device 479 | 480 | # Time embedding 481 | self.time_embed = nn.Sequential( 482 | nn.Linear(in_features=self.n_base_channels,out_features=self.n_emb_dim), 483 | nn.SiLU(), 484 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim), 485 | ).to(self.device) 486 | 487 | # Conditional embedding 488 | if self.n_cond_dim > 0: 489 | self.cond_embed = nn.Sequential( 490 | nn.Linear(in_features=self.n_cond_dim,out_features=self.n_emb_dim), 491 | nn.SiLU(), 492 | nn.Linear(in_features=self.n_emb_dim,out_features=self.n_emb_dim), 493 | ).to(self.device) 494 | 495 | # Lifting (1x1 conv) 496 | self.lift = conv_nd( 497 | dims = self.dims, 498 | in_channels = self.n_in_channels, 499 | out_channels = self.n_base_channels, 500 | kernel_size = 1, 501 | ).to(device) 502 | 503 | # Encoder 504 | self.enc_layers = [] 505 | n_channels2cat = [] # channel size to concat for decoder (note that we should use .pop() ) 506 | for e_idx in range(self.n_enc_blocks): # for each encoder block 507 | if e_idx == 0: 508 | in_channel = self.n_base_channels 509 | out_channel = self.n_base_channels*self.chnnel_multiples[e_idx] 510 | else: 511 | in_channel = self.n_base_channels*self.chnnel_multiples[e_idx-1] 512 | out_channel = self.n_base_channels*self.chnnel_multiples[e_idx] 513 | n_channels2cat.append(out_channel) # append out_channel 514 | updown_rate = updown_rates[e_idx] 515 | 516 | # Residual block in encoder 517 | self.enc_layers.append( 518 | ResBlock( 519 | name = 'res', 520 | n_channels = in_channel, 521 | n_emb_channels = self.n_emb_dim, 522 | n_out_channels = out_channel, 523 | n_groups = self.n_groups, 524 | dims = self.dims, 525 | actv = self.actv, 526 | kernel_size = self.kernel_size, 527 | padding = self.padding, 528 | downsample = updown_rate != 1, 529 | down_rate = updown_rate, 530 | use_scale_shift_norm = self.use_scale_shift_norm, 531 | ).to(device) 532 | ) 533 | # Attention block in encoder 534 | if self.use_attention: 535 | self.enc_layers.append( 536 | AttentionBlock( 537 | name = 'att', 538 | n_channels = out_channel, 539 | n_heads = self.n_heads, 540 | n_groups = self.n_groups, 541 | ).to(device) 542 | ) 543 | 544 | # Mid 545 | self.mid = conv_nd( 546 | dims = self.dims, 547 | in_channels = self.n_base_channels*self.chnnel_multiples[-1], 548 | out_channels = self.n_base_channels*self.chnnel_multiples[-1], 549 | kernel_size = 1, 550 | ).to(device) 551 | 552 | # Decoder 553 | self.dec_layers = [] 554 | for d_idx in range(self.n_dec_blocks): 555 | # Residual block in decoder 556 | if d_idx == 0: # first decoder 557 | in_channel = self.chnnel_multiples[::-1][d_idx]*self.n_base_channels + n_channels2cat.pop() 558 | out_channel = self.chnnel_multiples[::-1][d_idx]*self.n_base_channels 559 | else: 560 | in_channel = self.chnnel_multiples[::-1][d_idx-1]*self.n_base_channels + n_channels2cat.pop() 561 | out_channel = self.chnnel_multiples[::-1][d_idx]*self.n_base_channels 562 | 563 | updown_rate = updown_rates[::-1][d_idx] 564 | 565 | self.dec_layers.append( 566 | ResBlock( 567 | name = 'res', 568 | n_channels = in_channel, 569 | n_emb_channels = self.n_emb_dim, 570 | n_out_channels = out_channel, 571 | n_groups = self.n_groups, 572 | dims = self.dims, 573 | actv = self.actv, 574 | kernel_size = self.kernel_size, 575 | padding = self.padding, 576 | upsample = updown_rate != 1, 577 | up_rate = updown_rate, 578 | use_scale_shift_norm = self.use_scale_shift_norm, 579 | ).to(device) 580 | ) 581 | # Attention block in decoder 582 | if self.use_attention: 583 | self.dec_layers.append( 584 | AttentionBlock( 585 | name = 'att', 586 | n_channels = out_channel, 587 | n_heads = self.n_heads, 588 | n_groups = self.n_groups, 589 | ).to(device) 590 | ) 591 | 592 | # Projection 593 | self.proj = conv_nd( 594 | dims = self.dims, 595 | in_channels = (1+self.chnnel_multiples[0])*self.n_base_channels, 596 | out_channels = self.n_in_channels, 597 | kernel_size = 1, 598 | ).to(device) 599 | 600 | # Define U-net 601 | self.enc_net = nn.Sequential() 602 | for l_idx,layer in enumerate(self.enc_layers): 603 | self.enc_net.add_module( 604 | name = 'enc_%02d'%(l_idx), 605 | module = TimestepEmbedSequential(layer) 606 | ) 607 | self.dec_net = nn.Sequential() 608 | for l_idx,layer in enumerate(self.dec_layers): 609 | self.dec_net.add_module( 610 | name = 'dec_%02d'%(l_idx), 611 | module = TimestepEmbedSequential(layer) 612 | ) 613 | 614 | def forward(self,x,timesteps,c=None): 615 | """ 616 | :param x: [B x n_in_channels x ...] 617 | :param timesteps: [B] 618 | :param c: 619 | :return: [B x n_in_channels x ...], same shape as x 620 | """ 621 | intermediate_output_dict = {} 622 | intermediate_output_dict['x'] = x 623 | 624 | # time embedding 625 | emb = self.time_embed( 626 | timestep_embedding(timesteps,self.n_base_channels) 627 | ) # [B x n_emb_dim] 628 | 629 | # conditional embedding 630 | if self.n_cond_dim > 0: 631 | cond = self.cond_embed(c) 632 | emb = emb + cond # [B x n_base_channels] 633 | 634 | # Lift input 635 | h = self.lift(x) # [B x n_base_channels x ...] 636 | if isinstance(h,tuple): h = h[0] # in case of having tuple 637 | intermediate_output_dict['x_lifted'] = h 638 | 639 | # Encoder 640 | self.h_enc_list = [h] # start with lifted input 641 | for m_idx,module in enumerate(self.enc_net): 642 | h = module(h,emb) 643 | if isinstance(h,tuple): h = h[0] # in case of having tuple 644 | # Append 645 | module_name = module[0].name 646 | intermediate_output_dict['h_enc_%s_%02d'%(module_name,m_idx)] = h 647 | # Append encoder output 648 | if self.use_attention: 649 | if (m_idx%2) == 1: 650 | self.h_enc_list.append(h) 651 | else: 652 | self.h_enc_list.append(h) 653 | 654 | # Mid 655 | h = self.mid(h) 656 | if isinstance(h,tuple): h = h[0] # in case of having tuple 657 | 658 | # Decoder 659 | for m_idx,module in enumerate(self.dec_net): 660 | if self.use_attention: 661 | if (m_idx%2) == 0: 662 | h = th.cat((h,self.h_enc_list.pop()),dim=1) 663 | else: 664 | h = th.cat((h,self.h_enc_list.pop()),dim=1) 665 | h = module(h,emb) # [B x n_base_channels x ...] 666 | if isinstance(h,tuple): h = h[0] # in cfase of having tuple 667 | # Append 668 | module_name = module[0].name 669 | intermediate_output_dict['h_dec_%s_%02d'%(module_name,m_idx)] = h 670 | 671 | # Projection 672 | h = th.cat((h,self.h_enc_list.pop()),dim=1) 673 | 674 | if self.skip_connection: 675 | out = self.proj(h) + x # [B x n_in_channels x ...] 676 | else: 677 | out = self.proj(h) # [B x n_in_channels x ...] 678 | 679 | # Append 680 | intermediate_output_dict['out'] = out # [B x n_in_channels x ...] 681 | 682 | return out,intermediate_output_dict 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | def get_param_groups_and_shapes(named_model_params): 714 | named_model_params = list(named_model_params) 715 | scalar_vector_named_params = ( 716 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 717 | (-1), 718 | ) 719 | matrix_named_params = ( 720 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 721 | (1, -1), 722 | ) 723 | return [scalar_vector_named_params, matrix_named_params] 724 | 725 | def make_master_params(param_groups_and_shapes): 726 | """ 727 | Copy model parameters into a (differently-shaped) list of full-precision 728 | parameters. 729 | """ 730 | master_params = [] 731 | for param_group, shape in param_groups_and_shapes: 732 | master_param = nn.Parameter( 733 | _flatten_dense_tensors( 734 | [param.detach().float() for (_, param) in param_group] 735 | ).view(shape) 736 | ) 737 | master_param.requires_grad = True 738 | master_params.append(master_param) 739 | return master_params 740 | 741 | def unflatten_master_params(param_group, master_param): 742 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 743 | 744 | def master_params_to_state_dict( 745 | model, param_groups_and_shapes, master_params, use_fp16 746 | ): 747 | if use_fp16: 748 | state_dict = model.state_dict() 749 | for master_param, (param_group, _) in zip( 750 | master_params, param_groups_and_shapes 751 | ): 752 | for (name, _), unflat_master_param in zip( 753 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 754 | ): 755 | assert name in state_dict 756 | state_dict[name] = unflat_master_param 757 | else: 758 | state_dict = model.state_dict() 759 | for i, (name, _value) in enumerate(model.named_parameters()): 760 | assert name in state_dict 761 | state_dict[name] = master_params[i] 762 | return state_dict 763 | 764 | def state_dict_to_master_params(model, state_dict, use_fp16): 765 | if use_fp16: 766 | named_model_params = [ 767 | (name, state_dict[name]) for name, _ in model.named_parameters() 768 | ] 769 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 770 | master_params = make_master_params(param_groups_and_shapes) 771 | else: 772 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 773 | return master_params 774 | 775 | def zero_grad(model_params): 776 | for param in model_params: 777 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 778 | if param.grad is not None: 779 | param.grad.detach_() 780 | param.grad.zero_() 781 | 782 | def param_grad_or_zeros(param): 783 | if param.grad is not None: 784 | return param.grad.data.detach() 785 | else: 786 | return th.zeros_like(param) 787 | 788 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 789 | """ 790 | Copy the gradients from the model parameters into the master parameters 791 | from make_master_params(). 792 | """ 793 | for master_param, (param_group, shape) in zip( 794 | master_params, param_groups_and_shapes 795 | ): 796 | master_param.grad = _flatten_dense_tensors( 797 | [param_grad_or_zeros(param) for (_, param) in param_group] 798 | ).view(shape) 799 | 800 | def check_overflow(value): 801 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 802 | 803 | def zero_master_grads(master_params): 804 | for param in master_params: 805 | param.grad = None 806 | 807 | def master_params_to_model_params(param_groups_and_shapes, master_params): 808 | """ 809 | Copy the master parameter data back into the model parameters. 810 | """ 811 | # Without copying to a list, if a generator is passed, this will 812 | # silently not copy any parameters. 813 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 814 | for (_, param), unflat_master_param in zip( 815 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 816 | ): 817 | param.detach().copy_(unflat_master_param) 818 | 819 | INITIAL_LOG_LOSS_SCALE = 20.0 820 | class MixedPrecisionTrainer: 821 | def __init__( 822 | self, 823 | *, 824 | model, 825 | use_fp16=False, 826 | fp16_scale_growth=1e-3, 827 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 828 | ): 829 | self.model = model 830 | self.use_fp16 = use_fp16 831 | self.fp16_scale_growth = fp16_scale_growth 832 | 833 | self.model_params = list(self.model.parameters()) 834 | self.master_params = self.model_params 835 | self.param_groups_and_shapes = None 836 | self.lg_loss_scale = initial_lg_loss_scale 837 | 838 | if self.use_fp16: 839 | self.param_groups_and_shapes = get_param_groups_and_shapes( 840 | self.model.named_parameters() 841 | ) 842 | self.master_params = make_master_params(self.param_groups_and_shapes) 843 | self.model.convert_to_fp16() 844 | 845 | def zero_grad(self): 846 | zero_grad(self.model_params) 847 | 848 | def backward(self, loss: th.Tensor): 849 | if self.use_fp16: 850 | loss_scale = 2 ** self.lg_loss_scale 851 | (loss * loss_scale).backward() 852 | else: 853 | loss.backward() 854 | 855 | def optimize(self, opt: th.optim.Optimizer): 856 | if self.use_fp16: 857 | return self._optimize_fp16(opt) 858 | else: 859 | return self._optimize_normal(opt) 860 | 861 | def _optimize_fp16(self, opt: th.optim.Optimizer): 862 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 863 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 864 | if check_overflow(grad_norm): 865 | self.lg_loss_scale -= 1 866 | zero_master_grads(self.master_params) 867 | return False 868 | 869 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 870 | opt.step() 871 | zero_master_grads(self.master_params) 872 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 873 | self.lg_loss_scale += self.fp16_scale_growth 874 | return True 875 | 876 | def _optimize_normal(self, opt: th.optim.Optimizer): 877 | grad_norm, param_norm = self._compute_norms() 878 | opt.step() 879 | return True 880 | 881 | def _compute_norms(self, grad_scale=1.0): 882 | grad_norm = 0.0 883 | param_norm = 0.0 884 | for p in self.master_params: 885 | with th.no_grad(): 886 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 887 | if p.grad is not None: 888 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 889 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 890 | 891 | def master_params_to_state_dict(self, master_params): 892 | return master_params_to_state_dict( 893 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 894 | ) 895 | 896 | def state_dict_to_master_params(self, state_dict): 897 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 898 | 899 | def eval_ddpm_1d( 900 | model, 901 | dc, 902 | n_sample, 903 | x_0, 904 | step_list_to_append, 905 | device, 906 | cond = None, 907 | M = None, 908 | noise_scale = 1.0 909 | ): 910 | """ 911 | Evaluate DDPM in 1D case 912 | :param model: score function 913 | :param dc: dictionary of diffusion coefficients 914 | :param n_sample: integer of how many trajectories to sample 915 | :param x_0: [N x C x L] tensor 916 | :param step_list_to_append: an ndarry of diffusion steps to append x_t 917 | """ 918 | model.eval() 919 | n_data,C,L = x_0.shape 920 | x_dummy = th.zeros(n_sample,C,L,device=device) 921 | step_dummy = th.zeros(n_sample).type(th.long).to(device) 922 | _,x_T = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x L] 923 | x_t = x_T.clone() # [n_sample x C x L] 924 | x_t_list = ['']*dc['T'] # empty list 925 | for t in range(0,dc['T'])[::-1]: # 999 to 0 926 | # Score function 927 | step = th.full( 928 | size = (n_sample,), 929 | fill_value = t, 930 | device = device, 931 | dtype = th.long) # [n_sample] 932 | with th.no_grad(): 933 | if cond is None: # unconditioned model 934 | eps_t,_ = model(x_t,step) # [n_sample x C x L] 935 | else: 936 | cond_weight = 0.5 937 | eps_cont_d,_ = model(x_t,step,cond.repeat(n_sample,1)) 938 | eps_uncond_d,_ = model(x_t,step,0.0*cond.repeat(n_sample,1)) 939 | # Addup 940 | eps_t = (1+cond_weight)*eps_cont_d - cond_weight*eps_uncond_d # [n_sample x C x L] 941 | betas_t = th.gather( 942 | input = th.from_numpy(dc['betas']).to(device), # [T] 943 | dim = -1, 944 | index = step 945 | ).reshape((-1,1,1)) # [n_sample x 1 x 1] 946 | sqrt_one_minus_alphas_bar_t = th.gather( 947 | input = th.from_numpy(dc['sqrt_one_minus_alphas_bar']).to(device), # [T] 948 | dim = -1, 949 | index = step 950 | ).reshape((-1,1,1)) # [n_sample x 1 x 1] 951 | sqrt_recip_alphas_t = th.gather( 952 | input = th.from_numpy(dc['sqrt_recip_alphas']).to(device), # [T] 953 | dim = -1, 954 | index = step 955 | ).reshape((-1,1,1)) # [n_sample x 1 x 1] 956 | # Compute posterior mean 957 | mean_t = sqrt_recip_alphas_t * ( 958 | x_t - betas_t*eps_t/sqrt_one_minus_alphas_bar_t 959 | ) # [n_sample x C x L] 960 | # Compute posterior variance 961 | posterior_variance_t = th.gather( 962 | input = th.from_numpy(dc['posterior_variance']).to(device), # [T] 963 | dim = -1, 964 | index = step 965 | ).reshape((-1,1,1)) # [n_sample x 1 x 1] 966 | # Sample 967 | if t == 0: # last sampling, use mean 968 | x_t = mean_t 969 | else: 970 | _,noise_t = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x 1] 971 | x_t = mean_t + noise_scale*th.sqrt(posterior_variance_t)*noise_t 972 | # Append 973 | if t in step_list_to_append: 974 | x_t_list[t] = x_t 975 | model.train() 976 | return x_t_list # list of [n_sample x C x L] 977 | 978 | def eval_ddpm_2d( 979 | model, 980 | dc, 981 | n_sample, 982 | x_0, 983 | step_list_to_append, 984 | device, 985 | cond=None, 986 | M=None, 987 | noise_scale=1.0 988 | ): 989 | """ 990 | Evaluate DDPM in 2D case 991 | :param model: score function 992 | :param dc: dictionary of diffusion coefficients 993 | :param n_sample: integer of how many trajectories to sample 994 | :param x_0: [N x C x W x H] tensor 995 | :param step_list_to_append: an ndarry of diffusion steps to append x_t 996 | """ 997 | model.eval() 998 | n_data,C,W,H = x_0.shape 999 | x_dummy = th.zeros(n_sample,C,W,H,device=device) 1000 | step_dummy = th.zeros(n_sample).type(th.long).to(device) 1001 | _,x_T = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x W x H] 1002 | x_t = x_T.clone() # [n_sample x C x W x H] 1003 | x_t_list = ['']*dc['T'] # empty list 1004 | for t in range(0,dc['T'])[::-1]: # 999 to 0 1005 | # Score function 1006 | step = th.full( 1007 | size = (n_sample,), 1008 | fill_value = t, 1009 | device = device, 1010 | dtype = th.long) # [n_sample] 1011 | with th.no_grad(): 1012 | if cond is None: # unconditioned model 1013 | eps_t,_ = model(x_t,step) # [n_sample x C x W x H] 1014 | else: 1015 | cond_weight = 0.5 1016 | eps_cont_d,_ = model(x_t,step,cond.repeat(n_sample,1)) 1017 | eps_uncond_d,_ = model(x_t,step,0.0*cond.repeat(n_sample,1)) 1018 | # Addup 1019 | eps_t = (1+cond_weight)*eps_cont_d - cond_weight*eps_uncond_d # [n_sample x C x W x H] 1020 | betas_t = th.gather( 1021 | input = th.from_numpy(dc['betas']).to(device), # [T] 1022 | dim = -1, 1023 | index = step 1024 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1] 1025 | sqrt_one_minus_alphas_bar_t = th.gather( 1026 | input = th.from_numpy(dc['sqrt_one_minus_alphas_bar']).to(device), # [T] 1027 | dim = -1, 1028 | index = step 1029 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1] 1030 | sqrt_recip_alphas_t = th.gather( 1031 | input = th.from_numpy(dc['sqrt_recip_alphas']).to(device), # [T] 1032 | dim = -1, 1033 | index = step 1034 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1] 1035 | # Compute posterior mean 1036 | mean_t = sqrt_recip_alphas_t * ( 1037 | x_t - betas_t*eps_t/sqrt_one_minus_alphas_bar_t 1038 | ) # [n_sample x C x W x H] 1039 | # Compute posterior variance 1040 | posterior_variance_t = th.gather( 1041 | input = th.from_numpy(dc['posterior_variance']).to(device), # [T] 1042 | dim = -1, 1043 | index = step 1044 | ).reshape((-1,1,1,1)) # [n_sample x 1 x 1 x 1] 1045 | # Sample 1046 | if t == 0: # last sampling, use mean 1047 | x_t = mean_t 1048 | else: 1049 | _,noise_t = forward_sample(x_dummy,step_dummy,dc,M) # [n_sample x C x W x H] 1050 | x_t = mean_t + noise_scale*th.sqrt(posterior_variance_t)*noise_t 1051 | # Append 1052 | if t in step_list_to_append: 1053 | x_t_list[t] = x_t 1054 | model.train() 1055 | return x_t_list # list of [n_sample x C x W x H] -------------------------------------------------------------------------------- /code/diffusion_resblock.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "00d7ebaa", 6 | "metadata": {}, 7 | "source": [ 8 | "### Resdual block for diffusion models" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "c37dccb5", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "PyTorch version:[2.0.1].\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import numpy as np\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import torch as th\n", 29 | "import torch.nn as nn\n", 30 | "import torch.nn.functional as F\n", 31 | "from module import (\n", 32 | " ResBlock\n", 33 | ")\n", 34 | "from dataset import mnist\n", 35 | "from util import get_torch_size_string,plot_4x4_torch_tensor\n", 36 | "np.set_printoptions(precision=3)\n", 37 | "th.set_printoptions(precision=3)\n", 38 | "%matplotlib inline\n", 39 | "%config InlineBackend.figure_format='retina'\n", 40 | "print (\"PyTorch version:[%s].\"%(th.__version__))" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "257ff26d", 46 | "metadata": {}, 47 | "source": [ 48 | "### 1-D case `[B x C x L]`" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 5, 54 | "id": "abd8843c", 55 | "metadata": { 56 | "scrolled": true 57 | }, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "1. No upsample nor downsample\n", 64 | " Shape x:[16x32x200] => out:[16x32x400]\n", 65 | "2. Upsample\n", 66 | " Shape x:[16x32x200] => out:[16x32x400]\n", 67 | "3. Downsample\n", 68 | " Shape x:[16x32x200] => out:[16x32x100]\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "# Input\n", 74 | "x = th.randn(16,32,200) # [B x C x L]\n", 75 | "emb = th.randn(16,128) # [B x n_emb_channels]\n", 76 | "\n", 77 | "print (\"1. No upsample nor downsample\")\n", 78 | "resblock = ResBlock(\n", 79 | " n_channels = 32,\n", 80 | " n_emb_channels = 128,\n", 81 | " n_out_channels = 32,\n", 82 | " n_groups = 16,\n", 83 | " dims = 1,\n", 84 | " upsample = True,\n", 85 | " downsample = False,\n", 86 | " down_rate = 1\n", 87 | ")\n", 88 | "out = resblock(x,emb)\n", 89 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 90 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 91 | "\n", 92 | "print (\"2. Upsample\")\n", 93 | "resblock = ResBlock(\n", 94 | " n_channels = 32,\n", 95 | " n_emb_channels = 128,\n", 96 | " n_out_channels = 32,\n", 97 | " n_groups = 16,\n", 98 | " dims = 1,\n", 99 | " upsample = True,\n", 100 | " downsample = False,\n", 101 | " down_rate = 2\n", 102 | ")\n", 103 | "out = resblock(x,emb)\n", 104 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 105 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 106 | "\n", 107 | "print (\"3. Downsample\")\n", 108 | "resblock = ResBlock(\n", 109 | " n_channels = 32,\n", 110 | " n_emb_channels = 128,\n", 111 | " n_out_channels = 32,\n", 112 | " n_groups = 16,\n", 113 | " dims = 1,\n", 114 | " upsample = False,\n", 115 | " downsample = True,\n", 116 | " down_rate = 2\n", 117 | ")\n", 118 | "out = resblock(x,emb)\n", 119 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 120 | " (get_torch_size_string(x),get_torch_size_string(out)))" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "01d77ea8", 126 | "metadata": {}, 127 | "source": [ 128 | "### 2-D case `[B x C x W x H]`" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 3, 134 | "id": "890df0a2", 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "1. No upsample nor downsample\n", 142 | " Shape x:[16x32x28x28] => out:[16x32x28x28]\n", 143 | "2. Upsample\n", 144 | " Shape x:[16x32x28x28] => out:[16x32x56x56]\n", 145 | "3. Downsample\n", 146 | " Shape x:[16x32x28x28] => out:[16x32x14x14]\n", 147 | "4. (uneven) Upsample\n", 148 | " Shape x:[16x32x28x28] => out:[16x32x56x28]\n", 149 | "5. (uneven) Downsample\n", 150 | " Shape x:[16x32x28x28] => out:[16x32x14x28]\n", 151 | "6. (fake) Downsample\n", 152 | " Shape x:[16x32x28x28] => out:[16x32x28x28]\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "# Input\n", 158 | "x = th.randn(16,32,28,28) # [B x C x W x H]\n", 159 | "emb = th.randn(16,128) # [B x n_emb_channels]\n", 160 | "\n", 161 | "print (\"1. No upsample nor downsample\")\n", 162 | "resblock = ResBlock(\n", 163 | " n_channels = 32,\n", 164 | " n_emb_channels = 128,\n", 165 | " n_out_channels = 32,\n", 166 | " n_groups = 16,\n", 167 | " dims = 2,\n", 168 | " upsample = False,\n", 169 | " downsample = False\n", 170 | ")\n", 171 | "out = resblock(x,emb)\n", 172 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 173 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 174 | "\n", 175 | "print (\"2. Upsample\")\n", 176 | "resblock = ResBlock(\n", 177 | " n_channels = 32,\n", 178 | " n_emb_channels = 128,\n", 179 | " n_out_channels = 32,\n", 180 | " n_groups = 16,\n", 181 | " dims = 2,\n", 182 | " upsample = True,\n", 183 | " downsample = False\n", 184 | ")\n", 185 | "out = resblock(x,emb)\n", 186 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 187 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 188 | "\n", 189 | "print (\"3. Downsample\")\n", 190 | "resblock = ResBlock(\n", 191 | " n_channels = 32,\n", 192 | " n_emb_channels = 128,\n", 193 | " n_out_channels = 32,\n", 194 | " n_groups = 16,\n", 195 | " dims = 2,\n", 196 | " upsample = False,\n", 197 | " downsample = True\n", 198 | ")\n", 199 | "out = resblock(x,emb)\n", 200 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 201 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 202 | "\n", 203 | "print (\"4. (uneven) Upsample\")\n", 204 | "resblock = ResBlock(\n", 205 | " n_channels = 32,\n", 206 | " n_emb_channels = 128,\n", 207 | " n_out_channels = 32,\n", 208 | " n_groups = 16,\n", 209 | " dims = 2,\n", 210 | " upsample = True,\n", 211 | " downsample = False,\n", 212 | " up_rate = (2,1)\n", 213 | ")\n", 214 | "out = resblock(x,emb)\n", 215 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 216 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 217 | "\n", 218 | "print (\"5. (uneven) Downsample\")\n", 219 | "resblock = ResBlock(\n", 220 | " n_channels = 32,\n", 221 | " n_emb_channels = 128,\n", 222 | " n_out_channels = 32,\n", 223 | " n_groups = 16,\n", 224 | " dims = 2,\n", 225 | " upsample = False,\n", 226 | " downsample = True,\n", 227 | " down_rate = (2,1)\n", 228 | ")\n", 229 | "out = resblock(x,emb)\n", 230 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 231 | " (get_torch_size_string(x),get_torch_size_string(out)))\n", 232 | "\n", 233 | "print (\"6. (fake) Downsample\")\n", 234 | "resblock = ResBlock(\n", 235 | " n_channels = 32,\n", 236 | " n_emb_channels = 128,\n", 237 | " n_out_channels = 32,\n", 238 | " n_groups = 16,\n", 239 | " dims = 2,\n", 240 | " upsample = False,\n", 241 | " downsample = True,\n", 242 | " down_rate = (1,1)\n", 243 | ")\n", 244 | "out = resblock(x,emb)\n", 245 | "print (\" Shape x:[%s] => out:[%s]\"%\n", 246 | " (get_torch_size_string(x),get_torch_size_string(out)))" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "id": "184ebe32", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [] 256 | } 257 | ], 258 | "metadata": { 259 | "kernelspec": { 260 | "display_name": "Python 3 (ipykernel)", 261 | "language": "python", 262 | "name": "python3" 263 | }, 264 | "language_info": { 265 | "codemirror_mode": { 266 | "name": "ipython", 267 | "version": 3 268 | }, 269 | "file_extension": ".py", 270 | "mimetype": "text/x-python", 271 | "name": "python", 272 | "nbconvert_exporter": "python", 273 | "pygments_lexer": "ipython3", 274 | "version": "3.9.16" 275 | } 276 | }, 277 | "nbformat": 4, 278 | "nbformat_minor": 5 279 | } 280 | -------------------------------------------------------------------------------- /code/diffusion_unet_legacy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e331df79-2400-4f60-aca2-15debac3e5de", 6 | "metadata": {}, 7 | "source": [ 8 | "### U-net Lagacy" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "a1021c53-7ada-41a7-86fc-ce2ec7faada5", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "PyTorch version:[2.0.1].\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import numpy as np\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import torch as th\n", 29 | "import torch.nn as nn\n", 30 | "from util import (\n", 31 | " get_torch_size_string\n", 32 | ")\n", 33 | "from diffusion import (\n", 34 | " get_ddpm_constants,\n", 35 | " plot_ddpm_constants,\n", 36 | " DiffusionUNet,\n", 37 | " DiffusionUNetLegacy\n", 38 | ")\n", 39 | "from dataset import mnist\n", 40 | "np.set_printoptions(precision=3)\n", 41 | "th.set_printoptions(precision=3)\n", 42 | "%matplotlib inline\n", 43 | "%config InlineBackend.figure_format='retina'\n", 44 | "print (\"PyTorch version:[%s].\"%(th.__version__))" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "bc376285-0fa0-4a0b-aacb-f5b49b0e2f5e", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "device:[mps]\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "device = 'mps'\n", 63 | "print (\"device:[%s]\"%(device))" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "id": "7c84b758-2535-439c-936b-1a20397d7598", 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "Ready.\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "dc = get_ddpm_constants(\n", 82 | " schedule_name = 'cosine', # 'linear', 'cosine'\n", 83 | " T = 1000,\n", 84 | " np_type = np.float32)\n", 85 | "print(\"Ready.\") " 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "d8641679-0ee9-4f4b-9248-23e0c2e0de47", 91 | "metadata": {}, 92 | "source": [ 93 | "### Guided U-net\n", 94 | "" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "id": "08f6f417-81a6-49fe-817c-663acd47414a", 100 | "metadata": {}, 101 | "source": [ 102 | "### 1-D case: `[B x C x L]` with attention" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "id": "57655fad-1bcb-4a5c-9626-ded7373e2b9e", 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "Input: x:[2x3x256] timesteps:[2]\n", 116 | "Output: out:[2x3x256]\n", 117 | "[ 0] key:[ x] shape:[ 2x3x256]\n", 118 | "[ 1] key:[ x_lifted] shape:[ 2x32x256]\n", 119 | "[ 2] key:[h_enc_res_00] shape:[ 2x32x128]\n", 120 | "[ 3] key:[h_enc_att_01] shape:[ 2x32x128]\n", 121 | "[ 4] key:[h_enc_res_02] shape:[ 2x64x64]\n", 122 | "[ 5] key:[h_enc_att_03] shape:[ 2x64x64]\n", 123 | "[ 6] key:[h_enc_res_04] shape:[ 2x128x32]\n", 124 | "[ 7] key:[h_enc_att_05] shape:[ 2x128x32]\n", 125 | "[ 8] key:[h_enc_res_06] shape:[ 2x256x16]\n", 126 | "[ 9] key:[h_enc_att_07] shape:[ 2x256x16]\n", 127 | "[10] key:[h_dec_res_00] shape:[ 2x256x32]\n", 128 | "[11] key:[h_dec_att_01] shape:[ 2x256x32]\n", 129 | "[12] key:[h_dec_res_02] shape:[ 2x128x64]\n", 130 | "[13] key:[h_dec_att_03] shape:[ 2x128x64]\n", 131 | "[14] key:[h_dec_res_04] shape:[ 2x64x128]\n", 132 | "[15] key:[h_dec_att_05] shape:[ 2x64x128]\n", 133 | "[16] key:[h_dec_res_06] shape:[ 2x32x256]\n", 134 | "[17] key:[h_dec_att_07] shape:[ 2x32x256]\n", 135 | "[18] key:[ out] shape:[ 2x3x256]\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "unet = DiffusionUNetLegacy(\n", 141 | " name = 'unet',\n", 142 | " dims = 1,\n", 143 | " n_in_channels = 3,\n", 144 | " n_base_channels = 32,\n", 145 | " n_emb_dim = 128,\n", 146 | " n_enc_blocks = 4, # number of encoder blocks\n", 147 | " n_dec_blocks = 4, # number of decoder blocks\n", 148 | " n_groups = 16, # group norm paramter\n", 149 | " use_attention = True,\n", 150 | " skip_connection = True, # additional skip connection\n", 151 | " chnnel_multiples = (1,2,4,8),\n", 152 | " updown_rates = (2,2,2,2),\n", 153 | " device = device,\n", 154 | ")\n", 155 | "# Inputs, timesteps:[B] and x:[B x C x L]\n", 156 | "batch_size = 2\n", 157 | "x = th.randn(batch_size,3,256).to(device) # [B x C x L]\n", 158 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n", 159 | "out,intermediate_output_dict = unet(x,timesteps)\n", 160 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n", 161 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n", 162 | "))\n", 163 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n", 164 | "# Print intermediate layers\n", 165 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n", 166 | " z = intermediate_output_dict[key]\n", 167 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "id": "05572674-6d35-46a6-a880-d41e4b4e5111", 173 | "metadata": {}, 174 | "source": [ 175 | "### 1-D case: `[B x C x L]` without attention" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 5, 181 | "id": "6962a40a-0c14-4c4a-a024-06b64b076df5", 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "Input: x:[2x3x256] timesteps:[2]\n", 189 | "Output: out:[2x3x256]\n", 190 | "[ 0] key:[ x] shape:[ 2x3x256]\n", 191 | "[ 1] key:[ x_lifted] shape:[ 2x32x256]\n", 192 | "[ 2] key:[h_enc_res_00] shape:[ 2x32x128]\n", 193 | "[ 3] key:[h_enc_res_01] shape:[ 2x64x64]\n", 194 | "[ 4] key:[h_enc_res_02] shape:[ 2x128x32]\n", 195 | "[ 5] key:[h_enc_res_03] shape:[ 2x256x16]\n", 196 | "[ 6] key:[h_dec_res_00] shape:[ 2x256x32]\n", 197 | "[ 7] key:[h_dec_res_01] shape:[ 2x128x64]\n", 198 | "[ 8] key:[h_dec_res_02] shape:[ 2x64x128]\n", 199 | "[ 9] key:[h_dec_res_03] shape:[ 2x32x256]\n", 200 | "[10] key:[ out] shape:[ 2x3x256]\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "unet = DiffusionUNetLegacy(\n", 206 | " name = 'unet',\n", 207 | " dims = 1,\n", 208 | " n_in_channels = 3,\n", 209 | " n_base_channels = 32,\n", 210 | " n_emb_dim = 128,\n", 211 | " n_enc_blocks = 4, # number of encoder blocks\n", 212 | " n_dec_blocks = 4, # number of decoder blocks\n", 213 | " n_groups = 16, # group norm paramter\n", 214 | " use_attention = False,\n", 215 | " skip_connection = True, # additional skip connection\n", 216 | " chnnel_multiples = (1,2,4,8),\n", 217 | " updown_rates = (2,2,2,2),\n", 218 | " device = device,\n", 219 | ")\n", 220 | "# Inputs, timesteps:[B] and x:[B x C x L]\n", 221 | "batch_size = 2\n", 222 | "x = th.randn(batch_size,3,256).to(device) # [B x C x L]\n", 223 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n", 224 | "out,intermediate_output_dict = unet(x,timesteps)\n", 225 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n", 226 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n", 227 | "))\n", 228 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n", 229 | "# Print intermediate layers\n", 230 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n", 231 | " z = intermediate_output_dict[key]\n", 232 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "1881fc4e-d713-4c24-955c-85e3c0ece8ff", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "id": "c7143a40-7d2f-46e6-a6b6-a089a1227bd7", 246 | "metadata": {}, 247 | "source": [ 248 | "### 2-D case: `[B x C x W x H]` without attention" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 6, 254 | "id": "0c3da514-b189-4c01-ba6f-1f2e8daa24b0", 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "Input: x:[2x3x256x256] timesteps:[2]\n", 262 | "Output: out:[2x3x256x256]\n", 263 | "[ 0] key:[ x] shape:[ 2x3x256x256]\n", 264 | "[ 1] key:[ x_lifted] shape:[2x32x256x256]\n", 265 | "[ 2] key:[h_enc_res_00] shape:[2x32x256x256]\n", 266 | "[ 3] key:[h_enc_res_01] shape:[2x64x256x256]\n", 267 | "[ 4] key:[h_enc_res_02] shape:[2x128x256x256]\n", 268 | "[ 5] key:[h_dec_res_00] shape:[2x128x256x256]\n", 269 | "[ 6] key:[h_dec_res_01] shape:[2x64x256x256]\n", 270 | "[ 7] key:[h_dec_res_02] shape:[2x32x256x256]\n", 271 | "[ 8] key:[ out] shape:[ 2x3x256x256]\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "unet = DiffusionUNetLegacy(\n", 277 | " name = 'unet',\n", 278 | " dims = 2,\n", 279 | " n_in_channels = 3,\n", 280 | " n_base_channels = 32,\n", 281 | " n_emb_dim = 128,\n", 282 | " n_enc_blocks = 3, # number of encoder blocks\n", 283 | " n_dec_blocks = 3, # number of decoder blocks\n", 284 | " n_groups = 16, # group norm paramter\n", 285 | " use_attention = False,\n", 286 | " skip_connection = True, # additional skip connection\n", 287 | " chnnel_multiples = (1,2,4),\n", 288 | " updown_rates = (1,1,1),\n", 289 | " device = device,\n", 290 | ")\n", 291 | "# Inputs, timesteps:[B] and x:[B x C x W x H]\n", 292 | "batch_size = 2\n", 293 | "x = th.randn(batch_size,3,256,256).to(device) # [B x C x W x H]\n", 294 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n", 295 | "out,intermediate_output_dict = unet(x,timesteps)\n", 296 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n", 297 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n", 298 | "))\n", 299 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n", 300 | "# Print intermediate layers\n", 301 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n", 302 | " z = intermediate_output_dict[key]\n", 303 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "id": "66e7d820-fc72-41c3-88da-09e70147c3e0", 309 | "metadata": {}, 310 | "source": [ 311 | "### 2-D case: `[B x C x W x H]` without attention + updown pooling" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 7, 317 | "id": "12680faa-9bf0-4ef9-b65f-08952641dfa3", 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "name": "stdout", 322 | "output_type": "stream", 323 | "text": [ 324 | "Input: x:[2x3x256x256] timesteps:[2]\n", 325 | "Output: out:[2x3x256x256]\n", 326 | "[ 0] key:[ x] shape:[ 2x3x256x256]\n", 327 | "[ 1] key:[ x_lifted] shape:[2x32x256x256]\n", 328 | "[ 2] key:[h_enc_res_00] shape:[2x32x256x256]\n", 329 | "[ 3] key:[h_enc_res_01] shape:[2x64x128x128]\n", 330 | "[ 4] key:[h_enc_res_02] shape:[ 2x128x64x64]\n", 331 | "[ 5] key:[h_dec_res_00] shape:[2x128x128x128]\n", 332 | "[ 6] key:[h_dec_res_01] shape:[2x64x256x256]\n", 333 | "[ 7] key:[h_dec_res_02] shape:[2x32x256x256]\n", 334 | "[ 8] key:[ out] shape:[ 2x3x256x256]\n" 335 | ] 336 | } 337 | ], 338 | "source": [ 339 | "unet = DiffusionUNetLegacy(\n", 340 | " name = 'unet',\n", 341 | " dims = 2,\n", 342 | " n_in_channels = 3,\n", 343 | " n_base_channels = 32,\n", 344 | " n_emb_dim = 128,\n", 345 | " n_enc_blocks = 3, # number of encoder blocks\n", 346 | " n_dec_blocks = 3, # number of decoder blocks\n", 347 | " n_groups = 16, # group norm paramter\n", 348 | " use_attention = False,\n", 349 | " skip_connection = True, # additional skip connection\n", 350 | " chnnel_multiples = (1,2,4),\n", 351 | " updown_rates = (1,2,2),\n", 352 | " device = device,\n", 353 | ")\n", 354 | "# Inputs, timesteps:[B] and x:[B x C x W x H]\n", 355 | "batch_size = 2\n", 356 | "x = th.randn(batch_size,3,256,256).to(device) # [B x C x W x H]\n", 357 | "timesteps = th.linspace(1,dc['T'],batch_size).to(th.int64).to(device) # [B]\n", 358 | "out,intermediate_output_dict = unet(x,timesteps)\n", 359 | "print (\"Input: x:[%s] timesteps:[%s]\"%(\n", 360 | " get_torch_size_string(x),get_torch_size_string(timesteps)\n", 361 | "))\n", 362 | "print (\"Output: out:[%s]\"%(get_torch_size_string(out)))\n", 363 | "# Print intermediate layers\n", 364 | "for k_idx,key in enumerate(intermediate_output_dict.keys()):\n", 365 | " z = intermediate_output_dict[key]\n", 366 | " print (\"[%2d] key:[%12s] shape:[%12s]\"%(k_idx,key,get_torch_size_string(z)))" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "id": "b9bf9fc1-90f1-4700-a4b1-360f730d0e51", 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [] 376 | } 377 | ], 378 | "metadata": { 379 | "kernelspec": { 380 | "display_name": "Python 3 (ipykernel)", 381 | "language": "python", 382 | "name": "python3" 383 | }, 384 | "language_info": { 385 | "codemirror_mode": { 386 | "name": "ipython", 387 | "version": 3 388 | }, 389 | "file_extension": ".py", 390 | "mimetype": "text/x-python", 391 | "name": "python", 392 | "nbconvert_exporter": "python", 393 | "pygments_lexer": "ipython3", 394 | "version": "3.9.16" 395 | } 396 | }, 397 | "nbformat": 4, 398 | "nbformat_minor": 5 399 | } 400 | -------------------------------------------------------------------------------- /code/mdn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | from util import th2np 6 | 7 | def get_argmax_mu(pi,mu): 8 | """ 9 | :param pi: [N x K x D] 10 | :param mu: [N x K x D] 11 | """ 12 | max_idx = th.argmax(pi,dim=1) # [N x D] 13 | argmax_mu = th.gather(input=mu,dim=1,index=max_idx.unsqueeze(dim=1)).squeeze(dim=1) # [N x D] 14 | return argmax_mu 15 | 16 | def gmm_forward(pi,mu,sigma,data): 17 | """ 18 | Compute Gaussian mixture model probability 19 | 20 | :param pi: GMM mixture weights [N x K x D] 21 | :param mu: GMM means [N x K x D] 22 | :param sigma: GMM stds [N x K x D] 23 | :param data: data [N x D] 24 | """ 25 | data_usq = th.unsqueeze(data,1) # [N x 1 x D] 26 | data_exp = data_usq.expand_as(sigma) # [N x K x D] 27 | ONEOVERSQRT2PI = 1/np.sqrt(2*np.pi) 28 | probs = ONEOVERSQRT2PI * th.exp(-0.5 * ((data_exp-mu)/sigma)**2) / sigma # [N x K x D] 29 | probs = probs*pi # [N x K x D] 30 | probs = th.sum(probs,dim=1) # [N x D] 31 | log_probs = th.log(probs) # [N x D] 32 | log_probs = th.sum(log_probs,dim=1) # [N] 33 | nlls = -log_probs # [N] 34 | 35 | # Most probable mean [N x D] 36 | argmax_mu = get_argmax_mu(pi,mu) # [N x D] 37 | 38 | out = { 39 | 'data_usq':data_usq,'data_exp':data_exp, 40 | 'probs':probs,'log_probs':log_probs,'nlls':nlls,'argmax_mu':argmax_mu 41 | } 42 | return out 43 | 44 | def gmm_uncertainties(pi, mu, sigma): 45 | """ 46 | :param pi: [N x K x D] 47 | :param mu: [N x K x D] 48 | :param sigma: [N x K x D] 49 | """ 50 | # Compute Epistemic Uncertainty 51 | mu_avg = th.sum(th.mul(pi,mu),dim=1).unsqueeze(1) # [N x 1 x D] 52 | mu_exp = mu_avg.expand_as(mu) # [N x K x D] 53 | mu_diff_sq = th.square(mu-mu_exp) # [N x K x D] 54 | epis_unct = th.sum(th.mul(pi,mu_diff_sq), dim=1) # [N x D] 55 | 56 | # Compute Aleatoric Uncertainty 57 | alea_unct = th.sum(th.mul(pi,sigma), dim=1) # [N x D] 58 | 59 | # (Optional) sqrt operation helps scaling 60 | epis_unct,alea_unct = th.sqrt(epis_unct),th.sqrt(alea_unct) 61 | return epis_unct,alea_unct 62 | 63 | class MixturesOfGaussianLayer(nn.Module): 64 | def __init__( 65 | self, 66 | in_dim, 67 | out_dim, 68 | k, 69 | sig_max=None 70 | ): 71 | super(MixturesOfGaussianLayer,self).__init__() 72 | self.in_dim = in_dim 73 | self.out_dim = out_dim 74 | self.k = k 75 | self.sig_max = sig_max 76 | 77 | # Netowrks 78 | self.fc_pi = nn.Linear(self.in_dim,self.k*self.out_dim) 79 | self.fc_mu = nn.Linear(self.in_dim,self.k*self.out_dim) 80 | self.fc_sigma = nn.Linear(self.in_dim,self.k*self.out_dim) 81 | 82 | def forward(self,x): 83 | pi_logit = self.fc_pi(x) # [N x KD] 84 | pi_logit = th.reshape(pi_logit,(-1,self.k,self.out_dim)) # [N x K x D] 85 | pi = th.softmax(pi_logit,dim=1) # [N x K x D] 86 | mu = self.fc_mu(x) # [N x KD] 87 | mu = th.reshape(mu,(-1,self.k,self.out_dim)) # [N x K x D] 88 | sigma = self.fc_sigma(x) # [N x KD] 89 | sigma = th.reshape(sigma,(-1,self.k,self.out_dim)) # [N x K x D] 90 | if self.sig_max is None: 91 | sigma = th.exp(sigma) # [N x K x D] 92 | else: 93 | sigma = self.sig_max * th.sigmoid(sigma) # [N x K x D] 94 | return pi,mu,sigma 95 | 96 | class MixtureDensityNetwork(nn.Module): 97 | def __init__( 98 | self, 99 | name = 'mdn', 100 | x_dim = 1, 101 | y_dim = 1, 102 | k = 5, 103 | h_dim_list = [32,32], 104 | actv = nn.ReLU(), 105 | sig_max = 1.0, 106 | mu_min = -3.0, 107 | mu_max = +3.0, 108 | p_drop = 0.1, 109 | use_bn = False, 110 | ): 111 | super(MixtureDensityNetwork,self).__init__() 112 | self.name = name 113 | self.x_dim = x_dim 114 | self.y_dim = y_dim 115 | self.k = k 116 | self.h_dim_list = h_dim_list 117 | self.actv = actv 118 | self.sig_max = sig_max 119 | self.mu_min = mu_min 120 | self.mu_max = mu_max 121 | self.p_drop = p_drop 122 | self.use_bn = use_bn 123 | 124 | # Declare layers 125 | self.layer_list = [] 126 | h_dim_prev = self.x_dim 127 | for h_dim in self.h_dim_list: 128 | # dense -> batchnorm -> actv -> dropout 129 | self.layer_list.append(nn.Linear(h_dim_prev,h_dim)) 130 | if self.use_bn: self.layer_list.append(nn.BatchNorm1d(num_features=h_dim)) # (optional) batchnorm 131 | self.layer_list.append(self.actv) 132 | self.layer_list.append(nn.Dropout1d(p=self.p_drop)) 133 | h_dim_prev = h_dim 134 | self.layer_list.append( 135 | MixturesOfGaussianLayer( 136 | in_dim = h_dim_prev, 137 | out_dim = self.y_dim, 138 | k = self.k, 139 | sig_max = self.sig_max 140 | ) 141 | ) 142 | 143 | # Definie network 144 | self.net = nn.Sequential() 145 | self.layer_names = [] 146 | for l_idx,layer in enumerate(self.layer_list): 147 | layer_name = "%s_%02d"%(type(layer).__name__.lower(),l_idx) 148 | self.layer_names.append(layer_name) 149 | self.net.add_module(layer_name,layer) 150 | 151 | # Initialize parameters 152 | self.init_param(VERBOSE=False) 153 | 154 | def init_param(self,VERBOSE=False): 155 | """ 156 | Initialize parameters 157 | """ 158 | for m_idx,m in enumerate(self.modules()): 159 | if VERBOSE: 160 | print ("[%02d]"%(m_idx)) 161 | if isinstance(m,nn.Conv2d): # init conv 162 | nn.init.kaiming_normal_(m.weight) 163 | nn.init.zeros_(m.bias) 164 | elif isinstance(m,nn.BatchNorm1d): # init BN 165 | nn.init.constant_(m.weight,1.0) 166 | nn.init.constant_(m.bias,0.0) 167 | elif isinstance(m,nn.Linear): # lnit dense 168 | nn.init.kaiming_normal_(m.weight,nonlinearity='relu') 169 | nn.init.zeros_(m.bias) 170 | # (Hueristics) mu bias between mu_min ~ mu_max 171 | self.layer_list[-1].fc_mu.bias.data.uniform_(self.mu_min,self.mu_max) 172 | 173 | def forward(self, x): 174 | """ 175 | Forward propagate 176 | """ 177 | intermediate_output_list = [] 178 | for idx,layer in enumerate(self.net): 179 | x = layer(x) 180 | intermediate_output_list.append(x) 181 | # Final output 182 | final_output = x 183 | return final_output 184 | 185 | def eval_mdn_1d( 186 | mdn, 187 | x_train_np, 188 | y_train_np, 189 | figsize=(12,3), 190 | device='cpu', 191 | ): 192 | # Eval 193 | mdn.eval() 194 | x_margin = 0.2 195 | x_test = th.linspace( 196 | start = x_train_np.min()-x_margin, 197 | end = x_train_np.max()+x_margin, 198 | steps = 300 199 | ).reshape((-1,1)).to(device) 200 | pi_test,mu_test,sigma_test = mdn.forward(x_test) 201 | 202 | # Get the most probable mu 203 | argmax_mu_test = get_argmax_mu(pi_test,mu_test) # [N x D] 204 | 205 | # To numpy array 206 | x_test_np,pi_np,mu_np,sigma_np = th2np(x_test),th2np(pi_test),th2np(mu_test),th2np(sigma_test) 207 | argmax_mu_test_np = th2np(argmax_mu_test) # [N x D] 208 | 209 | # Uncertainties 210 | epis_unct,alea_unct = gmm_uncertainties(pi_test,mu_test,sigma_test) # [N x D] 211 | epis_unct_np,alea_unct_np = th2np(epis_unct),th2np(alea_unct) 212 | 213 | # Plot fitted results 214 | y_dim = y_train_np.shape[1] 215 | plt.figure(figsize=figsize) 216 | cmap = plt.get_cmap('gist_rainbow') 217 | colors = [cmap(ii) for ii in np.linspace(0, 1, mdn.k)] # colors 218 | pi_th = 0.1 219 | for d_idx in range(y_dim): # for each output dimension 220 | plt.subplot(1,y_dim,d_idx+1) 221 | # Plot training data 222 | plt.plot(x_train_np,y_train_np[:,d_idx],'.',color=(0,0,0,0.2),markersize=3, 223 | label="Training Data") 224 | # Plot mixture standard deviations 225 | for k_idx in range(mdn.k): # for each mixture 226 | pi_high_idx = np.where(pi_np[:,k_idx,d_idx] > pi_th)[0] 227 | mu_k = mu_np[:,k_idx,d_idx] 228 | sigma_k = sigma_np[:,k_idx,d_idx] 229 | upper_bound = mu_k + 2*sigma_k 230 | lower_bound = mu_k - 2*sigma_k 231 | plt.fill_between(x_test_np[pi_high_idx,0].squeeze(), 232 | lower_bound[pi_high_idx], 233 | upper_bound[pi_high_idx], 234 | facecolor=colors[k_idx], interpolate=False, alpha=0.3) 235 | # Plot mixture means 236 | for k_idx in range(mdn.k): # for each mixture 237 | pi_high_idx = np.where(pi_np[:,k_idx,d_idx] > pi_th)[0] # [?,] 238 | pi_low_idx = np.where(pi_np[:,k_idx,d_idx] <= pi_th)[0] # [?,] 239 | plt.plot(x_test_np[pi_high_idx,0],mu_np[pi_high_idx,k_idx,d_idx],'-', 240 | color=colors[k_idx],linewidth=1/2) 241 | plt.plot(x_test_np[pi_low_idx,0],mu_np[pi_low_idx,k_idx,d_idx],'-', 242 | color=(0,0,0,0.3),linewidth=1/2) 243 | 244 | # Plot most probable mu 245 | plt.plot(x_test_np[:,0],argmax_mu_test_np[:,d_idx],'-',color='b',linewidth=2, 246 | label="Argmax Mu") 247 | plt.xlim(x_test_np.min(),x_test_np.max()) 248 | plt.legend(loc='lower right',fontsize=8) 249 | plt.title("y_dim:[%d]"%(d_idx),fontsize=10) 250 | plt.show() 251 | 252 | # Plot uncertainties 253 | plt.figure(figsize=figsize) 254 | for d_idx in range(y_dim): # for each output dimension 255 | plt.subplot(1,y_dim,d_idx+1) 256 | plt.plot(x_test_np[:,0],epis_unct_np[:,d_idx],'-',color='r',linewidth=2, 257 | label="Epistemic Uncertainty") 258 | plt.plot(x_test_np[:,0],alea_unct_np[:,d_idx],'-',color='b',linewidth=2, 259 | label="Aleatoric Uncertainty") 260 | plt.xlim(x_test_np.min(),x_test_np.max()) 261 | plt.legend(loc='lower right',fontsize=8) 262 | plt.title("y_dim:[%d]"%(d_idx),fontsize=10) 263 | plt.show() 264 | 265 | # Back to train 266 | mdn.train() 267 | -------------------------------------------------------------------------------- /code/mlp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "95fb33ad", 6 | "metadata": {}, 7 | "source": [ 8 | "### Multi-Layer Perceptron" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "f1e30a24", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "PyTorch version:[2.0.1].\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "import numpy as np\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import torch as th\n", 29 | "import torch.nn as nn\n", 30 | "from dataset import mnist\n", 31 | "from util import (\n", 32 | " get_torch_size_string,\n", 33 | " print_model_parameters,\n", 34 | " print_model_layers,\n", 35 | " model_train,\n", 36 | " model_eval,\n", 37 | " model_test\n", 38 | ")\n", 39 | "np.set_printoptions(precision=3)\n", 40 | "th.set_printoptions(precision=3)\n", 41 | "%matplotlib inline\n", 42 | "%config InlineBackend.figure_format='retina'\n", 43 | "print (\"PyTorch version:[%s].\"%(th.__version__))" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "bd42312f", 49 | "metadata": {}, 50 | "source": [ 51 | "### Hyperparameters" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "id": "339cebfd", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "Ready.\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "device = 'cpu'\n", 70 | "n_epoch = 20\n", 71 | "batch_size = 128\n", 72 | "print_every = 1\n", 73 | "print (\"Ready.\")" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "id": "2aac096e", 79 | "metadata": {}, 80 | "source": [ 81 | "### Dataset" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "id": "95ebd370", 88 | "metadata": { 89 | "scrolled": true 90 | }, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "MNIST ready.\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "train_iter,test_iter,train_data,train_label,test_data,test_label = \\\n", 102 | " mnist(root_path='../data',batch_size=128)\n", 103 | "print (\"MNIST ready.\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "72dc751b", 109 | "metadata": {}, 110 | "source": [ 111 | "### Model" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 4, 117 | "id": "dfacec51", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "Ready.\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "class MultiLayerPerceptronClass(th.nn.Module):\n", 130 | " def __init__(\n", 131 | " self,\n", 132 | " name = 'mlp',\n", 133 | " x_dim = 784,\n", 134 | " h_dim_list = [256,256],\n", 135 | " y_dim = 10,\n", 136 | " actv = nn.ReLU(),\n", 137 | " p_drop = 0.2\n", 138 | " ):\n", 139 | " \"\"\"\n", 140 | " Initialize MLP\n", 141 | " \"\"\"\n", 142 | " super(MultiLayerPerceptronClass,self).__init__()\n", 143 | " self.name = name\n", 144 | " self.x_dim = x_dim\n", 145 | " self.h_dim_list = h_dim_list\n", 146 | " self.y_dim = y_dim\n", 147 | " self.actv = actv\n", 148 | " self.p_drop = p_drop\n", 149 | " \n", 150 | " # Declare layers\n", 151 | " self.layer_list = []\n", 152 | " h_dim_prev = self.x_dim\n", 153 | " for h_dim in self.h_dim_list:\n", 154 | " # dense -> batchnorm -> actv -> dropout\n", 155 | " self.layer_list.append(nn.Linear(h_dim_prev,h_dim))\n", 156 | " self.layer_list.append(nn.BatchNorm1d(num_features=h_dim))\n", 157 | " self.layer_list.append(self.actv)\n", 158 | " self.layer_list.append(nn.Dropout1d(p=self.p_drop))\n", 159 | " h_dim_prev = h_dim\n", 160 | " self.layer_list.append(nn.Linear(h_dim_prev,self.y_dim))\n", 161 | " \n", 162 | " # Define net\n", 163 | " self.net = nn.Sequential()\n", 164 | " self.layer_names = []\n", 165 | " for l_idx,layer in enumerate(self.layer_list):\n", 166 | " layer_name = \"%s_%02d\"%(type(layer).__name__.lower(),l_idx)\n", 167 | " self.layer_names.append(layer_name)\n", 168 | " self.net.add_module(layer_name,layer)\n", 169 | " \n", 170 | " # Initialize parameters\n", 171 | " self.init_param(VERBOSE=False)\n", 172 | " \n", 173 | " def init_param(self,VERBOSE=False):\n", 174 | " \"\"\"\n", 175 | " Initialize parameters\n", 176 | " \"\"\"\n", 177 | " for m_idx,m in enumerate(self.modules()):\n", 178 | " if VERBOSE:\n", 179 | " print (\"[%02d]\"%(m_idx))\n", 180 | " if isinstance(m,nn.Conv2d): # init conv\n", 181 | " nn.init.kaiming_normal_(m.weight)\n", 182 | " nn.init.zeros_(m.bias)\n", 183 | " elif isinstance(m,nn.BatchNorm1d): # init BN\n", 184 | " nn.init.constant_(m.weight,1.0)\n", 185 | " nn.init.constant_(m.bias,0.0)\n", 186 | " elif isinstance(m,nn.Linear): # lnit dense\n", 187 | " nn.init.kaiming_normal_(m.weight,nonlinearity='relu')\n", 188 | " nn.init.zeros_(m.bias)\n", 189 | " \n", 190 | " def forward(self,x):\n", 191 | " \"\"\"\n", 192 | " Forward propagate\n", 193 | " \"\"\"\n", 194 | " intermediate_output_list = []\n", 195 | " for idx,layer in enumerate(self.net):\n", 196 | " x = layer(x)\n", 197 | " intermediate_output_list.append(x)\n", 198 | " # Final output\n", 199 | " final_output = x\n", 200 | " return final_output,intermediate_output_list\n", 201 | " \n", 202 | "print (\"Ready.\") " 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 5, 208 | "id": "fd93b90d", 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "name": "stdout", 213 | "output_type": "stream", 214 | "text": [ 215 | "Ready.\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "mlp = MultiLayerPerceptronClass(\n", 221 | " name = 'mlp',\n", 222 | " x_dim = 784,\n", 223 | " h_dim_list = [512,256],\n", 224 | " y_dim = 10,\n", 225 | " actv = nn.ReLU(),\n", 226 | " p_drop = 0.25\n", 227 | ").to(device)\n", 228 | "loss = nn.CrossEntropyLoss()\n", 229 | "optm = th.optim.Adam(mlp.parameters(),lr=1e-3)\n", 230 | "print (\"Ready.\")" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "id": "f6c04e91", 236 | "metadata": {}, 237 | "source": [ 238 | "### Print model parameters" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 6, 244 | "id": "2eafd659", 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "[ 0] parameter:[ net.linear_00.weight] shape:[ 512x784] numel:[ 401408]\n", 252 | "[ 1] parameter:[ net.linear_00.bias] shape:[ 512] numel:[ 512]\n", 253 | "[ 2] parameter:[ net.batchnorm1d_01.weight] shape:[ 512] numel:[ 512]\n", 254 | "[ 3] parameter:[ net.batchnorm1d_01.bias] shape:[ 512] numel:[ 512]\n", 255 | "[ 4] parameter:[ net.linear_04.weight] shape:[ 256x512] numel:[ 131072]\n", 256 | "[ 5] parameter:[ net.linear_04.bias] shape:[ 256] numel:[ 256]\n", 257 | "[ 6] parameter:[ net.batchnorm1d_05.weight] shape:[ 256] numel:[ 256]\n", 258 | "[ 7] parameter:[ net.batchnorm1d_05.bias] shape:[ 256] numel:[ 256]\n", 259 | "[ 8] parameter:[ net.linear_08.weight] shape:[ 10x256] numel:[ 2560]\n", 260 | "[ 9] parameter:[ net.linear_08.bias] shape:[ 10] numel:[ 10]\n" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "print_model_parameters(mlp)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "id": "96310c96", 271 | "metadata": {}, 272 | "source": [ 273 | "### Print model layers" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 7, 279 | "id": "9161988d", 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "batch_size:[16]\n", 287 | "[ ] layer:[ input] size:[ 16x784]\n", 288 | "[ 0] layer:[ linear_00] size:[ 16x512] numel:[ 8192]\n", 289 | "[ 1] layer:[ batchnorm1d_01] size:[ 16x512] numel:[ 8192]\n", 290 | "[ 2] layer:[ relu_02] size:[ 16x512] numel:[ 8192]\n", 291 | "[ 3] layer:[ dropout1d_03] size:[ 16x512] numel:[ 8192]\n", 292 | "[ 4] layer:[ linear_04] size:[ 16x256] numel:[ 4096]\n", 293 | "[ 5] layer:[ batchnorm1d_05] size:[ 16x256] numel:[ 4096]\n", 294 | "[ 6] layer:[ relu_06] size:[ 16x256] numel:[ 4096]\n", 295 | "[ 7] layer:[ dropout1d_07] size:[ 16x256] numel:[ 4096]\n", 296 | "[ 8] layer:[ linear_08] size:[ 16x10] numel:[ 160]\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "x_torch = th.randn(16,mlp.x_dim).to(device)\n", 302 | "print_model_layers(mlp,x_torch)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "id": "5da9aaf1", 308 | "metadata": {}, 309 | "source": [ 310 | "### Train MLP" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 8, 316 | "id": "c73c3242", 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "epoch:[ 0/20] loss:[1.137] train_accr:[0.9704] test_accr:[0.9631].\n", 324 | "epoch:[ 1/20] loss:[1.073] train_accr:[0.9839] test_accr:[0.9727].\n", 325 | "epoch:[ 2/20] loss:[1.050] train_accr:[0.9869] test_accr:[0.9746].\n", 326 | "epoch:[ 3/20] loss:[1.041] train_accr:[0.9889] test_accr:[0.9783].\n", 327 | "epoch:[ 4/20] loss:[1.037] train_accr:[0.9880] test_accr:[0.9772].\n", 328 | "epoch:[ 5/20] loss:[1.032] train_accr:[0.9929] test_accr:[0.9798].\n", 329 | "epoch:[ 6/20] loss:[1.024] train_accr:[0.9918] test_accr:[0.9754].\n", 330 | "epoch:[ 7/20] loss:[1.031] train_accr:[0.9926] test_accr:[0.9779].\n", 331 | "epoch:[ 8/20] loss:[1.028] train_accr:[0.9950] test_accr:[0.9779].\n", 332 | "epoch:[ 9/20] loss:[1.030] train_accr:[0.9955] test_accr:[0.9807].\n", 333 | "epoch:[10/20] loss:[1.017] train_accr:[0.9961] test_accr:[0.9812].\n", 334 | "epoch:[11/20] loss:[1.014] train_accr:[0.9968] test_accr:[0.9817].\n", 335 | "epoch:[12/20] loss:[1.016] train_accr:[0.9935] test_accr:[0.9752].\n", 336 | "epoch:[13/20] loss:[1.023] train_accr:[0.9964] test_accr:[0.9790].\n", 337 | "epoch:[14/20] loss:[1.022] train_accr:[0.9974] test_accr:[0.9805].\n", 338 | "epoch:[15/20] loss:[1.018] train_accr:[0.9972] test_accr:[0.9800].\n", 339 | "epoch:[16/20] loss:[1.005] train_accr:[0.9965] test_accr:[0.9794].\n", 340 | "epoch:[17/20] loss:[1.014] train_accr:[0.9973] test_accr:[0.9792].\n", 341 | "epoch:[18/20] loss:[1.012] train_accr:[0.9970] test_accr:[0.9794].\n", 342 | "epoch:[19/20] loss:[1.015] train_accr:[0.9984] test_accr:[0.9821].\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "model_train(mlp,optm,loss,train_iter,test_iter,n_epoch,print_every,device)" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "id": "0b8ec7e0", 353 | "metadata": {}, 354 | "source": [ 355 | "### Test MLP" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 9, 361 | "id": "619bd29c", 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "data": { 366 | "image/png": "", 367 | "text/plain": [ 368 | "
" 369 | ] 370 | }, 371 | "metadata": { 372 | "image/png": { 373 | "height": 558, 374 | "width": 488 375 | } 376 | }, 377 | "output_type": "display_data" 378 | } 379 | ], 380 | "source": [ 381 | "model_test(mlp,test_data,test_label,device)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "id": "b557e4ed", 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [] 391 | } 392 | ], 393 | "metadata": { 394 | "kernelspec": { 395 | "display_name": "Python 3 (ipykernel)", 396 | "language": "python", 397 | "name": "python3" 398 | }, 399 | "language_info": { 400 | "codemirror_mode": { 401 | "name": "ipython", 402 | "version": 3 403 | }, 404 | "file_extension": ".py", 405 | "mimetype": "text/x-python", 406 | "name": "python", 407 | "nbconvert_exporter": "python", 408 | "pygments_lexer": "ipython3", 409 | "version": "3.9.16" 410 | } 411 | }, 412 | "nbformat": 4, 413 | "nbformat_minor": 5 414 | } 415 | -------------------------------------------------------------------------------- /code/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from abc import abstractmethod 6 | 7 | class TimestepBlock(nn.Module): 8 | """ 9 | Any module where forward() takes timestep embeddings as a second argument. 10 | """ 11 | 12 | @abstractmethod 13 | def forward(self, x, emb): 14 | """ 15 | Apply the module to `x` given `emb` timestep embeddings. 16 | """ 17 | 18 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 19 | """ 20 | A sequential module that passes timestep embeddings to the children that 21 | support it as an extra input. 22 | """ 23 | 24 | def forward(self, x, emb): 25 | for layer in self: 26 | if isinstance(layer, TimestepBlock): 27 | x = layer(x, emb) 28 | else: 29 | x = layer(x) 30 | return x 31 | 32 | def zero_module(module): 33 | """ 34 | Zero out the parameters of a module and return it. 35 | """ 36 | for p in module.parameters(): 37 | p.detach().zero_() 38 | return module 39 | 40 | def conv_nd(dims, *args, **kwargs): 41 | """ 42 | Create a 1D, 2D, or 3D convolution module. 43 | """ 44 | if dims == 1: 45 | return nn.Conv1d(*args, **kwargs) 46 | elif dims == 2: 47 | return nn.Conv2d(*args, **kwargs) 48 | elif dims == 3: 49 | return nn.Conv3d(*args, **kwargs) 50 | raise ValueError(f"unsupported dimensions: {dims}") 51 | 52 | def avg_pool_nd(dims, *args, **kwargs): 53 | """ 54 | Create a 1D, 2D, or 3D average pooling module. 55 | """ 56 | if dims == 1: 57 | return nn.AvgPool1d(*args, **kwargs) 58 | elif dims == 2: 59 | return nn.AvgPool2d(*args, **kwargs) 60 | elif dims == 3: 61 | return nn.AvgPool3d(*args, **kwargs) 62 | raise ValueError(f"unsupported dimensions: {dims}") 63 | 64 | class GroupNorm32(nn.GroupNorm): 65 | def forward(self, x): 66 | return super().forward(x.float()).type(x.dtype) 67 | 68 | def normalization(n_channels,n_groups=1): 69 | """ 70 | Make a standard normalization layer. 71 | 72 | :param n_channels: number of input channels. 73 | :param n_groups: number of groups. if this is 1, then it is identical to layernorm. 74 | :return: an nn.Module for normalization. 75 | """ 76 | return GroupNorm32(num_groups=n_groups,num_channels=n_channels) 77 | 78 | class Upsample(nn.Module): 79 | """ 80 | An upsampling layer with an optional convolution. 81 | 82 | :param n_channels: number of channels in the inputs and outputs. 83 | :param use_conv: a bool determining if a convolution is applied. 84 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 85 | upsampling occurs in the inner-two dimensions. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | n_channels, 91 | up_rate = 2, # upsample rate 92 | up_mode = 'nearest', # upsample mode ('nearest' or 'bilinear') 93 | use_conv = False, # (optional) use output conv 94 | dims = 2, # (optional) spatial dimension 95 | n_out_channels = None, # (optional) in case output channels are different from the input 96 | padding_mode = 'zeros', 97 | padding = 1 98 | ): 99 | super().__init__() 100 | self.n_channels = n_channels 101 | self.up_rate = up_rate 102 | self.up_mode = up_mode 103 | self.use_conv = use_conv 104 | self.dims = dims 105 | self.n_out_channels = n_out_channels or n_channels 106 | self.padding_mode = padding_mode; 107 | self.padding = padding 108 | 109 | if use_conv: 110 | self.conv = conv_nd( 111 | dims = dims, 112 | in_channels = self.n_channels, 113 | out_channels = self.n_out_channels, 114 | kernel_size = 3, 115 | padding = padding, 116 | padding_mode = padding_mode) 117 | 118 | def forward(self, x): 119 | """ 120 | :param x: [B x C x W x H] 121 | :return: [B x C x 2W x 2H] 122 | """ 123 | assert x.shape[1] == self.n_channels 124 | if self.dims == 3: # 3D convolution 125 | x = F.interpolate( 126 | input = x, 127 | size = (x.shape[2],x.shape[3]*2,x.shape[4]*2), 128 | mode = self.up_mode 129 | ) 130 | else: 131 | x = F.interpolate( 132 | input = x, 133 | scale_factor = self.up_rate, 134 | mode = self.up_mode 135 | ) # 'nearest' or 'bilinear' 136 | 137 | # (optional) final convolution 138 | if self.use_conv: 139 | x = self.conv(x) 140 | return x 141 | 142 | class Downsample(nn.Module): 143 | """ 144 | A downsampling layer with an optional convolution. 145 | 146 | :param channels: channels in the inputs and outputs. 147 | :param use_conv: a bool determining if a convolution is applied. 148 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 149 | downsampling occurs in the inner-two dimensions. 150 | """ 151 | 152 | def __init__( 153 | self, 154 | n_channels, 155 | down_rate = 2, # down rate 156 | use_conv = False, # (optional) use output conv 157 | dims = 2, # (optional) spatial dimension 158 | n_out_channels = None, # (optional) in case output channels are different from the input 159 | padding_mode = 'zeros', 160 | padding = 1 161 | ): 162 | super().__init__() 163 | self.n_channels = n_channels 164 | self.down_rate = down_rate 165 | self.n_out_channels = n_out_channels or n_channels 166 | self.use_conv = use_conv 167 | self.dims = dims 168 | stride = self.down_rate if dims != 3 else (1, self.down_rate, self.down_rate) 169 | if use_conv: 170 | self.op = conv_nd( 171 | dims = dims, 172 | in_channels = self.n_channels, 173 | out_channels = self.n_out_channels, 174 | kernel_size = 3, 175 | stride = stride, 176 | padding = padding, 177 | padding_mode = padding_mode 178 | ) 179 | else: 180 | assert self.n_channels == self.n_out_channels 181 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 182 | 183 | def forward(self, x): 184 | assert x.shape[1] == self.n_channels 185 | return self.op(x) 186 | 187 | class QKVAttentionLegacy(nn.Module): 188 | """ 189 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 190 | """ 191 | def __init__(self, n_heads): 192 | super().__init__() 193 | self.n_heads = n_heads 194 | 195 | def forward(self, qkv): 196 | """ 197 | Apply QKV attention. 198 | (B:#batches, C:channel size, T:#tokens, H:#heads) 199 | 200 | :param qkv: an [B x (3*C) x T] tensor of Qs, Ks, and Vs. 201 | :return: an [B x C x T] tensor after attention. 202 | """ 203 | n_batches, width, n_tokens = qkv.shape 204 | assert width % (3 * self.n_heads) == 0 205 | ch = width // (3 * self.n_heads) 206 | q, k, v = qkv.reshape(n_batches * self.n_heads, ch * 3, n_tokens).split(ch, dim=1) 207 | scale = 1 / math.sqrt(math.sqrt(ch)) 208 | weight = th.einsum( 209 | "bct,bcs->bts", q * scale, k * scale 210 | ) # More stable with f16 than dividing afterwards 211 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 212 | a = th.einsum("bts,bcs->bct", weight, v) # [(H*B) x (C//H) x T] 213 | out = a.reshape(n_batches, -1, n_tokens) # [B x C x T] 214 | return out 215 | 216 | class AttentionBlock(nn.Module): 217 | """ 218 | An attention block that allows spatial positions to attend to each other. 219 | Input: [B x C x W x H] tensor 220 | Output: [B x C x W x H] tensor 221 | """ 222 | def __init__( 223 | self, 224 | name = 'attentionblock', 225 | n_channels = 1, 226 | n_heads = 1, 227 | n_groups = 32, 228 | ): 229 | super().__init__() 230 | self.name = name 231 | self.n_channels = n_channels 232 | self.n_heads = n_heads 233 | assert ( 234 | n_channels % n_heads == 0 235 | ), f"n_channels:[%d] should be divisible by n_heads:[%d]."%(n_channels,n_heads) 236 | 237 | # Normalize 238 | self.norm = normalization(n_channels=n_channels,n_groups=n_groups) 239 | 240 | # Tripple the channel 241 | self.qkv = nn.Conv1d( 242 | in_channels = self.n_channels, 243 | out_channels = self.n_channels*3, 244 | kernel_size = 1 245 | ) 246 | 247 | # QKV Attention 248 | self.attention = QKVAttentionLegacy( 249 | n_heads = self.n_heads 250 | ) 251 | 252 | # Projection 253 | self.proj_out = zero_module( 254 | nn.Conv1d( 255 | in_channels = self.n_channels, 256 | out_channels = self.n_channels, 257 | kernel_size = 1 258 | ) 259 | ) 260 | 261 | def forward(self, x): 262 | """ 263 | :param x: [B x C x W x H] tensor 264 | :return out: [B x C x W x H] tensor 265 | """ 266 | intermediate_output_dict = {} 267 | b, c, *spatial = x.shape 268 | # Triple the channel 269 | x_rsh = x.reshape(b, c, -1) # [B x C x WH] 270 | x_nzd = self.norm(x_rsh) # [B x C x WH] 271 | qkv = self.qkv(x_nzd) # [B x 3C x WH] 272 | # QKV attention 273 | h_att = self.attention(qkv) # [B x C x WH] 274 | h_proj = self.proj_out(h_att) # [B x C x WH] 275 | # Residual connection 276 | out = (x_rsh + h_proj).reshape(b, c, *spatial) # [B x C x W x H] 277 | # Append 278 | intermediate_output_dict['x'] = x 279 | intermediate_output_dict['x_rsh'] = x_rsh 280 | intermediate_output_dict['x_nzd'] = x_nzd 281 | intermediate_output_dict['qkv'] = qkv 282 | intermediate_output_dict['h_att'] = h_att 283 | intermediate_output_dict['h_proj'] = h_proj 284 | intermediate_output_dict['out'] = out 285 | return out,intermediate_output_dict 286 | 287 | class ResBlock(TimestepBlock): 288 | """ 289 | A residual block that can optionally change the number of channels and resolution 290 | 291 | :param n_channels: the number of input channels 292 | :param n_emb_channels: the number of timestep embedding channels 293 | :param n_out_channels: (if specified) the number of output channels 294 | :param n_groups: the number of groups in group normalization layer 295 | :param dims: spatial dimension 296 | :param p_dropout: the rate of dropout 297 | :param actv: activation 298 | :param use_conv: if True, and n_out_channels is specified, 299 | use 3x3 conv instead of 1x1 conv 300 | :param use_scale_shift_norm: if True, use scale_shift_norm for handling emb 301 | :param upsample: if True, upsample 302 | :param downsample: if True, downsample 303 | :param sample_mode: upsample, downsample mode ('nearest' or 'bilinear') 304 | :param padding_mode: str 305 | :param padding: int 306 | """ 307 | def __init__( 308 | self, 309 | name = 'resblock', 310 | n_channels = 128, 311 | n_emb_channels = 128, 312 | n_out_channels = None, 313 | n_groups = 16, 314 | dims = 2, 315 | p_dropout = 0.5, 316 | kernel_size = 3, 317 | actv = nn.SiLU(), 318 | use_conv = False, 319 | use_scale_shift_norm = True, 320 | upsample = False, 321 | downsample = False, 322 | up_rate = 2, 323 | down_rate = 2, 324 | sample_mode = 'nearest', 325 | padding_mode = 'zeros', 326 | padding = 1, 327 | ): 328 | super().__init__() 329 | self.name = name 330 | self.n_channels = n_channels 331 | self.n_emb_channels = n_emb_channels 332 | self.n_groups = n_groups 333 | self.dims = dims 334 | self.n_out_channels = n_out_channels or self.n_channels 335 | self.kernel_size = kernel_size 336 | self.p_dropout = p_dropout 337 | self.actv = actv 338 | self.use_conv = use_conv 339 | self.use_scale_shift_norm = use_scale_shift_norm 340 | self.upsample = upsample 341 | self.downsample = downsample 342 | self.up_rate = up_rate 343 | self.down_rate = down_rate 344 | self.sample_mode = sample_mode 345 | self.padding_mode = padding_mode 346 | self.padding = padding 347 | 348 | # Input layers 349 | self.in_layers = nn.Sequential( 350 | normalization(n_channels=self.n_channels,n_groups=self.n_groups), 351 | self.actv, 352 | conv_nd( 353 | dims = self.dims, 354 | in_channels = self.n_channels, 355 | out_channels = self.n_out_channels, 356 | kernel_size = self.kernel_size, 357 | padding = self.padding, 358 | padding_mode = self.padding_mode 359 | ) 360 | ) 361 | 362 | # Upsample or downsample 363 | self.updown = self.upsample or self.downsample 364 | if self.upsample: 365 | self.h_upd = Upsample( 366 | n_channels = self.n_channels, 367 | up_rate = self.up_rate, 368 | up_mode = self.sample_mode, 369 | dims = self.dims) 370 | self.x_upd = Upsample( 371 | n_channels = self.n_channels, 372 | up_rate = self.up_rate, 373 | up_mode = self.sample_mode, 374 | dims = self.dims) 375 | elif self.downsample: 376 | self.h_upd = Downsample( 377 | n_channels = self.n_channels, 378 | down_rate = self.down_rate, 379 | dims = self.dims) 380 | self.x_upd = Downsample( 381 | n_channels = self.n_channels, 382 | down_rate = self.down_rate, 383 | dims = self.dims) 384 | else: 385 | self.h_upd = nn.Identity() 386 | self.x_upd = nn.Identity() 387 | 388 | # Embedding layers 389 | self.emb_layers = nn.Sequential( 390 | self.actv, 391 | nn.Linear( 392 | in_features = self.n_emb_channels, 393 | out_features = 2*self.n_out_channels if self.use_scale_shift_norm 394 | else self.n_out_channels, 395 | ), 396 | ) 397 | 398 | # Output layers 399 | self.out_layers = nn.Sequential( 400 | normalization(n_channels=self.n_out_channels,n_groups=self.n_groups), 401 | self.actv, 402 | nn.Dropout(p=self.p_dropout), 403 | zero_module( 404 | conv_nd( 405 | dims = self.dims, 406 | in_channels = self.n_out_channels, 407 | out_channels = self.n_out_channels, 408 | kernel_size = self.kernel_size, 409 | padding = self.padding, 410 | padding_mode = self.padding_mode 411 | ) 412 | ), 413 | ) 414 | if self.n_channels == self.n_out_channels: 415 | self.skip_connection = nn.Identity() 416 | elif use_conv: 417 | self.skip_connection = conv_nd( 418 | dims = self.dims, 419 | in_channels = self.n_channels, 420 | out_channels = self.n_out_channels, 421 | kernel_size = self.kernel_size, 422 | padding = self.padding, 423 | padding_mode = self.padding_mode 424 | ) 425 | else: 426 | self.skip_connection = conv_nd( 427 | dims = self.dims, 428 | in_channels = self.n_channels, 429 | out_channels = self.n_out_channels, 430 | kernel_size = 1 431 | ) 432 | 433 | def forward(self,x,emb): 434 | """ 435 | :param x: [B x C x ...] 436 | :param emb: [B x n_emb_channels] 437 | :return: [B x C x ...] 438 | """ 439 | # Input layer (groupnorm -> actv -> conv) 440 | if self.updown: # upsample or downsample 441 | in_norm_actv = self.in_layers[:-1] 442 | in_conv = self.in_layers[-1] 443 | h = in_norm_actv(x) 444 | h = self.h_upd(h) 445 | h = in_conv(h) 446 | x = self.x_upd(x) 447 | else: 448 | h = self.in_layers(x) # [B x C x ...] 449 | 450 | # Embedding layer 451 | emb_out = self.emb_layers(emb).type(h.dtype) 452 | while len(emb_out.shape) < len(h.shape): 453 | emb_out = emb_out[..., None] # match 'emb_out' with 'h': [B x C x ...] 454 | 455 | # Combine input with embedding 456 | if self.use_scale_shift_norm: 457 | out_norm = self.out_layers[0] # layernorm 458 | out_actv_dr_conv = self.out_layers[1:] # activation -> dropout -> conv 459 | # emb_out: [B x 2C x ...] 460 | scale,shift = th.chunk(emb_out, 2, dim=1) # [B x C x ...] 461 | h = out_norm(h) * (1.0 + scale) + shift # [B x C x ...] 462 | h = out_actv_dr_conv(h) # [B x C x ...] 463 | else: 464 | # emb_out: [B x C x ...] 465 | h = h + emb_out 466 | h = self.out_layers(h) # layernorm -> activation -> dropout -> conv 467 | 468 | # Skip connection 469 | out = h + self.skip_connection(x) # [B x C x ...] 470 | return out # [B x C x ...] 471 | 472 | -------------------------------------------------------------------------------- /code/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch as th 4 | from scipy.spatial import distance 5 | 6 | def th2np(x): 7 | return x.detach().cpu().numpy() 8 | 9 | def get_torch_size_string(x): 10 | return "x".join(map(str,x.shape)) 11 | 12 | def plot_4x4_torch_tensor( 13 | x_torch, 14 | figsize = (4,4), 15 | cmap = 'gray', 16 | info_str = '', 17 | top = 0.92, 18 | hspace = 0.1 19 | ): 20 | """ 21 | :param x_torch: [B x C x W x H] 22 | """ 23 | batch_size = x_torch.shape[0] 24 | fig = plt.figure(figsize=figsize) 25 | for i in range(batch_size): 26 | ax = plt.subplot(4,4,i+1) 27 | plt.imshow(x_torch.permute(0,2,3,1).detach().numpy()[i,:,:,:], cmap=cmap) 28 | plt.axis('off') 29 | plt.subplots_adjust( 30 | left=0.0,right=1.0,bottom=0.0,top=top,wspace=0.0,hspace=hspace) 31 | plt.suptitle('%s[%d] Images of [%dx%d] sizes'% 32 | (info_str,batch_size,x_torch.shape[2],x_torch.shape[3]),fontsize=10) 33 | plt.show() 34 | 35 | def plot_1xN_torch_img_tensor( 36 | x_torch, 37 | title_str_list = None, 38 | title_fontsize = 20): 39 | """ 40 | : param x_torch: [B x C x W x H] 41 | """ 42 | xt_np = x_torch.cpu().numpy() # [B x C x W x H] 43 | n_imgs = xt_np.shape[0] 44 | plt.figure(figsize=(n_imgs*2,3)) 45 | for img_idx in range(n_imgs): 46 | plt.subplot(1,n_imgs,img_idx+1) 47 | if xt_np.shape[1]==1: 48 | plt.imshow(xt_np[img_idx,0,:,:], cmap='gray') 49 | else: 50 | plt.imshow(xt_np[img_idx,:,:,:].transpose(1,2,0)) 51 | if title_str_list: 52 | plt.title(title_str_list[img_idx],fontsize=title_fontsize) 53 | plt.axis('off') 54 | plt.tight_layout() 55 | plt.show() 56 | 57 | def plot_1xN_torch_traj_tensor( 58 | times, 59 | x_torch, 60 | title_str_list = None, 61 | title_fontsize = 20, 62 | ylim = None, 63 | figsize = None, 64 | ): 65 | """ 66 | : param x_torch: [B x C x ...] 67 | """ 68 | xt_np = x_torch.cpu().numpy() # [B x C x W x H] 69 | n_trajs = xt_np.shape[0] 70 | L = times.shape[0] 71 | if figsize is None: figsize = (n_trajs*2,3) 72 | plt.figure(figsize=figsize) 73 | for traj_idx in range(n_trajs): 74 | plt.subplot(1,n_trajs,traj_idx+1) 75 | plt.plot(times,x_torch[traj_idx,0,:].cpu().numpy(),'-',color='k') 76 | if title_str_list: 77 | plt.title(title_str_list[traj_idx],fontsize=title_fontsize) 78 | if ylim: 79 | plt.ylim(ylim) 80 | plt.tight_layout() 81 | plt.show() 82 | 83 | def print_model_parameters(model): 84 | """ 85 | Print model parameters 86 | """ 87 | for p_idx,(param_name,param) in enumerate(model.named_parameters()): 88 | print ("[%2d] parameter:[%27s] shape:[%12s] numel:[%10d]"% 89 | (p_idx, 90 | param_name, 91 | get_torch_size_string(param), 92 | param.numel() 93 | ) 94 | ) 95 | 96 | def print_model_layers(model,x_torch): 97 | """ 98 | Print model layers 99 | """ 100 | y_torch,intermediate_output_list = model(x_torch) 101 | batch_size = x_torch.shape[0] 102 | print ("batch_size:[%d]"%(batch_size)) 103 | print ("[ ] layer:[%15s] size:[%14s]" 104 | %('input',"x".join(map(str,x_torch.shape))) 105 | ) 106 | for idx,layer_name in enumerate(model.layer_names): 107 | intermediate_output = intermediate_output_list[idx] 108 | print ("[%2d] layer:[%15s] size:[%14s] numel:[%10d]"% 109 | (idx, 110 | layer_name, 111 | get_torch_size_string(intermediate_output), 112 | intermediate_output.numel() 113 | )) 114 | 115 | def model_train(model,optm,loss,train_iter,test_iter,n_epoch,print_every,device): 116 | """ 117 | Train model 118 | """ 119 | model.init_param(VERBOSE=False) 120 | model.train() 121 | for epoch in range(n_epoch): 122 | loss_val_sum = 0 123 | for batch_in,batch_out in train_iter: 124 | # Forward path 125 | if isinstance(model.x_dim,int): 126 | y_pred,_ = model(batch_in.view(-1,model.x_dim).to(device)) 127 | else: 128 | y_pred,_ = model(batch_in.view((-1,)+model.x_dim).to(device)) 129 | loss_out = loss(y_pred,batch_out.to(device)) 130 | # Update 131 | optm.zero_grad() # reset gradient 132 | loss_out.backward() # back-propagate loss 133 | optm.step() # optimizer update 134 | loss_val_sum += loss_out 135 | loss_val_avg = loss_val_sum/len(train_iter) 136 | # Print 137 | if ((epoch%print_every)==0) or (epoch==(n_epoch-1)): 138 | train_accr = model_eval(model,train_iter,device) 139 | test_accr = model_eval(model,test_iter,device) 140 | print ("epoch:[%2d/%d] loss:[%.3f] train_accr:[%.4f] test_accr:[%.4f]."% 141 | (epoch,n_epoch,loss_val_avg,train_accr,test_accr)) 142 | 143 | def model_eval(model,data_iter,device): 144 | """ 145 | Evaluate model 146 | """ 147 | with th.no_grad(): 148 | n_total,n_correct = 0,0 149 | model.eval() # evaluate (affects DropOut and BN) 150 | for batch_in,batch_out in data_iter: 151 | y_trgt = batch_out.to(device) 152 | if isinstance(model.x_dim,int): 153 | model_pred,_ = model(batch_in.view(-1,model.x_dim).to(device)) 154 | else: 155 | model_pred,_ = model(batch_in.view((-1,)+model.x_dim).to(device)) 156 | _,y_pred = th.max(model_pred.data,1) 157 | n_correct += (y_pred==y_trgt).sum().item() 158 | n_total += batch_in.size(0) 159 | val_accr = (n_correct/n_total) 160 | model.train() # back to train mode 161 | return val_accr 162 | 163 | def model_test(model,test_data,test_label,device): 164 | """ 165 | Test model 166 | """ 167 | n_sample = 25 168 | sample_indices = np.random.choice(len(test_data),n_sample,replace=False) 169 | test_data_samples = test_data[sample_indices] 170 | test_label_samples = test_label[sample_indices] 171 | with th.no_grad(): 172 | model.eval() 173 | if isinstance(model.x_dim,int): 174 | x_in = test_data_samples.view(-1,model.x_dim).type(th.float).to(device)/255. 175 | else: 176 | x_in = test_data_samples.view((-1,)+model.x_dim).type(th.float).to(device)/255. 177 | y_pred,_ = model(x_in) 178 | y_pred = y_pred.argmax(axis=1) 179 | # Plot 180 | plt.figure(figsize=(6,6)) 181 | plt.subplots_adjust(top=1.0) 182 | for idx in range(n_sample): 183 | plt.subplot(5,5, idx+1) 184 | plt.imshow(test_data_samples[idx],cmap='gray') 185 | plt.axis('off') 186 | fontcolor = 'k' if (y_pred[idx] == test_label_samples[idx]) else 'r' 187 | plt.title("Pred:%d, Label:%d"%(y_pred[idx],test_label_samples[idx]), 188 | fontsize=8,color=fontcolor) 189 | plt.show() 190 | 191 | def kernel_se(x1,x2,hyp={'gain':1.0,'len':1.0}): 192 | """ Squared-exponential kernel function """ 193 | D = distance.cdist(x1/hyp['len'],x2/hyp['len'],'sqeuclidean') 194 | K = hyp['gain']*np.exp(-D) 195 | return K 196 | 197 | def gp_sampler( 198 | times = np.linspace(start=0.0,stop=1.0,num=100).reshape((-1,1)), # [L x 1] 199 | hyp_gain = 1.0, 200 | hyp_len = 1.0, 201 | meas_std = 0e-8, 202 | n_traj = 1 203 | ): 204 | """ 205 | Gaussian process sampling 206 | """ 207 | if len(times.shape) == 1: times = times.reshape((-1,1)) 208 | L = times.shape[0] 209 | K = kernel_se(times,times,hyp={'gain':hyp_gain,'len':hyp_len}) # [L x L] 210 | K_chol = np.linalg.cholesky(K+1e-8*np.eye(L,L)) # [L x L] 211 | traj = K_chol @ np.random.randn(L,n_traj) # [L x n_traj] 212 | traj = traj + meas_std*np.random.randn(*traj.shape) 213 | return traj.T 214 | 215 | def hbm_sampler( 216 | times = np.linspace(start=0.0,stop=1.0,num=100).reshape((-1,1)), # [L x 1] 217 | hyp_gain = 1.0, 218 | hyp_len = 1.0, 219 | meas_std = 0e-8, 220 | n_traj = 1 221 | ): 222 | """ 223 | Hilbert Brownian motion sampling 224 | """ 225 | if len(times.shape) == 1: times = times.reshape((-1,1)) 226 | L = times.shape[0] 227 | K = kernel_se(times,times,hyp={'gain':hyp_gain,'len':hyp_len}) # [L x L] 228 | K = K + 1e-8*np.eye(L,L) 229 | U,V = np.linalg.eigh(K,UPLO='L') 230 | traj = V @ np.diag(np.sqrt(U)) @ np.random.randn(L,n_traj) # [L x n_traj] 231 | traj = traj + meas_std*np.random.randn(*traj.shape) 232 | return traj.T 233 | 234 | def get_colors(n): 235 | return [plt.cm.Set1(x) for x in np.linspace(0,1,n)] 236 | 237 | def periodic_step(times,period,time_offset=0.0,y_min=0.0,y_max=1.0): 238 | times_mod = np.mod(times+time_offset,period) 239 | y = np.zeros_like(times) 240 | y[times_mod < (period/2)] = 1 241 | y*=(y_max-y_min) 242 | y+=y_min 243 | return y 244 | 245 | def plot_ddpm_1d_result( 246 | times,x_data,step_list,x_t_list, 247 | plot_ancestral_sampling=True, 248 | plot_one_sample=False, 249 | lw_gt=1,lw_sample=1/2, 250 | ls_gt='-',ls_sample='-', 251 | lc_gt='b',lc_sample='k', 252 | ylim=(-4,+4),figsize=(6,3),title_str=None 253 | ): 254 | """ 255 | :param times: [L x 1] ndarray 256 | :param x_0: [N x C x L] torch tensor, training data 257 | :param step_list: [M] ndarray, diffusion steps to append x_t 258 | :param x_t_list: list of [n_sample x C x L] torch tensors 259 | """ 260 | 261 | x_data_np = x_data.detach().cpu().numpy() # [n_data x C x L] 262 | n_data = x_data_np.shape[0] # number of GT data 263 | C = x_data_np.shape[1] # number of GT data 264 | 265 | # Plot a seqeunce of ancestral sampling procedure 266 | if plot_ancestral_sampling: 267 | for c_idx in range(C): 268 | plt.figure(figsize=(15,2)); plt.rc('xtick',labelsize=6); plt.rc('ytick',labelsize=6) 269 | for i_idx,t in enumerate(step_list): 270 | plt.subplot(1,len(step_list),i_idx+1) 271 | x_t = x_t_list[t] # [n_sample x C x L] 272 | x_t_np = x_t.detach().cpu().numpy() # [n_sample x C x L] 273 | n_sample = x_t_np.shape[0] 274 | for i_idx in range(n_data): # GT 275 | plt.plot(times.flatten(),x_data_np[i_idx,c_idx,:],ls='-',color=lc_gt,lw=lw_gt) 276 | for i_idx in range(n_sample): # sampled trajectories 277 | plt.plot(times.flatten(),x_t_np[i_idx,c_idx,:],ls='-',color=lc_sample,lw=lw_sample) 278 | plt.xlim([0.0,1.0]); plt.ylim(ylim) 279 | plt.xlabel('Time',fontsize=8); plt.title('Step:[%d]'%(t),fontsize=8) 280 | plt.tight_layout(); plt.show() 281 | 282 | # Plot generated data 283 | for c_idx in range(C): 284 | plt.figure(figsize=figsize) 285 | x_0_np = x_t_list[0].detach().cpu().numpy() # [n_sample x C x L] 286 | for i_idx in range(n_data): # GT 287 | plt.plot(times.flatten(),x_data_np[i_idx,c_idx,:],ls=ls_gt,color=lc_gt,lw=lw_gt) 288 | n_sample = x_0_np.shape[0] 289 | if plot_one_sample: 290 | i_idx = np.random.randint(low=0,high=n_sample) 291 | plt.plot(times.flatten(),x_0_np[i_idx,c_idx,:],ls=ls_sample,color=lc_sample,lw=lw_sample) 292 | else: 293 | for i_idx in range(n_sample): # sampled trajectories 294 | plt.plot(times.flatten(),x_0_np[i_idx,c_idx,:],ls=ls_sample,color=lc_sample,lw=lw_sample) 295 | plt.xlim([0.0,1.0]); plt.ylim(ylim) 296 | plt.xlabel('Time',fontsize=8) 297 | if title_str is None: 298 | plt.title('[%d] Groundtruth and Generated trajectories'%(c_idx),fontsize=10) 299 | else: 300 | plt.title('[%d] %s'%(c_idx,title_str),fontsize=10) 301 | plt.tight_layout(); plt.show() 302 | 303 | 304 | def plot_ddpm_2d_result( 305 | x_data,step_list,x_t_list,n_plot=1, 306 | tfs=10 307 | ): 308 | """ 309 | :param x_data: [N x C x W x H] torch tensor, training data 310 | :param step_list: [M] ndarray, diffusion steps to append x_t 311 | :param x_t_list: list of [n_sample x C x L] torch tensors 312 | """ 313 | for sample_idx in range(n_plot): 314 | plt.figure(figsize=(15,2)) 315 | for i_idx,t in enumerate(step_list): 316 | x_t = x_t_list[t] # [n_sample x C x W x H] 317 | x_t_np = x_t.detach().cpu().numpy() # [n_sample x C x W x H] 318 | plt.subplot(1,len(step_list),i_idx+1) 319 | if x_data.shape[1]==1: # gray image 320 | plt.imshow(x_t_np[sample_idx,0,:,:], cmap='gray') 321 | else: 322 | plt.imshow(x_t_np[sample_idx,:,:,:].transpose(1,2,0)) 323 | plt.axis('off') 324 | plt.title('Step:[%d]'%(t),fontsize=tfs) 325 | plt.tight_layout() 326 | plt.show() 327 | 328 | 329 | def get_hbm_M(times,hyp_gain=1.0,hyp_len=0.1,device='cpu'): 330 | """ 331 | Get a matrix M for Hilbert Brownian motion 332 | :param times: [L x 1] ndarray 333 | :return: [L x L] torch tensor 334 | """ 335 | L = times.shape[0] 336 | K = kernel_se(times,times,hyp={'gain':hyp_gain,'len':hyp_len}) # [L x L] 337 | K = K + 1e-8*np.eye(L,L) 338 | U,V = np.linalg.eigh(K,UPLO='L') 339 | M = V @ np.diag(np.sqrt(U)) 340 | M = th.from_numpy(M).to(th.float32).to(device) # [L x L] 341 | return M 342 | 343 | def get_resampling_steps(t_T, j, r,plot_steps=False,figsize=(15,4)): 344 | """ 345 | Get resampling steps for repaint, inpainting method using diffusion models 346 | :param t_T: maximum time steps for inpainting 347 | :param j: jump length 348 | :param r: the number of resampling 349 | """ 350 | jumps = np.zeros(t_T+1) 351 | for i in range(1, t_T-j, j): 352 | jumps[i] = r-1 353 | t = t_T+1 354 | resampling_steps = [] 355 | while t > 1: 356 | t -= 1 357 | resampling_steps.append(t) 358 | if jumps[t] > 0: 359 | jumps[t] -= 1 360 | for _ in range(j): 361 | t += 1 362 | resampling_steps.append(t) 363 | resampling_steps.append(0) 364 | 365 | # (optional) plot 366 | if plot_steps: 367 | plt.figure(figsize=figsize) 368 | plt.plot(resampling_steps,'-',color='k',lw=1) 369 | plt.xlabel('Number of Transitions') 370 | plt.ylabel('Diffusion time step') 371 | plt.show() 372 | 373 | # Return 374 | return resampling_steps -------------------------------------------------------------------------------- /img/unet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjchoi86/yet-another-pytorch-tutorial-v2/1b2fdabfc11586fcabc9d3aa1123c90a026f1c1f/img/unet.jpg --------------------------------------------------------------------------------