├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── imgs ├── u1.gif ├── u2.gif ├── u3.gif └── u4.gif ├── langevin.py ├── make_langevin_gif.py ├── models.py ├── tests.py └── train_energy_based_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sangwoong Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-energy-based-model 2 | This repository provides simple illustrative working examples for energy-based models (EBM) in PyTorch. 3 | 4 | The aim of the repository is to provide educational resources, to validate each step with toy examples, and to build a platform for future experiment. 5 | 6 | ## Quickstart 7 | 8 | The main requirements are `python>=3.6` and `torch>=1.2`. 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | **Validate Langevin dynamics sampling** 15 | ``` 16 | python run_langevin.py 8gaussians 17 | ``` 18 | 19 | **Training an energy-based model** 20 | ``` 21 | python run_ebm.py 8gaussians 22 | ``` 23 | 24 | 25 | ## Expected Results 26 | 27 | 28 | ## Directories 29 | 30 | * `run_langevin.py` : Run Langevin dynamics sampling of a toy distribution. Produces images of samples. 31 | * `run_ebm.py` : Train an EBM for a samples from a toy distribution. 32 | * `langevin.py` : Codes related to Langevin dynamics 33 | * `model.py` : Codes related to neural networks 34 | * `data.py` : Codes related to generating toy distributions 35 | 36 | 37 | ## Further reading 38 | 39 | * IGEMB 40 | * LeCun 41 | * secretely 42 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def potential_fn(dataset): 6 | """ 7 | toy potention functions 8 | Code borrowed from https://github.com/kamenbliznashki/normalizing_flows/blob/master/bnaf.py""" 9 | w1 = lambda z: torch.sin(2 * math.pi * z[:,0] / 4) 10 | w2 = lambda z: 3 * torch.exp(-0.5 * ((z[:,0] - 1)/0.6)**2) 11 | w3 = lambda z: 3 * torch.sigmoid((z[:,0] - 1) / 0.3) 12 | 13 | if dataset == 'u1': 14 | return lambda z: 0.5 * ((torch.norm(z, p=2, dim=1) - 2) / 0.4)**2 - \ 15 | torch.log(torch.exp(-0.5*((z[:,0] - 2) / 0.6)**2) + \ 16 | torch.exp(-0.5*((z[:,0] + 2) / 0.6)**2) + 1e-10) 17 | 18 | elif dataset == 'u2': 19 | return lambda z: 0.5 * ((z[:,1] - w1(z)) / 0.4)**2 20 | 21 | elif dataset == 'u3': 22 | return lambda z: - torch.log(torch.exp(-0.5*((z[:,1] - w1(z))/0.35)**2) + \ 23 | torch.exp(-0.5*((z[:,1] - w1(z) + w2(z))/0.35)**2) + 1e-10) 24 | 25 | elif dataset == 'u4': 26 | return lambda z: - torch.log(torch.exp(-0.5*((z[:,1] - w1(z))/0.4)**2) + \ 27 | torch.exp(-0.5*((z[:,1] - w1(z) + w3(z))/0.35)**2) + 1e-10) 28 | 29 | else: 30 | raise RuntimeError('Invalid potential name to sample from.') 31 | 32 | 33 | def sample_2d_data(dataset, n_samples): 34 | """generate samples from 2D toy distributions 35 | Code borrowed from https://github.com/kamenbliznashki/normalizing_flows/blob/master/bnaf.py""" 36 | z = torch.randn(n_samples, 2) 37 | 38 | if dataset == '8gaussians': 39 | scale = 4 40 | sq2 = 1/math.sqrt(2) 41 | centers = [(1,0), (-1,0), (0,1), (0,-1), (sq2,sq2), (-sq2,sq2), (sq2,-sq2), (-sq2,-sq2)] 42 | centers = torch.tensor([(scale * x, scale * y) for x,y in centers]) 43 | return sq2 * (0.5 * z + centers[torch.randint(len(centers), size=(n_samples,))]) 44 | 45 | elif dataset == '2spirals': 46 | n = torch.sqrt(torch.rand(n_samples // 2)) * 540 * (2 * math.pi) / 360 47 | d1x = - torch.cos(n) * n + torch.rand(n_samples // 2) * 0.5 48 | d1y = torch.sin(n) * n + torch.rand(n_samples // 2) * 0.5 49 | x = torch.cat([torch.stack([ d1x, d1y], dim=1), 50 | torch.stack([-d1x, -d1y], dim=1)], dim=0) / 3 51 | return x + 0.1*z 52 | 53 | elif dataset == 'checkerboard': 54 | x1 = torch.rand(n_samples) * 4 - 2 55 | x2_ = torch.rand(n_samples) - torch.randint(0, 2, (n_samples,), dtype=torch.float) * 2 56 | x2 = x2_ + x1.floor() % 2 57 | return torch.stack([x1, x2], dim=1) * 2 58 | 59 | elif dataset == 'rings': 60 | n_samples4 = n_samples3 = n_samples2 = n_samples // 4 61 | n_samples1 = n_samples - n_samples4 - n_samples3 - n_samples2 62 | 63 | # so as not to have the first point = last point, set endpoint=False in np; here shifted by one 64 | linspace4 = torch.linspace(0, 2 * math.pi, n_samples4 + 1)[:-1] 65 | linspace3 = torch.linspace(0, 2 * math.pi, n_samples3 + 1)[:-1] 66 | linspace2 = torch.linspace(0, 2 * math.pi, n_samples2 + 1)[:-1] 67 | linspace1 = torch.linspace(0, 2 * math.pi, n_samples1 + 1)[:-1] 68 | 69 | circ4_x = torch.cos(linspace4) 70 | circ4_y = torch.sin(linspace4) 71 | circ3_x = torch.cos(linspace4) * 0.75 72 | circ3_y = torch.sin(linspace3) * 0.75 73 | circ2_x = torch.cos(linspace2) * 0.5 74 | circ2_y = torch.sin(linspace2) * 0.5 75 | circ1_x = torch.cos(linspace1) * 0.25 76 | circ1_y = torch.sin(linspace1) * 0.25 77 | 78 | x = torch.stack([torch.cat([circ4_x, circ3_x, circ2_x, circ1_x]), 79 | torch.cat([circ4_y, circ3_y, circ2_y, circ1_y])], dim=1) * 3.0 80 | 81 | # random sample 82 | x = x[torch.randint(0, n_samples, size=(n_samples,))] 83 | 84 | # Add noise 85 | return x + torch.normal(mean=torch.zeros_like(x), std=0.08*torch.ones_like(x)) 86 | 87 | else: 88 | raise RuntimeError('Invalid `dataset` to sample from.') 89 | -------------------------------------------------------------------------------- /imgs/u1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/pytorch-energy-based-model/052c21a8bedd93d9abf3c8924e7ae46cf5b84d57/imgs/u1.gif -------------------------------------------------------------------------------- /imgs/u2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/pytorch-energy-based-model/052c21a8bedd93d9abf3c8924e7ae46cf5b84d57/imgs/u2.gif -------------------------------------------------------------------------------- /imgs/u3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/pytorch-energy-based-model/052c21a8bedd93d9abf3c8924e7ae46cf5b84d57/imgs/u3.gif -------------------------------------------------------------------------------- /imgs/u4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swyoon/pytorch-energy-based-model/052c21a8bedd93d9abf3c8924e7ae46cf5b84d57/imgs/u4.gif -------------------------------------------------------------------------------- /langevin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.autograd as autograd 4 | 5 | 6 | def sample_langevin(x, model, stepsize, n_steps, noise_scale=None, intermediate_samples=False): 7 | """Draw samples using Langevin dynamics 8 | x: torch.Tensor, initial points 9 | model: An energy-based model 10 | stepsize: float 11 | n_steps: integer 12 | noise_scale: Optional. float. If None, set to np.sqrt(stepsize * 2) 13 | """ 14 | if noise_scale is None: 15 | noise_scale = np.sqrt(stepsize * 2) 16 | 17 | l_samples = [] 18 | l_dynamics = [] 19 | x.requires_grad = True 20 | for _ in range(n_steps): 21 | l_samples.append(x.detach().to('cpu')) 22 | noise = torch.randn_like(x) * noise_scale 23 | out = model(x) 24 | grad = autograd.grad(out.sum(), x, only_inputs=True)[0] 25 | dynamics = stepsize * grad + noise 26 | x = x + dynamics 27 | l_samples.append(x.detach().to('cpu')) 28 | l_dynamics.append(dynamics.detach().to('cpu')) 29 | 30 | if intermediate_samples: 31 | return l_samples, l_dynamics 32 | else: 33 | return l_samples[-1] 34 | 35 | 36 | -------------------------------------------------------------------------------- /make_langevin_gif.py: -------------------------------------------------------------------------------- 1 | """ 2 | make_langevin_gif.py 3 | =================== 4 | Generate Langevin dyanmics sampling gif file using matplotlib 5 | Output files are saved in imgs directory 6 | """ 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torch 10 | from matplotlib.animation import FuncAnimation 11 | import argparse 12 | 13 | from langevin import sample_langevin 14 | from data import sample_2d_data, potential_fn 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('energy_function', help='select toy energy function to generate sample from. (u1, u2, u3, u4)', 19 | choices=['u1', 'u2', 'u3', 'u4']) 20 | parser.add_argument('--no-arrow', action='store_true', help='disable display of arrows') 21 | parser.add_argument('--out', help='the name of output file. default is the name of energy function. ex) u1.gif', 22 | default=None) 23 | args = parser.parse_args() 24 | 25 | def init(): 26 | """initialize animation""" 27 | global point, arrow 28 | ax = plt.gca() 29 | ax.contour(XX, YY, np.exp(-e_grid.view(100,100))) 30 | return (point, arrow) 31 | 32 | 33 | def update(i): 34 | """update animation for i-th frame""" 35 | global point, arrow, ax 36 | g = l_dynamics[i] 37 | s = l_sample[i] 38 | 39 | point.set_offsets(s) 40 | arrow.set_offsets(s) 41 | arrow.set_UVC(U=g[:,0], V=g[:,1]) 42 | ax.set_title(f'Step: {i}') 43 | return (point, arrow) 44 | 45 | # configuration 46 | grid_lim = 4 47 | n_grid = 100 48 | n_sample = 100 49 | 50 | stepsize = 0.03 51 | n_steps = 100 52 | 53 | # prepare for contour plot 54 | energy_fn = potential_fn(args.energy_function) 55 | 56 | xs = np.linspace(- grid_lim, grid_lim, n_grid) 57 | ys = np.linspace(- grid_lim, grid_lim, n_grid) 58 | XX, YY = np.meshgrid(xs, ys) 59 | grids = np.stack([XX.flatten(), YY.flatten()]).T 60 | e_grid = energy_fn(torch.tensor(grids)) 61 | 62 | # run langevin dynamics 63 | grad_log_p = lambda x: - energy_fn(x) 64 | x0 = torch.randn(n_grid, 2) 65 | l_sample, l_dynamics = sample_langevin(x0, grad_log_p, stepsize, n_steps, intermediate_samples=True) 66 | 67 | # plot 68 | fig = plt.figure() 69 | ax = plt.gca() 70 | plt.axis('equal') 71 | 72 | point = plt.scatter([],[]) 73 | if args.no_arrow: 74 | arrow = None 75 | else: 76 | arrow = plt.quiver([0], [0], [1], [1], scale=0.5, scale_units='xy', headwidth=2, headlength=2, alpha=0.3) 77 | plt.tight_layout() 78 | 79 | anim = FuncAnimation(fig, update, frames=np.arange(100), 80 | init_func=init, 81 | interval=200, blit=False) 82 | if args.out is None: 83 | outfile = f'imgs/{args.energy_function}.gif' 84 | else: 85 | outfile = f'imgs/{args.out}.gif' 86 | anim.save(outfile, writer='pillow', dpi=80) 87 | print(f'file saved in {outfile}') 88 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvNet(nn.Module): 7 | def __init__(self, in_chan=1, out_chan=64, nh=8, out_activation=None): 8 | """ 9 | ConvNet tailored for MNIST (28x28 images) 10 | nh: multiplier for the number of filters 11 | """ 12 | super(ConvNet, self).__init__() 13 | self.conv1 = nn.Conv2d(in_chan, nh * 4, kernel_size=3, bias=True) 14 | self.conv2 = nn.Conv2d(nh * 4, nh * 8, kernel_size=3, bias=True) 15 | self.max1 = nn.MaxPool2d(kernel_size=2, stride=2) 16 | self.conv3 = nn.Conv2d(nh * 8, nh * 8, kernel_size=3, bias=True) 17 | self.conv4 = nn.Conv2d(nh * 8, nh * 16, kernel_size=3, bias=True) 18 | self.max2 = nn.MaxPool2d(kernel_size=2, stride=2) 19 | self.conv5 = nn.Conv2d(nh * 16, out_chan, kernel_size=4, bias=True) 20 | self.in_chan, self.out_chan = in_chan, out_chan 21 | self.out_activation = out_activation 22 | 23 | def forward(self, x): 24 | x = self.conv1(x) 25 | x = F.relu(x) 26 | x = self.conv2(x) 27 | x = F.relu(x) 28 | x = self.max1(x) 29 | x = self.conv3(x) 30 | x = F.relu(x) 31 | x = self.conv4(x) 32 | x = F.relu(x) 33 | x = self.max2(x) 34 | x = self.conv5(x) 35 | return x 36 | 37 | 38 | # Fully Connected Network 39 | def get_activation(s_act): 40 | if s_act == 'relu': 41 | return nn.ReLU(inplace=True) 42 | elif s_act == 'sigmoid': 43 | return nn.Sigmoid() 44 | elif s_act == 'softplus': 45 | return nn.Softplus() 46 | elif s_act == 'linear': 47 | return None 48 | elif s_act == 'tanh': 49 | return nn.Tanh() 50 | elif s_act == 'leakyrelu': 51 | return nn.LeakyReLU(0.2, inplace=True) 52 | elif s_act == 'softmax': 53 | return nn.Softmax(dim=1) 54 | else: 55 | raise ValueError(f'Unexpected activation: {s_act}') 56 | 57 | 58 | class FCNet(nn.Module): 59 | """fully-connected network""" 60 | def __init__(self, in_dim, out_dim, l_hidden=(50,), activation='sigmoid', out_activation='linear'): 61 | super(FCNet, self).__init__() 62 | l_neurons = tuple(l_hidden) + (out_dim,) 63 | if isinstance(activation, str): 64 | activation = (activation,) * len(l_hidden) 65 | activation = tuple(activation) + (out_activation,) 66 | 67 | l_layer = [] 68 | prev_dim = in_dim 69 | for i_layer, (n_hidden, act) in enumerate(zip(l_neurons, activation)): 70 | l_layer.append(nn.Linear(prev_dim, n_hidden)) 71 | act_fn = get_activation(act) 72 | if act_fn is not None: 73 | l_layer.append(act_fn) 74 | prev_dim = n_hidden 75 | 76 | self.net = nn.Sequential(*l_layer) 77 | self.in_dim = in_dim 78 | self.out_shape = (out_dim,) 79 | 80 | def forward(self, x): 81 | return self.net(x) 82 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import FCNet 3 | from langevin import sample_langevin 4 | 5 | def test_langevin(): 6 | X = torch.randn(100, 2) 7 | model = FCNet(2, 1, l_hidden=(50,)) 8 | sample = sample_langevin(X, model, 0.1, 10) 9 | 10 | 11 | -------------------------------------------------------------------------------- /train_energy_based_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | train_energy_based_model.py 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch.optim import Adam 7 | from torch.utils.data import TensorDataset, DataLoader 8 | from tqdm import tqdm 9 | import argparse 10 | 11 | from models import FCNet, ConvNet 12 | from langevin import sample_langevin 13 | from data import sample_2d_data 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('dataset', choices=('8gaussians', '2spirals', 'checkerboard', 'rings', 'MNIST')) 17 | parser.add_argument('model', choices=('FCNet', 'ConvNet')) 18 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate. default: 1e-3') 19 | parser.add_argument('--stepsize', type=float, default=0.1, help='Langevin dynamics step size. default 0.1') 20 | parser.add_argument('--n_step', type=int, default=100, help='The number of Langevin dynamics steps. default 100') 21 | parser.add_argument('--n_epoch', type=int, default=100, help='The number of training epoches. default 100') 22 | parser.add_argument('--alpha', type=float, default=1., help='Regularizer coefficient. default 100') 23 | args = parser.parse_args() 24 | 25 | # load dataset 26 | N_train = 5000 27 | N_val = 1000 28 | N_test = 5000 29 | 30 | X_train = sample_2d_data(args.dataset, N_train) 31 | X_val = sample_2d_data(args.dataset, N_val) 32 | X_test = sample_2d_data(args.dataset, N_test) 33 | 34 | train_dl = DataLoader(TensorDataset(X_train), batch_size=32, shuffle=True, num_workers=8) 35 | val_dl = DataLoader(TensorDataset(X_train), batch_size=32, shuffle=True, num_workers=8) 36 | test_dl = DataLoader(TensorDataset(X_train), batch_size=32, shuffle=True, num_workers=8) 37 | 38 | # build model 39 | if args.model == 'FCNet': 40 | model = FCNet(in_dim=2, out_dim=1, l_hidden=(100, 100), activation='relu', out_activation='linear') 41 | elif args.model == 'ConvNet': 42 | model = ConvNet(in_chan=1, out_chan=1) 43 | model.cuda() 44 | 45 | opt = Adam(model.parameters(), lr=args.lr) 46 | 47 | # train loop 48 | for i_epoch in range(args.n_epoch): 49 | l_loss = [] 50 | for pos_x, in train_dl: 51 | 52 | pos_x = pos_x.cuda() 53 | 54 | neg_x = torch.randn_like(pos_x) 55 | neg_x = sample_langevin(neg_x, model, args.stepsize, args.n_steps, intermediate_samples=False) 56 | 57 | opt.zero_grad() 58 | pos_out = model(pos_x) 59 | neg_out = model(neg_x) 60 | 61 | loss = (pos_out - neg_out) + args.alpha * (pos_out ** 2 + neg_out ** 2) 62 | loss = loss.mean() 63 | loss.backward() 64 | 65 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) 66 | opt.step() 67 | 68 | l_loss.append(loss.item()) 69 | print(np.mean(l_loss)) 70 | 71 | 72 | # draw samples 73 | 74 | --------------------------------------------------------------------------------