├── .gitignore ├── LICENSE ├── README.md ├── arts ├── generated_samples.jpg └── loss-function.png ├── const.py ├── data.py ├── lsun.py ├── main.py ├── model.py ├── requirements.txt ├── train.py ├── utils.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | 104 | datasets/ 105 | checkpoints/ 106 | .env 107 | _cache_* 108 | .download.cgi* 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ha Junsoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-wgan-gp 2 | PyTorch implementation of [Improved Training of Wasserstein GANs, arxiv:1704.00028](https://arxiv.org/abs/1704.00028) 3 | 4 | ![loss-function-with-penalty](./arts/loss-function.png) 5 | 6 | 7 | ## Results 8 | 9 | Generated samples after training 1 epoch on LSUN Bedroom dataset 10 | 11 | ![generated samples](./arts/generated_samples.jpg) 12 | 13 | 14 | ## Installation 15 | ``` 16 | $ git clone https://github.com/kuc2477/pytorch-wgan-gp && cd pytorch-wgan-gp 17 | $ pip install -r requirements.txt 18 | ``` 19 | 20 | ## CLI 21 | 22 | #### Train 23 | ``` 24 | $ # To download LSUN dataset (optional) 25 | $ ./lsun.py --category=bedroom 26 | 27 | $ # To Run a Visdom server and start training on LSUN dataset. 28 | $ python -m visdom.server 29 | $ ./main.py --train --dataset=lsun [--resume] 30 | ``` 31 | 32 | #### Test 33 | ``` 34 | $ # checkout "./samples" directory 35 | $ ./main.py --test --dataset=lsun 36 | ``` 37 | 38 | 39 | ## References 40 | - [Improved Training of Wasserstein GANs, arxiv:1704.00028](https://arxiv.org/abs/1704.00028) 41 | - [caogang/wgan-gp](https://github.com/caogang/wgan-gp) 42 | 43 | 44 | ## Author 45 | Ha Junsoo / [@kuc2477](https://github.com/kuc2477) / MIT License 46 | -------------------------------------------------------------------------------- /arts/generated_samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-wgan-gp/0a1e8bd719577ffd7061059fcf26bc61fe2bd076/arts/generated_samples.jpg -------------------------------------------------------------------------------- /arts/loss-function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuc2477/pytorch-wgan-gp/0a1e8bd719577ffd7061059fcf26bc61fe2bd076/arts/loss-function.png -------------------------------------------------------------------------------- /const.py: -------------------------------------------------------------------------------- 1 | EPSILON = 1e-16 2 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | 5 | # ================== 6 | # Dataset Transforms 7 | # ================== 8 | 9 | _MNIST_TRAIN_TRANSFORMS = _MNIST_TEST_TRANSFORMS = [ 10 | transforms.ToTensor(), 11 | transforms.ToPILImage(), 12 | transforms.Pad(2), 13 | transforms.ToTensor(), 14 | ] 15 | 16 | _CIFAR_TRAIN_TRANSFORMS = [ 17 | transforms.RandomCrop(32), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | ] 21 | 22 | _CIFAR_TEST_TRANSFORMS = [ 23 | transforms.ToTensor(), 24 | ] 25 | 26 | _SHVN_IMAGE_SIZE = 32 27 | _SHVN_TRAIN_TRANSFORMS = _SHVN_TEST_TRANSFORMS = [ 28 | transforms.Scale(_SHVN_IMAGE_SIZE), 29 | transforms.CenterCrop(_SHVN_IMAGE_SIZE), 30 | transforms.ToTensor(), 31 | ] 32 | 33 | _LSUN_IMAGE_SIZE = 64 34 | _LSUN_TRAIN_TRANSFORMS = _LSUN_TEST_TRANSFORMS = [ 35 | transforms.Scale(_LSUN_IMAGE_SIZE), 36 | transforms.CenterCrop(_LSUN_IMAGE_SIZE), 37 | transforms.ToTensor(), 38 | ] 39 | 40 | 41 | def _LSUN_COLLATE_FN(batch): 42 | return torch.stack([x[0] for x in batch]) 43 | 44 | 45 | # ======== 46 | # Datasets 47 | # ======== 48 | 49 | TRAIN_DATASETS = { 50 | 'mnist': lambda: datasets.MNIST( 51 | './datasets/mnist', train=True, download=True, 52 | transform=transforms.Compose(_MNIST_TRAIN_TRANSFORMS) 53 | ), 54 | 'cifar10': lambda: datasets.CIFAR10( 55 | './datasets/cifar10', train=True, download=True, 56 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS) 57 | ), 58 | 'cifar100': lambda: datasets.CIFAR100( 59 | './datasets/cifar100', train=True, download=True, 60 | transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS) 61 | ), 62 | 'shvn': lambda: datasets.SVHN( 63 | './datasets/shvn', download=True, split='train', 64 | transform=transforms.Compose(_SHVN_TEST_TRANSFORMS) 65 | ), 66 | 'lsun': lambda: datasets.LSUNClass( 67 | './datasets/lsun/bedroom_train', 68 | transform=transforms.Compose(_LSUN_TRAIN_TRANSFORMS) 69 | ) 70 | } 71 | 72 | TEST_DATASETS = { 73 | 'mnist': lambda: datasets.MNIST( 74 | './datasets/mnist', train=False, 75 | transform=transforms.Compose(_MNIST_TEST_TRANSFORMS) 76 | ), 77 | 'cifar10': lambda: datasets.CIFAR10( 78 | './datasets/cifar10', train=False, 79 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS) 80 | ), 81 | 'cifar100': lambda: datasets.CIFAR100( 82 | './datasets/cifar100', train=False, 83 | transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS) 84 | ), 85 | 'shvn': lambda: datasets.SVHN( 86 | './datasets/shvn', download=True, split='test', 87 | transform=transforms.Compose(_SHVN_TEST_TRANSFORMS) 88 | ), 89 | 'lsun': lambda: datasets.LSUN( 90 | './datasets/lsun/bedroom_test', 91 | transform=transforms.Compose(_LSUN_TRAIN_TRANSFORMS) 92 | ) 93 | } 94 | 95 | DATASET_CONFIGS = { 96 | 'mnist': {'size': 32, 'channels': 1, 'classes': 10}, 97 | 'cifar10': {'size': 32, 'channels': 3, 'classes': 10}, 98 | 'cifar100': {'size': 32, 'channels': 3, 'classes': 100}, 99 | 'shvn': {'size': _SHVN_IMAGE_SIZE, 'channels': 3}, 100 | 'lsun': {'size': _LSUN_IMAGE_SIZE, 'channels': 3, 101 | 'collate_fn': _LSUN_COLLATE_FN}, 102 | } 103 | -------------------------------------------------------------------------------- /lsun.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from contextlib import contextmanager 3 | import os 4 | import os.path 5 | import argparse 6 | import subprocess 7 | import zipfile 8 | from colorama import Fore 9 | from tqdm import tqdm 10 | import requests 11 | 12 | 13 | # ================ 14 | # Helper Functions 15 | # ================ 16 | 17 | def c(string, color): 18 | return '{}{}{}'.format(getattr(Fore, color.upper()), string, Fore.RESET) 19 | 20 | 21 | @contextmanager 22 | def log(start, end, start_color='yellow', end_color='cyan'): 23 | print(c('>> ' + start, start_color)) 24 | yield 25 | print(c('>> ' + end, end_color) + '\n') 26 | 27 | 28 | def _download(url, filename=None): 29 | local_filename = filename or url.split('/')[-1] 30 | temp_filename = '.{}'.format(local_filename) 31 | response = requests.get(url, stream=True) 32 | total_size = int(response.headers.get('content-length', 0)) 33 | 34 | with open(temp_filename, 'wb') as f: 35 | for chunk in tqdm( 36 | response.iter_content(1024 * 32), 37 | total=total_size // (1024 * 32), 38 | unit='KiB', unit_scale=True, 39 | ): 40 | if chunk: 41 | f.write(chunk) 42 | response.close() 43 | os.rename(temp_filename, local_filename) 44 | return local_filename 45 | 46 | 47 | def _extract_zip(zipfile_path, extraction_path='.'): 48 | with zipfile.ZipFile(zipfile_path) as zf: 49 | extracted_dirname = zf.namelist()[0] 50 | zf.extractall(extraction_path) 51 | return extracted_dirname 52 | 53 | 54 | def _extract_gz(gzfile_path, extraction_path='.'): 55 | cmd = ['gzip', '-d', gzfile_path] 56 | subprocess.call(cmd) 57 | return '.'.join(gzfile_path.split('.')[:-1]) 58 | 59 | 60 | def _download_zip_dataset( 61 | url, dataset_dirpath, dataset_dirname, download_path=None): 62 | download_path = _download(url) 63 | download_dirpath = os.path.dirname(download_path) 64 | extracted_dirname = _extract_zip(download_path) 65 | 66 | os.remove(download_path) 67 | os.renames(os.path.join(download_dirpath, extracted_dirname), 68 | os.path.join(dataset_dirpath, dataset_dirname)) 69 | 70 | 71 | def _download_gz_dataset( 72 | url, dataset_dirpath, dataset_dirname, download_path=None): 73 | download_path = _download(url) 74 | download_dirpath = os.path.dirname(download_path) 75 | extracted_filename = _extract_gz(download_path) 76 | 77 | os.renames(os.path.join(download_dirpath, extracted_filename), 78 | os.path.join(dataset_dirpath, dataset_dirname, 79 | extracted_filename)) 80 | 81 | 82 | # ==== 83 | # Main 84 | # ==== 85 | 86 | def maybe_download_lsun(dataset_dirpath, dataset_dirname, 87 | category, set_name, tag='latest'): 88 | dataset_path = os.path.join(dataset_dirpath, dataset_dirname) 89 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 90 | '&category={category}&set={set_name}'.format(**locals()) 91 | 92 | # check existance 93 | if os.path.exists(dataset_path): 94 | print(c( 95 | 'lsun dataset already exists: {}' 96 | .format(dataset_path), 'red' 97 | )) 98 | return 99 | 100 | # start downloading lsun dataset 101 | with log( 102 | 'download lsun dataset from {}'.format(url), 103 | 'downloaded lsun dataset to {}'.format(dataset_path)): 104 | _download_zip_dataset(url, dataset_dirpath, dataset_dirname) 105 | 106 | 107 | parser = argparse.ArgumentParser(description='LSUN dataset downloading CLI') 108 | parser.add_argument('--dataset-dir', type=str, default='./datasets/lsun') 109 | parser.add_argument('--category', type=str, default='bedroom') 110 | parser.add_argument('--test', action='store_false', dest='train') 111 | 112 | 113 | if __name__ == '__main__': 114 | args = parser.parse_args() 115 | category = args.category 116 | set_name = 'train' if args.train else 'test' 117 | maybe_download_lsun( 118 | args.dataset_dir, 119 | '{category}_{set_name}'.format(category=category, set_name=set_name), 120 | category, 121 | set_name, 122 | ) 123 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import os.path 4 | import argparse 5 | import torch 6 | from data import DATASET_CONFIGS, TRAIN_DATASETS 7 | from model import WGAN 8 | from train import train 9 | import utils 10 | 11 | 12 | parser = argparse.ArgumentParser('PyTorch Implementation of WGAN-GP') 13 | parser.add_argument( 14 | '--dataset', type=str, 15 | choices=list(TRAIN_DATASETS.keys()), default='cifar100' 16 | ) 17 | 18 | parser.add_argument('--z-size', type=int, default=100) 19 | parser.add_argument('--g-channel-size', type=int, default=64) 20 | parser.add_argument('--c-channel-size', type=int, default=64) 21 | parser.add_argument('--lamda', type=float, default=10.) 22 | 23 | parser.add_argument('--lr', type=float, default=1e-04) 24 | parser.add_argument('--weight-decay', type=float, default=1e-04) 25 | parser.add_argument('--beta1', type=float, default=0.5) 26 | parser.add_argument('--beta2', type=float, default=0.9) 27 | parser.add_argument('--epochs', type=int, default=100) 28 | parser.add_argument('--batch-size', type=int, default=32) 29 | parser.add_argument('--sample-size', type=int, default=36) 30 | parser.add_argument('--d-trains-per-g-train', type=int, default=5) 31 | 32 | parser.add_argument('--sample-dir', type=str, default='samples') 33 | parser.add_argument('--checkpoint-dir', type=str, default='checkpoints') 34 | parser.add_argument('--loss-log-interval', type=int, default=30) 35 | parser.add_argument('--image-log-interval', type=int, default=100) 36 | parser.add_argument('--checkpoint-interval', type=int, default=1000) 37 | parser.add_argument('--resume', action='store_true') 38 | parser.add_argument('--no-gpus', action='store_false', dest='cuda') 39 | 40 | command = parser.add_mutually_exclusive_group(required=True) 41 | command.add_argument('--test', action='store_true', dest='test') 42 | command.add_argument('--train', action='store_false', dest='test') 43 | 44 | 45 | if __name__ == '__main__': 46 | args = parser.parse_args() 47 | cuda = torch.cuda.is_available() and args.cuda 48 | dataset = TRAIN_DATASETS[args.dataset]() 49 | dataset_config = DATASET_CONFIGS[args.dataset] 50 | 51 | wgan = WGAN( 52 | label=args.dataset, 53 | z_size=args.z_size, 54 | image_size=dataset_config['size'], 55 | image_channel_size=dataset_config['channels'], 56 | c_channel_size=args.c_channel_size, 57 | g_channel_size=args.g_channel_size, 58 | ) 59 | 60 | utils.gaussian_intiailize(wgan, 0.02) 61 | 62 | if cuda: 63 | wgan.cuda() 64 | 65 | if args.test: 66 | path = os.path.join(args.sample_dir, '{}-sample'.format(wgan.name)) 67 | utils.load_checkpoint(wgan, args.checkpoint_dir) 68 | utils.test_model(wgan, args.sample_size, path) 69 | else: 70 | train( 71 | wgan, dataset, 72 | collate_fn=dataset_config.get('collate_fn', None), 73 | lr=args.lr, 74 | weight_decay=args.weight_decay, 75 | beta1=args.beta1, 76 | beta2=args.beta2, 77 | lamda=args.lamda, 78 | batch_size=args.batch_size, 79 | sample_size=args.sample_size, 80 | epochs=args.epochs, 81 | d_trains_per_g_train=args.d_trains_per_g_train, 82 | checkpoint_dir=args.checkpoint_dir, 83 | checkpoint_interval=args.checkpoint_interval, 84 | image_log_interval=args.image_log_interval, 85 | loss_log_interval=args.loss_log_interval, 86 | resume=args.resume, 87 | cuda=cuda, 88 | ) 89 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | from torch import autograd 6 | from const import EPSILON 7 | 8 | 9 | class Critic(nn.Module): 10 | def __init__(self, image_size, image_channel_size, channel_size): 11 | # configurations 12 | super().__init__() 13 | self.image_size = image_size 14 | self.image_channel_size = image_channel_size 15 | self.channel_size = channel_size 16 | 17 | # layers 18 | self.conv1 = nn.Conv2d( 19 | image_channel_size, channel_size, 20 | kernel_size=4, stride=2, padding=1 21 | ) 22 | self.conv2 = nn.Conv2d( 23 | channel_size, channel_size*2, 24 | kernel_size=4, stride=2, padding=1 25 | ) 26 | self.conv3 = nn.Conv2d( 27 | channel_size*2, channel_size*4, 28 | kernel_size=4, stride=2, padding=1 29 | ) 30 | self.conv4 = nn.Conv2d( 31 | channel_size*4, channel_size*8, 32 | kernel_size=4, stride=1, padding=1, 33 | ) 34 | self.fc = nn.Linear((image_size//8)**2 * channel_size*4, 1) 35 | 36 | def forward(self, x): 37 | x = F.leaky_relu(self.conv1(x)) 38 | x = F.leaky_relu(self.conv2(x)) 39 | x = F.leaky_relu(self.conv3(x)) 40 | x = F.leaky_relu(self.conv4(x)) 41 | x = x.view(-1, (self.image_size//8)**2 * self.channel_size*4) 42 | return self.fc(x) 43 | 44 | 45 | class Generator(nn.Module): 46 | def __init__(self, z_size, image_size, image_channel_size, channel_size): 47 | # configurations 48 | super().__init__() 49 | self.z_size = z_size 50 | self.image_size = image_size 51 | self.image_channel_size = image_channel_size 52 | self.channel_size = channel_size 53 | 54 | # layers 55 | self.fc = nn.Linear(z_size, (image_size//8)**2 * channel_size*8) 56 | self.bn0 = nn.BatchNorm2d(channel_size*8) 57 | self.bn1 = nn.BatchNorm2d(channel_size*4) 58 | self.deconv1 = nn.ConvTranspose2d( 59 | channel_size*8, channel_size*4, 60 | kernel_size=4, stride=2, padding=1 61 | ) 62 | self.bn2 = nn.BatchNorm2d(channel_size*2) 63 | self.deconv2 = nn.ConvTranspose2d( 64 | channel_size*4, channel_size*2, 65 | kernel_size=4, stride=2, padding=1, 66 | ) 67 | self.bn3 = nn.BatchNorm2d(channel_size) 68 | self.deconv3 = nn.ConvTranspose2d( 69 | channel_size*2, channel_size, 70 | kernel_size=4, stride=2, padding=1 71 | ) 72 | self.deconv4 = nn.ConvTranspose2d( 73 | channel_size, image_channel_size, 74 | kernel_size=3, stride=1, padding=1 75 | ) 76 | 77 | def forward(self, z): 78 | g = F.relu(self.bn0(self.fc(z).view( 79 | z.size(0), 80 | self.channel_size*8, 81 | self.image_size//8, 82 | self.image_size//8, 83 | ))) 84 | g = F.relu(self.bn1(self.deconv1(g))) 85 | g = F.relu(self.bn2(self.deconv2(g))) 86 | g = F.relu(self.bn3(self.deconv3(g))) 87 | g = self.deconv4(g) 88 | return F.sigmoid(g) 89 | 90 | 91 | class WGAN(nn.Module): 92 | def __init__(self, label, z_size, 93 | image_size, image_channel_size, 94 | c_channel_size, g_channel_size): 95 | # configurations 96 | super().__init__() 97 | self.label = label 98 | self.z_size = z_size 99 | self.image_size = image_size 100 | self.image_channel_size = image_channel_size 101 | self.c_channel_size = c_channel_size 102 | self.g_channel_size = g_channel_size 103 | 104 | # components 105 | self.critic = Critic( 106 | image_size=self.image_size, 107 | image_channel_size=self.image_channel_size, 108 | channel_size=self.c_channel_size, 109 | ) 110 | self.generator = Generator( 111 | z_size=self.z_size, 112 | image_size=self.image_size, 113 | image_channel_size=self.image_channel_size, 114 | channel_size=self.g_channel_size, 115 | ) 116 | 117 | @property 118 | def name(self): 119 | return ( 120 | 'WGAN-GP' 121 | '-z{z_size}' 122 | '-c{c_channel_size}' 123 | '-g{g_channel_size}' 124 | '-{label}-{image_size}x{image_size}x{image_channel_size}' 125 | ).format( 126 | z_size=self.z_size, 127 | c_channel_size=self.c_channel_size, 128 | g_channel_size=self.g_channel_size, 129 | label=self.label, 130 | image_size=self.image_size, 131 | image_channel_size=self.image_channel_size, 132 | ) 133 | 134 | def c_loss(self, x, z, return_g=False): 135 | g = self.generator(z) 136 | c_x = self.critic(x).mean() 137 | c_g = self.critic(g).mean() 138 | l = -(c_x-c_g) 139 | return (l, g) if return_g else l 140 | 141 | def g_loss(self, z, return_g=False): 142 | g = self.generator(z) 143 | l = -self.critic(g).mean() 144 | return (l, g) if return_g else l 145 | 146 | def sample_image(self, size): 147 | return self.generator(self.sample_noise(size)) 148 | 149 | def sample_noise(self, size): 150 | z = Variable(torch.randn(size, self.z_size)) * .1 151 | return z.cuda() if self._is_on_cuda() else z 152 | 153 | def gradient_penalty(self, x, g, lamda): 154 | assert x.size() == g.size() 155 | a = torch.rand(x.size(0), 1) 156 | a = a.cuda() if self._is_on_cuda() else a 157 | a = a\ 158 | .expand(x.size(0), x.nelement()//x.size(0))\ 159 | .contiguous()\ 160 | .view( 161 | x.size(0), 162 | self.image_channel_size, 163 | self.image_size, 164 | self.image_size 165 | ) 166 | interpolated = Variable(a*x.data + (1-a)*g.data, requires_grad=True) 167 | c = self.critic(interpolated) 168 | gradients = autograd.grad( 169 | c, interpolated, grad_outputs=( 170 | torch.ones(c.size()).cuda() if self._is_on_cuda() else 171 | torch.ones(c.size()) 172 | ), 173 | create_graph=True, 174 | retain_graph=True, 175 | )[0] 176 | return lamda * ((1-(gradients+EPSILON).norm(2, dim=1))**2).mean() 177 | 178 | def _is_on_cuda(self): 179 | return next(self.parameters()).is_cuda 180 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pytorch 2 | http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl 3 | torchvision==0.1.9 4 | torchtext==0.1.1 5 | visdom 6 | 7 | # data 8 | scipy 9 | scikit-learn 10 | numpy 11 | pillow 12 | 13 | # utils (others) 14 | colorama 15 | tqdm 16 | lmdb 17 | requests 18 | fake-useragent 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch.autograd import Variable 3 | from tqdm import tqdm 4 | import utils 5 | import visual 6 | 7 | 8 | def train(model, dataset, collate_fn=None, 9 | lr=1e-04, weight_decay=1e-04, beta1=0.5, beta2=.999, lamda=10., 10 | batch_size=32, sample_size=32, epochs=10, 11 | d_trains_per_g_train=2, 12 | checkpoint_dir='checkpoints', 13 | checkpoint_interval=1000, 14 | image_log_interval=100, 15 | loss_log_interval=30, 16 | resume=False, cuda=False): 17 | # define the optimizers. 18 | generator_optimizer = optim.Adam( 19 | model.generator.parameters(), lr=lr, betas=(beta1, beta2), 20 | weight_decay=weight_decay 21 | ) 22 | critic_optimizer = optim.Adam( 23 | model.critic.parameters(), lr=lr, betas=(beta1, beta2), 24 | weight_decay=weight_decay 25 | ) 26 | 27 | # prepare the model and statistics. 28 | model.train() 29 | epoch_start = 1 30 | 31 | # load checkpoint if needed. 32 | if resume: 33 | iteration = utils.load_checkpoint(model, checkpoint_dir) 34 | epoch_start = iteration // (len(dataset) // batch_size) + 1 35 | 36 | for epoch in range(epoch_start, epochs+1): 37 | data_loader = utils.get_data_loader( 38 | dataset, batch_size, 39 | cuda=cuda, collate_fn=collate_fn, 40 | ) 41 | data_stream = tqdm(enumerate(data_loader, 1)) 42 | for batch_index, data in data_stream: 43 | # unpack the data if needed. 44 | try: 45 | x, _ = data 46 | except ValueError: 47 | x = data 48 | 49 | # where are we? 50 | dataset_size = len(data_loader.dataset) 51 | dataset_batches = len(data_loader) 52 | iteration = ( 53 | (epoch-1)*(dataset_size // batch_size) + 54 | batch_index + 1 55 | ) 56 | 57 | # prepare the data. 58 | x = Variable(x).cuda() if cuda else Variable(x) 59 | d_trains = ( 60 | 30 if (batch_index < 25 or batch_index % 500 == 0) else 61 | d_trains_per_g_train 62 | ) 63 | 64 | # run the critic and backpropagate the errors. 65 | for _ in range(d_trains): 66 | critic_optimizer.zero_grad() 67 | z = model.sample_noise(batch_size) 68 | c_loss, g = model.c_loss(x, z, return_g=True) 69 | c_loss_gp = c_loss + model.gradient_penalty(x, g, lamda=lamda) 70 | c_loss_gp.backward() 71 | critic_optimizer.step() 72 | 73 | # run the generator and backpropagate the errors. 74 | generator_optimizer.zero_grad() 75 | z = model.sample_noise(batch_size) 76 | g_loss = model.g_loss(z) 77 | g_loss.backward() 78 | generator_optimizer.step() 79 | 80 | # update the progress. 81 | data_stream.set_description(( 82 | 'epoch: {epoch}/{epochs} | ' 83 | 'iteration: {iteration} | ' 84 | 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 85 | 'loss => ' 86 | 'g: {g_loss:.4} / ' 87 | 'w: {w_dist:.4}' 88 | ).format( 89 | epoch=epoch, 90 | epochs=epochs, 91 | iteration=iteration, 92 | trained=batch_index*batch_size, 93 | total=dataset_size, 94 | progress=(100.*batch_index/dataset_batches), 95 | g_loss=g_loss.data[0], 96 | w_dist=-c_loss.data[0], 97 | )) 98 | 99 | # send losses to the visdom server. 100 | if iteration % loss_log_interval == 0: 101 | visual.visualize_scalar( 102 | -c_loss.data[0], 103 | 'estimated wasserstein distance between x and g', 104 | iteration=iteration, 105 | env=model.name 106 | ) 107 | visual.visualize_scalar( 108 | g_loss.data[0], 109 | 'generator loss', 110 | iteration=iteration, 111 | env=model.name 112 | ) 113 | 114 | # send sample images to the visdom server. 115 | if iteration % image_log_interval == 0: 116 | visual.visualize_images( 117 | model.sample_image(sample_size).data, 118 | 'generated samples', 119 | env=model.name 120 | ) 121 | 122 | # save the model at checkpoints. 123 | if iteration % checkpoint_interval == 0: 124 | # notify that we've reached to a new checkpoint. 125 | print() 126 | print() 127 | print('#############') 128 | print('# checkpoint!') 129 | print('#############') 130 | print() 131 | 132 | utils.save_checkpoint(model, checkpoint_dir, iteration) 133 | 134 | print() 135 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.dataloader import default_collate 6 | from torch.nn import init 7 | import torchvision 8 | 9 | 10 | def get_data_loader(dataset, batch_size, cuda=False, collate_fn=None): 11 | collate_fn = collate_fn or default_collate 12 | return DataLoader( 13 | dataset, batch_size=batch_size, 14 | shuffle=True, drop_last=True, collate_fn=collate_fn, 15 | **({'num_workers': 0, 'pin_memory': True} if cuda else {}) 16 | ) 17 | 18 | 19 | def save_checkpoint(model, model_dir, iteration): 20 | path = os.path.join(model_dir, model.name) 21 | 22 | # save the checkpoint. 23 | if not os.path.exists(model_dir): 24 | os.makedirs(model_dir) 25 | torch.save({ 26 | 'state': model.state_dict(), 27 | 'iteration': iteration, 28 | }, path) 29 | 30 | # notify that we successfully saved the checkpoint. 31 | print('=> saved the model {name} to {path}'.format( 32 | name=model.name, path=path 33 | )) 34 | 35 | return iteration 36 | 37 | 38 | def load_checkpoint(model, model_dir): 39 | path = os.path.join(model_dir, model.name) 40 | 41 | # load the checkpoint. 42 | checkpoint = torch.load(path) 43 | print('=> loaded checkpoint of {name} from {path}'.format( 44 | name=model.name, path=path 45 | )) 46 | 47 | # load parameters and return the checkpoint's epoch and precision. 48 | model.load_state_dict(checkpoint['state']) 49 | iteration = checkpoint['iteration'] 50 | return iteration 51 | 52 | 53 | def test_model(model, sample_size, path): 54 | os.makedirs(os.path.dirname(path), exist_ok=True) 55 | torchvision.utils.save_image( 56 | model.sample_image(sample_size).data, 57 | path + '.jpg' 58 | ) 59 | print('=> generated sample images at "{}".'.format(path)) 60 | 61 | 62 | def xavier_initialize(model): 63 | modules = [ 64 | m for n, m in model.named_modules() if 65 | 'conv' in n or 'fc' in n 66 | ] 67 | 68 | parameters = [ 69 | p for 70 | m in modules for 71 | p in m.parameters() 72 | ] 73 | 74 | for p in parameters: 75 | if p.dim() >= 2: 76 | init.xavier_normal(p) 77 | else: 78 | init.constant(p, 0) 79 | 80 | 81 | def gaussian_intiailize(model, std=.01): 82 | modules = [ 83 | m for n, m in model.named_modules() if 84 | 'conv' in n or 'fc' in n 85 | ] 86 | 87 | parameters = [ 88 | p for 89 | m in modules for 90 | p in m.parameters() 91 | ] 92 | 93 | for p in parameters: 94 | if p.dim() >= 2: 95 | init.normal(p, std=std) 96 | else: 97 | init.constant(p, 0) 98 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.cuda import FloatTensor as CUDATensor 3 | from visdom import Visdom 4 | 5 | _WINDOW_CASH = {} 6 | 7 | 8 | def _vis(env='main'): 9 | return Visdom(env=env) 10 | 11 | 12 | def visualize_image(tensor, name, label=None, env='main', w=250, h=250, 13 | update_window_without_label=False): 14 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 15 | title = name + ('-{}'.format(label) if label is not None else '') 16 | 17 | _WINDOW_CASH[title] = _vis(env).image( 18 | tensor.numpy(), win=_WINDOW_CASH.get(title), 19 | opts=dict(title=title, width=w, height=h) 20 | ) 21 | 22 | # This is useful when you want to maintain the most recent images. 23 | if update_window_without_label: 24 | _WINDOW_CASH[name] = _vis(env).image( 25 | tensor.numpy(), win=_WINDOW_CASH.get(name), 26 | opts=dict(title=name, width=w, height=h) 27 | ) 28 | 29 | 30 | def visualize_images(tensor, name, label=None, env='main', w=400, h=400, 31 | update_window_without_label=False): 32 | tensor = tensor.cpu() if isinstance(tensor, CUDATensor) else tensor 33 | title = name + ('-{}'.format(label) if label is not None else '') 34 | 35 | _WINDOW_CASH[title] = _vis(env).images( 36 | tensor.numpy(), win=_WINDOW_CASH.get(title), nrow=6, 37 | opts=dict(title=title, width=w, height=h) 38 | ) 39 | 40 | # This is useful when you want to maintain the most recent images. 41 | if update_window_without_label: 42 | _WINDOW_CASH[name] = _vis(env).images( 43 | tensor.numpy(), win=_WINDOW_CASH.get(name), nrow=6, 44 | opts=dict(title=name, width=w, height=h) 45 | ) 46 | 47 | 48 | def visualize_kernel(kernel, name, label=None, env='main', w=250, h=250, 49 | update_window_without_label=False, compress_tensor=False): 50 | # Do not visualize kernels that does not exists. 51 | if kernel is None: 52 | return 53 | 54 | assert len(kernel.size()) in (2, 4) 55 | title = name + ('-{}'.format(label) if label is not None else '') 56 | kernel = kernel.cpu() if isinstance(kernel, CUDATensor) else kernel 57 | kernel_norm = kernel if len(kernel.size()) == 2 else ( 58 | (kernel**2).mean(-1).mean(-1) if compress_tensor else 59 | kernel.view( 60 | kernel.size()[0] * kernel.size()[2], 61 | kernel.size()[1] * kernel.size()[3], 62 | ) 63 | ) 64 | kernel_norm = kernel_norm.abs() 65 | 66 | visualized = ( 67 | (kernel_norm - kernel_norm.min()) / 68 | (kernel_norm.max() - kernel_norm.min()) 69 | ).numpy() 70 | 71 | _WINDOW_CASH[title] = _vis(env).image( 72 | visualized, win=_WINDOW_CASH.get(title), 73 | opts=dict(title=title, width=w, height=h) 74 | ) 75 | 76 | # This is useful when you want to maintain the most recent images. 77 | if update_window_without_label: 78 | _WINDOW_CASH[name] = _vis(env).image( 79 | visualized, win=_WINDOW_CASH.get(name), 80 | opts=dict(title=name, width=w, height=h) 81 | ) 82 | 83 | 84 | def visualize_scalar(scalar, name, iteration, env='main'): 85 | visualize_scalars( 86 | [scalar] if isinstance(scalar, float) or len(scalar) == 1 else scalar, 87 | [name], name, iteration, env=env 88 | ) 89 | 90 | 91 | def visualize_scalars(scalars, names, title, iteration, env='main'): 92 | assert len(scalars) == len(names) 93 | # Convert scalar tensors to numpy arrays. 94 | scalars, names = list(scalars), list(names) 95 | scalars = [s.cpu() if isinstance(s, CUDATensor) else s for s in scalars] 96 | scalars = [s.numpy() if hasattr(s, 'numpy') else np.array([s]) for s in 97 | scalars] 98 | multi = len(scalars) > 1 99 | num = len(scalars) 100 | 101 | options = dict( 102 | fillarea=True, 103 | legend=names, 104 | width=400, 105 | height=400, 106 | xlabel='Iterations', 107 | ylabel=title, 108 | title=title, 109 | marginleft=30, 110 | marginright=30, 111 | marginbottom=80, 112 | margintop=30, 113 | ) 114 | 115 | X = ( 116 | np.column_stack(np.array([iteration] * num)) if multi else 117 | np.array([iteration] * num) 118 | ) 119 | Y = np.column_stack(scalars) if multi else scalars[0] 120 | 121 | if title in _WINDOW_CASH: 122 | _vis(env).updateTrace(X=X, Y=Y, win=_WINDOW_CASH[title], opts=options) 123 | else: 124 | _WINDOW_CASH[title] = _vis(env).line(X=X, Y=Y, opts=options) 125 | --------------------------------------------------------------------------------