├── .gitignore ├── LICENSE ├── README.md ├── biggan_pytorch ├── BigGAN.py ├── BigGANdeep.py ├── LICENSE ├── README.md ├── TFHub │ ├── README.md │ ├── biggan_v1.py │ └── converter.py ├── animal_hash.py ├── calculate_inception_moments.py ├── datasets.py ├── imgs │ ├── D Singular Values.png │ ├── DeepSamples.png │ ├── DogBall.png │ ├── G Singular Values.png │ ├── IS_FID.png │ ├── Losses.png │ ├── header_image.jpg │ └── interp_sample.jpg ├── inception_tf13.py ├── inception_utils.py ├── layers.py ├── logs │ ├── BigGAN_ch96_bs256x8.jsonl │ ├── compare_IS.m │ ├── metalog.txt │ ├── process_inception_log.m │ └── process_training.m ├── losses.py ├── make_hdf5.py ├── sample.py ├── scripts │ ├── launch_BigGAN_bs256x8.sh │ ├── launch_BigGAN_bs512x4.sh │ ├── launch_BigGAN_ch64_bs256x8.sh │ ├── launch_BigGAN_deep.sh │ ├── launch_SAGAN_bs128x2_ema.sh │ ├── launch_SNGAN.sh │ ├── launch_cifar_ema.sh │ ├── sample_BigGAN_bs256x8.sh │ ├── sample_cifar_ema.sh │ └── utils │ │ ├── duplicate.sh │ │ └── prepare_data.sh ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── train.py ├── train_fns.py └── utils.py ├── datasets ├── __init__.py ├── datasets.py ├── imagenet2synset.json ├── imagenet_utils.py └── imagenetclass992.json ├── models.py ├── prepare_biggan_images.py ├── prepare_imagenet_images.py ├── sample_dataset.py ├── train.py ├── utils.py └── viz └── sample_overall.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ##

BigDatasetGAN: Synthesizing ImageNet with Pixel-wise Annotations

2 | 3 | This is the official code for: 4 | 5 | #### BigDatasetGAN: Synthesizing ImageNet with Pixel-wise Annotations 6 | 7 | [Daiqing Li](https://scholar.google.ca/citations?user=8q2ISMIAAAAJ&hl=en), [Huan Ling](http://www.cs.toronto.edu/~linghuan/), [Seung Wook Kim](https://seung-kim.github.io/seungkim/), [Karsten Kreis](https://scholar.google.de/citations?user=rFd-DiAAAAAJ&hl=de), Adela Barriuso, [Sanja Fidler](http://www.cs.toronto.edu/~fidler/), [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab/)\ 8 | **[[Paper](https://arxiv.org/abs/2201.04684)] [[Bibtex](https://nv-tlabs.github.io/big-datasetgan/resources/bibtex.txt)] [[Project Page](https://nv-tlabs.github.io/big-datasetgan/)]** 9 | 10 |

11 | 12 |

13 | 14 | 15 |

16 | 17 |

18 | 19 | ## Requirements 20 | * Python 3.7 21 | * Cuda v11.0+ 22 | * gcc v7.3.0 23 | * Pytorch 1.9.0+ 24 | 25 | ## Pretrained BigGAN weights 26 | Our annotation and model are based on BigGAN-512, please download the model from https://tfhub.dev/deepmind/biggan-512/2, store it in `./pretrain` folder. Since the original model is trained using TensorFlow, you need to convert the model weights back to Pytorch, following the instructions here https://github.com/ajbrock/BigGAN-PyTorch/tree/master/TFHub. Notice the model is Licensed under Apache-2.0 issued by DeepMind. 27 | 28 | ## Dataset preparation 29 | We only release our annotations on sampled [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/master/LICENSE) images and images from [ImageNet](https://www.image-net.org/index.php) along with its latents used to get the sampled images. For their licenses, please refer to their websites. Notice our dataset release is under the [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. Please see [License](#license) session for details. 30 | 31 | 1. Download ImageNet from [here](https://www.image-net.org/index.php). 32 | 2. Download our annotations `annotations.zip` and latent codes `latents.zip` from [gdrive](https://drive.google.com/drive/folders/1NC0VSZrVZsd1Z_IRSdadyfCfTXMAPsf1). Unzip them into data folder under `./data/`. 33 | 3. Process images from ImageNet into our dataset format. Run the following script 34 | ``` 35 | python prepare_imagenet_images.py --imagenet_dir [path to imagenet dir] --dataset_dir ./data/ 36 | ``` 37 | 4. Prepare images generated from BigGAN. Please download the pretrained weights following [this](#pretrained-biggan-weights) session. And run 38 | ``` 39 | python prepare_biggan_images.py --biggan_ckpt ./pretrain/biggan-512.pth --dataset_dir ./data/ 40 | ``` 41 | After the processing steps, you should have data folder structure like this 42 | ``` 43 | data 44 | | 45 | └───annotations 46 | │ | 47 | │ └───biggan512 48 | │ | │ n01440764 49 | │ | │ ... 50 | | └───real-random-list.txt 51 | └───images 52 | │ | 53 | │ └───biggan512 54 | │ | │ n01440764 55 | │ | │ ... 56 | | └───real-random 57 | │ │ n01440764 58 | │ │ ... 59 | └───latents 60 | │ | 61 | │ └───biggan512 62 | │ | │ n01440764 63 | │ | │ ... 64 | ``` 65 | 66 | ## Training 67 | After the dataset preparation, we now can train BigDatasetGAN to synthesize dataset. 68 | 69 | Run the following 70 | ``` 71 | python train.py --gan_ckpt ./pretrain/biggan-512.pth \ 72 | --dataset_dir ./data/ \ 73 | --save_dir ./logs/ 74 | ``` 75 | 76 | You can monitor the training progress in tensorboard, as well as the training predictions in logs dir. 77 | 78 | By default, the training runs 5k iteration with a batch size of 4, you can adjust it for the best capacity. 79 | 80 | ## Sampling dataset 81 | After the training is done, we can synthesize ImageNet with pixel-wise labels. 82 | 83 | Run the following 84 | ``` 85 | python sample_dataset.py --ckpt [path to your pretrained BigDatasetGAN weights] \ 86 | --save_dir ./dataset_viz/ \ 87 | --class_idx 225, 200, [you can give it more with ImagetNet class idx] \ 88 | --samples_per_class 10 89 | --z_var 0.9 90 | ``` 91 | 92 |

93 | 94 |

