├── README.md ├── config ├── cifar10 │ ├── C.yaml │ └── E.yaml ├── cifar100 │ ├── C.yaml │ └── E.yaml └── tinyimagenet │ ├── C.yaml │ └── E.yaml ├── datasets.py ├── linear.py ├── model ├── EDM.py ├── __init__.py ├── blockC.py ├── blockE.py ├── blockG.py ├── models.py ├── unetC.py ├── unetE.py └── unetG.py ├── sample.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # EDM2: Analyzing and Improving the Training Dynamics of Diffusion Models 2 | 3 | This is a multi-gpu PyTorch implementation of the paper [Analyzing and Improving the Training Dynamics of Diffusion Models](https://arxiv.org/abs/2312.02696): 4 | ```bibtex 5 | @article{karras2023analyzing, 6 | title={Analyzing and Improving the Training Dynamics of Diffusion Models}, 7 | author={Karras, Tero and Aittala, Miika and Lehtinen, Jaakko and Hellsten, Janne and Aila, Timo and Laine, Samuli}, 8 | journal={arXiv preprint arXiv:2312.02696}, 9 | year={2023} 10 | } 11 | ``` 12 | :exclamation: This repo only contains configs and experiments on small or medium scale datasets such as CIFAR-10/100 and Tiny-ImageNet. Full re-implementation on ImageNet-1k would be extremely expensive. 13 | 14 | :fire: This repo contains implementations of `Config C`, `Config E` and the final `Config G` models. You can compare `block[C/E/G].py` and `unet[C/E/G].py` against each other to learn about the improvements proposed by the authors. :smile: 15 | 16 | ## Requirements 17 | In addition to PyTorch environments, please install: 18 | ```sh 19 | conda install pyyaml 20 | pip install ema-pytorch tensorboard 21 | ``` 22 | 23 | ## Usage 24 | Use 4 GPUs to train unconditional Config C/E/G models on the CIFAR-100 dataset: 25 | ```sh 26 | torchrun --nproc_per_node=4 27 | train.py --config config/cifar100/C.yaml --use_amp 28 | train.py --config config/cifar100/E.yaml --use_amp 29 | train.py --config config/cifar100/G.yaml --use_amp 30 | ``` 31 | 32 | To generate 50000 images with different checkpoints, for example, run: 33 | ```sh 34 | torchrun --nproc_per_node=4 35 | sample.py --config config/cifar100/C.yaml --use_amp --epoch 1000 36 | sample.py --config config/cifar100/E.yaml --use_amp --epoch 1600 37 | sample.py --config config/cifar100/G.yaml --use_amp --epoch 1999 38 | ``` 39 | 40 | ## Observations and takeaways :star: 41 | - Config C shows **consistent improvements** over the original EDM. The main contribution is from the multi-task weighting. 42 | - Config E could perform better than Config C, but the convergence is **significantly slower** (see the epoch counts below). 43 | - Config G seems to **favor latent-space modeling**, and it's inferior to Config E on pixel-space generation. 44 | - Post-hoc EMA (and the power function EMA) tends to **favor longer training** durations. For small-sized datasets like CIFAR, it won't help much. 45 | 46 | ## Results 47 | We report Config C and Config E results on CIFAR-10, CIFAR-100, and Tiny-ImageNet datasets. 48 | 49 | | Config | Model | Network size | Best FID (18 steps) | Best linear probe acc. | 50 | |:--------------------|:--------------|:-------------|:----------------------|:-----------------------| 51 | | cifar10/C.yaml | Uncond. EDM2C | 39.5M | 3.03 @ epoch 1000 | 91.85 @ epoch 500 | 52 | | cifar10/E.yaml | Uncond. EDM2E | 39.5M | 2.72 @ epoch 2000 | 93.46 @ epoch 1000 | 53 | | cifar100/C.yaml | Uncond. EDM2C | 39.5M | 5.06 @ epoch 1000 | 65.40 @ epoch 500 | 54 | | cifar100/E.yaml | Uncond. EDM2E | 39.5M | 4.33 @ epoch 2000 | 69.04 @ epoch 1100 | 55 | | tinyimagenet/C.yaml | Uncond. EDM2C | 62.4M | 15.96 @ epoch 1600* | 50.99 @ epoch 600 | 56 | | tinyimagenet/E.yaml | Uncond. EDM2E | 62.4M | 16.79 @ epoch 1500* | 52.07 @ epoch 1400 | 57 | 58 | *Note: Unfinished training (due to high computational cost). The FID has not saturated, and keep training can lead to lower FIDs. 59 | -------------------------------------------------------------------------------- /config/cifar10/C.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar' 3 | classes: 10 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNetC' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 32, 32] 20 | n_channels: 128 21 | ch_mults: [1, 2, 2, 2] 22 | is_attn: [False, True, False, False] 23 | dropout: 0.13 24 | n_blocks: 3 # equiv. to "n_blocks=2" in UNetHo + "use BigGAN up/down" 25 | 26 | # training params 27 | n_epoch: 1000 28 | batch_size: 512 29 | lrate: 4.0e-4 30 | warm_epoch: 200 31 | load_epoch: -1 32 | flip: True 33 | ema: 0.9993 34 | # optim: 'Adam' 35 | # optim_args: 36 | # betas: [0.9, 0.99] 37 | 38 | # testing params 39 | n_sample: 30 40 | save_dir: './output_C10' 41 | save_model: True 42 | 43 | # linear probe 44 | linear: 45 | n_epoch: 15 46 | batch_size: 128 47 | lrate: 1.0e-3 48 | timestep: 4 49 | blockname: 'out_6' 50 | -------------------------------------------------------------------------------- /config/cifar10/E.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar' 3 | classes: 10 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNetE' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 32, 32] 20 | n_channels: 128 21 | ch_mults: [1, 2, 2, 2] 22 | is_attn: [False, True, False, False] 23 | dropout: 0.13 24 | n_blocks: 3 # equiv. to "n_blocks=2" in UNetHo + "use BigGAN up/down" 25 | 26 | # training params 27 | n_epoch: 2000 28 | batch_size: 512 29 | lrate: 2.0e-2 30 | warm_epoch: 200 31 | tref_epoch: 400 32 | load_epoch: -1 33 | flip: True 34 | ema: 0.9993 35 | # optim: 'Adam' 36 | # optim_args: 37 | # betas: [0.9, 0.99] 38 | 39 | # testing params 40 | n_sample: 30 41 | save_dir: './output_E10' 42 | save_model: True 43 | 44 | # linear probe 45 | linear: 46 | n_epoch: 15 47 | batch_size: 128 48 | lrate: 1.0e-3 49 | timestep: 4 50 | blockname: 'out_6' 51 | -------------------------------------------------------------------------------- /config/cifar100/C.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar100' 3 | classes: 100 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNetC' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 32, 32] 20 | n_channels: 128 21 | ch_mults: [1, 2, 2, 2] 22 | is_attn: [False, True, False, False] 23 | dropout: 0.13 24 | n_blocks: 3 # equiv. to "n_blocks=2" in UNetHo + "use BigGAN up/down" 25 | 26 | # training params 27 | n_epoch: 1000 28 | batch_size: 512 29 | lrate: 4.0e-4 30 | warm_epoch: 200 31 | load_epoch: -1 32 | flip: True 33 | ema: 0.9993 34 | # optim: 'Adam' 35 | # optim_args: 36 | # betas: [0.9, 0.99] 37 | 38 | # testing params 39 | n_sample: 30 40 | save_dir: './output_C100' 41 | save_model: True 42 | 43 | # linear probe 44 | linear: 45 | n_epoch: 15 46 | batch_size: 128 47 | lrate: 1.0e-3 48 | timestep: 4 49 | blockname: 'out_6' 50 | -------------------------------------------------------------------------------- /config/cifar100/E.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'cifar100' 3 | classes: 100 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNetE' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 32, 32] 20 | n_channels: 128 21 | ch_mults: [1, 2, 2, 2] 22 | is_attn: [False, True, False, False] 23 | dropout: 0.13 24 | n_blocks: 3 # equiv. to "n_blocks=2" in UNetHo + "use BigGAN up/down" 25 | 26 | # training params 27 | n_epoch: 2000 28 | batch_size: 512 29 | lrate: 2.0e-2 30 | warm_epoch: 200 31 | tref_epoch: 400 32 | load_epoch: -1 33 | flip: True 34 | ema: 0.9993 35 | # optim: 'Adam' 36 | # optim_args: 37 | # betas: [0.9, 0.99] 38 | 39 | # testing params 40 | n_sample: 30 41 | save_dir: './output_E100' 42 | save_model: True 43 | 44 | # linear probe 45 | linear: 46 | n_epoch: 15 47 | batch_size: 128 48 | lrate: 1.0e-3 49 | timestep: 4 50 | blockname: 'out_6' 51 | -------------------------------------------------------------------------------- /config/tinyimagenet/C.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'tiny' 3 | classes: 200 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNetC' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 64, 64] 20 | n_channels: 128 21 | ch_mults: [1, 2, 2, 2] 22 | is_attn: [False, False, True, False] 23 | dropout: 0.1 24 | n_blocks: 5 # equiv. to "n_blocks=4" in UNetHo + "use BigGAN up/down" 25 | 26 | # training params 27 | n_epoch: 2000 28 | batch_size: 192 29 | lrate: 1.5e-4 30 | warm_epoch: 200 31 | load_epoch: -1 32 | flip: False 33 | ema: 0.9993 34 | # optim: 'Adam' 35 | # optim_args: 36 | # betas: [0.9, 0.99] 37 | 38 | # testing params 39 | n_sample: 30 40 | save_dir: './output_Ctiny' 41 | save_model: True 42 | 43 | # linear probe 44 | linear: 45 | n_epoch: 15 46 | batch_size: 128 47 | lrate: 1.0e-3 48 | timestep: 5 49 | blockname: 'out_6' 50 | -------------------------------------------------------------------------------- /config/tinyimagenet/E.yaml: -------------------------------------------------------------------------------- 1 | # dataset params 2 | dataset: 'tiny' 3 | classes: 200 4 | 5 | # model params 6 | model_type: 'EDM' 7 | net_type: 'UNetE' 8 | diffusion: 9 | sigma_data: 0.5 10 | p_mean: -1.2 11 | p_std: 1.2 12 | sigma_min: 0.002 13 | sigma_max: 80 14 | rho: 7 15 | S_min: 0.01 16 | S_max: 1 17 | S_noise: 1.007 18 | network: 19 | image_shape: [3, 64, 64] 20 | n_channels: 128 21 | ch_mults: [1, 2, 2, 2] 22 | is_attn: [False, False, True, False] 23 | dropout: 0.1 24 | n_blocks: 5 # equiv. to "n_blocks=4" in UNetHo + "use BigGAN up/down" 25 | 26 | # training params 27 | n_epoch: 2000 28 | batch_size: 192 29 | lrate: 7.5e-3 30 | warm_epoch: 200 31 | tref_epoch: 400 32 | load_epoch: -1 33 | flip: False 34 | ema: 0.9993 35 | # optim: 'Adam' 36 | # optim_args: 37 | # betas: [0.9, 0.99] 38 | 39 | # testing params 40 | n_sample: 30 41 | save_dir: './output_Etiny' 42 | save_model: True 43 | 44 | # linear probe 45 | linear: 46 | n_epoch: 15 47 | batch_size: 128 48 | lrate: 1.0e-3 49 | timestep: 5 50 | blockname: 'out_6' 51 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from torchvision.datasets import CIFAR10, CIFAR100 7 | 8 | 9 | class TinyImageNet(Dataset): 10 | def __init__(self, root, train=True, transform=None): 11 | if not root.endswith("tiny-imagenet-200"): 12 | root = os.path.join(root, "tiny-imagenet-200") 13 | self.train_dir = os.path.join(root, "train") 14 | self.val_dir = os.path.join(root, "val") 15 | self.transform = transform 16 | if train: 17 | self._scan_train() 18 | else: 19 | self._scan_val() 20 | 21 | def _scan_train(self): 22 | classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()] 23 | classes = sorted(classes) 24 | assert len(classes) == 200 25 | 26 | self.data = [] 27 | for idx, name in enumerate(classes): 28 | this_dir = os.path.join(self.train_dir, name) 29 | for root, _, files in sorted(os.walk(this_dir)): 30 | for fname in sorted(files): 31 | if fname.endswith(".JPEG"): 32 | path = os.path.join(root, fname) 33 | item = (path, idx) 34 | self.data.append(item) 35 | self.labels_dict = {i: classes[i] for i in range(len(classes))} 36 | 37 | def _scan_val(self): 38 | self.file_to_class = {} 39 | classes = set() 40 | with open(os.path.join(self.val_dir, "val_annotations.txt"), 'r') as f: 41 | lines = f.readlines() 42 | for line in lines: 43 | words = line.split("\t") 44 | self.file_to_class[words[0]] = words[1] 45 | classes.add(words[1]) 46 | classes = sorted(list(classes)) 47 | assert len(classes) == 200 48 | 49 | class_to_idx = {classes[i]: i for i in range(len(classes))} 50 | self.data = [] 51 | this_dir = os.path.join(self.val_dir, "images") 52 | for root, _, files in sorted(os.walk(this_dir)): 53 | for fname in sorted(files): 54 | if fname.endswith(".JPEG"): 55 | path = os.path.join(root, fname) 56 | idx = class_to_idx[self.file_to_class[fname]] 57 | item = (path, idx) 58 | self.data.append(item) 59 | self.labels_dict = {i: classes[i] for i in range(len(classes))} 60 | 61 | def __len__(self): 62 | return len(self.data) 63 | 64 | def __getitem__(self, idx): 65 | path, label = self.data[idx] 66 | image = Image.open(path) 67 | image = image.convert("RGB") 68 | 69 | if self.transform: 70 | image = self.transform(image) 71 | 72 | return image, label 73 | 74 | 75 | def get_dataset(name, root="./data", train=True, flip=False, crop=False, resize=None): 76 | if name == 'cifar': 77 | DATASET = CIFAR10 78 | RES = 32 79 | elif name == 'cifar100': 80 | DATASET = CIFAR100 81 | RES = 32 82 | elif name == 'tiny': 83 | DATASET = TinyImageNet 84 | RES = 64 85 | else: 86 | raise NotImplementedError 87 | 88 | tf = [transforms.ToTensor()] 89 | if resize is not None: 90 | tf = [transforms.Resize(resize)] + tf 91 | if train: 92 | if crop: 93 | tf = [transforms.RandomCrop(RES, 4)] + tf 94 | if flip: 95 | tf = [transforms.RandomHorizontalFlip()] + tf 96 | 97 | return DATASET(root=root, train=train, transform=transforms.Compose(tf)) 98 | -------------------------------------------------------------------------------- /linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from functools import partial 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import yaml 8 | import torch.nn as nn 9 | from datasets import get_dataset 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | from tqdm import tqdm 12 | from ema_pytorch import EMA 13 | 14 | from model.models import get_models_class 15 | from utils import Config, init_seeds, gather_tensor, DataLoaderDDP, print0 16 | 17 | 18 | def get_model(opt, load_epoch): 19 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 20 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 21 | **opt.diffusion, 22 | device=device, 23 | ) 24 | diff.to(device) 25 | target = os.path.join(opt.save_dir, "ckpts", f"model_{load_epoch}.pth") 26 | print0("loading model at", target) 27 | checkpoint = torch.load(target, map_location=device) 28 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 29 | ema.to(device) 30 | ema.load_state_dict(checkpoint['EMA']) 31 | model = ema.ema_model 32 | model.eval() 33 | return model 34 | 35 | 36 | class ClassifierDict(nn.Module): 37 | def __init__(self, feat_func, time_list, name_list, base_lr, epoch, img_shape, local_rank, num_classes): 38 | super(ClassifierDict, self).__init__() 39 | self.feat_func = feat_func 40 | self.times = time_list 41 | self.names = name_list 42 | self.classifiers = nn.ModuleDict() 43 | self.optims = {} 44 | self.schedulers = {} 45 | self.loss_fn = nn.CrossEntropyLoss() 46 | 47 | for time in self.times: 48 | feats = self.feat_func(torch.zeros(1, *img_shape).to(device), time) 49 | if self.names is None: 50 | self.names = list(feats.keys()) # all available names 51 | 52 | for name in self.names: 53 | key = self.make_key(time, name) 54 | layers = nn.Linear(feats[name].shape[1], num_classes) 55 | layers = torch.nn.parallel.DistributedDataParallel( 56 | layers.to(device), device_ids=[local_rank], output_device=local_rank) 57 | optimizer = torch.optim.Adam(layers.parameters(), lr=base_lr) 58 | scheduler = CosineAnnealingLR(optimizer, epoch) 59 | self.classifiers[key] = layers 60 | self.optims[key] = optimizer 61 | self.schedulers[key] = scheduler 62 | 63 | def train(self, x, y): 64 | self.classifiers.train() 65 | for time in self.times: 66 | feats = self.feat_func(x, time) 67 | for name in self.names: 68 | key = self.make_key(time, name) 69 | representation = feats[name].detach() 70 | logit = self.classifiers[key](representation) 71 | loss = self.loss_fn(logit, y) 72 | 73 | self.optims[key].zero_grad() 74 | loss.backward() 75 | self.optims[key].step() 76 | 77 | def test(self, x): 78 | outputs = {} 79 | with torch.no_grad(): 80 | self.classifiers.eval() 81 | for time in self.times: 82 | feats = self.feat_func(x, time) 83 | for name in self.names: 84 | key = self.make_key(time, name) 85 | representation = feats[name].detach() 86 | logit = self.classifiers[key](representation) 87 | pred = logit.argmax(dim=-1) 88 | outputs[key] = pred 89 | return outputs 90 | 91 | def make_key(self, t, n): 92 | return str(t) + '/' + n 93 | 94 | def get_lr(self): 95 | key = self.make_key(self.times[0], self.names[0]) 96 | optim = self.optims[key] 97 | return optim.param_groups[0]['lr'] 98 | 99 | def schedule_step(self): 100 | for time in self.times: 101 | for name in self.names: 102 | key = self.make_key(time, name) 103 | self.schedulers[key].step() 104 | 105 | 106 | def train(opt): 107 | def test(): 108 | preds = {k: [] for k in classifiers.optims.keys()} 109 | accs = {} 110 | labels = [] 111 | for image, label in tqdm(valid_loader, disable=(local_rank!=0)): 112 | outputs = classifiers.test(image.to(device)) 113 | for key in outputs: 114 | preds[key].append(outputs[key]) 115 | labels.append(label.to(device)) 116 | 117 | for key in preds: 118 | preds[key] = torch.cat(preds[key]) 119 | label = torch.cat(labels) 120 | dist.barrier() 121 | label = gather_tensor(label) 122 | for key in preds: 123 | pred = gather_tensor(preds[key]) 124 | accs[key] = (pred == label).sum().item() / len(label) 125 | return accs 126 | 127 | yaml_path = opt.config 128 | ep = opt.epoch 129 | use_amp = opt.use_amp 130 | grid_search = opt.grid 131 | with open(yaml_path, 'r') as f: 132 | opt = yaml.full_load(f) 133 | print0(opt) 134 | opt = Config(opt) 135 | if ep == -1: 136 | ep = opt.n_epoch - 1 137 | model = get_model(opt, ep) 138 | 139 | epoch = opt.linear['n_epoch'] 140 | batch_size = opt.linear['batch_size'] 141 | base_lr = opt.linear['lrate'] 142 | 143 | if grid_search: 144 | time_list = [3, 4, 5] 145 | name_list = None 146 | else: 147 | time_list = [opt.linear['timestep']] 148 | name_list = [opt.linear['blockname']] 149 | 150 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=True) 151 | valid_set = get_dataset(name=opt.dataset, root="./data", train=False) 152 | train_loader, sampler = DataLoaderDDP( 153 | train_set, 154 | batch_size=batch_size, 155 | shuffle=True, 156 | ) 157 | valid_loader, _ = DataLoaderDDP( 158 | valid_set, 159 | batch_size=batch_size, 160 | shuffle=False, 161 | ) 162 | 163 | feat_func = partial(model.get_feature, norm=False, use_amp=use_amp) 164 | DDP_multiplier = dist.get_world_size() 165 | print0("Using DDP, lr = %f * %d" % (base_lr, DDP_multiplier)) 166 | base_lr *= DDP_multiplier 167 | classifiers = ClassifierDict(feat_func, time_list, name_list, 168 | base_lr, epoch, opt.network['image_shape'], local_rank, opt.classes).to(model.device) 169 | 170 | for e in range(epoch): 171 | sampler.set_epoch(e) 172 | pbar = tqdm(train_loader, disable=(local_rank!=0)) 173 | for i, (image, label) in enumerate(pbar): 174 | pbar.set_description("[epoch %d / iter %d]: lr: %.1e" % (e, i, classifiers.get_lr())) 175 | classifiers.train(image.to(device), label.to(device)) 176 | classifiers.schedule_step() 177 | 178 | accs = test() 179 | for key in accs: 180 | print0("[key %s]: Test acc: %.2f" % (key, accs[key] * 100)) 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--config", type=str) 186 | parser.add_argument("--epoch", type=int, default=-1) 187 | parser.add_argument("--use_amp", action='store_true', default=False) 188 | parser.add_argument("--grid", action='store_true', default=False) 189 | opt = parser.parse_args() 190 | opt.local_rank = int(os.environ['LOCAL_RANK']) 191 | print0(opt) 192 | 193 | local_rank = opt.local_rank 194 | init_seeds(no=local_rank) 195 | dist.init_process_group(backend='nccl') 196 | torch.cuda.set_device(local_rank) 197 | device = "cuda:%d" % local_rank 198 | 199 | train(opt) 200 | -------------------------------------------------------------------------------- /model/EDM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from torch.cuda.amp import autocast as autocast 6 | 7 | 8 | def normalize_to_neg_one_to_one(img): 9 | # [0.0, 1.0] -> [-1.0, 1.0] 10 | return img * 2 - 1 11 | 12 | 13 | def unnormalize_to_zero_to_one(t): 14 | # [-1.0, 1.0] -> [0.0, 1.0] 15 | return (t + 1) * 0.5 16 | 17 | 18 | class EDM(nn.Module): 19 | def __init__(self, nn_model, 20 | sigma_data, p_mean, p_std, 21 | sigma_min, sigma_max, rho, 22 | S_min, S_max, S_noise, 23 | device): 24 | ''' EDM proposed by "Elucidating the Design Space of Diffusion-Based Generative Models". 25 | 26 | Args: 27 | nn_model: A network (e.g. UNet) which performs same-shape mapping. 28 | device: The CUDA device that tensors run on. 29 | Training parameters: 30 | sigma_data, p_mean, p_std 31 | Sampling parameters: 32 | sigma_min, sigma_max, rho 33 | S_min, S_max, S_noise 34 | ''' 35 | super(EDM, self).__init__() 36 | self.nn_model = nn_model.to(device) 37 | params = sum(p.numel() for p in nn_model.parameters() if p.requires_grad) / 1e6 38 | print(f"nn model # params: {params:.1f}") 39 | 40 | self.device = device 41 | 42 | def number_to_torch_device(value): 43 | return torch.tensor(value).to(device) 44 | 45 | self.sigma_data = number_to_torch_device(sigma_data) 46 | self.p_mean = number_to_torch_device(p_mean) 47 | self.p_std = number_to_torch_device(p_std) 48 | self.sigma_min = number_to_torch_device(sigma_min) 49 | self.sigma_max = number_to_torch_device(sigma_max) 50 | self.rho = number_to_torch_device(rho) 51 | self.S_min = number_to_torch_device(S_min) 52 | self.S_max = number_to_torch_device(S_max) 53 | self.S_noise = number_to_torch_device(S_noise) 54 | self.reweight_mlp = self.nn_model.get_reweighting() 55 | 56 | def perturb(self, x, t=None, steps=None): 57 | ''' Add noise to a clean image (diffusion process). 58 | 59 | Args: 60 | x: The normalized image tensor. 61 | t: The specified timestep ranged in `[1, steps]`. Type: int / torch.LongTensor / None. \ 62 | Random `ln(sigma) ~ N(P_mean, P_std)` is taken if t is None. 63 | Returns: 64 | The perturbed image, and the corresponding sigma. 65 | ''' 66 | if t is None: 67 | rnd_normal = torch.randn((x.shape[0], 1, 1, 1)).to(self.device) 68 | sigma = (rnd_normal * self.p_std + self.p_mean).exp() 69 | else: 70 | times = reversed(self.sample_schedule(steps)) 71 | sigma = times[t] 72 | if len(sigma.shape) == 1: 73 | sigma = sigma[:, None, None, None] 74 | 75 | noise = torch.randn_like(x) 76 | x_noised = x + noise * sigma 77 | return x_noised, sigma 78 | 79 | def forward(self, x, use_amp=False): 80 | ''' Training with weighted denoising loss. 81 | 82 | Args: 83 | x: The clean image tensor ranged in `[0, 1]`. 84 | Returns: 85 | The weighted MSE loss. 86 | ''' 87 | x = normalize_to_neg_one_to_one(x) 88 | x_noised, sigma = self.perturb(x, t=None) 89 | 90 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 91 | u = self.nn_model.forward_reweighting(self.reweight_mlp, sigma) 92 | loss_4shape = weight * ((x - self.D_x(x_noised, sigma, use_amp)) ** 2) / u.exp() + u 93 | return loss_4shape.mean() 94 | 95 | def get_feature(self, x, t, steps=18, name=None, norm=False, use_amp=False): 96 | ''' Get network's intermediate activation in a forward pass. 97 | 98 | Args: 99 | x: The clean image tensor ranged in `[0, 1]`. 100 | t: The specified timestep ranged in `[1, steps]`. Type: int / torch.LongTensor. 101 | norm: to normalize features to the the unit hypersphere. 102 | Returns: 103 | A {name: tensor} dict which contains global average pooled features. 104 | ''' 105 | x = normalize_to_neg_one_to_one(x) 106 | x_noised, sigma = self.perturb(x, t, steps) 107 | 108 | def gap_and_norm(act, norm=False): 109 | if len(act.shape) == 4: 110 | # unet (B, C, H, W) 111 | act = act.view(act.shape[0], act.shape[1], -1).float() 112 | act = torch.mean(act, dim=2) 113 | else: 114 | raise NotImplementedError 115 | if norm: 116 | act = torch.nn.functional.normalize(act) 117 | return act 118 | 119 | _, acts = self.D_x(x_noised, sigma, use_amp, ret_activation=True) 120 | all_feats = {blockname: gap_and_norm(acts[blockname], norm) for blockname in acts} 121 | if name is not None: 122 | return all_feats[name] 123 | else: 124 | return all_feats 125 | 126 | def edm_sample(self, n_sample, size, steps=18, eta=0.0, notqdm=False, use_amp=False): 127 | ''' Sampling with EDM sampler. Actual NFE is `2 * steps - 1`. 128 | 129 | Args: 130 | n_sample: The batch size. 131 | size: The image shape (e.g. `(3, 32, 32)`). 132 | steps: The number of total timesteps. 133 | eta: controls stochasticity. Set `eta=0` for deterministic sampling. 134 | Returns: 135 | The sampled image tensor ranged in `[0, 1]`. 136 | ''' 137 | S_min, S_max, S_noise = self.S_min, self.S_max, self.S_noise 138 | gamma_stochasticity = torch.tensor(np.sqrt(2) - 1) * eta # S_churn = (sqrt(2) - 1) * eta * steps 139 | 140 | times = self.sample_schedule(steps) 141 | time_pairs = list(zip(times[:-1], times[1:])) 142 | 143 | x_next = torch.randn(n_sample, *size).to(self.device).to(torch.float64) * times[0] 144 | for i, (t_cur, t_next) in enumerate(tqdm(time_pairs, disable=notqdm)): # 0, ..., N-1 145 | x_cur = x_next 146 | 147 | # Increase noise temporarily. 148 | gamma = gamma_stochasticity if S_min <= t_cur <= S_max else 0 149 | t_hat = t_cur + gamma * t_cur 150 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) 151 | 152 | # Euler step. 153 | d_cur = self.pred_eps_(x_hat, t_hat, use_amp) 154 | x_next = x_hat + (t_next - t_hat) * d_cur 155 | 156 | # Apply 2nd order correction. 157 | if i < steps - 1: 158 | d_prime = self.pred_eps_(x_next, t_next, use_amp) 159 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 160 | 161 | return unnormalize_to_zero_to_one(x_next) 162 | 163 | def pred_eps_(self, x, t, use_amp, clip_x=True): 164 | denoised = self.D_x(x, t, use_amp).to(torch.float64) 165 | # pixel-space clipping (optional) 166 | if clip_x: 167 | denoised = torch.clip(denoised, -1., 1.) 168 | eps = (x - denoised) / t 169 | return eps 170 | 171 | def D_x(self, x_noised, sigma, use_amp, ret_activation=False): 172 | ''' Denoising with network preconditioning. 173 | 174 | Args: 175 | x_noised: The perturbed image tensor. 176 | sigma: The variance (noise level) tensor. 177 | Returns: 178 | The estimated denoised image tensor. 179 | The {name: (B, C, H, W) tensor} activation dict (if ret_activation is True). 180 | ''' 181 | x_noised = x_noised.to(torch.float32) 182 | sigma = sigma.to(torch.float32) 183 | 184 | # Preconditioning 185 | c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) 186 | c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() 187 | c_in = 1 / (sigma ** 2 + self.sigma_data ** 2).sqrt() 188 | c_noise = sigma.log() / 4 189 | 190 | # Denoising 191 | with autocast(enabled=use_amp): 192 | F_x = self.nn_model(c_in * x_noised, c_noise.flatten(), ret_activation) 193 | 194 | if ret_activation: 195 | return c_skip * x_noised + c_out * F_x[0], F_x[1] 196 | else: 197 | return c_skip * x_noised + c_out * F_x 198 | 199 | def sample_schedule(self, steps): 200 | ''' Make the variance schedule for EDM sampling. 201 | 202 | Args: 203 | steps: The number of total timesteps. Typically 18, 50 or 100. 204 | Returns: 205 | times: A decreasing tensor list such that 206 | `times[0] == sigma_max`, 207 | `times[steps-1] == sigma_min`, and 208 | `times[steps] == 0`. 209 | ''' 210 | sigma_min, sigma_max, rho = self.sigma_min, self.sigma_max, self.rho 211 | times = torch.arange(steps, dtype=torch.float64, device=self.device) 212 | times = (sigma_max ** (1 / rho) + times / (steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 213 | times = torch.cat([times, torch.zeros_like(times[:1])]) # t_N = 0 214 | return times 215 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutureXiang/edm2/cea49702a2fd92957f7412f8e0ea804e59aaaa5a/model/__init__.py -------------------------------------------------------------------------------- /model/blockC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | def GetConv2d(C_in, C_out, kernel_size, stride=1, padding=0): 8 | return nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False) 9 | 10 | 11 | def GetLinear(C_in, C_out): 12 | return nn.Linear(C_in, C_out, bias=False) 13 | 14 | # ===== Embedding ===== 15 | 16 | class Fourier(nn.Module): 17 | def __init__(self, embedding_size=256): 18 | super().__init__() 19 | self.frequencies = nn.Parameter(torch.randn(embedding_size), requires_grad=False) 20 | self.phases = nn.Parameter(torch.rand(embedding_size), requires_grad=False) 21 | 22 | def forward(self, a): 23 | b = (2 * np.pi) * (a[:, None] * self.frequencies[None, :] + self.phases[None, :]) 24 | b = torch.cos(b) 25 | return b 26 | 27 | 28 | class Embedding(nn.Module): 29 | def __init__(self, n_channels): 30 | super().__init__() 31 | self.fourier = Fourier(embedding_size=n_channels // 4) 32 | self.linear1 = GetLinear(n_channels // 4, n_channels) 33 | self.act = nn.SiLU() 34 | self.linear2 = GetLinear(n_channels, n_channels) 35 | 36 | def forward(self, c_noise): 37 | emb = self.fourier(c_noise) 38 | emb = self.act(self.linear1(emb)) 39 | emb = self.act(self.linear2(emb)) 40 | return emb 41 | 42 | 43 | class Reweighting(nn.Module): 44 | def __init__(self, n_channels=256): 45 | super().__init__() 46 | self.fourier = Fourier(embedding_size=n_channels) 47 | self.linear = GetLinear(n_channels, 1) 48 | 49 | def forward(self, c_noise): 50 | emb = self.fourier(c_noise) 51 | emb = self.linear(emb) 52 | return emb 53 | 54 | # ===== Residual blocks ===== 55 | 56 | def PixNorm(x, dim=1, eps=1e-8): 57 | return x / torch.sqrt(torch.mean(x ** 2, dim=dim, keepdim=True) + eps) 58 | 59 | 60 | class Downsample(nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | self.pool = nn.AvgPool2d(2) 64 | 65 | def forward(self, x): 66 | return self.pool(x) 67 | 68 | 69 | class Upsample(nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | 73 | def forward(self, x): 74 | return nn.functional.interpolate(x, scale_factor=2, mode="nearest") 75 | 76 | 77 | def GroupNorm32(channels): 78 | return nn.GroupNorm(32, channels, affine=False) 79 | 80 | 81 | class EncoderResBlock(nn.Module): 82 | def __init__(self, in_channels, out_channels, emb_channels, dropout=0.1, down=False): 83 | super().__init__() 84 | self.linear = GetLinear(emb_channels, out_channels) 85 | 86 | self.down = Downsample() if down else nn.Identity() 87 | self.shortcut = GetConv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() 88 | 89 | self.norm1 = GroupNorm32(in_channels) 90 | self.conv1 = GetConv2d(in_channels, out_channels, kernel_size=3, padding=1) 91 | self.norm2 = GroupNorm32(out_channels) 92 | self.conv2 = nn.Sequential( 93 | nn.Dropout(dropout), 94 | GetConv2d(out_channels, out_channels, kernel_size=3, padding=1) 95 | ) 96 | self.act = nn.SiLU() 97 | 98 | def forward(self, x, emb): 99 | residual = self.down(self.act(self.norm1(x))) 100 | residual = self.norm2(self.conv1(residual)) 101 | 102 | emb = self.linear(emb) 103 | residual = residual * (1 + emb)[:, :, None, None] 104 | 105 | residual = self.conv2(self.act(residual)) 106 | 107 | main = self.shortcut(self.down(x)) 108 | return main + residual 109 | 110 | 111 | class DecoderResBlock(nn.Module): 112 | def __init__(self, in_channels, out_channels, emb_channels, dropout=0.1, up=False): 113 | super().__init__() 114 | self.linear = GetLinear(emb_channels, out_channels) 115 | 116 | self.up = Upsample() if up else nn.Identity() 117 | self.shortcut = GetConv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() 118 | 119 | self.norm1 = GroupNorm32(in_channels) 120 | self.conv1 = GetConv2d(in_channels, out_channels, kernel_size=3, padding=1) 121 | self.norm2 = GroupNorm32(out_channels) 122 | self.conv2 = nn.Sequential( 123 | nn.Dropout(dropout), 124 | GetConv2d(out_channels, out_channels, kernel_size=3, padding=1) 125 | ) 126 | self.act = nn.SiLU() 127 | 128 | def forward(self, x, emb): 129 | residual = self.up(self.act(self.norm1(x))) 130 | residual = self.norm2(self.conv1(residual)) 131 | 132 | emb = self.linear(emb) 133 | residual = residual * (1 + emb)[:, :, None, None] 134 | 135 | residual = self.conv2(self.act(residual)) 136 | 137 | main = self.shortcut(self.up(x)) 138 | return main + residual 139 | 140 | # ===== Attention block ===== 141 | 142 | class AttentionBlock(nn.Module): 143 | def __init__(self, n_channels, d_k): 144 | super().__init__() 145 | # Default `d_k` 146 | if d_k is None: 147 | d_k = n_channels 148 | n_heads = n_channels // d_k 149 | assert n_heads * d_k == n_channels 150 | 151 | self.projection = GetConv2d(n_channels, n_channels * 3, kernel_size=1) 152 | self.output = GetConv2d(n_channels, n_channels, kernel_size=1) 153 | 154 | self.scale = 1 / math.sqrt(math.sqrt(d_k)) 155 | self.n_heads = n_heads 156 | self.d_k = d_k 157 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 158 | print(f"{self.n_heads} heads, {self.d_k} channels per head") 159 | 160 | def forward(self, x): 161 | batch_size, n_channels, height, width = x.shape 162 | # (b, c, h, w) -> (b, 3c, h, w) 163 | h = self.projection(x) 164 | 165 | # (b, 3c, h, w) -> (b, 3c, l) -> (b, l, 3c) 166 | h = h.flatten(start_dim=2, end_dim=-1).permute(0, 2, 1).contiguous() 167 | 168 | # (b, l, 3c) -> (b, l, n_heads, d_k * 3) -> 3 * (b, l, n_heads, d_k) 169 | qkv = h.view(batch_size, -1, self.n_heads, 3 * self.d_k) 170 | q, k, v = torch.chunk(qkv, 3, dim=-1) 171 | 172 | q = PixNorm(q, dim=3) 173 | k = PixNorm(k, dim=3) 174 | v = PixNorm(v, dim=3) 175 | attn = torch.einsum('bihd,bjhd->bijh', q * self.scale, k * self.scale) # More stable with f16 than dividing afterwards 176 | attn = attn.softmax(dim=2) 177 | res = torch.einsum('bijh,bjhd->bihd', attn, v) 178 | 179 | # (b, l, n_heads, d_k) -> (b, l, n_heads * d_k) -> (b, n_heads * d_k, l) -> (b, n_heads * d_k, h, w) -> (b, c, h, w) 180 | res = res.reshape(batch_size, -1, n_channels).permute(0, 2, 1).contiguous() 181 | res = res.view(batch_size, n_channels, height, width) 182 | res = self.output(res) 183 | return x + res 184 | -------------------------------------------------------------------------------- /model/blockE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | def sqrt(x): 8 | return np.sqrt(x, dtype=np.float32) 9 | 10 | # ===== Forced weight normalization ===== 11 | 12 | def normalize(x, eps=1e-4): 13 | dim = list(range(1, x.ndim)) 14 | n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) 15 | alpha = sqrt(n.numel() / x.numel()) 16 | return x / torch.add(eps, n, alpha=alpha) 17 | 18 | 19 | class Conv2d(nn.Module): 20 | def __init__(self, C_in, C_out, kernel_size, stride, padding): 21 | super().__init__() 22 | w = torch.randn(C_out, C_in, kernel_size, kernel_size) 23 | self.weight = nn.Parameter(w) 24 | self.stride = stride 25 | self.padding = padding 26 | 27 | def forward(self, x): 28 | if self.training: 29 | with torch.no_grad(): 30 | self.weight.copy_(normalize(self.weight)) 31 | fan_in = self.weight[0].numel() 32 | w = normalize(self.weight) / sqrt(fan_in) 33 | x = nn.functional.conv2d(x, w, bias=None, stride=self.stride, padding=self.padding) 34 | return x 35 | 36 | 37 | class Linear(nn.Module): 38 | def __init__(self, C_in, C_out): 39 | super().__init__() 40 | w = torch.randn(C_out, C_in) 41 | self.weight = nn.Parameter(w) 42 | 43 | def forward(self, x): 44 | if self.training: 45 | with torch.no_grad(): 46 | self.weight.copy_(normalize(self.weight)) 47 | fan_in = self.weight[0].numel() 48 | w = normalize(self.weight) / sqrt(fan_in) 49 | x = nn.functional.linear(x, w, bias=None) 50 | return x 51 | 52 | 53 | def GetConv2d(C_in, C_out, kernel_size, stride=1, padding=0): 54 | return Conv2d(C_in, C_out, kernel_size, stride, padding) 55 | 56 | 57 | def GetLinear(C_in, C_out): 58 | return Linear(C_in, C_out) 59 | 60 | # ===== Embedding ===== 61 | 62 | class Fourier(nn.Module): 63 | def __init__(self, embedding_size=256): 64 | super().__init__() 65 | self.frequencies = nn.Parameter(torch.randn(embedding_size), requires_grad=False) 66 | self.phases = nn.Parameter(torch.rand(embedding_size), requires_grad=False) 67 | 68 | def forward(self, a): 69 | b = (2 * np.pi) * (a[:, None] * self.frequencies[None, :] + self.phases[None, :]) 70 | b = torch.cos(b) 71 | return b 72 | 73 | 74 | class Embedding(nn.Module): 75 | def __init__(self, n_channels): 76 | super().__init__() 77 | self.fourier = Fourier(embedding_size=n_channels // 4) 78 | self.linear1 = GetLinear(n_channels // 4, n_channels) 79 | self.act = nn.SiLU() 80 | self.linear2 = GetLinear(n_channels, n_channels) 81 | 82 | def forward(self, c_noise): 83 | emb = self.fourier(c_noise) 84 | emb = self.act(self.linear1(emb)) 85 | emb = self.act(self.linear2(emb)) 86 | return emb 87 | 88 | 89 | class Reweighting(nn.Module): 90 | def __init__(self, n_channels=256): 91 | super().__init__() 92 | self.fourier = Fourier(embedding_size=n_channels) 93 | self.linear = GetLinear(n_channels, 1) 94 | 95 | def forward(self, c_noise): 96 | emb = self.fourier(c_noise) 97 | emb = self.linear(emb) 98 | return emb 99 | 100 | # ===== Residual blocks ===== 101 | 102 | def PixNorm(x, dim=1, eps=1e-8): 103 | return x / torch.sqrt(torch.mean(x ** 2, dim=dim, keepdim=True) + eps) 104 | 105 | 106 | class Downsample(nn.Module): 107 | def __init__(self): 108 | super().__init__() 109 | self.pool = nn.AvgPool2d(2) 110 | 111 | def forward(self, x): 112 | return self.pool(x) 113 | 114 | 115 | class Upsample(nn.Module): 116 | def __init__(self): 117 | super().__init__() 118 | 119 | def forward(self, x): 120 | return nn.functional.interpolate(x, scale_factor=2, mode="nearest") 121 | 122 | 123 | def GroupNorm32(channels): 124 | return nn.GroupNorm(32, channels, affine=False) 125 | 126 | 127 | class EncoderResBlock(nn.Module): 128 | def __init__(self, in_channels, out_channels, emb_channels, dropout=0.1, down=False): 129 | super().__init__() 130 | self.linear = GetLinear(emb_channels, out_channels) 131 | 132 | self.down = Downsample() if down else nn.Identity() 133 | self.shortcut = GetConv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() 134 | 135 | self.norm1 = GroupNorm32(in_channels) 136 | self.conv1 = GetConv2d(in_channels, out_channels, kernel_size=3, padding=1) 137 | self.norm2 = GroupNorm32(out_channels) 138 | self.conv2 = nn.Sequential( 139 | nn.Dropout(dropout), 140 | GetConv2d(out_channels, out_channels, kernel_size=3, padding=1) 141 | ) 142 | self.act = nn.SiLU() 143 | 144 | def forward(self, x, emb): 145 | residual = self.down(self.act(self.norm1(x))) 146 | residual = self.norm2(self.conv1(residual)) 147 | 148 | emb = self.linear(emb) 149 | residual = residual * (1 + emb)[:, :, None, None] 150 | 151 | residual = self.conv2(self.act(residual)) 152 | 153 | main = self.shortcut(self.down(x)) 154 | return main + residual 155 | 156 | 157 | class DecoderResBlock(nn.Module): 158 | def __init__(self, in_channels, out_channels, emb_channels, dropout=0.1, up=False): 159 | super().__init__() 160 | self.linear = GetLinear(emb_channels, out_channels) 161 | 162 | self.up = Upsample() if up else nn.Identity() 163 | self.shortcut = GetConv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() 164 | 165 | self.norm1 = GroupNorm32(in_channels) 166 | self.conv1 = GetConv2d(in_channels, out_channels, kernel_size=3, padding=1) 167 | self.norm2 = GroupNorm32(out_channels) 168 | self.conv2 = nn.Sequential( 169 | nn.Dropout(dropout), 170 | GetConv2d(out_channels, out_channels, kernel_size=3, padding=1) 171 | ) 172 | self.act = nn.SiLU() 173 | 174 | def forward(self, x, emb): 175 | residual = self.up(self.act(self.norm1(x))) 176 | residual = self.norm2(self.conv1(residual)) 177 | 178 | emb = self.linear(emb) 179 | residual = residual * (1 + emb)[:, :, None, None] 180 | 181 | residual = self.conv2(self.act(residual)) 182 | 183 | main = self.shortcut(self.up(x)) 184 | return main + residual 185 | 186 | # ===== Attention block ===== 187 | 188 | class AttentionBlock(nn.Module): 189 | def __init__(self, n_channels, d_k): 190 | super().__init__() 191 | # Default `d_k` 192 | if d_k is None: 193 | d_k = n_channels 194 | n_heads = n_channels // d_k 195 | assert n_heads * d_k == n_channels 196 | 197 | self.projection = GetConv2d(n_channels, n_channels * 3, kernel_size=1) 198 | self.output = GetConv2d(n_channels, n_channels, kernel_size=1) 199 | 200 | self.scale = 1 / math.sqrt(math.sqrt(d_k)) 201 | self.n_heads = n_heads 202 | self.d_k = d_k 203 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 204 | print(f"{self.n_heads} heads, {self.d_k} channels per head") 205 | 206 | def forward(self, x): 207 | batch_size, n_channels, height, width = x.shape 208 | # (b, c, h, w) -> (b, 3c, h, w) 209 | h = self.projection(x) 210 | 211 | # (b, 3c, h, w) -> (b, 3c, l) -> (b, l, 3c) 212 | h = h.flatten(start_dim=2, end_dim=-1).permute(0, 2, 1).contiguous() 213 | 214 | # (b, l, 3c) -> (b, l, n_heads, d_k * 3) -> 3 * (b, l, n_heads, d_k) 215 | qkv = h.view(batch_size, -1, self.n_heads, 3 * self.d_k) 216 | q, k, v = torch.chunk(qkv, 3, dim=-1) 217 | 218 | q = PixNorm(q, dim=3) 219 | k = PixNorm(k, dim=3) 220 | v = PixNorm(v, dim=3) 221 | attn = torch.einsum('bihd,bjhd->bijh', q * self.scale, k * self.scale) # More stable with f16 than dividing afterwards 222 | attn = attn.softmax(dim=2) 223 | res = torch.einsum('bijh,bjhd->bihd', attn, v) 224 | 225 | # (b, l, n_heads, d_k) -> (b, l, n_heads * d_k) -> (b, n_heads * d_k, l) -> (b, n_heads * d_k, h, w) -> (b, c, h, w) 226 | res = res.reshape(batch_size, -1, n_channels).permute(0, 2, 1).contiguous() 227 | res = res.view(batch_size, n_channels, height, width) 228 | res = self.output(res) 229 | return x + res 230 | -------------------------------------------------------------------------------- /model/blockG.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | def sqrt(x): 8 | return np.sqrt(x, dtype=np.float32) 9 | 10 | # ===== Forced weight normalization ===== 11 | 12 | def normalize(x, eps=1e-4): 13 | dim = list(range(1, x.ndim)) 14 | n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) 15 | alpha = sqrt(n.numel() / x.numel()) 16 | return x / torch.add(eps, n, alpha=alpha) 17 | 18 | 19 | class Conv2d(nn.Module): 20 | def __init__(self, C_in, C_out, kernel_size, stride, padding): 21 | super().__init__() 22 | w = torch.randn(C_out, C_in, kernel_size, kernel_size) 23 | self.weight = nn.Parameter(w) 24 | self.stride = stride 25 | self.padding = padding 26 | 27 | def forward(self, x): 28 | if self.training: 29 | with torch.no_grad(): 30 | self.weight.copy_(normalize(self.weight)) 31 | fan_in = self.weight[0].numel() 32 | w = normalize(self.weight) / sqrt(fan_in) 33 | x = nn.functional.conv2d(x, w, bias=None, stride=self.stride, padding=self.padding) 34 | return x 35 | 36 | 37 | class Linear(nn.Module): 38 | def __init__(self, C_in, C_out): 39 | super().__init__() 40 | w = torch.randn(C_out, C_in) 41 | self.weight = nn.Parameter(w) 42 | 43 | def forward(self, x): 44 | if self.training: 45 | with torch.no_grad(): 46 | self.weight.copy_(normalize(self.weight)) 47 | fan_in = self.weight[0].numel() 48 | w = normalize(self.weight) / sqrt(fan_in) 49 | x = nn.functional.linear(x, w, bias=None) 50 | return x 51 | 52 | 53 | def GetConv2d(C_in, C_out, kernel_size, stride=1, padding=0): 54 | return Conv2d(C_in, C_out, kernel_size, stride, padding) 55 | 56 | 57 | def GetLinear(C_in, C_out): 58 | return Linear(C_in, C_out) 59 | 60 | # ===== Magnitude-preserving fixed-function layers ===== 61 | 62 | class MP_Fourier(nn.Module): 63 | def __init__(self, embedding_size=256): 64 | super().__init__() 65 | self.frequencies = nn.Parameter(torch.randn(embedding_size), requires_grad=False) 66 | self.phases = nn.Parameter(torch.rand(embedding_size), requires_grad=False) 67 | 68 | def forward(self, a): 69 | b = (2 * np.pi) * (a[:, None] * self.frequencies[None, :] + self.phases[None, :]) 70 | b = torch.cos(b) * sqrt(2.) 71 | return b 72 | 73 | 74 | class MP_SiLU(nn.Module): 75 | def __init__(self): 76 | super().__init__() 77 | self.act = nn.SiLU() 78 | 79 | def forward(self, x): 80 | return self.act(x) / 0.596 81 | 82 | 83 | class MP_Sum(nn.Module): 84 | def __init__(self, b_contribution=0.3): 85 | super().__init__() 86 | self.b_contribution = b_contribution 87 | self.a_contribution = 1.0 - b_contribution 88 | self.div = sqrt(self.a_contribution ** 2 + self.b_contribution ** 2) 89 | 90 | def forward(self, a, b): 91 | return (self.a_contribution * a + self.b_contribution * b) / self.div 92 | 93 | 94 | class MP_Cat(nn.Module): 95 | def __init__(self, b_contribution=0.5): 96 | super().__init__() 97 | self.b_contribution = b_contribution 98 | self.a_contribution = 1.0 - b_contribution 99 | self.div = sqrt(self.a_contribution ** 2 + self.b_contribution ** 2) 100 | 101 | def forward(self, a, b): 102 | Na = a.shape[1] 103 | Nb = b.shape[1] 104 | c = torch.cat([self.a_contribution * a / sqrt(Na), self.b_contribution * b / sqrt(Nb)], dim=1) 105 | return c * sqrt(Na + Nb) / self.div 106 | 107 | 108 | class Gain(nn.Module): 109 | def __init__(self, init='zero'): 110 | super().__init__() 111 | if init == 'one': 112 | self.g = nn.Parameter(torch.tensor(1.0)) 113 | else: 114 | assert init == 'zero' 115 | self.g = nn.Parameter(torch.tensor(0.0)) 116 | 117 | def forward(self, x): 118 | return x * self.g 119 | 120 | # ===== Embedding ===== 121 | 122 | class Embedding(nn.Module): 123 | def __init__(self, n_channels): 124 | super().__init__() 125 | self.fourier = MP_Fourier(embedding_size=n_channels // 4) 126 | self.linear1 = GetLinear(n_channels // 4, n_channels) 127 | self.act = MP_SiLU() 128 | 129 | def forward(self, c_noise): 130 | emb = self.fourier(c_noise) 131 | emb = self.act(self.linear1(emb)) 132 | return emb 133 | 134 | 135 | class Reweighting(nn.Module): 136 | def __init__(self, n_channels=256): 137 | super().__init__() 138 | self.fourier = MP_Fourier(embedding_size=n_channels) 139 | self.linear = GetLinear(n_channels, 1) 140 | 141 | def forward(self, c_noise): 142 | emb = self.fourier(c_noise) 143 | emb = self.linear(emb) 144 | return emb 145 | 146 | # ===== Residual blocks ===== 147 | 148 | def PixNorm(x, dim=1, eps=1e-8): 149 | return x / torch.sqrt(torch.mean(x ** 2, dim=dim, keepdim=True) + eps) 150 | 151 | 152 | class Downsample(nn.Module): 153 | def __init__(self): 154 | super().__init__() 155 | self.pool = nn.AvgPool2d(2) 156 | 157 | def forward(self, x): 158 | return self.pool(x) 159 | 160 | 161 | class Upsample(nn.Module): 162 | def __init__(self): 163 | super().__init__() 164 | 165 | def forward(self, x): 166 | return nn.functional.interpolate(x, scale_factor=2, mode="nearest") 167 | 168 | 169 | class EncoderResBlock(nn.Module): 170 | def __init__(self, in_channels, out_channels, emb_channels, dropout=0.1, down=False): 171 | super().__init__() 172 | self.linear = GetLinear(emb_channels, out_channels) 173 | self.gain = Gain(init='one') 174 | 175 | self.down = Downsample() if down else nn.Identity() 176 | self.shortcut = GetConv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() 177 | 178 | self.conv1 = GetConv2d(out_channels, out_channels, kernel_size=3, padding=1) 179 | self.conv2 = nn.Sequential( 180 | nn.Dropout(dropout), 181 | GetConv2d(out_channels, out_channels, kernel_size=3, padding=1) 182 | ) 183 | self.act = MP_SiLU() 184 | self.outadd = MP_Sum() 185 | 186 | def forward(self, x, emb): 187 | main = PixNorm(self.shortcut(self.down(x))) 188 | residual = self.conv1(self.act(main)) 189 | 190 | emb = self.gain(self.linear(emb)) 191 | residual = residual * (1 + emb)[:, :, None, None] 192 | 193 | residual = self.conv2(self.act(residual)) 194 | 195 | return self.outadd(main, residual) 196 | 197 | 198 | class DecoderResBlock(nn.Module): 199 | def __init__(self, in_channels, out_channels, emb_channels, dropout=0.1, up=False): 200 | super().__init__() 201 | self.linear = GetLinear(emb_channels, out_channels) 202 | self.gain = Gain(init='one') 203 | 204 | self.up = Upsample() if up else nn.Identity() 205 | self.shortcut = GetConv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() 206 | 207 | self.conv1 = GetConv2d(in_channels, out_channels, kernel_size=3, padding=1) 208 | self.conv2 = nn.Sequential( 209 | nn.Dropout(dropout), 210 | GetConv2d(out_channels, out_channels, kernel_size=3, padding=1) 211 | ) 212 | self.act = MP_SiLU() 213 | self.outadd = MP_Sum() 214 | 215 | def forward(self, x, emb): 216 | main = self.up(x) 217 | residual = self.conv1(self.act(main)) 218 | 219 | emb = self.gain(self.linear(emb)) 220 | residual = residual * (1 + emb)[:, :, None, None] 221 | 222 | residual = self.conv2(self.act(residual)) 223 | 224 | main = self.shortcut(main) 225 | return self.outadd(main, residual) 226 | 227 | # ===== Attention block ===== 228 | 229 | class AttentionBlock(nn.Module): 230 | def __init__(self, n_channels, d_k): 231 | super().__init__() 232 | # Default `d_k` 233 | if d_k is None: 234 | d_k = n_channels 235 | n_heads = n_channels // d_k 236 | assert n_heads * d_k == n_channels 237 | 238 | self.projection = GetConv2d(n_channels, n_channels * 3, kernel_size=1) 239 | self.output = GetConv2d(n_channels, n_channels, kernel_size=1) 240 | 241 | self.outadd = MP_Sum() 242 | self.scale = 1 / math.sqrt(math.sqrt(d_k)) 243 | self.n_heads = n_heads 244 | self.d_k = d_k 245 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 246 | print(f"{self.n_heads} heads, {self.d_k} channels per head") 247 | 248 | def forward(self, x): 249 | batch_size, n_channels, height, width = x.shape 250 | # (b, c, h, w) -> (b, 3c, h, w) 251 | h = self.projection(x) 252 | 253 | # (b, 3c, h, w) -> (b, 3c, l) -> (b, l, 3c) 254 | h = h.flatten(start_dim=2, end_dim=-1).permute(0, 2, 1).contiguous() 255 | 256 | # (b, l, 3c) -> (b, l, n_heads, d_k * 3) -> 3 * (b, l, n_heads, d_k) 257 | qkv = h.view(batch_size, -1, self.n_heads, 3 * self.d_k) 258 | q, k, v = torch.chunk(qkv, 3, dim=-1) 259 | 260 | q = PixNorm(q, dim=3) 261 | k = PixNorm(k, dim=3) 262 | v = PixNorm(v, dim=3) 263 | attn = torch.einsum('bihd,bjhd->bijh', q * self.scale, k * self.scale) # More stable with f16 than dividing afterwards 264 | attn = attn.softmax(dim=2) 265 | res = torch.einsum('bijh,bjhd->bihd', attn, v) 266 | 267 | # (b, l, n_heads, d_k) -> (b, l, n_heads * d_k) -> (b, n_heads * d_k, l) -> (b, n_heads * d_k, h, w) -> (b, c, h, w) 268 | res = res.reshape(batch_size, -1, n_channels).permute(0, 2, 1).contiguous() 269 | res = res.view(batch_size, n_channels, height, width) 270 | res = self.output(res) 271 | return self.outadd(x, res) 272 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | from .EDM import EDM 2 | from .unetC import UNetC 3 | from .unetE import UNetE 4 | from .unetG import UNetG 5 | 6 | 7 | CLASSES = { 8 | cls.__name__: cls 9 | for cls in [EDM, UNetC, UNetE, UNetG] 10 | } 11 | 12 | 13 | def get_models_class(model_type, net_type): 14 | return CLASSES[model_type], CLASSES[net_type] 15 | -------------------------------------------------------------------------------- /model/unetC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .blockC import EncoderResBlock, DecoderResBlock, AttentionBlock, Embedding, Reweighting, GetConv2d 4 | 5 | 6 | class EncoderBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, emb_channels, dropout, down, has_attn, attn_channels_per_head): 8 | super().__init__() 9 | self.res = EncoderResBlock(in_channels, out_channels, emb_channels, dropout, down) 10 | if has_attn: 11 | self.attn = AttentionBlock(out_channels, attn_channels_per_head) 12 | else: 13 | self.attn = nn.Identity() 14 | 15 | def forward(self, x, emb): 16 | x = self.res(x, emb) 17 | x = self.attn(x) 18 | return x 19 | 20 | 21 | class DecoderBlock(nn.Module): 22 | def __init__(self, in_channels, out_channels, emb_channels, dropout, up, has_attn, attn_channels_per_head): 23 | super().__init__() 24 | self.res = DecoderResBlock(in_channels, out_channels, emb_channels, dropout, up) 25 | if has_attn: 26 | self.attn = AttentionBlock(out_channels, attn_channels_per_head) 27 | else: 28 | self.attn = nn.Identity() 29 | self.up = up 30 | 31 | def forward(self, x, emb): 32 | x = self.res(x, emb) 33 | x = self.attn(x) 34 | return x 35 | 36 | 37 | class UNetC(nn.Module): 38 | def __init__(self, image_shape = [3, 32, 32], n_channels = 128, 39 | ch_mults = (1, 2, 2, 2), 40 | is_attn = (False, True, False, False), 41 | attn_channels_per_head = None, 42 | dropout = 0.1, 43 | n_blocks = 3): 44 | """ 45 | * `image_shape` is the (channel, height, width) size of images. 46 | * `n_channels` is number of channels in the initial feature map that we transform the image into 47 | * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `n_channels * ch_mults[i]` 48 | * `is_attn` is a list of booleans that indicate whether to use attention at each resolution 49 | * `dropout` is the dropout rate 50 | * `n_blocks` is the number of `UpDownBlocks` at each resolution 51 | """ 52 | super().__init__() 53 | n_resolutions = len(ch_mults) 54 | 55 | self.image_proj = GetConv2d(image_shape[0] + 1, n_channels, kernel_size=3, padding=1) 56 | 57 | # Embedding layers 58 | emb_channels = n_channels * 4 59 | self.embedding = Embedding(emb_channels) 60 | 61 | # Down stages 62 | down = [] 63 | in_channels = n_channels 64 | h_channels = [n_channels] 65 | for i in range(n_resolutions): 66 | # Number of output channels at this resolution 67 | out_channels = n_channels * ch_mults[i] 68 | # `n_blocks` at the same resolution 69 | down.append(EncoderBlock(in_channels, out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 70 | h_channels.append(out_channels) 71 | for _ in range(n_blocks - 2): 72 | down.append(EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 73 | h_channels.append(out_channels) 74 | # Down sample at all resolutions except the last 75 | if i < n_resolutions - 1: 76 | down.append(EncoderBlock(out_channels, out_channels, emb_channels, dropout, True, False, 0)) 77 | h_channels.append(out_channels) 78 | in_channels = out_channels 79 | self.down = nn.ModuleList(down) 80 | 81 | # Middle block 82 | self.middle1 = EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, True, attn_channels_per_head) 83 | self.middle2 = EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, False, 0) 84 | 85 | # Up stages 86 | up = [] 87 | in_channels = out_channels 88 | for i in reversed(range(n_resolutions)): 89 | # Number of output channels at this resolution 90 | out_channels = n_channels * ch_mults[i] 91 | # `n_blocks + 1` at the same resolution 92 | for _ in range(n_blocks): 93 | up.append(DecoderBlock(in_channels + h_channels.pop(), out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 94 | in_channels = out_channels 95 | # Up sample at all resolutions except last 96 | if i > 0: 97 | up.append(DecoderBlock(out_channels, out_channels, emb_channels, dropout, True, False, 0)) 98 | assert not h_channels 99 | self.up = nn.ModuleList(up) 100 | 101 | # Final convolution layer 102 | self.final = nn.Sequential( 103 | nn.GroupNorm(8, out_channels, affine=False), 104 | nn.SiLU(), 105 | GetConv2d(out_channels, image_shape[0], kernel_size=3, padding=1), 106 | ) 107 | 108 | def forward(self, x, t, ret_activation=False): 109 | if not ret_activation: 110 | return self.forward_core(x, t) 111 | 112 | activation = {} 113 | def namedHook(name): 114 | def hook(module, input, output): 115 | activation[name] = output 116 | return hook 117 | hooks = {} 118 | no = 0 119 | for blk in self.up: 120 | if isinstance(blk, DecoderBlock): 121 | no += 1 122 | name = f'out_{no}' 123 | hooks[name] = blk.register_forward_hook(namedHook(name)) 124 | 125 | result = self.forward_core(x, t) 126 | for name in hooks: 127 | hooks[name].remove() 128 | return result, activation 129 | 130 | def forward_core(self, x, t): 131 | """ 132 | * `x` has shape `[batch_size, in_channels, height, width]` 133 | * `t` has shape `[batch_size]` 134 | """ 135 | ones_tensor = torch.ones(x.shape[0], 1, x.shape[2], x.shape[3], dtype=x.dtype, device=x.device) 136 | x = torch.cat([x, ones_tensor], dim=1) 137 | x = self.image_proj(x) 138 | emb = self.embedding(t) 139 | 140 | # `h` will store outputs at each resolution for skip connection 141 | h = [x] 142 | 143 | for m in self.down: 144 | x = m(x, emb) 145 | h.append(x) 146 | 147 | x = self.middle1(x, emb) 148 | x = self.middle2(x, emb) 149 | 150 | for m in self.up: 151 | if m.up: 152 | x = m(x, emb) 153 | else: 154 | # Get the skip connection from first half of U-Net and concatenate 155 | s = h.pop() 156 | x = torch.cat((x, s), dim=1) 157 | x = m(x, emb) 158 | 159 | return self.final(x) 160 | 161 | def get_reweighting(self): 162 | return Reweighting() 163 | 164 | def forward_reweighting(self, MLP, sigma): 165 | return MLP(sigma.flatten().log() / 4).unsqueeze(-1).unsqueeze(-1) 166 | 167 | ''' 168 | from model.unetC import UNetC 169 | import torch 170 | net = UNetC() 171 | x = torch.zeros(1, 3, 32, 32) 172 | t = torch.zeros(1,) 173 | 174 | net(x, t).shape 175 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 176 | 177 | >>> 39.509888 M parameters for CIFAR-10 model 178 | 179 | 180 | net = UNetC(image_shape=[3,64,64], n_channels=192, ch_mults=[1,2,3,4], is_attn=[False,False,True,True], n_blocks=4) 181 | x = torch.zeros(1, 3, 64, 64) 182 | t = torch.zeros(1,) 183 | 184 | net(x, t).shape 185 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 186 | 187 | >>> 277.045056 M parameters for ImageNet-64 model 188 | (becomes exactly 277.8M in Figure 17, after adding 768*1000 class embedding) 189 | ''' 190 | -------------------------------------------------------------------------------- /model/unetE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .blockE import EncoderResBlock, DecoderResBlock, AttentionBlock, Embedding, Reweighting, GetConv2d 4 | 5 | 6 | class EncoderBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, emb_channels, dropout, down, has_attn, attn_channels_per_head): 8 | super().__init__() 9 | self.res = EncoderResBlock(in_channels, out_channels, emb_channels, dropout, down) 10 | if has_attn: 11 | self.attn = AttentionBlock(out_channels, attn_channels_per_head) 12 | else: 13 | self.attn = nn.Identity() 14 | 15 | def forward(self, x, emb): 16 | x = self.res(x, emb) 17 | x = self.attn(x) 18 | x = torch.clip(x, -256, 256) 19 | return x 20 | 21 | 22 | class DecoderBlock(nn.Module): 23 | def __init__(self, in_channels, out_channels, emb_channels, dropout, up, has_attn, attn_channels_per_head): 24 | super().__init__() 25 | self.res = DecoderResBlock(in_channels, out_channels, emb_channels, dropout, up) 26 | if has_attn: 27 | self.attn = AttentionBlock(out_channels, attn_channels_per_head) 28 | else: 29 | self.attn = nn.Identity() 30 | self.up = up 31 | 32 | def forward(self, x, emb): 33 | x = self.res(x, emb) 34 | x = self.attn(x) 35 | x = torch.clip(x, -256, 256) 36 | return x 37 | 38 | 39 | class UNetE(nn.Module): 40 | def __init__(self, image_shape = [3, 32, 32], n_channels = 128, 41 | ch_mults = (1, 2, 2, 2), 42 | is_attn = (False, True, False, False), 43 | attn_channels_per_head = None, 44 | dropout = 0.1, 45 | n_blocks = 3): 46 | """ 47 | * `image_shape` is the (channel, height, width) size of images. 48 | * `n_channels` is number of channels in the initial feature map that we transform the image into 49 | * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `n_channels * ch_mults[i]` 50 | * `is_attn` is a list of booleans that indicate whether to use attention at each resolution 51 | * `dropout` is the dropout rate 52 | * `n_blocks` is the number of `UpDownBlocks` at each resolution 53 | """ 54 | super().__init__() 55 | n_resolutions = len(ch_mults) 56 | 57 | self.image_proj = GetConv2d(image_shape[0] + 1, n_channels, kernel_size=3, padding=1) 58 | 59 | # Embedding layers 60 | emb_channels = n_channels * 4 61 | self.embedding = Embedding(emb_channels) 62 | 63 | # Down stages 64 | down = [] 65 | in_channels = n_channels 66 | h_channels = [n_channels] 67 | for i in range(n_resolutions): 68 | # Number of output channels at this resolution 69 | out_channels = n_channels * ch_mults[i] 70 | # `n_blocks` at the same resolution 71 | down.append(EncoderBlock(in_channels, out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 72 | h_channels.append(out_channels) 73 | for _ in range(n_blocks - 2): 74 | down.append(EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 75 | h_channels.append(out_channels) 76 | # Down sample at all resolutions except the last 77 | if i < n_resolutions - 1: 78 | down.append(EncoderBlock(out_channels, out_channels, emb_channels, dropout, True, False, 0)) 79 | h_channels.append(out_channels) 80 | in_channels = out_channels 81 | self.down = nn.ModuleList(down) 82 | 83 | # Middle block 84 | self.middle1 = EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, True, attn_channels_per_head) 85 | self.middle2 = EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, False, 0) 86 | 87 | # Up stages 88 | up = [] 89 | in_channels = out_channels 90 | for i in reversed(range(n_resolutions)): 91 | # Number of output channels at this resolution 92 | out_channels = n_channels * ch_mults[i] 93 | # `n_blocks + 1` at the same resolution 94 | for _ in range(n_blocks): 95 | up.append(DecoderBlock(in_channels + h_channels.pop(), out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 96 | in_channels = out_channels 97 | # Up sample at all resolutions except last 98 | if i > 0: 99 | up.append(DecoderBlock(out_channels, out_channels, emb_channels, dropout, True, False, 0)) 100 | assert not h_channels 101 | self.up = nn.ModuleList(up) 102 | 103 | # Final convolution layer 104 | self.final = nn.Sequential( 105 | nn.GroupNorm(8, out_channels, affine=False), 106 | nn.SiLU(), 107 | GetConv2d(out_channels, image_shape[0], kernel_size=3, padding=1), 108 | ) 109 | 110 | def forward(self, x, t, ret_activation=False): 111 | if not ret_activation: 112 | return self.forward_core(x, t) 113 | 114 | activation = {} 115 | def namedHook(name): 116 | def hook(module, input, output): 117 | activation[name] = output 118 | return hook 119 | hooks = {} 120 | no = 0 121 | for blk in self.up: 122 | if isinstance(blk, DecoderBlock): 123 | no += 1 124 | name = f'out_{no}' 125 | hooks[name] = blk.register_forward_hook(namedHook(name)) 126 | 127 | result = self.forward_core(x, t) 128 | for name in hooks: 129 | hooks[name].remove() 130 | return result, activation 131 | 132 | def forward_core(self, x, t): 133 | """ 134 | * `x` has shape `[batch_size, in_channels, height, width]` 135 | * `t` has shape `[batch_size]` 136 | """ 137 | ones_tensor = torch.ones(x.shape[0], 1, x.shape[2], x.shape[3], dtype=x.dtype, device=x.device) 138 | x = torch.cat([x, ones_tensor], dim=1) 139 | x = self.image_proj(x) 140 | emb = self.embedding(t) 141 | 142 | # `h` will store outputs at each resolution for skip connection 143 | h = [x] 144 | 145 | for m in self.down: 146 | x = m(x, emb) 147 | h.append(x) 148 | 149 | x = self.middle1(x, emb) 150 | x = self.middle2(x, emb) 151 | 152 | for m in self.up: 153 | if m.up: 154 | x = m(x, emb) 155 | else: 156 | # Get the skip connection from first half of U-Net and concatenate 157 | s = h.pop() 158 | x = torch.cat((x, s), dim=1) 159 | x = m(x, emb) 160 | 161 | return self.final(x) 162 | 163 | def get_reweighting(self): 164 | return Reweighting() 165 | 166 | def forward_reweighting(self, MLP, sigma): 167 | return MLP(sigma.flatten().log() / 4).unsqueeze(-1).unsqueeze(-1) 168 | 169 | ''' 170 | from model.unetE import UNetE 171 | import torch 172 | net = UNetE() 173 | x = torch.zeros(1, 3, 32, 32) 174 | t = torch.zeros(1,) 175 | 176 | net(x, t).shape 177 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 178 | 179 | >>> 39.509888 M parameters for CIFAR-10 model 180 | 181 | 182 | net = UNetE(image_shape=[3,64,64], n_channels=192, ch_mults=[1,2,3,4], is_attn=[False,False,True,True], n_blocks=4) 183 | x = torch.zeros(1, 3, 64, 64) 184 | t = torch.zeros(1,) 185 | 186 | net(x, t).shape 187 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 188 | 189 | >>> 277.045056 M parameters for ImageNet-64 model 190 | (becomes exactly 277.8M in Figure 19, after adding 768*1000 class embedding) 191 | ''' 192 | -------------------------------------------------------------------------------- /model/unetG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .blockG import EncoderResBlock, DecoderResBlock, AttentionBlock, Embedding, Reweighting, GetConv2d, Gain, MP_Cat 4 | 5 | 6 | class EncoderBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, emb_channels, dropout, down, has_attn, attn_channels_per_head): 8 | super().__init__() 9 | self.res = EncoderResBlock(in_channels, out_channels, emb_channels, dropout, down) 10 | if has_attn: 11 | self.attn = AttentionBlock(out_channels, attn_channels_per_head) 12 | else: 13 | self.attn = nn.Identity() 14 | 15 | def forward(self, x, emb): 16 | x = self.res(x, emb) 17 | x = self.attn(x) 18 | x = torch.clip(x, -256, 256) 19 | return x 20 | 21 | 22 | class DecoderBlock(nn.Module): 23 | def __init__(self, in_channels, out_channels, emb_channels, dropout, up, has_attn, attn_channels_per_head): 24 | super().__init__() 25 | self.res = DecoderResBlock(in_channels, out_channels, emb_channels, dropout, up) 26 | if has_attn: 27 | self.attn = AttentionBlock(out_channels, attn_channels_per_head) 28 | else: 29 | self.attn = nn.Identity() 30 | self.up = up 31 | 32 | def forward(self, x, emb): 33 | x = self.res(x, emb) 34 | x = self.attn(x) 35 | x = torch.clip(x, -256, 256) 36 | return x 37 | 38 | 39 | class UNetG(nn.Module): 40 | def __init__(self, image_shape = [3, 32, 32], n_channels = 128, 41 | ch_mults = (1, 2, 2, 2), 42 | is_attn = (False, True, False, False), 43 | attn_channels_per_head = None, 44 | dropout = 0.1, 45 | n_blocks = 3): 46 | """ 47 | * `image_shape` is the (channel, height, width) size of images. 48 | * `n_channels` is number of channels in the initial feature map that we transform the image into 49 | * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `n_channels * ch_mults[i]` 50 | * `is_attn` is a list of booleans that indicate whether to use attention at each resolution 51 | * `dropout` is the dropout rate 52 | * `n_blocks` is the number of `UpDownBlocks` at each resolution 53 | """ 54 | super().__init__() 55 | n_resolutions = len(ch_mults) 56 | 57 | self.image_proj = GetConv2d(image_shape[0] + 1, n_channels, kernel_size=3, padding=1) 58 | 59 | # Embedding layers 60 | emb_channels = n_channels * 4 61 | self.embedding = Embedding(emb_channels) 62 | 63 | # Down stages 64 | down = [] 65 | in_channels = n_channels 66 | h_channels = [n_channels] 67 | for i in range(n_resolutions): 68 | # Number of output channels at this resolution 69 | out_channels = n_channels * ch_mults[i] 70 | # `n_blocks` at the same resolution 71 | down.append(EncoderBlock(in_channels, out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 72 | h_channels.append(out_channels) 73 | for _ in range(n_blocks - 2): 74 | down.append(EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 75 | h_channels.append(out_channels) 76 | # Down sample at all resolutions except the last 77 | if i < n_resolutions - 1: 78 | down.append(EncoderBlock(out_channels, out_channels, emb_channels, dropout, True, False, 0)) 79 | h_channels.append(out_channels) 80 | in_channels = out_channels 81 | self.down = nn.ModuleList(down) 82 | 83 | # Middle block 84 | self.middle1 = EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, True, attn_channels_per_head) 85 | self.middle2 = EncoderBlock(out_channels, out_channels, emb_channels, dropout, False, False, 0) 86 | 87 | # Up stages 88 | up = [] 89 | in_channels = out_channels 90 | for i in reversed(range(n_resolutions)): 91 | # Number of output channels at this resolution 92 | out_channels = n_channels * ch_mults[i] 93 | # `n_blocks + 1` at the same resolution 94 | for _ in range(n_blocks): 95 | up.append(DecoderBlock(in_channels + h_channels.pop(), out_channels, emb_channels, dropout, False, is_attn[i], attn_channels_per_head)) 96 | in_channels = out_channels 97 | # Up sample at all resolutions except last 98 | if i > 0: 99 | up.append(DecoderBlock(out_channels, out_channels, emb_channels, dropout, True, False, 0)) 100 | assert not h_channels 101 | self.up = nn.ModuleList(up) 102 | self.skipcat = MP_Cat() 103 | 104 | # Final convolution layer 105 | self.final = nn.Sequential( 106 | GetConv2d(out_channels, image_shape[0], kernel_size=3, padding=1), 107 | Gain() 108 | ) 109 | 110 | def forward(self, x, t, ret_activation=False): 111 | if not ret_activation: 112 | return self.forward_core(x, t) 113 | 114 | activation = {} 115 | def namedHook(name): 116 | def hook(module, input, output): 117 | activation[name] = output 118 | return hook 119 | hooks = {} 120 | no = 0 121 | for blk in self.up: 122 | if isinstance(blk, DecoderBlock): 123 | no += 1 124 | name = f'out_{no}' 125 | hooks[name] = blk.register_forward_hook(namedHook(name)) 126 | 127 | result = self.forward_core(x, t) 128 | for name in hooks: 129 | hooks[name].remove() 130 | return result, activation 131 | 132 | def forward_core(self, x, t): 133 | """ 134 | * `x` has shape `[batch_size, in_channels, height, width]` 135 | * `t` has shape `[batch_size]` 136 | """ 137 | ones_tensor = torch.ones(x.shape[0], 1, x.shape[2], x.shape[3], dtype=x.dtype, device=x.device) 138 | x = torch.cat([x, ones_tensor], dim=1) 139 | x = self.image_proj(x) 140 | emb = self.embedding(t) 141 | 142 | # `h` will store outputs at each resolution for skip connection 143 | h = [x] 144 | 145 | for m in self.down: 146 | x = m(x, emb) 147 | h.append(x) 148 | 149 | x = self.middle1(x, emb) 150 | x = self.middle2(x, emb) 151 | 152 | for m in self.up: 153 | if m.up: 154 | x = m(x, emb) 155 | else: 156 | # Get the skip connection from first half of U-Net and concatenate 157 | s = h.pop() 158 | x = self.skipcat(x, s) 159 | x = m(x, emb) 160 | 161 | return self.final(x) 162 | 163 | def get_reweighting(self): 164 | return Reweighting() 165 | 166 | def forward_reweighting(self, MLP, sigma): 167 | return MLP(sigma.flatten().log() / 4).unsqueeze(-1).unsqueeze(-1) 168 | 169 | ''' 170 | from model.unetG import UNetG 171 | import torch 172 | net = UNetG() 173 | x = torch.zeros(1, 3, 32, 32) 174 | t = torch.zeros(1,) 175 | 176 | net(x, t).shape 177 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 178 | 179 | >>> 39.542685 M parameters for CIFAR-10 model 180 | 181 | 182 | net = UNetG(image_shape=[3,64,64], n_channels=192, ch_mults=[1,2,3,4], is_attn=[False,False,True,True], n_blocks=4) 183 | x = torch.zeros(1, 3, 64, 64) 184 | t = torch.zeros(1,) 185 | 186 | net(x, t).shape 187 | sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 188 | 189 | >>> 279.441253 M parameters for ImageNet-64 model 190 | (becomes exactly 280.2M in Figure 21, after adding 768*1000 class embedding) 191 | ''' 192 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.distributed as dist 6 | import yaml 7 | from torchvision.utils import make_grid, save_image 8 | from ema_pytorch import EMA 9 | 10 | from model.models import get_models_class 11 | from utils import Config, init_seeds, gather_tensor, print0 12 | 13 | 14 | def get_default_steps(model_type, steps): 15 | if steps is not None: 16 | return steps 17 | else: 18 | return {'EDM': 18}[model_type] 19 | 20 | 21 | # ===== sampling ===== 22 | 23 | def sample(opt): 24 | yaml_path = opt.config 25 | local_rank = opt.local_rank 26 | use_amp = opt.use_amp 27 | steps = opt.steps 28 | eta = opt.eta 29 | batches = opt.batches 30 | ep = opt.epoch 31 | 32 | with open(yaml_path, 'r') as f: 33 | opt = yaml.full_load(f) 34 | print0(opt) 35 | opt = Config(opt) 36 | if ep == -1: 37 | ep = opt.n_epoch - 1 38 | 39 | device = "cuda:%d" % local_rank 40 | steps = get_default_steps(opt.model_type, steps) 41 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 42 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 43 | **opt.diffusion, 44 | device=device, 45 | ) 46 | diff.to(device) 47 | 48 | target = os.path.join(opt.save_dir, "ckpts", f"model_{ep}.pth") 49 | print0("loading model at", target) 50 | checkpoint = torch.load(target, map_location=device) 51 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 52 | ema.to(device) 53 | ema.load_state_dict(checkpoint['EMA']) 54 | model = ema.ema_model 55 | model.eval() 56 | 57 | if local_rank == 0: 58 | if opt.model_type == 'EDM': 59 | gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_edm_steps{steps}_eta{eta}") 60 | else: 61 | raise NotImplementedError 62 | os.makedirs(gen_dir) 63 | gen_dir_png = os.path.join(gen_dir, "pngs") 64 | os.makedirs(gen_dir_png) 65 | res = [] 66 | 67 | for batch in range(batches): 68 | with torch.no_grad(): 69 | assert 400 % dist.get_world_size() == 0 70 | samples_per_process = 400 // dist.get_world_size() 71 | args = dict(n_sample=samples_per_process, size=opt.network['image_shape'], notqdm=(local_rank != 0), use_amp=use_amp) 72 | if opt.model_type == 'EDM': 73 | x_gen = model.edm_sample(**args, steps=steps, eta=eta) 74 | else: 75 | raise NotImplementedError 76 | dist.barrier() 77 | x_gen = gather_tensor(x_gen).cpu() 78 | if local_rank == 0: 79 | res.append(x_gen) 80 | grid = make_grid(x_gen, nrow=20) 81 | png_path = os.path.join(gen_dir, f"grid_{batch}.png") 82 | save_image(grid, png_path) 83 | 84 | if local_rank == 0: 85 | res = torch.cat(res) 86 | for no, img in enumerate(res): 87 | png_path = os.path.join(gen_dir_png, f"{no}.png") 88 | save_image(img, png_path) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--config", type=str) 94 | parser.add_argument("--use_amp", action='store_true', default=False) 95 | parser.add_argument("--steps", type=int, default=None) 96 | parser.add_argument("--eta", type=float, default=0.0) 97 | parser.add_argument("--batches", type=int, default=125) 98 | parser.add_argument("--epoch", type=int, default=-1) 99 | opt = parser.parse_args() 100 | opt.local_rank = int(os.environ['LOCAL_RANK']) 101 | print0(opt) 102 | 103 | init_seeds(no=opt.local_rank) 104 | dist.init_process_group(backend='nccl') 105 | torch.cuda.set_device(opt.local_rank) 106 | sample(opt) 107 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import yaml 8 | from datasets import get_dataset 9 | from torchvision.utils import make_grid, save_image 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | from ema_pytorch import EMA 13 | 14 | from model.models import get_models_class 15 | from utils import Config, get_optimizer, init_seeds, reduce_tensor, DataLoaderDDP, print0, Meter 16 | 17 | # ===== training ===== 18 | 19 | def train(opt): 20 | yaml_path = opt.config 21 | local_rank = opt.local_rank 22 | use_amp = opt.use_amp 23 | 24 | with open(yaml_path, 'r') as f: 25 | opt = yaml.full_load(f) 26 | print0(opt) 27 | opt = Config(opt) 28 | model_dir = os.path.join(opt.save_dir, "ckpts") 29 | vis_dir = os.path.join(opt.save_dir, "visual") 30 | tsbd_dir = os.path.join(opt.save_dir, "tensorboard") 31 | if local_rank == 0: 32 | os.makedirs(model_dir, exist_ok=True) 33 | os.makedirs(vis_dir, exist_ok=True) 34 | 35 | device = "cuda:%d" % local_rank 36 | DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type) 37 | diff = DIFFUSION(nn_model=NETWORK(**opt.network), 38 | **opt.diffusion, 39 | device=device, 40 | ) 41 | diff.to(device) 42 | if local_rank == 0: 43 | ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1) 44 | ema.to(device) 45 | ema.eval() 46 | writer = SummaryWriter(log_dir=tsbd_dir) 47 | 48 | diff = torch.nn.SyncBatchNorm.convert_sync_batchnorm(diff) 49 | diff = torch.nn.parallel.DistributedDataParallel( 50 | diff, device_ids=[local_rank], output_device=local_rank) 51 | 52 | train_set = get_dataset(name=opt.dataset, root="./data", train=True, flip=opt.flip) 53 | print0("train dataset:", len(train_set)) 54 | 55 | total_bs = opt.batch_size 56 | total_gpus = dist.get_world_size() 57 | per_gpu_bs = total_bs // total_gpus 58 | train_loader, sampler = DataLoaderDDP(train_set, 59 | batch_size=per_gpu_bs, 60 | shuffle=True) 61 | 62 | lr = opt.lrate 63 | print0("Using DDP, effective batch size = %d = %d * %d, lr = %f" % (total_bs, total_gpus, per_gpu_bs, lr)) 64 | optim = get_optimizer(diff.parameters(), opt, lr=lr) 65 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 66 | 67 | if opt.load_epoch != -1: 68 | target = os.path.join(model_dir, f"model_{opt.load_epoch}.pth") 69 | print0("loading model at", target) 70 | checkpoint = torch.load(target, map_location=device) 71 | diff.load_state_dict(checkpoint['MODEL']) 72 | if local_rank == 0: 73 | ema.load_state_dict(checkpoint['EMA']) 74 | optim.load_state_dict(checkpoint['opt']) 75 | 76 | for ep in range(opt.load_epoch + 1, opt.n_epoch): 77 | for g in optim.param_groups: 78 | if ep < opt.warm_epoch: 79 | g['lr'] = lr * min((ep + 1.0) / opt.warm_epoch, 1.0) # warmup 80 | else: 81 | if not hasattr(opt, 'tref_epoch'): 82 | opt.tref_epoch = opt.n_epoch 83 | g['lr'] = lr / math.sqrt(max(ep / opt.tref_epoch, 1.0)) # inverse square root 84 | sampler.set_epoch(ep) 85 | dist.barrier() 86 | # training 87 | diff.train() 88 | if local_rank == 0: 89 | now_lr = optim.param_groups[0]['lr'] 90 | print(f'epoch {ep}, lr {now_lr:f}') 91 | meter = Meter(n_items=1) 92 | pbar = tqdm(train_loader) 93 | else: 94 | pbar = train_loader 95 | for x, c in pbar: 96 | optim.zero_grad() 97 | x = x.to(device) 98 | loss = diff(x, use_amp=use_amp) 99 | scaler.scale(loss).backward() 100 | scaler.step(optim) 101 | scaler.update() 102 | 103 | # logging 104 | dist.barrier() 105 | loss = reduce_tensor(loss) 106 | if local_rank == 0: 107 | ema.update() 108 | meter.update([loss.item()]) 109 | pbar.set_description(f"loss: {meter.get(0):.4f}") 110 | 111 | # testing 112 | if local_rank == 0: 113 | writer.add_scalar('lr', now_lr, ep) 114 | writer.add_scalar('loss', meter.get(0), ep) 115 | if ep % 100 == 0 or ep == opt.n_epoch - 1: 116 | pass 117 | else: 118 | continue 119 | 120 | if opt.model_type == 'EDM': 121 | ema_sample_method = ema.ema_model.edm_sample 122 | else: 123 | raise NotImplementedError 124 | 125 | ema.ema_model.eval() 126 | with torch.no_grad(): 127 | x_gen = ema_sample_method(opt.n_sample, x.shape[1:]) 128 | # save an image of currently generated samples (top rows) 129 | # followed by real images (bottom rows) 130 | x_real = x[:opt.n_sample] 131 | x_all = torch.cat([x_gen.cpu(), x_real.cpu()]) 132 | grid = make_grid(x_all, nrow=10) 133 | 134 | save_path = os.path.join(vis_dir, f"image_ep{ep}_ema.png") 135 | save_image(grid, save_path) 136 | print('saved image at', save_path) 137 | 138 | # optionally save model 139 | if opt.save_model: 140 | checkpoint = { 141 | 'MODEL': diff.state_dict(), 142 | 'EMA': ema.state_dict(), 143 | 'opt': optim.state_dict(), 144 | } 145 | save_path = os.path.join(model_dir, f"model_{ep}.pth") 146 | torch.save(checkpoint, save_path) 147 | print('saved model at', save_path) 148 | 149 | 150 | if __name__ == "__main__": 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("--config", type=str) 153 | parser.add_argument("--use_amp", action='store_true', default=False) 154 | opt = parser.parse_args() 155 | opt.local_rank = int(os.environ['LOCAL_RANK']) 156 | print0(opt) 157 | 158 | init_seeds(no=opt.local_rank) 159 | dist.init_process_group(backend='nccl') 160 | torch.cuda.set_device(opt.local_rank) 161 | train(opt) 162 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | from torch.utils.data import DataLoader 7 | 8 | # ===== Meter ===== 9 | 10 | class Meter: 11 | def __init__(self, n_items, ema=0.95): 12 | self.data = [None for _ in range(n_items)] 13 | self.ema = ema 14 | 15 | def update(self, x): 16 | for i in range(len(x)): 17 | if self.data[i] is None: 18 | self.data[i] = x[i] 19 | else: 20 | self.data[i] = self.ema * self.data[i] + (1 - self.ema) * x[i] 21 | 22 | def get(self, i): 23 | return self.data[i] 24 | 25 | # ===== Configs ===== 26 | 27 | class Config(object): 28 | def __init__(self, dic): 29 | for key in dic: 30 | setattr(self, key, dic[key]) 31 | 32 | def get_optimizer(parameters, opt, lr): 33 | if not hasattr(opt, 'optim'): 34 | return torch.optim.Adam(parameters, lr=lr) 35 | elif opt.optim == 'Adam': 36 | return torch.optim.Adam(parameters, **opt.optim_args, lr=lr) 37 | elif opt.optim == 'AdamW': 38 | return torch.optim.AdamW(parameters, **opt.optim_args, lr=lr) 39 | else: 40 | raise NotImplementedError() 41 | 42 | # ===== Multi-GPU training ===== 43 | 44 | def init_seeds(RANDOM_SEED=1337, no=0): 45 | RANDOM_SEED += no 46 | print("local_rank = {}, seed = {}".format(no, RANDOM_SEED)) 47 | random.seed(RANDOM_SEED) 48 | np.random.seed(RANDOM_SEED) 49 | torch.manual_seed(RANDOM_SEED) 50 | torch.cuda.manual_seed_all(RANDOM_SEED) 51 | torch.backends.cudnn.deterministic = True 52 | torch.backends.cudnn.benchmark = False 53 | 54 | 55 | def reduce_tensor(tensor): 56 | rt = tensor.clone() 57 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 58 | rt /= dist.get_world_size() 59 | return rt 60 | 61 | 62 | def gather_tensor(tensor): 63 | tensor_list = [tensor.clone() for _ in range(dist.get_world_size())] 64 | dist.all_gather(tensor_list, tensor) 65 | tensor_list = torch.cat(tensor_list, dim=0) 66 | return tensor_list 67 | 68 | 69 | def DataLoaderDDP(dataset, batch_size, shuffle=True): 70 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) 71 | dataloader = DataLoader( 72 | dataset, 73 | batch_size=batch_size, 74 | sampler=sampler, 75 | num_workers=1, 76 | ) 77 | return dataloader, sampler 78 | 79 | 80 | def print0(*args, **kwargs): 81 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 82 | print(*args, **kwargs) 83 | --------------------------------------------------------------------------------