├── nets ├── __init__.py ├── layers │ ├── __init__.py │ ├── categorical_batch_norm.py │ └── spectral_norm.py ├── dcgan.py └── resnet.py ├── utils ├── __init__.py ├── losses.py ├── sample.py ├── yaml_utils.py └── load.py ├── training ├── __init__.py ├── scheduler.py ├── evaluator.py └── trainer.py ├── datasets ├── __init__.py ├── cifar10.py ├── lsun.py ├── stl10.py ├── dataset.py └── imagenet_dog.py ├── samples ├── 50000_img.png ├── 250000_img.png └── inception_scores.png ├── LICENSE ├── configs ├── sn_cifar10_conditional.yml ├── sn_projection_dog.yml └── sn_projection_dog64.yml ├── eval.py ├── train.py ├── .gitignore ├── README.md └── generate.py /nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nets/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /samples/50000_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/balansky/pytorch_gan/HEAD/samples/50000_img.png -------------------------------------------------------------------------------- /samples/250000_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/balansky/pytorch_gan/HEAD/samples/250000_img.png -------------------------------------------------------------------------------- /samples/inception_scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/balansky/pytorch_gan/HEAD/samples/inception_scores.png -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def loss_hinge_dis(dis_fake, dis_real): 4 | loss = torch.nn.functional.relu(1.0 - dis_real).mean() + \ 5 | torch.nn.functional.relu(1.0 + dis_fake).mean() 6 | return loss 7 | 8 | def loss_hinge_gen(dis_fake): 9 | loss = -dis_fake.mean() 10 | return loss 11 | -------------------------------------------------------------------------------- /utils/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sample_noises(n_samples, noise_dim, n_categories=None, device=torch.device("cpu")): 5 | noise = torch.randn(n_samples, noise_dim, device=device) 6 | if n_categories: 7 | y_fake = torch.randint(low=0, high=n_categories, size=(n_samples,), dtype=torch.long, 8 | device=device) 9 | else: 10 | y_fake = None 11 | return noise, y_fake -------------------------------------------------------------------------------- /utils/yaml_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import yaml 5 | 6 | 7 | # Copy from tgans repo. 8 | class Config(object): 9 | def __init__(self, config_dict): 10 | self.config = config_dict 11 | 12 | def __getattr__(self, key): 13 | if key in self.config: 14 | return self.config[key] 15 | else: 16 | raise AttributeError(key) 17 | 18 | def __getitem__(self, key): 19 | return self.config[key] 20 | 21 | def __repr__(self): 22 | return yaml.dump(self.config, default_flow_style=False) 23 | 24 | 25 | -------------------------------------------------------------------------------- /training/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LinearDecayLR(torch.optim.lr_scheduler._LRScheduler): 4 | 5 | def __init__(self, optimizer, decay_start, max_iterations, last_epoch=-1): 6 | self.decay_start = decay_start 7 | self.max_iterations = max_iterations 8 | super(LinearDecayLR, self).__init__(optimizer, last_epoch) 9 | self.step_gamma = [base_lr / (max_iterations - decay_start) for base_lr in self.base_lrs] 10 | 11 | def get_lr(self): 12 | if self.last_epoch < self.decay_start: 13 | return self.base_lrs 14 | else: 15 | return [(base_lr - step_gamma*(self.last_epoch - self.decay_start + 1)) for step_gamma, base_lr 16 | in zip(self.step_gamma, self.base_lrs)] 17 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset import Dataset 2 | from torchvision import datasets, transforms 3 | 4 | 5 | class Cifar10(Dataset): 6 | 7 | def __init__(self, root='/tmp', train=True, image_transform=None, **kwargs): 8 | if not image_transform: 9 | image_transform = transforms.Compose([ 10 | transforms.ToTensor(), 11 | transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) 12 | # transforms.Normalize(mean=[.485, .456, .406], std=[0.229, 0.224, 0.225]) 13 | ]) 14 | 15 | cifar_dataset = datasets.CIFAR10(root=root, train=train, download=True, 16 | transform=image_transform) 17 | super(Cifar10, self).__init__(cifar_dataset, **kwargs) 18 | -------------------------------------------------------------------------------- /datasets/lsun.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset import Dataset 2 | from torchvision import datasets, transforms 3 | 4 | 5 | class Lsun(Dataset): 6 | 7 | def __init__(self, root='/tmp', train=True, image_transform=None, **kwargs): 8 | if not image_transform: 9 | image_transform = transforms.Compose([ 10 | transforms.Resize(64), 11 | transforms.ToTensor(), 12 | transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) 13 | ]) 14 | if train: 15 | cls = "train" 16 | else: 17 | cls = "val" 18 | lsun_dataset = datasets.LSUN(root=root, classes=cls, 19 | transform=image_transform) 20 | super(Lsun, self).__init__(lsun_dataset, **kwargs) -------------------------------------------------------------------------------- /datasets/stl10.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset import Dataset 2 | from torchvision import datasets, transforms 3 | 4 | 5 | class Stl10(Dataset): 6 | 7 | def __init__(self, root='/tmp', train=True, image_transform=None, **kwargs): 8 | if not image_transform: 9 | image_transform = transforms.Compose([ 10 | transforms.RandomCrop(64), 11 | transforms.ToTensor(), 12 | transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) 13 | ]) 14 | if train: 15 | split = 'train' 16 | else: 17 | split = 'test' 18 | cifar_dataset = datasets.STL10(root=root, split=split, download=True, 19 | transform=image_transform) 20 | super(Stl10, self).__init__(cifar_dataset, **kwargs) -------------------------------------------------------------------------------- /nets/layers/categorical_batch_norm.py: -------------------------------------------------------------------------------- 1 | from torch.nn import BatchNorm2d 2 | import torch 3 | 4 | 5 | 6 | class CategoricalBatchNorm(torch.nn.Module): 7 | 8 | def __init__(self, num_features, num_categories, eps=1e-5, momentum=0.1, affine=False, 9 | track_running_stats=True): 10 | super(CategoricalBatchNorm, self).__init__() 11 | self.batch_norm = BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) 12 | self.gamma_c = torch.nn.Embedding(num_categories, num_features) 13 | self.beta_c = torch.nn.Embedding(num_categories, num_features) 14 | torch.nn.init.constant_(self.batch_norm.running_var.data, 0) 15 | torch.nn.init.constant_(self.gamma_c.weight.data, 1) 16 | torch.nn.init.constant_(self.beta_c.weight.data, 0) 17 | 18 | def forward(self, input, y): 19 | ret = self.batch_norm(input) 20 | gamma = self.gamma_c(y) 21 | beta = self.beta_c(y) 22 | gamma_b = gamma.unsqueeze(2).unsqueeze(3).expand_as(ret) 23 | beta_b = beta.unsqueeze(2).unsqueeze(3).expand_as(ret) 24 | return gamma_b*ret + beta_b 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 balansky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import DataLoader, _DataLoaderIter, default_collate 2 | 3 | 4 | class Dataset(DataLoader): 5 | 6 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 7 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 8 | timeout=0, worker_init_fn=None): 9 | super(Dataset, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, 10 | num_workers, collate_fn, pin_memory, drop_last, 11 | timeout, worker_init_fn) 12 | self.epochs = 1 13 | self.data_iter = iter(_DataLoaderIter(self)) 14 | 15 | def reinitialize_iter(self): 16 | self.data_iter = iter(_DataLoaderIter(self)) 17 | 18 | def get_next(self): 19 | try: 20 | data = next(self.data_iter) 21 | except StopIteration: 22 | self.reinitialize_iter() 23 | self.epochs += 1 24 | return self.get_next() 25 | return data 26 | 27 | def __iter__(self): 28 | return self.data_iter 29 | -------------------------------------------------------------------------------- /configs/sn_cifar10_conditional.yml: -------------------------------------------------------------------------------- 1 | # conditional CIFAR10 generation with SN and projection discriminator 2 | iteration: 50000 3 | 4 | models: 5 | generator: 6 | fn: nets/resnet.py 7 | name: ResnetGenerator32 8 | args: 9 | z_dim: 128 10 | bottom_width: 4 11 | ch: 256 12 | n_categories: 10 13 | 14 | 15 | discriminator: 16 | fn: nets/resnet.py 17 | name: ResnetDiscriminator32 18 | args: 19 | ch: 128 20 | n_categories: 10 21 | spectral_norm: 1 22 | 23 | dataset: 24 | fn: datasets/cifar10.py 25 | name: Cifar10 26 | args: 27 | train: True 28 | shuffle: True 29 | pin_memory: True 30 | drop_last: True 31 | 32 | optimizer: 33 | name: adam 34 | alpha: 0.0002 35 | beta1: 0.0 36 | beta2: 0.9 37 | 38 | trainer: 39 | fn: training/trainer.py 40 | name: GanTrainer 41 | args: 42 | n_dis: 5 43 | n_gen_samples: 64 44 | loss_type: hinge 45 | display_interval: 100 46 | snapshot_interval: 10000 47 | evaluation_interval: 1000 48 | 49 | 50 | evaluator: 51 | fn: training/evaluator.py 52 | name: Inception 53 | args: 54 | n_images: 5000 55 | batch_size: 100 56 | splits: 1 57 | -------------------------------------------------------------------------------- /configs/sn_projection_dog.yml: -------------------------------------------------------------------------------- 1 | iteration: 250000 2 | 3 | models: 4 | generator: 5 | fn: nets/resnet.py 6 | name: ResnetGenerator64 7 | args: 8 | z_dim: 128 9 | bottom_width: 4 10 | ch: 64 11 | n_categories: 120 12 | 13 | 14 | discriminator: 15 | fn: nets/resnet.py 16 | name: ResnetDiscriminator64 17 | args: 18 | ch: 64 19 | n_categories: 120 20 | spectral_norm: 1 21 | 22 | dataset: 23 | fn: datasets/imagenet_dog.py 24 | name: ImageNetDogDataset 25 | args: 26 | size: 64 27 | augmentation: True 28 | shuffle: True 29 | pin_memory: False 30 | drop_last: True 31 | 32 | optimizer: 33 | name: adam 34 | alpha: 0.0002 35 | beta1: 0.0 36 | beta2: 0.9 37 | 38 | scheduler: 39 | fn: training/scheduler.py 40 | name: LinearDecayLR 41 | args: 42 | decay_start: 200000 43 | max_iterations: 250000 44 | 45 | trainer: 46 | fn: training/trainer.py 47 | name: GanTrainer 48 | args: 49 | n_dis: 5 50 | n_gen_samples: 64 51 | loss_type: hinge 52 | display_interval: 100 53 | snapshot_interval: 1000 54 | evaluation_interval: 5000 55 | 56 | evaluator: 57 | fn: training/evaluator.py 58 | name: Inception 59 | args: 60 | n_images: 50000 61 | batch_size: 100 62 | splits: 10 63 | -------------------------------------------------------------------------------- /configs/sn_projection_dog64.yml: -------------------------------------------------------------------------------- 1 | iteration: 250000 2 | seed: 0 3 | 4 | models: 5 | generator: 6 | fn: nets/resnet.py 7 | name: ResnetGenerator64 8 | args: 9 | z_dim: 128 10 | bottom_width: 4 11 | ch: 64 12 | n_categories: 120 13 | 14 | 15 | discriminator: 16 | fn: nets/resnet.py 17 | name: ResnetDiscriminator64 18 | args: 19 | ch: 64 20 | n_categories: 120 21 | spectral_norm: 1 22 | 23 | dataset: 24 | fn: datasets/imagenet_dog.py 25 | name: ImageNetDogDataset 26 | args: 27 | size: 64 28 | augmentation: True 29 | shuffle: True 30 | pin_memory: True 31 | drop_last: True 32 | 33 | optimizer: 34 | name: adam 35 | alpha: 0.0002 36 | beta1: 0.0 37 | beta2: 0.9 38 | 39 | scheduler: 40 | fn: training/scheduler.py 41 | name: LinearDecayLR 42 | args: 43 | decay_start: 200000 44 | max_iterations: 250000 45 | 46 | trainer: 47 | fn: training/trainer.py 48 | name: GanTrainer 49 | args: 50 | n_dis: 5 51 | n_gen_samples: 64 52 | loss_type: hinge 53 | display_interval: 100 54 | snapshot_interval: 1000 55 | evaluation_interval: 10000 56 | 57 | evaluator: 58 | fn: training/evaluator.py 59 | name: Inception 60 | args: 61 | n_images: 50000 62 | batch_size: 100 63 | splits: 10 64 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | from utils import yaml_utils 4 | from utils.load import * 5 | from training.evaluator import Inception 6 | 7 | 8 | def main(args): 9 | if args.device: 10 | device = torch.device(args.device) 11 | else: 12 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 13 | config = yaml_utils.Config(yaml.load(open(args.config_path))) 14 | gen, _ = load_gan_model(config) 15 | gen.load_state_dict(torch.load(args.model_path)) 16 | gen.eval().to(device) 17 | evaluator = Inception(n_images=args.n_eval, batch_size=args.batch_size, splits=args.splits, device=device) 18 | print("Evaluating Inception Score....") 19 | kl_score, kl_std = evaluator.eval_gen(gen) 20 | print("Inception Score: %.4f, Std: %.4f" % (kl_score, kl_std)) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--model_path', type=str, default='./results/gans', 26 | help='saved model path') 27 | parser.add_argument('--config_path', type=str, default='configs/sn_cifar10_conditional.yml', 28 | help='model configuration file') 29 | parser.add_argument('--batch_size', type=int, default=100, help="evaluation batch size(default:100)") 30 | parser.add_argument('--splits', type=int, default=10, help="splits for inception score(default: 10)") 31 | parser.add_argument('--n_eval', type=int, default=50000, help="total number of evaluations(default:50000)") 32 | parser.add_argument('--device', type=str, default=None, help="cpu or gpu") 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | from utils import yaml_utils 4 | from utils.load import * 5 | from training.trainer import GanTrainer 6 | 7 | torch.manual_seed(0) 8 | torch.cuda.manual_seed_all(0) 9 | 10 | 11 | def main(args): 12 | device = torch.device("cuda:0") 13 | config = yaml_utils.Config(yaml.load(open(args.config_path))) 14 | gen, dis = load_gan_model(config) 15 | gen_optimizer = load_optimizer(config, gen.parameters()) 16 | dis_optimizer = load_optimizer(config, dis.parameters()) 17 | 18 | scheduler_g = load_scheduler(config, gen_optimizer) 19 | scheduler_d = load_scheduler(config, dis_optimizer) 20 | 21 | 22 | dataset = load_dataset(args.batch_size, args.data_dir, args.loaderjob, config) 23 | 24 | evaluator = load_evaluator(config, device) 25 | 26 | trainer = GanTrainer(args.iterations, dataset, gen, dis, gen_optimizer, dis_optimizer, args.result_dir, 27 | scheduler_g, scheduler_d, evaluator=evaluator, device=device, **config.trainer['args']) 28 | 29 | trainer.run() 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--config_path', type=str, default='configs/base.yml', help='path to config file') 35 | parser.add_argument('--data_dir', type=str, default='./data/imagenet') 36 | parser.add_argument('--iterations', type=int, default=250000) 37 | parser.add_argument('--result_dir', type=str, default='./results/gans', 38 | help='directory to save the results to') 39 | parser.add_argument('--batch_size', type=int, default=64, help="mini batch size") 40 | parser.add_argument('--loaderjob', type=int, default=4, 41 | help='number of parallel data loading processes') 42 | args = parser.parse_args() 43 | main(args) -------------------------------------------------------------------------------- /datasets/imagenet_dog.py: -------------------------------------------------------------------------------- 1 | 2 | from datasets.dataset import Dataset 3 | from torchvision import datasets, transforms 4 | import torch 5 | import random 6 | 7 | class Crop(object): 8 | 9 | def __init__(self, augmentation=True, crop_ratio=0.9): 10 | self.augmentation = augmentation 11 | self.crop_ratio = crop_ratio 12 | 13 | def __call__(self, img): 14 | w, h = img.size 15 | short_side = h if h < w else w 16 | if self.augmentation: 17 | crop_size = int(short_side * self.crop_ratio) 18 | top = random.randint(0, h - crop_size - 1) 19 | left = random.randint(0, w - crop_size - 1) 20 | else: 21 | crop_size = short_side 22 | top = (h - crop_size) // 2 23 | left = (w - crop_size) // 2 24 | bottom = top + crop_size 25 | right = left + crop_size 26 | img = img.crop((left, top, right, bottom)) 27 | return img 28 | 29 | def add_noise(tensor): 30 | noise = torch.rand_like(tensor) * (1 / 128.) 31 | tensor += noise 32 | return tensor 33 | 34 | 35 | 36 | class ImageNetDogDataset(Dataset): 37 | 38 | def __init__(self, root='/tmp', size=128, augmentation=True, image_transform=None, **kwargs): 39 | 40 | if not image_transform: 41 | image_transform = transforms.Compose([ 42 | Crop(augmentation=augmentation), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.Resize((size, size)), 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]), 47 | # transforms.Lambda(add_noise) 48 | ]) 49 | 50 | dog_dataset = datasets.ImageFolder(root, image_transform) 51 | super(ImageNetDogDataset, self).__init__(dog_dataset, **kwargs) 52 | 53 | 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea/ 107 | 108 | models/ 109 | *.ini 110 | _vizdoom 111 | test.py 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spectral Normalization and Projection Discriminator(Pytorch) 2 | This project attempts to reproduce the results from "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida. The Official Chainer implementation [**link**](https://github.com/pfnet-research/sngan_projection) 3 | 4 | ### Setup: 5 | `pip install pytorch pyyaml` 6 | 7 | ### Training(cifar10): 8 | ```angular2html 9 | python train.py --config_path configs/sn_cifar10_conditional.yml --batch_size 64 10 | ``` 11 | 12 | ### Evaluation: 13 | Inception Score: 14 | ```angular2html 15 | python eval.py --config_path configs/sn_cifar10_conditional.yml --model_path=/path/to/model 16 | ``` 17 | 18 | Generate Samples: 19 | ```angular2html 20 | python generate.py --config_path configs/sn_cifar10_conditional.yml --model_path=/path/to/model 21 | ``` 22 | 23 | 24 | ### 32x32 Image Samples 25 | ![](samples/50000_img.png) 26 | 27 | model [download](https://drive.google.com/file/d/1SXUSAIPj2gPlKB_EzVV4_ix8X2Bn_cnn/view?usp=sharing) 28 | 29 | 30 | ### 64x64 Dog Samples 31 | ![](samples/250000_img.png) 32 | 33 | model [download](https://drive.google.com/file/d/1luHjHZnLOclmGr684FdY1sXWM7_hcIpE/view?usp=sharing) 34 | 35 | ### Notes 36 | The Inception Score of PyTorch implementation is roughly 1.57 less than tf implementation. The inception score of my implementation is 6.63 which is matched the claim(8.22 - 1.57) from the origin paper. 37 | ![](samples/inception_scores.png) 38 | from [A Note on the Inception Score](https://arxiv.org/pdf/1801.01973.pdf) 39 | 40 | ### References 41 | - Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida. *Spectral Normalization for Generative Adversarial Networks*. ICLR2018. [OpenReview][sngans] 42 | - Takeru Miyato, Masanori Koyama. *cGANs with Projection Discriminator*. ICLR2018. [OpenReview][pcgans] 43 | 44 | -------------------------------------------------------------------------------- /nets/dcgan.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from nets.layers.spectral_norm import SpectralNorm 3 | 4 | 5 | class Generator(nn.Module): 6 | 7 | def __init__(self, nz): 8 | super(Generator, self).__init__() 9 | self.fc = nn.Linear(nz, 7*7*64, bias=False) 10 | self.fc_bn = nn.BatchNorm1d(7*7*64) 11 | self.fc_relu = nn.ReLU() 12 | self.main = nn.Sequential( 13 | 14 | nn.ConvTranspose2d(64, 64, 5, 1, 2, bias=False), 15 | nn.BatchNorm2d(64), 16 | nn.ReLU(), 17 | 18 | nn.ConvTranspose2d(64, 32, 5, 2, 2, output_padding=1, bias=False), 19 | nn.BatchNorm2d(32), 20 | nn.ReLU(), 21 | 22 | nn.ConvTranspose2d(32, 1, 5, 2, 2, output_padding=1, bias=False), 23 | nn.Tanh() 24 | ) 25 | 26 | def forward(self, input): 27 | x = self.fc(input) 28 | x = self.fc_bn(x) 29 | x = self.fc_relu(x) 30 | x = x.view(-1, 64, 7, 7) 31 | x = self.main(x) 32 | return x 33 | 34 | 35 | class Descriminator(nn.Module): 36 | 37 | def __init__(self, nc): 38 | super(Descriminator, self).__init__() 39 | self.main = nn.Sequential( 40 | SpectralNorm(nn.Conv2d(nc, 64, 5, 2, 16, bias=False)), 41 | # nn.Conv2d(nc, 64, 5, 2, 16, bias=False), 42 | nn.LeakyReLU(), 43 | nn.Dropout(0.3), 44 | SpectralNorm(nn.Conv2d(64, 128, 5, 2, 16, bias=False)), 45 | # nn.Conv2d(64, 128, 5, 2, 16, bias=False), 46 | nn.LeakyReLU(), 47 | nn.Dropout(0.3), 48 | ) 49 | self.fc = nn.Linear(28*28*128, 1, bias=False) 50 | self.a = nn.Sigmoid() 51 | 52 | def forward(self, input): 53 | x = self.main(input) 54 | x = x.view(x.size(0), -1) 55 | x = self.fc(x) 56 | x = self.a(x) 57 | return x -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | from utils import yaml_utils 4 | from utils.load import * 5 | from utils.sample import sample_noises 6 | import torchvision 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def main(args): 12 | if args.device: 13 | device = torch.device(args.device) 14 | else: 15 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 16 | config = yaml_utils.Config(yaml.load(open(args.config_path))) 17 | gen, _ = load_gan_model(config) 18 | gen.load_state_dict(torch.load(args.model_path)) 19 | gen.eval().to(device) 20 | if not args.g_category: 21 | batch_noise, batch_labels = sample_noises(args.n_samples, gen.z_dim, gen.n_categories, device) 22 | else: 23 | batch_noise, _ = sample_noises(args.n_samples, gen.z_dim, device=device) 24 | batch_labels = batch_noise.new_full((args.n_samples,), fill_value=args.g_category, dtype=torch.long) 25 | samples = gen(batch_noise, batch_labels).detach().cpu() * .5 + .5 26 | grid = torchvision.utils.make_grid(samples).numpy() 27 | grid = np.transpose(grid, (1, 2, 0)) 28 | plt.imshow(grid) 29 | plt.show() 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--model_path', type=str, default='./results/gans', 35 | help='saved model path') 36 | parser.add_argument('--config_path', type=str, default='configs/sn_cifar10_conditional.yml', 37 | help='model configuration file') 38 | parser.add_argument('--g_category', type=int, default=None, help="category index to generate") 39 | parser.add_argument('--n_samples', type=int, default=64, help="number of samples to generate") 40 | parser.add_argument('--device', type=str, default=None, help="cpu or gpu") 41 | 42 | args = parser.parse_args() 43 | main(args) -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | 5 | Optimizer = { 6 | "adam": torch.optim.Adam 7 | } 8 | 9 | 10 | def load_module(fn, name): 11 | mod_name = os.path.splitext(os.path.basename(fn))[0] 12 | mod_path = os.path.dirname(fn) 13 | sys.path.insert(0, mod_path) 14 | return getattr(__import__(mod_name), name) 15 | 16 | 17 | def load_model(model_fn, model_name, args=None): 18 | model = load_module(model_fn, model_name) 19 | if args: 20 | return model(**args) 21 | return model() 22 | 23 | 24 | def load_optimizer(config, params): 25 | optimizer_config = config.optimizer 26 | optimizer = Optimizer[optimizer_config['name']] 27 | return optimizer(params, lr=optimizer_config['alpha'], 28 | betas=(optimizer_config['beta1'], optimizer_config['beta2'])) 29 | 30 | 31 | def load_scheduler(config, optimizer): 32 | if not hasattr(config, 'scheduler'): 33 | return None 34 | scheduler_config = config.scheduler 35 | scheduler = load_module(scheduler_config['fn'], 36 | scheduler_config['name']) 37 | return scheduler(optimizer, **scheduler_config['args']) 38 | 39 | 40 | def load_evaluator(config, device): 41 | if not hasattr(config, 'evaluator'): 42 | return None 43 | evaluator_config = config.evaluator 44 | evaluator = load_module(evaluator_config['fn'], evaluator_config['name']) 45 | return evaluator(device=device, **evaluator_config['args']) 46 | 47 | 48 | def load_dataset(batch_size, root_dir, num_workers, config): 49 | dataset = load_module(config.dataset['fn'], 50 | config.dataset['name']) 51 | return dataset(root=root_dir, batch_size=batch_size, num_workers=num_workers, **config.dataset['args']) 52 | 53 | 54 | def load_gan_model(config): 55 | gen_config = config.models['generator'] 56 | dis_config = config.models['discriminator'] 57 | gen = load_model(gen_config['fn'], gen_config['name'], gen_config['args']) 58 | dis = load_model(dis_config['fn'], dis_config['name'], dis_config['args']) 59 | return gen, dis 60 | 61 | 62 | def load_updater_class(config): 63 | return load_module(config.updater['fn'], config.updater['name']) -------------------------------------------------------------------------------- /training/evaluator.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import inception_v3 2 | from utils.sample import sample_noises 3 | import torch 4 | import math 5 | 6 | class Inception(object): 7 | 8 | def __init__(self, n_images=50000, batch_size=100, splits=10, device=torch.device("cpu")): 9 | self.n_images = n_images 10 | self.batch_size = batch_size 11 | self.splits = splits 12 | self.n_batches = int(math.ceil(float(n_images)/float(batch_size))) 13 | self.device = device 14 | self.inception_model = inception_v3(pretrained=True, transform_input=False) 15 | # self.mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(2).to(self.device) 16 | # self.std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(2).to(self.device) 17 | self.inception_model.eval().to(device) 18 | 19 | def generate_images(self, gen): 20 | with torch.no_grad(): 21 | batch_noise, batch_y = sample_noises(self.batch_size, gen.z_dim, gen.n_categories, self.device) 22 | # batch_images = (gen(batch_noise, batch_y).detach() * .5 + .5) 23 | # batch_images = (batch_images - self.mean) / self.std 24 | batch_images = gen(batch_noise, batch_y).detach() 25 | return batch_images 26 | 27 | def inception_softmax(self, batch_images): 28 | with torch.no_grad(): 29 | if batch_images.shape[-1] != 299 or batch_images.shape[-2] != 299: 30 | batch_images = torch.nn.functional.interpolate(batch_images, size=(299, 299), mode='bilinear', 31 | align_corners=False) 32 | 33 | y = self.inception_model(batch_images) 34 | y = torch.nn.functional.softmax(y, dim=1) 35 | return y 36 | 37 | def kl_scores(self, ys): 38 | scores = [] 39 | with torch.no_grad(): 40 | for j in range(self.splits): 41 | part = ys[(j*self.n_images//self.splits): ((j+1)*self.n_images // self.splits), :] 42 | kl = part * (torch.log(part) - torch.log(torch.unsqueeze(torch.mean(part, 0), 0))) 43 | kl = torch.mean(torch.sum(kl, 1)) 44 | kl = torch.exp(kl) 45 | scores.append(kl.unsqueeze(0)) 46 | scores = torch.cat(scores, 0) 47 | m_scores = torch.mean(scores).detach().cpu().numpy() 48 | m_std = torch.std(scores).detach().cpu().numpy() 49 | return m_scores, m_std 50 | 51 | def eval_gen(self, gen): 52 | ys = [] 53 | for i in range(self.n_batches): 54 | batch_images = self.generate_images(gen) 55 | y = self.inception_softmax(batch_images) 56 | ys.append(y) 57 | ys = torch.cat(ys, 0) 58 | m_scores, m_std = self.kl_scores(ys) 59 | return m_scores, m_std 60 | 61 | def eval_dataset(self, dataset): 62 | ys = [] 63 | for i in range(self.n_batches): 64 | batch_images = dataset.get_next() 65 | if isinstance(batch_images, list): 66 | batch_images = batch_images[0] 67 | batch_images = batch_images.to(self.device) 68 | y = self.inception_softmax(batch_images) 69 | ys.append(y) 70 | ys = torch.cat(ys, 0) 71 | m_scores, m_std = self.kl_scores(ys) 72 | return m_scores, m_std 73 | 74 | -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from torchvision import utils 4 | import torch 5 | import copy 6 | from utils.sample import sample_noises 7 | from utils.losses import loss_hinge_dis, loss_hinge_gen 8 | import time 9 | 10 | 11 | class GanTrainer(object): 12 | 13 | def __init__(self, iteration, dataset, genenerator, discriminator, gen_optimizer, dis_optimizer, output_dir, 14 | scheduler_g=None, scheduler_d=None, evaluator=None, n_gen_samples=64, n_dis=5, loss_type='hinge', 15 | display_interval=100, snapshot_interval=1000, evaluation_interval=1000, device=torch.device('cpu')): 16 | self.gen = genenerator.to(device) 17 | self.dis = discriminator.to(device) 18 | # self.mirror_gen = copy.deepcopy(genenerator).to(device) 19 | # self.mirror_gen.eval() 20 | self.gen_optimizer = gen_optimizer 21 | self.dis_optimizer = dis_optimizer 22 | self.scheduler_g = scheduler_g 23 | self.scheduler_d = scheduler_d 24 | self.num_categories = genenerator.n_categories 25 | self.dataset = dataset 26 | self.batch_size = dataset.batch_size 27 | self.z_dim = genenerator.z_dim 28 | self.device = device 29 | self.n_dis = n_dis 30 | self.n_gen_samples = n_gen_samples 31 | self.fixed_noise, self.fixed_y = sample_noises(self.n_gen_samples, self.z_dim, self.num_categories, device) 32 | if loss_type == "hinge": 33 | self.loss_gen = loss_hinge_gen 34 | self.loss_dis = loss_hinge_dis 35 | else: 36 | raise NotImplementedError 37 | 38 | self.iteration = iteration 39 | self.evaluator = evaluator 40 | self.display_interval = display_interval 41 | self.snapshot_interval = snapshot_interval 42 | self.evaluation_interval = evaluation_interval 43 | self.output_dir = output_dir 44 | self.snapshot_dir = os.path.join(output_dir, 'snapshots') 45 | self.sample_dir = os.path.join(output_dir, 'samples') 46 | self.n_row = max(int(math.sqrt(n_gen_samples)), 1) 47 | 48 | def create_snapshot_dir(self): 49 | if not os.path.exists(self.output_dir): 50 | os.mkdir(self.output_dir) 51 | if not os.path.exists(self.snapshot_dir): 52 | os.mkdir(self.snapshot_dir) 53 | if not os.path.exists(self.sample_dir): 54 | os.mkdir(self.sample_dir) 55 | 56 | def gen_samples(self, gen): 57 | with torch.no_grad(): 58 | fake = gen(self.fixed_noise, self.fixed_y).detach().cpu() * .5 + .5 59 | return fake 60 | 61 | def save(self, gen_path, dis_path): 62 | torch.save(self.gen.state_dict(), gen_path) 63 | torch.save(self.dis.state_dict(), dis_path) 64 | 65 | def load(self, gen_path, dis_path): 66 | self.gen.load_state_dict(torch.load(gen_path)) 67 | self.dis.load_state_dict(torch.load(dis_path)) 68 | 69 | def update(self, x_, y_): 70 | 71 | self.gen_optimizer.zero_grad() 72 | 73 | noise, y_fake = sample_noises(self.batch_size, self.z_dim, self.num_categories, self.device) 74 | # noise, y_fake = sample_noises(self.batch_size*(1 + self.n_dis), self.z_dim, self.num_categories, self.device) 75 | x_fake = self.gen(noise, y_fake) 76 | dis_fake = self.dis(x_fake, y_fake) 77 | gen_loss = loss_hinge_gen(dis_fake) 78 | gen_loss.backward() 79 | 80 | self.gen_optimizer.step() 81 | 82 | for i in range(self.n_dis): 83 | x_real = x_[self.batch_size*i: self.batch_size*i + self.batch_size] 84 | y_real = y_[self.batch_size*i: self.batch_size*i + self.batch_size] 85 | noise, y_fake = sample_noises(self.batch_size, self.z_dim, self.num_categories, self.device) 86 | self.dis_optimizer.zero_grad() 87 | dis_real = self.dis(x_real, y_real) 88 | dis_fake = self.dis(self.gen(noise, y_fake).detach(), y_fake) 89 | 90 | disc_loss = self.loss_dis(dis_fake, dis_real) 91 | disc_loss.backward() 92 | self.dis_optimizer.step() 93 | 94 | if self.scheduler_d and self.scheduler_g: 95 | self.scheduler_g.step() 96 | self.scheduler_d.step() 97 | 98 | return disc_loss, gen_loss 99 | 100 | def run(self): 101 | self.create_snapshot_dir() 102 | update_t = 0 103 | dis_losses = [] 104 | gen_losses = [] 105 | for i in range(1, self.iteration + 1): 106 | st_t = time.time() 107 | x = [] 108 | y = [] 109 | for _ in range(self.n_dis): 110 | x_, y_ = self.dataset.get_next() 111 | x_ = x_.to(self.device) 112 | y_ = y_.to(self.device) 113 | x.append(x_) 114 | y.append(y_) 115 | x = torch.cat(x, 0) 116 | y = torch.cat(y, 0) 117 | disc_loss, gen_loss = self.update(x, y) 118 | update_t += (time.time() - st_t) 119 | dis_losses.append(disc_loss.item()) 120 | gen_losses.append(gen_loss.item()) 121 | if i % self.display_interval == 0 or i == self.iteration: 122 | interval = self.display_interval if i != self.iteration else self.iteration % self.display_interval 123 | if interval == 0: 124 | interval = self.display_interval 125 | diff_t = interval / update_t 126 | print('[%d]\tLoss_D: %.4f\tLoss_G: %.4f, %.4f iters/sec' 127 | % (i, sum(dis_losses)/float(interval), sum(gen_losses)/float(interval), diff_t)) 128 | dis_losses = [] 129 | gen_losses = [] 130 | update_t = 0 131 | if i % self.snapshot_interval == 0 or i == self.iteration: 132 | self.save(os.path.join(self.snapshot_dir, "gen_%d.pt" % i), 133 | os.path.join(self.snapshot_dir, "dis_%d.pt" % i)) 134 | if self.evaluator and (i % self.evaluation_interval == 0 or i == self.iteration): 135 | print("evaluating inception score....") 136 | self.gen.eval() 137 | # self.mirror_gen.load_state_dict(self.gen.state_dict()) 138 | score, _ = self.evaluator.eval_gen(self.gen) 139 | fake = self.gen_samples(self.gen) 140 | utils.save_image(fake, os.path.join(self.sample_dir, "%d_img.png" % i), nrow=self.n_row, padding=2) 141 | print("[%d] evaluated inception score: %.4f" % (i, score)) 142 | self.gen.train() 143 | print("Training Done !") 144 | 145 | 146 | -------------------------------------------------------------------------------- /nets/layers/spectral_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn import Parameter 5 | 6 | 7 | def l2normalize(v, eps=1e-12): 8 | return v / (v.norm() + eps) 9 | 10 | 11 | # def weight_bar(w, u, pi): 12 | # # _, sigma, _ = torch.svd(w) 13 | # # sigma = sigma[0] 14 | # 15 | # w_mat = w.data.view(w.data.shape[0], -1) 16 | # 17 | # for _ in range(pi): 18 | # v = l2normalize(torch.mv(torch.t(w_mat), u)) 19 | # 20 | # u = l2normalize(torch.mv(w_mat, v)) 21 | # 22 | # sigma = torch.dot(torch.mv(torch.t(w_mat), u), v) 23 | # w_bar = w / sigma 24 | # 25 | # return w_bar, u, sigma 26 | 27 | # 28 | # class SpectralNorm(torch.nn.Module): 29 | # 30 | # def __init__(self, out_features, power_iterations=1): 31 | # super(SpectralNorm, self).__init__() 32 | # self.power_iterations = power_iterations 33 | # self.out_features = out_features 34 | # # self.register_buffer("u", torch.randn(out_features, requires_grad=False)) 35 | # 36 | # self.register_buffer("u", torch.randn((1, out_features), requires_grad=False)) 37 | # 38 | # def forward(self, w): 39 | # w_mat = w.view(w.data.shape[0], -1) 40 | # 41 | # # with torch.no_grad(): 42 | # # _, sigma, _ = torch.svd(w_mat) 43 | # # sigma = sigma[0] 44 | # 45 | # # 46 | # u = self.u 47 | # with torch.no_grad(): 48 | # for _ in range(self.power_iterations): 49 | # v = l2normalize(torch.mm(u, w_mat.data)) 50 | # 51 | # u = l2normalize(torch.mm(v, torch.t(w_mat.data))) 52 | # 53 | # # v = l2normalize(torch.mv(torch.t(w_mat), self.u)) 54 | # 55 | # # u = l2normalize(torch.mv(w_mat, v)) 56 | # 57 | # # sigma = u.dot(w_mat.mv(v)) 58 | # sigma = torch.sum(torch.mm(u, w_mat) * v) 59 | # 60 | # if self.training: 61 | # self.u = u 62 | # w_bar = torch.div(w, sigma) 63 | # # w_bar = w / sigma.expand_as(w.data) 64 | # 65 | # return w_bar, sigma 66 | 67 | 68 | def max_singular_value(w_mat, u, power_iterations): 69 | 70 | for _ in range(power_iterations): 71 | v = l2normalize(torch.mm(u, w_mat.data)) 72 | 73 | u = l2normalize(torch.mm(v, torch.t(w_mat.data))) 74 | 75 | sigma = torch.sum(torch.mm(u, w_mat) * v) 76 | 77 | return u, sigma, v 78 | 79 | 80 | 81 | class Linear(torch.nn.Linear): 82 | 83 | def __init__(self, *args, spectral_norm_pi=1, **kwargs): 84 | super(Linear, self).__init__(*args, **kwargs) 85 | self.spectral_norm_pi = spectral_norm_pi 86 | if spectral_norm_pi > 0: 87 | self.register_buffer("u", torch.randn((1, self.out_features), requires_grad=False)) 88 | else: 89 | self.register_buffer("u", None) 90 | if self.bias is not None: 91 | torch.nn.init.constant_(self.bias.data, 0) 92 | 93 | 94 | def forward(self, input): 95 | if self.spectral_norm_pi > 0: 96 | w_mat = self.weight.view(self.out_features, -1) 97 | u, sigma, _ = max_singular_value(w_mat, self.u, self.spectral_norm_pi) 98 | 99 | # w_bar = torch.div(w_mat, sigma) 100 | w_bar = torch.div(self.weight, sigma) 101 | if self.training: 102 | self.u = u 103 | # self.w_bar = w_bar.detach() 104 | # self.sigma = sigma.detach() 105 | else: 106 | w_bar = self.weight 107 | return F.linear(input, w_bar, self.bias) 108 | 109 | 110 | class Conv2d(torch.nn.Conv2d): 111 | 112 | def __init__(self, *args, spectral_norm_pi=1, **kwargs): 113 | super(Conv2d, self).__init__(*args, **kwargs) 114 | self.spectral_norm_pi = spectral_norm_pi 115 | if spectral_norm_pi > 0: 116 | self.register_buffer("u", torch.randn((1, self.out_channels), requires_grad=False)) 117 | else: 118 | self.register_buffer("u", None) 119 | if self.bias is not None: 120 | torch.nn.init.constant_(self.bias.data, 0) 121 | 122 | def forward(self, input): 123 | if self.spectral_norm_pi > 0: 124 | w_mat = self.weight.view(self.out_channels, -1) 125 | u, sigma, _ = max_singular_value(w_mat, self.u, self.spectral_norm_pi) 126 | w_bar = torch.div(self.weight, sigma) 127 | if self.training: 128 | self.u = u 129 | else: 130 | w_bar = self.weight 131 | 132 | return F.conv2d(input, w_bar, self.bias, self.stride, 133 | self.padding, self.dilation, self.groups) 134 | 135 | 136 | class Embedding(torch.nn.Embedding): 137 | 138 | def __init__(self, *args, spectral_norm_pi=1, **kwargs): 139 | super(Embedding, self).__init__(*args, **kwargs) 140 | self.spectral_norm_pi = spectral_norm_pi 141 | if spectral_norm_pi > 0: 142 | self.register_buffer("u", torch.randn((1, self.num_embeddings), requires_grad=False)) 143 | else: 144 | self.register_buffer("u", None) 145 | 146 | def forward(self, input): 147 | if self.spectral_norm_pi > 0: 148 | w_mat = self.weight.view(self.num_embeddings, -1) 149 | u, sigma, _ = max_singular_value(w_mat, self.u, self.spectral_norm_pi) 150 | w_bar = torch.div(self.weight, sigma) 151 | if self.training: 152 | self.u = u 153 | else: 154 | w_bar = self.weight 155 | 156 | return F.embedding( 157 | input, w_bar, self.padding_idx, self.max_norm, 158 | self.norm_type, self.scale_grad_by_freq, self.sparse) 159 | 160 | 161 | # class SpectralNorm(nn.Module): 162 | # def __init__(self, module, name='weight', power_iterations=1): 163 | # super(SpectralNorm, self).__init__() 164 | # self.module = module 165 | # self.name = name 166 | # self.power_iterations = power_iterations 167 | # self.sigma = None 168 | # self.w_bar = None 169 | # self.register_buffer("u", torch.randn(getattr(module, name).data.shape[0], requires_grad=False)) 170 | # 171 | # 172 | # def _update_u_v(self): 173 | # 174 | # w = getattr(self.module, self.name) 175 | # _, sigma, _ = torch.svd(w.data) 176 | # self.sigma = sigma[0] 177 | # 178 | # self.w_bar = w / self.sigma.expand_as(w) 179 | # 180 | # w.data = self.w_bar 181 | # 182 | # # w_mat = w.data.view(w.data.shape[0], -1) 183 | # 184 | # # for _ in range(self.power_iterations): 185 | # # v = l2normalize(torch.mv(torch.t(w_mat), self.u)) 186 | # # 187 | # # self.u = l2normalize(torch.mv(w_mat, v)) 188 | # # 189 | # # self.sigma = torch.dot(torch.mv(torch.t(w_mat), self.u), v) 190 | # # w.data = w.data/self.sigma 191 | # 192 | # # setattr(self.module, self.name, w) 193 | # 194 | # 195 | # def forward(self, input): 196 | # # self._update_u_v() 197 | # w = getattr(self.module, self.name) 198 | # _, sigma, _ = torch.svd(w.data) 199 | # self.sigma = sigma[0] 200 | # 201 | # self.w_bar = w / self.sigma.expand_as(w) 202 | # return torch.nn.functional.linear(input, self.w_bar, self.module.bias) 203 | # return self.module.forward(*args) -------------------------------------------------------------------------------- /nets/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from nets.layers.categorical_batch_norm import CategoricalBatchNorm 3 | from nets.layers.spectral_norm import * 4 | 5 | 6 | class Block(torch.nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, hidden_channels=None, 9 | kernel_size=3, stride=1, padding=1, optimized=False, spectral_norm=1): 10 | super(Block, self).__init__() 11 | self.in_channels = in_channels 12 | self.out_channels = out_channels 13 | self.optimized = optimized 14 | self.hidden_channels = out_channels if not hidden_channels else hidden_channels 15 | 16 | self.conv1 = Conv2d(self.in_channels, self.hidden_channels, 17 | kernel_size=kernel_size, stride=stride, padding=padding, spectral_norm_pi=spectral_norm) 18 | self.conv2 = Conv2d(self.hidden_channels, self.out_channels, 19 | kernel_size=kernel_size, stride=stride, padding=padding, spectral_norm_pi=spectral_norm) 20 | self.s_conv = None 21 | torch.nn.init.xavier_uniform_(self.conv1.weight.data, math.sqrt(2)) 22 | torch.nn.init.xavier_uniform_(self.conv2.weight.data, math.sqrt(2)) 23 | if self.in_channels != self.out_channels or optimized: 24 | self.s_conv = Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0, 25 | spectral_norm_pi=spectral_norm) 26 | torch.nn.init.xavier_uniform_(self.s_conv.weight.data, 1.) 27 | 28 | self.activate = torch.nn.ReLU() 29 | 30 | def residual(self, input): 31 | x = self.conv1(input) 32 | x = self.activate(x) 33 | x = self.conv2(x) 34 | if self.optimized: 35 | x = torch.nn.functional.avg_pool2d(x, 2) 36 | return x 37 | 38 | def shortcut(self, input): 39 | x = input 40 | if self.optimized: 41 | x = torch.nn.functional.avg_pool2d(x, 2) 42 | if self.s_conv: 43 | x = self.s_conv(x) 44 | return x 45 | 46 | 47 | def forward(self, input): 48 | x = self.residual(input) 49 | x_r = self.shortcut(input) 50 | return x + x_r 51 | 52 | 53 | class Gblock(Block): 54 | 55 | def __init__(self, in_channels, out_channels, hidden_channels=None, num_categories=None, 56 | kernel_size=3, stride=1, padding=1, upsample=True): 57 | super(Gblock, self).__init__(in_channels, out_channels, hidden_channels, kernel_size, stride, padding, 58 | upsample, spectral_norm=0) 59 | self.upsample = upsample 60 | self.num_categories = num_categories 61 | 62 | self.bn1 = self.batch_norm(self.in_channels) 63 | self.bn2 = self.batch_norm(self.hidden_channels) 64 | if upsample: 65 | # self.up = torch.nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2) 66 | self.up = lambda a: torch.nn.functional.interpolate(a, scale_factor=2) 67 | else: 68 | self.up = lambda a: None 69 | 70 | def batch_norm(self, num_features): 71 | return torch.nn.BatchNorm2d(num_features) if not self.num_categories \ 72 | else CategoricalBatchNorm(num_features, self.num_categories) 73 | 74 | def residual(self, input, y=None): 75 | x = input 76 | x = self.bn1(x, y) if self.num_categories else self.bn1(x) 77 | x = self.activate(x) 78 | if self.upsample: 79 | x = self.up(x) 80 | # output_size = list(input.size()) 81 | # output_size[-1] = output_size[-1] * 2 82 | # output_size[-2] = output_size[-2] * 2 83 | # x = self.up(x, output_size=output_size) 84 | x = self.conv1(x) 85 | x = self.bn2(x, y) if self.num_categories else self.bn2(x) 86 | x = self.activate(x) 87 | x = self.conv2(x) 88 | return x 89 | 90 | def shortcut(self, input): 91 | x = input 92 | if self.upsample: 93 | x = self.up(x) 94 | if self.s_conv: 95 | x = self.s_conv(x) 96 | return x 97 | 98 | def forward(self, input, y=None): 99 | x = self.residual(input, y) 100 | x_r = self.shortcut(input) 101 | return x + x_r 102 | 103 | 104 | class Dblock(Block): 105 | 106 | def __init__(self, in_channels, out_channels, hidden_channels=None, kernel_size=3, stride=1, padding=1, 107 | downsample=False, spectral_norm=1): 108 | super(Dblock, self).__init__(in_channels, out_channels, hidden_channels, kernel_size, stride, padding, 109 | downsample, spectral_norm) 110 | self.downsample = downsample 111 | 112 | def residual(self, input): 113 | x = self.activate(input) 114 | x = self.conv1(x) 115 | x = self.activate(x) 116 | x = self.conv2(x) 117 | if self.downsample: 118 | x = torch.nn.functional.avg_pool2d(x, 2) 119 | return x 120 | 121 | def shortcut(self, input): 122 | x = input 123 | if self.s_conv: 124 | x = self.s_conv(x) 125 | if self.downsample: 126 | x = torch.nn.functional.avg_pool2d(x, 2) 127 | return x 128 | 129 | def forward(self, input): 130 | x = self.residual(input) 131 | x_r = self.shortcut(input) 132 | return x + x_r 133 | 134 | 135 | class BaseGenerator(torch.nn.Module): 136 | 137 | def __init__(self, z_dim, ch, d_ch=None, n_categories=None, bottom_width=4): 138 | super(BaseGenerator, self).__init__() 139 | self.z_dim = z_dim 140 | self.ch = ch 141 | self.d_ch = d_ch if d_ch else ch 142 | self.n_categories = n_categories 143 | self.bottom_width = bottom_width 144 | self.dense = torch.nn.Linear(self.z_dim, self.bottom_width * self.bottom_width * self.d_ch) 145 | torch.nn.init.xavier_uniform_(self.dense.weight.data, 1.) 146 | self.blocks = torch.nn.ModuleList() 147 | self.final = self.final_block() 148 | 149 | def final_block(self): 150 | conv = torch.nn.Conv2d(self.ch, 3, kernel_size=3, stride=1, padding=1) 151 | torch.nn.init.xavier_uniform_(conv.weight.data, 1.) 152 | final_ = torch.nn.Sequential( 153 | torch.nn.BatchNorm2d(self.ch), 154 | torch.nn.ReLU(), 155 | conv, 156 | torch.nn.Tanh() 157 | ) 158 | return final_ 159 | 160 | 161 | def forward(self, input, y=None): 162 | x = self.dense(input) 163 | x = x.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 164 | for block in self.blocks: 165 | x = block(x, y) 166 | x = self.final(x) 167 | return x 168 | 169 | 170 | class ResnetGenerator(BaseGenerator): 171 | 172 | def __init__(self, ch=64, z_dim=128, n_categories=None, bottom_width=4): 173 | super(ResnetGenerator, self).__init__(z_dim, ch, ch*16, n_categories, bottom_width) 174 | self.blocks.append(Gblock(self.ch * 16, self.ch * 16, upsample=True, num_categories=self.n_categories)) 175 | self.blocks.append(Gblock(self.ch * 16, self.ch * 8, upsample=True, num_categories=self.n_categories)) 176 | self.blocks.append(Gblock(self.ch * 8, self.ch * 4, upsample=True, num_categories=self.n_categories)) 177 | self.blocks.append(Gblock(self.ch * 4, self.ch * 2, upsample=True, num_categories=self.n_categories)) 178 | self.blocks.append(Gblock(self.ch * 2, self.ch, upsample=True, num_categories=self.n_categories)) 179 | 180 | 181 | class ResnetGenerator32(BaseGenerator): 182 | 183 | def __init__(self, ch=256, z_dim=128, n_categories=None, bottom_width=4): 184 | super(ResnetGenerator32, self).__init__(z_dim, ch, ch, n_categories, bottom_width) 185 | self.blocks.append(Gblock(self.ch, self.ch, upsample=True, num_categories=self.n_categories)) 186 | self.blocks.append(Gblock(self.ch, self.ch, upsample=True, num_categories=self.n_categories)) 187 | self.blocks.append(Gblock(self.ch, self.ch, upsample=True, num_categories=self.n_categories)) 188 | 189 | 190 | class ResnetGenerator64(BaseGenerator): 191 | 192 | def __init__(self, ch=64, z_dim=128, n_categories=None, bottom_width=4): 193 | super(ResnetGenerator64, self).__init__(z_dim, ch, ch*16, n_categories, bottom_width) 194 | self.blocks.append(Gblock(self.ch*16, self.ch*8, upsample=True, num_categories=self.n_categories)) 195 | self.blocks.append(Gblock(self.ch*8, self.ch*4, upsample=True, num_categories=self.n_categories)) 196 | self.blocks.append(Gblock(self.ch*4, self.ch*2, upsample=True, num_categories=self.n_categories)) 197 | self.blocks.append(Gblock(self.ch*2, self.ch, upsample=True, num_categories=self.n_categories)) 198 | 199 | 200 | class BaseDiscriminator(torch.nn.Module): 201 | 202 | def __init__(self, in_ch, out_ch=None, n_categories=0, l_bias=True, spectral_norm=1): 203 | super(BaseDiscriminator, self).__init__() 204 | self.activate = torch.nn.ReLU() 205 | self.ch = in_ch 206 | self.out_ch = out_ch if out_ch else in_ch 207 | self.n_categories = n_categories 208 | self.blocks = torch.nn.ModuleList([Block(3, self.ch, optimized=True, spectral_norm=spectral_norm)]) 209 | self.l = Linear(self.out_ch, 1, l_bias, spectral_norm_pi=spectral_norm) 210 | torch.nn.init.xavier_uniform_(self.l.weight.data, 1.) 211 | if n_categories > 0: 212 | self.l_y = Embedding(n_categories, self.out_ch, spectral_norm_pi=spectral_norm) 213 | torch.nn.init.xavier_uniform_(self.l_y.weight.data, 1.) 214 | 215 | def forward(self, input, y=None): 216 | x = input 217 | for block in self.blocks: 218 | x = block(x) 219 | x = self.activate(x) 220 | x = torch.sum(x, (2, 3)) 221 | output = self.l(x) 222 | if y is not None: 223 | w_y = self.l_y(y) 224 | output += torch.sum(w_y*x, dim=1, keepdim=True) 225 | return output 226 | 227 | 228 | class ResnetDiscriminator(BaseDiscriminator): 229 | 230 | def __init__(self, ch=64, n_categories=0, spectral_norm=1): 231 | super(ResnetDiscriminator, self).__init__(ch, ch*16, n_categories, spectral_norm=spectral_norm) 232 | self.blocks.append(Dblock(self.ch, self.ch*2, downsample=True, spectral_norm=spectral_norm)) 233 | self.blocks.append(Dblock(self.ch*2, self.ch*4, downsample=True, spectral_norm=spectral_norm)) 234 | self.blocks.append(Dblock(self.ch*4, self.ch*8, downsample=True, spectral_norm=spectral_norm)) 235 | self.blocks.append(Dblock(self.ch*8, self.ch*16, downsample=True, spectral_norm=spectral_norm)) 236 | self.blocks.append(Dblock(self.ch*16, self.ch*16, downsample=False, spectral_norm=spectral_norm)) 237 | 238 | 239 | class ResnetDiscriminator32(BaseDiscriminator): 240 | 241 | def __init__(self, ch=128, n_categories=0, spectral_norm=1): 242 | super(ResnetDiscriminator32, self).__init__(ch, ch, n_categories, l_bias=False, spectral_norm=spectral_norm) 243 | self.blocks.append(Dblock(self.ch, self.ch, downsample=True, spectral_norm=spectral_norm)) 244 | self.blocks.append(Dblock(self.ch, self.ch, downsample=False, spectral_norm=spectral_norm)) 245 | self.blocks.append(Dblock(self.ch, self.ch, downsample=False, spectral_norm=spectral_norm)) 246 | 247 | 248 | class ResnetDiscriminator64(BaseDiscriminator): 249 | 250 | def __init__(self, ch=64, n_categories=0, spectral_norm=1): 251 | super(ResnetDiscriminator64, self).__init__(ch, ch*16, n_categories, spectral_norm=spectral_norm) 252 | self.blocks.append(Dblock(self.ch, self.ch*2, downsample=True, spectral_norm=spectral_norm)) 253 | self.blocks.append(Dblock(self.ch*2, self.ch*4, downsample=True, spectral_norm=spectral_norm)) 254 | self.blocks.append(Dblock(self.ch*4, self.ch*8, downsample=True, spectral_norm=spectral_norm)) 255 | self.blocks.append(Dblock(self.ch*8, self.ch*16, downsample=True, spectral_norm=spectral_norm)) 256 | --------------------------------------------------------------------------------