├── README.md ├── calc_fid.py ├── calc_fid.sh ├── data.py ├── disentangle.sh ├── eval_disentangle.sh ├── eval_disentanglement.py ├── eval_fid.sh ├── flowchart.drawio-1.png ├── gen_fid.sh ├── gen_fid_stats.py ├── graphicalabstract.drawio_v2-1.png ├── interpolate.sh ├── latent_quality.sh ├── models.py ├── modules.py ├── run.py ├── run.sh ├── sampling.py ├── save_latent.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [InfoDiffusion: Representation Learning Using Information Maximizing Diffusion Models](https://arxiv.org/abs/2306.08757) (ICML 2023) 2 | By [Yingheng Wang](https://isjakewong.github.io/), [Yair Schiff](https://yair-schiff.github.io), [Aaron Gokaslan](https://skylion007.github.io), [Weishen Pan](https://vivo.weill.cornell.edu/display/cwid-wep4001), 3 | [Fei Wang](https://wcm-wanglab.github.io/), [Chris De Sa](https://www.cs.cornell.edu/~cdesa/), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/) 4 | 5 | [![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](https://isjakewong.github.io/infodiffusion-page/) 6 | [![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2306.08757) 7 | 8 | 9 | 10 | 13 | 16 | 17 |
11 | 12 | 14 | 15 |
18 | 19 | We introduce *InfoDiffusion*, a principled probabilistic extension of diffusion models that supports low-dimensional latents with associated variational learning objectives that are regularized with a mutual information term. We show that these algorithms simultaneously yield high-quality samples and latent representations, achieving competitive performance with state-of-the-art methods on both fronts. 20 | 21 | 22 | In this repo, we release: 23 | * **The Auxiliary-Variable Diffusion Models (AVDM)**: 24 | 1. Diffusion decoder conditioned on auxiliary variable using AdaNorm 25 | 2. Simplified loss calucation for auxiliary latent variables with semantic prior 26 | * **Baseline implementations** [[Examples]](#baselines): 27 | 1. A set of model variants from the VAE family (VAE, $\beta$-VAE, InfoVAE) with different priors (Gaussian, Mixture of Gaussians, spiral). 28 | 2. A simplified version of Diffusion Autoencoder [DiffAE](https://arxiv.org/abs/2111.15640) within our AVDM framework. 29 | 3. A minimal and efficient implementation of vanilla diffusion models. 30 | * **Evaluation metrics**: 31 | 1. Generation quality: Fréchet inception distance (FID). 32 | 2. Latent quality: latent space interpolation, latent variables for classification. 33 | 3. Disentanglement: [DCI](https://openreview.net/forum?id=By-7dz-AZ) score, [TAD](https://link.springer.com/chapter/10.1007/978-3-031-19812-0_3) score. 34 | * **Samplers**: 35 | 1. [DDPM](https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html) sampling and [DDIM](https://arxiv.org/abs/2010.02502) sampling. 36 | 2. Two phase sampling where these two phases samples from regular diffusion models and VADM consecutivley. 37 | 3. Latent sampling that has an auxiliary latent diffusion model used to sample $\mathbf{z}_t$ along with $\mathbf{x}_t$. 38 | 4. Reverse DDIM sampling to visualize the latent $\mathbf{x}_T$ from $\mathbf{x}_0$. 39 | 40 | 41 | ## Code Organization 42 | 1. ```run.py```: Routines for training and evaluation 43 | 2. ```models.py```: Diffusion models (InfoDiffusion, DiffAE, regular diffusion), VAEs (InfoVAE, $\beta$-VAE, VAE) 44 | 3. ```modules.py```: Neural network blocks 45 | 4. ```sampling.py```: DDPM/DDIM sampler, Reverse DDIM sampler, Two-phase sampler, Latent sampler 46 | 5. ```utils.py```: LR scheduler, logging, utils to calculate priors 47 | 7. ```gen_fid_stats.py```: Generate stats used for FID calculation 48 | 8. ```calc_fid.py```: Calculation FID scores 49 | 50 | 51 | 52 | 53 | ## Getting started in this repository 54 | 55 | To get started, create a conda environment containing the required dependencies. 56 | 57 | ```bash 58 | conda create -n infodiffusion 59 | conda activate infodiffusion 60 | pip install -r requirements.txt 61 | ``` 62 | 63 | Run the training using the bash script: 64 | ```bash 65 | bash run.sh 66 | ``` 67 | or 68 | ```bash 69 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64 70 | ``` 71 | the arguments in this script are given to train a diffusion model `--model diff` using Maximum Mean Discrepancy (MMD) `--mmd_weight 0.1` with a regular Gaussian prior `--prior regular` on CelebA `--dataset celeba`. 72 | 73 | 74 | ## Evaluation 75 | 76 | Below, we describe the steps required for evaluation the trained diffusion models. 77 | Throughout, the main entry point for running experiments is the [`run.py`](./run.py) script. 78 | We also provide sample `bash` scripts for launching these evaluation runs. 79 | In general, different evaluation runs can be switched using `--mode`, which takes one of the following values: 80 | * `eval`: sampling images from the trained diffusion model. 81 | * `eval_fid`: sampling images for FID score calculation. 82 | * `save_latent`: save the auxiliary variables. 83 | * `disentangle`: run evaluation on auxiliary variable disentanglement. 84 | * `interpolate`: run interpolation between two given input images. 85 | * `latent_quality`: save the auxiliary variables and latent variables for classification. 86 | * `train_latent_ddim`: train the latent diffusion models used in latent sampler. 87 | * `plot_latent`: plot the latent space. 88 | However, the quantitative disentanglement evaluation, the latent classification, and the FID score calculation need multiple steps. 89 | 90 | ### Disentanglement evluation 91 | To evaluate latent disentanglement, we need to conduct the following steps: 92 | 1. ```save_latent.sh```: save the auxiliary variables $\mathbf{z}$ and latent variables $\mathbf{x_T}$. 93 | 2. ```eval_disentangle.sh```: evaluate the latent disentanglement by computing DCI and TAD scores. 94 | 95 | #### Save the latents 96 | ```bash 97 | python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64 98 | ``` 99 | 100 | #### Eval the disentanglement metrics on latents 101 | ```bash 102 | python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64 103 | ``` 104 | 105 | ### Latent classification 106 | To run latent classification, we need to conduct the following steps: 107 | 1. ```save_latent.sh```: save the auxiliary variables $\mathbf{z}$ and latent variables $\mathbf{x_T}$ used to train the classifier. 108 | 2. ```eval_disentangle.sh```: use the same evaluation script for disentanglement to train the classifier and obtain the classification accuracy. 109 | 110 | #### Save the latents 111 | ```bash 112 | python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64 113 | ``` 114 | 115 | #### Train latent classifier and evaluate 116 | ```bash 117 | python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64 118 | ``` 119 | 120 | ### FID calculation 121 | 122 | To calculate the FID scores, we need to conduct the following steps: 123 | 1. ```eval_fid.sh```: train diffusion models and latent diffusion models and generate samples from them. 124 | 2. ```gen_fid.sh```: generate FID stats given the dataset name and the folder storing the preprocessed images from this dataset. 125 | 3. ```calc_fid.sh```: calculate FID scores given the dataset name and the folder storing the generated samples. 126 | 127 | We also provide the commands in the above steps: 128 | #### Train and sample: 129 | ```bash 130 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64 131 | 132 | python run.py --model diff --mode save_latent --disent_metric tad --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --deterministic --prior regular --r_seed 64 133 | 134 | python run.py --model diff --mode train_latent_ddim --a_dim 256 --epochs 50 --mmd_weight 0.1 --dataset celeba --deterministic --save_epoch 10 --prior regular --r_seed 64 135 | 136 | python run.py --model diff --mode eval_fid --split_step 500 --a_dim 256 --batch_size 256 --mmd_weight 0.1 --sampling_number 10000 --epochs 50 --dataset celeba --is_latent --prior regular --r_seed 64 137 | ``` 138 | 139 | #### Generate FID stats: 140 | ```bash 141 | python gen_fid_stats.py celeba ./celeba_imgs 142 | ``` 143 | 144 | #### Calculate FID scores: 145 | ```bash 146 | python calc_fid.py celeba ./imgs/celeba_32d_0.1mmd/eval-fid-latent 147 | ``` 148 | 149 | Note: please refer to [clean-fid](https://github.com/GaParmar/clean-fid) for more options to calculate FID. 150 | 151 | ## Baselines 152 | 153 | 154 | The baselines can be easily switched by using the argument `--model`, which takes in one of the following values `['diff', 'vae', 'vanilla']` where `'diff'` is for AVDM, `'vae'` is for the VAE model family, and `'vanilla'` is for the regular diffusion models. Below is an example to train InfoVAE: 155 | ```bash 156 | python run.py --model vae --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --prior regular --r_seed 64 157 | ``` 158 | 159 | ## Notes and disclaimer 160 | The ```main``` branch provides codes and implementations optimized for representation learning tasks and ```InfoDiffusion-dev``` provides codes closer to the version for reproducing the results reported in the paper. 161 | 162 | This research code is provided as-is, without any support or guarantee of quality. However, if you identify any issues or areas for improvement, please feel free to raise an issue or submit a pull request. We will do our best to address them. 163 | 164 | ## Citation 165 | ``` 166 | @inproceedings{wang2023infodiffusion, 167 | title={Infodiffusion: Representation learning using information maximizing diffusion models}, 168 | author={Wang, Yingheng and Schiff, Yair and Gokaslan, Aaron and Pan, Weishen and Wang, Fei and De Sa, Christopher and Kuleshov, Volodymyr}, 169 | booktitle={International Conference on Machine Learning}, 170 | pages={36336--36354}, 171 | year={2023}, 172 | organization={PMLR} 173 | } 174 | ``` 175 | -------------------------------------------------------------------------------- /calc_fid.py: -------------------------------------------------------------------------------- 1 | from cleanfid import fid 2 | import sys 3 | if __name__ == '__main__': 4 | dataset_name = sys.argv[1] 5 | folder_1 = sys.argv[2] 6 | cleanfid_args = dict( 7 | dataset_name=dataset_name, 8 | dataset_res=64, 9 | num_gen=10000, 10 | dataset_split="custom" 11 | ) 12 | fid_score = fid.compute_fid(folder_1, **cleanfid_args) 13 | print(f'fid: score: {fid_score}') 14 | kid_score = fid.compute_kid(folder_1, **cleanfid_args) 15 | print(f'kid: score: {kid_score}') -------------------------------------------------------------------------------- /calc_fid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python calc_fid.py celeba ./imgs/celeba_32d_0.1mmd/eval-fid-latent -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision.datasets import ImageFolder 7 | 8 | class Crop: 9 | def __init__(self, x1, x2, y1, y2): 10 | self.x1 = x1 11 | self.x2 = x2 12 | self.y1 = y1 13 | self.y2 = y2 14 | 15 | def __call__(self, img): 16 | return torchvision.transforms.crop(img, self.x1, self.y1, self.x2 - self.x1, 17 | self.y2 - self.y1) 18 | 19 | def __repr__(self): 20 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 21 | self.x1, self.x2, self.y1, self.y2) 22 | 23 | 24 | def d2c_crop(): 25 | # from D2C paper for CelebA dataset. 26 | cx = 89 27 | cy = 121 28 | x1 = cy - 64 29 | x2 = cy + 64 30 | y1 = cx - 64 31 | y2 = cx + 64 32 | return Crop(x1, x2, y1, y2) 33 | 34 | 35 | class CustomTensorDataset(Dataset): 36 | def __init__(self, data, latents_values, latents_classes): 37 | self.data = data 38 | self.latents_values = latents_values 39 | self.latents_classes = latents_classes 40 | 41 | def __getitem__(self, index): 42 | return (torch.from_numpy(self.data[index]).float(), 43 | torch.from_numpy(self.latents_values[index]).float(), 44 | torch.from_numpy(self.latents_classes[index]).int()) 45 | 46 | def __len__(self): 47 | return self.data.shape[0] 48 | 49 | 50 | class CustomImageFolder(ImageFolder): 51 | def __init__(self, root, transform=None): 52 | super(CustomImageFolder, self).__init__(root, transform) 53 | 54 | def __getitem__(self, index): 55 | path = self.imgs[index][0] 56 | img = self.loader(path) 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | 60 | return img 61 | 62 | 63 | def get_dataset_config(args): 64 | if args.dataset == 'fmnist': 65 | args.input_channels = 1 66 | args.unets_channels = 32 67 | args.encoder_channels = 32 68 | args.input_size = 32 69 | elif args.dataset == 'mnist': 70 | args.input_channels = 1 71 | args.unets_channels = 32 72 | args.encoder_channels = 32 73 | args.input_size = 32 74 | elif args.dataset == 'dsprites': 75 | args.input_channels = 1 76 | args.unets_channels = 32 77 | args.encoder_channels = 32 78 | args.input_size = 32 79 | elif args.dataset == 'celeba': 80 | args.input_channels = 3 81 | args.unets_channels = 64 82 | args.encoder_channels = 64 83 | args.input_size = 64 84 | elif args.dataset == 'cifar10': 85 | args.input_channels = 3 86 | args.unets_channels = 64 87 | args.encoder_channels = 64 88 | args.input_size = 32 89 | elif args.dataset == 'chairs': 90 | args.input_channels = 3 91 | args.unets_channels = 32 92 | args.encoder_channels = 32 93 | args.input_size = 64 94 | elif args.dataset == 'ffhq': 95 | args.input_channels = 3 96 | args.unets_channels = 64 97 | args.encoder_channels = 64 98 | args.input_size = 64 99 | 100 | shape = (args.input_channels, args.input_size, args.input_size) 101 | 102 | return shape 103 | 104 | 105 | def get_dataset(args): 106 | if args.dataset == 'fmnist': 107 | return get_fmnist(args) 108 | elif args.dataset == 'mnist': 109 | return get_mnist(args) 110 | elif args.dataset == 'celeba': 111 | return get_celeba(args) 112 | elif args.dataset == 'cifar10': 113 | return get_cifar10(args) 114 | elif args.dataset == 'dsprites': 115 | return get_dsprites(args) 116 | elif args.dataset == 'chairs': 117 | return get_chairs(args) 118 | elif args.dataset == 'ffhq': 119 | return get_ffhq(args) 120 | 121 | 122 | def get_mnist(args): 123 | transform = torchvision.transforms.Compose([ 124 | torchvision.transforms.Resize((args.input_size, args.input_size)), 125 | torchvision.transforms.ToTensor(), 126 | torchvision.transforms.Lambda(lambda t: (t * 2) - 1), 127 | ]) 128 | 129 | dataset = torchvision.datasets.MNIST(root = args.data_dir, train=True, download=True, transform = transform) 130 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, num_workers = 4) 131 | 132 | return dataloader 133 | 134 | 135 | def get_fmnist(args): 136 | transform = torchvision.transforms.Compose([ 137 | torchvision.transforms.Resize((args.input_size, args.input_size)), 138 | torchvision.transforms.RandomHorizontalFlip(), 139 | torchvision.transforms.ToTensor(), 140 | torchvision.transforms.Lambda(lambda t: (t * 2) - 1), 141 | ]) 142 | 143 | dataset = torchvision.datasets.FashionMNIST(root = args.data_dir, train=True, download=True, transform = transform) 144 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, num_workers = 4) 145 | 146 | return dataloader 147 | 148 | 149 | def get_celeba(args, 150 | as_tensor: bool = True, 151 | do_augment: bool = True, 152 | do_normalize: bool = True, 153 | crop_d2c: bool = False): 154 | if crop_d2c: 155 | transform = [ 156 | d2c_crop(), 157 | torchvision.transforms.Resize(args.input_size), 158 | ] 159 | else: 160 | transform = [ 161 | torchvision.transforms.Resize(args.input_size), 162 | torchvision.transforms.CenterCrop(args.input_size), 163 | ] 164 | 165 | if do_augment: 166 | transform.append(torchvision.transforms.RandomHorizontalFlip()) 167 | if as_tensor: 168 | transform.append(torchvision.transforms.ToTensor()) 169 | if do_normalize: 170 | transform.append( 171 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) 172 | transform = torchvision.transforms.Compose(transform) 173 | 174 | if args.mode in ['attr_classification', 'eval_fid', 'reconstruction']: 175 | train_set = torchvision.datasets.CelebA(root = args.data_dir, split = "train", download = True, transform = transform) 176 | valid_set = torchvision.datasets.CelebA(root = args.data_dir, split = "valid", download = True, transform = transform) 177 | test_set = torchvision.datasets.CelebA(root = args.data_dir, split = "test", download = True, transform = transform) 178 | train_loader = torch.utils.data.DataLoader(train_set, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4) 179 | valid_loader = torch.utils.data.DataLoader(valid_set, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4) 180 | test_loader = torch.utils.data.DataLoader(test_set, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4) 181 | return (train_loader, valid_loader, test_loader) 182 | else: 183 | dataset = torchvision.datasets.CelebA(root = args.data_dir, split = "train", download = True, transform = transform) 184 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = False, num_workers = 4) 185 | 186 | return dataloader 187 | 188 | 189 | def get_cifar10(args): 190 | transform = torchvision.transforms.Compose([ 191 | torchvision.transforms.RandomHorizontalFlip(), 192 | torchvision.transforms.ToTensor(), 193 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 194 | ]) 195 | 196 | dataset = torchvision.datasets.CIFAR10(root = args.data_dir, train = True, download = True, transform = transform) 197 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4) 198 | return dataloader 199 | 200 | 201 | def get_dsprites(args): 202 | root = os.path.join(args.data_dir+'/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz') 203 | file = np.load(root, encoding='latin1') 204 | data = file['imgs'][:, np.newaxis, :, :] 205 | latents_values = file['latents_values'] 206 | latents_classes = file['latents_classes'] 207 | train_kwargs = {'data':data, 'latents_values':latents_values, 'latents_classes':latents_classes} 208 | dset = CustomTensorDataset 209 | dataset = dset(**train_kwargs) 210 | 211 | dataloader = DataLoader(dataset, 212 | batch_size=args.batch_size, 213 | shuffle=True, 214 | num_workers=4, 215 | pin_memory=True, 216 | drop_last=True) 217 | 218 | return dataloader 219 | 220 | 221 | def get_chairs(args): 222 | transform = torchvision.transforms.Compose([ 223 | torchvision.transforms.Resize((args.input_size, args.input_size)), 224 | torchvision.transforms.RandomHorizontalFlip(), 225 | torchvision.transforms.ToTensor(), 226 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 227 | ]) 228 | 229 | dataset = CustomImageFolder(root = args.data_dir+'/3DChairs', transform = transform) 230 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4) 231 | return dataloader 232 | 233 | 234 | def get_ffhq(args): 235 | transform = torchvision.transforms.Compose([ 236 | torchvision.transforms.Resize((args.input_size, args.input_size)), 237 | torchvision.transforms.RandomHorizontalFlip(), 238 | torchvision.transforms.ToTensor(), 239 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 240 | ]) 241 | 242 | dataset = CustomImageFolder(root = args.data_dir+'/ffhq', transform = transform) 243 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = False, num_workers = 4) 244 | return dataloader -------------------------------------------------------------------------------- /disentangle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python run.py --model diff --mode disentangle --img_id 0 --mmd_weight 0.1 --a_dim 32 --epochs 20 --dataset celeba --deterministic --prior regular --r_seed 64 4 | -------------------------------------------------------------------------------- /eval_disentangle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64 3 | -------------------------------------------------------------------------------- /eval_disentanglement.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | import scipy 5 | import torch 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.metrics import roc_auc_score, accuracy_score 8 | from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier 9 | from sklearn.preprocessing import StandardScaler 10 | from sklearn.model_selection import KFold 11 | from utils import generate_exp_string 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--r_seed', type=int, default=0, 17 | help='the value of given random seed') 18 | parser.add_argument('--img_id', type=int, default=0, 19 | help='the id of given img') 20 | parser.add_argument('--model', required=True, 21 | choices=['diff', 'vae', 'vanilla'], help='which type of model to run') 22 | parser.add_argument('--mode', required=True, 23 | choices=['train', 'eval', 'eval_fid', 'save_latent', 'disentangle', 24 | 'interpolate', 'save_original_img', 'latent_quality', 25 | 'train_latent_ddim', 'plot_latent'], help='which mode to run') 26 | parser.add_argument('--prior', required=True, 27 | choices=['regular', '10mix', 'roll'], help='which type of prior to run') 28 | parser.add_argument('--kld_weight', type=float, default=0, 29 | help='weight of kld loss') 30 | parser.add_argument('--mmd_weight', type=float, default=0.1, 31 | help='weight of mmd loss') 32 | parser.add_argument('--use_C', action='store_true', 33 | default=False, help='use control constant or not') 34 | parser.add_argument('--C_max', type=float, default=25, 35 | help='control constant of kld loss (orig defualt: 25 for simple, 50 for complex)') 36 | parser.add_argument('--dataset', required=True, 37 | choices=['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites', 'chairs', 'ffhq'], 38 | help='training dataset') 39 | parser.add_argument('--img_folder', default='./imgs', 40 | help='path to save sampled images') 41 | parser.add_argument('--log_folder', default='./logs', 42 | help='path to save logs') 43 | parser.add_argument('-e', '--epochs', type=int, default=20, 44 | help='number of epochs to train') 45 | parser.add_argument('--save_epochs', type=int, default=5, 46 | help='number of epochs to save model') 47 | parser.add_argument('--batch_size', type=int, default=64, 48 | help='training batch size') 49 | parser.add_argument('--learning_rate', type=float, default=0.0001, 50 | help='learning rate') 51 | parser.add_argument('--optimizer', default='adam', choices=['adam'], 52 | help='optimization algorithm') 53 | parser.add_argument('--model_folder', default='./models', 54 | help='folder where logs will be stored') 55 | parser.add_argument('--deterministic', action='store_true', 56 | default=False, help='deterministid sampling') 57 | parser.add_argument('--input_channels', type=int, default=1, 58 | help='number of input channels') 59 | parser.add_argument('--unets_channels', type=int, default=64, 60 | help='number of input channels') 61 | parser.add_argument('--encoder_channels', type=int, default=64, 62 | help='number of input channels') 63 | parser.add_argument('--input_size', type=int, default=32, 64 | help='expected size of input') 65 | parser.add_argument('--a_dim', type=int, default=32, required=True, 66 | help='dimensionality of auxiliary variable') 67 | parser.add_argument('--beta1', type=float, default=1e-5, 68 | help='value of beta 1') 69 | parser.add_argument('--betaT', type=float, default=1e-2, 70 | help='value of beta T') 71 | parser.add_argument('--diffusion_steps', type=int, default=1000, 72 | help='number of diffusion steps') 73 | parser.add_argument('--split_step', type=int, default=500, 74 | help='the step for splitting two phases') 75 | parser.add_argument('--sampling_number', type=int, default=16, 76 | help='number of sampled images') 77 | parser.add_argument('--data_dir', type=str, default='./data') 78 | parser.add_argument('--tb_logger', action='store_true', 79 | help='use tensorboard logger.') 80 | parser.add_argument('--is_latent', action='store_true', 81 | help='use latent diffusion for unconditional sampling.') 82 | parser.add_argument('--is_bottleneck', action='store_true', 83 | help='only fuse aux variable in bottleneck layers.') 84 | args = parser.parse_args() 85 | 86 | return args 87 | 88 | """ Impementation of the DCI metric is from: 89 | https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/evaluation/metrics/dci.py 90 | """ 91 | def compute_dci(mus_train, ys_train, mus_test, ys_test): 92 | """Computes score based on both training and testing codes and factors.""" 93 | scores = {} 94 | importance_matrix, train_err, test_err = compute_importance_gbt( 95 | mus_train, ys_train, mus_test, ys_test) 96 | assert importance_matrix.shape[0] == mus_train.shape[0] 97 | assert importance_matrix.shape[1] == ys_train.shape[0] 98 | scores["informativeness_train"] = train_err 99 | scores["informativeness_test"] = test_err 100 | scores['importance'] = importance_matrix 101 | scores["disentanglement"] = disentanglement(importance_matrix) 102 | scores["completeness"] = completeness(importance_matrix) 103 | return scores 104 | 105 | def compute_importance_gbt(x_train, y_train, x_test, y_test): 106 | """Compute importance based on gradient boosted trees.""" 107 | num_factors = y_train.shape[0] 108 | num_codes = x_train.shape[0] 109 | importance_matrix = np.zeros(shape=[num_codes, num_factors], dtype=np.float64) 110 | train_loss = [] 111 | test_loss = [] 112 | for i in range(num_factors): 113 | model = GradientBoostingClassifier() 114 | model.fit(x_train.T, y_train[i, :]) 115 | importance_matrix[:, i] = np.abs(model.feature_importances_) 116 | train_loss.append(np.mean(model.predict(x_train.T) == y_train[i, :])) 117 | test_loss.append(np.mean(model.predict(x_test.T) == y_test[i, :])) 118 | return importance_matrix, np.mean(train_loss), np.mean(test_loss) 119 | 120 | 121 | def disentanglement_per_code(importance_matrix): 122 | """Compute disentanglement score of each code.""" 123 | # importance_matrix is of shape [num_codes, num_factors]. 124 | return 1. - scipy.stats.entropy(importance_matrix.T + 1e-11, base=importance_matrix.shape[1]) 125 | 126 | 127 | def disentanglement(importance_matrix): 128 | """Compute the disentanglement score of the representation.""" 129 | per_code = disentanglement_per_code(importance_matrix) 130 | if importance_matrix.sum() == 0.: 131 | importance_matrix = np.ones_like(importance_matrix) 132 | code_importance = importance_matrix.sum(axis=1) / importance_matrix.sum() 133 | 134 | return np.sum(per_code*code_importance) 135 | 136 | 137 | def completeness_per_factor(importance_matrix): 138 | """Compute completeness of each factor.""" 139 | # importance_matrix is of shape [num_codes, num_factors]. 140 | return 1. - scipy.stats.entropy(importance_matrix + 1e-11, 141 | base=importance_matrix.shape[0]) 142 | 143 | 144 | def completeness(importance_matrix): 145 | """"Compute completeness of the representation.""" 146 | per_factor = completeness_per_factor(importance_matrix) 147 | if importance_matrix.sum() == 0.: 148 | importance_matrix = np.ones_like(importance_matrix) 149 | factor_importance = importance_matrix.sum(axis=0) / importance_matrix.sum() 150 | return np.sum(per_factor*factor_importance) 151 | 152 | 153 | class PredMetric(): 154 | """ Impementation to calculate the AUROC for predicting each attribute 155 | """ 156 | def __init__(self, predictor = "RandomForest", output_type = "b", attr_names = None, *args, **kwargs): 157 | super(PredMetric, self).__init__(*args, **kwargs) 158 | 159 | self.attr_names = attr_names 160 | self._predictor = predictor 161 | self.output_type = output_type 162 | if predictor == "Linear": 163 | self.predictor_class = LogisticRegression 164 | self.params = {} 165 | # weights 166 | self.importances_attr = "coef_" 167 | elif predictor == "RandomForest": 168 | self.predictor_class = RandomForestClassifier 169 | self.importances_attr = "feature_importances_" 170 | self.params = {"oob_score": True} 171 | else: 172 | raise NotImplementedError() 173 | 174 | self.TINY = 1e-12 175 | 176 | def evaluate(self, train_codes, train_attrs, test_codes, test_attrs): 177 | R = [] 178 | results = [] 179 | # train_codes, test_codes, train_attrs, test_attrs = train_test_split(codes, attrs, test_size=0.2) 180 | print("Calculate for attribute:") 181 | for j in range(train_attrs.shape[-1]): 182 | if isinstance(self.params, dict): 183 | predictor = self.predictor_class(**self.params) 184 | elif isinstance(self.params, list): 185 | predictor = self.predictor_class(**self.params[j]) 186 | else: 187 | raise NotImplementedError() 188 | predictor.fit(train_codes, train_attrs[:, j]) 189 | 190 | r = getattr(predictor, self.importances_attr)[:, None] 191 | R.append(np.abs(r)) 192 | # extract relative importance of each code variable in 193 | # predicting the j attribute 194 | if self.output_type == "b": 195 | test_pred_prob = predictor.predict_proba(test_codes)[:, 1] 196 | tmp_result = roc_auc_score(test_attrs[:, j], test_pred_prob) 197 | elif self.output_type == "c": 198 | test_pred = predictor.predict(test_codes) 199 | tmp_result = accuracy_score(test_attrs[:, j], test_pred) 200 | results.append(tmp_result) 201 | if self.attr_names is not None: 202 | print(j, self.attr_names[j], tmp_result) 203 | else: 204 | print(j, tmp_result) 205 | 206 | # R = np.hstack(R) #columnwise, predictions of each z 207 | results = np.array(results) 208 | 209 | return { 210 | "{}_avg_result".format(self._predictor): results.mean(), 211 | "{}_result".format(self._predictor): results 212 | } 213 | 214 | # function that takes a lists of latent indices, thresholds, and signs for classification 215 | class LatentClass(object): 216 | def __init__(self, targ_ind, lat_ind, is_pos, thresh, __max, __min): 217 | super(LatentClass, self).__init__() 218 | self.targ_ind = targ_ind 219 | self.lat_ind = lat_ind 220 | self.is_pos = is_pos 221 | self.thresh = thresh 222 | self._max = __max 223 | self._min = __min 224 | self.it = list(zip(self.targ_ind, self.lat_ind, self.is_pos, self.thresh)) 225 | 226 | def __call__(self, z, y_dim): 227 | # expect z to be [batch, z_dim] 228 | out = torch.ones((z.shape[0], y_dim)) 229 | for t_i, l_i, is_pos, t in self.it: 230 | ma, mi = self._max[l_i], self._min[l_i] 231 | thr = t * (ma - mi) + mi 232 | res = (z[:, l_i] >= thr if is_pos else z[:, l_i] < thr).type(torch.int) 233 | out[:, t_i] = res 234 | return out 235 | 236 | class TADMetric(): 237 | """ Impementation of the metric in: 238 | NashAE: Disentangling Representations Through Adversarial Covariance Minimization 239 | The code is from: 240 | https://github.com/ericyeats/nashae-beamsynthesis 241 | """ 242 | def __init__(self, y_dim, all_attrs): 243 | self.y_dim = y_dim 244 | self.all_attrs = all_attrs 245 | 246 | def calculate_auroc(self, targ, targ_ind, lat_ind, z, _ma, _mi, stepsize=0.01): 247 | thr = torch.arange(0.0, 1.0001, step=stepsize) 248 | total = targ.shape[0] 249 | pos_total = targ.sum(dim=0)[targ_ind].item() 250 | neg_total = total - pos_total 251 | p_fpr_tpr = torch.zeros((thr.shape[0], 2)) 252 | n_fpr_tpr = torch.zeros((thr.shape[0], 2)) 253 | for i, t in enumerate(thr): 254 | local_lc = LatentClass([targ_ind], [lat_ind], [True], [t], _ma, _mi) 255 | pred = local_lc(z.clone(), self.y_dim).to(targ.device) 256 | p_tp = torch.logical_and(pred == targ, pred).sum(dim=0)[targ_ind].item() 257 | p_fp = torch.logical_and(pred != targ, pred).sum(dim=0)[targ_ind].item() 258 | p_fpr_tpr[i][0] = p_fp / neg_total 259 | p_fpr_tpr[i][1] = p_tp / pos_total 260 | local_lc = LatentClass([targ_ind], [lat_ind], [False], [t], _ma, _mi) 261 | pred = local_lc(z.clone(), self.y_dim).to(targ.device) 262 | n_tp = torch.logical_and(pred == targ, pred).sum(dim=0)[targ_ind].item() 263 | n_fp = torch.logical_and(pred != targ, pred).sum(dim=0)[targ_ind].item() 264 | n_fpr_tpr[i][0] = n_fp / neg_total 265 | n_fpr_tpr[i][1] = n_tp / pos_total 266 | p_fpr_tpr = p_fpr_tpr.sort(dim=0)[0] 267 | n_fpr_tpr = n_fpr_tpr.sort(dim=0)[0] 268 | p_dists = p_fpr_tpr[1:, 0] - p_fpr_tpr[:-1, 0] 269 | p_area = (p_fpr_tpr[1:, 1] * p_dists).sum().item() 270 | n_dists = n_fpr_tpr[1:, 0] - n_fpr_tpr[:-1, 0] 271 | n_area = (n_fpr_tpr[1:, 1] * n_dists).sum().item() 272 | return p_area, n_area 273 | 274 | def aurocs(self, _z, targ, targ_ind, _ma, _mi): 275 | # perform a grid search of lat_ind to find the best classification metric 276 | aurocs = torch.ones(_z.shape[1]) * 0.5 # initialize as random guess 277 | for lat_ind in range(_z.shape[1]): 278 | if _ma[lat_ind] - _mi[lat_ind] > 0.2: 279 | p_auroc, n_auroc = self.calculate_auroc(targ, targ_ind, lat_ind, _z.clone(), _ma, _mi) 280 | m_auroc = max(p_auroc, n_auroc) 281 | aurocs[lat_ind] = m_auroc 282 | # print("{}\t{:1.3f}".format(lat_ind, m_auroc)) 283 | return aurocs 284 | 285 | def aurocs_search(self, a, y): 286 | aurocs_all = torch.ones((y.shape[1], a.shape[1])) * 0.5 287 | base_rates_all = y.sum(dim=0) 288 | base_rates_all = base_rates_all / y.shape[0] 289 | _ma = a.max(dim=0)[0] 290 | _mi = a.min(dim=0)[0] 291 | print("Calculate for attribute:") 292 | for i in range(y.shape[1]): 293 | print(i) 294 | for j in range(a.shape[1]): 295 | aurocs_all[i, j] = max(roc_auc_score(y.numpy()[:, i], a.numpy()[:, j]), roc_auc_score(y.numpy()[:, i], -a.numpy()[:, j])) 296 | # aurocs_all[i] = self.aurocs(a, y, i, _ma, _mi) 297 | return aurocs_all.cpu(), base_rates_all.cpu() 298 | 299 | def evaluate(self, a, y): 300 | auroc_result, base_rates_raw = self.aurocs_search(torch.FloatTensor(a), torch.IntTensor(y)) 301 | base_rates = base_rates_raw.where(base_rates_raw <= 0.5, 1. - base_rates_raw) 302 | targ = torch.IntTensor(y) 303 | dim_y = y.shape[1] 304 | 305 | thresh = 0.75 306 | ent_red_thresh = 0.2 307 | max_aur, argmax_aur = torch.max(auroc_result.clone(), dim=1) 308 | norm_diffs = torch.zeros(dim_y) 309 | aurs_diffs = torch.zeros(dim_y) 310 | for ind, tag, max_a, argmax_a, aurs in zip(range(dim_y), self.all_attrs, max_aur.clone(), argmax_aur.clone(), 311 | auroc_result.clone()): 312 | norm_aurs = (aurs.clone() - 0.5) / (aurs.clone()[argmax_a] - 0.5) 313 | aurs_next = aurs.clone() 314 | aurs_next[argmax_a] = 0.0 315 | aurs_diff = max_a - aurs_next.max() 316 | aurs_diffs[ind] = aurs_diff 317 | norm_aurs[argmax_a] = 0.0 318 | norm_diff = 1. - norm_aurs.max() 319 | norm_diffs[ind] = norm_diff 320 | 321 | # calculate mutual information shared between attributes 322 | # determine which share a lot of information with each other 323 | with torch.no_grad(): 324 | not_targ = 1 - targ 325 | j_prob = lambda x, y: torch.logical_and(x, y).sum() / x.numel() 326 | mi = lambda jp, px, py: 0. if jp == 0. or px == 0. or py == 0. else jp * torch.log(jp / (px * py)) 327 | 328 | # Compute the Mutual Information (MI) between the labels 329 | mi_mat = torch.zeros((dim_y, dim_y)) 330 | for i in range(dim_y): 331 | # get the marginal of i 332 | i_mp = targ[:, i].sum() / targ.shape[0] 333 | for j in range(dim_y): 334 | j_mp = targ[:, j].sum() / targ.shape[0] 335 | # get the joint probabilities of FF, FT, TF, TT 336 | # FF 337 | jp = j_prob(not_targ[:, i], not_targ[:, j]) 338 | pi = 1. - i_mp 339 | pj = 1. - j_mp 340 | mi_mat[i][j] += mi(jp, pi, pj) 341 | # FT 342 | jp = j_prob(not_targ[:, i], targ[:, j]) 343 | pi = 1. - i_mp 344 | pj = j_mp 345 | mi_mat[i][j] += mi(jp, pi, pj) 346 | # TF 347 | jp = j_prob(targ[:, i], not_targ[:, j]) 348 | pi = i_mp 349 | pj = 1. - j_mp 350 | mi_mat[i][j] += mi(jp, pi, pj) 351 | # TT 352 | jp = j_prob(targ[:, i], targ[:, j]) 353 | pi = i_mp 354 | pj = j_mp 355 | mi_mat[i][j] += mi(jp, pi, pj) 356 | 357 | mi_maxes, mi_inds = (mi_mat * (1 - torch.eye(dim_y))).max(dim=1) 358 | ent_red_prop = 1. - (mi_mat.diag() - mi_maxes) / mi_mat.diag() 359 | 360 | # calculate Average Norm AUROC Diff when best detector score is at a certain threshold 361 | filt = (max_aur >= thresh).logical_and(ent_red_prop <= ent_red_thresh) 362 | aurs_diffs_filt = aurs_diffs[filt] 363 | return aurs_diffs_filt.sum().item(), auroc_result.cpu().numpy(), len(aurs_diffs_filt) 364 | 365 | if __name__ == "__main__": 366 | dataset = "celeba" 367 | args = parse_args() 368 | data_dict = np.load("{}_{}_latent.npz".format(args.model, generate_exp_string(args).replace(".", "_")), allow_pickle=True) 369 | 370 | if dataset == "celeba": 371 | y_names = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 372 | 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 373 | 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 374 | 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 375 | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 376 | 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 377 | 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 378 | 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young' 379 | ] 380 | output_type = "b" 381 | elif dataset == "fmnist": 382 | y_names = ["Class"] 383 | output_type = "c" 384 | elif dataset == "cifar10": 385 | y_names = ["Class"] 386 | output_type = "c" 387 | elif dataset == "ffhq": 388 | y_names = ["Age", "Gender", "Glass"] 389 | output_type = "c" 390 | elif dataset == "3dshapes": 391 | y_names = ['Floor hue', 'Wall hue', 'Object hue:', 'Scale', 'Shape', 'Orientation'] 392 | output_type = "c" 393 | 394 | if dataset == "celeba": 395 | a = data_dict["all_a"][:10000,:] 396 | y = data_dict["all_attr"][:10000,:].astype(np.int) 397 | elif dataset == "ffhq": 398 | a = data_dict["all_a"][:,:] 399 | y = pd.read_csv("ffhq_labels.csv") 400 | y = y.values[:,2:].astype(np.int) 401 | y = y[:69952, :] 402 | elif dataset == "3dshapes": 403 | a = data_dict["all_a"][:10000, :] 404 | y = data_dict["all_attr"][:10000, :] 405 | 406 | y[:, 0] = y[:, 0] * 10 407 | y[:, 1] = y[:, 1] * 10 408 | y[:, 2] = y[:, 2] * 10 409 | y[:, 3] = y[:, 3] * 14 - 10.5 410 | y[:, 5] = y[:, 5] * 14 / 60 + 7 411 | y = y.astype(np.int) 412 | else: 413 | a = data_dict["all_a"] 414 | if len(data_dict["all_attr"].shape) == 2: 415 | y = data_dict["all_attr"][:, :].astype(np.int) 416 | else: 417 | y = data_dict["all_attr"][:, np.newaxis].astype(np.int) 418 | 419 | kf = KFold(n_splits=5, shuffle=True, random_state=0) 420 | preds_rf, avg_preds_rf = [], [] 421 | preds_ln, avg_preds_ln = [], [] 422 | if dataset == "celeba": 423 | tad_scores, tad_attrs = [], [] 424 | if dataset == "3dshapes": 425 | dci_scores = [] 426 | 427 | for tr_idx, te_idx in kf.split(a): 428 | tr_a, te_a = a[tr_idx], a[te_idx] 429 | tr_y, te_y = y[tr_idx], y[te_idx] 430 | std = StandardScaler() 431 | std.fit(tr_a) 432 | tr_a = std.transform(tr_a) 433 | te_a = std.transform(te_a) 434 | 435 | if dataset == "celeba": 436 | tad_metric = TADMetric(y.shape[1], y_names) 437 | tad_score, auroc_result, num_attr = tad_metric.evaluate(tr_a, tr_y) 438 | # 439 | print("TAD SCORE: ", tad_score, "Attributes Captured: ", num_attr) 440 | tad_scores.append(tad_score) 441 | tad_attrs.append(num_attr) 442 | # sns.heatmap(auroc_result.transpose(), xticklabels = y_names) 443 | # plt.show() 444 | 445 | if dataset == "3dshapes": 446 | dci_result = compute_dci(tr_a.transpose(), tr_y.transpose(), te_a.transpose(), te_y.transpose()) 447 | R = dci_result['importance'] 448 | print("DCI Score", dci_result['disentanglement']) 449 | dci_scores.append(dci_result['disentanglement']) 450 | print(dci_result["informativeness_train"]) 451 | print(dci_result["informativeness_test"]) 452 | 453 | pred_metric = PredMetric("Linear", output_type, y_names) 454 | pred_result = pred_metric.evaluate(tr_a, tr_y, te_a, te_y) 455 | 456 | print("Avg Result", pred_result['Linear_avg_result']) 457 | avg_preds_ln.append(pred_result['Linear_avg_result']) 458 | preds_ln.append(pred_result['Linear_result']) 459 | 460 | if dataset == "3dshapes": 461 | dci_scores = np.array(dci_scores) 462 | print("DCI Score, {:.4f} \pm {:.4f}".format(np.array(dci_scores).mean(), np.array(dci_scores).std())) 463 | 464 | 465 | if dataset == "celeba": 466 | tad_scores = np.array(tad_scores) 467 | tad_attrs = np.array(tad_attrs) 468 | print("TAD Score, {:.4f} \pm {:.4f}".format(np.array(tad_scores).mean(), np.array(tad_scores).std())) 469 | print("TAD Attr, {:.4f} \pm {:.4f}".format(np.array(tad_attrs).mean(), np.array(tad_attrs).std())) 470 | 471 | avg_preds_ln = np.array(avg_preds_ln) 472 | print("Avg Acc (Linear), {:.4f} \pm {:.4f}".format(np.array(avg_preds_ln).mean(), np.array(avg_preds_ln).std())) 473 | 474 | preds_ln = np.vstack(preds_ln) 475 | for a_idx in range(preds_ln.shape[1]): 476 | print("Acc for {} (Linear), {:.4f} \pm {:.4f}".format(y_names[a_idx], preds_ln[:, a_idx].mean(), preds_ln[:, a_idx].std())) -------------------------------------------------------------------------------- /eval_fid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | # learned latent 5 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64 6 | 7 | python run.py --model diff --mode save_latent --disent_metric tad --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --deterministic --prior regular --r_seed 64 8 | 9 | python run.py --model diff --mode train_latent_ddim --a_dim 256 --epochs 50 --mmd_weight 0.1 --dataset celeba --deterministic --save_epoch 10 --prior regular --r_seed 64 10 | 11 | python run.py --model diff --mode eval_fid --split_step 500 --a_dim 256 --batch_size 256 --mmd_weight 0.1 --sampling_number 10000 --epochs 50 --dataset celeba --is_latent --prior regular --r_seed 64 12 | 13 | # without learned latent 14 | # python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64 15 | 16 | # python run.py --model vanilla --mode train --a_dim 256 --epochs 50 --dataset celeba --batch_size 128 --save_epochs 10 --prior regular --r_seed 64 17 | 18 | # python run.py --model diff --mode eval_fid --split_step 500 --a_dim 256 --batch_size 256 --mmd_weight 0.01 --sampling_number 10000 --epochs 50 --dataset celeba --prior regular --r_seed 64 -------------------------------------------------------------------------------- /flowchart.drawio-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isjakewong/InfoDiffusion/75c7b31953692af2f122e7fe8715902e160c9de5/flowchart.drawio-1.png -------------------------------------------------------------------------------- /gen_fid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python gen_fid_stats.py celeba ./celeba_imgs -------------------------------------------------------------------------------- /gen_fid_stats.py: -------------------------------------------------------------------------------- 1 | from cleanfid import fid 2 | import sys 3 | if __name__ == '__main__': 4 | custom_name = sys.argv[1] 5 | dataset_path = sys.argv[2] 6 | # TODO: remove saved stats if you want to reproduce. 7 | # fid.remove_custom_stats(custom_name, mode="clean") 8 | print(f'Generating clean-fid for dataset {custom_name} located at {dataset_path}') 9 | fid.make_custom_stats(custom_name, dataset_path, mode="clean") -------------------------------------------------------------------------------- /graphicalabstract.drawio_v2-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isjakewong/InfoDiffusion/75c7b31953692af2f122e7fe8715902e160c9de5/graphicalabstract.drawio_v2-1.png -------------------------------------------------------------------------------- /interpolate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python run.py --model diff --mode interpolate --mmd_weight 0.1 --img_id 0 --a_dim 32 --epochs 50 --dataset celeba --deterministic --prior regular --r_seed 64 4 | -------------------------------------------------------------------------------- /latent_quality.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python run.py --model diff --mode latent_quality --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64 4 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from modules import * 5 | from utils import compute_mmd, gaussian_mixture, swiss_roll 6 | 7 | class UNet(nn.Module): 8 | def __init__(self, T, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, shape=None): 9 | super().__init__() 10 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 11 | tdim = ch * 4 12 | self.time_embedding = TimeEmbedding(T, ch, tdim) 13 | 14 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1) 15 | 16 | self.downblocks = nn.ModuleList() 17 | chs = [ch] # record output channel when dowmsample for upsample 18 | now_ch = ch 19 | for i, mult in enumerate(ch_mult): 20 | out_ch = ch * mult 21 | for _ in range(num_res_blocks): 22 | self.downblocks.append(ResBlock( 23 | in_ch=now_ch, out_ch=out_ch, tdim=tdim, 24 | dropout=dropout, attn=(i in attn))) 25 | now_ch = out_ch 26 | chs.append(now_ch) 27 | if i != len(ch_mult) - 1: 28 | self.downblocks.append(DownSample(now_ch)) 29 | chs.append(now_ch) 30 | 31 | self.middleblocks = nn.ModuleList([ 32 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True, crossattn=False), 33 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False, crossattn=False), 34 | ]) 35 | 36 | self.upblocks = nn.ModuleList() 37 | for i, mult in reversed(list(enumerate(ch_mult))): 38 | out_ch = ch * mult 39 | for _ in range(num_res_blocks + 1): 40 | self.upblocks.append(ResBlock( 41 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, 42 | dropout=dropout, attn=(i in attn))) 43 | now_ch = out_ch 44 | if i != 0: 45 | self.upblocks.append(UpSample(now_ch)) 46 | assert len(chs) == 0 47 | 48 | self.tail = nn.Sequential( 49 | nn.GroupNorm(32, now_ch), 50 | nn.SiLU(), 51 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1) 52 | ) 53 | 54 | self.initialize() 55 | 56 | def initialize(self): 57 | init.xavier_uniform_(self.head.weight) 58 | init.zeros_(self.head.bias) 59 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 60 | init.zeros_(self.tail[-1].bias) 61 | 62 | def forward(self, x, t): 63 | # Timestep embedding 64 | temb = self.time_embedding(t) 65 | # Downsampling 66 | h = self.head(x) 67 | hs = [h] 68 | 69 | for layer in self.downblocks: 70 | h = layer(h, temb) 71 | hs.append(h) 72 | 73 | # Middle 74 | for layer in self.middleblocks: 75 | if isinstance(layer, ResBlock): 76 | h = layer(h, temb) 77 | else: 78 | h = layer(h) 79 | 80 | # Upsampling 81 | for layer in self.upblocks: 82 | if isinstance(layer, ResBlock): 83 | h = torch.cat([h, hs.pop()], dim=1) 84 | h = layer(h, temb) 85 | h = self.tail(h) 86 | 87 | assert len(hs) == 0 88 | return h 89 | 90 | 91 | class MLPLNAct(nn.Module): 92 | def __init__( 93 | self, 94 | in_channels: int, 95 | out_channels: int, 96 | norm: bool, 97 | use_cond: bool, 98 | activation: str = None, 99 | cond_channels: int = None, 100 | condition_bias: float = 0, 101 | dropout: float = 0, 102 | ): 103 | super().__init__() 104 | self.activation = activation 105 | if self.activation is not None: 106 | self.act = nn.SiLU() 107 | else: 108 | self.act = nn.Identity() 109 | self.condition_bias = condition_bias 110 | self.use_cond = use_cond 111 | 112 | self.linear = nn.Linear(in_channels, out_channels) 113 | if self.use_cond: 114 | self.linear_emb = nn.Linear(cond_channels, out_channels) 115 | self.cond_layers = nn.Sequential(self.act, self.linear_emb) 116 | if norm: 117 | self.norm = nn.LayerNorm(out_channels) 118 | else: 119 | self.norm = nn.Identity() 120 | 121 | if dropout > 0: 122 | self.dropout = nn.Dropout(dropout) 123 | else: 124 | self.dropout = nn.Identity() 125 | 126 | self.init_weights() 127 | 128 | def init_weights(self): 129 | for module in self.modules(): 130 | if isinstance(module, nn.Linear): 131 | if self.activation == 'relu': 132 | init.kaiming_normal_(module.weight, 133 | a=0, 134 | nonlinearity='relu') 135 | elif self.activation == 'leaky_relu': 136 | init.kaiming_normal_(module.weight, 137 | a=0.2, 138 | nonlinearity='leaky_relu') 139 | elif self.activation == 'silu': 140 | init.kaiming_normal_(module.weight, 141 | a=0, 142 | nonlinearity='relu') 143 | else: 144 | # leave it as default 145 | pass 146 | 147 | def forward(self, x, cond=None): 148 | x = self.linear(x) 149 | if self.use_cond: 150 | # (n, c) or (n, c * 2) 151 | cond = self.cond_layers(cond) 152 | 153 | # scale shift first 154 | x = x * (self.condition_bias + cond) 155 | 156 | # then norm 157 | x = self.norm(x) 158 | else: 159 | # no condition 160 | x = self.norm(x) 161 | x = self.act(x) 162 | x = self.dropout(x) 163 | return x 164 | 165 | 166 | class LatentUNet(nn.Module): 167 | def __init__(self, T, num_layers=10, dropout=0.1, shape=None, activation='silu', 168 | num_time_emb_channels: int = 64, num_time_layers: int = 2): 169 | super().__init__() 170 | self.num_time_emb_channels = num_time_emb_channels 171 | self.shape = shape 172 | 173 | layers = [] 174 | for i in range(num_time_layers): 175 | if i == 0: 176 | a = num_time_emb_channels 177 | b = shape[-1] 178 | else: 179 | a = shape[-1] 180 | b = shape[-1] 181 | layers.append(nn.Linear(a, b)) 182 | if i < num_time_layers - 1: 183 | layers.append(nn.SiLU()) 184 | self.time_embed = nn.Sequential(*layers) 185 | 186 | self.skip_layers = list(range(1, num_layers)) 187 | self.layers = nn.ModuleList([]) 188 | for i in range(num_layers): 189 | if i == 0: 190 | act = activation 191 | norm = True 192 | cond = True 193 | a, b = shape[-1], shape[-1] * 4 194 | dropout = dropout 195 | elif i == num_layers - 1: 196 | act = None 197 | norm = False 198 | cond = False 199 | a, b = shape[-1] * 4, shape[-1] 200 | dropout = 0 201 | else: 202 | act = 'silu' 203 | norm = True 204 | cond = True 205 | a, b = shape[-1] * 4, shape[-1] * 4 206 | dropout = dropout 207 | 208 | if i in self.skip_layers: 209 | a += shape[-1] 210 | 211 | self.layers.append( 212 | MLPLNAct( 213 | a, 214 | b, 215 | norm=norm, 216 | activation=act, 217 | cond_channels=shape[-1], 218 | use_cond=cond, 219 | condition_bias=1, 220 | dropout=dropout, 221 | )) 222 | 223 | def forward(self, x, t): 224 | # Timestep embedding 225 | t = timestep_embedding(t, self.num_time_emb_channels) 226 | temb = self.time_embed(t) 227 | 228 | h = x 229 | for i in range(len(self.layers)): 230 | if i in self.skip_layers: 231 | # injecting input into the hidden layers 232 | h = torch.cat([h, x], dim=1) 233 | h = self.layers[i].forward(x=h, cond=temb) 234 | return h 235 | 236 | 237 | class AuxiliaryUNet(nn.Module): 238 | def __init__(self, T, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=32, shape=None): 239 | super().__init__() 240 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 241 | tdim = ch * 4 242 | self.a_dim = a_dim 243 | self.time_embedding = TimeEmbedding(T, ch, tdim) 244 | self.fc_a = nn.Linear(self.a_dim, tdim) 245 | 246 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1) 247 | 248 | self.downblocks = nn.ModuleList() 249 | chs = [ch] # record output channel when dowmsample for upsample 250 | now_ch = ch 251 | for i, mult in enumerate(ch_mult): 252 | out_ch = ch * mult 253 | for _ in range(num_res_blocks): 254 | self.downblocks.append(AuxResBlock( 255 | in_ch=now_ch, out_ch=out_ch, tdim=tdim, 256 | dropout=dropout, attn=(i in attn))) 257 | now_ch = out_ch 258 | chs.append(now_ch) 259 | if i != len(ch_mult) - 1: 260 | self.downblocks.append(DownSample(now_ch)) 261 | chs.append(now_ch) 262 | 263 | self.middleblocks = nn.ModuleList([ 264 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=True, crossattn=False), 265 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=False, crossattn=False), 266 | ]) 267 | 268 | self.upblocks = nn.ModuleList() 269 | for i, mult in reversed(list(enumerate(ch_mult))): 270 | out_ch = ch * mult 271 | for _ in range(num_res_blocks + 1): 272 | self.upblocks.append(AuxResBlock( 273 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, 274 | dropout=dropout, attn=(i in attn))) 275 | now_ch = out_ch 276 | if i != 0: 277 | self.upblocks.append(UpSample(now_ch)) 278 | assert len(chs) == 0 279 | 280 | self.tail = nn.Sequential( 281 | nn.GroupNorm(32, now_ch), 282 | nn.SiLU(), 283 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1) 284 | ) 285 | 286 | self.initialize() 287 | 288 | def initialize(self): 289 | init.xavier_uniform_(self.head.weight) 290 | init.zeros_(self.head.bias) 291 | init.xavier_uniform_(self.fc_a.weight) 292 | init.zeros_(self.fc_a.bias) 293 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 294 | init.zeros_(self.tail[-1].bias) 295 | 296 | def forward(self, x, t, a): 297 | # Latent embedding 298 | aemb = self.fc_a(a) 299 | 300 | # Timestep embedding 301 | temb = self.time_embedding(t) 302 | 303 | # Downsampling 304 | h = self.head(x) 305 | hs = [h] 306 | 307 | for layer in self.downblocks: 308 | h = layer(h, temb, aemb) 309 | hs.append(h) 310 | 311 | # Middle 312 | for layer in self.middleblocks: 313 | if isinstance(layer, AuxResBlock): 314 | h = layer(h, temb, aemb) 315 | else: 316 | h = layer(h) 317 | 318 | # Upsampling 319 | for layer in self.upblocks: 320 | if isinstance(layer, AuxResBlock): 321 | h = torch.cat([h, hs.pop()], dim=1) 322 | h = layer(h, temb, aemb) 323 | h = self.tail(h) 324 | 325 | assert len(hs) == 0 326 | return h 327 | 328 | 329 | class BottleneckAuxUNet(nn.Module): 330 | def __init__(self, T, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=32, shape=None): 331 | super().__init__() 332 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 333 | tdim = ch * 4 334 | self.a_dim = a_dim 335 | self.time_embedding = TimeEmbedding(T, ch, tdim) 336 | self.fc_a = nn.Sequential( 337 | nn.SiLU(), 338 | nn.Linear(self.a_dim, tdim) 339 | ) 340 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1) 341 | 342 | self.downblocks = nn.ModuleList() 343 | chs = [ch] # record output channel when dowmsample for upsample 344 | now_ch = ch 345 | for i, mult in enumerate(ch_mult): 346 | out_ch = ch * mult 347 | for _ in range(num_res_blocks): 348 | self.downblocks.append(ResBlock( 349 | in_ch=now_ch, out_ch=out_ch, tdim=tdim, 350 | dropout=dropout, attn=(i in attn))) 351 | now_ch = out_ch 352 | chs.append(now_ch) 353 | if i != len(ch_mult) - 1: 354 | self.downblocks.append(DownSample(now_ch)) 355 | chs.append(now_ch) 356 | 357 | self.middleblocks = nn.ModuleList([ 358 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=True, crossattn=False), 359 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=False, crossattn=False), 360 | ]) 361 | 362 | self.upblocks = nn.ModuleList() 363 | for i, mult in reversed(list(enumerate(ch_mult))): 364 | out_ch = ch * mult 365 | for _ in range(num_res_blocks + 1): 366 | self.upblocks.append(ResBlock( 367 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, 368 | dropout=dropout, attn=(i in attn))) 369 | now_ch = out_ch 370 | if i != 0: 371 | self.upblocks.append(UpSample(now_ch)) 372 | assert len(chs) == 0 373 | 374 | self.tail = nn.Sequential( 375 | nn.GroupNorm(32, now_ch), 376 | nn.SiLU(), 377 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1) 378 | ) 379 | 380 | self.initialize() 381 | 382 | def initialize(self): 383 | init.xavier_uniform_(self.head.weight) 384 | init.zeros_(self.head.bias) 385 | init.kaiming_normal_(self.fc_a[1].weight, 386 | a=0, 387 | nonlinearity='relu') 388 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 389 | init.zeros_(self.tail[-1].bias) 390 | 391 | def forward(self, x, t, a): 392 | # Latent embedding 393 | aemb = self.fc_a(a) 394 | 395 | # Timestep embedding 396 | temb = self.time_embedding(t) 397 | 398 | # Downsampling 399 | h = self.head(x) 400 | hs = [h] 401 | 402 | for layer in self.downblocks: 403 | h = layer(h, temb) 404 | hs.append(h) 405 | 406 | # Middle 407 | for layer in self.middleblocks: 408 | if isinstance(layer, AuxResBlock): 409 | h = layer(h, temb, aemb) 410 | else: 411 | h = layer(h) 412 | 413 | # Upsampling 414 | for layer in self.upblocks: 415 | if isinstance(layer, ResBlock): 416 | h = torch.cat([h, hs.pop()], dim=1) 417 | h = layer(h, temb) 418 | h = self.tail(h) 419 | 420 | assert len(hs) == 0 421 | return h 422 | 423 | 424 | class Encoder(nn.Module): 425 | def __init__(self, ch=64, ch_mult=[1,2,4,8,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=32, shape=None): 426 | super().__init__() 427 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 428 | 429 | self.shape = shape 430 | self.a_dim = a_dim 431 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1) 432 | self.downblocks = nn.ModuleList() 433 | chs = [ch] # record output channel when dowmsample for upsample 434 | now_ch = ch 435 | for i, mult in enumerate(ch_mult): 436 | out_ch = ch * mult 437 | for _ in range(num_res_blocks): 438 | self.downblocks.append(ResBlock_encoder( 439 | in_ch=now_ch, out_ch=out_ch, 440 | dropout=dropout, attn=(i in attn))) 441 | now_ch = out_ch 442 | chs.append(now_ch) 443 | if i != len(ch_mult) - 1: 444 | self.downblocks.append(DownSample(now_ch)) 445 | chs.append(now_ch) 446 | 447 | self.middleblocks = nn.ModuleList([ 448 | ResBlock_encoder(now_ch, now_ch, dropout, attn=True), 449 | ResBlock_encoder(now_ch, now_ch, dropout, attn=False), 450 | ]) 451 | 452 | self.upblocks = nn.ModuleList() 453 | for i, mult in reversed(list(enumerate(ch_mult))): 454 | out_ch = ch * mult 455 | for _ in range(num_res_blocks + 1): 456 | self.upblocks.append(ResBlock_encoder( 457 | in_ch=chs.pop() + now_ch, out_ch=out_ch, 458 | dropout=dropout, attn=(i in attn))) 459 | now_ch = out_ch 460 | if i != 0: 461 | self.upblocks.append(UpSample(now_ch)) 462 | assert len(chs) == 0 463 | 464 | self.tail = nn.Sequential( 465 | nn.GroupNorm(32, now_ch), 466 | nn.SiLU(), 467 | nn.Conv2d(now_ch, 1, 3, stride=1, padding=1) 468 | ) 469 | 470 | self.fc_a = nn.Linear(self.shape[1]*self.shape[2], self.a_dim) 471 | self.fc_mu = nn.Linear(self.a_dim, self.a_dim) 472 | self.fc_var = nn.Linear(self.a_dim, self.a_dim) 473 | 474 | self.initialize() 475 | 476 | def initialize(self): 477 | init.xavier_uniform_(self.head.weight) 478 | init.zeros_(self.head.bias) 479 | init.xavier_uniform_(self.fc_a.weight) 480 | init.zeros_(self.fc_a.bias) 481 | init.xavier_uniform_(self.fc_mu.weight) 482 | init.zeros_(self.fc_mu.bias) 483 | init.xavier_uniform_(self.fc_var.weight) 484 | init.zeros_(self.fc_var.bias) 485 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 486 | init.zeros_(self.tail[-1].bias) 487 | 488 | def forward(self, x): 489 | # Downsampling 490 | h = self.head(x) 491 | 492 | hs = [h] 493 | for layer in self.downblocks: 494 | if isinstance(layer, ResBlock_encoder): 495 | h = layer(h) 496 | else: 497 | h = layer(h, None, None) # for downsample module, 0 is placeholder for temb and aemb 498 | hs.append(h) 499 | # Middle 500 | for layer in self.middleblocks: 501 | h = layer(h) 502 | # Upsampling 503 | for layer in self.upblocks: 504 | if isinstance(layer, ResBlock_encoder): 505 | h = torch.cat([h, hs.pop()], dim=1) 506 | h = layer(h) 507 | else: 508 | h = layer(h, None, None) # for upsample module, 0 is placeholder for temb and aemb 509 | 510 | h = torch.flatten(self.tail(h), start_dim=1) 511 | 512 | a = self.fc_a(h) 513 | mu = self.fc_mu(a) 514 | log_var = self.fc_var(a) 515 | a_q = mu + torch.randn_like(mu) * torch.exp(0.5 * log_var) 516 | 517 | assert len(hs) == 0 518 | return a, a_q, mu, log_var 519 | 520 | 521 | class Decoder(nn.Module): 522 | def __init__(self, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=10, shape=None): 523 | super().__init__() 524 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 525 | self.a_dim = a_dim 526 | self.shape = shape 527 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1) 528 | self.downblocks = nn.ModuleList() 529 | chs = [ch] # record output channel when dowmsample for upsample 530 | now_ch = ch 531 | for i, mult in enumerate(ch_mult): 532 | out_ch = ch * mult 533 | for _ in range(num_res_blocks): 534 | self.downblocks.append(ResBlock_encoder( 535 | in_ch=now_ch, out_ch=out_ch, 536 | dropout=dropout, attn=(i in attn))) 537 | now_ch = out_ch 538 | chs.append(now_ch) 539 | if i != len(ch_mult) - 1: 540 | self.downblocks.append(DownSample(now_ch)) 541 | chs.append(now_ch) 542 | 543 | self.middleblocks = nn.ModuleList([ 544 | ResBlock_encoder(now_ch, now_ch, dropout, attn=True), 545 | ResBlock_encoder(now_ch, now_ch, dropout, attn=False), 546 | ]) 547 | 548 | self.upblocks = nn.ModuleList() 549 | for i, mult in reversed(list(enumerate(ch_mult))): 550 | out_ch = ch * mult 551 | for _ in range(num_res_blocks + 1): 552 | self.upblocks.append(ResBlock_encoder( 553 | in_ch=chs.pop() + now_ch, out_ch=out_ch, 554 | dropout=dropout, attn=(i in attn))) 555 | now_ch = out_ch 556 | if i != 0: 557 | self.upblocks.append(UpSample(now_ch)) 558 | assert len(chs) == 0 559 | 560 | self.tail = nn.Sequential( 561 | nn.GroupNorm(32, now_ch), 562 | nn.SiLU(), 563 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1) 564 | ) 565 | 566 | self.fc_a = nn.Linear(self.a_dim, self.shape[0]*self.shape[1]*self.shape[2]) 567 | 568 | self.initialize() 569 | 570 | def initialize(self): 571 | init.xavier_uniform_(self.head.weight) 572 | init.zeros_(self.head.bias) 573 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) 574 | init.zeros_(self.tail[-1].bias) 575 | 576 | def forward(self, a): 577 | # Latent embedding 578 | aemb = self.fc_a(a) 579 | h = aemb.reshape(a.shape[0], self.shape[0], self.shape[1], self.shape[2]) 580 | 581 | # Downsampling 582 | h = self.head(h) 583 | hs = [h] 584 | for layer in self.downblocks: 585 | if isinstance(layer, ResBlock_encoder): 586 | h = layer(h) 587 | else: 588 | h = layer(h, None) # for downsample module, 0 is placeholder 589 | hs.append(h) 590 | # Middle 591 | for layer in self.middleblocks: 592 | h = layer(h) 593 | # Upsampling 594 | for layer in self.upblocks: 595 | if isinstance(layer, ResBlock_encoder): 596 | h = torch.cat([h, hs.pop()], dim=1) 597 | h = layer(h) 598 | else: 599 | h = layer(h, None) # for upsample module, 0 is placeholder 600 | rec = self.tail(h) 601 | 602 | assert len(hs) == 0 603 | return rec 604 | 605 | class InfoDiff(nn.Module): 606 | def __init__(self, args, device, shape): 607 | ''' 608 | beta_1 : beta_1 of diffusion process 609 | beta_T : beta_T of diffusion process 610 | T : Diffusion Steps 611 | ''' 612 | 613 | super().__init__() 614 | self.device = device 615 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device) 616 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps).to(device = device) 617 | self.alphas = 1 - self.betas 618 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]]) 619 | if args.input_size == 28: 620 | ch_mult = [1,2,4,] 621 | else: 622 | ch_mult = [1,2,2,2] 623 | if args.is_bottleneck: 624 | self.backbone = BottleneckAuxUNet(ch_mult=ch_mult, T=args.diffusion_steps, ch=args.unets_channels, a_dim=args.a_dim, shape=shape) 625 | else: 626 | self.backbone = AuxiliaryUNet(ch_mult=ch_mult, T=args.diffusion_steps, ch=args.unets_channels, a_dim=args.a_dim, shape=shape) 627 | self.encoder = Encoder(ch_mult=ch_mult, ch=args.encoder_channels, a_dim=args.a_dim, shape=shape) 628 | self.mmd_weight : float = args.mmd_weight 629 | self.kld_weight : float = args.kld_weight 630 | self.to(device) 631 | 632 | def loss_fn(self, args, x, idx=None, curr_epoch=0): 633 | ''' 634 | x : real data if idx==None else perturbation data 635 | idx : if None (training phase), we perturbed random index. 636 | ''' 637 | output, epsilon, a, mu, log_var = self.forward(x, idx=idx, get_target=True) 638 | 639 | # denoising matching term 640 | loss = (output - epsilon).square().mean() 641 | print('denoising loss:', loss) 642 | 643 | # reconstruction term 644 | x_0 = torch.sqrt(1 / self.alphas[0]) * (x - self.betas[0] / torch.sqrt(1 - self.alpha_bars[0]) * output) 645 | loss_rec = (x_0 - x).square().mean() 646 | loss += loss_rec / args.diffusion_steps 647 | print('recon loss:', loss_rec / args.diffusion_steps) 648 | 649 | if self.mmd_weight != 0 and self.kld_weight != 0: 650 | # MMD term 651 | if args.prior == 'regular': 652 | true_samples = torch.randn_like(a, device=self.device) 653 | elif args.prior == '10mix': 654 | prior = gaussian_mixture(args.batch_size, args.a_dim) 655 | true_samples = torch.FloatTensor(prior).to(device=self.device) 656 | elif args.prior == 'roll': 657 | prior = swiss_roll(args.batch_size) 658 | true_samples = torch.FloatTensor(prior).to(device=self.device) 659 | loss_mmd = compute_mmd(true_samples, mu) 660 | print('mmd loss:', args.mmd_weight * loss_mmd) 661 | loss += args.mmd_weight * loss_mmd 662 | # KLD term 663 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) 664 | if args.use_C: 665 | # KLD term w/ control constant 666 | self.C_max = torch.FloatTensor([args.C_max]).to(device=self.device) 667 | C = torch.clamp(self.C_max/args.epochs*curr_epoch, torch.FloatTensor([0]).to(device=self.device), self.C_max) 668 | loss += args.kld_weight * (kld_loss - C.squeeze(dim=0)).abs() 669 | else: 670 | print('kld loss:', args.kld_weight * kld_loss) 671 | loss += args.kld_weight * kld_loss 672 | elif args.mmd_weight != 0: 673 | # MMD term 674 | if args.prior == 'regular': 675 | true_samples = torch.randn_like(a, device=self.device) 676 | elif args.prior == '10mix': 677 | prior = gaussian_mixture(args.batch_size, args.a_dim) 678 | true_samples = torch.FloatTensor(prior).to(device=self.device) 679 | elif args.prior == 'roll': 680 | prior = swiss_roll(args.batch_size) 681 | true_samples = torch.FloatTensor(prior).to(device=self.device) 682 | loss_mmd = compute_mmd(true_samples, a) 683 | print('mmd loss:', args.mmd_weight * loss_mmd) 684 | loss += args.mmd_weight * loss_mmd 685 | elif args.kld_weight != 0: 686 | # KLD term 687 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) 688 | if args.use_C: 689 | # KLD term w/ control constant 690 | self.C_max = torch.FloatTensor([args.C_max]).to(device=self.device) 691 | C = torch.clamp(self.C_max/args.epochs*curr_epoch, torch.FloatTensor([0]).to(device=self.device), self.C_max) 692 | loss += args.kld_weight * (kld_loss - C.squeeze(dim=0)).abs() 693 | else: 694 | print('kld loss:', args.kld_weight * kld_loss) 695 | loss += args.kld_weight * kld_loss 696 | return loss 697 | 698 | def forward(self, x, idx=None, a=None, get_target=False): 699 | 700 | if idx is None: 701 | idx = torch.randint(0, len(self.alpha_bars), (x.size(0), )).to(device = self.device) 702 | used_alpha_bars = self.alpha_bars[idx][:, None, None, None] 703 | epsilon = torch.randn_like(x) 704 | x_tilde = torch.sqrt(used_alpha_bars) * x + torch.sqrt(1 - used_alpha_bars) * epsilon 705 | else: 706 | idx = torch.Tensor([idx for _ in range(x.size(0))]).to(device = self.device).long() 707 | x_tilde = x 708 | 709 | if a is None: 710 | a, a_q, mu, log_var = self.encoder(x) 711 | else: 712 | a_q = a 713 | 714 | if self.mmd_weight != 0 and self.kld_weight != 0: 715 | output = self.backbone(x_tilde, idx, a_q) 716 | elif self.mmd_weight == 0 and self.kld_weight == 0: 717 | output = self.backbone(x_tilde, idx, a) 718 | elif self.mmd_weight != 0: 719 | output = self.backbone(x_tilde, idx, a) 720 | elif self.kld_weight != 0: 721 | output = self.backbone(x_tilde, idx, a_q) 722 | 723 | return (output, epsilon, a, mu, log_var) if get_target else output 724 | 725 | 726 | class Diff(nn.Module): 727 | def __init__(self, args, device, shape): 728 | ''' 729 | beta_1 : beta_1 of diffusion process 730 | beta_T : beta_T of diffusion process 731 | T : Diffusion Steps 732 | ''' 733 | 734 | super().__init__() 735 | self.device = device 736 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device) 737 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps).to(device = device) 738 | self.alphas = 1 - self.betas 739 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]]) 740 | self.is_latent = args.is_latent 741 | if args.mode == "train_latent_ddim": 742 | self.is_latent = True 743 | if args.input_size == 28: 744 | ch_mult = [1,2,4,] 745 | else: 746 | ch_mult = [1,2,4,8] 747 | if self.is_latent: 748 | self.backbone = LatentUNet(T=args.diffusion_steps, num_layers=10, dropout=0.1, shape=shape, activation='silu') 749 | else: 750 | self.backbone = UNet(ch_mult=ch_mult, T=args.diffusion_steps, ch=args.unets_channels, shape=shape) 751 | self.to(device) 752 | 753 | def loss_fn(self, args, x, idx=None, curr_epoch=0): 754 | ''' 755 | x : real data if idx==None else perturbation data 756 | idx : if None (training phase), we perturbed random index. 757 | ''' 758 | output, epsilon = self.forward(x, idx=idx, get_target=True) 759 | # denoising matching term 760 | loss = (output - epsilon).square().mean() 761 | # print('denoising loss:', loss) 762 | return loss 763 | 764 | def forward(self, x, idx=None, get_target=False): 765 | 766 | if idx is None: 767 | idx = torch.randint(0, len(self.alpha_bars), (x.size(0), )).to(device = self.device) 768 | if self.is_latent: 769 | used_alpha_bars = self.alpha_bars[idx][:, None] 770 | else: 771 | used_alpha_bars = self.alpha_bars[idx][:, None, None, None] 772 | epsilon = torch.randn_like(x) 773 | x_tilde = torch.sqrt(used_alpha_bars) * x + torch.sqrt(1 - used_alpha_bars) * epsilon 774 | else: 775 | idx = torch.Tensor([idx for _ in range(x.size(0))]).to(device = self.device).long() 776 | x_tilde = x 777 | output = self.backbone(x_tilde, idx) 778 | 779 | return (output, epsilon) if get_target else output 780 | 781 | class VAE(nn.Module): 782 | def __init__(self, args, device, shape): 783 | super().__init__() 784 | self.device = device 785 | if args.input_size == 28: 786 | ch_mult = [1,2,4,] 787 | else: 788 | ch_mult = [1,2,4,8] 789 | self.encoder = Encoder(ch_mult=ch_mult, ch=args.encoder_channels, a_dim=args.a_dim, shape=shape) 790 | self.decoder = Decoder(ch_mult=ch_mult, ch=args.encoder_channels, a_dim=args.a_dim, shape=shape) 791 | self.mmd_weight : float = args.mmd_weight 792 | self.kld_weight : float = args.kld_weight 793 | self.to(device) 794 | 795 | def loss_fn(self, args, x, curr_epoch=0): 796 | reconstruction, a_q, mu, log_var = self.forward(x, get_target=True) 797 | # denoising matching term 798 | loss = (reconstruction - x).square().mean() 799 | print('reconstruction loss:', loss) 800 | # prior matching term 801 | if args.mmd_weight != 0: 802 | # MMD term 803 | true_samples = torch.randn_like(a_q, device=self.device) 804 | loss_mmd = compute_mmd(true_samples, a_q) 805 | print('mmd loss:', args.mmd_weight * loss_mmd) 806 | loss += args.mmd_weight * loss_mmd 807 | elif args.kld_weight != 0: 808 | # KLD term 809 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) 810 | if args.use_C: 811 | # KLD term w/ control constant 812 | self.C_max = torch.FloatTensor([args.C_max]).to(device=self.device) 813 | C = torch.clamp(self.C_max/args.epochs*curr_epoch, torch.FloatTensor([0]).to(device=self.device), self.C_max) 814 | loss += args.kld_weight * (kld_loss - C.squeeze(dim=0)).abs() 815 | print('kld-c loss:', (kld_loss - C.squeeze(dim=0)).abs()) 816 | else: 817 | loss += args.kld_weight * kld_loss 818 | print('kld loss:', args.kld_weight * kld_loss) 819 | return loss 820 | 821 | def forward(self, x, get_target=False): 822 | a, a_q, mu, log_var = self.encoder(x) 823 | 824 | if self.mmd_weight != 0 and self.kld_weight != 0: 825 | reconstruction = self.decoder(a_q) 826 | elif self.mmd_weight == 0 and self.kld_weight == 0: 827 | reconstruction = self.decoder(a) 828 | elif self.mmd_weight != 0: 829 | reconstruction = self.decoder(a_q) 830 | elif self.kld_weight != 0: 831 | reconstruction = self.decoder(a_q) 832 | 833 | return (reconstruction, a_q, mu, log_var) if get_target else reconstruction 834 | 835 | 836 | class FeatureClassfier(nn.Module): 837 | def __init__(self, args, output_dim = 40): 838 | super(FeatureClassfier, self).__init__() 839 | 840 | """build full connect layers for every attribute""" 841 | self.fc_set = {} 842 | 843 | self.fc = nn.Sequential( 844 | nn.Linear(args.a_dim, 512), 845 | nn.ReLU(True), 846 | nn.Dropout(p=0.5), 847 | nn.Linear(512, 128), 848 | nn.ReLU(True), 849 | nn.Dropout(p=0.5), 850 | nn.Linear(128, output_dim), 851 | ) 852 | 853 | self.sigmoid = nn.Sigmoid() 854 | 855 | def forward(self, x): 856 | x = x.view(x.size(0), -1) # flatten 857 | res = self.fc(x) 858 | res = self.sigmoid(res) 859 | return res -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import torch.nn.functional as F 6 | 7 | from typing import Union 8 | 9 | class TimeEmbedding(nn.Module): 10 | def __init__(self, T, d_model, dim): 11 | assert d_model % 2 == 0 12 | super().__init__() 13 | emb = torch.arange(0, d_model, step=2) / torch.Tensor([d_model]) * math.log(10000) 14 | emb = torch.exp(-emb) 15 | pos = torch.arange(T).float() 16 | emb = pos[:, None] * emb[None, :] 17 | assert list(emb.shape) == [T, d_model // 2] 18 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) 19 | assert list(emb.shape) == [T, d_model // 2, 2] 20 | emb = emb.view(T, d_model) 21 | 22 | self.timembedding = nn.Sequential( 23 | nn.Embedding.from_pretrained(emb), 24 | nn.Linear(d_model, dim), 25 | nn.SiLU(), 26 | nn.Linear(dim, dim), 27 | ) 28 | self.initialize() 29 | 30 | def initialize(self): 31 | for module in self.modules(): 32 | if isinstance(module, nn.Linear): 33 | init.xavier_uniform_(module.weight) 34 | init.zeros_(module.bias) 35 | 36 | def forward(self, t): 37 | emb = self.timembedding(t) 38 | return emb 39 | 40 | 41 | def timestep_embedding(timesteps, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | 45 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 46 | These may be fractional. 47 | :param dim: the dimension of the output. 48 | :param max_period: controls the minimum frequency of the embeddings. 49 | :return: an [N x dim] Tensor of positional embeddings. 50 | """ 51 | half = dim // 2 52 | freqs = torch.exp(-math.log(max_period) * 53 | torch.arange(start=0, end=half, dtype=torch.float32) / 54 | half).to(device=timesteps.device) 55 | args = timesteps[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat( 59 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 60 | return embedding 61 | 62 | 63 | class DownSample(nn.Module): 64 | def __init__(self, in_ch): 65 | super().__init__() 66 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) 67 | self.initialize() 68 | 69 | def initialize(self): 70 | init.xavier_uniform_(self.main.weight) 71 | init.zeros_(self.main.bias) 72 | 73 | def forward(self, x, temb=None, aemb=None): 74 | x = self.main(x) 75 | return x 76 | 77 | 78 | class UpSample(nn.Module): 79 | def __init__(self, in_ch): 80 | super().__init__() 81 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) 82 | self.initialize() 83 | 84 | def initialize(self): 85 | init.xavier_uniform_(self.main.weight) 86 | init.zeros_(self.main.bias) 87 | 88 | def forward(self, x, temb=None, aemb=None): 89 | _, _, H, W = x.shape 90 | x = F.interpolate( 91 | x, scale_factor=float(2.0), mode='nearest') 92 | x = self.main(x) 93 | return x 94 | 95 | 96 | class LatentDownSample(nn.Module): 97 | def __init__(self, in_ch): 98 | super().__init__() 99 | self.main = nn.Conv1d(in_ch, in_ch, 3, stride=2, padding=1) 100 | self.initialize() 101 | 102 | def initialize(self): 103 | init.xavier_uniform_(self.main.weight) 104 | init.zeros_(self.main.bias) 105 | 106 | def forward(self, x, temb=None, aemb=None): 107 | x = self.main(x) 108 | return x 109 | 110 | 111 | class LatentUpSample(nn.Module): 112 | def __init__(self, in_ch): 113 | super().__init__() 114 | self.main = nn.Conv1d(in_ch, in_ch, 3, stride=1, padding=1) 115 | self.initialize() 116 | 117 | def initialize(self): 118 | init.xavier_uniform_(self.main.weight) 119 | init.zeros_(self.main.bias) 120 | 121 | def forward(self, x, temb=None, aemb=None): 122 | _, _, L = x.shape 123 | x = F.interpolate( 124 | x, scale_factor=float(2.0), mode='nearest') 125 | x = self.main(x) 126 | return x 127 | 128 | 129 | class AttnBlock(nn.Module): 130 | def __init__(self, in_ch): 131 | super().__init__() 132 | self.group_norm = nn.GroupNorm(32, in_ch) 133 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 134 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 135 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 136 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 137 | self.initialize() 138 | 139 | def initialize(self): 140 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: 141 | init.xavier_uniform_(module.weight) 142 | init.zeros_(module.bias) 143 | init.xavier_uniform_(self.proj.weight, gain=1e-5) 144 | 145 | def forward(self, x): 146 | B, C, H, W = x.shape 147 | h = self.group_norm(x) 148 | q = self.proj_q(h) 149 | k = self.proj_k(h) 150 | v = self.proj_v(h) 151 | 152 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 153 | k = k.view(B, C, H * W) 154 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 155 | assert list(w.shape) == [B, H * W, H * W] 156 | w = F.softmax(w, dim=-1) 157 | 158 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 159 | h = torch.bmm(w, v) 160 | assert list(h.shape) == [B, H * W, C] 161 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 162 | h = self.proj(h) 163 | 164 | return x + h 165 | 166 | 167 | class CrossAttnBlock(nn.Module): 168 | def __init__(self, in_ch): 169 | super().__init__() 170 | self.group_norm = nn.GroupNorm(32, in_ch) 171 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 172 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 173 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 174 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) 175 | self.initialize() 176 | 177 | def initialize(self): 178 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: 179 | init.xavier_uniform_(module.weight) 180 | init.zeros_(module.bias) 181 | init.xavier_uniform_(self.proj.weight, gain=1e-5) 182 | 183 | def forward(self, x, a): 184 | B, C, H, W = x.shape 185 | h = self.group_norm(x) 186 | h_a = self.group_norm(a) 187 | q = self.proj_q(h_a) 188 | k = self.proj_k(h) 189 | v = self.proj_v(h) 190 | 191 | q = q.permute(0, 2, 3, 1).view(B, H * W, C) 192 | k = k.view(B, C, H * W) 193 | w = torch.bmm(q, k) * (int(C) ** (-0.5)) 194 | assert list(w.shape) == [B, H * W, H * W] 195 | w = F.softmax(w, dim=-1) 196 | 197 | v = v.permute(0, 2, 3, 1).view(B, H * W, C) 198 | h = torch.bmm(w, v) 199 | assert list(h.shape) == [B, H * W, C] 200 | h = h.view(B, H, W, C).permute(0, 3, 1, 2) 201 | h = self.proj(h) 202 | 203 | return x + h 204 | 205 | 206 | class ResBlock(nn.Module): 207 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): 208 | super().__init__() 209 | self.temb_proj = nn.Sequential( 210 | nn.SiLU(), 211 | nn.Linear(tdim, 2*out_ch), 212 | ) 213 | self.block1 = nn.Sequential( 214 | nn.GroupNorm(32, in_ch), 215 | nn.SiLU(), 216 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 217 | ) 218 | self.block2 = nn.Sequential( 219 | nn.GroupNorm(32, out_ch), 220 | nn.SiLU(), 221 | nn.Dropout(dropout), 222 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 223 | ) 224 | self.block3 = nn.Sequential( 225 | nn.GroupNorm(32, out_ch), 226 | nn.SiLU(), 227 | nn.Dropout(dropout), 228 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 229 | ) 230 | if in_ch != out_ch: 231 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 232 | else: 233 | self.shortcut = nn.Identity() 234 | self.use_attn = attn 235 | if self.use_attn: 236 | self.attn = AttnBlock(out_ch) 237 | else: 238 | self.attn = nn.Identity() 239 | self.initialize() 240 | 241 | def initialize(self): 242 | for module in self.modules(): 243 | if isinstance(module, (nn.Conv2d, nn.Linear)): 244 | init.xavier_uniform_(module.weight) 245 | init.zeros_(module.bias) 246 | 247 | def forward(self, x, temb): 248 | h = self.block1(x) 249 | temb_out = self.temb_proj(temb)[:, :, None, None] 250 | out_norm, out_rest = self.block2[0], self.block2[1:] 251 | scale, shift = torch.chunk(temb_out, 2, dim=1) 252 | h = out_norm(h) * (1 + scale) + shift 253 | h = out_rest(h) 254 | h = self.block3(h) 255 | h = h + self.shortcut(x) 256 | h = self.attn(h) 257 | 258 | return h 259 | 260 | 261 | class AuxResBlock(nn.Module): 262 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False, crossattn : Union[bool, nn.Module] =False): 263 | super().__init__() 264 | self.block1 = nn.Sequential( 265 | nn.GroupNorm(32, in_ch), 266 | nn.SiLU(), 267 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 268 | ) 269 | self.temb_proj = nn.Sequential( 270 | nn.SiLU(), 271 | nn.Linear(tdim, 2*out_ch), 272 | ) 273 | self.aemb_proj = nn.Sequential( 274 | nn.SiLU(), 275 | nn.Linear(tdim, 2*out_ch), 276 | ) 277 | self.block2 = nn.Sequential( 278 | nn.GroupNorm(32, out_ch), 279 | nn.SiLU(), 280 | nn.Dropout(dropout), 281 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 282 | ) 283 | self.block3 = nn.Sequential( 284 | nn.GroupNorm(32, out_ch), 285 | nn.SiLU(), 286 | nn.Dropout(dropout), 287 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 288 | ) 289 | 290 | if in_ch != out_ch: 291 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 292 | else: 293 | self.shortcut = nn.Identity() 294 | self.use_attn = attn 295 | if self.use_attn: 296 | self.attn = AttnBlock(out_ch) 297 | else: 298 | self.attn = nn.Identity() 299 | self.use_crossattn = bool(crossattn) 300 | self.crossattn = CrossAttnBlock(out_ch) 301 | self.initialize() 302 | 303 | def initialize(self): 304 | for module in self.modules(): 305 | if isinstance(module, (nn.Conv2d, nn.Linear)): 306 | init.xavier_uniform_(module.weight) 307 | init.zeros_(module.bias) 308 | 309 | def forward(self, x, temb, aemb=None): 310 | h = self.block1(x) 311 | 312 | temb_out = self.temb_proj(temb)[:, :, None, None] 313 | out_norm, out_rest = self.block2[0], self.block2[1:] 314 | scale, shift = torch.chunk(temb_out, 2, dim=1) 315 | h = out_norm(h) * (1 + scale) + shift 316 | aemb_out = self.aemb_proj(aemb)[:, :, None, None] 317 | scale, shift = torch.chunk(aemb_out, 2, dim=1) 318 | h = h * (1 + scale) + shift 319 | h = out_rest(h) 320 | h = self.block3(h) 321 | 322 | h += self.shortcut(x) 323 | h = self.attn(h) 324 | 325 | if self.use_crossattn: 326 | h = self.crossattn(h, aemb) 327 | 328 | return h 329 | 330 | 331 | class ResBlock_encoder(nn.Module): 332 | def __init__(self, in_ch, out_ch, dropout, attn=False): 333 | super().__init__() 334 | self.block1 = nn.Sequential( 335 | nn.GroupNorm(32, in_ch), 336 | nn.SiLU(), 337 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), 338 | ) 339 | self.block2 = nn.Sequential( 340 | nn.GroupNorm(32, out_ch), 341 | nn.SiLU(), 342 | nn.Dropout(dropout), 343 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), 344 | ) 345 | if in_ch != out_ch: 346 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 347 | else: 348 | self.shortcut = nn.Identity() 349 | if attn: 350 | self.attn = AttnBlock(out_ch) 351 | else: 352 | self.attn = nn.Identity() 353 | self.initialize() 354 | 355 | def initialize(self): 356 | for module in self.modules(): 357 | if isinstance(module, (nn.Conv2d, nn.Linear)): 358 | init.xavier_uniform_(module.weight) 359 | init.zeros_(module.bias) 360 | 361 | def forward(self, x): 362 | h = self.block1(x) 363 | h = self.block2(h) 364 | h += self.shortcut(x) 365 | h = self.attn(h) 366 | return h 367 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from torchvision.utils import save_image 6 | from tqdm.auto import tqdm, trange 7 | from torch.utils.tensorboard import SummaryWriter 8 | from torch.utils.data import DataLoader 9 | from data import get_dataset, get_dataset_config 10 | from models import InfoDiff, Diff, VAE 11 | from sampling import DiffusionProcess, TwoPhaseDiffusionProcess, LatentDiffusionProcess 12 | from utils import ( 13 | AverageMeter, ProgressMeter, GradualWarmupScheduler, \ 14 | generate_exp_string, seed_everything, cos, LatentDataset 15 | ) 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | torch.backends.cudnn.benchmark = True 20 | torch.backends.cudnn.allow_tf32 = True 21 | 22 | 23 | # ---------------------------------------------------------------------------- 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument('--r_seed', type=int, default=0, 29 | help='the value of given random seed') 30 | parser.add_argument('--img_id', type=int, default=0, 31 | help='the id of given img') 32 | parser.add_argument('--model', required=True, 33 | choices=['diff', 'vae', 'vanilla'], help='which type of model to run') 34 | parser.add_argument('--mode', required=True, 35 | choices=['train', 'eval', 'eval_fid', 'save_latent', 'disentangle', 36 | 'interpolate', 'save_original_img', 'latent_quality', 37 | 'train_latent_ddim', 'plot_latent'], help='which mode to run') 38 | parser.add_argument('--prior', required=True, 39 | choices=['regular', '10mix', 'roll'], help='which type of prior to run') 40 | parser.add_argument('--kld_weight', type=float, default=0, 41 | help='weight of kld loss') 42 | parser.add_argument('--mmd_weight', type=float, default=0.1, 43 | help='weight of mmd loss') 44 | parser.add_argument('--use_C', action='store_true', 45 | default=False, help='use control constant or not') 46 | parser.add_argument('--C_max', type=float, default=25, 47 | help='control constant of kld loss (orig defualt: 25 for simple, 50 for complex)') 48 | parser.add_argument('--dataset', required=True, 49 | choices=['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites', 'chairs', 'ffhq'], help='training dataset') 50 | parser.add_argument('--img_folder', default='./imgs', 51 | help='path to save sampled images') 52 | parser.add_argument('--log_folder', default='./logs', 53 | help='path to save logs') 54 | parser.add_argument('-e', '--epochs', type=int, default=20, 55 | help='number of epochs to train') 56 | parser.add_argument('--save_epochs', type=int, default=5, 57 | help='number of epochs to save model') 58 | parser.add_argument('--batch_size', type=int, default=64, 59 | help='training batch size') 60 | parser.add_argument('--learning_rate', type=float, default=0.0001, 61 | help='learning rate') 62 | parser.add_argument('--optimizer', default='adam', choices=['adam'], 63 | help='optimization algorithm') 64 | parser.add_argument('--model_folder', default='./models', 65 | help='folder where logs will be stored') 66 | parser.add_argument('--deterministic', action='store_true', 67 | default=False, help='deterministid sampling') 68 | parser.add_argument('--input_channels', type=int, default=1, 69 | help='number of input channels') 70 | parser.add_argument('--unets_channels', type=int, default=64, 71 | help='number of input channels') 72 | parser.add_argument('--encoder_channels', type=int, default=64, 73 | help='number of input channels') 74 | parser.add_argument('--input_size', type=int, default=32, 75 | help='expected size of input') 76 | parser.add_argument('--a_dim', type=int, default=32, required=True, 77 | help='dimensionality of auxiliary variable') 78 | parser.add_argument('--beta1', type=float, default=1e-5, 79 | help='value of beta 1') 80 | parser.add_argument('--betaT', type=float, default=1e-2, 81 | help='value of beta T') 82 | parser.add_argument('--diffusion_steps', type=int, default=1000, 83 | help='number of diffusion steps') 84 | parser.add_argument('--split_step', type=int, default=500, 85 | help='the step for splitting two phases') 86 | parser.add_argument('--sampling_number', type=int, default=16, 87 | help='number of sampled images') 88 | parser.add_argument('--data_dir', type=str, default='./data') 89 | parser.add_argument('--tb_logger', action='store_true', 90 | help='use tensorboard logger.') 91 | parser.add_argument('--is_latent', action='store_true', 92 | help='use latent diffusion for unconditional sampling.') 93 | parser.add_argument('--is_bottleneck', action='store_true', 94 | help='only fuse aux variable in bottleneck layers.') 95 | args = parser.parse_args() 96 | 97 | return args 98 | 99 | 100 | # ---------------------------------------------------------------------------- 101 | 102 | 103 | def save_images(args, sample=None, epoch=0, sample_num=0): 104 | root = f'{args.img_folder}' 105 | if args.model == 'vae': 106 | root = os.path.join(root, 'vae') 107 | else: 108 | if args.model == 'vanilla': 109 | root = os.path.join(root, 'diff') 110 | root = os.path.join(root, generate_exp_string(args)) 111 | if args.mode == 'eval': 112 | root = os.path.join(root, 'eval') 113 | elif args.mode == 'disentangle': 114 | root = os.path.join(root, f'disentangle-{args.img_id}') 115 | elif args.mode == 'interpolate': 116 | root = os.path.join(root, f'interpolate-{args.img_id}') 117 | elif args.mode == 'save_latent': 118 | root = os.path.join(root, 'save_latent') 119 | elif args.mode == 'attr_classification': 120 | root = os.path.join(root, 'attr_classification') 121 | elif args.mode == 'plot_latent': 122 | root = os.path.join(root, 'plot_latent') 123 | os.makedirs(root, exist_ok=True) 124 | path = os.path.join(root, f'sample-{epoch}.png') 125 | 126 | img_range = (-1, 1) 127 | if args.mode == 'train': 128 | save_image(sample, path, normalize=True, range=img_range, nrow=4) 129 | elif args.mode == 'eval': 130 | for _ in range(sample_num, sample_num + len(sample)): 131 | path = os.path.join(root, f"sample{sample_num:05d}.png") 132 | save_image(sample, path, normalize=True, range=img_range) 133 | elif args.mode == 'disentangle': 134 | path = os.path.join(root, f"sample{sample_num}.png") 135 | save_image(sample, path, normalize=True, range=img_range, nrow=sample.shape[0]) 136 | elif args.mode == 'interpolate': 137 | path = os.path.join(root, f"sample{sample_num}.png") 138 | save_image(sample, path, normalize=True, range=img_range, nrow=sample.shape[0]) 139 | elif args.mode == 'plot_latent': 140 | path = os.path.join(root, f"{args.mode}.png") 141 | return path 142 | elif args.mode == 'attr_classification': 143 | return root 144 | 145 | def save_model(args, epoch, model): 146 | root = f'{args.model_folder}' 147 | if args.model == 'vae': 148 | root = os.path.join(root, 'vae') 149 | else: 150 | if args.model == 'vanilla': 151 | root = os.path.join(root, 'diff') 152 | root = os.path.join(root, generate_exp_string(args)) 153 | if args.mode == "train_latent_ddim": 154 | root += '_latent' 155 | os.makedirs(root, exist_ok=True) 156 | path = os.path.join(root, f'model-{epoch}.pth') 157 | torch.save(model.state_dict(), path) 158 | print(f"Saved PyTorch model state to {path}") 159 | 160 | 161 | def train(args): 162 | seed_everything(args.r_seed) 163 | log_dir = f'{args.log_folder}' 164 | log_dir = os.path.join(log_dir, generate_exp_string(args)) 165 | tb_logger = SummaryWriter(log_dir=log_dir) if args.tb_logger else None 166 | device = "cuda" if torch.cuda.is_available() else "cpu" 167 | shape = get_dataset_config(args) 168 | print(dict(vars(args))) 169 | dataloader = get_dataset(args) 170 | 171 | if args.model == 'diff': 172 | model = InfoDiff(args, device, shape) 173 | elif args.model == 'vanilla': 174 | model = Diff(args, device, shape) 175 | elif args.model == 'vae': 176 | model = VAE(args, device, shape) 177 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5) 178 | 179 | losses = AverageMeter('Loss', ':.4f') 180 | progress = ProgressMeter(args.epochs, [losses], prefix='Epoch ') 181 | 182 | cosineScheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 183 | optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1) 184 | warmUpScheduler = GradualWarmupScheduler( 185 | optimizer=optimizer, multiplier=2., warm_epoch=1, after_scheduler=cosineScheduler) 186 | 187 | global_step = 0 188 | for curr_epoch in trange(0, args.epochs, desc="Epoch #"): 189 | total_loss = 0 190 | batch_bar = tqdm(dataloader, desc="Batch #") 191 | for idx, data in enumerate(batch_bar): 192 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10']: 193 | data = data[0] 194 | data = data.to(device=device) 195 | loss = model.loss_fn(args=args, x=data, curr_epoch=curr_epoch) 196 | batch_bar.set_postfix(loss=format(loss,'.4f')) 197 | optimizer.zero_grad() 198 | loss.backward() 199 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) 200 | optimizer.step() 201 | total_loss += loss.item() 202 | global_step += 1 203 | if tb_logger: 204 | tb_logger.add_scalar('train/loss', loss.item(), global_step) 205 | losses.update(total_loss / idx) 206 | current_epoch = curr_epoch 207 | progress.display(current_epoch) 208 | current_epoch += 1 209 | warmUpScheduler.step() 210 | losses.reset() 211 | if current_epoch % args.save_epochs == 0: 212 | save_model(args, current_epoch, model) 213 | 214 | 215 | def eval(args): 216 | if args.mode != 'train_latent_ddim': 217 | seed_everything(args.r_seed) 218 | device = "cuda" if torch.cuda.is_available() else "cpu" 219 | shape = get_dataset_config(args) 220 | print(dict(vars(args))) 221 | root = f'{args.model_folder}' 222 | if args.model == 'diff': 223 | model = InfoDiff(args, device, shape) 224 | elif args.model == 'vanilla': 225 | model = Diff(args, device, shape) 226 | root = os.path.join(root, 'diff') 227 | elif args.model == 'vae': 228 | model = VAE(args, device, shape) 229 | root = os.path.join(root, 'vae') 230 | root = os.path.join(root, generate_exp_string(args)) 231 | path = os.path.join(root, f'model-{args.epochs}.pth') 232 | print(f"Loading model from {path}") 233 | model.load_state_dict(torch.load(path, map_location=device), strict=False) 234 | if (args.dataset in ['celeba', 'cifar10', 'mnist', 'fmnist', 'ffhq'] and args.mode in ['eval_fid']): 235 | if args.is_latent: 236 | shape_latent = (1, args.a_dim, args.a_dim) 237 | model2 = Diff(args, device, shape_latent) 238 | path2 = f'./models/{generate_exp_string(args)}_latent/model-{args.epochs}.pth' 239 | if os.path.exists(path2): 240 | print(f"Loading model from {path2}") 241 | else: 242 | raise FileNotFoundError("The file path {} does not exist, please train the latent diffusion model first.".format(path2)) 243 | model2.load_state_dict(torch.load(path2, map_location=device), strict=True) 244 | else: 245 | model2 = Diff(args, device, shape) 246 | path2 = f'./models/diff/{args.dataset}_{args.a_dim}d/model-{args.epochs}.pth' 247 | if os.path.exists(path2): 248 | print(f"Loading model from {path2}") 249 | else: 250 | raise FileNotFoundError("The file path {} does not exist, please train the vanilla diffusion model first.".format(path2)) 251 | model2.load_state_dict(torch.load(path2, map_location=device), strict=True) 252 | model2.eval() 253 | model.eval() 254 | if args.model in ['diff', 'vanilla']: 255 | process = DiffusionProcess(args, model, device, shape) 256 | if args.mode == 'eval': 257 | if args.model in ['diff', 'vanilla']: 258 | for sample_num in trange(0, args.sampling_number, args.batch_size, desc="Generating eval images"): 259 | sample = process.sampling(sampling_number=16) 260 | save_images(args, sample, sample_num=sample_num) 261 | elif args.model == 'vae': 262 | a = torch.randn([args.sampling_number, args.a_dim]).to(device=device) 263 | sample = model.decoder(a) 264 | save_images(args, sample) 265 | elif args.mode == 'eval_fid': 266 | root = f'{args.img_folder}' 267 | if args.model == 'vae': 268 | root = os.path.join(root, 'vae') 269 | root = os.path.join(root, generate_exp_string(args)) 270 | if args.is_latent: 271 | root = os.path.join(root, 'eval-fid-latent') 272 | else: 273 | root = os.path.join(root, 'eval-fid-fast') 274 | os.makedirs(root, exist_ok=True) 275 | print(f"Saving images to {root}") 276 | if args.model == 'diff': 277 | if args.is_latent: 278 | process_latent = LatentDiffusionProcess(args, model2, device) 279 | else: 280 | process = TwoPhaseDiffusionProcess(args, model, model2, device, shape) 281 | 282 | for sample_num in trange(0, args.sampling_number, args.batch_size, desc="Generating eval images"): 283 | if args.is_latent: 284 | batch_a = process_latent.sampling(sampling_number=args.batch_size) 285 | batch = process.sampling(sampling_number=args.batch_size, a=batch_a) 286 | else: 287 | batch = process.sampling(sampling_number=args.batch_size) 288 | for batch_num, img in enumerate(batch): 289 | img = torch.clip(img, min=-1, max=1) 290 | img = ((img + 1)/2) # normalize to 0 - 1 291 | img_num = sample_num + batch_num 292 | if img_num >= args.sampling_number: 293 | return 294 | path = os.path.join(root, f'sample-{img_num:06d}.png') 295 | save_image(img, path) 296 | print("DONE") 297 | elif args.model == 'vae': 298 | for sample_num in trange(0, args.sampling_number, args.batch_size, desc="Generating eval images"): 299 | a = torch.randn([args.batch_size, args.a_dim]).to(device=device) 300 | batch = model.decoder(a) 301 | for batch_num, img in enumerate(batch): 302 | img = torch.clip(img, min=-1, max=1) 303 | img = ((img + 1)/2) # normalize to 0 - 1 304 | img_num = sample_num + batch_num 305 | if img_num >= args.sampling_number: 306 | return 307 | path = os.path.join(root, f'sample-{img_num:06d}.png') 308 | save_image(img, path) 309 | print("DONE") 310 | elif args.mode == 'latent_quality': 311 | process = DiffusionProcess(args, model, device, shape) 312 | dataloader = get_dataset(args) 313 | root = f'{args.img_folder}' 314 | root = os.path.join(root, generate_exp_string(args)) 315 | root = os.path.join(root, 'latent_quality') 316 | print(f"Saving images to {root}") 317 | for idx, data in enumerate(dataloader): 318 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites']: 319 | data_all = data 320 | data = data_all[0] 321 | if idx == 10: 322 | break 323 | data = data.to(device=device) 324 | if args.kld_weight != 0: 325 | with torch.no_grad(): 326 | _, _, mu, log_var = model.encoder(data) 327 | a = mu + torch.exp(0.5 * log_var) 328 | elif args.mmd_weight != 0: 329 | with torch.no_grad(): 330 | a, _, _, _ = model.encoder(data) 331 | xT = process.reverse_sampling(data, a) 332 | xT_original = xT.repeat(args.sampling_number, 1, 1, 1) 333 | a_original = a.repeat(args.sampling_number, 1) 334 | xT = torch.randn_like(xT_original) 335 | batch = process.sampling(xT=xT, a=a_original) 336 | os.makedirs(root, exist_ok=True) 337 | for batch_num, img in enumerate(batch): 338 | img = torch.clip(img, min=-1, max=1) 339 | img = ((img + 1)/2) # normalize to 0 - 1 340 | path = os.path.join(path, f'sample-{batch_num:06d}.png') 341 | save_image(img, path) 342 | elif args.mode == 'plot_latent': 343 | all_a, all_attr = [], [] 344 | dataloader = get_dataset(args) 345 | for idx, data in enumerate(dataloader): 346 | if args.dataset in ['fmnist', 'celeba', 'cifar10', 'dsprites', 'mnist']: 347 | data_all = data 348 | data = data_all[0] 349 | if args.dataset in ['celeba', 'fmnist', 'mnist']: 350 | latents_classes = data_all[1] 351 | elif args.dataset == 'dsprites': 352 | latents_classes = data_all[2] 353 | data = data.to(device=device) 354 | if (args.mmd_weight == 0 and args.kld_weight == 0): 355 | with torch.no_grad(): 356 | a, _, _, _ = model.encoder(data) 357 | elif (args.mmd_weight != 0): 358 | with torch.no_grad(): 359 | a, _, _, _ = model.encoder(data) 360 | else: 361 | with torch.no_grad(): 362 | _, _, mu, _ = model.encoder(data) 363 | a = mu 364 | all_a.append(a.cpu().numpy()) 365 | all_attr.append(latents_classes) 366 | all_a = np.concatenate(all_a) 367 | all_attr = np.concatenate(all_attr) 368 | plt.scatter(all_a[:, 0], all_a[:, 1], c = all_attr, cmap = 'tab10', s=5) 369 | path = save_images(args) 370 | plt.savefig(path) 371 | elif args.mode == 'disentangle': 372 | dataloader = get_dataset(args) 373 | for idx, data in enumerate(dataloader): 374 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites']: 375 | data_all = data 376 | data = data_all[0] 377 | if args.dataset == 'celeba': 378 | latents_classes = data_all[1] 379 | elif args.dataset == 'dsprites': 380 | latents_classes = data_all[2] 381 | if idx == args.img_id: 382 | break 383 | data = data.to(device=device) 384 | # eta = [-3, -2.4, -1.8, -1.2, -0.6, 0.0, 0.6, 1.2, 1.8, 2.4, 3.0] 385 | eta = [-1.5, -1.2, -0.9, -0.6, -0.3, 0.0, 0.3, 0.6, 0.9, 1.2, 1.5] 386 | if args.kld_weight != 0: 387 | with torch.no_grad(): 388 | _, _, mu, _ = model.encoder(data) 389 | a = mu 390 | elif args.mmd_weight != 0: 391 | with torch.no_grad(): 392 | a, _, _, _ = model.encoder(data) 393 | if args.model == 'diff': 394 | xT = process.reverse_sampling(data, a) 395 | xT = xT.repeat(len(eta), 1, 1, 1) 396 | for k in range(args.a_dim): 397 | a_list = [] 398 | for e in eta: 399 | if args.kld_weight != 0: 400 | with torch.no_grad(): 401 | _, _, mu, log_var = model.encoder(data) 402 | a = mu 403 | print(mu, log_var) 404 | elif args.mmd_weight != 0: 405 | with torch.no_grad(): 406 | a, _, _, _ = model.encoder(data) 407 | a[0][k] = e 408 | a_list.append(a) 409 | a = torch.stack(a_list).squeeze(dim=1) 410 | if args.model == 'diff': 411 | sample = process.sampling(xT=xT, a=a) 412 | elif args.model == 'vae': 413 | sample = model.decoder(a) 414 | save_images(args, sample, sample_num=k) 415 | elif args.mode == 'save_latent': 416 | all_a, all_attr = [], [] 417 | dataloader = get_dataset(args) 418 | for idx, data in enumerate(dataloader): 419 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites']: 420 | data_all = data 421 | data = data_all[0] 422 | if args.dataset in ['celeba', 'fmnist', 'mnist', 'cifar10']: 423 | latents_classes = data_all[1] 424 | elif args.dataset == 'dsprites': 425 | latents_classes = data_all[2] 426 | else: 427 | latents_classes = ['No Attributes'] 428 | data = data.to(device=device) 429 | if args.kld_weight != 0: 430 | with torch.no_grad(): 431 | _, _, mu, _ = model.encoder(data) 432 | a = mu 433 | elif args.mmd_weight != 0: 434 | with torch.no_grad(): 435 | a, _, _, _ = model.encoder(data) 436 | elif (args.mmd_weight == 0 and args.kld_weight == 0): 437 | with torch.no_grad(): 438 | a, _, _, _ = model.encoder(data) 439 | all_a.append(a.cpu().numpy()) 440 | all_attr.append(latents_classes) 441 | all_a = np.concatenate(all_a) 442 | all_attr = np.concatenate(all_attr) 443 | np.savez("{}_{}_latent".format(args.model, generate_exp_string(args).replace(".", "_")), all_a = all_a, all_attr = all_attr) 444 | elif args.mode == 'interpolate': 445 | dataloader = get_dataset(args) 446 | for idx, data in enumerate(dataloader): 447 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10']: 448 | data = data[0] 449 | if idx == args.img_id: 450 | break 451 | data = data.to(device=device) 452 | if args.kld_weight != 0: 453 | with torch.no_grad(): 454 | _, _, mu, _ = model.encoder(data) 455 | a = mu 456 | elif args.mmd_weight != 0: 457 | with torch.no_grad(): 458 | a, _, _, _ = model.encoder(data) 459 | elif (args.mmd_weight == 0 and args.kld_weight == 0): 460 | with torch.no_grad(): 461 | a, _, _, _ = model.encoder(data) 462 | if args.model in ['diff', 'vanilla']: 463 | xT = process.reverse_sampling(data, a) 464 | theta = torch.arccos(cos(xT[0], xT[1])) 465 | a1 = a[0] 466 | a2 = a[1] 467 | eta = [0.0, 0.11, 0.22, 0.33, 0.44, 0.55, 0.66, 0.77, 0.88, 1.0] 468 | intp_a_list = [] 469 | intp_x_list = [] 470 | for e in eta: 471 | intp_a_list.append(np.cos(e * np.pi / 2) * a1 + np.sin(e * np.pi / 2) * a2) 472 | if args.model in ['diff', 'vanilla']: 473 | intp_x = (torch.sin((1 - e) * theta) * xT[0] + torch.sin(e * theta) * xT[1]) / torch.sin(theta) 474 | intp_x_list.append(intp_x) 475 | intp_a = torch.stack(intp_a_list) 476 | if args.model in ['diff', 'vanilla']: 477 | intp_x = torch.stack(intp_x_list).squeeze(dim=1) 478 | sample = process.sampling(xT=intp_x, a=intp_a) 479 | elif args.model == 'vae': 480 | sample = model.decoder(intp_a) 481 | save_images(args, sample) 482 | elif args.mode == 'train_latent_ddim': 483 | dataset = LatentDataset("{}_{}_latent.npz".format(args.model, generate_exp_string(args).replace(".", "_"))) 484 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 485 | seed_everything(args.r_seed) 486 | log_dir = f'{args.log_folder}' 487 | log_dir = os.path.join(log_dir, generate_exp_string(args)) 488 | log_dir += '_latent' 489 | tb_logger = SummaryWriter(log_dir=log_dir) if args.tb_logger else None 490 | device = "cuda" if torch.cuda.is_available() else "cpu" 491 | shape = (1, args.a_dim, args.a_dim) 492 | model = Diff(args, device, shape) 493 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5) 494 | 495 | losses = AverageMeter('Loss', ':.4f') 496 | progress = ProgressMeter(args.epochs, [losses], prefix='Epoch ') 497 | 498 | cosineScheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 499 | optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1) 500 | warmUpScheduler = GradualWarmupScheduler( 501 | optimizer=optimizer, multiplier=2., warm_epoch=1, after_scheduler=cosineScheduler) 502 | 503 | global_step = 0 504 | for curr_epoch in trange(0, args.epochs, desc="Epoch #"): 505 | total_loss = 0 506 | batch_bar = tqdm(dataloader, desc="Batch #") 507 | for idx, data in enumerate(batch_bar): 508 | data = data.to(device=device) 509 | loss = model.loss_fn(args=args, x=data, curr_epoch=curr_epoch) 510 | batch_bar.set_postfix(loss=format(loss,'.4f')) 511 | optimizer.zero_grad() 512 | loss.backward() 513 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) 514 | optimizer.step() 515 | total_loss += loss.item() 516 | global_step += 1 517 | if tb_logger: 518 | tb_logger.add_scalar('train/loss', loss.item(), global_step) 519 | losses.update(total_loss / idx) 520 | current_epoch = curr_epoch 521 | progress.display(current_epoch) 522 | current_epoch += 1 523 | warmUpScheduler.step() 524 | losses.reset() 525 | if current_epoch % args.save_epochs == 0: 526 | save_model(args, current_epoch, model) 527 | 528 | 529 | if __name__ == '__main__': 530 | args = parse_args() 531 | if args.mode in ['train']: 532 | train(args) 533 | elif args.mode in ['eval', 'eval_fid', 'latent_quality', 'disentangle', 'interpolate', 534 | 'save_latent', 'train_latent_ddim', 'plot_latent']: 535 | if args.mode in ['disentangle', 'latent_quality']: 536 | args.batch_size = 1 537 | elif args.mode == 'interpolate': 538 | args.batch_size = 2 539 | eval(args) 540 | elif args.mode in ['save_original_img']: 541 | from tqdm import tqdm 542 | import os 543 | from torchvision.utils import save_image 544 | output_folder = f'./{args.dataset}_imgs/' 545 | os.makedirs(output_folder, exist_ok=True) 546 | dataloader = get_dataset(args) 547 | for i, img in enumerate(tqdm(dataloader)): 548 | img = ((img[0] + 1)/2) # normalize to 0 - 1 549 | save_image(img, f'{output_folder}/{i:06d}.png') -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64 4 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DiffusionProcess(): 4 | def __init__(self, args, diffusion_fn, device, shape): 5 | ''' 6 | beta_1 : beta_1 of diffusion process 7 | beta_T : beta_T of diffusion process 8 | T : step of diffusion process 9 | diffusion_fn : trained diffusion network 10 | shape : data shape 11 | ''' 12 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps) 13 | self.alphas = 1 - self.betas 14 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device) 15 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]]) 16 | self.shape = shape 17 | self.deterministic = args.deterministic 18 | self.a_dim = args.a_dim 19 | self.model = args.model 20 | self.diffusion_fn = diffusion_fn.to(device=device) 21 | self.device = device 22 | 23 | def _ddpm_one_diffusion_step(self, x, a=None): 24 | ''' 25 | x : perturbated data 26 | ''' 27 | for idx in reversed(range(len(self.alpha_bars))): 28 | 29 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x) 30 | sqrt_tilde_beta = torch.sqrt((1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx]) 31 | if self.model == 'vanilla': 32 | predict_epsilon = self.diffusion_fn(x, idx) 33 | else: 34 | predict_epsilon = self.diffusion_fn(x, idx, a) 35 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * (x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon) 36 | 37 | x = mu_theta_xt + sqrt_tilde_beta * noise 38 | 39 | yield x 40 | 41 | def _ddim_one_diffusion_step(self, x, a=None): 42 | ''' 43 | x : perturbated data 44 | ''' 45 | eta = 0.01 46 | for idx in reversed(range(len(self.alpha_bars))): 47 | 48 | if self.model == 'vanilla': 49 | predict_epsilon = self.diffusion_fn(x, idx) 50 | else: 51 | predict_epsilon = self.diffusion_fn(x, idx, a) 52 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx]) 53 | if idx == 0: 54 | x = x_0 55 | else: 56 | noise = torch.randn_like(x) 57 | sigma = eta * torch.sqrt((1 - self.alpha_prev_bars[idx-1]) / (1 - self.alpha_bars[idx-1])) * torch.sqrt(self.betas[idx-1]) 58 | x = torch.sqrt(self.alpha_prev_bars[idx-1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx-1] - sigma**2) * predict_epsilon 59 | x += sigma * noise 60 | yield x 61 | 62 | def _ddim_one_reverse_diffusion_step(self, x, a=None): 63 | for idx in range(len(self.alpha_bars)-1): 64 | if idx == 0: 65 | yield x 66 | else: 67 | if self.model == 'vanilla': 68 | predict_epsilon = self.diffusion_fn(x, idx) 69 | else: 70 | predict_epsilon = self.diffusion_fn(x, idx, a) 71 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx]) 72 | x = torch.sqrt(self.alpha_prev_bars[idx+1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx+1]) * predict_epsilon 73 | yield x 74 | 75 | def _one_diffusion_step(self, sample, a=None, deterministic=False): 76 | if deterministic: 77 | return self._ddim_one_diffusion_step(sample, a) 78 | else: 79 | return self._ddpm_one_diffusion_step(sample, a) 80 | 81 | @torch.no_grad() 82 | def reverse_sampling(self, x0, a=None): 83 | sample = x0 84 | for sample in self._ddim_one_reverse_diffusion_step(sample): 85 | final = sample 86 | 87 | return final 88 | 89 | @torch.no_grad() 90 | def sampling(self, sampling_number=16, xT=None, a=None): 91 | if xT is None: 92 | xT = torch.randn([sampling_number, *self.shape]).to(device=self.device) 93 | if self.model != 'vanilla': 94 | if a is None: 95 | a = torch.randn([sampling_number, self.a_dim]).to(device=self.device) 96 | 97 | sample = xT 98 | for sample in self._one_diffusion_step(sample=sample, a=a, deterministic=self.deterministic): 99 | final = sample 100 | 101 | return final 102 | 103 | 104 | class TwoPhaseDiffusionProcess(): 105 | def __init__(self, args, diffusion_fn_1, diffusion_fn_2, device, shape): 106 | ''' 107 | beta_1 : beta_1 of diffusion process 108 | beta_T : beta_T of diffusion process 109 | T : step of diffusion process 110 | diffusion_fn : trained diffusion network 111 | shape : data shape 112 | ''' 113 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps) 114 | self.alphas = 1 - self.betas 115 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device) 116 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]]) 117 | self.shape = shape 118 | self.deterministic = args.deterministic 119 | self.a_dim = args.a_dim 120 | self.model = args.model 121 | self.split_step = args.split_step 122 | self.mode = args.mode 123 | 124 | self.diffusion_fn_1 = diffusion_fn_1.to(device=device) 125 | self.diffusion_fn_2 = diffusion_fn_2.to(device=device) 126 | self.device = device 127 | 128 | def _ddpm_one_diffusion_step(self, x, a=None, t=None): 129 | ''' 130 | x : perturbated data 131 | ''' 132 | for idx in reversed(range(len(self.alpha_bars))): 133 | 134 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x) 135 | sqrt_tilde_beta = torch.sqrt((1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx]) 136 | if t <= self.split_step: 137 | predict_epsilon = self.diffusion_fn_2(x, idx) 138 | else: 139 | predict_epsilon = self.diffusion_fn_1(x, idx, a) 140 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * (x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon) 141 | 142 | x = mu_theta_xt + sqrt_tilde_beta * noise 143 | 144 | yield x 145 | 146 | def _ddim_one_diffusion_step(self, x, a=None, t=None): 147 | ''' 148 | x : perturbated data 149 | ''' 150 | eta = 0.01 151 | for idx in reversed(range(len(self.alpha_bars))): 152 | 153 | if t <= self.split_step: 154 | predict_epsilon = self.diffusion_fn_2(x, idx) 155 | else: 156 | predict_epsilon = self.diffusion_fn_1(x, idx, a) 157 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx]) 158 | if idx == 0: 159 | x = x_0 160 | else: 161 | noise = torch.randn_like(x) 162 | sigma = eta * torch.sqrt((1 - self.alpha_prev_bars[idx-1]) / (1 - self.alpha_bars[idx-1])) * torch.sqrt(self.betas[idx-1]) 163 | x = torch.sqrt(self.alpha_prev_bars[idx-1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx-1] - sigma**2) * predict_epsilon 164 | x += sigma * noise 165 | yield x 166 | 167 | def _ddim_one_reverse_diffusion_step(self, x, a=None): 168 | for idx in range(len(self.alpha_bars)-1): 169 | if idx == 0: 170 | yield x 171 | else: 172 | predict_epsilon = self.diffusion_fn_1(x, idx, a) 173 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx]) 174 | x = torch.sqrt(self.alpha_prev_bars[idx+1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx+1]) * predict_epsilon 175 | yield x 176 | 177 | def _one_diffusion_step(self, sample, a=None, deterministic=False, t=None): 178 | if deterministic: 179 | return self._ddim_one_diffusion_step(sample, a, t) 180 | else: 181 | return self._ddpm_one_diffusion_step(sample, a, t) 182 | 183 | @torch.no_grad() 184 | def reverse_sampling(self, x0, a=None): 185 | sample = x0 186 | for sample in self._ddim_one_reverse_diffusion_step(sample): 187 | final = sample 188 | 189 | return final 190 | 191 | @torch.no_grad() 192 | def sampling(self, sampling_number=16, xT=None, a=None): 193 | if xT is None: 194 | xT = torch.randn([sampling_number, *self.shape]).to(device=self.device) 195 | if a is None: 196 | a = torch.randn([sampling_number, self.a_dim]).to(device=self.device) 197 | 198 | sample = xT 199 | t = 0 200 | for sample in self._one_diffusion_step(sample=sample, a=a, deterministic=self.deterministic, t=t): 201 | final = sample 202 | t += 1 203 | 204 | return final 205 | 206 | 207 | class LatentDiffusionProcess(): 208 | def __init__(self, args, diffusion_fn, device): 209 | ''' 210 | beta_1 : beta_1 of diffusion process 211 | beta_T : beta_T of diffusion process 212 | T : step of diffusion process 213 | diffusion_fn : trained diffusion network 214 | ''' 215 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps) 216 | self.alphas = 1 - self.betas 217 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device) 218 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]]) 219 | self.deterministic = args.deterministic 220 | self.a_dim = args.a_dim 221 | self.model = args.model 222 | self.split_step = args.split_step 223 | self.mode = args.mode 224 | self.diffusion_fn = diffusion_fn.to(device=device) 225 | self.device = device 226 | 227 | def _ddpm_one_diffusion_step(self, x): 228 | ''' 229 | x : perturbated data 230 | ''' 231 | for idx in reversed(range(len(self.alpha_bars))): 232 | 233 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x) 234 | sqrt_tilde_beta = torch.sqrt((1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx]) 235 | predict_epsilon = self.diffusion_fn(x, idx) 236 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * (x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon) 237 | 238 | x = mu_theta_xt + sqrt_tilde_beta * noise 239 | 240 | yield x 241 | 242 | def _ddim_one_diffusion_step(self, x): 243 | ''' 244 | x : perturbated data 245 | ''' 246 | eta = 0.01 247 | for idx in reversed(range(len(self.alpha_bars))): 248 | predict_epsilon = self.diffusion_fn(x, idx) 249 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx]) 250 | if idx == 0: 251 | x = x_0 252 | else: 253 | noise = torch.randn_like(x) 254 | sigma = eta * torch.sqrt((1 - self.alpha_prev_bars[idx-1]) / (1 - self.alpha_bars[idx-1])) * torch.sqrt(self.betas[idx-1]) 255 | x = torch.sqrt(self.alpha_prev_bars[idx-1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx-1] - sigma**2) * predict_epsilon 256 | x += sigma * noise 257 | yield x 258 | 259 | def _ddim_one_reverse_diffusion_step(self, x): 260 | for idx in range(len(self.alpha_bars)-1): 261 | if idx == 0: 262 | yield x 263 | else: 264 | predict_epsilon = self.diffusion_fn(x, idx) 265 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx]) 266 | x = torch.sqrt(self.alpha_prev_bars[idx+1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx+1]) * predict_epsilon 267 | yield x 268 | 269 | def _one_diffusion_step(self, sample, deterministic=False): 270 | if deterministic: 271 | return self._ddim_one_diffusion_step(sample) 272 | else: 273 | return self._ddpm_one_diffusion_step(sample) 274 | 275 | @torch.no_grad() 276 | def reverse_sampling(self, x0): 277 | sample = x0 278 | for sample in self._ddim_one_reverse_diffusion_step(sample): 279 | final = sample 280 | 281 | return final 282 | 283 | @torch.no_grad() 284 | def sampling(self, sampling_number=16, xT=None): 285 | if xT is None: 286 | xT = torch.randn([sampling_number, self.a_dim]).to(device=self.device) 287 | 288 | sample = xT 289 | for sample in self._one_diffusion_step(sample=sample, deterministic=self.deterministic): 290 | final = sample 291 | 292 | return final -------------------------------------------------------------------------------- /save_latent.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | from torch.nn import functional as F 7 | from torch.utils.data import Dataset 8 | from sklearn.datasets import make_swiss_roll 9 | 10 | 11 | def gaussian_mixture(batch_size, n_dim=2, n_labels=10, 12 | x_var=0.5, y_var=0.1, label_indices=None): 13 | if n_dim % 2 != 0: 14 | raise Exception("n_dim must be a multiple of 2.") 15 | 16 | def sample(x, y, label, n_labels): 17 | shift = 1.4 18 | if label >= n_labels: 19 | label = np.random.randint(0, n_labels) 20 | r = 2.0 * np.pi / float(n_labels) * float(label) 21 | new_x = x * math.cos(r) - y * math.sin(r) 22 | new_y = x * math.sin(r) + y * math.cos(r) 23 | new_x += shift * math.cos(r) 24 | new_y += shift * math.sin(r) 25 | return np.array([new_x, new_y]).reshape((2,)) 26 | 27 | x = np.random.normal(0, x_var, (batch_size, n_dim // 2)) 28 | y = np.random.normal(0, y_var, (batch_size, n_dim // 2)) 29 | z = np.empty((batch_size, n_dim), dtype=np.float32) 30 | for batch in range(batch_size): 31 | for zi in range(n_dim // 2): 32 | if label_indices is not None: 33 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels) 34 | else: 35 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels) 36 | 37 | return z 38 | 39 | def swiss_roll(batch_size, noise=0.5): 40 | return make_swiss_roll(n_samples=batch_size, noise=noise)[0][:, [0, 2]] / 5. 41 | 42 | def cos(a, b): 43 | a = a.view(-1) 44 | b = b.view(-1) 45 | a = F.normalize(a, dim=0) 46 | b = F.normalize(b, dim=0) 47 | return (a * b).sum() 48 | 49 | def generate_exp_string(args) -> str: 50 | root = f'{args.dataset}_{args.a_dim}d' 51 | if args.kld_weight != 0: 52 | root += f'_{args.kld_weight}kld' 53 | if args.use_C: 54 | root += f'_{args.C_max}C' 55 | if args.mmd_weight != 0: 56 | root += f'_{args.mmd_weight}mmd' 57 | if args.prior != 'regular': 58 | root += f'_{args.prior}' 59 | if args.is_bottleneck: 60 | root += '_bottleneck' 61 | return root 62 | 63 | 64 | def seed_everything(r_seed): 65 | print("Set seed: ", r_seed) 66 | random.seed(r_seed) 67 | np.random.seed(r_seed) 68 | torch.manual_seed(r_seed) 69 | torch.cuda.manual_seed(r_seed) 70 | torch.cuda.manual_seed_all(r_seed) 71 | torch.backends.cudnn.deterministic = True 72 | 73 | 74 | @torch.jit.script 75 | def compute_kernel(x, y): 76 | x_size = x.shape[0] 77 | y_size = y.shape[0] 78 | dim = x.shape[1] 79 | 80 | tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1) 81 | tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1) 82 | 83 | return torch.exp(-torch.mean((tiled_x - tiled_y)**2, dim=2)/dim*1.0) 84 | 85 | @torch.jit.script 86 | def compute_mmd(x, y): 87 | x_kernel = compute_kernel(x, x) 88 | y_kernel = compute_kernel(y, y) 89 | xy_kernel = compute_kernel(x, y) 90 | return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel) 91 | 92 | 93 | class AverageMeter(object): 94 | def __init__(self, name, fmt=':f'): 95 | self.name = name 96 | self.fmt = fmt 97 | self.reset() 98 | 99 | def reset(self): 100 | self.val = 0 101 | self.avg = 0 102 | self.sum = 0 103 | self.count = 0 104 | 105 | def update(self, val, n=1): 106 | self.val = val 107 | self.sum += val * n 108 | self. count += n 109 | self.avg = self.sum / self.count 110 | 111 | def __str__(self): 112 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 113 | return fmtstr.format(**self.__dict__) 114 | 115 | 116 | class ProgressMeter(object): 117 | def __init__(self, num_batches, meters, prefix=""): 118 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 119 | self.meters = meters 120 | self.prefix = prefix 121 | 122 | def display(self, batch): 123 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 124 | entries += [str(meter) for meter in self.meters] 125 | print('\r' + '\t'.join(entries), end='') 126 | 127 | def _get_batch_fmtstr(self, num_batches): 128 | num_digits = len(str(num_batches // 1)) 129 | fmt = '{:' + str(num_digits) + 'd}' 130 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 131 | 132 | 133 | class GradualWarmupScheduler(_LRScheduler): 134 | def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None): 135 | self.multiplier = multiplier 136 | self.total_epoch = warm_epoch 137 | self.after_scheduler = after_scheduler 138 | self.finished = False 139 | self.last_epoch = None 140 | self.base_lrs = None 141 | super().__init__(optimizer) 142 | 143 | def get_lr(self): 144 | if self.last_epoch > self.total_epoch: 145 | if self.after_scheduler: 146 | if not self.finished: 147 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 148 | self.finished = True 149 | return self.after_scheduler.get_lr() 150 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 151 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 152 | 153 | def step(self, epoch=None, metrics=None): 154 | if self.finished and self.after_scheduler: 155 | if epoch is None: 156 | self.after_scheduler.step(None) 157 | else: 158 | self.after_scheduler.step(epoch - self.total_epoch) 159 | else: 160 | return super(GradualWarmupScheduler, self).step(epoch) 161 | 162 | 163 | class LatentDataset(Dataset): 164 | def __init__(self, data_path): 165 | data = np.load(data_path) 166 | self.x = torch.from_numpy(data['all_a']).float() 167 | 168 | def __getitem__(self, index): 169 | return self.x[index] 170 | 171 | def __len__(self): 172 | return len(self.x) --------------------------------------------------------------------------------