├── .gitignore ├── LICENSE ├── README.md ├── configs ├── celebAHQ.yaml ├── default.yaml ├── imagenet.yaml ├── lsun_bedroom.yaml ├── lsun_bridge.yaml ├── lsun_church.yaml ├── lsun_tower.yaml └── pretrained │ ├── celebAHQ_pretrained.yaml │ ├── celebA_pretrained.yaml │ ├── imagenet_pretrained.yaml │ ├── lsun_bedroom_pretrained.yaml │ ├── lsun_bridge_pretrained.yaml │ ├── lsun_church_pretrained.yaml │ └── lsun_tower_pretrained.yaml ├── gan_training ├── __init__.py ├── checkpoints.py ├── config.py ├── distributions.py ├── eval.py ├── inputs.py ├── logger.py ├── metrics │ ├── __init__.py │ └── inception_score.py ├── models │ ├── __init__.py │ ├── resnet.py │ ├── resnet2.py │ ├── resnet3.py │ └── resnet4.py ├── ops.py ├── train.py └── utils.py ├── interpolate.py ├── interpolate_class.py ├── notebooks ├── DiracGAN.ipynb ├── create_video.sh └── diracgan │ ├── __init__.py │ ├── gans.py │ ├── plotting.py │ ├── simulate.py │ ├── subplots.py │ └── util.py ├── results ├── celebA-HQ.jpg ├── imagenet_00.jpg ├── imagenet_01.jpg ├── imagenet_02.jpg ├── imagenet_03.jpg └── imagenet_04.jpg ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | data 3 | *_lmdb 4 | __pycache__ 5 | *.pyc 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Lars Mescheder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAN stability 2 | This repository contains the experiments in the supplementary material for the paper [Which Training Methods for GANs do actually Converge?](https://avg.is.tuebingen.mpg.de/publications/meschedericml2018). 3 | 4 | To cite this work, please use 5 | ``` 6 | @INPROCEEDINGS{Mescheder2018ICML, 7 | author = {Lars Mescheder and Sebastian Nowozin and Andreas Geiger}, 8 | title = {Which Training Methods for GANs do actually Converge?}, 9 | booktitle = {International Conference on Machine Learning (ICML)}, 10 | year = {2018} 11 | } 12 | ``` 13 | You can find further details on [our project page](https://avg.is.tuebingen.mpg.de/research_projects/convergence-and-stability-of-gan-training). 14 | 15 | # Usage 16 | First download your data and put it into the `./data` folder. 17 | 18 | To train a new model, first create a config script similar to the ones provided in the `./configs` folder. You can then train you model using 19 | ``` 20 | python train.py PATH_TO_CONFIG 21 | ``` 22 | 23 | To compute the inception score for your model and generate samples, use 24 | ``` 25 | python test.py PATH_TO_CONFIG 26 | ``` 27 | 28 | Finally, you can create nice latent space interpolations using 29 | ``` 30 | python interpolate.py PATH_TO_CONFIG 31 | ``` 32 | or 33 | ``` 34 | python interpolate_class.py PATH_TO_CONFIG 35 | ``` 36 | 37 | # Pretrained models 38 | We also provide several pretrained models. 39 | 40 | You can use the models for sampling by entering 41 | ``` 42 | python test.py PATH_TO_CONFIG 43 | ``` 44 | where `PATH_TO_CONFIG` is one of the config files 45 | ``` 46 | configs/pretrained/celebA_pretrained.yaml 47 | configs/pretrained/celebAHQ_pretrained.yaml 48 | configs/pretrained/imagenet_pretrained.yaml 49 | configs/pretrained/lsun_bedroom_pretrained.yaml 50 | configs/pretrained/lsun_bridge_pretrained.yaml 51 | configs/pretrained/lsun_church_pretrained.yaml 52 | configs/pretrained/lsun_tower_pretrained.yaml 53 | ``` 54 | Our script will automatically download the model checkpoints and run the generation. 55 | You can find the outputs in the `output/pretrained` folders. 56 | Similarly, you can use the scripts `interpolate.py` and `interpolate_class.py` for generating interpolations for the pretrained models. 57 | 58 | Please note that the config files `*_pretrained.yaml` are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model. 59 | 60 | # Notes 61 | * Batch normalization is currently *not* supported when using an exponential running average, as the running average is only computed over the parameters of the models and not the other buffers of the model. 62 | 63 | # Results 64 | ## celebA-HQ 65 | ![celebA-HQ](results/celebA-HQ.jpg) 66 | 67 | ## Imagenet 68 | ![Imagenet 0](results/imagenet_00.jpg) 69 | ![Imagenet 1](results/imagenet_01.jpg) 70 | ![Imagenet 2](results/imagenet_02.jpg) 71 | ![Imagenet 3](results/imagenet_03.jpg) 72 | ![Imagenet 4](results/imagenet_04.jpg) 73 | -------------------------------------------------------------------------------- /configs/celebAHQ.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: npy 3 | train_dir: data/celebA-HQ 4 | test_dir: data/celebA-HQ 5 | img_size: 1024 6 | generator: 7 | name: resnet 8 | kwargs: 9 | nfilter: 16 10 | nfilter_max: 512 11 | embed_size: 1 12 | discriminator: 13 | name: resnet 14 | kwargs: 15 | nfilter: 16 16 | nfilter_max: 512 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/celebAHQ 23 | batch_size: 24 24 | test: 25 | batch_size: 4 26 | sample_size: 6 27 | sample_nrow: 3 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bedroom_train] 6 | lsun_categories_test: [bedroom_test] 7 | img_size: 256 8 | nlabels: 1 9 | generator: 10 | name: resnet 11 | kwargs: 12 | discriminator: 13 | name: resnet 14 | kwargs: 15 | z_dist: 16 | type: gauss 17 | dim: 256 18 | training: 19 | out_dir: output/default 20 | gan_type: standard 21 | reg_type: real 22 | reg_param: 10. 23 | batch_size: 64 24 | nworkers: 16 25 | take_model_average: true 26 | model_average_beta: 0.999 27 | model_average_reinit: false 28 | monitoring: tensorboard 29 | sample_every: 1000 30 | sample_nlabels: 20 31 | inception_every: -1 32 | save_every: 900 33 | backup_every: 100000 34 | restart_every: -1 35 | optimizer: rmsprop 36 | lr_g: 0.0001 37 | lr_d: 0.0001 38 | lr_anneal: 1. 39 | lr_anneal_every: 150000 40 | d_steps: 1 41 | equalize_lr: false 42 | model_file: model.pt 43 | test: 44 | batch_size: 32 45 | sample_size: 64 46 | sample_nrow: 8 47 | use_model_average: true 48 | compute_inception: false 49 | conditional_samples: false 50 | model_file: model.pt 51 | interpolations: 52 | nzs: 10 53 | nsubsteps: 75 54 | -------------------------------------------------------------------------------- /configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: image 3 | train_dir: data/Imagenet 4 | test_dir: data/Imagenet 5 | img_size: 128 6 | nlabels: 1000 7 | generator: 8 | name: resnet2 9 | kwargs: 10 | nfilter: 64 11 | nfilter_max: 1024 12 | embed_size: 256 13 | discriminator: 14 | name: resnet2 15 | kwargs: 16 | nfilter: 64 17 | nfilter_max: 1024 18 | embed_size: 256 19 | z_dist: 20 | type: gauss 21 | dim: 256 22 | training: 23 | out_dir: output/imagenet 24 | gan_type: standard 25 | sample_nlabels: 20 26 | inception_every: 10000 27 | batch_size: 128 28 | test: 29 | batch_size: 32 30 | sample_size: 64 31 | sample_nrow: 8 32 | compute_inception: true 33 | conditional_samples: true 34 | interpolations: 35 | ys: [15, 157, 307, 321, 442, 483, 484, 525, 36 | 536, 598, 607, 734, 768, 795, 927, 977, 37 | 963, 946, 979] 38 | nzs: 10 39 | nsubsteps: 75 40 | -------------------------------------------------------------------------------- /configs/lsun_bedroom.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bedroom_train] 6 | lsun_categories_test: [bedroom_test] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_bedroom 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /configs/lsun_bridge.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bridge_train] 6 | lsun_categories_test: [bridge_train] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_bridge 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /configs/lsun_church.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [church_outdoor_train] 6 | lsun_categories_test: [church_outdoor_test] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_church 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /configs/lsun_tower.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [tower_train] 6 | lsun_categories_test: [tower_test] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_tower 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /configs/pretrained/celebAHQ_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: npy 3 | train_dir: data/celebA-HQ 4 | test_dir: data/celebA-HQ 5 | img_size: 1024 6 | generator: 7 | name: resnet 8 | kwargs: 9 | nfilter: 16 10 | nfilter_max: 512 11 | embed_size: 1 12 | discriminator: 13 | name: resnet 14 | kwargs: 15 | nfilter: 16 16 | nfilter_max: 512 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/celebAHQ 23 | batch_size: 24 24 | test: 25 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celebahq-baab46b2.pt 26 | batch_size: 4 27 | sample_size: 6 28 | sample_nrow: 3 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /configs/pretrained/celebA_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: image 3 | train_dir: data/celebA 4 | test_dir: data/celebA 5 | img_size: 256 6 | generator: 7 | name: resnet4 8 | kwargs: 9 | nfilter: 64 10 | embed_size: 1 11 | discriminator: 12 | name: resnet4 13 | kwargs: 14 | nfilter: 64 15 | embed_size: 1 16 | z_dist: 17 | type: gauss 18 | dim: 256 19 | training: 20 | out_dir: output/pretrained/celebA 21 | test: 22 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celeba-ab478c9d.pt 23 | batch_size: 32 24 | sample_size: 64 25 | sample_nrow: 8 26 | interpolations: 27 | nzs: 10 28 | nsubsteps: 75 29 | -------------------------------------------------------------------------------- /configs/pretrained/imagenet_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: image 3 | train_dir: data/Imagenet 4 | test_dir: data/Imagenet 5 | img_size: 128 6 | nlabels: 1000 7 | generator: 8 | name: resnet2 9 | kwargs: 10 | nfilter: 64 11 | nfilter_max: 1024 12 | embed_size: 256 13 | discriminator: 14 | name: resnet2 15 | kwargs: 16 | nfilter: 64 17 | nfilter_max: 1024 18 | embed_size: 256 19 | z_dist: 20 | type: gauss 21 | dim: 256 22 | training: 23 | out_dir: output/pretrained/imagenet 24 | sample_nlabels: 20 25 | inception_every: 10000 26 | batch_size: 128 27 | test: 28 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/imagenet-8c505f47.pt 29 | batch_size: 32 30 | sample_size: 64 31 | sample_nrow: 8 32 | compute_inception: false 33 | conditional_samples: true 34 | interpolations: 35 | ys: [15, 157, 307, 321, 442, 483, 484, 525, 36 | 536, 598, 607, 734, 768, 795, 927, 977, 37 | 963, 946, 979] 38 | nzs: 10 39 | nsubsteps: 75 40 | -------------------------------------------------------------------------------- /configs/pretrained/lsun_bedroom_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bedroom_train] 6 | lsun_categories_test: [bedroom_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_bedroom 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bedroom-df4e7dd2.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /configs/pretrained/lsun_bridge_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bridge_train] 6 | lsun_categories_test: [bridge_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_bridge 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bridge-82887d22.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /configs/pretrained/lsun_church_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [church_outdoor_train] 6 | lsun_categories_test: [church_outdoor_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_church 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_church-b6f0191b.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /configs/pretrained/lsun_tower_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [tower_train] 6 | lsun_categories_test: [tower_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_tower 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_tower-1af5e570.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /gan_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/gan_training/__init__.py -------------------------------------------------------------------------------- /gan_training/checkpoints.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import urllib 4 | import torch 5 | from torch.utils import model_zoo 6 | 7 | 8 | class CheckpointIO(object): 9 | ''' CheckpointIO class. 10 | 11 | It handles saving and loading checkpoints. 12 | 13 | Args: 14 | checkpoint_dir (str): path where checkpoints are saved 15 | ''' 16 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 17 | self.module_dict = kwargs 18 | self.checkpoint_dir = checkpoint_dir 19 | if not os.path.exists(checkpoint_dir): 20 | os.makedirs(checkpoint_dir) 21 | 22 | def register_modules(self, **kwargs): 23 | ''' Registers modules in current module dictionary. 24 | ''' 25 | self.module_dict.update(kwargs) 26 | 27 | def save(self, filename, **kwargs): 28 | ''' Saves the current module dictionary. 29 | 30 | Args: 31 | filename (str): name of output file 32 | ''' 33 | if not os.path.isabs(filename): 34 | filename = os.path.join(self.checkpoint_dir, filename) 35 | 36 | outdict = kwargs 37 | for k, v in self.module_dict.items(): 38 | outdict[k] = v.state_dict() 39 | torch.save(outdict, filename) 40 | 41 | def load(self, filename): 42 | '''Loads a module dictionary from local file or url. 43 | 44 | Args: 45 | filename (str): name of saved module dictionary 46 | ''' 47 | if is_url(filename): 48 | return self.load_url(filename) 49 | else: 50 | return self.load_file(filename) 51 | 52 | def load_file(self, filename): 53 | '''Loads a module dictionary from file. 54 | 55 | Args: 56 | filename (str): name of saved module dictionary 57 | ''' 58 | 59 | if not os.path.isabs(filename): 60 | filename = os.path.join(self.checkpoint_dir, filename) 61 | 62 | if os.path.exists(filename): 63 | print(filename) 64 | print('=> Loading checkpoint from local file...') 65 | state_dict = torch.load(filename) 66 | scalars = self.parse_state_dict(state_dict) 67 | return scalars 68 | else: 69 | raise FileNotFoundError 70 | 71 | def load_url(self, url): 72 | '''Load a module dictionary from url. 73 | 74 | Args: 75 | url (str): url to saved model 76 | ''' 77 | print(url) 78 | print('=> Loading checkpoint from url...') 79 | state_dict = model_zoo.load_url(url, progress=True) 80 | scalars = self.parse_state_dict(state_dict) 81 | return scalars 82 | 83 | def parse_state_dict(self, state_dict): 84 | '''Parse state_dict of model and return scalars. 85 | 86 | Args: 87 | state_dict (dict): State dict of model 88 | ''' 89 | 90 | for k, v in self.module_dict.items(): 91 | if k in state_dict: 92 | v.load_state_dict(state_dict[k]) 93 | else: 94 | print('Warning: Could not find %s in checkpoint!' % k) 95 | scalars = {k: v for k, v in state_dict.items() 96 | if k not in self.module_dict} 97 | return scalars 98 | 99 | def is_url(url): 100 | scheme = urllib.parse.urlparse(url).scheme 101 | return scheme in ('http', 'https') -------------------------------------------------------------------------------- /gan_training/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torch import optim 3 | from os import path 4 | from gan_training.models import generator_dict, discriminator_dict 5 | from gan_training.train import toggle_grad 6 | 7 | 8 | # General config 9 | def load_config(path, default_path): 10 | ''' Loads config file. 11 | 12 | Args: 13 | path (str): path to config file 14 | default_path (bool): whether to use default path 15 | ''' 16 | # Load configuration from file itself 17 | with open(path, 'r') as f: 18 | cfg_special = yaml.load(f) 19 | 20 | # Check if we should inherit from a config 21 | inherit_from = cfg_special.get('inherit_from') 22 | 23 | # If yes, load this config first as default 24 | # If no, use the default_path 25 | if inherit_from is not None: 26 | cfg = load_config(inherit_from, default_path) 27 | elif default_path is not None: 28 | with open(default_path, 'r') as f: 29 | cfg = yaml.load(f) 30 | else: 31 | cfg = dict() 32 | 33 | # Include main configuration 34 | update_recursive(cfg, cfg_special) 35 | 36 | return cfg 37 | 38 | 39 | def update_recursive(dict1, dict2): 40 | ''' Update two config dictionaries recursively. 41 | 42 | Args: 43 | dict1 (dict): first dictionary to be updated 44 | dict2 (dict): second dictionary which entries should be used 45 | 46 | ''' 47 | for k, v in dict2.items(): 48 | # Add item if not yet in dict1 49 | if k not in dict1: 50 | dict1[k] = None 51 | # Update 52 | if isinstance(dict1[k], dict): 53 | update_recursive(dict1[k], v) 54 | else: 55 | dict1[k] = v 56 | 57 | 58 | def build_models(config): 59 | # Get classes 60 | Generator = generator_dict[config['generator']['name']] 61 | Discriminator = discriminator_dict[config['discriminator']['name']] 62 | 63 | # Build models 64 | generator = Generator( 65 | z_dim=config['z_dist']['dim'], 66 | nlabels=config['data']['nlabels'], 67 | size=config['data']['img_size'], 68 | **config['generator']['kwargs'] 69 | ) 70 | discriminator = Discriminator( 71 | config['discriminator']['name'], 72 | nlabels=config['data']['nlabels'], 73 | size=config['data']['img_size'], 74 | **config['discriminator']['kwargs'] 75 | ) 76 | 77 | return generator, discriminator 78 | 79 | 80 | def build_optimizers(generator, discriminator, config): 81 | optimizer = config['training']['optimizer'] 82 | lr_g = config['training']['lr_g'] 83 | lr_d = config['training']['lr_d'] 84 | equalize_lr = config['training']['equalize_lr'] 85 | 86 | toggle_grad(generator, True) 87 | toggle_grad(discriminator, True) 88 | 89 | if equalize_lr: 90 | g_gradient_scales = getattr(generator, 'gradient_scales', dict()) 91 | d_gradient_scales = getattr(discriminator, 'gradient_scales', dict()) 92 | 93 | g_params = get_parameter_groups(generator.parameters(), 94 | g_gradient_scales, 95 | base_lr=lr_g) 96 | d_params = get_parameter_groups(discriminator.parameters(), 97 | d_gradient_scales, 98 | base_lr=lr_d) 99 | else: 100 | g_params = generator.parameters() 101 | d_params = discriminator.parameters() 102 | 103 | # Optimizers 104 | if optimizer == 'rmsprop': 105 | g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8) 106 | d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8) 107 | elif optimizer == 'adam': 108 | g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(0., 0.99), eps=1e-8) 109 | d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(0., 0.99), eps=1e-8) 110 | elif optimizer == 'sgd': 111 | g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.) 112 | d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.) 113 | 114 | return g_optimizer, d_optimizer 115 | 116 | 117 | def build_lr_scheduler(optimizer, config, last_epoch=-1): 118 | lr_scheduler = optim.lr_scheduler.StepLR( 119 | optimizer, 120 | step_size=config['training']['lr_anneal_every'], 121 | gamma=config['training']['lr_anneal'], 122 | last_epoch=last_epoch 123 | ) 124 | return lr_scheduler 125 | 126 | 127 | # Some utility functions 128 | def get_parameter_groups(parameters, gradient_scales, base_lr): 129 | param_groups = [] 130 | for p in parameters: 131 | c = gradient_scales.get(p, 1.) 132 | param_groups.append({ 133 | 'params': [p], 134 | 'lr': c * base_lr 135 | }) 136 | return param_groups 137 | -------------------------------------------------------------------------------- /gan_training/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | 4 | 5 | def get_zdist(dist_name, dim, device=None): 6 | # Get distribution 7 | if dist_name == 'uniform': 8 | low = -torch.ones(dim, device=device) 9 | high = torch.ones(dim, device=device) 10 | zdist = distributions.Uniform(low, high) 11 | elif dist_name == 'gauss': 12 | mu = torch.zeros(dim, device=device) 13 | scale = torch.ones(dim, device=device) 14 | zdist = distributions.Normal(mu, scale) 15 | else: 16 | raise NotImplementedError 17 | 18 | # Add dim attribute 19 | zdist.dim = dim 20 | 21 | return zdist 22 | 23 | 24 | def get_ydist(nlabels, device=None): 25 | logits = torch.zeros(nlabels, device=device) 26 | ydist = distributions.categorical.Categorical(logits=logits) 27 | 28 | # Add nlabels attribute 29 | ydist.nlabels = nlabels 30 | 31 | return ydist 32 | 33 | 34 | def interpolate_sphere(z1, z2, t): 35 | p = (z1 * z2).sum(dim=-1, keepdim=True) 36 | p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt() 37 | p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt() 38 | omega = torch.acos(p) 39 | s1 = torch.sin((1-t)*omega)/torch.sin(omega) 40 | s2 = torch.sin(t*omega)/torch.sin(omega) 41 | z = s1 * z1 + s2 * z2 42 | 43 | return z 44 | -------------------------------------------------------------------------------- /gan_training/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gan_training.metrics import inception_score 3 | 4 | 5 | class Evaluator(object): 6 | def __init__(self, generator, zdist, ydist, batch_size=64, 7 | inception_nsamples=60000, device=None): 8 | self.generator = generator 9 | self.zdist = zdist 10 | self.ydist = ydist 11 | self.inception_nsamples = inception_nsamples 12 | self.batch_size = batch_size 13 | self.device = device 14 | 15 | def compute_inception_score(self): 16 | self.generator.eval() 17 | imgs = [] 18 | while(len(imgs) < self.inception_nsamples): 19 | ztest = self.zdist.sample((self.batch_size,)) 20 | ytest = self.ydist.sample((self.batch_size,)) 21 | 22 | samples = self.generator(ztest, ytest) 23 | samples = [s.data.cpu().numpy() for s in samples] 24 | imgs.extend(samples) 25 | 26 | imgs = imgs[:self.inception_nsamples] 27 | score, score_std = inception_score( 28 | imgs, device=self.device, resize=True, splits=10 29 | ) 30 | 31 | return score, score_std 32 | 33 | def create_samples(self, z, y=None): 34 | self.generator.eval() 35 | batch_size = z.size(0) 36 | # Parse y 37 | if y is None: 38 | y = self.ydist.sample((batch_size,)) 39 | elif isinstance(y, int): 40 | y = torch.full((batch_size,), y, 41 | device=self.device, dtype=torch.int64) 42 | # Sample x 43 | with torch.no_grad(): 44 | x = self.generator(z, y) 45 | return x 46 | -------------------------------------------------------------------------------- /gan_training/inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torchvision.datasets as datasets 4 | import numpy as np 5 | 6 | 7 | def get_dataset(name, data_dir, size=64, lsun_categories=None): 8 | transform = transforms.Compose([ 9 | transforms.Resize(size), 10 | transforms.CenterCrop(size), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 14 | transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())), 15 | ]) 16 | 17 | if name == 'image': 18 | dataset = datasets.ImageFolder(data_dir, transform) 19 | nlabels = len(dataset.classes) 20 | elif name == 'npy': 21 | # Only support normalization for now 22 | dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy']) 23 | nlabels = len(dataset.classes) 24 | elif name == 'cifar10': 25 | dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, 26 | transform=transform) 27 | nlabels = 10 28 | elif name == 'lsun': 29 | if lsun_categories is None: 30 | lsun_categories = 'train' 31 | dataset = datasets.LSUN(data_dir, lsun_categories, transform) 32 | nlabels = len(dataset.classes) 33 | elif name == 'lsun_class': 34 | dataset = datasets.LSUNClass(data_dir, transform, 35 | target_transform=(lambda t: 0)) 36 | nlabels = 1 37 | else: 38 | raise NotImplemented 39 | 40 | return dataset, nlabels 41 | 42 | 43 | def npy_loader(path): 44 | img = np.load(path) 45 | 46 | if img.dtype == np.uint8: 47 | img = img.astype(np.float32) 48 | img = img/127.5 - 1. 49 | elif img.dtype == np.float32: 50 | img = img * 2 - 1. 51 | else: 52 | raise NotImplementedError 53 | 54 | img = torch.Tensor(img) 55 | if len(img.size()) == 4: 56 | img.squeeze_(0) 57 | 58 | return img 59 | -------------------------------------------------------------------------------- /gan_training/logger.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import torchvision 4 | 5 | 6 | class Logger(object): 7 | def __init__(self, log_dir='./logs', img_dir='./imgs', 8 | monitoring=None, monitoring_dir=None): 9 | self.stats = dict() 10 | self.log_dir = log_dir 11 | self.img_dir = img_dir 12 | 13 | if not os.path.exists(log_dir): 14 | os.makedirs(log_dir) 15 | 16 | if not os.path.exists(img_dir): 17 | os.makedirs(img_dir) 18 | 19 | if not (monitoring is None or monitoring == 'none'): 20 | self.setup_monitoring(monitoring, monitoring_dir) 21 | else: 22 | self.monitoring = None 23 | self.monitoring_dir = None 24 | 25 | def setup_monitoring(self, monitoring, monitoring_dir=None): 26 | self.monitoring = monitoring 27 | self.monitoring_dir = monitoring_dir 28 | 29 | if monitoring == 'telemetry': 30 | import telemetry 31 | self.tm = telemetry.ApplicationTelemetry() 32 | if self.tm.get_status() == 0: 33 | print('Telemetry successfully connected.') 34 | elif monitoring == 'tensorboard': 35 | import tensorboardX 36 | self.tb = tensorboardX.SummaryWriter(monitoring_dir) 37 | else: 38 | raise NotImplementedError('Monitoring tool "%s" not supported!' 39 | % monitoring) 40 | 41 | def add(self, category, k, v, it): 42 | if category not in self.stats: 43 | self.stats[category] = {} 44 | 45 | if k not in self.stats[category]: 46 | self.stats[category][k] = [] 47 | 48 | self.stats[category][k].append((it, v)) 49 | 50 | k_name = '%s/%s' % (category, k) 51 | if self.monitoring == 'telemetry': 52 | self.tm.metric_push_async({ 53 | 'metric': k_name, 'value': v, 'it': it 54 | }) 55 | elif self.monitoring == 'tensorboard': 56 | self.tb.add_scalar(k_name, v, it) 57 | 58 | def add_imgs(self, imgs, class_name, it): 59 | outdir = os.path.join(self.img_dir, class_name) 60 | if not os.path.exists(outdir): 61 | os.makedirs(outdir) 62 | outfile = os.path.join(outdir, '%08d.png' % it) 63 | 64 | imgs = imgs / 2 + 0.5 65 | imgs = torchvision.utils.make_grid(imgs) 66 | torchvision.utils.save_image(imgs, outfile, nrow=8) 67 | 68 | if self.monitoring == 'tensorboard': 69 | self.tb.add_image(class_name, imgs, it) 70 | 71 | def get_last(self, category, k, default=0.): 72 | if category not in self.stats: 73 | return default 74 | elif k not in self.stats[category]: 75 | return default 76 | else: 77 | return self.stats[category][k][-1][1] 78 | 79 | def save_stats(self, filename): 80 | filename = os.path.join(self.log_dir, filename) 81 | with open(filename, 'wb') as f: 82 | pickle.dump(self.stats, f) 83 | 84 | def load_stats(self, filename): 85 | filename = os.path.join(self.log_dir, filename) 86 | if not os.path.exists(filename): 87 | print('Warning: file "%s" does not exist!' % filename) 88 | return 89 | 90 | try: 91 | with open(filename, 'rb') as f: 92 | self.stats = pickle.load(f) 93 | except EOFError: 94 | print('Warning: log file corrupted!') 95 | -------------------------------------------------------------------------------- /gan_training/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from gan_training.metrics.inception_score import inception_score 2 | 3 | __all__ = [ 4 | inception_score 5 | ] 6 | -------------------------------------------------------------------------------- /gan_training/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | 6 | from torchvision.models.inception import inception_v3 7 | 8 | import numpy as np 9 | from scipy.stats import entropy 10 | 11 | 12 | def inception_score(imgs, device=None, batch_size=32, resize=False, splits=1): 13 | """Computes the inception score of the generated images imgs 14 | 15 | Args: 16 | imgs: Torch dataset of (3xHxW) numpy images normalized in the 17 | range [-1, 1] 18 | cuda: whether or not to run on GPU 19 | batch_size: batch size for feeding into Inception v3 20 | splits: number of splits 21 | """ 22 | N = len(imgs) 23 | 24 | assert batch_size > 0 25 | assert N > batch_size 26 | 27 | # Set up dataloader 28 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 29 | 30 | # Load inception model 31 | inception_model = inception_v3(pretrained=True, transform_input=False) 32 | inception_model = inception_model.to(device) 33 | inception_model.eval() 34 | up = nn.Upsample(size=(299, 299), mode='bilinear').to(device) 35 | 36 | def get_pred(x): 37 | with torch.no_grad(): 38 | if resize: 39 | x = up(x) 40 | x = inception_model(x) 41 | out = F.softmax(x, dim=-1) 42 | out = out.cpu().numpy() 43 | return out 44 | 45 | # Get predictions 46 | preds = np.zeros((N, 1000)) 47 | 48 | for i, batch in enumerate(dataloader, 0): 49 | batchv = batch.to(device) 50 | batch_size_i = batch.size()[0] 51 | 52 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 53 | 54 | # Now compute the mean kl-div 55 | split_scores = [] 56 | 57 | for k in range(splits): 58 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 59 | py = np.mean(part, axis=0) 60 | scores = [] 61 | for i in range(part.shape[0]): 62 | pyx = part[i, :] 63 | scores.append(entropy(pyx, py)) 64 | split_scores.append(np.exp(np.mean(scores))) 65 | 66 | return np.mean(split_scores), np.std(split_scores) 67 | -------------------------------------------------------------------------------- /gan_training/models/__init__.py: -------------------------------------------------------------------------------- 1 | from gan_training.models import ( 2 | resnet, resnet2, resnet3, resnet4, 3 | ) 4 | 5 | generator_dict = { 6 | 'resnet': resnet.Generator, 7 | 'resnet2': resnet2.Generator, 8 | 'resnet3': resnet3.Generator, 9 | 'resnet4': resnet4.Generator, 10 | } 11 | 12 | discriminator_dict = { 13 | 'resnet': resnet.Discriminator, 14 | 'resnet2': resnet2.Discriminator, 15 | 'resnet3': resnet3.Discriminator, 16 | 'resnet4': resnet4.Discriminator, 17 | } 18 | -------------------------------------------------------------------------------- /gan_training/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | import numpy as np 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, nfilter_max=512, **kwargs): 12 | super().__init__() 13 | s0 = self.s0 = 4 14 | nf = self.nf = nfilter 15 | nf_max = self.nf_max = nfilter_max 16 | 17 | self.z_dim = z_dim 18 | 19 | # Submodules 20 | nlayers = int(np.log2(size / s0)) 21 | self.nf0 = min(nf_max, nf * 2**nlayers) 22 | 23 | self.embedding = nn.Embedding(nlabels, embed_size) 24 | self.fc = nn.Linear(z_dim + embed_size, self.nf0*s0*s0) 25 | 26 | blocks = [] 27 | for i in range(nlayers): 28 | nf0 = min(nf * 2**(nlayers-i), nf_max) 29 | nf1 = min(nf * 2**(nlayers-i-1), nf_max) 30 | blocks += [ 31 | ResnetBlock(nf0, nf1), 32 | nn.Upsample(scale_factor=2) 33 | ] 34 | 35 | blocks += [ 36 | ResnetBlock(nf, nf), 37 | ] 38 | 39 | self.resnet = nn.Sequential(*blocks) 40 | self.conv_img = nn.Conv2d(nf, 3, 3, padding=1) 41 | 42 | def forward(self, z, y): 43 | assert(z.size(0) == y.size(0)) 44 | batch_size = z.size(0) 45 | 46 | if y.dtype is torch.int64: 47 | yembed = self.embedding(y) 48 | else: 49 | yembed = y 50 | 51 | yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True) 52 | 53 | yz = torch.cat([z, yembed], dim=1) 54 | out = self.fc(yz) 55 | out = out.view(batch_size, self.nf0, self.s0, self.s0) 56 | 57 | out = self.resnet(out) 58 | 59 | out = self.conv_img(actvn(out)) 60 | out = torch.tanh(out) 61 | 62 | return out 63 | 64 | 65 | class Discriminator(nn.Module): 66 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, nfilter_max=1024): 67 | super().__init__() 68 | self.embed_size = embed_size 69 | s0 = self.s0 = 4 70 | nf = self.nf = nfilter 71 | nf_max = self.nf_max = nfilter_max 72 | 73 | # Submodules 74 | nlayers = int(np.log2(size / s0)) 75 | self.nf0 = min(nf_max, nf * 2**nlayers) 76 | 77 | blocks = [ 78 | ResnetBlock(nf, nf) 79 | ] 80 | 81 | for i in range(nlayers): 82 | nf0 = min(nf * 2**i, nf_max) 83 | nf1 = min(nf * 2**(i+1), nf_max) 84 | blocks += [ 85 | nn.AvgPool2d(3, stride=2, padding=1), 86 | ResnetBlock(nf0, nf1), 87 | ] 88 | 89 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1) 90 | self.resnet = nn.Sequential(*blocks) 91 | self.fc = nn.Linear(self.nf0*s0*s0, nlabels) 92 | 93 | def forward(self, x, y): 94 | assert(x.size(0) == y.size(0)) 95 | batch_size = x.size(0) 96 | 97 | out = self.conv_img(x) 98 | out = self.resnet(out) 99 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 100 | out = self.fc(actvn(out)) 101 | 102 | index = Variable(torch.LongTensor(range(out.size(0)))) 103 | if y.is_cuda: 104 | index = index.cuda() 105 | out = out[index, y] 106 | 107 | return out 108 | 109 | 110 | class ResnetBlock(nn.Module): 111 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 112 | super().__init__() 113 | # Attributes 114 | self.is_bias = is_bias 115 | self.learned_shortcut = (fin != fout) 116 | self.fin = fin 117 | self.fout = fout 118 | if fhidden is None: 119 | self.fhidden = min(fin, fout) 120 | else: 121 | self.fhidden = fhidden 122 | 123 | # Submodules 124 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 125 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 126 | if self.learned_shortcut: 127 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 128 | 129 | def forward(self, x): 130 | x_s = self._shortcut(x) 131 | dx = self.conv_0(actvn(x)) 132 | dx = self.conv_1(actvn(dx)) 133 | out = x_s + 0.1*dx 134 | 135 | return out 136 | 137 | def _shortcut(self, x): 138 | if self.learned_shortcut: 139 | x_s = self.conv_s(x) 140 | else: 141 | x_s = x 142 | return x_s 143 | 144 | 145 | def actvn(x): 146 | out = F.leaky_relu(x, 2e-1) 147 | return out 148 | -------------------------------------------------------------------------------- /gan_training/models/resnet2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 11 | super().__init__() 12 | s0 = self.s0 = size // 32 13 | nf = self.nf = nfilter 14 | self.z_dim = z_dim 15 | 16 | # Submodules 17 | self.embedding = nn.Embedding(nlabels, embed_size) 18 | self.fc = nn.Linear(z_dim + embed_size, 16*nf*s0*s0) 19 | 20 | self.resnet_0_0 = ResnetBlock(16*nf, 16*nf) 21 | self.resnet_0_1 = ResnetBlock(16*nf, 16*nf) 22 | 23 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf) 24 | self.resnet_1_1 = ResnetBlock(16*nf, 16*nf) 25 | 26 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf) 27 | self.resnet_2_1 = ResnetBlock(8*nf, 8*nf) 28 | 29 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf) 30 | self.resnet_3_1 = ResnetBlock(4*nf, 4*nf) 31 | 32 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf) 33 | self.resnet_4_1 = ResnetBlock(2*nf, 2*nf) 34 | 35 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf) 36 | self.resnet_5_1 = ResnetBlock(1*nf, 1*nf) 37 | 38 | self.conv_img = nn.Conv2d(nf, 3, 3, padding=1) 39 | 40 | def forward(self, z, y): 41 | assert(z.size(0) == y.size(0)) 42 | batch_size = z.size(0) 43 | 44 | if y.dtype is torch.int64: 45 | yembed = self.embedding(y) 46 | else: 47 | yembed = y 48 | 49 | yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True) 50 | 51 | yz = torch.cat([z, yembed], dim=1) 52 | out = self.fc(yz) 53 | out = out.view(batch_size, 16*self.nf, self.s0, self.s0) 54 | 55 | out = self.resnet_0_0(out) 56 | out = self.resnet_0_1(out) 57 | 58 | out = F.interpolate(out, scale_factor=2) 59 | out = self.resnet_1_0(out) 60 | out = self.resnet_1_1(out) 61 | 62 | out = F.interpolate(out, scale_factor=2) 63 | out = self.resnet_2_0(out) 64 | out = self.resnet_2_1(out) 65 | 66 | out = F.interpolate(out, scale_factor=2) 67 | out = self.resnet_3_0(out) 68 | out = self.resnet_3_1(out) 69 | 70 | out = F.interpolate(out, scale_factor=2) 71 | out = self.resnet_4_0(out) 72 | out = self.resnet_4_1(out) 73 | 74 | out = F.interpolate(out, scale_factor=2) 75 | out = self.resnet_5_0(out) 76 | out = self.resnet_5_1(out) 77 | 78 | out = self.conv_img(actvn(out)) 79 | out = torch.tanh(out) 80 | 81 | return out 82 | 83 | 84 | class Discriminator(nn.Module): 85 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 86 | super().__init__() 87 | self.embed_size = embed_size 88 | s0 = self.s0 = size // 32 89 | nf = self.nf = nfilter 90 | ny = nlabels 91 | 92 | # Submodules 93 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1) 94 | 95 | self.resnet_0_0 = ResnetBlock(1*nf, 1*nf) 96 | self.resnet_0_1 = ResnetBlock(1*nf, 2*nf) 97 | 98 | self.resnet_1_0 = ResnetBlock(2*nf, 2*nf) 99 | self.resnet_1_1 = ResnetBlock(2*nf, 4*nf) 100 | 101 | self.resnet_2_0 = ResnetBlock(4*nf, 4*nf) 102 | self.resnet_2_1 = ResnetBlock(4*nf, 8*nf) 103 | 104 | self.resnet_3_0 = ResnetBlock(8*nf, 8*nf) 105 | self.resnet_3_1 = ResnetBlock(8*nf, 16*nf) 106 | 107 | self.resnet_4_0 = ResnetBlock(16*nf, 16*nf) 108 | self.resnet_4_1 = ResnetBlock(16*nf, 16*nf) 109 | 110 | self.resnet_5_0 = ResnetBlock(16*nf, 16*nf) 111 | self.resnet_5_1 = ResnetBlock(16*nf, 16*nf) 112 | 113 | self.fc = nn.Linear(16*nf*s0*s0, nlabels) 114 | 115 | 116 | def forward(self, x, y): 117 | assert(x.size(0) == y.size(0)) 118 | batch_size = x.size(0) 119 | 120 | out = self.conv_img(x) 121 | 122 | out = self.resnet_0_0(out) 123 | out = self.resnet_0_1(out) 124 | 125 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 126 | out = self.resnet_1_0(out) 127 | out = self.resnet_1_1(out) 128 | 129 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 130 | out = self.resnet_2_0(out) 131 | out = self.resnet_2_1(out) 132 | 133 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 134 | out = self.resnet_3_0(out) 135 | out = self.resnet_3_1(out) 136 | 137 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 138 | out = self.resnet_4_0(out) 139 | out = self.resnet_4_1(out) 140 | 141 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 142 | out = self.resnet_5_0(out) 143 | out = self.resnet_5_1(out) 144 | 145 | out = out.view(batch_size, 16*self.nf*self.s0*self.s0) 146 | out = self.fc(actvn(out)) 147 | 148 | index = Variable(torch.LongTensor(range(out.size(0)))) 149 | if y.is_cuda: 150 | index = index.cuda() 151 | out = out[index, y] 152 | 153 | return out 154 | 155 | 156 | class ResnetBlock(nn.Module): 157 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 158 | super().__init__() 159 | # Attributes 160 | self.is_bias = is_bias 161 | self.learned_shortcut = (fin != fout) 162 | self.fin = fin 163 | self.fout = fout 164 | if fhidden is None: 165 | self.fhidden = min(fin, fout) 166 | else: 167 | self.fhidden = fhidden 168 | 169 | # Submodules 170 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 171 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 172 | if self.learned_shortcut: 173 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 174 | 175 | 176 | def forward(self, x): 177 | x_s = self._shortcut(x) 178 | dx = self.conv_0(actvn(x)) 179 | dx = self.conv_1(actvn(dx)) 180 | out = x_s + 0.1*dx 181 | 182 | return out 183 | 184 | def _shortcut(self, x): 185 | if self.learned_shortcut: 186 | x_s = self.conv_s(x) 187 | else: 188 | x_s = x 189 | return x_s 190 | 191 | 192 | def actvn(x): 193 | out = F.leaky_relu(x, 2e-1) 194 | return out 195 | -------------------------------------------------------------------------------- /gan_training/models/resnet3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 11 | super().__init__() 12 | s0 = self.s0 = size // 64 13 | nf = self.nf = nfilter 14 | self.z_dim = z_dim 15 | 16 | # Submodules 17 | self.embedding = nn.Embedding(nlabels, embed_size) 18 | self.fc = nn.Linear(z_dim + embed_size, 32*nf*s0*s0) 19 | 20 | self.resnet_0_0 = ResnetBlock(32*nf, 16*nf) 21 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf) 22 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf) 23 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf) 24 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf) 25 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf) 26 | self.conv_img = nn.Conv2d(nf, 3, 7, padding=3) 27 | 28 | def forward(self, z, y): 29 | assert(z.size(0) == y.size(0)) 30 | batch_size = z.size(0) 31 | 32 | yembed = self.embedding(y) 33 | yz = torch.cat([z, yembed], dim=1) 34 | out = self.fc(yz) 35 | out = out.view(batch_size, 32*self.nf, self.s0, self.s0) 36 | 37 | out = self.resnet_0_0(out) 38 | 39 | out = F.interpolate(out, scale_factor=2) 40 | out = self.resnet_1_0(out) 41 | 42 | out = F.interpolate(out, scale_factor=2) 43 | out = self.resnet_2_0(out) 44 | 45 | out = F.interpolate(out, scale_factor=2) 46 | out = self.resnet_3_0(out) 47 | 48 | out = F.interpolate(out, scale_factor=2) 49 | out = self.resnet_4_0(out) 50 | 51 | out = F.interpolate(out, scale_factor=2) 52 | out = self.resnet_5_0(out) 53 | 54 | out = F.interpolate(out, scale_factor=2) 55 | 56 | out = self.conv_img(actvn(out)) 57 | out = torch.tanh(out) 58 | 59 | return out 60 | 61 | 62 | class Discriminator(nn.Module): 63 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 64 | super().__init__() 65 | self.embed_size = embed_size 66 | s0 = self.s0 = size // 64 67 | nf = self.nf = nfilter 68 | 69 | # Submodules 70 | self.conv_img = nn.Conv2d(3, 1*nf, 7, padding=3) 71 | 72 | self.resnet_0_0 = ResnetBlock(1*nf, 2*nf) 73 | self.resnet_1_0 = ResnetBlock(2*nf, 4*nf) 74 | self.resnet_2_0 = ResnetBlock(4*nf, 8*nf) 75 | self.resnet_3_0 = ResnetBlock(8*nf, 16*nf) 76 | self.resnet_4_0 = ResnetBlock(16*nf, 16*nf) 77 | self.resnet_5_0 = ResnetBlock(16*nf, 32*nf) 78 | 79 | self.fc = nn.Linear(32*nf*s0*s0, nlabels) 80 | 81 | def forward(self, x, y): 82 | assert(x.size(0) == y.size(0)) 83 | batch_size = x.size(0) 84 | 85 | out = self.conv_img(x) 86 | 87 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 88 | out = self.resnet_0_0(out) 89 | 90 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 91 | out = self.resnet_1_0(out) 92 | 93 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 94 | out = self.resnet_2_0(out) 95 | 96 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 97 | out = self.resnet_3_0(out) 98 | 99 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 100 | out = self.resnet_4_0(out) 101 | 102 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 103 | out = self.resnet_5_0(out) 104 | 105 | out = out.view(batch_size, 32*self.nf*self.s0*self.s0) 106 | out = self.fc(actvn(out)) 107 | 108 | index = Variable(torch.LongTensor(range(out.size(0)))) 109 | if y.is_cuda: 110 | index = index.cuda() 111 | out = out[index, y] 112 | 113 | return out 114 | 115 | 116 | class ResnetBlock(nn.Module): 117 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 118 | super().__init__() 119 | # Attributes 120 | self.is_bias = is_bias 121 | self.learned_shortcut = (fin != fout) 122 | self.fin = fin 123 | self.fout = fout 124 | if fhidden is None: 125 | self.fhidden = min(fin, fout) 126 | else: 127 | self.fhidden = fhidden 128 | 129 | # Submodules 130 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 131 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 132 | if self.learned_shortcut: 133 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 134 | 135 | def forward(self, x): 136 | x_s = self._shortcut(x) 137 | dx = self.conv_0(actvn(x)) 138 | dx = self.conv_1(actvn(dx)) 139 | out = x_s + 0.1*dx 140 | 141 | return out 142 | 143 | def _shortcut(self, x): 144 | if self.learned_shortcut: 145 | x_s = self.conv_s(x) 146 | else: 147 | x_s = x 148 | return x_s 149 | 150 | 151 | def actvn(x): 152 | out = F.leaky_relu(x, 2e-1) 153 | return out 154 | -------------------------------------------------------------------------------- /gan_training/models/resnet4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 11 | super().__init__() 12 | s0 = self.s0 = size // 64 13 | nf = self.nf = nfilter 14 | self.z_dim = z_dim 15 | 16 | # Submodules 17 | self.embedding = nn.Embedding(nlabels, embed_size) 18 | self.fc = nn.Linear(z_dim + embed_size, 16*nf*s0*s0) 19 | 20 | self.resnet_0_0 = ResnetBlock(16*nf, 16*nf) 21 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf) 22 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf) 23 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf) 24 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf) 25 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf) 26 | self.resnet_6_0 = ResnetBlock(1*nf, 1*nf) 27 | self.conv_img = nn.Conv2d(nf, 3, 7, padding=3) 28 | 29 | 30 | def forward(self, z, y): 31 | assert(z.size(0) == y.size(0)) 32 | batch_size = z.size(0) 33 | 34 | yembed = self.embedding(y) 35 | yz = torch.cat([z, yembed], dim=1) 36 | out = self.fc(yz) 37 | out = out.view(batch_size, 16*self.nf, self.s0, self.s0) 38 | 39 | out = self.resnet_0_0(out) 40 | 41 | out = F.interpolate(out, scale_factor=2) 42 | out = self.resnet_1_0(out) 43 | 44 | out = F.interpolate(out, scale_factor=2) 45 | out = self.resnet_2_0(out) 46 | 47 | out = F.interpolate(out, scale_factor=2) 48 | out = self.resnet_3_0(out) 49 | 50 | out = F.interpolate(out, scale_factor=2) 51 | out = self.resnet_4_0(out) 52 | 53 | out = F.interpolate(out, scale_factor=2) 54 | out = self.resnet_5_0(out) 55 | 56 | out = F.interpolate(out, scale_factor=2) 57 | out = self.resnet_6_0(out) 58 | out = self.conv_img(actvn(out)) 59 | out = torch.tanh(out) 60 | 61 | return out 62 | 63 | 64 | class Discriminator(nn.Module): 65 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 66 | super().__init__() 67 | self.embed_size = embed_size 68 | s0 = self.s0 = size // 64 69 | nf = self.nf = nfilter 70 | 71 | # Submodules 72 | self.conv_img = nn.Conv2d(3, 1*nf, 7, padding=3) 73 | 74 | self.resnet_0_0 = ResnetBlock(1*nf, 1*nf) 75 | self.resnet_1_0 = ResnetBlock(1*nf, 2*nf) 76 | self.resnet_2_0 = ResnetBlock(2*nf, 4*nf) 77 | self.resnet_3_0 = ResnetBlock(4*nf, 8*nf) 78 | self.resnet_4_0 = ResnetBlock(8*nf, 16*nf) 79 | self.resnet_5_0 = ResnetBlock(16*nf, 16*nf) 80 | self.resnet_6_0 = ResnetBlock(16*nf, 16*nf) 81 | 82 | self.fc = nn.Linear(16*nf*s0*s0, nlabels) 83 | 84 | def forward(self, x, y): 85 | assert(x.size(0) == y.size(0)) 86 | batch_size = x.size(0) 87 | 88 | out = self.conv_img(x) 89 | out = self.resnet_0_0(out) 90 | 91 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 92 | out = self.resnet_1_0(out) 93 | 94 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 95 | out = self.resnet_2_0(out) 96 | 97 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 98 | out = self.resnet_3_0(out) 99 | 100 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 101 | out = self.resnet_4_0(out) 102 | 103 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 104 | out = self.resnet_5_0(out) 105 | 106 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 107 | out = self.resnet_6_0(out) 108 | 109 | out = out.view(batch_size, 16*self.nf*self.s0*self.s0) 110 | out = self.fc(actvn(out)) 111 | 112 | index = Variable(torch.LongTensor(range(out.size(0)))) 113 | if y.is_cuda: 114 | index = index.cuda() 115 | out = out[index, y] 116 | 117 | return out 118 | 119 | 120 | class ResnetBlock(nn.Module): 121 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 122 | super().__init__() 123 | # Attributes 124 | self.is_bias = is_bias 125 | self.learned_shortcut = (fin != fout) 126 | self.fin = fin 127 | self.fout = fout 128 | if fhidden is None: 129 | self.fhidden = min(fin, fout) 130 | else: 131 | self.fhidden = fhidden 132 | 133 | # Submodules 134 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 135 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 136 | if self.learned_shortcut: 137 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 138 | 139 | def forward(self, x): 140 | x_s = self._shortcut(x) 141 | dx = self.conv_0(actvn(x)) 142 | dx = self.conv_1(actvn(dx)) 143 | out = x_s + 0.1*dx 144 | 145 | return out 146 | 147 | def _shortcut(self, x): 148 | if self.learned_shortcut: 149 | x_s = self.conv_s(x) 150 | else: 151 | x_s = x 152 | return x_s 153 | 154 | 155 | def actvn(x): 156 | out = F.leaky_relu(x, 2e-1) 157 | return out 158 | -------------------------------------------------------------------------------- /gan_training/ops.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import Parameter 4 | 5 | 6 | class SpectralNorm(nn.Module): 7 | def __init__(self, module, name='weight', power_iterations=1): 8 | super(SpectralNorm, self).__init__() 9 | self.module = module 10 | self.name = name 11 | self.power_iterations = power_iterations 12 | if not self._made_params(): 13 | self._make_params() 14 | 15 | def _update_u_v(self): 16 | u = getattr(self.module, self.name + "_u") 17 | v = getattr(self.module, self.name + "_v") 18 | w = getattr(self.module, self.name + "_bar") 19 | 20 | height = w.data.shape[0] 21 | for _ in range(self.power_iterations): 22 | v.data = l2normalize( 23 | torch.mv(torch.t(w.view(height, -1).data), u.data)) 24 | u.data = l2normalize( 25 | torch.mv(w.view(height, -1).data, v.data)) 26 | 27 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 28 | sigma = u.dot(w.view(height, -1).mv(v)) 29 | setattr(self.module, self.name, w / sigma.expand_as(w)) 30 | 31 | def _made_params(self): 32 | made_params = ( 33 | hasattr(self.module, self.name + "_u") 34 | and hasattr(self.module, self.name + "_v") 35 | and hasattr(self.module, self.name + "_bar") 36 | ) 37 | return made_params 38 | 39 | def _make_params(self): 40 | w = getattr(self.module, self.name) 41 | 42 | height = w.data.shape[0] 43 | width = w.view(height, -1).data.shape[1] 44 | 45 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 46 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 47 | u.data = l2normalize(u.data) 48 | v.data = l2normalize(v.data) 49 | w_bar = Parameter(w.data) 50 | 51 | del self.module._parameters[self.name] 52 | 53 | self.module.register_parameter(self.name + "_u", u) 54 | self.module.register_parameter(self.name + "_v", v) 55 | self.module.register_parameter(self.name + "_bar", w_bar) 56 | 57 | def forward(self, *args): 58 | self._update_u_v() 59 | return self.module.forward(*args) 60 | 61 | 62 | def l2normalize(v, eps=1e-12): 63 | return v / (v.norm() + eps) 64 | 65 | 66 | class CBatchNorm(nn.Module): 67 | def __init__(self, nfilter, nlabels): 68 | super().__init__() 69 | # Attributes 70 | self.nlabels = nlabels 71 | self.nfilter = nfilter 72 | # Submodules 73 | self.alpha_embedding = nn.Embedding(nlabels, nfilter) 74 | self.beta_embedding = nn.Embedding(nlabels, nfilter) 75 | self.bn = nn.BatchNorm2d(nfilter, affine=False) 76 | # Initialize 77 | nn.init.constant_(self.alpha_embedding.weight, 1.) 78 | nn.init.constant_(self.beta_embedding.weight, 0.) 79 | 80 | def forward(self, x, y): 81 | dim = len(x.size()) 82 | batch_size = x.size(0) 83 | assert(dim >= 2) 84 | assert(x.size(1) == self.nfilter) 85 | 86 | s = [batch_size, self.nfilter] + [1] * (dim - 2) 87 | alpha = self.alpha_embedding(y) 88 | alpha = alpha.view(s) 89 | beta = self.beta_embedding(y) 90 | beta = beta.view(s) 91 | 92 | out = self.bn(x) 93 | out = alpha * out + beta 94 | 95 | return out 96 | 97 | 98 | class CInstanceNorm(nn.Module): 99 | def __init__(self, nfilter, nlabels): 100 | super().__init__() 101 | # Attributes 102 | self.nlabels = nlabels 103 | self.nfilter = nfilter 104 | # Submodules 105 | self.alpha_embedding = nn.Embedding(nlabels, nfilter) 106 | self.beta_embedding = nn.Embedding(nlabels, nfilter) 107 | self.bn = nn.InstanceNorm2d(nfilter, affine=False) 108 | # Initialize 109 | nn.init.uniform(self.alpha_embedding.weight, -1., 1.) 110 | nn.init.constant_(self.beta_embedding.weight, 0.) 111 | 112 | def forward(self, x, y): 113 | dim = len(x.size()) 114 | batch_size = x.size(0) 115 | assert(dim >= 2) 116 | assert(x.size(1) == self.nfilter) 117 | 118 | s = [batch_size, self.nfilter] + [1] * (dim - 2) 119 | alpha = self.alpha_embedding(y) 120 | alpha = alpha.view(s) 121 | beta = self.beta_embedding(y) 122 | beta = beta.view(s) 123 | 124 | out = self.bn(x) 125 | out = alpha * out + beta 126 | 127 | return out 128 | -------------------------------------------------------------------------------- /gan_training/train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | from torch import autograd 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, generator, discriminator, g_optimizer, d_optimizer, 11 | gan_type, reg_type, reg_param): 12 | self.generator = generator 13 | self.discriminator = discriminator 14 | self.g_optimizer = g_optimizer 15 | self.d_optimizer = d_optimizer 16 | 17 | self.gan_type = gan_type 18 | self.reg_type = reg_type 19 | self.reg_param = reg_param 20 | 21 | def generator_trainstep(self, y, z): 22 | assert(y.size(0) == z.size(0)) 23 | toggle_grad(self.generator, True) 24 | toggle_grad(self.discriminator, False) 25 | self.generator.train() 26 | self.discriminator.train() 27 | self.g_optimizer.zero_grad() 28 | 29 | x_fake = self.generator(z, y) 30 | d_fake = self.discriminator(x_fake, y) 31 | gloss = self.compute_loss(d_fake, 1) 32 | gloss.backward() 33 | 34 | self.g_optimizer.step() 35 | 36 | return gloss.item() 37 | 38 | def discriminator_trainstep(self, x_real, y, z): 39 | toggle_grad(self.generator, False) 40 | toggle_grad(self.discriminator, True) 41 | self.generator.train() 42 | self.discriminator.train() 43 | self.d_optimizer.zero_grad() 44 | 45 | # On real data 46 | x_real.requires_grad_() 47 | 48 | d_real = self.discriminator(x_real, y) 49 | dloss_real = self.compute_loss(d_real, 1) 50 | 51 | if self.reg_type == 'real' or self.reg_type == 'real_fake': 52 | dloss_real.backward(retain_graph=True) 53 | reg = self.reg_param * compute_grad2(d_real, x_real).mean() 54 | reg.backward() 55 | else: 56 | dloss_real.backward() 57 | 58 | # On fake data 59 | with torch.no_grad(): 60 | x_fake = self.generator(z, y) 61 | 62 | x_fake.requires_grad_() 63 | d_fake = self.discriminator(x_fake, y) 64 | dloss_fake = self.compute_loss(d_fake, 0) 65 | 66 | if self.reg_type == 'fake' or self.reg_type == 'real_fake': 67 | dloss_fake.backward(retain_graph=True) 68 | reg = self.reg_param * compute_grad2(d_fake, x_fake).mean() 69 | reg.backward() 70 | else: 71 | dloss_fake.backward() 72 | 73 | if self.reg_type == 'wgangp': 74 | reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y) 75 | reg.backward() 76 | elif self.reg_type == 'wgangp0': 77 | reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y, center=0.) 78 | reg.backward() 79 | 80 | self.d_optimizer.step() 81 | 82 | toggle_grad(self.discriminator, False) 83 | 84 | # Output 85 | dloss = (dloss_real + dloss_fake) 86 | 87 | if self.reg_type == 'none': 88 | reg = torch.tensor(0.) 89 | 90 | return dloss.item(), reg.item() 91 | 92 | def compute_loss(self, d_out, target): 93 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 94 | 95 | if self.gan_type == 'standard': 96 | loss = F.binary_cross_entropy_with_logits(d_out, targets) 97 | elif self.gan_type == 'wgan': 98 | loss = (2*target - 1) * d_out.mean() 99 | else: 100 | raise NotImplementedError 101 | 102 | return loss 103 | 104 | def wgan_gp_reg(self, x_real, x_fake, y, center=1.): 105 | batch_size = y.size(0) 106 | eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1) 107 | x_interp = (1 - eps) * x_real + eps * x_fake 108 | x_interp = x_interp.detach() 109 | x_interp.requires_grad_() 110 | d_out = self.discriminator(x_interp, y) 111 | 112 | reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean() 113 | 114 | return reg 115 | 116 | 117 | # Utility functions 118 | def toggle_grad(model, requires_grad): 119 | for p in model.parameters(): 120 | p.requires_grad_(requires_grad) 121 | 122 | 123 | def compute_grad2(d_out, x_in): 124 | batch_size = x_in.size(0) 125 | grad_dout = autograd.grad( 126 | outputs=d_out.sum(), inputs=x_in, 127 | create_graph=True, retain_graph=True, only_inputs=True 128 | )[0] 129 | grad_dout2 = grad_dout.pow(2) 130 | assert(grad_dout2.size() == x_in.size()) 131 | reg = grad_dout2.view(batch_size, -1).sum(1) 132 | return reg 133 | 134 | 135 | def update_average(model_tgt, model_src, beta): 136 | toggle_grad(model_src, False) 137 | toggle_grad(model_tgt, False) 138 | 139 | param_dict_src = dict(model_src.named_parameters()) 140 | 141 | for p_name, p_tgt in model_tgt.named_parameters(): 142 | p_src = param_dict_src[p_name] 143 | assert(p_src is not p_tgt) 144 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src) 145 | -------------------------------------------------------------------------------- /gan_training/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.utils.data 4 | import torch.utils.data.distributed 5 | import torchvision 6 | 7 | 8 | def save_images(imgs, outfile, nrow=8): 9 | imgs = imgs / 2 + 0.5 # unnormalize 10 | torchvision.utils.save_image(imgs, outfile, nrow=nrow) 11 | 12 | 13 | def get_nsamples(data_loader, N): 14 | x = [] 15 | y = [] 16 | n = 0 17 | while n < N: 18 | x_next, y_next = next(iter(data_loader)) 19 | x.append(x_next) 20 | y.append(y_next) 21 | n += x_next.size(0) 22 | x = torch.cat(x, dim=0)[:N] 23 | y = torch.cat(y, dim=0)[:N] 24 | return x, y 25 | 26 | 27 | def update_average(model_tgt, model_src, beta): 28 | param_dict_src = dict(model_src.named_parameters()) 29 | 30 | for p_name, p_tgt in model_tgt.named_parameters(): 31 | p_src = param_dict_src[p_name] 32 | assert(p_src is not p_tgt) 33 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src) 34 | -------------------------------------------------------------------------------- /interpolate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import copy 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.checkpoints import CheckpointIO 10 | from gan_training.distributions import get_ydist, get_zdist, interpolate_sphere 11 | from gan_training.config import ( 12 | load_config, build_models 13 | ) 14 | 15 | # Arguments 16 | parser = argparse.ArgumentParser( 17 | description='Create interpolations for a trained GAN.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 21 | 22 | args = parser.parse_args() 23 | 24 | config = load_config(args.config, 'configs/default.yaml') 25 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 26 | 27 | # Shorthands 28 | nlabels = config['data']['nlabels'] 29 | out_dir = config['training']['out_dir'] 30 | batch_size = config['test']['batch_size'] 31 | sample_size = config['test']['sample_size'] 32 | sample_nrow = config['test']['sample_nrow'] 33 | checkpoint_dir = path.join(out_dir, 'chkpts') 34 | interp_dir = path.join(out_dir, 'test', 'interp') 35 | 36 | # Creat missing directories 37 | if not path.exists(interp_dir): 38 | os.makedirs(interp_dir) 39 | 40 | # Logger 41 | checkpoint_io = CheckpointIO( 42 | checkpoint_dir=checkpoint_dir 43 | ) 44 | 45 | # Get model file 46 | model_file = config['test']['model_file'] 47 | 48 | # Models 49 | device = torch.device("cuda:0" if is_cuda else "cpu") 50 | 51 | generator, discriminator = build_models(config) 52 | print(generator) 53 | print(discriminator) 54 | 55 | # Put models on gpu if needed 56 | generator = generator.to(device) 57 | discriminator = discriminator.to(device) 58 | 59 | # Use multiple GPUs if possible 60 | generator = nn.DataParallel(generator) 61 | discriminator = nn.DataParallel(discriminator) 62 | 63 | # Register modules to checkpoint 64 | checkpoint_io.register_modules( 65 | generator=generator, 66 | discriminator=discriminator, 67 | ) 68 | 69 | # Test generator 70 | if config['test']['use_model_average']: 71 | generator_test = copy.deepcopy(generator) 72 | checkpoint_io.register_modules(generator_test=generator_test) 73 | else: 74 | generator_test = generator 75 | 76 | # Distributions 77 | ydist = get_ydist(nlabels, device=device) 78 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 79 | device=device) 80 | 81 | 82 | # Load checkpoint if existant 83 | load_dict = checkpoint_io.load(model_file) 84 | it = load_dict.get('it', -1) 85 | epoch_idx = load_dict.get('epoch_idx', -1) 86 | 87 | # Interpolations 88 | print('Creating interplations...') 89 | nsteps = config['interpolations']['nzs'] 90 | nsubsteps = config['interpolations']['nsubsteps'] 91 | 92 | y = ydist.sample((sample_size,)) 93 | zs = [zdist.sample((sample_size,)) for i in range(nsteps)] 94 | ts = np.linspace(0, 1, nsubsteps) 95 | 96 | it = 0 97 | for z1, z2 in zip(zs, zs[1:] + [zs[0]]): 98 | for t in ts: 99 | z = interpolate_sphere(z1, z2, float(t)) 100 | with torch.no_grad(): 101 | x = generator_test(z, y) 102 | utils.save_images(x, path.join(interp_dir, '%04d.png' % it), 103 | nrow=sample_nrow) 104 | it += 1 105 | print('%d/%d done!' % (it, nsteps * nsubsteps)) 106 | -------------------------------------------------------------------------------- /interpolate_class.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import copy 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.checkpoints import CheckpointIO 10 | from gan_training.distributions import get_ydist, get_zdist, interpolate_sphere 11 | from gan_training.config import ( 12 | load_config, build_models 13 | ) 14 | 15 | # Arguments 16 | parser = argparse.ArgumentParser( 17 | description='Create interpolations for a trained GAN.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 21 | 22 | args = parser.parse_args() 23 | 24 | config = load_config(args.config, 'configs/default.yaml') 25 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 26 | 27 | # Shorthands 28 | nlabels = config['data']['nlabels'] 29 | out_dir = config['training']['out_dir'] 30 | batch_size = config['test']['batch_size'] 31 | sample_size = config['test']['sample_size'] 32 | sample_nrow = config['test']['sample_nrow'] 33 | checkpoint_dir = path.join(out_dir, 'chkpts') 34 | interp_dir = path.join(out_dir, 'test', 'interp_class') 35 | 36 | # Creat missing directories 37 | if not path.exists(interp_dir): 38 | os.makedirs(interp_dir) 39 | 40 | # Logger 41 | checkpoint_io = CheckpointIO( 42 | checkpoint_dir=checkpoint_dir 43 | ) 44 | 45 | # Get model file 46 | model_file = config['test']['model_file'] 47 | 48 | # Models 49 | device = torch.device("cuda:0" if is_cuda else "cpu") 50 | 51 | generator, discriminator = build_models(config) 52 | print(generator) 53 | print(discriminator) 54 | 55 | # Put models on gpu if needed 56 | generator = generator.to(device) 57 | discriminator = discriminator.to(device) 58 | 59 | # Use multiple GPUs if possible 60 | generator = nn.DataParallel(generator) 61 | discriminator = nn.DataParallel(discriminator) 62 | 63 | # Register modules to checkpoint 64 | checkpoint_io.register_modules( 65 | generator=generator, 66 | discriminator=discriminator, 67 | ) 68 | 69 | # Test generator 70 | if config['test']['use_model_average']: 71 | generator_test = copy.deepcopy(generator) 72 | checkpoint_io.register_modules(generator_test=generator_test) 73 | else: 74 | generator_test = generator 75 | 76 | # Distributions 77 | ydist = get_ydist(nlabels, device=device) 78 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 79 | device=device) 80 | 81 | 82 | # Load checkpoint if existant 83 | load_dict = checkpoint_io.load(model_file) 84 | it = load_dict.get('it', -1) 85 | epoch_idx = load_dict.get('epoch_idx', -1) 86 | 87 | # Interpolations 88 | print('Creating interplations...') 89 | nsubsteps = config['interpolations']['nsubsteps'] 90 | ys = config['interpolations']['ys'] 91 | 92 | nsteps = len(ys) 93 | z = zdist.sample((sample_size,)) 94 | ts = np.linspace(0, 1, nsubsteps) 95 | 96 | it = 0 97 | for y1, y2 in zip(ys, ys[1:] + [ys[0]]): 98 | for t in ts: 99 | y1_pt = torch.full((sample_size,), y1, dtype=torch.int64, device=device) 100 | y2_pt = torch.full((sample_size,), y2, dtype=torch.int64, device=device) 101 | y1_embed = generator_test.module.embedding(y1_pt) 102 | y2_embed = generator_test.module.embedding(y2_pt) 103 | t = float(t) 104 | y_embed = (1 - t) * y1_embed + t * y2_embed 105 | with torch.no_grad(): 106 | x = generator_test(z, y_embed) 107 | utils.save_images(x, path.join(interp_dir, '%04d.png' % it), 108 | nrow=sample_nrow) 109 | it += 1 110 | print('%d/%d done!' % (it, nsteps * nsubsteps)) 111 | -------------------------------------------------------------------------------- /notebooks/create_video.sh: -------------------------------------------------------------------------------- 1 | EXEC=ffmpeg 2 | declare -a FOLDERS=( 3 | "simgd" 4 | "altgd1" 5 | "altgd5" 6 | ) 7 | declare -a SUBFOLDERS=( 8 | "gan" 9 | "gan_consensus" 10 | "gan_gradpen" 11 | "gan_gradpen_critical" 12 | "gan_instnoise" 13 | "nsgan" 14 | "nsgan_gradpen" 15 | "wgan" 16 | "wgan_gp" 17 | ) 18 | OPTIONS="-y" 19 | 20 | cd ./out 21 | for FOLDER in ${FOLDERS[@]}; do 22 | for SUBFOLDER in ${SUBFOLDERS[@]}; do 23 | INPUT="$FOLDER/animations/$SUBFOLDER/%06d.png" 24 | OUTPUT="$FOLDER/animations/$SUBFOLDER.mp4" 25 | $EXEC -framerate 30 -i $INPUT $OPTIONS $OUTPUT 26 | echo $FOLDER 27 | echo $SUBFOLDER 28 | done 29 | 30 | done 31 | -------------------------------------------------------------------------------- /notebooks/diracgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/notebooks/diracgan/__init__.py -------------------------------------------------------------------------------- /notebooks/diracgan/gans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from diracgan.util import sigmoid, clip 3 | 4 | 5 | class VectorField(object): 6 | def __call__(self, theta, psi): 7 | theta_isfloat = isinstance(theta, float) 8 | psi_isfloat = isinstance(psi, float) 9 | if theta_isfloat: 10 | theta = np.array([theta]) 11 | if psi_isfloat: 12 | psi = np.array([psi]) 13 | 14 | v1, v2 = self._get_vector(theta, psi) 15 | 16 | if theta_isfloat: 17 | v1 = v1[0] 18 | if psi_isfloat: 19 | v2 = v2[0] 20 | 21 | return v1, v2 22 | 23 | def postprocess(self, theta, psi): 24 | theta_isfloat = isinstance(theta, float) 25 | psi_isfloat = isinstance(psi, float) 26 | if theta_isfloat: 27 | theta = np.array([theta]) 28 | if psi_isfloat: 29 | psi = np.array([psi]) 30 | theta, psi = self._postprocess(theta, psi) 31 | if theta_isfloat: 32 | theta = theta[0] 33 | if psi_isfloat: 34 | psi = psi[0] 35 | 36 | return theta, psi 37 | 38 | def step_sizes(self, h): 39 | return h, h 40 | 41 | def _get_vector(self, theta, psi): 42 | raise NotImplemented 43 | 44 | def _postprocess(self, theta, psi): 45 | return theta, psi 46 | 47 | 48 | # GANs 49 | def fp(x): 50 | return sigmoid(-x) 51 | 52 | 53 | def fp2(x): 54 | return -sigmoid(-x) * sigmoid(x) 55 | 56 | 57 | class GAN(VectorField): 58 | def _get_vector(self, theta, psi): 59 | v1 = -psi * fp(psi*theta) 60 | v2 = theta * fp(psi*theta) 61 | return v1, v2 62 | 63 | 64 | class NSGAN(VectorField): 65 | def _get_vector(self, theta, psi): 66 | v1 = -psi * fp(-psi*theta) 67 | v2 = theta * fp(psi*theta) 68 | return v1, v2 69 | 70 | 71 | class WGAN(VectorField): 72 | def __init__(self, clip=0.3): 73 | super().__init__() 74 | self.clip = clip 75 | 76 | def _get_vector(self, theta, psi): 77 | v1 = -psi 78 | v2 = theta 79 | 80 | return v1, v2 81 | 82 | def _postprocess(self, theta, psi): 83 | psi = clip(psi, self.clip) 84 | return theta, psi 85 | 86 | 87 | class WGAN_GP(VectorField): 88 | def __init__(self, reg=1., target=0.3): 89 | super().__init__() 90 | self.reg = reg 91 | self.target = target 92 | 93 | def _get_vector(self, theta, psi): 94 | v1 = -psi 95 | v2 = theta - self.reg * (np.abs(psi) - self.target) * np.sign(psi) 96 | return v1, v2 97 | 98 | 99 | class GAN_InstNoise(VectorField): 100 | def __init__(self, std=1): 101 | self.std = std 102 | 103 | def _get_vector(self, theta, psi): 104 | theta_eps = ( 105 | theta + self.std*np.random.randn(*([1000] + list(theta.shape))) 106 | ) 107 | x_eps = ( 108 | self.std * np.random.randn(*([1000] + list(theta.shape))) 109 | ) 110 | v1 = -psi * fp(psi*theta_eps) 111 | v2 = theta_eps * fp(psi*theta_eps) - x_eps * fp(-x_eps * psi) 112 | v1 = v1.mean(axis=0) 113 | v2 = v2.mean(axis=0) 114 | return v1, v2 115 | 116 | 117 | class GAN_GradPenalty(VectorField): 118 | def __init__(self, reg=0.3): 119 | self.reg = reg 120 | 121 | def _get_vector(self, theta, psi): 122 | v1 = -psi * fp(psi*theta) 123 | v2 = +theta * fp(psi*theta) - self.reg * psi 124 | return v1, v2 125 | 126 | 127 | class NSGAN_GradPenalty(VectorField): 128 | def __init__(self, reg=0.3): 129 | self.reg = reg 130 | 131 | def _get_vector(self, theta, psi): 132 | v1 = -psi * fp(-psi*theta) 133 | v2 = theta * fp(psi*theta) - self.reg * psi 134 | return v1, v2 135 | 136 | 137 | class GAN_Consensus(VectorField): 138 | def __init__(self, reg=0.3): 139 | self.reg = reg 140 | 141 | def _get_vector(self, theta, psi): 142 | v1 = -psi * fp(psi*theta) 143 | v2 = +theta * fp(psi*theta) 144 | 145 | # L 0.5*(psi**2 + theta**2)*f(psi*theta)**2 146 | v1reg = ( 147 | theta * fp(psi*theta)**2 148 | + 0.5*psi * (psi**2 + theta**2) * fp(psi*theta)*fp2(psi*theta) 149 | ) 150 | v2reg = ( 151 | psi * fp(psi*theta)**2 152 | + 0.5*theta * (psi**2 + theta**2) * fp(psi*theta)*fp2(psi*theta) 153 | ) 154 | v1 -= self.reg * v1reg 155 | v2 -= self.reg * v2reg 156 | 157 | return v1, v2 158 | 159 | -------------------------------------------------------------------------------- /notebooks/diracgan/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import matplotlib.patches as patches 4 | import os 5 | from diracgan.gans import WGAN 6 | from diracgan.subplots import vector_field_plot 7 | from tqdm import tqdm 8 | 9 | 10 | def plot_vector(vecfn, theta, psi, outfile, trajectory=None, marker='b^'): 11 | fig, ax = plt.subplots(1, 1) 12 | theta, psi = np.meshgrid(theta, psi) 13 | v1, v2 = vecfn(theta, psi) 14 | if isinstance(vecfn, WGAN): 15 | clip_y = vecfn.clip 16 | else: 17 | clip_y = None 18 | vector_field_plot(theta, psi, v1, v2, trajectory, clip_y=clip_y, marker=marker) 19 | plt.savefig(outfile, bbox_inches='tight') 20 | plt.show() 21 | 22 | 23 | def simulate_trajectories(vecfn, theta, psi, trajectory, outfolder, maxframes=300): 24 | if not os.path.exists(outfolder): 25 | os.makedirs(outfolder) 26 | theta, psi = np.meshgrid(theta, psi) 27 | 28 | N = min(len(trajectory[0]), maxframes) 29 | 30 | v1, v2 = vecfn(theta, psi) 31 | if isinstance(vecfn, WGAN): 32 | clip_y = vecfn.clip 33 | else: 34 | clip_y = None 35 | 36 | for i in tqdm(range(1, N)): 37 | fig, (ax1, ax2) = plt.subplots(1, 2, 38 | subplot_kw=dict(adjustable='box', aspect=0.7)) 39 | 40 | plt.sca(ax1) 41 | trajectory_i = [trajectory[0][:i], trajectory[1][:i]] 42 | vector_field_plot(theta, psi, v1, v2, trajectory_i, clip_y=clip_y, marker='b-') 43 | plt.plot(trajectory_i[0][-1], trajectory_i[1][-1], 'bo') 44 | 45 | plt.sca(ax2) 46 | ax2.set_axisbelow(True) 47 | plt.grid() 48 | 49 | x = np.linspace(np.min(theta), np.max(theta)) 50 | y = x*trajectory[1][i] 51 | plt.plot(x, y, 'C1') 52 | 53 | ax2.add_patch(patches.Rectangle( 54 | (-0.05, 0), .1, 2.5, facecolor='C2' 55 | )) 56 | 57 | ax2.add_patch(patches.Rectangle( 58 | (trajectory[0][i]-0.05, 0), .1, 2.5, facecolor='C0' 59 | )) 60 | 61 | plt.xlim(np.min(theta), np.max(theta)) 62 | plt.ylim(-1, 3.) 63 | plt.xlabel(r'$\theta$') 64 | plt.xticks(np.linspace(np.min(theta), np.max(theta), 5)) 65 | ax2.set_yticklabels([]) 66 | 67 | plt.savefig(os.path.join(outfolder, '%06d.png' % i), dpi=200, bbox_inches='tight') 68 | plt.close() 69 | -------------------------------------------------------------------------------- /notebooks/diracgan/simulate.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Simulate 4 | def trajectory_simgd(vec_fn, theta0, psi0, 5 | nsteps=50, hs_g=0.1, hs_d=0.1): 6 | theta, psi = vec_fn.postprocess(theta0, psi0) 7 | thetas, psis = [theta], [psi] 8 | 9 | if isinstance(hs_g, float): 10 | hs_g = [hs_g] * nsteps 11 | if isinstance(hs_d, float): 12 | hs_d = [hs_d] * nsteps 13 | assert(len(hs_g) == nsteps) 14 | assert(len(hs_d) == nsteps) 15 | 16 | for h_g, h_d in zip(hs_g, hs_d): 17 | v1, v2 = vec_fn(theta, psi) 18 | theta += h_g * v1 19 | psi += h_d * v2 20 | theta, psi = vec_fn.postprocess(theta, psi) 21 | thetas.append(theta) 22 | psis.append(psi) 23 | 24 | return thetas, psis 25 | 26 | 27 | def trajectory_altgd(vec_fn, theta0, psi0, 28 | nsteps=50, hs_g=0.1, hs_d=0.1, gsteps=1, dsteps=1): 29 | theta, psi = vec_fn.postprocess(theta0, psi0) 30 | thetas, psis = [theta], [psi] 31 | 32 | if isinstance(hs_g, float): 33 | hs_g = [hs_g] * nsteps 34 | if isinstance(hs_d, float): 35 | hs_d = [hs_d] * nsteps 36 | assert(len(hs_g) == nsteps) 37 | assert(len(hs_d) == nsteps) 38 | 39 | for h_g, h_d in zip(hs_g, hs_d): 40 | for it in range(gsteps): 41 | v1, v2 = vec_fn(theta, psi) 42 | theta += h_g * v1 43 | theta, psi = vec_fn.postprocess(theta, psi) 44 | 45 | for it in range(dsteps): 46 | v1, v2 = vec_fn(theta, psi) 47 | psi += h_d * v2 48 | theta, psi = vec_fn.postprocess(theta, psi) 49 | thetas.append(theta) 50 | psis.append(psi) 51 | 52 | return thetas, psis 53 | -------------------------------------------------------------------------------- /notebooks/diracgan/subplots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | 5 | def arrow_plot(x, y, color='C1'): 6 | plt.quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1], 7 | color=color, scale_units='xy', angles='xy', scale=1) 8 | 9 | 10 | def vector_field_plot(theta, psi, v1, v2, trajectory=None, clip_y=None, marker='b^'): 11 | plt.quiver(theta, psi, v1, v2) 12 | if clip_y is not None: 13 | plt.axhspan(np.min(psi), -clip_y, facecolor='0.2', alpha=0.5) 14 | plt.plot([np.min(theta), np.max(theta)], [-clip_y, -clip_y], 'k-') 15 | plt.axhspan(clip_y, np.max(psi), facecolor='0.2', alpha=0.5) 16 | plt.plot([np.min(theta), np.max(theta)], [clip_y, clip_y], 'k-') 17 | 18 | if trajectory is not None: 19 | psis, thetas = trajectory 20 | plt.plot(psis, thetas, marker, markerfacecolor='None') 21 | plt.plot(psis[0], thetas[0], 'ro') 22 | 23 | plt.xlim(np.min(theta), np.max(theta)) 24 | plt.ylim(np.min(psi), np.max(psi)) 25 | plt.xlabel(r'$\theta$') 26 | plt.ylabel(r'$\psi$') 27 | plt.xticks(np.linspace(np.min(theta), np.max(theta), 5)) 28 | plt.yticks(np.linspace(np.min(psi), np.max(psi), 5)) 29 | -------------------------------------------------------------------------------- /notebooks/diracgan/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def sigmoid(x): 4 | m = np.minimum(0, x) 5 | return np.exp(m)/(np.exp(m) + np.exp(-x + m)) 6 | 7 | 8 | def clip(x, clipval=0.3): 9 | x = np.clip(x, -clipval, clipval) 10 | return x 11 | -------------------------------------------------------------------------------- /results/celebA-HQ.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/results/celebA-HQ.jpg -------------------------------------------------------------------------------- /results/imagenet_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/results/imagenet_00.jpg -------------------------------------------------------------------------------- /results/imagenet_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/results/imagenet_01.jpg -------------------------------------------------------------------------------- /results/imagenet_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/results/imagenet_02.jpg -------------------------------------------------------------------------------- /results/imagenet_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/results/imagenet_03.jpg -------------------------------------------------------------------------------- /results/imagenet_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LMescheder/GAN_stability/c1f64c9efeac371453065e5ce71860f4c2b97357/results/imagenet_04.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import copy 5 | from tqdm import tqdm 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.checkpoints import CheckpointIO 10 | from gan_training.distributions import get_ydist, get_zdist 11 | from gan_training.eval import Evaluator 12 | from gan_training.config import ( 13 | load_config, build_models 14 | ) 15 | 16 | # Arguments 17 | parser = argparse.ArgumentParser( 18 | description='Test a trained GAN and create visualizations.' 19 | ) 20 | parser.add_argument('config', type=str, help='Path to config file.') 21 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 22 | 23 | args = parser.parse_args() 24 | 25 | config = load_config(args.config, 'configs/default.yaml') 26 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 27 | 28 | # Shorthands 29 | nlabels = config['data']['nlabels'] 30 | out_dir = config['training']['out_dir'] 31 | batch_size = config['test']['batch_size'] 32 | sample_size = config['test']['sample_size'] 33 | sample_nrow = config['test']['sample_nrow'] 34 | checkpoint_dir = path.join(out_dir, 'chkpts') 35 | img_dir = path.join(out_dir, 'test', 'img') 36 | img_all_dir = path.join(out_dir, 'test', 'img_all') 37 | 38 | # Creat missing directories 39 | if not path.exists(img_dir): 40 | os.makedirs(img_dir) 41 | if not path.exists(img_all_dir): 42 | os.makedirs(img_all_dir) 43 | 44 | # Logger 45 | checkpoint_io = CheckpointIO( 46 | checkpoint_dir=checkpoint_dir 47 | ) 48 | 49 | # Get model file 50 | model_file = config['test']['model_file'] 51 | 52 | # Models 53 | device = torch.device("cuda:0" if is_cuda else "cpu") 54 | 55 | generator, discriminator = build_models(config) 56 | print(generator) 57 | print(discriminator) 58 | 59 | # Put models on gpu if needed 60 | generator = generator.to(device) 61 | discriminator = discriminator.to(device) 62 | 63 | # Use multiple GPUs if possible 64 | generator = nn.DataParallel(generator) 65 | discriminator = nn.DataParallel(discriminator) 66 | 67 | # Register modules to checkpoint 68 | checkpoint_io.register_modules( 69 | generator=generator, 70 | discriminator=discriminator, 71 | ) 72 | 73 | # Test generator 74 | if config['test']['use_model_average']: 75 | generator_test = copy.deepcopy(generator) 76 | checkpoint_io.register_modules(generator_test=generator_test) 77 | else: 78 | generator_test = generator 79 | 80 | # Distributions 81 | ydist = get_ydist(nlabels, device=device) 82 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 83 | device=device) 84 | 85 | # Evaluator 86 | evaluator = Evaluator(generator_test, zdist, ydist, 87 | batch_size=batch_size, device=device) 88 | 89 | # Load checkpoint if existant 90 | load_dict = checkpoint_io.load(model_file) 91 | it = load_dict.get('it', -1) 92 | epoch_idx = load_dict.get('epoch_idx', -1) 93 | 94 | # Inception score 95 | if config['test']['compute_inception']: 96 | print('Computing inception score...') 97 | inception_mean, inception_std = evaluator.compute_inception_score() 98 | print('Inception score: %.4f +- %.4f' % (inception_mean, inception_std)) 99 | 100 | # Samples 101 | print('Creating samples...') 102 | ztest = zdist.sample((sample_size,)) 103 | x = evaluator.create_samples(ztest) 104 | utils.save_images(x, path.join(img_all_dir, '%08d.png' % it), 105 | nrow=sample_nrow) 106 | if config['test']['conditional_samples']: 107 | for y_inst in tqdm(range(nlabels)): 108 | x = evaluator.create_samples(ztest, y_inst) 109 | utils.save_images(x, path.join(img_dir, '%04d.png' % y_inst), 110 | nrow=sample_nrow) 111 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import time 5 | import copy 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.train import Trainer, update_average 10 | from gan_training.logger import Logger 11 | from gan_training.checkpoints import CheckpointIO 12 | from gan_training.inputs import get_dataset 13 | from gan_training.distributions import get_ydist, get_zdist 14 | from gan_training.eval import Evaluator 15 | from gan_training.config import ( 16 | load_config, build_models, build_optimizers, build_lr_scheduler, 17 | ) 18 | 19 | # Arguments 20 | parser = argparse.ArgumentParser( 21 | description='Train a GAN with different regularization strategies.' 22 | ) 23 | parser.add_argument('config', type=str, help='Path to config file.') 24 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 25 | 26 | args = parser.parse_args() 27 | 28 | config = load_config(args.config, 'configs/default.yaml') 29 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 30 | 31 | # Short hands 32 | batch_size = config['training']['batch_size'] 33 | d_steps = config['training']['d_steps'] 34 | restart_every = config['training']['restart_every'] 35 | inception_every = config['training']['inception_every'] 36 | save_every = config['training']['save_every'] 37 | backup_every = config['training']['backup_every'] 38 | sample_nlabels = config['training']['sample_nlabels'] 39 | 40 | out_dir = config['training']['out_dir'] 41 | checkpoint_dir = path.join(out_dir, 'chkpts') 42 | 43 | # Create missing directories 44 | if not path.exists(out_dir): 45 | os.makedirs(out_dir) 46 | if not path.exists(checkpoint_dir): 47 | os.makedirs(checkpoint_dir) 48 | 49 | # Logger 50 | checkpoint_io = CheckpointIO( 51 | checkpoint_dir=checkpoint_dir 52 | ) 53 | 54 | device = torch.device("cuda:0" if is_cuda else "cpu") 55 | 56 | 57 | # Dataset 58 | train_dataset, nlabels = get_dataset( 59 | name=config['data']['type'], 60 | data_dir=config['data']['train_dir'], 61 | size=config['data']['img_size'], 62 | lsun_categories=config['data']['lsun_categories_train'] 63 | ) 64 | train_loader = torch.utils.data.DataLoader( 65 | train_dataset, 66 | batch_size=batch_size, 67 | num_workers=config['training']['nworkers'], 68 | shuffle=True, pin_memory=True, sampler=None, drop_last=True 69 | ) 70 | 71 | # Number of labels 72 | nlabels = min(nlabels, config['data']['nlabels']) 73 | sample_nlabels = min(nlabels, sample_nlabels) 74 | 75 | # Create models 76 | generator, discriminator = build_models(config) 77 | print(generator) 78 | print(discriminator) 79 | 80 | # Put models on gpu if needed 81 | generator = generator.to(device) 82 | discriminator = discriminator.to(device) 83 | 84 | g_optimizer, d_optimizer = build_optimizers( 85 | generator, discriminator, config 86 | ) 87 | 88 | # Use multiple GPUs if possible 89 | generator = nn.DataParallel(generator) 90 | discriminator = nn.DataParallel(discriminator) 91 | 92 | # Register modules to checkpoint 93 | checkpoint_io.register_modules( 94 | generator=generator, 95 | discriminator=discriminator, 96 | g_optimizer=g_optimizer, 97 | d_optimizer=d_optimizer, 98 | ) 99 | 100 | # Get model file 101 | model_file = config['training']['model_file'] 102 | 103 | # Logger 104 | logger = Logger( 105 | log_dir=path.join(out_dir, 'logs'), 106 | img_dir=path.join(out_dir, 'imgs'), 107 | monitoring=config['training']['monitoring'], 108 | monitoring_dir=path.join(out_dir, 'monitoring') 109 | ) 110 | 111 | # Distributions 112 | ydist = get_ydist(nlabels, device=device) 113 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 114 | device=device) 115 | 116 | # Save for tests 117 | ntest = batch_size 118 | x_real, ytest = utils.get_nsamples(train_loader, ntest) 119 | ytest.clamp_(None, nlabels-1) 120 | ztest = zdist.sample((ntest,)) 121 | utils.save_images(x_real, path.join(out_dir, 'real.png')) 122 | 123 | # Test generator 124 | if config['training']['take_model_average']: 125 | generator_test = copy.deepcopy(generator) 126 | checkpoint_io.register_modules(generator_test=generator_test) 127 | else: 128 | generator_test = generator 129 | 130 | # Evaluator 131 | evaluator = Evaluator(generator_test, zdist, ydist, 132 | batch_size=batch_size, device=device) 133 | 134 | # Train 135 | tstart = t0 = time.time() 136 | 137 | # Load checkpoint if it exists 138 | try: 139 | load_dict = checkpoint_io.load(model_file) 140 | except FileNotFoundError: 141 | it = epoch_idx = -1 142 | else: 143 | it = load_dict.get('it', -1) 144 | epoch_idx = load_dict.get('epoch_idx', -1) 145 | logger.load_stats('stats.p') 146 | 147 | # Reinitialize model average if needed 148 | if (config['training']['take_model_average'] 149 | and config['training']['model_average_reinit']): 150 | update_average(generator_test, generator, 0.) 151 | 152 | # Learning rate anneling 153 | g_scheduler = build_lr_scheduler(g_optimizer, config, last_epoch=it) 154 | d_scheduler = build_lr_scheduler(d_optimizer, config, last_epoch=it) 155 | 156 | # Trainer 157 | trainer = Trainer( 158 | generator, discriminator, g_optimizer, d_optimizer, 159 | gan_type=config['training']['gan_type'], 160 | reg_type=config['training']['reg_type'], 161 | reg_param=config['training']['reg_param'] 162 | ) 163 | 164 | # Training loop 165 | print('Start training...') 166 | while True: 167 | epoch_idx += 1 168 | print('Start epoch %d...' % epoch_idx) 169 | 170 | for x_real, y in train_loader: 171 | it += 1 172 | g_scheduler.step() 173 | d_scheduler.step() 174 | 175 | d_lr = d_optimizer.param_groups[0]['lr'] 176 | g_lr = g_optimizer.param_groups[0]['lr'] 177 | logger.add('learning_rates', 'discriminator', d_lr, it=it) 178 | logger.add('learning_rates', 'generator', g_lr, it=it) 179 | 180 | x_real, y = x_real.to(device), y.to(device) 181 | y.clamp_(None, nlabels-1) 182 | 183 | # Discriminator updates 184 | z = zdist.sample((batch_size,)) 185 | dloss, reg = trainer.discriminator_trainstep(x_real, y, z) 186 | logger.add('losses', 'discriminator', dloss, it=it) 187 | logger.add('losses', 'regularizer', reg, it=it) 188 | 189 | # Generators updates 190 | if ((it + 1) % d_steps) == 0: 191 | z = zdist.sample((batch_size,)) 192 | gloss = trainer.generator_trainstep(y, z) 193 | logger.add('losses', 'generator', gloss, it=it) 194 | 195 | if config['training']['take_model_average']: 196 | update_average(generator_test, generator, 197 | beta=config['training']['model_average_beta']) 198 | 199 | # Print stats 200 | g_loss_last = logger.get_last('losses', 'generator') 201 | d_loss_last = logger.get_last('losses', 'discriminator') 202 | d_reg_last = logger.get_last('losses', 'regularizer') 203 | print('[epoch %0d, it %4d] g_loss = %.4f, d_loss = %.4f, reg=%.4f' 204 | % (epoch_idx, it, g_loss_last, d_loss_last, d_reg_last)) 205 | 206 | # (i) Sample if necessary 207 | if (it % config['training']['sample_every']) == 0: 208 | print('Creating samples...') 209 | x = evaluator.create_samples(ztest, ytest) 210 | logger.add_imgs(x, 'all', it) 211 | for y_inst in range(sample_nlabels): 212 | x = evaluator.create_samples(ztest, y_inst) 213 | logger.add_imgs(x, '%04d' % y_inst, it) 214 | 215 | # (ii) Compute inception if necessary 216 | if inception_every > 0 and ((it + 1) % inception_every) == 0: 217 | inception_mean, inception_std = evaluator.compute_inception_score() 218 | logger.add('inception_score', 'mean', inception_mean, it=it) 219 | logger.add('inception_score', 'stddev', inception_std, it=it) 220 | 221 | # (iii) Backup if necessary 222 | if ((it + 1) % backup_every) == 0: 223 | print('Saving backup...') 224 | checkpoint_io.save('model_%08d.pt' % it, it=it) 225 | logger.save_stats('stats_%08d.p' % it) 226 | 227 | # (iv) Save checkpoint if necessary 228 | if time.time() - t0 > save_every: 229 | print('Saving checkpoint...') 230 | checkpoint_io.save(model_file, it=it) 231 | logger.save_stats('stats.p') 232 | t0 = time.time() 233 | 234 | if (restart_every > 0 and t0 - tstart > restart_every): 235 | exit(3) 236 | --------------------------------------------------------------------------------