├── README.md
├── calc_fid.py
├── calc_fid.sh
├── data.py
├── disentangle.sh
├── eval_disentangle.sh
├── eval_disentanglement.py
├── eval_fid.sh
├── flowchart.drawio-1.png
├── gen_fid.sh
├── gen_fid_stats.py
├── graphicalabstract.drawio_v2-1.png
├── interpolate.sh
├── latent_quality.sh
├── models.py
├── modules.py
├── run.py
├── run.sh
├── sampling.py
├── save_latent.sh
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # [InfoDiffusion: Representation Learning Using Information Maximizing Diffusion Models](https://arxiv.org/abs/2306.08757) (ICML 2023)
2 | By [Yingheng Wang](https://isjakewong.github.io/), [Yair Schiff](https://yair-schiff.github.io), [Aaron Gokaslan](https://skylion007.github.io), [Weishen Pan](https://vivo.weill.cornell.edu/display/cwid-wep4001),
3 | [Fei Wang](https://wcm-wanglab.github.io/), [Chris De Sa](https://www.cs.cornell.edu/~cdesa/), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)
4 |
5 | [](https://isjakewong.github.io/infodiffusion-page/)
6 | [](https://arxiv.org/abs/2306.08757)
7 |
8 |
9 |
10 |
11 |
12 | |
13 |
14 |
15 | |
16 |
17 |
18 |
19 | We introduce *InfoDiffusion*, a principled probabilistic extension of diffusion models that supports low-dimensional latents with associated variational learning objectives that are regularized with a mutual information term. We show that these algorithms simultaneously yield high-quality samples and latent representations, achieving competitive performance with state-of-the-art methods on both fronts.
20 |
21 |
22 | In this repo, we release:
23 | * **The Auxiliary-Variable Diffusion Models (AVDM)**:
24 | 1. Diffusion decoder conditioned on auxiliary variable using AdaNorm
25 | 2. Simplified loss calucation for auxiliary latent variables with semantic prior
26 | * **Baseline implementations** [[Examples]](#baselines):
27 | 1. A set of model variants from the VAE family (VAE, $\beta$-VAE, InfoVAE) with different priors (Gaussian, Mixture of Gaussians, spiral).
28 | 2. A simplified version of Diffusion Autoencoder [DiffAE](https://arxiv.org/abs/2111.15640) within our AVDM framework.
29 | 3. A minimal and efficient implementation of vanilla diffusion models.
30 | * **Evaluation metrics**:
31 | 1. Generation quality: Fréchet inception distance (FID).
32 | 2. Latent quality: latent space interpolation, latent variables for classification.
33 | 3. Disentanglement: [DCI](https://openreview.net/forum?id=By-7dz-AZ) score, [TAD](https://link.springer.com/chapter/10.1007/978-3-031-19812-0_3) score.
34 | * **Samplers**:
35 | 1. [DDPM](https://proceedings.neurips.cc/paper/2020/hash/4c5bcfec8584af0d967f1ab10179ca4b-Abstract.html) sampling and [DDIM](https://arxiv.org/abs/2010.02502) sampling.
36 | 2. Two phase sampling where these two phases samples from regular diffusion models and VADM consecutivley.
37 | 3. Latent sampling that has an auxiliary latent diffusion model used to sample $\mathbf{z}_t$ along with $\mathbf{x}_t$.
38 | 4. Reverse DDIM sampling to visualize the latent $\mathbf{x}_T$ from $\mathbf{x}_0$.
39 |
40 |
41 | ## Code Organization
42 | 1. ```run.py```: Routines for training and evaluation
43 | 2. ```models.py```: Diffusion models (InfoDiffusion, DiffAE, regular diffusion), VAEs (InfoVAE, $\beta$-VAE, VAE)
44 | 3. ```modules.py```: Neural network blocks
45 | 4. ```sampling.py```: DDPM/DDIM sampler, Reverse DDIM sampler, Two-phase sampler, Latent sampler
46 | 5. ```utils.py```: LR scheduler, logging, utils to calculate priors
47 | 7. ```gen_fid_stats.py```: Generate stats used for FID calculation
48 | 8. ```calc_fid.py```: Calculation FID scores
49 |
50 |
51 |
52 |
53 | ## Getting started in this repository
54 |
55 | To get started, create a conda environment containing the required dependencies.
56 |
57 | ```bash
58 | conda create -n infodiffusion
59 | conda activate infodiffusion
60 | pip install -r requirements.txt
61 | ```
62 |
63 | Run the training using the bash script:
64 | ```bash
65 | bash run.sh
66 | ```
67 | or
68 | ```bash
69 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64
70 | ```
71 | the arguments in this script are given to train a diffusion model `--model diff` using Maximum Mean Discrepancy (MMD) `--mmd_weight 0.1` with a regular Gaussian prior `--prior regular` on CelebA `--dataset celeba`.
72 |
73 |
74 | ## Evaluation
75 |
76 | Below, we describe the steps required for evaluation the trained diffusion models.
77 | Throughout, the main entry point for running experiments is the [`run.py`](./run.py) script.
78 | We also provide sample `bash` scripts for launching these evaluation runs.
79 | In general, different evaluation runs can be switched using `--mode`, which takes one of the following values:
80 | * `eval`: sampling images from the trained diffusion model.
81 | * `eval_fid`: sampling images for FID score calculation.
82 | * `save_latent`: save the auxiliary variables.
83 | * `disentangle`: run evaluation on auxiliary variable disentanglement.
84 | * `interpolate`: run interpolation between two given input images.
85 | * `latent_quality`: save the auxiliary variables and latent variables for classification.
86 | * `train_latent_ddim`: train the latent diffusion models used in latent sampler.
87 | * `plot_latent`: plot the latent space.
88 | However, the quantitative disentanglement evaluation, the latent classification, and the FID score calculation need multiple steps.
89 |
90 | ### Disentanglement evluation
91 | To evaluate latent disentanglement, we need to conduct the following steps:
92 | 1. ```save_latent.sh```: save the auxiliary variables $\mathbf{z}$ and latent variables $\mathbf{x_T}$.
93 | 2. ```eval_disentangle.sh```: evaluate the latent disentanglement by computing DCI and TAD scores.
94 |
95 | #### Save the latents
96 | ```bash
97 | python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
98 | ```
99 |
100 | #### Eval the disentanglement metrics on latents
101 | ```bash
102 | python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
103 | ```
104 |
105 | ### Latent classification
106 | To run latent classification, we need to conduct the following steps:
107 | 1. ```save_latent.sh```: save the auxiliary variables $\mathbf{z}$ and latent variables $\mathbf{x_T}$ used to train the classifier.
108 | 2. ```eval_disentangle.sh```: use the same evaluation script for disentanglement to train the classifier and obtain the classification accuracy.
109 |
110 | #### Save the latents
111 | ```bash
112 | python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
113 | ```
114 |
115 | #### Train latent classifier and evaluate
116 | ```bash
117 | python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
118 | ```
119 |
120 | ### FID calculation
121 |
122 | To calculate the FID scores, we need to conduct the following steps:
123 | 1. ```eval_fid.sh```: train diffusion models and latent diffusion models and generate samples from them.
124 | 2. ```gen_fid.sh```: generate FID stats given the dataset name and the folder storing the preprocessed images from this dataset.
125 | 3. ```calc_fid.sh```: calculate FID scores given the dataset name and the folder storing the generated samples.
126 |
127 | We also provide the commands in the above steps:
128 | #### Train and sample:
129 | ```bash
130 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64
131 |
132 | python run.py --model diff --mode save_latent --disent_metric tad --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --deterministic --prior regular --r_seed 64
133 |
134 | python run.py --model diff --mode train_latent_ddim --a_dim 256 --epochs 50 --mmd_weight 0.1 --dataset celeba --deterministic --save_epoch 10 --prior regular --r_seed 64
135 |
136 | python run.py --model diff --mode eval_fid --split_step 500 --a_dim 256 --batch_size 256 --mmd_weight 0.1 --sampling_number 10000 --epochs 50 --dataset celeba --is_latent --prior regular --r_seed 64
137 | ```
138 |
139 | #### Generate FID stats:
140 | ```bash
141 | python gen_fid_stats.py celeba ./celeba_imgs
142 | ```
143 |
144 | #### Calculate FID scores:
145 | ```bash
146 | python calc_fid.py celeba ./imgs/celeba_32d_0.1mmd/eval-fid-latent
147 | ```
148 |
149 | Note: please refer to [clean-fid](https://github.com/GaParmar/clean-fid) for more options to calculate FID.
150 |
151 | ## Baselines
152 |
153 |
154 | The baselines can be easily switched by using the argument `--model`, which takes in one of the following values `['diff', 'vae', 'vanilla']` where `'diff'` is for AVDM, `'vae'` is for the VAE model family, and `'vanilla'` is for the regular diffusion models. Below is an example to train InfoVAE:
155 | ```bash
156 | python run.py --model vae --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --prior regular --r_seed 64
157 | ```
158 |
159 | ## Notes and disclaimer
160 | The ```main``` branch provides codes and implementations optimized for representation learning tasks and ```InfoDiffusion-dev``` provides codes closer to the version for reproducing the results reported in the paper.
161 |
162 | This research code is provided as-is, without any support or guarantee of quality. However, if you identify any issues or areas for improvement, please feel free to raise an issue or submit a pull request. We will do our best to address them.
163 |
164 | ## Citation
165 | ```
166 | @inproceedings{wang2023infodiffusion,
167 | title={Infodiffusion: Representation learning using information maximizing diffusion models},
168 | author={Wang, Yingheng and Schiff, Yair and Gokaslan, Aaron and Pan, Weishen and Wang, Fei and De Sa, Christopher and Kuleshov, Volodymyr},
169 | booktitle={International Conference on Machine Learning},
170 | pages={36336--36354},
171 | year={2023},
172 | organization={PMLR}
173 | }
174 | ```
175 |
--------------------------------------------------------------------------------
/calc_fid.py:
--------------------------------------------------------------------------------
1 | from cleanfid import fid
2 | import sys
3 | if __name__ == '__main__':
4 | dataset_name = sys.argv[1]
5 | folder_1 = sys.argv[2]
6 | cleanfid_args = dict(
7 | dataset_name=dataset_name,
8 | dataset_res=64,
9 | num_gen=10000,
10 | dataset_split="custom"
11 | )
12 | fid_score = fid.compute_fid(folder_1, **cleanfid_args)
13 | print(f'fid: score: {fid_score}')
14 | kid_score = fid.compute_kid(folder_1, **cleanfid_args)
15 | print(f'kid: score: {kid_score}')
--------------------------------------------------------------------------------
/calc_fid.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | python calc_fid.py celeba ./imgs/celeba_32d_0.1mmd/eval-fid-latent
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import numpy as np
5 | from torch.utils.data import Dataset, DataLoader
6 | from torchvision.datasets import ImageFolder
7 |
8 | class Crop:
9 | def __init__(self, x1, x2, y1, y2):
10 | self.x1 = x1
11 | self.x2 = x2
12 | self.y1 = y1
13 | self.y2 = y2
14 |
15 | def __call__(self, img):
16 | return torchvision.transforms.crop(img, self.x1, self.y1, self.x2 - self.x1,
17 | self.y2 - self.y1)
18 |
19 | def __repr__(self):
20 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
21 | self.x1, self.x2, self.y1, self.y2)
22 |
23 |
24 | def d2c_crop():
25 | # from D2C paper for CelebA dataset.
26 | cx = 89
27 | cy = 121
28 | x1 = cy - 64
29 | x2 = cy + 64
30 | y1 = cx - 64
31 | y2 = cx + 64
32 | return Crop(x1, x2, y1, y2)
33 |
34 |
35 | class CustomTensorDataset(Dataset):
36 | def __init__(self, data, latents_values, latents_classes):
37 | self.data = data
38 | self.latents_values = latents_values
39 | self.latents_classes = latents_classes
40 |
41 | def __getitem__(self, index):
42 | return (torch.from_numpy(self.data[index]).float(),
43 | torch.from_numpy(self.latents_values[index]).float(),
44 | torch.from_numpy(self.latents_classes[index]).int())
45 |
46 | def __len__(self):
47 | return self.data.shape[0]
48 |
49 |
50 | class CustomImageFolder(ImageFolder):
51 | def __init__(self, root, transform=None):
52 | super(CustomImageFolder, self).__init__(root, transform)
53 |
54 | def __getitem__(self, index):
55 | path = self.imgs[index][0]
56 | img = self.loader(path)
57 | if self.transform is not None:
58 | img = self.transform(img)
59 |
60 | return img
61 |
62 |
63 | def get_dataset_config(args):
64 | if args.dataset == 'fmnist':
65 | args.input_channels = 1
66 | args.unets_channels = 32
67 | args.encoder_channels = 32
68 | args.input_size = 32
69 | elif args.dataset == 'mnist':
70 | args.input_channels = 1
71 | args.unets_channels = 32
72 | args.encoder_channels = 32
73 | args.input_size = 32
74 | elif args.dataset == 'dsprites':
75 | args.input_channels = 1
76 | args.unets_channels = 32
77 | args.encoder_channels = 32
78 | args.input_size = 32
79 | elif args.dataset == 'celeba':
80 | args.input_channels = 3
81 | args.unets_channels = 64
82 | args.encoder_channels = 64
83 | args.input_size = 64
84 | elif args.dataset == 'cifar10':
85 | args.input_channels = 3
86 | args.unets_channels = 64
87 | args.encoder_channels = 64
88 | args.input_size = 32
89 | elif args.dataset == 'chairs':
90 | args.input_channels = 3
91 | args.unets_channels = 32
92 | args.encoder_channels = 32
93 | args.input_size = 64
94 | elif args.dataset == 'ffhq':
95 | args.input_channels = 3
96 | args.unets_channels = 64
97 | args.encoder_channels = 64
98 | args.input_size = 64
99 |
100 | shape = (args.input_channels, args.input_size, args.input_size)
101 |
102 | return shape
103 |
104 |
105 | def get_dataset(args):
106 | if args.dataset == 'fmnist':
107 | return get_fmnist(args)
108 | elif args.dataset == 'mnist':
109 | return get_mnist(args)
110 | elif args.dataset == 'celeba':
111 | return get_celeba(args)
112 | elif args.dataset == 'cifar10':
113 | return get_cifar10(args)
114 | elif args.dataset == 'dsprites':
115 | return get_dsprites(args)
116 | elif args.dataset == 'chairs':
117 | return get_chairs(args)
118 | elif args.dataset == 'ffhq':
119 | return get_ffhq(args)
120 |
121 |
122 | def get_mnist(args):
123 | transform = torchvision.transforms.Compose([
124 | torchvision.transforms.Resize((args.input_size, args.input_size)),
125 | torchvision.transforms.ToTensor(),
126 | torchvision.transforms.Lambda(lambda t: (t * 2) - 1),
127 | ])
128 |
129 | dataset = torchvision.datasets.MNIST(root = args.data_dir, train=True, download=True, transform = transform)
130 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, num_workers = 4)
131 |
132 | return dataloader
133 |
134 |
135 | def get_fmnist(args):
136 | transform = torchvision.transforms.Compose([
137 | torchvision.transforms.Resize((args.input_size, args.input_size)),
138 | torchvision.transforms.RandomHorizontalFlip(),
139 | torchvision.transforms.ToTensor(),
140 | torchvision.transforms.Lambda(lambda t: (t * 2) - 1),
141 | ])
142 |
143 | dataset = torchvision.datasets.FashionMNIST(root = args.data_dir, train=True, download=True, transform = transform)
144 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, num_workers = 4)
145 |
146 | return dataloader
147 |
148 |
149 | def get_celeba(args,
150 | as_tensor: bool = True,
151 | do_augment: bool = True,
152 | do_normalize: bool = True,
153 | crop_d2c: bool = False):
154 | if crop_d2c:
155 | transform = [
156 | d2c_crop(),
157 | torchvision.transforms.Resize(args.input_size),
158 | ]
159 | else:
160 | transform = [
161 | torchvision.transforms.Resize(args.input_size),
162 | torchvision.transforms.CenterCrop(args.input_size),
163 | ]
164 |
165 | if do_augment:
166 | transform.append(torchvision.transforms.RandomHorizontalFlip())
167 | if as_tensor:
168 | transform.append(torchvision.transforms.ToTensor())
169 | if do_normalize:
170 | transform.append(
171 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
172 | transform = torchvision.transforms.Compose(transform)
173 |
174 | if args.mode in ['attr_classification', 'eval_fid', 'reconstruction']:
175 | train_set = torchvision.datasets.CelebA(root = args.data_dir, split = "train", download = True, transform = transform)
176 | valid_set = torchvision.datasets.CelebA(root = args.data_dir, split = "valid", download = True, transform = transform)
177 | test_set = torchvision.datasets.CelebA(root = args.data_dir, split = "test", download = True, transform = transform)
178 | train_loader = torch.utils.data.DataLoader(train_set, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4)
179 | valid_loader = torch.utils.data.DataLoader(valid_set, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4)
180 | test_loader = torch.utils.data.DataLoader(test_set, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4)
181 | return (train_loader, valid_loader, test_loader)
182 | else:
183 | dataset = torchvision.datasets.CelebA(root = args.data_dir, split = "train", download = True, transform = transform)
184 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = False, num_workers = 4)
185 |
186 | return dataloader
187 |
188 |
189 | def get_cifar10(args):
190 | transform = torchvision.transforms.Compose([
191 | torchvision.transforms.RandomHorizontalFlip(),
192 | torchvision.transforms.ToTensor(),
193 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
194 | ])
195 |
196 | dataset = torchvision.datasets.CIFAR10(root = args.data_dir, train = True, download = True, transform = transform)
197 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4)
198 | return dataloader
199 |
200 |
201 | def get_dsprites(args):
202 | root = os.path.join(args.data_dir+'/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
203 | file = np.load(root, encoding='latin1')
204 | data = file['imgs'][:, np.newaxis, :, :]
205 | latents_values = file['latents_values']
206 | latents_classes = file['latents_classes']
207 | train_kwargs = {'data':data, 'latents_values':latents_values, 'latents_classes':latents_classes}
208 | dset = CustomTensorDataset
209 | dataset = dset(**train_kwargs)
210 |
211 | dataloader = DataLoader(dataset,
212 | batch_size=args.batch_size,
213 | shuffle=True,
214 | num_workers=4,
215 | pin_memory=True,
216 | drop_last=True)
217 |
218 | return dataloader
219 |
220 |
221 | def get_chairs(args):
222 | transform = torchvision.transforms.Compose([
223 | torchvision.transforms.Resize((args.input_size, args.input_size)),
224 | torchvision.transforms.RandomHorizontalFlip(),
225 | torchvision.transforms.ToTensor(),
226 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
227 | ])
228 |
229 | dataset = CustomImageFolder(root = args.data_dir+'/3DChairs', transform = transform)
230 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = True, num_workers = 4)
231 | return dataloader
232 |
233 |
234 | def get_ffhq(args):
235 | transform = torchvision.transforms.Compose([
236 | torchvision.transforms.Resize((args.input_size, args.input_size)),
237 | torchvision.transforms.RandomHorizontalFlip(),
238 | torchvision.transforms.ToTensor(),
239 | torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
240 | ])
241 |
242 | dataset = CustomImageFolder(root = args.data_dir+'/ffhq', transform = transform)
243 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, drop_last = True, shuffle = False, num_workers = 4)
244 | return dataloader
--------------------------------------------------------------------------------
/disentangle.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | python run.py --model diff --mode disentangle --img_id 0 --mmd_weight 0.1 --a_dim 32 --epochs 20 --dataset celeba --deterministic --prior regular --r_seed 64
4 |
--------------------------------------------------------------------------------
/eval_disentangle.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python eval_disentanglement.py --model diff --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
3 |
--------------------------------------------------------------------------------
/eval_disentanglement.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import pandas as pd
4 | import scipy
5 | import torch
6 | from sklearn.linear_model import LogisticRegression
7 | from sklearn.metrics import roc_auc_score, accuracy_score
8 | from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
9 | from sklearn.preprocessing import StandardScaler
10 | from sklearn.model_selection import KFold
11 | from utils import generate_exp_string
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument('--r_seed', type=int, default=0,
17 | help='the value of given random seed')
18 | parser.add_argument('--img_id', type=int, default=0,
19 | help='the id of given img')
20 | parser.add_argument('--model', required=True,
21 | choices=['diff', 'vae', 'vanilla'], help='which type of model to run')
22 | parser.add_argument('--mode', required=True,
23 | choices=['train', 'eval', 'eval_fid', 'save_latent', 'disentangle',
24 | 'interpolate', 'save_original_img', 'latent_quality',
25 | 'train_latent_ddim', 'plot_latent'], help='which mode to run')
26 | parser.add_argument('--prior', required=True,
27 | choices=['regular', '10mix', 'roll'], help='which type of prior to run')
28 | parser.add_argument('--kld_weight', type=float, default=0,
29 | help='weight of kld loss')
30 | parser.add_argument('--mmd_weight', type=float, default=0.1,
31 | help='weight of mmd loss')
32 | parser.add_argument('--use_C', action='store_true',
33 | default=False, help='use control constant or not')
34 | parser.add_argument('--C_max', type=float, default=25,
35 | help='control constant of kld loss (orig defualt: 25 for simple, 50 for complex)')
36 | parser.add_argument('--dataset', required=True,
37 | choices=['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites', 'chairs', 'ffhq'],
38 | help='training dataset')
39 | parser.add_argument('--img_folder', default='./imgs',
40 | help='path to save sampled images')
41 | parser.add_argument('--log_folder', default='./logs',
42 | help='path to save logs')
43 | parser.add_argument('-e', '--epochs', type=int, default=20,
44 | help='number of epochs to train')
45 | parser.add_argument('--save_epochs', type=int, default=5,
46 | help='number of epochs to save model')
47 | parser.add_argument('--batch_size', type=int, default=64,
48 | help='training batch size')
49 | parser.add_argument('--learning_rate', type=float, default=0.0001,
50 | help='learning rate')
51 | parser.add_argument('--optimizer', default='adam', choices=['adam'],
52 | help='optimization algorithm')
53 | parser.add_argument('--model_folder', default='./models',
54 | help='folder where logs will be stored')
55 | parser.add_argument('--deterministic', action='store_true',
56 | default=False, help='deterministid sampling')
57 | parser.add_argument('--input_channels', type=int, default=1,
58 | help='number of input channels')
59 | parser.add_argument('--unets_channels', type=int, default=64,
60 | help='number of input channels')
61 | parser.add_argument('--encoder_channels', type=int, default=64,
62 | help='number of input channels')
63 | parser.add_argument('--input_size', type=int, default=32,
64 | help='expected size of input')
65 | parser.add_argument('--a_dim', type=int, default=32, required=True,
66 | help='dimensionality of auxiliary variable')
67 | parser.add_argument('--beta1', type=float, default=1e-5,
68 | help='value of beta 1')
69 | parser.add_argument('--betaT', type=float, default=1e-2,
70 | help='value of beta T')
71 | parser.add_argument('--diffusion_steps', type=int, default=1000,
72 | help='number of diffusion steps')
73 | parser.add_argument('--split_step', type=int, default=500,
74 | help='the step for splitting two phases')
75 | parser.add_argument('--sampling_number', type=int, default=16,
76 | help='number of sampled images')
77 | parser.add_argument('--data_dir', type=str, default='./data')
78 | parser.add_argument('--tb_logger', action='store_true',
79 | help='use tensorboard logger.')
80 | parser.add_argument('--is_latent', action='store_true',
81 | help='use latent diffusion for unconditional sampling.')
82 | parser.add_argument('--is_bottleneck', action='store_true',
83 | help='only fuse aux variable in bottleneck layers.')
84 | args = parser.parse_args()
85 |
86 | return args
87 |
88 | """ Impementation of the DCI metric is from:
89 | https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/evaluation/metrics/dci.py
90 | """
91 | def compute_dci(mus_train, ys_train, mus_test, ys_test):
92 | """Computes score based on both training and testing codes and factors."""
93 | scores = {}
94 | importance_matrix, train_err, test_err = compute_importance_gbt(
95 | mus_train, ys_train, mus_test, ys_test)
96 | assert importance_matrix.shape[0] == mus_train.shape[0]
97 | assert importance_matrix.shape[1] == ys_train.shape[0]
98 | scores["informativeness_train"] = train_err
99 | scores["informativeness_test"] = test_err
100 | scores['importance'] = importance_matrix
101 | scores["disentanglement"] = disentanglement(importance_matrix)
102 | scores["completeness"] = completeness(importance_matrix)
103 | return scores
104 |
105 | def compute_importance_gbt(x_train, y_train, x_test, y_test):
106 | """Compute importance based on gradient boosted trees."""
107 | num_factors = y_train.shape[0]
108 | num_codes = x_train.shape[0]
109 | importance_matrix = np.zeros(shape=[num_codes, num_factors], dtype=np.float64)
110 | train_loss = []
111 | test_loss = []
112 | for i in range(num_factors):
113 | model = GradientBoostingClassifier()
114 | model.fit(x_train.T, y_train[i, :])
115 | importance_matrix[:, i] = np.abs(model.feature_importances_)
116 | train_loss.append(np.mean(model.predict(x_train.T) == y_train[i, :]))
117 | test_loss.append(np.mean(model.predict(x_test.T) == y_test[i, :]))
118 | return importance_matrix, np.mean(train_loss), np.mean(test_loss)
119 |
120 |
121 | def disentanglement_per_code(importance_matrix):
122 | """Compute disentanglement score of each code."""
123 | # importance_matrix is of shape [num_codes, num_factors].
124 | return 1. - scipy.stats.entropy(importance_matrix.T + 1e-11, base=importance_matrix.shape[1])
125 |
126 |
127 | def disentanglement(importance_matrix):
128 | """Compute the disentanglement score of the representation."""
129 | per_code = disentanglement_per_code(importance_matrix)
130 | if importance_matrix.sum() == 0.:
131 | importance_matrix = np.ones_like(importance_matrix)
132 | code_importance = importance_matrix.sum(axis=1) / importance_matrix.sum()
133 |
134 | return np.sum(per_code*code_importance)
135 |
136 |
137 | def completeness_per_factor(importance_matrix):
138 | """Compute completeness of each factor."""
139 | # importance_matrix is of shape [num_codes, num_factors].
140 | return 1. - scipy.stats.entropy(importance_matrix + 1e-11,
141 | base=importance_matrix.shape[0])
142 |
143 |
144 | def completeness(importance_matrix):
145 | """"Compute completeness of the representation."""
146 | per_factor = completeness_per_factor(importance_matrix)
147 | if importance_matrix.sum() == 0.:
148 | importance_matrix = np.ones_like(importance_matrix)
149 | factor_importance = importance_matrix.sum(axis=0) / importance_matrix.sum()
150 | return np.sum(per_factor*factor_importance)
151 |
152 |
153 | class PredMetric():
154 | """ Impementation to calculate the AUROC for predicting each attribute
155 | """
156 | def __init__(self, predictor = "RandomForest", output_type = "b", attr_names = None, *args, **kwargs):
157 | super(PredMetric, self).__init__(*args, **kwargs)
158 |
159 | self.attr_names = attr_names
160 | self._predictor = predictor
161 | self.output_type = output_type
162 | if predictor == "Linear":
163 | self.predictor_class = LogisticRegression
164 | self.params = {}
165 | # weights
166 | self.importances_attr = "coef_"
167 | elif predictor == "RandomForest":
168 | self.predictor_class = RandomForestClassifier
169 | self.importances_attr = "feature_importances_"
170 | self.params = {"oob_score": True}
171 | else:
172 | raise NotImplementedError()
173 |
174 | self.TINY = 1e-12
175 |
176 | def evaluate(self, train_codes, train_attrs, test_codes, test_attrs):
177 | R = []
178 | results = []
179 | # train_codes, test_codes, train_attrs, test_attrs = train_test_split(codes, attrs, test_size=0.2)
180 | print("Calculate for attribute:")
181 | for j in range(train_attrs.shape[-1]):
182 | if isinstance(self.params, dict):
183 | predictor = self.predictor_class(**self.params)
184 | elif isinstance(self.params, list):
185 | predictor = self.predictor_class(**self.params[j])
186 | else:
187 | raise NotImplementedError()
188 | predictor.fit(train_codes, train_attrs[:, j])
189 |
190 | r = getattr(predictor, self.importances_attr)[:, None]
191 | R.append(np.abs(r))
192 | # extract relative importance of each code variable in
193 | # predicting the j attribute
194 | if self.output_type == "b":
195 | test_pred_prob = predictor.predict_proba(test_codes)[:, 1]
196 | tmp_result = roc_auc_score(test_attrs[:, j], test_pred_prob)
197 | elif self.output_type == "c":
198 | test_pred = predictor.predict(test_codes)
199 | tmp_result = accuracy_score(test_attrs[:, j], test_pred)
200 | results.append(tmp_result)
201 | if self.attr_names is not None:
202 | print(j, self.attr_names[j], tmp_result)
203 | else:
204 | print(j, tmp_result)
205 |
206 | # R = np.hstack(R) #columnwise, predictions of each z
207 | results = np.array(results)
208 |
209 | return {
210 | "{}_avg_result".format(self._predictor): results.mean(),
211 | "{}_result".format(self._predictor): results
212 | }
213 |
214 | # function that takes a lists of latent indices, thresholds, and signs for classification
215 | class LatentClass(object):
216 | def __init__(self, targ_ind, lat_ind, is_pos, thresh, __max, __min):
217 | super(LatentClass, self).__init__()
218 | self.targ_ind = targ_ind
219 | self.lat_ind = lat_ind
220 | self.is_pos = is_pos
221 | self.thresh = thresh
222 | self._max = __max
223 | self._min = __min
224 | self.it = list(zip(self.targ_ind, self.lat_ind, self.is_pos, self.thresh))
225 |
226 | def __call__(self, z, y_dim):
227 | # expect z to be [batch, z_dim]
228 | out = torch.ones((z.shape[0], y_dim))
229 | for t_i, l_i, is_pos, t in self.it:
230 | ma, mi = self._max[l_i], self._min[l_i]
231 | thr = t * (ma - mi) + mi
232 | res = (z[:, l_i] >= thr if is_pos else z[:, l_i] < thr).type(torch.int)
233 | out[:, t_i] = res
234 | return out
235 |
236 | class TADMetric():
237 | """ Impementation of the metric in:
238 | NashAE: Disentangling Representations Through Adversarial Covariance Minimization
239 | The code is from:
240 | https://github.com/ericyeats/nashae-beamsynthesis
241 | """
242 | def __init__(self, y_dim, all_attrs):
243 | self.y_dim = y_dim
244 | self.all_attrs = all_attrs
245 |
246 | def calculate_auroc(self, targ, targ_ind, lat_ind, z, _ma, _mi, stepsize=0.01):
247 | thr = torch.arange(0.0, 1.0001, step=stepsize)
248 | total = targ.shape[0]
249 | pos_total = targ.sum(dim=0)[targ_ind].item()
250 | neg_total = total - pos_total
251 | p_fpr_tpr = torch.zeros((thr.shape[0], 2))
252 | n_fpr_tpr = torch.zeros((thr.shape[0], 2))
253 | for i, t in enumerate(thr):
254 | local_lc = LatentClass([targ_ind], [lat_ind], [True], [t], _ma, _mi)
255 | pred = local_lc(z.clone(), self.y_dim).to(targ.device)
256 | p_tp = torch.logical_and(pred == targ, pred).sum(dim=0)[targ_ind].item()
257 | p_fp = torch.logical_and(pred != targ, pred).sum(dim=0)[targ_ind].item()
258 | p_fpr_tpr[i][0] = p_fp / neg_total
259 | p_fpr_tpr[i][1] = p_tp / pos_total
260 | local_lc = LatentClass([targ_ind], [lat_ind], [False], [t], _ma, _mi)
261 | pred = local_lc(z.clone(), self.y_dim).to(targ.device)
262 | n_tp = torch.logical_and(pred == targ, pred).sum(dim=0)[targ_ind].item()
263 | n_fp = torch.logical_and(pred != targ, pred).sum(dim=0)[targ_ind].item()
264 | n_fpr_tpr[i][0] = n_fp / neg_total
265 | n_fpr_tpr[i][1] = n_tp / pos_total
266 | p_fpr_tpr = p_fpr_tpr.sort(dim=0)[0]
267 | n_fpr_tpr = n_fpr_tpr.sort(dim=0)[0]
268 | p_dists = p_fpr_tpr[1:, 0] - p_fpr_tpr[:-1, 0]
269 | p_area = (p_fpr_tpr[1:, 1] * p_dists).sum().item()
270 | n_dists = n_fpr_tpr[1:, 0] - n_fpr_tpr[:-1, 0]
271 | n_area = (n_fpr_tpr[1:, 1] * n_dists).sum().item()
272 | return p_area, n_area
273 |
274 | def aurocs(self, _z, targ, targ_ind, _ma, _mi):
275 | # perform a grid search of lat_ind to find the best classification metric
276 | aurocs = torch.ones(_z.shape[1]) * 0.5 # initialize as random guess
277 | for lat_ind in range(_z.shape[1]):
278 | if _ma[lat_ind] - _mi[lat_ind] > 0.2:
279 | p_auroc, n_auroc = self.calculate_auroc(targ, targ_ind, lat_ind, _z.clone(), _ma, _mi)
280 | m_auroc = max(p_auroc, n_auroc)
281 | aurocs[lat_ind] = m_auroc
282 | # print("{}\t{:1.3f}".format(lat_ind, m_auroc))
283 | return aurocs
284 |
285 | def aurocs_search(self, a, y):
286 | aurocs_all = torch.ones((y.shape[1], a.shape[1])) * 0.5
287 | base_rates_all = y.sum(dim=0)
288 | base_rates_all = base_rates_all / y.shape[0]
289 | _ma = a.max(dim=0)[0]
290 | _mi = a.min(dim=0)[0]
291 | print("Calculate for attribute:")
292 | for i in range(y.shape[1]):
293 | print(i)
294 | for j in range(a.shape[1]):
295 | aurocs_all[i, j] = max(roc_auc_score(y.numpy()[:, i], a.numpy()[:, j]), roc_auc_score(y.numpy()[:, i], -a.numpy()[:, j]))
296 | # aurocs_all[i] = self.aurocs(a, y, i, _ma, _mi)
297 | return aurocs_all.cpu(), base_rates_all.cpu()
298 |
299 | def evaluate(self, a, y):
300 | auroc_result, base_rates_raw = self.aurocs_search(torch.FloatTensor(a), torch.IntTensor(y))
301 | base_rates = base_rates_raw.where(base_rates_raw <= 0.5, 1. - base_rates_raw)
302 | targ = torch.IntTensor(y)
303 | dim_y = y.shape[1]
304 |
305 | thresh = 0.75
306 | ent_red_thresh = 0.2
307 | max_aur, argmax_aur = torch.max(auroc_result.clone(), dim=1)
308 | norm_diffs = torch.zeros(dim_y)
309 | aurs_diffs = torch.zeros(dim_y)
310 | for ind, tag, max_a, argmax_a, aurs in zip(range(dim_y), self.all_attrs, max_aur.clone(), argmax_aur.clone(),
311 | auroc_result.clone()):
312 | norm_aurs = (aurs.clone() - 0.5) / (aurs.clone()[argmax_a] - 0.5)
313 | aurs_next = aurs.clone()
314 | aurs_next[argmax_a] = 0.0
315 | aurs_diff = max_a - aurs_next.max()
316 | aurs_diffs[ind] = aurs_diff
317 | norm_aurs[argmax_a] = 0.0
318 | norm_diff = 1. - norm_aurs.max()
319 | norm_diffs[ind] = norm_diff
320 |
321 | # calculate mutual information shared between attributes
322 | # determine which share a lot of information with each other
323 | with torch.no_grad():
324 | not_targ = 1 - targ
325 | j_prob = lambda x, y: torch.logical_and(x, y).sum() / x.numel()
326 | mi = lambda jp, px, py: 0. if jp == 0. or px == 0. or py == 0. else jp * torch.log(jp / (px * py))
327 |
328 | # Compute the Mutual Information (MI) between the labels
329 | mi_mat = torch.zeros((dim_y, dim_y))
330 | for i in range(dim_y):
331 | # get the marginal of i
332 | i_mp = targ[:, i].sum() / targ.shape[0]
333 | for j in range(dim_y):
334 | j_mp = targ[:, j].sum() / targ.shape[0]
335 | # get the joint probabilities of FF, FT, TF, TT
336 | # FF
337 | jp = j_prob(not_targ[:, i], not_targ[:, j])
338 | pi = 1. - i_mp
339 | pj = 1. - j_mp
340 | mi_mat[i][j] += mi(jp, pi, pj)
341 | # FT
342 | jp = j_prob(not_targ[:, i], targ[:, j])
343 | pi = 1. - i_mp
344 | pj = j_mp
345 | mi_mat[i][j] += mi(jp, pi, pj)
346 | # TF
347 | jp = j_prob(targ[:, i], not_targ[:, j])
348 | pi = i_mp
349 | pj = 1. - j_mp
350 | mi_mat[i][j] += mi(jp, pi, pj)
351 | # TT
352 | jp = j_prob(targ[:, i], targ[:, j])
353 | pi = i_mp
354 | pj = j_mp
355 | mi_mat[i][j] += mi(jp, pi, pj)
356 |
357 | mi_maxes, mi_inds = (mi_mat * (1 - torch.eye(dim_y))).max(dim=1)
358 | ent_red_prop = 1. - (mi_mat.diag() - mi_maxes) / mi_mat.diag()
359 |
360 | # calculate Average Norm AUROC Diff when best detector score is at a certain threshold
361 | filt = (max_aur >= thresh).logical_and(ent_red_prop <= ent_red_thresh)
362 | aurs_diffs_filt = aurs_diffs[filt]
363 | return aurs_diffs_filt.sum().item(), auroc_result.cpu().numpy(), len(aurs_diffs_filt)
364 |
365 | if __name__ == "__main__":
366 | dataset = "celeba"
367 | args = parse_args()
368 | data_dict = np.load("{}_{}_latent.npz".format(args.model, generate_exp_string(args).replace(".", "_")), allow_pickle=True)
369 |
370 | if dataset == "celeba":
371 | y_names = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald',
372 | 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
373 | 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
374 | 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
375 | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
376 | 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
377 | 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
378 | 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
379 | ]
380 | output_type = "b"
381 | elif dataset == "fmnist":
382 | y_names = ["Class"]
383 | output_type = "c"
384 | elif dataset == "cifar10":
385 | y_names = ["Class"]
386 | output_type = "c"
387 | elif dataset == "ffhq":
388 | y_names = ["Age", "Gender", "Glass"]
389 | output_type = "c"
390 | elif dataset == "3dshapes":
391 | y_names = ['Floor hue', 'Wall hue', 'Object hue:', 'Scale', 'Shape', 'Orientation']
392 | output_type = "c"
393 |
394 | if dataset == "celeba":
395 | a = data_dict["all_a"][:10000,:]
396 | y = data_dict["all_attr"][:10000,:].astype(np.int)
397 | elif dataset == "ffhq":
398 | a = data_dict["all_a"][:,:]
399 | y = pd.read_csv("ffhq_labels.csv")
400 | y = y.values[:,2:].astype(np.int)
401 | y = y[:69952, :]
402 | elif dataset == "3dshapes":
403 | a = data_dict["all_a"][:10000, :]
404 | y = data_dict["all_attr"][:10000, :]
405 |
406 | y[:, 0] = y[:, 0] * 10
407 | y[:, 1] = y[:, 1] * 10
408 | y[:, 2] = y[:, 2] * 10
409 | y[:, 3] = y[:, 3] * 14 - 10.5
410 | y[:, 5] = y[:, 5] * 14 / 60 + 7
411 | y = y.astype(np.int)
412 | else:
413 | a = data_dict["all_a"]
414 | if len(data_dict["all_attr"].shape) == 2:
415 | y = data_dict["all_attr"][:, :].astype(np.int)
416 | else:
417 | y = data_dict["all_attr"][:, np.newaxis].astype(np.int)
418 |
419 | kf = KFold(n_splits=5, shuffle=True, random_state=0)
420 | preds_rf, avg_preds_rf = [], []
421 | preds_ln, avg_preds_ln = [], []
422 | if dataset == "celeba":
423 | tad_scores, tad_attrs = [], []
424 | if dataset == "3dshapes":
425 | dci_scores = []
426 |
427 | for tr_idx, te_idx in kf.split(a):
428 | tr_a, te_a = a[tr_idx], a[te_idx]
429 | tr_y, te_y = y[tr_idx], y[te_idx]
430 | std = StandardScaler()
431 | std.fit(tr_a)
432 | tr_a = std.transform(tr_a)
433 | te_a = std.transform(te_a)
434 |
435 | if dataset == "celeba":
436 | tad_metric = TADMetric(y.shape[1], y_names)
437 | tad_score, auroc_result, num_attr = tad_metric.evaluate(tr_a, tr_y)
438 | #
439 | print("TAD SCORE: ", tad_score, "Attributes Captured: ", num_attr)
440 | tad_scores.append(tad_score)
441 | tad_attrs.append(num_attr)
442 | # sns.heatmap(auroc_result.transpose(), xticklabels = y_names)
443 | # plt.show()
444 |
445 | if dataset == "3dshapes":
446 | dci_result = compute_dci(tr_a.transpose(), tr_y.transpose(), te_a.transpose(), te_y.transpose())
447 | R = dci_result['importance']
448 | print("DCI Score", dci_result['disentanglement'])
449 | dci_scores.append(dci_result['disentanglement'])
450 | print(dci_result["informativeness_train"])
451 | print(dci_result["informativeness_test"])
452 |
453 | pred_metric = PredMetric("Linear", output_type, y_names)
454 | pred_result = pred_metric.evaluate(tr_a, tr_y, te_a, te_y)
455 |
456 | print("Avg Result", pred_result['Linear_avg_result'])
457 | avg_preds_ln.append(pred_result['Linear_avg_result'])
458 | preds_ln.append(pred_result['Linear_result'])
459 |
460 | if dataset == "3dshapes":
461 | dci_scores = np.array(dci_scores)
462 | print("DCI Score, {:.4f} \pm {:.4f}".format(np.array(dci_scores).mean(), np.array(dci_scores).std()))
463 |
464 |
465 | if dataset == "celeba":
466 | tad_scores = np.array(tad_scores)
467 | tad_attrs = np.array(tad_attrs)
468 | print("TAD Score, {:.4f} \pm {:.4f}".format(np.array(tad_scores).mean(), np.array(tad_scores).std()))
469 | print("TAD Attr, {:.4f} \pm {:.4f}".format(np.array(tad_attrs).mean(), np.array(tad_attrs).std()))
470 |
471 | avg_preds_ln = np.array(avg_preds_ln)
472 | print("Avg Acc (Linear), {:.4f} \pm {:.4f}".format(np.array(avg_preds_ln).mean(), np.array(avg_preds_ln).std()))
473 |
474 | preds_ln = np.vstack(preds_ln)
475 | for a_idx in range(preds_ln.shape[1]):
476 | print("Acc for {} (Linear), {:.4f} \pm {:.4f}".format(y_names[a_idx], preds_ln[:, a_idx].mean(), preds_ln[:, a_idx].std()))
--------------------------------------------------------------------------------
/eval_fid.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 |
4 | # learned latent
5 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64
6 |
7 | python run.py --model diff --mode save_latent --disent_metric tad --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --deterministic --prior regular --r_seed 64
8 |
9 | python run.py --model diff --mode train_latent_ddim --a_dim 256 --epochs 50 --mmd_weight 0.1 --dataset celeba --deterministic --save_epoch 10 --prior regular --r_seed 64
10 |
11 | python run.py --model diff --mode eval_fid --split_step 500 --a_dim 256 --batch_size 256 --mmd_weight 0.1 --sampling_number 10000 --epochs 50 --dataset celeba --is_latent --prior regular --r_seed 64
12 |
13 | # without learned latent
14 | # python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 256 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64
15 |
16 | # python run.py --model vanilla --mode train --a_dim 256 --epochs 50 --dataset celeba --batch_size 128 --save_epochs 10 --prior regular --r_seed 64
17 |
18 | # python run.py --model diff --mode eval_fid --split_step 500 --a_dim 256 --batch_size 256 --mmd_weight 0.01 --sampling_number 10000 --epochs 50 --dataset celeba --prior regular --r_seed 64
--------------------------------------------------------------------------------
/flowchart.drawio-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isjakewong/InfoDiffusion/75c7b31953692af2f122e7fe8715902e160c9de5/flowchart.drawio-1.png
--------------------------------------------------------------------------------
/gen_fid.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | python gen_fid_stats.py celeba ./celeba_imgs
--------------------------------------------------------------------------------
/gen_fid_stats.py:
--------------------------------------------------------------------------------
1 | from cleanfid import fid
2 | import sys
3 | if __name__ == '__main__':
4 | custom_name = sys.argv[1]
5 | dataset_path = sys.argv[2]
6 | # TODO: remove saved stats if you want to reproduce.
7 | # fid.remove_custom_stats(custom_name, mode="clean")
8 | print(f'Generating clean-fid for dataset {custom_name} located at {dataset_path}')
9 | fid.make_custom_stats(custom_name, dataset_path, mode="clean")
--------------------------------------------------------------------------------
/graphicalabstract.drawio_v2-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isjakewong/InfoDiffusion/75c7b31953692af2f122e7fe8715902e160c9de5/graphicalabstract.drawio_v2-1.png
--------------------------------------------------------------------------------
/interpolate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | python run.py --model diff --mode interpolate --mmd_weight 0.1 --img_id 0 --a_dim 32 --epochs 50 --dataset celeba --deterministic --prior regular --r_seed 64
4 |
--------------------------------------------------------------------------------
/latent_quality.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | python run.py --model diff --mode latent_quality --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
4 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from modules import *
5 | from utils import compute_mmd, gaussian_mixture, swiss_roll
6 |
7 | class UNet(nn.Module):
8 | def __init__(self, T, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, shape=None):
9 | super().__init__()
10 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
11 | tdim = ch * 4
12 | self.time_embedding = TimeEmbedding(T, ch, tdim)
13 |
14 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1)
15 |
16 | self.downblocks = nn.ModuleList()
17 | chs = [ch] # record output channel when dowmsample for upsample
18 | now_ch = ch
19 | for i, mult in enumerate(ch_mult):
20 | out_ch = ch * mult
21 | for _ in range(num_res_blocks):
22 | self.downblocks.append(ResBlock(
23 | in_ch=now_ch, out_ch=out_ch, tdim=tdim,
24 | dropout=dropout, attn=(i in attn)))
25 | now_ch = out_ch
26 | chs.append(now_ch)
27 | if i != len(ch_mult) - 1:
28 | self.downblocks.append(DownSample(now_ch))
29 | chs.append(now_ch)
30 |
31 | self.middleblocks = nn.ModuleList([
32 | ResBlock(now_ch, now_ch, tdim, dropout, attn=True, crossattn=False),
33 | ResBlock(now_ch, now_ch, tdim, dropout, attn=False, crossattn=False),
34 | ])
35 |
36 | self.upblocks = nn.ModuleList()
37 | for i, mult in reversed(list(enumerate(ch_mult))):
38 | out_ch = ch * mult
39 | for _ in range(num_res_blocks + 1):
40 | self.upblocks.append(ResBlock(
41 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
42 | dropout=dropout, attn=(i in attn)))
43 | now_ch = out_ch
44 | if i != 0:
45 | self.upblocks.append(UpSample(now_ch))
46 | assert len(chs) == 0
47 |
48 | self.tail = nn.Sequential(
49 | nn.GroupNorm(32, now_ch),
50 | nn.SiLU(),
51 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1)
52 | )
53 |
54 | self.initialize()
55 |
56 | def initialize(self):
57 | init.xavier_uniform_(self.head.weight)
58 | init.zeros_(self.head.bias)
59 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
60 | init.zeros_(self.tail[-1].bias)
61 |
62 | def forward(self, x, t):
63 | # Timestep embedding
64 | temb = self.time_embedding(t)
65 | # Downsampling
66 | h = self.head(x)
67 | hs = [h]
68 |
69 | for layer in self.downblocks:
70 | h = layer(h, temb)
71 | hs.append(h)
72 |
73 | # Middle
74 | for layer in self.middleblocks:
75 | if isinstance(layer, ResBlock):
76 | h = layer(h, temb)
77 | else:
78 | h = layer(h)
79 |
80 | # Upsampling
81 | for layer in self.upblocks:
82 | if isinstance(layer, ResBlock):
83 | h = torch.cat([h, hs.pop()], dim=1)
84 | h = layer(h, temb)
85 | h = self.tail(h)
86 |
87 | assert len(hs) == 0
88 | return h
89 |
90 |
91 | class MLPLNAct(nn.Module):
92 | def __init__(
93 | self,
94 | in_channels: int,
95 | out_channels: int,
96 | norm: bool,
97 | use_cond: bool,
98 | activation: str = None,
99 | cond_channels: int = None,
100 | condition_bias: float = 0,
101 | dropout: float = 0,
102 | ):
103 | super().__init__()
104 | self.activation = activation
105 | if self.activation is not None:
106 | self.act = nn.SiLU()
107 | else:
108 | self.act = nn.Identity()
109 | self.condition_bias = condition_bias
110 | self.use_cond = use_cond
111 |
112 | self.linear = nn.Linear(in_channels, out_channels)
113 | if self.use_cond:
114 | self.linear_emb = nn.Linear(cond_channels, out_channels)
115 | self.cond_layers = nn.Sequential(self.act, self.linear_emb)
116 | if norm:
117 | self.norm = nn.LayerNorm(out_channels)
118 | else:
119 | self.norm = nn.Identity()
120 |
121 | if dropout > 0:
122 | self.dropout = nn.Dropout(dropout)
123 | else:
124 | self.dropout = nn.Identity()
125 |
126 | self.init_weights()
127 |
128 | def init_weights(self):
129 | for module in self.modules():
130 | if isinstance(module, nn.Linear):
131 | if self.activation == 'relu':
132 | init.kaiming_normal_(module.weight,
133 | a=0,
134 | nonlinearity='relu')
135 | elif self.activation == 'leaky_relu':
136 | init.kaiming_normal_(module.weight,
137 | a=0.2,
138 | nonlinearity='leaky_relu')
139 | elif self.activation == 'silu':
140 | init.kaiming_normal_(module.weight,
141 | a=0,
142 | nonlinearity='relu')
143 | else:
144 | # leave it as default
145 | pass
146 |
147 | def forward(self, x, cond=None):
148 | x = self.linear(x)
149 | if self.use_cond:
150 | # (n, c) or (n, c * 2)
151 | cond = self.cond_layers(cond)
152 |
153 | # scale shift first
154 | x = x * (self.condition_bias + cond)
155 |
156 | # then norm
157 | x = self.norm(x)
158 | else:
159 | # no condition
160 | x = self.norm(x)
161 | x = self.act(x)
162 | x = self.dropout(x)
163 | return x
164 |
165 |
166 | class LatentUNet(nn.Module):
167 | def __init__(self, T, num_layers=10, dropout=0.1, shape=None, activation='silu',
168 | num_time_emb_channels: int = 64, num_time_layers: int = 2):
169 | super().__init__()
170 | self.num_time_emb_channels = num_time_emb_channels
171 | self.shape = shape
172 |
173 | layers = []
174 | for i in range(num_time_layers):
175 | if i == 0:
176 | a = num_time_emb_channels
177 | b = shape[-1]
178 | else:
179 | a = shape[-1]
180 | b = shape[-1]
181 | layers.append(nn.Linear(a, b))
182 | if i < num_time_layers - 1:
183 | layers.append(nn.SiLU())
184 | self.time_embed = nn.Sequential(*layers)
185 |
186 | self.skip_layers = list(range(1, num_layers))
187 | self.layers = nn.ModuleList([])
188 | for i in range(num_layers):
189 | if i == 0:
190 | act = activation
191 | norm = True
192 | cond = True
193 | a, b = shape[-1], shape[-1] * 4
194 | dropout = dropout
195 | elif i == num_layers - 1:
196 | act = None
197 | norm = False
198 | cond = False
199 | a, b = shape[-1] * 4, shape[-1]
200 | dropout = 0
201 | else:
202 | act = 'silu'
203 | norm = True
204 | cond = True
205 | a, b = shape[-1] * 4, shape[-1] * 4
206 | dropout = dropout
207 |
208 | if i in self.skip_layers:
209 | a += shape[-1]
210 |
211 | self.layers.append(
212 | MLPLNAct(
213 | a,
214 | b,
215 | norm=norm,
216 | activation=act,
217 | cond_channels=shape[-1],
218 | use_cond=cond,
219 | condition_bias=1,
220 | dropout=dropout,
221 | ))
222 |
223 | def forward(self, x, t):
224 | # Timestep embedding
225 | t = timestep_embedding(t, self.num_time_emb_channels)
226 | temb = self.time_embed(t)
227 |
228 | h = x
229 | for i in range(len(self.layers)):
230 | if i in self.skip_layers:
231 | # injecting input into the hidden layers
232 | h = torch.cat([h, x], dim=1)
233 | h = self.layers[i].forward(x=h, cond=temb)
234 | return h
235 |
236 |
237 | class AuxiliaryUNet(nn.Module):
238 | def __init__(self, T, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=32, shape=None):
239 | super().__init__()
240 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
241 | tdim = ch * 4
242 | self.a_dim = a_dim
243 | self.time_embedding = TimeEmbedding(T, ch, tdim)
244 | self.fc_a = nn.Linear(self.a_dim, tdim)
245 |
246 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1)
247 |
248 | self.downblocks = nn.ModuleList()
249 | chs = [ch] # record output channel when dowmsample for upsample
250 | now_ch = ch
251 | for i, mult in enumerate(ch_mult):
252 | out_ch = ch * mult
253 | for _ in range(num_res_blocks):
254 | self.downblocks.append(AuxResBlock(
255 | in_ch=now_ch, out_ch=out_ch, tdim=tdim,
256 | dropout=dropout, attn=(i in attn)))
257 | now_ch = out_ch
258 | chs.append(now_ch)
259 | if i != len(ch_mult) - 1:
260 | self.downblocks.append(DownSample(now_ch))
261 | chs.append(now_ch)
262 |
263 | self.middleblocks = nn.ModuleList([
264 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=True, crossattn=False),
265 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=False, crossattn=False),
266 | ])
267 |
268 | self.upblocks = nn.ModuleList()
269 | for i, mult in reversed(list(enumerate(ch_mult))):
270 | out_ch = ch * mult
271 | for _ in range(num_res_blocks + 1):
272 | self.upblocks.append(AuxResBlock(
273 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
274 | dropout=dropout, attn=(i in attn)))
275 | now_ch = out_ch
276 | if i != 0:
277 | self.upblocks.append(UpSample(now_ch))
278 | assert len(chs) == 0
279 |
280 | self.tail = nn.Sequential(
281 | nn.GroupNorm(32, now_ch),
282 | nn.SiLU(),
283 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1)
284 | )
285 |
286 | self.initialize()
287 |
288 | def initialize(self):
289 | init.xavier_uniform_(self.head.weight)
290 | init.zeros_(self.head.bias)
291 | init.xavier_uniform_(self.fc_a.weight)
292 | init.zeros_(self.fc_a.bias)
293 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
294 | init.zeros_(self.tail[-1].bias)
295 |
296 | def forward(self, x, t, a):
297 | # Latent embedding
298 | aemb = self.fc_a(a)
299 |
300 | # Timestep embedding
301 | temb = self.time_embedding(t)
302 |
303 | # Downsampling
304 | h = self.head(x)
305 | hs = [h]
306 |
307 | for layer in self.downblocks:
308 | h = layer(h, temb, aemb)
309 | hs.append(h)
310 |
311 | # Middle
312 | for layer in self.middleblocks:
313 | if isinstance(layer, AuxResBlock):
314 | h = layer(h, temb, aemb)
315 | else:
316 | h = layer(h)
317 |
318 | # Upsampling
319 | for layer in self.upblocks:
320 | if isinstance(layer, AuxResBlock):
321 | h = torch.cat([h, hs.pop()], dim=1)
322 | h = layer(h, temb, aemb)
323 | h = self.tail(h)
324 |
325 | assert len(hs) == 0
326 | return h
327 |
328 |
329 | class BottleneckAuxUNet(nn.Module):
330 | def __init__(self, T, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=32, shape=None):
331 | super().__init__()
332 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
333 | tdim = ch * 4
334 | self.a_dim = a_dim
335 | self.time_embedding = TimeEmbedding(T, ch, tdim)
336 | self.fc_a = nn.Sequential(
337 | nn.SiLU(),
338 | nn.Linear(self.a_dim, tdim)
339 | )
340 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1)
341 |
342 | self.downblocks = nn.ModuleList()
343 | chs = [ch] # record output channel when dowmsample for upsample
344 | now_ch = ch
345 | for i, mult in enumerate(ch_mult):
346 | out_ch = ch * mult
347 | for _ in range(num_res_blocks):
348 | self.downblocks.append(ResBlock(
349 | in_ch=now_ch, out_ch=out_ch, tdim=tdim,
350 | dropout=dropout, attn=(i in attn)))
351 | now_ch = out_ch
352 | chs.append(now_ch)
353 | if i != len(ch_mult) - 1:
354 | self.downblocks.append(DownSample(now_ch))
355 | chs.append(now_ch)
356 |
357 | self.middleblocks = nn.ModuleList([
358 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=True, crossattn=False),
359 | AuxResBlock(now_ch, now_ch, tdim, dropout, attn=False, crossattn=False),
360 | ])
361 |
362 | self.upblocks = nn.ModuleList()
363 | for i, mult in reversed(list(enumerate(ch_mult))):
364 | out_ch = ch * mult
365 | for _ in range(num_res_blocks + 1):
366 | self.upblocks.append(ResBlock(
367 | in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
368 | dropout=dropout, attn=(i in attn)))
369 | now_ch = out_ch
370 | if i != 0:
371 | self.upblocks.append(UpSample(now_ch))
372 | assert len(chs) == 0
373 |
374 | self.tail = nn.Sequential(
375 | nn.GroupNorm(32, now_ch),
376 | nn.SiLU(),
377 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1)
378 | )
379 |
380 | self.initialize()
381 |
382 | def initialize(self):
383 | init.xavier_uniform_(self.head.weight)
384 | init.zeros_(self.head.bias)
385 | init.kaiming_normal_(self.fc_a[1].weight,
386 | a=0,
387 | nonlinearity='relu')
388 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
389 | init.zeros_(self.tail[-1].bias)
390 |
391 | def forward(self, x, t, a):
392 | # Latent embedding
393 | aemb = self.fc_a(a)
394 |
395 | # Timestep embedding
396 | temb = self.time_embedding(t)
397 |
398 | # Downsampling
399 | h = self.head(x)
400 | hs = [h]
401 |
402 | for layer in self.downblocks:
403 | h = layer(h, temb)
404 | hs.append(h)
405 |
406 | # Middle
407 | for layer in self.middleblocks:
408 | if isinstance(layer, AuxResBlock):
409 | h = layer(h, temb, aemb)
410 | else:
411 | h = layer(h)
412 |
413 | # Upsampling
414 | for layer in self.upblocks:
415 | if isinstance(layer, ResBlock):
416 | h = torch.cat([h, hs.pop()], dim=1)
417 | h = layer(h, temb)
418 | h = self.tail(h)
419 |
420 | assert len(hs) == 0
421 | return h
422 |
423 |
424 | class Encoder(nn.Module):
425 | def __init__(self, ch=64, ch_mult=[1,2,4,8,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=32, shape=None):
426 | super().__init__()
427 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
428 |
429 | self.shape = shape
430 | self.a_dim = a_dim
431 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1)
432 | self.downblocks = nn.ModuleList()
433 | chs = [ch] # record output channel when dowmsample for upsample
434 | now_ch = ch
435 | for i, mult in enumerate(ch_mult):
436 | out_ch = ch * mult
437 | for _ in range(num_res_blocks):
438 | self.downblocks.append(ResBlock_encoder(
439 | in_ch=now_ch, out_ch=out_ch,
440 | dropout=dropout, attn=(i in attn)))
441 | now_ch = out_ch
442 | chs.append(now_ch)
443 | if i != len(ch_mult) - 1:
444 | self.downblocks.append(DownSample(now_ch))
445 | chs.append(now_ch)
446 |
447 | self.middleblocks = nn.ModuleList([
448 | ResBlock_encoder(now_ch, now_ch, dropout, attn=True),
449 | ResBlock_encoder(now_ch, now_ch, dropout, attn=False),
450 | ])
451 |
452 | self.upblocks = nn.ModuleList()
453 | for i, mult in reversed(list(enumerate(ch_mult))):
454 | out_ch = ch * mult
455 | for _ in range(num_res_blocks + 1):
456 | self.upblocks.append(ResBlock_encoder(
457 | in_ch=chs.pop() + now_ch, out_ch=out_ch,
458 | dropout=dropout, attn=(i in attn)))
459 | now_ch = out_ch
460 | if i != 0:
461 | self.upblocks.append(UpSample(now_ch))
462 | assert len(chs) == 0
463 |
464 | self.tail = nn.Sequential(
465 | nn.GroupNorm(32, now_ch),
466 | nn.SiLU(),
467 | nn.Conv2d(now_ch, 1, 3, stride=1, padding=1)
468 | )
469 |
470 | self.fc_a = nn.Linear(self.shape[1]*self.shape[2], self.a_dim)
471 | self.fc_mu = nn.Linear(self.a_dim, self.a_dim)
472 | self.fc_var = nn.Linear(self.a_dim, self.a_dim)
473 |
474 | self.initialize()
475 |
476 | def initialize(self):
477 | init.xavier_uniform_(self.head.weight)
478 | init.zeros_(self.head.bias)
479 | init.xavier_uniform_(self.fc_a.weight)
480 | init.zeros_(self.fc_a.bias)
481 | init.xavier_uniform_(self.fc_mu.weight)
482 | init.zeros_(self.fc_mu.bias)
483 | init.xavier_uniform_(self.fc_var.weight)
484 | init.zeros_(self.fc_var.bias)
485 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
486 | init.zeros_(self.tail[-1].bias)
487 |
488 | def forward(self, x):
489 | # Downsampling
490 | h = self.head(x)
491 |
492 | hs = [h]
493 | for layer in self.downblocks:
494 | if isinstance(layer, ResBlock_encoder):
495 | h = layer(h)
496 | else:
497 | h = layer(h, None, None) # for downsample module, 0 is placeholder for temb and aemb
498 | hs.append(h)
499 | # Middle
500 | for layer in self.middleblocks:
501 | h = layer(h)
502 | # Upsampling
503 | for layer in self.upblocks:
504 | if isinstance(layer, ResBlock_encoder):
505 | h = torch.cat([h, hs.pop()], dim=1)
506 | h = layer(h)
507 | else:
508 | h = layer(h, None, None) # for upsample module, 0 is placeholder for temb and aemb
509 |
510 | h = torch.flatten(self.tail(h), start_dim=1)
511 |
512 | a = self.fc_a(h)
513 | mu = self.fc_mu(a)
514 | log_var = self.fc_var(a)
515 | a_q = mu + torch.randn_like(mu) * torch.exp(0.5 * log_var)
516 |
517 | assert len(hs) == 0
518 | return a, a_q, mu, log_var
519 |
520 |
521 | class Decoder(nn.Module):
522 | def __init__(self, ch=64, ch_mult=[1,2,4,8], attn=[2], num_res_blocks=2, dropout=0.1, a_dim=10, shape=None):
523 | super().__init__()
524 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
525 | self.a_dim = a_dim
526 | self.shape = shape
527 | self.head = nn.Conv2d(shape[0], ch, kernel_size=3, stride=1, padding=1)
528 | self.downblocks = nn.ModuleList()
529 | chs = [ch] # record output channel when dowmsample for upsample
530 | now_ch = ch
531 | for i, mult in enumerate(ch_mult):
532 | out_ch = ch * mult
533 | for _ in range(num_res_blocks):
534 | self.downblocks.append(ResBlock_encoder(
535 | in_ch=now_ch, out_ch=out_ch,
536 | dropout=dropout, attn=(i in attn)))
537 | now_ch = out_ch
538 | chs.append(now_ch)
539 | if i != len(ch_mult) - 1:
540 | self.downblocks.append(DownSample(now_ch))
541 | chs.append(now_ch)
542 |
543 | self.middleblocks = nn.ModuleList([
544 | ResBlock_encoder(now_ch, now_ch, dropout, attn=True),
545 | ResBlock_encoder(now_ch, now_ch, dropout, attn=False),
546 | ])
547 |
548 | self.upblocks = nn.ModuleList()
549 | for i, mult in reversed(list(enumerate(ch_mult))):
550 | out_ch = ch * mult
551 | for _ in range(num_res_blocks + 1):
552 | self.upblocks.append(ResBlock_encoder(
553 | in_ch=chs.pop() + now_ch, out_ch=out_ch,
554 | dropout=dropout, attn=(i in attn)))
555 | now_ch = out_ch
556 | if i != 0:
557 | self.upblocks.append(UpSample(now_ch))
558 | assert len(chs) == 0
559 |
560 | self.tail = nn.Sequential(
561 | nn.GroupNorm(32, now_ch),
562 | nn.SiLU(),
563 | nn.Conv2d(now_ch, shape[0], 3, stride=1, padding=1)
564 | )
565 |
566 | self.fc_a = nn.Linear(self.a_dim, self.shape[0]*self.shape[1]*self.shape[2])
567 |
568 | self.initialize()
569 |
570 | def initialize(self):
571 | init.xavier_uniform_(self.head.weight)
572 | init.zeros_(self.head.bias)
573 | init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
574 | init.zeros_(self.tail[-1].bias)
575 |
576 | def forward(self, a):
577 | # Latent embedding
578 | aemb = self.fc_a(a)
579 | h = aemb.reshape(a.shape[0], self.shape[0], self.shape[1], self.shape[2])
580 |
581 | # Downsampling
582 | h = self.head(h)
583 | hs = [h]
584 | for layer in self.downblocks:
585 | if isinstance(layer, ResBlock_encoder):
586 | h = layer(h)
587 | else:
588 | h = layer(h, None) # for downsample module, 0 is placeholder
589 | hs.append(h)
590 | # Middle
591 | for layer in self.middleblocks:
592 | h = layer(h)
593 | # Upsampling
594 | for layer in self.upblocks:
595 | if isinstance(layer, ResBlock_encoder):
596 | h = torch.cat([h, hs.pop()], dim=1)
597 | h = layer(h)
598 | else:
599 | h = layer(h, None) # for upsample module, 0 is placeholder
600 | rec = self.tail(h)
601 |
602 | assert len(hs) == 0
603 | return rec
604 |
605 | class InfoDiff(nn.Module):
606 | def __init__(self, args, device, shape):
607 | '''
608 | beta_1 : beta_1 of diffusion process
609 | beta_T : beta_T of diffusion process
610 | T : Diffusion Steps
611 | '''
612 |
613 | super().__init__()
614 | self.device = device
615 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device)
616 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps).to(device = device)
617 | self.alphas = 1 - self.betas
618 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]])
619 | if args.input_size == 28:
620 | ch_mult = [1,2,4,]
621 | else:
622 | ch_mult = [1,2,2,2]
623 | if args.is_bottleneck:
624 | self.backbone = BottleneckAuxUNet(ch_mult=ch_mult, T=args.diffusion_steps, ch=args.unets_channels, a_dim=args.a_dim, shape=shape)
625 | else:
626 | self.backbone = AuxiliaryUNet(ch_mult=ch_mult, T=args.diffusion_steps, ch=args.unets_channels, a_dim=args.a_dim, shape=shape)
627 | self.encoder = Encoder(ch_mult=ch_mult, ch=args.encoder_channels, a_dim=args.a_dim, shape=shape)
628 | self.mmd_weight : float = args.mmd_weight
629 | self.kld_weight : float = args.kld_weight
630 | self.to(device)
631 |
632 | def loss_fn(self, args, x, idx=None, curr_epoch=0):
633 | '''
634 | x : real data if idx==None else perturbation data
635 | idx : if None (training phase), we perturbed random index.
636 | '''
637 | output, epsilon, a, mu, log_var = self.forward(x, idx=idx, get_target=True)
638 |
639 | # denoising matching term
640 | loss = (output - epsilon).square().mean()
641 | print('denoising loss:', loss)
642 |
643 | # reconstruction term
644 | x_0 = torch.sqrt(1 / self.alphas[0]) * (x - self.betas[0] / torch.sqrt(1 - self.alpha_bars[0]) * output)
645 | loss_rec = (x_0 - x).square().mean()
646 | loss += loss_rec / args.diffusion_steps
647 | print('recon loss:', loss_rec / args.diffusion_steps)
648 |
649 | if self.mmd_weight != 0 and self.kld_weight != 0:
650 | # MMD term
651 | if args.prior == 'regular':
652 | true_samples = torch.randn_like(a, device=self.device)
653 | elif args.prior == '10mix':
654 | prior = gaussian_mixture(args.batch_size, args.a_dim)
655 | true_samples = torch.FloatTensor(prior).to(device=self.device)
656 | elif args.prior == 'roll':
657 | prior = swiss_roll(args.batch_size)
658 | true_samples = torch.FloatTensor(prior).to(device=self.device)
659 | loss_mmd = compute_mmd(true_samples, mu)
660 | print('mmd loss:', args.mmd_weight * loss_mmd)
661 | loss += args.mmd_weight * loss_mmd
662 | # KLD term
663 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
664 | if args.use_C:
665 | # KLD term w/ control constant
666 | self.C_max = torch.FloatTensor([args.C_max]).to(device=self.device)
667 | C = torch.clamp(self.C_max/args.epochs*curr_epoch, torch.FloatTensor([0]).to(device=self.device), self.C_max)
668 | loss += args.kld_weight * (kld_loss - C.squeeze(dim=0)).abs()
669 | else:
670 | print('kld loss:', args.kld_weight * kld_loss)
671 | loss += args.kld_weight * kld_loss
672 | elif args.mmd_weight != 0:
673 | # MMD term
674 | if args.prior == 'regular':
675 | true_samples = torch.randn_like(a, device=self.device)
676 | elif args.prior == '10mix':
677 | prior = gaussian_mixture(args.batch_size, args.a_dim)
678 | true_samples = torch.FloatTensor(prior).to(device=self.device)
679 | elif args.prior == 'roll':
680 | prior = swiss_roll(args.batch_size)
681 | true_samples = torch.FloatTensor(prior).to(device=self.device)
682 | loss_mmd = compute_mmd(true_samples, a)
683 | print('mmd loss:', args.mmd_weight * loss_mmd)
684 | loss += args.mmd_weight * loss_mmd
685 | elif args.kld_weight != 0:
686 | # KLD term
687 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
688 | if args.use_C:
689 | # KLD term w/ control constant
690 | self.C_max = torch.FloatTensor([args.C_max]).to(device=self.device)
691 | C = torch.clamp(self.C_max/args.epochs*curr_epoch, torch.FloatTensor([0]).to(device=self.device), self.C_max)
692 | loss += args.kld_weight * (kld_loss - C.squeeze(dim=0)).abs()
693 | else:
694 | print('kld loss:', args.kld_weight * kld_loss)
695 | loss += args.kld_weight * kld_loss
696 | return loss
697 |
698 | def forward(self, x, idx=None, a=None, get_target=False):
699 |
700 | if idx is None:
701 | idx = torch.randint(0, len(self.alpha_bars), (x.size(0), )).to(device = self.device)
702 | used_alpha_bars = self.alpha_bars[idx][:, None, None, None]
703 | epsilon = torch.randn_like(x)
704 | x_tilde = torch.sqrt(used_alpha_bars) * x + torch.sqrt(1 - used_alpha_bars) * epsilon
705 | else:
706 | idx = torch.Tensor([idx for _ in range(x.size(0))]).to(device = self.device).long()
707 | x_tilde = x
708 |
709 | if a is None:
710 | a, a_q, mu, log_var = self.encoder(x)
711 | else:
712 | a_q = a
713 |
714 | if self.mmd_weight != 0 and self.kld_weight != 0:
715 | output = self.backbone(x_tilde, idx, a_q)
716 | elif self.mmd_weight == 0 and self.kld_weight == 0:
717 | output = self.backbone(x_tilde, idx, a)
718 | elif self.mmd_weight != 0:
719 | output = self.backbone(x_tilde, idx, a)
720 | elif self.kld_weight != 0:
721 | output = self.backbone(x_tilde, idx, a_q)
722 |
723 | return (output, epsilon, a, mu, log_var) if get_target else output
724 |
725 |
726 | class Diff(nn.Module):
727 | def __init__(self, args, device, shape):
728 | '''
729 | beta_1 : beta_1 of diffusion process
730 | beta_T : beta_T of diffusion process
731 | T : Diffusion Steps
732 | '''
733 |
734 | super().__init__()
735 | self.device = device
736 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device)
737 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps).to(device = device)
738 | self.alphas = 1 - self.betas
739 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]])
740 | self.is_latent = args.is_latent
741 | if args.mode == "train_latent_ddim":
742 | self.is_latent = True
743 | if args.input_size == 28:
744 | ch_mult = [1,2,4,]
745 | else:
746 | ch_mult = [1,2,4,8]
747 | if self.is_latent:
748 | self.backbone = LatentUNet(T=args.diffusion_steps, num_layers=10, dropout=0.1, shape=shape, activation='silu')
749 | else:
750 | self.backbone = UNet(ch_mult=ch_mult, T=args.diffusion_steps, ch=args.unets_channels, shape=shape)
751 | self.to(device)
752 |
753 | def loss_fn(self, args, x, idx=None, curr_epoch=0):
754 | '''
755 | x : real data if idx==None else perturbation data
756 | idx : if None (training phase), we perturbed random index.
757 | '''
758 | output, epsilon = self.forward(x, idx=idx, get_target=True)
759 | # denoising matching term
760 | loss = (output - epsilon).square().mean()
761 | # print('denoising loss:', loss)
762 | return loss
763 |
764 | def forward(self, x, idx=None, get_target=False):
765 |
766 | if idx is None:
767 | idx = torch.randint(0, len(self.alpha_bars), (x.size(0), )).to(device = self.device)
768 | if self.is_latent:
769 | used_alpha_bars = self.alpha_bars[idx][:, None]
770 | else:
771 | used_alpha_bars = self.alpha_bars[idx][:, None, None, None]
772 | epsilon = torch.randn_like(x)
773 | x_tilde = torch.sqrt(used_alpha_bars) * x + torch.sqrt(1 - used_alpha_bars) * epsilon
774 | else:
775 | idx = torch.Tensor([idx for _ in range(x.size(0))]).to(device = self.device).long()
776 | x_tilde = x
777 | output = self.backbone(x_tilde, idx)
778 |
779 | return (output, epsilon) if get_target else output
780 |
781 | class VAE(nn.Module):
782 | def __init__(self, args, device, shape):
783 | super().__init__()
784 | self.device = device
785 | if args.input_size == 28:
786 | ch_mult = [1,2,4,]
787 | else:
788 | ch_mult = [1,2,4,8]
789 | self.encoder = Encoder(ch_mult=ch_mult, ch=args.encoder_channels, a_dim=args.a_dim, shape=shape)
790 | self.decoder = Decoder(ch_mult=ch_mult, ch=args.encoder_channels, a_dim=args.a_dim, shape=shape)
791 | self.mmd_weight : float = args.mmd_weight
792 | self.kld_weight : float = args.kld_weight
793 | self.to(device)
794 |
795 | def loss_fn(self, args, x, curr_epoch=0):
796 | reconstruction, a_q, mu, log_var = self.forward(x, get_target=True)
797 | # denoising matching term
798 | loss = (reconstruction - x).square().mean()
799 | print('reconstruction loss:', loss)
800 | # prior matching term
801 | if args.mmd_weight != 0:
802 | # MMD term
803 | true_samples = torch.randn_like(a_q, device=self.device)
804 | loss_mmd = compute_mmd(true_samples, a_q)
805 | print('mmd loss:', args.mmd_weight * loss_mmd)
806 | loss += args.mmd_weight * loss_mmd
807 | elif args.kld_weight != 0:
808 | # KLD term
809 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
810 | if args.use_C:
811 | # KLD term w/ control constant
812 | self.C_max = torch.FloatTensor([args.C_max]).to(device=self.device)
813 | C = torch.clamp(self.C_max/args.epochs*curr_epoch, torch.FloatTensor([0]).to(device=self.device), self.C_max)
814 | loss += args.kld_weight * (kld_loss - C.squeeze(dim=0)).abs()
815 | print('kld-c loss:', (kld_loss - C.squeeze(dim=0)).abs())
816 | else:
817 | loss += args.kld_weight * kld_loss
818 | print('kld loss:', args.kld_weight * kld_loss)
819 | return loss
820 |
821 | def forward(self, x, get_target=False):
822 | a, a_q, mu, log_var = self.encoder(x)
823 |
824 | if self.mmd_weight != 0 and self.kld_weight != 0:
825 | reconstruction = self.decoder(a_q)
826 | elif self.mmd_weight == 0 and self.kld_weight == 0:
827 | reconstruction = self.decoder(a)
828 | elif self.mmd_weight != 0:
829 | reconstruction = self.decoder(a_q)
830 | elif self.kld_weight != 0:
831 | reconstruction = self.decoder(a_q)
832 |
833 | return (reconstruction, a_q, mu, log_var) if get_target else reconstruction
834 |
835 |
836 | class FeatureClassfier(nn.Module):
837 | def __init__(self, args, output_dim = 40):
838 | super(FeatureClassfier, self).__init__()
839 |
840 | """build full connect layers for every attribute"""
841 | self.fc_set = {}
842 |
843 | self.fc = nn.Sequential(
844 | nn.Linear(args.a_dim, 512),
845 | nn.ReLU(True),
846 | nn.Dropout(p=0.5),
847 | nn.Linear(512, 128),
848 | nn.ReLU(True),
849 | nn.Dropout(p=0.5),
850 | nn.Linear(128, output_dim),
851 | )
852 |
853 | self.sigmoid = nn.Sigmoid()
854 |
855 | def forward(self, x):
856 | x = x.view(x.size(0), -1) # flatten
857 | res = self.fc(x)
858 | res = self.sigmoid(res)
859 | return res
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | import torch.nn.functional as F
6 |
7 | from typing import Union
8 |
9 | class TimeEmbedding(nn.Module):
10 | def __init__(self, T, d_model, dim):
11 | assert d_model % 2 == 0
12 | super().__init__()
13 | emb = torch.arange(0, d_model, step=2) / torch.Tensor([d_model]) * math.log(10000)
14 | emb = torch.exp(-emb)
15 | pos = torch.arange(T).float()
16 | emb = pos[:, None] * emb[None, :]
17 | assert list(emb.shape) == [T, d_model // 2]
18 | emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
19 | assert list(emb.shape) == [T, d_model // 2, 2]
20 | emb = emb.view(T, d_model)
21 |
22 | self.timembedding = nn.Sequential(
23 | nn.Embedding.from_pretrained(emb),
24 | nn.Linear(d_model, dim),
25 | nn.SiLU(),
26 | nn.Linear(dim, dim),
27 | )
28 | self.initialize()
29 |
30 | def initialize(self):
31 | for module in self.modules():
32 | if isinstance(module, nn.Linear):
33 | init.xavier_uniform_(module.weight)
34 | init.zeros_(module.bias)
35 |
36 | def forward(self, t):
37 | emb = self.timembedding(t)
38 | return emb
39 |
40 |
41 | def timestep_embedding(timesteps, dim, max_period=10000):
42 | """
43 | Create sinusoidal timestep embeddings.
44 |
45 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
46 | These may be fractional.
47 | :param dim: the dimension of the output.
48 | :param max_period: controls the minimum frequency of the embeddings.
49 | :return: an [N x dim] Tensor of positional embeddings.
50 | """
51 | half = dim // 2
52 | freqs = torch.exp(-math.log(max_period) *
53 | torch.arange(start=0, end=half, dtype=torch.float32) /
54 | half).to(device=timesteps.device)
55 | args = timesteps[:, None].float() * freqs[None]
56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57 | if dim % 2:
58 | embedding = torch.cat(
59 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
60 | return embedding
61 |
62 |
63 | class DownSample(nn.Module):
64 | def __init__(self, in_ch):
65 | super().__init__()
66 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
67 | self.initialize()
68 |
69 | def initialize(self):
70 | init.xavier_uniform_(self.main.weight)
71 | init.zeros_(self.main.bias)
72 |
73 | def forward(self, x, temb=None, aemb=None):
74 | x = self.main(x)
75 | return x
76 |
77 |
78 | class UpSample(nn.Module):
79 | def __init__(self, in_ch):
80 | super().__init__()
81 | self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
82 | self.initialize()
83 |
84 | def initialize(self):
85 | init.xavier_uniform_(self.main.weight)
86 | init.zeros_(self.main.bias)
87 |
88 | def forward(self, x, temb=None, aemb=None):
89 | _, _, H, W = x.shape
90 | x = F.interpolate(
91 | x, scale_factor=float(2.0), mode='nearest')
92 | x = self.main(x)
93 | return x
94 |
95 |
96 | class LatentDownSample(nn.Module):
97 | def __init__(self, in_ch):
98 | super().__init__()
99 | self.main = nn.Conv1d(in_ch, in_ch, 3, stride=2, padding=1)
100 | self.initialize()
101 |
102 | def initialize(self):
103 | init.xavier_uniform_(self.main.weight)
104 | init.zeros_(self.main.bias)
105 |
106 | def forward(self, x, temb=None, aemb=None):
107 | x = self.main(x)
108 | return x
109 |
110 |
111 | class LatentUpSample(nn.Module):
112 | def __init__(self, in_ch):
113 | super().__init__()
114 | self.main = nn.Conv1d(in_ch, in_ch, 3, stride=1, padding=1)
115 | self.initialize()
116 |
117 | def initialize(self):
118 | init.xavier_uniform_(self.main.weight)
119 | init.zeros_(self.main.bias)
120 |
121 | def forward(self, x, temb=None, aemb=None):
122 | _, _, L = x.shape
123 | x = F.interpolate(
124 | x, scale_factor=float(2.0), mode='nearest')
125 | x = self.main(x)
126 | return x
127 |
128 |
129 | class AttnBlock(nn.Module):
130 | def __init__(self, in_ch):
131 | super().__init__()
132 | self.group_norm = nn.GroupNorm(32, in_ch)
133 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
134 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
135 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
136 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
137 | self.initialize()
138 |
139 | def initialize(self):
140 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
141 | init.xavier_uniform_(module.weight)
142 | init.zeros_(module.bias)
143 | init.xavier_uniform_(self.proj.weight, gain=1e-5)
144 |
145 | def forward(self, x):
146 | B, C, H, W = x.shape
147 | h = self.group_norm(x)
148 | q = self.proj_q(h)
149 | k = self.proj_k(h)
150 | v = self.proj_v(h)
151 |
152 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
153 | k = k.view(B, C, H * W)
154 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
155 | assert list(w.shape) == [B, H * W, H * W]
156 | w = F.softmax(w, dim=-1)
157 |
158 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
159 | h = torch.bmm(w, v)
160 | assert list(h.shape) == [B, H * W, C]
161 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
162 | h = self.proj(h)
163 |
164 | return x + h
165 |
166 |
167 | class CrossAttnBlock(nn.Module):
168 | def __init__(self, in_ch):
169 | super().__init__()
170 | self.group_norm = nn.GroupNorm(32, in_ch)
171 | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
172 | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
173 | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
174 | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
175 | self.initialize()
176 |
177 | def initialize(self):
178 | for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
179 | init.xavier_uniform_(module.weight)
180 | init.zeros_(module.bias)
181 | init.xavier_uniform_(self.proj.weight, gain=1e-5)
182 |
183 | def forward(self, x, a):
184 | B, C, H, W = x.shape
185 | h = self.group_norm(x)
186 | h_a = self.group_norm(a)
187 | q = self.proj_q(h_a)
188 | k = self.proj_k(h)
189 | v = self.proj_v(h)
190 |
191 | q = q.permute(0, 2, 3, 1).view(B, H * W, C)
192 | k = k.view(B, C, H * W)
193 | w = torch.bmm(q, k) * (int(C) ** (-0.5))
194 | assert list(w.shape) == [B, H * W, H * W]
195 | w = F.softmax(w, dim=-1)
196 |
197 | v = v.permute(0, 2, 3, 1).view(B, H * W, C)
198 | h = torch.bmm(w, v)
199 | assert list(h.shape) == [B, H * W, C]
200 | h = h.view(B, H, W, C).permute(0, 3, 1, 2)
201 | h = self.proj(h)
202 |
203 | return x + h
204 |
205 |
206 | class ResBlock(nn.Module):
207 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
208 | super().__init__()
209 | self.temb_proj = nn.Sequential(
210 | nn.SiLU(),
211 | nn.Linear(tdim, 2*out_ch),
212 | )
213 | self.block1 = nn.Sequential(
214 | nn.GroupNorm(32, in_ch),
215 | nn.SiLU(),
216 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
217 | )
218 | self.block2 = nn.Sequential(
219 | nn.GroupNorm(32, out_ch),
220 | nn.SiLU(),
221 | nn.Dropout(dropout),
222 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
223 | )
224 | self.block3 = nn.Sequential(
225 | nn.GroupNorm(32, out_ch),
226 | nn.SiLU(),
227 | nn.Dropout(dropout),
228 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
229 | )
230 | if in_ch != out_ch:
231 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
232 | else:
233 | self.shortcut = nn.Identity()
234 | self.use_attn = attn
235 | if self.use_attn:
236 | self.attn = AttnBlock(out_ch)
237 | else:
238 | self.attn = nn.Identity()
239 | self.initialize()
240 |
241 | def initialize(self):
242 | for module in self.modules():
243 | if isinstance(module, (nn.Conv2d, nn.Linear)):
244 | init.xavier_uniform_(module.weight)
245 | init.zeros_(module.bias)
246 |
247 | def forward(self, x, temb):
248 | h = self.block1(x)
249 | temb_out = self.temb_proj(temb)[:, :, None, None]
250 | out_norm, out_rest = self.block2[0], self.block2[1:]
251 | scale, shift = torch.chunk(temb_out, 2, dim=1)
252 | h = out_norm(h) * (1 + scale) + shift
253 | h = out_rest(h)
254 | h = self.block3(h)
255 | h = h + self.shortcut(x)
256 | h = self.attn(h)
257 |
258 | return h
259 |
260 |
261 | class AuxResBlock(nn.Module):
262 | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False, crossattn : Union[bool, nn.Module] =False):
263 | super().__init__()
264 | self.block1 = nn.Sequential(
265 | nn.GroupNorm(32, in_ch),
266 | nn.SiLU(),
267 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
268 | )
269 | self.temb_proj = nn.Sequential(
270 | nn.SiLU(),
271 | nn.Linear(tdim, 2*out_ch),
272 | )
273 | self.aemb_proj = nn.Sequential(
274 | nn.SiLU(),
275 | nn.Linear(tdim, 2*out_ch),
276 | )
277 | self.block2 = nn.Sequential(
278 | nn.GroupNorm(32, out_ch),
279 | nn.SiLU(),
280 | nn.Dropout(dropout),
281 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
282 | )
283 | self.block3 = nn.Sequential(
284 | nn.GroupNorm(32, out_ch),
285 | nn.SiLU(),
286 | nn.Dropout(dropout),
287 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
288 | )
289 |
290 | if in_ch != out_ch:
291 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
292 | else:
293 | self.shortcut = nn.Identity()
294 | self.use_attn = attn
295 | if self.use_attn:
296 | self.attn = AttnBlock(out_ch)
297 | else:
298 | self.attn = nn.Identity()
299 | self.use_crossattn = bool(crossattn)
300 | self.crossattn = CrossAttnBlock(out_ch)
301 | self.initialize()
302 |
303 | def initialize(self):
304 | for module in self.modules():
305 | if isinstance(module, (nn.Conv2d, nn.Linear)):
306 | init.xavier_uniform_(module.weight)
307 | init.zeros_(module.bias)
308 |
309 | def forward(self, x, temb, aemb=None):
310 | h = self.block1(x)
311 |
312 | temb_out = self.temb_proj(temb)[:, :, None, None]
313 | out_norm, out_rest = self.block2[0], self.block2[1:]
314 | scale, shift = torch.chunk(temb_out, 2, dim=1)
315 | h = out_norm(h) * (1 + scale) + shift
316 | aemb_out = self.aemb_proj(aemb)[:, :, None, None]
317 | scale, shift = torch.chunk(aemb_out, 2, dim=1)
318 | h = h * (1 + scale) + shift
319 | h = out_rest(h)
320 | h = self.block3(h)
321 |
322 | h += self.shortcut(x)
323 | h = self.attn(h)
324 |
325 | if self.use_crossattn:
326 | h = self.crossattn(h, aemb)
327 |
328 | return h
329 |
330 |
331 | class ResBlock_encoder(nn.Module):
332 | def __init__(self, in_ch, out_ch, dropout, attn=False):
333 | super().__init__()
334 | self.block1 = nn.Sequential(
335 | nn.GroupNorm(32, in_ch),
336 | nn.SiLU(),
337 | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
338 | )
339 | self.block2 = nn.Sequential(
340 | nn.GroupNorm(32, out_ch),
341 | nn.SiLU(),
342 | nn.Dropout(dropout),
343 | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
344 | )
345 | if in_ch != out_ch:
346 | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
347 | else:
348 | self.shortcut = nn.Identity()
349 | if attn:
350 | self.attn = AttnBlock(out_ch)
351 | else:
352 | self.attn = nn.Identity()
353 | self.initialize()
354 |
355 | def initialize(self):
356 | for module in self.modules():
357 | if isinstance(module, (nn.Conv2d, nn.Linear)):
358 | init.xavier_uniform_(module.weight)
359 | init.zeros_(module.bias)
360 |
361 | def forward(self, x):
362 | h = self.block1(x)
363 | h = self.block2(h)
364 | h += self.shortcut(x)
365 | h = self.attn(h)
366 | return h
367 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | from torchvision.utils import save_image
6 | from tqdm.auto import tqdm, trange
7 | from torch.utils.tensorboard import SummaryWriter
8 | from torch.utils.data import DataLoader
9 | from data import get_dataset, get_dataset_config
10 | from models import InfoDiff, Diff, VAE
11 | from sampling import DiffusionProcess, TwoPhaseDiffusionProcess, LatentDiffusionProcess
12 | from utils import (
13 | AverageMeter, ProgressMeter, GradualWarmupScheduler, \
14 | generate_exp_string, seed_everything, cos, LatentDataset
15 | )
16 | import matplotlib.pyplot as plt
17 |
18 |
19 | torch.backends.cudnn.benchmark = True
20 | torch.backends.cudnn.allow_tf32 = True
21 |
22 |
23 | # ----------------------------------------------------------------------------
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser()
27 |
28 | parser.add_argument('--r_seed', type=int, default=0,
29 | help='the value of given random seed')
30 | parser.add_argument('--img_id', type=int, default=0,
31 | help='the id of given img')
32 | parser.add_argument('--model', required=True,
33 | choices=['diff', 'vae', 'vanilla'], help='which type of model to run')
34 | parser.add_argument('--mode', required=True,
35 | choices=['train', 'eval', 'eval_fid', 'save_latent', 'disentangle',
36 | 'interpolate', 'save_original_img', 'latent_quality',
37 | 'train_latent_ddim', 'plot_latent'], help='which mode to run')
38 | parser.add_argument('--prior', required=True,
39 | choices=['regular', '10mix', 'roll'], help='which type of prior to run')
40 | parser.add_argument('--kld_weight', type=float, default=0,
41 | help='weight of kld loss')
42 | parser.add_argument('--mmd_weight', type=float, default=0.1,
43 | help='weight of mmd loss')
44 | parser.add_argument('--use_C', action='store_true',
45 | default=False, help='use control constant or not')
46 | parser.add_argument('--C_max', type=float, default=25,
47 | help='control constant of kld loss (orig defualt: 25 for simple, 50 for complex)')
48 | parser.add_argument('--dataset', required=True,
49 | choices=['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites', 'chairs', 'ffhq'], help='training dataset')
50 | parser.add_argument('--img_folder', default='./imgs',
51 | help='path to save sampled images')
52 | parser.add_argument('--log_folder', default='./logs',
53 | help='path to save logs')
54 | parser.add_argument('-e', '--epochs', type=int, default=20,
55 | help='number of epochs to train')
56 | parser.add_argument('--save_epochs', type=int, default=5,
57 | help='number of epochs to save model')
58 | parser.add_argument('--batch_size', type=int, default=64,
59 | help='training batch size')
60 | parser.add_argument('--learning_rate', type=float, default=0.0001,
61 | help='learning rate')
62 | parser.add_argument('--optimizer', default='adam', choices=['adam'],
63 | help='optimization algorithm')
64 | parser.add_argument('--model_folder', default='./models',
65 | help='folder where logs will be stored')
66 | parser.add_argument('--deterministic', action='store_true',
67 | default=False, help='deterministid sampling')
68 | parser.add_argument('--input_channels', type=int, default=1,
69 | help='number of input channels')
70 | parser.add_argument('--unets_channels', type=int, default=64,
71 | help='number of input channels')
72 | parser.add_argument('--encoder_channels', type=int, default=64,
73 | help='number of input channels')
74 | parser.add_argument('--input_size', type=int, default=32,
75 | help='expected size of input')
76 | parser.add_argument('--a_dim', type=int, default=32, required=True,
77 | help='dimensionality of auxiliary variable')
78 | parser.add_argument('--beta1', type=float, default=1e-5,
79 | help='value of beta 1')
80 | parser.add_argument('--betaT', type=float, default=1e-2,
81 | help='value of beta T')
82 | parser.add_argument('--diffusion_steps', type=int, default=1000,
83 | help='number of diffusion steps')
84 | parser.add_argument('--split_step', type=int, default=500,
85 | help='the step for splitting two phases')
86 | parser.add_argument('--sampling_number', type=int, default=16,
87 | help='number of sampled images')
88 | parser.add_argument('--data_dir', type=str, default='./data')
89 | parser.add_argument('--tb_logger', action='store_true',
90 | help='use tensorboard logger.')
91 | parser.add_argument('--is_latent', action='store_true',
92 | help='use latent diffusion for unconditional sampling.')
93 | parser.add_argument('--is_bottleneck', action='store_true',
94 | help='only fuse aux variable in bottleneck layers.')
95 | args = parser.parse_args()
96 |
97 | return args
98 |
99 |
100 | # ----------------------------------------------------------------------------
101 |
102 |
103 | def save_images(args, sample=None, epoch=0, sample_num=0):
104 | root = f'{args.img_folder}'
105 | if args.model == 'vae':
106 | root = os.path.join(root, 'vae')
107 | else:
108 | if args.model == 'vanilla':
109 | root = os.path.join(root, 'diff')
110 | root = os.path.join(root, generate_exp_string(args))
111 | if args.mode == 'eval':
112 | root = os.path.join(root, 'eval')
113 | elif args.mode == 'disentangle':
114 | root = os.path.join(root, f'disentangle-{args.img_id}')
115 | elif args.mode == 'interpolate':
116 | root = os.path.join(root, f'interpolate-{args.img_id}')
117 | elif args.mode == 'save_latent':
118 | root = os.path.join(root, 'save_latent')
119 | elif args.mode == 'attr_classification':
120 | root = os.path.join(root, 'attr_classification')
121 | elif args.mode == 'plot_latent':
122 | root = os.path.join(root, 'plot_latent')
123 | os.makedirs(root, exist_ok=True)
124 | path = os.path.join(root, f'sample-{epoch}.png')
125 |
126 | img_range = (-1, 1)
127 | if args.mode == 'train':
128 | save_image(sample, path, normalize=True, range=img_range, nrow=4)
129 | elif args.mode == 'eval':
130 | for _ in range(sample_num, sample_num + len(sample)):
131 | path = os.path.join(root, f"sample{sample_num:05d}.png")
132 | save_image(sample, path, normalize=True, range=img_range)
133 | elif args.mode == 'disentangle':
134 | path = os.path.join(root, f"sample{sample_num}.png")
135 | save_image(sample, path, normalize=True, range=img_range, nrow=sample.shape[0])
136 | elif args.mode == 'interpolate':
137 | path = os.path.join(root, f"sample{sample_num}.png")
138 | save_image(sample, path, normalize=True, range=img_range, nrow=sample.shape[0])
139 | elif args.mode == 'plot_latent':
140 | path = os.path.join(root, f"{args.mode}.png")
141 | return path
142 | elif args.mode == 'attr_classification':
143 | return root
144 |
145 | def save_model(args, epoch, model):
146 | root = f'{args.model_folder}'
147 | if args.model == 'vae':
148 | root = os.path.join(root, 'vae')
149 | else:
150 | if args.model == 'vanilla':
151 | root = os.path.join(root, 'diff')
152 | root = os.path.join(root, generate_exp_string(args))
153 | if args.mode == "train_latent_ddim":
154 | root += '_latent'
155 | os.makedirs(root, exist_ok=True)
156 | path = os.path.join(root, f'model-{epoch}.pth')
157 | torch.save(model.state_dict(), path)
158 | print(f"Saved PyTorch model state to {path}")
159 |
160 |
161 | def train(args):
162 | seed_everything(args.r_seed)
163 | log_dir = f'{args.log_folder}'
164 | log_dir = os.path.join(log_dir, generate_exp_string(args))
165 | tb_logger = SummaryWriter(log_dir=log_dir) if args.tb_logger else None
166 | device = "cuda" if torch.cuda.is_available() else "cpu"
167 | shape = get_dataset_config(args)
168 | print(dict(vars(args)))
169 | dataloader = get_dataset(args)
170 |
171 | if args.model == 'diff':
172 | model = InfoDiff(args, device, shape)
173 | elif args.model == 'vanilla':
174 | model = Diff(args, device, shape)
175 | elif args.model == 'vae':
176 | model = VAE(args, device, shape)
177 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
178 |
179 | losses = AverageMeter('Loss', ':.4f')
180 | progress = ProgressMeter(args.epochs, [losses], prefix='Epoch ')
181 |
182 | cosineScheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
183 | optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1)
184 | warmUpScheduler = GradualWarmupScheduler(
185 | optimizer=optimizer, multiplier=2., warm_epoch=1, after_scheduler=cosineScheduler)
186 |
187 | global_step = 0
188 | for curr_epoch in trange(0, args.epochs, desc="Epoch #"):
189 | total_loss = 0
190 | batch_bar = tqdm(dataloader, desc="Batch #")
191 | for idx, data in enumerate(batch_bar):
192 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10']:
193 | data = data[0]
194 | data = data.to(device=device)
195 | loss = model.loss_fn(args=args, x=data, curr_epoch=curr_epoch)
196 | batch_bar.set_postfix(loss=format(loss,'.4f'))
197 | optimizer.zero_grad()
198 | loss.backward()
199 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
200 | optimizer.step()
201 | total_loss += loss.item()
202 | global_step += 1
203 | if tb_logger:
204 | tb_logger.add_scalar('train/loss', loss.item(), global_step)
205 | losses.update(total_loss / idx)
206 | current_epoch = curr_epoch
207 | progress.display(current_epoch)
208 | current_epoch += 1
209 | warmUpScheduler.step()
210 | losses.reset()
211 | if current_epoch % args.save_epochs == 0:
212 | save_model(args, current_epoch, model)
213 |
214 |
215 | def eval(args):
216 | if args.mode != 'train_latent_ddim':
217 | seed_everything(args.r_seed)
218 | device = "cuda" if torch.cuda.is_available() else "cpu"
219 | shape = get_dataset_config(args)
220 | print(dict(vars(args)))
221 | root = f'{args.model_folder}'
222 | if args.model == 'diff':
223 | model = InfoDiff(args, device, shape)
224 | elif args.model == 'vanilla':
225 | model = Diff(args, device, shape)
226 | root = os.path.join(root, 'diff')
227 | elif args.model == 'vae':
228 | model = VAE(args, device, shape)
229 | root = os.path.join(root, 'vae')
230 | root = os.path.join(root, generate_exp_string(args))
231 | path = os.path.join(root, f'model-{args.epochs}.pth')
232 | print(f"Loading model from {path}")
233 | model.load_state_dict(torch.load(path, map_location=device), strict=False)
234 | if (args.dataset in ['celeba', 'cifar10', 'mnist', 'fmnist', 'ffhq'] and args.mode in ['eval_fid']):
235 | if args.is_latent:
236 | shape_latent = (1, args.a_dim, args.a_dim)
237 | model2 = Diff(args, device, shape_latent)
238 | path2 = f'./models/{generate_exp_string(args)}_latent/model-{args.epochs}.pth'
239 | if os.path.exists(path2):
240 | print(f"Loading model from {path2}")
241 | else:
242 | raise FileNotFoundError("The file path {} does not exist, please train the latent diffusion model first.".format(path2))
243 | model2.load_state_dict(torch.load(path2, map_location=device), strict=True)
244 | else:
245 | model2 = Diff(args, device, shape)
246 | path2 = f'./models/diff/{args.dataset}_{args.a_dim}d/model-{args.epochs}.pth'
247 | if os.path.exists(path2):
248 | print(f"Loading model from {path2}")
249 | else:
250 | raise FileNotFoundError("The file path {} does not exist, please train the vanilla diffusion model first.".format(path2))
251 | model2.load_state_dict(torch.load(path2, map_location=device), strict=True)
252 | model2.eval()
253 | model.eval()
254 | if args.model in ['diff', 'vanilla']:
255 | process = DiffusionProcess(args, model, device, shape)
256 | if args.mode == 'eval':
257 | if args.model in ['diff', 'vanilla']:
258 | for sample_num in trange(0, args.sampling_number, args.batch_size, desc="Generating eval images"):
259 | sample = process.sampling(sampling_number=16)
260 | save_images(args, sample, sample_num=sample_num)
261 | elif args.model == 'vae':
262 | a = torch.randn([args.sampling_number, args.a_dim]).to(device=device)
263 | sample = model.decoder(a)
264 | save_images(args, sample)
265 | elif args.mode == 'eval_fid':
266 | root = f'{args.img_folder}'
267 | if args.model == 'vae':
268 | root = os.path.join(root, 'vae')
269 | root = os.path.join(root, generate_exp_string(args))
270 | if args.is_latent:
271 | root = os.path.join(root, 'eval-fid-latent')
272 | else:
273 | root = os.path.join(root, 'eval-fid-fast')
274 | os.makedirs(root, exist_ok=True)
275 | print(f"Saving images to {root}")
276 | if args.model == 'diff':
277 | if args.is_latent:
278 | process_latent = LatentDiffusionProcess(args, model2, device)
279 | else:
280 | process = TwoPhaseDiffusionProcess(args, model, model2, device, shape)
281 |
282 | for sample_num in trange(0, args.sampling_number, args.batch_size, desc="Generating eval images"):
283 | if args.is_latent:
284 | batch_a = process_latent.sampling(sampling_number=args.batch_size)
285 | batch = process.sampling(sampling_number=args.batch_size, a=batch_a)
286 | else:
287 | batch = process.sampling(sampling_number=args.batch_size)
288 | for batch_num, img in enumerate(batch):
289 | img = torch.clip(img, min=-1, max=1)
290 | img = ((img + 1)/2) # normalize to 0 - 1
291 | img_num = sample_num + batch_num
292 | if img_num >= args.sampling_number:
293 | return
294 | path = os.path.join(root, f'sample-{img_num:06d}.png')
295 | save_image(img, path)
296 | print("DONE")
297 | elif args.model == 'vae':
298 | for sample_num in trange(0, args.sampling_number, args.batch_size, desc="Generating eval images"):
299 | a = torch.randn([args.batch_size, args.a_dim]).to(device=device)
300 | batch = model.decoder(a)
301 | for batch_num, img in enumerate(batch):
302 | img = torch.clip(img, min=-1, max=1)
303 | img = ((img + 1)/2) # normalize to 0 - 1
304 | img_num = sample_num + batch_num
305 | if img_num >= args.sampling_number:
306 | return
307 | path = os.path.join(root, f'sample-{img_num:06d}.png')
308 | save_image(img, path)
309 | print("DONE")
310 | elif args.mode == 'latent_quality':
311 | process = DiffusionProcess(args, model, device, shape)
312 | dataloader = get_dataset(args)
313 | root = f'{args.img_folder}'
314 | root = os.path.join(root, generate_exp_string(args))
315 | root = os.path.join(root, 'latent_quality')
316 | print(f"Saving images to {root}")
317 | for idx, data in enumerate(dataloader):
318 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites']:
319 | data_all = data
320 | data = data_all[0]
321 | if idx == 10:
322 | break
323 | data = data.to(device=device)
324 | if args.kld_weight != 0:
325 | with torch.no_grad():
326 | _, _, mu, log_var = model.encoder(data)
327 | a = mu + torch.exp(0.5 * log_var)
328 | elif args.mmd_weight != 0:
329 | with torch.no_grad():
330 | a, _, _, _ = model.encoder(data)
331 | xT = process.reverse_sampling(data, a)
332 | xT_original = xT.repeat(args.sampling_number, 1, 1, 1)
333 | a_original = a.repeat(args.sampling_number, 1)
334 | xT = torch.randn_like(xT_original)
335 | batch = process.sampling(xT=xT, a=a_original)
336 | os.makedirs(root, exist_ok=True)
337 | for batch_num, img in enumerate(batch):
338 | img = torch.clip(img, min=-1, max=1)
339 | img = ((img + 1)/2) # normalize to 0 - 1
340 | path = os.path.join(path, f'sample-{batch_num:06d}.png')
341 | save_image(img, path)
342 | elif args.mode == 'plot_latent':
343 | all_a, all_attr = [], []
344 | dataloader = get_dataset(args)
345 | for idx, data in enumerate(dataloader):
346 | if args.dataset in ['fmnist', 'celeba', 'cifar10', 'dsprites', 'mnist']:
347 | data_all = data
348 | data = data_all[0]
349 | if args.dataset in ['celeba', 'fmnist', 'mnist']:
350 | latents_classes = data_all[1]
351 | elif args.dataset == 'dsprites':
352 | latents_classes = data_all[2]
353 | data = data.to(device=device)
354 | if (args.mmd_weight == 0 and args.kld_weight == 0):
355 | with torch.no_grad():
356 | a, _, _, _ = model.encoder(data)
357 | elif (args.mmd_weight != 0):
358 | with torch.no_grad():
359 | a, _, _, _ = model.encoder(data)
360 | else:
361 | with torch.no_grad():
362 | _, _, mu, _ = model.encoder(data)
363 | a = mu
364 | all_a.append(a.cpu().numpy())
365 | all_attr.append(latents_classes)
366 | all_a = np.concatenate(all_a)
367 | all_attr = np.concatenate(all_attr)
368 | plt.scatter(all_a[:, 0], all_a[:, 1], c = all_attr, cmap = 'tab10', s=5)
369 | path = save_images(args)
370 | plt.savefig(path)
371 | elif args.mode == 'disentangle':
372 | dataloader = get_dataset(args)
373 | for idx, data in enumerate(dataloader):
374 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites']:
375 | data_all = data
376 | data = data_all[0]
377 | if args.dataset == 'celeba':
378 | latents_classes = data_all[1]
379 | elif args.dataset == 'dsprites':
380 | latents_classes = data_all[2]
381 | if idx == args.img_id:
382 | break
383 | data = data.to(device=device)
384 | # eta = [-3, -2.4, -1.8, -1.2, -0.6, 0.0, 0.6, 1.2, 1.8, 2.4, 3.0]
385 | eta = [-1.5, -1.2, -0.9, -0.6, -0.3, 0.0, 0.3, 0.6, 0.9, 1.2, 1.5]
386 | if args.kld_weight != 0:
387 | with torch.no_grad():
388 | _, _, mu, _ = model.encoder(data)
389 | a = mu
390 | elif args.mmd_weight != 0:
391 | with torch.no_grad():
392 | a, _, _, _ = model.encoder(data)
393 | if args.model == 'diff':
394 | xT = process.reverse_sampling(data, a)
395 | xT = xT.repeat(len(eta), 1, 1, 1)
396 | for k in range(args.a_dim):
397 | a_list = []
398 | for e in eta:
399 | if args.kld_weight != 0:
400 | with torch.no_grad():
401 | _, _, mu, log_var = model.encoder(data)
402 | a = mu
403 | print(mu, log_var)
404 | elif args.mmd_weight != 0:
405 | with torch.no_grad():
406 | a, _, _, _ = model.encoder(data)
407 | a[0][k] = e
408 | a_list.append(a)
409 | a = torch.stack(a_list).squeeze(dim=1)
410 | if args.model == 'diff':
411 | sample = process.sampling(xT=xT, a=a)
412 | elif args.model == 'vae':
413 | sample = model.decoder(a)
414 | save_images(args, sample, sample_num=k)
415 | elif args.mode == 'save_latent':
416 | all_a, all_attr = [], []
417 | dataloader = get_dataset(args)
418 | for idx, data in enumerate(dataloader):
419 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10', 'dsprites']:
420 | data_all = data
421 | data = data_all[0]
422 | if args.dataset in ['celeba', 'fmnist', 'mnist', 'cifar10']:
423 | latents_classes = data_all[1]
424 | elif args.dataset == 'dsprites':
425 | latents_classes = data_all[2]
426 | else:
427 | latents_classes = ['No Attributes']
428 | data = data.to(device=device)
429 | if args.kld_weight != 0:
430 | with torch.no_grad():
431 | _, _, mu, _ = model.encoder(data)
432 | a = mu
433 | elif args.mmd_weight != 0:
434 | with torch.no_grad():
435 | a, _, _, _ = model.encoder(data)
436 | elif (args.mmd_weight == 0 and args.kld_weight == 0):
437 | with torch.no_grad():
438 | a, _, _, _ = model.encoder(data)
439 | all_a.append(a.cpu().numpy())
440 | all_attr.append(latents_classes)
441 | all_a = np.concatenate(all_a)
442 | all_attr = np.concatenate(all_attr)
443 | np.savez("{}_{}_latent".format(args.model, generate_exp_string(args).replace(".", "_")), all_a = all_a, all_attr = all_attr)
444 | elif args.mode == 'interpolate':
445 | dataloader = get_dataset(args)
446 | for idx, data in enumerate(dataloader):
447 | if args.dataset in ['fmnist', 'mnist', 'celeba', 'cifar10']:
448 | data = data[0]
449 | if idx == args.img_id:
450 | break
451 | data = data.to(device=device)
452 | if args.kld_weight != 0:
453 | with torch.no_grad():
454 | _, _, mu, _ = model.encoder(data)
455 | a = mu
456 | elif args.mmd_weight != 0:
457 | with torch.no_grad():
458 | a, _, _, _ = model.encoder(data)
459 | elif (args.mmd_weight == 0 and args.kld_weight == 0):
460 | with torch.no_grad():
461 | a, _, _, _ = model.encoder(data)
462 | if args.model in ['diff', 'vanilla']:
463 | xT = process.reverse_sampling(data, a)
464 | theta = torch.arccos(cos(xT[0], xT[1]))
465 | a1 = a[0]
466 | a2 = a[1]
467 | eta = [0.0, 0.11, 0.22, 0.33, 0.44, 0.55, 0.66, 0.77, 0.88, 1.0]
468 | intp_a_list = []
469 | intp_x_list = []
470 | for e in eta:
471 | intp_a_list.append(np.cos(e * np.pi / 2) * a1 + np.sin(e * np.pi / 2) * a2)
472 | if args.model in ['diff', 'vanilla']:
473 | intp_x = (torch.sin((1 - e) * theta) * xT[0] + torch.sin(e * theta) * xT[1]) / torch.sin(theta)
474 | intp_x_list.append(intp_x)
475 | intp_a = torch.stack(intp_a_list)
476 | if args.model in ['diff', 'vanilla']:
477 | intp_x = torch.stack(intp_x_list).squeeze(dim=1)
478 | sample = process.sampling(xT=intp_x, a=intp_a)
479 | elif args.model == 'vae':
480 | sample = model.decoder(intp_a)
481 | save_images(args, sample)
482 | elif args.mode == 'train_latent_ddim':
483 | dataset = LatentDataset("{}_{}_latent.npz".format(args.model, generate_exp_string(args).replace(".", "_")))
484 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
485 | seed_everything(args.r_seed)
486 | log_dir = f'{args.log_folder}'
487 | log_dir = os.path.join(log_dir, generate_exp_string(args))
488 | log_dir += '_latent'
489 | tb_logger = SummaryWriter(log_dir=log_dir) if args.tb_logger else None
490 | device = "cuda" if torch.cuda.is_available() else "cpu"
491 | shape = (1, args.a_dim, args.a_dim)
492 | model = Diff(args, device, shape)
493 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
494 |
495 | losses = AverageMeter('Loss', ':.4f')
496 | progress = ProgressMeter(args.epochs, [losses], prefix='Epoch ')
497 |
498 | cosineScheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
499 | optimizer=optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1)
500 | warmUpScheduler = GradualWarmupScheduler(
501 | optimizer=optimizer, multiplier=2., warm_epoch=1, after_scheduler=cosineScheduler)
502 |
503 | global_step = 0
504 | for curr_epoch in trange(0, args.epochs, desc="Epoch #"):
505 | total_loss = 0
506 | batch_bar = tqdm(dataloader, desc="Batch #")
507 | for idx, data in enumerate(batch_bar):
508 | data = data.to(device=device)
509 | loss = model.loss_fn(args=args, x=data, curr_epoch=curr_epoch)
510 | batch_bar.set_postfix(loss=format(loss,'.4f'))
511 | optimizer.zero_grad()
512 | loss.backward()
513 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
514 | optimizer.step()
515 | total_loss += loss.item()
516 | global_step += 1
517 | if tb_logger:
518 | tb_logger.add_scalar('train/loss', loss.item(), global_step)
519 | losses.update(total_loss / idx)
520 | current_epoch = curr_epoch
521 | progress.display(current_epoch)
522 | current_epoch += 1
523 | warmUpScheduler.step()
524 | losses.reset()
525 | if current_epoch % args.save_epochs == 0:
526 | save_model(args, current_epoch, model)
527 |
528 |
529 | if __name__ == '__main__':
530 | args = parse_args()
531 | if args.mode in ['train']:
532 | train(args)
533 | elif args.mode in ['eval', 'eval_fid', 'latent_quality', 'disentangle', 'interpolate',
534 | 'save_latent', 'train_latent_ddim', 'plot_latent']:
535 | if args.mode in ['disentangle', 'latent_quality']:
536 | args.batch_size = 1
537 | elif args.mode == 'interpolate':
538 | args.batch_size = 2
539 | eval(args)
540 | elif args.mode in ['save_original_img']:
541 | from tqdm import tqdm
542 | import os
543 | from torchvision.utils import save_image
544 | output_folder = f'./{args.dataset}_imgs/'
545 | os.makedirs(output_folder, exist_ok=True)
546 | dataloader = get_dataset(args)
547 | for i, img in enumerate(tqdm(dataloader)):
548 | img = ((img[0] + 1)/2) # normalize to 0 - 1
549 | save_image(img, f'{output_folder}/{i:06d}.png')
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | python run.py --model diff --mode train --mmd_weight 0.1 --a_dim 32 --epochs 50 --dataset celeba --batch_size 32 --save_epochs 5 --deterministic --prior regular --r_seed 64
4 |
--------------------------------------------------------------------------------
/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class DiffusionProcess():
4 | def __init__(self, args, diffusion_fn, device, shape):
5 | '''
6 | beta_1 : beta_1 of diffusion process
7 | beta_T : beta_T of diffusion process
8 | T : step of diffusion process
9 | diffusion_fn : trained diffusion network
10 | shape : data shape
11 | '''
12 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps)
13 | self.alphas = 1 - self.betas
14 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device)
15 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]])
16 | self.shape = shape
17 | self.deterministic = args.deterministic
18 | self.a_dim = args.a_dim
19 | self.model = args.model
20 | self.diffusion_fn = diffusion_fn.to(device=device)
21 | self.device = device
22 |
23 | def _ddpm_one_diffusion_step(self, x, a=None):
24 | '''
25 | x : perturbated data
26 | '''
27 | for idx in reversed(range(len(self.alpha_bars))):
28 |
29 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x)
30 | sqrt_tilde_beta = torch.sqrt((1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx])
31 | if self.model == 'vanilla':
32 | predict_epsilon = self.diffusion_fn(x, idx)
33 | else:
34 | predict_epsilon = self.diffusion_fn(x, idx, a)
35 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * (x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon)
36 |
37 | x = mu_theta_xt + sqrt_tilde_beta * noise
38 |
39 | yield x
40 |
41 | def _ddim_one_diffusion_step(self, x, a=None):
42 | '''
43 | x : perturbated data
44 | '''
45 | eta = 0.01
46 | for idx in reversed(range(len(self.alpha_bars))):
47 |
48 | if self.model == 'vanilla':
49 | predict_epsilon = self.diffusion_fn(x, idx)
50 | else:
51 | predict_epsilon = self.diffusion_fn(x, idx, a)
52 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx])
53 | if idx == 0:
54 | x = x_0
55 | else:
56 | noise = torch.randn_like(x)
57 | sigma = eta * torch.sqrt((1 - self.alpha_prev_bars[idx-1]) / (1 - self.alpha_bars[idx-1])) * torch.sqrt(self.betas[idx-1])
58 | x = torch.sqrt(self.alpha_prev_bars[idx-1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx-1] - sigma**2) * predict_epsilon
59 | x += sigma * noise
60 | yield x
61 |
62 | def _ddim_one_reverse_diffusion_step(self, x, a=None):
63 | for idx in range(len(self.alpha_bars)-1):
64 | if idx == 0:
65 | yield x
66 | else:
67 | if self.model == 'vanilla':
68 | predict_epsilon = self.diffusion_fn(x, idx)
69 | else:
70 | predict_epsilon = self.diffusion_fn(x, idx, a)
71 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx])
72 | x = torch.sqrt(self.alpha_prev_bars[idx+1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx+1]) * predict_epsilon
73 | yield x
74 |
75 | def _one_diffusion_step(self, sample, a=None, deterministic=False):
76 | if deterministic:
77 | return self._ddim_one_diffusion_step(sample, a)
78 | else:
79 | return self._ddpm_one_diffusion_step(sample, a)
80 |
81 | @torch.no_grad()
82 | def reverse_sampling(self, x0, a=None):
83 | sample = x0
84 | for sample in self._ddim_one_reverse_diffusion_step(sample):
85 | final = sample
86 |
87 | return final
88 |
89 | @torch.no_grad()
90 | def sampling(self, sampling_number=16, xT=None, a=None):
91 | if xT is None:
92 | xT = torch.randn([sampling_number, *self.shape]).to(device=self.device)
93 | if self.model != 'vanilla':
94 | if a is None:
95 | a = torch.randn([sampling_number, self.a_dim]).to(device=self.device)
96 |
97 | sample = xT
98 | for sample in self._one_diffusion_step(sample=sample, a=a, deterministic=self.deterministic):
99 | final = sample
100 |
101 | return final
102 |
103 |
104 | class TwoPhaseDiffusionProcess():
105 | def __init__(self, args, diffusion_fn_1, diffusion_fn_2, device, shape):
106 | '''
107 | beta_1 : beta_1 of diffusion process
108 | beta_T : beta_T of diffusion process
109 | T : step of diffusion process
110 | diffusion_fn : trained diffusion network
111 | shape : data shape
112 | '''
113 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps)
114 | self.alphas = 1 - self.betas
115 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device)
116 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]])
117 | self.shape = shape
118 | self.deterministic = args.deterministic
119 | self.a_dim = args.a_dim
120 | self.model = args.model
121 | self.split_step = args.split_step
122 | self.mode = args.mode
123 |
124 | self.diffusion_fn_1 = diffusion_fn_1.to(device=device)
125 | self.diffusion_fn_2 = diffusion_fn_2.to(device=device)
126 | self.device = device
127 |
128 | def _ddpm_one_diffusion_step(self, x, a=None, t=None):
129 | '''
130 | x : perturbated data
131 | '''
132 | for idx in reversed(range(len(self.alpha_bars))):
133 |
134 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x)
135 | sqrt_tilde_beta = torch.sqrt((1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx])
136 | if t <= self.split_step:
137 | predict_epsilon = self.diffusion_fn_2(x, idx)
138 | else:
139 | predict_epsilon = self.diffusion_fn_1(x, idx, a)
140 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * (x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon)
141 |
142 | x = mu_theta_xt + sqrt_tilde_beta * noise
143 |
144 | yield x
145 |
146 | def _ddim_one_diffusion_step(self, x, a=None, t=None):
147 | '''
148 | x : perturbated data
149 | '''
150 | eta = 0.01
151 | for idx in reversed(range(len(self.alpha_bars))):
152 |
153 | if t <= self.split_step:
154 | predict_epsilon = self.diffusion_fn_2(x, idx)
155 | else:
156 | predict_epsilon = self.diffusion_fn_1(x, idx, a)
157 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx])
158 | if idx == 0:
159 | x = x_0
160 | else:
161 | noise = torch.randn_like(x)
162 | sigma = eta * torch.sqrt((1 - self.alpha_prev_bars[idx-1]) / (1 - self.alpha_bars[idx-1])) * torch.sqrt(self.betas[idx-1])
163 | x = torch.sqrt(self.alpha_prev_bars[idx-1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx-1] - sigma**2) * predict_epsilon
164 | x += sigma * noise
165 | yield x
166 |
167 | def _ddim_one_reverse_diffusion_step(self, x, a=None):
168 | for idx in range(len(self.alpha_bars)-1):
169 | if idx == 0:
170 | yield x
171 | else:
172 | predict_epsilon = self.diffusion_fn_1(x, idx, a)
173 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx])
174 | x = torch.sqrt(self.alpha_prev_bars[idx+1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx+1]) * predict_epsilon
175 | yield x
176 |
177 | def _one_diffusion_step(self, sample, a=None, deterministic=False, t=None):
178 | if deterministic:
179 | return self._ddim_one_diffusion_step(sample, a, t)
180 | else:
181 | return self._ddpm_one_diffusion_step(sample, a, t)
182 |
183 | @torch.no_grad()
184 | def reverse_sampling(self, x0, a=None):
185 | sample = x0
186 | for sample in self._ddim_one_reverse_diffusion_step(sample):
187 | final = sample
188 |
189 | return final
190 |
191 | @torch.no_grad()
192 | def sampling(self, sampling_number=16, xT=None, a=None):
193 | if xT is None:
194 | xT = torch.randn([sampling_number, *self.shape]).to(device=self.device)
195 | if a is None:
196 | a = torch.randn([sampling_number, self.a_dim]).to(device=self.device)
197 |
198 | sample = xT
199 | t = 0
200 | for sample in self._one_diffusion_step(sample=sample, a=a, deterministic=self.deterministic, t=t):
201 | final = sample
202 | t += 1
203 |
204 | return final
205 |
206 |
207 | class LatentDiffusionProcess():
208 | def __init__(self, args, diffusion_fn, device):
209 | '''
210 | beta_1 : beta_1 of diffusion process
211 | beta_T : beta_T of diffusion process
212 | T : step of diffusion process
213 | diffusion_fn : trained diffusion network
214 | '''
215 | self.betas = torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps)
216 | self.alphas = 1 - self.betas
217 | self.alpha_bars = torch.cumprod(1 - torch.linspace(start=args.beta1, end=args.betaT, steps=args.diffusion_steps), dim=0).to(device=device)
218 | self.alpha_prev_bars = torch.cat([torch.Tensor([1]).to(device=device), self.alpha_bars[:-1]])
219 | self.deterministic = args.deterministic
220 | self.a_dim = args.a_dim
221 | self.model = args.model
222 | self.split_step = args.split_step
223 | self.mode = args.mode
224 | self.diffusion_fn = diffusion_fn.to(device=device)
225 | self.device = device
226 |
227 | def _ddpm_one_diffusion_step(self, x):
228 | '''
229 | x : perturbated data
230 | '''
231 | for idx in reversed(range(len(self.alpha_bars))):
232 |
233 | noise = torch.zeros_like(x) if idx == 0 else torch.randn_like(x)
234 | sqrt_tilde_beta = torch.sqrt((1 - self.alpha_prev_bars[idx]) / (1 - self.alpha_bars[idx]) * self.betas[idx])
235 | predict_epsilon = self.diffusion_fn(x, idx)
236 | mu_theta_xt = torch.sqrt(1 / self.alphas[idx]) * (x - self.betas[idx] / torch.sqrt(1 - self.alpha_bars[idx]) * predict_epsilon)
237 |
238 | x = mu_theta_xt + sqrt_tilde_beta * noise
239 |
240 | yield x
241 |
242 | def _ddim_one_diffusion_step(self, x):
243 | '''
244 | x : perturbated data
245 | '''
246 | eta = 0.01
247 | for idx in reversed(range(len(self.alpha_bars))):
248 | predict_epsilon = self.diffusion_fn(x, idx)
249 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx])
250 | if idx == 0:
251 | x = x_0
252 | else:
253 | noise = torch.randn_like(x)
254 | sigma = eta * torch.sqrt((1 - self.alpha_prev_bars[idx-1]) / (1 - self.alpha_bars[idx-1])) * torch.sqrt(self.betas[idx-1])
255 | x = torch.sqrt(self.alpha_prev_bars[idx-1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx-1] - sigma**2) * predict_epsilon
256 | x += sigma * noise
257 | yield x
258 |
259 | def _ddim_one_reverse_diffusion_step(self, x):
260 | for idx in range(len(self.alpha_bars)-1):
261 | if idx == 0:
262 | yield x
263 | else:
264 | predict_epsilon = self.diffusion_fn(x, idx)
265 | x_0 = (x - torch.sqrt(1 - self.alpha_prev_bars[idx]) * predict_epsilon) / torch.sqrt(self.alpha_prev_bars[idx])
266 | x = torch.sqrt(self.alpha_prev_bars[idx+1]) * x_0 + torch.sqrt(1 - self.alpha_prev_bars[idx+1]) * predict_epsilon
267 | yield x
268 |
269 | def _one_diffusion_step(self, sample, deterministic=False):
270 | if deterministic:
271 | return self._ddim_one_diffusion_step(sample)
272 | else:
273 | return self._ddpm_one_diffusion_step(sample)
274 |
275 | @torch.no_grad()
276 | def reverse_sampling(self, x0):
277 | sample = x0
278 | for sample in self._ddim_one_reverse_diffusion_step(sample):
279 | final = sample
280 |
281 | return final
282 |
283 | @torch.no_grad()
284 | def sampling(self, sampling_number=16, xT=None):
285 | if xT is None:
286 | xT = torch.randn([sampling_number, self.a_dim]).to(device=self.device)
287 |
288 | sample = xT
289 | for sample in self._one_diffusion_step(sample=sample, deterministic=self.deterministic):
290 | final = sample
291 |
292 | return final
--------------------------------------------------------------------------------
/save_latent.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0
3 | python run.py --model diff --mode save_latent --a_dim 256 --mmd_weight 0.1 --epochs 50 --dataset celeba --sampling_number 16 --deterministic --prior regular --r_seed 64
4 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 | from torch.nn import functional as F
7 | from torch.utils.data import Dataset
8 | from sklearn.datasets import make_swiss_roll
9 |
10 |
11 | def gaussian_mixture(batch_size, n_dim=2, n_labels=10,
12 | x_var=0.5, y_var=0.1, label_indices=None):
13 | if n_dim % 2 != 0:
14 | raise Exception("n_dim must be a multiple of 2.")
15 |
16 | def sample(x, y, label, n_labels):
17 | shift = 1.4
18 | if label >= n_labels:
19 | label = np.random.randint(0, n_labels)
20 | r = 2.0 * np.pi / float(n_labels) * float(label)
21 | new_x = x * math.cos(r) - y * math.sin(r)
22 | new_y = x * math.sin(r) + y * math.cos(r)
23 | new_x += shift * math.cos(r)
24 | new_y += shift * math.sin(r)
25 | return np.array([new_x, new_y]).reshape((2,))
26 |
27 | x = np.random.normal(0, x_var, (batch_size, n_dim // 2))
28 | y = np.random.normal(0, y_var, (batch_size, n_dim // 2))
29 | z = np.empty((batch_size, n_dim), dtype=np.float32)
30 | for batch in range(batch_size):
31 | for zi in range(n_dim // 2):
32 | if label_indices is not None:
33 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], label_indices[batch], n_labels)
34 | else:
35 | z[batch, zi*2:zi*2+2] = sample(x[batch, zi], y[batch, zi], np.random.randint(0, n_labels), n_labels)
36 |
37 | return z
38 |
39 | def swiss_roll(batch_size, noise=0.5):
40 | return make_swiss_roll(n_samples=batch_size, noise=noise)[0][:, [0, 2]] / 5.
41 |
42 | def cos(a, b):
43 | a = a.view(-1)
44 | b = b.view(-1)
45 | a = F.normalize(a, dim=0)
46 | b = F.normalize(b, dim=0)
47 | return (a * b).sum()
48 |
49 | def generate_exp_string(args) -> str:
50 | root = f'{args.dataset}_{args.a_dim}d'
51 | if args.kld_weight != 0:
52 | root += f'_{args.kld_weight}kld'
53 | if args.use_C:
54 | root += f'_{args.C_max}C'
55 | if args.mmd_weight != 0:
56 | root += f'_{args.mmd_weight}mmd'
57 | if args.prior != 'regular':
58 | root += f'_{args.prior}'
59 | if args.is_bottleneck:
60 | root += '_bottleneck'
61 | return root
62 |
63 |
64 | def seed_everything(r_seed):
65 | print("Set seed: ", r_seed)
66 | random.seed(r_seed)
67 | np.random.seed(r_seed)
68 | torch.manual_seed(r_seed)
69 | torch.cuda.manual_seed(r_seed)
70 | torch.cuda.manual_seed_all(r_seed)
71 | torch.backends.cudnn.deterministic = True
72 |
73 |
74 | @torch.jit.script
75 | def compute_kernel(x, y):
76 | x_size = x.shape[0]
77 | y_size = y.shape[0]
78 | dim = x.shape[1]
79 |
80 | tiled_x = x.view(x_size, 1, dim).repeat(1, y_size, 1)
81 | tiled_y = y.view(1, y_size, dim).repeat(x_size, 1, 1)
82 |
83 | return torch.exp(-torch.mean((tiled_x - tiled_y)**2, dim=2)/dim*1.0)
84 |
85 | @torch.jit.script
86 | def compute_mmd(x, y):
87 | x_kernel = compute_kernel(x, x)
88 | y_kernel = compute_kernel(y, y)
89 | xy_kernel = compute_kernel(x, y)
90 | return torch.mean(x_kernel) + torch.mean(y_kernel) - 2*torch.mean(xy_kernel)
91 |
92 |
93 | class AverageMeter(object):
94 | def __init__(self, name, fmt=':f'):
95 | self.name = name
96 | self.fmt = fmt
97 | self.reset()
98 |
99 | def reset(self):
100 | self.val = 0
101 | self.avg = 0
102 | self.sum = 0
103 | self.count = 0
104 |
105 | def update(self, val, n=1):
106 | self.val = val
107 | self.sum += val * n
108 | self. count += n
109 | self.avg = self.sum / self.count
110 |
111 | def __str__(self):
112 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
113 | return fmtstr.format(**self.__dict__)
114 |
115 |
116 | class ProgressMeter(object):
117 | def __init__(self, num_batches, meters, prefix=""):
118 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
119 | self.meters = meters
120 | self.prefix = prefix
121 |
122 | def display(self, batch):
123 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
124 | entries += [str(meter) for meter in self.meters]
125 | print('\r' + '\t'.join(entries), end='')
126 |
127 | def _get_batch_fmtstr(self, num_batches):
128 | num_digits = len(str(num_batches // 1))
129 | fmt = '{:' + str(num_digits) + 'd}'
130 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
131 |
132 |
133 | class GradualWarmupScheduler(_LRScheduler):
134 | def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
135 | self.multiplier = multiplier
136 | self.total_epoch = warm_epoch
137 | self.after_scheduler = after_scheduler
138 | self.finished = False
139 | self.last_epoch = None
140 | self.base_lrs = None
141 | super().__init__(optimizer)
142 |
143 | def get_lr(self):
144 | if self.last_epoch > self.total_epoch:
145 | if self.after_scheduler:
146 | if not self.finished:
147 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
148 | self.finished = True
149 | return self.after_scheduler.get_lr()
150 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
151 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
152 |
153 | def step(self, epoch=None, metrics=None):
154 | if self.finished and self.after_scheduler:
155 | if epoch is None:
156 | self.after_scheduler.step(None)
157 | else:
158 | self.after_scheduler.step(epoch - self.total_epoch)
159 | else:
160 | return super(GradualWarmupScheduler, self).step(epoch)
161 |
162 |
163 | class LatentDataset(Dataset):
164 | def __init__(self, data_path):
165 | data = np.load(data_path)
166 | self.x = torch.from_numpy(data['all_a']).float()
167 |
168 | def __getitem__(self, index):
169 | return self.x[index]
170 |
171 | def __len__(self):
172 | return len(self.x)
--------------------------------------------------------------------------------