├── .gitignore ├── LICENSE ├── README.md ├── model.py ├── sample.png ├── samples └── .gitignore └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # igebm-pytorch 2 | Implicit Generation and Generalization in Energy Based Models (https://arxiv.org/abs/1903.08689) in PyTorch 3 | 4 | ![Sample](sample.png) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import utils 6 | 7 | 8 | class SpectralNorm: 9 | def __init__(self, name, bound=False): 10 | self.name = name 11 | self.bound = bound 12 | 13 | def compute_weight(self, module): 14 | weight = getattr(module, self.name + '_orig') 15 | u = getattr(module, self.name + '_u') 16 | size = weight.size() 17 | weight_mat = weight.contiguous().view(size[0], -1) 18 | 19 | with torch.no_grad(): 20 | v = weight_mat.t() @ u 21 | v = v / v.norm() 22 | u = weight_mat @ v 23 | u = u / u.norm() 24 | 25 | sigma = u @ weight_mat @ v 26 | 27 | if self.bound: 28 | weight_sn = weight / (sigma + 1e-6) * torch.clamp(sigma, max=1) 29 | 30 | else: 31 | weight_sn = weight / sigma 32 | 33 | return weight_sn, u 34 | 35 | @staticmethod 36 | def apply(module, name, bound): 37 | fn = SpectralNorm(name, bound) 38 | 39 | weight = getattr(module, name) 40 | del module._parameters[name] 41 | module.register_parameter(name + '_orig', weight) 42 | input_size = weight.size(0) 43 | u = weight.new_empty(input_size).normal_() 44 | module.register_buffer(name, weight) 45 | module.register_buffer(name + '_u', u) 46 | 47 | module.register_forward_pre_hook(fn) 48 | 49 | return fn 50 | 51 | def __call__(self, module, input): 52 | weight_sn, u = self.compute_weight(module) 53 | setattr(module, self.name, weight_sn) 54 | setattr(module, self.name + '_u', u) 55 | 56 | 57 | def spectral_norm(module, init=True, std=1, bound=False): 58 | if init: 59 | nn.init.normal_(module.weight, 0, std) 60 | 61 | if hasattr(module, 'bias') and module.bias is not None: 62 | module.bias.data.zero_() 63 | 64 | SpectralNorm.apply(module, 'weight', bound=bound) 65 | 66 | return module 67 | 68 | 69 | class ResBlock(nn.Module): 70 | def __init__(self, in_channel, out_channel, n_class=None, downsample=False): 71 | super().__init__() 72 | 73 | self.conv1 = spectral_norm( 74 | nn.Conv2d( 75 | in_channel, 76 | out_channel, 77 | 3, 78 | padding=1, 79 | bias=False if n_class is not None else True, 80 | ) 81 | ) 82 | 83 | self.conv2 = spectral_norm( 84 | nn.Conv2d( 85 | out_channel, 86 | out_channel, 87 | 3, 88 | padding=1, 89 | bias=False if n_class is not None else True, 90 | ), std=1e-10, bound=True 91 | ) 92 | 93 | self.class_embed = None 94 | 95 | if n_class is not None: 96 | class_embed = nn.Embedding(n_class, out_channel * 2 * 2) 97 | class_embed.weight.data[:, : out_channel * 2] = 1 98 | class_embed.weight.data[:, out_channel * 2 :] = 0 99 | 100 | self.class_embed = class_embed 101 | 102 | self.skip = None 103 | 104 | if in_channel != out_channel or downsample: 105 | self.skip = nn.Sequential( 106 | spectral_norm(nn.Conv2d(in_channel, out_channel, 1, bias=False)) 107 | ) 108 | 109 | self.downsample = downsample 110 | 111 | def forward(self, input, class_id=None): 112 | out = input 113 | 114 | out = self.conv1(out) 115 | 116 | if self.class_embed is not None: 117 | embed = self.class_embed(class_id).view(input.shape[0], -1, 1, 1) 118 | weight1, weight2, bias1, bias2 = embed.chunk(4, 1) 119 | out = weight1 * out + bias1 120 | 121 | out = F.leaky_relu(out, negative_slope=0.2) 122 | 123 | out = self.conv2(out) 124 | 125 | if self.class_embed is not None: 126 | out = weight2 * out + bias2 127 | 128 | if self.skip is not None: 129 | skip = self.skip(input) 130 | 131 | else: 132 | skip = input 133 | 134 | out = out + skip 135 | 136 | if self.downsample: 137 | out = F.avg_pool2d(out, 2) 138 | 139 | out = F.leaky_relu(out, negative_slope=0.2) 140 | 141 | return out 142 | 143 | 144 | class IGEBM(nn.Module): 145 | def __init__(self, n_class=None): 146 | super().__init__() 147 | 148 | self.conv1 = spectral_norm(nn.Conv2d(3, 128, 3, padding=1), std=1) 149 | 150 | self.blocks = nn.ModuleList( 151 | [ 152 | ResBlock(128, 128, n_class, downsample=True), 153 | ResBlock(128, 128, n_class), 154 | ResBlock(128, 256, n_class, downsample=True), 155 | ResBlock(256, 256, n_class), 156 | ResBlock(256, 256, n_class, downsample=True), 157 | ResBlock(256, 256, n_class), 158 | ] 159 | ) 160 | 161 | self.linear = nn.Linear(256, 1) 162 | 163 | def forward(self, input, class_id=None): 164 | out = self.conv1(input) 165 | 166 | out = F.leaky_relu(out, negative_slope=0.2) 167 | 168 | for block in self.blocks: 169 | out = block(out, class_id) 170 | 171 | out = F.relu(out) 172 | out = out.view(out.shape[0], out.shape[1], -1).sum(2) 173 | out = self.linear(out) 174 | 175 | return out 176 | 177 | -------------------------------------------------------------------------------- /sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/igebm-pytorch/3c4f104af8781f232fd994759c60b7c19f13eba7/sample.png -------------------------------------------------------------------------------- /samples/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | import torch 6 | from torch import nn, optim 7 | from torch.utils.data import DataLoader 8 | 9 | from torchvision import datasets, transforms, utils 10 | 11 | from tqdm import tqdm 12 | 13 | from model import IGEBM 14 | 15 | 16 | class SampleBuffer: 17 | def __init__(self, max_samples=10000): 18 | self.max_samples = max_samples 19 | self.buffer = [] 20 | 21 | def __len__(self): 22 | return len(self.buffer) 23 | 24 | def push(self, samples, class_ids=None): 25 | samples = samples.detach().to('cpu') 26 | class_ids = class_ids.detach().to('cpu') 27 | 28 | for sample, class_id in zip(samples, class_ids): 29 | self.buffer.append((sample.detach(), class_id)) 30 | 31 | if len(self.buffer) > self.max_samples: 32 | self.buffer.pop(0) 33 | 34 | def get(self, n_samples, device='cuda'): 35 | items = random.choices(self.buffer, k=n_samples) 36 | samples, class_ids = zip(*items) 37 | samples = torch.stack(samples, 0) 38 | class_ids = torch.tensor(class_ids) 39 | samples = samples.to(device) 40 | class_ids = class_ids.to(device) 41 | 42 | return samples, class_ids 43 | 44 | 45 | def sample_buffer(buffer, batch_size=128, p=0.95, device='cuda'): 46 | if len(buffer) < 1: 47 | return ( 48 | torch.rand(batch_size, 3, 32, 32, device=device), 49 | torch.randint(0, 10, (batch_size,), device=device), 50 | ) 51 | 52 | n_replay = (np.random.rand(batch_size) < p).sum() 53 | 54 | replay_sample, replay_id = buffer.get(n_replay) 55 | random_sample = torch.rand(batch_size - n_replay, 3, 32, 32, device=device) 56 | random_id = torch.randint(0, 10, (batch_size - n_replay,), device=device) 57 | 58 | return ( 59 | torch.cat([replay_sample, random_sample], 0), 60 | torch.cat([replay_id, random_id], 0), 61 | ) 62 | 63 | 64 | def sample_data(loader): 65 | loader_iter = iter(loader) 66 | 67 | while True: 68 | try: 69 | yield next(loader_iter) 70 | 71 | except StopIteration: 72 | loader_iter = iter(loader) 73 | 74 | yield next(loader_iter) 75 | 76 | 77 | def requires_grad(parameters, flag=True): 78 | for p in parameters: 79 | p.requires_grad = flag 80 | 81 | 82 | def clip_grad(parameters, optimizer): 83 | with torch.no_grad(): 84 | for group in optimizer.param_groups: 85 | for p in group['params']: 86 | state = optimizer.state[p] 87 | 88 | if 'step' not in state or state['step'] < 1: 89 | continue 90 | 91 | step = state['step'] 92 | exp_avg_sq = state['exp_avg_sq'] 93 | _, beta2 = group['betas'] 94 | 95 | bound = 3 * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1 96 | p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound)) 97 | 98 | 99 | def train(model, alpha=1, step_size=10, sample_step=60, device='cuda'): 100 | dataset = datasets.CIFAR10('.', download=True, transform=transforms.ToTensor()) 101 | loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) 102 | loader = tqdm(enumerate(sample_data(loader))) 103 | 104 | buffer = SampleBuffer() 105 | 106 | noise = torch.randn(128, 3, 32, 32, device=device) 107 | 108 | parameters = model.parameters() 109 | optimizer = optim.Adam(parameters, lr=1e-4, betas=(0.0, 0.999)) 110 | 111 | for i, (pos_img, pos_id) in loader: 112 | pos_img, pos_id = pos_img.to(device), pos_id.to(device) 113 | 114 | neg_img, neg_id = sample_buffer(buffer, pos_img.shape[0]) 115 | neg_img.requires_grad = True 116 | 117 | requires_grad(parameters, False) 118 | model.eval() 119 | 120 | for k in tqdm(range(sample_step)): 121 | if noise.shape[0] != neg_img.shape[0]: 122 | noise = torch.randn(neg_img.shape[0], 3, 32, 32, device=device) 123 | 124 | noise.normal_(0, 0.005) 125 | neg_img.data.add_(noise.data) 126 | 127 | neg_out = model(neg_img, neg_id) 128 | neg_out.sum().backward() 129 | neg_img.grad.data.clamp_(-0.01, 0.01) 130 | 131 | neg_img.data.add_(-step_size, neg_img.grad.data) 132 | 133 | neg_img.grad.detach_() 134 | neg_img.grad.zero_() 135 | 136 | neg_img.data.clamp_(0, 1) 137 | 138 | neg_img = neg_img.detach() 139 | 140 | requires_grad(parameters, True) 141 | model.train() 142 | 143 | model.zero_grad() 144 | 145 | pos_out = model(pos_img, pos_id) 146 | neg_out = model(neg_img, neg_id) 147 | 148 | loss = alpha * (pos_out ** 2 + neg_out ** 2) 149 | loss = loss + (pos_out - neg_out) 150 | loss = loss.mean() 151 | loss.backward() 152 | 153 | clip_grad(parameters, optimizer) 154 | 155 | optimizer.step() 156 | 157 | buffer.push(neg_img, neg_id) 158 | 159 | loader.set_description(f'loss: {loss.item():.5f}') 160 | 161 | if i % 100 == 0: 162 | utils.save_image( 163 | neg_img.detach().to('cpu'), 164 | f'samples/{str(i).zfill(5)}.png', 165 | nrow=16, 166 | normalize=True, 167 | range=(0, 1), 168 | ) 169 | 170 | 171 | if __name__ == '__main__': 172 | model = IGEBM(10).to('cuda') 173 | train(model) 174 | --------------------------------------------------------------------------------