├── .gitignore ├── LICENSE ├── README.md ├── config ├── celea2K.json ├── celeba.json ├── fashionmnist.json ├── gmm_aniso.json ├── gmm_iso.json └── mnist.json ├── gmm.py ├── grbm.py ├── main.py ├── requirements.txt ├── utils.py └── vis ├── CelebA.gif └── MNIST.gif /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 lrjconan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gaussian-Bernoulli Restricted Boltzmann Machines (GRBMs) 2 | This is the official PyTorch implementation of [Gaussian-Bernoulli RBMs Without Tears](https://arxiv.org/abs/2210.10318) as described in the following paper: 3 | 4 | ``` 5 | @article{liao2022grbm, 6 | title={Gaussian-Bernoulli RBMs Without Tears}, 7 | author={Liao, Renjie and Kornblith, Simon and Ren, Mengye and Fleet, David J and Hinton, Geoffrey}, 8 | journal={arXiv preprint arXiv:2210.10318}, 9 | year={2022} 10 | } 11 | ``` 12 | 13 | ## Sampling processes of learned GRBMs on MNIST and CelebA(32 X 32): 14 | ![](vis/MNIST.gif) 15 | ![](vis/CelebA.gif) 16 | 17 | ## Dependencies 18 | Python 3, PyTorch(1.12.0). Other dependencies can be installed via ```pip install -r requirements.txt``` 19 | 20 | 21 | ## Run Demos 22 | 23 | ### Train 24 | * To run the training of experiment ```X``` where ```X``` is one of {```gmm_iso```, ```gmm_aniso```, ```mnist```, ```fashionmnist```, ```celeba```, ```celeba2K```}: 25 | 26 | ```python main.py -d X``` 27 | 28 | **Note**: 29 | 30 | * Please check the folder ```config``` for the configuration jason files where most hyperparameters are self-explanatory. 31 | * Important hyperparameters include: 32 | * CD_step: #CD steps to generate negative samples 33 | * inference_method: must be one of ```Gibbs```, ```Langevin```, ```Gibbs-Langevin``` 34 | * Langevin_step: # inner loop Langevin steps for Gibbs-Langevin sampling method 35 | * Langevin_eta: step size of both Langevin and Gibbs-Langevin sampling methods 36 | * Langevin_adjust_step: when set to ```X```, it enables Metropolis adjustment from ```X```-th to ```#CD```-th steps 37 | * is_vis_verbose: when set to True, it saves learned filters and hidden activations (conisder turning it off for better efficiency if you have too many filters and images are large) 38 | * For CelebA experiments, you need to download the dataset and set the relative path as ```data/celeba``` 39 | 40 | ## Cite 41 | Please consider citing our paper if you use this code in your research work. 42 | 43 | ## Questions/Bugs 44 | Please submit a Github issue or contact rjliao@ece.ubc.ca if you have any questions or find any bugs. 45 | -------------------------------------------------------------------------------- /config/celea2K.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CelebA2K", 3 | "cuda": true, 4 | "model": "GRBM", 5 | "batch_size": 100, 6 | "epochs": 10000, 7 | "lr": 0.01, 8 | "clip_norm": 10.0, 9 | "wd": 0.0, 10 | "resume": 0, 11 | "is_vis_verbose": false, 12 | "init_var": 1.0, 13 | "CD_step": 100, 14 | "CD_burnin": 0, 15 | "Langevin_step": 10, 16 | "Langevin_eta": 20.0, 17 | "Langevin_adjust_warmup_epoch": 0, 18 | "Langevin_adjust_step": 100, 19 | "inference_method": "Gibbs-Langevin", 20 | "sampling_batch_size": 64, 21 | "sampling_steps": 100, 22 | "sampling_gap": 5, 23 | "sampling_nrow": 8, 24 | "height": 64, 25 | "width": 64, 26 | "channel": 3, 27 | "crop_size": 140, 28 | "img_mean": [ 29 | 0.5239999890327454, 30 | 0.41519999504089355, 31 | 0.35899999737739563 32 | ], 33 | "img_std": [ 34 | 0.28679999709129333, 35 | 0.2529999911785126, 36 | 0.24529999494552612 37 | ], 38 | "log_interval": 1, 39 | "save_interval": 5, 40 | "hidden_size": 10000 41 | } -------------------------------------------------------------------------------- /config/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CelebA", 3 | "cuda": true, 4 | "model": "GRBM", 5 | "batch_size": 512, 6 | "epochs": 10000, 7 | "lr": 0.01, 8 | "clip_norm": 10.0, 9 | "wd": 0.0, 10 | "resume": 0, 11 | "is_vis_verbose": false, 12 | "init_var": 1.0, 13 | "CD_step": 100, 14 | "CD_burnin": 0, 15 | "Langevin_step": 10, 16 | "Langevin_eta": 20.0, 17 | "Langevin_adjust_warmup_epoch": 0, 18 | "Langevin_adjust_step": 100, 19 | "inference_method": "Gibbs-Langevin", 20 | "sampling_batch_size": 100, 21 | "sampling_steps": 100, 22 | "sampling_gap": 5, 23 | "sampling_nrow": 10, 24 | "height": 32, 25 | "width": 32, 26 | "channel": 3, 27 | "crop_size": 140, 28 | "img_mean": [ 29 | 0.5239999890327454, 30 | 0.41519999504089355, 31 | 0.35899999737739563 32 | ], 33 | "img_std": [ 34 | 0.28679999709129333, 35 | 0.2529999911785126, 36 | 0.24529999494552612 37 | ], 38 | "log_interval": 1, 39 | "save_interval": 5, 40 | "hidden_size": 10000 41 | } -------------------------------------------------------------------------------- /config/fashionmnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "FashionMNIST", 3 | "cuda": true, 4 | "model": "GRBM", 5 | "batch_size": 512, 6 | "epochs": 3000, 7 | "lr": 0.01, 8 | "clip_norm": 10.0, 9 | "wd": 0.0, 10 | "resume": 0, 11 | "is_vis_verbose": false, 12 | "init_var": 1.0, 13 | "CD_step": 100, 14 | "CD_burnin": 0, 15 | "Langevin_step": 10, 16 | "Langevin_eta": 20.0, 17 | "Langevin_adjust_warmup_epoch": 1000, 18 | "Langevin_adjust_step": 100, 19 | "inference_method": "Gibbs-Langevin", 20 | "sampling_batch_size": 100, 21 | "sampling_steps": 100, 22 | "sampling_gap": 5, 23 | "sampling_nrow": 10, 24 | "height": 28, 25 | "width": 28, 26 | "channel": 1, 27 | "img_mean": [ 28 | 0.28600001335144043 29 | ], 30 | "img_std": [ 31 | 0.3529999852180481 32 | ], 33 | "log_interval": 10, 34 | "save_interval": 100, 35 | "hidden_size": 10000 36 | } -------------------------------------------------------------------------------- /config/gmm_aniso.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "GMM_aniso", 3 | "cuda": true, 4 | "model": "GRBM", 5 | "batch_size": 100, 6 | "epochs": 50000, 7 | "lr": 0.01, 8 | "clip_norm": 10.0, 9 | "wd": 0.0, 10 | "resume": 0, 11 | "is_vis_verbose": true, 12 | "init_var": 1.0, 13 | "CD_step": 100, 14 | "CD_burnin": 0, 15 | "Langevin_step": 10, 16 | "Langevin_eta": 10.0, 17 | "Langevin_adjust_warmup_epoch": 0, 18 | "Langevin_adjust_step": 0, 19 | "inference_method": "Gibbs-Langevin", 20 | "sampling_batch_size": 100, 21 | "sampling_steps": 100, 22 | "sampling_gap": 5, 23 | "sampling_nrow": 10, 24 | "num_samples": 1000, 25 | "height": 1, 26 | "width": 1, 27 | "channel": 2, 28 | "log_interval": 100, 29 | "save_interval": 50000, 30 | "hidden_size": 256 31 | } -------------------------------------------------------------------------------- /config/gmm_iso.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "GMM_iso", 3 | "cuda": true, 4 | "model": "GRBM", 5 | "batch_size": 100, 6 | "epochs": 50000, 7 | "lr": 0.01, 8 | "clip_norm": 10.0, 9 | "wd": 0.0, 10 | "resume": 0, 11 | "is_vis_verbose": true, 12 | "init_var": 1.0, 13 | "CD_step": 100, 14 | "CD_burnin": 0, 15 | "Langevin_step": 10, 16 | "Langevin_eta": 10.0, 17 | "Langevin_adjust_warmup_epoch": 0, 18 | "Langevin_adjust_step": 0, 19 | "inference_method": "Gibbs-Langevin", 20 | "sampling_batch_size": 100, 21 | "sampling_steps": 100, 22 | "sampling_gap": 5, 23 | "sampling_nrow": 10, 24 | "num_samples": 1000, 25 | "height": 1, 26 | "width": 1, 27 | "channel": 2, 28 | "log_interval": 100, 29 | "save_interval": 50000, 30 | "hidden_size": 256 31 | } -------------------------------------------------------------------------------- /config/mnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "MNIST", 3 | "cuda": true, 4 | "model": "GRBM", 5 | "batch_size": 512, 6 | "epochs": 3000, 7 | "lr": 0.01, 8 | "clip_norm": 10.0, 9 | "wd": 0.0, 10 | "resume": 0, 11 | "is_vis_verbose": true, 12 | "init_var": 1.0, 13 | "CD_step": 100, 14 | "CD_burnin": 0, 15 | "Langevin_step": 10, 16 | "Langevin_eta": 20.0, 17 | "Langevin_adjust_warmup_epoch": 0, 18 | "Langevin_adjust_step": 100, 19 | "inference_method": "Gibbs-Langevin", 20 | "sampling_batch_size": 100, 21 | "sampling_steps": 100, 22 | "sampling_gap": 5, 23 | "sampling_nrow": 10, 24 | "height": 28, 25 | "width": 28, 26 | "channel": 1, 27 | "img_mean": [ 28 | 0.1307000070810318 29 | ], 30 | "img_std": [ 31 | 0.30809998512268066 32 | ], 33 | "log_interval": 10, 34 | "save_interval": 100, 35 | "hidden_size": 4096 36 | } -------------------------------------------------------------------------------- /gmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from utils import cosine_schedule 5 | 6 | 7 | class GMMDataset(torch.utils.data.Dataset): 8 | 9 | def __init__(self, samples): 10 | self.samples = samples 11 | 12 | def __len__(self): 13 | return self.samples.shape[0] 14 | 15 | def __getitem__(self, idx): 16 | return self.samples[idx, :], torch.ones(1).to(self.samples.device) 17 | 18 | 19 | class GMM(nn.Module): 20 | """ Gaussian Mixture Models 21 | N.B.: covariance is assumed to be diagonal 22 | """ 23 | 24 | def __init__(self, w, mu, sigma): 25 | """ 26 | p(x) = sum_i w[i] N(mu[i], sigma[i]^2 * I) 27 | 28 | config: 29 | w: shape K X 1, mixture coefficients, must sum to 1 30 | mu: shape K X D, mean 31 | sigma: shape K X D, (diagonal) variance 32 | """ 33 | super().__init__() 34 | self.register_buffer('w', w) 35 | self.register_buffer('mu', mu) 36 | self.register_buffer('sigma', sigma) 37 | self.K = w.shape[0] 38 | self.D = mu.shape[1] 39 | 40 | @torch.no_grad() 41 | def log_gaussian(self, x, mu, sigma): 42 | """ log density of single (diagonal-covariance) multivariate Gaussian""" 43 | return -0.5 * ((x - mu)**2 / sigma**2).sum(dim=1) - 0.5 * ( 44 | self.D * np.log(2 * np.pi) + torch.log(torch.prod(sigma**2))) 45 | 46 | @torch.no_grad() 47 | def log_prob(self, x): 48 | return torch.logsumexp( 49 | torch.stack([ 50 | torch.log(self.w[kk]) + 51 | self.log_gaussian(x, self.mu[kk], self.sigma[kk]) 52 | for kk in range(self.K) 53 | ]), 0) 54 | 55 | @torch.no_grad() 56 | def sampling(self, num_samples): 57 | m = torch.distributions.Categorical(self.w) 58 | idx = m.sample((num_samples,)) 59 | return self.mu[idx, :] + torch.randn(num_samples, self.D).to( 60 | self.w.device) * self.sigma[idx, :] 61 | 62 | @torch.no_grad() 63 | def langevin_sampling(self, x, num_steps=10, eta=1.0e+0, is_anneal=False): 64 | eta_list = cosine_schedule(eta_max=eta, T=num_steps) 65 | for ii in range(num_steps): 66 | eta_ii = eta_list[ii] if is_anneal else eta 67 | x = x.detach() 68 | x.requires_grad = True 69 | eng = -self.log_prob(x).sum() 70 | grad = torch.autograd.grad(eng, x)[0] 71 | x = x - eta_ii * grad + torch.randn_like(x) * np.sqrt(eta_ii * 2) 72 | 73 | return x.detach() 74 | -------------------------------------------------------------------------------- /grbm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from utils import cosine_schedule 6 | 7 | class GRBM(nn.Module): 8 | """ Gaussian-Bernoulli Restricted Boltzmann Machines (GRBM) """ 9 | 10 | def __init__(self, 11 | visible_size, 12 | hidden_size, 13 | CD_step=1, 14 | CD_burnin=0, 15 | init_var=1e-0, 16 | inference_method='Gibbs', 17 | Langevin_step=10, 18 | Langevin_eta=1.0, 19 | is_anneal_Langevin=True, 20 | Langevin_adjust_step=0) -> None: 21 | super().__init__() 22 | # we use samples in [CD_burnin, CD_step) steps 23 | assert CD_burnin >= 0 and CD_burnin <= CD_step 24 | assert inference_method in ['Gibbs', 'Langevin', 'Gibbs-Langevin'] 25 | 26 | self.visible_size = visible_size 27 | self.hidden_size = hidden_size 28 | self.CD_step = CD_step 29 | self.CD_burnin = CD_burnin 30 | self.init_var = init_var 31 | self.inference_method = inference_method 32 | self.Langevin_step = Langevin_step 33 | self.Langevin_eta = Langevin_eta 34 | self.is_anneal_Langevin = is_anneal_Langevin 35 | self.Langevin_adjust_step = Langevin_adjust_step 36 | 37 | self.W = nn.Parameter(torch.Tensor(visible_size, hidden_size)) 38 | self.b = nn.Parameter(torch.Tensor(hidden_size)) 39 | self.mu = nn.Parameter(torch.Tensor(visible_size)) 40 | self.log_var = nn.Parameter(torch.Tensor(visible_size)) 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | nn.init.normal_(self.W, 45 | std=1.0 * self.init_var / 46 | np.sqrt(self.visible_size + self.hidden_size)) 47 | nn.init.constant_(self.b, 0.0) 48 | nn.init.constant_(self.mu, 0.0) 49 | nn.init.constant_(self.log_var, 50 | np.log(self.init_var)) # init variance = 1.0 51 | 52 | def get_var(self): 53 | return self.log_var.exp().clip(min=1e-8) 54 | 55 | def set_Langevin_eta(self, eta): 56 | self.Langevin_eta = eta 57 | 58 | def set_Langevin_adjust_step(self, step): 59 | self.Langevin_adjust_step = step 60 | 61 | @torch.no_grad() 62 | def energy(self, v, h): 63 | # compute per-sample energy averaged over batch size 64 | B = v.shape[0] 65 | var = self.get_var() 66 | eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1) 67 | eng -= ((v / var).mm(self.W) * h).sum(dim=1) + h.mv(self.b) 68 | return eng / B 69 | 70 | @torch.no_grad() 71 | def marginal_energy(self, v): 72 | # compute per-sample energy averaged over batch size 73 | B = v.shape[0] 74 | var = self.get_var() 75 | eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1) 76 | eng -= F.softplus((v / var).mm(self.W) + self.b).sum(dim=1) 77 | return eng / B 78 | 79 | @torch.no_grad() 80 | def energy_grad_v(self, v, h): 81 | # compute the gradient (sample) of energy averaged over batch size 82 | B = v.shape[0] 83 | var = self.get_var() 84 | return ((v - self.mu) / var - h.mm(self.W.T) / var) / B 85 | 86 | @torch.no_grad() 87 | def marginal_energy_grad_v(self, v): 88 | # compute the gradient (sample) of energy averaged over batch size 89 | B = v.shape[0] 90 | var = self.get_var() 91 | return ((v - self.mu) / var - torch.sigmoid((v / var).mm(self.W) + self.b).mm(self.W.T) / var) / B 92 | 93 | @torch.no_grad() 94 | def energy_grad_param(self, v, h): 95 | # compute the gradient (parameter) of energy averaged over batch size 96 | var = self.get_var() 97 | grad = {} 98 | grad['W'] = -torch.einsum("bi,bj->ij", v / var, h) / v.shape[0] 99 | grad['b'] = -h.mean(dim=0) 100 | grad['mu'] = ((self.mu - v) / var).mean(dim=0) 101 | grad['log_var'] = (-0.5 * (v - self.mu)**2 / var + 102 | ((v / var) * h.mm(self.W.T))).mean(dim=0) 103 | return grad 104 | 105 | @torch.no_grad() 106 | def marginal_energy_grad_param(self, v): 107 | # compute the gradient (parameter) of energy averaged over batch size 108 | var = self.get_var() 109 | vv = v / var 110 | tmp = torch.sigmoid(vv.mm(self.W) + self.b) 111 | grad = {} 112 | grad['W'] = -torch.einsum("bi,bj->ij", vv, tmp) / v.shape[0] 113 | grad['b'] = -tmp.mean(dim=0) 114 | grad['mu'] = ((self.mu - v) / var).mean(dim=0) 115 | grad['log_var'] = (-0.5 * (v - self.mu)**2 / var + 116 | (vv * tmp.mm(self.W.T))).mean(dim=0) 117 | return grad 118 | 119 | @torch.no_grad() 120 | def prob_h_given_v(self, v, var): 121 | return torch.sigmoid((v / var).mm(self.W) + self.b) 122 | 123 | @torch.no_grad() 124 | def prob_v_given_h(self, h): 125 | return h.mm(self.W.T) + self.mu 126 | 127 | @torch.no_grad() 128 | def log_metropolis_ratio_Gibbs_Langevin(self, v_old, h_old, v_new, h_new, eta_list): 129 | """ Metropolis-Hasting ratio of accepting the move from old to new state """ 130 | B = v_old.shape[0] 131 | var = self.get_var() 132 | eng_diff = -self.energy(v_new, h_new) + self.energy(v_old, h_old) 133 | state_h_new = (v_new / var).mm(self.W) + self.b 134 | state_h_old = (v_old / var).mm(self.W) + self.b 135 | log_prob_h_given_v_new = - \ 136 | F.binary_cross_entropy_with_logits( 137 | state_h_old, h_old, reduction='none').sum(dim=1) 138 | log_prob_h_given_v_old = - \ 139 | F.binary_cross_entropy_with_logits( 140 | state_h_new, h_new, reduction='none').sum(dim=1) 141 | 142 | eta = torch.tensor(eta_list).to(var.device) # shape K X 1 143 | beta_in = 1.0 - eta.unsqueeze(1) / (B * var.unsqueeze(0)) # shape K X D 144 | beta = torch.flip(torch.cumprod( 145 | torch.flip(beta_in, [0]), 0), [0]) # shape K X D 146 | beta = F.pad(beta, [0, 0, 0, 1], "constant", 1.0) # shape (K+1) X D 147 | va = (beta[1:] * eta.view(-1, 1)).sum(dim=0) / (B * var) # shape 1 X D 148 | tilde_sigma_sqrt = ( 149 | (beta[1:]**2 * eta.view(-1, 1)).sum(dim=0)).sqrt() # shape 1 X D 150 | proposal_eng_new = - torch.pow((v_old - beta[0] * v_new - va * ( 151 | self.mu + h_new.mm(self.W.T))) / (2 * tilde_sigma_sqrt), 2.0).sum(dim=1) 152 | proposal_eng_old = - torch.pow((v_new - beta[0] * v_old - va * ( 153 | self.mu + h_old.mm(self.W.T))) / (2 * tilde_sigma_sqrt), 2.0).sum(dim=1) 154 | 155 | return eng_diff + proposal_eng_new - proposal_eng_old + log_prob_h_given_v_new - log_prob_h_given_v_old 156 | 157 | @torch.no_grad() 158 | def log_metropolis_ratio_Langevin_one_step(self, v_old, v_new, grad_old, eta): 159 | """ Metropolis-Hasting ratio of accepting the move from old to new state """ 160 | eng_diff = -self.marginal_energy(v_new) + self.marginal_energy(v_old) 161 | proposal_eng_new = - \ 162 | torch.pow(v_old - v_new + eta * 163 | self.marginal_energy_grad_v(v_new), 2.0).sum(dim=1) / (4 * eta) 164 | proposal_eng_old = - \ 165 | torch.pow(v_new - v_old + eta * grad_old, 166 | 2.0).sum(dim=1) / (4 * eta) 167 | 168 | return eng_diff + proposal_eng_new - proposal_eng_old 169 | 170 | @torch.no_grad() 171 | def Gibbs_sampling_vh(self, v, num_steps=10, burn_in=0): 172 | samples, var = [], self.get_var() 173 | std = var.sqrt() 174 | h = torch.bernoulli(self.prob_h_given_v(v, var)) 175 | for ii in range(num_steps): 176 | # backward sampling 177 | mu = self.prob_v_given_h(h) 178 | v = mu + torch.randn_like(mu) * std 179 | 180 | # forward sampling 181 | h = torch.bernoulli(self.prob_h_given_v(v, var)) 182 | 183 | if ii >= burn_in: 184 | samples += [(v, h)] 185 | 186 | return samples 187 | 188 | @torch.no_grad() 189 | def Langevin_sampling_v(self, 190 | v, 191 | num_steps=10, 192 | eta=1.0e+0, 193 | burn_in=0, 194 | is_anneal=True, 195 | adjust_step=0): 196 | eta_list = cosine_schedule(eta_max=eta, T=num_steps) 197 | samples = [] 198 | 199 | for ii in range(num_steps): 200 | eta_ii = eta_list[ii] if is_anneal else eta 201 | grad_v = self.marginal_energy_grad_v(v) 202 | 203 | v_new = v - eta_ii * grad_v + \ 204 | torch.randn_like(v) * np.sqrt(eta_ii * 2) 205 | 206 | if ii >= adjust_step: 207 | tmp_u = torch.rand(v.shape[0]).to(v.device) 208 | log_ratio = self.log_metropolis_ratio_Langevin_one_step( 209 | v, v_new, grad_v, eta_ii) 210 | ratio = torch.minimum( 211 | torch.ones_like(log_ratio), log_ratio.exp()) 212 | v = v_new * (tmp_u < ratio).float().view( 213 | -1, 1) + v * (tmp_u >= ratio).float().view(-1, 1) 214 | else: 215 | v = v_new 216 | 217 | if ii >= burn_in: 218 | samples += [v] 219 | 220 | return samples 221 | 222 | @torch.no_grad() 223 | def Gibbs_Langevin_sampling_vh(self, 224 | v, 225 | num_steps=10, 226 | num_steps_Langevin=10, 227 | eta=1.0e+0, 228 | burn_in=0, 229 | is_anneal=True, 230 | adjust_step=0): 231 | samples, var = [], self.get_var() 232 | eta_list = cosine_schedule(eta_max=eta, T=num_steps_Langevin) 233 | 234 | h = torch.bernoulli(self.prob_h_given_v(v, var)) 235 | 236 | for ii in range(num_steps): 237 | v_old, h_old = v, h 238 | # backward sampling 239 | for jj in range(num_steps_Langevin): 240 | eta_jj = eta_list[jj] if is_anneal else eta 241 | grad_v = self.energy_grad_v(v, h) 242 | v = v - eta_jj * grad_v + \ 243 | torch.randn_like(v) * np.sqrt(eta_jj * 2) 244 | 245 | # forward sampling 246 | h = torch.bernoulli(self.prob_h_given_v(v, var)) 247 | 248 | if ii >= adjust_step: 249 | tmp_u = torch.rand(v.shape[0]).to(v.device) 250 | log_ratio = self.log_metropolis_ratio_Gibbs_Langevin( 251 | v_old, h_old, v, h, eta_list) 252 | ratio = torch.minimum( 253 | torch.ones_like(log_ratio), log_ratio.exp()) 254 | v = v * (tmp_u < ratio).float().view( 255 | -1, 1) + v_old * (tmp_u >= ratio).float().view(-1, 1) 256 | h = h * (tmp_u < ratio).float().view( 257 | -1, 1) + h_old * (tmp_u >= ratio).float().view(-1, 1) 258 | 259 | if ii >= burn_in: 260 | samples += [(v, h)] 261 | 262 | return samples 263 | 264 | @torch.no_grad() 265 | def reconstruction(self, v): 266 | v, var = v.view(v.shape[0], -1), self.get_var() 267 | prob_h = self.prob_h_given_v(v, var) 268 | v_bar = self.prob_v_given_h(prob_h) 269 | return F.mse_loss(v, v_bar) 270 | 271 | @torch.no_grad() 272 | def sampling(self, v_init, num_steps=1, save_gap=1): 273 | v_shape = v_init.shape 274 | v = v_init.view(v_shape[0], -1) 275 | var = self.get_var() 276 | var_mean = var.mean().item() 277 | 278 | if self.inference_method == 'Gibbs': 279 | samples = self.Gibbs_sampling_vh(v, num_steps=num_steps - 1) 280 | samples = [xx[0] for xx in samples] # extract v 281 | elif self.inference_method == 'Langevin': 282 | samples = self.Langevin_sampling_v(v, 283 | num_steps=num_steps - 1, 284 | eta=self.Langevin_eta * var_mean, 285 | is_anneal=self.is_anneal_Langevin, 286 | adjust_step=self.Langevin_adjust_step) 287 | elif self.inference_method == 'Gibbs-Langevin': 288 | samples = self.Gibbs_Langevin_sampling_vh( 289 | v, 290 | num_steps=num_steps - 1, 291 | num_steps_Langevin=self.Langevin_step, 292 | eta=self.Langevin_eta * var_mean, 293 | is_anneal=self.is_anneal_Langevin, 294 | adjust_step=self.Langevin_adjust_step) 295 | samples = [xx[0] for xx in samples] # extract v 296 | 297 | # use conditional mean as the last sample 298 | h = torch.bernoulli(self.prob_h_given_v(samples[-1], var)) 299 | mu = self.prob_v_given_h(h) 300 | v_list = [(0, v_init)] + [(ii + 1, samples[ii].view(v_shape).detach()) 301 | for ii in range(num_steps - 1) 302 | if (ii + 1) % save_gap == 0 303 | ] + [(num_steps, mu.view(v_shape).detach())] 304 | 305 | return v_list 306 | 307 | @torch.no_grad() 308 | def positive_grad(self, v): 309 | h = torch.bernoulli(self.prob_h_given_v(v, self.get_var())) 310 | grad = self.energy_grad_param(v, h) 311 | return grad 312 | 313 | @torch.no_grad() 314 | def negative_grad(self, v): 315 | var = self.get_var() 316 | var_mean = var.mean().item() 317 | if self.inference_method == 'Gibbs': 318 | samples = self.Gibbs_sampling_vh(v, 319 | num_steps=self.CD_step, 320 | burn_in=self.CD_burnin) 321 | v_neg = torch.cat([xx[0] for xx in samples], dim=0) 322 | h_neg = torch.cat([xx[1] for xx in samples], dim=0) 323 | grad = self.energy_grad_param(v_neg, h_neg) 324 | elif self.inference_method == 'Langevin': 325 | samples = self.Langevin_sampling_v(v, 326 | num_steps=self.CD_step, 327 | burn_in=self.CD_burnin, 328 | eta=self.Langevin_eta * var_mean, 329 | is_anneal=self.is_anneal_Langevin, 330 | adjust_step=self.Langevin_adjust_step) 331 | v_neg = torch.cat(samples, dim=0) 332 | grad = self.marginal_energy_grad_param(v_neg) 333 | 334 | elif self.inference_method == 'Gibbs-Langevin': 335 | samples = self.Gibbs_Langevin_sampling_vh( 336 | v, 337 | num_steps=self.CD_step, 338 | burn_in=self.CD_burnin, 339 | num_steps_Langevin=self.Langevin_step, 340 | eta=self.Langevin_eta * var_mean, 341 | is_anneal=self.is_anneal_Langevin, 342 | adjust_step=self.Langevin_adjust_step) 343 | v_neg = torch.cat([xx[0] for xx in samples], dim=0) 344 | h_neg = torch.cat([xx[1] for xx in samples], dim=0) 345 | grad = self.energy_grad_param(v_neg, h_neg) 346 | 347 | return grad 348 | 349 | @torch.no_grad() 350 | def CD_grad(self, v): 351 | v = v.view(v.shape[0], -1) 352 | # postive gradient 353 | grad_pos = self.positive_grad(v) 354 | 355 | # negative gradient 356 | v_neg = torch.randn_like(v) 357 | grad_neg = self.negative_grad(v_neg) 358 | 359 | # compute update 360 | for name, param in self.named_parameters(): 361 | param.grad = grad_pos[name] - grad_neg[name] 362 | 363 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms, utils 8 | import numpy as np 9 | from tqdm import tqdm 10 | from utils import setup_logging, vis_density_GMM, vis_2D_samples, visualize_sampling 11 | from gmm import GMM, GMMDataset 12 | from grbm import GRBM 13 | 14 | EPS = 1e-7 15 | SEED = 1234 16 | torch.manual_seed(SEED) 17 | np.random.seed(SEED) 18 | torch.set_default_dtype(torch.float32) 19 | 20 | 21 | def save(model, results_folder, epoch): 22 | data = {'epoch': epoch, 'model': model.state_dict()} 23 | torch.save(data, f'{results_folder}/model-{epoch}.pt') 24 | 25 | 26 | def load(model, results_folder, epoch): 27 | data = torch.load(f'{results_folder}/model-{epoch}.pt') 28 | model.load_state_dict(data['model']) 29 | 30 | 31 | def train(model, 32 | train_loader, 33 | optimizer, 34 | config): 35 | model.train() 36 | for ii, (data, _) in enumerate(tqdm(train_loader)): 37 | if config['cuda']: 38 | data = data.cuda() 39 | 40 | optimizer.zero_grad() 41 | model.CD_grad(data) 42 | if config['clip_norm'] > 0: 43 | nn.utils.clip_grad_norm_(model.parameters(), config['clip_norm']) 44 | optimizer.step() 45 | 46 | if ii == len(train_loader) - 1: 47 | recon_loss = model.reconstruction(data).item() 48 | 49 | return recon_loss 50 | 51 | 52 | def create_dataset(config): 53 | if 'GMM' in config['dataset']: 54 | if config['dataset'] == 'GMM_iso': 55 | # isotropic 56 | gmm_model = GMM(torch.tensor([0.33, 0.33, 0.34]), 57 | torch.tensor([[-5, -5], [5, -5], [0, 5]]), 58 | torch.tensor([[1, 1], [1, 1], [1, 1]])).cuda() 59 | else: 60 | # anisotropic 61 | gmm_model = GMM(torch.tensor([0.33, 0.33, 0.34]), 62 | torch.tensor([[-5, -5], [5, -5], [0, 5]]), 63 | torch.tensor([[1.25, 0.5], [1.25, 0.5], [0.5, 64 | 1.25]])).cuda() 65 | 66 | vis_density_GMM(gmm_model, config) 67 | samples = gmm_model.sampling(config['num_samples']) 68 | vis_2D_samples(samples.cpu().numpy(), config, tags='ground_truth') 69 | train_set = GMMDataset(samples) 70 | elif config['dataset'] == 'MNIST': 71 | train_set = datasets.MNIST('./data', 72 | train=True, 73 | download=True, 74 | transform=transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize(config['img_mean'], 77 | config['img_std']) 78 | ])) 79 | elif config['dataset'] == 'CelebA': 80 | train_set = datasets.CelebA('./data', 81 | split='train', 82 | download=False, 83 | transform=transforms.Compose([ 84 | transforms.CenterCrop( 85 | config['crop_size']), 86 | transforms.Resize(config['height']), 87 | transforms.ToTensor(), 88 | transforms.Normalize(config['img_mean'], 89 | config['img_std']) 90 | ])) 91 | elif config['dataset'] == 'CelebA2K': 92 | train_set = datasets.CelebA('./data', 93 | split='train', 94 | download=False, 95 | transform=transforms.Compose([ 96 | transforms.CenterCrop( 97 | config['crop_size']), 98 | transforms.Resize(config['height']), 99 | transforms.ToTensor(), 100 | transforms.Normalize(config['img_mean'], 101 | config['img_std']) 102 | ])) 103 | train_set = torch.utils.data.Subset(train_set, range(2000)) 104 | elif config['dataset'] == 'FashionMNIST': 105 | train_set = datasets.FashionMNIST('./data', 106 | train=True, 107 | download=True, 108 | transform=transforms.Compose([ 109 | transforms.ToTensor(), 110 | transforms.Normalize(config['img_mean'], 111 | config['img_std']) 112 | ])) 113 | 114 | if 'GMM' not in config['dataset']: 115 | config['img_mean'] = torch.tensor(config['img_mean']) 116 | config['img_std'] = torch.tensor(config['img_std']) 117 | 118 | return train_set 119 | 120 | 121 | def train_model(args): 122 | """Let us train a GRBM and see how it performs""" 123 | pid = os.getpid() 124 | # Load config 125 | with open(f'config/{args.dataset}.json') as json_file: 126 | config = json.load(json_file) 127 | 128 | config['exp_folder'] = f"exp/{config['dataset']}_{config['model']}_{pid}_inference={config['inference_method']}_H={config['hidden_size']}_B={config['batch_size']}_CD={config['CD_step']}" 129 | 130 | if not os.path.isdir(config['exp_folder']): 131 | os.makedirs(config['exp_folder']) 132 | 133 | log_file = os.path.join(config['exp_folder'], f'log_exp_{pid}.txt') 134 | logger = setup_logging('INFO', log_file) 135 | logger.info('Writing log file to {}'.format(log_file)) 136 | 137 | with open(os.path.join(config['exp_folder'], f'config_{pid}.json'), 138 | 'w') as outfile: 139 | json.dump(config, outfile, indent=4) 140 | 141 | config['visible_size'] = config['height'] * \ 142 | config['width'] * config['channel'] 143 | train_set = create_dataset(config) 144 | train_loader = torch.utils.data.DataLoader(train_set, 145 | batch_size=config['batch_size'], 146 | shuffle=True) 147 | 148 | model = GRBM(config['visible_size'], 149 | config['hidden_size'], 150 | CD_step=config['CD_step'], 151 | CD_burnin=config['CD_burnin'], 152 | init_var=config['init_var'], 153 | inference_method=config['inference_method'], 154 | Langevin_step=config['Langevin_step'], 155 | Langevin_eta=config['Langevin_eta'], 156 | is_anneal_Langevin=True, 157 | Langevin_adjust_step=config['Langevin_adjust_step']) 158 | 159 | if config['cuda']: 160 | model.cuda() 161 | 162 | param_wd, param_no_wd = [], [] 163 | for xx, yy in model.named_parameters(): 164 | if 'W' in xx: 165 | param_wd += [yy] 166 | else: 167 | param_no_wd += [yy] 168 | 169 | optimizer = optim.SGD([{ 170 | 'params': param_no_wd, 171 | 'weight_decay': 0 172 | }, { 173 | 'params': param_wd 174 | }], 175 | lr=config['lr'], 176 | momentum=0.0, 177 | weight_decay=config['wd']) 178 | 179 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 180 | optimizer, config['epochs']) 181 | 182 | if config['resume'] > 0: 183 | load(model, config['exp_folder'], config['resume']) 184 | 185 | for epoch in range(config['resume']): 186 | scheduler.step() 187 | 188 | is_show_training_data = False 189 | for epoch in range(config['resume'] + 1, config['epochs'] + 1): 190 | if epoch <= config['Langevin_adjust_warmup_epoch']: 191 | model.set_Langevin_adjust_step(config['CD_step']) 192 | else: 193 | model.set_Langevin_adjust_step(config['Langevin_adjust_step']) 194 | 195 | recon_loss = train(model, 196 | train_loader, 197 | optimizer, 198 | config) 199 | 200 | var = model.get_var().detach().cpu().numpy() 201 | 202 | # show samples periodically 203 | if epoch % config['log_interval'] == 0: 204 | if 'GMM' in config['dataset']: 205 | logger.info( 206 | f'PID={pid} || {epoch} epoch || mean = {model.mu.detach().cpu().numpy()} || var={model.get_var().detach().cpu().numpy()} || Reconstruction Loss = {recon_loss}' 207 | ) 208 | else: 209 | logger.info( 210 | f'PID={pid} || {epoch} epoch || var={model.get_var().mean().item()} || Reconstruction Loss = {recon_loss}' 211 | ) 212 | 213 | visualize_sampling(model, 214 | epoch, 215 | config, 216 | is_show_gif=config['is_vis_verbose']) 217 | 218 | # visualize one mini-batch of training data 219 | if not is_show_training_data and 'GMM' not in config['dataset']: 220 | data, _ = next(iter(train_loader)) 221 | mean = config['img_mean'].view(1, -1, 1, 1).to(data.device) 222 | std = config['img_std'].view(1, -1, 1, 1).to(data.device) 223 | vis_data = (data * std + mean).clamp(min=0, max=1) 224 | utils.save_image( 225 | utils.make_grid(vis_data, 226 | nrow=config['sampling_nrow'], 227 | normalize=False, 228 | padding=1, 229 | pad_value=1.0).cpu(), 230 | f"{config['exp_folder']}/training_imgs.png") 231 | is_show_training_data = True 232 | 233 | # visualize filters & hidden states 234 | if config['is_vis_verbose']: 235 | filters = model.W.T.view(model.W.shape[1], config['channel'], 236 | config['height'], config['width']) 237 | utils.save_image( 238 | filters, 239 | f"{config['exp_folder']}/filters_epoch_{epoch:05d}.png", 240 | nrow=8, 241 | normalize=True, 242 | padding=1, 243 | pad_value=1.0) 244 | 245 | # visualize hidden states 246 | data, _ = next(iter(train_loader)) 247 | h_pos = model.prob_h_given_v( 248 | data.view(data.shape[0], -1).cuda(), model.get_var()) 249 | utils.save_image(h_pos.view(1, 1, -1, config['hidden_size']), 250 | f"{config['exp_folder']}/hidden_epoch_{epoch:05d}.png", 251 | normalize=True) 252 | 253 | # save models periodically 254 | if epoch % config['save_interval'] == 0: 255 | save(model, config['exp_folder'], epoch) 256 | 257 | scheduler.step() 258 | 259 | 260 | if __name__ == '__main__': 261 | parser = argparse.ArgumentParser() 262 | parser.add_argument('-d', '--dataset', type=str, default='mnist', 263 | help='Dataset name {gmm_iso, gmm_aniso, mnist, fashionmnist, celeba, celeba2K}') 264 | args = parser.parse_args() 265 | train_model(args) 266 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | tqdm 3 | seaborn -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import utils 2 | from mpl_toolkits.axes_grid1 import make_axes_locatable 3 | import numpy as np 4 | import seaborn as sns 5 | import math 6 | import logging 7 | import torch 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt # NOQA 11 | from PIL import Image # NOQA 12 | 13 | sns.set_theme(style="darkgrid") 14 | 15 | 16 | def setup_logging(log_level, log_file, logger_name="exp_logger"): 17 | """ Setup logging """ 18 | numeric_level = getattr(logging, log_level.upper(), None) 19 | if not isinstance(numeric_level, int): 20 | raise ValueError("Invalid log level: %s" % log_level) 21 | 22 | logging.basicConfig( 23 | filename=log_file, 24 | filemode="w", 25 | format="%(levelname)-5s | %(asctime)s | File %(filename)-20s | Line %(lineno)-5d | %(message)s", 26 | datefmt="%m/%d/%Y %I:%M:%S %p", 27 | level=numeric_level) 28 | 29 | console = logging.StreamHandler() 30 | console.setLevel(numeric_level) 31 | formatter = logging.Formatter( 32 | "%(levelname)-5s | %(asctime)s | %(filename)-25s | line %(lineno)-5d: %(message)s" 33 | ) 34 | console.setFormatter(formatter) 35 | logging.getLogger(logger_name).addHandler(console) 36 | 37 | return get_logger(logger_name) 38 | 39 | 40 | def get_logger(logger_name="exp_logger"): 41 | return logging.getLogger(logger_name) 42 | 43 | 44 | def cosine_schedule(eta_min=0, eta_max=1, T=10): 45 | return [ 46 | eta_min + (eta_max - eta_min) * (1 + math.cos(tt * math.pi / T)) / 2 47 | for tt in range(T) 48 | ] 49 | 50 | 51 | def unnormalize_img_tuple(img_tuple, mean, std): 52 | if isinstance(std, torch.Tensor): 53 | mean = mean.view(1, -1, 1, 1).to(img_tuple[0][1].device) 54 | std = std.view(1, -1, 1, 1).to(img_tuple[0][1].device) 55 | 56 | return [(xx[0], (xx[1] * std + mean).clamp(min=0, max=1)) for xx in img_tuple] 57 | 58 | 59 | def fig2img(fig): 60 | import io 61 | buf = io.BytesIO() 62 | fig.savefig(buf, bbox_inches='tight') 63 | buf.seek(0) 64 | img = Image.open(buf) 65 | return img 66 | 67 | 68 | def show_img(matrix, title): 69 | plt.figure() 70 | plt.axis('off') 71 | plt.gray() 72 | img = np.array(matrix, np.float64) 73 | plt.imshow(img) 74 | plt.title(title) 75 | 76 | fig = plt.gcf() 77 | img_out = fig2img(fig) 78 | plt.close() 79 | 80 | return img_out 81 | 82 | 83 | def save_gif_fancy(imgs, nrow, save_name): 84 | imgs = (show_img(utils.make_grid(xx[1], 85 | nrow=nrow, 86 | normalize=False, 87 | padding=1, 88 | pad_value=1.0).permute(1, 2, 0).cpu().numpy(), f'sample at {xx[0]:03d} step') for xx in imgs) 89 | img = next(imgs) 90 | img.save(fp=save_name, 91 | format='GIF', 92 | append_images=imgs, 93 | save_all=True, 94 | duration=400, 95 | loop=0) 96 | 97 | 98 | def visualize_sampling(model, epoch, config, tag=None, is_show_gif=True): 99 | tag = '' if tag is None else tag 100 | B, C, H, W = config['sampling_batch_size'], config['channel'], config[ 101 | 'height'], config['width'] 102 | v_init = torch.randn(B, C, H, W).cuda() 103 | v_list = model.sampling(v_init, 104 | num_steps=config['sampling_steps'], 105 | save_gap=config['sampling_gap']) 106 | 107 | if 'GMM' in config['dataset']: 108 | samples = v_list[-1][1].view(B, -1).cpu().numpy() 109 | vis_2D_samples(samples, config, tags=f'{epoch:05d}') 110 | vis_density_GRBM(model, config, epoch=epoch) 111 | else: 112 | if is_show_gif: 113 | v_list = unnormalize_img_tuple(v_list, config['img_mean'], 114 | config['img_std']) 115 | save_gif_fancy( 116 | v_list, config['sampling_nrow'], 117 | f"{config['exp_folder']}/sample_imgs_epoch_{epoch:05d}{tag}.gif") 118 | img_vis = v_list[-1][1] 119 | else: 120 | if isinstance(config['img_std'], torch.Tensor): 121 | mean = config['img_mean'].view(1, -1, 1, 1).cuda() 122 | std = config['img_std'].view(1, -1, 1, 1).cuda() 123 | else: 124 | mean = config['img_mean'] 125 | std = config['img_std'] 126 | 127 | img_vis = (v_list[-1][1] * std + mean).clamp(min=0, max=1) 128 | 129 | utils.save_image( 130 | utils.make_grid(img_vis, 131 | nrow=config['sampling_nrow'], 132 | normalize=False, 133 | padding=1, 134 | pad_value=1.0).cpu(), 135 | f"{config['exp_folder']}/sample_imgs_epoch_{epoch:05d}{tag}.png") 136 | 137 | 138 | def vis_2D_samples(samples, config, tags=None): 139 | f, ax = plt.subplots(figsize=(6, 6)) 140 | sns.scatterplot(x=samples[:, 0], y=samples[:, 1], color="#4CB391") 141 | ax.set(xlim=(-10, 10)) 142 | ax.set(ylim=(-10, 10)) 143 | plt.show() 144 | plt.savefig( 145 | f"{config['exp_folder']}/samples_{tags}.png", bbox_inches='tight') 146 | plt.close() 147 | 148 | 149 | def vis_density_GMM(model, config): 150 | fig, ax = plt.subplots() 151 | x_density, y_density = 500, 500 152 | xses = np.linspace(-10, 10, x_density) 153 | yses = np.linspace(-10, 10, y_density) 154 | xy = torch.tensor([[[x, y] for x in xses] 155 | for y in yses]).view(-1, 2).cuda().float() 156 | log_density_values = model.log_prob(xy) 157 | log_density_values = log_density_values.detach().view( 158 | x_density, y_density).cpu().numpy() 159 | dx = (xses[1] - xses[0]) / 2 160 | dy = (yses[1] - yses[0]) / 2 161 | extent = [xses[0] - dx, xses[-1] + dx, yses[0] - dy, yses[-1] + dy] 162 | im = ax.imshow(np.exp(log_density_values), 163 | extent=extent, 164 | origin='lower', 165 | cmap='viridis') 166 | divider = make_axes_locatable(ax) 167 | cax = divider.append_axes('right', size='5%', pad=0.05) 168 | cb = fig.colorbar(im, cax=cax) 169 | cb.set_label('probability density') 170 | plt.show() 171 | plt.savefig(f"{config['exp_folder']}/GMM_density.png", bbox_inches='tight') 172 | plt.close() 173 | 174 | 175 | def vis_density_GRBM(model, config, epoch=None): 176 | fig, ax = plt.subplots() 177 | x_density, y_density = 500, 500 178 | xses = np.linspace(-10, 10, x_density) 179 | yses = np.linspace(-10, 10, y_density) 180 | xy = torch.tensor([[[x, y] for x in xses] 181 | for y in yses]).view(-1, 2).cuda().float() 182 | eng_val = -model.marginal_energy(xy) 183 | eng_val = eng_val.detach().view(x_density, y_density).cpu().numpy() 184 | dx = (xses[1] - xses[0]) / 2 185 | dy = (yses[1] - yses[0]) / 2 186 | extent = [xses[0] - dx, xses[-1] + dx, yses[0] - dy, yses[-1] + dy] 187 | im = ax.imshow(eng_val, extent=extent, origin='lower', cmap='viridis') 188 | divider = make_axes_locatable(ax) 189 | cax = divider.append_axes('right', size='5%', pad=0.05) 190 | cb = fig.colorbar(im, cax=cax) 191 | cb.set_label('negative energy') 192 | plt.show() 193 | plt.savefig(f"{config['exp_folder']}/GRBM_density_{epoch:05d}.png", 194 | bbox_inches='tight') 195 | plt.close() 196 | -------------------------------------------------------------------------------- /vis/CelebA.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSL-Lab/GRBM/308b8a5c8da6b998e2acc7b19e52839346ac72f8/vis/CelebA.gif -------------------------------------------------------------------------------- /vis/MNIST.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSL-Lab/GRBM/308b8a5c8da6b998e2acc7b19e52839346ac72f8/vis/MNIST.gif --------------------------------------------------------------------------------