├── .gitignore ├── LICENSE ├── README.md ├── idbm ├── __init__.py ├── idbm.py └── unet │ ├── __init__.py │ ├── fp16_util.py │ ├── layers.py │ └── unet.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stefano Peluchetti 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IDBM - PyTorch 2 | 3 | This repository consists of a self-contained implementation (~500 lines of code, neural network model excluded) of the dataset transfer experiment of: 4 | 5 | [_Diffusion Bridge Mixture Transports, Schrödinger Bridge Problems and Generative Modeling_](https://arxiv.org/abs/2304.00917). 6 | 7 | The following assumptions are made (see the paper, specifically Section 5.4, for more details): 8 | 9 | - the reference process is given by $dX_t = σdW_t$ over $t ∈ [0,1]$ for some scalar $σ ≥ 0$ ; 10 | - the initial dataset is MNIST and the terminal dataset is a subset of EMNIST. 11 | 12 | ## Install 13 | 14 | Having cloned this repository, the recommended installation procedure is as follows: 15 | 16 | ### 1. Create Virtual Environment 17 | 18 | Create a new virtual environment and activate it. 19 | 20 | For instance, using [(Mini)Conda](https://docs.conda.io/en/latest/miniconda.html): 21 | 22 | ```bash 23 | conda create -n idbm pip 24 | conda activate idbm 25 | ``` 26 | 27 | ### 2. Install PyTorch 28 | 29 | Install the latest appropriate version of PyTorch according to the [official instructions](https://pytorch.org/get-started/locally/). 30 | 31 | ### 3. Install Other Requirements 32 | 33 | Install the remaining requirements: 34 | 35 | ```bash 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Run 40 | 41 | The Python script [`idbm.py`](idbm/idbm.py) accepts the following options: 42 | 43 | ```bash 44 | python idbm.py [FLAGS] 45 | 46 | FLAGS: 47 | --method=METHOD 48 | Default: 'IDBM' 49 | --sigma=SIGMA 50 | Default: 1.0 51 | --iterations=ITERATIONS 52 | Default: 60 53 | --training_steps=TRAINING_STEPS 54 | Default: 5000 55 | --discretization_steps=DISCRETIZATION_STEPS 56 | Default: 30 57 | --batch_dim=BATCH_DIM 58 | Default: 128 59 | --learning_rate=LEARNING_RATE 60 | Default: 0.0001 61 | --grad_max_norm=GRAD_MAX_NORM 62 | Default: 1.0 63 | --ema_decay=EMA_DECAY 64 | Default: 0.999 65 | --cache_steps=CACHE_STEPS 66 | Default: 250 67 | --cache_batch_dim=CACHE_BATCH_DIM 68 | Default: 2560 69 | --test_steps=TEST_STEPS 70 | Default: 5000 71 | --test_batch_dim=TEST_BATCH_DIM 72 | Default: 500 73 | --loss_log_steps=LOSS_LOG_STEPS 74 | Default: 100 75 | --imge_log_steps=IMGE_LOG_STEPS 76 | Default: 1000 77 | ``` 78 | 79 | The findings of the paper are replicated by the following runs: 80 | 81 | ```bash 82 | # IDBM -- Iterated Diffusion Bridge Mixture Transport: 83 | python idbm.py --method=IDBM --sigma=1.0 84 | python idbm.py --method=IDBM --sigma=0.5 85 | python idbm.py --method=IDBM --sigma=0.2 86 | 87 | # BDBM -- Backward Diffusion Bridge Mixture Transport: 88 | python idbm.py --method=IDBM --sigma=1.0 --iterations=1 --training_steps=300000 89 | 90 | # DIPF -- Diffusion Iterated Proportional Fitting Transport: 91 | python idbm.py --method=DIPF --sigma=1.0 92 | python idbm.py --method=DIPF --sigma=0.5 93 | python idbm.py --method=DIPF --sigma=0.2 94 | ``` 95 | 96 | The runs' histories have been persisted on [Weights & Biases](https://wandb.ai/stepelu/pub-idbm-pytorch), to aid reproducibility, analysis and experimentation. 97 | -------------------------------------------------------------------------------- /idbm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepelu/idbm-pytorch/addae9837bee0b2161082b6bca3cd99f1a41b39b/idbm/__init__.py -------------------------------------------------------------------------------- /idbm/idbm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fire 4 | import numpy as np 5 | import torch as th 6 | import wandb 7 | from rich.console import Console 8 | from rich.progress import ( 9 | BarColumn, 10 | MofNCompleteColumn, 11 | Progress, 12 | TextColumn, 13 | TimeRemainingColumn, 14 | ) 15 | from torch.utils.data import DataLoader, Subset 16 | from torchmetrics.image.fid import FrechetInceptionDistance 17 | from torchvision import datasets, transforms 18 | from torchvision.utils import make_grid 19 | from unet.unet import UNetModel 20 | 21 | data_path = os.path.expanduser("~/torch-data/") 22 | device = th.device("cuda") if th.cuda.is_available() else th.device("cpu") 23 | 24 | IDBM = "IDBM" 25 | DIPF = "DIPF" 26 | 27 | # Routines ----------------------------------------------------------------------------- 28 | 29 | 30 | # Reference SDE with law R: $dx_t = sigma dw_t$, $σ ≥ 0, t ∈ [0, 1].$ 31 | # Shapes: x_t is [B × C × H x W], x is [T, B × C × H x W]. 32 | # Fwd and bwd inferred SDEs formulated on increasing timescales. 33 | # Implementation allows σ = 0. 34 | 35 | 36 | def sample_bridge(x_0, x_1, t, sigma): 37 | t = t[:, None, None, None] 38 | mean_t = (1.0 - t) * x_0 + t * x_1 39 | var_t = sigma**2 * t * (1.0 - t) 40 | z_t = th.randn_like(x_0) 41 | x_t = mean_t + th.sqrt(var_t) * z_t 42 | return x_t 43 | 44 | 45 | def idbm_target(x_t, x_1, t): 46 | target_t = (x_1 - x_t) / (1.0 - t[:, None, None, None]) 47 | return target_t 48 | 49 | 50 | def drift_target(x_t, x_t_dt, dt): 51 | dt = th.full(size=[x_t.shape[1], 1, 1, 1], fill_value=dt, device=device) 52 | target_t = (x_t_dt - x_t) / dt[:, None, None, None] 53 | return target_t 54 | 55 | 56 | def euler_discretization(x, xp, nn, sigma): 57 | # Assumes x has shape [T, B, C, H, W]. 58 | # Assumes x[0] already initialized. 59 | # We normalize by D = C * H * W the drift squared norm, and not by scalar sigma. 60 | # Fills x[1] to x[T] and xp[0] to xp[T - 1]. 61 | T = x.shape[0] - 1 # Discretization steps. 62 | B = x.shape[1] 63 | dt = th.full(size=(x.shape[1],), fill_value=1.0 / T, device=device) 64 | drift_norms = 0.0 65 | for i in range(1, T + 1): 66 | t = dt * (i - 1) 67 | alpha_t = nn(x[i - 1], None, t) 68 | drift_norms = drift_norms + th.mean(alpha_t.view(B, -1) ** 2, dim=1) 69 | xp[i - 1] = x[i - 1] + alpha_t * (1 - t[:, None, None, None]) 70 | drift_t = alpha_t * dt[:, None, None, None] 71 | eps_t = th.randn_like(x[i - 1]) 72 | diffusion_t = sigma * th.sqrt(dt[:, None, None, None]) * eps_t 73 | x[i] = x[i - 1] + drift_t + diffusion_t 74 | drift_norms = drift_norms / T 75 | return drift_norms.cpu() 76 | 77 | 78 | # Data --------------------------------------------------------------------------------- 79 | 80 | 81 | class EMNIST(datasets.EMNIST): 82 | def __init__(self, **kwargs): 83 | super().__init__(split="letters", **kwargs) 84 | indices = (self.targets <= 5).nonzero(as_tuple=True)[0] 85 | self.data, self.targets = ( 86 | self.data[indices].transpose(1, 2), 87 | self.targets[indices], 88 | ) 89 | 90 | 91 | class InfiniteDataLoader(DataLoader): 92 | def __init__(self, *args, **kwargs): 93 | super().__init__(*args, **kwargs) 94 | self.epoch_iterator = super().__iter__() 95 | 96 | def __iter__(self): 97 | return self 98 | 99 | def __next__(self): 100 | try: 101 | batch = next(self.epoch_iterator) 102 | except StopIteration: 103 | self.epoch_iterator = super().__iter__() 104 | batch = next(self.epoch_iterator) 105 | return batch 106 | 107 | 108 | def train_iter(data, batch_dim): 109 | return iter( 110 | InfiniteDataLoader( 111 | dataset=data, 112 | batch_size=batch_dim, 113 | num_workers=2, 114 | pin_memory=True, 115 | shuffle=True, 116 | drop_last=True, 117 | ) 118 | ) 119 | 120 | 121 | def test_loader(data, batch_dim): 122 | return DataLoader( 123 | dataset=data, 124 | batch_size=batch_dim, 125 | num_workers=2, 126 | pin_memory=True, 127 | shuffle=False, 128 | drop_last=False, 129 | ) 130 | 131 | 132 | def resample_indices(from_n, to_n): 133 | # Equi spaced resampling, first and last element always included. 134 | return np.round(np.linspace(0, from_n - 1, num=to_n)).astype(int) 135 | 136 | 137 | def image_grid(x, normalize=False, n=5): 138 | img = x[: n**2].cpu() 139 | img = make_grid(img, nrow=n, normalize=normalize, scale_each=normalize) 140 | img = wandb.Image(img) 141 | return img 142 | 143 | 144 | # For fixed permutations of test sets: 145 | rng = np.random.default_rng(seed=0x87351080E25CB0FAD77A44A3BE03B491) 146 | 147 | # Linear scaling to float [-1.0, 1.0]: 148 | transform = transforms.Compose( 149 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 150 | ) 151 | tr_data_0 = datasets.MNIST( 152 | root=data_path + "mnist", train=True, download=True, transform=transform 153 | ) 154 | te_data_0 = datasets.MNIST( 155 | root=data_path + "mnist", train=False, download=True, transform=transform 156 | ) 157 | te_data_0 = Subset(te_data_0, rng.permutation(len(te_data_0))) 158 | tr_data_1 = EMNIST( 159 | root=data_path + "emnist", train=True, download=True, transform=transform 160 | ) 161 | te_data_1 = EMNIST( 162 | root=data_path + "emnist", train=False, download=True, transform=transform 163 | ) 164 | te_data_1 = Subset(te_data_1, rng.permutation(len(te_data_1))) 165 | 166 | 167 | # NN Model ----------------------------------------------------------------------------- 168 | 169 | 170 | def init_nn(): 171 | # From https://github.com/openai/guided-diffusion/tree/main/guided_diffusion: 172 | return UNetModel( 173 | in_channels=1, 174 | model_channels=128, 175 | out_channels=1, 176 | num_res_blocks=2, 177 | attention_resolutions=(), 178 | dropout=0.1, 179 | channel_mult=(0.5, 1, 1), 180 | num_heads=4, 181 | use_scale_shift_norm=True, 182 | temb_scale=1000, 183 | ) 184 | 185 | 186 | class EMAHelper: 187 | # Simplified from https://github.com/ermongroup/ddim/blob/main/models/ema.py: 188 | def __init__(self, module, mu=0.999, device=None): 189 | self.module = module 190 | self.mu = mu 191 | self.device = device 192 | self.shadow = {} 193 | # Register: 194 | for name, param in self.module.named_parameters(): 195 | if param.requires_grad: 196 | self.shadow[name] = param.data.clone() 197 | 198 | def update(self): 199 | for name, param in self.module.named_parameters(): 200 | if param.requires_grad: 201 | self.shadow[name].data = ( 202 | 1.0 - self.mu 203 | ) * param.data + self.mu * self.shadow[name].data 204 | 205 | def ema(self, module): 206 | for name, param in module.named_parameters(): 207 | if param.requires_grad: 208 | param.data.copy_(self.shadow[name].data) 209 | 210 | def ema_copy(self): 211 | locs = self.module.locals 212 | module_copy = type(self.module)(*locs).to(self.device) 213 | module_copy.load_state_dict(self.module.state_dict()) 214 | self.ema(module_copy) 215 | return module_copy 216 | 217 | 218 | # Run ---------------------------------------------------------------------------------- 219 | 220 | 221 | def run( 222 | method=IDBM, 223 | sigma=1.0, 224 | iterations=60, 225 | training_steps=5000, 226 | discretization_steps=30, 227 | batch_dim=128, 228 | learning_rate=1e-4, 229 | grad_max_norm=1.0, 230 | ema_decay=0.999, 231 | cache_steps=250, 232 | cache_batch_dim=2560, 233 | test_steps=5000, 234 | test_batch_dim=500, 235 | loss_log_steps=100, 236 | imge_log_steps=1000, 237 | ): 238 | config = locals() 239 | assert isinstance(sigma, float) and sigma >= 0 240 | assert isinstance(learning_rate, float) and learning_rate > 0 241 | assert isinstance(grad_max_norm, float) and grad_max_norm >= 0 242 | assert method in [IDBM, DIPF] 243 | 244 | console = Console(log_path=False) 245 | progress = Progress( 246 | TextColumn("[progress.description]{task.description}"), 247 | BarColumn(), 248 | TimeRemainingColumn(), 249 | TextColumn("•"), 250 | MofNCompleteColumn(), 251 | console=console, 252 | speed_estimate_period=60 * 5, 253 | ) 254 | iteration_t = progress.add_task("iteration", total=iterations) 255 | step_t = progress.add_task("step", total=iterations * training_steps) 256 | 257 | wandb.init(project="idbm-x", config=config) 258 | console.log(wandb.config) 259 | 260 | tr_iter_0 = train_iter(tr_data_0, batch_dim) 261 | tr_iter_1 = train_iter(tr_data_1, batch_dim) 262 | tr_cache_iter_0 = train_iter(tr_data_0, cache_batch_dim) 263 | tr_cache_iter_1 = train_iter(tr_data_1, cache_batch_dim) 264 | te_loader_0 = test_loader(te_data_0, test_batch_dim) 265 | te_loader_1 = test_loader(te_data_1, test_batch_dim) 266 | 267 | bwd_nn = init_nn().to(device) 268 | fwd_nn = init_nn().to(device) 269 | 270 | bwd_ema = EMAHelper(bwd_nn, ema_decay, device) 271 | fwd_ema = EMAHelper(fwd_nn, ema_decay, device) 272 | bwd_sample_nn = bwd_ema.ema_copy() 273 | fwd_sample_nn = fwd_ema.ema_copy() 274 | 275 | bwd_nn.train() 276 | fwd_nn.train() 277 | bwd_sample_nn.eval() 278 | fwd_sample_nn.eval() 279 | 280 | bwd_optim = th.optim.Adam(bwd_nn.parameters(), lr=learning_rate) 281 | fwd_optim = th.optim.Adam(fwd_nn.parameters(), lr=learning_rate) 282 | 283 | dt = 1.0 / discretization_steps 284 | t_T = 1.0 - dt * 0.5 285 | 286 | s_path = th.zeros( 287 | size=(discretization_steps + 1,) + (cache_batch_dim, 1, 28, 28), device=device 288 | ) # i: 0, ..., discretization_steps; t: 0, dt, ..., 1.0. 289 | p_path = th.zeros( 290 | size=(discretization_steps,) + (cache_batch_dim, 1, 28, 28), device=device 291 | ) # i: 0, ..., discretization_steps - 1; t: 0, dt, ..., 1.0 - dt. 292 | 293 | progress.start() 294 | step = 0 295 | for iteration in range(1, iterations + 1): 296 | console.log(f"iteration {iteration}: {step}") 297 | progress.update(iteration_t, completed=iteration) 298 | # Setup: 299 | if (iteration % 2) != 0: 300 | # Odd iteration => bwd. 301 | direction = "bwd" 302 | nn = bwd_nn 303 | ema = bwd_ema 304 | sample_nn = bwd_sample_nn 305 | optim = bwd_optim 306 | te_loader_x_0 = te_loader_1 307 | te_loader_x_1 = te_loader_0 308 | 309 | def sample_idbm_coupling(step): 310 | if iteration == 1: 311 | # Independent coupling: 312 | x_0 = next(tr_iter_1)[0].to(device) 313 | x_1 = next(tr_iter_0)[0].to(device) 314 | else: 315 | with th.no_grad(): 316 | if (step - 1) % cache_steps == 0: 317 | console.log(f"cache update: {step}") 318 | # Simulate previously inferred SDE: 319 | s_path[0] = next(tr_cache_iter_0)[0].to(device) 320 | euler_discretization(s_path, p_path, fwd_sample_nn, sigma) 321 | # Random selection: 322 | idx = th.randperm(cache_batch_dim, device=device)[:batch_dim] 323 | # Reverse path: 324 | x_0, x_1 = s_path[-1, idx], s_path[0, idx] 325 | return x_0, x_1 326 | 327 | def sample_dipf_path(step): 328 | with th.no_grad(): 329 | if (step - 1) % cache_steps == 0: 330 | console.log(f"cache update: {step}") 331 | # Simulate previously inferred SDE: 332 | # NN initialized at 0.0 => first iteration == refence SDE. 333 | s_path[0] = next(tr_cache_iter_0)[0].to(device) 334 | euler_discretization(s_path, p_path, fwd_sample_nn, sigma) 335 | # Random selection: 336 | idx = th.randperm(cache_batch_dim, device=device)[:batch_dim] 337 | # Reverse path: 338 | x_path = th.flip(s_path[:, idx], [0]) 339 | return x_path 340 | 341 | else: 342 | # Even iteration => fwd. 343 | direction = "fwd" 344 | nn = fwd_nn 345 | ema = fwd_ema 346 | sample_nn = fwd_sample_nn 347 | optim = fwd_optim 348 | te_loader_x_0 = te_loader_0 349 | te_loader_x_1 = te_loader_1 350 | 351 | def sample_idbm_coupling(step): 352 | with th.no_grad(): 353 | if (step - 1) % cache_steps == 0: 354 | console.log(f"cache update: {step}") 355 | # Simulate previously inferred SDE: 356 | s_path[0] = next(tr_cache_iter_1)[0].to(device) 357 | euler_discretization(s_path, p_path, bwd_sample_nn, sigma) 358 | # Random selection: 359 | idx = th.randperm(cache_batch_dim, device=device)[:batch_dim] 360 | # Reverse path: 361 | x_0, x_1 = s_path[-1, idx], s_path[0, idx] 362 | return x_0, x_1 363 | 364 | def sample_dipf_path(step): 365 | with th.no_grad(): 366 | if (step - 1) % cache_steps == 0: 367 | console.log(f"cache update: {step}") 368 | # Simulate previously inferred SDE: 369 | s_path[0] = next(tr_cache_iter_1)[0].to(device) 370 | euler_discretization(s_path, p_path, bwd_sample_nn, sigma) 371 | # Random selection: 372 | idx = th.randperm(cache_batch_dim, device=device)[:batch_dim] 373 | # Reverse path: 374 | x_path = th.flip(s_path[:, idx], [0]) 375 | return x_path 376 | 377 | for step in range(step + 1, step + training_steps + 1): 378 | progress.update(step_t, completed=step) 379 | optim.zero_grad() 380 | 381 | if method == IDBM: 382 | x_0, x_1 = sample_idbm_coupling(step) 383 | t = th.rand(size=(batch_dim,), device=device) * t_T 384 | x_t = sample_bridge(x_0, x_1, t, sigma) 385 | target_t = idbm_target(x_t, x_1, t) 386 | elif method == DIPF: 387 | x_path = sample_dipf_path(step) 388 | t_i = th.randint( 389 | 0, discretization_steps, size=(batch_dim,), device=device 390 | ) 391 | t = t_i.to(th.float32) * dt 392 | x_t = th.stack([x_path[ti, i] for i, ti in enumerate(t_i)]) 393 | x_t_dt = th.stack([x_path[ti, i] for i, ti in enumerate(t_i + 1)]) 394 | target_t = drift_target(x_t, x_t_dt, dt) 395 | 396 | alpha_t = nn(x_t, None, t) 397 | losses = (target_t - alpha_t) ** 2 398 | losses = th.mean(losses.view(losses.shape[0], -1), dim=1) 399 | if method == DIPF: 400 | losses = losses / sigma**2 401 | loss = th.mean(losses) 402 | 403 | loss.backward() 404 | if grad_max_norm > 0: 405 | grad_norm = th.nn.utils.clip_grad_norm_(nn.parameters(), grad_max_norm) 406 | optim.step() 407 | 408 | ema.update() 409 | 410 | if step % test_steps == 0: 411 | console.log(f"test: {step}") 412 | ema.ema(sample_nn) 413 | with th.no_grad(): 414 | te_s_path = th.zeros( 415 | size=(discretization_steps + 1,) + (test_batch_dim, 1, 28, 28), 416 | device=device, 417 | ) 418 | te_p_path = th.zeros( 419 | size=(discretization_steps,) + (test_batch_dim, 1, 28, 28), 420 | device=device, 421 | ) 422 | # Assumes data is in [0.0, 1.0], scale appropriately: 423 | fid_metric = FrechetInceptionDistance(normalize=True).to(device) 424 | drift_norm = [] 425 | for te_x_0, te_x_1 in zip(te_loader_x_0, te_loader_x_1): 426 | te_x_0, te_x_1 = te_x_0[0].to(device), te_x_1[0].to(device) 427 | te_x_1 = (te_x_1 + 1.0) / 2.0 428 | te_s_path[0] = te_x_0 429 | drift_norm.append( 430 | euler_discretization(te_s_path, te_p_path, sample_nn, sigma) 431 | ) 432 | te_s_path = th.clip((te_s_path + 1.0) / 2.0, 0.0, 1.0) 433 | te_p_path = th.clip((te_p_path + 1.0) / 2.0, 0.0, 1.0) 434 | fid_metric.update(te_x_1.expand(-1, 3, -1, -1), real=True) 435 | if method == IDBM: 436 | fid_idx = -2 437 | elif method == DIPF: 438 | fid_idx = -1 439 | fid_metric.update( 440 | te_p_path[fid_idx].expand(-1, 3, -1, -1), real=False 441 | ) 442 | drift_norm = th.mean(th.cat(drift_norm)).item() 443 | fid = fid_metric.compute().item() 444 | wandb.log({f"{direction}/test/drift_norm": drift_norm}, step=step) 445 | wandb.log({f"{direction}/test/fid": fid}, step=step) 446 | for i, ti in enumerate( 447 | resample_indices(discretization_steps + 1, 5) 448 | ): 449 | wandb.log( 450 | {f"{direction}/test/x[{i}-{5}]": image_grid(te_s_path[ti])}, 451 | step=step, 452 | ) 453 | for i, ti in enumerate(resample_indices(discretization_steps, 5)): 454 | wandb.log( 455 | {f"{direction}/test/p[{i}-{5}]": image_grid(te_p_path[ti])}, 456 | step=step, 457 | ) 458 | 459 | if step % loss_log_steps == 0: 460 | wandb.log({f"{direction}/train/loss": loss.item()}, step=step) 461 | wandb.log({f"{direction}/train/grad_norm": grad_norm}, step=step) 462 | 463 | if step % imge_log_steps == 0: 464 | if method == DIPF: 465 | x_0 = x_path[0] 466 | x_1 = x_path[-1] 467 | wandb.log({f"{direction}/train/x_0": image_grid(x_0, True)}, step=step) 468 | wandb.log({f"{direction}/train/x_1": image_grid(x_1, True)}, step=step) 469 | 470 | if step % training_steps == 0: 471 | console.log(f"EMA update: {step}") 472 | # Make sure EMA is updated at the end of each iteration: 473 | ema.ema(sample_nn) 474 | progress.stop() 475 | 476 | 477 | if __name__ == "__main__": 478 | fire.Fire(run) 479 | -------------------------------------------------------------------------------- /idbm/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import * 2 | 3 | -------------------------------------------------------------------------------- /idbm/unet/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() -------------------------------------------------------------------------------- /idbm/unet/layers.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from abc import abstractmethod 4 | import torch as th 5 | import torch.nn as nn 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | class GroupNorm32(nn.GroupNorm): 11 | def forward(self, x): 12 | return super().forward(x.float()).type(x.dtype) 13 | 14 | 15 | def conv_nd(dims, *args, **kwargs): 16 | """ 17 | Create a 1D, 2D, or 3D convolution module. 18 | """ 19 | if dims == 1: 20 | return nn.Conv1d(*args, **kwargs) 21 | elif dims == 2: 22 | return nn.Conv2d(*args, **kwargs) 23 | elif dims == 3: 24 | return nn.Conv3d(*args, **kwargs) 25 | raise ValueError(f"unsupported dimensions: {dims}") 26 | 27 | 28 | def linear(*args, **kwargs): 29 | """ 30 | Create a linear module. 31 | """ 32 | return nn.Linear(*args, **kwargs) 33 | 34 | 35 | def avg_pool_nd(dims, *args, **kwargs): 36 | """ 37 | Create a 1D, 2D, or 3D average pooling module. 38 | """ 39 | if dims == 1: 40 | return nn.AvgPool1d(*args, **kwargs) 41 | elif dims == 2: 42 | return nn.AvgPool2d(*args, **kwargs) 43 | elif dims == 3: 44 | return nn.AvgPool3d(*args, **kwargs) 45 | raise ValueError(f"unsupported dimensions: {dims}") 46 | 47 | 48 | def update_ema(target_params, source_params, rate=0.99): 49 | """ 50 | Update target parameters to be closer to those of source parameters using 51 | an exponential moving average. 52 | :param target_params: the target parameter sequence. 53 | :param source_params: the source parameter sequence. 54 | :param rate: the EMA rate (closer to 1 means slower). 55 | """ 56 | for targ, src in zip(target_params, source_params): 57 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 58 | 59 | 60 | def zero_module(module, active=True): 61 | """ 62 | Zero out the parameters of a module and return it. 63 | """ 64 | if active: 65 | for p in module.parameters(): 66 | p.detach().zero_() 67 | return module 68 | 69 | 70 | def scale_module(module, scale): 71 | """ 72 | Scale the parameters of a module and return it. 73 | """ 74 | for p in module.parameters(): 75 | p.detach().mul_(scale) 76 | return module 77 | 78 | 79 | def mean_flat(tensor): 80 | """ 81 | Take the mean over all non-batch dimensions. 82 | """ 83 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 84 | 85 | 86 | def normalization(channels, num_groups=32): 87 | """ 88 | Make a standard normalization layer. 89 | :param channels: number of input channels. 90 | :return: an nn.Module for normalization. 91 | """ 92 | return GroupNorm32(num_groups, channels) 93 | # return nn.GroupNorm(32, channels) 94 | 95 | def normalization_act(channels, num_groups=32): 96 | return nn.Sequential(GroupNorm32(num_groups, channels), nn.SiLU(inplace=True)) 97 | 98 | def timestep_embedding(timesteps, dim, max_period=10000): 99 | """ 100 | Create sinusoidal timestep embeddings. 101 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 102 | These may be fractional. 103 | :param dim: the dimension of the output. 104 | :param max_period: controls the minimum frequency of the embeddings. 105 | :return: an [N x dim] Tensor of positional embeddings. 106 | """ 107 | half = dim // 2 108 | freqs = th.exp( 109 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / (half - 1) 110 | ).to(device=timesteps.device) 111 | args = timesteps[:, None].float() * freqs[None] 112 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 113 | if dim % 2: 114 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 115 | return embedding 116 | 117 | 118 | def checkpoint(func, inputs, params, flag): 119 | """ 120 | Evaluate a function without caching intermediate activations, allowing for 121 | reduced memory at the expense of extra compute in the backward pass. 122 | :param func: the function to evaluate. 123 | :param inputs: the argument sequence to pass to `func`. 124 | :param params: a sequence of parameters `func` depends on but does not 125 | explicitly take as arguments. 126 | :param flag: if False, disable gradient checkpointing. 127 | """ 128 | if flag: 129 | args = tuple(inputs) + tuple(params) 130 | return CheckpointFunction.apply(func, len(inputs), *args) 131 | else: 132 | return func(*inputs) 133 | 134 | 135 | class CheckpointFunction(th.autograd.Function): 136 | @staticmethod 137 | def forward(ctx, run_function, length, *args): 138 | ctx.run_function = run_function 139 | ctx.input_tensors = list(args[:length]) 140 | ctx.input_params = list(args[length:]) 141 | with th.no_grad(): 142 | output_tensors = ctx.run_function(*ctx.input_tensors) 143 | return output_tensors 144 | 145 | @staticmethod 146 | def backward(ctx, *output_grads): 147 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 148 | with th.enable_grad(): 149 | # Fixes a bug where the first op in run_function modifies the 150 | # Tensor storage in place, which is not allowed for detach()'d 151 | # Tensors. 152 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 153 | output_tensors = ctx.run_function(*shallow_copies) 154 | input_grads = th.autograd.grad( 155 | output_tensors, 156 | ctx.input_tensors + ctx.input_params, 157 | output_grads, 158 | allow_unused=True, 159 | ) 160 | del ctx.input_tensors 161 | del ctx.input_params 162 | del output_tensors 163 | return (None, None) + input_grads 164 | 165 | 166 | class TimestepBlock(nn.Module): 167 | """ 168 | Any module where forward() takes timestep embeddings as a second argument. 169 | """ 170 | 171 | @abstractmethod 172 | def forward(self, x, emb): 173 | """ 174 | Apply the module to `x` given `emb` timestep embeddings. 175 | """ 176 | 177 | 178 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 179 | """ 180 | A sequential module that passes timestep embeddings to the children that 181 | support it as an extra input. 182 | """ 183 | 184 | def forward(self, x, emb): 185 | for layer in self: 186 | if isinstance(layer, TimestepBlock): 187 | x = layer(x, emb) 188 | else: 189 | x = layer(x) 190 | return x 191 | 192 | 193 | class Upsample(nn.Module): 194 | """ 195 | An upsampling layer with an optional convolution. 196 | :param channels: channels in the inputs and outputs. 197 | :param use_conv: a bool determining if a convolution is applied. 198 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 199 | upsampling occurs in the inner-two dimensions. 200 | """ 201 | 202 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 203 | super().__init__() 204 | self.channels = channels 205 | self.out_channels = out_channels or channels 206 | self.use_conv = use_conv 207 | self.dims = dims 208 | if use_conv: 209 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 210 | 211 | def forward(self, x): 212 | assert x.shape[1] == self.channels 213 | if self.dims == 3: 214 | x = F.interpolate( 215 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 216 | ) 217 | else: 218 | x = F.interpolate(x, scale_factor=2, mode="nearest") 219 | if self.use_conv: 220 | x = self.conv(x) 221 | return x 222 | 223 | 224 | class Downsample(nn.Module): 225 | """ 226 | A downsampling layer with an optional convolution. 227 | :param channels: channels in the inputs and outputs. 228 | :param use_conv: a bool determining if a convolution is applied. 229 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 230 | downsampling occurs in the inner-two dimensions. 231 | """ 232 | 233 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 234 | super().__init__() 235 | self.channels = channels 236 | self.out_channels = out_channels or channels 237 | self.use_conv = use_conv 238 | self.dims = dims 239 | stride = 2 if dims != 3 else (1, 2, 2) 240 | if use_conv: 241 | self.op = conv_nd( 242 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 243 | ) 244 | else: 245 | assert self.channels == self.out_channels 246 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 247 | 248 | def forward(self, x): 249 | assert x.shape[1] == self.channels 250 | return self.op(x) 251 | 252 | 253 | class ResBlock(TimestepBlock): 254 | """ 255 | A residual block that can optionally change the number of channels. 256 | :param channels: the number of input channels. 257 | :param emb_channels: the number of timestep embedding channels. 258 | :param dropout: the rate of dropout. 259 | :param out_channels: if specified, the number of out channels. 260 | :param use_conv: if True and out_channels is specified, use a spatial 261 | convolution instead of a smaller 1x1 convolution to change the 262 | channels in the skip connection. 263 | :param dims: determines if the signal is 1D, 2D, or 3D. 264 | :param use_checkpoint: if True, use gradient checkpointing on this module. 265 | :param up: if True, use this block for upsampling. 266 | :param down: if True, use this block for downsampling. 267 | """ 268 | 269 | def __init__( 270 | self, 271 | channels, 272 | emb_channels, 273 | dropout, 274 | out_channels=None, 275 | use_conv=False, 276 | use_scale_shift_norm=False, 277 | dims=2, 278 | use_checkpoint=False, 279 | up=False, 280 | down=False, 281 | num_groups=32 282 | ): 283 | super().__init__() 284 | self.channels = channels 285 | self.emb_channels = emb_channels 286 | self.dropout = dropout 287 | self.out_channels = out_channels or channels 288 | self.use_conv = use_conv 289 | self.use_checkpoint = use_checkpoint 290 | self.use_scale_shift_norm = use_scale_shift_norm 291 | 292 | self.in_layers = nn.Sequential( 293 | normalization(channels, num_groups), 294 | nn.SiLU(inplace=True), 295 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 296 | ) 297 | 298 | self.updown = up or down 299 | 300 | if up: 301 | self.h_upd = Upsample(channels, False, dims) 302 | self.x_upd = Upsample(channels, False, dims) 303 | elif down: 304 | self.h_upd = Downsample(channels, False, dims) 305 | self.x_upd = Downsample(channels, False, dims) 306 | else: 307 | self.h_upd = self.x_upd = nn.Identity() 308 | 309 | self.emb_layers = nn.Sequential( 310 | nn.SiLU(), 311 | linear( 312 | emb_channels, 313 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 314 | ), 315 | ) 316 | self.out_layers = nn.Sequential( 317 | normalization(self.out_channels, num_groups), 318 | nn.SiLU(inplace=True), 319 | nn.Dropout(p=dropout, inplace=True), 320 | zero_module( 321 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 322 | ), 323 | ) 324 | 325 | if self.out_channels == channels: 326 | self.skip_connection = nn.Identity() 327 | elif use_conv: 328 | self.skip_connection = conv_nd( 329 | dims, channels, self.out_channels, 3, padding=1 330 | ) 331 | else: 332 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 333 | 334 | def forward(self, x, emb): 335 | """ 336 | Apply the block to a Tensor, conditioned on a timestep embedding. 337 | :param x: an [N x C x ...] Tensor of features. 338 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 339 | :return: an [N x C x ...] Tensor of outputs. 340 | """ 341 | return checkpoint( 342 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 343 | ) 344 | 345 | def _forward(self, x, emb): 346 | if self.updown: 347 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 348 | h = in_rest(x) 349 | h = self.h_upd(h) 350 | x = self.x_upd(x) 351 | h = in_conv(h) 352 | else: 353 | h = self.in_layers(x) 354 | emb_out = self.emb_layers(emb).type(h.dtype) 355 | while len(emb_out.shape) < len(h.shape): 356 | emb_out = emb_out[..., None] 357 | if self.use_scale_shift_norm: 358 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 359 | scale, shift = th.chunk(emb_out, 2, dim=1) 360 | h = out_norm(h) * (1 + scale) + shift 361 | h = out_rest(h) 362 | else: 363 | h = h + emb_out 364 | h = self.out_layers(h) 365 | return self.skip_connection(x) + h 366 | 367 | 368 | class AttentionBlock(nn.Module): 369 | """ 370 | An attention block that allows spatial positions to attend to each other. 371 | Originally ported from here, but adapted to the N-d case. 372 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 373 | """ 374 | 375 | def __init__(self, channels, num_heads=1, use_checkpoint=False, num_groups=32): 376 | super().__init__() 377 | self.channels = channels 378 | self.num_heads = num_heads 379 | self.use_checkpoint = use_checkpoint 380 | 381 | self.norm = normalization(channels, num_groups) 382 | self.qkv = conv_nd(1, channels, channels * 3, 1) 383 | self.attention = QKVAttention() 384 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 385 | 386 | def forward(self, x): 387 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 388 | 389 | def _forward(self, x): 390 | b, c, *spatial = x.shape 391 | x = x.reshape(b, c, -1) 392 | qkv = self.qkv(self.norm(x)) 393 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 394 | h = self.attention(qkv) 395 | h = h.reshape(b, -1, h.shape[-1]) 396 | h = self.proj_out(h) 397 | return (x + h).reshape(b, c, *spatial) 398 | 399 | 400 | class QKVAttention(nn.Module): 401 | """ 402 | A module which performs QKV attention. 403 | """ 404 | 405 | def forward(self, qkv): 406 | """ 407 | Apply QKV attention. 408 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 409 | :return: an [N x C x T] tensor after attention. 410 | """ 411 | ch = qkv.shape[1] // 3 412 | q, k, v = th.split(qkv, ch, dim=1) 413 | scale = 1 / math.sqrt(math.sqrt(ch)) 414 | weight = th.einsum( 415 | "bct,bcs->bts", q * scale, k * scale 416 | ) # More stable with f16 than dividing afterwards 417 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 418 | return th.einsum("bts,bcs->bct", weight, v) 419 | 420 | @staticmethod 421 | def count_flops(model, _x, y): 422 | """ 423 | A counter for the `thop` package to count the operations in an 424 | attention operation. 425 | Meant to be used like: 426 | macs, params = thop.profile( 427 | model, 428 | inputs=(inputs, timestamps), 429 | custom_ops={QKVAttention: QKVAttention.count_flops}, 430 | ) 431 | """ 432 | b, c, *spatial = y[0].shape 433 | num_spatial = int(np.prod(spatial)) 434 | # We perform two matmuls with the same number of ops. 435 | # The first computes the weight matrix, the second computes 436 | # the combination of the value vectors. 437 | matmul_ops = 2 * b * (num_spatial ** 2) * c 438 | model.total_ops += th.DoubleTensor([matmul_ops]) 439 | 440 | 441 | class BasicResBlock(nn.Module): 442 | """ 443 | A residual block that can optionally change the number of channels. 444 | :param channels: the number of input channels. 445 | :param dropout: the rate of dropout. 446 | :param out_channels: if specified, the number of out channels. 447 | :param use_conv: if True and out_channels is specified, use a spatial 448 | convolution instead of a smaller 1x1 convolution to change the 449 | channels in the skip connection. 450 | :param dims: determines if the signal is 1D, 2D, or 3D. 451 | :param use_checkpoint: if True, use gradient checkpointing on this module. 452 | :param up: if True, use this block for upsampling. 453 | :param down: if True, use this block for downsampling. 454 | """ 455 | 456 | def __init__( 457 | self, 458 | channels, 459 | dropout, 460 | out_channels=None, 461 | use_conv=False, 462 | dims=2, 463 | use_checkpoint=False, 464 | up=False, 465 | down=False, 466 | ): 467 | super().__init__() 468 | self.channels = channels 469 | self.dropout = dropout 470 | self.out_channels = out_channels or channels 471 | self.use_conv = use_conv 472 | self.use_checkpoint = use_checkpoint 473 | 474 | self.in_layers = nn.Sequential( 475 | normalization(channels), 476 | nn.SiLU(inplace=True), 477 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 478 | ) 479 | 480 | self.updown = up or down 481 | 482 | if up: 483 | self.h_upd = Upsample(channels, False, dims) 484 | self.x_upd = Upsample(channels, False, dims) 485 | elif down: 486 | self.h_upd = Downsample(channels, False, dims) 487 | self.x_upd = Downsample(channels, False, dims) 488 | else: 489 | self.h_upd = self.x_upd = nn.Identity() 490 | 491 | self.out_layers = nn.Sequential( 492 | normalization(self.out_channels), 493 | nn.SiLU(inplace=True), 494 | nn.Dropout(p=dropout, inplace=True), 495 | zero_module( 496 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 497 | ), 498 | ) 499 | 500 | if self.out_channels == channels: 501 | self.skip_connection = nn.Identity() 502 | elif use_conv: 503 | self.skip_connection = conv_nd( 504 | dims, channels, self.out_channels, 3, padding=1 505 | ) 506 | else: 507 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 508 | 509 | def forward(self, x): 510 | """ 511 | Apply the block to a Tensor. 512 | :param x: an [N x C x ...] Tensor of features. 513 | :return: an [N x C x ...] Tensor of outputs. 514 | """ 515 | return checkpoint( 516 | self._forward, (x, ), self.parameters(), self.use_checkpoint 517 | ) 518 | 519 | def _forward(self, x): 520 | if self.updown: 521 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 522 | h = in_rest(x) 523 | h = self.h_upd(h) 524 | x = self.x_upd(x) 525 | h = in_conv(h) 526 | else: 527 | h = self.in_layers(x) 528 | 529 | h = self.out_layers(h) 530 | return self.skip_connection(x) + h 531 | 532 | def expand_dims(t, target_len): 533 | assert target_len >= len(t.shape) 534 | out = t[(..., ) + (None, ) * (target_len - len(t.shape))] 535 | return out -------------------------------------------------------------------------------- /idbm/unet/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .layers import * 12 | 13 | 14 | class UNetModel(nn.Module): 15 | """ 16 | The full UNet model with attention and timestep embedding. 17 | :param in_channels: channels in the input Tensor. 18 | :param model_channels: base channel count for the model. 19 | :param out_channels: channels in the output Tensor. 20 | :param num_res_blocks: number of residual blocks per downsample. 21 | :param attention_resolutions: a collection of downsample rates at which 22 | attention will take place. May be a set, list, or tuple. 23 | For example, if this contains 4, then at 4x downsampling, attention 24 | will be used. 25 | :param dropout: the dropout probability. 26 | :param channel_mult: channel multiplier for each level of the UNet. 27 | :param conv_resample: if True, use learned convolutions for upsampling and 28 | downsampling. 29 | :param dims: determines if the signal is 1D, 2D, or 3D. 30 | :param num_classes: if specified (as an int), then this model will be 31 | class-conditional with `num_classes` classes. 32 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 33 | :param num_heads: the number of attention heads in each attention layer. 34 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 35 | :param resblock_updown: use residual blocks for up/downsampling. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | in_channels, 41 | model_channels, 42 | out_channels, 43 | num_res_blocks, 44 | attention_resolutions, 45 | dropout=0, 46 | channel_mult=(1, 2, 4, 8), 47 | conv_resample=True, 48 | dims=2, 49 | num_classes=None, 50 | use_checkpoint=False, 51 | num_heads=1, 52 | use_scale_shift_norm=False, 53 | resblock_updown=False, 54 | temb_scale=1 55 | ): 56 | super().__init__() 57 | 58 | self.locals = [ in_channels, 59 | model_channels, 60 | out_channels, 61 | num_res_blocks, 62 | attention_resolutions, 63 | dropout, 64 | channel_mult, 65 | conv_resample, 66 | dims, 67 | num_classes, 68 | use_checkpoint, 69 | num_heads, 70 | use_scale_shift_norm, 71 | resblock_updown, 72 | temb_scale 73 | ] 74 | self.in_channels = in_channels 75 | self.model_channels = model_channels 76 | self.out_channels = out_channels 77 | self.num_res_blocks = num_res_blocks 78 | self.attention_resolutions = attention_resolutions 79 | self.dropout = dropout 80 | self.channel_mult = channel_mult 81 | self.conv_resample = conv_resample 82 | self.num_classes = num_classes 83 | self.use_checkpoint = use_checkpoint 84 | self.num_heads = num_heads 85 | self.temb_scale = temb_scale 86 | 87 | # some hacky logic to allow small unets 88 | if self.model_channels <= 32: 89 | self.num_groups = 8 90 | else: 91 | self.num_groups = 32 92 | 93 | self.input_ch = int(channel_mult[0] * model_channels) 94 | ch = self.input_ch 95 | 96 | time_embed_dim = self.input_ch * 4 97 | self.time_embed = nn.Sequential( 98 | linear(self.input_ch, time_embed_dim), 99 | nn.SiLU(inplace=True), 100 | linear(time_embed_dim, time_embed_dim), 101 | ) 102 | 103 | if self.num_classes is not None: 104 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 105 | 106 | self.input_blocks = nn.ModuleList( 107 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 108 | ) 109 | self._feature_size = ch 110 | input_block_chans = [ch] 111 | ds = 1 112 | for level, mult in enumerate(channel_mult): 113 | for _ in range(num_res_blocks): 114 | layers = [ 115 | ResBlock( 116 | ch, 117 | time_embed_dim, 118 | dropout, 119 | out_channels=int(mult * model_channels), 120 | dims=dims, 121 | use_checkpoint=use_checkpoint, 122 | use_scale_shift_norm=use_scale_shift_norm, 123 | num_groups=self.num_groups 124 | ) 125 | ] 126 | ch = int(mult * model_channels) 127 | if ds in attention_resolutions: 128 | layers.append( 129 | AttentionBlock( 130 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_groups=self.num_groups 131 | ) 132 | ) 133 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 134 | self._feature_size += ch 135 | input_block_chans.append(ch) 136 | if level != len(channel_mult) - 1: 137 | out_ch = ch 138 | self.input_blocks.append( 139 | TimestepEmbedSequential( 140 | ResBlock( 141 | ch, 142 | time_embed_dim, 143 | dropout, 144 | out_channels=out_ch, 145 | dims=dims, 146 | use_checkpoint=use_checkpoint, 147 | use_scale_shift_norm=use_scale_shift_norm, 148 | down=True, 149 | num_groups=self.num_groups 150 | ) 151 | if resblock_updown 152 | else Downsample( 153 | ch, conv_resample, dims=dims, out_channels=out_ch 154 | ) 155 | ) 156 | ) 157 | input_block_chans.append(ch) 158 | ds *= 2 159 | self._feature_size += ch 160 | 161 | self.middle_block = TimestepEmbedSequential( 162 | ResBlock( 163 | ch, 164 | time_embed_dim, 165 | dropout, 166 | dims=dims, 167 | use_checkpoint=use_checkpoint, 168 | use_scale_shift_norm=use_scale_shift_norm, 169 | num_groups=self.num_groups 170 | ), 171 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_groups=self.num_groups), 172 | ResBlock( 173 | ch, 174 | time_embed_dim, 175 | dropout, 176 | dims=dims, 177 | use_checkpoint=use_checkpoint, 178 | use_scale_shift_norm=use_scale_shift_norm, 179 | num_groups=self.num_groups 180 | ), 181 | ) 182 | self._feature_size += ch 183 | 184 | self.output_blocks = nn.ModuleList([]) 185 | for level, mult in list(enumerate(channel_mult))[::-1]: 186 | for i in range(num_res_blocks + 1): 187 | ich = input_block_chans.pop() 188 | layers = [ 189 | ResBlock( 190 | ch + ich, 191 | time_embed_dim, 192 | dropout, 193 | out_channels=int(model_channels * mult), 194 | dims=dims, 195 | use_checkpoint=use_checkpoint, 196 | use_scale_shift_norm=use_scale_shift_norm, 197 | num_groups=self.num_groups 198 | ) 199 | ] 200 | ch = int(model_channels * mult) 201 | if ds in attention_resolutions: 202 | layers.append( 203 | AttentionBlock( 204 | ch, 205 | use_checkpoint=use_checkpoint, 206 | num_heads=num_heads, 207 | num_groups=self.num_groups 208 | ) 209 | ) 210 | if level and i == num_res_blocks: 211 | out_ch = ch 212 | layers.append( 213 | ResBlock( 214 | ch, 215 | time_embed_dim, 216 | dropout, 217 | out_channels=out_ch, 218 | dims=dims, 219 | use_checkpoint=use_checkpoint, 220 | use_scale_shift_norm=use_scale_shift_norm, 221 | up=True, 222 | num_groups=self.num_groups 223 | ) 224 | if resblock_updown 225 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 226 | ) 227 | ds //= 2 228 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 229 | self._feature_size += ch 230 | 231 | self.out = nn.Sequential( 232 | normalization(ch, self.num_groups), 233 | nn.SiLU(inplace=True), 234 | zero_module(conv_nd(dims, self.input_ch, out_channels, 3, padding=1)), 235 | ) 236 | 237 | 238 | def convert_to_fp16(self): 239 | """ 240 | Convert the torso of the model to float16. 241 | """ 242 | self.input_blocks.apply(convert_module_to_f16) 243 | self.middle_block.apply(convert_module_to_f16) 244 | self.output_blocks.apply(convert_module_to_f16) 245 | 246 | def convert_to_fp32(self): 247 | """ 248 | Convert the torso of the model to float32. 249 | """ 250 | self.input_blocks.apply(convert_module_to_f32) 251 | self.middle_block.apply(convert_module_to_f32) 252 | self.output_blocks.apply(convert_module_to_f32) 253 | 254 | 255 | def forward(self, x, y, timesteps): 256 | 257 | """ 258 | Apply the model to an input batch. 259 | :param x: an [N x C x ...] Tensor of inputs. 260 | :param timesteps: a 1-D batch of timesteps. 261 | :param y: an [N] Tensor of labels, if class-conditional. 262 | :return: an [N x C x ...] Tensor of outputs. 263 | """ 264 | timesteps = timesteps.squeeze() 265 | assert (y is not None) == ( 266 | self.num_classes is not None 267 | ), "must specify y if and only if the model is class-conditional" 268 | 269 | hs = [] 270 | emb = self.time_embed(timestep_embedding(timesteps * self.temb_scale + 1, self.input_ch)) 271 | 272 | if self.num_classes is not None: 273 | assert y.shape == (x.shape[0],) 274 | emb = emb + self.label_emb(y) 275 | 276 | h = x # .type(self.dtype) 277 | for module in self.input_blocks: 278 | h = module(h, emb) 279 | hs.append(h) 280 | h = self.middle_block(h, emb) 281 | for module in self.output_blocks: 282 | h = th.cat([h, hs.pop()], dim=1) 283 | h = module(h, emb) 284 | h = h.type(x.dtype) 285 | return self.out(h) 286 | 287 | 288 | class SuperResModel(UNetModel): 289 | """ 290 | A UNetModel that performs super-resolution. 291 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 292 | """ 293 | 294 | def __init__(self, in_channels, cond_channels, *args, **kwargs): 295 | super().__init__(in_channels + cond_channels, *args, **kwargs) 296 | self.locals[0] = in_channels 297 | self.locals.insert(1, cond_channels) 298 | 299 | def forward(self, x, low_res, timesteps, **kwargs): 300 | _, _, new_height, new_width = x.shape 301 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 302 | x = th.cat([x, upsampled], dim=1) 303 | return super().forward(x, None, timesteps, **kwargs) 304 | 305 | 306 | class DownscalerUNetModel(nn.Module): 307 | def __init__( 308 | self, 309 | in_channels, 310 | cond_channels, 311 | model_channels, 312 | out_channels, 313 | num_res_blocks, 314 | dropout=0, 315 | channel_mult=(1, 2, 4, 8), 316 | dims=2, 317 | temb_scale=1, 318 | mean_bypass=False, 319 | scale_mean_bypass=False, 320 | shift_input=False, 321 | shift_output=False, 322 | **kwargs 323 | ): 324 | super().__init__() 325 | 326 | self.locals = [ in_channels, 327 | cond_channels, 328 | model_channels, 329 | out_channels, 330 | num_res_blocks, 331 | dropout, 332 | channel_mult, 333 | dims, 334 | temb_scale, 335 | mean_bypass, 336 | scale_mean_bypass, 337 | shift_input, 338 | shift_output 339 | ] 340 | 341 | in_channels = in_channels + cond_channels 342 | self.in_channels = in_channels 343 | self.model_channels = model_channels 344 | self.out_channels = out_channels 345 | self.num_res_blocks = num_res_blocks 346 | self.dropout = dropout 347 | self.channel_mult = channel_mult 348 | self.temb_scale = temb_scale 349 | 350 | self.mean_bypass = mean_bypass 351 | self.scale_mean_bypass = scale_mean_bypass 352 | self.shift_input = shift_input 353 | self.shift_output = shift_output 354 | 355 | assert len(channel_mult) == 4 356 | self.input_ch = int(channel_mult[0] * model_channels) 357 | ch = self.input_ch 358 | 359 | embed_dim = time_embed_dim = int(channel_mult[-1] * model_channels) 360 | self.time_embed_dim = time_embed_dim 361 | self.time_embed = nn.Sequential( 362 | linear(time_embed_dim, time_embed_dim), 363 | nn.SiLU(inplace=True), 364 | # linear(time_embed_dim, time_embed_dim), 365 | ) 366 | 367 | if self.mean_bypass: 368 | self.mean_skip_1 = conv_nd(dims, in_channels, embed_dim, 1) # Conv((1, 1), inchannels => embed_dim) 369 | self.mean_skip_2 = conv_nd(dims, embed_dim, embed_dim, 1) # Conv((1, 1), embed_dim => embed_dim) 370 | self.mean_skip_3 = conv_nd(dims, embed_dim, out_channels, 1) # Conv((1, 1), embed_dim => outchannels) 371 | self.mean_dense_1 = linear(embed_dim, embed_dim) 372 | self.mean_dense_2 = linear(embed_dim, embed_dim) 373 | self.mean_gnorm_1 = normalization_act(embed_dim, 32) # GroupNorm(embed_dim, 32, swish) 374 | self.mean_gnorm_2 = normalization_act(embed_dim, 32) # GroupNorm(embed_dim, 32, swish) 375 | 376 | self.conv1 = conv_nd(dims, in_channels, ch, 3, padding=1) # 3 -> 32 # Conv((3, 3), inchannels => channels[1], stride=1, pad=SamePad()) 377 | self.dense1 = linear(time_embed_dim, ch) # Dense(embed_dim, channels[1]), 378 | self.gnorm1 = normalization_act(ch, 4) # GroupNorm(channels[1], 4, swish), 379 | 380 | # Encoding 381 | out_ch = int(channel_mult[1] * model_channels) 382 | self.conv2 = Downsample(ch, use_conv=True, dims=dims, out_channels=out_ch) # 32 -> 64 383 | self.dense2 = linear(time_embed_dim, out_ch) 384 | self.gnorm2 = normalization_act(out_ch, 32) 385 | 386 | ch = out_ch 387 | out_ch = int(channel_mult[2] * model_channels) 388 | self.conv3 = Downsample(ch, use_conv=True, dims=dims, out_channels=out_ch) # 64 -> 128 389 | self.dense3 = linear(time_embed_dim, out_ch) 390 | self.gnorm3 = normalization_act(out_ch, 32) 391 | 392 | ch = out_ch 393 | out_ch = int(channel_mult[3] * model_channels) 394 | self.conv4 = Downsample(ch, use_conv=True, dims=dims, out_channels=out_ch) # 128 -> 256 395 | self.dense4 = linear(time_embed_dim, out_ch) 396 | 397 | self.middle_block = TimestepEmbedSequential( 398 | *[ 399 | ResBlock( 400 | out_ch, 401 | time_embed_dim, 402 | dropout, 403 | dims=dims, 404 | num_groups=min(out_ch//4, 32) 405 | ) for _ in range(num_res_blocks) 406 | ] 407 | ) 408 | 409 | # Decoding 410 | self.gnorm4 = normalization_act(out_ch, 32) 411 | self.tconv4 = Upsample(out_ch, use_conv=True, dims=dims, out_channels=ch) # 256 -> 128 412 | self.denset4 = linear(time_embed_dim, ch) 413 | self.tgnorm4 = normalization_act(ch, 32) 414 | 415 | out_ch = ch 416 | ch = int(channel_mult[1] * model_channels) 417 | self.tconv3 = Upsample(out_ch*2, use_conv=True, dims=dims, out_channels=ch) # 128 + 128 -> 64 418 | self.denset3 = linear(time_embed_dim, ch) 419 | self.tgnorm3 = normalization_act(ch, 32) 420 | 421 | out_ch = ch 422 | ch = int(channel_mult[0] * model_channels) 423 | self.tconv2 = Upsample(out_ch*2, use_conv=True, dims=dims, out_channels=ch) # 64 + 64 -> 32 424 | self.denset2 = linear(time_embed_dim, ch) 425 | self.tgnorm2 = normalization_act(ch, 32) 426 | 427 | self.tconv1 = zero_module(conv_nd(dims, self.input_ch*2, out_channels, 3, padding=1)) 428 | 429 | 430 | def forward(self, x, y, timesteps): 431 | timesteps = timesteps.squeeze() 432 | embed = self.time_embed(timestep_embedding(timesteps * self.temb_scale + 1, self.time_embed_dim)) 433 | 434 | # Encoder 435 | if self.shift_input: 436 | h1 = x - th.mean(x, dim=(-1,-2), keepdim=True) # remove mean of noised variables before input 437 | else: 438 | h1 = x 439 | 440 | h1 = th.cat([x, y], dim=1) 441 | h1 = self.conv1(h1) 442 | h1 = h1 + expand_dims(self.dense1(embed), len(h1.shape)) 443 | h1 = self.gnorm1(h1) 444 | h2 = self.conv2(h1) 445 | h2 = h2 + expand_dims(self.dense2(embed), len(h2.shape)) 446 | h2 = self.gnorm2(h2) 447 | h3 = self.conv3(h2) 448 | h3 = h3 + expand_dims(self.dense3(embed), len(h3.shape)) 449 | h3 = self.gnorm3(h3) 450 | h4 = self.conv4(h3) 451 | h4 = h4 + expand_dims(self.dense4(embed), len(h4.shape)) 452 | 453 | # middle 454 | h = h4 455 | h = self.middle_block(h, embed) 456 | 457 | # Decoder 458 | h = self.gnorm4(h) 459 | h = self.tconv4(h) 460 | h = h + expand_dims(self.denset4(embed), len(h.shape)) 461 | h = self.tgnorm4(h) 462 | h = self.tconv3(th.cat([h, h3], dim=1)) 463 | h = h + expand_dims(self.denset3(embed), len(h.shape)) 464 | h = self.tgnorm3(h) 465 | h = self.tconv2(th.cat([h, h2], dim=1)) 466 | h = h + expand_dims(self.denset2(embed), len(h.shape)) 467 | h = self.tgnorm2(h) 468 | h = self.tconv1(th.cat([h, h1], dim=1)) 469 | 470 | if self.shift_output: 471 | h = h - th.mean(h, dim=(-1,-2), keepdim=True) # remove mean after output 472 | 473 | # Mean processing of noised variable channels 474 | if self.mean_bypass: 475 | hm = self.mean_skip_1(th.mean(th.cat([x, y], dim=1), dim=(-1,-2), keepdim=True)) 476 | hm = hm + expand_dims(self.mean_dense_1(embed), len(hm.shape)) 477 | hm = self.mean_gnorm_1(hm) 478 | hm = self.mean_skip_2(hm) 479 | hm = hm + expand_dims(self.mean_dense_2(embed), len(hm.shape)) 480 | hm = self.mean_gnorm_2(hm) 481 | hm = self.mean_skip_3(hm) 482 | if self.scale_mean_bypass: 483 | scale = np.sqrt(np.prod(x.shape[2:])) 484 | hm = hm / scale 485 | # Add back in noised channel mean to noised channel spatial variatons 486 | return h + hm 487 | else: 488 | return h 489 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchmetrics[image] 2 | fire 3 | rich 4 | wandb 5 | --------------------------------------------------------------------------------