├── .gitignore ├── README.md ├── data.zip ├── datasets ├── __init__.py ├── mnist.py ├── sbmnist.py └── toy.py ├── ivae_ardae.py ├── models ├── __init__.py ├── aux.py ├── dae │ └── mlp.py ├── graddae │ └── mlp.py ├── ivae │ ├── auxconv.py │ ├── auxmnist.py │ ├── auxresconv.py │ ├── auxresconv2.py │ ├── auxtoy.py │ ├── conv.py │ ├── mnist.py │ ├── resconv.py │ └── toy.py ├── layers.py ├── layers2.py ├── reparam.py ├── resdae │ └── mlp.py └── vae │ ├── auxconv.py │ ├── auxmnist.py │ ├── auxresconv.py │ ├── auxtoy.py │ ├── conv.py │ ├── mnist.py │ ├── resconv.py │ └── toy.py ├── notebooks ├── ardae_fit.ipynb ├── ardae_toy.ipynb └── dae_toy.ipynb ├── run_vae_25gaussians.sh ├── run_vae_dbmnist.sh ├── run_vae_sbmnist.sh ├── utils ├── __init__.py ├── distributions.py ├── energy.py ├── jacobian_clamping.py ├── lr_scheduler.py ├── models.py ├── msc.py ├── optim.py ├── sample.py ├── stat.py ├── vae.py └── visualization.py └── vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sw* 3 | *.pdf 4 | *.DS_Store 5 | *.out 6 | *.png 7 | *.pt 8 | 9 | data 10 | experiments 11 | experiments* 12 | cache 13 | cache* 14 | 15 | contrib 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AR-DAE: Towards Unbiased Neural Entropy Gradient Estimation 2 | Pytorch implementation of AR-DAE on our paper: 3 | > Jae Hyun Lim, Aaron Courville, Christopher Pal, Chin-Wei Huang, [*AR-DAE: Towards Unbiased Neural Entropy Gradient Estimation*](https://arxiv.org/abs/2006.05164) (2020) 4 | 5 | ## Toy example of AR-DAE 6 | Example code to train AR-DAE on swiss roll dataset: 7 | [ipython-notebook](https://github.com/lim0606/pytorch-ardae-vae/tree/master/notebooks/ardae_toy.ipynb) 8 | 9 | ## Energy function fitting with AR-DAE 10 | Example code to train an implicit sampler using AR-DAE-based entropy gradient estimator: 11 | [ipython-notebook](https://github.com/lim0606/pytorch-ardae-vae/tree/master/notebooks/ardae_fit.ipynb) 12 | 13 | ## AR-DAE VAE 14 | ### Getting Started 15 | 16 | #### Requirements 17 | `python>=3.6` 18 | `pytorch==1.2.0` 19 | `tensorflow` (for tensorboardX) 20 | `tensorboardX` 21 | `git+https://github.com/lim0606/contrib.git` 22 | 23 | #### Dataset 24 | ```sh 25 | # http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist 26 | unzip data.zip -d . 27 | ``` 28 | #### Structure 29 | - `data`: data folder 30 | - `datasets`: dataloader definitions 31 | - `models`: model definitions 32 | - `utils`: miscelleneous functions 33 | - `ivae_ardae.py`: main function to train model (ardae vae) 34 | - `vae.py`: main function to train baselines (vae) 35 | 36 | ### Experiments 37 | #### Train 38 | - For example, you can train a vae with implicit posterior for a mixutre of 25 gaussians as follows, 39 | ```sh 40 | python ivae_ardae.py \ 41 | --cache experiments/25gaussians \ 42 | --dataset 25gaussians --nheight 1 --nchannels 2 \ 43 | --model mlp-concat --model-z-dim 2 --model-h-dim 256 --model-n-layers 2 --model-nonlin relu --model-n-dim 10 --model-clip-z0-logvar none --model-clip-z-logvar none \ 44 | --cdae mlp-grad --cdae-h-dim 256 --cdae-n-layers 3 --cdae-nonlin softplus --cdae-ctx-type lt0 \ 45 | --train-batch-size 512 --eval-batch-size 1 --train-nz-cdae 256 --train-nz-model 1 \ 46 | --delta 0.1 --std-scale 10000 --num-cdae-updates 1 \ 47 | --m-lr 0.0001 --m-optimizer adam --m-momentum 0.5 --m-beta1 0.5 \ 48 | --d-lr 0.0001 --d-optimizer rmsprop --d-momentum 0.5 --d-beta1 0.5 \ 49 | --epochs 16 \ 50 | --eval-iws-interval 0 --iws-samples 64 --log-interval 100 --vis-interval 100 --ckpt-interval 1000 --exp-num 1 51 | ``` 52 | For more information, please find example scripts, `run_vae_25gaussians.sh`, `run_vae_dbmnist.sh`, and `run_vae_sbmnist.sh`. 53 | 54 | ## SAC-AR-DAE 55 | please find the code at https://github.com/lim0606/pytorch-ardae-rl 56 | 57 | ## Contact 58 | For questions and comments, feel free to contact [Jae Hyun Lim](mailto:jae.hyun.lim@umontreal.ca) and [Chin-Wei Huang](mailto:chin-wei.huang@umontreal.ca). 59 | 60 | ## License 61 | MIT License 62 | 63 | ## Reference 64 | ``` 65 | @article{jaehyun2020ardae, 66 | title={{AR-DAE}: Towards Unbiased Neural Entropy Gradient Estimation}, 67 | author={Jae Hyun Lim and 68 | Aaron Courville and 69 | Christopher J. Pal and 70 | Chin-Wei Huang}, 71 | journal={arXiv preprint arXiv:2006.05164}, 72 | year={2020} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim0606/pytorch-ardae-vae/52f460a90fa5822692031ab7dcca39fa9168988e/data.zip -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.toy import get_toy_dataset 2 | from datasets.mnist import get_image_dataset as get_mnist_dataset 3 | #from datasets.image import get_image_dataset 4 | 5 | def get_dataset(dataset, train_batch_size, eval_batch_size=None, cuda=False, final_mode=False): 6 | if dataset in ['swissroll', '25gaussians']: 7 | assert final_mode == False 8 | return get_toy_dataset(dataset, train_batch_size, eval_batch_size, cuda) 9 | elif dataset in ['mnist', 'sbmnist', 'dbmnist', 'dbmnist-val5k', 'cmnist',]: 10 | return get_mnist_dataset(dataset, train_batch_size, eval_batch_size, cuda, final_mode=final_mode) 11 | #elif dataset in ['celeba',]: 12 | # assert final_mode == False 13 | # return get_image_dataset(dataset, train_batch_size, eval_batch_size, cuda) 14 | else: 15 | raise NotImplementationError('dataset: {}'.format(dataset)) 16 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torchvision import datasets, transforms 5 | from datasets.sbmnist import load_sbmnist_image 6 | 7 | 8 | class StackedMNIST(datasets.MNIST): 9 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False): 10 | super().__init__(root=root, transform=transform, target_transform=target_transform, download=download) 11 | 12 | def __getitem__(self, index): 13 | # get indices 14 | mnist_size = self.__len__() 15 | indices = np.random.randint(mnist_size, size=2) 16 | index1 = indices[0] 17 | index2 = indices[1] 18 | index3 = index 19 | 20 | # get item 21 | img1, target1 = super().__getitem__(index1) 22 | img2, target2 = super().__getitem__(index2) 23 | img3, target3 = super().__getitem__(index3) 24 | target = 100*target1 + 10*target2 + 1*target3 25 | img = torch.cat([img1, img2, img3], dim=0) 26 | return img, target 27 | 28 | def get_mnist_transform(image_size=28, binary=False, center=False): 29 | # resize image 30 | if image_size != 28: 31 | trsfms = [transforms.Resize(image_size)] 32 | else: 33 | trsfms = [] 34 | 35 | # default 36 | trsfms += [transforms.ToTensor()] 37 | 38 | # binary 39 | if binary: 40 | trsfms += [torch.bernoulli] 41 | 42 | # center 43 | if center: 44 | trsfms += [transforms.Normalize((0.5,), (0.5,))] 45 | 46 | # return 47 | return transforms.Compose(trsfms) 48 | 49 | def get_mnist(train_batch_size, eval_batch_size, dataset, kwargs, binary=False, center=False, image_size=28, val_size=10000, final_mode=False): 50 | assert dataset in ['mnist', 'cmnist', 'dbmnist', 'dbmnist-val5k',] 51 | if dataset in ['mnist', 'cmnist', 'dbmnist', 'dbmnist-val5k']: 52 | DATASET = datasets.MNIST 53 | nclasses = 10 54 | else: 55 | raise NotImplementedError 56 | 57 | # init dataset (train / val) 58 | train_dataset = DATASET('data', train=True, download=True, transform=get_mnist_transform(binary=binary, center=center, image_size=image_size)) 59 | val_dataset = DATASET('data', train=True, download=True, transform=get_mnist_transform(binary=binary, center=center, image_size=image_size)) if not final_mode else None 60 | 61 | # final mode 62 | if not final_mode: 63 | n = len(train_dataset.data) 64 | split_filename = os.path.join('data/MNIST', '{}-val{}-split.pt'.format(dataset, val_size)) 65 | if os.path.exists(split_filename): 66 | indices = torch.load(split_filename) 67 | else: 68 | indices = torch.from_numpy(np.random.permutation(n)) 69 | torch.save(indices, open(split_filename, 'wb')) 70 | train_dataset.data = torch.index_select(train_dataset.data, 0, indices[:n-val_size]) 71 | train_dataset.targets = torch.index_select(train_dataset.targets, 0, indices[:n-val_size]) 72 | val_dataset.data = torch.index_select(val_dataset.data, 0, indices[n-val_size:]) 73 | val_dataset.targets = torch.index_select(val_dataset.targets, 0, indices[n-val_size:]) 74 | else: 75 | pass 76 | 77 | # init dataset test 78 | test_dataset = DATASET('data', train=False, transform=get_mnist_transform(binary=binary, center=center, image_size=image_size)) 79 | 80 | # init dataloader 81 | train_loader = torch.utils.data.DataLoader(train_dataset, 82 | batch_size=train_batch_size, shuffle=True, **kwargs) 83 | val_loader = torch.utils.data.DataLoader(val_dataset, 84 | batch_size=eval_batch_size, shuffle=False, **kwargs) if not final_mode else None 85 | test_loader = torch.utils.data.DataLoader(test_dataset, 86 | batch_size=eval_batch_size, shuffle=False, **kwargs) 87 | 88 | # init info 89 | info = {} 90 | info['nclasses'] = nclasses 91 | 92 | return train_loader, val_loader, test_loader, info 93 | 94 | def get_sbmnist(train_batch_size, eval_batch_size, dataset, kwargs, final_mode=False): 95 | # get bmnist 96 | train_data, val_data, test_data = load_sbmnist_image('data') 97 | train_labels, val_labels, test_labels = torch.zeros(50000).long(), torch.zeros(10000).long(), torch.zeros(10000).long() 98 | 99 | # final mode 100 | if final_mode: 101 | train_data = torch.cat([train_data, val_data], dim=0) 102 | train_labels = torch.cat([train_labels, val_labels], dim=0) 103 | val_data = None 104 | val_labels = None 105 | 106 | # init datasets 107 | train_dataset = torch.utils.data.TensorDataset(train_data, train_labels) 108 | val_dataset = torch.utils.data.TensorDataset(val_data, val_labels) if not final_mode else None 109 | test_dataset = torch.utils.data.TensorDataset(test_data, test_labels) 110 | 111 | # init dataloader 112 | train_loader = torch.utils.data.DataLoader(train_dataset, 113 | batch_size=train_batch_size, shuffle=True, **kwargs) 114 | val_loader = torch.utils.data.DataLoader(val_dataset, 115 | batch_size=eval_batch_size, shuffle=False, **kwargs) if not final_mode else None 116 | test_loader = torch.utils.data.DataLoader(test_dataset, 117 | batch_size=eval_batch_size, shuffle=False, **kwargs) 118 | 119 | # init info 120 | info = {} 121 | info['nclasses'] = 10 122 | 123 | return train_loader, val_loader, test_loader, info 124 | 125 | def get_image_dataset(dataset, train_batch_size, eval_batch_size=None, cuda=False, final_mode=False): 126 | # init arguments 127 | if eval_batch_size is None: 128 | eval_batch_size = train_batch_size 129 | kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {} 130 | 131 | # get dataset 132 | if dataset in ['mnist']: 133 | return get_mnist(train_batch_size, eval_batch_size, dataset, kwargs, image_size=28, binary=False, center=False, final_mode=final_mode) 134 | elif dataset in ['cmnist']: 135 | return get_mnist(train_batch_size, eval_batch_size, dataset, kwargs, image_size=28, binary=False, center=True, final_mode=final_mode) 136 | elif dataset in ['dbmnist']: 137 | return get_mnist(train_batch_size, eval_batch_size, dataset, kwargs, image_size=28, binary=True, center=False, final_mode=final_mode) 138 | elif dataset in ['dbmnist-val5k']: 139 | return get_mnist(train_batch_size, eval_batch_size, dataset, kwargs, image_size=28, binary=True, center=False, val_size=5000, final_mode=final_mode) 140 | elif dataset in ['sbmnist']: 141 | return get_sbmnist(train_batch_size, eval_batch_size, dataset, kwargs, final_mode=final_mode) 142 | elif dataset in ['mnist32']: 143 | return get_mnist(train_batch_size, eval_batch_size, dataset, kwargs, image_size=32, binary=False, center=False, final_mode=final_mode) 144 | else: 145 | raise NotImplementedError('dataset: {}'.format(dataset)) 146 | -------------------------------------------------------------------------------- /datasets/sbmnist.py: -------------------------------------------------------------------------------- 1 | ''' 2 | copied and modified from 3 | 1) https://github.com/CW-Huang/torchkit/blob/8da6c100c48a1d1464765928fbf68fdfd99fff8a/torchkit/datasets.py 4 | 2) https://github.com/CW-Huang/torchkit/blob/8da6c100c48a1d1464765928fbf68fdfd99fff8a/torchkit/downloader.py 5 | ''' 6 | import os 7 | import urllib.request 8 | import numpy as np 9 | import torch 10 | floatX = 'float32' 11 | 12 | 13 | def create(*args): 14 | path = '/'.join(a for a in args) 15 | if not os.path.isdir(path): 16 | os.makedirs(path) 17 | 18 | def download_sbmnist(savedir): 19 | #print 'dynamically binarized mnist' 20 | #mnist_filenames = ['train-images-idx3-ubyte', 't10k-images-idx3-ubyte'] 21 | 22 | #for filename in mnist_filenames: 23 | # local_filename = os.path.join(savedir, filename) 24 | # urllib.request.urlretrieve( 25 | # "http://yann.lecun.com/exdb/mnist/{}.gz".format( 26 | # filename),local_filename+'.gz') 27 | # with gzip.open(local_filename+'.gz', 'rb') as f: 28 | # file_content = f.read() 29 | # with open(local_filename, 'wb') as f: 30 | # f.write(file_content) 31 | # np.savetxt(local_filename,load_mnist_images_np(local_filename)) 32 | # os.remove(local_filename+'.gz') 33 | 34 | print('download statically binarized mnist') 35 | subdatasets = ['train', 'valid', 'test'] 36 | for subdataset in subdatasets: 37 | filename = 'binarized_mnist_{}.amat'.format(subdataset) 38 | url = 'http://www.cs.toronto.edu/~larocheh/'\ 39 | 'public/datasets/binarized_mnist/'\ 40 | 'binarized_mnist_{}.amat'.format(subdataset) 41 | local_filename = os.path.join(savedir, filename) 42 | urllib.request.urlretrieve(url, local_filename) 43 | 44 | def load_sbmnist_image(root='data'): 45 | create(root, 'bmnist') 46 | droot = root+'/'+'bmnist' 47 | 48 | if not os.path.exists('{}/binarized_mnist_train.amat'.format(droot)): 49 | # download sbmnist 50 | download_sbmnist(droot) 51 | 52 | if not os.path.exists('{}/binarized_mnist_train.pt'.format(droot)): 53 | # Larochelle 2011 54 | path_tr = '{}/binarized_mnist_train.amat'.format(droot) 55 | path_va = '{}/binarized_mnist_valid.amat'.format(droot) 56 | path_te = '{}/binarized_mnist_test.amat'.format(droot) 57 | train_x = np.loadtxt(path_tr).astype(floatX).reshape(50000,784) 58 | valid_x = np.loadtxt(path_va).astype(floatX).reshape(10000,784) 59 | test_x = np.loadtxt(path_te).astype(floatX).reshape(10000,784) 60 | 61 | # save 62 | path_tr = '{}/binarized_mnist_train.pt'.format(droot) 63 | path_va = '{}/binarized_mnist_valid.pt'.format(droot) 64 | path_te = '{}/binarized_mnist_test.pt'.format(droot) 65 | train_x_pt = torch.from_numpy(train_x) 66 | valid_x_pt = torch.from_numpy(valid_x) 67 | test_x_pt = torch.from_numpy(test_x) 68 | torch.save(train_x_pt, open(path_tr, 'wb')) 69 | torch.save(valid_x_pt, open(path_va, 'wb')) 70 | torch.save(test_x_pt, open(path_te, 'wb')) 71 | 72 | else: 73 | path_tr = '{}/binarized_mnist_train.pt'.format(droot) 74 | path_va = '{}/binarized_mnist_valid.pt'.format(droot) 75 | path_te = '{}/binarized_mnist_test.pt'.format(droot) 76 | train_x_pt = torch.load(path_tr) 77 | valid_x_pt = torch.load(path_va) 78 | test_x_pt = torch.load(path_te) 79 | 80 | return train_x_pt, valid_x_pt, test_x_pt 81 | -------------------------------------------------------------------------------- /datasets/toy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | 5 | import numpy as np 6 | import sklearn.datasets 7 | from scipy.stats import multivariate_normal 8 | 9 | import torch 10 | import torch.utils.data 11 | 12 | 13 | def _normal_energy_func(x, mu=0., logvar=0.): 14 | energy = logvar + (x - mu)**2 / math.exp(logvar) + math.log(2.*math.pi) 15 | energy = 0.5 * energy 16 | return energy 17 | 18 | def normal_energy_func(x, mu, logvar, ndim): 19 | ''' 20 | x = b x ndim 21 | mu = 1 x ndim vector 22 | logvar = scalar 23 | ndim = scalar 24 | ''' 25 | assert x.dim() == 2 26 | assert x.size(1) == ndim 27 | batch_size = x.size(0) 28 | x = x.view(batch_size, -1) 29 | mu = mu.view(1, ndim).contiguous() 30 | 31 | energy = torch.sum(_normal_energy_func(x, mu.expand_as(x), logvar), dim=1, keepdim=True) 32 | return energy 33 | 34 | def normal_logprob(x, mu, std, ndim): 35 | ''' 36 | Inputs:⋅ 37 | x: b1 x ndim⋅ 38 | mu: 1 x ndim 39 | logvar: scalar 40 | Outputs: 41 | prob: b1 x 1 42 | ''' 43 | var = std**2 44 | logvar = math.log(var) 45 | logprob = - normal_energy_func(x, mu, logvar, ndim) 46 | return logprob 47 | 48 | def normal_prob(x, mu, std, ndim): 49 | logprob = normal_logprob(x, mu, std, ndim) 50 | prob = torch.exp(logprob) 51 | return prob 52 | 53 | 54 | # swissroll 55 | def get_swissroll(num_data): 56 | ''' 57 | copied and modified from https://github.com/caogang/wgan-gp/blob/ae47a185ed2e938c39cf3eb2f06b32dc1b6a2064/gan_toy.py#L153 58 | ''' 59 | # data 60 | data = sklearn.datasets.make_swiss_roll( 61 | n_samples=num_data, 62 | noise=0.75, #0.25 63 | ) 64 | data = data[0] 65 | data = data.astype('float32')[:, [0, 2]] 66 | data /= 3. #/= 7.5 # stdev plus a little 67 | 68 | # target 69 | target = np.zeros(num_data) 70 | 71 | # convert to torch tensor 72 | data = torch.from_numpy(data) 73 | target = torch.from_numpy(target) 74 | 75 | return data, target, None, None 76 | 77 | # exp1: single gaussian 78 | def exp1(num_data=1000): 79 | 80 | var = 1.0 81 | std = var**(0.5) 82 | N = 1 83 | if num_data % N != 0: 84 | raise ValueError('num_data should be multiple of {} (num_data = {})'.format(N, num_data)) 85 | 86 | # data and label 87 | mu = np.array([[0, 0]], dtype=np.float) 88 | mu = torch.from_numpy(mu) 89 | num_data_per_mixture = num_data // N 90 | sigma = math.sqrt(var) 91 | x = torch.zeros(num_data, 2) 92 | label = torch.LongTensor(num_data).zero_() 93 | for i in range(N): 94 | xx = x[i*num_data_per_mixture:(i+1)*num_data_per_mixture, :] 95 | xx.copy_(torch.cat( 96 | (torch.FloatTensor(num_data_per_mixture).normal_(mu[i,0], sigma).view(num_data_per_mixture, 1), 97 | torch.FloatTensor(num_data_per_mixture).normal_(mu[i,1], sigma).view(num_data_per_mixture, 1)), 1)) 98 | label[i*num_data_per_mixture:(i+1)*num_data_per_mixture] = i 99 | 100 | # classifier 101 | def classifier(x): 102 | numer = [(1.0 / float(N)) * normal_prob(x, mu[i, :], std, 2) for i in range(N)] 103 | numer = torch.cat(numer, dim=1) 104 | denom = torch.sum(numer, dim=1, keepdim=True) 105 | prob = numer / (denom + 1e-10) 106 | pred = prob.argmax(dim=1, keepdim=True) 107 | return pred 108 | 109 | # pdf 110 | def pdf(x): 111 | prob = 0 112 | for i in range(N): 113 | _mu = mu[i, :] 114 | prob += (1.0 / float(N)) * normal_prob(x, _mu, std, 2) 115 | return prob 116 | 117 | def logpdf(x): 118 | return torch.log((pdf(x) + 1e-10)) 119 | 120 | #return x, label, sumloglikelihood 121 | #return x, label, logpdf, classifier#, sumloglikelihood 122 | return x, label, None, None#, sumloglikelihood 123 | 124 | # mixture of three Gaussians 125 | def exp3(num_data=1000): 126 | if num_data % 4 != 0: 127 | raise ValueError('num_data should be multiple of 4. num_data = {}'.format(num_data)) 128 | 129 | center = 2 130 | sigma = 0.5 #math.sqrt(3) 131 | num_of_modes = 3 132 | 133 | # init data⋅ 134 | d1x = torch.FloatTensor(num_data//num_of_modes, 1) #//4, 1) 135 | d1y = torch.FloatTensor(num_data//num_of_modes, 1) #//4, 1) 136 | d1x.normal_(center, sigma * 1) #3) 137 | d1y.normal_(center, sigma * 1) 138 | 139 | #d2x = torch.FloatTensor(num_data//4, 1) 140 | #d2y = torch.FloatTensor(num_data//4, 1) 141 | #d2x.normal_(-center, sigma * 1) 142 | #d2y.normal_(center, sigma * 1) #3) 143 | 144 | d3x = torch.FloatTensor(num_data//num_of_modes, 1) 145 | d3y = torch.FloatTensor(num_data//num_of_modes, 1) 146 | d3x.normal_(center, sigma * 1) #3) 147 | d3y.normal_(-center, sigma * 1) #2) 148 | 149 | d4x = torch.FloatTensor(num_data//num_of_modes, 1) #//4, 1) 150 | d4y = torch.FloatTensor(num_data//num_of_modes, 1) #//4, 1) 151 | d4x.normal_(-center, sigma * 1) #2) 152 | d4y.normal_(-center, sigma * 1) #2) 153 | 154 | d1 = torch.cat((d1x, d1y), 1) 155 | #d2 = torch.cat((d2x, d2y), 1) 156 | d3 = torch.cat((d3x, d3y), 1) 157 | d4 = torch.cat((d4x, d4y), 1) 158 | 159 | #d = torch.cat((d1, d2, d3, d4), 0) 160 | d = torch.cat((d1, d3, d4), 0) 161 | 162 | # label 163 | label = torch.LongTensor((num_data//num_of_modes)*num_of_modes).zero_() 164 | #for i in range(4): 165 | # label[i*(num_data//4):(i+1)*(num_data//4)] = i 166 | for i in range(num_of_modes): 167 | label[i*(num_data//num_of_modes):(i+1)*(num_data//num_of_modes)] = i 168 | 169 | # shuffle 170 | shuffle = torch.randperm(d.size()[0]) 171 | d = torch.index_select(d, 0, shuffle) 172 | label = torch.index_select(label, 0, shuffle) 173 | 174 | # pdf 175 | #rv1 = multivariate_normal([ center, center], [[math.pow(sigma * 3, 2), 0.0], [0.0, math.pow(sigma * 1, 2)]]) 176 | #rv2 = multivariate_normal([-center, center], [[math.pow(sigma * 1, 2), 0.0], [0.0, math.pow(sigma * 3, 2)]]) 177 | #rv3 = multivariate_normal([ center, -center], [[math.pow(sigma * 3, 2), 0.0], [0.0, math.pow(sigma * 2, 2)]]) 178 | #rv4 = multivariate_normal([-center, -center], [[math.pow(sigma * 2, 2), 0.0], [0.0, math.pow(sigma * 2, 2)]]) 179 | rv1 = multivariate_normal([ center, center], [[math.pow(sigma * 1, 2), 0.0], [0.0, math.pow(sigma * 1, 2)]]) 180 | rv3 = multivariate_normal([ center, -center], [[math.pow(sigma * 1, 2), 0.0], [0.0, math.pow(sigma * 1, 2)]]) 181 | rv4 = multivariate_normal([-center, -center], [[math.pow(sigma * 1, 2), 0.0], [0.0, math.pow(sigma * 1, 2)]]) 182 | 183 | def pdf(x): 184 | #prob = 0.25 * rv1.pdf(x) + 0.25 * rv2.pdf(x) + 0.25 * rv3.pdf(x) + 0.25 * rv4.pdf(x) 185 | prob = 1./float(num_of_modes) (rv1.pdf(x) + rv3.pdf(x) + rv4.pdf(x)) 186 | return prob 187 | 188 | def sumloglikelihood(x): 189 | return np.sum(np.log((pdf(x) + 1e-10))) 190 | 191 | #return d, label, sumloglikelihood 192 | return d, label, None, None 193 | 194 | # exp4: grid shapes 195 | def exp4(num_data=1000): 196 | 197 | var = 0.1 198 | std = var**(0.5) 199 | max_x = 4 #21 200 | max_y = 4 #21 201 | min_x = -max_x 202 | min_y = -max_y 203 | n = 5 204 | 205 | # init 206 | nx, ny = (n, n) 207 | x = np.linspace(min_x, max_x, nx) 208 | y = np.linspace(min_y, max_y, ny) 209 | xv, yv = np.meshgrid(x, y) 210 | N = xv.size 211 | if num_data % N != 0: 212 | raise ValueError('num_data should be multiple of {} (num_data = {})'.format(N, num_data)) 213 | 214 | # data and label 215 | mu = np.concatenate((xv.reshape(N,1), yv.reshape(N,1)), axis=1) 216 | mu = torch.FloatTensor(mu) 217 | num_data_per_mixture = num_data // N 218 | sigma = math.sqrt(var) 219 | x = torch.zeros(num_data, 2) 220 | label = torch.LongTensor(num_data).zero_() 221 | for i in range(N): 222 | xx = x[i*num_data_per_mixture:(i+1)*num_data_per_mixture, :] 223 | xx.copy_(torch.cat( 224 | (torch.FloatTensor(num_data_per_mixture).normal_(mu[i,0], sigma).view(num_data_per_mixture, 1), 225 | torch.FloatTensor(num_data_per_mixture).normal_(mu[i,1], sigma).view(num_data_per_mixture, 1)), 1)) 226 | label[i*num_data_per_mixture:(i+1)*num_data_per_mixture] = i 227 | 228 | # classifier 229 | def classifier(x): 230 | numer = [(1.0 / float(N)) * normal_prob(x, mu[i, :], std, 2) for i in range(N)] 231 | numer = torch.cat(numer, dim=1) 232 | denom = torch.sum(numer, dim=1, keepdim=True) 233 | prob = numer / (denom + 1e-10) 234 | pred = prob.argmax(dim=1, keepdim=True) 235 | return pred 236 | 237 | # pdf 238 | def pdf(x): 239 | prob = 0 240 | for i in range(N): 241 | _mu = mu[i, :] 242 | prob += (1.0 / float(N)) * normal_prob(x, _mu, std, 2) 243 | return prob 244 | 245 | def logpdf(x): 246 | return torch.log((pdf(x) + 1e-10)) 247 | 248 | #return x, label, sumloglikelihood 249 | #return x, label, logpdf, classifier#, sumloglikelihood 250 | return x, label, None, None 251 | 252 | def get_toy_data(name, num_data): 253 | if name == 'swissroll': 254 | return get_swissroll(num_data) 255 | elif name == 'gaussian': 256 | return exp1(num_data) 257 | elif name == '25gaussians': 258 | return exp4(num_data) 259 | elif name == 'toy3': 260 | return exp3(num_data) 261 | else: 262 | raise NotImplementedError('no toy data: {}'.format(name)) 263 | 264 | def generate_data(name, num_train_samples=2000000, num_test_samples=20000, num_val_samples=2000): 265 | path = 'data/toy' 266 | os.system('mkdir -p {}'.format(path)) 267 | 268 | # generate 269 | train_data_tensor, train_target_tensor, logpdf, classifier = get_toy_data(name, num_train_samples) 270 | val_data_tensor, val_target_tensor, _, _ = get_toy_data(name, num_val_samples) 271 | test_data_tensor, test_target_tensor, _, _ = get_toy_data(name, num_test_samples) 272 | 273 | # save 274 | with open(os.path.join(path, '{}.pt'.format(name)), 'wb') as f: 275 | torch.save({ 276 | 'train': (train_data_tensor, train_target_tensor), 277 | 'val': (val_data_tensor, val_target_tensor), 278 | 'test': (test_data_tensor, test_target_tensor), 279 | 'logpdf': logpdf, 280 | 'classifier': classifier, 281 | }, f) 282 | 283 | # return 284 | return (logpdf, classifier, 285 | train_data_tensor, train_target_tensor, 286 | val_data_tensor, val_target_tensor, 287 | test_data_tensor, test_target_tensor, 288 | ) 289 | 290 | ''' 291 | get dataset with name 292 | ''' 293 | def get_toy_dataset_with_name(name, train_batch_size, eval_batch_size, kwargs): 294 | path = 'data/toy' 295 | filename = os.path.join(path, '{}.pt'.format(name)) 296 | if os.path.exists(filename): 297 | data = torch.load(filename) 298 | logpdf = data['logpdf'] 299 | classifier = data['classifier'] 300 | train_data_tensor, train_target_tensor = data['train'] 301 | val_data_tensor, val_target_tensor = data['val'] 302 | test_data_tensor, test_target_tensor = data['test'] 303 | else: 304 | (logpdf, classifier, 305 | train_data_tensor, train_target_tensor, 306 | val_data_tensor, val_target_tensor, 307 | test_data_tensor, test_target_tensor, 308 | ) = generate_data(name) 309 | 310 | # init dataset (train / val) 311 | train_dataset = torch.utils.data.TensorDataset(train_data_tensor, train_target_tensor.long()) 312 | val_dataset = torch.utils.data.TensorDataset(val_data_tensor, val_target_tensor.long()) 313 | test_dataset = torch.utils.data.TensorDataset(test_data_tensor, test_target_tensor.long()) 314 | 315 | # init dataloader 316 | train_loader = torch.utils.data.DataLoader(train_dataset, 317 | batch_size=train_batch_size, shuffle=True, **kwargs) 318 | val_loader = torch.utils.data.DataLoader(val_dataset, 319 | batch_size=train_batch_size, shuffle=False, **kwargs) 320 | test_loader = torch.utils.data.DataLoader(test_dataset, 321 | batch_size=eval_batch_size, shuffle=False, **kwargs) 322 | 323 | # init info 324 | info = {} 325 | info['nclasses'] = len(torch.unique(train_target_tensor)) 326 | info['classifier'] = classifier 327 | info['logpdf'] = logpdf 328 | 329 | return train_loader, val_loader, test_loader, info 330 | 331 | 332 | ''' 333 | get dataset 334 | ''' 335 | def get_toy_dataset(dataset, train_batch_size, eval_batch_size=None, cuda=False): 336 | # init arguments 337 | if eval_batch_size is None: 338 | eval_batch_size = train_batch_size 339 | kwargs = {'num_workers': 0, 'pin_memory': True} if cuda else {} 340 | 341 | # get dataset 342 | if dataset in ['swissroll', 'toy3', '25gaussians', 'gaussian']: 343 | return get_toy_dataset_with_name(dataset, train_batch_size, eval_batch_size, kwargs) 344 | else: 345 | raise NotImplementedError('dataset: {}'.format(dataset)) 346 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # dae 2 | from models.dae.mlp import DAE as MLPDAE 3 | from models.dae.mlp import ConditionalDAE as MLPCDAE 4 | 5 | from models.resdae.mlp import DAE as MLPResDAE 6 | from models.resdae.mlp import ARDAE as MLPResARDAE 7 | from models.resdae.mlp import ConditionalDAE as MLPResCDAE 8 | from models.resdae.mlp import ConditionalARDAE as MLPResCARDAE 9 | 10 | from models.graddae.mlp import DAE as MLPGradDAE 11 | from models.graddae.mlp import ARDAE as MLPGradARDAE 12 | from models.graddae.mlp import ConditionalDAE as MLPGradCDAE 13 | from models.graddae.mlp import ConditionalARDAE as MLPGradCARDAE 14 | 15 | # vae 16 | from models.vae.toy import VAE as ToyVAE 17 | from models.vae.mnist import VAE as MNISTVAE 18 | from models.vae.conv import VAE as MNISTConvVAE 19 | from models.vae.resconv import VAE as MNISTResConvVAE 20 | from models.vae.auxtoy import VAE as ToyAuxVAE 21 | from models.vae.auxmnist import VAE as MNISTAuxVAE 22 | from models.vae.auxconv import VAE as MNISTConvAuxVAE 23 | from models.vae.auxresconv import VAE as MNISTResConvAuxVAE 24 | 25 | # ivae 26 | from models.ivae.toy import ImplicitPosteriorVAE as ToyIPVAE 27 | from models.ivae.mnist import ImplicitPosteriorVAE as MNISTIPVAE 28 | from models.ivae.conv import ImplicitPosteriorVAE as ConvIPVAE 29 | from models.ivae.resconv import ImplicitPosteriorVAE as ResConvIPVAE 30 | from models.ivae.auxtoy import ImplicitPosteriorVAE as ToyAuxIPVAE 31 | from models.ivae.auxmnist import ImplicitPosteriorVAE as MNISTAuxIPVAE 32 | from models.ivae.auxconv import ImplicitPosteriorVAE as MNISTConvAuxIPVAE 33 | from models.ivae.auxresconv import ImplicitPosteriorVAE as MNISTResConvAuxIPVAE 34 | from models.ivae.auxresconv2 import ImplicitPosteriorVAE as MNISTResConvAuxIPVAEClipped 35 | -------------------------------------------------------------------------------- /models/aux.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd 3 | from torch.autograd import Function 4 | import torch.nn.functional as F 5 | 6 | ''' 7 | https://pytorch.org/docs/stable/notes/extending.html 8 | ''' 9 | class AuxLossForGradFunction(Function): 10 | 11 | # Note that both forward and backward are @staticmethods 12 | @staticmethod 13 | def forward(ctx, input, grad): 14 | ctx.save_for_backward(input, grad.detach()) 15 | return torch.sum(torch.zeros_like(input)) 16 | 17 | # This function has only a single output, so it gets only one gradient 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | # This is a pattern that is very convenient - at the top of backward 21 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 22 | # None. Thanks to the fact that additional trailing Nones are 23 | # ignored, the return statement is simple even when the function has 24 | # optional inputs. 25 | input, grad = ctx.saved_tensors 26 | grad_input = grad_grad = None 27 | 28 | # These needs_input_grad checks are optional and there only to 29 | # improve efficiency. If you want to make your code simpler, you can 30 | # skip them. Returning gradients for inputs that don't require it is 31 | # not an error. 32 | if ctx.needs_input_grad[0]: 33 | grad_input = grad 34 | 35 | return grad_input, grad_grad 36 | 37 | aux_loss_for_grad = AuxLossForGradFunction.apply 38 | 39 | 40 | ''' test ''' 41 | #import ipdb 42 | if __name__ == '__main__': 43 | batch_size = 10 44 | input_dim = 20 45 | input = torch.randn(batch_size, input_dim, dtype=torch.double, requires_grad=True) 46 | grad = torch.randn(batch_size, input_dim, dtype=torch.double, requires_grad=False) 47 | target = torch.randn(batch_size, input_dim, dtype=torch.double) 48 | 49 | #loss = F.mse_loss(input, target) 50 | #loss.backward() 51 | #ipdb.set_trace() 52 | 53 | loss = aux_loss_for_grad(input, grad) 54 | loss.backward() 55 | #ipdb.set_trace() 56 | print(input.grad) 57 | print(grad) 58 | print(torch.allclose(input.grad, grad)) 59 | -------------------------------------------------------------------------------- /models/ivae/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import autograd 8 | from torch.distributions import MultivariateNormal 9 | 10 | from models.layers import Identity, MLP 11 | from models.reparam import NormalDistributionLinear 12 | from utils import loss_kld_gaussian, loss_kld_gaussian_vs_gaussian, loss_recon_gaussian, loss_recon_bernoulli_with_logit, normal_energy_func 13 | from utils import logprob_gaussian, get_covmat 14 | from utils import get_nonlinear_func 15 | from utils import conv_out_size, deconv_out_size 16 | 17 | from models.vae.conv import Decoder 18 | from models.vae.auxconv import weight_init 19 | 20 | from utils import expand_tensor 21 | from utils import cond_jac_clamping_loss 22 | 23 | 24 | def sample_noise(sz, std=None, device=torch.device('cpu')): 25 | std = std if std is not None else 1 26 | eps = torch.randn(*sz).to(device) 27 | return std * eps 28 | 29 | def sample_gaussian(mu, logvar): 30 | std = torch.exp(0.5*logvar) 31 | eps = torch.randn_like(std) 32 | return mu + std * eps 33 | 34 | def convert_2d_3d_tensor(input, sample_size): 35 | assert input.dim() == 2 36 | input_expanded, _ = expand_tensor(input, sample_size, do_unsqueeze=True) 37 | return input_expanded 38 | 39 | def convert_4d_5d_tensor(input, sample_size): 40 | assert input.dim() == 4 41 | input_expanded, _ = expand_tensor(input, sample_size, do_unsqueeze=True) 42 | return input_expanded 43 | 44 | class Encoder(nn.Module): 45 | def __init__(self, 46 | input_height=28, 47 | input_channels=1, 48 | noise_dim=100, 49 | z_dim=32, 50 | nonlinearity='softplus', 51 | enc_noise=False, 52 | ): 53 | super().__init__() 54 | self.input_height = input_height 55 | self.input_channels = input_channels 56 | self.noise_dim = noise_dim 57 | self.z_dim = z_dim 58 | self.nonlinearity = nonlinearity 59 | self.enc_noise = enc_noise 60 | h_dim = 256 61 | nos_dim = noise_dim if not enc_noise else h_dim 62 | 63 | s_h = input_height 64 | s_h2 = conv_out_size(s_h, 5, 2, 2) 65 | s_h4 = conv_out_size(s_h2, 5, 2, 2) 66 | s_h8 = conv_out_size(s_h4, 5, 2, 2) 67 | #print(s_h, s_h2, s_h4, s_h8) 68 | 69 | self.afun = get_nonlinear_func(nonlinearity) 70 | self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True) 71 | self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True) 72 | self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True) 73 | self.fc4 = nn.Linear(s_h8*s_h8*32 + nos_dim, 800, bias=True) 74 | self.fc5 = nn.Linear(800, z_dim, bias=True) 75 | 76 | self.nos_encode = Identity() if not enc_noise \ 77 | else MLP(input_dim=noise_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=2, use_nonlinearity_output=True) 78 | 79 | def sample_noise(self, batch_size, std=None, device=None): 80 | return sample_noise(sz=[batch_size, self.noise_dim], std=std, device=device) 81 | 82 | def _forward_inp(self, x): 83 | batch_size = x.size(0) 84 | x = x.view(batch_size, self.input_channels, self.input_height, self.input_height) 85 | 86 | # rescale 87 | x = 2*x -1 88 | 89 | # enc 90 | h1 = self.afun(self.conv1(x)) 91 | h2 = self.afun(self.conv2(h1)) 92 | h3 = self.afun(self.conv3(h2)) 93 | inp = h3.view(batch_size, -1) 94 | 95 | return inp 96 | 97 | def _forward_nos(self, batch_size=None, noise=None, std=None, device=None): 98 | assert batch_size is not None or noise is not None 99 | if noise is None: 100 | noise = self.sample_noise(batch_size, std=std, device=device) 101 | 102 | # enc 103 | nos = self.nos_encode(noise) 104 | 105 | return nos 106 | 107 | def _forward_all(self, inp, nos): 108 | # concat 109 | inp_nos = torch.cat([inp, nos], dim=1) 110 | 111 | # forward 112 | h4 = self.afun(self.fc4(inp_nos)) 113 | z = self.fc5(h4) 114 | 115 | return z 116 | 117 | def forward(self, x, noise=None, std=None, nz=1): 118 | batch_size = x.size(0) 119 | if noise is None: 120 | noise = self.sample_noise(batch_size*nz, std=std, device=x.device) 121 | else: 122 | assert noise.size(0) == batch_size*nz 123 | assert noise.size(1) == self.noise_dim 124 | 125 | # enc 126 | nos = self._forward_nos(noise=noise, std=std, device=x.device) 127 | inp = self._forward_inp(x) 128 | 129 | # view 130 | inp = inp.unsqueeze(1).expand(-1, nz, -1).contiguous() 131 | inp = inp.view(batch_size*nz, -1) 132 | 133 | # forward 134 | z = self._forward_all(inp, nos) 135 | 136 | return z.view(batch_size, nz, -1) 137 | 138 | class ImplicitPosteriorVAE(nn.Module): 139 | def __init__(self, 140 | energy_func=normal_energy_func, 141 | input_height=28, 142 | input_channels=1, 143 | z_dim=32, 144 | noise_dim=100, 145 | nonlinearity='softplus', 146 | do_xavier=True, 147 | #do_m5bias=False, 148 | ): 149 | super().__init__() 150 | self.energy_func = energy_func 151 | self.input_height = input_height 152 | self.input_channels = input_channels 153 | self.z_dim = z_dim 154 | self.latent_dim = z_dim # for ais 155 | self.noise_dim = noise_dim 156 | self.nonlinearity = nonlinearity 157 | self.do_xavier = do_xavier 158 | #self.do_m5bias = do_m5bias 159 | 160 | self.encode = Encoder(input_height, input_channels, noise_dim, z_dim, nonlinearity=nonlinearity) 161 | self.decode = Decoder(input_height, input_channels, z_dim, nonlinearity=nonlinearity) 162 | self.reset_parameters() 163 | 164 | def reset_parameters(self): 165 | if self.do_xavier: 166 | self.apply(weight_init) 167 | #if self.do_m5bias: 168 | # torch.nn.init.constant_(self.decode.reparam.logit_fn.bias, -5) 169 | 170 | def loss(self, z, logit_x, target_x, beta=1.0): 171 | # loss from energy func 172 | prior_loss = self.energy_func(z.view(-1, self.z_dim)) 173 | 174 | # recon loss (neg likelihood): -log p(x|z) 175 | recon_loss = loss_recon_bernoulli_with_logit(logit_x, target_x, do_sum=False) 176 | 177 | # add loss 178 | loss = recon_loss + beta*prior_loss 179 | return loss.mean(), recon_loss.mean(), prior_loss.mean() 180 | 181 | def jac_clamping_loss(self, input, z, eps, std, nz, eta_min, p=2, EPS=1.): 182 | raise NotImplementedError 183 | 184 | def forward_hidden(self, input, std=None, nz=1): 185 | # init 186 | batch_size = input.size(0) 187 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 188 | 189 | # gen noise source 190 | eps = sample_noise(sz=[batch_size*nz, self.noise_dim], std=std, device=input.device) 191 | 192 | # sample z 193 | z = self.encode(input, noise=eps, std=std, nz=nz) 194 | 195 | return z 196 | 197 | def forward(self, input, beta=1.0, eta=0.0, lmbd=0.0, std=None, nz=1): 198 | # init 199 | batch_size = input.size(0) 200 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 201 | input_expanded = convert_4d_5d_tensor(input, sample_size=nz) 202 | input_expanded_flattened = input_expanded.view(batch_size*nz, self.input_channels, self.input_height, self.input_height) 203 | 204 | # gen noise source 205 | eps = sample_noise(sz=[batch_size*nz, self.noise_dim], std=std, device=input.device) 206 | 207 | # sample z 208 | z = self.encode(input, noise=eps, std=std, nz=nz) 209 | 210 | # z flattten 211 | z_flattened = z.view(batch_size*nz, -1) 212 | 213 | # decode 214 | x, logit_x = self.decode(z_flattened) 215 | 216 | # loss 217 | if lmbd > 0: 218 | raise NotImplementedError 219 | jaclmp_loss = lmbd*self.jac_clamping_loss(input, z, eps, std=std, nz=nz, eta_min=eta) 220 | else: 221 | jaclmp_loss = 0 222 | loss, recon_loss, prior_loss = self.loss( 223 | z_flattened, 224 | logit_x, input_expanded_flattened, 225 | beta=beta, 226 | ) 227 | loss += jaclmp_loss 228 | 229 | # return 230 | return x, torch.sigmoid(logit_x), z, loss, recon_loss.detach(), prior_loss.detach() 231 | 232 | def generate(self, batch_size=1): 233 | # init mu_z and logvar_z (as unit normal dist) 234 | weight = next(self.parameters()) 235 | mu_z = weight.new_zeros(batch_size, self.z_dim) 236 | logvar_z = weight.new_zeros(batch_size, self.z_dim) 237 | 238 | # sample z (from unit normal dist) 239 | z = sample_gaussian(mu_z, logvar_z) # sample z 240 | 241 | # decode 242 | output, logit_x = self.decode(z) 243 | 244 | # return 245 | return output, torch.sigmoid(logit_x), z 246 | 247 | def logprob(self, input, sample_size=128, z=None, std=None): 248 | return self.logprob_w_cov_gaussian_posterior(input, sample_size, z, std) 249 | 250 | def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None): 251 | # init 252 | batch_size = input.size(0) 253 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 254 | assert sample_size >= 2*self.z_dim 255 | 256 | ''' get z and pseudo log q(newz|x) ''' 257 | z, newz = [], [] 258 | logposterior = [] 259 | inp = self.encode._forward_inp(input).detach() 260 | for i in range(batch_size): 261 | _inp = inp[i:i+1, :].expand(sample_size, inp.size(1)) 262 | _nos = self.encode._forward_nos(batch_size=sample_size, std=std, device=input.device).detach() 263 | _z = self.encode._forward_all(_inp, _nos) # ssz x zdim 264 | z += [_z.detach().unsqueeze(0)] 265 | z = torch.cat(z, dim=0) # bsz x ssz x zdim 266 | mu_qz = torch.mean(z, dim=1) # bsz x zdim 267 | for i in range(batch_size): 268 | _cov_qz = get_covmat(z[i, :, :]) 269 | _rv_z = MultivariateNormal(mu_qz[i], _cov_qz) 270 | _newz = _rv_z.rsample(torch.Size([1, sample_size])) 271 | _logposterior = _rv_z.log_prob(_newz) 272 | 273 | newz += [_newz] 274 | logposterior += [_logposterior] 275 | newz = torch.cat(newz, dim=0) # bsz x ssz x zdim 276 | logposterior = torch.cat(logposterior, dim=0) # bsz x ssz 277 | 278 | ''' get log p(z) ''' 279 | # get prior (as unit normal dist) 280 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 281 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 282 | logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) 283 | logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 284 | 285 | ''' get log p(x|z) ''' 286 | # decode 287 | logit_x = [] 288 | #for i in range(sample_size): 289 | for i in range(batch_size): 290 | _, _logit_x = self.decode(newz[i, :, :]) # ssz x zdim 291 | logit_x += [_logit_x.detach().unsqueeze(0)] 292 | logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim 293 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_channels, self.input_height, self.input_height) # bsz x ssz x input_dim 294 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_x, _input, reduction='none') 295 | loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size, -1), dim=2) # bsz x ssz 296 | 297 | ''' get log p(x|z)p(z)/q(z|x) ''' 298 | logprob = loglikelihood + logprior - logposterior # bsz x ssz 299 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 300 | rprob = (logprob - logprob_max).exp() # relative prob 301 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 302 | 303 | # return 304 | return logprob.mean() 305 | -------------------------------------------------------------------------------- /models/ivae/resconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import autograd 8 | from torch.distributions import MultivariateNormal 9 | 10 | from models.layers import Identity, MLP, ResMLP 11 | from models.reparam import NormalDistributionLinear 12 | from utils import loss_kld_gaussian, loss_kld_gaussian_vs_gaussian, loss_recon_gaussian, loss_recon_bernoulli_with_logit, normal_energy_func 13 | from utils import logprob_gaussian, get_covmat 14 | from utils import get_nonlinear_func 15 | from utils import conv_out_size, deconv_out_size 16 | import models.layers2 as nn_ 17 | 18 | from models.vae.resconv import Decoder 19 | 20 | from utils import expand_tensor 21 | from utils import cond_jac_clamping_loss 22 | 23 | def get_afunc(nonlinearity_type='elu'): 24 | if nonlinearity_type == 'relu': 25 | return nn.ReLU() 26 | elif nonlinearity_type == 'elu': 27 | return nn.ELU() 28 | elif nonlinearity_type == 'softplus': 29 | return nn.Softplus() 30 | else: 31 | raise NotImplementedError 32 | 33 | def sample_noise(sz, std=None, device=torch.device('cpu')): 34 | std = std if std is not None else 1 35 | eps = torch.randn(*sz).to(device) 36 | return std * eps 37 | 38 | def sample_gaussian(mu, logvar): 39 | std = torch.exp(0.5*logvar) 40 | eps = torch.randn_like(std) 41 | return mu + std * eps 42 | 43 | def convert_2d_3d_tensor(input, sample_size): 44 | assert input.dim() == 2 45 | input_expanded, _ = expand_tensor(input, sample_size, do_unsqueeze=True) 46 | return input_expanded 47 | 48 | def convert_4d_5d_tensor(input, sample_size): 49 | assert input.dim() == 4 50 | input_expanded, _ = expand_tensor(input, sample_size, do_unsqueeze=True) 51 | return input_expanded 52 | 53 | class Encoder(nn.Module): 54 | def __init__(self, 55 | noise_dim=100, 56 | z_dim=32, 57 | c_dim=512,#450, 58 | h_dim=800, 59 | num_hidden_layers=1, 60 | nonlinearity='elu', #act=nn.ELU(), 61 | do_center=False, 62 | enc_noise=False, 63 | enc_type='mlp', 64 | ): 65 | super().__init__() 66 | self.noise_dim = noise_dim 67 | self.z_dim = z_dim 68 | self.c_dim = c_dim 69 | self.h_dim = h_dim 70 | self.num_hidden_layers = num_hidden_layers 71 | assert num_hidden_layers > 0 72 | self.nonlinearity = nonlinearity 73 | self.do_center = do_center 74 | self.enc_noise = enc_noise 75 | nos_dim = noise_dim if not enc_noise else c_dim 76 | self.enc_type = enc_type 77 | assert enc_type in ['mlp', 'res-wn-mlp', 'res-mlp', 'res-wn-mlp-lin', 'res-mlp-lin'] 78 | 79 | act = get_afunc(nonlinearity_type=nonlinearity) 80 | 81 | self.inp_encode = nn.Sequential( 82 | nn_.ResConv2d(1,16,3,2,padding=1,activation=act), 83 | act, 84 | nn_.ResConv2d(16,16,3,1,padding=1,activation=act), 85 | act, 86 | nn_.ResConv2d(16,32,3,2,padding=1,activation=act), 87 | act, 88 | nn_.ResConv2d(32,32,3,1,padding=1,activation=act), 89 | act, 90 | nn_.ResConv2d(32,32,3,2,padding=1,activation=act), 91 | act, 92 | nn_.Reshape((-1,32*4*4)), 93 | nn_.ResLinear(32*4*4,c_dim), 94 | act 95 | ) 96 | #self.fc = nn.Sequential( 97 | # nn.Linear(c_dim + nos_dim, h_dim, bias=True), 98 | # act, 99 | # nn.Linear(h_dim, z_dim, bias=True), 100 | # ) 101 | if enc_type == 'mlp': 102 | self.fc = MLP(input_dim=c_dim+nos_dim, hidden_dim=h_dim, output_dim=z_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers, use_nonlinearity_output=False) 103 | elif enc_type == 'res-wn-mlp': 104 | self.fc = ResMLP(input_dim=c_dim+nos_dim, hidden_dim=h_dim, output_dim=z_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers, use_nonlinearity_output=False, layer='wnlinear') 105 | elif enc_type == 'res-mlp': 106 | self.fc = ResMLP(input_dim=c_dim+nos_dim, hidden_dim=h_dim, output_dim=z_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers, use_nonlinearity_output=False, layer='linear') 107 | elif enc_type == 'res-wn-mlp-lin': 108 | self.fc = nn.Sequential( 109 | ResMLP(input_dim=c_dim+nos_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True, layer='wnlinear'), 110 | nn.Linear(h_dim, z_dim, bias=True), 111 | ) 112 | elif enc_type == 'res-mlp-lin': 113 | self.fc = nn.Sequential( 114 | ResMLP(input_dim=c_dim+nos_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True, layer='linear'), 115 | nn.Linear(h_dim, z_dim, bias=True), 116 | ) 117 | else: 118 | raise NotImplementedError 119 | 120 | self.nos_encode = Identity() if not enc_noise \ 121 | else nn.Sequential( 122 | nn.Linear(noise_dim, c_dim, bias=True), 123 | act, 124 | ) 125 | 126 | def sample_noise(self, batch_size, std=None, device=None): 127 | return sample_noise(sz=[batch_size, self.noise_dim], std=std, device=device) 128 | 129 | def _forward_inp(self, x): 130 | batch_size = x.size(0) 131 | x = x.view(batch_size, 1, 28, 28) 132 | 133 | # rescale 134 | if self.do_center: 135 | x = 2*x -1 136 | 137 | # enc 138 | inp = self.inp_encode(x) 139 | 140 | return inp 141 | 142 | def _forward_nos(self, batch_size=None, noise=None, std=None, device=None): 143 | assert batch_size is not None or noise is not None 144 | if noise is None: 145 | noise = self.sample_noise(batch_size, std=std, device=device) 146 | 147 | # enc 148 | nos = self.nos_encode(noise) 149 | 150 | return nos 151 | 152 | def _forward_all(self, inp, nos): 153 | # concat 154 | inp_nos = torch.cat([inp, nos], dim=1) 155 | 156 | # forward 157 | z = self.fc(inp_nos) 158 | 159 | return z 160 | 161 | def forward(self, x, noise=None, std=None, nz=1): 162 | batch_size = x.size(0) 163 | if noise is None: 164 | noise = self.sample_noise(batch_size*nz, std=std, device=x.device) 165 | else: 166 | assert noise.size(0) == batch_size*nz 167 | assert noise.size(1) == self.noise_dim 168 | 169 | # enc 170 | nos = self._forward_nos(noise=noise, std=std, device=x.device) 171 | inp = self._forward_inp(x) 172 | 173 | # view 174 | inp = inp.unsqueeze(1).expand(-1, nz, -1).contiguous() 175 | inp = inp.view(batch_size*nz, -1) 176 | 177 | # forward 178 | z = self._forward_all(inp, nos) 179 | 180 | return z.view(batch_size, nz, -1) 181 | 182 | class ImplicitPosteriorVAE(nn.Module): 183 | def __init__(self, 184 | energy_func=normal_energy_func, 185 | input_height=28, 186 | input_channels=1, 187 | z_dim=32, 188 | noise_dim=100, 189 | c_dim=512, #450, 190 | h_dim=800, 191 | num_hidden_layers=1, 192 | nonlinearity='elu', 193 | do_center=False, 194 | do_m5bias=False, 195 | enc_noise=False, 196 | enc_type='mlp', 197 | ): 198 | super().__init__() 199 | self.energy_func = energy_func 200 | self.input_height = input_height 201 | self.input_channels = input_channels 202 | self.z_dim = z_dim 203 | self.latent_dim = z_dim # for ais 204 | self.noise_dim = noise_dim 205 | self.c_dim = c_dim 206 | self.h_dim = h_dim 207 | self.num_hidden_layers = num_hidden_layers 208 | self.nonlinearity = nonlinearity 209 | self.do_center = do_center 210 | self.do_m5bias = do_m5bias 211 | self.enc_noise = enc_noise 212 | self.enc_type = enc_type 213 | 214 | assert input_height == 28 215 | assert input_channels == 1 216 | #if nonlinearity == 'elu': 217 | # afunc = nn.ELU() 218 | #elif nonlinearity == 'softplus': 219 | # afunc = nn.Softplus() 220 | #else: 221 | # raise NotImplementedError 222 | 223 | self.encode = Encoder(noise_dim=noise_dim, z_dim=z_dim, c_dim=c_dim, h_dim=h_dim, num_hidden_layers=num_hidden_layers, nonlinearity=nonlinearity, do_center=do_center, enc_noise=enc_noise, enc_type=enc_type) 224 | self.decode = Decoder(z_dim=z_dim, c_dim=c_dim, act=nn.ELU(), do_m5bias=do_m5bias) 225 | 226 | def loss(self, z, logit_x, target_x, beta=1.0): 227 | # loss from energy func 228 | prior_loss = self.energy_func(z.view(-1, self.z_dim)) 229 | 230 | # recon loss (neg likelihood): -log p(x|z) 231 | recon_loss = loss_recon_bernoulli_with_logit(logit_x, target_x, do_sum=False) 232 | 233 | # add loss 234 | loss = recon_loss + beta*prior_loss 235 | return loss.mean(), recon_loss.mean(), prior_loss.mean() 236 | 237 | def jac_clamping_loss(self, input, z, eps, std, nz, eta_min, p=2, EPS=1.): 238 | raise NotImplementedError 239 | 240 | def forward_hidden(self, input, std=None, nz=1): 241 | # init 242 | batch_size = input.size(0) 243 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 244 | 245 | # gen noise source 246 | eps = sample_noise(sz=[batch_size*nz, self.noise_dim], std=std, device=input.device) 247 | 248 | # sample z 249 | z = self.encode(input, noise=eps, std=std, nz=nz) 250 | 251 | return z 252 | 253 | def forward(self, input, beta=1.0, eta=0.0, lmbd=0.0, std=None, nz=1): 254 | # init 255 | batch_size = input.size(0) 256 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 257 | input_expanded = convert_4d_5d_tensor(input, sample_size=nz) 258 | input_expanded_flattened = input_expanded.view(batch_size*nz, self.input_channels, self.input_height, self.input_height) 259 | 260 | # gen noise source 261 | eps = sample_noise(sz=[batch_size*nz, self.noise_dim], std=std, device=input.device) 262 | 263 | # sample z 264 | z = self.encode(input, noise=eps, std=std, nz=nz) 265 | 266 | # z flattten 267 | z_flattened = z.view(batch_size*nz, -1) 268 | 269 | # decode 270 | x, logit_x = self.decode(z_flattened) 271 | 272 | # loss 273 | if lmbd > 0: 274 | raise NotImplementedError 275 | jaclmp_loss = lmbd*self.jac_clamping_loss(input, z, eps, std=std, nz=nz, eta_min=eta) 276 | else: 277 | jaclmp_loss = 0 278 | loss, recon_loss, prior_loss = self.loss( 279 | z_flattened, 280 | logit_x, input_expanded_flattened, 281 | beta=beta, 282 | ) 283 | loss += jaclmp_loss 284 | 285 | # return 286 | return x, torch.sigmoid(logit_x), z, loss, recon_loss.detach(), prior_loss.detach() 287 | 288 | def generate(self, batch_size=1): 289 | # init mu_z and logvar_z (as unit normal dist) 290 | weight = next(self.parameters()) 291 | mu_z = weight.new_zeros(batch_size, self.z_dim) 292 | logvar_z = weight.new_zeros(batch_size, self.z_dim) 293 | 294 | # sample z (from unit normal dist) 295 | z = sample_gaussian(mu_z, logvar_z) # sample z 296 | 297 | # decode 298 | output, logit_x = self.decode(z) 299 | 300 | # return 301 | return output, torch.sigmoid(logit_x), z 302 | 303 | def logprob(self, input, sample_size=128, z=None, std=None): 304 | return self.logprob_w_cov_gaussian_posterior(input, sample_size, z, std) 305 | 306 | def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None): 307 | # init 308 | batch_size = input.size(0) 309 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 310 | assert sample_size >= 2*self.z_dim 311 | 312 | ''' get z and pseudo log q(newz|x) ''' 313 | z, newz = [], [] 314 | logposterior = [] 315 | inp = self.encode._forward_inp(input).detach() 316 | for i in range(batch_size): 317 | _inp = inp[i:i+1, :].expand(sample_size, inp.size(1)) 318 | _nos = self.encode._forward_nos(batch_size=sample_size, std=std, device=input.device).detach() 319 | _z = self.encode._forward_all(_inp, _nos) # ssz x zdim 320 | z += [_z.detach().unsqueeze(0)] 321 | z = torch.cat(z, dim=0) # bsz x ssz x zdim 322 | mu_qz = torch.mean(z, dim=1) # bsz x zdim 323 | for i in range(batch_size): 324 | _cov_qz = get_covmat(z[i, :, :]) 325 | _rv_z = MultivariateNormal(mu_qz[i], _cov_qz) 326 | _newz = _rv_z.rsample(torch.Size([1, sample_size])) 327 | _logposterior = _rv_z.log_prob(_newz) 328 | 329 | newz += [_newz] 330 | logposterior += [_logposterior] 331 | newz = torch.cat(newz, dim=0) # bsz x ssz x zdim 332 | logposterior = torch.cat(logposterior, dim=0) # bsz x ssz 333 | 334 | ''' get log p(z) ''' 335 | # get prior (as unit normal dist) 336 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 337 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 338 | logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) 339 | logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 340 | 341 | ''' get log p(x|z) ''' 342 | # decode 343 | logit_x = [] 344 | #for i in range(sample_size): 345 | for i in range(batch_size): 346 | _, _logit_x = self.decode(newz[i, :, :]) # ssz x zdim 347 | logit_x += [_logit_x.detach().unsqueeze(0)] 348 | logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim 349 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_channels, self.input_height, self.input_height) # bsz x ssz x input_dim 350 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_x, _input, reduction='none') 351 | loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size, -1), dim=2) # bsz x ssz 352 | 353 | ''' get log p(x|z)p(z)/q(z|x) ''' 354 | logprob = loglikelihood + logprior - logposterior # bsz x ssz 355 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 356 | rprob = (logprob - logprob_max).exp() # relative prob 357 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 358 | 359 | # return 360 | return logprob.mean() 361 | -------------------------------------------------------------------------------- /models/layers2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Dec 11 13:58:12 2017 5 | @author: CW 6 | """ 7 | 8 | 9 | 10 | import math 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import Module 16 | from torch.nn import functional as F 17 | from torch.nn.parameter import Parameter 18 | from torch.nn.modules.utils import _pair 19 | from torch.autograd import Variable 20 | 21 | # aliasing 22 | N_ = None 23 | 24 | 25 | delta = 1e-6 26 | softplus_ = nn.Softplus() 27 | softplus = lambda x: softplus_(x) + delta 28 | sigmoid_ = nn.Sigmoid() 29 | sigmoid = lambda x: sigmoid_(x) * (1-delta) + 0.5 * delta 30 | sigmoid2 = lambda x: sigmoid(x) * 2.0 31 | logsigmoid = lambda x: -softplus(-x) 32 | logit = lambda x: torch.log 33 | log = lambda x: torch.log(x*1e2)-np.log(1e2) 34 | logit = lambda x: log(x) - log(1-x) 35 | def softmax(x, dim=-1): 36 | e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0]) 37 | out = e_x / e_x.sum(dim=dim, keepdim=True) 38 | return out 39 | 40 | sum1 = lambda x: x.sum(1) 41 | sum_from_one = lambda x: sum_from_one(sum1(x)) if len(x.size())>2 else sum1(x) 42 | 43 | 44 | 45 | class Sigmoid(Module): 46 | def forward(self, x): 47 | return sigmoid(x) 48 | 49 | 50 | class WNlinear(Module): 51 | 52 | def __init__(self, in_features, out_features, 53 | bias=True, mask=N_, norm=True): 54 | super(WNlinear, self).__init__() 55 | self.in_features = in_features 56 | self.out_features = out_features 57 | self.register_buffer('mask',mask) 58 | self.norm = norm 59 | self.direction = Parameter(torch.Tensor(out_features, in_features)) 60 | self.scale = Parameter(torch.Tensor(out_features)) 61 | if bias: 62 | self.bias = Parameter(torch.Tensor(out_features)) 63 | else: 64 | self.register_parameter('bias', N_) 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | stdv = 1. / math.sqrt(self.direction.size(1)) 69 | self.direction.data.uniform_(-stdv, stdv) 70 | self.scale.data.uniform_(1, 1) 71 | if self.bias is not N_: 72 | self.bias.data.uniform_(-stdv, stdv) 73 | 74 | def forward(self, input): 75 | if self.norm: 76 | dir_ = self.direction 77 | direction = dir_.div(dir_.pow(2).sum(1).sqrt()[:,N_]) 78 | weight = self.scale[:,N_].mul(direction) 79 | else: 80 | weight = self.scale[:,N_].mul(self.direction) 81 | if self.mask is not N_: 82 | #weight = weight * getattr(self.mask, 83 | # ('cpu', 'cuda')[weight.is_cuda])() 84 | weight = weight * Variable(self.mask) 85 | return F.linear(input, weight, self.bias) 86 | 87 | def __repr__(self): 88 | return self.__class__.__name__ + '(' \ 89 | + 'in_features=' + str(self.in_features) \ 90 | + ', out_features=' + str(self.out_features) + ')' 91 | 92 | 93 | 94 | 95 | class CWNlinear(Module): 96 | 97 | def __init__(self, in_features, out_features, context_features, 98 | mask=N_, norm=True): 99 | super(CWNlinear, self).__init__() 100 | self.in_features = in_features 101 | self.out_features = out_features 102 | self.context_features = context_features 103 | self.register_buffer('mask',mask) 104 | self.norm = norm 105 | self.direction = Parameter(torch.Tensor(out_features, in_features)) 106 | self.cscale = nn.Linear(context_features, out_features) 107 | self.cbias = nn.Linear(context_features, out_features) 108 | self.reset_parameters() 109 | self.cscale.weight.data.normal_(0, 0.001) 110 | self.cbias.weight.data.normal_(0, 0.001) 111 | 112 | def reset_parameters(self): 113 | self.direction.data.normal_(0, 0.001) 114 | 115 | def forward(self, inputs): 116 | input, context = inputs 117 | scale = self.cscale(context) 118 | bias = self.cbias(context) 119 | if self.norm: 120 | dir_ = self.direction 121 | direction = dir_.div(dir_.pow(2).sum(1).sqrt()[:,N_]) 122 | weight = direction 123 | else: 124 | weight = self.direction 125 | if self.mask is not N_: 126 | #weight = weight * getattr(self.mask, 127 | # ('cpu', 'cuda')[weight.is_cuda])() 128 | weight = weight * Variable(self.mask) 129 | return scale * F.linear(input, weight, None) + bias, context 130 | 131 | def __repr__(self): 132 | return self.__class__.__name__ + '(' \ 133 | + 'in_features=' + str(self.in_features) \ 134 | + ', out_features=' + str(self.out_features) + ')' 135 | 136 | 137 | 138 | 139 | class WNBilinear(Module): 140 | 141 | def __init__(self, in1_features, in2_features, out_features, bias=True): 142 | super(WNBilinear, self).__init__() 143 | self.in1_features = in1_features 144 | self.in2_features = in2_features 145 | self.out_features = out_features 146 | self.direction = Parameter(torch.Tensor( 147 | out_features, in1_features, in2_features)) 148 | self.scale = Parameter(torch.Tensor(out_features)) 149 | if bias: 150 | self.bias = Parameter(torch.Tensor(out_features)) 151 | else: 152 | self.register_parameter('bias', N_) 153 | self.reset_parameters() 154 | 155 | def reset_parameters(self): 156 | stdv = 1. / math.sqrt(self.direction.size(1)) 157 | self.direction.data.uniform_(-stdv, stdv) 158 | self.scale.data.uniform_(1, 1) 159 | if self.bias is not N_: 160 | self.bias.data.uniform_(-stdv, stdv) 161 | 162 | def forward(self, input1, input2): 163 | dir_ = self.direction 164 | direction = dir_.div(dir_.pow(2).sum(1).sum(1).sqrt()[:,N_,N_]) 165 | weight = self.scale[:,N_,N_].mul(direction) 166 | return F.bilinear(input1, input2, weight, self.bias) 167 | 168 | def __repr__(self): 169 | return self.__class__.__name__ + '(' \ 170 | + 'in1_features=' + str(self.in1_features) \ 171 | + ', in2_features=' + str(self.in2_features) \ 172 | + ', out_features=' + str(self.out_features) + ')' 173 | 174 | 175 | 176 | class _WNconvNd(Module): 177 | 178 | def __init__(self, in_channels, out_channels, kernel_size, stride, 179 | padding, dilation, transposed, output_padding, groups, bias): 180 | super(_WNconvNd, self).__init__() 181 | if in_channels % groups != 0: 182 | raise ValueError('in_channels must be divisible by groups') 183 | if out_channels % groups != 0: 184 | raise ValueError('out_channels must be divisible by groups') 185 | self.in_channels = in_channels 186 | self.out_channels = out_channels 187 | self.kernel_size = kernel_size 188 | self.stride = stride 189 | self.padding = padding 190 | self.dilation = dilation 191 | self.transposed = transposed 192 | self.output_padding = output_padding 193 | self.groups = groups 194 | 195 | # weight – filters tensor (out_channels x in_channels/groups x kH x kW) 196 | if transposed: 197 | self.direction = Parameter(torch.Tensor( 198 | in_channels, out_channels // groups, *kernel_size)) 199 | self.scale = Parameter(torch.Tensor(in_channels)) 200 | else: 201 | self.direction = Parameter(torch.Tensor( 202 | out_channels, in_channels // groups, *kernel_size)) 203 | self.scale = Parameter(torch.Tensor(out_channels)) 204 | if bias: 205 | self.bias = Parameter(torch.Tensor(out_channels)) 206 | else: 207 | self.register_parameter('bias', N_) 208 | self.reset_parameters() 209 | 210 | def reset_parameters(self): 211 | n = self.in_channels 212 | for k in self.kernel_size: 213 | n *= k 214 | stdv = 1. / math.sqrt(n) 215 | self.direction.data.uniform_(-stdv, stdv) 216 | self.scale.data.uniform_(1, 1) 217 | if self.bias is not N_: 218 | self.bias.data.uniform_(-stdv, stdv) 219 | 220 | def __repr__(self): 221 | s = ('{name} ({in_channels}, {out_channels}, kernel_size={kernel_size}' 222 | ', stride={stride}') 223 | if self.padding != (0,) * len(self.padding): 224 | s += ', padding={padding}' 225 | if self.dilation != (1,) * len(self.dilation): 226 | s += ', dilation={dilation}' 227 | if self.output_padding != (0,) * len(self.output_padding): 228 | s += ', output_padding={output_padding}' 229 | if self.groups != 1: 230 | s += ', groups={groups}' 231 | if self.bias is N_: 232 | s += ', bias=False' 233 | s += ')' 234 | return s.format(name=self.__class__.__name__, **self.__dict__) 235 | 236 | 237 | class WNconv2d(_WNconvNd): 238 | 239 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 240 | padding=0, dilation=1, groups=1, bias=True, 241 | mask=N_, norm=True): 242 | kernel_size = _pair(kernel_size) 243 | stride = _pair(stride) 244 | padding = _pair(padding) 245 | dilation = _pair(dilation) 246 | super(WNconv2d, self).__init__( 247 | in_channels, out_channels, kernel_size, stride, padding, dilation, 248 | False, _pair(0), groups, bias) 249 | 250 | self.register_buffer('mask',mask) 251 | self.norm = norm 252 | 253 | def forward(self, input): 254 | if self.norm: 255 | dir_ = self.direction 256 | direction = dir_.div( 257 | dir_.pow(2).sum(1).sum(1).sum(1).sqrt()[:,N_,N_,N_]) 258 | weight = self.scale[:,N_,N_,N_].mul(direction) 259 | else: 260 | weight = self.scale[:,N_,N_,N_].mul(self.direction) 261 | if self.mask is not None: 262 | weight = weight * Variable(self.mask) 263 | return F.conv2d(input, weight, self.bias, self.stride, 264 | self.padding, self.dilation, self.groups) 265 | 266 | 267 | class CWNconv2d(_WNconvNd): 268 | 269 | def __init__(self, context_features, in_channels, out_channels, 270 | kernel_size, stride=1, 271 | padding=0, dilation=1, groups=1, 272 | mask=N_, norm=True): 273 | kernel_size = _pair(kernel_size) 274 | stride = _pair(stride) 275 | padding = _pair(padding) 276 | dilation = _pair(dilation) 277 | super(CWNconv2d, self).__init__( 278 | in_channels, out_channels, kernel_size, stride, padding, dilation, 279 | False, _pair(0), groups, False) 280 | 281 | self.register_buffer('mask',mask) 282 | self.norm = norm 283 | self.cscale = nn.Linear(context_features, out_channels) 284 | self.cbias = nn.Linear(context_features, out_channels) 285 | 286 | def forward(self, inputs): 287 | input, context = inputs 288 | scale = self.cscale(context)[:,:,N_,N_] 289 | bias = self.cbias(context)[:,:,N_,N_] 290 | if self.norm: 291 | dir_ = self.direction 292 | direction = dir_.div( 293 | dir_.pow(2).sum(1).sum(1).sum(1).sqrt()[:,N_,N_,N_]) 294 | weight = direction 295 | else: 296 | weight = self.direction 297 | if self.mask is not None: 298 | weight = weight * Variable(self.mask) 299 | pre = F.conv2d( 300 | input, weight, None, self.stride, 301 | self.padding, self.dilation, self.groups) 302 | return pre * scale + bias, context 303 | 304 | 305 | class ResConv2d(nn.Module): 306 | 307 | def __init__( 308 | self, in_channels, out_channels, kernel_size, stride=1, 309 | padding=0, dilation=1, groups=1, bias=True, activation=nn.ReLU(), 310 | oper=WNconv2d): 311 | super(ResConv2d, self).__init__() 312 | 313 | self.conv_0h = oper( 314 | in_channels, out_channels, kernel_size, stride, 315 | padding, dilation, groups, bias) 316 | self.conv_h1 = oper( 317 | out_channels, out_channels, 3, 1, 1, 1, 1, True) 318 | self.conv_01 = oper( 319 | in_channels, out_channels, kernel_size, stride, 320 | padding, dilation, groups, bias) 321 | 322 | self.activation = activation 323 | 324 | def forward(self, input): 325 | h = self.activation(self.conv_0h(input)) 326 | out_nonlinear = self.conv_h1(h) 327 | out_skip = self.conv_01(input) 328 | return out_nonlinear + out_skip 329 | 330 | 331 | class ResLinear(nn.Module): 332 | 333 | def __init__( 334 | self, in_features, out_features, bias=True, same_dim=False, 335 | activation=nn.ReLU(), oper=WNlinear): 336 | super(ResLinear, self).__init__() 337 | 338 | self.same_dim = same_dim 339 | 340 | self.dot_0h = oper(in_features, out_features, bias) 341 | self.dot_h1 = oper(out_features, out_features, bias) 342 | if not same_dim: 343 | self.dot_01 = oper(in_features, out_features, bias) 344 | 345 | self.activation = activation 346 | 347 | def forward(self, input): 348 | h = self.activation(self.dot_0h(input)) 349 | out_nonlinear = self.dot_h1(h) 350 | out_skip = input if self.same_dim else self.dot_01(input) 351 | return out_nonlinear + out_skip 352 | 353 | 354 | 355 | class GatingLinear(nn.Module): 356 | 357 | def __init__( 358 | self, in_features, out_features, oper=WNlinear, **kwargs): 359 | super(GatingLinear, self).__init__() 360 | 361 | 362 | self.dot = oper(in_features, out_features, **kwargs) 363 | self.gate = oper(in_features, out_features, **kwargs) 364 | 365 | def forward(self, input): 366 | h = self.dot(input) 367 | s = sigmoid_(self.gate(input)) 368 | return s * h 369 | 370 | 371 | 372 | 373 | class Reshape(nn.Module): 374 | 375 | def __init__(self, shape): 376 | super(Reshape, self).__init__() 377 | self.shape = shape 378 | 379 | def forward(self, input): 380 | return input.view(self.shape) 381 | 382 | 383 | class Slice(nn.Module): 384 | 385 | def __init__(self, slc): 386 | super(Slice, self).__init__() 387 | self.slc = slc 388 | 389 | def forward(self, input): 390 | return input.__getitem__(self.slc) 391 | 392 | 393 | class SliceFactory(object): 394 | def __init__(self): 395 | pass 396 | 397 | def __getitem__(self, slc): 398 | return Slice(slc) 399 | slicer = SliceFactory() 400 | 401 | 402 | class Lambda(nn.Module): 403 | 404 | def __init__(self, function): 405 | super(Lambda, self).__init__() 406 | self.function = function 407 | 408 | def forward(self, input): 409 | return self.function(input) 410 | 411 | 412 | class SequentialFlow(nn.Sequential): 413 | 414 | def sample(self, n=1, context=None, **kwargs): 415 | dim = self[0].dim 416 | if isinstance(dim, int): 417 | dim = [dim,] 418 | 419 | spl = torch.autograd.Variable(torch.FloatTensor(n,*dim).normal_()) 420 | lgd = torch.autograd.Variable(torch.from_numpy( 421 | np.random.rand(n).astype('float32'))) 422 | if context is None: 423 | context = torch.autograd.Variable(torch.from_numpy( 424 | np.zeros((n,self[0].context_dim)).astype('float32'))) 425 | 426 | if hasattr(self, 'gpu'): 427 | if self.gpu: 428 | spl = spl.cuda() 429 | lgd = lgd.cuda() 430 | context = context.cuda() 431 | 432 | return self.forward((spl, lgd, context)) 433 | 434 | 435 | 436 | def cuda(self): 437 | self.gpu = True 438 | return super(SequentialFlow, self).cuda() 439 | 440 | 441 | class ContextWrapper(nn.Module): 442 | def __init__(self, module): 443 | super(ContextWrapper, self).__init__() 444 | self.module = module 445 | 446 | def forward(self, inputs): 447 | input, context = inputs 448 | output = self.module.forward(input) 449 | return output, context 450 | 451 | 452 | if __name__ == '__main__': 453 | 454 | mdl = CWNlinear(2,5,3) 455 | 456 | 457 | inp = torch.autograd.Variable( 458 | torch.from_numpy(np.random.rand(2,2).astype('float32'))) 459 | con = torch.autograd.Variable( 460 | torch.from_numpy(np.random.rand(2,3).astype('float32'))) 461 | 462 | print(mdl((inp, con))[0].size()) 463 | -------------------------------------------------------------------------------- /models/reparam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical 5 | 6 | 7 | MIN_LOGVAR = -4. 8 | MAX_LOGVAR = 2. 9 | 10 | 11 | ''' NormalDistribution ''' 12 | class NormalDistribution(nn.Module): 13 | def __init__(self, nonlinearity=None): 14 | super(NormalDistribution, self).__init__() 15 | self.nonlinearity = nonlinearity 16 | 17 | def clip_logvar(self, logvar): 18 | # clip logvar values 19 | if self.nonlinearity == 'hard': 20 | logvar = torch.max(logvar, MIN_LOGVAR*torch.ones_like(logvar)) 21 | logvar = torch.min(logvar, MAX_LOGVAR*torch.ones_like(logvar)) 22 | elif self.nonlinearity == 'softplus': 23 | logvar = F.softplus(logvar) 24 | elif self.nonlinearity == 'spm10': 25 | logvar = F.softplus(logvar+10.) - 10. 26 | elif self.nonlinearity == 'spm6': 27 | logvar = F.softplus(logvar+6.) - 6. 28 | elif self.nonlinearity == 'spm5': 29 | logvar = F.softplus(logvar+5.) - 5. 30 | elif self.nonlinearity == 'spm4': 31 | logvar = F.softplus(logvar+4.) - 4. 32 | elif self.nonlinearity == 'spm3': 33 | logvar = F.softplus(logvar+3.) - 3. 34 | elif self.nonlinearity == 'spm2': 35 | logvar = F.softplus(logvar+2.) - 2. 36 | elif self.nonlinearity == 'tanh': 37 | logvar = F.tanh(logvar) 38 | elif self.nonlinearity == '2tanh': 39 | logvar = 2.0*F.tanh(logvar) 40 | return logvar 41 | 42 | def sample_gaussian(self, mu, logvar): 43 | #if self.training: 44 | # std = torch.exp(0.5*logvar) 45 | # eps = torch.randn_like(std) 46 | # return mu + std * eps 47 | #else: 48 | # return mu 49 | std = torch.exp(0.5*logvar) 50 | eps = torch.randn_like(std) 51 | return mu + std * eps 52 | 53 | #def forward(self, input): 54 | # raise NotImplementedError() 55 | def forward(self, input): 56 | mu = self.mean_fn(input) 57 | logvar = self.clip_logvar(self.logvar_fn(input)) 58 | return mu, logvar 59 | #output = self.sample_gaussian(mu, logvar) 60 | #return mu, logvar, output 61 | 62 | class NormalDistributionLinear(NormalDistribution): 63 | def __init__(self, input_size, output_size, nonlinearity=None): 64 | super(NormalDistributionLinear, self).__init__(nonlinearity=nonlinearity) 65 | 66 | self.input_size = input_size 67 | self.output_size = output_size 68 | 69 | # define net 70 | self.mean_fn = nn.Linear(input_size, output_size) 71 | self.logvar_fn = nn.Linear(input_size, output_size) 72 | 73 | #def forward(self, input): 74 | # mu = self.mean_fn(input) 75 | # logvar = self.clip_logvar(self.logvar_fn(input)) 76 | # return mu, logvar 77 | 78 | class NormalDistributionConv2d(NormalDistribution): 79 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, nonlinearity=None): 80 | super(NormalDistributionConv2d, self).__init__(nonlinearity=nonlinearity) 81 | 82 | # define net 83 | self.mean_fn = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 84 | self.logvar_fn = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 85 | 86 | #def forward(self, input): 87 | # mu = self.mean_fn(input) 88 | # logvar = self.clip_logvar(self.logvar_fn(input)) 89 | # return mu, logvar 90 | 91 | class NormalDistributionConvTranspose2d(NormalDistribution): 92 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, bias=True, nonlinearity=None): 93 | super(NormalDistributionConvTranspose2d, self).__init__(nonlinearity=nonlinearity) 94 | 95 | # define net 96 | self.mean_fn = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=bias) 97 | self.logvar_fn = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=bias) 98 | 99 | #def forward(self, input): 100 | # mu = self.mean_fn(input) 101 | # logvar = self.clip_logvar(self.logvar_fn(input)) 102 | # return mu, logvar 103 | 104 | 105 | ''' BernoulliDistribution ''' 106 | class BernoulliDistribution(nn.Module): 107 | def __init__(self, hard=False): 108 | super(BernoulliDistribution, self).__init__() 109 | self.hard = hard 110 | 111 | def _sample_logistic(self, logits, eps=1e-20): 112 | ''' Sample from Logistic(0, 1) ''' 113 | noise = torch.rand_like(logits) 114 | return torch.log(torch.div(noise, 1.-noise) + eps) 115 | 116 | def _sample_logistic_sigmoid(self, logits, temperature): 117 | ''' Draw a sample from the Logistic-Sigmoid distribution (Binary Concrete distribution) ''' 118 | ''' See, https://arxiv.org/abs/1611.00712 ''' 119 | y = logits + self._sample_logistic(logits) 120 | return torch.sigmoid(y / temperature) 121 | 122 | def sample_logistic_sigmoid(self, logits, temperature=1.0, hard=False): 123 | ''' Sample from the Logistic-Sigmoid distribution and optionally discretize. 124 | Args: 125 | logits: [batch_size, output_size] unnormalized log-probs 126 | temperature: non-negative scalar 127 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 128 | Returns: 129 | [batch_size, output_size] sample from the Logistic-Sigmoid distribution. 130 | If hard=True, then the returned sample will be one-hot, otherwise it will 131 | be a probabilitiy distribution that sums to 1 across classes 132 | ''' 133 | y = self._sample_logistic_sigmoid(logits, temperature) 134 | 135 | if hard: 136 | raise NotImplementedError('current code for torch 0.1') 137 | # see https://github.com/yandexdataschool/gumbel_lstm/blob/master/gumbel_sigmoid.py 138 | #"""computes a hard indicator function. Not differentiable""" 139 | #y = T.switch(T.gt(logits,0),1,0) 140 | 141 | # check dimension 142 | assert y.dim() == 2 143 | 144 | # init y_hard 145 | if args.cuda: 146 | y_hard = torch.cuda.FloatTensor(logits.size()).zero_() 147 | else: 148 | y_hard = torch.FloatTensor(logits.size()).zero_() 149 | y_hard = Variable(y_hard) 150 | 151 | # get hard representation for y (into y_hard) 152 | y_hard.data[torch.gt(y.data, 0.5)] = 1.0 153 | 154 | # get hard 155 | y_hard.data.add(-1, y.data) 156 | y = y_hard + y 157 | 158 | return y 159 | 160 | def forward(self, input): 161 | raise NotImplementedError() 162 | 163 | class BernoulliDistributionLinear(BernoulliDistribution): 164 | def __init__(self, input_size, output_size, hard=False): 165 | super(BernoulliDistributionLinear, self).__init__(hard=hard) 166 | self.input_size = input_size 167 | self.output_size = output_size 168 | 169 | # define net 170 | self.logit_fn = nn.Linear(input_size, output_size) 171 | 172 | def forward(self, input): 173 | logits = self.logit_fn(input) 174 | return logits 175 | #output = self.sample_logistic_sigmoid(logits, temperature, self.hard) 176 | #return logits, output 177 | 178 | class BernoulliDistributionConv2d(BernoulliDistribution): 179 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, hard=False): 180 | super(BernoulliDistributionConv2d, self).__init__(hard=hard) 181 | 182 | # define net 183 | self.logit_fn = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 184 | 185 | def forward(self, input): 186 | logits = self.logit_fn(input) 187 | return logits 188 | #output = self.sample_logistic_sigmoid(logits, temperature, self.hard) 189 | #return logits, output 190 | 191 | class BernoulliDistributionConvTranspose2d(BernoulliDistribution): 192 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, bias=True, hard=False): 193 | super(BernoulliDistributionConvTranspose2d, self).__init__(hard=hard) 194 | 195 | # define net 196 | self.logit_fn = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=bias) 197 | 198 | def forward(self, input): 199 | logits = self.logit_fn(input) 200 | return logits 201 | #output = self.sample_logistic_sigmoid(logits, temperature, self.hard) 202 | #return logits, output 203 | 204 | 205 | ''' CategoricalDistribution ''' 206 | class CategoricalDistribution(nn.Module): 207 | def __init__(self, hard=False): 208 | super(CategoricalDistribution, self).__init__() 209 | self.hard = hard 210 | 211 | def _sample_gumbel(self, input, eps=1e-20): 212 | ''' Sample from Gumbel(0, 1) ''' 213 | noise = torch.rand_like(input) 214 | return -torch.log(-torch.log(noise + eps) + eps) 215 | 216 | def _sample_gumbel_softmax(self, logits, temperature): 217 | ''' Draw a sample from the Gumbel-Softmax distribution ''' 218 | y = logits + self._sample_gumbel(input) 219 | return self.softmax(y / temperature) 220 | 221 | def sample_gumbel_softmax(self, logits, temperature=1.0, hard=False): 222 | ''' Sample from the Gumbel-Softmax distribution and optionally discretize. 223 | Args: 224 | logits: [batch_size, num_class] unnormalized log-probs 225 | temperature: non-negative scalar 226 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 227 | Returns: 228 | [batch_size, num_class] sample from the Gumbel-Softmax distribution. 229 | If hard=True, then the returned sample will be one-hot, otherwise it will 230 | be a probabilitiy distribution that sums to 1 across classes 231 | ''' 232 | y = self._sample_gumbel_softmax(logits, temperature) 233 | 234 | if hard: 235 | raise NotImplementedError('current code for torch 0.1') 236 | #k = tf.shape(logits)[-1] 237 | ##y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) 238 | #y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype) 239 | #y = tf.stop_gradient(y_hard - y) + y 240 | 241 | # check dimension 242 | assert y.dim() == 2 243 | 244 | # init y_hard 245 | if args.cuda: 246 | y_hard = torch.cuda.FloatTensor(logits.size()).zero_() 247 | else: 248 | y_hard = torch.FloatTensor(logits.size()).zero_() 249 | y_hard = Variable(y_hard) 250 | 251 | # get one-hot representation for y (into y_hard) 252 | val, ind = torch.max(y.data, 1) 253 | y_hard.data.scatter_(1, ind, 1) 254 | 255 | # get hard 256 | y_hard.data.add(-1, y.data) 257 | y = y_hard + y 258 | 259 | return y 260 | 261 | def forward(self, input): 262 | raise NotImplementedError() 263 | 264 | class CategoricalDistributionLinear(CategoricalDistribution): 265 | def __init__(self, input_size, num_class, hard=False): 266 | super(CategoricalDistributionLinear, self).__init__(hard=hard) 267 | 268 | self.input_size = input_size 269 | self.num_class = num_class 270 | 271 | # define net 272 | self.logit_fn = nn.Linear(input_size, num_class) 273 | 274 | def forward(self, input): 275 | logits = self.logit_fn(input) 276 | return logits 277 | #output = self.gumbel_softmax(logits, temperature, self.hard) 278 | #return logits, output 279 | 280 | class CategoricalDistributionConv2d(CategoricalDistribution): 281 | def __init__(self, in_channels, num_class, kernel_size, stride=1, padding=0, hard=False): 282 | super(CategoricalDistributionConv2d, self).__init__(hard=hard) 283 | self.num_class = num_class 284 | 285 | # define net 286 | self.logit_fn = nn.Conv2d(in_channels, num_class, kernel_size, stride, padding) 287 | 288 | def forward(self, input): 289 | logits = self.logit_fn(input) 290 | 291 | return logits 292 | #output = self.gumbel_softmax(logits, temperature, self.hard) 293 | #return logits, output 294 | 295 | def sample_gumbel_softmax(self, logits, temperature=1.0, hard=False): 296 | batch_size = logits.size(0) 297 | n_channels = logits.size(1) 298 | height = logits.size(2) 299 | width = logits.size(3) 300 | 301 | # 4d tensor [b c h w] to 2d tensor [bhw c] 302 | logits = torch.transpose(logits.view(batch_size, n_channels, height*width), 1, 2).view(batch_size*height*width, n_channels) 303 | 304 | # sample 305 | y = super(CategoricalDistributionConv2d, self).sample_gumbel_softmax(logits, temperature, hard) 306 | 307 | # 2d tensor [bhw c] to 4d tensor [b c h w] 308 | y = torch.transpose(y.view(batch_size, height*width, n_channels), 1, 2).view(batch_size, n_channels, height, width) 309 | return y 310 | -------------------------------------------------------------------------------- /models/resdae/mlp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from utils import get_nonlinear_func, expand_tensor, sample_laplace_noise, sample_unit_laplace_noise 9 | from models.layers import MLP, WNMLP, Identity 10 | 11 | 12 | def add_gaussian_noise(input, std): 13 | eps = torch.randn_like(input) 14 | return input + std*eps, eps 15 | 16 | def add_uniform_noise(input, val): 17 | #raise NotImplementedError 18 | #eps = 2.*val*torch.rand_like(input) - val 19 | eps = torch.rand_like(input) 20 | return input + 2.*val*eps-val, eps 21 | 22 | def add_laplace_noise(input, scale): 23 | eps = sample_unit_laplace_noise(shape=input.size(), dtype=input.dtype, device=input.device) 24 | return input + scale*eps, eps 25 | 26 | 27 | class DAE(nn.Module): 28 | def __init__(self, 29 | input_dim=2, 30 | h_dim=1000, 31 | std=0.1, 32 | num_hidden_layers=1, 33 | nonlinearity='tanh', 34 | noise_type='gaussian', 35 | #init=True, 36 | ): 37 | super().__init__() 38 | self.input_dim = input_dim 39 | self.h_dim = h_dim 40 | self.std = std 41 | self.num_hidden_layers = num_hidden_layers 42 | self.nonlinearity = nonlinearity 43 | self.noise_type = noise_type 44 | 45 | self.main = MLP(input_dim, h_dim, input_dim, use_nonlinearity_output=False, num_hidden_layers=num_hidden_layers, nonlinearity=nonlinearity) 46 | 47 | def add_noise(self, input, std=None): 48 | std = self.std if std is None else std 49 | if self.noise_type == 'gaussian': 50 | return add_gaussian_noise(input, std) 51 | elif self.noise_type == 'uniform': 52 | return add_uniform_noise(input, std) 53 | elif self.noise_type == 'laplace': 54 | return add_laplace_noise(input, std) 55 | else: 56 | raise NotImplementedError 57 | 58 | def loss(self, input, target): 59 | # recon loss (likelihood) 60 | recon_loss = F.mse_loss(input, target)#, reduction='sum') 61 | return recon_loss 62 | 63 | def forward(self, input, std=None): 64 | # init 65 | std = self.std if std is None else std 66 | batch_size = input.size(0) 67 | input = input.view(-1, self.input_dim) 68 | 69 | # add noise 70 | x_bar, eps = self.add_noise(input, std) 71 | 72 | # predict 73 | glogprob = self.main(x_bar) 74 | 75 | ''' get loss ''' 76 | #loss = (std**2)*self.loss(std*glogprob, -eps) 77 | loss = self.loss(std*glogprob, -eps) 78 | 79 | # return 80 | return None, loss 81 | 82 | def glogprob(self, input, std=None): 83 | std = self.std if std is None else std 84 | batch_size = input.size(0) 85 | input = input.view(-1, self.input_dim) 86 | 87 | # predict 88 | glogprob = self.main(input) 89 | 90 | return glogprob 91 | 92 | class ARDAE(nn.Module): 93 | def __init__(self, 94 | input_dim=2, 95 | h_dim=1000, 96 | std=0.1, 97 | num_hidden_layers=1, 98 | nonlinearity='tanh', 99 | noise_type='gaussian', 100 | #init=True, 101 | ): 102 | super().__init__() 103 | self.input_dim = input_dim 104 | self.h_dim = h_dim 105 | self.std = std 106 | self.num_hidden_layers = num_hidden_layers 107 | self.nonlinearity = nonlinearity 108 | self.noise_type = noise_type 109 | #self.init = init 110 | 111 | self.main = MLP(input_dim+1, h_dim, input_dim, use_nonlinearity_output=False, num_hidden_layers=num_hidden_layers, nonlinearity=nonlinearity) 112 | 113 | def add_noise(self, input, std=None): 114 | std = self.std if std is None else std 115 | if self.noise_type == 'gaussian': 116 | return add_gaussian_noise(input, std) 117 | elif self.noise_type == 'uniform': 118 | return add_uniform_noise(input, std) 119 | elif self.noise_type == 'laplace': 120 | return add_laplace_noise(input, std) 121 | else: 122 | raise NotImplementedError 123 | 124 | def loss(self, input, target): 125 | # recon loss (likelihood) 126 | recon_loss = F.mse_loss(input, target)#, reduction='sum') 127 | return recon_loss 128 | 129 | def forward(self, input, std=None): 130 | # init 131 | batch_size = input.size(0) 132 | input = input.view(-1, self.input_dim) 133 | if std is None: 134 | std = input.new_zeros(batch_size, 1) 135 | else: 136 | assert torch.is_tensor(std) 137 | 138 | # add noise 139 | x_bar, eps = self.add_noise(input, std) 140 | 141 | # concat 142 | h = torch.cat([x_bar, std], dim=1) 143 | 144 | # predict 145 | glogprob = self.main(h) 146 | 147 | ''' get loss ''' 148 | loss = self.loss(std*glogprob, -eps) 149 | 150 | # return 151 | return None, loss 152 | 153 | def glogprob(self, input, std=None): 154 | batch_size = input.size(0) 155 | input = input.view(-1, self.input_dim) 156 | if std is None: 157 | std = input.new_zeros(batch_size, 1) 158 | else: 159 | assert torch.is_tensor(std) 160 | 161 | # concat 162 | h = torch.cat([input, std], dim=1) 163 | 164 | # predict 165 | glogprob = self.main(h) 166 | 167 | return glogprob 168 | 169 | 170 | class ConditionalDAE(nn.Module): 171 | def __init__(self, 172 | input_dim=2, #10, 173 | h_dim=128, 174 | context_dim=2, 175 | std=0.01, 176 | num_hidden_layers=1, 177 | nonlinearity='tanh', 178 | noise_type='gaussian', 179 | enc_input=True, 180 | enc_ctx=True, 181 | #init=True, 182 | ): 183 | super().__init__() 184 | self.input_dim = input_dim 185 | self.h_dim = h_dim 186 | self.context_dim = context_dim 187 | self.std = std 188 | self.num_hidden_layers = num_hidden_layers 189 | self.nonlinearity = nonlinearity 190 | self.noise_type = noise_type 191 | self.enc_input = enc_input 192 | if self.enc_input: 193 | inp_dim = h_dim 194 | else: 195 | inp_dim = input_dim 196 | self.enc_ctx = enc_ctx 197 | if self.enc_ctx: 198 | ctx_dim = h_dim 199 | else: 200 | ctx_dim = context_dim 201 | #self.init = init 202 | 203 | self.ctx_encode = Identity() if not self.enc_ctx \ 204 | else MLP(context_dim, h_dim, h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 205 | self.inp_encode = Identity() if not self.enc_input \ 206 | else MLP(input_dim, h_dim, h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 207 | self.dae = MLP(inp_dim+ctx_dim, h_dim, input_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers, use_nonlinearity_output=False) 208 | 209 | def reset_parameters(self): 210 | nn.init.normal_(self.dae.fc.weight) 211 | 212 | def add_noise(self, input, std=None): 213 | std = self.std if std is None else std 214 | if self.noise_type == 'gaussian': 215 | return add_gaussian_noise(input, std) 216 | elif self.noise_type == 'uniform': 217 | return add_uniform_noise(input, std) 218 | elif self.noise_type == 'laplace': 219 | return add_laplace_noise(input, std) 220 | else: 221 | raise NotImplementedError 222 | 223 | def loss(self, input, target): 224 | # recon loss (likelihood) 225 | recon_loss = F.mse_loss(input, target)#, reduction='sum') 226 | return recon_loss 227 | 228 | def forward(self, input, context, std=None): 229 | # init 230 | assert input.dim() == 3 # bsz x ssz x x_dim 231 | assert context.dim() == 3 # bsz x 1 x ctx_dim 232 | std = self.std if std is None else std 233 | batch_size = input.size(0) 234 | sample_size = input.size(1) 235 | 236 | # reschape 237 | input = input.view(batch_size*sample_size, self.input_dim) # bsz*ssz x xdim 238 | _, context = expand_tensor(context, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim 239 | #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim 240 | 241 | # add noise 242 | x_bar, eps = self.add_noise(input, std) 243 | 244 | # encode 245 | ctx = self.ctx_encode(context) 246 | inp = self.inp_encode(x_bar) 247 | 248 | # concat 249 | h = torch.cat([inp, ctx], dim=1) 250 | 251 | # de-noise with context 252 | glogprob = self.dae(h) 253 | 254 | ''' get loss ''' 255 | #loss = (std**2)*self.loss(std*glogprob, -eps) 256 | loss = self.loss(std*glogprob, -eps) 257 | 258 | # return 259 | return None, loss 260 | 261 | def glogprob(self, input, context, std=None): 262 | # init 263 | assert input.dim() == 3 # bsz x ssz x x_dim 264 | assert context.dim() == 3 # bsz x 1 x ctx_dim 265 | std = self.std if std is None else std 266 | batch_size = input.size(0) 267 | sample_size = input.size(1) 268 | 269 | # reschape 270 | input = input.view(batch_size*sample_size, self.input_dim) # bsz*ssz x xdim 271 | _, context = expand_tensor(context, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim 272 | #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim 273 | 274 | # encode 275 | ctx = self.ctx_encode(context) 276 | inp = self.inp_encode(input) 277 | 278 | # concat 279 | h = torch.cat([inp, ctx], dim=1) 280 | 281 | # de-noise with context 282 | glogprob = self.dae(h) 283 | 284 | return glogprob.view(batch_size, sample_size, self.input_dim) 285 | 286 | class ConditionalARDAE(nn.Module): 287 | def __init__(self, 288 | input_dim=2, #10, 289 | h_dim=128, 290 | context_dim=2, 291 | std=0.01, 292 | num_hidden_layers=1, 293 | nonlinearity='tanh', 294 | noise_type='gaussian', 295 | enc_input=True, 296 | enc_ctx=True, 297 | #init=True, 298 | std_method='default', 299 | ): 300 | super().__init__() 301 | self.input_dim = input_dim 302 | self.h_dim = h_dim 303 | self.context_dim = context_dim 304 | self.std = std 305 | self.num_hidden_layers = num_hidden_layers 306 | self.nonlinearity = nonlinearity 307 | self.noise_type = noise_type 308 | self.enc_input = enc_input 309 | if self.enc_input: 310 | inp_dim = h_dim 311 | else: 312 | inp_dim = input_dim 313 | self.enc_ctx = enc_ctx 314 | if self.enc_ctx: 315 | ctx_dim = h_dim 316 | else: 317 | ctx_dim = context_dim 318 | 319 | self.ctx_encode = Identity() if not self.enc_ctx \ 320 | else MLP(context_dim, h_dim, h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 321 | self.inp_encode = Identity() if not self.enc_input \ 322 | else MLP(input_dim, h_dim, h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 323 | self.dae = MLP(inp_dim+ctx_dim+1, h_dim, input_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers, use_nonlinearity_output=False) 324 | 325 | def reset_parameters(self): 326 | nn.init.normal_(self.dae.fc.weight) 327 | 328 | def add_noise(self, input, std=None): 329 | std = self.std if std is None else std 330 | if self.noise_type == 'gaussian': 331 | return add_gaussian_noise(input, std) 332 | elif self.noise_type == 'uniform': 333 | return add_uniform_noise(input, std) 334 | elif self.noise_type == 'laplace': 335 | return add_laplace_noise(input, std) 336 | else: 337 | raise NotImplementedError 338 | 339 | def loss(self, input, target): 340 | # recon loss (likelihood) 341 | recon_loss = F.mse_loss(input, target)#, reduction='sum') 342 | return recon_loss 343 | 344 | def forward(self, input, context, std=None, scale=None): 345 | # init 346 | assert input.dim() == 3 # bsz x ssz x x_dim 347 | assert context.dim() == 3 # bsz x 1 x ctx_dim 348 | batch_size = input.size(0) 349 | sample_size = input.size(1) 350 | if std is None: 351 | std = input.new_zeros(batch_size, sample_size, 1) 352 | else: 353 | assert torch.is_tensor(std) 354 | if scale is None: 355 | scale = 1. 356 | 357 | # reschape 358 | input = input.view(batch_size*sample_size, self.input_dim) # bsz*ssz x xdim 359 | _, context = expand_tensor(context, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim 360 | #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim 361 | std = std.view(batch_size*sample_size, 1) 362 | 363 | # add noise 364 | x_bar, eps = self.add_noise(input, std) 365 | 366 | # encode 367 | ctx = self.ctx_encode(context) 368 | inp = self.inp_encode(x_bar) 369 | 370 | # concat 371 | h = torch.cat([inp, ctx, std], dim=1) 372 | 373 | # de-noise with context 374 | glogprob = self.dae(h) 375 | 376 | ''' get loss ''' 377 | #loss = (std**2)*self.loss(std*glogprob, -eps) 378 | loss = self.loss(std*glogprob, -eps) 379 | 380 | # return 381 | return None, loss 382 | 383 | def glogprob(self, input, context, std=None, scale=None): 384 | # init 385 | assert input.dim() == 3 # bsz x ssz x x_dim 386 | assert context.dim() == 3 # bsz x 1 x ctx_dim 387 | #std = self.std if std is None else std 388 | batch_size = input.size(0) 389 | sample_size = input.size(1) 390 | if std is None: 391 | std = input.new_zeros(batch_size*sample_size, 1) 392 | else: 393 | assert torch.is_tensor(std) 394 | if scale is None: 395 | scale = 1. 396 | 397 | # reschape 398 | input = input.view(batch_size*sample_size, self.input_dim) # bsz*ssz x xdim 399 | _, context = expand_tensor(context, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim 400 | #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim 401 | std = std.view(batch_size*sample_size, 1) 402 | 403 | # encode 404 | ctx = self.ctx_encode(context) 405 | inp = self.inp_encode(input) 406 | 407 | # concat 408 | h = torch.cat([inp, ctx, std], dim=1) 409 | 410 | # de-noise with context 411 | glogprob = self.dae(h) 412 | 413 | return glogprob.view(batch_size, sample_size, self.input_dim) 414 | -------------------------------------------------------------------------------- /models/vae/auxconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import autograd 8 | 9 | from models.layers import Identity, MLP 10 | from models.reparam import NormalDistributionLinear, BernoulliDistributionLinear, BernoulliDistributionConvTranspose2d 11 | from utils import loss_kld_gaussian, loss_kld_gaussian_vs_gaussian, loss_recon_gaussian, loss_recon_bernoulli_with_logit, normal_energy_func 12 | from utils import logprob_gaussian 13 | from utils import get_nonlinear_func 14 | from utils import conv_out_size, deconv_out_size 15 | from models.vae.conv import Decoder 16 | 17 | 18 | def weight_init(m): 19 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): 20 | torch.nn.init.xavier_uniform_(m.weight) 21 | #torch.nn.init.xavier_normal_(m.weight) 22 | if m.bias is not None: 23 | torch.nn.init.zeros_(m.bias) 24 | 25 | def sample_gaussian(mu, logvar, _std=1.): 26 | if _std is None: 27 | _std = 1. 28 | std = _std*torch.exp(0.5*logvar) 29 | eps = torch.randn_like(std) 30 | return mu + std * eps 31 | 32 | class AuxEncoder(nn.Module): 33 | def __init__(self, 34 | input_height=28, 35 | input_channels=1, 36 | z0_dim=32, 37 | nonlinearity='softplus', 38 | ): 39 | super().__init__() 40 | self.input_height = input_height 41 | self.input_channels = input_channels 42 | self.z0_dim = z0_dim 43 | self.nonlinearity = nonlinearity 44 | 45 | s_h = input_height 46 | s_h2 = conv_out_size(s_h, 5, 2, 2) 47 | s_h4 = conv_out_size(s_h2, 5, 2, 2) 48 | s_h8 = conv_out_size(s_h4, 5, 2, 2) 49 | #print(s_h, s_h2, s_h4, s_h8) 50 | 51 | self.afun = get_nonlinear_func(nonlinearity) 52 | self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True) 53 | self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True) 54 | self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True) 55 | self.fc = nn.Linear(s_h8*s_h8*32, 800, bias=True) 56 | self.reparam = NormalDistributionLinear(800, z0_dim) 57 | 58 | def sample(self, mu, logvar, _std=1.): 59 | return sample_gaussian(mu, logvar, _std=_std) 60 | 61 | def forward(self, x, _std=1.): 62 | batch_size = x.size(0) 63 | x = x.view(batch_size, self.input_channels, self.input_height, self.input_height) 64 | 65 | # rescale 66 | x = 2*x -1 67 | 68 | # forward 69 | h1 = self.afun(self.conv1(x)) 70 | h2 = self.afun(self.conv2(h1)) 71 | h3 = self.afun(self.conv3(h2)) 72 | h3 = h3.view(batch_size, -1) 73 | h4 = self.afun(self.fc(h3)) 74 | mu, logvar = self.reparam(h4) 75 | 76 | # sample 77 | z0 = self.sample(mu, logvar, _std=_std) 78 | 79 | return z0, mu, logvar, h4 80 | 81 | class Encoder(nn.Module): 82 | def __init__(self, 83 | input_height=28, 84 | input_channels=1, 85 | z0_dim=100, 86 | z_dim=32, 87 | nonlinearity='softplus', 88 | ): 89 | super().__init__() 90 | self.input_height = input_height 91 | self.input_channels = input_channels 92 | self.z0_dim = z0_dim 93 | self.z_dim = z_dim 94 | self.nonlinearity = nonlinearity 95 | 96 | s_h = input_height 97 | s_h2 = conv_out_size(s_h, 5, 2, 2) 98 | s_h4 = conv_out_size(s_h2, 5, 2, 2) 99 | s_h8 = conv_out_size(s_h4, 5, 2, 2) 100 | #print(s_h, s_h2, s_h4, s_h8) 101 | 102 | self.afun = get_nonlinear_func(nonlinearity) 103 | self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True) 104 | self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True) 105 | self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True) 106 | self.fc = nn.Linear(s_h8*s_h8*32 + z0_dim, 800, bias=True) 107 | self.reparam = NormalDistributionLinear(800, z_dim) 108 | 109 | def sample(self, mu_z, logvar_z): 110 | return self.reparam.sample_gaussian(mu_z, logvar_z) 111 | 112 | def forward(self, x, z0, nz=1): 113 | batch_size = x.size(0) 114 | x = x.view(batch_size, self.input_channels, self.input_height, self.input_height) 115 | assert z0.size(0) == batch_size*nz 116 | 117 | # rescale 118 | x = 2*x -1 119 | 120 | # forward 121 | h1 = self.afun(self.conv1(x)) 122 | h2 = self.afun(self.conv2(h1)) 123 | h3 = self.afun(self.conv3(h2)) 124 | h3 = h3.view(batch_size, -1) 125 | 126 | # view 127 | h3 = h3.unsqueeze(1).expand(-1, nz, -1).contiguous() 128 | h3 = h3.view(batch_size*nz, -1) 129 | 130 | # concat 131 | h3z0 = torch.cat([h3, z0], dim=1) 132 | 133 | # forward 134 | h4 = self.afun(self.fc(h3z0)) 135 | mu, logvar = self.reparam(h4) 136 | 137 | # sample 138 | z = self.sample(mu, logvar) 139 | 140 | return z, mu, logvar, h4 141 | 142 | class AuxDecoder(nn.Module): 143 | def __init__(self, 144 | input_height=28, 145 | input_channels=1, 146 | z_dim=32, 147 | z0_dim=100, 148 | nonlinearity='softplus', 149 | ): 150 | super().__init__() 151 | self.input_height = input_height 152 | self.input_channels = input_channels 153 | self.z_dim = z_dim 154 | self.z0_dim = z0_dim 155 | self.nonlinearity = nonlinearity 156 | 157 | s_h = input_height 158 | s_h2 = conv_out_size(s_h, 5, 2, 2) 159 | s_h4 = conv_out_size(s_h2, 5, 2, 2) 160 | s_h8 = conv_out_size(s_h4, 5, 2, 2) 161 | #print(s_h, s_h2, s_h4, s_h8) 162 | 163 | self.afun = get_nonlinear_func(nonlinearity) 164 | self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True) 165 | self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True) 166 | self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True) 167 | self.fc = nn.Linear(s_h8*s_h8*32 + z_dim, 800, bias=True) 168 | self.reparam = NormalDistributionLinear(800, z0_dim) 169 | 170 | def sample(self, mu, logvar): 171 | return self.reparam.sample_gaussian(mu, logvar) 172 | 173 | def forward(self, x, z, nz=1): 174 | batch_size = x.size(0) 175 | x = x.view(batch_size, self.input_channels, self.input_height, self.input_height) 176 | 177 | # rescale 178 | x = 2*x -1 179 | 180 | # forward 181 | h1 = self.afun(self.conv1(x)) 182 | h2 = self.afun(self.conv2(h1)) 183 | h3 = self.afun(self.conv3(h2)) 184 | h3 = h3.view(batch_size, -1) 185 | 186 | # view 187 | assert z.size(0) == batch_size*nz 188 | h3 = h3.unsqueeze(1).expand(-1, nz, -1).contiguous() 189 | h3 = h3.view(batch_size*nz, -1) 190 | 191 | # concat 192 | h3z = torch.cat([h3, z], dim=1) 193 | 194 | # forward 195 | h4 = self.afun(self.fc(h3z)) 196 | mu, logvar = self.reparam(h4) 197 | 198 | # sample 199 | z0 = self.sample(mu, logvar) 200 | 201 | return z0, mu, logvar 202 | 203 | class VAE(nn.Module): 204 | def __init__(self, 205 | energy_func=normal_energy_func, 206 | input_height=28, 207 | input_channels=1, 208 | z0_dim=100, 209 | z_dim=32, 210 | nonlinearity='softplus', 211 | do_xavier=True, 212 | do_m5bias=False, 213 | ): 214 | super().__init__() 215 | self.energy_func = energy_func 216 | self.input_height = input_height 217 | self.input_channels = input_channels 218 | self.z0_dim = z0_dim 219 | self.z_dim = z_dim 220 | self.latent_dim = z_dim # for ais 221 | self.nonlinearity = nonlinearity 222 | self.do_xavier = do_xavier 223 | self.do_m5bias = do_m5bias 224 | 225 | self.aux_encode = AuxEncoder(input_height, input_channels, z0_dim, nonlinearity=nonlinearity) 226 | self.encode = Encoder(input_height, input_channels, z0_dim, z_dim, nonlinearity=nonlinearity) 227 | self.decode = Decoder(input_height, input_channels, z_dim, nonlinearity=nonlinearity) 228 | self.aux_decode = AuxDecoder(input_height, input_channels, z_dim, z0_dim, nonlinearity=nonlinearity) 229 | self.reset_parameters() 230 | 231 | def reset_parameters(self): 232 | if self.do_xavier: 233 | self.apply(weight_init) 234 | if self.do_m5bias: 235 | torch.nn.init.constant_(self.decode.reparam.logit_fn.bias, -5) 236 | 237 | def loss(self, 238 | mu_qz, logvar_qz, 239 | mu_qz0, logvar_qz0, 240 | mu_pz0, logvar_pz0, 241 | logit_px, target_x, 242 | beta=1.0, 243 | ): 244 | # kld loss: log q(z|z0, x) - log p(z) 245 | kld_loss = loss_kld_gaussian(mu_qz, logvar_qz, do_sum=False) 246 | 247 | # aux dec loss: -log r(z0|z,x) 248 | aux_kld_loss = loss_kld_gaussian_vs_gaussian( 249 | mu_qz0, logvar_qz0, 250 | mu_pz0, logvar_pz0, 251 | do_sum=False, 252 | ) 253 | 254 | # recon loss (neg likelihood): -log p(x|z) 255 | recon_loss = loss_recon_bernoulli_with_logit(logit_px, target_x, do_sum=False) 256 | 257 | # add loss 258 | loss = recon_loss + beta*kld_loss + beta*aux_kld_loss 259 | return loss.mean(), recon_loss.mean(), kld_loss.mean(), aux_kld_loss.mean() 260 | 261 | def forward(self, input, beta=1.0): 262 | # init 263 | batch_size = input.size(0) 264 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 265 | 266 | # aux encode 267 | z0, mu_qz0, logvar_qz0, _ = self.aux_encode(input) 268 | 269 | # encode 270 | z, mu_qz, logvar_qz, _ = self.encode(input, z0) 271 | 272 | # aux decode 273 | _, mu_pz0, logvar_pz0 = self.aux_decode(input, z) 274 | 275 | # decode 276 | x, logit_px = self.decode(z) 277 | 278 | ''' get loss ''' 279 | loss, recon_loss, kld_loss, aux_kld_loss = self.loss( 280 | mu_qz, logvar_qz, 281 | mu_qz0, logvar_qz0, 282 | mu_pz0, logvar_pz0, 283 | logit_px, input, 284 | beta=beta, 285 | ) 286 | 287 | # return 288 | return x, torch.sigmoid(logit_px), z, loss, recon_loss.detach(), kld_loss.detach()+aux_kld_loss.detach() 289 | 290 | def generate(self, batch_size=1): 291 | # init mu_z and logvar_z (as unit normal dist) 292 | weight = next(self.parameters()) 293 | mu_z = weight.new_zeros(batch_size, self.z_dim) 294 | logvar_z = weight.new_zeros(batch_size, self.z_dim) 295 | 296 | # sample z (from unit normal dist) 297 | z = sample_gaussian(mu_z, logvar_z) # sample z 298 | 299 | # decode 300 | output, logit_px = self.decode(z) 301 | 302 | # return 303 | return output, torch.sigmoid(logit_px), z 304 | 305 | def logprob(self, input, sample_size=128, z=None): 306 | #assert int(math.sqrt(sample_size))**2 == sample_size 307 | # init 308 | batch_size = input.size(0) 309 | sample_size1 = sample_size #int(math.sqrt(sample_size)) 310 | sample_size2 = 1 #int(math.sqrt(sample_size)) 311 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 312 | 313 | ''' get - (log q(z|z0,x) + log q(z0|z) - log p(z0|z,x) - log p(z)) ''' 314 | ''' get log q(z0|x) ''' 315 | _, mu_qz0, logvar_qz0, _ = self.aux_encode(input) 316 | mu_qz0 = mu_qz0.unsqueeze(1).expand(batch_size, sample_size1, self.z0_dim).contiguous().view(batch_size*sample_size1, self.z0_dim) # bsz*ssz1 x z0_dim 317 | logvar_qz0 = logvar_qz0.unsqueeze(1).expand(batch_size, sample_size1, self.z0_dim).contiguous().view(batch_size*sample_size1, self.z0_dim) # bsz*ssz1 x z0_dim 318 | z0 = self.aux_encode.sample(mu_qz0, logvar_qz0) # bsz*ssz1 x z0_dim 319 | log_qz0 = logprob_gaussian(mu_qz0, logvar_qz0, z0, do_unsqueeze=False, do_mean=False) 320 | log_qz0 = torch.sum(log_qz0.view(batch_size, sample_size1, self.z0_dim), dim=2) # bsz x ssz1 321 | log_qz0 = log_qz0.unsqueeze(2).expand(batch_size, sample_size1, sample_size2).contiguous().view(batch_size, sample_size1*sample_size2) # bsz x ssz1*ssz2 322 | 323 | ''' get log q(z|z0,x) ''' 324 | # forward 325 | _, mu_qz, logvar_qz, _ = self.encode(input, z0, nz=sample_size1) # bsz*ssz1 x z_dim 326 | mu_qz = mu_qz.detach().repeat(1, sample_size2).view(batch_size*sample_size1, sample_size2, self.z_dim) 327 | logvar_qz = logvar_qz.detach().repeat(1, sample_size2).view(batch_size*sample_size1, sample_size2, self.z_dim) 328 | z = self.encode.sample(mu_qz, logvar_qz) # bsz x ssz1 x ssz2 x z_dim 329 | log_qz = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) 330 | log_qz = torch.sum(log_qz.view(batch_size, sample_size1*sample_size2, self.z_dim), dim=2) # bsz x ssz1*ssz2 331 | 332 | ''' get log p(z0|z,x) ''' 333 | # encode 334 | _z0 = z0.unsqueeze(1).expand(batch_size*sample_size1, sample_size2, self.z0_dim).contiguous().view(batch_size, sample_size1, sample_size2, self.z0_dim).detach() 335 | _, mu_pz0, logvar_pz0 = self.aux_decode(input, z.view(-1, self.z_dim), nz=sample_size1*sample_size2) # bsz*ssz1 x z_dim 336 | mu_pz0 = mu_pz0.view(batch_size, sample_size1, sample_size2, self.z0_dim) 337 | logvar_pz0 = logvar_pz0.view(batch_size, sample_size1, sample_size2, self.z0_dim) 338 | log_pz0 = logprob_gaussian(mu_pz0, logvar_pz0, _z0, do_unsqueeze=False, do_mean=False) # bsz x ssz1 x ssz2 xz0_dim 339 | log_pz0 = torch.sum(log_pz0.view(batch_size, sample_size1*sample_size2, self.z0_dim), dim=2) # bsz x ssz1*ssz2 340 | 341 | ''' get log p(z) ''' 342 | # get prior (as unit normal dist) 343 | mu_pz = input.new_zeros(batch_size*sample_size1, sample_size2, self.z_dim) 344 | logvar_pz = input.new_zeros(batch_size*sample_size1, sample_size2, self.z_dim) 345 | log_pz = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) 346 | log_pz = torch.sum(log_pz.view(batch_size, sample_size1*sample_size2, self.z_dim), dim=2) # bsz x ssz1*ssz2 347 | 348 | ''' get log p(x|z) ''' 349 | # decode 350 | _input = input.unsqueeze(1).unsqueeze(1).expand( 351 | batch_size, sample_size1, sample_size2, self.input_channels, self.input_height, self.input_height) # bsz x ssz1 x ssz2 x input_dim 352 | _z = z.view(-1, self.z_dim) 353 | #_, mu_x, logvar_x = self.decode(_z) # bsz*ssz1*ssz2 x zdim 354 | #mu_x = mu_x.view(batch_size, sample_size1, sample_size2, self.input_dim) 355 | #logvar_x = logvar_x.view(batch_size, sample_size1, sample_size2, self.input_dim) 356 | #loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) 357 | _, logit_px = self.decode(_z) # bsz*ssz1*ssz2 x zdim 358 | logit_px = logit_px.view(batch_size, sample_size1, sample_size2, self.input_channels, self.input_height, self.input_height) 359 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_px, _input, reduction='none') 360 | loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size1*sample_size2, -1), dim=2) # bsz x ssz1*ssz2 361 | 362 | ''' get log p(x|z)p(z)/q(z|x) ''' 363 | logprob = loglikelihood + log_pz + log_pz0 - log_qz - log_qz0 # bsz x ssz1*ssz2 364 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 365 | rprob = (logprob - logprob_max).exp() # relative prob 366 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 367 | 368 | # return 369 | return logprob.mean() 370 | -------------------------------------------------------------------------------- /models/vae/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import autograd 8 | 9 | from models.layers import MLP 10 | from models.reparam import NormalDistributionLinear, BernoulliDistributionConvTranspose2d 11 | from utils import loss_kld_gaussian, loss_recon_bernoulli_with_logit, normal_energy_func 12 | from utils import logprob_gaussian 13 | from utils import get_nonlinear_func 14 | from utils import conv_out_size, deconv_out_size 15 | 16 | 17 | def weight_init(m): 18 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): 19 | torch.nn.init.xavier_uniform_(m.weight) 20 | #torch.nn.init.xavier_normal_(m.weight) 21 | if m.bias is not None: 22 | torch.nn.init.zeros_(m.bias) 23 | 24 | def sample_gaussian(mu, logvar): 25 | std = torch.exp(0.5*logvar) 26 | eps = torch.randn_like(std) 27 | return mu + std * eps 28 | 29 | class Encoder(nn.Module): 30 | def __init__(self, 31 | input_height=28, 32 | input_channels=1, 33 | z_dim=32, 34 | nonlinearity='softplus', 35 | ): 36 | super().__init__() 37 | self.input_height = input_height 38 | self.input_channels = input_channels 39 | self.z_dim = z_dim 40 | self.nonlinearity = nonlinearity 41 | 42 | s_h = input_height 43 | s_h2 = conv_out_size(s_h, 5, 2, 2) 44 | s_h4 = conv_out_size(s_h2, 5, 2, 2) 45 | s_h8 = conv_out_size(s_h4, 5, 2, 2) 46 | #print(s_h, s_h2, s_h4, s_h8) 47 | #ipdb.set_trace() 48 | 49 | self.afun = get_nonlinear_func(nonlinearity) 50 | self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True) 51 | self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True) 52 | self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True) 53 | self.fc = nn.Linear(s_h8*s_h8*32, 800, bias=True) 54 | self.reparam = NormalDistributionLinear(800, z_dim) 55 | 56 | def sample(self, mu, logvar): 57 | return self.reparam.sample_gaussian(mu, logvar) 58 | 59 | def forward(self, x): 60 | batch_size = x.size(0) 61 | x = x.view(batch_size, self.input_channels, self.input_height, self.input_height) 62 | 63 | # rescale 64 | x = 2*x -1 65 | 66 | # forward 67 | h1 = self.afun(self.conv1(x)) 68 | h2 = self.afun(self.conv2(h1)) 69 | h3 = self.afun(self.conv3(h2)) 70 | h3 = h3.view(batch_size, -1) 71 | h4 = self.afun(self.fc(h3)) 72 | mu, logvar = self.reparam(h4) 73 | 74 | # sample 75 | z = self.sample(mu, logvar) 76 | 77 | return z, mu, logvar 78 | 79 | class Decoder(nn.Module): 80 | def __init__(self, 81 | input_height=28, 82 | input_channels=1, 83 | z_dim=32, 84 | nonlinearity='softplus', 85 | #do_trim=True, 86 | ): 87 | super().__init__() 88 | self.input_height = input_height 89 | self.input_channels = input_channels 90 | self.z_dim = z_dim 91 | self.nonlinearity = nonlinearity 92 | #self.do_trim = do_trim 93 | 94 | s_h = input_height 95 | s_h2 = conv_out_size(s_h, 5, 2, 2) 96 | s_h4 = conv_out_size(s_h2, 5, 2, 2) 97 | s_h8 = conv_out_size(s_h4, 5, 2, 2) 98 | #print(s_h, s_h2, s_h4, s_h8) 99 | #_s_h8 = s_h8 100 | #_s_h4 = deconv_out_size(_s_h8, 5, 2, 2, 0) 101 | #_s_h2 = deconv_out_size(_s_h4+1, 5, 2, 2, 0) 102 | #_s_h = deconv_out_size(_s_h2, 5, 2, 2, 0) 103 | #if self.do_trim: 104 | #else: 105 | # _s_h = deconv_out_size(_s_h2, 5, 2, 2, 1) 106 | #print(_s_h, _s_h2, _s_h4, _s_h8) 107 | #ipdb.set_trace() 108 | self.s_h8 = s_h8 109 | 110 | self.afun = get_nonlinear_func(nonlinearity) 111 | self.fc = MLP(input_dim=z_dim, hidden_dim=300, output_dim=s_h8*s_h8*32, nonlinearity=nonlinearity, num_hidden_layers=1, use_nonlinearity_output=True) 112 | self.deconv1 = nn.ConvTranspose2d(32, 32, 5, 2, 2, 0, bias=True) 113 | self.pad1 = nn.ZeroPad2d((0, 1, 0, 1)) 114 | self.deconv2 = nn.ConvTranspose2d(32, 16, 5, 2, 2, 0, bias=True) 115 | self.reparam = BernoulliDistributionConvTranspose2d(16, self.input_channels, 5, 2, 2, 0, bias=True) 116 | self.padr = nn.ZeroPad2d((0, -1, 0, -1)) 117 | 118 | def sample(self, logit): 119 | return self.reparam.sample_logistic_sigmoid(logit) 120 | 121 | def forward(self, z): 122 | batch_size = z.size(0) 123 | z = z.view(batch_size, -1) 124 | 125 | # forward 126 | h1 = self.fc(z) 127 | h1 = h1.view(batch_size, 32, self.s_h8, self.s_h8) 128 | h2 = self.pad1(self.afun(self.deconv1(h1))) 129 | h3 = self.afun(self.deconv2(h2)) 130 | logit = self.reparam(h3) 131 | logit = self.padr(logit) 132 | 133 | # sample 134 | x = self.sample(logit) 135 | 136 | return x, logit 137 | 138 | class VAE(nn.Module): 139 | def __init__(self, 140 | energy_func=normal_energy_func, 141 | input_height=28, 142 | input_channels=1, 143 | z_dim=32, 144 | nonlinearity='softplus', 145 | do_xavier=False, 146 | do_m5bias=False, 147 | #do_trim=True, 148 | ): 149 | super().__init__() 150 | self.energy_func = energy_func 151 | self.input_height = input_height 152 | self.input_channels = input_channels 153 | self.z_dim = z_dim 154 | self.latent_dim = self.z_dim # for ais 155 | self.nonlinearity = nonlinearity 156 | self.do_xavier = do_xavier 157 | self.do_m5bias = do_m5bias 158 | #self.do_trim = do_trim 159 | 160 | self.encode = Encoder(input_height, input_channels, z_dim, nonlinearity=nonlinearity) 161 | self.decode = Decoder(input_height, input_channels, z_dim, nonlinearity=nonlinearity)#, do_trim=do_trim) 162 | self.reset_parameters() 163 | 164 | def reset_parameters(self): 165 | if self.do_xavier: 166 | self.apply(weight_init) 167 | if self.do_m5bias: 168 | torch.nn.init.constant_(self.decode.reparam.logit_fn.bias, -5) 169 | 170 | def loss(self, mu_z, logvar_z, logit_x, target_x, beta=1.0): 171 | # kld loss 172 | kld_loss = loss_kld_gaussian(mu_z, logvar_z, do_sum=False) 173 | 174 | # recon loss (likelihood) 175 | recon_loss = loss_recon_bernoulli_with_logit(logit_x, target_x, do_sum=False) 176 | 177 | # add loss 178 | loss = recon_loss + beta*kld_loss 179 | return loss.mean(), recon_loss.mean(), kld_loss.mean() 180 | 181 | def forward(self, input, beta=1.0): 182 | # init 183 | batch_size = input.size(0) 184 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 185 | 186 | # encode 187 | z, mu_z, logvar_z = self.encode(input) 188 | 189 | # decode 190 | x, logit_x = self.decode(z) 191 | 192 | # loss 193 | loss, recon_loss, kld_loss \ 194 | = self.loss(mu_z, logvar_z, 195 | logit_x, 196 | input, 197 | beta=beta, 198 | ) 199 | 200 | # return 201 | return x, torch.sigmoid(logit_x), z, loss, recon_loss.detach(), kld_loss.detach() 202 | 203 | def generate(self, batch_size=1): 204 | # init mu_z and logvar_z (as unit normal dist) 205 | weight = next(self.parameters()) 206 | mu_z = weight.new_zeros(batch_size, self.z_dim) 207 | logvar_z = weight.new_zeros(batch_size, self.z_dim) 208 | 209 | # sample z (from unit normal dist) 210 | z = sample_gaussian(mu_z, logvar_z) # sample z 211 | 212 | # decode 213 | output, logit_x = self.decode(z) 214 | 215 | # return 216 | return output, torch.sigmoid(logit_x), z 217 | 218 | def logprob(self, input, sample_size=128, z=None): 219 | ''' 220 | input: positive samples 221 | ''' 222 | # init 223 | batch_size = input.size(0) 224 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 225 | 226 | ''' get log q(z|x) ''' 227 | _, mu_qz, logvar_qz = self.encode(input) 228 | mu_qz = mu_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 229 | logvar_qz = logvar_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 230 | z = self.encode.sample(mu_qz, logvar_qz) 231 | logposterior = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) 232 | logposterior = torch.sum(logposterior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 233 | 234 | ''' get log p(z) ''' 235 | # get prior (as unit normal dist) 236 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 237 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 238 | logprior = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) 239 | logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 240 | 241 | ''' get log p(x|z) ''' 242 | # decode 243 | logit_x = [] 244 | #for i in range(sample_size): 245 | for i in range(batch_size): 246 | _, _logit_x = self.decode(z[i, :, :]) # ssz x zdim 247 | logit_x += [_logit_x.detach().unsqueeze(0)] 248 | logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim 249 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_channels, self.input_height, self.input_height) # bsz x ssz x input_dim 250 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_x, _input, reduction='none') 251 | loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size, -1), dim=2) # bsz x ssz 252 | 253 | ''' get log p(x|z)p(z)/q(z|x) ''' 254 | logprob = loglikelihood + logprior - logposterior # bsz x ssz 255 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 256 | rprob = (logprob - logprob_max).exp() # relative prob 257 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 258 | 259 | # return 260 | return logprob.mean() 261 | 262 | def logprob_w_prior(self, input, sample_size=128, z=None): 263 | ''' 264 | input: positive samples 265 | ''' 266 | # init 267 | batch_size = input.size(0) 268 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 269 | 270 | ''' get z samples from p(z) ''' 271 | # get prior (as unit normal dist) 272 | if z is None: 273 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 274 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 275 | z = sample_gaussian(mu_pz, logvar_pz) # sample z 276 | 277 | ''' get log p(x|z) ''' 278 | # decode 279 | logit_x = [] 280 | for i in range(sample_size): 281 | _, _logit_x = self.decode(z[:, i, :]) 282 | logit_x += [_logit_x.detach().unsqueeze(1)] 283 | logit_x = torch.cat(logit_x, dim=1) # bsz x ssz x input_dim 284 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_channels, self.input_height, self.input_height) # bsz x ssz x input_dim 285 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_x, _input, reduction='none') 286 | loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size, -1), dim=2) # bsz x ssz 287 | 288 | ''' get log p(x) ''' 289 | logprob = loglikelihood # bsz x ssz 290 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 291 | rprob = (logprob-logprob_max).exp() # relative prob 292 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 293 | 294 | # return 295 | return logprob.mean() 296 | -------------------------------------------------------------------------------- /models/vae/mnist.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import autograd 8 | 9 | from models.layers import MLP 10 | from models.reparam import NormalDistributionLinear, BernoulliDistributionLinear 11 | from utils import loss_kld_gaussian, loss_recon_bernoulli_with_logit, normal_energy_func 12 | from utils import logprob_gaussian 13 | from utils import get_nonlinear_func 14 | 15 | 16 | def weight_init(m): 17 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): 18 | torch.nn.init.xavier_uniform_(m.weight) 19 | #torch.nn.init.xavier_normal_(m.weight) 20 | if m.bias is not None: 21 | torch.nn.init.zeros_(m.bias) 22 | 23 | def sample_gaussian(mu, logvar): 24 | std = torch.exp(0.5*logvar) 25 | eps = torch.randn_like(std) 26 | return mu + std * eps 27 | 28 | class Encoder(nn.Module): 29 | def __init__(self, 30 | input_dim=784, 31 | h_dim=300, 32 | z_dim=32, 33 | nonlinearity='softplus', 34 | num_hidden_layers=2, 35 | ): 36 | super().__init__() 37 | self.input_dim = input_dim 38 | self.h_dim = h_dim 39 | self.z_dim = z_dim 40 | self.nonlinearity = nonlinearity 41 | self.num_hidden_layers = num_hidden_layers 42 | 43 | self.main = MLP(input_dim=input_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 44 | self.reparam = NormalDistributionLinear(h_dim, z_dim) 45 | 46 | def sample(self, mu, logvar): 47 | return self.reparam.sample_gaussian(mu, logvar) 48 | 49 | def forward(self, x): 50 | batch_size = x.size(0) 51 | x = x.view(batch_size, self.input_dim) 52 | 53 | # rescale 54 | x = 2*x -1 55 | 56 | # forward 57 | h = self.main(x) 58 | mu, logvar = self.reparam(h) 59 | 60 | # sample 61 | z = self.sample(mu, logvar) 62 | 63 | return z, mu, logvar 64 | 65 | class Decoder(nn.Module): 66 | def __init__(self, 67 | input_dim=784, 68 | h_dim=300, 69 | z_dim=32, 70 | nonlinearity='softplus', 71 | num_hidden_layers=2, 72 | ): 73 | super().__init__() 74 | self.input_dim = input_dim 75 | self.h_dim = h_dim 76 | self.z_dim = z_dim 77 | self.nonlinearity = nonlinearity 78 | self.num_hidden_layers = num_hidden_layers 79 | 80 | self.main = MLP(input_dim=z_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 81 | self.reparam = BernoulliDistributionLinear(h_dim, input_dim) 82 | 83 | def sample(self, logit): 84 | return self.reparam.sample_logistic_sigmoid(logit) 85 | 86 | def forward(self, z): 87 | batch_size = z.size(0) 88 | z = z.view(batch_size, -1) 89 | 90 | # forward 91 | h = self.main(z) 92 | logit = self.reparam(h) 93 | 94 | # sample 95 | x = self.sample(logit) 96 | 97 | return x, logit 98 | 99 | class VAE(nn.Module): 100 | def __init__(self, 101 | energy_func=normal_energy_func, 102 | input_dim=784, 103 | h_dim=300, 104 | z_dim=32, 105 | nonlinearity='softplus', 106 | num_hidden_layers=2, 107 | do_xavier=False, 108 | do_m5bias=False, 109 | ): 110 | super().__init__() 111 | self.energy_func = energy_func 112 | self.input_dim = input_dim 113 | self.h_dim = h_dim 114 | self.z_dim = z_dim 115 | self.latent_dim = self.z_dim # for ais 116 | self.nonlinearity = nonlinearity 117 | self.num_hidden_layers = num_hidden_layers 118 | self.do_xavier = do_xavier 119 | self.do_m5bias = do_m5bias 120 | 121 | self.encode = Encoder(input_dim, h_dim, z_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers) 122 | self.decode = Decoder(input_dim, h_dim, z_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers) 123 | self.reset_parameters() 124 | 125 | def reset_parameters(self): 126 | if self.do_xavier: 127 | self.apply(weight_init) 128 | if self.do_m5bias: 129 | torch.nn.init.constant_(self.decode.reparam.logit_fn.bias, -5) 130 | 131 | def loss(self, mu_z, logvar_z, logit_x, target_x, beta=1.0): 132 | # kld loss 133 | kld_loss = loss_kld_gaussian(mu_z, logvar_z, do_sum=False) 134 | 135 | # recon loss (likelihood) 136 | recon_loss = loss_recon_bernoulli_with_logit(logit_x, target_x.view(-1, self.input_dim), do_sum=False) 137 | 138 | # add loss 139 | loss = recon_loss + beta*kld_loss 140 | return loss.mean(), recon_loss.mean(), kld_loss.mean() 141 | 142 | def forward(self, input, beta=1.0): 143 | # init 144 | batch_size = input.size(0) 145 | input = input.view(batch_size, self.input_dim) 146 | 147 | # encode 148 | z, mu_z, logvar_z = self.encode(input) 149 | 150 | # decode 151 | x, logit_x = self.decode(z) 152 | 153 | # loss 154 | loss, recon_loss, kld_loss \ 155 | = self.loss(mu_z, logvar_z, 156 | logit_x, 157 | input, 158 | beta=beta, 159 | ) 160 | 161 | # return 162 | return x, torch.sigmoid(logit_x), z, loss, recon_loss.detach(), kld_loss.detach() 163 | 164 | def generate(self, batch_size=1): 165 | # init mu_z and logvar_z (as unit normal dist) 166 | weight = next(self.parameters()) 167 | mu_z = weight.new_zeros(batch_size, self.z_dim) 168 | logvar_z = weight.new_zeros(batch_size, self.z_dim) 169 | 170 | # sample z (from unit normal dist) 171 | z = sample_gaussian(mu_z, logvar_z) # sample z 172 | 173 | # decode 174 | output, logit_x = self.decode(z) 175 | 176 | # return 177 | return output, torch.sigmoid(logit_x), z 178 | 179 | def logprob(self, input, sample_size=128, z=None): 180 | # init 181 | batch_size = input.size(0) 182 | input = input.view(batch_size, self.input_dim) 183 | 184 | ''' get log q(z|x) ''' 185 | _, mu_qz, logvar_qz = self.encode(input) 186 | mu_qz = mu_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 187 | logvar_qz = logvar_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 188 | z = self.encode.sample(mu_qz, logvar_qz) 189 | logposterior = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) 190 | logposterior = torch.sum(logposterior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 191 | 192 | ''' get log p(z) ''' 193 | # get prior (as unit normal dist) 194 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 195 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 196 | logprior = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) 197 | logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 198 | 199 | ''' get log p(x|z) ''' 200 | # decode 201 | #logit_x = [] 202 | #for i in range(batch_size): 203 | # _, _logit_x = self.decode(z[i, :, :]) # ssz x zdim 204 | # logit_x += [_logit_x.detach().unsqueeze(0)] 205 | #logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim 206 | _z = z.view(-1, self.z_dim) 207 | _, logit_x = self.decode(_z) # bsz*ssz x zdim 208 | logit_x = logit_x.view(batch_size, sample_size, self.input_dim) 209 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim 210 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_x, _input, reduction='none') 211 | loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz 212 | 213 | ''' get log p(x|z)p(z)/q(z|x) ''' 214 | logprob = loglikelihood + logprior - logposterior # bsz x ssz 215 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 216 | rprob = (logprob - logprob_max).exp() # relative prob 217 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 218 | 219 | # return 220 | return logprob.mean() 221 | 222 | def logprob_w_prior(self, input, sample_size=128, z=None): 223 | # init 224 | batch_size = input.size(0) 225 | input = input.view(batch_size, self.input_dim) 226 | 227 | ''' get z samples from p(z) ''' 228 | # get prior (as unit normal dist) 229 | if z is None: 230 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 231 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 232 | z = sample_gaussian(mu_pz, logvar_pz) # sample z 233 | 234 | ''' get log p(x|z) ''' 235 | # decode 236 | _z = z.view(-1, self.z_dim) 237 | #logit_x = [] 238 | #for i in range(sample_size): 239 | # _, _logit_x = self.decode(z[:, i, :]) 240 | # logit_x += [_logit_x.detach().unsqueeze(1)] 241 | #logit_x = torch.cat(logit_x, dim=1) # bsz x ssz x input_dim 242 | _, logit_x = self.decode(_z) # bsz*ssz x zdim 243 | logit_x = logit_x.view(batch_size, sample_size, self.input_dim) 244 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim 245 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_x, _input, reduction='none') 246 | loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz 247 | 248 | ''' get log p(x) ''' 249 | logprob = loglikelihood # bsz x ssz 250 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 251 | rprob = (logprob-logprob_max).exp() # relative prob 252 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 253 | 254 | # return 255 | return logprob.mean() 256 | -------------------------------------------------------------------------------- /models/vae/resconv.py: -------------------------------------------------------------------------------- 1 | ''' 2 | copied and modified from https://github.com/CW-Huang/torchkit/blob/master/torchkit/autoencoders.py#L20-L70 3 | ''' 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import autograd 11 | 12 | from models.layers import MLP 13 | from models.reparam import NormalDistributionLinear, BernoulliDistribution#, BernoulliDistributionConvTranspose2d 14 | from utils import loss_kld_gaussian, loss_recon_bernoulli_with_logit, normal_energy_func 15 | from utils import logprob_gaussian 16 | from utils import get_nonlinear_func 17 | from utils import conv_out_size, deconv_out_size 18 | import models.layers2 as nn_ 19 | 20 | 21 | def sample_gaussian(mu, logvar): 22 | std = torch.exp(0.5*logvar) 23 | eps = torch.randn_like(std) 24 | return mu + std * eps 25 | 26 | class Encoder(nn.Module): 27 | def __init__(self, 28 | z_dim=32, 29 | c_dim=450, 30 | act=nn.ELU(), 31 | do_center=False, 32 | ): 33 | super().__init__() 34 | self.z_dim = z_dim 35 | self.c_dim = c_dim 36 | self.do_center = do_center 37 | 38 | self.enc = nn.Sequential( 39 | nn_.ResConv2d(1,16,3,2,padding=1,activation=act), 40 | act, 41 | nn_.ResConv2d(16,16,3,1,padding=1,activation=act), 42 | act, 43 | nn_.ResConv2d(16,32,3,2,padding=1,activation=act), 44 | act, 45 | nn_.ResConv2d(32,32,3,1,padding=1,activation=act), 46 | act, 47 | nn_.ResConv2d(32,32,3,2,padding=1,activation=act), 48 | act, 49 | nn_.Reshape((-1,32*4*4)), 50 | nn_.ResLinear(32*4*4,c_dim), 51 | act 52 | ) 53 | self.reparam = NormalDistributionLinear(c_dim, z_dim) 54 | 55 | def sample(self, mu, logvar): 56 | return self.reparam.sample_gaussian(mu, logvar) 57 | 58 | def forward(self, x): 59 | batch_size = x.size(0) 60 | x = x.view(batch_size, 1, 28, 28) 61 | 62 | # rescale 63 | if self.do_center: 64 | x = 2*x -1 65 | 66 | # enc 67 | ctx = self.enc(x) 68 | mu, logvar = self.reparam(ctx) 69 | 70 | # sample 71 | z = self.sample(mu, logvar) 72 | 73 | return z, mu, logvar 74 | 75 | class Decoder(nn.Module): 76 | def __init__(self, 77 | z_dim=32, 78 | c_dim=450, 79 | act=nn.ELU(), 80 | do_m5bias=False, 81 | ): 82 | super().__init__() 83 | self.z_dim = z_dim 84 | self.c_dim = c_dim 85 | self.do_m5bias = do_m5bias 86 | 87 | self.dec = nn.Sequential( 88 | nn_.ResLinear(z_dim,c_dim), 89 | act, 90 | nn_.ResLinear(c_dim,32*4*4), 91 | act, 92 | nn_.Reshape((-1,32,4,4)), 93 | nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True), 94 | nn_.ResConv2d(32,32,3,1,padding=1,activation=act), 95 | act, 96 | nn_.ResConv2d(32,32,3,1,padding=1,activation=act), 97 | act, 98 | nn_.slicer[:,:,:-1,:-1], 99 | nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True), 100 | nn_.ResConv2d(32,16,3,1,padding=1,activation=act), 101 | act, 102 | nn_.ResConv2d(16,16,3,1,padding=1,activation=act), 103 | act, 104 | nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True), 105 | nn_.ResConv2d(16,1,3,1,padding=1,activation=act), 106 | ) 107 | if self.do_m5bias: 108 | self.dec[-1].conv_01.bias.data.normal_(-3, 0.0001) 109 | self.reparam = BernoulliDistribution() 110 | 111 | def sample(self, logit): 112 | return self.reparam.sample_logistic_sigmoid(logit) 113 | 114 | def forward(self, input): 115 | logit = self.dec(input) 116 | 117 | # sample 118 | x = self.sample(logit) 119 | 120 | return x, logit 121 | 122 | class VAE(nn.Module): 123 | def __init__(self, 124 | energy_func=normal_energy_func, 125 | input_height=28, 126 | input_channels=1, 127 | z_dim=32, 128 | c_dim=450, 129 | nonlinearity='elu', 130 | do_center=False, 131 | do_m5bias=False, 132 | ): 133 | super().__init__() 134 | self.energy_func = energy_func 135 | self.input_height = input_height 136 | self.input_channels = input_channels 137 | self.z_dim = z_dim 138 | self.latent_dim = self.z_dim # for ais 139 | self.nonlinearity = nonlinearity 140 | self.do_center = do_center 141 | self.do_m5bias = do_m5bias 142 | 143 | assert input_height == 28 144 | assert input_channels == 1 145 | assert nonlinearity == 'elu' 146 | 147 | self.encode = Encoder(z_dim=z_dim, c_dim=c_dim, act=nn.ELU(), do_center=do_center) 148 | self.decode = Decoder(z_dim=z_dim, c_dim=c_dim, act=nn.ELU(), do_m5bias=do_m5bias) 149 | 150 | def loss(self, mu_z, logvar_z, logit_x, target_x, beta=1.0): 151 | # kld loss 152 | kld_loss = loss_kld_gaussian(mu_z, logvar_z, do_sum=False) 153 | 154 | # recon loss (likelihood) 155 | recon_loss = loss_recon_bernoulli_with_logit(logit_x, target_x, do_sum=False) 156 | 157 | # add loss 158 | loss = recon_loss + beta*kld_loss 159 | return loss.mean(), recon_loss.mean(), kld_loss.mean() 160 | 161 | def forward(self, input, beta=1.0): 162 | # init 163 | batch_size = input.size(0) 164 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 165 | 166 | # encode 167 | z, mu_z, logvar_z = self.encode(input) 168 | 169 | # decode 170 | x, logit_x = self.decode(z) 171 | 172 | # loss 173 | loss, recon_loss, kld_loss \ 174 | = self.loss(mu_z, logvar_z, 175 | logit_x, 176 | input, 177 | beta=beta, 178 | ) 179 | 180 | # return 181 | return x, torch.sigmoid(logit_x), z, loss, recon_loss.detach(), kld_loss.detach() 182 | 183 | def generate(self, batch_size=1): 184 | # init mu_z and logvar_z (as unit normal dist) 185 | weight = next(self.parameters()) 186 | mu_z = weight.new_zeros(batch_size, self.z_dim) 187 | logvar_z = weight.new_zeros(batch_size, self.z_dim) 188 | 189 | # sample z (from unit normal dist) 190 | z = sample_gaussian(mu_z, logvar_z) # sample z 191 | 192 | # decode 193 | output, logit_x = self.decode(z) 194 | 195 | # return 196 | return output, torch.sigmoid(logit_x), z 197 | 198 | def logprob(self, input, sample_size=128, z=None): 199 | ''' 200 | input: positive samples 201 | ''' 202 | # init 203 | batch_size = input.size(0) 204 | input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) 205 | 206 | ''' get log q(z|x) ''' 207 | _, mu_qz, logvar_qz = self.encode(input) 208 | mu_qz = mu_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 209 | logvar_qz = logvar_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 210 | z = self.encode.sample(mu_qz, logvar_qz) 211 | logposterior = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) 212 | logposterior = torch.sum(logposterior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 213 | 214 | ''' get log p(z) ''' 215 | # get prior (as unit normal dist) 216 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 217 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 218 | logprior = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) 219 | logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 220 | 221 | ''' get log p(x|z) ''' 222 | # decode 223 | logit_x = [] 224 | #for i in range(sample_size): 225 | for i in range(batch_size): 226 | _, _logit_x = self.decode(z[i, :, :]) # ssz x zdim 227 | logit_x += [_logit_x.detach().unsqueeze(0)] 228 | logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim 229 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_channels, self.input_height, self.input_height) # bsz x ssz x input_dim 230 | loglikelihood = -F.binary_cross_entropy_with_logits(logit_x, _input, reduction='none') 231 | loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size, -1), dim=2) # bsz x ssz 232 | 233 | ''' get log p(x|z)p(z)/q(z|x) ''' 234 | logprob = loglikelihood + logprior - logposterior # bsz x ssz 235 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 236 | rprob = (logprob - logprob_max).exp() # relative prob 237 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 238 | 239 | # return 240 | return logprob.mean() 241 | -------------------------------------------------------------------------------- /models/vae/toy.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import autograd 8 | 9 | from models.layers import MLP 10 | from models.reparam import NormalDistributionLinear 11 | from utils import loss_kld_gaussian, loss_recon_gaussian, normal_energy_func 12 | from utils import logprob_gaussian 13 | from utils import get_nonlinear_func 14 | 15 | 16 | def sample_gaussian(mu, logvar): 17 | std = torch.exp(0.5*logvar) 18 | eps = torch.randn_like(std) 19 | return mu + std * eps 20 | 21 | class Encoder(nn.Module): 22 | def __init__(self, 23 | input_dim=2, 24 | h_dim=64, 25 | z_dim=2, 26 | nonlinearity='softplus', 27 | num_hidden_layers=1, 28 | ): 29 | super().__init__() 30 | self.input_dim = input_dim 31 | self.h_dim = h_dim 32 | self.z_dim = z_dim 33 | self.nonlinearity = nonlinearity 34 | self.num_hidden_layers = num_hidden_layers 35 | 36 | self.main = MLP(input_dim=input_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 37 | self.reparam = NormalDistributionLinear(h_dim, z_dim) 38 | 39 | def sample(self, mu, logvar): 40 | return self.reparam.sample_gaussian(mu, logvar) 41 | 42 | def forward(self, x): 43 | batch_size = x.size(0) 44 | x = x.view(batch_size, self.input_dim) 45 | 46 | # forward 47 | h = self.main(x) 48 | mu, logvar = self.reparam(h) 49 | 50 | # sample 51 | z = self.sample(mu, logvar) 52 | 53 | return z, mu, logvar 54 | 55 | class Decoder(nn.Module): 56 | def __init__(self, 57 | input_dim=2, 58 | h_dim=64, 59 | z_dim=2, 60 | nonlinearity='tanh', 61 | num_hidden_layers=1, 62 | init='gaussian', #None, 63 | ): 64 | super().__init__() 65 | self.input_dim = input_dim 66 | self.h_dim = h_dim 67 | self.z_dim = z_dim 68 | self.nonlinearity = nonlinearity 69 | self.num_hidden_layers = num_hidden_layers 70 | self.init = init 71 | 72 | self.main = MLP(input_dim=z_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True) 73 | self.reparam = NormalDistributionLinear(h_dim, input_dim) 74 | 75 | if self.init == 'gaussian': 76 | self.reset_parameters() 77 | else: 78 | pass 79 | 80 | def reset_parameters(self): 81 | nn.init.normal_(self.reparam.mean_fn.weight) 82 | 83 | def sample(self, mu, logvar): 84 | return self.reparam.sample_gaussian(mu, logvar) 85 | 86 | def forward(self, z): 87 | batch_size = z.size(0) 88 | z = z.view(batch_size, -1) 89 | 90 | # forward 91 | h = self.main(z) 92 | mu, logvar = self.reparam(h) 93 | 94 | # sample 95 | x = self.sample(mu, logvar) 96 | 97 | return x, mu, logvar 98 | 99 | class VAE(nn.Module): 100 | def __init__(self, 101 | energy_func=normal_energy_func, 102 | input_dim=2, 103 | h_dim=64, 104 | z_dim=2, 105 | nonlinearity='softplus', 106 | num_hidden_layers=1, 107 | init='gaussian', #None, 108 | ): 109 | super().__init__() 110 | self.energy_func = energy_func 111 | self.input_dim = input_dim 112 | self.h_dim = h_dim 113 | self.z_dim = z_dim 114 | self.latent_dim = self.z_dim # for ais 115 | self.nonlinearity = nonlinearity 116 | self.num_hidden_layers = num_hidden_layers 117 | self.init = init 118 | 119 | self.encode = Encoder(input_dim, h_dim, z_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers) 120 | self.decode = Decoder(input_dim, h_dim, z_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers) 121 | 122 | def loss(self, mu_z, logvar_z, mu_x, logvar_x, target_x, beta=1.0): 123 | # kld loss 124 | kld_loss = loss_kld_gaussian(mu_z, logvar_z, do_sum=False) 125 | 126 | # recon loss (likelihood) 127 | recon_loss = loss_recon_gaussian(mu_x, logvar_x, target_x.view(-1, 2), do_sum=False) 128 | 129 | # add loss 130 | loss = recon_loss + beta*kld_loss 131 | return loss.mean(), recon_loss.mean(), kld_loss.mean() 132 | 133 | def forward(self, input, beta=1.0): 134 | # init 135 | batch_size = input.size(0) 136 | input = input.view(batch_size, self.input_dim) 137 | 138 | # encode 139 | z, mu_z, logvar_z = self.encode(input) 140 | 141 | # decode 142 | x, mu_x, logvar_x = self.decode(z) 143 | 144 | # loss 145 | loss, recon_loss, kld_loss \ 146 | = self.loss(mu_z, logvar_z, 147 | mu_x, logvar_x, input, 148 | beta=beta, 149 | ) 150 | 151 | # return 152 | return x, mu_x, z, loss, recon_loss.detach(), kld_loss.detach() 153 | 154 | def generate(self, batch_size=1): 155 | # init mu_z and logvar_z (as unit normal dist) 156 | weight = next(self.parameters()) 157 | mu_z = weight.new_zeros(batch_size, self.z_dim) 158 | logvar_z = weight.new_zeros(batch_size, self.z_dim) 159 | 160 | # sample z (from unit normal dist) 161 | z = sample_gaussian(mu_z, logvar_z) # sample z 162 | 163 | # decode 164 | output, mu_x, logvar_x = self.decode(z) 165 | 166 | # return 167 | return output, mu_x, z 168 | 169 | def logprob(self, input, sample_size=128, z=None): 170 | # init 171 | batch_size = input.size(0) 172 | input = input.view(batch_size, self.input_dim) 173 | 174 | ''' get log q(z|x) ''' 175 | _, mu_qz, logvar_qz = self.encode(input) 176 | mu_qz = mu_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 177 | logvar_qz = logvar_qz.detach().repeat(1, sample_size).view(batch_size, sample_size, self.z_dim) 178 | z = self.encode.sample(mu_qz, logvar_qz) 179 | logposterior = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) 180 | logposterior = torch.sum(logposterior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 181 | 182 | ''' get log p(z) ''' 183 | # get prior (as unit normal dist) 184 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 185 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 186 | logprior = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) 187 | logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz 188 | 189 | ''' get log p(x|z) ''' 190 | # decode 191 | #mu_x, logvar_x = [], [] 192 | #for i in range(batch_size): 193 | # _, _mu_x, _logvar_x = self.decode(z[i, :, :]) # ssz x zdim 194 | # mu_x += [_mu_x.detach().unsqueeze(0)] 195 | # logvar_x += [_logvar_x.detach().unsqueeze(0)] 196 | #mu_x = torch.cat(mu_x, dim=0) # bsz x ssz x input_dim 197 | #logvar_x = torch.cat(logvar_x, dim=0) # bsz x ssz x input_dim 198 | _z = z.view(-1, self.z_dim) 199 | _, mu_x, logvar_x = self.decode(_z) # bsz*ssz x zdim 200 | mu_x = mu_x.view(batch_size, sample_size, self.input_dim) 201 | logvar_x = logvar_x.view(batch_size, sample_size, self.input_dim) 202 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim 203 | loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) 204 | loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz 205 | 206 | ''' get log p(x|z)p(z)/q(z|x) ''' 207 | logprob = loglikelihood + logprior - logposterior # bsz x ssz 208 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 209 | rprob = (logprob - logprob_max).exp() # relative prob 210 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 211 | 212 | # return 213 | return logprob.mean() 214 | 215 | def logprob_w_prior(self, input, sample_size=128, z=None): 216 | # init 217 | batch_size = input.size(0) 218 | input = input.view(batch_size, self.input_dim) 219 | 220 | ''' get z samples from p(z) ''' 221 | # get prior (as unit normal dist) 222 | if z is None: 223 | mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 224 | logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) 225 | z = sample_gaussian(mu_pz, logvar_pz) # sample z 226 | 227 | ''' get log p(x|z) ''' 228 | # decode 229 | _z = z.view(-1, self.z_dim) 230 | _, mu_x, logvar_x = self.decode(_z) # bsz*ssz x zdim 231 | mu_x = mu_x.view(batch_size, sample_size, self.input_dim) 232 | logvar_x = logvar_x.view(batch_size, sample_size, self.input_dim) 233 | _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim 234 | loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) 235 | loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz 236 | 237 | ''' get log p(x) ''' 238 | logprob = loglikelihood # bsz x ssz 239 | logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) 240 | rprob = (logprob-logprob_max).exp() # relative prob 241 | logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 242 | 243 | # return 244 | return logprob.mean() 245 | -------------------------------------------------------------------------------- /run_vae_25gaussians.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python ivae_ardae.py \ 3 | --cache experiments/25gaussians \ 4 | --dataset 25gaussians --nheight 1 --nchannels 2 \ 5 | --model mlp-concat --model-z-dim 2 --model-h-dim 256 --model-n-layers 2 --model-nonlin relu --model-n-dim 10 --model-clip-z0-logvar none --model-clip-z-logvar none \ 6 | --cdae mlp-grad --cdae-h-dim 256 --cdae-n-layers 3 --cdae-nonlin softplus --cdae-ctx-type lt0 \ 7 | --train-batch-size 512 --eval-batch-size 1 --train-nz-cdae 256 --train-nz-model 1 \ 8 | --delta 0.1 --std-scale 10000 --num-cdae-updates 1 \ 9 | --m-lr 0.0001 --m-optimizer adam --m-momentum 0.5 --m-beta1 0.5 \ 10 | --d-lr 0.0001 --d-optimizer rmsprop --d-momentum 0.5 --d-beta1 0.5 \ 11 | --epochs 16 \ 12 | --eval-iws-interval 0 --iws-samples 64 --log-interval 100 --vis-interval 100 --ckpt-interval 1000 --exp-num 1 13 | -------------------------------------------------------------------------------- /run_vae_dbmnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # baseline models 4 | # resconv 5 | python vae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 32 --optimizer adam --momentum 0.9 --beta1 0.9 --model resconv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin elu --model-n-dim 0 --model-clip-logvar none --init-method none --do-m5bias none --exp-num 1 --lr 0.0001 --beta-init 0.0001 --beta-fin 1.0 --beta-annealing 0 --epochs 6400 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode train 6 | 7 | # hierarchical resconv 8 | python vae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 32 --optimizer adam --momentum 0.9 --beta1 0.9 --model auxresconv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin elu --model-n-dim 100 --model-clip-logvar none --init-method none --do-m5bias none --exp-num 1 --lr 0.0001 --beta-init 0.0001 --beta-fin 1.0 --beta-annealing 0 --epochs 6400 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode train 9 | 10 | # conv 11 | python vae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 32 --optimizer adam --momentum 0.5 --beta1 0.5 --model conv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin softplus --model-n-dim 0 --model-clip-logvar none --init-method xv --do-m5bias none --exp-num 1 --lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --epochs 4700 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode train 12 | 13 | # hierarchical conv 14 | python vae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 32 --optimizer adam --momentum 0.5 --beta1 0.5 --model auxconv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin softplus --model-n-dim 100 --model-clip-logvar none --init-method xv --do-m5bias none --exp-num 1 --lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --epochs 4700 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode train 15 | 16 | # mlp 17 | python vae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 32 --optimizer adam --momentum 0.5 --beta1 0.5 --model mnist --model-z-dim 32 --model-h-dim 300 --model-n-layers 2 --model-nonlin softplus --model-n-dim 0 --model-clip-logvar none --init-method xv --do-m5bias none --exp-num 1 --lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --epochs 4700 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode train 18 | 19 | # hierarchical mlp 20 | python vae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 32 --optimizer adam --momentum 0.5 --beta1 0.5 --model auxmnist --model-z-dim 32 --model-h-dim 300 --model-n-layers 2 --model-nonlin softplus --model-n-dim 100 --model-clip-logvar none --init-method xv --do-m5bias none --exp-num 1 --lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --epochs 4700 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode train 21 | 22 | 23 | # proposed method 24 | # implicit resconv 25 | python ivae_ardae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.9 --m-beta1 0.9 --d-optimizer rmsprop --d-momentum 0.9 --d-beta1 0.9 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model resconvct-res --model-z-dim 32 --model-h-dim 512 --model-n-layers 1 --model-nonlin elu --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-res --cdae-h-dim 512 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type lt0 --exp-num 1 --m-lr 0.001 --d-lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --delta 0.1 --std-scale 100 --num-cdae-updates 2 --epochs 6400 --eval-iws-interval 10000 --iws-samples 256 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 50000 --ckpt-interval 10000 --train-mode train 26 | 27 | # hierarchical resconv 28 | python ivae_ardae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.9 --m-beta1 0.9 --d-optimizer rmsprop --d-momentum 0.9 --d-beta1 0.9 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model auxresconvct --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin elu --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-res --cdae-h-dim 512 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type hidden1a --exp-num 1 --m-lr 0.001 --d-lr 0.0001 --beta-init 0.0001 --beta-fin 1.0 --beta-annealing 50000 --delta 0.1 --std-scale 100 --num-cdae-updates 2 --epochs 6400 --eval-iws-interval 10000 --iws-samples 256 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 50000 --ckpt-interval 5000 --train-mode train 29 | 30 | # implicit conv 31 | python ivae_ardae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.5 --m-beta1 0.5 --d-optimizer rmsprop --d-momentum 0.5 --d-beta1 0.5 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model mnist-conv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin softplus --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-grad --cdae-h-dim 256 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type lt0 --exp-num 1 --m-lr 0.0001 --d-lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --delta 0.1 --std-scale 10000 --num-cdae-updates 1 --epochs 6400 --eval-iws-interval 10000 --iws-samples 1024 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 50000 --ckpt-interval 10000 --train-mode train 32 | 33 | # hierarchical conv 34 | python ivae_ardae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.5 --m-beta1 0.5 --d-optimizer rmsprop --d-momentum 0.5 --d-beta1 0.5 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model auxconv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin softplus --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-grad --cdae-h-dim 256 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type hidden1a --exp-num 1 --m-lr 0.0001 --d-lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --delta 0.1 --std-scale 10000 --num-cdae-updates 1 --epochs 6400 --eval-iws-interval 10000 --iws-samples 1024 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 50000 --ckpt-interval 10000 --train-mode train 35 | 36 | # implicit mlp 37 | python ivae_ardae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.5 --m-beta1 0.5 --d-optimizer rmsprop --d-momentum 0.5 --d-beta1 0.5 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model mnist-concat --model-z-dim 32 --model-h-dim 300 --model-n-layers 2 --model-nonlin softplus --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-grad --cdae-h-dim 256 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type lt0 --exp-num 1 --m-lr 0.0001 --d-lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --delta 0.1 --std-scale 10000 --num-cdae-updates 1 --epochs 6400 --eval-iws-interval 10000 --iws-samples 1024 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 50000 --ckpt-interval 10000 --train-mode train 38 | 39 | # hierarchical mlp 40 | python ivae_ardae.py --cache experiments/dbmnist-val5k --dataset dbmnist-val5k --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.5 --m-beta1 0.5 --d-optimizer rmsprop --d-momentum 0.5 --d-beta1 0.5 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model auxmnist --model-z-dim 32 --model-h-dim 300 --model-n-layers 2 --model-nonlin softplus --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-grad --cdae-h-dim 256 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type hidden1a --exp-num 1 --m-lr 0.0001 --d-lr 0.0001 --beta-init 1.0 --beta-fin 1.0 --beta-annealing 0 --delta 0.1 --std-scale 10000 --num-cdae-updates 1 --epochs 6400 --eval-iws-interval 10000 --iws-samples 1024 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 50000 --ckpt-interval 10000 --train-mode train 41 | -------------------------------------------------------------------------------- /run_vae_sbmnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # baseline models 4 | # resconv 5 | python vae.py --cache experiments/sbmnist --dataset sbmnist --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --optimizer adam --momentum 0.9 --beta1 0.9 --model resconv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin elu --model-n-dim 0 --model-clip-logvar none --exp-num 1 --lr 0.001 --beta-init 0.0001 --beta-fin 1.0 --beta-annealing 50000 --epochs 6400 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode train 6 | 7 | python vae.py --cache experiments/sbmnist --dataset sbmnist --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --optimizer adam --momentum 0.9 --beta1 0.9 --model resconv --model-z-dim 32 --model-h-dim 0 --model-n-layers 0 --model-nonlin elu --model-n-dim 0 --model-clip-logvar none --exp-num 1 --lr 0.001 --beta-init 0.0001 --beta-fin 1.0 --beta-annealing 50000 --epochs 6400 --eval-iws-interval 5000 --iws-samples 256 --weight-avg none --weight-avg-start -1 --weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 5000 --train-mode final 8 | 9 | 10 | # proposed method 11 | # implicit resconv 12 | python ivae_ardae.py --cache experiments/sbmnist --dataset sbmnist --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.9 --m-beta1 0.9 --d-optimizer rmsprop --d-momentum 0.9 --d-beta1 0.9 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model resconvct-res --model-z-dim 32 --model-h-dim 512 --model-n-layers 1 --model-nonlin elu --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-res --cdae-h-dim 512 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type lt0 --exp-num 1 --m-lr 0.001 --d-lr 0.0001 --beta-init 0.0001 --beta-fin 1.0 --beta-annealing 50000 --delta 0.1 --std-scale 100 --num-cdae-updates 2 --epochs 6400 --eval-iws-interval 5000 --iws-samples 256 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 800 --train-mode train 13 | 14 | python ivae_ardae.py --cache experiments/sbmnist --dataset sbmnist --nheight 28 --nchannels 1 --train-batch-size 128 --eval-batch-size 1 --m-optimizer adam --m-momentum 0.9 --m-beta1 0.9 --d-optimizer rmsprop --d-momentum 0.9 --d-beta1 0.9 --train-nstd-cdae 1 --train-nz-cdae 625 --train-nz-model 1 --model resconvct-res --model-z-dim 32 --model-h-dim 512 --model-n-layers 1 --model-nonlin elu --model-n-dim 100 --model-clip-z0-logvar none --model-clip-z-logvar none --cdae mlp-res --cdae-h-dim 512 --cdae-n-layers 5 --cdae-nonlin softplus --cdae-ctx-type lt0 --exp-num 1 --m-lr 0.001 --d-lr 0.0001 --beta-init 0.0001 --beta-fin 1.0 --beta-annealing 50000 --delta 0.1 --std-scale 100 --num-cdae-updates 2 --epochs 6400 --eval-iws-interval 5000 --iws-samples 256 --m-weight-avg none --m-weight-avg-start -1 --m-weight-avg-decay 0.998 --log-interval 100 --vis-interval 10000 --ckpt-interval 800 --train-mode final 15 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # msc 2 | from utils.msc import save_checkpoint, load_checkpoint, load_end_iter, logging, get_time, annealing_func, EndIterError 3 | from utils.msc import conv_out_size, deconv_out_size 4 | from utils.msc import expand_tensor 5 | 6 | # visualization 7 | from utils.visualization import convert_npimage_torchimage, get_scatter_plot, get_quiver_plot, get_data_for_quiver_plot, get_prob_from_energy_func_for_vis, get_imshow_plot, get_1d_histogram_plot, get_2d_histogram_plot, get_grid_image 8 | 9 | # models 10 | from utils.models import get_nonlinear_func 11 | 12 | # vae 13 | from utils.vae import loss_recon_bernoulli_with_logit, loss_recon_bernoulli, loss_recon_gaussian, loss_recon_gaussian_w_fixed_var, loss_kld_gaussian, loss_kld_gaussian_vs_gaussian 14 | 15 | # stat 16 | from utils.stat import logprob_gaussian, loss_entropy_gaussian, get_covmat 17 | 18 | # energy 19 | from utils.energy import energy_func1, energy_func2, energy_func3, energy_func4, regularization_func, normal_energy_func, normal_prob 20 | 21 | # optim 22 | from utils.optim import Adam, AdamW 23 | 24 | # lr_scheduler 25 | from utils.lr_scheduler import StepLR 26 | 27 | # jacobian clamping 28 | from utils.jacobian_clamping import minrelu, jac_clamping_loss, cond_jac_clamping_loss 29 | 30 | # sample 31 | from utils.sample import sample_laplace_noise, sample_unit_laplace_noise 32 | -------------------------------------------------------------------------------- /utils/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ''' 4 | https://github.com/nicola-decao/s-vae-pytorch/blob/master/hyperspherical_vae/distributions/hyperspherical_uniform.py 5 | ''' 6 | def sample_hypershperical_uniform_bsz_ssz(batch_size, sample_size, device=torch.device('cpu')): 7 | output = torch.distributions.Normal(0, 1).sample( 8 | (shape if isinstance(shape, torch.Size) else torch.Size([shape])) + torch.Size([self._dim + 1])).to(self.device) 9 | return output / output.norm(dim=-1, keepdim=True) 10 | output = torch.zeros(batch_size, sample_size, 1).normal_(0, 1) 11 | output / output.norm(dim=-1, keepdim=True) 12 | 13 | class HypersphericalUniform(torch.distributions.Distribution): 14 | 15 | support = torch.distributions.constraints.real 16 | has_rsample = False 17 | _mean_carrier_measure = 0 18 | 19 | @property 20 | def dim(self): 21 | return self._dim 22 | 23 | @property 24 | def device(self): 25 | return self._device 26 | 27 | @device.setter 28 | def device(self, val): 29 | self._device = val if isinstance(val, torch.device) else torch.device(val) 30 | 31 | def __init__(self, dim, validate_args=None, device="cpu"): 32 | super(HypersphericalUniform, self).__init__(torch.Size([dim]), validate_args=validate_args) 33 | self._dim = dim 34 | self.device = device 35 | 36 | def sample(self, shape=torch.Size()): 37 | output = torch.distributions.Normal(0, 1).sample( 38 | (shape if isinstance(shape, torch.Size) else torch.Size([shape])) + torch.Size([self._dim + 1])).to(self.device) 39 | 40 | return output / output.norm(dim=-1, keepdim=True) 41 | 42 | def entropy(self): 43 | return self.__log_surface_area() 44 | 45 | def log_prob(self, x): 46 | return - torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area() 47 | 48 | def __log_surface_area(self): 49 | return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - torch.lgamma( 50 | torch.Tensor([(self._dim + 1) / 2], device=self.device)) 51 | -------------------------------------------------------------------------------- /utils/energy.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | EPS = 1e-9 6 | 7 | def regularization_func(x): 8 | return (torch.relu(x.abs() - 6) ** 2).sum(-1, keepdim=True) 9 | 10 | def w1(z1): 11 | return torch.sin(2.*math.pi*z1/4.) 12 | 13 | def w2(z1): 14 | return 3.*torch.exp(-0.5*((z1-1)/0.6)**2) 15 | 16 | def w3(z1): 17 | return 3.*torch.sigmoid((z1-1.)/0.3) 18 | 19 | def energy_func1(x): 20 | assert x.dim() == 2 21 | assert x.size(1) == 2 22 | batch_size = x.size(0) 23 | x1 = x[:, :1] 24 | x2 = x[:, 1:] 25 | xnorm = torch.norm(x, dim=1, keepdim=True) 26 | energy = 0.5 * ((xnorm-2)/0.4)**2 \ 27 | - torch.log( 28 | torch.exp(-0.5*((x1-2)/0.6)**2) 29 | + torch.exp(-0.5*((x1+2)/0.6)**2) 30 | + EPS) 31 | return energy + regularization_func(x) 32 | 33 | def energy_func2(x): 34 | assert x.dim() == 2 35 | assert x.size(1) == 2 36 | batch_size = x.size(0) 37 | x1 = x[:, :1] 38 | x2 = x[:, 1:] 39 | 40 | energy = 0.5 * ((x2-w1(x1))/0.4)**2 41 | return energy + regularization_func(x) 42 | 43 | def energy_func3(x): 44 | assert x.dim() == 2 45 | assert x.size(1) == 2 46 | batch_size = x.size(0) 47 | x1 = x[:, :1] 48 | x2 = x[:, 1:] 49 | 50 | energy = -torch.log( 51 | torch.exp(-0.5 * ((x2-w1(x1))/0.35)**2) 52 | + torch.exp(-0.5 * ((x2-w1(x1)+w2(x1))/0.35)**2) 53 | + EPS) 54 | return energy + regularization_func(x) 55 | 56 | def energy_func4(x): 57 | assert x.dim() == 2 58 | assert x.size(1) == 2 59 | batch_size = x.size(0) 60 | x1 = x[:, :1] 61 | x2 = x[:, 1:] 62 | 63 | energy = -torch.log( 64 | torch.exp(-0.5 * ((x2-w1(x1))/0.4)**2) 65 | + torch.exp(-0.5 * ((x2-w1(x1)+w3(x1))/0.35)**2) 66 | + EPS) 67 | return energy + regularization_func(x) 68 | 69 | def _normal_energy_func(x, mu=0., logvar=0.): 70 | energy = logvar + (x - mu)**2 / math.exp(logvar) + math.log(2.*math.pi) 71 | energy = 0.5 * energy 72 | return energy 73 | 74 | def normal_energy_func(x, mu=0., logvar=0.): 75 | batch_size = x.size(0) 76 | x = x.view(batch_size, -1) 77 | return torch.sum(_normal_energy_func(x, mu, logvar), dim=1) 78 | 79 | #def normal_energy_func(x, 80 | # mu1=0., logvar1=0., 81 | # mu2=0., logvar2=0., 82 | # ): 83 | # assert x.dim() == 2 84 | # assert x.size(1) == 2 85 | # batch_size = x.size(0) 86 | # x = x.view(batch_size, -1) 87 | # x1 = x[:, :1] 88 | # x2 = x[:, 1:] 89 | # 90 | # energy = _normal_energy_func(x1, mu1, logvar1) \ 91 | # + _normal_energy_func(x2, mu2, logvar2) 92 | # return energy 93 | 94 | def normal_prob(x, mu=0., std=1.): 95 | ''' 96 | Inputs:⋅ 97 | x: b1 x 1 98 | mu, logvar: scalar 99 | Outputs: 100 | prob: b1 x nz 101 | ''' 102 | var = std**2 103 | logvar = math.log(var) 104 | logprob = - normal_energy_func(x, mu, logvar) 105 | #return logprob 106 | prob = torch.exp(logprob) 107 | return prob 108 | 109 | #def energy_to_unnormalized_prob(energy): 110 | # prob = torch.exp(-energy) # unnormalized prob 111 | # return prob 112 | # 113 | #def get_data_for_heatmap(val=4, num=256): 114 | # _x = np.linspace(-val, val, num) 115 | # _y = np.linspace(-val, val, num) 116 | # _u, _v = np.meshgrid(_x, _y) 117 | # _data = np.stack([_u.reshape(num**2), _v.reshape(num**2)], axis=1) 118 | # return _data, _x, _y 119 | # 120 | #def run_energy_fun_for_vis(energy_func, num=256): 121 | # _z, _, _ = get_data_for_heatmap(num=num) 122 | # z = torch.from_numpy(_z) 123 | # prob = energy_func(z) 124 | # _prob = prob.cpu().float().numpy() 125 | # _prob = _prob.reshape(num, num) 126 | # return _prob 127 | # 128 | #def get_energy_func_plot(prob): 129 | # # plot 130 | # fig, ax = plt.subplots(figsize=(5, 5)) 131 | # im = ax.imshow(prob, cmap='jet') 132 | # ax.grid(False) 133 | # 134 | # # draw to canvas 135 | # fig.canvas.draw() # draw the canvas, cache the renderer 136 | # image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 137 | # image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 138 | # 139 | # # close figure 140 | # plt.close() 141 | # return image 142 | 143 | ''' test ''' 144 | ''' 145 | import numpy as np 146 | import torchvision.utils as vutils 147 | from visualization import convert_npimage_torchimage, get_prob_from_energy_func_for_vis, get_imshow_plot 148 | 149 | #energy_func = energy_func4 150 | energy_func = normal_energy_func 151 | _prob = get_prob_from_energy_func_for_vis(energy_func, num=256) 152 | _img = get_imshow_plot(_prob) 153 | filename='hi.png' 154 | vutils.save_image(convert_npimage_torchimage(_img), filename) 155 | ''' 156 | -------------------------------------------------------------------------------- /utils/jacobian_clamping.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Odena, Augustus, et al. "Is generator conditioning causally related to gan performance?." arXiv preprint arXiv:1802.08768 (2018). 3 | Kumar, Abhishek, Ben Poole, and Kevin Murphy. Learning Generative Samplers using Relaxed Injective Flow. Technical report, 2019. (https://invertibleworkshop.github.io/accepted_papers/pdfs/INNF_2019_paper_32.pdf) 4 | copy and modified from https://github.com/MCDM2018/Implementations/blob/master/GAN.py 5 | ''' 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | def minrelu(x): 10 | return -F.relu(-x) 11 | 12 | def jac_clamping_loss( 13 | forward, 14 | x, z, 15 | num_pert_samples, 16 | eta_min, p=2, EPS=0.01, 17 | postprocessing=None, 18 | ): 19 | ''' 20 | batch_size == num_z_samples 21 | forward: f(ctx, z, num_pert_samples) 22 | x: batch_size x x_dim 23 | z: batch_size x z_dim 24 | ''' 25 | # init 26 | batch_size = x.size(0) 27 | x_dim = x.size(-1) 28 | z_dim = z.size(-1) 29 | numel = batch_size*num_pert_samples 30 | assert x.size(0) == batch_size 31 | assert z.size(0) == batch_size 32 | assert p == 2 33 | 34 | # get perturb and z_bar 35 | perturb = torch.randn([batch_size, num_pert_samples, z_dim], device=z.device) 36 | z_bar = z.unsqueeze(1) + EPS*perturb # bsz x psz x zdim 37 | 38 | # forward 39 | x = x.unsqueeze(1).expand(batch_size, num_pert_samples, -1).contiguous().view(numel, x_dim) 40 | z_bar = z_bar.view(numel, z_dim) 41 | x_bar = forward(z_bar) # numel x x_dim 42 | if postprocessing: 43 | x = postprocessing(x) 44 | x_bar = postprocessing(x_bar) 45 | 46 | # get diff 47 | x_diff = x_bar - x 48 | 49 | # flatten 50 | x_diff_flattened = x_diff.view(numel, x_dim) 51 | perturb_flattened = perturb.view(numel, z_dim) 52 | 53 | # get jac 54 | unjac_l2sq = torch.sum((x_diff_flattened**2), dim=1)/(EPS**2) #torch.norm(x_diff_flattened, dim=1, p=p) 55 | per_l2sq = torch.sum((perturb_flattened**2), dim=1) #torch.norm(eps_diff_flattened, dim=1, p=p) 56 | jac_l2sq = unjac_l2sq / per_l2sq 57 | 58 | # get loss 59 | loss = (minrelu(jac_l2sq-eta_min))**2 60 | 61 | # return 62 | return loss.mean() 63 | 64 | def cond_jac_clamping_loss( 65 | forward, 66 | x, ctx, z, 67 | num_z_samples, num_pert_samples, 68 | eta_min, p=2, EPS=0.01, 69 | postprocessing=None, 70 | ): 71 | ''' 72 | forward: f(ctx, z, num_z_samples, num_pert_samples) 73 | x: batch_size x num_z_samples x x_dim 74 | z: batch_size x num_z_samples x z_dim 75 | ctx: batch_size x ctx_dim 76 | ''' 77 | # init 78 | batch_size = ctx.size(0) 79 | x_dim = x.size(-1) 80 | z_dim = z.size(-1) 81 | numel = batch_size*num_z_samples*num_pert_samples 82 | assert x.size(0) == batch_size*num_z_samples 83 | assert z.size(0) == batch_size*num_z_samples 84 | assert p == 2 85 | 86 | # get perturb and z_bar 87 | perturb = torch.randn([batch_size*num_z_samples, num_pert_samples, z_dim], device=z.device) 88 | z_bar = z.unsqueeze(1) + EPS*perturb # bsz*zsz x psz x zdim 89 | 90 | # forward 91 | x = x.unsqueeze(1).expand(batch_size*num_z_samples, num_pert_samples, -1).contiguous().view(numel, x_dim) 92 | z_bar = z_bar.view(numel, z_dim) 93 | x_bar = forward(ctx, z_bar, num_z_samples, num_pert_samples) # numel x x_dim 94 | if postprocessing: 95 | x = postprocessing(x) 96 | x_bar = postprocessing(x_bar) 97 | 98 | # get diff 99 | x_diff = x_bar - x 100 | 101 | # flatten 102 | x_diff_flattened = x_diff.view(numel, x_dim) 103 | perturb_flattened = perturb.view(numel, z_dim) 104 | 105 | # get jac 106 | unjac_l2sq = torch.sum((x_diff_flattened**2), dim=1)/(EPS**2) #torch.norm(x_diff_flattened, dim=1, p=p) 107 | per_l2sq = torch.sum((perturb_flattened**2), dim=1) #torch.norm(eps_diff_flattened, dim=1, p=p) 108 | jac_l2sq = unjac_l2sq / per_l2sq 109 | 110 | # get loss 111 | loss = (minrelu(jac_l2sq-eta_min))**2 112 | 113 | # return 114 | return loss.mean() 115 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py 3 | ''' 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | class StepLR(_LRScheduler): 7 | """Sets the learning rate of each parameter group to the initial lr 8 | decayed by gamma every step_size epochs. When last_epoch=-1, sets 9 | initial lr as lr. 10 | 11 | Args: 12 | optimizer (Optimizer): Wrapped optimizer. 13 | step_size (int): Period of learning rate decay. 14 | gamma (float): Multiplicative factor of learning rate decay. 15 | Default: 0.1. 16 | last_epoch (int): The index of last epoch. Default: -1. 17 | 18 | Example: 19 | >>> # Assuming optimizer uses lr = 0.05 for all groups 20 | >>> # lr = 0.05 if epoch < 30 21 | >>> # lr = 0.005 if 30 <= epoch < 60 22 | >>> # lr = 0.0005 if 60 <= epoch < 90 23 | >>> # ... 24 | >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 25 | >>> for epoch in range(100): 26 | >>> train(...) 27 | >>> validate(...) 28 | >>> scheduler.step() 29 | """ 30 | 31 | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, min_lr=None): 32 | self.step_size = step_size 33 | self.gamma = gamma 34 | self.min_lr = min_lr 35 | super(StepLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | return [max(self.min_lr, base_lr * self.gamma ** (self.last_epoch // self.step_size)) 39 | for base_lr in self.base_lrs] 40 | -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def softplus(x): 6 | return torch.log(torch.exp(x) + 1) 7 | 8 | def swish(x): 9 | ''' 10 | https://arxiv.org/abs/1710.05941 11 | ''' 12 | return x*torch.sigmoid(x) 13 | 14 | def get_nonlinear_func(nonlinearity_type='elu'): 15 | if nonlinearity_type == 'relu': 16 | return F.relu 17 | elif nonlinearity_type == 'elu': 18 | return F.elu 19 | elif nonlinearity_type == 'tanh': 20 | return torch.tanh 21 | elif nonlinearity_type == 'softplus': 22 | return F.softplus 23 | elif nonlinearity_type == 'csoftplus': 24 | return softplus 25 | elif nonlinearity_type == 'leaky_relu': 26 | def leaky_relu(input): 27 | return F.leaky_relu(input, negative_slope=0.2) 28 | return leaky_relu 29 | elif nonlinearity_type == 'swish': 30 | return swish 31 | else: 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /utils/msc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | miscellaneous functions: learning 3 | ''' 4 | import os 5 | import datetime 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torchvision.utils as vutils 11 | 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | import seaborn as sns 15 | import matplotlib.pyplot as plt 16 | 17 | from sklearn.manifold import TSNE 18 | 19 | 20 | ''' expand tensor ''' 21 | def expand_tensor(input, sample_size, do_unsqueeze): 22 | batch_size = input.size(0) 23 | if do_unsqueeze: 24 | sz_from = [-1]*(input.dim()+1) 25 | sz_from[1] = sample_size 26 | input_expanded = input.unsqueeze(1).expand(*sz_from).contiguous() 27 | 28 | sz_to = list(input.size()) 29 | sz_to[0] = batch_size*sample_size 30 | else: 31 | assert input.size(1) == 1 32 | sz_from = [-1]*(input.dim()) 33 | sz_from[1] = sample_size 34 | input_expanded = input.expand(*sz_from).contiguous() 35 | 36 | _sz_to = list(input.size()) 37 | sz_to = _sz_to[0:1]+_sz_to[2:] 38 | sz_to[0] = batch_size*sample_size 39 | input_expanded_flattend = input_expanded.view(*sz_to) 40 | return input_expanded, input_expanded_flattend 41 | 42 | ''' cont out size ''' 43 | def conv_out_size(hin, kernel_size, stride=1, padding=0, dilation=1): 44 | hout = (hin + 2*padding - dilation*(kernel_size-1) - 1)/stride + 1 45 | return int(hout) 46 | 47 | def deconv_out_size(hin, kernel_size, stride=1, padding=0, output_padding=0, dilation=1): 48 | hout = (hin-1)*stride - 2*padding + dilation*(kernel_size-1) + output_padding + 1 49 | return int(hout) 50 | 51 | 52 | ''' annealing ''' 53 | def annealing_func(val_init, val_fin, val_annealing, step): 54 | val = val_init + (val_fin - val_init) / float(val_annealing) * float(min(val_annealing, step)) if val_annealing is not None else val_fin 55 | return float(val) 56 | 57 | 58 | ''' for monitoring lr ''' 59 | def get_lrs(optimizer): 60 | lrs = [float(param_group['lr']) for param_group in optimizer.param_groups] 61 | lr_max = max(lrs) 62 | lr_min = min(lrs) 63 | return lr_min, lr_max 64 | 65 | 66 | ''' save and load ''' 67 | def save_checkpoint(state, opt, is_best, filename='checkpoint.pth.tar'): 68 | filename = os.path.join(opt.path, filename) 69 | print("=> save checkpoint '{}'".format(filename)) 70 | torch.save(state, filename) 71 | if is_best: 72 | shutil.copyfile(filename, 'model_best.pth.tar') 73 | 74 | def load_checkpoint(model, optimizer, opt, filename='checkpoint.pth.tar', verbose=True, device=None, scheduler=None): 75 | filename = os.path.join(opt.path, filename) 76 | if os.path.isfile(filename): 77 | if verbose: 78 | print("=> loading checkpoint '{}'".format(filename)) 79 | checkpoint = torch.load(filename, map_location=device) if device is not None else torch.load(filename) 80 | opt.start_epoch = checkpoint['epoch'] 81 | opt.start_batch_idx = checkpoint['batch_idx'] 82 | opt.best_val_loss = checkpoint['best_val_loss'] 83 | if 'train_num_iters_per_epoch' in checkpoint.keys(): 84 | opt.train_num_iters_per_epoch = checkpoint['train_num_iters_per_epoch'] 85 | if model is not None: 86 | model.load_state_dict(checkpoint['state_dict']) 87 | if optimizer is not None: 88 | optimizer.load_state_dict(checkpoint['optimizer']) 89 | if scheduler is not None: 90 | scheduler.load_state_dict(checkpoint['scheduler']) 91 | if verbose: 92 | print("=> loaded checkpoint '{}'".format(filename)) 93 | if 'start_std' in checkpoint.keys(): 94 | opt.start_std = checkpoint['start_std'] 95 | else: 96 | print("=> no checkpoint found at '{}'".format(filename)) 97 | 98 | def load_end_iter(opt, filename='best-checkpoint.pth.tar', verbose=True, device=None): 99 | filename = os.path.join(opt.path, filename) 100 | if os.path.isfile(filename): 101 | if verbose: 102 | print("=> loading checkpoint '{}'".format(filename)) 103 | checkpoint = torch.load(filename, map_location=device) if device is not None else torch.load(filename) 104 | start_epoch = checkpoint['epoch'] 105 | start_batch_idx = checkpoint['batch_idx'] 106 | train_num_iters_per_epoch = checkpoint['train_num_iters_per_epoch'] 107 | i_ep = (start_epoch-1)*train_num_iters_per_epoch + start_batch_idx 108 | return i_ep-1 109 | else: 110 | raise ValueError("=> no checkpoint found at '{}'".format(filename)) 111 | 112 | class EndIterError(Exception): 113 | pass 114 | 115 | 116 | ''' log ''' 117 | def logging(s, path=None, filename='log.txt'): 118 | # print 119 | print(s) 120 | 121 | # save 122 | if path is not None: 123 | assert path, 'path is not define. path: {}'.format(path) 124 | with open(os.path.join(path, filename), 'a+') as f_log: 125 | f_log.write(s + '\n') 126 | 127 | def get_time(): 128 | return datetime.datetime.now().strftime('%y%m%d-%H:%M:%S') 129 | #return datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S.%f') 130 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py 3 | ''' 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class Adam(Optimizer): 10 | r"""Implements Adam algorithm. 11 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 12 | Arguments: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float, optional): learning rate (default: 1e-3) 16 | betas (Tuple[float, float], optional): coefficients used for computing 17 | running averages of gradient and its square (default: (0.9, 0.999)) 18 | eps (float, optional): term added to the denominator to improve 19 | numerical stability (default: 1e-8) 20 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 21 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 22 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 23 | (default: False) 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _On the Convergence of Adam and Beyond: 27 | https://openreview.net/forum?id=ryQu7f-RZ 28 | """ 29 | 30 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 31 | weight_decay=0, amsgrad=False): 32 | if not 0.0 <= lr: 33 | raise ValueError("Invalid learning rate: {}".format(lr)) 34 | if not 0.0 <= eps: 35 | raise ValueError("Invalid epsilon value: {}".format(eps)) 36 | if not 0.0 <= betas[0] < 1.0: 37 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 38 | if not 0.0 <= betas[1] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 40 | defaults = dict(lr=lr, betas=betas, eps=eps, 41 | weight_decay=weight_decay, amsgrad=amsgrad) 42 | super(Adam, self).__init__(params, defaults) 43 | 44 | def __setstate__(self, state): 45 | super(Adam, self).__setstate__(state) 46 | for group in self.param_groups: 47 | group.setdefault('amsgrad', False) 48 | 49 | def step(self, closure=None): 50 | """Performs a single optimization step. 51 | Arguments: 52 | closure (callable, optional): A closure that reevaluates the model 53 | and returns the loss. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | for group in self.param_groups: 60 | for p in group['params']: 61 | if p.grad is None: 62 | continue 63 | grad = p.grad.data 64 | if grad.is_sparse: 65 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 66 | amsgrad = group['amsgrad'] 67 | 68 | state = self.state[p] 69 | 70 | # State initialization 71 | if len(state) == 0: 72 | state['step'] = 0 73 | # Exponential moving average of gradient values 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | # Exponential moving average of squared gradient values 76 | state['exp_avg_sq'] = torch.zeros_like(p.data) 77 | if amsgrad: 78 | # Maintains max of all exp. moving avg. of sq. grad. values 79 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 80 | 81 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 82 | if amsgrad: 83 | max_exp_avg_sq = state['max_exp_avg_sq'] 84 | beta1, beta2 = group['betas'] 85 | 86 | state['step'] += 1 87 | bias_correction1 = 1 - beta1 ** state['step'] 88 | bias_correction2 = 1 - beta2 ** state['step'] 89 | 90 | if group['weight_decay'] != 0: 91 | grad.add_(group['weight_decay'], p.data) 92 | 93 | # Decay the first and second moment running average coefficient 94 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 95 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 96 | if amsgrad: 97 | # Maintains the maximum of all 2nd moment running avg. till now 98 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 99 | # Use the max. for normalizing running avg. of gradient 100 | denom = (max_exp_avg_sq.sqrt().add_(group['eps']) / math.sqrt(bias_correction2)) 101 | else: 102 | denom = (exp_avg_sq.sqrt().add_(group['eps']) / math.sqrt(bias_correction2)) 103 | 104 | step_size = group['lr'] / bias_correction1 105 | 106 | p.data.addcdiv_(-step_size, exp_avg, denom) 107 | 108 | return loss 109 | 110 | 111 | class AdamW(Optimizer): 112 | r"""Implements AdamW algorithm. 113 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 114 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 115 | Arguments: 116 | params (iterable): iterable of parameters to optimize or dicts defining 117 | parameter groups 118 | lr (float, optional): learning rate (default: 1e-3) 119 | betas (Tuple[float, float], optional): coefficients used for computing 120 | running averages of gradient and its square (default: (0.9, 0.999)) 121 | eps (float, optional): term added to the denominator to improve 122 | numerical stability (default: 1e-8) 123 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 124 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 125 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 126 | (default: False) 127 | .. _Adam\: A Method for Stochastic Optimization: 128 | https://arxiv.org/abs/1412.6980 129 | .. _Decoupled Weight Decay Regularization: 130 | https://arxiv.org/abs/1711.05101 131 | .. _On the Convergence of Adam and Beyond: 132 | https://openreview.net/forum?id=ryQu7f-RZ 133 | """ 134 | 135 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 136 | weight_decay=1e-2, amsgrad=False): 137 | if not 0.0 <= lr: 138 | raise ValueError("Invalid learning rate: {}".format(lr)) 139 | if not 0.0 <= eps: 140 | raise ValueError("Invalid epsilon value: {}".format(eps)) 141 | if not 0.0 <= betas[0] < 1.0: 142 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 143 | if not 0.0 <= betas[1] < 1.0: 144 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 145 | defaults = dict(lr=lr, betas=betas, eps=eps, 146 | weight_decay=weight_decay, amsgrad=amsgrad) 147 | super(AdamW, self).__init__(params, defaults) 148 | 149 | def __setstate__(self, state): 150 | super(AdamW, self).__setstate__(state) 151 | for group in self.param_groups: 152 | group.setdefault('amsgrad', False) 153 | 154 | def step(self, closure=None): 155 | """Performs a single optimization step. 156 | Arguments: 157 | closure (callable, optional): A closure that reevaluates the model 158 | and returns the loss. 159 | """ 160 | loss = None 161 | if closure is not None: 162 | loss = closure() 163 | 164 | for group in self.param_groups: 165 | for p in group['params']: 166 | if p.grad is None: 167 | continue 168 | 169 | # Perform stepweight decay 170 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 171 | 172 | # Perform optimization step 173 | grad = p.grad.data 174 | if grad.is_sparse: 175 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 176 | amsgrad = group['amsgrad'] 177 | 178 | state = self.state[p] 179 | 180 | # State initialization 181 | if len(state) == 0: 182 | state['step'] = 0 183 | # Exponential moving average of gradient values 184 | state['exp_avg'] = torch.zeros_like(p.data) 185 | # Exponential moving average of squared gradient values 186 | state['exp_avg_sq'] = torch.zeros_like(p.data) 187 | if amsgrad: 188 | # Maintains max of all exp. moving avg. of sq. grad. values 189 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 190 | 191 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 192 | if amsgrad: 193 | max_exp_avg_sq = state['max_exp_avg_sq'] 194 | beta1, beta2 = group['betas'] 195 | 196 | state['step'] += 1 197 | bias_correction1 = 1 - beta1 ** state['step'] 198 | bias_correction2 = 1 - beta2 ** state['step'] 199 | 200 | # Decay the first and second moment running average coefficient 201 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 202 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 203 | if amsgrad: 204 | # Maintains the maximum of all 2nd moment running avg. till now 205 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 206 | # Use the max. for normalizing running avg. of gradient 207 | denom = (max_exp_avg_sq.sqrt().add_(group['eps']) / math.sqrt(bias_correction2)) 208 | else: 209 | denom = (exp_avg_sq.sqrt().add_(group['eps']) / math.sqrt(bias_correction2)) 210 | 211 | step_size = group['lr'] / bias_correction1 212 | 213 | p.data.addcdiv_(-step_size, exp_avg, denom) 214 | 215 | return loss 216 | -------------------------------------------------------------------------------- /utils/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def sample_laplace_noise(loc, scale, shape, dtype, device): 4 | ''' 5 | https://github.com/pytorch/pytorch/blob/6911ce19d7fcf06e7af241e6494b23acdc320dc4/torch/distributions/laplace.py 6 | ''' 7 | finfo = torch.finfo(dtype) 8 | u = torch.zeros(shape, dtype=dtype, device=device).uniform_(finfo.eps - 1, 1) 9 | return loc - scale * u.sign() * torch.log1p(-u.abs()) 10 | 11 | def sample_unit_laplace_noise(shape, dtype, device): 12 | return sample_laplace_noise(0., 1., shape, dtype, device) 13 | -------------------------------------------------------------------------------- /utils/stat.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def shuffle(z): 10 | batch_size = z.size(0) 11 | z_dim = z.size(1) 12 | indices = [torch.from_numpy(np.random.permutation(batch_size)).to(z.device) for i in range(z_dim)] 13 | new_z = [z[:,i:i+1].index_select(0, indices[i]) for i in range(z_dim)] 14 | new_z = torch.cat(new_z, dim=1) 15 | return new_z 16 | 17 | def loss_entropy_gaussian(mu, logvar, do_sum=True): 18 | # mu, logvar = nomral distribution 19 | entropy_loss_element = logvar + 1. + math.log(2.*math.pi) 20 | 21 | # do sum 22 | if do_sum: 23 | entropy_loss = torch.sum(entropy_loss_element) * 0.5 24 | return entropy_loss 25 | else: 26 | #entropy_loss_element = torch.sum(entropy_loss_element, 1) * 0.5 + (1. + math.log(2.*math.pi)) * 0.5 27 | entropy_loss_element = entropy_loss_element * 0.5 28 | return entropy_loss_element 29 | 30 | def prob_gaussian(mu, logvar, z, eps=1e-6, do_unsqueeze=True, do_mean=True): 31 | ''' 32 | Inputs: 33 | z: b1 x nz 34 | mu, logvar: b2 x nz 35 | Outputs: 36 | prob: b1 x nz 37 | ''' 38 | if do_unsqueeze: 39 | z = z.unsqueeze(1) 40 | mu = mu.unsqueeze(0) 41 | logvar = logvar.unsqueeze(0) 42 | 43 | var = logvar.exp() + eps 44 | std = torch.sqrt(var) + eps 45 | 46 | prob = torch.exp(- 0.5 * (z - mu)**2 / var) / std / math.sqrt(2.*math.pi) 47 | 48 | if do_mean: 49 | assert do_unsqueeze 50 | prob = torch.mean(prob, dim=1) 51 | 52 | return prob 53 | 54 | def loss_marginal_entropy_gaussian(mu, logvar, z, do_sum=True): 55 | marginal_entropy_loss_element = -torch.log(prob_gaussian(mu, logvar, z)) 56 | 57 | # do sum 58 | if do_sum: 59 | marginal_entropy_loss = torch.sum(marginal_entropy_loss_element) 60 | return marginal_entropy_loss 61 | else: 62 | #marginal_entropy_loss_element = torch.sum(marginal_entropy_loss_element, 1) 63 | return marginal_entropy_loss_element 64 | 65 | def logprob_gaussian(mu, logvar, z, do_unsqueeze=True, do_mean=True): 66 | ''' 67 | Inputs: 68 | z: b1 x nz 69 | mu, logvar: b2 x nz 70 | Outputs: 71 | prob: b1 x nz 72 | ''' 73 | if do_unsqueeze: 74 | z = z.unsqueeze(1) 75 | mu = mu.unsqueeze(0) 76 | logvar = logvar.unsqueeze(0) 77 | 78 | neglogprob = (z - mu)**2 / logvar.exp() + logvar + math.log(2.*math.pi) 79 | logprob = - neglogprob*0.5 80 | 81 | if do_mean: 82 | assert do_unsqueeze 83 | logprob = torch.mean(logprob, dim=1) 84 | 85 | return logprob 86 | 87 | def loss_approx_marginal_entropy_gaussian(mu, logvar, z, do_sum=True): 88 | marginal_entropy_loss_element = -logprob_gaussian(mu, logvar, z) 89 | 90 | # do sum 91 | if do_sum: 92 | marginal_entropy_loss = torch.sum(marginal_entropy_loss_element) 93 | return marginal_entropy_loss 94 | else: 95 | #marginal_entropy_loss_element = torch.sum(marginal_entropy_loss_element, 1) 96 | return marginal_entropy_loss_element 97 | 98 | def logprob_gaussian_w_fixed_var(mu, z, std=1.0, do_unsqueeze=True, do_mean=True): 99 | ''' 100 | Inputs: 101 | z: b1 x nz 102 | mu, logvar: b2 x nz 103 | Outputs: 104 | prob: b1 x nz 105 | ''' 106 | # init var, logvar 107 | var = std**2 108 | logvar = math.log(var) 109 | 110 | if do_unsqueeze: 111 | z = z.unsqueeze(1) 112 | mu = mu.unsqueeze(0) 113 | #logvar = logvar.unsqueeze(0) 114 | 115 | neglogprob = (z - mu)**2 / var + logvar + math.log(2.*math.pi) 116 | logprob = - neglogprob*0.5 117 | 118 | if do_mean: 119 | assert do_unsqueeze 120 | logprob = torch.mean(logprob, dim=1) 121 | 122 | return logprob 123 | 124 | ''' 125 | copied and modified from https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 126 | ''' 127 | def get_covmat(m, rowvar=False): 128 | '''Estimate a covariance matrix given data. 129 | 130 | Covariance indicates the level to which two variables vary together. 131 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 132 | then the covariance matrix element `C_{ij}` is the covariance of 133 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 134 | 135 | Args: 136 | m: A 1-D or 2-D array containing multiple variables and observations. 137 | Each row of `m` represents a variable, and each column a single 138 | observation of all those variables. 139 | rowvar: If `rowvar` is True, then each row represents a 140 | variable, with observations in the columns. Otherwise, the 141 | relationship is transposed: each column represents a variable, 142 | while the rows contain observations. 143 | rowvar == True: m: dim x batch_size 144 | rowvar == False: m: batch_size x dim 145 | Returns: 146 | The covariance matrix of the variables. 147 | ''' 148 | if m.dim() > 2: 149 | raise ValueError('m has more than 2 dimensions') 150 | if m.dim() < 2: 151 | m = m.view(1, -1) 152 | if not rowvar and m.size(0) != 1: 153 | m = m.t() 154 | # m = m.type(torch.double) # uncomment this line if desired 155 | fact = 1.0 / (m.size(1) - 1) 156 | m = m - torch.mean(m, dim=1, keepdim=True) 157 | mt = m.t() # if complex: mt = m.t().conj() 158 | return fact * m.matmul(mt).squeeze() 159 | -------------------------------------------------------------------------------- /utils/vae.py: -------------------------------------------------------------------------------- 1 | ''' 2 | miscellaneous functions: prob 3 | ''' 4 | import os 5 | import datetime 6 | import math 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | #from torch.autograd import Variable 13 | #from torch.distributions import Categorical, Normal 14 | 15 | 16 | ''' for vae ''' 17 | #def binary_cross_entropy_with_logits(logit, target): 18 | # zeros = logit.new_zeros(logit.size()) 19 | # return torch.max(logit, zeros) - logit * target + (1 + (-logit.abs()).exp()).log() 20 | 21 | def loss_recon_bernoulli_with_logit(logit, x, do_sum=True): 22 | # p = recon prob 23 | if do_sum: 24 | return F.binary_cross_entropy_with_logits(logit, x, reduction='sum') 25 | #return torch.sum(binary_cross_entropy_with_logits(logit, x)) 26 | else: 27 | batch_size = x.size(0) 28 | cross_entropy = F.binary_cross_entropy_with_logits(logit, x, reduction='none') 29 | #cross_entropy = binary_cross_entropy_with_logits(logit, x) 30 | return torch.sum(cross_entropy.view(batch_size, -1), dim=1) 31 | 32 | def loss_recon_bernoulli(p, x): 33 | # p = recon prob 34 | return F.binary_cross_entropy(p, x, size_average=False) 35 | 36 | def loss_recon_gaussian(mu, logvar, x, const=None, do_sum=True): 37 | # https://math.stackexchange.com/questions/1307381/logarithm-of-gaussian-function-is-whether-convex-or-nonconvex 38 | # mu, logvar = nomral distribution 39 | recon_loss_element = logvar + (x - mu)**2 / logvar.exp() + math.log(2.*math.pi) 40 | 41 | # add const (can be used in change of variable) 42 | if const is not None: 43 | recon_loss_element += const 44 | 45 | # do sum 46 | if do_sum: 47 | recon_loss = torch.sum(recon_loss_element) * 0.5 48 | return recon_loss 49 | else: 50 | batch_size = recon_loss_element.size(0) 51 | recon_loss_element = torch.sum(recon_loss_element.view(batch_size, -1), 1) * 0.5 52 | return recon_loss_element 53 | 54 | def loss_recon_gaussian_w_fixed_var(mu, x, std=1.0, const=None, do_sum=True, add_logvar=True): 55 | # init var, logvar 56 | var = std**2 57 | logvar = math.log(var) 58 | 59 | # estimate loss per element 60 | if add_logvar: 61 | recon_loss_element = logvar + (x - mu)**2 / var + math.log(2.*math.pi) 62 | else: 63 | recon_loss_element = (x - mu)**2 / var + math.log(2.*math.pi) 64 | 65 | # add const (can be used in change of variable) 66 | if const is not None: 67 | recon_loss_element += const 68 | 69 | # do sum 70 | if do_sum: 71 | recon_loss = torch.sum(recon_loss_element) * 0.5 72 | return recon_loss 73 | else: 74 | batch_size = recon_loss_element.size(0) 75 | recon_loss_element = torch.sum(recon_loss_element.view(batch_size, -1), 1) * 0.5 76 | return recon_loss_element 77 | 78 | def loss_kld_gaussian(mu, logvar, do_sum=True): 79 | # see Appendix B from VAE paper: 80 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 81 | # https://arxiv.org/abs/1312.6114 82 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 83 | KLD_element = 1 + logvar - mu.pow(2) - logvar.exp() 84 | 85 | # do sum 86 | if do_sum: 87 | KLD = torch.sum(KLD_element) * -0.5 88 | return KLD 89 | else: 90 | batch_size = KLD_element.size(0) 91 | KLD_element = torch.sum(KLD_element.view(batch_size, -1), 1) * -0.5 92 | return KLD_element 93 | 94 | def loss_kld_gaussian_vs_gaussian(mu1, logvar1, mu2, logvar2, do_sum=True): 95 | # see Appendix B from VAE paper: 96 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 97 | # https://arxiv.org/abs/1312.6114 98 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 99 | # https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians 100 | # log(sigma2) - log(sigma1) + 0.5 * (sigma1^2 + (mu1 - mu2)^2) / sigma2^2 - 0.5 101 | # 0 - log(sigma1) + 0.5 * (sigma1^2 + mu1^2) - 0.5 102 | # 0 - log(sigma1) + 0.5 * sigma1^2 + 0.5 * mu1^2 - 0.5 103 | # 0 - 0.5 * log(sigma1^2) + 0.5 * sigma1^2 + 0.5 * mu1^2 - 0.5 104 | # log(sigma2) - log(sigma1) + 0.5 * (sigma1^2 + (mu1 - mu2)^2) / sigma2^2 - 0.5 105 | KLD_element = - logvar2 + logvar1 - (logvar1.exp() + (mu1 - mu2)**2) / logvar2.exp() + 1. 106 | 107 | # do sum 108 | if do_sum: 109 | KLD = torch.sum(KLD_element) * -0.5 110 | return KLD 111 | else: 112 | batch_size = KLD_element.size(0) 113 | KLD_element = torch.sum(KLD_element.view(batch_size, -1), 1) * -0.5 114 | return KLD_element 115 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torchvision.utils as vutils 5 | 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | #import pandas as pd 10 | import seaborn as sns 11 | sns.set() 12 | sns.set_style('whitegrid') 13 | sns.set_palette('colorblind') 14 | 15 | 16 | def convert_npimage_torchimage(image): 17 | return 255*torch.transpose(torch.transpose(torch.from_numpy(image), 0, 2), 1, 2) 18 | 19 | def get_scatter_plot(data, labels=None, n_classes=None, num_samples=1000, xlim=None, ylim=None): 20 | ''' 21 | data : 2d points, batch_size x data_dim (numpy array) 22 | labels : labels, batch_size (numpy array) 23 | ''' 24 | batch_size, data_dim = data.shape 25 | num_samples = min(num_samples, batch_size) 26 | if labels is None: 27 | labels = np.zeros(batch_size, dtype=np.int) 28 | if n_classes is None: 29 | n_classes = len(np.unique(labels)) 30 | 31 | # sub-samples 32 | if num_samples != batch_size: 33 | indices = np.random.permutation(batch_size) 34 | data = data[indices[:num_samples]] 35 | labels = labels[indices[:num_samples]] 36 | 37 | # init config 38 | palette = sns.color_palette(n_colors=n_classes) 39 | palette = [palette[i] for i in np.unique(labels)] 40 | 41 | # plot 42 | fig, ax = plt.subplots(figsize=(5, 5)) 43 | data = {'x': data[:, 0], 44 | 'y': data[:, 1], 45 | 'class': labels} 46 | sns.scatterplot(x='x', y='y', hue='class', data=data, palette=palette) 47 | 48 | # set config 49 | if xlim is not None: 50 | plt.xlim((-xlim, xlim)) 51 | if ylim is not None: 52 | plt.ylim((-ylim, ylim)) 53 | 54 | # draw to canvas 55 | fig.canvas.draw() # draw the canvas, cache the renderer 56 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 57 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 58 | 59 | # close figure 60 | plt.close() 61 | return image 62 | 63 | def get_data_for_quiver_plot(val, num): 64 | _x = np.linspace(-val, val, num) 65 | _y = np.linspace(-val, val, num) 66 | _u, _v = np.meshgrid(_x, _y) 67 | _vis_data = np.stack([_u.reshape(num**2), _v.reshape(num**2)], axis=1) 68 | vis_data = torch.from_numpy(_vis_data).float() 69 | return vis_data, _x, _y 70 | 71 | def get_quiver_plot(vec, x_pos, y_pos, xlim=None, ylim=None, scale=None): 72 | ''' 73 | vec : 2d points, batch_size x data_dim (numpy array) 74 | pos : 2d points, batch_size x data_dim (numpy array) 75 | ''' 76 | grid_size = x_pos.shape[0] 77 | batch_size = vec.shape[0] 78 | assert batch_size == grid_size**2 79 | assert y_pos.shape[0] == grid_size 80 | 81 | # get x, y, u, v 82 | X = x_pos #np.arange(-10, 10, 1) 83 | Y = y_pos #np.arange(-10, 10, 1) 84 | #U, V = np.meshgrid(X, Y) 85 | U = vec[:, 0].reshape(grid_size, grid_size) 86 | V = vec[:, 1].reshape(grid_size, grid_size) 87 | 88 | # plot 89 | fig, ax = plt.subplots(figsize=(5, 5)) 90 | q = ax.quiver(X, Y, U, V, pivot='mid', scale=scale) 91 | #ax.quiverkey(q, X=0.3, Y=1.1, U=10, 92 | # label='Quiver key, length = 10', labelpos='E') 93 | 94 | # set config 95 | if xlim is not None: 96 | plt.xlim((-xlim, xlim)) 97 | if ylim is not None: 98 | plt.ylim((-ylim, ylim)) 99 | 100 | # tight 101 | plt.tight_layout() 102 | 103 | # draw to canvas 104 | fig.canvas.draw() # draw the canvas, cache the renderer 105 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 106 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 107 | 108 | # close figure 109 | plt.close() 110 | return image 111 | 112 | def get_data_for_heatmap(val=4, num=256): 113 | _x = np.linspace(-val, val, num) 114 | _y = np.linspace(-val, val, num) 115 | _u, _v = np.meshgrid(_x, _y) 116 | _data = np.stack([_u.reshape(num**2), _v.reshape(num**2)], axis=1) 117 | return _data, _x, _y 118 | 119 | def energy_to_unnormalized_prob(energy): 120 | prob = torch.exp(-energy) # unnormalized prob 121 | return prob 122 | 123 | def get_prob_from_energy_func_for_vis(energy_func, val=4, num=256): 124 | # get grid 125 | _z, _, _ = get_data_for_heatmap(val=val, num=num) 126 | z = torch.from_numpy(_z).float() 127 | 128 | # run energy func 129 | energy = energy_func(z) 130 | prob = energy_to_unnormalized_prob(energy) 131 | 132 | # convert to numpy array 133 | _prob = prob.cpu().float().numpy() 134 | _prob = _prob.reshape(num, num) 135 | return _prob 136 | 137 | def get_imshow_plot(prob, val=4, use_grid=True): 138 | # plot 139 | fig, ax = plt.subplots(figsize=(5, 5)) 140 | im = ax.imshow(prob, cmap='jet', extent=[-val, val, -val, val]) 141 | ax.grid(False) 142 | if use_grid: 143 | plt.xticks(np.arange(-val, val+1, step=1)) 144 | plt.yticks(np.arange(-val, val+1, step=1)) 145 | else: 146 | plt.xticks([]) 147 | plt.yticks([]) 148 | 149 | # tight 150 | plt.tight_layout() 151 | 152 | # draw to canvas 153 | fig.canvas.draw() # draw the canvas, cache the renderer 154 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 155 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 156 | 157 | # close figure 158 | plt.close() 159 | return image 160 | 161 | def get_1d_histogram_plot(data, val=4, num=256, use_grid=True): 162 | xmin = 0 163 | xmax = val 164 | 165 | # get data 166 | x = data 167 | 168 | # get histogram 169 | hist, xedges = np.histogram(x, range=[xmin, xmax], bins=num) 170 | 171 | # plot heatmap 172 | fig, ax = plt.subplots(figsize=(5, 5)) 173 | im = ax.bar(xedges[:-1], hist, width=0.5)#, color='#0504aa',alpha=0.7) 174 | 175 | ax.grid(False) 176 | if use_grid: 177 | plt.xticks(np.arange(0, val+1, step=1)) 178 | else: 179 | plt.xticks([]) 180 | 181 | # tight 182 | plt.tight_layout() 183 | 184 | # draw to canvas 185 | fig.canvas.draw() # draw the canvas, cache the renderer 186 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 187 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 188 | 189 | # close figure 190 | plt.close() 191 | return image 192 | 193 | def get_2d_histogram_plot(data, val=4, num=256, use_grid=True): 194 | xmin = -val 195 | xmax = val 196 | ymin = -val 197 | ymax = val 198 | 199 | # get data 200 | x = data[:, 0] 201 | y = data[:, 1] 202 | 203 | # get histogram 204 | heatmap, xedges, yedges = np.histogram2d(x, y, range=[[xmin, xmax], [ymin, ymax]], bins=num) 205 | extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] 206 | 207 | # plot heatmap 208 | fig, ax = plt.subplots(figsize=(5, 5)) 209 | im = ax.imshow(heatmap.T, extent=extent, cmap='jet') 210 | ax.grid(False) 211 | if use_grid: 212 | plt.xticks(np.arange(-val, val+1, step=1)) 213 | plt.yticks(np.arange(-val, val+1, step=1)) 214 | else: 215 | plt.xticks([]) 216 | plt.yticks([]) 217 | 218 | # tight 219 | plt.tight_layout() 220 | 221 | # draw to canvas 222 | fig.canvas.draw() # draw the canvas, cache the renderer 223 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 224 | image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 225 | 226 | # close figure 227 | plt.close() 228 | return image 229 | 230 | def get_grid_image(input, batch_size, nchannels, nheight, nwidth=None, ncol=8, pad_value=0, do_square=True): 231 | ''' 232 | input : b x c x h x w (where h = w) 233 | ''' 234 | if batch_size > ncol**2 and do_square: 235 | input = input[:ncol**2, :, :, :] 236 | batch_size = ncol**2 237 | nwidth = nwidth if nwidth is not None else nheight 238 | input = input.detach() 239 | output = input.view(batch_size, nchannels, nheight, nwidth).clone().cpu() 240 | output = vutils.make_grid(output, nrow=ncol, normalize=True, scale_each=True, pad_value=pad_value) 241 | #output = vutils.make_grid(output, normalize=False, scale_each=False) 242 | return output 243 | 244 | #def get_canvas(fig): 245 | # fig.canvas.draw() # draw the canvas, cache the renderer 246 | # image = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 247 | # image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 248 | # return image 249 | # # close figure 250 | # plt.close() 251 | # 252 | #def get_contour_with_batch_size(model, batch_size=128, vmin=-10.0, vmax=10.0, title=None): 253 | # model.eval() 254 | # matplotlib.rcParams['xtick.direction'] = 'out' 255 | # matplotlib.rcParams['ytick.direction'] = 'out' 256 | # matplotlib.rcParams['contour.negative_linestyle'] = 'solid' 257 | # 258 | # # tmp 259 | # weight = next(model.parameters()) 260 | # 261 | # # gen grid⋅ 262 | # delta = 0.1 263 | # xv, yv = torch.meshgrid([torch.arange(vmin, vmax, delta), torch.arange(vmin, vmax, delta)]) 264 | # h = yv.size(0) 265 | # w = xv.size(0) 266 | # yv = yv.contiguous().view(-1) 267 | # xv = xv.contiguous().view(-1) 268 | # input = torch.cat([xv.unsqueeze(1), yv.unsqueeze(1)], dim=1).to(weight.device) 269 | # 270 | # # forward 271 | # prob = model.prob(input, batch_size=batch_size) 272 | # 273 | # # convert torch variable to numpy array 274 | # xv = xv.cpu().numpy().reshape(h, w) 275 | # yv = yv.cpu().numpy().reshape(h, w) 276 | # zv = prob.detach().cpu().numpy().reshape(h, w) 277 | # 278 | # # plot and save⋅ 279 | # fig = plt.figure() 280 | # CS1 = plt.contourf(xv, yv, zv) 281 | # CS2 = plt.contour(xv, yv, zv, alpha=.7, colors='k') 282 | # plt.clabel(CS2, inline=1, fontsize=10, colors='k') 283 | # #plt.title('Simplest default with labels') 284 | # if title is not None: 285 | # plt.title(title) 286 | # #plt.savefig(filename) 287 | # #plt.close() 288 | # image = get_canvas(fig) 289 | # plt.close() 290 | # 291 | # return image 292 | # 293 | ##def get_contour_with_data(model, data, vmin=-10.0, vmax=10.0, title=None): 294 | ## model.eval() 295 | ## matplotlib.rcParams['xtick.direction'] = 'out' 296 | ## matplotlib.rcParams['ytick.direction'] = 'out' 297 | ## matplotlib.rcParams['contour.negative_linestyle'] = 'solid' 298 | ## 299 | ## # gen grid⋅ 300 | ## delta = 0.1 301 | ## xv, yv = torch.meshgrid([torch.arange(vmin, vmax, delta), torch.arange(vmin, vmax, delta)]) 302 | ## h = yv.size(0) 303 | ## w = xv.size(0) 304 | ## yv = yv.contiguous().view(-1) 305 | ## xv = xv.contiguous().view(-1) 306 | ## input = torch.cat([xv.unsqueeze(1), yv.unsqueeze(1)], dim=1).to(data.device) 307 | ## 308 | ## # forward 309 | ## prob = model.prob(input, data) 310 | ## 311 | ## # convert torch variable to numpy array 312 | ## xv = xv.cpu().numpy().reshape(h, w) 313 | ## yv = yv.cpu().numpy().reshape(h, w) 314 | ## zv = prob.detach().cpu().numpy().reshape(h, w) 315 | ## 316 | ## # plot and save⋅ 317 | ## fig = plt.figure() 318 | ## CS1 = plt.contourf(xv, yv, zv) 319 | ## CS2 = plt.contour(xv, yv, zv, alpha=.7, colors='k') 320 | ## plt.clabel(CS2, inline=1, fontsize=10, colors='k') 321 | ## #plt.title('Simplest default with labels') 322 | ## if title is not None: 323 | ## plt.title(title) 324 | ## #plt.savefig(filename) 325 | ## #plt.close() 326 | ## image = get_canvas(fig) 327 | ## plt.close() 328 | ## 329 | ## return image 330 | # 331 | #def get_contour_with_z(model, z, vmin=-10.0, vmax=10.0, title=None): 332 | # model.eval() 333 | # matplotlib.rcParams['xtick.direction'] = 'out' 334 | # matplotlib.rcParams['ytick.direction'] = 'out' 335 | # matplotlib.rcParams['contour.negative_linestyle'] = 'solid' 336 | # 337 | # # gen grid⋅ 338 | # delta = 0.1 339 | # xv, yv = torch.meshgrid([torch.arange(vmin, vmax, delta), torch.arange(vmin, vmax, delta)]) 340 | # h = yv.size(0) 341 | # w = xv.size(0) 342 | # yv = yv.contiguous().view(-1) 343 | # xv = xv.contiguous().view(-1) 344 | # input = torch.cat([xv.unsqueeze(1), yv.unsqueeze(1)], dim=1).to(z.device) 345 | # 346 | # # forward 347 | # prob = model.prob(input, z=z) 348 | # 349 | # # convert torch variable to numpy array 350 | # xv = xv.cpu().numpy().reshape(h, w) 351 | # yv = yv.cpu().numpy().reshape(h, w) 352 | # zv = prob.detach().cpu().numpy().reshape(h, w) 353 | # 354 | # # plot and save⋅ 355 | # fig = plt.figure() 356 | # CS1 = plt.contourf(xv, yv, zv) 357 | # CS2 = plt.contour(xv, yv, zv, alpha=.7, colors='k') 358 | # plt.clabel(CS2, inline=1, fontsize=10, colors='k') 359 | # #plt.title('Simplest default with labels') 360 | # if title is not None: 361 | # plt.title(title) 362 | # #plt.savefig(filename) 363 | # #plt.close() 364 | # image = get_canvas(fig) 365 | # plt.close() 366 | # 367 | # return image 368 | --------------------------------------------------------------------------------