├── .gitignore ├── README.md ├── models ├── __init__.py ├── ccvae.py └── networks.py ├── ss_vae.py └── utils ├── __init__.py └── dataset_cached.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | #code 132 | *.vscode 133 | 134 | data 135 | 136 | .vector_cache/ 137 | 138 | 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Capturing label characteristics in VAEs 3 | 4 | Pytorch repository for [Capturing Label Chacteristics in VAEs ICLR 2021](https://openreview.net/pdf?id=wQRlSUZ5V7B). We kindly ask that you cite our work if you plan to use this codebase: 5 | 6 | @inproceedings{ 7 | joy2021capturing, 8 | title={Capturing Label Characteristics in {\{}VAE{\}}s}, 9 | author={Tom Joy and Sebastian Schmon and Philip Torr and Siddharth N and Tom Rainforth}, 10 | booktitle={International Conference on Learning Representations}, 11 | year={2021}, 12 | url={https://openreview.net/forum?id=wQRlSUZ5V7B} 13 | } 14 | 15 | ## Usage 16 | 17 | Ensure that CelebA is in the directory `data/datasets/celeba`, such that the path `data/datasets/celeba/celeba/img_align_celeba/*` is accessable. 18 | 19 | To train, run: 20 | 21 | > `python ss_vay.py -sup --cuda>` 22 | 23 | where `` is the fraction of supervised data (e.g. 0.004, 0.06, 0.2, 1.0). 24 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thwjoy/ccvae_pytorch/b4af1d3273340044b6ab20b994dc3f4e92503505/models/__init__.py -------------------------------------------------------------------------------- /models/ccvae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision.utils import make_grid, save_image 6 | import torch.distributions as dist 7 | import os 8 | from utils.dataset_cached import CELEBA_EASY_LABELS 9 | from .networks import (Classifier, CondPrior, 10 | CELEBADecoder, CELEBAEncoder) 11 | 12 | 13 | def compute_kl(locs_q, scale_q, locs_p=None, scale_p=None): 14 | """ 15 | Computes the KL(q||p) 16 | """ 17 | if locs_p is None: 18 | locs_p = torch.zeros_like(locs_q) 19 | if scale_p is None: 20 | scale_p = torch.ones_like(scale_q) 21 | 22 | dist_q = dist.Normal(locs_q, scale_q) 23 | dist_p = dist.Normal(locs_p, scale_p) 24 | return dist.kl.kl_divergence(dist_q, dist_p).sum(dim=-1) 25 | 26 | def img_log_likelihood(recon, xs): 27 | return dist.Laplace(recon, torch.ones_like(recon)).log_prob(xs).sum(dim=(1,2,3)) 28 | 29 | class CCVAE(nn.Module): 30 | """ 31 | CCVAE 32 | """ 33 | def __init__(self, z_dim, num_classes, 34 | im_shape, use_cuda, prior_fn): 35 | super(CCVAE, self).__init__() 36 | self.z_dim = z_dim 37 | self.z_classify = num_classes 38 | self.z_style = z_dim - num_classes 39 | self.im_shape = im_shape 40 | self.use_cuda = use_cuda 41 | self.num_classes = num_classes 42 | self.ones = torch.ones(1, self.z_style) 43 | self.zeros = torch.zeros(1, self.z_style) 44 | self.y_prior_params = prior_fn() 45 | 46 | self.encoder = CELEBAEncoder(self.z_dim) 47 | self.decoder = CELEBADecoder(self.z_dim) 48 | self.classifier = Classifier(self.num_classes) 49 | self.cond_prior = CondPrior(self.num_classes) 50 | 51 | if self.use_cuda: 52 | self.ones = self.ones.cuda() 53 | self.zeros = self.zeros.cuda() 54 | self.y_prior_params = self.y_prior_params.cuda() 55 | self.cuda() 56 | 57 | def unsup(self, x): 58 | bs = x.shape[0] 59 | #inference 60 | post_params = self.encoder(x) 61 | z = dist.Normal(*post_params).rsample() 62 | zc, zs = z.split([self.z_classify, self.z_style], 1) 63 | qyzc = dist.Bernoulli(logits=self.classifier(zc)) 64 | y = qyzc.sample() 65 | log_qy = qyzc.log_prob(y).sum(dim=-1) 66 | 67 | # compute kl 68 | locs_p_zc, scales_p_zc = self.cond_prior(y) 69 | prior_params = (torch.cat([locs_p_zc, self.zeros.expand(bs, -1)], dim=1), 70 | torch.cat([scales_p_zc, self.ones.expand(bs, -1)], dim=1)) 71 | kl = compute_kl(*post_params, *prior_params) 72 | 73 | #compute log probs for x and y 74 | recon = self.decoder(z) 75 | log_py = dist.Bernoulli(self.y_prior_params.expand(bs, -1)).log_prob(y).sum(dim=-1) 76 | elbo = (img_log_likelihood(recon, x) + log_py - kl - log_qy).mean() 77 | return -elbo 78 | 79 | def sup(self, x, y): 80 | bs = x.shape[0] 81 | #inference 82 | post_params = self.encoder(x) 83 | z = dist.Normal(*post_params).rsample() 84 | zc, zs = z.split([self.z_classify, self.z_style], 1) 85 | qyzc = dist.Bernoulli(logits=self.classifier(zc)) 86 | log_qyzc = qyzc.log_prob(y).sum(dim=-1) 87 | 88 | # compute kl 89 | locs_p_zc, scales_p_zc = self.cond_prior(y) 90 | prior_params = (torch.cat([locs_p_zc, self.zeros.expand(bs, -1)], dim=1), 91 | torch.cat([scales_p_zc, self.ones.expand(bs, -1)], dim=1)) 92 | #prior_params = (self.zeros.expand(bs, -1), self.ones.expand(bs, -1)) 93 | kl = compute_kl(*post_params, *prior_params) 94 | 95 | #compute log probs for x and y 96 | recon = self.decoder(z) 97 | log_py = dist.Bernoulli(self.y_prior_params.expand(bs, -1)).log_prob(y).sum(dim=-1) 98 | log_qyx = self.classifier_loss(x, y) 99 | log_pxz = img_log_likelihood(recon, x) 100 | 101 | # we only want gradients wrt to params of qyz, so stop them propogating to qzx 102 | log_qyzc_ = dist.Bernoulli(logits=self.classifier(zc.detach())).log_prob(y).sum(dim=-1) 103 | w = torch.exp(log_qyzc_ - log_qyx) 104 | elbo = (w * (log_pxz - kl - log_qyzc) + log_py + log_qyx).mean() 105 | return -elbo 106 | 107 | def classifier_loss(self, x, y, k=100): 108 | """ 109 | Computes the classifier loss. 110 | """ 111 | zc, _ = dist.Normal(*self.encoder(x)).rsample(torch.tensor([k])).split([self.z_classify, self.z_style], -1) 112 | logits = self.classifier(zc.view(-1, self.z_classify)) 113 | d = dist.Bernoulli(logits=logits) 114 | y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes) 115 | lqy_z = d.log_prob(y).view(k, x.shape[0], self.num_classes).sum(dim=-1) 116 | lqy_x = torch.logsumexp(lqy_z, dim=0) - np.log(k) 117 | return lqy_x 118 | 119 | def reconstruct_img(self, x): 120 | return self.decoder(dist.Normal(*self.encoder(x)).rsample()) 121 | 122 | def classifier_acc(self, x, y=None, k=1): 123 | zc, _ = dist.Normal(*self.encoder(x)).rsample(torch.tensor([k])).split([self.z_classify, self.z_style], -1) 124 | logits = self.classifier(zc.view(-1, self.z_classify)).view(-1, self.num_classes) 125 | y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes) 126 | preds = torch.round(torch.sigmoid(logits)) 127 | acc = (preds.eq(y)).float().mean() 128 | return acc 129 | 130 | def save_models(self, path='./data'): 131 | torch.save(self.encoder, os.path.join(path,'encoder.pt')) 132 | torch.save(self.decoder, os.path.join(path,'decoder.pt')) 133 | torch.save(self.classifier, os.path.join(path,'classifier.pt')) 134 | torch.save(self.cond_prior, os.path.join(path,'cond_prior.pt')) 135 | 136 | def accuracy(self, data_loader, *args, **kwargs): 137 | acc = 0.0 138 | for (x, y) in data_loader: 139 | if self.use_cuda: 140 | x, y = x.cuda(), y.cuda() 141 | batch_acc = self.classifier_acc(x, y) 142 | acc += batch_acc 143 | return acc / len(data_loader) 144 | 145 | def latent_walk(self, image, save_dir): 146 | """ 147 | Does latent walk between all possible classes 148 | """ 149 | mult = 5 150 | num_imgs = 5 151 | z_ = dist.Normal(*self.encoder(image.unsqueeze(0))).sample() 152 | for i in range(self.num_classes): 153 | y_1 = torch.zeros(1, self.num_classes) 154 | if self.use_cuda: 155 | y_1 = y_1.cuda() 156 | locs_false, scales_false = self.cond_prior(y_1) 157 | y_1[:, i].fill_(1.0) 158 | locs_true, scales_true = self.cond_prior(y_1) 159 | sign = torch.sign(locs_true[:, i] - locs_false[:, i]) 160 | # y axis 161 | z_1_false_lim = (locs_false[:, i] - mult * sign * scales_false[:, i]).item() 162 | z_1_true_lim = (locs_true[:, i] + mult * sign * scales_true[:, i]).item() 163 | for j in range(self.num_classes): 164 | z = z_.clone() 165 | z = z.expand(num_imgs**2, -1).contiguous() 166 | if i == j: 167 | continue 168 | y_2 = torch.zeros(1, self.num_classes) 169 | if self.use_cuda: 170 | y_2 = y_2.cuda() 171 | locs_false, scales_false = self.cond_prior(y_2) 172 | y_2[:, i].fill_(1.0) 173 | locs_true, scales_true = self.cond_prior(y_2) 174 | sign = torch.sign(locs_true[:, i] - locs_false[:, i]) 175 | # x axis 176 | z_2_false_lim = (locs_false[:, i] - mult * sign * scales_false[:, i]).item() 177 | z_2_true_lim = (locs_true[:, i] + mult * sign * scales_true[:, i]).item() 178 | 179 | # construct grid 180 | range_1 = torch.linspace(z_1_false_lim, z_1_true_lim, num_imgs) 181 | range_2 = torch.linspace(z_2_false_lim, z_2_true_lim, num_imgs) 182 | grid_1, grid_2 = torch.meshgrid(range_1, range_2) 183 | z[:, i] = grid_1.reshape(-1) 184 | z[:, j] = grid_2.reshape(-1) 185 | 186 | imgs = self.decoder(z).view(-1, *self.im_shape) 187 | grid = make_grid(imgs, nrow=num_imgs) 188 | save_image(grid, os.path.join(save_dir, "latent_walk_%s_and_%s.png" 189 | % (CELEBA_EASY_LABELS[i], CELEBA_EASY_LABELS[j]))) 190 | 191 | mult = 8 192 | for j in range(self.num_classes): 193 | z = z_.clone() 194 | z = z.expand(10, -1).contiguous() 195 | y = torch.zeros(1, self.num_classes) 196 | if self.use_cuda: 197 | y = y.cuda() 198 | locs_false, scales_false = self.cond_prior(y) 199 | y[:, i].fill_(1.0) 200 | locs_true, scales_true = self.cond_prior(y) 201 | sign = torch.sign(locs_true[:, i] - locs_false[:, i]) 202 | z_false_lim = (locs_false[:, i] - mult * sign * scales_false[:, i]).item() 203 | z_true_lim = (locs_true[:, i] + mult * sign * scales_true[:, i]).item() 204 | range_ = torch.linspace(z_false_lim, z_true_lim, 10) 205 | z[:, j] = range_ 206 | 207 | imgs = self.decoder(z).view(-1, *self.im_shape) 208 | grid = make_grid(imgs, nrow=10) 209 | save_image(grid, os.path.join(save_dir, "latent_walk_%s.png" 210 | % CELEBA_EASY_LABELS[j])) 211 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class View(nn.Module): 7 | def __init__(self, size): 8 | super(View, self).__init__() 9 | self.size = size 10 | 11 | def forward(self, tensor): 12 | return tensor.view(self.size) 13 | 14 | class CELEBAEncoder(nn.Module): 15 | def __init__(self, z_dim, hidden_dim=256, *args, **kwargs): 16 | super().__init__() 17 | # setup the three linear transformations used 18 | self.z_dim = z_dim 19 | self.encoder = nn.Sequential( 20 | nn.Conv2d(3, 32, 4, 2, 1), 21 | nn.ReLU(True), 22 | nn.Conv2d(32, 32, 4, 2, 1), 23 | nn.ReLU(True), 24 | nn.Conv2d(32, 64, 4, 2, 1), 25 | nn.ReLU(True), 26 | nn.Conv2d(64, 128, 4, 2, 1), 27 | nn.ReLU(True), 28 | nn.Conv2d(128, hidden_dim, 4, 1), 29 | nn.ReLU(True), 30 | View((-1, hidden_dim*1*1)) 31 | ) 32 | 33 | self.locs = nn.Linear(hidden_dim, z_dim) 34 | self.scales = nn.Linear(hidden_dim, z_dim) 35 | 36 | 37 | def forward(self, x): 38 | hidden = self.encoder(x) 39 | return self.locs(hidden), torch.clamp(F.softplus(self.scales(hidden)), min=1e-3) 40 | 41 | 42 | class CELEBADecoder(nn.Module): 43 | def __init__(self, z_dim, hidden_dim=256, *args, **kwargs): 44 | super().__init__() 45 | # setup the two linear transformations used 46 | self.decoder = nn.Sequential( 47 | nn.Linear(z_dim, hidden_dim), 48 | View((-1, hidden_dim, 1, 1)), 49 | nn.ReLU(True), 50 | nn.ConvTranspose2d(hidden_dim, 128, 4), 51 | nn.ReLU(True), 52 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 53 | nn.ReLU(True), 54 | nn.ConvTranspose2d(64, 32, 4, 2, 1), 55 | nn.ReLU(True), 56 | nn.ConvTranspose2d(32, 32, 4, 2, 1), 57 | nn.ReLU(True), 58 | nn.ConvTranspose2d(32, 3, 4, 2, 1), 59 | nn.Sigmoid() 60 | ) 61 | 62 | def forward(self, z): 63 | m = self.decoder(z) 64 | return m 65 | 66 | 67 | class Diagonal(nn.Module): 68 | def __init__(self, dim): 69 | super(Diagonal, self).__init__() 70 | self.dim = dim 71 | self.weight = nn.Parameter(torch.ones(self.dim)) 72 | self.bias = nn.Parameter(torch.zeros(self.dim)) 73 | 74 | def forward(self, x): 75 | return x * self.weight + self.bias 76 | 77 | class Classifier(nn.Module): 78 | def __init__(self, dim): 79 | super(Classifier, self).__init__() 80 | self.dim = dim 81 | self.diag = Diagonal(self.dim) 82 | 83 | def forward(self, x): 84 | return self.diag(x) 85 | 86 | class CondPrior(nn.Module): 87 | def __init__(self, dim): 88 | super(CondPrior, self).__init__() 89 | self.dim = dim 90 | self.diag_loc_true = nn.Parameter(torch.zeros(self.dim)) 91 | self.diag_loc_false = nn.Parameter(torch.zeros(self.dim)) 92 | self.diag_scale_true = nn.Parameter(torch.ones(self.dim)) 93 | self.diag_scale_false = nn.Parameter(torch.ones(self.dim)) 94 | 95 | def forward(self, x): 96 | loc = x * self.diag_loc_true + (1 - x) * self.diag_loc_false 97 | scale = x * self.diag_scale_true + (1 - x) * self.diag_scale_false 98 | return loc, torch.clamp(F.softplus(scale), min=1e-3) 99 | 100 | 101 | -------------------------------------------------------------------------------- /ss_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision.utils import make_grid, save_image 5 | from tqdm import tqdm 6 | 7 | from utils.dataset_cached import setup_data_loaders, CELEBA_EASY_LABELS 8 | from utils.dataset_cached import CELEBACached 9 | from models.ccvae import CCVAE 10 | 11 | import numpy as np 12 | import os 13 | 14 | 15 | def main(args): 16 | """ 17 | run inference for SS-VAE 18 | :param args: arguments for SS-VAE 19 | :return: None 20 | """ 21 | 22 | im_shape = (3, 64, 64) 23 | 24 | data_loaders = setup_data_loaders(args.cuda, 25 | args.batch_size, 26 | cache_data=True, 27 | sup_frac=args.sup_frac, 28 | root='./data/datasets/celeba') 29 | 30 | 31 | cc_vae = CCVAE(z_dim=args.z_dim, 32 | num_classes=len(CELEBA_EASY_LABELS), 33 | im_shape=im_shape, 34 | use_cuda=args.cuda, 35 | prior_fn=data_loaders['test'].dataset.prior_fn) 36 | 37 | optim = torch.optim.Adam(params=cc_vae.parameters(), lr=args.learning_rate) 38 | 39 | # run inference for a certain number of epochs 40 | for epoch in range(0, args.num_epochs): 41 | 42 | # # # compute number of batches for an epoch 43 | if args.sup_frac == 1.0: # fullt supervised 44 | batches_per_epoch = len(data_loaders["sup"]) 45 | period_sup_batches = 1 46 | sup_batches = batches_per_epoch 47 | elif args.sup_frac > 0.0: # semi-supervised 48 | sup_batches = len(data_loaders["sup"]) 49 | unsup_batches = len(data_loaders["unsup"]) 50 | batches_per_epoch = sup_batches + unsup_batches 51 | period_sup_batches = int(batches_per_epoch / sup_batches) 52 | elif args.sup_frac == 0.0: # unsupervised 53 | sup_batches = 0.0 54 | batches_per_epoch = len(data_loaders["unsup"]) 55 | period_sup_batches = np.Inf 56 | else: 57 | assert False, "Data frac not correct" 58 | 59 | # initialize variables to store loss values 60 | epoch_losses_sup = 0.0 61 | epoch_losses_unsup = 0.0 62 | 63 | # setup the iterators for training data loaders 64 | if args.sup_frac != 0.0: 65 | sup_iter = iter(data_loaders["sup"]) 66 | if args.sup_frac != 1.0: 67 | unsup_iter = iter(data_loaders["unsup"]) 68 | 69 | # count the number of supervised batches seen in this epoch 70 | ctr_sup = 0 71 | 72 | for i in tqdm(range(batches_per_epoch)): 73 | # whether this batch is supervised or not 74 | is_supervised = (i % period_sup_batches == 0) and ctr_sup < sup_batches 75 | # extract the corresponding batch 76 | if is_supervised: 77 | (xs, ys) = next(sup_iter) 78 | ctr_sup += 1 79 | else: 80 | (xs, ys) = next(unsup_iter) 81 | 82 | if args.cuda: 83 | xs, ys = xs.cuda(), ys.cuda() 84 | 85 | if is_supervised: 86 | loss = cc_vae.sup(xs, ys) 87 | epoch_losses_sup += loss.detach().item() 88 | else: 89 | loss = cc_vae.unsup(xs) 90 | epoch_losses_unsup += loss.detach().item() 91 | 92 | loss.backward() 93 | optim.step() 94 | optim.zero_grad() 95 | 96 | if args.sup_frac != 0.0: 97 | with torch.no_grad(): 98 | validation_accuracy = cc_vae.accuracy(data_loaders['valid']) 99 | else: 100 | validation_accuracy = np.nan 101 | 102 | with torch.no_grad(): 103 | # save some reconstructions 104 | img = CELEBACached.fixed_imgs 105 | if args.cuda: 106 | img = img.cuda() 107 | recon = cc_vae.reconstruct_img(img).view(-1, *im_shape) 108 | save_image(make_grid(recon, nrow=8), './data/output/recon.png') 109 | save_image(make_grid(img, nrow=8), './data/output/img.png') 110 | 111 | print("[Epoch %03d] Sup Loss %.3f, Unsup Loss %.3f, Val Acc %.3f" % 112 | (epoch, epoch_losses_sup, epoch_losses_unsup, validation_accuracy)) 113 | cc_vae.save_models(args.data_dir) 114 | test_acc = cc_vae.accuracy(data_loaders['test']) 115 | print("Test acc %.3f" % test_acc) 116 | cc_vae.latent_walk(img[5], './data/output') 117 | return 118 | 119 | def parser_args(parser): 120 | parser.add_argument('--cuda', action='store_true', 121 | help="use GPU(s) to speed up training") 122 | parser.add_argument('-n', '--num-epochs', default=200, type=int, 123 | help="number of epochs to run") 124 | parser.add_argument('-sup', '--sup-frac', default=1.0, 125 | type=float, help="supervised fractional amount of the data i.e. " 126 | "how many of the images have supervised labels." 127 | "Should be a multiple of train_size / batch_size") 128 | parser.add_argument('-zd', '--z_dim', default=45, type=int, 129 | help="size of the tensor representing the latent variable z " 130 | "variable (handwriting style for our MNIST dataset)") 131 | parser.add_argument('-lr', '--learning-rate', default=1e-4, type=float, 132 | help="learning rate for Adam optimizer") 133 | parser.add_argument('-bs', '--batch-size', default=200, type=int, 134 | help="number of images (and labels) to be considered in a batch") 135 | parser.add_argument('--data_dir', type=str, default='./data', 136 | help='Data path') 137 | return parser 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser() 141 | parser = parser_args(parser) 142 | args = parser.parse_args() 143 | 144 | main(args) 145 | 146 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thwjoy/ccvae_pytorch/b4af1d3273340044b6ab20b994dc3f4e92503505/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataset_cached.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import PIL 4 | from functools import reduce 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torchvision.datasets import CelebA 9 | import torchvision.transforms as transforms 10 | 11 | 12 | def split_celeba(X, y, sup_frac, validation_num): 13 | """ 14 | splits celeba 15 | """ 16 | 17 | # validation set is the last 10,000 examples 18 | X_valid = X[-validation_num:] 19 | y_valid = y[-validation_num:] 20 | 21 | X = X[0:-validation_num] 22 | y = y[0:-validation_num] 23 | 24 | if sup_frac == 0.0: 25 | return None, None, X, y, X_valid, y_valid 26 | 27 | if sup_frac == 1.0: 28 | return X, y, None, None, X_valid, y_valid 29 | 30 | split = int(sup_frac * len(X)) 31 | X_sup = X[0:split] 32 | y_sup = y[0:split] 33 | X_unsup = X[split:] 34 | y_unsup = y[split:] 35 | 36 | return X_sup, y_sup, X_unsup, y_unsup, X_valid, y_valid 37 | 38 | 39 | CELEBA_LABELS = ['5_o_Clock_Shadow', 'Arched_Eyebrows','Attractive','Bags_Under_Eyes','Bald','Bangs','Big_Lips','Big_Nose','Black_Hair','Blond_Hair','Blurry','Brown_Hair','Bushy_Eyebrows', \ 40 | 'Chubby', 'Double_Chin','Eyeglasses','Goatee','Gray_Hair','Heavy_Makeup','High_Cheekbones','Male','Mouth_Slightly_Open','Mustache','Narrow_Eyes', 'No_Beard', 'Oval_Face', \ 41 | 'Pale_Skin','Pointy_Nose','Receding_Hairline','Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', \ 42 | 'Wearing_Necklace', 'Wearing_Necktie', 'Young'] 43 | 44 | CELEBA_EASY_LABELS = ['Arched_Eyebrows', 'Bags_Under_Eyes', 'Bangs', 'Black_Hair', 'Blond_Hair','Brown_Hair','Bushy_Eyebrows', 'Chubby','Eyeglasses', 'Heavy_Makeup', 'Male', \ 45 | 'No_Beard', 'Pale_Skin', 'Receding_Hairline', 'Smiling', 'Wavy_Hair', 'Wearing_Necktie', 'Young'] 46 | 47 | 48 | class CELEBACached(CelebA): 49 | """ 50 | a wrapper around CelebA to load and cache the transformed data 51 | once at the beginning of the inference 52 | """ 53 | # static class variables for caching training data 54 | train_data_sup, train_labels_sup = None, None 55 | train_data_unsup, train_labels_unsup = None, None 56 | train_data, test_labels = None, None 57 | prior = torch.ones(1, len(CELEBA_EASY_LABELS)) / 2 58 | fixed_imgs = None 59 | validation_size = 20000 60 | data_valid, labels_valid = None, None 61 | 62 | def prior_fn(self): 63 | return CELEBACached.prior 64 | 65 | def clear_cache(): 66 | CELEBACached.train_data, CELEBACached.test_labels = None, None 67 | 68 | def __init__(self, mode, sup_frac=None, *args, **kwargs): 69 | super(CELEBACached, self).__init__(split='train' if mode in ["sup", "unsup", "valid"] else 'test', *args, **kwargs) 70 | self.sub_label_inds = [i for i in range(len(CELEBA_LABELS)) if CELEBA_LABELS[i] in CELEBA_EASY_LABELS] 71 | self.mode = mode 72 | self.transform = transforms.Compose([ 73 | transforms.Resize((64, 64)), 74 | transforms.ToTensor() 75 | ]) 76 | 77 | assert mode in ["sup", "unsup", "test", "valid"], "invalid train/test option values" 78 | 79 | if mode in ["sup", "unsup", "valid"]: 80 | 81 | if CELEBACached.train_data is None: 82 | print("Splitting Dataset") 83 | 84 | CELEBACached.train_data = self.filename 85 | CELEBACached.train_targets = self.attr 86 | 87 | CELEBACached.train_data_sup, CELEBACached.train_labels_sup, \ 88 | CELEBACached.train_data_unsup, CELEBACached.train_labels_unsup, \ 89 | CELEBACached.data_valid, CELEBACached.labels_valid = \ 90 | split_celeba(CELEBACached.train_data, CELEBACached.train_targets, 91 | sup_frac, CELEBACached.validation_size) 92 | 93 | if mode == "sup": 94 | self.data, self.targets = CELEBACached.train_data_sup, CELEBACached.train_labels_sup 95 | CELEBACached.prior = torch.mean(self.targets[:, self.sub_label_inds].float(), dim=0) 96 | elif mode == "unsup": 97 | self.data = CELEBACached.train_data_unsup 98 | # making sure that the unsupervised labels are not available to inference 99 | self.targets = CELEBACached.train_labels_unsup * np.nan 100 | else: 101 | self.data, self.targets = CELEBACached.data_valid, CELEBACached.labels_valid 102 | 103 | else: 104 | self.data = self.filename 105 | self.targets = self.attr 106 | 107 | # create a batch of fixed images 108 | if CELEBACached.fixed_imgs is None: 109 | temp = [] 110 | for i, f in enumerate(self.data[:64]): 111 | temp.append(self.transform(PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", f)))) 112 | CELEBACached.fixed_imgs = torch.stack(temp, dim=0) 113 | 114 | def __getitem__(self, index): 115 | 116 | X = self.transform(PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.data[index]))) 117 | 118 | target = self.targets[index].float() 119 | target = target[self.sub_label_inds] 120 | 121 | return X, target 122 | 123 | def __len__(self): 124 | return len(self.data) 125 | 126 | 127 | def setup_data_loaders(use_cuda, batch_size, sup_frac=1.0, root=None, cache_data=False, **kwargs): 128 | """ 129 | helper function for setting up pytorch data loaders for a semi-supervised dataset 130 | :param use_cuda: use GPU(s) for training 131 | :param batch_size: size of a batch of data to output when iterating over the data loaders 132 | :param sup_frac: fraction of supervised data examples 133 | :param cache_data: saves dataset to memory, prevents reading from file every time 134 | :param kwargs: other params for the pytorch data loader 135 | :return: three data loaders: (supervised data for training, un-supervised data for training, 136 | supervised data for testing) 137 | """ 138 | 139 | if root is None: 140 | root = get_data_directory(__file__) 141 | if 'num_workers' not in kwargs: 142 | kwargs = {'num_workers': 4, 'pin_memory': True} 143 | cached_data = {} 144 | loaders = {} 145 | 146 | #clear previous cache 147 | CELEBACached.clear_cache() 148 | 149 | if sup_frac == 0.0: 150 | modes = ["unsup", "test"] 151 | elif sup_frac == 1.0: 152 | modes = ["sup", "test", "valid"] 153 | else: 154 | modes = ["unsup", "test", "sup", "valid"] 155 | 156 | for mode in modes: 157 | cached_data[mode] = CELEBACached(root=root, mode=mode, download=True, sup_frac=sup_frac) 158 | loaders[mode] = DataLoader(cached_data[mode], batch_size=batch_size, shuffle=True, **kwargs) 159 | return loaders 160 | 161 | --------------------------------------------------------------------------------