├── HierarchicalDesign ├── Diffusion.py ├── VQVAE.py ├── __init__.py ├── utils.py └── version.py ├── HierarchicalDesignDIffusion_GetMicrostructure.ipynb ├── HierarchicalDesignDIffusion_GetStressStrain.ipynb ├── LICENSE ├── README.md ├── VQ_VAE_Microstructure.ipynb └── setup.py /HierarchicalDesign/Diffusion.py: -------------------------------------------------------------------------------- 1 | ######################################################### 2 | # Define Attention-Diffusion model 3 | ######################################################### 4 | 5 | #Code based on: 6 | #https://github.com/lucidrains/imagen-pytorch 7 | #https://github.com/lucidrains/denoising-diffusion-pytorch 8 | 9 | ################################################ 10 | # Tools, helpers, ec. 11 | ################################################ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | import numpy as np 18 | 19 | from torchvision.utils import save_image, make_grid 20 | import torch.nn.functional as F 21 | from torchvision import datasets, transforms, models 22 | from sklearn.metrics import r2_score 23 | 24 | import matplotlib.pyplot as plt 25 | 26 | import ast 27 | import pandas as pd 28 | import numpy as np 29 | from einops import rearrange 30 | 31 | from torch.utils.data import DataLoader,Dataset 32 | from torchvision.io import read_image 33 | import pandas as pd 34 | from sklearn.model_selection import train_test_split 35 | 36 | from PIL import Image 37 | import time 38 | to_pil = transforms.ToPILImage() 39 | 40 | from torchvision.utils import save_image, make_grid 41 | 42 | # 43 | def cycle(dl): 44 | while True: 45 | for data in dl: 46 | yield data 47 | 48 | def eval_decorator(fn): 49 | def inner(model, *args, **kwargs): 50 | was_training = model.training 51 | 52 | # norms and residuals 53 | 54 | class LayerNorm(nn.Module): 55 | def __init__(self, feats, stable = False, dim = -1): 56 | super().__init__() 57 | self.stable = stable 58 | self.dim = dim 59 | 60 | self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1)))) 61 | 62 | def forward(self, x): 63 | dtype, dim = x.dtype, self.dim 64 | 65 | if self.stable: 66 | x = x / x.amax(dim = dim, keepdim = True).detach() 67 | 68 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 69 | var = torch.var(x, dim = dim, unbiased = False, keepdim = True) 70 | mean = torch.mean(x, dim = dim, keepdim = True) 71 | 72 | return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype) 73 | 74 | import math 75 | import copy 76 | from random import random 77 | from typing import List, Union 78 | from tqdm.auto import tqdm 79 | from functools import partial, wraps 80 | from contextlib import contextmanager, nullcontext 81 | from collections import namedtuple 82 | from pathlib import Path 83 | 84 | import torch 85 | import torch.nn.functional as F 86 | from torch.nn.parallel import DistributedDataParallel 87 | from torch import nn, einsum 88 | from torch.cuda.amp import autocast 89 | from torch.special import expm1 90 | import torchvision.transforms as T 91 | 92 | import kornia.augmentation as K 93 | 94 | from einops import rearrange, repeat, reduce 95 | from einops.layers.torch import Rearrange, Reduce 96 | from einops_exts import rearrange_many, repeat_many, check_shape 97 | from einops_exts.torch import EinopsToAndFrom 98 | 99 | #from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 100 | 101 | #from imagen_pytorch.imagen_video.imagen_video import Unet3D, resize_video_to 102 | 103 | # helper functions 104 | 105 | def exists(val): 106 | return val is not None 107 | 108 | def identity(t, *args, **kwargs): 109 | return t 110 | 111 | def first(arr, d = None): 112 | if len(arr) == 0: 113 | return d 114 | return arr[0] 115 | 116 | def maybe(fn): 117 | @wraps(fn) 118 | def inner(x): 119 | if not exists(x): 120 | return x 121 | return fn(x) 122 | return inner 123 | 124 | def once(fn): 125 | called = False 126 | @wraps(fn) 127 | def inner(x): 128 | nonlocal called 129 | if called: 130 | return 131 | called = True 132 | return fn(x) 133 | return inner 134 | 135 | print_once = once(print) 136 | 137 | def default(val, d): 138 | if exists(val): 139 | return val 140 | return d() if callable(d) else d 141 | 142 | def cast_tuple(val, length = None): 143 | if isinstance(val, list): 144 | val = tuple(val) 145 | 146 | output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) 147 | 148 | if exists(length): 149 | assert len(output) == length 150 | 151 | return output 152 | 153 | def is_float_dtype(dtype): 154 | return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) 155 | 156 | def cast_uint8_images_to_float(images): 157 | if not images.dtype == torch.uint8: 158 | return images 159 | return images / 255 160 | 161 | def module_device(module): 162 | return next(module.parameters()).device 163 | 164 | def zero_init_(m): 165 | nn.init.zeros_(m.weight) 166 | if exists(m.bias): 167 | nn.init.zeros_(m.bias) 168 | 169 | def eval_decorator(fn): 170 | def inner(model, *args, **kwargs): 171 | was_training = model.training 172 | model.eval() 173 | out = fn(model, *args, **kwargs) 174 | model.train(was_training) 175 | return out 176 | return inner 177 | 178 | def pad_tuple_to_length(t, length, fillvalue = None): 179 | remain_length = length - len(t) 180 | if remain_length <= 0: 181 | return t 182 | return (*t, *((fillvalue,) * remain_length)) 183 | 184 | # helper classes 185 | 186 | class Identity(nn.Module): 187 | def __init__(self, *args, **kwargs): 188 | super().__init__() 189 | 190 | def forward(self, x, *args, **kwargs): 191 | return x 192 | 193 | # tensor helpers 194 | 195 | def log(t, eps: float = 1e-12): 196 | return torch.log(t.clamp(min = eps)) 197 | 198 | def l2norm(t): 199 | return F.normalize(t, dim = -1) 200 | 201 | def right_pad_dims_to(x, t): 202 | padding_dims = x.ndim - t.ndim 203 | if padding_dims <= 0: 204 | return t 205 | return t.view(*t.shape, *((1,) * padding_dims)) 206 | 207 | def masked_mean(t, *, dim, mask = None): 208 | if not exists(mask): 209 | return t.mean(dim = dim) 210 | 211 | denom = mask.sum(dim = dim, keepdim = True) 212 | mask = rearrange(mask, 'b n -> b n 1') 213 | masked_t = t.masked_fill(~mask, 0.) 214 | 215 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 216 | 217 | ################################################ 218 | # Image tools 219 | ################################################ 220 | def resize_image_to( 221 | image, 222 | target_image_size, 223 | clamp_range = None 224 | ): 225 | orig_image_size = image.shape[-1] 226 | 227 | if orig_image_size == target_image_size: 228 | return image 229 | 230 | out = F.interpolate(image.float(), target_image_size, mode = 'linear', align_corners = True) 231 | 232 | return out 233 | 234 | # image normalization functions 235 | # ddpms expect images to be in the range of -1 to 1 236 | 237 | def normalize_neg_one_to_one(img): 238 | return img * 2 - 1 239 | #return img #* 2 - 1 240 | 241 | def unnormalize_zero_to_one(normed_img): 242 | return (normed_img + 1) * 0.5 243 | 244 | # classifier free guidance functions 245 | 246 | def prob_mask_like(shape, prob, device): 247 | if prob == 1: 248 | return torch.ones(shape, device = device, dtype = torch.bool) 249 | elif prob == 0: 250 | return torch.zeros(shape, device = device, dtype = torch.bool) 251 | else: 252 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob 253 | 254 | # gaussian diffusion with continuous time helper functions and classes 255 | # large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py 256 | 257 | @torch.jit.script 258 | def beta_linear_log_snr(t): 259 | return -torch.log(expm1(1e-4 + 10 * (t ** 2))) 260 | 261 | @torch.jit.script 262 | def alpha_cosine_log_snr(t, s: float = 0.008): 263 | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) # not sure if this accounts for beta being clipped to 0.999 in discrete version 264 | 265 | def log_snr_to_alpha_sigma(log_snr): 266 | return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) 267 | 268 | class GaussianDiffusionContinuousTimes(nn.Module): 269 | def __init__(self, *, noise_schedule, timesteps = 1000): 270 | super().__init__() 271 | 272 | if noise_schedule == "linear": 273 | self.log_snr = beta_linear_log_snr 274 | elif noise_schedule == "cosine": 275 | self.log_snr = alpha_cosine_log_snr 276 | else: 277 | raise ValueError(f'invalid noise schedule {noise_schedule}') 278 | 279 | self.num_timesteps = timesteps 280 | 281 | def get_times(self, batch_size, noise_level, *, device): 282 | return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32) 283 | 284 | def sample_random_times(self, batch_size, max_thres = 0.999, *, device): 285 | return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres) 286 | 287 | def get_condition(self, times): 288 | return maybe(self.log_snr)(times) 289 | 290 | def get_sampling_timesteps(self, batch, *, device): 291 | times = torch.linspace(1., 0., self.num_timesteps + 1, device = device) 292 | times = repeat(times, 't -> b t', b = batch) 293 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) 294 | times = times.unbind(dim = -1) 295 | return times 296 | 297 | def q_posterior(self, x_start, x_t, t, *, t_next = None): 298 | t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.)) 299 | 300 | """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """ 301 | log_snr = self.log_snr(t) 302 | log_snr_next = self.log_snr(t_next) 303 | log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next)) 304 | 305 | alpha, sigma = log_snr_to_alpha_sigma(log_snr) 306 | alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) 307 | 308 | # c - as defined near eq 33 309 | c = -expm1(log_snr - log_snr_next) 310 | posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start) 311 | 312 | # following (eq. 33) 313 | posterior_variance = (sigma_next ** 2) * c 314 | posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20) 315 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 316 | 317 | def q_sample(self, x_start, t, noise = None): 318 | dtype = x_start.dtype 319 | 320 | if isinstance(t, float): 321 | batch = x_start.shape[0] 322 | t = torch.full((batch,), t, device = x_start.device, dtype = dtype) 323 | 324 | noise = default(noise, lambda: torch.randn_like(x_start)) 325 | log_snr = self.log_snr(t).type(dtype) 326 | log_snr_padded_dim = right_pad_dims_to(x_start, log_snr) 327 | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) 328 | 329 | return alpha * x_start + sigma * noise, log_snr 330 | 331 | def q_sample_from_to(self, x_from, from_t, to_t, noise = None): 332 | shape, device, dtype = x_from.shape, x_from.device, x_from.dtype 333 | batch = shape[0] 334 | 335 | if isinstance(from_t, float): 336 | from_t = torch.full((batch,), from_t, device = device, dtype = dtype) 337 | 338 | if isinstance(to_t, float): 339 | to_t = torch.full((batch,), to_t, device = device, dtype = dtype) 340 | 341 | noise = default(noise, lambda: torch.randn_like(x_from)) 342 | 343 | log_snr = self.log_snr(from_t) 344 | log_snr_padded_dim = right_pad_dims_to(x_from, log_snr) 345 | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) 346 | 347 | log_snr_to = self.log_snr(to_t) 348 | log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to) 349 | alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to) 350 | 351 | return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha 352 | 353 | def predict_start_from_noise(self, x_t, t, noise): 354 | log_snr = self.log_snr(t) 355 | log_snr = right_pad_dims_to(x_t, log_snr) 356 | alpha, sigma = log_snr_to_alpha_sigma(log_snr) 357 | return (x_t - sigma * noise) / alpha.clamp(min = 1e-8) 358 | 359 | # norms and residuals 360 | 361 | class LayerNorm(nn.Module): 362 | def __init__(self, feats, stable = False, dim = -1): 363 | super().__init__() 364 | self.stable = stable 365 | self.dim = dim 366 | 367 | self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1)))) 368 | 369 | def forward(self, x): 370 | dtype, dim = x.dtype, self.dim 371 | 372 | if self.stable: 373 | x = x / x.amax(dim = dim, keepdim = True).detach() 374 | 375 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 376 | var = torch.var(x, dim = dim, unbiased = False, keepdim = True) 377 | mean = torch.mean(x, dim = dim, keepdim = True) 378 | 379 | return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype) 380 | 381 | ChanLayerNorm = partial(LayerNorm, dim = -2) 382 | 383 | class Always(): 384 | def __init__(self, val): 385 | self.val = val 386 | 387 | def __call__(self, *args, **kwargs): 388 | return self.val 389 | 390 | class Residual(nn.Module): 391 | def __init__(self, fn): 392 | super().__init__() 393 | self.fn = fn 394 | 395 | def forward(self, x, **kwargs): 396 | return self.fn(x, **kwargs) + x 397 | 398 | class Parallel(nn.Module): 399 | def __init__(self, *fns): 400 | super().__init__() 401 | self.fns = nn.ModuleList(fns) 402 | 403 | def forward(self, x): 404 | outputs = [fn(x) for fn in self.fns] 405 | return sum(outputs) 406 | 407 | ##################### 408 | 409 | # attention pooling 410 | 411 | class PerceiverAttention(nn.Module): 412 | def __init__( 413 | self, 414 | *, 415 | dim, 416 | dim_head = 64, 417 | heads = 8, 418 | cosine_sim_attn = False 419 | ): 420 | super().__init__() 421 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1 422 | self.cosine_sim_attn = cosine_sim_attn 423 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 424 | 425 | self.heads = heads 426 | inner_dim = dim_head * heads 427 | 428 | self.norm = nn.LayerNorm(dim) 429 | self.norm_latents = nn.LayerNorm(dim) 430 | 431 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 432 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 433 | 434 | self.to_out = nn.Sequential( 435 | nn.Linear(inner_dim, dim, bias = False), 436 | nn.LayerNorm(dim) 437 | ) 438 | 439 | def forward(self, x, latents, mask = None): 440 | x = self.norm(x) 441 | latents = self.norm_latents(latents) 442 | 443 | b, h = x.shape[0], self.heads 444 | 445 | q = self.to_q(latents) 446 | 447 | # the paper differs from Perceiver in which they also concat the key / values 448 | #derived from the latents to be attended to 449 | kv_input = torch.cat((x, latents), dim = -2) 450 | k, v = self.to_kv(kv_input).chunk(2, dim = -1) 451 | 452 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) 453 | 454 | q = q * self.scale 455 | 456 | # cosine sim attention 457 | 458 | if self.cosine_sim_attn: 459 | q, k = map(l2norm, (q, k)) 460 | 461 | # similarities and masking 462 | 463 | sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale 464 | 465 | if exists(mask): 466 | max_neg_value = -torch.finfo(sim.dtype).max 467 | mask = F.pad(mask, (0, latents.shape[-2]), value = True) 468 | 469 | mask = rearrange(mask, 'b j -> b 1 1 j') 470 | sim = sim.masked_fill(~mask, max_neg_value) 471 | 472 | # attention 473 | 474 | attn = sim.softmax(dim = -1, dtype = torch.float32) 475 | attn = attn.to(sim.dtype) 476 | 477 | out = einsum('... i j, ... j d -> ... i d', attn, v) 478 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 479 | return self.to_out(out) 480 | 481 | class PerceiverResampler(nn.Module): 482 | def __init__( 483 | self, 484 | *, 485 | dim, 486 | depth, 487 | dim_head = 64, 488 | heads = 8, 489 | num_latents = 64, 490 | num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence 491 | max_seq_len = 512, 492 | ff_mult = 4, 493 | cosine_sim_attn = False 494 | ): 495 | super().__init__() 496 | self.pos_emb = nn.Embedding(max_seq_len, dim) 497 | 498 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 499 | 500 | self.to_latents_from_mean_pooled_seq = None 501 | 502 | if num_latents_mean_pooled > 0: 503 | self.to_latents_from_mean_pooled_seq = nn.Sequential( 504 | LayerNorm(dim), 505 | nn.Linear(dim, dim * num_latents_mean_pooled), 506 | Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) 507 | ) 508 | 509 | self.layers = nn.ModuleList([]) 510 | for _ in range(depth): 511 | self.layers.append(nn.ModuleList([ 512 | PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn), 513 | FeedForward(dim = dim, mult = ff_mult) 514 | ])) 515 | 516 | def forward(self, x, mask = None): 517 | n, device = x.shape[1], x.device 518 | pos_emb = self.pos_emb(torch.arange(n, device = device)) 519 | 520 | x_with_pos = x + pos_emb 521 | 522 | latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) 523 | 524 | if exists(self.to_latents_from_mean_pooled_seq): 525 | meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) 526 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 527 | latents = torch.cat((meanpooled_latents, latents), dim = -2) 528 | 529 | for attn, ff in self.layers: 530 | latents = attn(x_with_pos, latents, mask = mask) + latents 531 | latents = ff(latents) + latents 532 | 533 | return latents 534 | 535 | # attention 536 | 537 | class Attention(nn.Module): 538 | def __init__( 539 | self, 540 | dim, 541 | *, 542 | dim_head = 64, 543 | heads = 8, 544 | context_dim = None, 545 | cosine_sim_attn = False 546 | ): 547 | super().__init__() 548 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. 549 | self.cosine_sim_attn = cosine_sim_attn 550 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 551 | 552 | self.heads = heads 553 | inner_dim = dim_head * heads 554 | 555 | self.norm = LayerNorm(dim) 556 | 557 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 558 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 559 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) 560 | 561 | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None 562 | 563 | self.to_out = nn.Sequential( 564 | nn.Linear(inner_dim, dim, bias = False), 565 | LayerNorm(dim) 566 | ) 567 | 568 | def forward(self, x, context = None, mask = None, attn_bias = None): 569 | b, n, device = *x.shape[:2], x.device 570 | 571 | x = self.norm(x) 572 | 573 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) 574 | 575 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) 576 | q = q * self.scale 577 | 578 | # add null key / value for classifier free guidance in prior net 579 | 580 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) 581 | k = torch.cat((nk, k), dim = -2) 582 | v = torch.cat((nv, v), dim = -2) 583 | 584 | # add text conditioning, if present 585 | 586 | if exists(context): 587 | assert exists(self.to_context) 588 | ck, cv = self.to_context(context).chunk(2, dim = -1) 589 | k = torch.cat((ck, k), dim = -2) 590 | v = torch.cat((cv, v), dim = -2) 591 | 592 | # cosine sim attention 593 | 594 | if self.cosine_sim_attn: 595 | q, k = map(l2norm, (q, k)) 596 | 597 | # calculate query / key similarities 598 | 599 | sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale 600 | 601 | # relative positional encoding (T5 style) 602 | 603 | if exists(attn_bias): 604 | sim = sim + attn_bias 605 | 606 | # masking 607 | 608 | max_neg_value = -torch.finfo(sim.dtype).max 609 | 610 | if exists(mask): 611 | mask = F.pad(mask, (1, 0), value = True) 612 | 613 | mask = rearrange(mask, 'b j -> b 1 j') 614 | sim = sim.masked_fill(~mask, max_neg_value) 615 | 616 | # attention 617 | 618 | attn = sim.softmax(dim = -1, dtype = torch.float32) 619 | attn = attn.to(sim.dtype) 620 | 621 | # aggregate values 622 | 623 | out = einsum('b h i j, b j d -> b h i d', attn, v) 624 | 625 | out = rearrange(out, 'b h n d -> b n (h d)') 626 | return self.to_out(out) 627 | 628 | # decoder 629 | 630 | def Upsample(dim, dim_out = None): 631 | dim_out = default(dim_out, dim) 632 | 633 | return nn.Sequential( 634 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 635 | nn.Conv1d(dim, dim_out, 3, padding = 1) 636 | ) 637 | 638 | class PixelShuffleUpsample(nn.Module): 639 | """ 640 | code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts 641 | https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf 642 | """ 643 | def __init__(self, dim, dim_out = None): 644 | super().__init__() 645 | dim_out = default(dim_out, dim) 646 | conv = nn.Conv1d(dim, dim_out * 4, 1) 647 | 648 | self.net = nn.Sequential( 649 | conv, 650 | nn.SiLU(), 651 | nn.PixelShuffle(2) 652 | ) 653 | 654 | self.init_conv_(conv) 655 | 656 | def init_conv_(self, conv): 657 | 658 | o, i, h = conv.weight.shape 659 | conv_weight = torch.empty(o // 4, i, h ) 660 | nn.init.kaiming_uniform_(conv_weight) 661 | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') 662 | 663 | conv.weight.data.copy_(conv_weight) 664 | nn.init.zeros_(conv.bias.data) 665 | 666 | def forward(self, x): 667 | return self.net(x) 668 | 669 | def Downsample(dim, dim_out = None): 670 | # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample 671 | # named SP-conv in the paper, but basically a pixel unshuffle 672 | dim_out = default(dim_out, dim) 673 | 674 | return nn.Sequential( 675 | 676 | Rearrange('b c (h s1) -> b (c s1) h', s1 = 2), 677 | nn.Conv1d(dim * 2, dim_out, 1) 678 | 679 | ) 680 | 681 | class SinusoidalPosEmb(nn.Module): 682 | def __init__(self, dim): 683 | super().__init__() 684 | self.dim = dim 685 | 686 | def forward(self, x): 687 | half_dim = self.dim // 2 688 | emb = math.log(10000) / (half_dim - 1) 689 | emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) 690 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') 691 | return torch.cat((emb.sin(), emb.cos()), dim = -1) 692 | 693 | class LearnedSinusoidalPosEmb(nn.Module): 694 | """ following @crowsonkb 's lead with learned sinusoidal pos emb """ 695 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 696 | 697 | def __init__(self, dim): 698 | super().__init__() 699 | assert (dim % 2) == 0 700 | half_dim = dim // 2 701 | self.weights = nn.Parameter(torch.randn(half_dim)) 702 | 703 | def forward(self, x): 704 | x = rearrange(x, 'b -> b 1') 705 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 706 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 707 | fouriered = torch.cat((x, fouriered), dim = -1) 708 | return fouriered 709 | 710 | class Block(nn.Module): 711 | def __init__( 712 | self, 713 | dim, 714 | dim_out, 715 | groups = 8, 716 | norm = True 717 | ): 718 | super().__init__() 719 | self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() 720 | self.activation = nn.SiLU() 721 | self.project = nn.Conv1d(dim, dim_out, 3, padding = 1) 722 | 723 | def forward(self, x, scale_shift = None): 724 | x = self.groupnorm(x) 725 | 726 | if exists(scale_shift): 727 | scale, shift = scale_shift 728 | x = x * (scale + 1) + shift 729 | 730 | x = self.activation(x) 731 | return self.project(x) 732 | 733 | class ResnetBlock(nn.Module): 734 | def __init__( 735 | self, 736 | dim, 737 | dim_out, 738 | *, 739 | cond_dim = None, 740 | time_cond_dim = None, 741 | groups = 8, 742 | linear_attn = False, 743 | use_gca = False, 744 | squeeze_excite = False, 745 | **attn_kwargs 746 | ): 747 | super().__init__() 748 | 749 | self.time_mlp = None 750 | 751 | if exists(time_cond_dim): 752 | self.time_mlp = nn.Sequential( 753 | nn.SiLU(), 754 | nn.Linear(time_cond_dim, dim_out * 2) 755 | ) 756 | 757 | self.cross_attn = None 758 | 759 | if exists(cond_dim): 760 | attn_klass = CrossAttention if not linear_attn else LinearCrossAttention 761 | 762 | self.cross_attn = EinopsToAndFrom( 763 | 764 | 'b c h ', 765 | 'b h c', 766 | attn_klass( 767 | dim = dim_out, 768 | context_dim = cond_dim, 769 | **attn_kwargs 770 | ) 771 | ) 772 | 773 | self.block1 = Block(dim, dim_out, groups = groups) 774 | self.block2 = Block(dim_out, dim_out, groups = groups) 775 | 776 | self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) 777 | 778 | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else Identity() 779 | 780 | 781 | def forward(self, x, time_emb = None, cond = None): 782 | 783 | scale_shift = None 784 | if exists(self.time_mlp) and exists(time_emb): 785 | time_emb = self.time_mlp(time_emb) 786 | 787 | time_emb = rearrange(time_emb, 'b c -> b c 1') 788 | scale_shift = time_emb.chunk(2, dim = 1) 789 | 790 | h = self.block1(x) 791 | 792 | if exists(self.cross_attn): 793 | assert exists(cond) 794 | h = self.cross_attn(h, context = cond) + h 795 | 796 | h = self.block2(h, scale_shift = scale_shift) 797 | 798 | h = h * self.gca(h) 799 | 800 | return h + self.res_conv(x) 801 | 802 | class CrossAttention(nn.Module): 803 | def __init__( 804 | self, 805 | dim, 806 | *, 807 | context_dim = None, 808 | dim_head = 64, 809 | heads = 8, 810 | norm_context = False, 811 | cosine_sim_attn = False 812 | ): 813 | super().__init__() 814 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. 815 | self.cosine_sim_attn = cosine_sim_attn 816 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 817 | 818 | self.heads = heads 819 | inner_dim = dim_head * heads 820 | 821 | context_dim = default(context_dim, dim) 822 | 823 | self.norm = LayerNorm(dim) 824 | self.norm_context = LayerNorm(context_dim) if norm_context else Identity() 825 | 826 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 827 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 828 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) 829 | 830 | self.to_out = nn.Sequential( 831 | nn.Linear(inner_dim, dim, bias = False), 832 | LayerNorm(dim) 833 | ) 834 | 835 | def forward(self, x, context, mask = None): 836 | b, n, device = *x.shape[:2], x.device 837 | 838 | x = self.norm(x) 839 | context = self.norm_context(context) 840 | 841 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 842 | 843 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) 844 | 845 | # add null key / value for classifier free guidance in prior net 846 | 847 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) 848 | 849 | k = torch.cat((nk, k), dim = -2) 850 | v = torch.cat((nv, v), dim = -2) 851 | 852 | q = q * self.scale 853 | 854 | # cosine sim attention 855 | 856 | if self.cosine_sim_attn: 857 | q, k = map(l2norm, (q, k)) 858 | 859 | # similarities 860 | 861 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale 862 | 863 | # masking 864 | 865 | max_neg_value = -torch.finfo(sim.dtype).max 866 | 867 | if exists(mask): 868 | mask = F.pad(mask, (1, 0), value = True) 869 | 870 | mask = rearrange(mask, 'b j -> b 1 j') 871 | sim = sim.masked_fill(~mask, max_neg_value) 872 | 873 | attn = sim.softmax(dim = -1, dtype = torch.float32) 874 | attn = attn.to(sim.dtype) 875 | 876 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 877 | out = rearrange(out, 'b h n d -> b n (h d)') 878 | return self.to_out(out) 879 | 880 | class LinearCrossAttention(CrossAttention): 881 | def forward(self, x, context, mask = None): 882 | b, n, device = *x.shape[:2], x.device 883 | 884 | x = self.norm(x) 885 | context = self.norm_context(context) 886 | 887 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 888 | 889 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads) 890 | 891 | # add null key / value for classifier free guidance in prior net 892 | 893 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b) 894 | 895 | k = torch.cat((nk, k), dim = -2) 896 | v = torch.cat((nv, v), dim = -2) 897 | 898 | # masking 899 | 900 | max_neg_value = -torch.finfo(x.dtype).max 901 | 902 | if exists(mask): 903 | mask = F.pad(mask, (1, 0), value = True) 904 | mask = rearrange(mask, 'b n -> b n 1') 905 | k = k.masked_fill(~mask, max_neg_value) 906 | v = v.masked_fill(~mask, 0.) 907 | 908 | # linear attention 909 | 910 | q = q.softmax(dim = -1) 911 | k = k.softmax(dim = -2) 912 | 913 | q = q * self.scale 914 | 915 | context = einsum('b n d, b n e -> b d e', k, v) 916 | out = einsum('b n d, b d e -> b n e', q, context) 917 | out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) 918 | return self.to_out(out) 919 | 920 | class LinearAttention(nn.Module): 921 | def __init__( 922 | self, 923 | dim, 924 | dim_head = 32, 925 | heads = 8, 926 | dropout = 0.05, 927 | context_dim = None, 928 | **kwargs 929 | ): 930 | super().__init__() 931 | self.scale = dim_head ** -0.5 932 | self.heads = heads 933 | inner_dim = dim_head * heads 934 | self.norm = ChanLayerNorm(dim) 935 | 936 | self.nonlin = nn.SiLU() 937 | 938 | self.to_q = nn.Sequential( 939 | nn.Dropout(dropout), 940 | nn.Conv1d(dim, inner_dim, 1, bias = False), 941 | nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) 942 | ) 943 | 944 | self.to_k = nn.Sequential( 945 | nn.Dropout(dropout), 946 | nn.Conv1d(dim, inner_dim, 1, bias = False), 947 | nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) 948 | ) 949 | 950 | self.to_v = nn.Sequential( 951 | nn.Dropout(dropout), 952 | nn.Conv1d(dim, inner_dim, 1, bias = False), 953 | nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) 954 | ) 955 | 956 | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None 957 | 958 | self.to_out = nn.Sequential( 959 | nn.Conv1d(inner_dim, dim, 1, bias = False), 960 | ChanLayerNorm(dim) 961 | ) 962 | 963 | def forward(self, fmap, context = None): 964 | h, x, y = self.heads, *fmap.shape[-2:] 965 | 966 | fmap = self.norm(fmap) 967 | q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) 968 | q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) 969 | 970 | if exists(context): 971 | assert exists(self.to_context) 972 | ck, cv = self.to_context(context).chunk(2, dim = -1) 973 | ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h) 974 | k = torch.cat((k, ck), dim = -2) 975 | v = torch.cat((v, cv), dim = -2) 976 | 977 | q = q.softmax(dim = -1) 978 | k = k.softmax(dim = -2) 979 | 980 | q = q * self.scale 981 | 982 | context = einsum('b n d, b n e -> b d e', k, v) 983 | out = einsum('b n d, b d e -> b n e', q, context) 984 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) 985 | 986 | out = self.nonlin(out) 987 | return self.to_out(out) 988 | 989 | class GlobalContext(nn.Module): 990 | """ basically a superior form of squeeze-excitation that is attention-esque """ 991 | 992 | def __init__( 993 | self, 994 | *, 995 | dim_in, 996 | dim_out 997 | ): 998 | super().__init__() 999 | self.to_k = nn.Conv1d(dim_in, 1, 1) 1000 | hidden_dim = max(3, dim_out // 2) 1001 | 1002 | self.net = nn.Sequential( 1003 | nn.Conv1d(dim_in, hidden_dim, 1), 1004 | nn.SiLU(), 1005 | nn.Conv1d(hidden_dim, dim_out, 1), 1006 | nn.Sigmoid() 1007 | ) 1008 | 1009 | def forward(self, x): 1010 | context = self.to_k(x) 1011 | x, context = rearrange_many((x, context), 'b n ... -> b n (...)') 1012 | out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) 1013 | 1014 | return self.net(out) 1015 | 1016 | def FeedForward(dim, mult = 2): 1017 | hidden_dim = int(dim * mult) 1018 | return nn.Sequential( 1019 | LayerNorm(dim), 1020 | nn.Linear(dim, hidden_dim, bias = False), 1021 | nn.GELU(), 1022 | LayerNorm(hidden_dim), 1023 | nn.Linear(hidden_dim, dim, bias = False) 1024 | ) 1025 | 1026 | def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width 1027 | hidden_dim = int(dim * mult) 1028 | return nn.Sequential( 1029 | ChanLayerNorm(dim), 1030 | nn.Conv1d(dim, hidden_dim, 1, bias = False), 1031 | nn.GELU(), 1032 | ChanLayerNorm(hidden_dim), 1033 | nn.Conv1d(hidden_dim, dim, 1, bias = False) 1034 | ) 1035 | 1036 | class TransformerBlock(nn.Module): 1037 | def __init__( 1038 | self, 1039 | dim, 1040 | *, 1041 | depth = 1, 1042 | heads = 8, 1043 | dim_head = 32, 1044 | ff_mult = 2, 1045 | context_dim = None, 1046 | cosine_sim_attn = False 1047 | ): 1048 | super().__init__() 1049 | self.layers = nn.ModuleList([]) 1050 | 1051 | for _ in range(depth): 1052 | self.layers.append(nn.ModuleList([ 1053 | 1054 | EinopsToAndFrom('b c h', 'b h c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)), 1055 | ChanFeedForward(dim = dim, mult = ff_mult) 1056 | ])) 1057 | 1058 | def forward(self, x, context = None): 1059 | for attn, ff in self.layers: 1060 | x = attn(x, context = context) + x 1061 | x = ff(x) + x 1062 | return x 1063 | 1064 | class LinearAttentionTransformerBlock(nn.Module): 1065 | def __init__( 1066 | self, 1067 | dim, 1068 | *, 1069 | depth = 1, 1070 | heads = 8, 1071 | dim_head = 32, 1072 | ff_mult = 2, 1073 | context_dim = None, 1074 | **kwargs 1075 | ): 1076 | super().__init__() 1077 | self.layers = nn.ModuleList([]) 1078 | 1079 | for _ in range(depth): 1080 | self.layers.append(nn.ModuleList([ 1081 | LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), 1082 | ChanFeedForward(dim = dim, mult = ff_mult) 1083 | ])) 1084 | 1085 | def forward(self, x, context = None): 1086 | for attn, ff in self.layers: 1087 | x = attn(x, context = context) + x 1088 | x = ff(x) + x 1089 | return x 1090 | 1091 | class CrossEmbedLayer(nn.Module): 1092 | def __init__( 1093 | self, 1094 | dim_in, 1095 | kernel_sizes, 1096 | dim_out = None, 1097 | stride = 2 1098 | ): 1099 | super().__init__() 1100 | assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) 1101 | dim_out = default(dim_out, dim_in) 1102 | 1103 | kernel_sizes = sorted(kernel_sizes) 1104 | num_scales = len(kernel_sizes) 1105 | 1106 | # calculate the dimension at each scale 1107 | dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] 1108 | dim_scales = [*dim_scales, dim_out - sum(dim_scales)] 1109 | 1110 | self.convs = nn.ModuleList([]) 1111 | for kernel, dim_scale in zip(kernel_sizes, dim_scales): 1112 | self.convs.append(nn.Conv1d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) 1113 | 1114 | def forward(self, x): 1115 | fmaps = tuple(map(lambda conv: conv(x), self.convs)) 1116 | return torch.cat(fmaps, dim = 1) 1117 | 1118 | class UpsampleCombiner(nn.Module): 1119 | def __init__( 1120 | self, 1121 | dim, 1122 | *, 1123 | enabled = False, 1124 | dim_ins = tuple(), 1125 | dim_outs = tuple() 1126 | ): 1127 | super().__init__() 1128 | dim_outs = cast_tuple(dim_outs, len(dim_ins)) 1129 | assert len(dim_ins) == len(dim_outs) 1130 | 1131 | self.enabled = enabled 1132 | 1133 | if not self.enabled: 1134 | self.dim_out = dim 1135 | return 1136 | 1137 | self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) 1138 | self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) 1139 | 1140 | def forward(self, x, fmaps = None): 1141 | target_size = x.shape[-1] 1142 | 1143 | fmaps = default(fmaps, tuple()) 1144 | 1145 | if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: 1146 | return x 1147 | 1148 | fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps] 1149 | outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] 1150 | return torch.cat((x, *outs), dim = 1) 1151 | 1152 | ######################################################## 1153 | # 1D U-Net 1154 | ######################################################## 1155 | 1156 | class OneD_Unet(nn.Module): 1157 | def __init__( 1158 | self, 1159 | *, 1160 | dim, 1161 | image_embed_dim = 1024, 1162 | text_embed_dim = 768, #get_encoded_dim(DEFAULT_T5_NAME), 1163 | num_resnet_blocks = 1, 1164 | cond_dim = None, 1165 | num_image_tokens = 4, 1166 | num_time_tokens = 2, 1167 | learned_sinu_pos_emb_dim = 16, 1168 | out_dim = None, 1169 | dim_mults=(1, 2, 4, 8), 1170 | cond_images_channels = 0, 1171 | channels = 3, 1172 | channels_out = None, 1173 | attn_dim_head = 64, 1174 | attn_heads = 8, 1175 | ff_mult = 2., 1176 | lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ 1177 | layer_attns = True, 1178 | layer_attns_depth = 1, 1179 | layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 1180 | attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) 1181 | layer_cross_attns = True, 1182 | use_linear_attn = False, 1183 | use_linear_cross_attn = False, 1184 | cond_on_text = True, 1185 | max_text_len = 256, 1186 | init_dim = None, 1187 | resnet_groups = 8, 1188 | init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed 1189 | init_cross_embed = False, 1190 | init_cross_embed_kernel_sizes = (3, 7, 15), 1191 | cross_embed_downsample = False, 1192 | cross_embed_downsample_kernel_sizes = (2, 4), 1193 | attn_pool_text = True, 1194 | attn_pool_num_latents = 32, 1195 | dropout = 0., 1196 | memory_efficient = False, 1197 | init_conv_to_final_conv_residual = False, 1198 | use_global_context_attn = True, 1199 | scale_skip_connection = True, 1200 | final_resnet_block = True, 1201 | final_conv_kernel_size = 3, 1202 | cosine_sim_attn = False, 1203 | self_cond = False, 1204 | combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully 1205 | pixel_shuffle_upsample = False , # may address checkboard artifacts 1206 | beginning_and_final_conv_present = True , #TODO add cross-attn, doesnt work yet...whether or not to have final conv layer 1207 | 1208 | ): 1209 | super().__init__() 1210 | 1211 | # guide researchers 1212 | 1213 | assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' 1214 | 1215 | # save locals to take care of some hyperparameters for cascading DDPM 1216 | 1217 | self._locals = locals() 1218 | self._locals.pop('self', None) 1219 | self._locals.pop('__class__', None) 1220 | 1221 | # determine dimensions 1222 | 1223 | self.channels = channels 1224 | self.channels_out = default(channels_out, channels) 1225 | 1226 | # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis 1227 | # (2) in self conditioning, one appends the predict x0 (x_start) 1228 | init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) 1229 | init_dim = default(init_dim, dim) 1230 | 1231 | self.self_cond = self_cond 1232 | 1233 | # optional image conditioning 1234 | 1235 | self.has_cond_image = cond_images_channels > 0 1236 | self.cond_images_channels = cond_images_channels 1237 | 1238 | init_channels += cond_images_channels 1239 | 1240 | # initial convolution 1241 | 1242 | self.beginning_and_final_conv_present=beginning_and_final_conv_present 1243 | 1244 | if self.beginning_and_final_conv_present: 1245 | 1246 | self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, 1247 | kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv1d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) 1248 | 1249 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 1250 | in_out = list(zip(dims[:-1], dims[1:])) 1251 | 1252 | # time conditioning 1253 | 1254 | cond_dim = default(cond_dim, dim) 1255 | time_cond_dim = dim * 4 * (2 if lowres_cond else 1) 1256 | 1257 | # embedding time for log(snr) noise from continuous version 1258 | 1259 | sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) 1260 | sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 1261 | 1262 | self.to_time_hiddens = nn.Sequential( 1263 | sinu_pos_emb, 1264 | nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), 1265 | nn.SiLU() 1266 | ) 1267 | 1268 | self.to_time_cond = nn.Sequential( 1269 | nn.Linear(time_cond_dim, time_cond_dim) 1270 | ) 1271 | 1272 | # project to time tokens as well as time hiddens 1273 | 1274 | self.to_time_tokens = nn.Sequential( 1275 | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), 1276 | Rearrange('b (r d) -> b r d', r = num_time_tokens) 1277 | ) 1278 | 1279 | # low res aug noise conditioning 1280 | 1281 | self.lowres_cond = lowres_cond 1282 | 1283 | if lowres_cond: 1284 | self.to_lowres_time_hiddens = nn.Sequential( 1285 | LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), 1286 | nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), 1287 | nn.SiLU() 1288 | ) 1289 | 1290 | self.to_lowres_time_cond = nn.Sequential( 1291 | nn.Linear(time_cond_dim, time_cond_dim) 1292 | ) 1293 | 1294 | self.to_lowres_time_tokens = nn.Sequential( 1295 | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), 1296 | Rearrange('b (r d) -> b r d', r = num_time_tokens) 1297 | ) 1298 | 1299 | # normalizations 1300 | 1301 | self.norm_cond = nn.LayerNorm(cond_dim) 1302 | 1303 | # text encoding conditioning (optional) 1304 | 1305 | self.text_to_cond = None 1306 | 1307 | if cond_on_text: 1308 | assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' 1309 | if text_embed_dim != cond_dim: 1310 | self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) 1311 | self.text_cond_linear=True 1312 | 1313 | else: 1314 | #Text conditioning is equal to cond_dim - no linear layer used 1315 | self.text_cond_linear=False 1316 | 1317 | 1318 | # finer control over whether to condition on text encodings 1319 | 1320 | self.cond_on_text = cond_on_text 1321 | 1322 | # attention pooling 1323 | 1324 | self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, 1325 | dim_head = attn_dim_head, heads = attn_heads, 1326 | num_latents = attn_pool_num_latents, 1327 | cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None 1328 | 1329 | # for classifier free guidance 1330 | 1331 | self.max_text_len = max_text_len 1332 | 1333 | self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) 1334 | self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) 1335 | 1336 | # for non-attention based text conditioning at all points in the network where time is also conditioned 1337 | 1338 | self.to_text_non_attn_cond = None 1339 | 1340 | if cond_on_text: 1341 | self.to_text_non_attn_cond = nn.Sequential( 1342 | nn.LayerNorm(cond_dim), 1343 | nn.Linear(cond_dim, time_cond_dim), 1344 | nn.SiLU(), 1345 | nn.Linear(time_cond_dim, time_cond_dim) 1346 | ) 1347 | 1348 | # attention related params 1349 | 1350 | attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) 1351 | 1352 | num_layers = len(in_out) 1353 | 1354 | # resnet block klass 1355 | 1356 | num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) 1357 | resnet_groups = cast_tuple(resnet_groups, num_layers) 1358 | 1359 | resnet_klass = partial(ResnetBlock, **attn_kwargs) 1360 | 1361 | layer_attns = cast_tuple(layer_attns, num_layers) 1362 | layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) 1363 | layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) 1364 | 1365 | use_linear_attn = cast_tuple(use_linear_attn, num_layers) 1366 | use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) 1367 | 1368 | assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) 1369 | 1370 | # downsample klass 1371 | 1372 | downsample_klass = Downsample 1373 | 1374 | if cross_embed_downsample: 1375 | downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) 1376 | 1377 | # initial resnet block (for memory efficient unet) 1378 | 1379 | self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None 1380 | 1381 | # scale for resnet skip connections 1382 | 1383 | self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) 1384 | 1385 | # layers 1386 | 1387 | self.downs = nn.ModuleList([]) 1388 | self.ups = nn.ModuleList([]) 1389 | num_resolutions = len(in_out) 1390 | 1391 | layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] 1392 | reversed_layer_params = list(map(reversed, layer_params)) 1393 | 1394 | # downsampling layers 1395 | 1396 | skip_connect_dims = [] # keep track of skip connection dimensions 1397 | 1398 | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): 1399 | is_last = ind >= (num_resolutions - 1) 1400 | 1401 | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None 1402 | 1403 | if layer_attn: 1404 | transformer_block_klass = TransformerBlock 1405 | elif layer_use_linear_attn: 1406 | transformer_block_klass = LinearAttentionTransformerBlock 1407 | else: 1408 | transformer_block_klass = Identity 1409 | 1410 | current_dim = dim_in 1411 | 1412 | # whether to pre-downsample, from memory efficient unet 1413 | 1414 | pre_downsample = None 1415 | 1416 | if memory_efficient: 1417 | pre_downsample = downsample_klass(dim_in, dim_out) 1418 | current_dim = dim_out 1419 | 1420 | skip_connect_dims.append(current_dim) 1421 | 1422 | # whether to do post-downsample, for non-memory efficient unet 1423 | 1424 | post_downsample = None 1425 | if not memory_efficient: 1426 | post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1)) 1427 | 1428 | self.downs.append(nn.ModuleList([ 1429 | pre_downsample, 1430 | resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), 1431 | nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), 1432 | transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), 1433 | post_downsample 1434 | ])) 1435 | 1436 | # middle layers 1437 | 1438 | mid_dim = dims[-1] 1439 | 1440 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) 1441 | 1442 | self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None 1443 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) 1444 | 1445 | # upsample klass 1446 | 1447 | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample 1448 | 1449 | # upsampling layers 1450 | 1451 | upsample_fmap_dims = [] 1452 | 1453 | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): 1454 | is_last = ind == (len(in_out) - 1) 1455 | 1456 | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None 1457 | 1458 | if layer_attn: 1459 | transformer_block_klass = TransformerBlock 1460 | elif layer_use_linear_attn: 1461 | transformer_block_klass = LinearAttentionTransformerBlock 1462 | else: 1463 | transformer_block_klass = Identity 1464 | 1465 | skip_connect_dim = skip_connect_dims.pop() 1466 | 1467 | upsample_fmap_dims.append(dim_out) 1468 | 1469 | self.ups.append(nn.ModuleList([ 1470 | resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), 1471 | nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), 1472 | transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), 1473 | upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() 1474 | ])) 1475 | 1476 | # whether to combine feature maps from all upsample blocks before final resnet block out 1477 | 1478 | self.upsample_combiner = UpsampleCombiner( 1479 | dim = dim, 1480 | enabled = combine_upsample_fmaps, 1481 | dim_ins = upsample_fmap_dims, 1482 | dim_outs = dim 1483 | ) 1484 | 1485 | # whether to do a final residual from initial conv to the final resnet block out 1486 | 1487 | self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual 1488 | final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) 1489 | 1490 | # final optional resnet block and convolution out 1491 | 1492 | self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None 1493 | 1494 | final_conv_dim_in = dim if final_resnet_block else final_conv_dim 1495 | final_conv_dim_in += (channels if lowres_cond else 0) 1496 | 1497 | if self.beginning_and_final_conv_present: 1498 | self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) 1499 | 1500 | if self.beginning_and_final_conv_present: 1501 | zero_init_(self.final_conv) 1502 | 1503 | # if the current settings for the unet are not correct 1504 | # for cascading DDPM, then reinit the unet with the right settings 1505 | def cast_model_parameters( 1506 | self, 1507 | *, 1508 | lowres_cond, 1509 | text_embed_dim, 1510 | channels, 1511 | channels_out, 1512 | cond_on_text 1513 | ): 1514 | if lowres_cond == self.lowres_cond and \ 1515 | channels == self.channels and \ 1516 | cond_on_text == self.cond_on_text and \ 1517 | text_embed_dim == self._locals['text_embed_dim'] and \ 1518 | channels_out == self.channels_out: 1519 | return self 1520 | 1521 | updated_kwargs = dict( 1522 | lowres_cond = lowres_cond, 1523 | text_embed_dim = text_embed_dim, 1524 | channels = channels, 1525 | channels_out = channels_out, 1526 | cond_on_text = cond_on_text 1527 | ) 1528 | 1529 | return self.__class__(**{**self._locals, **updated_kwargs}) 1530 | 1531 | # methods for returning the full unet config as well as its parameter state 1532 | 1533 | def to_config_and_state_dict(self): 1534 | return self._locals, self.state_dict() 1535 | 1536 | # class method for rehydrating the unet from its config and state dict 1537 | 1538 | @classmethod 1539 | def from_config_and_state_dict(klass, config, state_dict): 1540 | unet = klass(**config) 1541 | unet.load_state_dict(state_dict) 1542 | return unet 1543 | 1544 | # methods for persisting unet to disk 1545 | 1546 | def persist_to_file(self, path): 1547 | path = Path(path) 1548 | path.parents[0].mkdir(exist_ok = True, parents = True) 1549 | 1550 | config, state_dict = self.to_config_and_state_dict() 1551 | pkg = dict(config = config, state_dict = state_dict) 1552 | torch.save(pkg, str(path)) 1553 | 1554 | # class method for rehydrating the unet from file saved with `persist_to_file` 1555 | 1556 | @classmethod 1557 | def hydrate_from_file(klass, path): 1558 | path = Path(path) 1559 | assert path.exists() 1560 | pkg = torch.load(str(path)) 1561 | 1562 | assert 'config' in pkg and 'state_dict' in pkg 1563 | config, state_dict = pkg['config'], pkg['state_dict'] 1564 | 1565 | return Unet.from_config_and_state_dict(config, state_dict) 1566 | 1567 | # forward with classifier free guidance 1568 | 1569 | def forward_with_cond_scale( 1570 | self, 1571 | *args, 1572 | cond_scale = 1., 1573 | **kwargs 1574 | ): 1575 | logits = self.forward(*args, **kwargs) 1576 | 1577 | if cond_scale == 1: 1578 | return logits 1579 | 1580 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) 1581 | return null_logits + (logits - null_logits) * cond_scale 1582 | 1583 | def forward( 1584 | self, 1585 | x, 1586 | time, 1587 | *, 1588 | lowres_cond_img = None, 1589 | lowres_noise_times = None, 1590 | text_embeds = None, 1591 | text_mask = None, 1592 | cond_images = None, 1593 | self_cond = None, 1594 | cond_drop_prob = 0. 1595 | ): 1596 | batch_size, device = x.shape[0], x.device 1597 | 1598 | # condition on self 1599 | 1600 | if self.self_cond: 1601 | self_cond = default(self_cond, lambda: torch.zeros_like(x)) 1602 | x = torch.cat((x, self_cond), dim = 1) 1603 | 1604 | # add low resolution conditioning, if present 1605 | 1606 | assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' 1607 | assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' 1608 | 1609 | if exists(lowres_cond_img): 1610 | x = torch.cat((x, lowres_cond_img), dim = 1) 1611 | 1612 | # condition on input image 1613 | 1614 | assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' 1615 | 1616 | if exists(cond_images): 1617 | assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' 1618 | cond_images = resize_image_to(cond_images, x.shape[-1]) 1619 | x = torch.cat((cond_images, x), dim = 1) 1620 | 1621 | # initial convolution 1622 | 1623 | if self.beginning_and_final_conv_present: 1624 | x = self.init_conv(x) 1625 | 1626 | # init conv residual 1627 | 1628 | if self.init_conv_to_final_conv_residual: 1629 | init_conv_residual = x.clone() 1630 | 1631 | # time conditioning 1632 | 1633 | time_hiddens = self.to_time_hiddens(time) 1634 | 1635 | # derive time tokens 1636 | 1637 | time_tokens = self.to_time_tokens(time_hiddens) 1638 | t = self.to_time_cond(time_hiddens) 1639 | 1640 | # add lowres time conditioning to time hiddens 1641 | # and add lowres time tokens along sequence dimension for attention 1642 | 1643 | if self.lowres_cond: 1644 | lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) 1645 | lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) 1646 | lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) 1647 | 1648 | t = t + lowres_t 1649 | 1650 | time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) 1651 | 1652 | # text conditioning 1653 | 1654 | text_tokens = None 1655 | 1656 | if exists(text_embeds) and self.cond_on_text: 1657 | 1658 | # conditional dropout 1659 | 1660 | text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) 1661 | 1662 | 1663 | text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') 1664 | text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') 1665 | 1666 | # calculate text embeds 1667 | 1668 | if self.text_cond_linear: 1669 | text_tokens = self.text_to_cond(text_embeds) 1670 | else: 1671 | text_tokens=text_embeds 1672 | 1673 | text_tokens = text_tokens[:, :self.max_text_len] 1674 | 1675 | if exists(text_mask): 1676 | text_mask = text_mask[:, :self.max_text_len] 1677 | 1678 | text_tokens_len = text_tokens.shape[1] 1679 | remainder = self.max_text_len - text_tokens_len 1680 | 1681 | if remainder > 0: 1682 | 1683 | text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) 1684 | 1685 | if exists(text_mask): 1686 | if remainder > 0: 1687 | text_mask = F.pad(text_mask, (0, remainder), value = False) 1688 | 1689 | 1690 | text_mask = rearrange(text_mask, 'b n -> b n 1') 1691 | text_keep_mask_embed = text_mask & text_keep_mask_embed 1692 | 1693 | null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working 1694 | 1695 | text_tokens = torch.where( 1696 | text_keep_mask_embed, 1697 | text_tokens, 1698 | null_text_embed 1699 | ) 1700 | 1701 | if exists(self.attn_pool): 1702 | text_tokens = self.attn_pool(text_tokens) 1703 | 1704 | # extra non-attention conditioning by projecting and then summing text embeddings to time 1705 | # termed as text hiddens 1706 | 1707 | mean_pooled_text_tokens = text_tokens.mean(dim = -2) 1708 | 1709 | text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) 1710 | 1711 | null_text_hidden = self.null_text_hidden.to(t.dtype) 1712 | 1713 | text_hiddens = torch.where( 1714 | text_keep_mask_hidden, 1715 | text_hiddens, 1716 | null_text_hidden 1717 | ) 1718 | 1719 | t = t + text_hiddens 1720 | 1721 | # main conditioning tokens (c) 1722 | 1723 | c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) 1724 | 1725 | # normalize conditioning tokens 1726 | 1727 | c = self.norm_cond(c) 1728 | 1729 | # initial resnet block (for memory efficient unet) 1730 | 1731 | if exists(self.init_resnet_block): 1732 | x = self.init_resnet_block(x, t) 1733 | 1734 | 1735 | # go through the layers of the unet, down and up 1736 | 1737 | hiddens = [] 1738 | 1739 | for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: 1740 | if exists(pre_downsample): 1741 | x = pre_downsample(x) 1742 | 1743 | x = init_block(x, t, c) 1744 | 1745 | for resnet_block in resnet_blocks: 1746 | x = resnet_block(x, t) 1747 | hiddens.append(x) 1748 | 1749 | x = attn_block(x, c) 1750 | 1751 | hiddens.append(x) 1752 | 1753 | if exists(post_downsample): 1754 | 1755 | x = post_downsample(x) 1756 | 1757 | 1758 | 1759 | x = self.mid_block1(x, t, c) 1760 | 1761 | if exists(self.mid_attn): 1762 | x = self.mid_attn(x) 1763 | 1764 | x = self.mid_block2(x, t, c) 1765 | 1766 | add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) 1767 | 1768 | up_hiddens = [] 1769 | 1770 | 1771 | for init_block, resnet_blocks, attn_block, upsample in self.ups: 1772 | 1773 | x = add_skip_connection(x) 1774 | 1775 | x = init_block(x, t, c) 1776 | 1777 | 1778 | for resnet_block in resnet_blocks: 1779 | x = add_skip_connection(x) 1780 | x = resnet_block(x, t) 1781 | 1782 | x = attn_block(x, c) 1783 | up_hiddens.append(x.contiguous()) 1784 | x = upsample(x) 1785 | 1786 | # whether to combine all feature maps from upsample blocks 1787 | 1788 | x = self.upsample_combiner(x, up_hiddens) 1789 | 1790 | # final top-most residual if needed 1791 | 1792 | if self.init_conv_to_final_conv_residual: 1793 | x = torch.cat((x, init_conv_residual), dim = 1) 1794 | 1795 | if exists(self.final_res_block): 1796 | x = self.final_res_block(x, t) 1797 | 1798 | if exists(lowres_cond_img): 1799 | x = torch.cat((x, lowres_cond_img), dim = 1) 1800 | 1801 | if self.beginning_and_final_conv_present: 1802 | x=self.final_conv(x) 1803 | 1804 | return x 1805 | 1806 | # null unet 1807 | 1808 | class Unet(nn.Module): 1809 | def __init__(self, *args, **kwargs): 1810 | super().__init__() 1811 | self.lowres_cond = False 1812 | self.dummy_parameter = nn.Parameter(torch.tensor([0.])) 1813 | 1814 | def cast_model_parameters(self, *args, **kwargs): 1815 | return self 1816 | 1817 | def forward(self, x, *args, **kwargs): 1818 | return x 1819 | 1820 | class Unet3D(nn.Module): 1821 | def __init__(self, *args, **kwargs): 1822 | super().__init__() 1823 | self.lowres_cond = False 1824 | self.dummy_parameter = nn.Parameter(torch.tensor([0.])) 1825 | 1826 | def cast_model_parameters(self, *args, **kwargs): 1827 | return self 1828 | 1829 | def forward(self, x, *args, **kwargs): 1830 | return x 1831 | 1832 | 1833 | class NullUnet(nn.Module): 1834 | def __init__(self, *args, **kwargs): 1835 | super().__init__() 1836 | self.lowres_cond = False 1837 | self.dummy_parameter = nn.Parameter(torch.tensor([0.])) 1838 | 1839 | def cast_model_parameters(self, *args, **kwargs): 1840 | return self 1841 | 1842 | def forward(self, x, *args, **kwargs): 1843 | return x 1844 | 1845 | from math import sqrt 1846 | 1847 | # constants 1848 | 1849 | Hparams_fields = [ 1850 | 'num_sample_steps', 1851 | 'sigma_min', 1852 | 'sigma_max', 1853 | 'sigma_data', 1854 | 'rho', 1855 | 'P_mean', 1856 | 'P_std', 1857 | 'S_churn', 1858 | 'S_tmin', 1859 | 'S_tmax', 1860 | 'S_noise' 1861 | ] 1862 | 1863 | Hparams = namedtuple('Hparams', Hparams_fields) 1864 | 1865 | # helper functions 1866 | 1867 | def log(t, eps = 1e-20): 1868 | return torch.log(t.clamp(min = eps)) 1869 | 1870 | # main class 1871 | 1872 | class ElucidatedImagen(nn.Module): 1873 | def __init__( 1874 | self, 1875 | unets, 1876 | *, 1877 | image_sizes, # for cascading ddpm, image size at each stage 1878 | text_encoder_name = '',#DEFAULT_T5_NAME, 1879 | text_embed_dim = None, 1880 | channels = 3, 1881 | channels_out=3, 1882 | cond_drop_prob = 0.1, 1883 | random_crop_sizes = None, 1884 | lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level 1885 | per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find 1886 | condition_on_text = True, 1887 | auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader 1888 | dynamic_thresholding = True, 1889 | dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper 1890 | only_train_unet_number = None, 1891 | lowres_noise_schedule = 'linear', 1892 | num_sample_steps = 32, # number of sampling steps 1893 | sigma_min = 0.002, # min noise level 1894 | sigma_max = 80, # max noise level 1895 | sigma_data = 0.5, # standard deviation of data distribution 1896 | rho = 7, # controls the sampling schedule 1897 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 1898 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 1899 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 1900 | S_tmin = 0.05, 1901 | S_tmax = 50, 1902 | S_noise = 1.003, 1903 | #categorical_loss = False, 1904 | loss_type=0, #0=MSE, 1=Cross entropy, 2=KLDIv Loss 1905 | categorical_loss_ignore=None, 1906 | add_z_loss = False, 1907 | loss_z_factor = 1., 1908 | VAE=None, 1909 | 1910 | ): 1911 | super().__init__() 1912 | self.only_train_unet_number = only_train_unet_number 1913 | self.add_z_loss=add_z_loss 1914 | self.loss_z_factor=loss_z_factor 1915 | 1916 | self.vit_vae=VAE 1917 | self.only_train_unet_number = only_train_unet_number 1918 | 1919 | # conditioning hparams 1920 | 1921 | self.condition_on_text = condition_on_text 1922 | self.unconditional = not condition_on_text 1923 | self.loss_type=loss_type 1924 | if self.loss_type>0: 1925 | self.categorical_loss=True 1926 | self.m = nn.LogSoftmax(dim=1) #used for some loss functins 1927 | else: 1928 | self.categorical_loss=False 1929 | 1930 | self.categorical_loss_ignore=categorical_loss_ignore 1931 | 1932 | # channels 1933 | 1934 | self.channels = channels 1935 | self.channels_out = channels_out 1936 | 1937 | # automatically take care of ensuring that first unet is unconditional 1938 | # while the rest of the unets are conditioned on the low resolution image produced by previous unet 1939 | 1940 | unets = cast_tuple(unets) 1941 | num_unets = len(unets) 1942 | 1943 | # randomly cropping for upsampler training 1944 | 1945 | self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) 1946 | assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' 1947 | 1948 | # lowres augmentation noise schedule 1949 | 1950 | self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule) 1951 | 1952 | # get text encoder 1953 | 1954 | self.text_embed_dim =text_embed_dim 1955 | 1956 | # construct unets 1957 | 1958 | self.unets = nn.ModuleList([]) 1959 | self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment 1960 | 1961 | for ind, one_unet in enumerate(unets): 1962 | 1963 | assert isinstance(one_unet, (Unet, Unet3D,OneD_Unet, NullUnet)) 1964 | is_first = ind == 0 1965 | 1966 | one_unet = one_unet.cast_model_parameters( 1967 | lowres_cond = not is_first, 1968 | cond_on_text = self.condition_on_text, 1969 | text_embed_dim = self.text_embed_dim if self.condition_on_text else None, 1970 | channels = self.channels, 1971 | #channels_out = self.channels 1972 | channels_out = self.channels_out 1973 | ) 1974 | 1975 | self.unets.append(one_unet) 1976 | 1977 | # determine whether we are training on images or video 1978 | 1979 | is_video = any([isinstance(unet, Unet3D) for unet in self.unets]) 1980 | self.is_video = is_video 1981 | 1982 | 1983 | self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1' if not is_video else 'b -> b 1 1 1')) 1984 | self.resize_to = resize_video_to if is_video else resize_image_to 1985 | 1986 | # unet image sizes 1987 | 1988 | #self.image_sizes = cast_tuple(self.image_sizes) 1989 | self.image_sizes = image_sizes 1990 | assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}' 1991 | 1992 | self.sample_channels = cast_tuple(self.channels, num_unets) 1993 | 1994 | # cascading ddpm related stuff 1995 | 1996 | lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) 1997 | assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' 1998 | 1999 | self.lowres_sample_noise_level = lowres_sample_noise_level 2000 | self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level 2001 | 2002 | # classifier free guidance 2003 | 2004 | self.cond_drop_prob = cond_drop_prob 2005 | self.can_classifier_guidance = cond_drop_prob > 0. 2006 | 2007 | # normalize and unnormalize image functions 2008 | 2009 | self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity 2010 | self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity 2011 | self.input_image_range = (0. if auto_normalize_img else -1., 1.) 2012 | 2013 | # dynamic thresholding 2014 | 2015 | self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) 2016 | self.dynamic_thresholding_percentile = dynamic_thresholding_percentile 2017 | 2018 | # elucidating parameters 2019 | 2020 | hparams = [ 2021 | num_sample_steps, 2022 | sigma_min, 2023 | sigma_max, 2024 | sigma_data, 2025 | rho, 2026 | P_mean, 2027 | P_std, 2028 | S_churn, 2029 | S_tmin, 2030 | S_tmax, 2031 | S_noise, 2032 | ] 2033 | 2034 | hparams = [cast_tuple(hp, num_unets) for hp in hparams] 2035 | self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)] 2036 | 2037 | # one temp parameter for keeping track of device 2038 | 2039 | self.register_buffer('_temp', torch.tensor([0.]), persistent = False) 2040 | 2041 | # default to device of unets passed in 2042 | 2043 | self.to(next(self.unets.parameters()).device) 2044 | 2045 | def force_unconditional_(self): 2046 | self.condition_on_text = False 2047 | self.unconditional = True 2048 | 2049 | for unet in self.unets: 2050 | unet.cond_on_text = False 2051 | 2052 | @property 2053 | def device(self): 2054 | return self._temp.device 2055 | 2056 | def get_unet(self, unet_number): 2057 | assert 0 < unet_number <= len(self.unets) 2058 | index = unet_number - 1 2059 | 2060 | if isinstance(self.unets, nn.ModuleList): 2061 | unets_list = [unet for unet in self.unets] 2062 | delattr(self, 'unets') 2063 | self.unets = unets_list 2064 | 2065 | if index != self.unet_being_trained_index: 2066 | for unet_index, unet in enumerate(self.unets): 2067 | unet.to(self.device if unet_index == index else 'cpu') 2068 | 2069 | self.unet_being_trained_index = index 2070 | return self.unets[index] 2071 | 2072 | def reset_unets_all_one_device(self, device = None): 2073 | device = default(device, self.device) 2074 | self.unets = nn.ModuleList([*self.unets]) 2075 | self.unets.to(device) 2076 | 2077 | self.unet_being_trained_index = -1 2078 | 2079 | @contextmanager 2080 | def one_unet_in_gpu(self, unet_number = None, unet = None): 2081 | assert exists(unet_number) ^ exists(unet) 2082 | 2083 | if exists(unet_number): 2084 | unet = self.unets[unet_number - 1] 2085 | 2086 | devices = [module_device(unet) for unet in self.unets] 2087 | self.unets.cpu() 2088 | unet.to(self.device) 2089 | 2090 | yield 2091 | 2092 | for unet, device in zip(self.unets, devices): 2093 | unet.to(device) 2094 | 2095 | # overriding state dict functions 2096 | 2097 | def state_dict(self, *args, **kwargs): 2098 | self.reset_unets_all_one_device() 2099 | return super().state_dict(*args, **kwargs) 2100 | 2101 | def load_state_dict(self, *args, **kwargs): 2102 | self.reset_unets_all_one_device() 2103 | return super().load_state_dict(*args, **kwargs) 2104 | 2105 | # dynamic thresholding 2106 | 2107 | def threshold_x_start(self, x_start, dynamic_threshold = True): 2108 | if not dynamic_threshold: 2109 | return x_start.clamp(-1., 1.) 2110 | 2111 | s = torch.quantile( 2112 | rearrange(x_start, 'b ... -> b (...)').abs(), 2113 | self.dynamic_thresholding_percentile, 2114 | dim = -1 2115 | ) 2116 | 2117 | s.clamp_(min = 1.) 2118 | s = right_pad_dims_to(x_start, s) 2119 | return x_start.clamp(-s, s) / s 2120 | 2121 | # derived preconditioning params - Table 1 2122 | 2123 | def c_skip(self, sigma_data, sigma): 2124 | return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2) 2125 | 2126 | def c_out(self, sigma_data, sigma): 2127 | return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5 2128 | 2129 | def c_in(self, sigma_data, sigma): 2130 | return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5 2131 | 2132 | def c_noise(self, sigma): 2133 | return log(sigma) * 0.25 2134 | 2135 | # preconditioned network output 2136 | # equation (7) in the paper 2137 | 2138 | def preconditioned_network_forward( 2139 | self, 2140 | unet_forward, 2141 | noised_images, 2142 | sigma, 2143 | *, 2144 | sigma_data, 2145 | clamp = False, 2146 | dynamic_threshold = True, 2147 | **kwargs 2148 | ): 2149 | batch, device = noised_images.shape[0], noised_images.device 2150 | 2151 | if isinstance(sigma, float): 2152 | sigma = torch.full((batch,), sigma, device = device) 2153 | 2154 | padded_sigma = self.right_pad_dims_to_datatype(sigma) 2155 | 2156 | net_out = unet_forward( 2157 | self.c_in(sigma_data, padded_sigma) * noised_images, 2158 | self.c_noise(sigma), 2159 | **kwargs 2160 | ) 2161 | 2162 | out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out 2163 | 2164 | if not clamp: 2165 | return out 2166 | 2167 | return self.threshold_x_start(out, dynamic_threshold) 2168 | 2169 | # sampling 2170 | 2171 | # sample schedule 2172 | # equation (5) in the paper 2173 | 2174 | def sample_schedule( 2175 | self, 2176 | num_sample_steps, 2177 | rho, 2178 | sigma_min, 2179 | sigma_max 2180 | ): 2181 | N = num_sample_steps 2182 | inv_rho = 1 / rho 2183 | 2184 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) 2185 | sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho 2186 | 2187 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. 2188 | return sigmas 2189 | 2190 | @torch.no_grad() 2191 | def one_unet_sample( 2192 | self, 2193 | unet, 2194 | shape, 2195 | *, 2196 | unet_number, 2197 | clamp = True, 2198 | dynamic_threshold = True, 2199 | cond_scale = 1., 2200 | use_tqdm = True, 2201 | inpaint_images = None, 2202 | inpaint_masks = None, 2203 | inpaint_resample_times = 5, 2204 | init_images = None, 2205 | skip_steps = None, 2206 | sigma_min = None, 2207 | sigma_max = None, 2208 | **kwargs 2209 | ): 2210 | # get specific sampling hyperparameters for unet 2211 | 2212 | hp = self.hparams[unet_number - 1] 2213 | 2214 | sigma_min = default(sigma_min, hp.sigma_min) 2215 | sigma_max = default(sigma_max, hp.sigma_max) 2216 | 2217 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 2218 | 2219 | sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max) 2220 | 2221 | gammas = torch.where( 2222 | (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax), 2223 | min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1), 2224 | 0. 2225 | ) 2226 | 2227 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 2228 | 2229 | # images is noise at the beginning 2230 | 2231 | init_sigma = sigmas[0] 2232 | 2233 | images = init_sigma * torch.randn(shape, device = self.device) 2234 | 2235 | # initializing with an image 2236 | 2237 | if exists(init_images): 2238 | images += init_images 2239 | 2240 | # keeping track of x0, for self conditioning if needed 2241 | 2242 | x_start = None 2243 | 2244 | # prepare inpainting images and mask 2245 | 2246 | has_inpainting = exists(inpaint_images) and exists(inpaint_masks) 2247 | resample_times = inpaint_resample_times if has_inpainting else 1 2248 | 2249 | if has_inpainting: 2250 | inpaint_images = self.normalize_img(inpaint_images) 2251 | inpaint_images = self.resize_to(inpaint_images, shape[-1]) 2252 | inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool() 2253 | 2254 | # unet kwargs 2255 | 2256 | unet_kwargs = dict( 2257 | sigma_data = hp.sigma_data, 2258 | clamp = clamp, 2259 | dynamic_threshold = dynamic_threshold, 2260 | cond_scale = cond_scale, 2261 | **kwargs 2262 | ) 2263 | 2264 | # gradually denoise 2265 | 2266 | initial_step = default(skip_steps, 0) 2267 | sigmas_and_gammas = sigmas_and_gammas[initial_step:] 2268 | 2269 | total_steps = len(sigmas_and_gammas) 2270 | 2271 | for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm): 2272 | is_last_timestep = ind == (total_steps - 1) 2273 | 2274 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 2275 | 2276 | for r in reversed(range(resample_times)): 2277 | is_last_resample_step = r == 0 2278 | 2279 | eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 2280 | 2281 | sigma_hat = sigma + gamma * sigma 2282 | added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps 2283 | 2284 | images_hat = images + added_noise 2285 | 2286 | self_cond = x_start if unet.self_cond else None 2287 | 2288 | if has_inpainting: 2289 | images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks 2290 | 2291 | model_output = self.preconditioned_network_forward( 2292 | unet.forward_with_cond_scale, 2293 | images_hat, 2294 | sigma_hat, 2295 | self_cond = self_cond, 2296 | **unet_kwargs 2297 | ) 2298 | 2299 | denoised_over_sigma = (images_hat - model_output) / sigma_hat 2300 | 2301 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma 2302 | 2303 | # second order correction, if not the last timestep 2304 | 2305 | if sigma_next != 0: 2306 | self_cond = model_output if unet.self_cond else None 2307 | 2308 | model_output_next = self.preconditioned_network_forward( 2309 | unet.forward_with_cond_scale, 2310 | images_next, 2311 | sigma_next, 2312 | self_cond = self_cond, 2313 | **unet_kwargs 2314 | ) 2315 | 2316 | 2317 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next 2318 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 2319 | 2320 | images = images_next 2321 | 2322 | if has_inpainting and not (is_last_resample_step or is_last_timestep): 2323 | # renoise in repaint and then resample 2324 | repaint_noise = torch.randn(shape, device = self.device) 2325 | images = images + (sigma - sigma_next) * repaint_noise 2326 | 2327 | x_start = model_output # save model output for self conditioning 2328 | 2329 | 2330 | if has_inpainting: 2331 | images = images * ~inpaint_masks + inpaint_images * inpaint_masks 2332 | 2333 | 2334 | return images 2335 | 2336 | @torch.no_grad() 2337 | @eval_decorator 2338 | def sample( 2339 | self, 2340 | texts: List[str] = None, 2341 | text_masks = None, 2342 | text_embeds = None, 2343 | cond_images = None, 2344 | inpaint_images = None, 2345 | inpaint_masks = None, 2346 | inpaint_resample_times = 5, 2347 | init_images = None, 2348 | skip_steps = None, 2349 | sigma_min = None, 2350 | sigma_max = None, 2351 | video_frames = None, 2352 | batch_size = 1, 2353 | cond_scale = 1., 2354 | lowres_sample_noise_level = None, 2355 | start_at_unet_number = 1, 2356 | start_image_or_video = None, 2357 | stop_at_unet_number = None, 2358 | return_all_unet_outputs = False, 2359 | return_pil_images = False, 2360 | use_tqdm = True, 2361 | device = None, 2362 | 2363 | ): 2364 | device = default(device, self.device) 2365 | self.reset_unets_all_one_device(device = device) 2366 | 2367 | cond_images = maybe(cast_uint8_images_to_float)(cond_images) 2368 | 2369 | if exists(texts) and not exists(text_embeds) and not self.unconditional: 2370 | assert all([*map(len, texts)]), 'text cannot be empty' 2371 | 2372 | with autocast(enabled = False): 2373 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) 2374 | 2375 | text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) 2376 | 2377 | if not self.unconditional: 2378 | assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training' 2379 | 2380 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) 2381 | batch_size = text_embeds.shape[0] 2382 | 2383 | if exists(inpaint_images): 2384 | if self.unconditional: 2385 | if batch_size == 1: # assume researcher wants to broadcast along inpainted images 2386 | batch_size = inpaint_images.shape[0] 2387 | 2388 | assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``' 2389 | assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on' 2390 | 2391 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' 2392 | assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' 2393 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' 2394 | 2395 | assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting' 2396 | 2397 | outputs = [] 2398 | 2399 | is_cuda = next(self.parameters()).is_cuda 2400 | device = next(self.parameters()).device 2401 | 2402 | lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level) 2403 | 2404 | num_unets = len(self.unets) 2405 | cond_scale = cast_tuple(cond_scale, num_unets) 2406 | 2407 | # handle video and frame dimension 2408 | 2409 | assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' 2410 | 2411 | frame_dims = (video_frames,) if self.is_video else tuple() 2412 | 2413 | # initializing with an image or video 2414 | 2415 | init_images = cast_tuple(init_images, num_unets) 2416 | init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images] 2417 | 2418 | skip_steps = cast_tuple(skip_steps, num_unets) 2419 | 2420 | sigma_min = cast_tuple(sigma_min, num_unets) 2421 | sigma_max = cast_tuple(sigma_max, num_unets) 2422 | 2423 | # handle starting at a unet greater than 1, for training only-upscaler training 2424 | 2425 | if start_at_unet_number > 1: 2426 | assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets' 2427 | assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number 2428 | assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' 2429 | 2430 | prev_image_size = self.image_sizes[start_at_unet_number - 2] 2431 | img = self.resize_to(start_image_or_video, prev_image_size) 2432 | 2433 | # go through each unet in cascade 2434 | 2435 | for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm): 2436 | if unet_number < start_at_unet_number: 2437 | continue 2438 | 2439 | assert not isinstance(unet, NullUnet), 'cannot sample from null unet' 2440 | 2441 | context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext() 2442 | 2443 | with context: 2444 | lowres_cond_img = lowres_noise_times = None 2445 | 2446 | 2447 | shape = (batch_size, channel, *frame_dims, image_size ) 2448 | 2449 | if unet.lowres_cond: 2450 | lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) 2451 | 2452 | lowres_cond_img = self.resize_to(img, image_size) 2453 | 2454 | 2455 | lowres_cond_img = self.normalize_img(lowres_cond_img.float()) 2456 | 2457 | 2458 | lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img.float(), 2459 | t = lowres_noise_times, 2460 | noise = torch.randn_like(lowres_cond_img.float())) 2461 | 2462 | if exists(unet_init_images): 2463 | unet_init_images = self.resize_to(unet_init_images, image_size) 2464 | 2465 | 2466 | shape = (batch_size, self.channels, *frame_dims, image_size) 2467 | 2468 | img = self.one_unet_sample( 2469 | unet, 2470 | shape, 2471 | unet_number = unet_number, 2472 | text_embeds = text_embeds, 2473 | text_mask =text_masks, 2474 | cond_images = cond_images, 2475 | inpaint_images = inpaint_images, 2476 | inpaint_masks = inpaint_masks, 2477 | inpaint_resample_times = inpaint_resample_times, 2478 | init_images = unet_init_images, 2479 | skip_steps = unet_skip_steps, 2480 | sigma_min = unet_sigma_min, 2481 | sigma_max = unet_sigma_max, 2482 | cond_scale = unet_cond_scale, 2483 | lowres_cond_img = lowres_cond_img, 2484 | lowres_noise_times = lowres_noise_times, 2485 | dynamic_threshold = dynamic_threshold, 2486 | use_tqdm = use_tqdm 2487 | ) 2488 | 2489 | if self.categorical_loss: 2490 | img=self.m(img) 2491 | outputs.append(img) 2492 | 2493 | if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: 2494 | break 2495 | 2496 | output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs 2497 | 2498 | if not return_all_unet_outputs: 2499 | outputs = outputs[-1:] 2500 | 2501 | assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet' 2502 | 2503 | 2504 | if self.categorical_loss: 2505 | return torch.argmax(outputs[output_index], dim=1).unsqueeze (1) 2506 | else: 2507 | return outputs[output_index] 2508 | 2509 | 2510 | def loss_weight(self, sigma_data, sigma): 2511 | return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2 2512 | 2513 | def noise_distribution(self, P_mean, P_std, batch_size): 2514 | return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp() 2515 | 2516 | def forward( 2517 | self, 2518 | images, 2519 | unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None, 2520 | texts: List[str] = None, 2521 | text_embeds = None, 2522 | text_masks = None, 2523 | unet_number = None, 2524 | cond_images = None, 2525 | 2526 | ): 2527 | #assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}' 2528 | assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' 2529 | unet_number = default(unet_number, 1) 2530 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' 2531 | 2532 | 2533 | cond_images = maybe(cast_uint8_images_to_float)(cond_images) 2534 | 2535 | if self.categorical_loss==False: 2536 | assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead' 2537 | 2538 | unet_index = unet_number - 1 2539 | 2540 | unet = default(unet, lambda: self.get_unet(unet_number)) 2541 | 2542 | assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' 2543 | 2544 | target_image_size = self.image_sizes[unet_index] 2545 | random_crop_size = self.random_crop_sizes[unet_index] 2546 | prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None 2547 | hp = self.hparams[unet_index] 2548 | 2549 | batch_size, c, *_, h, device, is_video = *images.shape, images.device, (images.ndim == 4) 2550 | 2551 | frames = images.shape[2] if is_video else None 2552 | 2553 | check_shape(images, 'b c ...', c = self.channels) 2554 | 2555 | 2556 | assert h >= target_image_size 2557 | 2558 | if exists(texts) and not exists(text_embeds) and not self.unconditional: 2559 | assert all([*map(len, texts)]), 'text cannot be empty' 2560 | assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' 2561 | 2562 | with autocast(enabled = False): 2563 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) 2564 | 2565 | text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) 2566 | 2567 | if not self.unconditional: 2568 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) 2569 | 2570 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified' 2571 | assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented' 2572 | 2573 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' 2574 | 2575 | lowres_cond_img = lowres_aug_times = None 2576 | if exists(prev_image_size): 2577 | lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range) 2578 | lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range) 2579 | 2580 | if self.per_sample_random_aug_noise_level: 2581 | lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device) 2582 | else: 2583 | lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) 2584 | lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size) 2585 | 2586 | 2587 | if exists(random_crop_size): 2588 | aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) 2589 | 2590 | if is_video: 2591 | 2592 | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h -> (b f) c h') 2593 | 2594 | # make sure low res conditioner and image both get augmented the same way 2595 | # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop 2596 | images = aug(images) 2597 | lowres_cond_img = aug(lowres_cond_img, params = aug._params) 2598 | 2599 | if is_video: 2600 | 2601 | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h -> b c f h', f = frames) 2602 | 2603 | # noise the lowres conditioning image 2604 | # at sample time, they then fix the noise level of 0.1 - 0.3 2605 | 2606 | lowres_cond_img_noisy = None 2607 | if exists(lowres_cond_img): 2608 | 2609 | lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, 2610 | t = lowres_aug_times, 2611 | noise = torch.randn_like(lowres_cond_img.float())) 2612 | 2613 | # get the sigmas 2614 | 2615 | sigmas = self.noise_distribution(hp.P_mean, hp.P_std, batch_size) 2616 | padded_sigmas = self.right_pad_dims_to_datatype(sigmas) 2617 | 2618 | # noise 2619 | 2620 | noise = torch.randn_like(images.float()) 2621 | 2622 | 2623 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper 2624 | 2625 | # unet kwargs 2626 | 2627 | unet_kwargs = dict( 2628 | sigma_data = hp.sigma_data, 2629 | text_embeds = text_embeds, 2630 | text_mask =text_masks, 2631 | cond_images = cond_images, 2632 | lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times), 2633 | lowres_cond_img = lowres_cond_img_noisy, 2634 | cond_drop_prob = self.cond_drop_prob, 2635 | ) 2636 | 2637 | # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower 2638 | 2639 | self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet 2640 | 2641 | if self_cond and random() < 0.5: 2642 | with torch.no_grad(): 2643 | pred_x0 = self.preconditioned_network_forward( 2644 | unet.forward, 2645 | noised_images, 2646 | sigmas, 2647 | **unet_kwargs 2648 | ).detach() 2649 | 2650 | unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0} 2651 | 2652 | # get prediction 2653 | 2654 | denoised_images = self.preconditioned_network_forward( 2655 | unet.forward, 2656 | noised_images, 2657 | sigmas, 2658 | **unet_kwargs 2659 | ) 2660 | 2661 | 2662 | # losses 2663 | 2664 | if self.loss_type==0: #self.categorical_loss == False: 2665 | 2666 | 2667 | losses = F.mse_loss(denoised_images, images, reduction = 'none') 2668 | losses = reduce(losses, 'b ... -> b', 'mean') 2669 | 2670 | # loss weighting 2671 | 2672 | losses = losses * self.loss_weight(hp.sigma_data, sigmas) 2673 | losses=losses.mean()#+losses_z.mean() 2674 | 2675 | ############################################################ 2676 | # ENSURE zs are close to codebook vectors z 2677 | # calculate difference between z predicted = denoised images and the snapped version 2678 | 2679 | if self.add_z_loss: 2680 | denoised_images_rearr=torch.reshape(denoised_images, (denoised_images.shape[0],denoised_images.shape[1], sqer_z,sqer_z)) 2681 | with torch.no_grad(): 2682 | 2683 | z_quant, _, _, = self.vit_vae.codebook(denoised_images_rearr) 2684 | 2685 | losses_z = F.mse_loss(denoised_images_rearr, z_quant, reduction = 'none') 2686 | losses_z = reduce(losses_z, 'b ... -> b', 'mean') 2687 | 2688 | losses=losses +losses_z.mean() 2689 | 2690 | ############################################################ 2691 | 2692 | if self.loss_type==1: #self.categorical_loss: 2693 | 2694 | #channel is last 2695 | denoised_images=torch.permute(denoised_images, (0,2,1) ) 2696 | images=torch.permute(images, (0,2,1) ) 2697 | 2698 | if self.categorical_loss_ignore==None: 2699 | criterion_loss=nn.CrossEntropyLoss () 2700 | else: 2701 | print ("Cannot use with probability a") 2702 | 2703 | 2704 | denoised_images=denoised_images.transpose(0, 1).reshape(-1, denoised_images.shape[-1]) 2705 | 2706 | images=images.transpose(0, 1).reshape(-1, images.shape[-1]).softmax(dim=1) 2707 | 2708 | losses = criterion_loss(denoised_images, images) 2709 | 2710 | if self.loss_type==2: #torch.nn.KLDivLoss 2711 | 2712 | denoised_images=torch.permute(denoised_images, (0,2,1) ) 2713 | images=torch.permute(images, (0,2,1) ) 2714 | 2715 | 2716 | if self.categorical_loss_ignore==None: 2717 | criterion_loss=nn.KLDivLoss(reduction = 'batchmean') 2718 | else: 2719 | print ("Cannot use with probability") 2720 | 2721 | 2722 | denoised_images=denoised_images.transpose(0, 1).reshape(-1, denoised_images.shape[-1]) 2723 | 2724 | images=images.transpose(0, 1).reshape(-1, images.shape[-1]).softmax(dim=1) 2725 | 2726 | 2727 | losses = criterion_loss(denoised_images, images) 2728 | if self.loss_type==3: #NLLLOss 2729 | 2730 | denoised_images=torch.permute(denoised_images, (0,2,1) ) 2731 | 2732 | 2733 | if self.categorical_loss_ignore==None: 2734 | criterion_loss=nn.NLLLoss()#nn.KLDivLoss(reduction = 'batchmean') 2735 | else: 2736 | print ("Cannot use with probability") 2737 | 2738 | 2739 | denoised_images=denoised_images.transpose(0, 1).reshape(-1, denoised_images.shape[-1]) 2740 | images=torch.argmax(images, dim=1) #the input needs to have same dim as output. hence, 2741 | #we cannot use ordinal 2742 | images=images.transpose(0, 1).long().reshape(-1) 2743 | 2744 | 2745 | losses = criterion_loss(self.m(denoised_images), images) 2746 | # return average loss 2747 | 2748 | return losses 2749 | 2750 | ################################################################################## 2751 | ###################### Positional encoding ####################################### 2752 | ################################################################################## 2753 | 2754 | #https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/positional_encodings.py 2755 | 2756 | class PositionalEncoding1D(nn.Module): 2757 | def __init__(self, channels): 2758 | """ 2759 | :param channels: The last dimension of the tensor you want to apply pos emb to. 2760 | """ 2761 | super(PositionalEncoding1D, self).__init__() 2762 | self.org_channels = channels 2763 | channels = int(np.ceil(channels / 2) * 2) 2764 | self.channels = channels 2765 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 2766 | self.register_buffer("inv_freq", inv_freq) 2767 | 2768 | def forward(self, tensor): 2769 | """ 2770 | :param tensor: A 3d tensor of size (batch_size, x, ch) 2771 | :return: Positional Encoding Matrix of size (batch_size, x, ch) 2772 | """ 2773 | if len(tensor.shape) != 3: 2774 | raise RuntimeError("The input tensor has to be 3d!") 2775 | batch_size, x, orig_ch = tensor.shape 2776 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 2777 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 2778 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1) 2779 | emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type()) 2780 | emb[:, : self.channels] = emb_x 2781 | 2782 | return emb[None, :, :orig_ch].repeat(batch_size, 1, 1) 2783 | 2784 | 2785 | class PositionalEncodingPermute1D(nn.Module): 2786 | def __init__(self, channels): 2787 | """ 2788 | Accepts (batchsize, ch, x) instead of (batchsize, x, ch) 2789 | """ 2790 | super(PositionalEncodingPermute1D, self).__init__() 2791 | self.penc = PositionalEncoding1D(channels) 2792 | 2793 | def forward(self, tensor): 2794 | tensor = tensor.permute(0, 2, 1) 2795 | enc = self.penc(tensor) 2796 | return enc.permute(0, 2, 1) 2797 | 2798 | @property 2799 | def org_channels(self): 2800 | return self.penc.org_channels 2801 | 2802 | 2803 | class PositionalEncoding2D(nn.Module): 2804 | def __init__(self, channels): 2805 | """ 2806 | :param channels: The last dimension of the tensor you want to apply pos emb to. 2807 | """ 2808 | super(PositionalEncoding2D, self).__init__() 2809 | self.org_channels = channels 2810 | channels = int(np.ceil(channels / 4) * 2) 2811 | self.channels = channels 2812 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 2813 | self.register_buffer("inv_freq", inv_freq) 2814 | 2815 | def forward(self, tensor): 2816 | """ 2817 | :param tensor: A 4d tensor of size (batch_size, x, y, ch) 2818 | :return: Positional Encoding Matrix of size (batch_size, x, y, ch) 2819 | """ 2820 | if len(tensor.shape) != 4: 2821 | raise RuntimeError("The input tensor has to be 4d!") 2822 | batch_size, x, y, orig_ch = tensor.shape 2823 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 2824 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 2825 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 2826 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 2827 | emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1) 2828 | emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1) 2829 | emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( 2830 | tensor.type() 2831 | ) 2832 | emb[:, :, : self.channels] = emb_x 2833 | emb[:, :, self.channels : 2 * self.channels] = emb_y 2834 | 2835 | return emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1) 2836 | 2837 | 2838 | class PositionalEncodingPermute2D(nn.Module): 2839 | def __init__(self, channels): 2840 | """ 2841 | Accepts (batchsize, ch, x, y) instead of (batchsize, x, y, ch) 2842 | """ 2843 | super(PositionalEncodingPermute2D, self).__init__() 2844 | self.penc = PositionalEncoding2D(channels) 2845 | 2846 | def forward(self, tensor): 2847 | tensor = tensor.permute(0, 2, 3, 1) 2848 | enc = self.penc(tensor) 2849 | return enc.permute(0, 3, 1, 2) 2850 | 2851 | @property 2852 | def org_channels(self): 2853 | return self.penc.org_channels 2854 | 2855 | 2856 | class PositionalEncoding3D(nn.Module): 2857 | def __init__(self, channels): 2858 | """ 2859 | :param channels: The last dimension of the tensor you want to apply pos emb to. 2860 | """ 2861 | super(PositionalEncoding3D, self).__init__() 2862 | self.org_channels = channels 2863 | channels = int(np.ceil(channels / 6) * 2) 2864 | if channels % 2: 2865 | channels += 1 2866 | self.channels = channels 2867 | inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) 2868 | self.register_buffer("inv_freq", inv_freq) 2869 | 2870 | def forward(self, tensor): 2871 | """ 2872 | :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) 2873 | :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) 2874 | """ 2875 | if len(tensor.shape) != 5: 2876 | raise RuntimeError("The input tensor has to be 5d!") 2877 | batch_size, x, y, z, orig_ch = tensor.shape 2878 | pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) 2879 | pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) 2880 | pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) 2881 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 2882 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 2883 | sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) 2884 | emb_x = ( 2885 | torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1) 2886 | .unsqueeze(1) 2887 | .unsqueeze(1) 2888 | ) 2889 | emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1) 2890 | emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1) 2891 | emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type( 2892 | tensor.type() 2893 | ) 2894 | emb[:, :, :, : self.channels] = emb_x 2895 | emb[:, :, :, self.channels : 2 * self.channels] = emb_y 2896 | emb[:, :, :, 2 * self.channels :] = emb_z 2897 | 2898 | return emb[None, :, :, :, :orig_ch].repeat(batch_size, 1, 1, 1, 1) 2899 | 2900 | 2901 | class PositionalEncodingPermute3D(nn.Module): 2902 | def __init__(self, channels): 2903 | """ 2904 | Accepts (batchsize, ch, x, y, z) instead of (batchsize, x, y, z, ch) 2905 | """ 2906 | super(PositionalEncodingPermute3D, self).__init__() 2907 | self.penc = PositionalEncoding3D(channels) 2908 | 2909 | def forward(self, tensor): 2910 | tensor = tensor.permute(0, 2, 3, 4, 1) 2911 | enc = self.penc(tensor) 2912 | return enc.permute(0, 4, 1, 2, 3) 2913 | 2914 | @property 2915 | def org_channels(self): 2916 | return self.penc.org_channels 2917 | 2918 | 2919 | class FixEncoding(nn.Module): 2920 | """ 2921 | :param pos_encoder: instance of PositionalEncoding1D, PositionalEncoding2D or PositionalEncoding3D 2922 | :param shape: shape of input, excluding batch and embedding size 2923 | Example: 2924 | p_enc_2d = FixEncoding(PositionalEncoding2D(32), (x, y)) # for where x and y are the dimensions of your image 2925 | inputs = torch.randn(64, 128, 128, 32) # where x and y are 128, and 64 is the batch size 2926 | p_enc_2d(inputs) 2927 | """ 2928 | 2929 | def __init__(self, pos_encoder, shape): 2930 | super(FixEncoding, self).__init__() 2931 | self.shape = shape 2932 | self.dim = len(shape) 2933 | self.pos_encoder = pos_encoder 2934 | self.pos_encoding = pos_encoder( 2935 | torch.ones(1, *shape, self.pos_encoder.org_channels) 2936 | ) 2937 | self.batch_size = 0 2938 | 2939 | def forward(self, tensor): 2940 | if self.batch_size != tensor.shape[0]: 2941 | self.repeated_pos_encoding = self.pos_encoding.to(tensor.device).repeat( 2942 | tensor.shape[0], *(self.dim + 1) * [1] 2943 | ) 2944 | self.batch_size = tensor.shape[0] 2945 | return self.repeated_pos_encoding 2946 | 2947 | 2948 | ########################## TRAINER ##################### 2949 | import os 2950 | import time 2951 | import copy 2952 | from pathlib import Path 2953 | from math import ceil 2954 | from contextlib import contextmanager, nullcontext 2955 | from functools import partial, wraps 2956 | from collections.abc import Iterable 2957 | 2958 | import torch 2959 | from torch import nn 2960 | import torch.nn.functional as F 2961 | from torch.utils.data import random_split, DataLoader 2962 | from torch.optim import Adam 2963 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR 2964 | from torch.cuda.amp import autocast, GradScaler 2965 | 2966 | import pytorch_warmup as warmup 2967 | 2968 | from packaging import version 2969 | 2970 | import numpy as np 2971 | 2972 | from ema_pytorch import EMA 2973 | 2974 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs 2975 | 2976 | from fsspec.core import url_to_fs 2977 | from fsspec.implementations.local import LocalFileSystem 2978 | 2979 | # helper functions 2980 | 2981 | def exists(val): 2982 | return val is not None 2983 | 2984 | def default(val, d): 2985 | if exists(val): 2986 | return val 2987 | return d() if callable(d) else d 2988 | 2989 | def cast_tuple(val, length = 1): 2990 | if isinstance(val, list): 2991 | val = tuple(val) 2992 | 2993 | return val if isinstance(val, tuple) else ((val,) * length) 2994 | 2995 | def find_first(fn, arr): 2996 | for ind, el in enumerate(arr): 2997 | if fn(el): 2998 | return ind 2999 | return -1 3000 | 3001 | def pick_and_pop(keys, d): 3002 | values = list(map(lambda key: d.pop(key), keys)) 3003 | return dict(zip(keys, values)) 3004 | 3005 | def group_dict_by_key(cond, d): 3006 | return_val = [dict(),dict()] 3007 | for key in d.keys(): 3008 | match = bool(cond(key)) 3009 | ind = int(not match) 3010 | return_val[ind][key] = d[key] 3011 | return (*return_val,) 3012 | 3013 | def string_begins_with(prefix, str): 3014 | return str.startswith(prefix) 3015 | 3016 | def group_by_key_prefix(prefix, d): 3017 | return group_dict_by_key(partial(string_begins_with, prefix), d) 3018 | 3019 | def groupby_prefix_and_trim(prefix, d): 3020 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 3021 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 3022 | return kwargs_without_prefix, kwargs 3023 | 3024 | def num_to_groups(num, divisor): 3025 | groups = num // divisor 3026 | remainder = num % divisor 3027 | arr = [divisor] * groups 3028 | if remainder > 0: 3029 | arr.append(remainder) 3030 | return arr 3031 | 3032 | # url to fs, bucket, path - for checkpointing to cloud 3033 | 3034 | def url_to_bucket(url): 3035 | if '://' not in url: 3036 | return url 3037 | 3038 | _, suffix = url.split('://') 3039 | 3040 | if prefix in {'gs', 's3'}: 3041 | return suffix.split('/')[0] 3042 | else: 3043 | raise ValueError(f'storage type prefix "{prefix}" is not supported yet') 3044 | 3045 | # decorators 3046 | 3047 | def eval_decorator(fn): 3048 | def inner(model, *args, **kwargs): 3049 | was_training = model.training 3050 | model.eval() 3051 | out = fn(model, *args, **kwargs) 3052 | model.train(was_training) 3053 | return out 3054 | return inner 3055 | 3056 | def cast_torch_tensor(fn, cast_fp16 = False): 3057 | @wraps(fn) 3058 | def inner(model, *args, **kwargs): 3059 | device = kwargs.pop('_device', model.device) 3060 | cast_device = kwargs.pop('_cast_device', True) 3061 | 3062 | should_cast_fp16 = cast_fp16 and model.cast_half_at_training 3063 | 3064 | kwargs_keys = kwargs.keys() 3065 | all_args = (*args, *kwargs.values()) 3066 | split_kwargs_index = len(all_args) - len(kwargs_keys) 3067 | all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) 3068 | 3069 | if cast_device: 3070 | all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) 3071 | 3072 | if should_cast_fp16: 3073 | all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args)) 3074 | 3075 | args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] 3076 | kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) 3077 | 3078 | out = fn(model, *args, **kwargs) 3079 | return out 3080 | return inner 3081 | 3082 | # gradient accumulation functions 3083 | 3084 | def split_iterable(it, split_size): 3085 | accum = [] 3086 | for ind in range(ceil(len(it) / split_size)): 3087 | start_index = ind * split_size 3088 | accum.append(it[start_index: (start_index + split_size)]) 3089 | return accum 3090 | 3091 | def split(t, split_size = None): 3092 | if not exists(split_size): 3093 | return t 3094 | 3095 | if isinstance(t, torch.Tensor): 3096 | return t.split(split_size, dim = 0) 3097 | 3098 | if isinstance(t, Iterable): 3099 | return split_iterable(t, split_size) 3100 | 3101 | return TypeError 3102 | 3103 | def find_first(cond, arr): 3104 | for el in arr: 3105 | if cond(el): 3106 | return el 3107 | return None 3108 | 3109 | def split_args_and_kwargs(*args, split_size = None, **kwargs): 3110 | all_args = (*args, *kwargs.values()) 3111 | len_all_args = len(all_args) 3112 | first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) 3113 | assert exists(first_tensor) 3114 | 3115 | batch_size = len(first_tensor) 3116 | split_size = default(split_size, batch_size) 3117 | num_chunks = ceil(batch_size / split_size) 3118 | 3119 | dict_len = len(kwargs) 3120 | dict_keys = kwargs.keys() 3121 | split_kwargs_index = len_all_args - dict_len 3122 | 3123 | split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] 3124 | chunk_sizes = tuple(map(len, split_all_args[0])) 3125 | 3126 | for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): 3127 | chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] 3128 | chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) 3129 | chunk_size_frac = chunk_size / batch_size 3130 | yield chunk_size_frac, (chunked_args, chunked_kwargs) 3131 | 3132 | # imagen trainer 3133 | 3134 | def imagen_sample_in_chunks(fn): 3135 | @wraps(fn) 3136 | def inner(self, *args, max_batch_size = None, **kwargs): 3137 | if not exists(max_batch_size): 3138 | return fn(self, *args, **kwargs) 3139 | 3140 | if self.imagen.unconditional: 3141 | batch_size = kwargs.get('batch_size') 3142 | batch_sizes = num_to_groups(batch_size, max_batch_size) 3143 | outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] 3144 | else: 3145 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] 3146 | 3147 | if isinstance(outputs[0], torch.Tensor): 3148 | return torch.cat(outputs, dim = 0) 3149 | 3150 | return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs)))) 3151 | 3152 | return inner 3153 | 3154 | 3155 | def restore_parts(state_dict_target, state_dict_from): 3156 | for name, param in state_dict_from.items(): 3157 | 3158 | if name not in state_dict_target: 3159 | continue 3160 | 3161 | if param.size() == state_dict_target[name].size(): 3162 | state_dict_target[name].copy_(param) 3163 | else: 3164 | print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}") 3165 | 3166 | return state_dict_target 3167 | 3168 | 3169 | class HiearchicalDesignTrainer(nn.Module): 3170 | locked = False 3171 | 3172 | def __init__( 3173 | self, 3174 | #imagen = None, 3175 | model = None, 3176 | 3177 | imagen_checkpoint_path = None, 3178 | use_ema = True, 3179 | lr = 1e-4, 3180 | eps = 1e-8, 3181 | beta1 = 0.9, 3182 | beta2 = 0.99, 3183 | max_grad_norm = None, 3184 | group_wd_params = True, 3185 | warmup_steps = None, 3186 | cosine_decay_max_steps = None, 3187 | only_train_unet_number = None, 3188 | fp16 = False, 3189 | precision = None, 3190 | split_batches = True, 3191 | dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), 3192 | verbose = True, 3193 | split_valid_fraction = 0.025, 3194 | split_valid_from_train = False, 3195 | split_random_seed = 42, 3196 | checkpoint_path = None, 3197 | checkpoint_every = None, 3198 | checkpoint_fs = None, 3199 | fs_kwargs: dict = None, 3200 | max_checkpoints_keep = 20, 3201 | **kwargs 3202 | ): 3203 | super().__init__() 3204 | assert not HiearchicalDesignTrainer.locked, 'HiearchicalDesignTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' 3205 | assert exists(model.imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' 3206 | 3207 | # determine filesystem, using fsspec, for saving to local filesystem or cloud 3208 | 3209 | self.fs = checkpoint_fs 3210 | 3211 | if not exists(self.fs): 3212 | fs_kwargs = default(fs_kwargs, {}) 3213 | self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) 3214 | 3215 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) 3216 | 3217 | 3218 | #self.imagen = imagen 3219 | self.imagen = model.imagen 3220 | 3221 | # elucidated or not 3222 | 3223 | self.model=model 3224 | self.is_elucidated = self.model.is_elucidated#self.imagen.isinstance(imagen, ElucidatedImagen) 3225 | # create accelerator instance 3226 | 3227 | accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) 3228 | 3229 | assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' 3230 | accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') 3231 | 3232 | self.accelerator = Accelerator(**{ 3233 | 'split_batches': split_batches, 3234 | 'mixed_precision': accelerator_mixed_precision, 3235 | 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] 3236 | , **accelerate_kwargs}) 3237 | 3238 | HiearchicalDesignTrainer.locked = self.is_distributed 3239 | 3240 | # cast data to fp16 at training time if needed 3241 | 3242 | self.cast_half_at_training = accelerator_mixed_precision == 'fp16' 3243 | 3244 | # grad scaler must be managed outside of accelerator 3245 | 3246 | grad_scaler_enabled = fp16 3247 | 3248 | # imagen, unets and ema unets 3249 | 3250 | 3251 | self.num_unets = len(self.imagen.unets) 3252 | 3253 | self.use_ema = use_ema and self.is_main 3254 | self.ema_unets = nn.ModuleList([]) 3255 | 3256 | # keep track of what unet is being trained on 3257 | # only going to allow 1 unet training at a time 3258 | 3259 | self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on 3260 | 3261 | # data related functions 3262 | 3263 | self.train_dl_iter = None 3264 | self.train_dl = None 3265 | 3266 | self.valid_dl_iter = None 3267 | self.valid_dl = None 3268 | 3269 | self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names 3270 | 3271 | # auto splitting validation from training, if dataset is passed in 3272 | 3273 | self.split_valid_from_train = split_valid_from_train 3274 | 3275 | assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' 3276 | self.split_valid_fraction = split_valid_fraction 3277 | self.split_random_seed = split_random_seed 3278 | 3279 | # be able to finely customize learning rate, weight decay 3280 | # per unet 3281 | 3282 | lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) 3283 | 3284 | for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): 3285 | optimizer = Adam( 3286 | unet.parameters(), 3287 | lr = unet_lr, 3288 | eps = unet_eps, 3289 | betas = (beta1, beta2), 3290 | **kwargs 3291 | ) 3292 | 3293 | if self.use_ema: 3294 | self.ema_unets.append(EMA(unet, **ema_kwargs)) 3295 | 3296 | scaler = GradScaler(enabled = grad_scaler_enabled) 3297 | 3298 | scheduler = warmup_scheduler = None 3299 | 3300 | if exists(unet_cosine_decay_max_steps): 3301 | scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) 3302 | 3303 | if exists(unet_warmup_steps): 3304 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) 3305 | 3306 | if not exists(scheduler): 3307 | scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) 3308 | 3309 | # set on object 3310 | 3311 | setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers 3312 | setattr(self, f'scaler{ind}', scaler) 3313 | setattr(self, f'scheduler{ind}', scheduler) 3314 | setattr(self, f'warmup{ind}', warmup_scheduler) 3315 | 3316 | # gradient clipping if needed 3317 | 3318 | self.max_grad_norm = max_grad_norm 3319 | 3320 | # step tracker and misc 3321 | 3322 | self.register_buffer('steps', torch.tensor([0] * self.num_unets)) 3323 | 3324 | self.verbose = verbose 3325 | 3326 | # automatic set devices based on what accelerator decided 3327 | 3328 | self.imagen.to(self.device) 3329 | self.to(self.device) 3330 | 3331 | # checkpointing 3332 | 3333 | assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) 3334 | self.checkpoint_path = checkpoint_path 3335 | self.checkpoint_every = checkpoint_every 3336 | self.max_checkpoints_keep = max_checkpoints_keep 3337 | 3338 | self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main 3339 | 3340 | if exists(checkpoint_path) and self.can_checkpoint: 3341 | bucket = url_to_bucket(checkpoint_path) 3342 | 3343 | if not self.fs.exists(bucket): 3344 | self.fs.mkdir(bucket) 3345 | 3346 | self.load_from_checkpoint_folder() 3347 | 3348 | # only allowing training for unet 3349 | 3350 | self.only_train_unet_number = only_train_unet_number 3351 | self.validate_and_set_unet_being_trained(only_train_unet_number) 3352 | 3353 | # computed values 3354 | 3355 | @property 3356 | def device(self): 3357 | return self.accelerator.device 3358 | 3359 | @property 3360 | def is_distributed(self): 3361 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 3362 | 3363 | @property 3364 | def is_main(self): 3365 | return self.accelerator.is_main_process 3366 | 3367 | @property 3368 | def is_local_main(self): 3369 | return self.accelerator.is_local_main_process 3370 | 3371 | @property 3372 | def unwrapped_unet(self): 3373 | return self.accelerator.unwrap_model(self.unet_being_trained) 3374 | 3375 | # optimizer helper functions 3376 | 3377 | def get_lr(self, unet_number): 3378 | self.validate_unet_number(unet_number) 3379 | unet_index = unet_number - 1 3380 | 3381 | optim = getattr(self, f'optim{unet_index}') 3382 | 3383 | return optim.param_groups[0]['lr'] 3384 | 3385 | # function for allowing only one unet from being trained at a time 3386 | 3387 | def validate_and_set_unet_being_trained(self, unet_number = None): 3388 | if exists(unet_number): 3389 | self.validate_unet_number(unet_number) 3390 | 3391 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' 3392 | 3393 | self.only_train_unet_number = unet_number 3394 | self.imagen.only_train_unet_number = unet_number 3395 | 3396 | if not exists(unet_number): 3397 | return 3398 | 3399 | self.wrap_unet(unet_number) 3400 | 3401 | def wrap_unet(self, unet_number): 3402 | if hasattr(self, 'one_unet_wrapped'): 3403 | return 3404 | 3405 | unet = self.imagen.get_unet(unet_number) 3406 | self.unet_being_trained = self.accelerator.prepare(unet) 3407 | unet_index = unet_number - 1 3408 | 3409 | optimizer = getattr(self, f'optim{unet_index}') 3410 | scheduler = getattr(self, f'scheduler{unet_index}') 3411 | 3412 | optimizer = self.accelerator.prepare(optimizer) 3413 | 3414 | if exists(scheduler): 3415 | scheduler = self.accelerator.prepare(scheduler) 3416 | 3417 | setattr(self, f'optim{unet_index}', optimizer) 3418 | setattr(self, f'scheduler{unet_index}', scheduler) 3419 | 3420 | self.one_unet_wrapped = True 3421 | 3422 | # hacking accelerator due to not having separate gradscaler per optimizer 3423 | 3424 | def set_accelerator_scaler(self, unet_number): 3425 | unet_number = self.validate_unet_number(unet_number) 3426 | scaler = getattr(self, f'scaler{unet_number - 1}') 3427 | 3428 | self.accelerator.scaler = scaler 3429 | for optimizer in self.accelerator._optimizers: 3430 | optimizer.scaler = scaler 3431 | 3432 | # helper print 3433 | 3434 | def print(self, msg): 3435 | if not self.is_main: 3436 | return 3437 | 3438 | if not self.verbose: 3439 | return 3440 | 3441 | return self.accelerator.print(msg) 3442 | 3443 | # validating the unet number 3444 | 3445 | def validate_unet_number(self, unet_number = None): 3446 | if self.num_unets == 1: 3447 | unet_number = default(unet_number, 1) 3448 | 3449 | assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' 3450 | return unet_number 3451 | 3452 | # number of training steps taken 3453 | 3454 | def num_steps_taken(self, unet_number = None): 3455 | if self.num_unets == 1: 3456 | unet_number = default(unet_number, 1) 3457 | 3458 | return self.steps[unet_number - 1].item() 3459 | 3460 | def print_untrained_unets(self): 3461 | print_final_error = False 3462 | 3463 | for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): 3464 | if steps > 0 or isinstance(unet, NullUnet): 3465 | continue 3466 | 3467 | self.print(f'unet {ind + 1} has not been trained') 3468 | print_final_error = True 3469 | 3470 | if print_final_error: 3471 | self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') 3472 | 3473 | # data related functions 3474 | 3475 | def add_train_dataloader(self, dl = None): 3476 | if not exists(dl): 3477 | return 3478 | 3479 | assert not exists(self.train_dl), 'training dataloader was already added' 3480 | self.train_dl = self.accelerator.prepare(dl) 3481 | 3482 | def add_valid_dataloader(self, dl): 3483 | if not exists(dl): 3484 | return 3485 | 3486 | assert not exists(self.valid_dl), 'validation dataloader was already added' 3487 | self.valid_dl = self.accelerator.prepare(dl) 3488 | 3489 | def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): 3490 | if not exists(ds): 3491 | return 3492 | 3493 | assert not exists(self.train_dl), 'training dataloader was already added' 3494 | 3495 | valid_ds = None 3496 | if self.split_valid_from_train: 3497 | train_size = int((1 - self.split_valid_fraction) * len(ds)) 3498 | valid_size = len(ds) - train_size 3499 | 3500 | ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) 3501 | self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') 3502 | 3503 | dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) 3504 | self.train_dl = self.accelerator.prepare(dl) 3505 | 3506 | if not self.split_valid_from_train: 3507 | return 3508 | 3509 | self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) 3510 | 3511 | def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): 3512 | if not exists(ds): 3513 | return 3514 | 3515 | assert not exists(self.valid_dl), 'validation dataloader was already added' 3516 | 3517 | dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) 3518 | self.valid_dl = self.accelerator.prepare(dl) 3519 | 3520 | def create_train_iter(self): 3521 | assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' 3522 | 3523 | if exists(self.train_dl_iter): 3524 | return 3525 | 3526 | self.train_dl_iter = cycle(self.train_dl) 3527 | 3528 | def create_valid_iter(self): 3529 | assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' 3530 | 3531 | if exists(self.valid_dl_iter): 3532 | return 3533 | 3534 | self.valid_dl_iter = cycle(self.valid_dl) 3535 | 3536 | def train_step(self, unet_number = None, **kwargs): 3537 | self.create_train_iter() 3538 | loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) 3539 | self.update(unet_number = unet_number) 3540 | return loss 3541 | 3542 | @torch.no_grad() 3543 | @eval_decorator 3544 | def valid_step(self, **kwargs): 3545 | self.create_valid_iter() 3546 | 3547 | context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext 3548 | 3549 | with context(): 3550 | loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) 3551 | return loss 3552 | 3553 | def step_with_dl_iter(self, dl_iter, **kwargs): 3554 | dl_tuple_output = cast_tuple(next(dl_iter)) 3555 | model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) 3556 | loss = self.forward(**{**kwargs, **model_input}) 3557 | return loss 3558 | 3559 | # checkpointing functions 3560 | 3561 | @property 3562 | def all_checkpoints_sorted(self): 3563 | glob_pattern = os.path.join(self.checkpoint_path, '*.pt') 3564 | checkpoints = self.fs.glob(glob_pattern) 3565 | sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) 3566 | return sorted_checkpoints 3567 | 3568 | def load_from_checkpoint_folder(self, last_total_steps = -1): 3569 | if last_total_steps != -1: 3570 | filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') 3571 | self.load(filepath) 3572 | return 3573 | 3574 | sorted_checkpoints = self.all_checkpoints_sorted 3575 | 3576 | if len(sorted_checkpoints) == 0: 3577 | self.print(f'no checkpoints found to load from at {self.checkpoint_path}') 3578 | return 3579 | 3580 | last_checkpoint = sorted_checkpoints[0] 3581 | self.load(last_checkpoint) 3582 | 3583 | def save_to_checkpoint_folder(self): 3584 | self.accelerator.wait_for_everyone() 3585 | 3586 | if not self.can_checkpoint: 3587 | return 3588 | 3589 | total_steps = int(self.steps.sum().item()) 3590 | filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') 3591 | 3592 | self.save(filepath) 3593 | 3594 | if self.max_checkpoints_keep <= 0: 3595 | return 3596 | 3597 | sorted_checkpoints = self.all_checkpoints_sorted 3598 | checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] 3599 | 3600 | for checkpoint in checkpoints_to_discard: 3601 | self.fs.rm(checkpoint) 3602 | 3603 | # saving and loading functions 3604 | 3605 | def save( 3606 | self, 3607 | path, 3608 | overwrite = True, 3609 | without_optim_and_sched = False, 3610 | **kwargs 3611 | ): 3612 | self.accelerator.wait_for_everyone() 3613 | 3614 | if not self.can_checkpoint: 3615 | return 3616 | 3617 | fs = self.fs 3618 | 3619 | assert not (fs.exists(path) and not overwrite) 3620 | 3621 | self.reset_ema_unets_all_one_device() 3622 | 3623 | save_obj = dict( 3624 | model = self.imagen.state_dict(), 3625 | version = __version__, 3626 | steps = self.steps.cpu(), 3627 | **kwargs 3628 | ) 3629 | 3630 | save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() 3631 | 3632 | for ind in save_optim_and_sched_iter: 3633 | scaler_key = f'scaler{ind}' 3634 | optimizer_key = f'optim{ind}' 3635 | scheduler_key = f'scheduler{ind}' 3636 | warmup_scheduler_key = f'warmup{ind}' 3637 | 3638 | scaler = getattr(self, scaler_key) 3639 | optimizer = getattr(self, optimizer_key) 3640 | scheduler = getattr(self, scheduler_key) 3641 | warmup_scheduler = getattr(self, warmup_scheduler_key) 3642 | 3643 | if exists(scheduler): 3644 | save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} 3645 | 3646 | if exists(warmup_scheduler): 3647 | save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} 3648 | 3649 | save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} 3650 | 3651 | if self.use_ema: 3652 | save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} 3653 | 3654 | # determine if imagen config is available 3655 | 3656 | if hasattr(self.imagen, '_config'): 3657 | self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') 3658 | 3659 | save_obj = { 3660 | **save_obj, 3661 | 'imagen_type': 'elucidated' if self.is_elucidated else 'original', 3662 | 'imagen_params': self.imagen._config 3663 | } 3664 | 3665 | #save to path 3666 | 3667 | with fs.open(path, 'wb') as f: 3668 | torch.save(save_obj, f) 3669 | 3670 | self.print(f'checkpoint saved to {path}') 3671 | 3672 | def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): 3673 | fs = self.fs 3674 | 3675 | if noop_if_not_exist and not fs.exists(path): 3676 | self.print(f'trainer checkpoint not found at {str(path)}') 3677 | return 3678 | 3679 | assert fs.exists(path), f'{path} does not exist' 3680 | 3681 | self.reset_ema_unets_all_one_device() 3682 | 3683 | # to avoid extra GPU memory usage in main process when using Accelerate 3684 | 3685 | with fs.open(path) as f: 3686 | loaded_obj = torch.load(f, map_location='cpu') 3687 | 3688 | if version.parse(__version__) != version.parse(loaded_obj['version']): 3689 | self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') 3690 | 3691 | try: 3692 | self.imagen.load_state_dict(loaded_obj['model'], strict = strict) 3693 | except RuntimeError: 3694 | print("Failed loading state dict. Trying partial load") 3695 | self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), 3696 | loaded_obj['model'])) 3697 | 3698 | if only_model: 3699 | return loaded_obj 3700 | 3701 | self.steps.copy_(loaded_obj['steps']) 3702 | 3703 | for ind in range(0, self.num_unets): 3704 | scaler_key = f'scaler{ind}' 3705 | optimizer_key = f'optim{ind}' 3706 | scheduler_key = f'scheduler{ind}' 3707 | warmup_scheduler_key = f'warmup{ind}' 3708 | 3709 | scaler = getattr(self, scaler_key) 3710 | optimizer = getattr(self, optimizer_key) 3711 | scheduler = getattr(self, scheduler_key) 3712 | warmup_scheduler = getattr(self, warmup_scheduler_key) 3713 | 3714 | if exists(scheduler) and scheduler_key in loaded_obj: 3715 | scheduler.load_state_dict(loaded_obj[scheduler_key]) 3716 | 3717 | if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: 3718 | warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) 3719 | 3720 | if exists(optimizer): 3721 | try: 3722 | optimizer.load_state_dict(loaded_obj[optimizer_key]) 3723 | scaler.load_state_dict(loaded_obj[scaler_key]) 3724 | except: 3725 | self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') 3726 | 3727 | if self.use_ema: 3728 | assert 'ema' in loaded_obj 3729 | try: 3730 | self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) 3731 | except RuntimeError: 3732 | print("Failed loading state dict. Trying partial load") 3733 | self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), 3734 | loaded_obj['ema'])) 3735 | 3736 | self.print(f'checkpoint loaded from {path}') 3737 | return loaded_obj 3738 | 3739 | # managing ema unets and their devices 3740 | 3741 | @property 3742 | def unets(self): 3743 | return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) 3744 | 3745 | def get_ema_unet(self, unet_number = None): 3746 | if not self.use_ema: 3747 | return 3748 | 3749 | unet_number = self.validate_unet_number(unet_number) 3750 | index = unet_number - 1 3751 | 3752 | if isinstance(self.unets, nn.ModuleList): 3753 | unets_list = [unet for unet in self.ema_unets] 3754 | delattr(self, 'ema_unets') 3755 | self.ema_unets = unets_list 3756 | 3757 | if index != self.ema_unet_being_trained_index: 3758 | for unet_index, unet in enumerate(self.ema_unets): 3759 | unet.to(self.device if unet_index == index else 'cpu') 3760 | 3761 | self.ema_unet_being_trained_index = index 3762 | return self.ema_unets[index] 3763 | 3764 | def reset_ema_unets_all_one_device(self, device = None): 3765 | if not self.use_ema: 3766 | return 3767 | 3768 | device = default(device, self.device) 3769 | self.ema_unets = nn.ModuleList([*self.ema_unets]) 3770 | self.ema_unets.to(device) 3771 | 3772 | self.ema_unet_being_trained_index = -1 3773 | 3774 | @torch.no_grad() 3775 | @contextmanager 3776 | def use_ema_unets(self): 3777 | if not self.use_ema: 3778 | output = yield 3779 | return output 3780 | 3781 | self.reset_ema_unets_all_one_device() 3782 | self.imagen.reset_unets_all_one_device() 3783 | 3784 | self.unets.eval() 3785 | 3786 | trainable_unets = self.imagen.unets 3787 | self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling 3788 | 3789 | output = yield 3790 | 3791 | self.imagen.unets = trainable_unets # restore original training unets 3792 | 3793 | # cast the ema_model unets back to original device 3794 | for ema in self.ema_unets: 3795 | ema.restore_ema_model_device() 3796 | 3797 | return output 3798 | 3799 | def print_unet_devices(self): 3800 | self.print('unet devices:') 3801 | for i, unet in enumerate(self.imagen.unets): 3802 | device = next(unet.parameters()).device 3803 | self.print(f'\tunet {i}: {device}') 3804 | 3805 | if not self.use_ema: 3806 | return 3807 | 3808 | self.print('\nema unet devices:') 3809 | for i, ema_unet in enumerate(self.ema_unets): 3810 | device = next(ema_unet.parameters()).device 3811 | self.print(f'\tema unet {i}: {device}') 3812 | 3813 | # overriding state dict functions 3814 | 3815 | def state_dict(self, *args, **kwargs): 3816 | self.reset_ema_unets_all_one_device() 3817 | return super().state_dict(*args, **kwargs) 3818 | 3819 | def load_state_dict(self, *args, **kwargs): 3820 | self.reset_ema_unets_all_one_device() 3821 | return super().load_state_dict(*args, **kwargs) 3822 | 3823 | # encoding text functions 3824 | 3825 | def encode_text(self, text, **kwargs): 3826 | return self.imagen.encode_text(text, **kwargs) 3827 | 3828 | # forwarding functions and gradient step updates 3829 | 3830 | def update(self, unet_number = None): 3831 | unet_number = self.validate_unet_number(unet_number) 3832 | self.validate_and_set_unet_being_trained(unet_number) 3833 | self.set_accelerator_scaler(unet_number) 3834 | 3835 | index = unet_number - 1 3836 | unet = self.unet_being_trained 3837 | 3838 | optimizer = getattr(self, f'optim{index}') 3839 | scaler = getattr(self, f'scaler{index}') 3840 | scheduler = getattr(self, f'scheduler{index}') 3841 | warmup_scheduler = getattr(self, f'warmup{index}') 3842 | 3843 | # set the grad scaler on the accelerator, since we are managing one per u-net 3844 | 3845 | if exists(self.max_grad_norm): 3846 | self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) 3847 | 3848 | optimizer.step() 3849 | optimizer.zero_grad() 3850 | 3851 | if self.use_ema: 3852 | ema_unet = self.get_ema_unet(unet_number) 3853 | ema_unet.update() 3854 | 3855 | # scheduler, if needed 3856 | 3857 | maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() 3858 | 3859 | with maybe_warmup_context: 3860 | if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs 3861 | scheduler.step() 3862 | 3863 | self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) 3864 | 3865 | if not exists(self.checkpoint_path): 3866 | return 3867 | 3868 | total_steps = int(self.steps.sum().item()) 3869 | 3870 | if total_steps % self.checkpoint_every: 3871 | return 3872 | 3873 | self.save_to_checkpoint_folder() 3874 | 3875 | @torch.no_grad() 3876 | @cast_torch_tensor 3877 | @imagen_sample_in_chunks 3878 | def sample(self, *args, **kwargs): 3879 | context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets 3880 | 3881 | self.print_untrained_unets() 3882 | 3883 | if not self.is_main: 3884 | kwargs['use_tqdm'] = False 3885 | 3886 | with context(): 3887 | output = self.imagen.sample(*args, device = self.device, **kwargs) 3888 | 3889 | return output 3890 | 3891 | @partial(cast_torch_tensor, cast_fp16 = True) 3892 | def forward( 3893 | self, 3894 | *args, 3895 | unet_number = None, 3896 | max_batch_size = None, 3897 | **kwargs 3898 | ): 3899 | unet_number = self.validate_unet_number(unet_number) 3900 | self.validate_and_set_unet_being_trained(unet_number) 3901 | self.set_accelerator_scaler(unet_number) 3902 | 3903 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' 3904 | 3905 | total_loss = 0. 3906 | 3907 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): 3908 | with self.accelerator.autocast(): 3909 | 3910 | loss = self.model(*chunked_args, unet_number = unet_number, **chunked_kwargs) 3911 | loss = loss * chunk_size_frac 3912 | 3913 | total_loss += loss#.item() 3914 | 3915 | if self.training: 3916 | self.accelerator.backward(loss) 3917 | 3918 | return total_loss 3919 | 3920 | ##################################################################################### 3921 | ########################### Integrated Diffusion Model ############################## 3922 | ##################################################################################### 3923 | 3924 | class HierarchicalDesignDiffusion (nn.Module): 3925 | def __init__(self, timesteps=10 , dim=32,pred_dim=25,loss_type=0, 3926 | padding_idx=0, 3927 | cond_dim = 512, 3928 | text_embed_dim = 512, 3929 | input_tokens=25,#for non-BERT 3930 | sequence_embed=False, 3931 | embed_dim_position=32, 3932 | max_text_len=16, 3933 | device='cuda:0', pos_emb=False, 3934 | pos_emb_fourier=False, 3935 | pos_emb_fourier_add=True, #add vs cooncenatite foruer pos end 3936 | add_z_loss = False, 3937 | loss_z_factor = 1, 3938 | VAE=None, 3939 | max_length = 128, 3940 | unets_list=None, #can provide custom Unets if needed, e.g. to have multiple unets 3941 | ): 3942 | super(HierarchicalDesignDiffusion, self).__init__() 3943 | 3944 | self.device=device 3945 | 3946 | self.pred_dim=pred_dim 3947 | self.loss_type=loss_type 3948 | self.pos_emb=pos_emb 3949 | self.pos_emb_fourier=pos_emb_fourier 3950 | self.pos_emb_fourier_add=pos_emb_fourier_add 3951 | self.fc_embed1 = nn.Linear( 8, max_length) # INPUT DIM (last), OUTPUT DIM, last 3952 | self.fc_embed2 = nn.Linear( 1, text_embed_dim) # INPUT DIM (last), OUTPUT DIM, last 3953 | self.max_text_len=max_text_len 3954 | 3955 | if self.pos_emb: 3956 | self.pos_emb_x = nn.Embedding(max_length+1, embed_dim_position) 3957 | text_embed_dim=text_embed_dim+embed_dim_position 3958 | self.pos_matrix_i = torch.zeros (max_length, dtype=torch.long) 3959 | for i in range (max_length): 3960 | self.pos_matrix_i [i]=i +1 3961 | if self.pos_emb_fourier: 3962 | if self.pos_emb_fourier_add==False: 3963 | text_embed_dim=text_embed_dim+embed_dim_position 3964 | 3965 | self.p_enc_1d = PositionalEncoding1D(embed_dim_position) 3966 | 3967 | if unets_list==None: 3968 | 3969 | unet1 = OneD_Unet( 3970 | dim = dim, #int(dim/4), 3971 | text_embed_dim = text_embed_dim, 3972 | cond_dim = cond_dim, #this is where text embeddings are projected to... 3973 | dim_mults = (1, 2, 4, 8), 3974 | 3975 | num_resnet_blocks = 3, 3976 | layer_attns = (True, True, True, True), 3977 | layer_cross_attns = (True, True, True, True), 3978 | channels=self.pred_dim, 3979 | channels_out=self.pred_dim , 3980 | # 3981 | attn_dim_head = 64, 3982 | attn_heads = 8, 3983 | ff_mult = 2., 3984 | lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ 3985 | 3986 | layer_attns_depth =2, 3987 | layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 3988 | attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) 3989 | 3990 | use_linear_attn = False, 3991 | use_linear_cross_attn = False, 3992 | cond_on_text = True, 3993 | max_text_len = max_length, 3994 | init_dim = None, 3995 | resnet_groups = 8, 3996 | init_conv_kernel_size =7, # # kernel size of initial conv, if not using cross embed 3997 | init_cross_embed = False, #TODO - fix ouput size calcs for conv1d 3998 | init_cross_embed_kernel_sizes = (3, 7, 15), 3999 | cross_embed_downsample = False, 4000 | cross_embed_downsample_kernel_sizes = (2, 4), 4001 | attn_pool_text = True, 4002 | attn_pool_num_latents = 32,#32, #perceiver model latents 4003 | dropout = 0., 4004 | memory_efficient = False, 4005 | init_conv_to_final_conv_residual = False, 4006 | use_global_context_attn = True, 4007 | scale_skip_connection = True, 4008 | final_resnet_block = True, #True, 4009 | final_conv_kernel_size = 3,#3, 4010 | cosine_sim_attn = True, 4011 | self_cond = False, 4012 | combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully 4013 | pixel_shuffle_upsample = False , # may address checkboard artifacts 4014 | 4015 | 4016 | ).to (self.device) 4017 | else: 4018 | unets=unets_list 4019 | 4020 | 4021 | self.is_elucidated=True 4022 | 4023 | self.imagen = ElucidatedImagen( 4024 | unets = (unet1), 4025 | channels=self.pred_dim, 4026 | channels_out=self.pred_dim , 4027 | loss_type=loss_type, 4028 | 4029 | text_embed_dim = text_embed_dim, 4030 | image_sizes = [max_length], 4031 | 4032 | cond_drop_prob = 0.3, 4033 | auto_normalize_img = False, 4034 | num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) 4035 | sigma_min = 0.002, # min noise level 4036 | sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler 4037 | sigma_data = 0.5, # standard deviation of data distribution 4038 | rho = 7, # controls the sampling schedule 4039 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 4040 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 4041 | S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 4042 | S_tmin = 0.05, 4043 | S_tmax = 50, 4044 | S_noise = 1.003, 4045 | add_z_loss = add_z_loss, 4046 | loss_z_factor = loss_z_factor, 4047 | VAE=VAE, 4048 | ).to (self.device) 4049 | 4050 | def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction 4051 | 4052 | x=x.unsqueeze (2) 4053 | 4054 | x= self.fc_embed2(x) 4055 | 4056 | if self.pos_emb: 4057 | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1, 1).to(device=device) 4058 | pos_emb_x = self.pos_emb_x( pos_matrix_i_) 4059 | pos_emb_x = torch.squeeze(pos_emb_x, 1) 4060 | x= torch.cat( (x, pos_emb_x), 2) 4061 | 4062 | if self.pos_emb_fourier: 4063 | 4064 | pos_fourier_xy=self.p_enc_1d(x) 4065 | 4066 | 4067 | if self.pos_emb_fourier_add: 4068 | x=x+pos_fourier_xy 4069 | 4070 | else: 4071 | x= torch.cat( (x, pos_fourier_xy), 2) 4072 | 4073 | loss = self.imagen(output, text_embeds = x, unet_number = unet_number, ) 4074 | 4075 | return loss 4076 | 4077 | def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,init_images=None): 4078 | 4079 | x=x.unsqueeze (2) 4080 | 4081 | x= self.fc_embed2(x) 4082 | 4083 | if self.pos_emb: 4084 | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1, 1).to(device=device) 4085 | pos_emb_x = self.pos_emb_x( pos_matrix_i_) 4086 | pos_emb_x = torch.squeeze(pos_emb_x, 1) 4087 | x= torch.cat( (x, pos_emb_x), 2) 4088 | 4089 | if self.pos_emb_fourier: 4090 | 4091 | pos_fourier_xy=self.p_enc_1d(x) 4092 | 4093 | 4094 | if self.pos_emb_fourier_add: 4095 | x=x+pos_fourier_xy 4096 | 4097 | else: 4098 | x= torch.cat( (x, pos_fourier_xy), 2) 4099 | 4100 | 4101 | output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, 4102 | stop_at_unet_number=stop_at_unet_number,init_images=init_images) 4103 | 4104 | return output 4105 | 4106 | ################################################################ 4107 | # Diffusion model to predict stress/strain from microstructure encoding 4108 | ################################################################ 4109 | 4110 | 4111 | class HierarchicalDesignDiffusion_PredictStressStrain (nn.Module): 4112 | def __init__(self, timesteps=10 , dim=32,pred_dim=25,loss_type=0, 4113 | padding_idx=0, 4114 | cond_dim = 512, 4115 | text_embed_dim = 512, 4116 | input_tokens=25,#for non-BERT 4117 | sequence_embed=False, 4118 | embed_dim_position=32, 4119 | max_text_len=16, 4120 | device='cuda:0', pos_emb=False, 4121 | pos_emb_fourier=False, 4122 | pos_emb_fourier_add=True, #add vs cooncenatite foruer pos end 4123 | add_z_loss = False, 4124 | loss_z_factor = 1, 4125 | VAE=None, 4126 | max_length = 128, 4127 | ): 4128 | super(HierarchicalDesignDiffusion_PredictStressStrain, self).__init__() 4129 | 4130 | self.device=device 4131 | 4132 | self.pred_dim=pred_dim 4133 | self.loss_type=loss_type 4134 | self.pos_emb=pos_emb 4135 | self.pos_emb_fourier=pos_emb_fourier 4136 | self.pos_emb_fourier_add=pos_emb_fourier_add 4137 | 4138 | self.max_text_len=max_text_len 4139 | 4140 | if self.pos_emb: 4141 | self.pos_emb_x = nn.Embedding(max_length+1, embed_dim_position) 4142 | text_embed_dim=text_embed_dim+embed_dim_position 4143 | self.pos_matrix_i = torch.zeros (max_length, dtype=torch.long) 4144 | for i in range (max_length): 4145 | #for j in range (im_res): 4146 | self.pos_matrix_i [i]=i +1 4147 | if self.pos_emb_fourier: 4148 | if self.pos_emb_fourier_add==False: 4149 | text_embed_dim=text_embed_dim+embed_dim_position 4150 | self.p_enc_1d = PositionalEncoding1D(embed_dim_position) 4151 | 4152 | unet1 = OneD_Unet( 4153 | dim = dim, 4154 | text_embed_dim = text_embed_dim, 4155 | cond_dim = cond_dim, #this is where text embeddings are projected to... 4156 | dim_mults = (1, 2, 4, 8), 4157 | 4158 | num_resnet_blocks = 3, 4159 | layer_attns = (True, True, True, True), 4160 | layer_cross_attns = (True, True, True, True), 4161 | channels=self.pred_dim, 4162 | channels_out=self.pred_dim , 4163 | # 4164 | attn_dim_head = 64, 4165 | attn_heads = 8, 4166 | ff_mult = 2., 4167 | lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ 4168 | 4169 | layer_attns_depth =2,# 1, 4170 | layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 4171 | attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) 4172 | 4173 | use_linear_attn = False, 4174 | use_linear_cross_attn = False, 4175 | cond_on_text = True, 4176 | max_text_len = max_length, 4177 | init_dim = None, 4178 | resnet_groups = 8,#8, 4179 | init_conv_kernel_size =7, #7, # kernel size of initial conv, if not using cross embed 4180 | init_cross_embed = False, #TODO - fix ouput size calcs for conv1d 4181 | init_cross_embed_kernel_sizes = (3, 7, 15), 4182 | cross_embed_downsample = False, 4183 | cross_embed_downsample_kernel_sizes = (2, 4), 4184 | attn_pool_text = True, 4185 | attn_pool_num_latents = 32, #perceiver model latents 4186 | dropout = 0., 4187 | memory_efficient = False, 4188 | init_conv_to_final_conv_residual = False, 4189 | use_global_context_attn = True, 4190 | scale_skip_connection = True, 4191 | final_resnet_block = True, 4192 | final_conv_kernel_size = 3, 4193 | cosine_sim_attn = True, 4194 | self_cond = False, 4195 | combine_upsample_fmaps = True, # combine feature maps from all upsample blocks, used in unet squared successfully 4196 | pixel_shuffle_upsample = False , # may address checkboard artifacts 4197 | 4198 | ).to (self.device) 4199 | 4200 | self.is_elucidated=True 4201 | self.imagen = ElucidatedImagen( 4202 | unets = (unet1), 4203 | channels=self.pred_dim, 4204 | channels_out=self.pred_dim , 4205 | loss_type=loss_type, 4206 | 4207 | text_embed_dim = text_embed_dim, 4208 | image_sizes = [max_length], 4209 | 4210 | cond_drop_prob = 0., 4211 | auto_normalize_img = False, 4212 | num_sample_steps = timesteps,#(64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) 4213 | sigma_min = 0.002, # min noise level 4214 | sigma_max = 160,#(80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler 4215 | sigma_data = 0.5, # standard deviation of data distribution 4216 | rho = 7, # controls the sampling schedule 4217 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 4218 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 4219 | S_churn = 40,#80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 4220 | S_tmin = 0.05, 4221 | S_tmax = 50, 4222 | S_noise = 1.003, 4223 | add_z_loss = add_z_loss, 4224 | loss_z_factor = loss_z_factor, 4225 | VAE=VAE, 4226 | ).to (self.device) 4227 | 4228 | def forward(self, x, output, unet_number=1): #sequences=conditioning, output=prediction 4229 | 4230 | x=torch.permute(x, (0,2,1) ) 4231 | 4232 | if self.pos_emb: 4233 | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1, 1).to(device=device) 4234 | pos_emb_x = self.pos_emb_x( pos_matrix_i_) 4235 | pos_emb_x = torch.squeeze(pos_emb_x, 1) 4236 | x= torch.cat( (x, pos_emb_x), 2) 4237 | if self.pos_emb_fourier: 4238 | 4239 | pos_fourier_xy=self.p_enc_1d(x) 4240 | 4241 | if self.pos_emb_fourier_add: 4242 | x=x+pos_fourier_xy 4243 | 4244 | else: 4245 | x= torch.cat( (x, pos_fourier_xy), 2) 4246 | loss = self.imagen(output, text_embeds = x, unet_number = unet_number, ) 4247 | 4248 | return loss 4249 | 4250 | def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,): 4251 | 4252 | x=torch.permute(x, (0,2,1) ) 4253 | 4254 | if self.pos_emb: 4255 | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1, 1).to(device=device) 4256 | pos_emb_x = self.pos_emb_x( pos_matrix_i_) 4257 | pos_emb_x = torch.squeeze(pos_emb_x, 1) 4258 | x= torch.cat( (x, pos_emb_x), 2) 4259 | if self.pos_emb_fourier: 4260 | 4261 | pos_fourier_xy=self.p_enc_1d(x) 4262 | 4263 | if self.pos_emb_fourier_add: 4264 | x=x+pos_fourier_xy 4265 | 4266 | else: 4267 | x= torch.cat( (x, pos_fourier_xy), 2) 4268 | 4269 | output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number) 4270 | 4271 | return output 4272 | -------------------------------------------------------------------------------- /HierarchicalDesign/VQVAE.py: -------------------------------------------------------------------------------- 1 | ######################################################### 2 | # Define codebook model and VQ-VAE with attention 3 | ######################################################### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | 11 | from torchvision.utils import save_image, make_grid 12 | import torch.nn.functional as F 13 | from torchvision import datasets, transforms, models 14 | from sklearn.metrics import r2_score 15 | 16 | import matplotlib.pyplot as plt 17 | 18 | import ast 19 | import pandas as pd 20 | import numpy as np 21 | 22 | 23 | from torchvision.io import read_image 24 | import pandas as pd 25 | 26 | from PIL import Image 27 | import time 28 | to_pil = transforms.ToPILImage() 29 | 30 | #Codebook 31 | import torch 32 | from torch import nn, einsum 33 | import torch.nn.functional as F 34 | import torch.distributed as distributed 35 | from torch.cuda.amp import autocast 36 | 37 | from einops import rearrange, repeat 38 | from contextlib import contextmanager 39 | 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | 43 | ######################################################### 44 | # CODE BASE: Codebook 45 | ######################################################### 46 | def get_fmap_from_codebook(model, indices): 47 | codes = model.codebook.codebook[indices] 48 | fmap = model.codebook.project_out(codes) 49 | return rearrange(fmap, 'b h w c -> b c h w') 50 | 51 | def exists(val): 52 | return val is not None 53 | 54 | def default(val, d): 55 | return val if exists(val) else d 56 | 57 | def noop(*args, **kwargs): 58 | pass 59 | 60 | def l2norm(t): 61 | return F.normalize(t, p = 2, dim = -1) 62 | 63 | def log(t, eps = 1e-20): 64 | return torch.log(t.clamp(min = eps)) 65 | 66 | def uniform_init(*shape): 67 | t = torch.empty(shape) 68 | nn.init.kaiming_uniform_(t) 69 | return t 70 | 71 | def gumbel_noise(t): 72 | noise = torch.zeros_like(t).uniform_(0, 1) 73 | return -log(-log(noise)) 74 | 75 | def gumbel_sample(t, temperature = 1., dim = -1): 76 | if temperature == 0: 77 | return t.argmax(dim = dim) 78 | 79 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 80 | 81 | def ema_inplace(moving_avg, new, decay): 82 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 83 | 84 | def laplace_smoothing(x, n_categories, eps = 1e-5): 85 | return (x + eps) / (x.sum() + n_categories * eps) 86 | 87 | def sample_vectors(samples, num): 88 | num_samples, device = samples.shape[0], samples.device 89 | if num_samples >= num: 90 | indices = torch.randperm(num_samples, device = device)[:num] 91 | else: 92 | indices = torch.randint(0, num_samples, (num,), device = device) 93 | 94 | return samples[indices] 95 | 96 | def batched_sample_vectors(samples, num): 97 | return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0) 98 | 99 | def pad_shape(shape, size, dim = 0): 100 | return [size if i == dim else s for i, s in enumerate(shape)] 101 | 102 | def sample_multinomial(total_count, probs): 103 | device = probs.device 104 | probs = probs.cpu() 105 | 106 | total_count = probs.new_full((), total_count) 107 | remainder = probs.new_ones(()) 108 | sample = torch.empty_like(probs, dtype = torch.long) 109 | 110 | for i, p in enumerate(probs): 111 | s = torch.binomial(total_count, p / remainder) 112 | sample[i] = s 113 | total_count -= s 114 | remainder -= p 115 | 116 | return sample.to(device) 117 | 118 | def all_gather_sizes(x, dim): 119 | size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device) 120 | all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] 121 | distributed.all_gather(all_sizes, size) 122 | return torch.stack(all_sizes) 123 | 124 | def all_gather_variably_sized(x, sizes, dim = 0): 125 | rank = distributed.get_rank() 126 | all_x = [] 127 | 128 | for i, size in enumerate(sizes): 129 | t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) 130 | distributed.broadcast(t, src = i, async_op = True) 131 | all_x.append(t) 132 | 133 | distributed.barrier() 134 | return all_x 135 | 136 | def sample_vectors_distributed(local_samples, num): 137 | local_samples = rearrange(local_samples, '1 ... -> ...') 138 | 139 | rank = distributed.get_rank() 140 | all_num_samples = all_gather_sizes(local_samples, dim = 0) 141 | 142 | if rank == 0: 143 | samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) 144 | else: 145 | samples_per_rank = torch.empty_like(all_num_samples) 146 | 147 | distributed.broadcast(samples_per_rank, src = 0) 148 | samples_per_rank = samples_per_rank.tolist() 149 | 150 | local_samples = sample_vectors(local_samples, samples_per_rank[rank]) 151 | all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0) 152 | out = torch.cat(all_samples, dim = 0) 153 | 154 | return rearrange(out, '... -> 1 ...') 155 | 156 | def batched_bincount(x, *, minlength): 157 | batch, dtype, device = x.shape[0], x.dtype, x.device 158 | target = torch.zeros(batch, minlength, dtype = dtype, device = device) 159 | values = torch.ones_like(x) 160 | target.scatter_add_(-1, x, values) 161 | return target 162 | 163 | def kmeans( 164 | samples, 165 | num_clusters, 166 | num_iters = 10, 167 | use_cosine_sim = False, 168 | sample_fn = batched_sample_vectors, 169 | all_reduce_fn = noop 170 | ): 171 | num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device 172 | 173 | means = sample_fn(samples, num_clusters) 174 | 175 | for _ in range(num_iters): 176 | if use_cosine_sim: 177 | dists = samples @ rearrange(means, 'h n d -> h d n') 178 | else: 179 | dists = -torch.cdist(samples, means, p = 2) 180 | 181 | buckets = torch.argmax(dists, dim = -1) 182 | bins = batched_bincount(buckets, minlength = num_clusters) 183 | all_reduce_fn(bins) 184 | 185 | zero_mask = bins == 0 186 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 187 | 188 | new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype) 189 | 190 | new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples) 191 | new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1') 192 | all_reduce_fn(new_means) 193 | 194 | if use_cosine_sim: 195 | new_means = l2norm(new_means) 196 | 197 | means = torch.where( 198 | rearrange(zero_mask, '... -> ... 1'), 199 | means, 200 | new_means 201 | ) 202 | 203 | return means, bins 204 | 205 | def batched_embedding(indices, embeds): 206 | batch, dim = indices.shape[1], embeds.shape[-1] 207 | indices = repeat(indices, 'h b n -> h b n d', d = dim) 208 | embeds = repeat(embeds, 'h c d -> h b c d', b = batch) 209 | return embeds.gather(2, indices) 210 | 211 | # regularization losses 212 | 213 | def orthogonal_loss_fn(t): 214 | # eq (2) from https://arxiv.org/abs/2112.00384 215 | h, n = t.shape[:2] 216 | normed_codes = l2norm(t) 217 | cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes) 218 | return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n) 219 | 220 | # distance types 221 | 222 | class EuclideanCodebook(nn.Module): 223 | def __init__( 224 | self, 225 | dim, 226 | codebook_size, 227 | num_codebooks = 1, 228 | kmeans_init = False, 229 | kmeans_iters = 10, 230 | sync_kmeans = True, 231 | decay = 0.8, 232 | eps = 1e-5, 233 | threshold_ema_dead_code = 2, 234 | use_ddp = False, 235 | learnable_codebook = False, 236 | sample_codebook_temp = 0 237 | ): 238 | super().__init__() 239 | self.decay = decay 240 | init_fn = uniform_init if not kmeans_init else torch.zeros 241 | embed = init_fn(num_codebooks, codebook_size, dim) 242 | 243 | self.codebook_size = codebook_size 244 | self.num_codebooks = num_codebooks 245 | 246 | self.kmeans_iters = kmeans_iters 247 | self.eps = eps 248 | self.threshold_ema_dead_code = threshold_ema_dead_code 249 | self.sample_codebook_temp = sample_codebook_temp 250 | 251 | assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now' 252 | 253 | self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors 254 | self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop 255 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 256 | 257 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 258 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 259 | self.register_buffer('embed_avg', embed.clone()) 260 | 261 | self.learnable_codebook = learnable_codebook 262 | if learnable_codebook: 263 | self.embed = nn.Parameter(embed) 264 | else: 265 | self.register_buffer('embed', embed) 266 | 267 | @torch.jit.ignore 268 | def init_embed_(self, data): 269 | if self.initted: 270 | return 271 | 272 | embed, cluster_size = kmeans( 273 | data, 274 | self.codebook_size, 275 | self.kmeans_iters, 276 | sample_fn = self.sample_fn, 277 | all_reduce_fn = self.kmeans_all_reduce_fn 278 | ) 279 | 280 | self.embed.data.copy_(embed) 281 | self.embed_avg.data.copy_(embed.clone()) 282 | self.cluster_size.data.copy_(cluster_size) 283 | self.initted.data.copy_(torch.Tensor([True])) 284 | 285 | def replace(self, batch_samples, batch_mask): 286 | batch_samples = l2norm(batch_samples) 287 | 288 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 289 | if not torch.any(mask): 290 | continue 291 | 292 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 293 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') 294 | 295 | def expire_codes_(self, batch_samples): 296 | if self.threshold_ema_dead_code == 0: 297 | return 298 | 299 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 300 | 301 | if not torch.any(expired_codes): 302 | return 303 | 304 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 305 | self.replace(batch_samples, batch_mask = expired_codes) 306 | 307 | @autocast(enabled = False) 308 | def forward(self, x): 309 | needs_codebook_dim = x.ndim < 4 310 | 311 | x = x.float() 312 | 313 | if needs_codebook_dim: 314 | x = rearrange(x, '... -> 1 ...') 315 | 316 | shape, dtype = x.shape, x.dtype 317 | flatten = rearrange(x, 'h ... d -> h (...) d') 318 | 319 | self.init_embed_(flatten) 320 | 321 | embed = self.embed if not self.learnable_codebook else self.embed.detach() 322 | 323 | dist = -torch.cdist(flatten, embed, p = 2) 324 | 325 | embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) 326 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 327 | embed_ind = embed_ind.view(*shape[:-1]) 328 | 329 | quantize = batched_embedding(embed_ind, self.embed) 330 | 331 | if self.training: 332 | cluster_size = embed_onehot.sum(dim = 1) 333 | 334 | self.all_reduce_fn(cluster_size) 335 | ema_inplace(self.cluster_size, cluster_size, self.decay) 336 | 337 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 338 | self.all_reduce_fn(embed_sum.contiguous()) 339 | ema_inplace(self.embed_avg, embed_sum, self.decay) 340 | 341 | cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() 342 | 343 | embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1') 344 | self.embed.data.copy_(embed_normalized) 345 | self.expire_codes_(x) 346 | 347 | if needs_codebook_dim: 348 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 349 | 350 | return quantize, embed_ind 351 | 352 | class CosineSimCodebook(nn.Module): 353 | def __init__( 354 | self, 355 | dim, 356 | codebook_size, 357 | num_codebooks = 1, 358 | kmeans_init = False, 359 | kmeans_iters = 10, 360 | sync_kmeans = True, 361 | decay = 0.8, 362 | eps = 1e-5, 363 | threshold_ema_dead_code = 2, 364 | use_ddp = False, 365 | learnable_codebook = False, 366 | sample_codebook_temp = 0. 367 | ): 368 | super().__init__() 369 | self.decay = decay 370 | 371 | if not kmeans_init: 372 | embed = l2norm(uniform_init(num_codebooks, codebook_size, dim)) 373 | else: 374 | embed = torch.zeros(num_codebooks, codebook_size, dim) 375 | 376 | self.codebook_size = codebook_size 377 | self.num_codebooks = num_codebooks 378 | 379 | self.kmeans_iters = kmeans_iters 380 | self.eps = eps 381 | self.threshold_ema_dead_code = threshold_ema_dead_code 382 | self.sample_codebook_temp = sample_codebook_temp 383 | 384 | self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors 385 | self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop 386 | self.all_reduce_fn = distributed.all_reduce if use_ddp else noop 387 | 388 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 389 | self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size)) 390 | 391 | self.learnable_codebook = learnable_codebook 392 | if learnable_codebook: 393 | self.embed = nn.Parameter(embed) 394 | else: 395 | self.register_buffer('embed', embed) 396 | 397 | @torch.jit.ignore 398 | def init_embed_(self, data): 399 | if self.initted: 400 | return 401 | 402 | embed, cluster_size = kmeans( 403 | data, 404 | self.codebook_size, 405 | self.kmeans_iters, 406 | use_cosine_sim = True, 407 | sample_fn = self.sample_fn, 408 | all_reduce_fn = self.kmeans_all_reduce_fn 409 | ) 410 | 411 | self.embed.data.copy_(embed) 412 | self.cluster_size.data.copy_(cluster_size) 413 | self.initted.data.copy_(torch.Tensor([True])) 414 | 415 | def replace(self, batch_samples, batch_mask): 416 | batch_samples = l2norm(batch_samples) 417 | 418 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))): 419 | if not torch.any(mask): 420 | continue 421 | 422 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) 423 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') 424 | 425 | def expire_codes_(self, batch_samples): 426 | if self.threshold_ema_dead_code == 0: 427 | return 428 | 429 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 430 | 431 | if not torch.any(expired_codes): 432 | return 433 | 434 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') 435 | self.replace(batch_samples, batch_mask = expired_codes) 436 | 437 | @autocast(enabled = False) 438 | def forward(self, x): 439 | needs_codebook_dim = x.ndim < 4 440 | 441 | x = x.float() 442 | 443 | if needs_codebook_dim: 444 | x = rearrange(x, '... -> 1 ...') 445 | 446 | shape, dtype = x.shape, x.dtype 447 | 448 | flatten = rearrange(x, 'h ... d -> h (...) d') 449 | flatten = l2norm(flatten) 450 | 451 | self.init_embed_(flatten) 452 | 453 | embed = self.embed if not self.learnable_codebook else self.embed.detach() 454 | embed = l2norm(embed) 455 | 456 | dist = einsum('h n d, h c d -> h n c', flatten, embed) 457 | embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp) 458 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 459 | embed_ind = embed_ind.view(*shape[:-1]) 460 | 461 | quantize = batched_embedding(embed_ind, self.embed) 462 | 463 | if self.training: 464 | bins = embed_onehot.sum(dim = 1) 465 | self.all_reduce_fn(bins) 466 | 467 | ema_inplace(self.cluster_size, bins, self.decay) 468 | 469 | zero_mask = (bins == 0) 470 | bins = bins.masked_fill(zero_mask, 1.) 471 | 472 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) 473 | self.all_reduce_fn(embed_sum) 474 | 475 | embed_normalized = embed_sum / rearrange(bins, '... -> ... 1') 476 | embed_normalized = l2norm(embed_normalized) 477 | 478 | embed_normalized = torch.where( 479 | rearrange(zero_mask, '... -> ... 1'), 480 | embed, 481 | embed_normalized 482 | ) 483 | 484 | ema_inplace(self.embed, embed_normalized, self.decay) 485 | self.expire_codes_(x) 486 | 487 | if needs_codebook_dim: 488 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) 489 | 490 | return quantize, embed_ind 491 | 492 | # main class 493 | 494 | class VectorQuantize(nn.Module): 495 | def __init__( 496 | self, 497 | dim, 498 | codebook_size, 499 | codebook_dim = None, 500 | heads = 1, 501 | separate_codebook_per_head = False, 502 | decay = 0.8, 503 | eps = 1e-5, 504 | kmeans_init = False, 505 | kmeans_iters = 10, 506 | sync_kmeans = True, 507 | use_cosine_sim = False, 508 | threshold_ema_dead_code = 0, 509 | channel_last = True, 510 | accept_image_fmap = False, 511 | commitment_weight = 1., 512 | orthogonal_reg_weight = 0., 513 | orthogonal_reg_active_codes_only = False, 514 | orthogonal_reg_max_codes = None, 515 | sample_codebook_temp = 0., 516 | sync_codebook = False 517 | ): 518 | super().__init__() 519 | self.heads = heads 520 | self.separate_codebook_per_head = separate_codebook_per_head 521 | 522 | codebook_dim = default(codebook_dim, dim) 523 | codebook_input_dim = codebook_dim * heads 524 | 525 | requires_projection = codebook_input_dim != dim 526 | self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() 527 | self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() 528 | 529 | self.eps = eps 530 | self.commitment_weight = commitment_weight 531 | 532 | has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 533 | self.orthogonal_reg_weight = orthogonal_reg_weight 534 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 535 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes 536 | 537 | codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook 538 | 539 | self._codebook = codebook_class( 540 | dim = codebook_dim, 541 | num_codebooks = heads if separate_codebook_per_head else 1, 542 | codebook_size = codebook_size, 543 | kmeans_init = kmeans_init, 544 | kmeans_iters = kmeans_iters, 545 | sync_kmeans = sync_kmeans, 546 | decay = decay, 547 | eps = eps, 548 | threshold_ema_dead_code = threshold_ema_dead_code, 549 | use_ddp = sync_codebook, 550 | learnable_codebook = has_codebook_orthogonal_loss, 551 | sample_codebook_temp = sample_codebook_temp 552 | ) 553 | 554 | self.codebook_size = codebook_size 555 | 556 | self.accept_image_fmap = accept_image_fmap 557 | self.channel_last = channel_last 558 | 559 | @property 560 | def codebook(self): 561 | codebook = self._codebook.embed 562 | if self.separate_codebook_per_head: 563 | return codebook 564 | 565 | return rearrange(codebook, '1 ... -> ...') 566 | 567 | def forward( 568 | self, 569 | x, 570 | mask = None 571 | ): 572 | shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size 573 | 574 | need_transpose = not self.channel_last and not self.accept_image_fmap 575 | 576 | if self.accept_image_fmap: 577 | height, width = x.shape[-2:] 578 | x = rearrange(x, 'b c h w -> b (h w) c') 579 | 580 | if need_transpose: 581 | x = rearrange(x, 'b d n -> b n d') 582 | 583 | x = self.project_in(x) 584 | 585 | if is_multiheaded: 586 | ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d' 587 | x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads) 588 | 589 | quantize, embed_ind = self._codebook(x) 590 | 591 | if self.training: 592 | quantize = x + (quantize - x).detach() 593 | 594 | loss = torch.tensor([0.], device = device, requires_grad = self.training) 595 | 596 | if self.training: 597 | if self.commitment_weight > 0: 598 | detached_quantize = quantize.detach() 599 | 600 | if exists(mask): 601 | # with variable lengthed sequences 602 | commit_loss = F.mse_loss(detached_quantize, x, reduction = 'none') 603 | 604 | if is_multiheaded: 605 | mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0]) 606 | 607 | commit_loss = commit_loss[mask].mean() 608 | else: 609 | commit_loss = F.mse_loss(detached_quantize, x) 610 | 611 | loss = loss + commit_loss * self.commitment_weight 612 | 613 | if self.orthogonal_reg_weight > 0: 614 | codebook = self._codebook.embed 615 | 616 | if self.orthogonal_reg_active_codes_only: 617 | # only calculate orthogonal loss for the activated codes for this batch 618 | unique_code_ids = torch.unique(embed_ind) 619 | codebook = codebook[unique_code_ids] 620 | 621 | num_codes = codebook.shape[0] 622 | if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: 623 | rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes] 624 | codebook = codebook[rand_ids] 625 | 626 | orthogonal_reg_loss = orthogonal_loss_fn(codebook) 627 | loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight 628 | 629 | if is_multiheaded: 630 | if self.separate_codebook_per_head: 631 | quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads) 632 | embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads) 633 | else: 634 | quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads) 635 | embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads) 636 | 637 | quantize = self.project_out(quantize) 638 | 639 | if need_transpose: 640 | quantize = rearrange(quantize, 'b n d -> b d n') 641 | 642 | if self.accept_image_fmap: 643 | quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width) 644 | embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width) 645 | 646 | return quantize, embed_ind, loss 647 | 648 | ######################################################### 649 | # CODE BASE: VQVAE Model 650 | ######################################################### 651 | 652 | class VQVAEModel(nn.Module): 653 | def __init__(self, Encoder, Codebook, Decoder): 654 | super(VQVAEModel, self).__init__() 655 | self.encoder = Encoder 656 | self.codebook = Codebook 657 | self.decoder = Decoder 658 | 659 | def forward(self, x): 660 | z = self.encoder(x) 661 | 662 | z_quantized, indices, commitment_loss = self.codebook(z) 663 | 664 | x_hat = self.decoder(z_quantized) 665 | 666 | return x_hat, indices, commitment_loss#, codebook_loss, perplexity 667 | 668 | def encode_z(self, x): 669 | z = self.encoder(x) 670 | 671 | z_quantized, indices, commitment_loss, = self.codebook(z) 672 | 673 | return z_quantized, indices 674 | 675 | def encode(self, x): 676 | z = self.encoder(x) 677 | 678 | return z 679 | 680 | def decode (self, z): #itake predicted z and snap it to the quantized ones, then calculate x_hat 681 | x_hat_nonquant = self.decoder(z) #non-snapped 682 | return x_hat_nonquant 683 | 684 | def decode_snapped (self, z): #take predicted z and snap it to the quantized ones, then calculate x_hat 685 | 686 | z_quantized, indices, commitment_loss, = self.codebook(z) 687 | 688 | x_hat = self.decoder(z_quantized) 689 | 690 | return x_hat, z_quantized 691 | 692 | ######################################################### 693 | # CODE BASE: CNN/Attention layers 694 | ######################################################### 695 | 696 | class GroupNorm(nn.Module): 697 | def __init__(self, channels ): 698 | super(GroupNorm, self).__init__() 699 | num_groups=8 700 | #print ("##", num_groups,channels) 701 | self.gn = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6, affine=True) 702 | 703 | def forward(self, x): 704 | return self.gn(x) 705 | 706 | 707 | class Swish(nn.Module): 708 | def forward(self, x): 709 | return x * torch.sigmoid(x) 710 | 711 | 712 | class ResidualBlock(nn.Module): 713 | def __init__(self, in_channels, out_channels): 714 | super(ResidualBlock, self).__init__() 715 | self.in_channels = in_channels 716 | self.out_channels = out_channels 717 | self.block = nn.Sequential( 718 | GroupNorm(in_channels), 719 | Swish(), 720 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 721 | GroupNorm(out_channels), 722 | Swish(), 723 | nn.Conv2d(out_channels, out_channels, 3, 1, 1) 724 | ) 725 | 726 | if in_channels != out_channels: 727 | self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 728 | 729 | def forward(self, x): 730 | if self.in_channels != self.out_channels: 731 | return self.channel_up(x) + self.block(x) 732 | else: 733 | return x + self.block(x) 734 | 735 | 736 | class UpSampleBlock(nn.Module): 737 | def __init__(self, channels): 738 | super(UpSampleBlock, self).__init__() 739 | self.conv = nn.Conv2d(channels, channels, 3, 1, 1) 740 | 741 | def forward(self, x): 742 | x = F.interpolate(x, scale_factor=2.0) 743 | return self.conv(x) 744 | 745 | 746 | class DownSampleBlock(nn.Module): 747 | def __init__(self, channels): 748 | super(DownSampleBlock, self).__init__() 749 | self.conv = nn.Conv2d(channels, channels, 3, 2, 0) 750 | 751 | def forward(self, x): 752 | pad = (0, 1, 0, 1) 753 | x = F.pad(x, pad, mode="constant", value=0) 754 | return self.conv(x) 755 | 756 | 757 | class NonLocalBlock(nn.Module): 758 | def __init__(self, channels): 759 | super(NonLocalBlock, self).__init__() 760 | self.in_channels = channels 761 | 762 | self.gn = GroupNorm(channels) 763 | self.q = nn.Conv2d(channels, channels, 1, 1, 0) 764 | self.k = nn.Conv2d(channels, channels, 1, 1, 0) 765 | self.v = nn.Conv2d(channels, channels, 1, 1, 0) 766 | self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0) 767 | 768 | def forward(self, x): 769 | h_ = self.gn(x) 770 | q = self.q(h_) 771 | k = self.k(h_) 772 | v = self.v(h_) 773 | 774 | b, c, h, w = q.shape 775 | 776 | q = q.reshape(b, c, h*w) 777 | q = q.permute(0, 2, 1) 778 | k = k.reshape(b, c, h*w) 779 | v = v.reshape(b, c, h*w) 780 | 781 | attn = torch.bmm(q, k) 782 | attn = attn * (int(c)**(-0.5)) 783 | attn = F.softmax(attn, dim=2) 784 | attn = attn.permute(0, 2, 1) 785 | 786 | A = torch.bmm(v, attn) 787 | A = A.reshape(b, c, h, w) 788 | 789 | return x + A 790 | 791 | 792 | class Encoder_Attn(nn.Module): 793 | def __init__(self, image_channels=3, latent_dim=128, channels = [128, 128, 128, 256, 256, 512], 794 | start_resolution_VAE_encoder=512, 795 | attn_resolutions_VAE_encoder = [16], 796 | num_res_blocks_encoder=2, 797 | ): 798 | super(Encoder_Attn, self).__init__() 799 | 800 | attn_resolutions = attn_resolutions_VAE_encoder 801 | num_res_blocks_encoder = num_res_blocks_encoder 802 | resolution = start_resolution_VAE_encoder 803 | 804 | layers = [nn.Conv2d(image_channels, channels[0], 3, 1, 1)] 805 | for i in range(len(channels)-1): 806 | in_channels = channels[i] 807 | out_channels = channels[i + 1] 808 | for j in range(num_res_blocks_encoder): 809 | layers.append(ResidualBlock(in_channels, out_channels)) 810 | in_channels = out_channels 811 | if resolution in attn_resolutions: 812 | layers.append(NonLocalBlock(in_channels)) 813 | #print (f"Added attention layer (encoder) at {resolution}") 814 | if i != len(channels)-2: 815 | layers.append(DownSampleBlock(channels[i+1])) 816 | resolution //= 2 817 | 818 | layers.append(ResidualBlock(channels[-1], channels[-1])) 819 | layers.append(NonLocalBlock(channels[-1])) 820 | layers.append(ResidualBlock(channels[-1], channels[-1])) 821 | layers.append(GroupNorm(channels=channels[-1] )) 822 | layers.append(Swish()) 823 | layers.append(nn.Conv2d(channels[-1], latent_dim, 3, 1, 1)) 824 | self.model = nn.Sequential(*layers) 825 | 826 | print ("Final resolution of encoder (must be start resolution of decoder): ", 2*resolution) 827 | 828 | def forward(self, x): 829 | return self.model(x) 830 | 831 | class Decoder_Attn(nn.Module): 832 | def __init__(self, image_channels=3, latent_dim=128,channels = [512, 256, 256, 128, 128], 833 | attn_resolutions_VAE_decoder=[16], 834 | start_resolution_VAE_decoder= 16, 835 | num_res_blocks_VAE_decoder=3, 836 | 837 | ): 838 | super(Decoder_Attn, self).__init__() 839 | 840 | attn_resolutions = attn_resolutions_VAE_decoder 841 | num_res_blocks = num_res_blocks_VAE_decoder 842 | resolution = start_resolution_VAE_decoder 843 | 844 | in_channels = channels[0] 845 | layers = [nn.Conv2d(latent_dim, in_channels, 3, 1, 1), 846 | ResidualBlock(in_channels, in_channels), 847 | NonLocalBlock(in_channels), 848 | ResidualBlock(in_channels, in_channels)] 849 | 850 | for i in range(len(channels)): 851 | out_channels = channels[i] 852 | for j in range(num_res_blocks_VAE_decoder): 853 | layers.append(ResidualBlock(in_channels, out_channels)) 854 | in_channels = out_channels 855 | if resolution in attn_resolutions: 856 | layers.append(NonLocalBlock(in_channels)) 857 | #print (f"Added attention layer (decoder) at {resolution}") 858 | if i != 0: 859 | layers.append(UpSampleBlock(in_channels)) 860 | resolution *= 2 861 | 862 | layers.append(GroupNorm(channels=in_channels )) 863 | layers.append(Swish()) 864 | layers.append(nn.Conv2d(in_channels, image_channels, 3, 1, 1)) 865 | self.model = nn.Sequential(*layers) 866 | 867 | print ("Final resolution of decoder: ", resolution//2) 868 | 869 | def forward(self, x): 870 | return self.model(x) 871 | 872 | -------------------------------------------------------------------------------- /HierarchicalDesign/__init__.py: -------------------------------------------------------------------------------- 1 | #import HierarchicalDesign.VQVAE 2 | #import HierarchicalDesign.Diffusion 3 | 4 | from HierarchicalDesign.utils import count_parameters 5 | from HierarchicalDesign.VQVAE import VectorQuantize, VQVAEModel, Encoder_Attn ,Decoder_Attn, get_fmap_from_codebook 6 | from HierarchicalDesign.Diffusion import HierarchicalDesignDiffusion, HiearchicalDesignTrainer, OneD_Unet, ElucidatedImagen, HierarchicalDesignDiffusion_PredictStressStrain 7 | -------------------------------------------------------------------------------- /HierarchicalDesign/utils.py: -------------------------------------------------------------------------------- 1 | ######################################################### 2 | # Utility functions 3 | ######################################################### 4 | 5 | def count_parameters (imagen): 6 | 7 | pytorch_total_params = sum(p.numel() for p in imagen.parameters()) 8 | pytorch_total_params_trainable = sum(p.numel() for p in imagen.parameters() if p.requires_grad) 9 | 10 | print ("----------------------------------------------------------------------------------------------------") 11 | print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable) 12 | print ("----------------------------------------------------------------------------------------------------") 13 | return 14 | 15 | -------------------------------------------------------------------------------- /HierarchicalDesign/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.001.A' 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Markus J. Buehler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | This code builds on earlier work by Phil Wang. 24 | 25 | MIT License 26 | 27 | Copyright (c) 2022 Phil Wang 28 | 29 | Permission is hereby granted, free of charge, to any person obtaining a copy 30 | of this software and associated documentation files (the "Software"), to deal 31 | in the Software without restriction, including without limitation the rights 32 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 33 | copies of the Software, and to permit persons to whom the Software is 34 | furnished to do so, subject to the following conditions: 35 | 36 | The above copyright notice and this permission notice shall be included in all 37 | copies or substantial portions of the Software. 38 | 39 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 40 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 41 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 42 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 43 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 44 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 45 | SOFTWARE. 46 | 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HierarchicalDesign: A computational building block approach towards multiscale architected materials analysis and design with application to hierarchical metal metamaterials 2 | ### Markus J. Buehler 3 | email: mbuehler@mit.edu 4 | 5 | We report a computational approach towards multiscale architected materials analysis and design. A particular challenge in modeling and simulation of materials, and especially the development of hierarchical design approaches, has been to identify ways by which complex multi-level material structures can be effectively modeled. One way to achieve this is to use coarse-graining approaches, where physical relationships can be effectively described with reduced dimensionality. In this paper we report an integrated deep neural network architecture that first learns coarse-grained representations of complex hierarchical microstructure data via a discrete variational autoencoder and then utilizes an attention-based diffusion model solve both forward and inverse problems, including a capacity to solve degenerate design problems. As an application, we demonstrate the method in the analysis and design of hierarchical highly porous metamaterials within the context of nonlinear stress-strain responses to compressive deformation. We validate the mechanical behavior and mechanisms of deformation using embedded-atom molecular dynamics simulations carried out for copper and nickel, showing good agreement with the design objectives. 6 | 7 | ### Key steps 8 | 9 | This repository contains a VQ-VAE model to learn codebook representations of hierarchical structures, and generative attention-diffusion model models to produce microstructural candidates from stress-strain conditioning, and stress-strain results from microstructural input. This code consists of 3 models 10 | 11 | - Model 1: VQ-VAE to encode hierarchical architected microstructures 12 | - Model 2: Diffusion model to predict hierarchical architected microstructures from a stress-strain response conditioning 13 | - Model 3: Diffusion model to predict stress-strain response from a microstructure 14 | 15 | Users should first train the VQ-VAE model (Model 1), then the attention-diffusion models (Models 2 and/or 3). 16 | 17 | ##### Reference: 18 | 19 | [1] M. Buehler, A computational building block approach towards multiscale architected materials analysis and design with application to hierarchical metal metamaterials, Modelling and Simulation in Materials Science and Engineering, 2023, https://doi.org/10.1088/1361-651X/accfb5 20 | 21 | ### Overview of the problem solved 22 | 23 | A bioinspired hierarchical honeycomb material is considered in this study, featuring multiple hierarchical levels incorporated into a complex design space. Panel a shows an overview of the hierarchical makeup with four levels of hierarchy ranging from H1…H4. Panel b summarizes the mechanical boundary condition used to assess mechanical performance by applying compressive loading. By generating a large number of hierarchical designs and associated stress-strain responses, we construct a data set that consists of paired relationships between microstructure images and nonlinear mechanical properties. Panel c summarizes the two problems addressed here. The forward problem produces a stress-strain response based on the input microstructure. In the inverse problem microstructure candidates are generated based on an input, desired, stress-strain response. 24 | 25 | ![image](https://user-images.githubusercontent.com/101393859/228824190-d5f5c5f5-babd-4d99-b802-08c4590ddfaa.png) 26 | 27 | 28 | ### How to install and use 29 | 30 | ``` 31 | conda create -n HierarchicalDesign python=3.8 32 | conda activate HierarchicalDesign 33 | ``` 34 | ``` 35 | git clone https://github.com/lamm-mit/HierarchicalDesign/ 36 | cd HierarchicalDesign 37 | ``` 38 | 39 | Then, install HierarchicalDesign: 40 | 41 | ``` 42 | pip install -e . 43 | ``` 44 | 45 | Start Jupyter Lab (or Jupyter Notebook): 46 | 47 | ``` 48 | jupyter-lab --no-browser 49 | ``` 50 | Then open the sample Jupyter file and train and/or load pretrained models. 51 | 52 | - Model 1: [VQ-VAE model to learn codebook representations of hierarchical structures: VQ_VAE_Microstructure.ipynb](VQ_VAE_Microstructure.ipynb) 53 | - Model 2: [Generative attention-diffusion model: HierarchicalDesignDIffusion_GetMicrostructure.ipynb](HierarchicalDesignDIffusion_GetMicrostructure.ipynb) 54 | - Model 3: [Attention-diffusion model to predict stress-strain response from microstructure: HierarchicalDesignDIffusion_GetStressStrain.ipynb](HierarchicalDesignDIffusion_GetStressStrain.ipynb) 55 | 56 | ### Details on the architecture and approach 57 | 58 | The figure shows an overview of the neural network architecture used to solve this problem. The model consists of two parts. First (panel a), a vector quantized variational autoencoder (VQ-VAE) architecture that learns to encode microstructure images into a lower-dimensional latent space. We use a discrete approach that encodes data into a discrete codebook representation that consists of a one-dimensional vector of length N where each entry is one of n_c possible “words” in the design language that defines the microstructures. 59 | 60 | The encoder and decoder blocks each consist of a deep neural network featuring convolutional and attention layers. The VQ-VAE model is trained based on unlabeled data of microstructure images. In the next step (panel b), the pre-trained VQ-VAE model is used as an encoding mechanism to train a diffusion model, where it learns how to produce codebook representations that satisfy a certain conditioning. During training, pairs of conditioning and codebook representations of microstructures are used to minimize the reconstruction loss. Once trained (panel c), the model can be used to generate microstructure solutions based on a certain conditioning stress-strain laws. 61 | 62 | The stress-strain response is encoded as a series of normalized floating point numbers, concatenated with Fourier positional encoding. An identical model is developed and trained also for the forward problem, where the conditioning is the input microstructure, and the diffusion model produces stress-strain responses. 63 | 64 | ![image](https://user-images.githubusercontent.com/101393859/228824011-86f1e866-5cce-4b90-9c9e-64ed88fcab68.png) 65 | 66 | ### Datasets, weights, and additional information 67 | 68 | - [Dataset CSV file](https://www.dropbox.com/s/tg4j25rmga4agu4/shc_09_raw_wimagename_5_thick.csv?dl=0) 69 | - [Microstructure data](https://www.dropbox.com/s/dz1hnhedjocacqu/gene_output_thick.zip?dl=0) 70 | - [Model 1 weights](https://www.dropbox.com/s/cluk4e4q95laqu2/model-epoch_128.pt?dl=0) 71 | - [Model 2 weights](https://www.dropbox.com/s/nnus6m2z5jbbshv/statedict_save-model-epoch_4000_FINAL.pt?dl=0) 72 | - [Model 3 weights](https://www.dropbox.com/s/u4mojfwp2uxjqkh/statedict_save-model-epoch_610_FINAL.pt?dl=0) 73 | 74 | ### Overview of the data and format 75 | 76 | You need two files to train the model. First, a CSV file that includes both, references to the microstructure data (column "microstructure") and stress data (column "stresses"). There is no separate strain data stored; the list of strains per stress increment is identical for all samples and defined in the code. 77 | 78 | Example, a list of stress values as stored in the column "stresses": 79 | ``` 80 | [2.8133392e-04, 5.0026216e-02, 1.0911082e-01, 1.5260771e-01, 1.9775425e-01, 81 | 2.4043675e-01, 2.8037483e-01, 2.9301879e-01, 2.9600343e-01, 2.9962808e-01, 82 | 3.0461299e-01, 3.0993605e-01, 3.1720343e-01, 3.2225695e-01, 3.2850131e-01, 83 | 3.3622128e-01, 3.4194285e-01, 3.4944272e-01, 3.5820404e-01, 3.6438131e-01, 84 | 3.7371701e-01, 3.8643220e-01, 4.0058151e-01, 4.1546318e-01, 4.3120158e-01, 85 | 4.4801408e-01, 4.6275303e-01, 4.7932705e-01, 4.9467662e-01, 5.1254719e-01, 86 | 5.3236824e-01, 5.5691981e-01] 87 | ``` 88 | Second, the microstructure data consists of images, each of which is associated to a list of stress data as defined in the CSV file. E.g., the associated image would be identified in the column "microstructures", e.g. as "microstructure_1.png" in the same row as the above list of stresses. The string of stresses is converted into a list of floating point numbers in the data loader. 89 | 90 | ### Acknowledgements 91 | 92 | This code is based on [https://github.com/lucidrains/imagen-pytorch](https://github.com/lucidrains/imagen-pytorch) and [https://github.com/lamm-mit/DynaGen](https://github.com/lamm-mit/DynaGen). 93 | 94 | ``` 95 | @article{BuehlerMSMSE_2023, 96 | title = {A computational building block approach towards multiscale architected materials analysis and design with application to hierarchical metal metamaterials}, 97 | author = {M.J. Buehler}, 98 | journal = {Modelling and Simulation in Materials Science and Engineering}, 99 | year = {2023}, 100 | volume = {}, 101 | pages = {}, 102 | url = {https://doi.org/10.1088/1361-651X/accfb5} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | exec(open('HierarchicalDesign/version.py').read()) 3 | 4 | setup( 5 | name = 'HierarchicalDesign', 6 | packages = find_packages(exclude=[]), 7 | include_package_data = True, 8 | version = __version__, 9 | license='MIT', 10 | description = 'HierarchicalDesign - learn, model and design hierarchical materials and structures', 11 | author = 'Markus J Buehler', 12 | author_email = 'mbuehler@MIT.EDU', 13 | long_description_content_type = 'text/markdown', 14 | url = 'https://github.com/lamm-mit/HierarchicalDesign', 15 | keywords = [ 16 | 'artificial intelligence', 17 | 'deep learning', 18 | 'transformers', 19 | 'VQ-VAE', 20 | 'scientific machine learning', 21 | 'dynamical processes', 22 | 'denoising-diffusion', 23 | 'materials science' 24 | ], 25 | install_requires=[ 26 | 'accelerate', 27 | 'click', 28 | 'einops>=0.4', 29 | 'einops-exts', 30 | 'ema-pytorch>=0.0.3', 31 | 'fsspec', 32 | 'kornia', 33 | 'numpy', 34 | 'packaging', 35 | 'pillow', 36 | 'pydantic', 37 | 'pytorch-lightning', 38 | 'pytorch-warmup', 39 | 'sentencepiece', 40 | 'torch', 41 | 'torchvision', 42 | 'transformers', 43 | 'tqdm', 44 | 'jupyterlab', 45 | 'matplotlib', 46 | 'scikit-learn', 47 | 'pandas', 48 | 'IProgress', 49 | 'ipywidgets', 50 | 'opencv-python', 51 | 'seaborn', 52 | ], 53 | classifiers=[ 54 | 'Development Status :: 4 - Beta', 55 | 'Intended Audience :: Developers', 56 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 57 | 'License :: OSI Approved :: MIT License', 58 | 'Programming Language :: Python :: 3.8', 59 | ], 60 | ) 61 | --------------------------------------------------------------------------------