95 | 96 | 97 | As an example, here we sample class 225 and 200 with 10 samples each. 98 | 99 | ## License 100 | For any code dependency related to BigGAN, the license is under the MIT License, see https://github.com/ajbrock/BigGAN-PyTorch/blob/master/LICENSE. 101 | 102 | The work BigDatasetGAN code is released under Creative Commons BY-NC 4.0 license, full text at http://creativecommons.org/licenses/by-nc/4.0/legalcode. 103 | 104 | The dataset of BigDatasetGAN is released under the [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made. 105 | 106 | ## Citation 107 | 108 | Please cite the following paper if you used the code in this repository. 109 | 110 | ``` 111 | @misc{li2022bigdatasetgan, 112 | title={BigDatasetGAN: Synthesizing ImageNet with Pixel-wise Annotations}, 113 | author={Daiqing Li and Huan Ling and Seung Wook Kim and Karsten Kreis and Adela Barriuso and Sanja Fidler and Antonio Torralba}, 114 | year={2022}, 115 | eprint={2201.04684}, 116 | archivePrefix={arXiv}, 117 | primaryClass={cs.CV} 118 | } 119 | ``` 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /biggan_pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Andy Brock 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /biggan_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # BigGAN-PyTorch 2 | The author's officially unofficial PyTorch BigGAN implementation. 3 | 4 | ![Dogball? Dogball!](imgs/header_image.jpg?raw=true "Dogball? Dogball!") 5 | 6 | 7 | This repo contains code for 4-8 GPU training of BigGANs from [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/abs/1809.11096) by Andrew Brock, Jeff Donahue, and Karen Simonyan. 8 | 9 | This code is by Andy Brock and Alex Andonian. 10 | 11 | ## How To Use This Code 12 | You will need: 13 | 14 | - [PyTorch](https://PyTorch.org/), version 1.0.1 15 | - tqdm, numpy, scipy, and h5py 16 | - The ImageNet training set 17 | 18 | First, you may optionally prepare a pre-processed HDF5 version of your target dataset for faster I/O. Following this (or not), you'll need the Inception moments needed to calculate FID. These can both be done by modifying and running 19 | 20 | ```sh 21 | sh scripts/utils/prepare_data.sh 22 | ``` 23 | 24 | Which by default assumes your ImageNet training set is downloaded into the root folder `data` in this directory, and will prepare the cached HDF5 at 128x128 pixel resolution. 25 | 26 | In the scripts folder, there are multiple bash scripts which will train BigGANs with different batch sizes. This code assumes you do not have access to a full TPU pod, and accordingly 27 | spoofs mega-batches by using gradient accumulation (averaging grads over multiple minibatches, and only taking an optimizer step after N accumulations). By default, the `launch_BigGAN_bs256x8.sh` script trains a 28 | full-sized BigGAN model with a batch size of 256 and 8 gradient accumulations, for a total batch size of 2048. On 8xV100 with full-precision training (no Tensor cores), this script takes 15 days to train to 150k iterations. 29 | 30 | You will first need to figure out the maximum batch size your setup can support. The pre-trained models provided here were trained on 8xV100 (16GB VRAM each) which can support slightly more than the BS256 used by default. 31 | Once you've determined this, you should modify the script so that the batch size times the number of gradient accumulations is equal to your desired total batch size (BigGAN defaults to 2048). 32 | 33 | Note also that this script uses the `--load_in_mem` arg, which loads the entire (~64GB) I128.hdf5 file into RAM for faster data loading. If you don't have enough RAM to support this (probably 96GB+), remove this argument. 34 | 35 | 36 | ## Metrics and Sampling 37 | ![I believe I can fly!](imgs/interp_sample.jpg?raw=true "I believe I can fly!") 38 | 39 | During training, this script will output logs with training metrics and test metrics, will save multiple copies (2 most recent and 5 highest-scoring) of the model weights/optimizer params, and will produce samples and interpolations every time it saves weights. 40 | The logs folder contains scripts to process these logs and plot the results using MATLAB (sorry not sorry). 41 | 42 | After training, one can use `sample.py` to produce additional samples and interpolations, test with different truncation values, batch sizes, number of standing stat accumulations, etc. See the `sample_BigGAN_bs256x8.sh` script for an example. 43 | 44 | By default, everything is saved to weights/samples/logs/data folders which are assumed to be in the same folder as this repo. 45 | You can point all of these to a different base folder using the `--base_root` argument, or pick specific locations for each of these with their respective arguments (e.g. `--logs_root`). 46 | 47 | We include scripts to run BigGAN-deep, but we have not fully trained a model using them, so consider them untested. Additionally, we include scripts to run a model on CIFAR, and to run SA-GAN (with EMA) and SN-GAN on ImageNet. The SA-GAN code assumes you have 4xTitanX (or equivalent in terms of GPU RAM) and will run with a batch size of 128 and 2 gradient accumulations. 48 | 49 | ## An Important Note on Inception Metrics 50 | This repo uses the PyTorch in-built inception network to calculate IS and FID. 51 | These scores are different from the scores you would get using the official TF inception code, and are only for monitoring purposes! 52 | Run sample.py on your model, with the `--sample_npz` argument, then run inception_tf13 to calculate the actual TensorFlow IS. Note that you will need to have TensorFlow 1.3 or earlier installed, as TF1.4+ breaks the original IS code. 53 | 54 | ## Pretrained models 55 | ![PyTorch Inception Score and FID](imgs/IS_FID.png) 56 | We include two pretrained model checkpoints (with G, D, the EMA copy of G, the optimizers, and the state dict): 57 | - The main checkpoint is for a BigGAN trained on ImageNet at 128x128, using BS256 and 8 gradient accumulations, taken just before collapse, with a TF Inception Score of 97.35 +/- 1.79: [LINK](https://drive.google.com/open?id=1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW) 58 | - An earlier checkpoint of the first model (100k G iters), at high performance but well before collapse, which may be easier to fine-tune: [LINK](https://drive.google.com/open?id=1dmZrcVJUAWkPBGza_XgswSuT-UODXZcO) 59 | 60 | 61 | 62 | Pretrained models for Places-365 coming soon. 63 | 64 | This repo also contains scripts for porting the original TFHub BigGAN Generator weights to PyTorch. See the scripts in the TFHub folder for more details. 65 | 66 | ## Fine-tuning, Using Your Own Dataset, or Making New Training Functions 67 | ![That's deep, man](imgs/DeepSamples.png?raw=true "Deep Samples") 68 | 69 | If you wish to resume interrupted training or fine-tune a pre-trained model, run the same launch script but with the `--resume` argument added. 70 | Experiment names are automatically generated from the configuration, but can be overridden using the `--experiment_name` arg (for example, if you wish to fine-tune a model using modified optimizer settings). 71 | 72 | To prep your own dataset, you will need to add it to datasets.py and modify the convenience dicts in utils.py (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict) to have the appropriate metadata for your dataset. 73 | Repeat the process in prepare_data.sh (optionally produce an HDF5 preprocessed copy, and calculate the Inception Moments for FID). 74 | 75 | By default, the training script will save the top 5 best checkpoints as measured by Inception Score. 76 | For datasets other than ImageNet, Inception Score can be a very poor measure of quality, so you will likely want to use `--which_best FID` instead. 77 | 78 | To use your own training function (e.g. train a BigVAE): either modify train_fns.GAN_training_function or add a new train fn and add it after the `if config['which_train_fn'] == 'GAN':` line in `train.py`. 79 | 80 | 81 | ## Neat Stuff 82 | - We include the full training and metrics logs [here](https://drive.google.com/open?id=1ZhY9Mg2b_S4QwxNmt57aXJ9FOC3ZN1qb) for reference. I've found that one of the hardest things about re-implementing a paper can be checking if the logs line up early in training, 83 | especially if training takes multiple weeks. Hopefully these will be helpful for future work. 84 | - We include an accelerated FID calculation--the original scipy version can require upwards of 10 minutes to calculate the matrix sqrt, this version uses an accelerated PyTorch version to calculate it in under a second. 85 | - We include an accelerated, low-memory consumption ortho reg implementation. 86 | - By default, we only compute the top singular value (the spectral norm), but this code supports computing more SVs through the `--num_G_SVs` argument. 87 | 88 | ## Key Differences Between This Code And The Original BigGAN 89 | - We use the optimizer settings from SA-GAN (G_lr=1e-4, D_lr=4e-4, num_D_steps=1, as opposed to BigGAN's G_lr=5e-5, D_lr=2e-4, num_D_steps=2). 90 | While slightly less performant, this was the first corner we cut to bring training times down. 91 | - By default, we do not use Cross-Replica BatchNorm (AKA Synced BatchNorm). 92 | The two variants we tried (a custom, naive one and the one included in this repo) have slightly different gradients (albeit identical forward passes) from the built-in BatchNorm, which appear to be sufficient to cripple training. 93 | - Gradient accumulation means that we update the SV estimates and the BN statistics 8 times more frequently. This means that the BN stats are much closer to standing stats, and that the singular value estimates tend to be more accurate. 94 | Because of this, we measure metrics by default with G in test mode (using the BatchNorm running stat estimates instead of computing standing stats as in the paper). We do still support standing stats (see the sample.sh scripts). 95 | This could also conceivably result in gradients from the earlier accumulations being stale, but in practice this does not appear to be a problem. 96 | - The currently provided pretrained models were not trained with orthogonal regularization. Training without ortho reg seems to increase the probability that models will not be amenable to truncation, 97 | but it looks like this particular model got a winning ticket. Regardless, we provide two highly optimized (fast and minimal memory consumption) ortho reg implementations which directly compute the ortho reg. gradients. 98 | 99 | ## A Note On The Design Of This Repo 100 | This code is designed from the ground up to serve as an extensible, hackable base for further research code. 101 | We've put a lot of thought into making sure the abstractions are the *right* thickness for research--not so thick as to be impenetrable, but not so thin as to be useless. 102 | The key idea is that if you want to experiment with a SOTA setup and make some modification (try out your own new loss function, architecture, self-attention block, etc) you should be able to easily do so just by dropping your code in one or two places, without having to worry about the rest of the codebase. 103 | Things like the use of self.which_conv and functools.partial in the BigGAN.py model definition were put together with this in mind, as was the design of the Spectral Norm class inheritance. 104 | 105 | With that said, this is a somewhat large codebase for a single project. While we tried to be thorough with the comments, if there's something you think could be more clear, better written, or better refactored, please feel free to raise an issue or a pull request. 106 | 107 | ## Feature Requests 108 | Want to work on or improve this code? There are a couple things this repo would benefit from, but which don't yet work. 109 | 110 | - Synchronized BatchNorm (AKA Cross-Replica BatchNorm). We tried out two variants of this, but for some unknown reason it crippled training each time. 111 | We have not tried the [apex](https://github.com/NVIDIA/apex) SyncBN as my school's servers are on ancient NVIDIA drivers that don't support it--apex would probably be a good place to start. 112 | - Mixed precision training and making use of Tensor cores. This repo includes a naive mixed-precision Adam implementation which works early in training but leads to early collapse, and doesn't do anything to activate Tensor cores (it just reduces memory consumption). 113 | As above, integrating [apex](https://github.com/NVIDIA/apex) into this code and employing its mixed-precision training techniques to take advantage of Tensor cores and reduce memory consumption could yield substantial speed gains. 114 | 115 | ## Misc Notes 116 | See [This directory](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for ImageNet labels. 117 | 118 | If you use this code, please cite 119 | ```text 120 | @inproceedings{ 121 | brock2018large, 122 | title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis}, 123 | author={Andrew Brock and Jeff Donahue and Karen Simonyan}, 124 | booktitle={International Conference on Learning Representations}, 125 | year={2019}, 126 | url={https://openreview.net/forum?id=B1xsqj09Fm}, 127 | } 128 | ``` 129 | 130 | ## Acknowledgments 131 | Thanks to Google for the generous cloud credit donations. 132 | 133 | [SyncBN](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) by Jiayuan Mao and Tete Xiao. 134 | 135 | [Progress bar](https://github.com/Lasagne/Recipes/tree/master/papers/densenet) originally from Jan Schlüter. 136 | 137 | Test metrics logger from [VoxNet.](https://github.com/dimatura/voxnet) 138 | 139 | PyTorch [implementation of cov](https://discuss.PyTorch.org/t/covariance-and-gradient-support/16217/2) from Modar M. Alfadly. 140 | 141 | PyTorch [fast Matrix Sqrt](https://github.com/msubhransu/matrix-sqrt) for FID from Tsung-Yu Lin and Subhransu Maji. 142 | 143 | TensorFlow Inception Score code from [OpenAI's Improved-GAN.](https://github.com/openai/improved-gan) 144 | 145 | -------------------------------------------------------------------------------- /biggan_pytorch/TFHub/README.md: -------------------------------------------------------------------------------- 1 | # BigGAN-PyTorch TFHub converter 2 | This dir contains scripts for taking the [pre-trained generator weights from TFHub](https://tfhub.dev/s?q=biggan) and porting them to BigGAN-Pytorch. 3 | 4 | In addition to the base libraries for BigGAN-PyTorch, to run this code you will need: 5 | 6 | TensorFlow 7 | TFHub 8 | parse 9 | 10 | Note that this code is only presently set up to run the ported models without truncation--you'll need to accumulate standing stats at each truncation level yourself if you wish to employ it. 11 | 12 | To port the 128x128 model from tfhub, produce a pretrained weights .pth file, and generate samples using all your GPUs, run 13 | 14 | `python converter.py -r 128 --generate_samples --parallel` -------------------------------------------------------------------------------- /biggan_pytorch/TFHub/biggan_v1.py: -------------------------------------------------------------------------------- 1 | # BigGAN V1: 2 | # This is now deprecated code used for porting the TFHub modules to pytorch, 3 | # included here for reference only. 4 | import numpy as np 5 | import torch 6 | from scipy.stats import truncnorm 7 | from torch import nn 8 | from torch.nn import Parameter 9 | from torch.nn import functional as F 10 | 11 | 12 | def l2normalize(v, eps=1e-4): 13 | return v / (v.norm() + eps) 14 | 15 | 16 | def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None): 17 | state = None if seed is None else np.random.RandomState(seed) 18 | values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state) 19 | return truncation * values 20 | 21 | 22 | def denorm(x): 23 | out = (x + 1) / 2 24 | return out.clamp_(0, 1) 25 | 26 | 27 | class SpectralNorm(nn.Module): 28 | def __init__(self, module, name='weight', power_iterations=1): 29 | super(SpectralNorm, self).__init__() 30 | self.module = module 31 | self.name = name 32 | self.power_iterations = power_iterations 33 | if not self._made_params(): 34 | self._make_params() 35 | 36 | def _update_u_v(self): 37 | u = getattr(self.module, self.name + "_u") 38 | v = getattr(self.module, self.name + "_v") 39 | w = getattr(self.module, self.name + "_bar") 40 | 41 | height = w.data.shape[0] 42 | _w = w.view(height, -1) 43 | for _ in range(self.power_iterations): 44 | v = l2normalize(torch.matmul(_w.t(), u)) 45 | u = l2normalize(torch.matmul(_w, v)) 46 | 47 | sigma = u.dot((_w).mv(v)) 48 | setattr(self.module, self.name, w / sigma.expand_as(w)) 49 | 50 | def _made_params(self): 51 | try: 52 | getattr(self.module, self.name + "_u") 53 | getattr(self.module, self.name + "_v") 54 | getattr(self.module, self.name + "_bar") 55 | return True 56 | except AttributeError: 57 | return False 58 | 59 | def _make_params(self): 60 | w = getattr(self.module, self.name) 61 | 62 | height = w.data.shape[0] 63 | width = w.view(height, -1).data.shape[1] 64 | 65 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 66 | v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 67 | u.data = l2normalize(u.data) 68 | v.data = l2normalize(v.data) 69 | w_bar = Parameter(w.data) 70 | 71 | del self.module._parameters[self.name] 72 | self.module.register_parameter(self.name + "_u", u) 73 | self.module.register_parameter(self.name + "_v", v) 74 | self.module.register_parameter(self.name + "_bar", w_bar) 75 | 76 | def forward(self, *args): 77 | self._update_u_v() 78 | return self.module.forward(*args) 79 | 80 | 81 | class SelfAttention(nn.Module): 82 | """ Self Attention Layer""" 83 | 84 | def __init__(self, in_dim, activation=F.relu): 85 | super().__init__() 86 | self.chanel_in = in_dim 87 | self.activation = activation 88 | 89 | self.theta = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) 90 | self.phi = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False)) 91 | self.pool = nn.MaxPool2d(2, 2) 92 | self.g = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False)) 93 | self.o_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False)) 94 | self.gamma = nn.Parameter(torch.zeros(1)) 95 | 96 | self.softmax = nn.Softmax(dim=-1) 97 | 98 | def forward(self, x): 99 | m_batchsize, C, width, height = x.size() 100 | N = height * width 101 | 102 | theta = self.theta(x) 103 | phi = self.phi(x) 104 | phi = self.pool(phi) 105 | phi = phi.view(m_batchsize, -1, N // 4) 106 | theta = theta.view(m_batchsize, -1, N) 107 | theta = theta.permute(0, 2, 1) 108 | attention = self.softmax(torch.bmm(theta, phi)) 109 | g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4) 110 | attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view(m_batchsize, -1, width, height) 111 | out = self.o_conv(attn_g) 112 | return self.gamma * out + x 113 | 114 | 115 | class ConditionalBatchNorm2d(nn.Module): 116 | def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1): 117 | super().__init__() 118 | self.num_features = num_features 119 | self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum) 120 | self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) 121 | self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False)) 122 | 123 | def forward(self, x, y): 124 | out = self.bn(x) 125 | gamma = self.gamma_embed(y) + 1 126 | beta = self.beta_embed(y) 127 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 128 | return out 129 | 130 | 131 | class GBlock(nn.Module): 132 | def __init__( 133 | self, 134 | in_channel, 135 | out_channel, 136 | kernel_size=[3, 3], 137 | padding=1, 138 | stride=1, 139 | n_class=None, 140 | bn=True, 141 | activation=F.relu, 142 | upsample=True, 143 | downsample=False, 144 | z_dim=148, 145 | ): 146 | super().__init__() 147 | 148 | self.conv0 = SpectralNorm( 149 | nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) 150 | ) 151 | self.conv1 = SpectralNorm( 152 | nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True) 153 | ) 154 | 155 | self.skip_proj = False 156 | if in_channel != out_channel or upsample or downsample: 157 | self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0)) 158 | self.skip_proj = True 159 | 160 | self.upsample = upsample 161 | self.downsample = downsample 162 | self.activation = activation 163 | self.bn = bn 164 | if bn: 165 | self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim) 166 | self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim) 167 | 168 | def forward(self, input, condition=None): 169 | out = input 170 | 171 | if self.bn: 172 | out = self.HyperBN(out, condition) 173 | out = self.activation(out) 174 | if self.upsample: 175 | out = F.interpolate(out, scale_factor=2) 176 | out = self.conv0(out) 177 | if self.bn: 178 | out = self.HyperBN_1(out, condition) 179 | out = self.activation(out) 180 | out = self.conv1(out) 181 | 182 | if self.downsample: 183 | out = F.avg_pool2d(out, 2) 184 | 185 | if self.skip_proj: 186 | skip = input 187 | if self.upsample: 188 | skip = F.interpolate(skip, scale_factor=2) 189 | skip = self.conv_sc(skip) 190 | if self.downsample: 191 | skip = F.avg_pool2d(skip, 2) 192 | else: 193 | skip = input 194 | return out + skip 195 | 196 | 197 | class Generator128(nn.Module): 198 | def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False): 199 | super().__init__() 200 | 201 | self.linear = nn.Linear(n_class, 128, bias=False) 202 | 203 | if debug: 204 | chn = 8 205 | 206 | self.first_view = 16 * chn 207 | 208 | self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) 209 | 210 | z_dim = code_dim + 28 211 | 212 | self.GBlock = nn.ModuleList([ 213 | GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), 214 | GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), 215 | GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), 216 | GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), 217 | GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), 218 | ]) 219 | 220 | self.sa_id = 4 221 | self.num_split = len(self.GBlock) + 1 222 | self.attention = SelfAttention(2 * chn) 223 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) 224 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) 225 | 226 | def forward(self, input, class_id): 227 | codes = torch.chunk(input, self.num_split, 1) 228 | class_emb = self.linear(class_id) # 128 229 | 230 | out = self.G_linear(codes[0]) 231 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) 232 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): 233 | if i == self.sa_id: 234 | out = self.attention(out) 235 | condition = torch.cat([code, class_emb], 1) 236 | out = GBlock(out, condition) 237 | 238 | out = self.ScaledCrossReplicaBN(out) 239 | out = F.relu(out) 240 | out = self.colorize(out) 241 | return torch.tanh(out) 242 | 243 | 244 | class Generator256(nn.Module): 245 | def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False): 246 | super().__init__() 247 | 248 | self.linear = nn.Linear(n_class, 128, bias=False) 249 | 250 | if debug: 251 | chn = 8 252 | 253 | self.first_view = 16 * chn 254 | 255 | self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn)) 256 | 257 | self.GBlock = nn.ModuleList([ 258 | GBlock(16 * chn, 16 * chn, n_class=n_class), 259 | GBlock(16 * chn, 8 * chn, n_class=n_class), 260 | GBlock(8 * chn, 8 * chn, n_class=n_class), 261 | GBlock(8 * chn, 4 * chn, n_class=n_class), 262 | GBlock(4 * chn, 2 * chn, n_class=n_class), 263 | GBlock(2 * chn, 1 * chn, n_class=n_class), 264 | ]) 265 | 266 | self.sa_id = 5 267 | self.num_split = len(self.GBlock) + 1 268 | self.attention = SelfAttention(2 * chn) 269 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4) 270 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) 271 | 272 | def forward(self, input, class_id): 273 | codes = torch.chunk(input, self.num_split, 1) 274 | class_emb = self.linear(class_id) # 128 275 | 276 | out = self.G_linear(codes[0]) 277 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) 278 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): 279 | if i == self.sa_id: 280 | out = self.attention(out) 281 | condition = torch.cat([code, class_emb], 1) 282 | out = GBlock(out, condition) 283 | 284 | out = self.ScaledCrossReplicaBN(out) 285 | out = F.relu(out) 286 | out = self.colorize(out) 287 | return torch.tanh(out) 288 | 289 | 290 | class Generator512(nn.Module): 291 | def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False): 292 | super().__init__() 293 | 294 | self.linear = nn.Linear(n_class, 128, bias=False) 295 | 296 | if debug: 297 | chn = 8 298 | 299 | self.first_view = 16 * chn 300 | 301 | self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn)) 302 | 303 | z_dim = code_dim + 16 304 | 305 | self.GBlock = nn.ModuleList([ 306 | GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim), 307 | GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), 308 | GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim), 309 | GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim), 310 | GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim), 311 | GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), 312 | GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim), 313 | ]) 314 | 315 | self.sa_id = 4 316 | self.num_split = len(self.GBlock) + 1 317 | self.attention = SelfAttention(4 * chn) 318 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn) 319 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1)) 320 | 321 | def forward(self, input, class_id): 322 | codes = torch.chunk(input, self.num_split, 1) 323 | class_emb = self.linear(class_id) # 128 324 | 325 | out = self.G_linear(codes[0]) 326 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2) 327 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)): 328 | if i == self.sa_id: 329 | out = self.attention(out) 330 | condition = torch.cat([code, class_emb], 1) 331 | out = GBlock(out, condition) 332 | 333 | out = self.ScaledCrossReplicaBN(out) 334 | out = F.relu(out) 335 | out = self.colorize(out) 336 | return torch.tanh(out) 337 | 338 | 339 | class Discriminator(nn.Module): 340 | def __init__(self, n_class=1000, chn=96, debug=False): 341 | super().__init__() 342 | 343 | def conv(in_channel, out_channel, downsample=True): 344 | return GBlock(in_channel, out_channel, bn=False, upsample=False, downsample=downsample) 345 | 346 | if debug: 347 | chn = 8 348 | self.debug = debug 349 | 350 | self.pre_conv = nn.Sequential( 351 | SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)), 352 | nn.ReLU(), 353 | SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)), 354 | nn.AvgPool2d(2), 355 | ) 356 | self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1)) 357 | 358 | self.conv = nn.Sequential( 359 | conv(1 * chn, 1 * chn, downsample=True), 360 | conv(1 * chn, 2 * chn, downsample=True), 361 | SelfAttention(2 * chn), 362 | conv(2 * chn, 2 * chn, downsample=True), 363 | conv(2 * chn, 4 * chn, downsample=True), 364 | conv(4 * chn, 8 * chn, downsample=True), 365 | conv(8 * chn, 8 * chn, downsample=True), 366 | conv(8 * chn, 16 * chn, downsample=True), 367 | conv(16 * chn, 16 * chn, downsample=False), 368 | ) 369 | 370 | self.linear = SpectralNorm(nn.Linear(16 * chn, 1)) 371 | 372 | self.embed = nn.Embedding(n_class, 16 * chn) 373 | self.embed.weight.data.uniform_(-0.1, 0.1) 374 | self.embed = SpectralNorm(self.embed) 375 | 376 | def forward(self, input, class_id): 377 | 378 | out = self.pre_conv(input) 379 | out += self.pre_skip(F.avg_pool2d(input, 2)) 380 | out = self.conv(out) 381 | out = F.relu(out) 382 | out = out.view(out.size(0), out.size(1), -1) 383 | out = out.sum(2) 384 | out_linear = self.linear(out).squeeze(1) 385 | embed = self.embed(class_id) 386 | 387 | prod = (out * embed).sum(1) 388 | 389 | return out_linear + prod -------------------------------------------------------------------------------- /biggan_pytorch/TFHub/converter.py: -------------------------------------------------------------------------------- 1 | """Utilities for converting TFHub BigGAN generator weights to PyTorch. 2 | 3 | Recommended usage: 4 | 5 | To convert all BigGAN variants and generate test samples, use: 6 | 7 | ```bash 8 | CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples 9 | ``` 10 | 11 | See `parse_args` for additional options. 12 | """ 13 | 14 | import argparse 15 | import os 16 | import sys 17 | 18 | import h5py 19 | import torch 20 | import torch.nn as nn 21 | from torchvision.utils import save_image 22 | import tensorflow as tf 23 | import tensorflow_hub as hub 24 | import parse 25 | 26 | # import reference biggan from this folder 27 | import biggan_v1 as biggan_for_conversion 28 | 29 | # Import model from main folder 30 | sys.path.append('..') 31 | import BigGAN 32 | 33 | 34 | 35 | 36 | DEVICE = 'cuda' 37 | HDF5_TMPL = 'biggan-{}.h5' 38 | PTH_TMPL = 'biggan-{}.pth' 39 | MODULE_PATH_TMPL = 'https://tfhub.dev/deepmind/biggan-{}/2' 40 | Z_DIMS = { 41 | 128: 120, 42 | 256: 140, 43 | 512: 128} 44 | RESOLUTIONS = list(Z_DIMS) 45 | 46 | 47 | def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False): 48 | """Loads TFHub weights and saves them to intermediate HDF5 file. 49 | 50 | Args: 51 | module_path ([Path-like]): Path to TFHub module. 52 | hdf5_path ([Path-like]): Path to output HDF5 file. 53 | 54 | Returns: 55 | [h5py.File]: Loaded hdf5 file containing module weights. 56 | """ 57 | if os.path.exists(hdf5_path) and (not redownload): 58 | print('Loading BigGAN hdf5 file from:', hdf5_path) 59 | return h5py.File(hdf5_path, 'r') 60 | 61 | print('Loading BigGAN module from:', module_path) 62 | tf.reset_default_graph() 63 | hub.Module(module_path) 64 | print('Loaded BigGAN module from:', module_path) 65 | 66 | initializer = tf.global_variables_initializer() 67 | sess = tf.Session() 68 | sess.run(initializer) 69 | 70 | print('Saving BigGAN weights to :', hdf5_path) 71 | h5f = h5py.File(hdf5_path, 'w') 72 | for var in tf.global_variables(): 73 | val = sess.run(var) 74 | h5f.create_dataset(var.name, data=val) 75 | print(f'Saving {var.name} with shape {val.shape}') 76 | h5f.close() 77 | return h5py.File(hdf5_path, 'r') 78 | 79 | 80 | class TFHub2Pytorch(object): 81 | 82 | TF_ROOT = 'module' 83 | 84 | NUM_GBLOCK = { 85 | 128: 5, 86 | 256: 6, 87 | 512: 7 88 | } 89 | 90 | w = 'w' 91 | b = 'b' 92 | u = 'u0' 93 | v = 'u1' 94 | gamma = 'gamma' 95 | beta = 'beta' 96 | 97 | def __init__(self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False): 98 | self.state_dict = state_dict 99 | self.tf_weights = tf_weights 100 | self.resolution = resolution 101 | self.verbose = verbose 102 | if load_ema: 103 | for name in ['w', 'b', 'gamma', 'beta']: 104 | setattr(self, name, getattr(self, name) + '/ema_b999900') 105 | 106 | def load(self): 107 | self.load_generator() 108 | return self.state_dict 109 | 110 | def load_generator(self): 111 | GENERATOR_ROOT = os.path.join(self.TF_ROOT, 'Generator') 112 | 113 | for i in range(self.NUM_GBLOCK[self.resolution]): 114 | name_tf = os.path.join(GENERATOR_ROOT, 'GBlock') 115 | name_tf += f'_{i}' if i != 0 else '' 116 | self.load_GBlock(f'GBlock.{i}.', name_tf) 117 | 118 | self.load_attention('attention.', os.path.join(GENERATOR_ROOT, 'attention')) 119 | self.load_linear('linear', os.path.join(self.TF_ROOT, 'linear'), bias=False) 120 | self.load_snlinear('G_linear', os.path.join(GENERATOR_ROOT, 'G_Z', 'G_linear')) 121 | self.load_colorize('colorize', os.path.join(GENERATOR_ROOT, 'conv_2d')) 122 | self.load_ScaledCrossReplicaBNs('ScaledCrossReplicaBN', 123 | os.path.join(GENERATOR_ROOT, 'ScaledCrossReplicaBN')) 124 | 125 | def load_linear(self, name_pth, name_tf, bias=True): 126 | self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) 127 | if bias: 128 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) 129 | 130 | def load_snlinear(self, name_pth, name_tf, bias=True): 131 | self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() 132 | self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() 133 | self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0) 134 | if bias: 135 | self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b) 136 | 137 | def load_colorize(self, name_pth, name_tf): 138 | self.load_snconv(name_pth, name_tf) 139 | 140 | def load_GBlock(self, name_pth, name_tf): 141 | self.load_convs(name_pth, name_tf) 142 | self.load_HyperBNs(name_pth, name_tf) 143 | 144 | def load_convs(self, name_pth, name_tf): 145 | self.load_snconv(name_pth + 'conv0', os.path.join(name_tf, 'conv0')) 146 | self.load_snconv(name_pth + 'conv1', os.path.join(name_tf, 'conv1')) 147 | self.load_snconv(name_pth + 'conv_sc', os.path.join(name_tf, 'conv_sc')) 148 | 149 | def load_snconv(self, name_pth, name_tf, bias=True): 150 | if self.verbose: 151 | print(f'loading: {name_pth} from {name_tf}') 152 | self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() 153 | self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() 154 | self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) 155 | if bias: 156 | self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b).squeeze() 157 | 158 | def load_conv(self, name_pth, name_tf, bias=True): 159 | 160 | self.state_dict[name_pth + '.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze() 161 | self.state_dict[name_pth + '.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze() 162 | self.state_dict[name_pth + '.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1) 163 | if bias: 164 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b) 165 | 166 | def load_HyperBNs(self, name_pth, name_tf): 167 | self.load_HyperBN(name_pth + 'HyperBN', os.path.join(name_tf, 'HyperBN')) 168 | self.load_HyperBN(name_pth + 'HyperBN_1', os.path.join(name_tf, 'HyperBN_1')) 169 | 170 | def load_ScaledCrossReplicaBNs(self, name_pth, name_tf): 171 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.beta).squeeze() 172 | self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.gamma).squeeze() 173 | self.state_dict[name_pth + '.running_mean'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_mean') 174 | self.state_dict[name_pth + '.running_var'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_var') 175 | self.state_dict[name_pth + '.num_batches_tracked'] = torch.tensor( 176 | self.tf_weights[os.path.join(name_tf + 'bn', 'accumulation_counter:0')][()], dtype=torch.float32) 177 | 178 | def load_HyperBN(self, name_pth, name_tf): 179 | if self.verbose: 180 | print(f'loading: {name_pth} from {name_tf}') 181 | beta = name_pth + '.beta_embed.module' 182 | gamma = name_pth + '.gamma_embed.module' 183 | self.state_dict[beta + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.u).squeeze() 184 | self.state_dict[gamma + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.u).squeeze() 185 | self.state_dict[beta + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.v).squeeze() 186 | self.state_dict[gamma + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.v).squeeze() 187 | self.state_dict[beta + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.w).permute(1, 0) 188 | self.state_dict[gamma + 189 | '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.w).permute(1, 0) 190 | 191 | cr_bn_name = name_tf.replace('HyperBN', 'CrossReplicaBN') 192 | self.state_dict[name_pth + '.bn.running_mean'] = self.load_tf_tensor(cr_bn_name, 'accumulated_mean') 193 | self.state_dict[name_pth + '.bn.running_var'] = self.load_tf_tensor(cr_bn_name, 'accumulated_var') 194 | self.state_dict[name_pth + '.bn.num_batches_tracked'] = torch.tensor( 195 | self.tf_weights[os.path.join(cr_bn_name, 'accumulation_counter:0')][()], dtype=torch.float32) 196 | 197 | def load_attention(self, name_pth, name_tf): 198 | 199 | self.load_snconv(name_pth + 'theta', os.path.join(name_tf, 'theta'), bias=False) 200 | self.load_snconv(name_pth + 'phi', os.path.join(name_tf, 'phi'), bias=False) 201 | self.load_snconv(name_pth + 'g', os.path.join(name_tf, 'g'), bias=False) 202 | self.load_snconv(name_pth + 'o_conv', os.path.join(name_tf, 'o_conv'), bias=False) 203 | self.state_dict[name_pth + 'gamma'] = self.load_tf_tensor(name_tf, self.gamma) 204 | 205 | def load_tf_tensor(self, prefix, var, device='0'): 206 | name = os.path.join(prefix, var) + f':{device}' 207 | return torch.from_numpy(self.tf_weights[name][:]) 208 | 209 | # Convert from v1: This function maps 210 | def convert_from_v1(hub_dict, resolution=128): 211 | weightname_dict = {'weight_u': 'u0', 'weight_bar': 'weight', 'bias': 'bias'} 212 | convnum_dict = {'conv0': 'conv1', 'conv1': 'conv2', 'conv_sc': 'conv_sc'} 213 | attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution] 214 | hub2me = {'linear.weight': 'shared.weight', # This is actually the shared weight 215 | # Linear stuff 216 | 'G_linear.module.weight_bar': 'linear.weight', 217 | 'G_linear.module.bias': 'linear.bias', 218 | 'G_linear.module.weight_u': 'linear.u0', 219 | # output layer stuff 220 | 'ScaledCrossReplicaBN.weight': 'output_layer.0.gain', 221 | 'ScaledCrossReplicaBN.bias': 'output_layer.0.bias', 222 | 'ScaledCrossReplicaBN.running_mean': 'output_layer.0.stored_mean', 223 | 'ScaledCrossReplicaBN.running_var': 'output_layer.0.stored_var', 224 | 'colorize.module.weight_bar': 'output_layer.2.weight', 225 | 'colorize.module.bias': 'output_layer.2.bias', 226 | 'colorize.module.weight_u': 'output_layer.2.u0', 227 | # Attention stuff 228 | 'attention.gamma': 'blocks.%d.1.gamma' % attention_blocknum, 229 | 'attention.theta.module.weight_u': 'blocks.%d.1.theta.u0' % attention_blocknum, 230 | 'attention.theta.module.weight_bar': 'blocks.%d.1.theta.weight' % attention_blocknum, 231 | 'attention.phi.module.weight_u': 'blocks.%d.1.phi.u0' % attention_blocknum, 232 | 'attention.phi.module.weight_bar': 'blocks.%d.1.phi.weight' % attention_blocknum, 233 | 'attention.g.module.weight_u': 'blocks.%d.1.g.u0' % attention_blocknum, 234 | 'attention.g.module.weight_bar': 'blocks.%d.1.g.weight' % attention_blocknum, 235 | 'attention.o_conv.module.weight_u': 'blocks.%d.1.o.u0' % attention_blocknum, 236 | 'attention.o_conv.module.weight_bar':'blocks.%d.1.o.weight' % attention_blocknum, 237 | } 238 | 239 | # Loop over the hub dict and build the hub2me map 240 | for name in hub_dict.keys(): 241 | if 'GBlock' in name: 242 | if 'HyperBN' not in name: # it's a conv 243 | out = parse.parse('GBlock.{:d}.{}.module.{}',name) 244 | blocknum, convnum, weightname = out 245 | if weightname not in weightname_dict: 246 | continue # else hyperBN in 247 | out_name = 'blocks.%d.0.%s.%s' % (blocknum, convnum_dict[convnum], weightname_dict[weightname]) # Increment conv number by 1 248 | else: # hyperbn not conv 249 | BNnum = 2 if 'HyperBN_1' in name else 1 250 | if 'embed' in name: 251 | out = parse.parse('GBlock.{:d}.{}.module.{}',name) 252 | blocknum, gamma_or_beta, weightname = out 253 | if weightname not in weightname_dict: # Ignore weight_v 254 | continue 255 | out_name = 'blocks.%d.0.bn%d.%s.%s' % (blocknum, BNnum, 'gain' if 'gamma' in gamma_or_beta else 'bias', weightname_dict[weightname]) 256 | else: 257 | out = parse.parse('GBlock.{:d}.{}.bn.{}',name) 258 | blocknum, dummy, mean_or_var = out 259 | if 'num_batches_tracked' in mean_or_var: 260 | continue 261 | out_name = 'blocks.%d.0.bn%d.%s' % (blocknum, BNnum, 'stored_mean' if 'mean' in mean_or_var else 'stored_var') 262 | hub2me[name] = out_name 263 | 264 | 265 | # Invert the hub2me map 266 | me2hub = {hub2me[item]: item for item in hub2me} 267 | new_dict = {} 268 | dimz_dict = {128: 20, 256: 20, 512:16} 269 | for item in me2hub: 270 | # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs 271 | if ('bn' in item and 'weight' in item) and ('gain' in item or 'bias' in item) and ('output_layer' not in item): 272 | new_dict[item] = torch.cat([hub_dict[me2hub[item]][:, -128:], hub_dict[me2hub[item]][:, :dimz_dict[resolution]]], 1) 273 | # Reshape the first linear weight, bias, and u0 274 | elif item == 'linear.weight': 275 | new_dict[item] = hub_dict[me2hub[item]].contiguous().view(4, 4, 96 * 16, -1).permute(2,0,1,3).contiguous().view(-1,dimz_dict[resolution]) 276 | elif item == 'linear.bias': 277 | new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(-1) 278 | elif item == 'linear.u0': 279 | new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(1, -1) 280 | elif me2hub[item] == 'linear.weight': # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER 281 | # Transpose shared weight so that it's an embedding 282 | new_dict[item] = hub_dict[me2hub[item]].t() 283 | elif 'weight_u' in me2hub[item]: # Unsqueeze u0s 284 | new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0) 285 | else: 286 | new_dict[item] = hub_dict[me2hub[item]] 287 | return new_dict 288 | 289 | def get_config(resolution): 290 | attn_dict = {128: '64', 256: '128', 512: '64'} 291 | dim_z_dict = {128: 120, 256: 140, 512: 128} 292 | config = {'G_param': 'SN', 'D_param': 'SN', 293 | 'G_ch': 96, 'D_ch': 96, 294 | 'D_wide': True, 'G_shared': True, 295 | 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], 296 | 'hier': True, 'cross_replica': False, 297 | 'mybn': False, 'G_activation': nn.ReLU(inplace=True), 298 | 'G_attn': attn_dict[resolution], 299 | 'norm_style': 'bn', 300 | 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, 301 | 'G_fp16': False, 'G_mixed_precision': False, 302 | 'accumulate_stats': False, 'num_standing_accumulations': 16, 303 | 'G_eval_mode': True, 304 | 'BN_eps': 1e-04, 'SN_eps': 1e-04, 305 | 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, 306 | 'n_classes': 1000} 307 | return config 308 | 309 | 310 | def convert_biggan(resolution, weight_dir, redownload=False, no_ema=False, verbose=False): 311 | module_path = MODULE_PATH_TMPL.format(resolution) 312 | hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution)) 313 | pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution)) 314 | 315 | tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload) 316 | G_temp = getattr(biggan_for_conversion, f'Generator{resolution}')() 317 | state_dict_temp = G_temp.state_dict() 318 | 319 | converter = TFHub2Pytorch(state_dict_temp, tf_weights, resolution=resolution, 320 | load_ema=(not no_ema), verbose=verbose) 321 | state_dict_v1 = converter.load() 322 | state_dict = convert_from_v1(state_dict_v1, resolution) 323 | # Get the config, build the model 324 | config = get_config(resolution) 325 | G = BigGAN.Generator(**config) 326 | G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries 327 | torch.save(state_dict, pth_path) 328 | 329 | # output_location ='pretrained_weights/TFHub-PyTorch-128.pth' 330 | 331 | return G 332 | 333 | 334 | def generate_sample(G, z_dim, batch_size, filename, parallel=False): 335 | 336 | G.eval() 337 | G.to(DEVICE) 338 | with torch.no_grad(): 339 | z = torch.randn(batch_size, G.dim_z).to(DEVICE) 340 | y = torch.randint(low=0, high=1000, size=(batch_size,), 341 | device=DEVICE, dtype=torch.int64, requires_grad=False) 342 | if parallel: 343 | images = nn.parallel.data_parallel(G, (z, G.shared(y))) 344 | else: 345 | images = G(z, G.shared(y)) 346 | save_image(images, filename, scale_each=True, normalize=True) 347 | 348 | def parse_args(): 349 | usage = 'Parser for conversion script.' 350 | parser = argparse.ArgumentParser(description=usage) 351 | parser.add_argument( 352 | '--resolution', '-r', type=int, default=None, choices=[128, 256, 512], 353 | help='Resolution of TFHub module to convert. Converts all resolutions if None.') 354 | parser.add_argument( 355 | '--redownload', action='store_true', default=False, 356 | help='Redownload weights and overwrite current hdf5 file, if present.') 357 | parser.add_argument( 358 | '--weights_dir', type=str, default='pretrained_weights') 359 | parser.add_argument( 360 | '--samples_dir', type=str, default='pretrained_samples') 361 | parser.add_argument( 362 | '--no_ema', action='store_true', default=False, 363 | help='Do not load ema weights.') 364 | parser.add_argument( 365 | '--verbose', action='store_true', default=False, 366 | help='Additionally logging.') 367 | parser.add_argument( 368 | '--generate_samples', action='store_true', default=False, 369 | help='Generate test sample with pretrained model.') 370 | parser.add_argument( 371 | '--batch_size', type=int, default=64, 372 | help='Batch size used for test sample.') 373 | parser.add_argument( 374 | '--parallel', action='store_true', default=False, 375 | help='Parallelize G?') 376 | args = parser.parse_args() 377 | return args 378 | 379 | 380 | if __name__ == '__main__': 381 | 382 | args = parse_args() 383 | os.makedirs(args.weights_dir, exist_ok=True) 384 | os.makedirs(args.samples_dir, exist_ok=True) 385 | 386 | if args.resolution is not None: 387 | G = convert_biggan(args.resolution, args.weights_dir, 388 | redownload=args.redownload, 389 | no_ema=args.no_ema, verbose=args.verbose) 390 | if args.generate_samples: 391 | filename = os.path.join(args.samples_dir, f'biggan{args.resolution}_samples.jpg') 392 | print('Generating samples...') 393 | generate_sample(G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel) 394 | else: 395 | for res in RESOLUTIONS: 396 | G = convert_biggan(res, args.weights_dir, 397 | redownload=args.redownload, 398 | no_ema=args.no_ema, verbose=args.verbose) 399 | if args.generate_samples: 400 | filename = os.path.join(args.samples_dir, f'biggan{res}_samples.jpg') 401 | print('Generating samples...') 402 | generate_sample(G, Z_DIMS[res], args.batch_size, filename, args.parallel) -------------------------------------------------------------------------------- /biggan_pytorch/calculate_inception_moments.py: -------------------------------------------------------------------------------- 1 | ''' Calculate Inception Moments 2 | This script iterates over the dataset and calculates the moments of the 3 | activations of the Inception net (needed for FID), and also returns 4 | the Inception Score of the training data. 5 | 6 | Note that if you don't shuffle the data, the IS of true data will be under- 7 | estimated as it is label-ordered. By default, the data is not shuffled 8 | so as to reduce non-determinism. ''' 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | import utils 15 | import inception_utils 16 | from tqdm import tqdm, trange 17 | from argparse import ArgumentParser 18 | 19 | def prepare_parser(): 20 | usage = 'Calculate and store inception metrics.' 21 | parser = ArgumentParser(description=usage) 22 | parser.add_argument( 23 | '--dataset', type=str, default='I128_hdf5', 24 | help='Which Dataset to train on, out of I128, I256, C10, C100...' 25 | 'Append _hdf5 to use the hdf5 version of the dataset. (default: %(default)s)') 26 | parser.add_argument( 27 | '--data_root', type=str, default='data', 28 | help='Default location where data is stored (default: %(default)s)') 29 | parser.add_argument( 30 | '--batch_size', type=int, default=64, 31 | help='Default overall batchsize (default: %(default)s)') 32 | parser.add_argument( 33 | '--parallel', action='store_true', default=False, 34 | help='Train with multiple GPUs (default: %(default)s)') 35 | parser.add_argument( 36 | '--augment', action='store_true', default=False, 37 | help='Augment with random crops and flips (default: %(default)s)') 38 | parser.add_argument( 39 | '--num_workers', type=int, default=8, 40 | help='Number of dataloader workers (default: %(default)s)') 41 | parser.add_argument( 42 | '--shuffle', action='store_true', default=False, 43 | help='Shuffle the data? (default: %(default)s)') 44 | parser.add_argument( 45 | '--seed', type=int, default=0, 46 | help='Random seed to use.') 47 | return parser 48 | 49 | def run(config): 50 | # Get loader 51 | config['drop_last'] = False 52 | loaders = utils.get_data_loaders(**config) 53 | 54 | # Load inception net 55 | net = inception_utils.load_inception_net(parallel=config['parallel']) 56 | pool, logits, labels = [], [], [] 57 | device = 'cuda' 58 | for i, (x, y) in enumerate(tqdm(loaders[0])): 59 | x = x.to(device) 60 | with torch.no_grad(): 61 | pool_val, logits_val = net(x) 62 | pool += [np.asarray(pool_val.cpu())] 63 | logits += [np.asarray(F.softmax(logits_val, 1).cpu())] 64 | labels += [np.asarray(y.cpu())] 65 | 66 | pool, logits, labels = [np.concatenate(item, 0) for item in [pool, logits, labels]] 67 | # uncomment to save pool, logits, and labels to disk 68 | # print('Saving pool, logits, and labels to disk...') 69 | # np.savez(config['dataset']+'_inception_activations.npz', 70 | # {'pool': pool, 'logits': logits, 'labels': labels}) 71 | # Calculate inception metrics and report them 72 | print('Calculating inception metrics...') 73 | IS_mean, IS_std = inception_utils.calculate_inception_score(logits) 74 | print('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config['dataset'], IS_mean, IS_std)) 75 | # Prepare mu and sigma, save to disk. Remove "hdf5" by default 76 | # (the FID code also knows to strip "hdf5") 77 | print('Calculating means and covariances...') 78 | mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False) 79 | print('Saving calculated means and covariances to disk...') 80 | np.savez(config['dataset'].strip('_hdf5')+'_inception_moments.npz', **{'mu' : mu, 'sigma' : sigma}) 81 | 82 | def main(): 83 | # parse command line 84 | parser = prepare_parser() 85 | config = vars(parser.parse_args()) 86 | print(config) 87 | run(config) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() -------------------------------------------------------------------------------- /biggan_pytorch/datasets.py: -------------------------------------------------------------------------------- 1 | ''' Datasets 2 | This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets 3 | ''' 4 | import os 5 | import os.path 6 | import sys 7 | from PIL import Image 8 | import numpy as np 9 | from tqdm import tqdm, trange 10 | 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | from torchvision.datasets.utils import download_url, check_integrity 14 | import torch.utils.data as data 15 | from torch.utils.data import DataLoader 16 | 17 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 18 | 19 | 20 | def is_image_file(filename): 21 | """Checks if a file is an image. 22 | 23 | Args: 24 | filename (string): path to a file 25 | 26 | Returns: 27 | bool: True if the filename ends with a known image extension 28 | """ 29 | filename_lower = filename.lower() 30 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 31 | 32 | 33 | def find_classes(dir): 34 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 35 | classes.sort() 36 | class_to_idx = {classes[i]: i for i in range(len(classes))} 37 | return classes, class_to_idx 38 | 39 | 40 | def make_dataset(dir, class_to_idx): 41 | images = [] 42 | dir = os.path.expanduser(dir) 43 | for target in tqdm(sorted(os.listdir(dir))): 44 | d = os.path.join(dir, target) 45 | if not os.path.isdir(d): 46 | continue 47 | 48 | for root, _, fnames in sorted(os.walk(d)): 49 | for fname in sorted(fnames): 50 | if is_image_file(fname): 51 | path = os.path.join(root, fname) 52 | item = (path, class_to_idx[target]) 53 | images.append(item) 54 | 55 | return images 56 | 57 | 58 | def pil_loader(path): 59 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 60 | with open(path, 'rb') as f: 61 | img = Image.open(f) 62 | return img.convert('RGB') 63 | 64 | 65 | def accimage_loader(path): 66 | import accimage 67 | try: 68 | return accimage.Image(path) 69 | except IOError: 70 | # Potentially a decoding problem, fall back to PIL.Image 71 | return pil_loader(path) 72 | 73 | 74 | def default_loader(path): 75 | from torchvision import get_image_backend 76 | if get_image_backend() == 'accimage': 77 | return accimage_loader(path) 78 | else: 79 | return pil_loader(path) 80 | 81 | 82 | class ImageFolder(data.Dataset): 83 | """A generic data loader where the images are arranged in this way: :: 84 | 85 | root/dogball/xxx.png 86 | root/dogball/xxy.png 87 | root/dogball/xxz.png 88 | 89 | root/cat/123.png 90 | root/cat/nsdf3.png 91 | root/cat/asd932_.png 92 | 93 | Args: 94 | root (string): Root directory path. 95 | transform (callable, optional): A function/transform that takes in an PIL image 96 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 97 | target_transform (callable, optional): A function/transform that takes in the 98 | target and transforms it. 99 | loader (callable, optional): A function to load an image given its path. 100 | 101 | Attributes: 102 | classes (list): List of the class names. 103 | class_to_idx (dict): Dict with items (class_name, class_index). 104 | imgs (list): List of (image path, class_index) tuples 105 | """ 106 | 107 | def __init__(self, root, transform=None, target_transform=None, 108 | loader=default_loader, load_in_mem=False, 109 | index_filename='imagenet_imgs.npz', **kwargs): 110 | classes, class_to_idx = find_classes(root) 111 | # Load pre-computed image directory walk 112 | if os.path.exists(index_filename): 113 | print('Loading pre-saved Index file %s...' % index_filename) 114 | imgs = np.load(index_filename)['imgs'] 115 | # If first time, walk the folder directory and save the 116 | # results to a pre-computed file. 117 | else: 118 | print('Generating Index file %s...' % index_filename) 119 | imgs = make_dataset(root, class_to_idx) 120 | np.savez_compressed(index_filename, **{'imgs' : imgs}) 121 | if len(imgs) == 0: 122 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 123 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 124 | 125 | self.root = root 126 | self.imgs = imgs 127 | self.classes = classes 128 | self.class_to_idx = class_to_idx 129 | self.transform = transform 130 | self.target_transform = target_transform 131 | self.loader = loader 132 | self.load_in_mem = load_in_mem 133 | 134 | if self.load_in_mem: 135 | print('Loading all images into memory...') 136 | self.data, self.labels = [], [] 137 | for index in tqdm(range(len(self.imgs))): 138 | path, target = imgs[index][0], imgs[index][1] 139 | self.data.append(self.transform(self.loader(path))) 140 | self.labels.append(target) 141 | 142 | 143 | def __getitem__(self, index): 144 | """ 145 | Args: 146 | index (int): Index 147 | 148 | Returns: 149 | tuple: (image, target) where target is class_index of the target class. 150 | """ 151 | if self.load_in_mem: 152 | img = self.data[index] 153 | target = self.labels[index] 154 | else: 155 | path, target = self.imgs[index] 156 | img = self.loader(str(path)) 157 | if self.transform is not None: 158 | img = self.transform(img) 159 | 160 | if self.target_transform is not None: 161 | target = self.target_transform(target) 162 | 163 | # print(img.size(), target) 164 | return img, int(target) 165 | 166 | def __len__(self): 167 | return len(self.imgs) 168 | 169 | def __repr__(self): 170 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 171 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 172 | fmt_str += ' Root Location: {}\n'.format(self.root) 173 | tmp = ' Transforms (if any): ' 174 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 175 | tmp = ' Target Transforms (if any): ' 176 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 177 | return fmt_str 178 | 179 | 180 | ''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid 181 | having to load individual images all the time. ''' 182 | import h5py as h5 183 | import torch 184 | class ILSVRC_HDF5(data.Dataset): 185 | def __init__(self, root, transform=None, target_transform=None, 186 | load_in_mem=False, train=True,download=False, validate_seed=0, 187 | val_split=0, **kwargs): # last four are dummies 188 | 189 | self.root = root 190 | self.num_imgs = len(h5.File(root, 'r')['labels']) 191 | 192 | # self.transform = transform 193 | self.target_transform = target_transform 194 | 195 | # Set the transform here 196 | self.transform = transform 197 | 198 | # load the entire dataset into memory? 199 | self.load_in_mem = load_in_mem 200 | 201 | # If loading into memory, do so now 202 | if self.load_in_mem: 203 | print('Loading %s into memory...' % root) 204 | with h5.File(root,'r') as f: 205 | self.data = f['imgs'][:] 206 | self.labels = f['labels'][:] 207 | 208 | def __getitem__(self, index): 209 | """ 210 | Args: 211 | index (int): Index 212 | 213 | Returns: 214 | tuple: (image, target) where target is class_index of the target class. 215 | """ 216 | # If loaded the entire dataset in RAM, get image from memory 217 | if self.load_in_mem: 218 | img = self.data[index] 219 | target = self.labels[index] 220 | 221 | # Else load it from disk 222 | else: 223 | with h5.File(self.root,'r') as f: 224 | img = f['imgs'][index] 225 | target = f['labels'][index] 226 | 227 | 228 | # if self.transform is not None: 229 | # img = self.transform(img) 230 | # Apply my own transform 231 | img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2 232 | 233 | if self.target_transform is not None: 234 | target = self.target_transform(target) 235 | 236 | return img, int(target) 237 | 238 | def __len__(self): 239 | return self.num_imgs 240 | # return len(self.f['imgs']) 241 | 242 | import pickle 243 | class CIFAR10(dset.CIFAR10): 244 | 245 | def __init__(self, root, train=True, 246 | transform=None, target_transform=None, 247 | download=True, validate_seed=0, 248 | val_split=0, load_in_mem=True, **kwargs): 249 | self.root = os.path.expanduser(root) 250 | self.transform = transform 251 | self.target_transform = target_transform 252 | self.train = train # training set or test set 253 | self.val_split = val_split 254 | 255 | if download: 256 | self.download() 257 | 258 | if not self._check_integrity(): 259 | raise RuntimeError('Dataset not found or corrupted.' + 260 | ' You can use download=True to download it') 261 | 262 | # now load the picked numpy arrays 263 | self.data = [] 264 | self.labels= [] 265 | for fentry in self.train_list: 266 | f = fentry[0] 267 | file = os.path.join(self.root, self.base_folder, f) 268 | fo = open(file, 'rb') 269 | if sys.version_info[0] == 2: 270 | entry = pickle.load(fo) 271 | else: 272 | entry = pickle.load(fo, encoding='latin1') 273 | self.data.append(entry['data']) 274 | if 'labels' in entry: 275 | self.labels += entry['labels'] 276 | else: 277 | self.labels += entry['fine_labels'] 278 | fo.close() 279 | 280 | self.data = np.concatenate(self.data) 281 | # Randomly select indices for validation 282 | if self.val_split > 0: 283 | label_indices = [[] for _ in range(max(self.labels)+1)] 284 | for i,l in enumerate(self.labels): 285 | label_indices[l] += [i] 286 | label_indices = np.asarray(label_indices) 287 | 288 | # randomly grab 500 elements of each class 289 | np.random.seed(validate_seed) 290 | self.val_indices = [] 291 | for l_i in label_indices: 292 | self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)]) 293 | 294 | if self.train=='validate': 295 | self.data = self.data[self.val_indices] 296 | self.labels = list(np.asarray(self.labels)[self.val_indices]) 297 | 298 | self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32)) 299 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 300 | 301 | elif self.train: 302 | print(np.shape(self.data)) 303 | if self.val_split > 0: 304 | self.data = np.delete(self.data,self.val_indices,axis=0) 305 | self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0)) 306 | 307 | self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32)) 308 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 309 | else: 310 | f = self.test_list[0][0] 311 | file = os.path.join(self.root, self.base_folder, f) 312 | fo = open(file, 'rb') 313 | if sys.version_info[0] == 2: 314 | entry = pickle.load(fo) 315 | else: 316 | entry = pickle.load(fo, encoding='latin1') 317 | self.data = entry['data'] 318 | if 'labels' in entry: 319 | self.labels = entry['labels'] 320 | else: 321 | self.labels = entry['fine_labels'] 322 | fo.close() 323 | self.data = self.data.reshape((10000, 3, 32, 32)) 324 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 325 | 326 | def __getitem__(self, index): 327 | """ 328 | Args: 329 | index (int): Index 330 | Returns: 331 | tuple: (image, target) where target is index of the target class. 332 | """ 333 | img, target = self.data[index], self.labels[index] 334 | 335 | # doing this so that it is consistent with all other datasets 336 | # to return a PIL Image 337 | img = Image.fromarray(img) 338 | 339 | if self.transform is not None: 340 | img = self.transform(img) 341 | 342 | if self.target_transform is not None: 343 | target = self.target_transform(target) 344 | 345 | return img, target 346 | 347 | def __len__(self): 348 | return len(self.data) 349 | 350 | 351 | class CIFAR100(CIFAR10): 352 | base_folder = 'cifar-100-python' 353 | url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 354 | filename = "cifar-100-python.tar.gz" 355 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 356 | train_list = [ 357 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 358 | ] 359 | 360 | test_list = [ 361 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 362 | ] 363 | -------------------------------------------------------------------------------- /biggan_pytorch/imgs/D Singular Values.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/D Singular Values.png -------------------------------------------------------------------------------- /biggan_pytorch/imgs/DeepSamples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/DeepSamples.png -------------------------------------------------------------------------------- /biggan_pytorch/imgs/DogBall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/DogBall.png -------------------------------------------------------------------------------- /biggan_pytorch/imgs/G Singular Values.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/G Singular Values.png -------------------------------------------------------------------------------- /biggan_pytorch/imgs/IS_FID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/IS_FID.png -------------------------------------------------------------------------------- /biggan_pytorch/imgs/Losses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/Losses.png -------------------------------------------------------------------------------- /biggan_pytorch/imgs/header_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/header_image.jpg -------------------------------------------------------------------------------- /biggan_pytorch/imgs/interp_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/interp_sample.jpg -------------------------------------------------------------------------------- /biggan_pytorch/inception_tf13.py: -------------------------------------------------------------------------------- 1 | ''' Tensorflow inception score code 2 | Derived from https://github.com/openai/improved-gan 3 | Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 4 | THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE 5 | 6 | To use this code, run sample.py on your model with --sample_npz, and then 7 | pass the experiment name in the --experiment_name. 8 | This code also saves pool3 stats to an npz file for FID calculation 9 | ''' 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os.path 15 | import sys 16 | import tarfile 17 | import math 18 | from tqdm import tqdm, trange 19 | from argparse import ArgumentParser 20 | 21 | import numpy as np 22 | from six.moves import urllib 23 | import tensorflow as tf 24 | 25 | MODEL_DIR = '' 26 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 27 | softmax = None 28 | 29 | def prepare_parser(): 30 | usage = 'Parser for TF1.3- Inception Score scripts.' 31 | parser = ArgumentParser(description=usage) 32 | parser.add_argument( 33 | '--experiment_name', type=str, default='', 34 | help='Which experiment''s samples.npz file to pull and evaluate') 35 | parser.add_argument( 36 | '--experiment_root', type=str, default='samples', 37 | help='Default location where samples are stored (default: %(default)s)') 38 | parser.add_argument( 39 | '--batch_size', type=int, default=500, 40 | help='Default overall batchsize (default: %(default)s)') 41 | return parser 42 | 43 | 44 | def run(config): 45 | # Inception with TF1.3 or earlier. 46 | # Call this function with list of images. Each of elements should be a 47 | # numpy array with values ranging from 0 to 255. 48 | def get_inception_score(images, splits=10): 49 | assert(type(images) == list) 50 | assert(type(images[0]) == np.ndarray) 51 | assert(len(images[0].shape) == 3) 52 | assert(np.max(images[0]) > 10) 53 | assert(np.min(images[0]) >= 0.0) 54 | inps = [] 55 | for img in images: 56 | img = img.astype(np.float32) 57 | inps.append(np.expand_dims(img, 0)) 58 | bs = config['batch_size'] 59 | with tf.Session() as sess: 60 | preds, pools = [], [] 61 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 62 | for i in trange(n_batches): 63 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 64 | inp = np.concatenate(inp, 0) 65 | pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp}) 66 | preds.append(pred) 67 | pools.append(pool) 68 | preds = np.concatenate(preds, 0) 69 | scores = [] 70 | for i in range(splits): 71 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 72 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 73 | kl = np.mean(np.sum(kl, 1)) 74 | scores.append(np.exp(kl)) 75 | return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0)) 76 | # Init inception 77 | def _init_inception(): 78 | global softmax, pool3 79 | if not os.path.exists(MODEL_DIR): 80 | os.makedirs(MODEL_DIR) 81 | filename = DATA_URL.split('/')[-1] 82 | filepath = os.path.join(MODEL_DIR, filename) 83 | if not os.path.exists(filepath): 84 | def _progress(count, block_size, total_size): 85 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 86 | filename, float(count * block_size) / float(total_size) * 100.0)) 87 | sys.stdout.flush() 88 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 89 | print() 90 | statinfo = os.stat(filepath) 91 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 92 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 93 | with tf.gfile.FastGFile(os.path.join( 94 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 95 | graph_def = tf.GraphDef() 96 | graph_def.ParseFromString(f.read()) 97 | _ = tf.import_graph_def(graph_def, name='') 98 | # Works with an arbitrary minibatch size. 99 | with tf.Session() as sess: 100 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 101 | ops = pool3.graph.get_operations() 102 | for op_idx, op in enumerate(ops): 103 | for o in op.outputs: 104 | shape = o.get_shape() 105 | shape = [s.value for s in shape] 106 | new_shape = [] 107 | for j, s in enumerate(shape): 108 | if s == 1 and j == 0: 109 | new_shape.append(None) 110 | else: 111 | new_shape.append(s) 112 | o._shape = tf.TensorShape(new_shape) 113 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 114 | logits = tf.matmul(tf.squeeze(pool3), w) 115 | softmax = tf.nn.softmax(logits) 116 | 117 | # if softmax is None: # No need to functionalize like this. 118 | _init_inception() 119 | 120 | fname = '%s/%s/samples.npz' % (config['experiment_root'], config['experiment_name']) 121 | print('loading %s ...'%fname) 122 | ims = np.load(fname)['x'] 123 | import time 124 | t0 = time.time() 125 | inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=10) 126 | t1 = time.time() 127 | print('Saving pool to numpy file for FID calculations...') 128 | np.savez('%s/%s/TF_pool.npz' % (config['experiment_root'], config['experiment_name']), **{'pool_mean': np.mean(pool_activations,axis=0), 'pool_var': np.cov(pool_activations, rowvar=False)}) 129 | print('Inception took %3f seconds, score of %3f +/- %3f.'%(t1-t0, inc_mean, inc_std)) 130 | def main(): 131 | # parse command line and run 132 | parser = prepare_parser() 133 | config = vars(parser.parse_args()) 134 | print(config) 135 | run(config) 136 | 137 | if __name__ == '__main__': 138 | main() -------------------------------------------------------------------------------- /biggan_pytorch/inception_utils.py: -------------------------------------------------------------------------------- 1 | ''' Inception utilities 2 | This file contains methods for calculating IS and FID, using either 3 | the original numpy code or an accelerated fully-pytorch version that 4 | uses a fast newton-schulz approximation for the matrix sqrt. There are also 5 | methods for acquiring a desired number of samples from the Generator, 6 | and parallelizing the inbuilt PyTorch inception network. 7 | 8 | NOTE that Inception Scores and FIDs calculated using these methods will 9 | *not* be directly comparable to values calculated using the original TF 10 | IS/FID code. You *must* use the TF model if you wish to report and compare 11 | numbers. This code tends to produce IS values that are 5-10% lower than 12 | those obtained through TF. 13 | ''' 14 | import numpy as np 15 | from scipy import linalg # For numpy FID 16 | import time 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn import Parameter as P 22 | from torchvision.models.inception import inception_v3 23 | 24 | 25 | # Module that wraps the inception network to enable use with dataparallel and 26 | # returning pool features and logits. 27 | class WrapInception(nn.Module): 28 | def __init__(self, net): 29 | super(WrapInception,self).__init__() 30 | self.net = net 31 | self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1), 32 | requires_grad=False) 33 | self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1), 34 | requires_grad=False) 35 | def forward(self, x): 36 | # Normalize x 37 | x = (x + 1.) / 2.0 38 | x = (x - self.mean) / self.std 39 | # Upsample if necessary 40 | if x.shape[2] != 299 or x.shape[3] != 299: 41 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 42 | # 299 x 299 x 3 43 | x = self.net.Conv2d_1a_3x3(x) 44 | # 149 x 149 x 32 45 | x = self.net.Conv2d_2a_3x3(x) 46 | # 147 x 147 x 32 47 | x = self.net.Conv2d_2b_3x3(x) 48 | # 147 x 147 x 64 49 | x = F.max_pool2d(x, kernel_size=3, stride=2) 50 | # 73 x 73 x 64 51 | x = self.net.Conv2d_3b_1x1(x) 52 | # 73 x 73 x 80 53 | x = self.net.Conv2d_4a_3x3(x) 54 | # 71 x 71 x 192 55 | x = F.max_pool2d(x, kernel_size=3, stride=2) 56 | # 35 x 35 x 192 57 | x = self.net.Mixed_5b(x) 58 | # 35 x 35 x 256 59 | x = self.net.Mixed_5c(x) 60 | # 35 x 35 x 288 61 | x = self.net.Mixed_5d(x) 62 | # 35 x 35 x 288 63 | x = self.net.Mixed_6a(x) 64 | # 17 x 17 x 768 65 | x = self.net.Mixed_6b(x) 66 | # 17 x 17 x 768 67 | x = self.net.Mixed_6c(x) 68 | # 17 x 17 x 768 69 | x = self.net.Mixed_6d(x) 70 | # 17 x 17 x 768 71 | x = self.net.Mixed_6e(x) 72 | # 17 x 17 x 768 73 | # 17 x 17 x 768 74 | x = self.net.Mixed_7a(x) 75 | # 8 x 8 x 1280 76 | x = self.net.Mixed_7b(x) 77 | # 8 x 8 x 2048 78 | x = self.net.Mixed_7c(x) 79 | # 8 x 8 x 2048 80 | pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2) 81 | # 1 x 1 x 2048 82 | logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1)) 83 | # 1000 (num_classes) 84 | return pool, logits 85 | 86 | 87 | # A pytorch implementation of cov, from Modar M. Alfadly 88 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 89 | def torch_cov(m, rowvar=False): 90 | '''Estimate a covariance matrix given data. 91 | 92 | Covariance indicates the level to which two variables vary together. 93 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 94 | then the covariance matrix element `C_{ij}` is the covariance of 95 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 96 | 97 | Args: 98 | m: A 1-D or 2-D array containing multiple variables and observations. 99 | Each row of `m` represents a variable, and each column a single 100 | observation of all those variables. 101 | rowvar: If `rowvar` is True, then each row represents a 102 | variable, with observations in the columns. Otherwise, the 103 | relationship is transposed: each column represents a variable, 104 | while the rows contain observations. 105 | 106 | Returns: 107 | The covariance matrix of the variables. 108 | ''' 109 | if m.dim() > 2: 110 | raise ValueError('m has more than 2 dimensions') 111 | if m.dim() < 2: 112 | m = m.view(1, -1) 113 | if not rowvar and m.size(0) != 1: 114 | m = m.t() 115 | # m = m.type(torch.double) # uncomment this line if desired 116 | fact = 1.0 / (m.size(1) - 1) 117 | m -= torch.mean(m, dim=1, keepdim=True) 118 | mt = m.t() # if complex: mt = m.t().conj() 119 | return fact * m.matmul(mt).squeeze() 120 | 121 | 122 | # Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji 123 | # https://github.com/msubhransu/matrix-sqrt 124 | def sqrt_newton_schulz(A, numIters, dtype=None): 125 | with torch.no_grad(): 126 | if dtype is None: 127 | dtype = A.type() 128 | batchSize = A.shape[0] 129 | dim = A.shape[1] 130 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 131 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)); 132 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 133 | Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 134 | for i in range(numIters): 135 | T = 0.5*(3.0*I - Z.bmm(Y)) 136 | Y = Y.bmm(T) 137 | Z = T.bmm(Z) 138 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 139 | return sA 140 | 141 | 142 | # FID calculator from TTUR--consider replacing this with GPU-accelerated cov 143 | # calculations using torch? 144 | def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 145 | """Numpy implementation of the Frechet Distance. 146 | Taken from https://github.com/bioinf-jku/TTUR 147 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 148 | and X_2 ~ N(mu_2, C_2) is 149 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 150 | Stable version by Dougal J. Sutherland. 151 | Params: 152 | -- mu1 : Numpy array containing the activations of a layer of the 153 | inception net (like returned by the function 'get_predictions') 154 | for generated samples. 155 | -- mu2 : The sample mean over activations, precalculated on an 156 | representive data set. 157 | -- sigma1: The covariance matrix over activations for generated samples. 158 | -- sigma2: The covariance matrix over activations, precalculated on an 159 | representive data set. 160 | Returns: 161 | -- : The Frechet Distance. 162 | """ 163 | 164 | mu1 = np.atleast_1d(mu1) 165 | mu2 = np.atleast_1d(mu2) 166 | 167 | sigma1 = np.atleast_2d(sigma1) 168 | sigma2 = np.atleast_2d(sigma2) 169 | 170 | assert mu1.shape == mu2.shape, \ 171 | 'Training and test mean vectors have different lengths' 172 | assert sigma1.shape == sigma2.shape, \ 173 | 'Training and test covariances have different dimensions' 174 | 175 | diff = mu1 - mu2 176 | 177 | # Product might be almost singular 178 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 179 | if not np.isfinite(covmean).all(): 180 | msg = ('fid calculation produces singular product; ' 181 | 'adding %s to diagonal of cov estimates') % eps 182 | print(msg) 183 | offset = np.eye(sigma1.shape[0]) * eps 184 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 185 | 186 | # Numerical error might give slight imaginary component 187 | if np.iscomplexobj(covmean): 188 | print('wat') 189 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 190 | m = np.max(np.abs(covmean.imag)) 191 | raise ValueError('Imaginary component {}'.format(m)) 192 | covmean = covmean.real 193 | 194 | tr_covmean = np.trace(covmean) 195 | 196 | out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 197 | return out 198 | 199 | 200 | def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 201 | """Pytorch implementation of the Frechet Distance. 202 | Taken from https://github.com/bioinf-jku/TTUR 203 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 204 | and X_2 ~ N(mu_2, C_2) is 205 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 206 | Stable version by Dougal J. Sutherland. 207 | Params: 208 | -- mu1 : Numpy array containing the activations of a layer of the 209 | inception net (like returned by the function 'get_predictions') 210 | for generated samples. 211 | -- mu2 : The sample mean over activations, precalculated on an 212 | representive data set. 213 | -- sigma1: The covariance matrix over activations for generated samples. 214 | -- sigma2: The covariance matrix over activations, precalculated on an 215 | representive data set. 216 | Returns: 217 | -- : The Frechet Distance. 218 | """ 219 | 220 | 221 | assert mu1.shape == mu2.shape, \ 222 | 'Training and test mean vectors have different lengths' 223 | assert sigma1.shape == sigma2.shape, \ 224 | 'Training and test covariances have different dimensions' 225 | 226 | diff = mu1 - mu2 227 | # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2 228 | covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze() 229 | out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) 230 | - 2 * torch.trace(covmean)) 231 | return out 232 | 233 | 234 | # Calculate Inception Score mean + std given softmax'd logits and number of splits 235 | def calculate_inception_score(pred, num_splits=10): 236 | scores = [] 237 | for index in range(num_splits): 238 | pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :] 239 | kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0))) 240 | kl_inception = np.mean(np.sum(kl_inception, 1)) 241 | scores.append(np.exp(kl_inception)) 242 | return np.mean(scores), np.std(scores) 243 | 244 | 245 | # Loop and run the sampler and the net until it accumulates num_inception_images 246 | # activations. Return the pool, the logits, and the labels (if one wants 247 | # Inception Accuracy the labels of the generated class will be needed) 248 | def accumulate_inception_activations(sample, net, num_inception_images=50000): 249 | pool, logits, labels = [], [], [] 250 | while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images: 251 | with torch.no_grad(): 252 | images, labels_val = sample() 253 | pool_val, logits_val = net(images.float()) 254 | pool += [pool_val] 255 | logits += [F.softmax(logits_val, 1)] 256 | labels += [labels_val] 257 | return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0) 258 | 259 | 260 | # Load and wrap the Inception model 261 | def load_inception_net(parallel=False): 262 | inception_model = inception_v3(pretrained=True, transform_input=False) 263 | inception_model = WrapInception(inception_model.eval()).cuda() 264 | if parallel: 265 | print('Parallelizing Inception module...') 266 | inception_model = nn.DataParallel(inception_model) 267 | return inception_model 268 | 269 | 270 | # This produces a function which takes in an iterator which returns a set number of samples 271 | # and iterates until it accumulates config['num_inception_images'] images. 272 | # The iterator can return samples with a different batch size than used in 273 | # training, using the setting confg['inception_batchsize'] 274 | def prepare_inception_metrics(dataset, parallel, no_fid=False): 275 | # Load metrics; this is intentionally not in a try-except loop so that 276 | # the script will crash here if it cannot find the Inception moments. 277 | # By default, remove the "hdf5" from dataset 278 | dataset = dataset.strip('_hdf5') 279 | data_mu = np.load(dataset+'_inception_moments.npz')['mu'] 280 | data_sigma = np.load(dataset+'_inception_moments.npz')['sigma'] 281 | # Load network 282 | net = load_inception_net(parallel) 283 | def get_inception_metrics(sample, num_inception_images, num_splits=10, 284 | prints=True, use_torch=True): 285 | if prints: 286 | print('Gathering activations...') 287 | pool, logits, labels = accumulate_inception_activations(sample, net, num_inception_images) 288 | if prints: 289 | print('Calculating Inception Score...') 290 | IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits) 291 | if no_fid: 292 | FID = 9999.0 293 | else: 294 | if prints: 295 | print('Calculating means and covariances...') 296 | if use_torch: 297 | mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False) 298 | else: 299 | mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False) 300 | if prints: 301 | print('Covariances calculated, getting FID...') 302 | if use_torch: 303 | FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda()) 304 | FID = float(FID.cpu().numpy()) 305 | else: 306 | FID = numpy_calculate_frechet_distance(mu.cpu().numpy(), sigma.cpu().numpy(), data_mu, data_sigma) 307 | # Delete mu, sigma, pool, logits, and labels, just in case 308 | del mu, sigma, pool, logits, labels 309 | return IS_mean, IS_std, FID 310 | return get_inception_metrics -------------------------------------------------------------------------------- /biggan_pytorch/logs/BigGAN_ch96_bs256x8.jsonl: -------------------------------------------------------------------------------- 1 | {"itr": 2000, "IS_mean": 2.806771755218506, "IS_std": 0.019480662420392036, "FID": 173.76484159711126, "_stamp": 1551403232.0425167} 2 | {"itr": 4000, "IS_mean": 4.962374687194824, "IS_std": 0.07276841998100281, "FID": 113.86730514283107, "_stamp": 1551422228.743057} 3 | {"itr": 6000, "IS_mean": 6.939817905426025, "IS_std": 0.11417163163423538, "FID": 101.63548498447199, "_stamp": 1551457139.3400874} 4 | {"itr": 8000, "IS_mean": 8.142985343933105, "IS_std": 0.11931543797254562, "FID": 92.0014385772705, "_stamp": 1551476217.2409613} 5 | {"itr": 10000, "IS_mean": 10.355518341064453, "IS_std": 0.09094739705324173, "FID": 83.58068997965364, "_stamp": 1551494854.2419689} 6 | {"itr": 12000, "IS_mean": 11.288347244262695, "IS_std": 0.14952820539474487, "FID": 80.98066299357106, "_stamp": 1551513232.5049698} 7 | {"itr": 14000, "IS_mean": 11.755794525146484, "IS_std": 0.17969024181365967, "FID": 76.80603924280956, "_stamp": 1551531425.150371} 8 | {"itr": 18000, "IS_mean": 13.65534496307373, "IS_std": 0.11151058971881866, "FID": 65.95736694335938, "_stamp": 1551588271.9177916} 9 | {"itr": 20000, "IS_mean": 14.817827224731445, "IS_std": 0.23588882386684418, "FID": 61.32061767578125, "_stamp": 1551606713.6567464} 10 | {"itr": 22000, "IS_mean": 17.16551399230957, "IS_std": 0.19506946206092834, "FID": 53.387969970703125, "_stamp": 1551624876.6513028} 11 | {"itr": 24000, "IS_mean": 19.60654067993164, "IS_std": 0.5591856837272644, "FID": 46.5386962890625, "_stamp": 1551642822.6126688} 12 | {"itr": 26000, "IS_mean": 21.74416732788086, "IS_std": 0.2850531041622162, "FID": 41.595001220703125, "_stamp": 1551663522.6019194} 13 | {"itr": 28000, "IS_mean": 23.923612594604492, "IS_std": 0.41587772965431213, "FID": 37.894744873046875, "_stamp": 1551681794.6567173} 14 | {"itr": 30000, "IS_mean": 25.569377899169922, "IS_std": 0.3333457112312317, "FID": 35.49310302734375, "_stamp": 1551699773.7080302} 15 | {"itr": 32000, "IS_mean": 26.867944717407227, "IS_std": 0.5968036651611328, "FID": 33.4849853515625, "_stamp": 1551717623.887933} 16 | {"itr": 34000, "IS_mean": 28.719074249267578, "IS_std": 0.5698027014732361, "FID": 31.375518798828125, "_stamp": 1551735411.1578612} 17 | {"itr": 36000, "IS_mean": 30.587574005126953, "IS_std": 0.5044271349906921, "FID": 29.432281494140625, "_stamp": 1551783380.6357439} 18 | {"itr": 38000, "IS_mean": 32.08299255371094, "IS_std": 0.49342143535614014, "FID": 28.099456787109375, "_stamp": 1551801179.6495197} 19 | {"itr": 40000, "IS_mean": 34.24657440185547, "IS_std": 0.7709177732467651, "FID": 26.53802490234375, "_stamp": 1551818775.171794} 20 | {"itr": 42000, "IS_mean": 35.891212463378906, "IS_std": 0.7036871314048767, "FID": 25.03021240234375, "_stamp": 1551836329.6873965} 21 | {"itr": 44000, "IS_mean": 38.184898376464844, "IS_std": 0.32996198534965515, "FID": 23.4940185546875, "_stamp": 1551897864.911537} 22 | {"itr": 46000, "IS_mean": 40.239479064941406, "IS_std": 0.7761151194572449, "FID": 22.53167724609375, "_stamp": 1551915406.4840703} 23 | {"itr": 48000, "IS_mean": 41.46656036376953, "IS_std": 1.1031498908996582, "FID": 21.5338134765625, "_stamp": 1551932899.6074848} 24 | {"itr": 50000, "IS_mean": 43.31670379638672, "IS_std": 0.7796809077262878, "FID": 20.53253173828125, "_stamp": 1551950390.345334} 25 | {"itr": 52000, "IS_mean": 45.1517333984375, "IS_std": 1.2925242185592651, "FID": 19.656646728515625, "_stamp": 1551967838.1501615} 26 | {"itr": 54000, "IS_mean": 47.638771057128906, "IS_std": 1.0689665079116821, "FID": 18.898162841796875, "_stamp": 1552044534.5349634} 27 | {"itr": 56000, "IS_mean": 48.87520217895508, "IS_std": 1.1317559480667114, "FID": 18.1248779296875, "_stamp": 1552061763.3080354} 28 | {"itr": 58000, "IS_mean": 49.40987014770508, "IS_std": 1.1866596937179565, "FID": 17.751922607421875, "_stamp": 1552078939.9828825} 29 | {"itr": 60000, "IS_mean": 51.051334381103516, "IS_std": 1.2281248569488525, "FID": 17.19964599609375, "_stamp": 1552096167.889482} 30 | {"itr": 62000, "IS_mean": 52.0235481262207, "IS_std": 0.5391153693199158, "FID": 16.62115478515625, "_stamp": 1552113417.9520617} 31 | {"itr": 64000, "IS_mean": 53.868492126464844, "IS_std": 1.327082633972168, "FID": 16.237335205078125, "_stamp": 1552142961.09602} 32 | {"itr": 66000, "IS_mean": 54.978721618652344, "IS_std": 0.9502049088478088, "FID": 15.81170654296875, "_stamp": 1552162403.2232807} 33 | {"itr": 68000, "IS_mean": 55.73248291015625, "IS_std": 1.0323851108551025, "FID": 15.545623779296875, "_stamp": 1552181112.676657} 34 | {"itr": 70000, "IS_mean": 56.78422927856445, "IS_std": 1.211003303527832, "FID": 15.28369140625, "_stamp": 1552199498.887533} 35 | {"itr": 72000, "IS_mean": 57.972999572753906, "IS_std": 0.8668608665466309, "FID": 14.86395263671875, "_stamp": 1552217782.2738616} 36 | {"itr": 74000, "IS_mean": 58.845054626464844, "IS_std": 1.4297977685928345, "FID": 14.620635986328125, "_stamp": 1552251085.1781816} 37 | {"itr": 76000, "IS_mean": 59.60982131958008, "IS_std": 0.9095696210861206, "FID": 14.360198974609375, "_stamp": 1552270214.9345307} 38 | {"itr": 78000, "IS_mean": 60.71195602416992, "IS_std": 0.960899829864502, "FID": 14.07183837890625, "_stamp": 1552288697.1580262} 39 | {"itr": 80000, "IS_mean": 61.772125244140625, "IS_std": 0.6913255453109741, "FID": 13.781585693359375, "_stamp": 1552307170.0280282} 40 | {"itr": 82000, "IS_mean": 62.98079299926758, "IS_std": 1.4735801219940186, "FID": 13.55389404296875, "_stamp": 1552325252.8553352} 41 | {"itr": 84000, "IS_mean": 64.95240783691406, "IS_std": 0.9018951654434204, "FID": 13.231689453125, "_stamp": 1552344135.3111835} 42 | {"itr": 86000, "IS_mean": 65.13968658447266, "IS_std": 0.8772205114364624, "FID": 13.176849365234375, "_stamp": 1552362429.6782444} 43 | {"itr": 88000, "IS_mean": 65.84476470947266, "IS_std": 1.167534351348877, "FID": 12.87078857421875, "_stamp": 1552380560.7988124} 44 | {"itr": 90000, "IS_mean": 67.41099548339844, "IS_std": 1.6899267435073853, "FID": 12.586517333984375, "_stamp": 1552398550.2060475} 45 | {"itr": 92000, "IS_mean": 68.63685607910156, "IS_std": 1.9431978464126587, "FID": 12.49505615234375, "_stamp": 1552430781.6406457} 46 | {"itr": 94000, "IS_mean": 70.09907531738281, "IS_std": 1.0715738534927368, "FID": 12.047607421875, "_stamp": 1552449001.1950285} 47 | {"itr": 96000, "IS_mean": 70.34623718261719, "IS_std": 1.7962944507598877, "FID": 11.896697998046875, "_stamp": 1552466989.3587568} 48 | {"itr": 98000, "IS_mean": 71.08210754394531, "IS_std": 1.458209753036499, "FID": 11.73046875, "_stamp": 1552484800.7138846} 49 | {"itr": 100000, "IS_mean": 72.24256896972656, "IS_std": 1.3259714841842651, "FID": 11.7386474609375, "_stamp": 1552502538.0269725} 50 | {"itr": 102000, "IS_mean": 73.19488525390625, "IS_std": 1.3439149856567383, "FID": 11.50494384765625, "_stamp": 1552523284.4514356} 51 | {"itr": 104000, "IS_mean": 73.38243103027344, "IS_std": 1.4162707328796387, "FID": 11.374542236328125, "_stamp": 1552541012.0651608} 52 | {"itr": 106000, "IS_mean": 74.95563507080078, "IS_std": 1.089124083518982, "FID": 11.10479736328125, "_stamp": 1552558577.7458107} 53 | {"itr": 108000, "IS_mean": 76.42997741699219, "IS_std": 1.9282453060150146, "FID": 10.998870849609375, "_stamp": 1552576111.9480467} 54 | {"itr": 110000, "IS_mean": 76.89225769042969, "IS_std": 1.4771150350570679, "FID": 10.847015380859375, "_stamp": 1552593659.445132} 55 | {"itr": 112000, "IS_mean": 78.04684448242188, "IS_std": 1.4850096702575684, "FID": 10.772552490234375, "_stamp": 1552616479.5201895} 56 | {"itr": 114000, "IS_mean": 79.67677307128906, "IS_std": 2.0147368907928467, "FID": 10.528045654296875, "_stamp": 1552633850.9315467} 57 | {"itr": 116000, "IS_mean": 79.8828125, "IS_std": 0.978247344493866, "FID": 10.626068115234375, "_stamp": 1552651198.9012825} 58 | {"itr": 118000, "IS_mean": 79.95381164550781, "IS_std": 1.8608143329620361, "FID": 10.46771240234375, "_stamp": 1552668560.4420238} 59 | {"itr": 120000, "IS_mean": 82.37217712402344, "IS_std": 1.8909310102462769, "FID": 10.259033203125, "_stamp": 1552749673.4319007} 60 | {"itr": 122000, "IS_mean": 83.49666595458984, "IS_std": 2.38446044921875, "FID": 9.996185302734375, "_stamp": 1552766698.2706933} 61 | {"itr": 124000, "IS_mean": 83.05189514160156, "IS_std": 1.8844469785690308, "FID": 10.164398193359375, "_stamp": 1552783762.891172} 62 | {"itr": 126000, "IS_mean": 84.27763366699219, "IS_std": 0.9329544901847839, "FID": 10.03509521484375, "_stamp": 1552800953.5724175} 63 | {"itr": 128000, "IS_mean": 85.84852600097656, "IS_std": 2.2698562145233154, "FID": 9.91644287109375, "_stamp": 1552818112.227726} 64 | {"itr": 130000, "IS_mean": 87.356689453125, "IS_std": 2.0958640575408936, "FID": 9.771148681640625, "_stamp": 1552837539.995247} 65 | {"itr": 132000, "IS_mean": 88.72562408447266, "IS_std": 1.7551432847976685, "FID": 9.8258056640625, "_stamp": 1552859685.9305944} 66 | {"itr": 134000, "IS_mean": 88.0631103515625, "IS_std": 1.8199039697647095, "FID": 9.957183837890625, "_stamp": 1552880037.5408435} 67 | {"itr": 136000, "IS_mean": 91.50938415527344, "IS_std": 1.9926033020019531, "FID": 9.876556396484375, "_stamp": 1552899854.652669} 68 | {"itr": 138000, "IS_mean": 93.09217834472656, "IS_std": 2.3062736988067627, "FID": 9.908477783203125, "_stamp": 1552921580.958927} -------------------------------------------------------------------------------- /biggan_pytorch/logs/compare_IS.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear all 3 | close all 4 | fclose all; 5 | 6 | 7 | 8 | %% Get All logs and sort them 9 | s = {}; 10 | d = dir(); 11 | j = 1; 12 | for i = 1:length(d) 13 | if any(strfind(d(i).name,'.jsonl')) 14 | s = [s; d(i).name]; 15 | end 16 | end 17 | 18 | 19 | j = 1; 20 | for i = 1:length(s) 21 | fname = s{i,1}; 22 | % Check if the Inception metrics log exists, and if so, plot it 23 | [itr, IS, FID, t] = process_inception_log(fname(1:end - 10), 'log.jsonl'); 24 | s{i,2} = itr; 25 | s{i,3} = IS; 26 | s{i,4} = FID; 27 | s{i,5} = max(IS); 28 | s{i,6} = min(FID); 29 | s{i,7} = t; 30 | end 31 | % Sort by Inception Score? 32 | [IS_sorted, IS_index] = sort(cell2mat(s(:,5))); 33 | % Cutoff inception scores below a certain value? 34 | threshold = 22; 35 | IS_index = IS_index(IS_sorted > threshold); 36 | 37 | % Sort by FID? 38 | [FID_sorted, FID_index] = sort(cell2mat(s(:,6))); 39 | % Cutoff also based on IS? 40 | % threshold = 0; 41 | FID_index = FID_index(IS_sorted > threshold); 42 | 43 | 44 | 45 | %% Plot things? 46 | cc = hsv(length(IS_index)); 47 | legend1 = {}; 48 | legend2 = {}; 49 | make_axis=true;%false % Turn this on to see the axis out to 1e6 iterations 50 | for i=1:length(IS_index) 51 | legend1 = [legend1; s{IS_index(i), 1}]; 52 | figure(1) 53 | plot(s{IS_index(i),2}, s{IS_index(i),3}, 'color', cc(i,:),'linewidth',2) 54 | hold on; 55 | xlabel('itr'); ylabel('IS'); 56 | grid on; 57 | if make_axis 58 | axis([0,1e6,0,80]); % 50% grid on; 59 | end 60 | legend(legend1,'Interpreter','none') 61 | %pause(1) % Turn this on to animate stuff 62 | legend2 = [legend2; s{IS_index(i), 1}]; 63 | figure(2) 64 | plot(s{IS_index(i),2}, s{IS_index(i),4}, 'color', cc(i,:),'linewidth',2) 65 | hold on; 66 | xlabel('itr'); ylabel('FID'); 67 | j = j + 1; 68 | grid on; 69 | if make_axis 70 | axis([0,1e6,0,50]);% grid on; 71 | end 72 | legend(legend2, 'Interpreter','none') 73 | 74 | end 75 | 76 | %% Quick script to plot IS versus timesteps 77 | if 0 78 | figure(3); 79 | this_index=4; 80 | subplot(2,1,1); 81 | %plot(s{this_index, 2}(2:end), s{this_index, 7}(2:end) - s{this_index, 7}(1:end-1), 'r*'); 82 | % xlabel('Iteration');ylabel('\Delta T') 83 | plot(s{this_index, 2}, s{this_index, 7}, 'r*'); 84 | xlabel('Iteration');ylabel('T') 85 | subplot(2,1,2); 86 | plot(s{this_index, 2}, s{this_index, 3}, 'r', 'linewidth',2); 87 | xlabel('Iteration'), ylabel('Inception score') 88 | title(s{this_index,1}) 89 | end -------------------------------------------------------------------------------- /biggan_pytorch/logs/metalog.txt: -------------------------------------------------------------------------------- 1 | datetime: 2019-03-18 13:27:59.181225 2 | config: {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 400, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': True, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)} 3 | state: {'itr': 137500, 'epoch': 2, 'save_num': 0, 'save_best_num': 1, 'best_IS': 91.509384, 'best_FID': tensor(9.7711, 'config': {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': False, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'BN_sync': False, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}} 4 | -------------------------------------------------------------------------------- /biggan_pytorch/logs/process_inception_log.m: -------------------------------------------------------------------------------- 1 | function [itr, IS, FID, t] = process_inception_log(fname, which_log) 2 | f = sprintf('%s_%s',fname, which_log);%'G_loss.log'); 3 | fid = fopen(f,'r'); 4 | itr = []; 5 | IS = []; 6 | FID = []; 7 | t = []; 8 | i = 1; 9 | while ~feof(fid); 10 | s = fgets(fid); 11 | parsed = sscanf(s,'{"itr": %d, "IS_mean": %f, "IS_std": %f, "FID": %f, "_stamp": %f}'); 12 | itr(i) = parsed(1); 13 | IS(i) = parsed(2); 14 | FID(i) = parsed(4); 15 | t(i) = parsed(5); 16 | i = i + 1; 17 | end 18 | fclose(fid); 19 | end -------------------------------------------------------------------------------- /biggan_pytorch/logs/process_training.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear all 3 | close all 4 | fclose all; 5 | 6 | 7 | 8 | %% Get all training logs for a given run 9 | target_dir = '.'; 10 | s = {}; 11 | nm = {}; 12 | d = dir(target_dir); 13 | j = 1; 14 | for i = 1:length(d) 15 | if any(strfind(d(i).name,'.log')) 16 | s = [s; sprintf('%s\\%s', target_dir, d(i).name)]; 17 | nm = [nm; d(i).name]; 18 | end 19 | end 20 | %% Loop over training logs and acquire data 21 | D_count = 0; 22 | G_count = 0; 23 | for i = 1:length(s) 24 | fname = s{i,1}; 25 | fid = fopen(s{i,1},'r'); 26 | % Prepare bookkeeping for sv0 27 | if any(strfind(s{i,1},'sv')) 28 | if any(strfind(s{i,1},'G_')) 29 | G_count = G_count +1; 30 | else 31 | D_count = D_count + 1; 32 | end 33 | end 34 | itr = []; 35 | val = []; 36 | j = 1; 37 | while ~feof(fid); 38 | line = fgets(fid); 39 | parsed = sscanf(line, '%d: %e'); 40 | itr(j) = parsed(1); 41 | val(j) = parsed(2); 42 | j = j + 1; 43 | end 44 | s{i,2} = itr; 45 | s{i,3} = val; 46 | fclose(fid); 47 | end 48 | 49 | %% Plot SVs and losses 50 | close all; 51 | Gcc = hsv(G_count); 52 | Dcc = hsv(D_count); 53 | gi = 1; 54 | di = 1; 55 | li = 1; 56 | legendG = {}; 57 | legendD = {}; 58 | legendL = {}; 59 | thresh=2; % wavelet denoising threshold 60 | losses = {}; 61 | for i=1:length(s) 62 | if any(strfind(s{i,1},'D_loss_real.log')) || any(strfind(s{i,1},'D_loss_fake.log')) || any(strfind(s{i,1},'G_loss.log')) 63 | % Select colors 64 | if any(strfind(s{i,1},'D_loss_real.log')) 65 | color1 = [0.7,0.7,1.0]; 66 | color2 = [0, 0, 1]; 67 | dlr = {s{i,2}, s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; 68 | losses = [losses; dlr]; 69 | elseif any(strfind(s{i,1},'D_loss_fake.log')) 70 | color1 = [0.7,1.0,0.7]; 71 | color2 = [0, 1, 0]; 72 | dlf = {s{i,2},s{i,3} wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2}; 73 | losses = [losses; dlf]; 74 | else % g loss 75 | color1 = [1.0, 0.7,0.7]; 76 | color2 = [1, 0, 0]; 77 | gl = {s{i,2},s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1 color2}; 78 | losses = [losses; gl]; 79 | end 80 | figure(1); hold on; 81 | % Plot the unsmoothed losses; we'll plot the smoothed losses later 82 | plot(s{i,2},s{i,3},'color', color1, 'HandleVisibility','off'); 83 | legendL = [legendL; nm{i}]; 84 | continue 85 | end 86 | if any(strfind(s{i,1},'G_')) 87 | legendG = [legendG; nm{i}]; 88 | figure(2); hold on; 89 | plot(s{i,2},s{i,3},'color',Gcc(gi,:),'linewidth',2); 90 | gi = gi+1; 91 | elseif any(strfind(s{i,1},'D_')) 92 | legendD = [legendD; nm{i}]; 93 | figure(3); hold on; 94 | plot(s{i,2},s{i,3},'color',Dcc(di,:),'linewidth',2); 95 | di = di+1; 96 | else 97 | s{i,1} % Debug print to show the name of the log that was not processed. 98 | end 99 | end 100 | figure(1); 101 | % Plot the smoothed losses last 102 | for i = 1:3 103 | % plot(losses{i,1}, losses{i,2},'color', losses{i,4}, 'HandleVisibility','off'); 104 | plot(losses{i,1},losses{i,3},'color',losses{i,5}); 105 | end 106 | legend(legendL, 'Interpreter', 'none'); title('Losses'); xlabel('Generator itr'); ylabel('loss'); axis([0, max(s{end,2}), -1, 4]); 107 | 108 | figure(2); legend(legendG,'Interpreter','none'); title('Singular Values in G'); xlabel('Generator itr'); ylabel('SV0'); 109 | figure(3); legend(legendD, 'Interpreter', 'none'); title('Singular Values in D'); xlabel('Generator itr'); ylabel('SV0'); 110 | -------------------------------------------------------------------------------- /biggan_pytorch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # DCGAN loss 5 | def loss_dcgan_dis(dis_fake, dis_real): 6 | L1 = torch.mean(F.softplus(-dis_real)) 7 | L2 = torch.mean(F.softplus(dis_fake)) 8 | return L1, L2 9 | 10 | 11 | def loss_dcgan_gen(dis_fake): 12 | loss = torch.mean(F.softplus(-dis_fake)) 13 | return loss 14 | 15 | 16 | # Hinge Loss 17 | def loss_hinge_dis(dis_fake, dis_real): 18 | loss_real = torch.mean(F.relu(1. - dis_real)) 19 | loss_fake = torch.mean(F.relu(1. + dis_fake)) 20 | return loss_real, loss_fake 21 | # def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss 22 | # loss = torch.mean(F.relu(1. - dis_real)) 23 | # loss += torch.mean(F.relu(1. + dis_fake)) 24 | # return loss 25 | 26 | 27 | def loss_hinge_gen(dis_fake): 28 | loss = -torch.mean(dis_fake) 29 | return loss 30 | 31 | # Default to hinge loss 32 | generator_loss = loss_hinge_gen 33 | discriminator_loss = loss_hinge_dis -------------------------------------------------------------------------------- /biggan_pytorch/make_hdf5.py: -------------------------------------------------------------------------------- 1 | """ Convert dataset to HDF5 2 | This script preprocesses a dataset and saves it (images and labels) to 3 | an HDF5 file for improved I/O. """ 4 | import os 5 | import sys 6 | from argparse import ArgumentParser 7 | from tqdm import tqdm, trange 8 | import h5py as h5 9 | 10 | import numpy as np 11 | import torch 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | from torchvision.utils import save_image 15 | import torchvision.transforms as transforms 16 | from torch.utils.data import DataLoader 17 | 18 | import utils 19 | 20 | def prepare_parser(): 21 | usage = 'Parser for ImageNet HDF5 scripts.' 22 | parser = ArgumentParser(description=usage) 23 | parser.add_argument( 24 | '--dataset', type=str, default='I128', 25 | help='Which Dataset to train on, out of I128, I256, C10, C100;' 26 | 'Append "_hdf5" to use the hdf5 version for ISLVRC (default: %(default)s)') 27 | parser.add_argument( 28 | '--data_root', type=str, default='data', 29 | help='Default location where data is stored (default: %(default)s)') 30 | parser.add_argument( 31 | '--batch_size', type=int, default=256, 32 | help='Default overall batchsize (default: %(default)s)') 33 | parser.add_argument( 34 | '--num_workers', type=int, default=16, 35 | help='Number of dataloader workers (default: %(default)s)') 36 | parser.add_argument( 37 | '--chunk_size', type=int, default=500, 38 | help='Default overall batchsize (default: %(default)s)') 39 | parser.add_argument( 40 | '--compression', action='store_true', default=False, 41 | help='Use LZF compression? (default: %(default)s)') 42 | return parser 43 | 44 | 45 | def run(config): 46 | if 'hdf5' in config['dataset']: 47 | raise ValueError('Reading from an HDF5 file which you will probably be ' 48 | 'about to overwrite! Override this error only if you know ' 49 | 'what you''re doing!') 50 | # Get image size 51 | config['image_size'] = utils.imsize_dict[config['dataset']] 52 | 53 | # Update compression entry 54 | config['compression'] = 'lzf' if config['compression'] else None #No compression; can also use 'lzf' 55 | 56 | # Get dataset 57 | kwargs = {'num_workers': config['num_workers'], 'pin_memory': False, 'drop_last': False} 58 | train_loader = utils.get_data_loaders(dataset=config['dataset'], 59 | batch_size=config['batch_size'], 60 | shuffle=False, 61 | data_root=config['data_root'], 62 | use_multiepoch_sampler=False, 63 | **kwargs)[0] 64 | 65 | # HDF5 supports chunking and compression. You may want to experiment 66 | # with different chunk sizes to see how it runs on your machines. 67 | # Chunk Size/compression Read speed @ 256x256 Read speed @ 128x128 Filesize @ 128x128 Time to write @128x128 68 | # 1 / None 20/s 69 | # 500 / None ramps up to 77/s 102/s 61GB 23min 70 | # 500 / LZF 8/s 56GB 23min 71 | # 1000 / None 78/s 72 | # 5000 / None 81/s 73 | # auto:(125,1,16,32) / None 11/s 61GB 74 | 75 | print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (config['dataset'], config['chunk_size'], config['compression'])) 76 | # Loop over train loader 77 | for i,(x,y) in enumerate(tqdm(train_loader)): 78 | # Stick X into the range [0, 255] since it's coming from the train loader 79 | x = (255 * ((x + 1) / 2.0)).byte().numpy() 80 | # Numpyify y 81 | y = y.numpy() 82 | # If we're on the first batch, prepare the hdf5 83 | if i==0: 84 | with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f: 85 | print('Producing dataset of len %d' % len(train_loader.dataset)) 86 | imgs_dset = f.create_dataset('imgs', x.shape,dtype='uint8', maxshape=(len(train_loader.dataset), 3, config['image_size'], config['image_size']), 87 | chunks=(config['chunk_size'], 3, config['image_size'], config['image_size']), compression=config['compression']) 88 | print('Image chunks chosen as ' + str(imgs_dset.chunks)) 89 | imgs_dset[...] = x 90 | labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(train_loader.dataset),), chunks=(config['chunk_size'],), compression=config['compression']) 91 | print('Label chunks chosen as ' + str(labels_dset.chunks)) 92 | labels_dset[...] = y 93 | # Else append to the hdf5 94 | else: 95 | with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f: 96 | f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) 97 | f['imgs'][-x.shape[0]:] = x 98 | f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) 99 | f['labels'][-y.shape[0]:] = y 100 | 101 | 102 | def main(): 103 | # parse command line and run 104 | parser = prepare_parser() 105 | config = vars(parser.parse_args()) 106 | print(config) 107 | run(config) 108 | 109 | if __name__ == '__main__': 110 | main() -------------------------------------------------------------------------------- /biggan_pytorch/sample.py: -------------------------------------------------------------------------------- 1 | ''' Sample 2 | This script loads a pretrained net and a weightsfile and sample ''' 3 | import functools 4 | import math 5 | import numpy as np 6 | from tqdm import tqdm, trange 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import init 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch.nn import Parameter as P 15 | import torchvision 16 | 17 | # Import my stuff 18 | import inception_utils 19 | import utils 20 | import losses 21 | 22 | 23 | 24 | def run(config): 25 | # Prepare state dict, which holds things like epoch # and itr # 26 | state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 27 | 'best_IS': 0, 'best_FID': 999999, 'config': config} 28 | 29 | # Optionally, get the configuration from the state dict. This allows for 30 | # recovery of the config provided only a state dict and experiment name, 31 | # and can be convenient for writing less verbose sample shell scripts. 32 | if config['config_from_name']: 33 | utils.load_weights(None, None, state_dict, config['weights_root'], 34 | config['experiment_name'], config['load_weights'], None, 35 | strict=False, load_optim=False) 36 | # Ignore items which we might want to overwrite from the command line 37 | for item in state_dict['config']: 38 | if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']: 39 | config[item] = state_dict['config'][item] 40 | 41 | # update config (see train.py for explanation) 42 | config['resolution'] = utils.imsize_dict[config['dataset']] 43 | config['n_classes'] = utils.nclass_dict[config['dataset']] 44 | config['G_activation'] = utils.activation_dict[config['G_nl']] 45 | config['D_activation'] = utils.activation_dict[config['D_nl']] 46 | config = utils.update_config_roots(config) 47 | config['skip_init'] = True 48 | config['no_optim'] = True 49 | device = 'cuda' 50 | 51 | # Seed RNG 52 | utils.seed_rng(config['seed']) 53 | 54 | # Setup cudnn.benchmark for free speed 55 | torch.backends.cudnn.benchmark = True 56 | 57 | # Import the model--this line allows us to dynamically select different files. 58 | model = __import__(config['model']) 59 | experiment_name = (config['experiment_name'] if config['experiment_name'] 60 | else utils.name_from_config(config)) 61 | print('Experiment name is %s' % experiment_name) 62 | 63 | G = model.Generator(**config).cuda() 64 | utils.count_parameters(G) 65 | 66 | # Load weights 67 | print('Loading weights...') 68 | # Here is where we deal with the ema--load ema weights or load normal weights 69 | utils.load_weights(G if not (config['use_ema']) else None, None, state_dict, 70 | config['weights_root'], experiment_name, config['load_weights'], 71 | G if config['ema'] and config['use_ema'] else None, 72 | strict=False, load_optim=False) 73 | # Update batch size setting used for G 74 | G_batch_size = max(config['G_batch_size'], config['batch_size']) 75 | z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], 76 | device=device, fp16=config['G_fp16'], 77 | z_var=config['z_var']) 78 | 79 | if config['G_eval_mode']: 80 | print('Putting G in eval mode..') 81 | G.eval() 82 | else: 83 | print('G is in %s mode...' % ('training' if G.training else 'eval')) 84 | 85 | #Sample function 86 | sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) 87 | if config['accumulate_stats']: 88 | print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) 89 | utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], 90 | config['num_standing_accumulations']) 91 | 92 | 93 | # Sample a number of images and save them to an NPZ, for use with TF-Inception 94 | if config['sample_npz']: 95 | # Lists to hold images and labels for images 96 | x, y = [], [] 97 | print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) 98 | for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): 99 | with torch.no_grad(): 100 | images, labels = sample() 101 | x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] 102 | y += [labels.cpu().numpy()] 103 | x = np.concatenate(x, 0)[:config['sample_num_npz']] 104 | y = np.concatenate(y, 0)[:config['sample_num_npz']] 105 | print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) 106 | npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) 107 | print('Saving npz to %s...' % npz_filename) 108 | np.savez(npz_filename, **{'x' : x, 'y' : y}) 109 | 110 | # Prepare sample sheets 111 | if config['sample_sheets']: 112 | print('Preparing conditional sample sheets...') 113 | utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], 114 | num_classes=config['n_classes'], 115 | samples_per_class=10, parallel=config['parallel'], 116 | samples_root=config['samples_root'], 117 | experiment_name=experiment_name, 118 | folder_number=config['sample_sheet_folder_num'], 119 | z_=z_,) 120 | # Sample interp sheets 121 | if config['sample_interps']: 122 | print('Preparing interp sheets...') 123 | for fix_z, fix_y in zip([False, False, True], [False, True, False]): 124 | utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, 125 | num_classes=config['n_classes'], 126 | parallel=config['parallel'], 127 | samples_root=config['samples_root'], 128 | experiment_name=experiment_name, 129 | folder_number=config['sample_sheet_folder_num'], 130 | sheet_number=0, 131 | fix_z=fix_z, fix_y=fix_y, device='cuda') 132 | # Sample random sheet 133 | if config['sample_random']: 134 | print('Preparing random sample sheet...') 135 | images, labels = sample() 136 | torchvision.utils.save_image(images.float(), 137 | '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name), 138 | nrow=int(G_batch_size**0.5), 139 | normalize=True) 140 | 141 | # Get Inception Score and FID 142 | get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) 143 | # Prepare a simple function get metrics that we use for trunc curves 144 | def get_metrics(): 145 | sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) 146 | IS_mean, IS_std, FID = get_inception_metrics(sample, config['num_inception_images'], num_splits=10, prints=False) 147 | # Prepare output string 148 | outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') 149 | outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') 150 | outstring += 'with noise variance %3.3f, ' % z_.var 151 | outstring += 'over %d images, ' % config['num_inception_images'] 152 | if config['accumulate_stats'] or not config['G_eval_mode']: 153 | outstring += 'with batch size %d, ' % G_batch_size 154 | if config['accumulate_stats']: 155 | outstring += 'using %d standing stat accumulations, ' % config['num_standing_accumulations'] 156 | outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID) 157 | print(outstring) 158 | if config['sample_inception_metrics']: 159 | print('Calculating Inception metrics...') 160 | get_metrics() 161 | 162 | # Sample truncation curve stuff. This is basically the same as the inception metrics code 163 | if config['sample_trunc_curves']: 164 | start, step, end = [float(item) for item in config['sample_trunc_curves'].split('_')] 165 | print('Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) 166 | for var in np.arange(start, end + step, step): 167 | z_.var = var 168 | # Optionally comment this out if you want to run with standing stats 169 | # accumulated at one z variance setting 170 | if config['accumulate_stats']: 171 | utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], 172 | config['num_standing_accumulations']) 173 | get_metrics() 174 | def main(): 175 | # parse command line and run 176 | parser = utils.prepare_parser() 177 | parser = utils.add_sample_parser(parser) 178 | config = vars(parser.parse_args()) 179 | print(config) 180 | run(config) 181 | 182 | if __name__ == '__main__': 183 | main() -------------------------------------------------------------------------------- /biggan_pytorch/scripts/launch_BigGAN_bs256x8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ 4 | --num_G_accumulations 8 --num_D_accumulations 8 \ 5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 6 | --G_attn 64 --D_attn 64 \ 7 | --G_nl inplace_relu --D_nl inplace_relu \ 8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 9 | --G_ortho 0.0 \ 10 | --G_shared \ 11 | --G_init ortho --D_init ortho \ 12 | --hier --dim_z 120 --shared_dim 128 \ 13 | --G_eval_mode \ 14 | --G_ch 96 --D_ch 96 \ 15 | --ema --use_ema --ema_start 20000 \ 16 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 17 | --use_multiepoch_sampler \ -------------------------------------------------------------------------------- /biggan_pytorch/scripts/launch_BigGAN_bs512x4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 512 --load_in_mem \ 4 | --num_G_accumulations 4 --num_D_accumulations 4 \ 5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 6 | --G_attn 64 --D_attn 64 \ 7 | --G_nl inplace_relu --D_nl inplace_relu \ 8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 9 | --G_ortho 0.0 \ 10 | --G_shared \ 11 | --G_init ortho --D_init ortho \ 12 | --hier --dim_z 120 --shared_dim 128 \ 13 | --G_eval_mode \ 14 | --G_ch 96 --D_ch 96 \ 15 | --ema --use_ema --ema_start 20000 \ 16 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 17 | --use_multiepoch_sampler \ -------------------------------------------------------------------------------- /biggan_pytorch/scripts/launch_BigGAN_ch64_bs256x8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \ 4 | --num_G_accumulations 8 --num_D_accumulations 8 \ 5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 6 | --G_attn 64 --D_attn 64 \ 7 | --G_nl inplace_relu --D_nl inplace_relu \ 8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 9 | --G_ortho 0.0 \ 10 | --G_shared \ 11 | --G_init ortho --D_init ortho \ 12 | --hier --dim_z 120 --shared_dim 128 \ 13 | --G_eval_mode \ 14 | --G_ch 64 --G_ch 64 \ 15 | --ema --use_ema --ema_start 20000 \ 16 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 17 | --use_multiepoch_sampler -------------------------------------------------------------------------------- /biggan_pytorch/scripts/launch_BigGAN_deep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --model BigGANdeep \ 4 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ 5 | --num_G_accumulations 8 --num_D_accumulations 8 \ 6 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 7 | --G_attn 64 --D_attn 64 \ 8 | --G_ch 128 --D_ch 128 \ 9 | --G_depth 2 --D_depth 2 \ 10 | --G_nl inplace_relu --D_nl inplace_relu \ 11 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 12 | --G_ortho 0.0 \ 13 | --G_shared \ 14 | --G_init ortho --D_init ortho \ 15 | --hier --dim_z 128 --shared_dim 128 \ 16 | --ema --use_ema --ema_start 20000 --G_eval_mode \ 17 | --test_every 2000 --save_every 500 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 18 | --use_multiepoch_sampler \ -------------------------------------------------------------------------------- /biggan_pytorch/scripts/launch_SAGAN_bs128x2_ema.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 128 \ 4 | --num_G_accumulations 2 --num_D_accumulations 2 \ 5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 6 | --G_attn 64 --D_attn 64 \ 7 | --G_nl relu --D_nl relu \ 8 | --SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ 9 | --G_ortho 0.0 \ 10 | --G_init xavier --D_init xavier \ 11 | --ema --use_ema --ema_start 2000 --G_eval_mode \ 12 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 13 | --name_suffix SAGAN_ema \ -------------------------------------------------------------------------------- /biggan_pytorch/scripts/launch_SNGAN.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 64 \ 4 | --num_G_accumulations 1 --num_D_accumulations 1 \ 5 | --num_D_steps 5 --G_lr 2e-4 --D_lr 2e-4 --D_B2 0.900 --G_B2 0.900 \ 6 | --G_attn 0 --D_attn 0 \ 7 | --G_nl relu --D_nl relu \ 8 | --SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \ 9 | --G_ortho 0.0 \ 10 | --D_thin \ 11 | --G_init xavier --D_init xavier \ 12 | --G_eval_mode \ 13 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 14 | --name_suffix SNGAN \ -------------------------------------------------------------------------------- /biggan_pytorch/scripts/launch_cifar_ema.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0,1 python train.py \ 3 | --shuffle --batch_size 50 --parallel \ 4 | --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ 5 | --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ 6 | --dataset C10 \ 7 | --G_ortho 0.0 \ 8 | --G_attn 0 --D_attn 0 \ 9 | --G_init N02 --D_init N02 \ 10 | --ema --use_ema --ema_start 1000 \ 11 | --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 -------------------------------------------------------------------------------- /biggan_pytorch/scripts/sample_BigGAN_bs256x8.sh: -------------------------------------------------------------------------------- 1 | # use z_var to change the variance of z for all the sampling 2 | # use --mybn --accumulate_stats --num_standing_accumulations 32 to 3 | # use running stats 4 | python sample.py \ 5 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \ 6 | --num_G_accumulations 8 --num_D_accumulations 8 \ 7 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \ 8 | --G_attn 64 --D_attn 64 \ 9 | --G_ch 96 --D_ch 96 \ 10 | --G_nl inplace_relu --D_nl inplace_relu \ 11 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \ 12 | --G_ortho 0.0 \ 13 | --G_shared \ 14 | --G_init ortho --D_init ortho --skip_init \ 15 | --hier --dim_z 120 --shared_dim 128 \ 16 | --ema --ema_start 20000 \ 17 | --use_multiepoch_sampler \ 18 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \ 19 | --skip_init --G_batch_size 512 --use_ema --G_eval_mode --sample_trunc_curves 0.05_0.05_1.0 \ 20 | --sample_inception_metrics --sample_npz --sample_random --sample_sheets --sample_interps 21 | -------------------------------------------------------------------------------- /biggan_pytorch/scripts/sample_cifar_ema.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0,1 python sample.py \ 3 | --shuffle --batch_size 50 --G_batch_size 256 --parallel \ 4 | --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \ 5 | --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \ 6 | --dataset C10 \ 7 | --G_ortho 0.0 \ 8 | --G_attn 0 --D_attn 0 \ 9 | --G_init N02 --D_init N02 \ 10 | --ema --use_ema --ema_start 1000 \ 11 | --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0 -------------------------------------------------------------------------------- /biggan_pytorch/scripts/utils/duplicate.sh: -------------------------------------------------------------------------------- 1 | #duplicate.sh 2 | source=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0 3 | target=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0A 4 | logs_root=logs 5 | weights_root=weights 6 | echo "copying ${source} to ${target}" 7 | cp -r ${logs_root}/${source} ${logs_root}/${target} 8 | cp ${logs_root}/${source}_log.jsonl ${logs_root}/${target}_log.jsonl 9 | cp ${weights_root}/${source}_G.pth ${weights_root}/${target}_G.pth 10 | cp ${weights_root}/${source}_G_ema.pth ${weights_root}/${target}_G_ema.pth 11 | cp ${weights_root}/${source}_D.pth ${weights_root}/${target}_D.pth 12 | cp ${weights_root}/${source}_G_optim.pth ${weights_root}/${target}_G_optim.pth 13 | cp ${weights_root}/${source}_D_optim.pth ${weights_root}/${target}_D_optim.pth 14 | cp ${weights_root}/${source}_state_dict.pth ${weights_root}/${target}_state_dict.pth -------------------------------------------------------------------------------- /biggan_pytorch/scripts/utils/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python make_hdf5.py --dataset I128 --batch_size 256 --data_root data 3 | python calculate_inception_moments.py --dataset I128_hdf5 --data_root data -------------------------------------------------------------------------------- /biggan_pytorch/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /biggan_pytorch/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size']) 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input, gain=None, bias=None): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | out = F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | if gain is not None: 55 | out = out + gain 56 | if bias is not None: 57 | out = out + bias 58 | return out 59 | 60 | # Resize the input to (B, C, -1). 61 | input_shape = input.size() 62 | # print(input_shape) 63 | input = input.view(input.size(0), input.size(1), -1) 64 | 65 | # Compute the sum and square-sum. 66 | sum_size = input.size(0) * input.size(2) 67 | input_sum = _sum_ft(input) 68 | input_ssum = _sum_ft(input ** 2) 69 | # Reduce-and-broadcast the statistics. 70 | # print('it begins') 71 | if self._parallel_id == 0: 72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 73 | else: 74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | # if self._parallel_id == 0: 76 | # # print('here') 77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 78 | # else: 79 | # # print('there') 80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 81 | 82 | # print('how2') 83 | # num = sum_size 84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu()))) 85 | # Fix the graph 86 | # sum = (sum.detach() - input_sum.detach()) + input_sum 87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum 88 | 89 | # mean = sum / num 90 | # var = ssum / num - mean ** 2 91 | # # var = (ssum - mean * sum) / num 92 | # inv_std = torch.rsqrt(var + self.eps) 93 | 94 | # Compute the output. 95 | if gain is not None: 96 | # print('gaining') 97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1) 98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1) 99 | # output = input * scale - shift 100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1) 101 | elif self.affine: 102 | # MJY:: Fuse the multiplication for speed. 103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 104 | else: 105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 106 | 107 | # Reshape it. 108 | return output.view(input_shape) 109 | 110 | def __data_parallel_replicate__(self, ctx, copy_id): 111 | self._is_parallel = True 112 | self._parallel_id = copy_id 113 | 114 | # parallel_id == 0 means master device. 115 | if self._parallel_id == 0: 116 | ctx.sync_master = self._sync_master 117 | else: 118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 119 | 120 | def _data_parallel_master(self, intermediates): 121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 122 | 123 | # Always using same "device order" makes the ReduceAdd operation faster. 124 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 126 | 127 | to_reduce = [i[1][:2] for i in intermediates] 128 | to_reduce = [j for i in to_reduce for j in i] # flatten 129 | target_gpus = [i[1].sum.get_device() for i in intermediates] 130 | 131 | sum_size = sum([i[1].sum_size for i in intermediates]) 132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 134 | 135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 136 | # print('a') 137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size) 138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device)) 139 | # print('b') 140 | outputs = [] 141 | for i, rec in enumerate(intermediates): 142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3]))) 144 | 145 | return outputs 146 | 147 | def _compute_mean_std(self, sum_, ssum, size): 148 | """Compute the mean and standard-deviation with sum and square-sum. This method 149 | also maintains the moving average on the master device.""" 150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 151 | mean = sum_ / size 152 | sumvar = ssum - sum_ * mean 153 | unbias_var = sumvar / (size - 1) 154 | bias_var = sumvar / size 155 | 156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 158 | return mean, torch.rsqrt(bias_var + self.eps) 159 | # return mean, bias_var.clamp(self.eps) ** -0.5 160 | 161 | 162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 164 | mini-batch. 165 | 166 | .. math:: 167 | 168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 169 | 170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 171 | standard-deviation are reduced across all devices during training. 172 | 173 | For example, when one uses `nn.DataParallel` to wrap the network during 174 | training, PyTorch's implementation normalize the tensor on each device using 175 | the statistics only on that device, which accelerated the computation and 176 | is also easy to implement, but the statistics might be inaccurate. 177 | Instead, in this synchronized version, the statistics will be computed 178 | over all training samples distributed on multiple devices. 179 | 180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 181 | as the built-in PyTorch implementation. 182 | 183 | The mean and standard-deviation are calculated per-dimension over 184 | the mini-batches and gamma and beta are learnable parameter vectors 185 | of size C (where C is the input size). 186 | 187 | During training, this layer keeps a running estimate of its computed mean 188 | and variance. The running sum is kept with a default momentum of 0.1. 189 | 190 | During evaluation, this running mean/variance is used for normalization. 191 | 192 | Because the BatchNorm is done over the `C` dimension, computing statistics 193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 194 | 195 | Args: 196 | num_features: num_features from an expected input of size 197 | `batch_size x num_features [x width]` 198 | eps: a value added to the denominator for numerical stability. 199 | Default: 1e-5 200 | momentum: the value used for the running_mean and running_var 201 | computation. Default: 0.1 202 | affine: a boolean value that when set to ``True``, gives the layer learnable 203 | affine parameters. Default: ``True`` 204 | 205 | Shape: 206 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 208 | 209 | Examples: 210 | >>> # With Learnable Parameters 211 | >>> m = SynchronizedBatchNorm1d(100) 212 | >>> # Without Learnable Parameters 213 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 214 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 215 | >>> output = m(input) 216 | """ 217 | 218 | def _check_input_dim(self, input): 219 | if input.dim() != 2 and input.dim() != 3: 220 | raise ValueError('expected 2D or 3D input (got {}D input)' 221 | .format(input.dim())) 222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 223 | 224 | 225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 227 | of 3d inputs 228 | 229 | .. math:: 230 | 231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 232 | 233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 234 | standard-deviation are reduced across all devices during training. 235 | 236 | For example, when one uses `nn.DataParallel` to wrap the network during 237 | training, PyTorch's implementation normalize the tensor on each device using 238 | the statistics only on that device, which accelerated the computation and 239 | is also easy to implement, but the statistics might be inaccurate. 240 | Instead, in this synchronized version, the statistics will be computed 241 | over all training samples distributed on multiple devices. 242 | 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | 246 | The mean and standard-deviation are calculated per-dimension over 247 | the mini-batches and gamma and beta are learnable parameter vectors 248 | of size C (where C is the input size). 249 | 250 | During training, this layer keeps a running estimate of its computed mean 251 | and variance. The running sum is kept with a default momentum of 0.1. 252 | 253 | During evaluation, this running mean/variance is used for normalization. 254 | 255 | Because the BatchNorm is done over the `C` dimension, computing statistics 256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 257 | 258 | Args: 259 | num_features: num_features from an expected input of 260 | size batch_size x num_features x height x width 261 | eps: a value added to the denominator for numerical stability. 262 | Default: 1e-5 263 | momentum: the value used for the running_mean and running_var 264 | computation. Default: 0.1 265 | affine: a boolean value that when set to ``True``, gives the layer learnable 266 | affine parameters. Default: ``True`` 267 | 268 | Shape: 269 | - Input: :math:`(N, C, H, W)` 270 | - Output: :math:`(N, C, H, W)` (same shape as input) 271 | 272 | Examples: 273 | >>> # With Learnable Parameters 274 | >>> m = SynchronizedBatchNorm2d(100) 275 | >>> # Without Learnable Parameters 276 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 278 | >>> output = m(input) 279 | """ 280 | 281 | def _check_input_dim(self, input): 282 | if input.dim() != 4: 283 | raise ValueError('expected 4D input (got {}D input)' 284 | .format(input.dim())) 285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 286 | 287 | 288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 290 | of 4d inputs 291 | 292 | .. math:: 293 | 294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 295 | 296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 297 | standard-deviation are reduced across all devices during training. 298 | 299 | For example, when one uses `nn.DataParallel` to wrap the network during 300 | training, PyTorch's implementation normalize the tensor on each device using 301 | the statistics only on that device, which accelerated the computation and 302 | is also easy to implement, but the statistics might be inaccurate. 303 | Instead, in this synchronized version, the statistics will be computed 304 | over all training samples distributed on multiple devices. 305 | 306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 307 | as the built-in PyTorch implementation. 308 | 309 | The mean and standard-deviation are calculated per-dimension over 310 | the mini-batches and gamma and beta are learnable parameter vectors 311 | of size C (where C is the input size). 312 | 313 | During training, this layer keeps a running estimate of its computed mean 314 | and variance. The running sum is kept with a default momentum of 0.1. 315 | 316 | During evaluation, this running mean/variance is used for normalization. 317 | 318 | Because the BatchNorm is done over the `C` dimension, computing statistics 319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 320 | or Spatio-temporal BatchNorm 321 | 322 | Args: 323 | num_features: num_features from an expected input of 324 | size batch_size x num_features x depth x height x width 325 | eps: a value added to the denominator for numerical stability. 326 | Default: 1e-5 327 | momentum: the value used for the running_mean and running_var 328 | computation. Default: 0.1 329 | affine: a boolean value that when set to ``True``, gives the layer learnable 330 | affine parameters. Default: ``True`` 331 | 332 | Shape: 333 | - Input: :math:`(N, C, D, H, W)` 334 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 335 | 336 | Examples: 337 | >>> # With Learnable Parameters 338 | >>> m = SynchronizedBatchNorm3d(100) 339 | >>> # Without Learnable Parameters 340 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 342 | >>> output = m(input) 343 | """ 344 | 345 | def _check_input_dim(self, input): 346 | if input.dim() != 5: 347 | raise ValueError('expected 5D input (got {}D input)' 348 | .format(input.dim())) 349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /biggan_pytorch/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /biggan_pytorch/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /biggan_pytorch/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /biggan_pytorch/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /biggan_pytorch/train.py: -------------------------------------------------------------------------------- 1 | """ BigGAN: The Authorized Unofficial PyTorch release 2 | Code by A. Brock and A. Andonian 3 | This code is an unofficial reimplementation of 4 | "Large-Scale GAN Training for High Fidelity Natural Image Synthesis," 5 | by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096). 6 | 7 | Let's go. 8 | """ 9 | 10 | import os 11 | import functools 12 | import math 13 | import numpy as np 14 | from tqdm import tqdm, trange 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import init 20 | import torch.optim as optim 21 | import torch.nn.functional as F 22 | from torch.nn import Parameter as P 23 | import torchvision 24 | 25 | # Import my stuff 26 | import inception_utils 27 | import utils 28 | import losses 29 | import train_fns 30 | from sync_batchnorm import patch_replication_callback 31 | 32 | # The main training file. Config is a dictionary specifying the configuration 33 | # of this training run. 34 | def run(config): 35 | 36 | # Update the config dict as necessary 37 | # This is for convenience, to add settings derived from the user-specified 38 | # configuration into the config-dict (e.g. inferring the number of classes 39 | # and size of the images from the dataset, passing in a pytorch object 40 | # for the activation specified as a string) 41 | config['resolution'] = utils.imsize_dict[config['dataset']] 42 | config['n_classes'] = utils.nclass_dict[config['dataset']] 43 | config['G_activation'] = utils.activation_dict[config['G_nl']] 44 | config['D_activation'] = utils.activation_dict[config['D_nl']] 45 | # By default, skip init if resuming training. 46 | if config['resume']: 47 | print('Skipping initialization for training resumption...') 48 | config['skip_init'] = True 49 | config = utils.update_config_roots(config) 50 | device = 'cuda' 51 | 52 | # Seed RNG 53 | utils.seed_rng(config['seed']) 54 | 55 | # Prepare root folders if necessary 56 | utils.prepare_root(config) 57 | 58 | # Setup cudnn.benchmark for free speed 59 | torch.backends.cudnn.benchmark = True 60 | 61 | # Import the model--this line allows us to dynamically select different files. 62 | model = __import__(config['model']) 63 | experiment_name = (config['experiment_name'] if config['experiment_name'] 64 | else utils.name_from_config(config)) 65 | print('Experiment name is %s' % experiment_name) 66 | 67 | # Next, build the model 68 | G = model.Generator(**config).to(device) 69 | D = model.Discriminator(**config).to(device) 70 | 71 | # If using EMA, prepare it 72 | if config['ema']: 73 | print('Preparing EMA for G with decay of {}'.format(config['ema_decay'])) 74 | G_ema = model.Generator(**{**config, 'skip_init':True, 75 | 'no_optim': True}).to(device) 76 | ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start']) 77 | else: 78 | G_ema, ema = None, None 79 | 80 | # FP16? 81 | if config['G_fp16']: 82 | print('Casting G to float16...') 83 | G = G.half() 84 | if config['ema']: 85 | G_ema = G_ema.half() 86 | if config['D_fp16']: 87 | print('Casting D to fp16...') 88 | D = D.half() 89 | # Consider automatically reducing SN_eps? 90 | GD = model.G_D(G, D) 91 | print(G) 92 | print(D) 93 | print('Number of params in G: {} D: {}'.format( 94 | *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]])) 95 | # Prepare state dict, which holds things like epoch # and itr # 96 | state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 97 | 'best_IS': 0, 'best_FID': 999999, 'config': config} 98 | 99 | # If loading from a pre-trained model, load weights 100 | if config['resume']: 101 | print('Loading weights...') 102 | utils.load_weights(G, D, state_dict, 103 | config['weights_root'], experiment_name, 104 | config['load_weights'] if config['load_weights'] else None, 105 | G_ema if config['ema'] else None) 106 | 107 | # If parallel, parallelize the GD module 108 | if config['parallel']: 109 | GD = nn.DataParallel(GD) 110 | if config['cross_replica']: 111 | patch_replication_callback(GD) 112 | 113 | # Prepare loggers for stats; metrics holds test metrics, 114 | # lmetrics holds any desired training metrics. 115 | test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], 116 | experiment_name) 117 | train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name) 118 | print('Inception Metrics will be saved to {}'.format(test_metrics_fname)) 119 | test_log = utils.MetricsLogger(test_metrics_fname, 120 | reinitialize=(not config['resume'])) 121 | print('Training Metrics will be saved to {}'.format(train_metrics_fname)) 122 | train_log = utils.MyLogger(train_metrics_fname, 123 | reinitialize=(not config['resume']), 124 | logstyle=config['logstyle']) 125 | # Write metadata 126 | utils.write_metadata(config['logs_root'], experiment_name, config, state_dict) 127 | # Prepare data; the Discriminator's batch size is all that needs to be passed 128 | # to the dataloader, as G doesn't require dataloading. 129 | # Note that at every loader iteration we pass in enough data to complete 130 | # a full D iteration (regardless of number of D steps and accumulations) 131 | D_batch_size = (config['batch_size'] * config['num_D_steps'] 132 | * config['num_D_accumulations']) 133 | loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 134 | 'start_itr': state_dict['itr']}) 135 | 136 | # Prepare inception metrics: FID and IS 137 | get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid']) 138 | 139 | # Prepare noise and randomly sampled label arrays 140 | # Allow for different batch sizes in G 141 | G_batch_size = max(config['G_batch_size'], config['batch_size']) 142 | z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], 143 | device=device, fp16=config['G_fp16']) 144 | # Prepare a fixed z & y to see individual sample evolution throghout training 145 | fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, 146 | config['n_classes'], device=device, 147 | fp16=config['G_fp16']) 148 | fixed_z.sample_() 149 | fixed_y.sample_() 150 | # Loaders are loaded, prepare the training function 151 | if config['which_train_fn'] == 'GAN': 152 | train = train_fns.GAN_training_function(G, D, GD, z_, y_, 153 | ema, state_dict, config) 154 | # Else, assume debugging and use the dummy train fn 155 | else: 156 | train = train_fns.dummy_training_function() 157 | # Prepare Sample function for use with inception metrics 158 | sample = functools.partial(utils.sample, 159 | G=(G_ema if config['ema'] and config['use_ema'] 160 | else G), 161 | z_=z_, y_=y_, config=config) 162 | 163 | print('Beginning training at epoch %d...' % state_dict['epoch']) 164 | # Train for specified number of epochs, although we mostly track G iterations. 165 | for epoch in range(state_dict['epoch'], config['num_epochs']): 166 | # Which progressbar to use? TQDM or my own? 167 | if config['pbar'] == 'mine': 168 | pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta') 169 | else: 170 | pbar = tqdm(loaders[0]) 171 | for i, (x, y) in enumerate(pbar): 172 | # Increment the iteration counter 173 | state_dict['itr'] += 1 174 | # Make sure G and D are in training mode, just in case they got set to eval 175 | # For D, which typically doesn't have BN, this shouldn't matter much. 176 | G.train() 177 | D.train() 178 | if config['ema']: 179 | G_ema.train() 180 | if config['D_fp16']: 181 | x, y = x.to(device).half(), y.to(device) 182 | else: 183 | x, y = x.to(device), y.to(device) 184 | metrics = train(x, y) 185 | train_log.log(itr=int(state_dict['itr']), **metrics) 186 | 187 | # Every sv_log_interval, log singular values 188 | if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])): 189 | train_log.log(itr=int(state_dict['itr']), 190 | **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')}) 191 | 192 | # If using my progbar, print metrics. 193 | if config['pbar'] == 'mine': 194 | print(', '.join(['itr: %d' % state_dict['itr']] 195 | + ['%s : %+4.3f' % (key, metrics[key]) 196 | for key in metrics]), end=' ') 197 | 198 | # Save weights and copies as configured at specified interval 199 | if not (state_dict['itr'] % config['save_every']): 200 | if config['G_eval_mode']: 201 | print('Switchin G to eval mode...') 202 | G.eval() 203 | if config['ema']: 204 | G_ema.eval() 205 | train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, 206 | state_dict, config, experiment_name) 207 | 208 | # Test every specified interval 209 | if not (state_dict['itr'] % config['test_every']): 210 | if config['G_eval_mode']: 211 | print('Switchin G to eval mode...') 212 | G.eval() 213 | train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample, 214 | get_inception_metrics, experiment_name, test_log) 215 | # Increment epoch counter at end of epoch 216 | state_dict['epoch'] += 1 217 | 218 | 219 | def main(): 220 | # parse command line and run 221 | parser = utils.prepare_parser() 222 | config = vars(parser.parse_args()) 223 | print(config) 224 | run(config) 225 | 226 | if __name__ == '__main__': 227 | main() -------------------------------------------------------------------------------- /biggan_pytorch/train_fns.py: -------------------------------------------------------------------------------- 1 | ''' train_fns.py 2 | Functions for the main loop of training different conditional image models 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import os 8 | 9 | import utils 10 | import losses 11 | 12 | 13 | # Dummy training function for debugging 14 | def dummy_training_function(): 15 | def train(x, y): 16 | return {} 17 | return train 18 | 19 | 20 | def GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config): 21 | def train(x, y): 22 | G.optim.zero_grad() 23 | D.optim.zero_grad() 24 | # How many chunks to split x and y into? 25 | x = torch.split(x, config['batch_size']) 26 | y = torch.split(y, config['batch_size']) 27 | counter = 0 28 | 29 | # Optionally toggle D and G's "require_grad" 30 | if config['toggle_grads']: 31 | utils.toggle_grad(D, True) 32 | utils.toggle_grad(G, False) 33 | 34 | for step_index in range(config['num_D_steps']): 35 | # If accumulating gradients, loop multiple times before an optimizer step 36 | D.optim.zero_grad() 37 | for accumulation_index in range(config['num_D_accumulations']): 38 | z_.sample_() 39 | y_.sample_() 40 | D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], 41 | x[counter], y[counter], train_G=False, 42 | split_D=config['split_D']) 43 | 44 | # Compute components of D's loss, average them, and divide by 45 | # the number of gradient accumulations 46 | D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) 47 | D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations']) 48 | D_loss.backward() 49 | counter += 1 50 | 51 | # Optionally apply ortho reg in D 52 | if config['D_ortho'] > 0.0: 53 | # Debug print to indicate we're using ortho reg in D. 54 | print('using modified ortho reg in D') 55 | utils.ortho(D, config['D_ortho']) 56 | 57 | D.optim.step() 58 | 59 | # Optionally toggle "requires_grad" 60 | if config['toggle_grads']: 61 | utils.toggle_grad(D, False) 62 | utils.toggle_grad(G, True) 63 | 64 | # Zero G's gradients by default before training G, for safety 65 | G.optim.zero_grad() 66 | 67 | # If accumulating gradients, loop multiple times 68 | for accumulation_index in range(config['num_G_accumulations']): 69 | z_.sample_() 70 | y_.sample_() 71 | D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) 72 | G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations']) 73 | G_loss.backward() 74 | 75 | # Optionally apply modified ortho reg in G 76 | if config['G_ortho'] > 0.0: 77 | print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G 78 | # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this 79 | utils.ortho(G, config['G_ortho'], 80 | blacklist=[param for param in G.shared.parameters()]) 81 | G.optim.step() 82 | 83 | # If we have an ema, update it, regardless of if we test with it or not 84 | if config['ema']: 85 | ema.update(state_dict['itr']) 86 | 87 | out = {'G_loss': float(G_loss.item()), 88 | 'D_loss_real': float(D_loss_real.item()), 89 | 'D_loss_fake': float(D_loss_fake.item())} 90 | # Return G's loss and the components of D's loss. 91 | return out 92 | return train 93 | 94 | ''' This function takes in the model, saves the weights (multiple copies if 95 | requested), and prepares sample sheets: one consisting of samples given 96 | a fixed noise seed (to show how the model evolves throughout training), 97 | a set of full conditional sample sheets, and a set of interp sheets. ''' 98 | def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, 99 | state_dict, config, experiment_name): 100 | utils.save_weights(G, D, state_dict, config['weights_root'], 101 | experiment_name, None, G_ema if config['ema'] else None) 102 | # Save an additional copy to mitigate accidental corruption if process 103 | # is killed during a save (it's happened to me before -.-) 104 | if config['num_save_copies'] > 0: 105 | utils.save_weights(G, D, state_dict, config['weights_root'], 106 | experiment_name, 107 | 'copy%d' % state_dict['save_num'], 108 | G_ema if config['ema'] else None) 109 | state_dict['save_num'] = (state_dict['save_num'] + 1 ) % config['num_save_copies'] 110 | 111 | # Use EMA G for samples or non-EMA? 112 | which_G = G_ema if config['ema'] and config['use_ema'] else G 113 | 114 | # Accumulate standing statistics? 115 | if config['accumulate_stats']: 116 | utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, 117 | z_, y_, config['n_classes'], 118 | config['num_standing_accumulations']) 119 | 120 | # Save a random sample sheet with fixed z and y 121 | with torch.no_grad(): 122 | if config['parallel']: 123 | fixed_Gz = nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y))) 124 | else: 125 | fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y)) 126 | if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)): 127 | os.mkdir('%s/%s' % (config['samples_root'], experiment_name)) 128 | image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'], 129 | experiment_name, 130 | state_dict['itr']) 131 | torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename, 132 | nrow=int(fixed_Gz.shape[0] **0.5), normalize=True) 133 | # For now, every time we save, also save sample sheets 134 | utils.sample_sheet(which_G, 135 | classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], 136 | num_classes=config['n_classes'], 137 | samples_per_class=10, parallel=config['parallel'], 138 | samples_root=config['samples_root'], 139 | experiment_name=experiment_name, 140 | folder_number=state_dict['itr'], 141 | z_=z_) 142 | # Also save interp sheets 143 | for fix_z, fix_y in zip([False, False, True], [False, True, False]): 144 | utils.interp_sheet(which_G, 145 | num_per_sheet=16, 146 | num_midpoints=8, 147 | num_classes=config['n_classes'], 148 | parallel=config['parallel'], 149 | samples_root=config['samples_root'], 150 | experiment_name=experiment_name, 151 | folder_number=state_dict['itr'], 152 | sheet_number=0, 153 | fix_z=fix_z, fix_y=fix_y, device='cuda') 154 | 155 | 156 | 157 | ''' This function runs the inception metrics code, checks if the results 158 | are an improvement over the previous best (either in IS or FID, 159 | user-specified), logs the results, and saves a best_ copy if it's an 160 | improvement. ''' 161 | def test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics, 162 | experiment_name, test_log): 163 | print('Gathering inception metrics...') 164 | if config['accumulate_stats']: 165 | utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G, 166 | z_, y_, config['n_classes'], 167 | config['num_standing_accumulations']) 168 | IS_mean, IS_std, FID = get_inception_metrics(sample, 169 | config['num_inception_images'], 170 | num_splits=10) 171 | print('Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID)) 172 | # If improved over previous best metric, save approrpiate copy 173 | if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS']) 174 | or (config['which_best'] == 'FID' and FID < state_dict['best_FID'])): 175 | print('%s improved over previous best, saving checkpoint...' % config['which_best']) 176 | utils.save_weights(G, D, state_dict, config['weights_root'], 177 | experiment_name, 'best%d' % state_dict['save_best_num'], 178 | G_ema if config['ema'] else None) 179 | state_dict['save_best_num'] = (state_dict['save_best_num'] + 1 ) % config['num_best_copies'] 180 | state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean) 181 | state_dict['best_FID'] = min(state_dict['best_FID'], FID) 182 | # Log results to file 183 | test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean), 184 | IS_std=float(IS_std), FID=float(FID)) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import Dataset 13 | from PIL import Image 14 | 15 | class ImagenetDataset(Dataset): 16 | # dataset for training bigdataset gan 17 | def __init__(self, data_root): 18 | self.label_dir = os.path.join(data_root, 'annotations/biggan512/') 19 | self.latent_dir = os.path.join(data_root, 'latents/biggan512/') 20 | 21 | self._prepare_data_list() 22 | 23 | def _prepare_data_list(self): 24 | class_list = sorted(os.listdir(self.label_dir)) 25 | label_list = [] 26 | latent_list = [] 27 | for class_n in class_list: 28 | label_file_list = sorted(os.listdir(os.path.join(self.label_dir, class_n))) 29 | latent_file_list = sorted(os.listdir(os.path.join(self.latent_dir, class_n))) 30 | 31 | for label_file_n, latent_file_n in zip(label_file_list, latent_file_list): 32 | label_list.append(os.path.join(class_n, label_file_n)) 33 | latent_list.append(os.path.join(class_n, latent_file_n)) 34 | 35 | self.label_list = label_list 36 | self.latent_list = latent_list 37 | self.data_size = len(self.label_list) 38 | 39 | def __len__(self): 40 | return self.data_size 41 | 42 | def __getitem__(self, idx): 43 | latent_z = np.load(os.path.join(self.latent_dir, self.latent_list[idx]))[0] 44 | label_pil = Image.open(os.path.join(self.label_dir, self.label_list[idx])).convert('L') 45 | label_np = np.array(label_pil) 46 | # make label to 1 47 | label_np[label_np != 0] = 1 48 | class_y = int(self.label_list[idx].split('.')[0].split('_')[-2]) 49 | 50 | latent_z = torch.tensor(latent_z, dtype=torch.float) 51 | label_tensor = torch.tensor(label_np, dtype=torch.long) 52 | class_y_tensor = torch.tensor(class_y, dtype=torch.long) 53 | 54 | return { 55 | 'latent': latent_z, 56 | 'label': label_tensor, 57 | 'y': class_y_tensor, 58 | } 59 | 60 | -------------------------------------------------------------------------------- /datasets/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import json 10 | 11 | # with open("./imagenet2synset.json", 'r') as f: 12 | # imagenet_to_synset = json.load(f) 13 | 14 | # with open("./imagenetclass992.json", 'r') as f: 15 | # synset_to_imagenet = json.load(f) 16 | 17 | pascal_to_synset = { 18 | "bird": ["cock (n01514668)", "hen (n01514859)", "water ouzel (n01601694)", "house finch (n01532829)", "brambling (n01530575)", "junco (n01534433)", "goldfinch (n01531178)", "indigo bunting (n01537544)", "chickadee (n01592084)", "robin (n01558993)", "bulbul (n01560419)", "magpie (n01582220)", "jay (n01580077)", "black swan (n01860187)", "crane (n02012849)", "spoonbill (n02006656)", "flamingo (n02007558)", "bustard (n02018795)", "limpkin (n02013706)", "bittern (n02011460)", "little blue heron (n02009229)", "American egret (n02009912)", "oystercatcher (n02037110)", "dowitcher (n02033041)", "red-backed sandpiper (n02027492)", "redshank (n02028035)", "ruddy turnstone (n02025239)", "black stork (n02002724)", "white stork (n02002556)", "American coot (n02018207)", "king penguin (n02056570)", "albatross (n02058221)", "pelican (n02051845)", "European gallinule (n02017213)", "goose (n01855672)", "drake (n01847000)", "red-breasted merganser (n01855032)", "coucal (n01824575)", "hummingbird (n01833805)", "ostrich (n01518878)", "ruffed grouse (n01797886)", "black grouse (n01795545)", "prairie chicken (n01798484)", "ptarmigan (n01796340)", "quail (n01806567)", "partridge (n01807496)", "peacock (n01806143)", "bee eater (n01828970)", "hornbill (n01829413)", "vulture (n01616318)", "great grey owl (n01622779)", "kite (n01608432)", "bald eagle (n01614925)", "jacamar (n01843065)", "toucan (n01843383)", "macaw (n01818515)", "African grey (n01817953)", "lorikeet (n01820546)", "sulphur-crested cockatoo (n01819313)"], 19 | "boat": ["gondola (n03447447)", "fireboat (n03344393)", "yawl (n04612504)", "canoe (n02951358)", "lifeboat (n03662601)", "speedboat (n04273569)"], 20 | "bottle": ["beer bottle (n02823428)", "wine bottle (n04591713)", "water bottle (n04557648)", "pop bottle (n03983396)", "pill bottle (n03937543)", "whiskey jug (n04579145)", "water jug (n04560804)"], 21 | "bus": ["minibus (n03769881)", "school bus (n04146614)", "trolleybus (n04487081)"], 22 | "car": ["ambulance (n02701002)", "limousine (n03670208)", "jeep (n03594945)", "Model T (n03777568)", "cab (n02930766)", "minivan (n03770679)", "convertible (n03100240)", "racer (n04037443)", "beach wagon (n02814533)", "sports car (n04285008)"], 23 | "cat": ["Egyptian cat (n02124075)", "Persian cat (n02123394)", "tabby (n02123045)", "Siamese cat (n02123597)", "tiger cat (n02123159)", "cougar (n02125311)", "lynx (n02127052)"], 24 | "chair": ["barber chair (n02791124)", "rocking chair (n04099969)", "folding chair (n03376595)", "throne (n04429376)"], 25 | "diningtable": ["dining table (n03201208)"], 26 | "dog": ["dalmatian (n02110341)", "Mexican hairless (n02113978)", "pug (n02110958)", "Newfoundland (n02111277)", "Leonberg (n02111129)", "basenji (n02110806)", "Great Pyrenees (n02111500)", "Eskimo dog (n02109961)", "bull mastiff (n02108422)", "Saint Bernard (n02109525)", "Great Dane (n02109047)", "boxer (n02108089)", "Rottweiler (n02106550)", "Old English sheepdog (n02105641)", "Shetland sheepdog (n02105855)", "kelpie (n02105412)", "Border collie (n02106166)", "Bouvier des Flandres (n02106382)", "German shepherd (n02106662)", "komondor (n02105505)", "briard (n02105251)", "collie (n02106030)", "groenendael (n02105056)", "malinois (n02105162)", "French bulldog (n02108915)", "kuvasz (n02104029)", "schipperke (n02104365)", "Doberman (n02107142)", "affenpinscher (n02110627)", "miniature pinscher (n02107312)", "Tibetan mastiff (n02108551)", "Siberian husky (n02110185)", "malamute (n02110063)", "Bernese mountain dog (n02107683)", "Appenzeller (n02107908)", "EntleBucher (n02108000)", "Greater Swiss Mountain dog (n02107574)", "toy poodle (n02113624)", "miniature poodle (n02113712)", "standard poodle (n02113799)", "Pembroke (n02113023)", "Cardigan (n02113186)", "Rhodesian ridgeback (n02087394)", "Scottish deerhound (n02092002)", "bloodhound (n02088466)", "otterhound (n02091635)", "Afghan hound (n02088094)", "redbone (n02090379)", "bluetick (n02088632)", "basset (n02088238)", "Ibizan hound (n02091244)", "Saluki (n02091831)", "Norwegian elkhound (n02091467)", "beagle (n02088364)", "Weimaraner (n02092339)", "black-and-tan coonhound (n02089078)", "borzoi (n02090622)", "Irish wolfhound (n02090721)", "English foxhound (n02089973)", "Walker hound (n02089867)", "whippet (n02091134)", "Italian greyhound (n02091032)", "Dandie Dinmont (n02096437)", "Norwich terrier (n02094258)", "Border terrier (n02093754)", "West Highland white terrier (n02098286)", "Yorkshire terrier (n02094433)", "Airedale (n02096051)", "Irish terrier (n02093991)", "Bedlington terrier (n02093647)", "Norfolk terrier (n02094114)", "Lhasa (n02098413)", "silky terrier (n02097658)", "Kerry blue terrier (n02093859)", "Scotch terrier (n02097298)", "Tibetan terrier (n02097474)", "cairn (n02096177)", "soft-coated wheaten terrier (n02098105)", "Boston bull (n02096585)", "Australian terrier (n02096294)", "Staffordshire bullterrier (n02093256)", "American Staffordshire terrier (n02093428)", "Lakeland terrier (n02095570)", "Sealyham terrier (n02095889)", "giant schnauzer (n02097130)", "miniature schnauzer (n02097047)", "standard schnauzer (n02097209)", "wire-haired fox terrier (n02095314)", "curly-coated retriever (n02099429)", "flat-coated retriever (n02099267)", "golden retriever (n02099601)", "Chesapeake Bay retriever (n02099849)", "Labrador retriever (n02099712)", "Sussex spaniel (n02102480)", "Brittany spaniel (n02101388)", "clumber (n02101556)", "cocker spaniel (n02102318)", "English springer (n02102040)", "Welsh springer spaniel (n02102177)", "Irish water spaniel (n02102973)", "vizsla (n02100583)", "German short-haired pointer (n02100236)", "Irish setter (n02100877)", "Gordon setter (n02101006)", "English setter (n02100735)", "Maltese dog (n02085936)", "Chihuahua (n02085620)", "Pekinese (n02086079)", "Shih-Tzu (n02086240)", "toy terrier (n02087046)", "Japanese spaniel (n02085782)", "papillon (n02086910)", "Blenheim spaniel (n02086646)", "Brabancon griffon (n02112706)", "Samoyed (n02111889)", "Pomeranian (n02112018)", "keeshond (n02112350)", "chow (n02112137)"], 27 | "horse": ["sorrel (n02389026)"], 28 | #"motorbike": ["moped (n03785016)"], 29 | "person": ["ballplayer (n09835506)", "ballplayer (n09835506)", "groom (n10148035)", "scuba diver (n10565667)"], 30 | #"pottedplant": ["daisy (n11939491)", "yellow lady slipper (n12057211)"], 31 | "sheep": ["ram (n02412080)"], 32 | #"sofa": ["studio couch (n04344873)"], 33 | "train": ["bullet train (n02917067)"], 34 | #"tvmonitor": ["monitor (n03782006)"], 35 | "aeroplane": ["airliner (n02690373)"], 36 | "bicycle": ["bicycle-built-for-two (n02835271)", "mountain bike (n03792782)"], 37 | } 38 | 39 | def parse_synset_str(x): 40 | synset = '' 41 | i, j = 0, 0 42 | while True: 43 | if x[i] == '(': 44 | j = i + 1 45 | while True: 46 | if x[j] == ')': 47 | break 48 | synset += x[j] 49 | j += 1 50 | 51 | break 52 | i += 1 53 | 54 | return synset 55 | 56 | 57 | 58 | pascal_to_id = { 59 | 'background': 0, 60 | 'aeroplane': 1, 61 | 'bicycle': 2, 62 | 'bird': 3, 63 | 'boat': 4, 64 | 'bottle': 5, 65 | 'bus': 6, 66 | 'car': 7, 67 | 'cat': 8, 68 | 'chair': 9, 69 | 'diningtable': 10, 70 | 'dog': 11, 71 | 'horse': 12, 72 | #'motorbike': 13, 73 | 'person': 13, 74 | #'pottedplant': 15, 75 | 'sheep': 14, 76 | #'sofa': 17, 77 | 'train': 15, 78 | #'tvmonitor': 19, 79 | } 80 | 81 | pascal_to_coco = { 82 | 'background': 0, 83 | 'aeroplane': 5, 84 | 'bicycle': 2, 85 | 'bird': 16, 86 | 'boat': 9, 87 | 'bottle': 44, 88 | 'bus': 6, 89 | 'car': 3, 90 | 'cat': 17, 91 | 'chair': 62, 92 | #'cow': 21, 93 | 'diningtable': 67, 94 | 'dog': 18, 95 | 'horse': 19, 96 | #'motorbike': 4, 97 | 'person': 1, 98 | #'pottedplant': 64, 99 | 'sheep': 20, 100 | #'sofa': 63, 101 | 'train': 7, 102 | #'tvmonitor': 72, 103 | } 104 | 105 | random_synset_100 = ['n02966193', 'n02825657', 'n02708093', 'n07716906', 106 | 'n02325366', 'n02025239', 'n02097209', 'n02106662', 'n02277742', 'n02117135', 107 | 'n02087394', 'n01601694', 'n02447366', 'n01682714', 'n03884397', 'n01537544', 'n03063689', 108 | 'n04606251', 'n02493509', 'n02090622', 'n02071294', 'n13037406', 'n04146614', 'n02342885', 109 | 'n02110958', 'n03223299', 'n02963159', 'n02093859', 'n01494475', 'n01955084', 'n02490219', 110 | 'n02840245', 'n02108000', 'n01944390', 'n01860187', 'n02113799', 'n01910747', 'n02086910', 111 | 'n01978455', 'n02107312', 'n02965783', 'n02013706', 'n04033901', 'n01692333', 'n03207941', 112 | 'n02109961', 'n02687172', 'n02002724', 'n01775062', 'n02104365', 'n01749939', 'n01945685', 113 | 'n01704323', 'n04136333', 'n02105855', 'n02443484', 'n02056570', 'n02403003', 'n02134418', 114 | 'n03417042', 'n02096051', 'n02978881', 'n01531178', 'n03065424', 'n01806567', 'n02100877', 115 | 'n03126707', 'n01843065', 'n02814860', 'n02088238', 'n02999410', 'n01484850', 'n02259212', 116 | 'n02097474', 'n02877765', 'n02099712', 'n02123159', 'n01630670', 'n04252077', 'n03218198', 117 | 'n02489166', 'n02727426', 'n02097047', 'n02492035', 'n01728572', 'n03337140', 'n02268853', 118 | 'n01872401', 'n02094433', 'n02206856', 'n01753488', 'n02910353', 'n02114855', 'n03179701', 119 | 'n01498041', 'n04009552', 'n02177972', 'n03016953', 'n02894605', 'n01843383'] 120 | 121 | 122 | def get_imagenet_id_list(class_name): 123 | id_list = [] 124 | if class_name == 'imagenet-dog': 125 | for x in pascal_to_synset['dog']: 126 | synset_id = parse_synset_str(x) 127 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id']) 128 | 129 | elif class_name == 'imagenet-bird': 130 | for x in pascal_to_synset['bird']: 131 | synset_id = parse_synset_str(x) 132 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id']) 133 | 134 | elif class_name == 'imagenet-pascal': 135 | for key in pascal_to_synset: 136 | for x in pascal_to_synset[key]: 137 | synset_id = parse_synset_str(x) 138 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id']) 139 | 140 | elif class_name == 'imagenet-100': 141 | for synset_id in random_synset_100: 142 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id']) 143 | 144 | else: 145 | for key in synset_to_imagenet: 146 | if key != 'background': 147 | id_list.append(synset_to_imagenet[key]['imagenet_id']) 148 | 149 | return id_list -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import sys 13 | sys.path.append('./biggan_pytorch') 14 | from biggan_pytorch import BigGAN 15 | 16 | class SegBlock(nn.Module): 17 | def __init__(self, in_channels, out_channels, con_channels, 18 | which_conv=nn.Conv2d, which_linear=None, activation=None, 19 | upsample=None): 20 | super(SegBlock, self).__init__() 21 | 22 | self.in_channels, self.out_channels = in_channels, out_channels 23 | self.which_conv, self.which_linear = which_conv, which_linear 24 | self.activation = activation 25 | self.upsample = upsample 26 | # Conv layers 27 | self.conv1 = self.which_conv(self.in_channels, self.out_channels) 28 | self.conv2 = self.which_conv(self.out_channels, self.out_channels) 29 | self.learnable_sc = in_channels != out_channels or upsample 30 | if self.learnable_sc: 31 | self.conv_sc = self.which_conv(in_channels, out_channels, 32 | kernel_size=1, padding=0) 33 | # Batchnorm layers 34 | self.bn1 = BigGAN.layers.ccbn(in_channels, con_channels, self.which_linear, eps=1e-4, norm_style='bn') 35 | self.bn2 = BigGAN.layers.ccbn(out_channels, con_channels, self.which_linear, eps=1e-4, norm_style='bn') 36 | # upsample layers 37 | self.upsample = upsample 38 | 39 | def forward(self, x, y): 40 | h = self.activation(self.bn1(x, y)) 41 | if self.upsample: 42 | h = self.upsample(h) 43 | x = self.upsample(x) 44 | h = self.conv1(h) 45 | h = self.activation(self.bn2(h, y)) 46 | h = self.conv2(h) 47 | if self.learnable_sc: 48 | x = self.conv_sc(x) 49 | return h + x 50 | 51 | def get_config(resolution): 52 | attn_dict = {128: '64', 256: '128', 512: '64'} 53 | dim_z_dict = {128: 120, 256: 140, 512: 128} 54 | config = {'G_param': 'SN', 'D_param': 'SN', 55 | 'G_ch': 96, 'D_ch': 96, 56 | 'D_wide': True, 'G_shared': True, 57 | 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], 58 | 'hier': True, 'cross_replica': False, 59 | 'mybn': False, 'G_activation': nn.ReLU(inplace=True), 60 | 'G_attn': attn_dict[resolution], 61 | 'norm_style': 'bn', 62 | 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, 63 | 'G_fp16': False, 'G_mixed_precision': False, 64 | 'accumulate_stats': False, 'num_standing_accumulations': 16, 65 | 'G_eval_mode': True, 66 | 'BN_eps': 1e-04, 'SN_eps': 1e-04, 67 | 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, 68 | 'n_classes': 1000} 69 | return config 70 | 71 | class BigdatasetGANModel(nn.Module): 72 | def __init__(self, resolution, out_dim, biggan_ckpt=None): 73 | super(BigdatasetGANModel, self).__init__() 74 | self.biggan_ckpt = biggan_ckpt 75 | self.resolution = resolution 76 | # load biggan model 77 | self._prepare_biggan_model() 78 | 79 | self.low_feature_size = 32 80 | self.mid_feature_size = 128 81 | self.high_feature_size = 512 82 | 83 | low_feature_channel = 128 84 | mid_feature_channel = 64 85 | high_feature_channel = 32 86 | 87 | self.low_feature_conv = nn.Sequential( 88 | nn.Conv2d(3072, low_feature_channel, kernel_size=1, bias=False), 89 | #nn.ReLU(), 90 | ) 91 | self.mid_feature_conv = nn.Sequential( 92 | nn.Conv2d(960, mid_feature_channel, kernel_size=1, bias=False), 93 | #nn.ReLU(), 94 | ) 95 | self.mid_feature_mix_conv = SegBlock( 96 | in_channels=low_feature_channel+mid_feature_channel, 97 | out_channels=low_feature_channel+mid_feature_channel, 98 | con_channels=self.biggan_model.shared_dim, 99 | which_conv=self.biggan_model.which_conv, 100 | which_linear=self.biggan_model.which_linear, 101 | activation=self.biggan_model.activation, 102 | upsample=False, 103 | ) 104 | 105 | self.high_feature_conv = nn.Sequential( 106 | nn.Conv2d(192, high_feature_channel, kernel_size=1, bias=False), 107 | #nn.ReLU(), 108 | ) 109 | 110 | self.high_feature_mix_conv = SegBlock( 111 | in_channels=low_feature_channel+mid_feature_channel+high_feature_channel, 112 | out_channels=low_feature_channel+mid_feature_channel+high_feature_channel, 113 | con_channels=self.biggan_model.shared_dim, 114 | which_conv=self.biggan_model.which_conv, 115 | which_linear=self.biggan_model.which_linear, 116 | activation=self.biggan_model.activation, 117 | upsample=False, 118 | ) 119 | 120 | self.out_layer = nn.Conv2d(low_feature_channel+mid_feature_channel+high_feature_channel, 121 | out_dim, kernel_size=3, padding=1) 122 | self.out_layer = nn.Sequential( 123 | BigGAN.layers.bn(low_feature_channel+mid_feature_channel+high_feature_channel), 124 | self.biggan_model.activation, 125 | self.biggan_model.which_conv(low_feature_channel+mid_feature_channel+high_feature_channel, out_dim) 126 | ) 127 | 128 | def _prepare_biggan_model(self): 129 | biggan_config = get_config(self.resolution) 130 | self.biggan_model = BigGAN.Generator(**biggan_config) 131 | if self.biggan_ckpt != None: 132 | state_dict = torch.load(self.biggan_ckpt) 133 | self.biggan_model.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries 134 | self.biggan_model.eval() 135 | 136 | def _prepare_features(self, features, upsample='bilinear'): 137 | # for low feature 138 | low_features = [ 139 | F.interpolate(features[0], size=self.low_feature_size, mode=upsample, align_corners=False), 140 | F.interpolate(features[1], size=self.low_feature_size, mode=upsample, align_corners=False), 141 | features[2], 142 | ] 143 | low_features = torch.cat(low_features, dim=1) 144 | # for mid feature 145 | mid_features = [ 146 | F.interpolate(features[3], size=self.mid_feature_size, mode=upsample, align_corners=False), 147 | F.interpolate(features[4], size=self.mid_feature_size, mode=upsample, align_corners=False), 148 | features[5], 149 | ] 150 | mid_features = torch.cat(mid_features, dim=1) 151 | # for high feature 152 | high_features = [ 153 | F.interpolate(features[6], size=self.high_feature_size, mode=upsample, align_corners=False), 154 | #F.interpolate(features[7], size=self.high_feature_size, mode=upsample, align_corners=False), 155 | features[7], 156 | ] 157 | high_features = torch.cat(high_features, dim=1) 158 | 159 | features_dict = { 160 | 'low': low_features, 161 | 'mid': mid_features, 162 | 'high': high_features, 163 | } 164 | 165 | return features_dict 166 | 167 | @torch.no_grad() 168 | def _get_biggan_features(self, z, y): 169 | features = [] 170 | y = self.biggan_model.shared(y) 171 | # forward thru biggan 172 | # If hierarchical, concatenate zs and ys 173 | if self.biggan_model.hier: 174 | zs = torch.split(z, self.biggan_model.z_chunk_size, 1) 175 | z = zs[0] 176 | ys = [torch.cat([y, item], 1) for item in zs[1:]] 177 | else: 178 | ys = [y] * len(self.biggan_model.blocks) 179 | 180 | # First linear layer 181 | h = self.biggan_model.linear(z) 182 | # Reshape 183 | h = h.view(h.size(0), -1, self.biggan_model.bottom_width, self.biggan_model.bottom_width) 184 | 185 | # Loop over blocks 186 | for index, blocklist in enumerate(self.biggan_model.blocks): 187 | # Second inner loop in case block has multiple layers 188 | for block in blocklist: 189 | h = block(h, ys[index]) 190 | # save feature 191 | features.append(h) 192 | #print(index, h.shape) 193 | 194 | features_dict = self._prepare_features(features) 195 | 196 | return features_dict, y, h 197 | 198 | def forward(self, z, y): 199 | features_dict, y, _ = self._get_biggan_features(z, y) 200 | 201 | # for low features 202 | low_feat = self.low_feature_conv(features_dict['low']) 203 | low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False) 204 | # for mid features 205 | mid_feat = self.mid_feature_conv(features_dict['mid']) 206 | mid_feat = torch.cat([low_feat, mid_feat], dim=1) 207 | mid_feat = self.mid_feature_mix_conv(mid_feat, y) 208 | mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False) 209 | # for high features 210 | high_feat = self.high_feature_conv(features_dict['high']) 211 | high_feat = torch.cat([mid_feat, high_feat], dim=1) 212 | high_feat = self.high_feature_mix_conv(high_feat, y) 213 | out = self.out_layer(high_feat) 214 | 215 | return out 216 | 217 | @torch.no_grad() 218 | def sample(self, z, y): 219 | features_dict, y, h = self._get_biggan_features(z,y) 220 | 221 | image = torch.tanh(self.biggan_model.output_layer(h.detach())) 222 | 223 | # for low features 224 | low_feat = self.low_feature_conv(features_dict['low']) 225 | low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False) 226 | # for mid features 227 | mid_feat = self.mid_feature_conv(features_dict['mid']) 228 | mid_feat = torch.cat([low_feat, mid_feat], dim=1) 229 | mid_feat = self.mid_feature_mix_conv(mid_feat, y) 230 | mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False) 231 | # for high features 232 | high_feat = self.high_feature_conv(features_dict['high']) 233 | high_feat = torch.cat([mid_feat, high_feat], dim=1) 234 | high_feat = self.high_feature_mix_conv(high_feat, y) 235 | out = self.out_layer(high_feat) 236 | 237 | return image, out 238 | 239 | 240 | if __name__ == '__main__': 241 | 242 | biggan_ckpt = './pretrain/biggan-512.pth' 243 | model = BigdatasetGANModel(512, 1, biggan_ckpt).cuda() -------------------------------------------------------------------------------- /prepare_biggan_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | from models import BigdatasetGANModel 11 | import numpy as np 12 | import argparse 13 | import torch 14 | import torchvision 15 | 16 | def parse_args(): 17 | usage = 'Parser for generate biggan images script.' 18 | parser = argparse.ArgumentParser(description=usage) 19 | 20 | parser.add_argument( 21 | '--biggan_ckpt', type=str, default='./pretrain/biggan-512.pth', help='path to the pretrained biggan ckpt') 22 | parser.add_argument( 23 | '--dataset_dir', type=str, default='./data/', help='path to the dataset dir') 24 | 25 | args = parser.parse_args() 26 | 27 | return args 28 | 29 | if __name__ == '__main__': 30 | args = parse_args() 31 | 32 | device = 'cuda' 33 | 34 | latents_dir = os.path.join(args.dataset_dir, 'latents/biggan512/') 35 | images_dir = os.path.join(args.dataset_dir, 'images/biggan512/') 36 | 37 | # loading model 38 | generator = BigdatasetGANModel(resolution=512, out_dim=1, biggan_ckpt=args.biggan_ckpt).to(device) 39 | 40 | generator.eval() 41 | 42 | class_list = os.listdir(latents_dir) 43 | 44 | for class_n in class_list: 45 | latent_class_dir = os.path.join(latents_dir, class_n) 46 | image_class_dir = os.path.join(images_dir, class_n) 47 | 48 | os.makedirs(image_class_dir, exist_ok=True) 49 | 50 | latent_list = os.listdir(latent_class_dir) 51 | 52 | for latent_n in latent_list: 53 | image_name = latent_n.split('.')[0] 54 | latent_np = np.load(os.path.join(latent_class_dir, latent_n))[0] 55 | class_y = int(image_name.split('_')[-2]) 56 | 57 | latent_tensor = torch.tensor(latent_np, dtype=torch.float).unsqueeze(0).to(device) 58 | class_y_tensor = torch.tensor([class_y], dtype=torch.long).to(device) 59 | 60 | image_tensor, _ = generator.sample(latent_tensor, class_y_tensor) 61 | 62 | print("Saving biggan images from the latent to: ", os.path.join(image_class_dir, image_name+'.png')) 63 | # save image 64 | torchvision.utils.save_image(image_tensor, os.path.join(image_class_dir, image_name+'.png'), normalize=True) 65 | -------------------------------------------------------------------------------- /prepare_imagenet_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import shutil 11 | import argparse 12 | import numpy as np 13 | 14 | def parse_args(): 15 | usage = 'Parser for generate biggan images script.' 16 | parser = argparse.ArgumentParser(description=usage) 17 | 18 | parser.add_argument( 19 | '--imagenet_dir', type=str, required=True, help='path to the imagenet folder') 20 | parser.add_argument( 21 | '--dataset_dir', type=str, default='./data/', help='path to the dataset dir') 22 | 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | if __name__ == '__main__': 28 | args = parse_args() 29 | 30 | real_list_path = os.path.join(args.dataset_dir, 'annotations/real-random-list.txt') 31 | real_list = np.loadtxt(real_list_path, dtype=str) 32 | 33 | for file_list in real_list: 34 | class_n = file_list.split('/')[0] 35 | real_anno_n = file_list.split('/')[1] 36 | 37 | img_id = real_anno_n.split('_')[-1] 38 | imagenet_file_name = class_n + '_' + img_id + '.JPEG' 39 | # copy imagenet image to dataset 40 | imagenet_file_path = os.path.join(args.imagenet_dir, class_n, imagenet_file_name) 41 | real_image_dir = os.path.join(args.dataset_dir, 'images/real-random/', class_n) 42 | os.makedirs(real_image_dir, exist_ok=True) 43 | save_path = os.path.join(real_image_dir, real_anno_n + '.jpg') 44 | print('Copy image from {0} to {1}'.format(imagenet_file_path, save_path)) 45 | shutil.copy(imagenet_file_path, save_path) 46 | 47 | -------------------------------------------------------------------------------- /sample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import torch 11 | import argparse 12 | import torchvision 13 | 14 | from models import BigdatasetGANModel 15 | from utils import VOCColorize 16 | 17 | def parse_args(): 18 | usage = 'Parser for training bigdataset script.' 19 | parser = argparse.ArgumentParser(description=usage) 20 | parser.add_argument( 21 | '--resolution', '-r', type=int, default=512, 22 | help='Resolution of the generated images, we use biggan-512 by default') 23 | parser.add_argument( 24 | '--ckpt', type=str, required=True, 25 | help='Path to the pretrained BigDatasetGAN weights') 26 | parser.add_argument( 27 | '--save_dir', type=str, default='./generated_datasets/', 28 | help='Path to save dataset') 29 | parser.add_argument( 30 | '--z_var', type=float, default=0.9, 31 | help='Truancation value of z') 32 | parser.add_argument( 33 | '--class_idx', type=int, default=[225, 200], nargs='+', 34 | help='Imagenet class index') 35 | parser.add_argument( 36 | '--samples_per_class', type=int, default=10, 37 | help='data samples per class') 38 | 39 | args = parser.parse_args() 40 | 41 | return args 42 | 43 | def main(args): 44 | device = 'cuda' 45 | 46 | # build seg model 47 | model = BigdatasetGANModel(resolution=args.resolution, out_dim=1, biggan_ckpt=None) 48 | 49 | # load pretrain model 50 | state_dict = torch.load(args.ckpt) 51 | model.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries 52 | 53 | model = model.to(device) 54 | model = model.eval() 55 | 56 | voc_col = VOCColorize(n=1000) 57 | 58 | overall_viz = [] 59 | 60 | os.makedirs(args.save_dir, exist_ok=True) 61 | 62 | for class_y in args.class_idx: 63 | print("Start sampling dataset with class idx: {0}, total samples: {1}".format(class_y, args.samples_per_class)) 64 | 65 | class_y_tensor = torch.tensor([class_y], dtype=torch.long).to(device) 66 | 67 | sample_imgs, sample_labels = [], [] 68 | for i in range(args.samples_per_class): 69 | z = torch.empty(1, model.biggan_model.dim_z).normal_(mean=0, std=args.z_var).to(device) 70 | sample_img, sample_pred = model.sample(z, class_y_tensor) 71 | sample_img, sample_pred = sample_img.cpu(), sample_pred.cpu() 72 | 73 | label_pred_prob = torch.sigmoid(sample_pred) 74 | label_pred_mask = torch.zeros_like(label_pred_prob, dtype=torch.long) 75 | label_pred_mask[label_pred_prob>0.5] = 1 76 | 77 | label_pred_rgb = voc_col(label_pred_mask[0][0].cpu().numpy()*class_y) 78 | label_pred_rgb = torch.from_numpy(label_pred_rgb).float() 79 | 80 | sample_imgs.append(sample_img) 81 | sample_labels.append(label_pred_rgb) 82 | 83 | sample_imgs = torch.cat(sample_imgs, dim=0) 84 | sample_labels = torch.stack(sample_labels, dim=0) 85 | 86 | sample_imgs_grid = torchvision.utils.make_grid(sample_imgs, nrow=args.samples_per_class, normalize=True, scale_each=True) 87 | sample_labels_grid = torchvision.utils.make_grid(sample_labels, nrow=args.samples_per_class, normalize=True, scale_each=True) 88 | class_viz_tensor = torchvision.utils.make_grid(torch.stack([sample_imgs_grid, sample_labels_grid]), dim=0, nrow=1) 89 | overall_viz.append(class_viz_tensor) 90 | 91 | overall_viz = torch.stack(overall_viz, dim=0) 92 | torchvision.utils.save_image(overall_viz, os.path.join(args.save_dir, 'sample_overall.jpg'), nrow=1) 93 | 94 | if __name__ == '__main__': 95 | 96 | args = parse_args() 97 | 98 | main(args) 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | import argparse 13 | import torchvision 14 | 15 | import numpy as np 16 | from torch.utils.data import DataLoader 17 | import torch.optim as optim 18 | 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | from models import BigdatasetGANModel 22 | from datasets.datasets import ImagenetDataset 23 | from utils import VOCColorize 24 | 25 | def parse_args(): 26 | usage = 'Parser for training bigdataset script.' 27 | parser = argparse.ArgumentParser(description=usage) 28 | parser.add_argument( 29 | '--resolution', '-r', type=int, default=512, 30 | help='Resolution of the generated images, we use biggan-512 by default') 31 | 32 | parser.add_argument( 33 | '--gan_ckpt', type=str, default='./pretrain/biggan-512.pth', 34 | help='Path to the pretrained gan ckpt') 35 | parser.add_argument( 36 | '--dataset_dir', type=str, default='./data/', 37 | help='Path to the dataset folder') 38 | 39 | parser.add_argument( 40 | '--save_dir', type=str, default='./logs/checkpoint_biggan512_label_conv/', 41 | help='Path to save logs') 42 | parser.add_argument( 43 | '--batch_size', type=int, default=4, 44 | help='training batch size') 45 | parser.add_argument( 46 | '--max_iter', type=int, default=5000, 47 | help='maximum iteration of training') 48 | parser.add_argument( 49 | '--lr', type=float, default=0.001, 50 | help='learning rate') 51 | 52 | args = parser.parse_args() 53 | 54 | return args 55 | 56 | def sample_data(loader): 57 | while True: 58 | for batch in loader: 59 | yield batch 60 | 61 | def main(args): 62 | device = 'cuda' 63 | # build checkpoint dir 64 | from datetime import datetime 65 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 66 | ckpt_dir = os.path.join(args.save_dir, 'run-'+current_time) 67 | os.makedirs(ckpt_dir, exist_ok=True) 68 | writer = SummaryWriter(log_dir=os.path.join(ckpt_dir, 'logs')) 69 | os.makedirs(os.path.join(ckpt_dir, 'training'), exist_ok=True) 70 | # os.makedirs(os.path.join(ckpt_dir, 'samples'), exist_ok=True) 71 | 72 | # build dataset 73 | dataset = ImagenetDataset(args.dataset_dir) 74 | #dataset = BigganDataset(args.dataset_dir, single_class=args.single_class) 75 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 76 | 77 | print("loading dataset size: ", len(dataset)) 78 | 79 | # build seg model 80 | g_seg = BigdatasetGANModel(resolution=args.resolution, out_dim=1, biggan_ckpt=args.gan_ckpt) 81 | 82 | g_seg = g_seg.to(device) 83 | 84 | g_optim = optim.Adam( 85 | g_seg.parameters(), 86 | lr=args.lr 87 | ) 88 | 89 | loss_fn = nn.BCEWithLogitsLoss() 90 | 91 | dataloader = sample_data(dataloader) 92 | 93 | voc_col = VOCColorize(n=1000) 94 | 95 | print("Start training with maximum {0} iterations.".format(args.max_iter)) 96 | 97 | for i, batch_data in enumerate(dataloader): 98 | 99 | if i > args.max_iter: 100 | break 101 | 102 | z = batch_data['latent'].to(device) 103 | label_gt = batch_data['label'].to(device) 104 | y = batch_data['y'].to(device) 105 | 106 | # set g_Seg in train mode 107 | g_seg.train() 108 | g_seg.biggan_model.eval() 109 | 110 | label_pred = g_seg(z, y) 111 | 112 | loss = loss_fn(label_pred, label_gt.float().unsqueeze(1)) 113 | 114 | g_optim.zero_grad() 115 | loss.backward() 116 | g_optim.step() 117 | 118 | writer.add_scalar('train/loss', loss.item(), global_step=i) 119 | 120 | if i % 10 == 0: 121 | print("Training step: {0:05d}/{1:05d}, loss: {2:0.4f}".format(i, args.max_iter, loss)) 122 | 123 | if i % 100 == 0: 124 | # save train pred 125 | g_seg.eval() 126 | sample_imgs, sample_pred = g_seg.sample(z, y) 127 | sample_imgs, sample_pred = sample_imgs.cpu(), sample_pred.cpu() 128 | 129 | label_pred_prob = torch.sigmoid(label_pred) 130 | label_pred_mask = torch.zeros_like(label_pred_prob, dtype=torch.long) 131 | label_pred_mask[label_pred_prob>0.5] = 1 132 | 133 | label_pred_rgb = voc_col(label_pred_mask[0][0].cpu().numpy()*y[0].cpu().numpy()) 134 | label_pred_rgb = torch.from_numpy(label_pred_rgb).float() 135 | 136 | label_gt_rgb = voc_col(label_gt[0].cpu().numpy()*y[0].cpu().numpy()) 137 | label_gt_rgb = torch.from_numpy(label_gt_rgb).float() 138 | 139 | viz_tensor = torch.stack([sample_imgs[0], label_gt_rgb, label_pred_rgb], dim=0) 140 | 141 | torchvision.utils.save_image(viz_tensor, os.path.join(ckpt_dir, 142 | 'training/viz_sample_{0:05d}.jpg'.format(i)), normalize=True, scale_each=True) 143 | 144 | if i % 1000 == 0: 145 | # save checkpoint 146 | print("Saving latest checkpoint.") 147 | torch.save(g_seg.state_dict(), os.path.join(ckpt_dir, 'checkpoint_latest.pth')) 148 | 149 | if __name__ == '__main__': 150 | 151 | args = parse_args() 152 | 153 | main(args) 154 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class VOCColorize(object): 4 | def __init__(self, n): 5 | self.cmap = color_map(n) 6 | 7 | def __call__(self, gray_image): 8 | size = gray_image.shape 9 | color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8) 10 | 11 | for label in range(0, len(self.cmap)): 12 | mask = (label == gray_image) 13 | color_image[0][mask] = self.cmap[label][0] 14 | color_image[1][mask] = self.cmap[label][1] 15 | color_image[2][mask] = self.cmap[label][2] 16 | 17 | # handle void 18 | # mask = (255 == gray_image) 19 | # color_image[0][mask] = color_image[1][mask] = color_image[2][mask] = 255 20 | 21 | return color_image 22 | 23 | def color_map(N, normalized=False): 24 | def bitget(byteval, idx): 25 | return ((byteval & (1 << idx)) != 0) 26 | 27 | dtype = 'float32' if normalized else 'uint8' 28 | cmap = np.zeros((N, 3), dtype=dtype) 29 | for i in range(N): 30 | r = g = b = 0 31 | c = i 32 | for j in range(8): 33 | r = r | (bitget(c, 0) << 7-j) 34 | g = g | (bitget(c, 1) << 7-j) 35 | b = b | (bitget(c, 2) << 7-j) 36 | c = c >> 3 37 | 38 | cmap[i] = np.array([r, g, b]) 39 | 40 | cmap = cmap/255 if normalized else cmap 41 | return cmap -------------------------------------------------------------------------------- /viz/sample_overall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/viz/sample_overall.jpg --------------------------------------------------------------------------------