├── .gitignore
├── LICENSE
├── README.md
├── biggan_pytorch
├── BigGAN.py
├── BigGANdeep.py
├── LICENSE
├── README.md
├── TFHub
│ ├── README.md
│ ├── biggan_v1.py
│ └── converter.py
├── animal_hash.py
├── calculate_inception_moments.py
├── datasets.py
├── imgs
│ ├── D Singular Values.png
│ ├── DeepSamples.png
│ ├── DogBall.png
│ ├── G Singular Values.png
│ ├── IS_FID.png
│ ├── Losses.png
│ ├── header_image.jpg
│ └── interp_sample.jpg
├── inception_tf13.py
├── inception_utils.py
├── layers.py
├── logs
│ ├── BigGAN_ch96_bs256x8.jsonl
│ ├── compare_IS.m
│ ├── metalog.txt
│ ├── process_inception_log.m
│ └── process_training.m
├── losses.py
├── make_hdf5.py
├── sample.py
├── scripts
│ ├── launch_BigGAN_bs256x8.sh
│ ├── launch_BigGAN_bs512x4.sh
│ ├── launch_BigGAN_ch64_bs256x8.sh
│ ├── launch_BigGAN_deep.sh
│ ├── launch_SAGAN_bs128x2_ema.sh
│ ├── launch_SNGAN.sh
│ ├── launch_cifar_ema.sh
│ ├── sample_BigGAN_bs256x8.sh
│ ├── sample_cifar_ema.sh
│ └── utils
│ │ ├── duplicate.sh
│ │ └── prepare_data.sh
├── sync_batchnorm
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── batchnorm_reimpl.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── train.py
├── train_fns.py
└── utils.py
├── datasets
├── __init__.py
├── datasets.py
├── imagenet2synset.json
├── imagenet_utils.py
└── imagenetclass992.json
├── models.py
├── prepare_biggan_images.py
├── prepare_imagenet_images.py
├── sample_dataset.py
├── train.py
├── utils.py
└── viz
└── sample_overall.jpg
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ##
BigDatasetGAN: Synthesizing ImageNet with Pixel-wise Annotations
2 |
3 | This is the official code for:
4 |
5 | #### BigDatasetGAN: Synthesizing ImageNet with Pixel-wise Annotations
6 |
7 | [Daiqing Li](https://scholar.google.ca/citations?user=8q2ISMIAAAAJ&hl=en), [Huan Ling](http://www.cs.toronto.edu/~linghuan/), [Seung Wook Kim](https://seung-kim.github.io/seungkim/), [Karsten Kreis](https://scholar.google.de/citations?user=rFd-DiAAAAAJ&hl=de), Adela Barriuso, [Sanja Fidler](http://www.cs.toronto.edu/~fidler/), [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab/)\
8 | **[[Paper](https://arxiv.org/abs/2201.04684)] [[Bibtex](https://nv-tlabs.github.io/big-datasetgan/resources/bibtex.txt)] [[Project Page](https://nv-tlabs.github.io/big-datasetgan/)]**
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Requirements
20 | * Python 3.7
21 | * Cuda v11.0+
22 | * gcc v7.3.0
23 | * Pytorch 1.9.0+
24 |
25 | ## Pretrained BigGAN weights
26 | Our annotation and model are based on BigGAN-512, please download the model from https://tfhub.dev/deepmind/biggan-512/2, store it in `./pretrain` folder. Since the original model is trained using TensorFlow, you need to convert the model weights back to Pytorch, following the instructions here https://github.com/ajbrock/BigGAN-PyTorch/tree/master/TFHub. Notice the model is Licensed under Apache-2.0 issued by DeepMind.
27 |
28 | ## Dataset preparation
29 | We only release our annotations on sampled [BigGAN](https://github.com/ajbrock/BigGAN-PyTorch/blob/master/LICENSE) images and images from [ImageNet](https://www.image-net.org/index.php) along with its latents used to get the sampled images. For their licenses, please refer to their websites. Notice our dataset release is under the [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. Please see [License](#license) session for details.
30 |
31 | 1. Download ImageNet from [here](https://www.image-net.org/index.php).
32 | 2. Download our annotations `annotations.zip` and latent codes `latents.zip` from [gdrive](https://drive.google.com/drive/folders/1NC0VSZrVZsd1Z_IRSdadyfCfTXMAPsf1). Unzip them into data folder under `./data/`.
33 | 3. Process images from ImageNet into our dataset format. Run the following script
34 | ```
35 | python prepare_imagenet_images.py --imagenet_dir [path to imagenet dir] --dataset_dir ./data/
36 | ```
37 | 4. Prepare images generated from BigGAN. Please download the pretrained weights following [this](#pretrained-biggan-weights) session. And run
38 | ```
39 | python prepare_biggan_images.py --biggan_ckpt ./pretrain/biggan-512.pth --dataset_dir ./data/
40 | ```
41 | After the processing steps, you should have data folder structure like this
42 | ```
43 | data
44 | |
45 | └───annotations
46 | │ |
47 | │ └───biggan512
48 | │ | │ n01440764
49 | │ | │ ...
50 | | └───real-random-list.txt
51 | └───images
52 | │ |
53 | │ └───biggan512
54 | │ | │ n01440764
55 | │ | │ ...
56 | | └───real-random
57 | │ │ n01440764
58 | │ │ ...
59 | └───latents
60 | │ |
61 | │ └───biggan512
62 | │ | │ n01440764
63 | │ | │ ...
64 | ```
65 |
66 | ## Training
67 | After the dataset preparation, we now can train BigDatasetGAN to synthesize dataset.
68 |
69 | Run the following
70 | ```
71 | python train.py --gan_ckpt ./pretrain/biggan-512.pth \
72 | --dataset_dir ./data/ \
73 | --save_dir ./logs/
74 | ```
75 |
76 | You can monitor the training progress in tensorboard, as well as the training predictions in logs dir.
77 |
78 | By default, the training runs 5k iteration with a batch size of 4, you can adjust it for the best capacity.
79 |
80 | ## Sampling dataset
81 | After the training is done, we can synthesize ImageNet with pixel-wise labels.
82 |
83 | Run the following
84 | ```
85 | python sample_dataset.py --ckpt [path to your pretrained BigDatasetGAN weights] \
86 | --save_dir ./dataset_viz/ \
87 | --class_idx 225, 200, [you can give it more with ImagetNet class idx] \
88 | --samples_per_class 10
89 | --z_var 0.9
90 | ```
91 |
92 |
93 |
94 |
95 |
96 |
97 | As an example, here we sample class 225 and 200 with 10 samples each.
98 |
99 | ## License
100 | For any code dependency related to BigGAN, the license is under the MIT License, see https://github.com/ajbrock/BigGAN-PyTorch/blob/master/LICENSE.
101 |
102 | The work BigDatasetGAN code is released under Creative Commons BY-NC 4.0 license, full text at http://creativecommons.org/licenses/by-nc/4.0/legalcode.
103 |
104 | The dataset of BigDatasetGAN is released under the [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made.
105 |
106 | ## Citation
107 |
108 | Please cite the following paper if you used the code in this repository.
109 |
110 | ```
111 | @misc{li2022bigdatasetgan,
112 | title={BigDatasetGAN: Synthesizing ImageNet with Pixel-wise Annotations},
113 | author={Daiqing Li and Huan Ling and Seung Wook Kim and Karsten Kreis and Adela Barriuso and Sanja Fidler and Antonio Torralba},
114 | year={2022},
115 | eprint={2201.04684},
116 | archivePrefix={arXiv},
117 | primaryClass={cs.CV}
118 | }
119 | ```
120 |
121 |
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/biggan_pytorch/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Andy Brock
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/biggan_pytorch/README.md:
--------------------------------------------------------------------------------
1 | # BigGAN-PyTorch
2 | The author's officially unofficial PyTorch BigGAN implementation.
3 |
4 | 
5 |
6 |
7 | This repo contains code for 4-8 GPU training of BigGANs from [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/abs/1809.11096) by Andrew Brock, Jeff Donahue, and Karen Simonyan.
8 |
9 | This code is by Andy Brock and Alex Andonian.
10 |
11 | ## How To Use This Code
12 | You will need:
13 |
14 | - [PyTorch](https://PyTorch.org/), version 1.0.1
15 | - tqdm, numpy, scipy, and h5py
16 | - The ImageNet training set
17 |
18 | First, you may optionally prepare a pre-processed HDF5 version of your target dataset for faster I/O. Following this (or not), you'll need the Inception moments needed to calculate FID. These can both be done by modifying and running
19 |
20 | ```sh
21 | sh scripts/utils/prepare_data.sh
22 | ```
23 |
24 | Which by default assumes your ImageNet training set is downloaded into the root folder `data` in this directory, and will prepare the cached HDF5 at 128x128 pixel resolution.
25 |
26 | In the scripts folder, there are multiple bash scripts which will train BigGANs with different batch sizes. This code assumes you do not have access to a full TPU pod, and accordingly
27 | spoofs mega-batches by using gradient accumulation (averaging grads over multiple minibatches, and only taking an optimizer step after N accumulations). By default, the `launch_BigGAN_bs256x8.sh` script trains a
28 | full-sized BigGAN model with a batch size of 256 and 8 gradient accumulations, for a total batch size of 2048. On 8xV100 with full-precision training (no Tensor cores), this script takes 15 days to train to 150k iterations.
29 |
30 | You will first need to figure out the maximum batch size your setup can support. The pre-trained models provided here were trained on 8xV100 (16GB VRAM each) which can support slightly more than the BS256 used by default.
31 | Once you've determined this, you should modify the script so that the batch size times the number of gradient accumulations is equal to your desired total batch size (BigGAN defaults to 2048).
32 |
33 | Note also that this script uses the `--load_in_mem` arg, which loads the entire (~64GB) I128.hdf5 file into RAM for faster data loading. If you don't have enough RAM to support this (probably 96GB+), remove this argument.
34 |
35 |
36 | ## Metrics and Sampling
37 | 
38 |
39 | During training, this script will output logs with training metrics and test metrics, will save multiple copies (2 most recent and 5 highest-scoring) of the model weights/optimizer params, and will produce samples and interpolations every time it saves weights.
40 | The logs folder contains scripts to process these logs and plot the results using MATLAB (sorry not sorry).
41 |
42 | After training, one can use `sample.py` to produce additional samples and interpolations, test with different truncation values, batch sizes, number of standing stat accumulations, etc. See the `sample_BigGAN_bs256x8.sh` script for an example.
43 |
44 | By default, everything is saved to weights/samples/logs/data folders which are assumed to be in the same folder as this repo.
45 | You can point all of these to a different base folder using the `--base_root` argument, or pick specific locations for each of these with their respective arguments (e.g. `--logs_root`).
46 |
47 | We include scripts to run BigGAN-deep, but we have not fully trained a model using them, so consider them untested. Additionally, we include scripts to run a model on CIFAR, and to run SA-GAN (with EMA) and SN-GAN on ImageNet. The SA-GAN code assumes you have 4xTitanX (or equivalent in terms of GPU RAM) and will run with a batch size of 128 and 2 gradient accumulations.
48 |
49 | ## An Important Note on Inception Metrics
50 | This repo uses the PyTorch in-built inception network to calculate IS and FID.
51 | These scores are different from the scores you would get using the official TF inception code, and are only for monitoring purposes!
52 | Run sample.py on your model, with the `--sample_npz` argument, then run inception_tf13 to calculate the actual TensorFlow IS. Note that you will need to have TensorFlow 1.3 or earlier installed, as TF1.4+ breaks the original IS code.
53 |
54 | ## Pretrained models
55 | 
56 | We include two pretrained model checkpoints (with G, D, the EMA copy of G, the optimizers, and the state dict):
57 | - The main checkpoint is for a BigGAN trained on ImageNet at 128x128, using BS256 and 8 gradient accumulations, taken just before collapse, with a TF Inception Score of 97.35 +/- 1.79: [LINK](https://drive.google.com/open?id=1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW)
58 | - An earlier checkpoint of the first model (100k G iters), at high performance but well before collapse, which may be easier to fine-tune: [LINK](https://drive.google.com/open?id=1dmZrcVJUAWkPBGza_XgswSuT-UODXZcO)
59 |
60 |
61 |
62 | Pretrained models for Places-365 coming soon.
63 |
64 | This repo also contains scripts for porting the original TFHub BigGAN Generator weights to PyTorch. See the scripts in the TFHub folder for more details.
65 |
66 | ## Fine-tuning, Using Your Own Dataset, or Making New Training Functions
67 | 
68 |
69 | If you wish to resume interrupted training or fine-tune a pre-trained model, run the same launch script but with the `--resume` argument added.
70 | Experiment names are automatically generated from the configuration, but can be overridden using the `--experiment_name` arg (for example, if you wish to fine-tune a model using modified optimizer settings).
71 |
72 | To prep your own dataset, you will need to add it to datasets.py and modify the convenience dicts in utils.py (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict) to have the appropriate metadata for your dataset.
73 | Repeat the process in prepare_data.sh (optionally produce an HDF5 preprocessed copy, and calculate the Inception Moments for FID).
74 |
75 | By default, the training script will save the top 5 best checkpoints as measured by Inception Score.
76 | For datasets other than ImageNet, Inception Score can be a very poor measure of quality, so you will likely want to use `--which_best FID` instead.
77 |
78 | To use your own training function (e.g. train a BigVAE): either modify train_fns.GAN_training_function or add a new train fn and add it after the `if config['which_train_fn'] == 'GAN':` line in `train.py`.
79 |
80 |
81 | ## Neat Stuff
82 | - We include the full training and metrics logs [here](https://drive.google.com/open?id=1ZhY9Mg2b_S4QwxNmt57aXJ9FOC3ZN1qb) for reference. I've found that one of the hardest things about re-implementing a paper can be checking if the logs line up early in training,
83 | especially if training takes multiple weeks. Hopefully these will be helpful for future work.
84 | - We include an accelerated FID calculation--the original scipy version can require upwards of 10 minutes to calculate the matrix sqrt, this version uses an accelerated PyTorch version to calculate it in under a second.
85 | - We include an accelerated, low-memory consumption ortho reg implementation.
86 | - By default, we only compute the top singular value (the spectral norm), but this code supports computing more SVs through the `--num_G_SVs` argument.
87 |
88 | ## Key Differences Between This Code And The Original BigGAN
89 | - We use the optimizer settings from SA-GAN (G_lr=1e-4, D_lr=4e-4, num_D_steps=1, as opposed to BigGAN's G_lr=5e-5, D_lr=2e-4, num_D_steps=2).
90 | While slightly less performant, this was the first corner we cut to bring training times down.
91 | - By default, we do not use Cross-Replica BatchNorm (AKA Synced BatchNorm).
92 | The two variants we tried (a custom, naive one and the one included in this repo) have slightly different gradients (albeit identical forward passes) from the built-in BatchNorm, which appear to be sufficient to cripple training.
93 | - Gradient accumulation means that we update the SV estimates and the BN statistics 8 times more frequently. This means that the BN stats are much closer to standing stats, and that the singular value estimates tend to be more accurate.
94 | Because of this, we measure metrics by default with G in test mode (using the BatchNorm running stat estimates instead of computing standing stats as in the paper). We do still support standing stats (see the sample.sh scripts).
95 | This could also conceivably result in gradients from the earlier accumulations being stale, but in practice this does not appear to be a problem.
96 | - The currently provided pretrained models were not trained with orthogonal regularization. Training without ortho reg seems to increase the probability that models will not be amenable to truncation,
97 | but it looks like this particular model got a winning ticket. Regardless, we provide two highly optimized (fast and minimal memory consumption) ortho reg implementations which directly compute the ortho reg. gradients.
98 |
99 | ## A Note On The Design Of This Repo
100 | This code is designed from the ground up to serve as an extensible, hackable base for further research code.
101 | We've put a lot of thought into making sure the abstractions are the *right* thickness for research--not so thick as to be impenetrable, but not so thin as to be useless.
102 | The key idea is that if you want to experiment with a SOTA setup and make some modification (try out your own new loss function, architecture, self-attention block, etc) you should be able to easily do so just by dropping your code in one or two places, without having to worry about the rest of the codebase.
103 | Things like the use of self.which_conv and functools.partial in the BigGAN.py model definition were put together with this in mind, as was the design of the Spectral Norm class inheritance.
104 |
105 | With that said, this is a somewhat large codebase for a single project. While we tried to be thorough with the comments, if there's something you think could be more clear, better written, or better refactored, please feel free to raise an issue or a pull request.
106 |
107 | ## Feature Requests
108 | Want to work on or improve this code? There are a couple things this repo would benefit from, but which don't yet work.
109 |
110 | - Synchronized BatchNorm (AKA Cross-Replica BatchNorm). We tried out two variants of this, but for some unknown reason it crippled training each time.
111 | We have not tried the [apex](https://github.com/NVIDIA/apex) SyncBN as my school's servers are on ancient NVIDIA drivers that don't support it--apex would probably be a good place to start.
112 | - Mixed precision training and making use of Tensor cores. This repo includes a naive mixed-precision Adam implementation which works early in training but leads to early collapse, and doesn't do anything to activate Tensor cores (it just reduces memory consumption).
113 | As above, integrating [apex](https://github.com/NVIDIA/apex) into this code and employing its mixed-precision training techniques to take advantage of Tensor cores and reduce memory consumption could yield substantial speed gains.
114 |
115 | ## Misc Notes
116 | See [This directory](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for ImageNet labels.
117 |
118 | If you use this code, please cite
119 | ```text
120 | @inproceedings{
121 | brock2018large,
122 | title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis},
123 | author={Andrew Brock and Jeff Donahue and Karen Simonyan},
124 | booktitle={International Conference on Learning Representations},
125 | year={2019},
126 | url={https://openreview.net/forum?id=B1xsqj09Fm},
127 | }
128 | ```
129 |
130 | ## Acknowledgments
131 | Thanks to Google for the generous cloud credit donations.
132 |
133 | [SyncBN](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) by Jiayuan Mao and Tete Xiao.
134 |
135 | [Progress bar](https://github.com/Lasagne/Recipes/tree/master/papers/densenet) originally from Jan Schlüter.
136 |
137 | Test metrics logger from [VoxNet.](https://github.com/dimatura/voxnet)
138 |
139 | PyTorch [implementation of cov](https://discuss.PyTorch.org/t/covariance-and-gradient-support/16217/2) from Modar M. Alfadly.
140 |
141 | PyTorch [fast Matrix Sqrt](https://github.com/msubhransu/matrix-sqrt) for FID from Tsung-Yu Lin and Subhransu Maji.
142 |
143 | TensorFlow Inception Score code from [OpenAI's Improved-GAN.](https://github.com/openai/improved-gan)
144 |
145 |
--------------------------------------------------------------------------------
/biggan_pytorch/TFHub/README.md:
--------------------------------------------------------------------------------
1 | # BigGAN-PyTorch TFHub converter
2 | This dir contains scripts for taking the [pre-trained generator weights from TFHub](https://tfhub.dev/s?q=biggan) and porting them to BigGAN-Pytorch.
3 |
4 | In addition to the base libraries for BigGAN-PyTorch, to run this code you will need:
5 |
6 | TensorFlow
7 | TFHub
8 | parse
9 |
10 | Note that this code is only presently set up to run the ported models without truncation--you'll need to accumulate standing stats at each truncation level yourself if you wish to employ it.
11 |
12 | To port the 128x128 model from tfhub, produce a pretrained weights .pth file, and generate samples using all your GPUs, run
13 |
14 | `python converter.py -r 128 --generate_samples --parallel`
--------------------------------------------------------------------------------
/biggan_pytorch/TFHub/biggan_v1.py:
--------------------------------------------------------------------------------
1 | # BigGAN V1:
2 | # This is now deprecated code used for porting the TFHub modules to pytorch,
3 | # included here for reference only.
4 | import numpy as np
5 | import torch
6 | from scipy.stats import truncnorm
7 | from torch import nn
8 | from torch.nn import Parameter
9 | from torch.nn import functional as F
10 |
11 |
12 | def l2normalize(v, eps=1e-4):
13 | return v / (v.norm() + eps)
14 |
15 |
16 | def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None):
17 | state = None if seed is None else np.random.RandomState(seed)
18 | values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state)
19 | return truncation * values
20 |
21 |
22 | def denorm(x):
23 | out = (x + 1) / 2
24 | return out.clamp_(0, 1)
25 |
26 |
27 | class SpectralNorm(nn.Module):
28 | def __init__(self, module, name='weight', power_iterations=1):
29 | super(SpectralNorm, self).__init__()
30 | self.module = module
31 | self.name = name
32 | self.power_iterations = power_iterations
33 | if not self._made_params():
34 | self._make_params()
35 |
36 | def _update_u_v(self):
37 | u = getattr(self.module, self.name + "_u")
38 | v = getattr(self.module, self.name + "_v")
39 | w = getattr(self.module, self.name + "_bar")
40 |
41 | height = w.data.shape[0]
42 | _w = w.view(height, -1)
43 | for _ in range(self.power_iterations):
44 | v = l2normalize(torch.matmul(_w.t(), u))
45 | u = l2normalize(torch.matmul(_w, v))
46 |
47 | sigma = u.dot((_w).mv(v))
48 | setattr(self.module, self.name, w / sigma.expand_as(w))
49 |
50 | def _made_params(self):
51 | try:
52 | getattr(self.module, self.name + "_u")
53 | getattr(self.module, self.name + "_v")
54 | getattr(self.module, self.name + "_bar")
55 | return True
56 | except AttributeError:
57 | return False
58 |
59 | def _make_params(self):
60 | w = getattr(self.module, self.name)
61 |
62 | height = w.data.shape[0]
63 | width = w.view(height, -1).data.shape[1]
64 |
65 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
66 | v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
67 | u.data = l2normalize(u.data)
68 | v.data = l2normalize(v.data)
69 | w_bar = Parameter(w.data)
70 |
71 | del self.module._parameters[self.name]
72 | self.module.register_parameter(self.name + "_u", u)
73 | self.module.register_parameter(self.name + "_v", v)
74 | self.module.register_parameter(self.name + "_bar", w_bar)
75 |
76 | def forward(self, *args):
77 | self._update_u_v()
78 | return self.module.forward(*args)
79 |
80 |
81 | class SelfAttention(nn.Module):
82 | """ Self Attention Layer"""
83 |
84 | def __init__(self, in_dim, activation=F.relu):
85 | super().__init__()
86 | self.chanel_in = in_dim
87 | self.activation = activation
88 |
89 | self.theta = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False))
90 | self.phi = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False))
91 | self.pool = nn.MaxPool2d(2, 2)
92 | self.g = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False))
93 | self.o_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False))
94 | self.gamma = nn.Parameter(torch.zeros(1))
95 |
96 | self.softmax = nn.Softmax(dim=-1)
97 |
98 | def forward(self, x):
99 | m_batchsize, C, width, height = x.size()
100 | N = height * width
101 |
102 | theta = self.theta(x)
103 | phi = self.phi(x)
104 | phi = self.pool(phi)
105 | phi = phi.view(m_batchsize, -1, N // 4)
106 | theta = theta.view(m_batchsize, -1, N)
107 | theta = theta.permute(0, 2, 1)
108 | attention = self.softmax(torch.bmm(theta, phi))
109 | g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4)
110 | attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view(m_batchsize, -1, width, height)
111 | out = self.o_conv(attn_g)
112 | return self.gamma * out + x
113 |
114 |
115 | class ConditionalBatchNorm2d(nn.Module):
116 | def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1):
117 | super().__init__()
118 | self.num_features = num_features
119 | self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum)
120 | self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
121 | self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
122 |
123 | def forward(self, x, y):
124 | out = self.bn(x)
125 | gamma = self.gamma_embed(y) + 1
126 | beta = self.beta_embed(y)
127 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
128 | return out
129 |
130 |
131 | class GBlock(nn.Module):
132 | def __init__(
133 | self,
134 | in_channel,
135 | out_channel,
136 | kernel_size=[3, 3],
137 | padding=1,
138 | stride=1,
139 | n_class=None,
140 | bn=True,
141 | activation=F.relu,
142 | upsample=True,
143 | downsample=False,
144 | z_dim=148,
145 | ):
146 | super().__init__()
147 |
148 | self.conv0 = SpectralNorm(
149 | nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True)
150 | )
151 | self.conv1 = SpectralNorm(
152 | nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding, bias=True if bn else True)
153 | )
154 |
155 | self.skip_proj = False
156 | if in_channel != out_channel or upsample or downsample:
157 | self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))
158 | self.skip_proj = True
159 |
160 | self.upsample = upsample
161 | self.downsample = downsample
162 | self.activation = activation
163 | self.bn = bn
164 | if bn:
165 | self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim)
166 | self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim)
167 |
168 | def forward(self, input, condition=None):
169 | out = input
170 |
171 | if self.bn:
172 | out = self.HyperBN(out, condition)
173 | out = self.activation(out)
174 | if self.upsample:
175 | out = F.interpolate(out, scale_factor=2)
176 | out = self.conv0(out)
177 | if self.bn:
178 | out = self.HyperBN_1(out, condition)
179 | out = self.activation(out)
180 | out = self.conv1(out)
181 |
182 | if self.downsample:
183 | out = F.avg_pool2d(out, 2)
184 |
185 | if self.skip_proj:
186 | skip = input
187 | if self.upsample:
188 | skip = F.interpolate(skip, scale_factor=2)
189 | skip = self.conv_sc(skip)
190 | if self.downsample:
191 | skip = F.avg_pool2d(skip, 2)
192 | else:
193 | skip = input
194 | return out + skip
195 |
196 |
197 | class Generator128(nn.Module):
198 | def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False):
199 | super().__init__()
200 |
201 | self.linear = nn.Linear(n_class, 128, bias=False)
202 |
203 | if debug:
204 | chn = 8
205 |
206 | self.first_view = 16 * chn
207 |
208 | self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn))
209 |
210 | z_dim = code_dim + 28
211 |
212 | self.GBlock = nn.ModuleList([
213 | GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim),
214 | GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
215 | GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim),
216 | GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim),
217 | GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
218 | ])
219 |
220 | self.sa_id = 4
221 | self.num_split = len(self.GBlock) + 1
222 | self.attention = SelfAttention(2 * chn)
223 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4)
224 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
225 |
226 | def forward(self, input, class_id):
227 | codes = torch.chunk(input, self.num_split, 1)
228 | class_emb = self.linear(class_id) # 128
229 |
230 | out = self.G_linear(codes[0])
231 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
232 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
233 | if i == self.sa_id:
234 | out = self.attention(out)
235 | condition = torch.cat([code, class_emb], 1)
236 | out = GBlock(out, condition)
237 |
238 | out = self.ScaledCrossReplicaBN(out)
239 | out = F.relu(out)
240 | out = self.colorize(out)
241 | return torch.tanh(out)
242 |
243 |
244 | class Generator256(nn.Module):
245 | def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False):
246 | super().__init__()
247 |
248 | self.linear = nn.Linear(n_class, 128, bias=False)
249 |
250 | if debug:
251 | chn = 8
252 |
253 | self.first_view = 16 * chn
254 |
255 | self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn))
256 |
257 | self.GBlock = nn.ModuleList([
258 | GBlock(16 * chn, 16 * chn, n_class=n_class),
259 | GBlock(16 * chn, 8 * chn, n_class=n_class),
260 | GBlock(8 * chn, 8 * chn, n_class=n_class),
261 | GBlock(8 * chn, 4 * chn, n_class=n_class),
262 | GBlock(4 * chn, 2 * chn, n_class=n_class),
263 | GBlock(2 * chn, 1 * chn, n_class=n_class),
264 | ])
265 |
266 | self.sa_id = 5
267 | self.num_split = len(self.GBlock) + 1
268 | self.attention = SelfAttention(2 * chn)
269 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4)
270 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
271 |
272 | def forward(self, input, class_id):
273 | codes = torch.chunk(input, self.num_split, 1)
274 | class_emb = self.linear(class_id) # 128
275 |
276 | out = self.G_linear(codes[0])
277 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
278 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
279 | if i == self.sa_id:
280 | out = self.attention(out)
281 | condition = torch.cat([code, class_emb], 1)
282 | out = GBlock(out, condition)
283 |
284 | out = self.ScaledCrossReplicaBN(out)
285 | out = F.relu(out)
286 | out = self.colorize(out)
287 | return torch.tanh(out)
288 |
289 |
290 | class Generator512(nn.Module):
291 | def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False):
292 | super().__init__()
293 |
294 | self.linear = nn.Linear(n_class, 128, bias=False)
295 |
296 | if debug:
297 | chn = 8
298 |
299 | self.first_view = 16 * chn
300 |
301 | self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn))
302 |
303 | z_dim = code_dim + 16
304 |
305 | self.GBlock = nn.ModuleList([
306 | GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim),
307 | GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
308 | GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
309 | GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim),
310 | GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim),
311 | GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
312 | GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
313 | ])
314 |
315 | self.sa_id = 4
316 | self.num_split = len(self.GBlock) + 1
317 | self.attention = SelfAttention(4 * chn)
318 | self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn)
319 | self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
320 |
321 | def forward(self, input, class_id):
322 | codes = torch.chunk(input, self.num_split, 1)
323 | class_emb = self.linear(class_id) # 128
324 |
325 | out = self.G_linear(codes[0])
326 | out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
327 | for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
328 | if i == self.sa_id:
329 | out = self.attention(out)
330 | condition = torch.cat([code, class_emb], 1)
331 | out = GBlock(out, condition)
332 |
333 | out = self.ScaledCrossReplicaBN(out)
334 | out = F.relu(out)
335 | out = self.colorize(out)
336 | return torch.tanh(out)
337 |
338 |
339 | class Discriminator(nn.Module):
340 | def __init__(self, n_class=1000, chn=96, debug=False):
341 | super().__init__()
342 |
343 | def conv(in_channel, out_channel, downsample=True):
344 | return GBlock(in_channel, out_channel, bn=False, upsample=False, downsample=downsample)
345 |
346 | if debug:
347 | chn = 8
348 | self.debug = debug
349 |
350 | self.pre_conv = nn.Sequential(
351 | SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)),
352 | nn.ReLU(),
353 | SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)),
354 | nn.AvgPool2d(2),
355 | )
356 | self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1))
357 |
358 | self.conv = nn.Sequential(
359 | conv(1 * chn, 1 * chn, downsample=True),
360 | conv(1 * chn, 2 * chn, downsample=True),
361 | SelfAttention(2 * chn),
362 | conv(2 * chn, 2 * chn, downsample=True),
363 | conv(2 * chn, 4 * chn, downsample=True),
364 | conv(4 * chn, 8 * chn, downsample=True),
365 | conv(8 * chn, 8 * chn, downsample=True),
366 | conv(8 * chn, 16 * chn, downsample=True),
367 | conv(16 * chn, 16 * chn, downsample=False),
368 | )
369 |
370 | self.linear = SpectralNorm(nn.Linear(16 * chn, 1))
371 |
372 | self.embed = nn.Embedding(n_class, 16 * chn)
373 | self.embed.weight.data.uniform_(-0.1, 0.1)
374 | self.embed = SpectralNorm(self.embed)
375 |
376 | def forward(self, input, class_id):
377 |
378 | out = self.pre_conv(input)
379 | out += self.pre_skip(F.avg_pool2d(input, 2))
380 | out = self.conv(out)
381 | out = F.relu(out)
382 | out = out.view(out.size(0), out.size(1), -1)
383 | out = out.sum(2)
384 | out_linear = self.linear(out).squeeze(1)
385 | embed = self.embed(class_id)
386 |
387 | prod = (out * embed).sum(1)
388 |
389 | return out_linear + prod
--------------------------------------------------------------------------------
/biggan_pytorch/TFHub/converter.py:
--------------------------------------------------------------------------------
1 | """Utilities for converting TFHub BigGAN generator weights to PyTorch.
2 |
3 | Recommended usage:
4 |
5 | To convert all BigGAN variants and generate test samples, use:
6 |
7 | ```bash
8 | CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples
9 | ```
10 |
11 | See `parse_args` for additional options.
12 | """
13 |
14 | import argparse
15 | import os
16 | import sys
17 |
18 | import h5py
19 | import torch
20 | import torch.nn as nn
21 | from torchvision.utils import save_image
22 | import tensorflow as tf
23 | import tensorflow_hub as hub
24 | import parse
25 |
26 | # import reference biggan from this folder
27 | import biggan_v1 as biggan_for_conversion
28 |
29 | # Import model from main folder
30 | sys.path.append('..')
31 | import BigGAN
32 |
33 |
34 |
35 |
36 | DEVICE = 'cuda'
37 | HDF5_TMPL = 'biggan-{}.h5'
38 | PTH_TMPL = 'biggan-{}.pth'
39 | MODULE_PATH_TMPL = 'https://tfhub.dev/deepmind/biggan-{}/2'
40 | Z_DIMS = {
41 | 128: 120,
42 | 256: 140,
43 | 512: 128}
44 | RESOLUTIONS = list(Z_DIMS)
45 |
46 |
47 | def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False):
48 | """Loads TFHub weights and saves them to intermediate HDF5 file.
49 |
50 | Args:
51 | module_path ([Path-like]): Path to TFHub module.
52 | hdf5_path ([Path-like]): Path to output HDF5 file.
53 |
54 | Returns:
55 | [h5py.File]: Loaded hdf5 file containing module weights.
56 | """
57 | if os.path.exists(hdf5_path) and (not redownload):
58 | print('Loading BigGAN hdf5 file from:', hdf5_path)
59 | return h5py.File(hdf5_path, 'r')
60 |
61 | print('Loading BigGAN module from:', module_path)
62 | tf.reset_default_graph()
63 | hub.Module(module_path)
64 | print('Loaded BigGAN module from:', module_path)
65 |
66 | initializer = tf.global_variables_initializer()
67 | sess = tf.Session()
68 | sess.run(initializer)
69 |
70 | print('Saving BigGAN weights to :', hdf5_path)
71 | h5f = h5py.File(hdf5_path, 'w')
72 | for var in tf.global_variables():
73 | val = sess.run(var)
74 | h5f.create_dataset(var.name, data=val)
75 | print(f'Saving {var.name} with shape {val.shape}')
76 | h5f.close()
77 | return h5py.File(hdf5_path, 'r')
78 |
79 |
80 | class TFHub2Pytorch(object):
81 |
82 | TF_ROOT = 'module'
83 |
84 | NUM_GBLOCK = {
85 | 128: 5,
86 | 256: 6,
87 | 512: 7
88 | }
89 |
90 | w = 'w'
91 | b = 'b'
92 | u = 'u0'
93 | v = 'u1'
94 | gamma = 'gamma'
95 | beta = 'beta'
96 |
97 | def __init__(self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False):
98 | self.state_dict = state_dict
99 | self.tf_weights = tf_weights
100 | self.resolution = resolution
101 | self.verbose = verbose
102 | if load_ema:
103 | for name in ['w', 'b', 'gamma', 'beta']:
104 | setattr(self, name, getattr(self, name) + '/ema_b999900')
105 |
106 | def load(self):
107 | self.load_generator()
108 | return self.state_dict
109 |
110 | def load_generator(self):
111 | GENERATOR_ROOT = os.path.join(self.TF_ROOT, 'Generator')
112 |
113 | for i in range(self.NUM_GBLOCK[self.resolution]):
114 | name_tf = os.path.join(GENERATOR_ROOT, 'GBlock')
115 | name_tf += f'_{i}' if i != 0 else ''
116 | self.load_GBlock(f'GBlock.{i}.', name_tf)
117 |
118 | self.load_attention('attention.', os.path.join(GENERATOR_ROOT, 'attention'))
119 | self.load_linear('linear', os.path.join(self.TF_ROOT, 'linear'), bias=False)
120 | self.load_snlinear('G_linear', os.path.join(GENERATOR_ROOT, 'G_Z', 'G_linear'))
121 | self.load_colorize('colorize', os.path.join(GENERATOR_ROOT, 'conv_2d'))
122 | self.load_ScaledCrossReplicaBNs('ScaledCrossReplicaBN',
123 | os.path.join(GENERATOR_ROOT, 'ScaledCrossReplicaBN'))
124 |
125 | def load_linear(self, name_pth, name_tf, bias=True):
126 | self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0)
127 | if bias:
128 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b)
129 |
130 | def load_snlinear(self, name_pth, name_tf, bias=True):
131 | self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze()
132 | self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze()
133 | self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(1, 0)
134 | if bias:
135 | self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b)
136 |
137 | def load_colorize(self, name_pth, name_tf):
138 | self.load_snconv(name_pth, name_tf)
139 |
140 | def load_GBlock(self, name_pth, name_tf):
141 | self.load_convs(name_pth, name_tf)
142 | self.load_HyperBNs(name_pth, name_tf)
143 |
144 | def load_convs(self, name_pth, name_tf):
145 | self.load_snconv(name_pth + 'conv0', os.path.join(name_tf, 'conv0'))
146 | self.load_snconv(name_pth + 'conv1', os.path.join(name_tf, 'conv1'))
147 | self.load_snconv(name_pth + 'conv_sc', os.path.join(name_tf, 'conv_sc'))
148 |
149 | def load_snconv(self, name_pth, name_tf, bias=True):
150 | if self.verbose:
151 | print(f'loading: {name_pth} from {name_tf}')
152 | self.state_dict[name_pth + '.module.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze()
153 | self.state_dict[name_pth + '.module.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze()
154 | self.state_dict[name_pth + '.module.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1)
155 | if bias:
156 | self.state_dict[name_pth + '.module.bias'] = self.load_tf_tensor(name_tf, self.b).squeeze()
157 |
158 | def load_conv(self, name_pth, name_tf, bias=True):
159 |
160 | self.state_dict[name_pth + '.weight_u'] = self.load_tf_tensor(name_tf, self.u).squeeze()
161 | self.state_dict[name_pth + '.weight_v'] = self.load_tf_tensor(name_tf, self.v).squeeze()
162 | self.state_dict[name_pth + '.weight_bar'] = self.load_tf_tensor(name_tf, self.w).permute(3, 2, 0, 1)
163 | if bias:
164 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.b)
165 |
166 | def load_HyperBNs(self, name_pth, name_tf):
167 | self.load_HyperBN(name_pth + 'HyperBN', os.path.join(name_tf, 'HyperBN'))
168 | self.load_HyperBN(name_pth + 'HyperBN_1', os.path.join(name_tf, 'HyperBN_1'))
169 |
170 | def load_ScaledCrossReplicaBNs(self, name_pth, name_tf):
171 | self.state_dict[name_pth + '.bias'] = self.load_tf_tensor(name_tf, self.beta).squeeze()
172 | self.state_dict[name_pth + '.weight'] = self.load_tf_tensor(name_tf, self.gamma).squeeze()
173 | self.state_dict[name_pth + '.running_mean'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_mean')
174 | self.state_dict[name_pth + '.running_var'] = self.load_tf_tensor(name_tf + 'bn', 'accumulated_var')
175 | self.state_dict[name_pth + '.num_batches_tracked'] = torch.tensor(
176 | self.tf_weights[os.path.join(name_tf + 'bn', 'accumulation_counter:0')][()], dtype=torch.float32)
177 |
178 | def load_HyperBN(self, name_pth, name_tf):
179 | if self.verbose:
180 | print(f'loading: {name_pth} from {name_tf}')
181 | beta = name_pth + '.beta_embed.module'
182 | gamma = name_pth + '.gamma_embed.module'
183 | self.state_dict[beta + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.u).squeeze()
184 | self.state_dict[gamma + '.weight_u'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.u).squeeze()
185 | self.state_dict[beta + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.v).squeeze()
186 | self.state_dict[gamma + '.weight_v'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.v).squeeze()
187 | self.state_dict[beta + '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'beta'), self.w).permute(1, 0)
188 | self.state_dict[gamma +
189 | '.weight_bar'] = self.load_tf_tensor(os.path.join(name_tf, 'gamma'), self.w).permute(1, 0)
190 |
191 | cr_bn_name = name_tf.replace('HyperBN', 'CrossReplicaBN')
192 | self.state_dict[name_pth + '.bn.running_mean'] = self.load_tf_tensor(cr_bn_name, 'accumulated_mean')
193 | self.state_dict[name_pth + '.bn.running_var'] = self.load_tf_tensor(cr_bn_name, 'accumulated_var')
194 | self.state_dict[name_pth + '.bn.num_batches_tracked'] = torch.tensor(
195 | self.tf_weights[os.path.join(cr_bn_name, 'accumulation_counter:0')][()], dtype=torch.float32)
196 |
197 | def load_attention(self, name_pth, name_tf):
198 |
199 | self.load_snconv(name_pth + 'theta', os.path.join(name_tf, 'theta'), bias=False)
200 | self.load_snconv(name_pth + 'phi', os.path.join(name_tf, 'phi'), bias=False)
201 | self.load_snconv(name_pth + 'g', os.path.join(name_tf, 'g'), bias=False)
202 | self.load_snconv(name_pth + 'o_conv', os.path.join(name_tf, 'o_conv'), bias=False)
203 | self.state_dict[name_pth + 'gamma'] = self.load_tf_tensor(name_tf, self.gamma)
204 |
205 | def load_tf_tensor(self, prefix, var, device='0'):
206 | name = os.path.join(prefix, var) + f':{device}'
207 | return torch.from_numpy(self.tf_weights[name][:])
208 |
209 | # Convert from v1: This function maps
210 | def convert_from_v1(hub_dict, resolution=128):
211 | weightname_dict = {'weight_u': 'u0', 'weight_bar': 'weight', 'bias': 'bias'}
212 | convnum_dict = {'conv0': 'conv1', 'conv1': 'conv2', 'conv_sc': 'conv_sc'}
213 | attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution]
214 | hub2me = {'linear.weight': 'shared.weight', # This is actually the shared weight
215 | # Linear stuff
216 | 'G_linear.module.weight_bar': 'linear.weight',
217 | 'G_linear.module.bias': 'linear.bias',
218 | 'G_linear.module.weight_u': 'linear.u0',
219 | # output layer stuff
220 | 'ScaledCrossReplicaBN.weight': 'output_layer.0.gain',
221 | 'ScaledCrossReplicaBN.bias': 'output_layer.0.bias',
222 | 'ScaledCrossReplicaBN.running_mean': 'output_layer.0.stored_mean',
223 | 'ScaledCrossReplicaBN.running_var': 'output_layer.0.stored_var',
224 | 'colorize.module.weight_bar': 'output_layer.2.weight',
225 | 'colorize.module.bias': 'output_layer.2.bias',
226 | 'colorize.module.weight_u': 'output_layer.2.u0',
227 | # Attention stuff
228 | 'attention.gamma': 'blocks.%d.1.gamma' % attention_blocknum,
229 | 'attention.theta.module.weight_u': 'blocks.%d.1.theta.u0' % attention_blocknum,
230 | 'attention.theta.module.weight_bar': 'blocks.%d.1.theta.weight' % attention_blocknum,
231 | 'attention.phi.module.weight_u': 'blocks.%d.1.phi.u0' % attention_blocknum,
232 | 'attention.phi.module.weight_bar': 'blocks.%d.1.phi.weight' % attention_blocknum,
233 | 'attention.g.module.weight_u': 'blocks.%d.1.g.u0' % attention_blocknum,
234 | 'attention.g.module.weight_bar': 'blocks.%d.1.g.weight' % attention_blocknum,
235 | 'attention.o_conv.module.weight_u': 'blocks.%d.1.o.u0' % attention_blocknum,
236 | 'attention.o_conv.module.weight_bar':'blocks.%d.1.o.weight' % attention_blocknum,
237 | }
238 |
239 | # Loop over the hub dict and build the hub2me map
240 | for name in hub_dict.keys():
241 | if 'GBlock' in name:
242 | if 'HyperBN' not in name: # it's a conv
243 | out = parse.parse('GBlock.{:d}.{}.module.{}',name)
244 | blocknum, convnum, weightname = out
245 | if weightname not in weightname_dict:
246 | continue # else hyperBN in
247 | out_name = 'blocks.%d.0.%s.%s' % (blocknum, convnum_dict[convnum], weightname_dict[weightname]) # Increment conv number by 1
248 | else: # hyperbn not conv
249 | BNnum = 2 if 'HyperBN_1' in name else 1
250 | if 'embed' in name:
251 | out = parse.parse('GBlock.{:d}.{}.module.{}',name)
252 | blocknum, gamma_or_beta, weightname = out
253 | if weightname not in weightname_dict: # Ignore weight_v
254 | continue
255 | out_name = 'blocks.%d.0.bn%d.%s.%s' % (blocknum, BNnum, 'gain' if 'gamma' in gamma_or_beta else 'bias', weightname_dict[weightname])
256 | else:
257 | out = parse.parse('GBlock.{:d}.{}.bn.{}',name)
258 | blocknum, dummy, mean_or_var = out
259 | if 'num_batches_tracked' in mean_or_var:
260 | continue
261 | out_name = 'blocks.%d.0.bn%d.%s' % (blocknum, BNnum, 'stored_mean' if 'mean' in mean_or_var else 'stored_var')
262 | hub2me[name] = out_name
263 |
264 |
265 | # Invert the hub2me map
266 | me2hub = {hub2me[item]: item for item in hub2me}
267 | new_dict = {}
268 | dimz_dict = {128: 20, 256: 20, 512:16}
269 | for item in me2hub:
270 | # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs
271 | if ('bn' in item and 'weight' in item) and ('gain' in item or 'bias' in item) and ('output_layer' not in item):
272 | new_dict[item] = torch.cat([hub_dict[me2hub[item]][:, -128:], hub_dict[me2hub[item]][:, :dimz_dict[resolution]]], 1)
273 | # Reshape the first linear weight, bias, and u0
274 | elif item == 'linear.weight':
275 | new_dict[item] = hub_dict[me2hub[item]].contiguous().view(4, 4, 96 * 16, -1).permute(2,0,1,3).contiguous().view(-1,dimz_dict[resolution])
276 | elif item == 'linear.bias':
277 | new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(-1)
278 | elif item == 'linear.u0':
279 | new_dict[item] = hub_dict[me2hub[item]].view(4, 4, 96 * 16).permute(2,0,1).contiguous().view(1, -1)
280 | elif me2hub[item] == 'linear.weight': # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER
281 | # Transpose shared weight so that it's an embedding
282 | new_dict[item] = hub_dict[me2hub[item]].t()
283 | elif 'weight_u' in me2hub[item]: # Unsqueeze u0s
284 | new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0)
285 | else:
286 | new_dict[item] = hub_dict[me2hub[item]]
287 | return new_dict
288 |
289 | def get_config(resolution):
290 | attn_dict = {128: '64', 256: '128', 512: '64'}
291 | dim_z_dict = {128: 120, 256: 140, 512: 128}
292 | config = {'G_param': 'SN', 'D_param': 'SN',
293 | 'G_ch': 96, 'D_ch': 96,
294 | 'D_wide': True, 'G_shared': True,
295 | 'shared_dim': 128, 'dim_z': dim_z_dict[resolution],
296 | 'hier': True, 'cross_replica': False,
297 | 'mybn': False, 'G_activation': nn.ReLU(inplace=True),
298 | 'G_attn': attn_dict[resolution],
299 | 'norm_style': 'bn',
300 | 'G_init': 'ortho', 'skip_init': True, 'no_optim': True,
301 | 'G_fp16': False, 'G_mixed_precision': False,
302 | 'accumulate_stats': False, 'num_standing_accumulations': 16,
303 | 'G_eval_mode': True,
304 | 'BN_eps': 1e-04, 'SN_eps': 1e-04,
305 | 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution,
306 | 'n_classes': 1000}
307 | return config
308 |
309 |
310 | def convert_biggan(resolution, weight_dir, redownload=False, no_ema=False, verbose=False):
311 | module_path = MODULE_PATH_TMPL.format(resolution)
312 | hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution))
313 | pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution))
314 |
315 | tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload)
316 | G_temp = getattr(biggan_for_conversion, f'Generator{resolution}')()
317 | state_dict_temp = G_temp.state_dict()
318 |
319 | converter = TFHub2Pytorch(state_dict_temp, tf_weights, resolution=resolution,
320 | load_ema=(not no_ema), verbose=verbose)
321 | state_dict_v1 = converter.load()
322 | state_dict = convert_from_v1(state_dict_v1, resolution)
323 | # Get the config, build the model
324 | config = get_config(resolution)
325 | G = BigGAN.Generator(**config)
326 | G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries
327 | torch.save(state_dict, pth_path)
328 |
329 | # output_location ='pretrained_weights/TFHub-PyTorch-128.pth'
330 |
331 | return G
332 |
333 |
334 | def generate_sample(G, z_dim, batch_size, filename, parallel=False):
335 |
336 | G.eval()
337 | G.to(DEVICE)
338 | with torch.no_grad():
339 | z = torch.randn(batch_size, G.dim_z).to(DEVICE)
340 | y = torch.randint(low=0, high=1000, size=(batch_size,),
341 | device=DEVICE, dtype=torch.int64, requires_grad=False)
342 | if parallel:
343 | images = nn.parallel.data_parallel(G, (z, G.shared(y)))
344 | else:
345 | images = G(z, G.shared(y))
346 | save_image(images, filename, scale_each=True, normalize=True)
347 |
348 | def parse_args():
349 | usage = 'Parser for conversion script.'
350 | parser = argparse.ArgumentParser(description=usage)
351 | parser.add_argument(
352 | '--resolution', '-r', type=int, default=None, choices=[128, 256, 512],
353 | help='Resolution of TFHub module to convert. Converts all resolutions if None.')
354 | parser.add_argument(
355 | '--redownload', action='store_true', default=False,
356 | help='Redownload weights and overwrite current hdf5 file, if present.')
357 | parser.add_argument(
358 | '--weights_dir', type=str, default='pretrained_weights')
359 | parser.add_argument(
360 | '--samples_dir', type=str, default='pretrained_samples')
361 | parser.add_argument(
362 | '--no_ema', action='store_true', default=False,
363 | help='Do not load ema weights.')
364 | parser.add_argument(
365 | '--verbose', action='store_true', default=False,
366 | help='Additionally logging.')
367 | parser.add_argument(
368 | '--generate_samples', action='store_true', default=False,
369 | help='Generate test sample with pretrained model.')
370 | parser.add_argument(
371 | '--batch_size', type=int, default=64,
372 | help='Batch size used for test sample.')
373 | parser.add_argument(
374 | '--parallel', action='store_true', default=False,
375 | help='Parallelize G?')
376 | args = parser.parse_args()
377 | return args
378 |
379 |
380 | if __name__ == '__main__':
381 |
382 | args = parse_args()
383 | os.makedirs(args.weights_dir, exist_ok=True)
384 | os.makedirs(args.samples_dir, exist_ok=True)
385 |
386 | if args.resolution is not None:
387 | G = convert_biggan(args.resolution, args.weights_dir,
388 | redownload=args.redownload,
389 | no_ema=args.no_ema, verbose=args.verbose)
390 | if args.generate_samples:
391 | filename = os.path.join(args.samples_dir, f'biggan{args.resolution}_samples.jpg')
392 | print('Generating samples...')
393 | generate_sample(G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel)
394 | else:
395 | for res in RESOLUTIONS:
396 | G = convert_biggan(res, args.weights_dir,
397 | redownload=args.redownload,
398 | no_ema=args.no_ema, verbose=args.verbose)
399 | if args.generate_samples:
400 | filename = os.path.join(args.samples_dir, f'biggan{res}_samples.jpg')
401 | print('Generating samples...')
402 | generate_sample(G, Z_DIMS[res], args.batch_size, filename, args.parallel)
--------------------------------------------------------------------------------
/biggan_pytorch/calculate_inception_moments.py:
--------------------------------------------------------------------------------
1 | ''' Calculate Inception Moments
2 | This script iterates over the dataset and calculates the moments of the
3 | activations of the Inception net (needed for FID), and also returns
4 | the Inception Score of the training data.
5 |
6 | Note that if you don't shuffle the data, the IS of true data will be under-
7 | estimated as it is label-ordered. By default, the data is not shuffled
8 | so as to reduce non-determinism. '''
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | import utils
15 | import inception_utils
16 | from tqdm import tqdm, trange
17 | from argparse import ArgumentParser
18 |
19 | def prepare_parser():
20 | usage = 'Calculate and store inception metrics.'
21 | parser = ArgumentParser(description=usage)
22 | parser.add_argument(
23 | '--dataset', type=str, default='I128_hdf5',
24 | help='Which Dataset to train on, out of I128, I256, C10, C100...'
25 | 'Append _hdf5 to use the hdf5 version of the dataset. (default: %(default)s)')
26 | parser.add_argument(
27 | '--data_root', type=str, default='data',
28 | help='Default location where data is stored (default: %(default)s)')
29 | parser.add_argument(
30 | '--batch_size', type=int, default=64,
31 | help='Default overall batchsize (default: %(default)s)')
32 | parser.add_argument(
33 | '--parallel', action='store_true', default=False,
34 | help='Train with multiple GPUs (default: %(default)s)')
35 | parser.add_argument(
36 | '--augment', action='store_true', default=False,
37 | help='Augment with random crops and flips (default: %(default)s)')
38 | parser.add_argument(
39 | '--num_workers', type=int, default=8,
40 | help='Number of dataloader workers (default: %(default)s)')
41 | parser.add_argument(
42 | '--shuffle', action='store_true', default=False,
43 | help='Shuffle the data? (default: %(default)s)')
44 | parser.add_argument(
45 | '--seed', type=int, default=0,
46 | help='Random seed to use.')
47 | return parser
48 |
49 | def run(config):
50 | # Get loader
51 | config['drop_last'] = False
52 | loaders = utils.get_data_loaders(**config)
53 |
54 | # Load inception net
55 | net = inception_utils.load_inception_net(parallel=config['parallel'])
56 | pool, logits, labels = [], [], []
57 | device = 'cuda'
58 | for i, (x, y) in enumerate(tqdm(loaders[0])):
59 | x = x.to(device)
60 | with torch.no_grad():
61 | pool_val, logits_val = net(x)
62 | pool += [np.asarray(pool_val.cpu())]
63 | logits += [np.asarray(F.softmax(logits_val, 1).cpu())]
64 | labels += [np.asarray(y.cpu())]
65 |
66 | pool, logits, labels = [np.concatenate(item, 0) for item in [pool, logits, labels]]
67 | # uncomment to save pool, logits, and labels to disk
68 | # print('Saving pool, logits, and labels to disk...')
69 | # np.savez(config['dataset']+'_inception_activations.npz',
70 | # {'pool': pool, 'logits': logits, 'labels': labels})
71 | # Calculate inception metrics and report them
72 | print('Calculating inception metrics...')
73 | IS_mean, IS_std = inception_utils.calculate_inception_score(logits)
74 | print('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config['dataset'], IS_mean, IS_std))
75 | # Prepare mu and sigma, save to disk. Remove "hdf5" by default
76 | # (the FID code also knows to strip "hdf5")
77 | print('Calculating means and covariances...')
78 | mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
79 | print('Saving calculated means and covariances to disk...')
80 | np.savez(config['dataset'].strip('_hdf5')+'_inception_moments.npz', **{'mu' : mu, 'sigma' : sigma})
81 |
82 | def main():
83 | # parse command line
84 | parser = prepare_parser()
85 | config = vars(parser.parse_args())
86 | print(config)
87 | run(config)
88 |
89 |
90 | if __name__ == '__main__':
91 | main()
--------------------------------------------------------------------------------
/biggan_pytorch/datasets.py:
--------------------------------------------------------------------------------
1 | ''' Datasets
2 | This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets
3 | '''
4 | import os
5 | import os.path
6 | import sys
7 | from PIL import Image
8 | import numpy as np
9 | from tqdm import tqdm, trange
10 |
11 | import torchvision.datasets as dset
12 | import torchvision.transforms as transforms
13 | from torchvision.datasets.utils import download_url, check_integrity
14 | import torch.utils.data as data
15 | from torch.utils.data import DataLoader
16 |
17 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
18 |
19 |
20 | def is_image_file(filename):
21 | """Checks if a file is an image.
22 |
23 | Args:
24 | filename (string): path to a file
25 |
26 | Returns:
27 | bool: True if the filename ends with a known image extension
28 | """
29 | filename_lower = filename.lower()
30 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
31 |
32 |
33 | def find_classes(dir):
34 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
35 | classes.sort()
36 | class_to_idx = {classes[i]: i for i in range(len(classes))}
37 | return classes, class_to_idx
38 |
39 |
40 | def make_dataset(dir, class_to_idx):
41 | images = []
42 | dir = os.path.expanduser(dir)
43 | for target in tqdm(sorted(os.listdir(dir))):
44 | d = os.path.join(dir, target)
45 | if not os.path.isdir(d):
46 | continue
47 |
48 | for root, _, fnames in sorted(os.walk(d)):
49 | for fname in sorted(fnames):
50 | if is_image_file(fname):
51 | path = os.path.join(root, fname)
52 | item = (path, class_to_idx[target])
53 | images.append(item)
54 |
55 | return images
56 |
57 |
58 | def pil_loader(path):
59 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
60 | with open(path, 'rb') as f:
61 | img = Image.open(f)
62 | return img.convert('RGB')
63 |
64 |
65 | def accimage_loader(path):
66 | import accimage
67 | try:
68 | return accimage.Image(path)
69 | except IOError:
70 | # Potentially a decoding problem, fall back to PIL.Image
71 | return pil_loader(path)
72 |
73 |
74 | def default_loader(path):
75 | from torchvision import get_image_backend
76 | if get_image_backend() == 'accimage':
77 | return accimage_loader(path)
78 | else:
79 | return pil_loader(path)
80 |
81 |
82 | class ImageFolder(data.Dataset):
83 | """A generic data loader where the images are arranged in this way: ::
84 |
85 | root/dogball/xxx.png
86 | root/dogball/xxy.png
87 | root/dogball/xxz.png
88 |
89 | root/cat/123.png
90 | root/cat/nsdf3.png
91 | root/cat/asd932_.png
92 |
93 | Args:
94 | root (string): Root directory path.
95 | transform (callable, optional): A function/transform that takes in an PIL image
96 | and returns a transformed version. E.g, ``transforms.RandomCrop``
97 | target_transform (callable, optional): A function/transform that takes in the
98 | target and transforms it.
99 | loader (callable, optional): A function to load an image given its path.
100 |
101 | Attributes:
102 | classes (list): List of the class names.
103 | class_to_idx (dict): Dict with items (class_name, class_index).
104 | imgs (list): List of (image path, class_index) tuples
105 | """
106 |
107 | def __init__(self, root, transform=None, target_transform=None,
108 | loader=default_loader, load_in_mem=False,
109 | index_filename='imagenet_imgs.npz', **kwargs):
110 | classes, class_to_idx = find_classes(root)
111 | # Load pre-computed image directory walk
112 | if os.path.exists(index_filename):
113 | print('Loading pre-saved Index file %s...' % index_filename)
114 | imgs = np.load(index_filename)['imgs']
115 | # If first time, walk the folder directory and save the
116 | # results to a pre-computed file.
117 | else:
118 | print('Generating Index file %s...' % index_filename)
119 | imgs = make_dataset(root, class_to_idx)
120 | np.savez_compressed(index_filename, **{'imgs' : imgs})
121 | if len(imgs) == 0:
122 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
123 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
124 |
125 | self.root = root
126 | self.imgs = imgs
127 | self.classes = classes
128 | self.class_to_idx = class_to_idx
129 | self.transform = transform
130 | self.target_transform = target_transform
131 | self.loader = loader
132 | self.load_in_mem = load_in_mem
133 |
134 | if self.load_in_mem:
135 | print('Loading all images into memory...')
136 | self.data, self.labels = [], []
137 | for index in tqdm(range(len(self.imgs))):
138 | path, target = imgs[index][0], imgs[index][1]
139 | self.data.append(self.transform(self.loader(path)))
140 | self.labels.append(target)
141 |
142 |
143 | def __getitem__(self, index):
144 | """
145 | Args:
146 | index (int): Index
147 |
148 | Returns:
149 | tuple: (image, target) where target is class_index of the target class.
150 | """
151 | if self.load_in_mem:
152 | img = self.data[index]
153 | target = self.labels[index]
154 | else:
155 | path, target = self.imgs[index]
156 | img = self.loader(str(path))
157 | if self.transform is not None:
158 | img = self.transform(img)
159 |
160 | if self.target_transform is not None:
161 | target = self.target_transform(target)
162 |
163 | # print(img.size(), target)
164 | return img, int(target)
165 |
166 | def __len__(self):
167 | return len(self.imgs)
168 |
169 | def __repr__(self):
170 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
171 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
172 | fmt_str += ' Root Location: {}\n'.format(self.root)
173 | tmp = ' Transforms (if any): '
174 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
175 | tmp = ' Target Transforms (if any): '
176 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
177 | return fmt_str
178 |
179 |
180 | ''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid
181 | having to load individual images all the time. '''
182 | import h5py as h5
183 | import torch
184 | class ILSVRC_HDF5(data.Dataset):
185 | def __init__(self, root, transform=None, target_transform=None,
186 | load_in_mem=False, train=True,download=False, validate_seed=0,
187 | val_split=0, **kwargs): # last four are dummies
188 |
189 | self.root = root
190 | self.num_imgs = len(h5.File(root, 'r')['labels'])
191 |
192 | # self.transform = transform
193 | self.target_transform = target_transform
194 |
195 | # Set the transform here
196 | self.transform = transform
197 |
198 | # load the entire dataset into memory?
199 | self.load_in_mem = load_in_mem
200 |
201 | # If loading into memory, do so now
202 | if self.load_in_mem:
203 | print('Loading %s into memory...' % root)
204 | with h5.File(root,'r') as f:
205 | self.data = f['imgs'][:]
206 | self.labels = f['labels'][:]
207 |
208 | def __getitem__(self, index):
209 | """
210 | Args:
211 | index (int): Index
212 |
213 | Returns:
214 | tuple: (image, target) where target is class_index of the target class.
215 | """
216 | # If loaded the entire dataset in RAM, get image from memory
217 | if self.load_in_mem:
218 | img = self.data[index]
219 | target = self.labels[index]
220 |
221 | # Else load it from disk
222 | else:
223 | with h5.File(self.root,'r') as f:
224 | img = f['imgs'][index]
225 | target = f['labels'][index]
226 |
227 |
228 | # if self.transform is not None:
229 | # img = self.transform(img)
230 | # Apply my own transform
231 | img = ((torch.from_numpy(img).float() / 255) - 0.5) * 2
232 |
233 | if self.target_transform is not None:
234 | target = self.target_transform(target)
235 |
236 | return img, int(target)
237 |
238 | def __len__(self):
239 | return self.num_imgs
240 | # return len(self.f['imgs'])
241 |
242 | import pickle
243 | class CIFAR10(dset.CIFAR10):
244 |
245 | def __init__(self, root, train=True,
246 | transform=None, target_transform=None,
247 | download=True, validate_seed=0,
248 | val_split=0, load_in_mem=True, **kwargs):
249 | self.root = os.path.expanduser(root)
250 | self.transform = transform
251 | self.target_transform = target_transform
252 | self.train = train # training set or test set
253 | self.val_split = val_split
254 |
255 | if download:
256 | self.download()
257 |
258 | if not self._check_integrity():
259 | raise RuntimeError('Dataset not found or corrupted.' +
260 | ' You can use download=True to download it')
261 |
262 | # now load the picked numpy arrays
263 | self.data = []
264 | self.labels= []
265 | for fentry in self.train_list:
266 | f = fentry[0]
267 | file = os.path.join(self.root, self.base_folder, f)
268 | fo = open(file, 'rb')
269 | if sys.version_info[0] == 2:
270 | entry = pickle.load(fo)
271 | else:
272 | entry = pickle.load(fo, encoding='latin1')
273 | self.data.append(entry['data'])
274 | if 'labels' in entry:
275 | self.labels += entry['labels']
276 | else:
277 | self.labels += entry['fine_labels']
278 | fo.close()
279 |
280 | self.data = np.concatenate(self.data)
281 | # Randomly select indices for validation
282 | if self.val_split > 0:
283 | label_indices = [[] for _ in range(max(self.labels)+1)]
284 | for i,l in enumerate(self.labels):
285 | label_indices[l] += [i]
286 | label_indices = np.asarray(label_indices)
287 |
288 | # randomly grab 500 elements of each class
289 | np.random.seed(validate_seed)
290 | self.val_indices = []
291 | for l_i in label_indices:
292 | self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)])
293 |
294 | if self.train=='validate':
295 | self.data = self.data[self.val_indices]
296 | self.labels = list(np.asarray(self.labels)[self.val_indices])
297 |
298 | self.data = self.data.reshape((int(50e3 * self.val_split), 3, 32, 32))
299 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
300 |
301 | elif self.train:
302 | print(np.shape(self.data))
303 | if self.val_split > 0:
304 | self.data = np.delete(self.data,self.val_indices,axis=0)
305 | self.labels = list(np.delete(np.asarray(self.labels),self.val_indices,axis=0))
306 |
307 | self.data = self.data.reshape((int(50e3 * (1.-self.val_split)), 3, 32, 32))
308 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
309 | else:
310 | f = self.test_list[0][0]
311 | file = os.path.join(self.root, self.base_folder, f)
312 | fo = open(file, 'rb')
313 | if sys.version_info[0] == 2:
314 | entry = pickle.load(fo)
315 | else:
316 | entry = pickle.load(fo, encoding='latin1')
317 | self.data = entry['data']
318 | if 'labels' in entry:
319 | self.labels = entry['labels']
320 | else:
321 | self.labels = entry['fine_labels']
322 | fo.close()
323 | self.data = self.data.reshape((10000, 3, 32, 32))
324 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
325 |
326 | def __getitem__(self, index):
327 | """
328 | Args:
329 | index (int): Index
330 | Returns:
331 | tuple: (image, target) where target is index of the target class.
332 | """
333 | img, target = self.data[index], self.labels[index]
334 |
335 | # doing this so that it is consistent with all other datasets
336 | # to return a PIL Image
337 | img = Image.fromarray(img)
338 |
339 | if self.transform is not None:
340 | img = self.transform(img)
341 |
342 | if self.target_transform is not None:
343 | target = self.target_transform(target)
344 |
345 | return img, target
346 |
347 | def __len__(self):
348 | return len(self.data)
349 |
350 |
351 | class CIFAR100(CIFAR10):
352 | base_folder = 'cifar-100-python'
353 | url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
354 | filename = "cifar-100-python.tar.gz"
355 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
356 | train_list = [
357 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
358 | ]
359 |
360 | test_list = [
361 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
362 | ]
363 |
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/D Singular Values.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/D Singular Values.png
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/DeepSamples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/DeepSamples.png
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/DogBall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/DogBall.png
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/G Singular Values.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/G Singular Values.png
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/IS_FID.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/IS_FID.png
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/Losses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/Losses.png
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/header_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/header_image.jpg
--------------------------------------------------------------------------------
/biggan_pytorch/imgs/interp_sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/biggan_pytorch/imgs/interp_sample.jpg
--------------------------------------------------------------------------------
/biggan_pytorch/inception_tf13.py:
--------------------------------------------------------------------------------
1 | ''' Tensorflow inception score code
2 | Derived from https://github.com/openai/improved-gan
3 | Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
4 | THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE
5 |
6 | To use this code, run sample.py on your model with --sample_npz, and then
7 | pass the experiment name in the --experiment_name.
8 | This code also saves pool3 stats to an npz file for FID calculation
9 | '''
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import os.path
15 | import sys
16 | import tarfile
17 | import math
18 | from tqdm import tqdm, trange
19 | from argparse import ArgumentParser
20 |
21 | import numpy as np
22 | from six.moves import urllib
23 | import tensorflow as tf
24 |
25 | MODEL_DIR = ''
26 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
27 | softmax = None
28 |
29 | def prepare_parser():
30 | usage = 'Parser for TF1.3- Inception Score scripts.'
31 | parser = ArgumentParser(description=usage)
32 | parser.add_argument(
33 | '--experiment_name', type=str, default='',
34 | help='Which experiment''s samples.npz file to pull and evaluate')
35 | parser.add_argument(
36 | '--experiment_root', type=str, default='samples',
37 | help='Default location where samples are stored (default: %(default)s)')
38 | parser.add_argument(
39 | '--batch_size', type=int, default=500,
40 | help='Default overall batchsize (default: %(default)s)')
41 | return parser
42 |
43 |
44 | def run(config):
45 | # Inception with TF1.3 or earlier.
46 | # Call this function with list of images. Each of elements should be a
47 | # numpy array with values ranging from 0 to 255.
48 | def get_inception_score(images, splits=10):
49 | assert(type(images) == list)
50 | assert(type(images[0]) == np.ndarray)
51 | assert(len(images[0].shape) == 3)
52 | assert(np.max(images[0]) > 10)
53 | assert(np.min(images[0]) >= 0.0)
54 | inps = []
55 | for img in images:
56 | img = img.astype(np.float32)
57 | inps.append(np.expand_dims(img, 0))
58 | bs = config['batch_size']
59 | with tf.Session() as sess:
60 | preds, pools = [], []
61 | n_batches = int(math.ceil(float(len(inps)) / float(bs)))
62 | for i in trange(n_batches):
63 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
64 | inp = np.concatenate(inp, 0)
65 | pred, pool = sess.run([softmax, pool3], {'ExpandDims:0': inp})
66 | preds.append(pred)
67 | pools.append(pool)
68 | preds = np.concatenate(preds, 0)
69 | scores = []
70 | for i in range(splits):
71 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
72 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
73 | kl = np.mean(np.sum(kl, 1))
74 | scores.append(np.exp(kl))
75 | return np.mean(scores), np.std(scores), np.squeeze(np.concatenate(pools, 0))
76 | # Init inception
77 | def _init_inception():
78 | global softmax, pool3
79 | if not os.path.exists(MODEL_DIR):
80 | os.makedirs(MODEL_DIR)
81 | filename = DATA_URL.split('/')[-1]
82 | filepath = os.path.join(MODEL_DIR, filename)
83 | if not os.path.exists(filepath):
84 | def _progress(count, block_size, total_size):
85 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
86 | filename, float(count * block_size) / float(total_size) * 100.0))
87 | sys.stdout.flush()
88 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
89 | print()
90 | statinfo = os.stat(filepath)
91 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
92 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
93 | with tf.gfile.FastGFile(os.path.join(
94 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
95 | graph_def = tf.GraphDef()
96 | graph_def.ParseFromString(f.read())
97 | _ = tf.import_graph_def(graph_def, name='')
98 | # Works with an arbitrary minibatch size.
99 | with tf.Session() as sess:
100 | pool3 = sess.graph.get_tensor_by_name('pool_3:0')
101 | ops = pool3.graph.get_operations()
102 | for op_idx, op in enumerate(ops):
103 | for o in op.outputs:
104 | shape = o.get_shape()
105 | shape = [s.value for s in shape]
106 | new_shape = []
107 | for j, s in enumerate(shape):
108 | if s == 1 and j == 0:
109 | new_shape.append(None)
110 | else:
111 | new_shape.append(s)
112 | o._shape = tf.TensorShape(new_shape)
113 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
114 | logits = tf.matmul(tf.squeeze(pool3), w)
115 | softmax = tf.nn.softmax(logits)
116 |
117 | # if softmax is None: # No need to functionalize like this.
118 | _init_inception()
119 |
120 | fname = '%s/%s/samples.npz' % (config['experiment_root'], config['experiment_name'])
121 | print('loading %s ...'%fname)
122 | ims = np.load(fname)['x']
123 | import time
124 | t0 = time.time()
125 | inc_mean, inc_std, pool_activations = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=10)
126 | t1 = time.time()
127 | print('Saving pool to numpy file for FID calculations...')
128 | np.savez('%s/%s/TF_pool.npz' % (config['experiment_root'], config['experiment_name']), **{'pool_mean': np.mean(pool_activations,axis=0), 'pool_var': np.cov(pool_activations, rowvar=False)})
129 | print('Inception took %3f seconds, score of %3f +/- %3f.'%(t1-t0, inc_mean, inc_std))
130 | def main():
131 | # parse command line and run
132 | parser = prepare_parser()
133 | config = vars(parser.parse_args())
134 | print(config)
135 | run(config)
136 |
137 | if __name__ == '__main__':
138 | main()
--------------------------------------------------------------------------------
/biggan_pytorch/inception_utils.py:
--------------------------------------------------------------------------------
1 | ''' Inception utilities
2 | This file contains methods for calculating IS and FID, using either
3 | the original numpy code or an accelerated fully-pytorch version that
4 | uses a fast newton-schulz approximation for the matrix sqrt. There are also
5 | methods for acquiring a desired number of samples from the Generator,
6 | and parallelizing the inbuilt PyTorch inception network.
7 |
8 | NOTE that Inception Scores and FIDs calculated using these methods will
9 | *not* be directly comparable to values calculated using the original TF
10 | IS/FID code. You *must* use the TF model if you wish to report and compare
11 | numbers. This code tends to produce IS values that are 5-10% lower than
12 | those obtained through TF.
13 | '''
14 | import numpy as np
15 | from scipy import linalg # For numpy FID
16 | import time
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch.nn import Parameter as P
22 | from torchvision.models.inception import inception_v3
23 |
24 |
25 | # Module that wraps the inception network to enable use with dataparallel and
26 | # returning pool features and logits.
27 | class WrapInception(nn.Module):
28 | def __init__(self, net):
29 | super(WrapInception,self).__init__()
30 | self.net = net
31 | self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1),
32 | requires_grad=False)
33 | self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1),
34 | requires_grad=False)
35 | def forward(self, x):
36 | # Normalize x
37 | x = (x + 1.) / 2.0
38 | x = (x - self.mean) / self.std
39 | # Upsample if necessary
40 | if x.shape[2] != 299 or x.shape[3] != 299:
41 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
42 | # 299 x 299 x 3
43 | x = self.net.Conv2d_1a_3x3(x)
44 | # 149 x 149 x 32
45 | x = self.net.Conv2d_2a_3x3(x)
46 | # 147 x 147 x 32
47 | x = self.net.Conv2d_2b_3x3(x)
48 | # 147 x 147 x 64
49 | x = F.max_pool2d(x, kernel_size=3, stride=2)
50 | # 73 x 73 x 64
51 | x = self.net.Conv2d_3b_1x1(x)
52 | # 73 x 73 x 80
53 | x = self.net.Conv2d_4a_3x3(x)
54 | # 71 x 71 x 192
55 | x = F.max_pool2d(x, kernel_size=3, stride=2)
56 | # 35 x 35 x 192
57 | x = self.net.Mixed_5b(x)
58 | # 35 x 35 x 256
59 | x = self.net.Mixed_5c(x)
60 | # 35 x 35 x 288
61 | x = self.net.Mixed_5d(x)
62 | # 35 x 35 x 288
63 | x = self.net.Mixed_6a(x)
64 | # 17 x 17 x 768
65 | x = self.net.Mixed_6b(x)
66 | # 17 x 17 x 768
67 | x = self.net.Mixed_6c(x)
68 | # 17 x 17 x 768
69 | x = self.net.Mixed_6d(x)
70 | # 17 x 17 x 768
71 | x = self.net.Mixed_6e(x)
72 | # 17 x 17 x 768
73 | # 17 x 17 x 768
74 | x = self.net.Mixed_7a(x)
75 | # 8 x 8 x 1280
76 | x = self.net.Mixed_7b(x)
77 | # 8 x 8 x 2048
78 | x = self.net.Mixed_7c(x)
79 | # 8 x 8 x 2048
80 | pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2)
81 | # 1 x 1 x 2048
82 | logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1))
83 | # 1000 (num_classes)
84 | return pool, logits
85 |
86 |
87 | # A pytorch implementation of cov, from Modar M. Alfadly
88 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
89 | def torch_cov(m, rowvar=False):
90 | '''Estimate a covariance matrix given data.
91 |
92 | Covariance indicates the level to which two variables vary together.
93 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
94 | then the covariance matrix element `C_{ij}` is the covariance of
95 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
96 |
97 | Args:
98 | m: A 1-D or 2-D array containing multiple variables and observations.
99 | Each row of `m` represents a variable, and each column a single
100 | observation of all those variables.
101 | rowvar: If `rowvar` is True, then each row represents a
102 | variable, with observations in the columns. Otherwise, the
103 | relationship is transposed: each column represents a variable,
104 | while the rows contain observations.
105 |
106 | Returns:
107 | The covariance matrix of the variables.
108 | '''
109 | if m.dim() > 2:
110 | raise ValueError('m has more than 2 dimensions')
111 | if m.dim() < 2:
112 | m = m.view(1, -1)
113 | if not rowvar and m.size(0) != 1:
114 | m = m.t()
115 | # m = m.type(torch.double) # uncomment this line if desired
116 | fact = 1.0 / (m.size(1) - 1)
117 | m -= torch.mean(m, dim=1, keepdim=True)
118 | mt = m.t() # if complex: mt = m.t().conj()
119 | return fact * m.matmul(mt).squeeze()
120 |
121 |
122 | # Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji
123 | # https://github.com/msubhransu/matrix-sqrt
124 | def sqrt_newton_schulz(A, numIters, dtype=None):
125 | with torch.no_grad():
126 | if dtype is None:
127 | dtype = A.type()
128 | batchSize = A.shape[0]
129 | dim = A.shape[1]
130 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
131 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A));
132 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
133 | Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
134 | for i in range(numIters):
135 | T = 0.5*(3.0*I - Z.bmm(Y))
136 | Y = Y.bmm(T)
137 | Z = T.bmm(Z)
138 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
139 | return sA
140 |
141 |
142 | # FID calculator from TTUR--consider replacing this with GPU-accelerated cov
143 | # calculations using torch?
144 | def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
145 | """Numpy implementation of the Frechet Distance.
146 | Taken from https://github.com/bioinf-jku/TTUR
147 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
148 | and X_2 ~ N(mu_2, C_2) is
149 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
150 | Stable version by Dougal J. Sutherland.
151 | Params:
152 | -- mu1 : Numpy array containing the activations of a layer of the
153 | inception net (like returned by the function 'get_predictions')
154 | for generated samples.
155 | -- mu2 : The sample mean over activations, precalculated on an
156 | representive data set.
157 | -- sigma1: The covariance matrix over activations for generated samples.
158 | -- sigma2: The covariance matrix over activations, precalculated on an
159 | representive data set.
160 | Returns:
161 | -- : The Frechet Distance.
162 | """
163 |
164 | mu1 = np.atleast_1d(mu1)
165 | mu2 = np.atleast_1d(mu2)
166 |
167 | sigma1 = np.atleast_2d(sigma1)
168 | sigma2 = np.atleast_2d(sigma2)
169 |
170 | assert mu1.shape == mu2.shape, \
171 | 'Training and test mean vectors have different lengths'
172 | assert sigma1.shape == sigma2.shape, \
173 | 'Training and test covariances have different dimensions'
174 |
175 | diff = mu1 - mu2
176 |
177 | # Product might be almost singular
178 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
179 | if not np.isfinite(covmean).all():
180 | msg = ('fid calculation produces singular product; '
181 | 'adding %s to diagonal of cov estimates') % eps
182 | print(msg)
183 | offset = np.eye(sigma1.shape[0]) * eps
184 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
185 |
186 | # Numerical error might give slight imaginary component
187 | if np.iscomplexobj(covmean):
188 | print('wat')
189 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
190 | m = np.max(np.abs(covmean.imag))
191 | raise ValueError('Imaginary component {}'.format(m))
192 | covmean = covmean.real
193 |
194 | tr_covmean = np.trace(covmean)
195 |
196 | out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
197 | return out
198 |
199 |
200 | def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
201 | """Pytorch implementation of the Frechet Distance.
202 | Taken from https://github.com/bioinf-jku/TTUR
203 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
204 | and X_2 ~ N(mu_2, C_2) is
205 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
206 | Stable version by Dougal J. Sutherland.
207 | Params:
208 | -- mu1 : Numpy array containing the activations of a layer of the
209 | inception net (like returned by the function 'get_predictions')
210 | for generated samples.
211 | -- mu2 : The sample mean over activations, precalculated on an
212 | representive data set.
213 | -- sigma1: The covariance matrix over activations for generated samples.
214 | -- sigma2: The covariance matrix over activations, precalculated on an
215 | representive data set.
216 | Returns:
217 | -- : The Frechet Distance.
218 | """
219 |
220 |
221 | assert mu1.shape == mu2.shape, \
222 | 'Training and test mean vectors have different lengths'
223 | assert sigma1.shape == sigma2.shape, \
224 | 'Training and test covariances have different dimensions'
225 |
226 | diff = mu1 - mu2
227 | # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2
228 | covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze()
229 | out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2)
230 | - 2 * torch.trace(covmean))
231 | return out
232 |
233 |
234 | # Calculate Inception Score mean + std given softmax'd logits and number of splits
235 | def calculate_inception_score(pred, num_splits=10):
236 | scores = []
237 | for index in range(num_splits):
238 | pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :]
239 | kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0)))
240 | kl_inception = np.mean(np.sum(kl_inception, 1))
241 | scores.append(np.exp(kl_inception))
242 | return np.mean(scores), np.std(scores)
243 |
244 |
245 | # Loop and run the sampler and the net until it accumulates num_inception_images
246 | # activations. Return the pool, the logits, and the labels (if one wants
247 | # Inception Accuracy the labels of the generated class will be needed)
248 | def accumulate_inception_activations(sample, net, num_inception_images=50000):
249 | pool, logits, labels = [], [], []
250 | while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images:
251 | with torch.no_grad():
252 | images, labels_val = sample()
253 | pool_val, logits_val = net(images.float())
254 | pool += [pool_val]
255 | logits += [F.softmax(logits_val, 1)]
256 | labels += [labels_val]
257 | return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0)
258 |
259 |
260 | # Load and wrap the Inception model
261 | def load_inception_net(parallel=False):
262 | inception_model = inception_v3(pretrained=True, transform_input=False)
263 | inception_model = WrapInception(inception_model.eval()).cuda()
264 | if parallel:
265 | print('Parallelizing Inception module...')
266 | inception_model = nn.DataParallel(inception_model)
267 | return inception_model
268 |
269 |
270 | # This produces a function which takes in an iterator which returns a set number of samples
271 | # and iterates until it accumulates config['num_inception_images'] images.
272 | # The iterator can return samples with a different batch size than used in
273 | # training, using the setting confg['inception_batchsize']
274 | def prepare_inception_metrics(dataset, parallel, no_fid=False):
275 | # Load metrics; this is intentionally not in a try-except loop so that
276 | # the script will crash here if it cannot find the Inception moments.
277 | # By default, remove the "hdf5" from dataset
278 | dataset = dataset.strip('_hdf5')
279 | data_mu = np.load(dataset+'_inception_moments.npz')['mu']
280 | data_sigma = np.load(dataset+'_inception_moments.npz')['sigma']
281 | # Load network
282 | net = load_inception_net(parallel)
283 | def get_inception_metrics(sample, num_inception_images, num_splits=10,
284 | prints=True, use_torch=True):
285 | if prints:
286 | print('Gathering activations...')
287 | pool, logits, labels = accumulate_inception_activations(sample, net, num_inception_images)
288 | if prints:
289 | print('Calculating Inception Score...')
290 | IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits)
291 | if no_fid:
292 | FID = 9999.0
293 | else:
294 | if prints:
295 | print('Calculating means and covariances...')
296 | if use_torch:
297 | mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False)
298 | else:
299 | mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False)
300 | if prints:
301 | print('Covariances calculated, getting FID...')
302 | if use_torch:
303 | FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda())
304 | FID = float(FID.cpu().numpy())
305 | else:
306 | FID = numpy_calculate_frechet_distance(mu.cpu().numpy(), sigma.cpu().numpy(), data_mu, data_sigma)
307 | # Delete mu, sigma, pool, logits, and labels, just in case
308 | del mu, sigma, pool, logits, labels
309 | return IS_mean, IS_std, FID
310 | return get_inception_metrics
--------------------------------------------------------------------------------
/biggan_pytorch/logs/BigGAN_ch96_bs256x8.jsonl:
--------------------------------------------------------------------------------
1 | {"itr": 2000, "IS_mean": 2.806771755218506, "IS_std": 0.019480662420392036, "FID": 173.76484159711126, "_stamp": 1551403232.0425167}
2 | {"itr": 4000, "IS_mean": 4.962374687194824, "IS_std": 0.07276841998100281, "FID": 113.86730514283107, "_stamp": 1551422228.743057}
3 | {"itr": 6000, "IS_mean": 6.939817905426025, "IS_std": 0.11417163163423538, "FID": 101.63548498447199, "_stamp": 1551457139.3400874}
4 | {"itr": 8000, "IS_mean": 8.142985343933105, "IS_std": 0.11931543797254562, "FID": 92.0014385772705, "_stamp": 1551476217.2409613}
5 | {"itr": 10000, "IS_mean": 10.355518341064453, "IS_std": 0.09094739705324173, "FID": 83.58068997965364, "_stamp": 1551494854.2419689}
6 | {"itr": 12000, "IS_mean": 11.288347244262695, "IS_std": 0.14952820539474487, "FID": 80.98066299357106, "_stamp": 1551513232.5049698}
7 | {"itr": 14000, "IS_mean": 11.755794525146484, "IS_std": 0.17969024181365967, "FID": 76.80603924280956, "_stamp": 1551531425.150371}
8 | {"itr": 18000, "IS_mean": 13.65534496307373, "IS_std": 0.11151058971881866, "FID": 65.95736694335938, "_stamp": 1551588271.9177916}
9 | {"itr": 20000, "IS_mean": 14.817827224731445, "IS_std": 0.23588882386684418, "FID": 61.32061767578125, "_stamp": 1551606713.6567464}
10 | {"itr": 22000, "IS_mean": 17.16551399230957, "IS_std": 0.19506946206092834, "FID": 53.387969970703125, "_stamp": 1551624876.6513028}
11 | {"itr": 24000, "IS_mean": 19.60654067993164, "IS_std": 0.5591856837272644, "FID": 46.5386962890625, "_stamp": 1551642822.6126688}
12 | {"itr": 26000, "IS_mean": 21.74416732788086, "IS_std": 0.2850531041622162, "FID": 41.595001220703125, "_stamp": 1551663522.6019194}
13 | {"itr": 28000, "IS_mean": 23.923612594604492, "IS_std": 0.41587772965431213, "FID": 37.894744873046875, "_stamp": 1551681794.6567173}
14 | {"itr": 30000, "IS_mean": 25.569377899169922, "IS_std": 0.3333457112312317, "FID": 35.49310302734375, "_stamp": 1551699773.7080302}
15 | {"itr": 32000, "IS_mean": 26.867944717407227, "IS_std": 0.5968036651611328, "FID": 33.4849853515625, "_stamp": 1551717623.887933}
16 | {"itr": 34000, "IS_mean": 28.719074249267578, "IS_std": 0.5698027014732361, "FID": 31.375518798828125, "_stamp": 1551735411.1578612}
17 | {"itr": 36000, "IS_mean": 30.587574005126953, "IS_std": 0.5044271349906921, "FID": 29.432281494140625, "_stamp": 1551783380.6357439}
18 | {"itr": 38000, "IS_mean": 32.08299255371094, "IS_std": 0.49342143535614014, "FID": 28.099456787109375, "_stamp": 1551801179.6495197}
19 | {"itr": 40000, "IS_mean": 34.24657440185547, "IS_std": 0.7709177732467651, "FID": 26.53802490234375, "_stamp": 1551818775.171794}
20 | {"itr": 42000, "IS_mean": 35.891212463378906, "IS_std": 0.7036871314048767, "FID": 25.03021240234375, "_stamp": 1551836329.6873965}
21 | {"itr": 44000, "IS_mean": 38.184898376464844, "IS_std": 0.32996198534965515, "FID": 23.4940185546875, "_stamp": 1551897864.911537}
22 | {"itr": 46000, "IS_mean": 40.239479064941406, "IS_std": 0.7761151194572449, "FID": 22.53167724609375, "_stamp": 1551915406.4840703}
23 | {"itr": 48000, "IS_mean": 41.46656036376953, "IS_std": 1.1031498908996582, "FID": 21.5338134765625, "_stamp": 1551932899.6074848}
24 | {"itr": 50000, "IS_mean": 43.31670379638672, "IS_std": 0.7796809077262878, "FID": 20.53253173828125, "_stamp": 1551950390.345334}
25 | {"itr": 52000, "IS_mean": 45.1517333984375, "IS_std": 1.2925242185592651, "FID": 19.656646728515625, "_stamp": 1551967838.1501615}
26 | {"itr": 54000, "IS_mean": 47.638771057128906, "IS_std": 1.0689665079116821, "FID": 18.898162841796875, "_stamp": 1552044534.5349634}
27 | {"itr": 56000, "IS_mean": 48.87520217895508, "IS_std": 1.1317559480667114, "FID": 18.1248779296875, "_stamp": 1552061763.3080354}
28 | {"itr": 58000, "IS_mean": 49.40987014770508, "IS_std": 1.1866596937179565, "FID": 17.751922607421875, "_stamp": 1552078939.9828825}
29 | {"itr": 60000, "IS_mean": 51.051334381103516, "IS_std": 1.2281248569488525, "FID": 17.19964599609375, "_stamp": 1552096167.889482}
30 | {"itr": 62000, "IS_mean": 52.0235481262207, "IS_std": 0.5391153693199158, "FID": 16.62115478515625, "_stamp": 1552113417.9520617}
31 | {"itr": 64000, "IS_mean": 53.868492126464844, "IS_std": 1.327082633972168, "FID": 16.237335205078125, "_stamp": 1552142961.09602}
32 | {"itr": 66000, "IS_mean": 54.978721618652344, "IS_std": 0.9502049088478088, "FID": 15.81170654296875, "_stamp": 1552162403.2232807}
33 | {"itr": 68000, "IS_mean": 55.73248291015625, "IS_std": 1.0323851108551025, "FID": 15.545623779296875, "_stamp": 1552181112.676657}
34 | {"itr": 70000, "IS_mean": 56.78422927856445, "IS_std": 1.211003303527832, "FID": 15.28369140625, "_stamp": 1552199498.887533}
35 | {"itr": 72000, "IS_mean": 57.972999572753906, "IS_std": 0.8668608665466309, "FID": 14.86395263671875, "_stamp": 1552217782.2738616}
36 | {"itr": 74000, "IS_mean": 58.845054626464844, "IS_std": 1.4297977685928345, "FID": 14.620635986328125, "_stamp": 1552251085.1781816}
37 | {"itr": 76000, "IS_mean": 59.60982131958008, "IS_std": 0.9095696210861206, "FID": 14.360198974609375, "_stamp": 1552270214.9345307}
38 | {"itr": 78000, "IS_mean": 60.71195602416992, "IS_std": 0.960899829864502, "FID": 14.07183837890625, "_stamp": 1552288697.1580262}
39 | {"itr": 80000, "IS_mean": 61.772125244140625, "IS_std": 0.6913255453109741, "FID": 13.781585693359375, "_stamp": 1552307170.0280282}
40 | {"itr": 82000, "IS_mean": 62.98079299926758, "IS_std": 1.4735801219940186, "FID": 13.55389404296875, "_stamp": 1552325252.8553352}
41 | {"itr": 84000, "IS_mean": 64.95240783691406, "IS_std": 0.9018951654434204, "FID": 13.231689453125, "_stamp": 1552344135.3111835}
42 | {"itr": 86000, "IS_mean": 65.13968658447266, "IS_std": 0.8772205114364624, "FID": 13.176849365234375, "_stamp": 1552362429.6782444}
43 | {"itr": 88000, "IS_mean": 65.84476470947266, "IS_std": 1.167534351348877, "FID": 12.87078857421875, "_stamp": 1552380560.7988124}
44 | {"itr": 90000, "IS_mean": 67.41099548339844, "IS_std": 1.6899267435073853, "FID": 12.586517333984375, "_stamp": 1552398550.2060475}
45 | {"itr": 92000, "IS_mean": 68.63685607910156, "IS_std": 1.9431978464126587, "FID": 12.49505615234375, "_stamp": 1552430781.6406457}
46 | {"itr": 94000, "IS_mean": 70.09907531738281, "IS_std": 1.0715738534927368, "FID": 12.047607421875, "_stamp": 1552449001.1950285}
47 | {"itr": 96000, "IS_mean": 70.34623718261719, "IS_std": 1.7962944507598877, "FID": 11.896697998046875, "_stamp": 1552466989.3587568}
48 | {"itr": 98000, "IS_mean": 71.08210754394531, "IS_std": 1.458209753036499, "FID": 11.73046875, "_stamp": 1552484800.7138846}
49 | {"itr": 100000, "IS_mean": 72.24256896972656, "IS_std": 1.3259714841842651, "FID": 11.7386474609375, "_stamp": 1552502538.0269725}
50 | {"itr": 102000, "IS_mean": 73.19488525390625, "IS_std": 1.3439149856567383, "FID": 11.50494384765625, "_stamp": 1552523284.4514356}
51 | {"itr": 104000, "IS_mean": 73.38243103027344, "IS_std": 1.4162707328796387, "FID": 11.374542236328125, "_stamp": 1552541012.0651608}
52 | {"itr": 106000, "IS_mean": 74.95563507080078, "IS_std": 1.089124083518982, "FID": 11.10479736328125, "_stamp": 1552558577.7458107}
53 | {"itr": 108000, "IS_mean": 76.42997741699219, "IS_std": 1.9282453060150146, "FID": 10.998870849609375, "_stamp": 1552576111.9480467}
54 | {"itr": 110000, "IS_mean": 76.89225769042969, "IS_std": 1.4771150350570679, "FID": 10.847015380859375, "_stamp": 1552593659.445132}
55 | {"itr": 112000, "IS_mean": 78.04684448242188, "IS_std": 1.4850096702575684, "FID": 10.772552490234375, "_stamp": 1552616479.5201895}
56 | {"itr": 114000, "IS_mean": 79.67677307128906, "IS_std": 2.0147368907928467, "FID": 10.528045654296875, "_stamp": 1552633850.9315467}
57 | {"itr": 116000, "IS_mean": 79.8828125, "IS_std": 0.978247344493866, "FID": 10.626068115234375, "_stamp": 1552651198.9012825}
58 | {"itr": 118000, "IS_mean": 79.95381164550781, "IS_std": 1.8608143329620361, "FID": 10.46771240234375, "_stamp": 1552668560.4420238}
59 | {"itr": 120000, "IS_mean": 82.37217712402344, "IS_std": 1.8909310102462769, "FID": 10.259033203125, "_stamp": 1552749673.4319007}
60 | {"itr": 122000, "IS_mean": 83.49666595458984, "IS_std": 2.38446044921875, "FID": 9.996185302734375, "_stamp": 1552766698.2706933}
61 | {"itr": 124000, "IS_mean": 83.05189514160156, "IS_std": 1.8844469785690308, "FID": 10.164398193359375, "_stamp": 1552783762.891172}
62 | {"itr": 126000, "IS_mean": 84.27763366699219, "IS_std": 0.9329544901847839, "FID": 10.03509521484375, "_stamp": 1552800953.5724175}
63 | {"itr": 128000, "IS_mean": 85.84852600097656, "IS_std": 2.2698562145233154, "FID": 9.91644287109375, "_stamp": 1552818112.227726}
64 | {"itr": 130000, "IS_mean": 87.356689453125, "IS_std": 2.0958640575408936, "FID": 9.771148681640625, "_stamp": 1552837539.995247}
65 | {"itr": 132000, "IS_mean": 88.72562408447266, "IS_std": 1.7551432847976685, "FID": 9.8258056640625, "_stamp": 1552859685.9305944}
66 | {"itr": 134000, "IS_mean": 88.0631103515625, "IS_std": 1.8199039697647095, "FID": 9.957183837890625, "_stamp": 1552880037.5408435}
67 | {"itr": 136000, "IS_mean": 91.50938415527344, "IS_std": 1.9926033020019531, "FID": 9.876556396484375, "_stamp": 1552899854.652669}
68 | {"itr": 138000, "IS_mean": 93.09217834472656, "IS_std": 2.3062736988067627, "FID": 9.908477783203125, "_stamp": 1552921580.958927}
--------------------------------------------------------------------------------
/biggan_pytorch/logs/compare_IS.m:
--------------------------------------------------------------------------------
1 | clc
2 | clear all
3 | close all
4 | fclose all;
5 |
6 |
7 |
8 | %% Get All logs and sort them
9 | s = {};
10 | d = dir();
11 | j = 1;
12 | for i = 1:length(d)
13 | if any(strfind(d(i).name,'.jsonl'))
14 | s = [s; d(i).name];
15 | end
16 | end
17 |
18 |
19 | j = 1;
20 | for i = 1:length(s)
21 | fname = s{i,1};
22 | % Check if the Inception metrics log exists, and if so, plot it
23 | [itr, IS, FID, t] = process_inception_log(fname(1:end - 10), 'log.jsonl');
24 | s{i,2} = itr;
25 | s{i,3} = IS;
26 | s{i,4} = FID;
27 | s{i,5} = max(IS);
28 | s{i,6} = min(FID);
29 | s{i,7} = t;
30 | end
31 | % Sort by Inception Score?
32 | [IS_sorted, IS_index] = sort(cell2mat(s(:,5)));
33 | % Cutoff inception scores below a certain value?
34 | threshold = 22;
35 | IS_index = IS_index(IS_sorted > threshold);
36 |
37 | % Sort by FID?
38 | [FID_sorted, FID_index] = sort(cell2mat(s(:,6)));
39 | % Cutoff also based on IS?
40 | % threshold = 0;
41 | FID_index = FID_index(IS_sorted > threshold);
42 |
43 |
44 |
45 | %% Plot things?
46 | cc = hsv(length(IS_index));
47 | legend1 = {};
48 | legend2 = {};
49 | make_axis=true;%false % Turn this on to see the axis out to 1e6 iterations
50 | for i=1:length(IS_index)
51 | legend1 = [legend1; s{IS_index(i), 1}];
52 | figure(1)
53 | plot(s{IS_index(i),2}, s{IS_index(i),3}, 'color', cc(i,:),'linewidth',2)
54 | hold on;
55 | xlabel('itr'); ylabel('IS');
56 | grid on;
57 | if make_axis
58 | axis([0,1e6,0,80]); % 50% grid on;
59 | end
60 | legend(legend1,'Interpreter','none')
61 | %pause(1) % Turn this on to animate stuff
62 | legend2 = [legend2; s{IS_index(i), 1}];
63 | figure(2)
64 | plot(s{IS_index(i),2}, s{IS_index(i),4}, 'color', cc(i,:),'linewidth',2)
65 | hold on;
66 | xlabel('itr'); ylabel('FID');
67 | j = j + 1;
68 | grid on;
69 | if make_axis
70 | axis([0,1e6,0,50]);% grid on;
71 | end
72 | legend(legend2, 'Interpreter','none')
73 |
74 | end
75 |
76 | %% Quick script to plot IS versus timesteps
77 | if 0
78 | figure(3);
79 | this_index=4;
80 | subplot(2,1,1);
81 | %plot(s{this_index, 2}(2:end), s{this_index, 7}(2:end) - s{this_index, 7}(1:end-1), 'r*');
82 | % xlabel('Iteration');ylabel('\Delta T')
83 | plot(s{this_index, 2}, s{this_index, 7}, 'r*');
84 | xlabel('Iteration');ylabel('T')
85 | subplot(2,1,2);
86 | plot(s{this_index, 2}, s{this_index, 3}, 'r', 'linewidth',2);
87 | xlabel('Iteration'), ylabel('Inception score')
88 | title(s{this_index,1})
89 | end
--------------------------------------------------------------------------------
/biggan_pytorch/logs/metalog.txt:
--------------------------------------------------------------------------------
1 | datetime: 2019-03-18 13:27:59.181225
2 | config: {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 400, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': True, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}
3 | state: {'itr': 137500, 'epoch': 2, 'save_num': 0, 'save_best_num': 1, 'best_IS': 91.509384, 'best_FID': tensor(9.7711, 'config': {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': False, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'BN_sync': False, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}}
4 |
--------------------------------------------------------------------------------
/biggan_pytorch/logs/process_inception_log.m:
--------------------------------------------------------------------------------
1 | function [itr, IS, FID, t] = process_inception_log(fname, which_log)
2 | f = sprintf('%s_%s',fname, which_log);%'G_loss.log');
3 | fid = fopen(f,'r');
4 | itr = [];
5 | IS = [];
6 | FID = [];
7 | t = [];
8 | i = 1;
9 | while ~feof(fid);
10 | s = fgets(fid);
11 | parsed = sscanf(s,'{"itr": %d, "IS_mean": %f, "IS_std": %f, "FID": %f, "_stamp": %f}');
12 | itr(i) = parsed(1);
13 | IS(i) = parsed(2);
14 | FID(i) = parsed(4);
15 | t(i) = parsed(5);
16 | i = i + 1;
17 | end
18 | fclose(fid);
19 | end
--------------------------------------------------------------------------------
/biggan_pytorch/logs/process_training.m:
--------------------------------------------------------------------------------
1 | clc
2 | clear all
3 | close all
4 | fclose all;
5 |
6 |
7 |
8 | %% Get all training logs for a given run
9 | target_dir = '.';
10 | s = {};
11 | nm = {};
12 | d = dir(target_dir);
13 | j = 1;
14 | for i = 1:length(d)
15 | if any(strfind(d(i).name,'.log'))
16 | s = [s; sprintf('%s\\%s', target_dir, d(i).name)];
17 | nm = [nm; d(i).name];
18 | end
19 | end
20 | %% Loop over training logs and acquire data
21 | D_count = 0;
22 | G_count = 0;
23 | for i = 1:length(s)
24 | fname = s{i,1};
25 | fid = fopen(s{i,1},'r');
26 | % Prepare bookkeeping for sv0
27 | if any(strfind(s{i,1},'sv'))
28 | if any(strfind(s{i,1},'G_'))
29 | G_count = G_count +1;
30 | else
31 | D_count = D_count + 1;
32 | end
33 | end
34 | itr = [];
35 | val = [];
36 | j = 1;
37 | while ~feof(fid);
38 | line = fgets(fid);
39 | parsed = sscanf(line, '%d: %e');
40 | itr(j) = parsed(1);
41 | val(j) = parsed(2);
42 | j = j + 1;
43 | end
44 | s{i,2} = itr;
45 | s{i,3} = val;
46 | fclose(fid);
47 | end
48 |
49 | %% Plot SVs and losses
50 | close all;
51 | Gcc = hsv(G_count);
52 | Dcc = hsv(D_count);
53 | gi = 1;
54 | di = 1;
55 | li = 1;
56 | legendG = {};
57 | legendD = {};
58 | legendL = {};
59 | thresh=2; % wavelet denoising threshold
60 | losses = {};
61 | for i=1:length(s)
62 | if any(strfind(s{i,1},'D_loss_real.log')) || any(strfind(s{i,1},'D_loss_fake.log')) || any(strfind(s{i,1},'G_loss.log'))
63 | % Select colors
64 | if any(strfind(s{i,1},'D_loss_real.log'))
65 | color1 = [0.7,0.7,1.0];
66 | color2 = [0, 0, 1];
67 | dlr = {s{i,2}, s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2};
68 | losses = [losses; dlr];
69 | elseif any(strfind(s{i,1},'D_loss_fake.log'))
70 | color1 = [0.7,1.0,0.7];
71 | color2 = [0, 1, 0];
72 | dlf = {s{i,2},s{i,3} wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2};
73 | losses = [losses; dlf];
74 | else % g loss
75 | color1 = [1.0, 0.7,0.7];
76 | color2 = [1, 0, 0];
77 | gl = {s{i,2},s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1 color2};
78 | losses = [losses; gl];
79 | end
80 | figure(1); hold on;
81 | % Plot the unsmoothed losses; we'll plot the smoothed losses later
82 | plot(s{i,2},s{i,3},'color', color1, 'HandleVisibility','off');
83 | legendL = [legendL; nm{i}];
84 | continue
85 | end
86 | if any(strfind(s{i,1},'G_'))
87 | legendG = [legendG; nm{i}];
88 | figure(2); hold on;
89 | plot(s{i,2},s{i,3},'color',Gcc(gi,:),'linewidth',2);
90 | gi = gi+1;
91 | elseif any(strfind(s{i,1},'D_'))
92 | legendD = [legendD; nm{i}];
93 | figure(3); hold on;
94 | plot(s{i,2},s{i,3},'color',Dcc(di,:),'linewidth',2);
95 | di = di+1;
96 | else
97 | s{i,1} % Debug print to show the name of the log that was not processed.
98 | end
99 | end
100 | figure(1);
101 | % Plot the smoothed losses last
102 | for i = 1:3
103 | % plot(losses{i,1}, losses{i,2},'color', losses{i,4}, 'HandleVisibility','off');
104 | plot(losses{i,1},losses{i,3},'color',losses{i,5});
105 | end
106 | legend(legendL, 'Interpreter', 'none'); title('Losses'); xlabel('Generator itr'); ylabel('loss'); axis([0, max(s{end,2}), -1, 4]);
107 |
108 | figure(2); legend(legendG,'Interpreter','none'); title('Singular Values in G'); xlabel('Generator itr'); ylabel('SV0');
109 | figure(3); legend(legendD, 'Interpreter', 'none'); title('Singular Values in D'); xlabel('Generator itr'); ylabel('SV0');
110 |
--------------------------------------------------------------------------------
/biggan_pytorch/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | # DCGAN loss
5 | def loss_dcgan_dis(dis_fake, dis_real):
6 | L1 = torch.mean(F.softplus(-dis_real))
7 | L2 = torch.mean(F.softplus(dis_fake))
8 | return L1, L2
9 |
10 |
11 | def loss_dcgan_gen(dis_fake):
12 | loss = torch.mean(F.softplus(-dis_fake))
13 | return loss
14 |
15 |
16 | # Hinge Loss
17 | def loss_hinge_dis(dis_fake, dis_real):
18 | loss_real = torch.mean(F.relu(1. - dis_real))
19 | loss_fake = torch.mean(F.relu(1. + dis_fake))
20 | return loss_real, loss_fake
21 | # def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss
22 | # loss = torch.mean(F.relu(1. - dis_real))
23 | # loss += torch.mean(F.relu(1. + dis_fake))
24 | # return loss
25 |
26 |
27 | def loss_hinge_gen(dis_fake):
28 | loss = -torch.mean(dis_fake)
29 | return loss
30 |
31 | # Default to hinge loss
32 | generator_loss = loss_hinge_gen
33 | discriminator_loss = loss_hinge_dis
--------------------------------------------------------------------------------
/biggan_pytorch/make_hdf5.py:
--------------------------------------------------------------------------------
1 | """ Convert dataset to HDF5
2 | This script preprocesses a dataset and saves it (images and labels) to
3 | an HDF5 file for improved I/O. """
4 | import os
5 | import sys
6 | from argparse import ArgumentParser
7 | from tqdm import tqdm, trange
8 | import h5py as h5
9 |
10 | import numpy as np
11 | import torch
12 | import torchvision.datasets as dset
13 | import torchvision.transforms as transforms
14 | from torchvision.utils import save_image
15 | import torchvision.transforms as transforms
16 | from torch.utils.data import DataLoader
17 |
18 | import utils
19 |
20 | def prepare_parser():
21 | usage = 'Parser for ImageNet HDF5 scripts.'
22 | parser = ArgumentParser(description=usage)
23 | parser.add_argument(
24 | '--dataset', type=str, default='I128',
25 | help='Which Dataset to train on, out of I128, I256, C10, C100;'
26 | 'Append "_hdf5" to use the hdf5 version for ISLVRC (default: %(default)s)')
27 | parser.add_argument(
28 | '--data_root', type=str, default='data',
29 | help='Default location where data is stored (default: %(default)s)')
30 | parser.add_argument(
31 | '--batch_size', type=int, default=256,
32 | help='Default overall batchsize (default: %(default)s)')
33 | parser.add_argument(
34 | '--num_workers', type=int, default=16,
35 | help='Number of dataloader workers (default: %(default)s)')
36 | parser.add_argument(
37 | '--chunk_size', type=int, default=500,
38 | help='Default overall batchsize (default: %(default)s)')
39 | parser.add_argument(
40 | '--compression', action='store_true', default=False,
41 | help='Use LZF compression? (default: %(default)s)')
42 | return parser
43 |
44 |
45 | def run(config):
46 | if 'hdf5' in config['dataset']:
47 | raise ValueError('Reading from an HDF5 file which you will probably be '
48 | 'about to overwrite! Override this error only if you know '
49 | 'what you''re doing!')
50 | # Get image size
51 | config['image_size'] = utils.imsize_dict[config['dataset']]
52 |
53 | # Update compression entry
54 | config['compression'] = 'lzf' if config['compression'] else None #No compression; can also use 'lzf'
55 |
56 | # Get dataset
57 | kwargs = {'num_workers': config['num_workers'], 'pin_memory': False, 'drop_last': False}
58 | train_loader = utils.get_data_loaders(dataset=config['dataset'],
59 | batch_size=config['batch_size'],
60 | shuffle=False,
61 | data_root=config['data_root'],
62 | use_multiepoch_sampler=False,
63 | **kwargs)[0]
64 |
65 | # HDF5 supports chunking and compression. You may want to experiment
66 | # with different chunk sizes to see how it runs on your machines.
67 | # Chunk Size/compression Read speed @ 256x256 Read speed @ 128x128 Filesize @ 128x128 Time to write @128x128
68 | # 1 / None 20/s
69 | # 500 / None ramps up to 77/s 102/s 61GB 23min
70 | # 500 / LZF 8/s 56GB 23min
71 | # 1000 / None 78/s
72 | # 5000 / None 81/s
73 | # auto:(125,1,16,32) / None 11/s 61GB
74 |
75 | print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (config['dataset'], config['chunk_size'], config['compression']))
76 | # Loop over train loader
77 | for i,(x,y) in enumerate(tqdm(train_loader)):
78 | # Stick X into the range [0, 255] since it's coming from the train loader
79 | x = (255 * ((x + 1) / 2.0)).byte().numpy()
80 | # Numpyify y
81 | y = y.numpy()
82 | # If we're on the first batch, prepare the hdf5
83 | if i==0:
84 | with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f:
85 | print('Producing dataset of len %d' % len(train_loader.dataset))
86 | imgs_dset = f.create_dataset('imgs', x.shape,dtype='uint8', maxshape=(len(train_loader.dataset), 3, config['image_size'], config['image_size']),
87 | chunks=(config['chunk_size'], 3, config['image_size'], config['image_size']), compression=config['compression'])
88 | print('Image chunks chosen as ' + str(imgs_dset.chunks))
89 | imgs_dset[...] = x
90 | labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(train_loader.dataset),), chunks=(config['chunk_size'],), compression=config['compression'])
91 | print('Label chunks chosen as ' + str(labels_dset.chunks))
92 | labels_dset[...] = y
93 | # Else append to the hdf5
94 | else:
95 | with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f:
96 | f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0)
97 | f['imgs'][-x.shape[0]:] = x
98 | f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0)
99 | f['labels'][-y.shape[0]:] = y
100 |
101 |
102 | def main():
103 | # parse command line and run
104 | parser = prepare_parser()
105 | config = vars(parser.parse_args())
106 | print(config)
107 | run(config)
108 |
109 | if __name__ == '__main__':
110 | main()
--------------------------------------------------------------------------------
/biggan_pytorch/sample.py:
--------------------------------------------------------------------------------
1 | ''' Sample
2 | This script loads a pretrained net and a weightsfile and sample '''
3 | import functools
4 | import math
5 | import numpy as np
6 | from tqdm import tqdm, trange
7 |
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import init
12 | import torch.optim as optim
13 | import torch.nn.functional as F
14 | from torch.nn import Parameter as P
15 | import torchvision
16 |
17 | # Import my stuff
18 | import inception_utils
19 | import utils
20 | import losses
21 |
22 |
23 |
24 | def run(config):
25 | # Prepare state dict, which holds things like epoch # and itr #
26 | state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
27 | 'best_IS': 0, 'best_FID': 999999, 'config': config}
28 |
29 | # Optionally, get the configuration from the state dict. This allows for
30 | # recovery of the config provided only a state dict and experiment name,
31 | # and can be convenient for writing less verbose sample shell scripts.
32 | if config['config_from_name']:
33 | utils.load_weights(None, None, state_dict, config['weights_root'],
34 | config['experiment_name'], config['load_weights'], None,
35 | strict=False, load_optim=False)
36 | # Ignore items which we might want to overwrite from the command line
37 | for item in state_dict['config']:
38 | if item not in ['z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode']:
39 | config[item] = state_dict['config'][item]
40 |
41 | # update config (see train.py for explanation)
42 | config['resolution'] = utils.imsize_dict[config['dataset']]
43 | config['n_classes'] = utils.nclass_dict[config['dataset']]
44 | config['G_activation'] = utils.activation_dict[config['G_nl']]
45 | config['D_activation'] = utils.activation_dict[config['D_nl']]
46 | config = utils.update_config_roots(config)
47 | config['skip_init'] = True
48 | config['no_optim'] = True
49 | device = 'cuda'
50 |
51 | # Seed RNG
52 | utils.seed_rng(config['seed'])
53 |
54 | # Setup cudnn.benchmark for free speed
55 | torch.backends.cudnn.benchmark = True
56 |
57 | # Import the model--this line allows us to dynamically select different files.
58 | model = __import__(config['model'])
59 | experiment_name = (config['experiment_name'] if config['experiment_name']
60 | else utils.name_from_config(config))
61 | print('Experiment name is %s' % experiment_name)
62 |
63 | G = model.Generator(**config).cuda()
64 | utils.count_parameters(G)
65 |
66 | # Load weights
67 | print('Loading weights...')
68 | # Here is where we deal with the ema--load ema weights or load normal weights
69 | utils.load_weights(G if not (config['use_ema']) else None, None, state_dict,
70 | config['weights_root'], experiment_name, config['load_weights'],
71 | G if config['ema'] and config['use_ema'] else None,
72 | strict=False, load_optim=False)
73 | # Update batch size setting used for G
74 | G_batch_size = max(config['G_batch_size'], config['batch_size'])
75 | z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
76 | device=device, fp16=config['G_fp16'],
77 | z_var=config['z_var'])
78 |
79 | if config['G_eval_mode']:
80 | print('Putting G in eval mode..')
81 | G.eval()
82 | else:
83 | print('G is in %s mode...' % ('training' if G.training else 'eval'))
84 |
85 | #Sample function
86 | sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
87 | if config['accumulate_stats']:
88 | print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations'])
89 | utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
90 | config['num_standing_accumulations'])
91 |
92 |
93 | # Sample a number of images and save them to an NPZ, for use with TF-Inception
94 | if config['sample_npz']:
95 | # Lists to hold images and labels for images
96 | x, y = [], []
97 | print('Sampling %d images and saving them to npz...' % config['sample_num_npz'])
98 | for i in trange(int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
99 | with torch.no_grad():
100 | images, labels = sample()
101 | x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
102 | y += [labels.cpu().numpy()]
103 | x = np.concatenate(x, 0)[:config['sample_num_npz']]
104 | y = np.concatenate(y, 0)[:config['sample_num_npz']]
105 | print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
106 | npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name)
107 | print('Saving npz to %s...' % npz_filename)
108 | np.savez(npz_filename, **{'x' : x, 'y' : y})
109 |
110 | # Prepare sample sheets
111 | if config['sample_sheets']:
112 | print('Preparing conditional sample sheets...')
113 | utils.sample_sheet(G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
114 | num_classes=config['n_classes'],
115 | samples_per_class=10, parallel=config['parallel'],
116 | samples_root=config['samples_root'],
117 | experiment_name=experiment_name,
118 | folder_number=config['sample_sheet_folder_num'],
119 | z_=z_,)
120 | # Sample interp sheets
121 | if config['sample_interps']:
122 | print('Preparing interp sheets...')
123 | for fix_z, fix_y in zip([False, False, True], [False, True, False]):
124 | utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8,
125 | num_classes=config['n_classes'],
126 | parallel=config['parallel'],
127 | samples_root=config['samples_root'],
128 | experiment_name=experiment_name,
129 | folder_number=config['sample_sheet_folder_num'],
130 | sheet_number=0,
131 | fix_z=fix_z, fix_y=fix_y, device='cuda')
132 | # Sample random sheet
133 | if config['sample_random']:
134 | print('Preparing random sample sheet...')
135 | images, labels = sample()
136 | torchvision.utils.save_image(images.float(),
137 | '%s/%s/random_samples.jpg' % (config['samples_root'], experiment_name),
138 | nrow=int(G_batch_size**0.5),
139 | normalize=True)
140 |
141 | # Get Inception Score and FID
142 | get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid'])
143 | # Prepare a simple function get metrics that we use for trunc curves
144 | def get_metrics():
145 | sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
146 | IS_mean, IS_std, FID = get_inception_metrics(sample, config['num_inception_images'], num_splits=10, prints=False)
147 | # Prepare output string
148 | outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema')
149 | outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training')
150 | outstring += 'with noise variance %3.3f, ' % z_.var
151 | outstring += 'over %d images, ' % config['num_inception_images']
152 | if config['accumulate_stats'] or not config['G_eval_mode']:
153 | outstring += 'with batch size %d, ' % G_batch_size
154 | if config['accumulate_stats']:
155 | outstring += 'using %d standing stat accumulations, ' % config['num_standing_accumulations']
156 | outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID)
157 | print(outstring)
158 | if config['sample_inception_metrics']:
159 | print('Calculating Inception metrics...')
160 | get_metrics()
161 |
162 | # Sample truncation curve stuff. This is basically the same as the inception metrics code
163 | if config['sample_trunc_curves']:
164 | start, step, end = [float(item) for item in config['sample_trunc_curves'].split('_')]
165 | print('Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end))
166 | for var in np.arange(start, end + step, step):
167 | z_.var = var
168 | # Optionally comment this out if you want to run with standing stats
169 | # accumulated at one z variance setting
170 | if config['accumulate_stats']:
171 | utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
172 | config['num_standing_accumulations'])
173 | get_metrics()
174 | def main():
175 | # parse command line and run
176 | parser = utils.prepare_parser()
177 | parser = utils.add_sample_parser(parser)
178 | config = vars(parser.parse_args())
179 | print(config)
180 | run(config)
181 |
182 | if __name__ == '__main__':
183 | main()
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/launch_BigGAN_bs256x8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \
4 | --num_G_accumulations 8 --num_D_accumulations 8 \
5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \
6 | --G_attn 64 --D_attn 64 \
7 | --G_nl inplace_relu --D_nl inplace_relu \
8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \
9 | --G_ortho 0.0 \
10 | --G_shared \
11 | --G_init ortho --D_init ortho \
12 | --hier --dim_z 120 --shared_dim 128 \
13 | --G_eval_mode \
14 | --G_ch 96 --D_ch 96 \
15 | --ema --use_ema --ema_start 20000 \
16 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \
17 | --use_multiepoch_sampler \
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/launch_BigGAN_bs512x4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 512 --load_in_mem \
4 | --num_G_accumulations 4 --num_D_accumulations 4 \
5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \
6 | --G_attn 64 --D_attn 64 \
7 | --G_nl inplace_relu --D_nl inplace_relu \
8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \
9 | --G_ortho 0.0 \
10 | --G_shared \
11 | --G_init ortho --D_init ortho \
12 | --hier --dim_z 120 --shared_dim 128 \
13 | --G_eval_mode \
14 | --G_ch 96 --D_ch 96 \
15 | --ema --use_ema --ema_start 20000 \
16 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \
17 | --use_multiepoch_sampler \
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/launch_BigGAN_ch64_bs256x8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \
4 | --num_G_accumulations 8 --num_D_accumulations 8 \
5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \
6 | --G_attn 64 --D_attn 64 \
7 | --G_nl inplace_relu --D_nl inplace_relu \
8 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \
9 | --G_ortho 0.0 \
10 | --G_shared \
11 | --G_init ortho --D_init ortho \
12 | --hier --dim_z 120 --shared_dim 128 \
13 | --G_eval_mode \
14 | --G_ch 64 --G_ch 64 \
15 | --ema --use_ema --ema_start 20000 \
16 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \
17 | --use_multiepoch_sampler
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/launch_BigGAN_deep.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --model BigGANdeep \
4 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \
5 | --num_G_accumulations 8 --num_D_accumulations 8 \
6 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \
7 | --G_attn 64 --D_attn 64 \
8 | --G_ch 128 --D_ch 128 \
9 | --G_depth 2 --D_depth 2 \
10 | --G_nl inplace_relu --D_nl inplace_relu \
11 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \
12 | --G_ortho 0.0 \
13 | --G_shared \
14 | --G_init ortho --D_init ortho \
15 | --hier --dim_z 128 --shared_dim 128 \
16 | --ema --use_ema --ema_start 20000 --G_eval_mode \
17 | --test_every 2000 --save_every 500 --num_best_copies 5 --num_save_copies 2 --seed 0 \
18 | --use_multiepoch_sampler \
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/launch_SAGAN_bs128x2_ema.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 128 \
4 | --num_G_accumulations 2 --num_D_accumulations 2 \
5 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \
6 | --G_attn 64 --D_attn 64 \
7 | --G_nl relu --D_nl relu \
8 | --SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \
9 | --G_ortho 0.0 \
10 | --G_init xavier --D_init xavier \
11 | --ema --use_ema --ema_start 2000 --G_eval_mode \
12 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \
13 | --name_suffix SAGAN_ema \
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/launch_SNGAN.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 64 \
4 | --num_G_accumulations 1 --num_D_accumulations 1 \
5 | --num_D_steps 5 --G_lr 2e-4 --D_lr 2e-4 --D_B2 0.900 --G_B2 0.900 \
6 | --G_attn 0 --D_attn 0 \
7 | --G_nl relu --D_nl relu \
8 | --SN_eps 1e-8 --BN_eps 1e-5 --adam_eps 1e-8 \
9 | --G_ortho 0.0 \
10 | --D_thin \
11 | --G_init xavier --D_init xavier \
12 | --G_eval_mode \
13 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \
14 | --name_suffix SNGAN \
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/launch_cifar_ema.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=0,1 python train.py \
3 | --shuffle --batch_size 50 --parallel \
4 | --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \
5 | --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \
6 | --dataset C10 \
7 | --G_ortho 0.0 \
8 | --G_attn 0 --D_attn 0 \
9 | --G_init N02 --D_init N02 \
10 | --ema --use_ema --ema_start 1000 \
11 | --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/sample_BigGAN_bs256x8.sh:
--------------------------------------------------------------------------------
1 | # use z_var to change the variance of z for all the sampling
2 | # use --mybn --accumulate_stats --num_standing_accumulations 32 to
3 | # use running stats
4 | python sample.py \
5 | --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 \
6 | --num_G_accumulations 8 --num_D_accumulations 8 \
7 | --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \
8 | --G_attn 64 --D_attn 64 \
9 | --G_ch 96 --D_ch 96 \
10 | --G_nl inplace_relu --D_nl inplace_relu \
11 | --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \
12 | --G_ortho 0.0 \
13 | --G_shared \
14 | --G_init ortho --D_init ortho --skip_init \
15 | --hier --dim_z 120 --shared_dim 128 \
16 | --ema --ema_start 20000 \
17 | --use_multiepoch_sampler \
18 | --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \
19 | --skip_init --G_batch_size 512 --use_ema --G_eval_mode --sample_trunc_curves 0.05_0.05_1.0 \
20 | --sample_inception_metrics --sample_npz --sample_random --sample_sheets --sample_interps
21 |
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/sample_cifar_ema.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=0,1 python sample.py \
3 | --shuffle --batch_size 50 --G_batch_size 256 --parallel \
4 | --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 \
5 | --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 \
6 | --dataset C10 \
7 | --G_ortho 0.0 \
8 | --G_attn 0 --D_attn 0 \
9 | --G_init N02 --D_init N02 \
10 | --ema --use_ema --ema_start 1000 \
11 | --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --seed 0
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/utils/duplicate.sh:
--------------------------------------------------------------------------------
1 | #duplicate.sh
2 | source=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0
3 | target=BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs256_Glr1.0e-04_Dlr4.0e-04_Gnlinplace_relu_Dnlinplace_relu_Ginitxavier_Dinitxavier_Gshared_alex0A
4 | logs_root=logs
5 | weights_root=weights
6 | echo "copying ${source} to ${target}"
7 | cp -r ${logs_root}/${source} ${logs_root}/${target}
8 | cp ${logs_root}/${source}_log.jsonl ${logs_root}/${target}_log.jsonl
9 | cp ${weights_root}/${source}_G.pth ${weights_root}/${target}_G.pth
10 | cp ${weights_root}/${source}_G_ema.pth ${weights_root}/${target}_G_ema.pth
11 | cp ${weights_root}/${source}_D.pth ${weights_root}/${target}_D.pth
12 | cp ${weights_root}/${source}_G_optim.pth ${weights_root}/${target}_G_optim.pth
13 | cp ${weights_root}/${source}_D_optim.pth ${weights_root}/${target}_D_optim.pth
14 | cp ${weights_root}/${source}_state_dict.pth ${weights_root}/${target}_state_dict.pth
--------------------------------------------------------------------------------
/biggan_pytorch/scripts/utils/prepare_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python make_hdf5.py --dataset I128 --batch_size 256 --data_root data
3 | python calculate_inception_moments.py --dataset I128_hdf5 --data_root data
--------------------------------------------------------------------------------
/biggan_pytorch/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/biggan_pytorch/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 | # _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'ssum', 'sum_size'])
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input, gain=None, bias=None):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | out = F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 | if gain is not None:
55 | out = out + gain
56 | if bias is not None:
57 | out = out + bias
58 | return out
59 |
60 | # Resize the input to (B, C, -1).
61 | input_shape = input.size()
62 | # print(input_shape)
63 | input = input.view(input.size(0), input.size(1), -1)
64 |
65 | # Compute the sum and square-sum.
66 | sum_size = input.size(0) * input.size(2)
67 | input_sum = _sum_ft(input)
68 | input_ssum = _sum_ft(input ** 2)
69 | # Reduce-and-broadcast the statistics.
70 | # print('it begins')
71 | if self._parallel_id == 0:
72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
73 | else:
74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
75 | # if self._parallel_id == 0:
76 | # # print('here')
77 | # sum, ssum, num = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
78 | # else:
79 | # # print('there')
80 | # sum, ssum, num = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
81 |
82 | # print('how2')
83 | # num = sum_size
84 | # print('Sum: %f, ssum: %f, sumsize: %f, insum: %f' %(float(sum.sum().cpu()), float(ssum.sum().cpu()), float(sum_size), float(input_sum.sum().cpu())))
85 | # Fix the graph
86 | # sum = (sum.detach() - input_sum.detach()) + input_sum
87 | # ssum = (ssum.detach() - input_ssum.detach()) + input_ssum
88 |
89 | # mean = sum / num
90 | # var = ssum / num - mean ** 2
91 | # # var = (ssum - mean * sum) / num
92 | # inv_std = torch.rsqrt(var + self.eps)
93 |
94 | # Compute the output.
95 | if gain is not None:
96 | # print('gaining')
97 | # scale = _unsqueeze_ft(inv_std) * gain.squeeze(-1)
98 | # shift = _unsqueeze_ft(mean) * scale - bias.squeeze(-1)
99 | # output = input * scale - shift
100 | output = (input - _unsqueeze_ft(mean)) * (_unsqueeze_ft(inv_std) * gain.squeeze(-1)) + bias.squeeze(-1)
101 | elif self.affine:
102 | # MJY:: Fuse the multiplication for speed.
103 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
104 | else:
105 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
106 |
107 | # Reshape it.
108 | return output.view(input_shape)
109 |
110 | def __data_parallel_replicate__(self, ctx, copy_id):
111 | self._is_parallel = True
112 | self._parallel_id = copy_id
113 |
114 | # parallel_id == 0 means master device.
115 | if self._parallel_id == 0:
116 | ctx.sync_master = self._sync_master
117 | else:
118 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
119 |
120 | def _data_parallel_master(self, intermediates):
121 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
122 |
123 | # Always using same "device order" makes the ReduceAdd operation faster.
124 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
125 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
126 |
127 | to_reduce = [i[1][:2] for i in intermediates]
128 | to_reduce = [j for i in to_reduce for j in i] # flatten
129 | target_gpus = [i[1].sum.get_device() for i in intermediates]
130 |
131 | sum_size = sum([i[1].sum_size for i in intermediates])
132 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
133 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
134 |
135 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
136 | # print('a')
137 | # print(type(sum_), type(ssum), type(sum_size), sum_.shape, ssum.shape, sum_size)
138 | # broadcasted = Broadcast.apply(target_gpus, sum_, ssum, torch.tensor(sum_size).float().to(sum_.device))
139 | # print('b')
140 | outputs = []
141 | for i, rec in enumerate(intermediates):
142 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
143 | # outputs.append((rec[0], _MasterMessage(*broadcasted[i*3:i*3+3])))
144 |
145 | return outputs
146 |
147 | def _compute_mean_std(self, sum_, ssum, size):
148 | """Compute the mean and standard-deviation with sum and square-sum. This method
149 | also maintains the moving average on the master device."""
150 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
151 | mean = sum_ / size
152 | sumvar = ssum - sum_ * mean
153 | unbias_var = sumvar / (size - 1)
154 | bias_var = sumvar / size
155 |
156 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
157 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
158 | return mean, torch.rsqrt(bias_var + self.eps)
159 | # return mean, bias_var.clamp(self.eps) ** -0.5
160 |
161 |
162 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
163 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
164 | mini-batch.
165 |
166 | .. math::
167 |
168 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
169 |
170 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
171 | standard-deviation are reduced across all devices during training.
172 |
173 | For example, when one uses `nn.DataParallel` to wrap the network during
174 | training, PyTorch's implementation normalize the tensor on each device using
175 | the statistics only on that device, which accelerated the computation and
176 | is also easy to implement, but the statistics might be inaccurate.
177 | Instead, in this synchronized version, the statistics will be computed
178 | over all training samples distributed on multiple devices.
179 |
180 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
181 | as the built-in PyTorch implementation.
182 |
183 | The mean and standard-deviation are calculated per-dimension over
184 | the mini-batches and gamma and beta are learnable parameter vectors
185 | of size C (where C is the input size).
186 |
187 | During training, this layer keeps a running estimate of its computed mean
188 | and variance. The running sum is kept with a default momentum of 0.1.
189 |
190 | During evaluation, this running mean/variance is used for normalization.
191 |
192 | Because the BatchNorm is done over the `C` dimension, computing statistics
193 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
194 |
195 | Args:
196 | num_features: num_features from an expected input of size
197 | `batch_size x num_features [x width]`
198 | eps: a value added to the denominator for numerical stability.
199 | Default: 1e-5
200 | momentum: the value used for the running_mean and running_var
201 | computation. Default: 0.1
202 | affine: a boolean value that when set to ``True``, gives the layer learnable
203 | affine parameters. Default: ``True``
204 |
205 | Shape:
206 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
207 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
208 |
209 | Examples:
210 | >>> # With Learnable Parameters
211 | >>> m = SynchronizedBatchNorm1d(100)
212 | >>> # Without Learnable Parameters
213 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
214 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
215 | >>> output = m(input)
216 | """
217 |
218 | def _check_input_dim(self, input):
219 | if input.dim() != 2 and input.dim() != 3:
220 | raise ValueError('expected 2D or 3D input (got {}D input)'
221 | .format(input.dim()))
222 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
223 |
224 |
225 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
226 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
227 | of 3d inputs
228 |
229 | .. math::
230 |
231 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
232 |
233 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
234 | standard-deviation are reduced across all devices during training.
235 |
236 | For example, when one uses `nn.DataParallel` to wrap the network during
237 | training, PyTorch's implementation normalize the tensor on each device using
238 | the statistics only on that device, which accelerated the computation and
239 | is also easy to implement, but the statistics might be inaccurate.
240 | Instead, in this synchronized version, the statistics will be computed
241 | over all training samples distributed on multiple devices.
242 |
243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
244 | as the built-in PyTorch implementation.
245 |
246 | The mean and standard-deviation are calculated per-dimension over
247 | the mini-batches and gamma and beta are learnable parameter vectors
248 | of size C (where C is the input size).
249 |
250 | During training, this layer keeps a running estimate of its computed mean
251 | and variance. The running sum is kept with a default momentum of 0.1.
252 |
253 | During evaluation, this running mean/variance is used for normalization.
254 |
255 | Because the BatchNorm is done over the `C` dimension, computing statistics
256 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
257 |
258 | Args:
259 | num_features: num_features from an expected input of
260 | size batch_size x num_features x height x width
261 | eps: a value added to the denominator for numerical stability.
262 | Default: 1e-5
263 | momentum: the value used for the running_mean and running_var
264 | computation. Default: 0.1
265 | affine: a boolean value that when set to ``True``, gives the layer learnable
266 | affine parameters. Default: ``True``
267 |
268 | Shape:
269 | - Input: :math:`(N, C, H, W)`
270 | - Output: :math:`(N, C, H, W)` (same shape as input)
271 |
272 | Examples:
273 | >>> # With Learnable Parameters
274 | >>> m = SynchronizedBatchNorm2d(100)
275 | >>> # Without Learnable Parameters
276 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
277 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
278 | >>> output = m(input)
279 | """
280 |
281 | def _check_input_dim(self, input):
282 | if input.dim() != 4:
283 | raise ValueError('expected 4D input (got {}D input)'
284 | .format(input.dim()))
285 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
286 |
287 |
288 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
289 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
290 | of 4d inputs
291 |
292 | .. math::
293 |
294 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
295 |
296 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
297 | standard-deviation are reduced across all devices during training.
298 |
299 | For example, when one uses `nn.DataParallel` to wrap the network during
300 | training, PyTorch's implementation normalize the tensor on each device using
301 | the statistics only on that device, which accelerated the computation and
302 | is also easy to implement, but the statistics might be inaccurate.
303 | Instead, in this synchronized version, the statistics will be computed
304 | over all training samples distributed on multiple devices.
305 |
306 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
307 | as the built-in PyTorch implementation.
308 |
309 | The mean and standard-deviation are calculated per-dimension over
310 | the mini-batches and gamma and beta are learnable parameter vectors
311 | of size C (where C is the input size).
312 |
313 | During training, this layer keeps a running estimate of its computed mean
314 | and variance. The running sum is kept with a default momentum of 0.1.
315 |
316 | During evaluation, this running mean/variance is used for normalization.
317 |
318 | Because the BatchNorm is done over the `C` dimension, computing statistics
319 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
320 | or Spatio-temporal BatchNorm
321 |
322 | Args:
323 | num_features: num_features from an expected input of
324 | size batch_size x num_features x depth x height x width
325 | eps: a value added to the denominator for numerical stability.
326 | Default: 1e-5
327 | momentum: the value used for the running_mean and running_var
328 | computation. Default: 0.1
329 | affine: a boolean value that when set to ``True``, gives the layer learnable
330 | affine parameters. Default: ``True``
331 |
332 | Shape:
333 | - Input: :math:`(N, C, D, H, W)`
334 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
335 |
336 | Examples:
337 | >>> # With Learnable Parameters
338 | >>> m = SynchronizedBatchNorm3d(100)
339 | >>> # Without Learnable Parameters
340 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
341 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
342 | >>> output = m(input)
343 | """
344 |
345 | def _check_input_dim(self, input):
346 | if input.dim() != 5:
347 | raise ValueError('expected 5D input (got {}D input)'
348 | .format(input.dim()))
349 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
--------------------------------------------------------------------------------
/biggan_pytorch/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNormReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/biggan_pytorch/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/biggan_pytorch/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/biggan_pytorch/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/biggan_pytorch/train.py:
--------------------------------------------------------------------------------
1 | """ BigGAN: The Authorized Unofficial PyTorch release
2 | Code by A. Brock and A. Andonian
3 | This code is an unofficial reimplementation of
4 | "Large-Scale GAN Training for High Fidelity Natural Image Synthesis,"
5 | by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096).
6 |
7 | Let's go.
8 | """
9 |
10 | import os
11 | import functools
12 | import math
13 | import numpy as np
14 | from tqdm import tqdm, trange
15 |
16 |
17 | import torch
18 | import torch.nn as nn
19 | from torch.nn import init
20 | import torch.optim as optim
21 | import torch.nn.functional as F
22 | from torch.nn import Parameter as P
23 | import torchvision
24 |
25 | # Import my stuff
26 | import inception_utils
27 | import utils
28 | import losses
29 | import train_fns
30 | from sync_batchnorm import patch_replication_callback
31 |
32 | # The main training file. Config is a dictionary specifying the configuration
33 | # of this training run.
34 | def run(config):
35 |
36 | # Update the config dict as necessary
37 | # This is for convenience, to add settings derived from the user-specified
38 | # configuration into the config-dict (e.g. inferring the number of classes
39 | # and size of the images from the dataset, passing in a pytorch object
40 | # for the activation specified as a string)
41 | config['resolution'] = utils.imsize_dict[config['dataset']]
42 | config['n_classes'] = utils.nclass_dict[config['dataset']]
43 | config['G_activation'] = utils.activation_dict[config['G_nl']]
44 | config['D_activation'] = utils.activation_dict[config['D_nl']]
45 | # By default, skip init if resuming training.
46 | if config['resume']:
47 | print('Skipping initialization for training resumption...')
48 | config['skip_init'] = True
49 | config = utils.update_config_roots(config)
50 | device = 'cuda'
51 |
52 | # Seed RNG
53 | utils.seed_rng(config['seed'])
54 |
55 | # Prepare root folders if necessary
56 | utils.prepare_root(config)
57 |
58 | # Setup cudnn.benchmark for free speed
59 | torch.backends.cudnn.benchmark = True
60 |
61 | # Import the model--this line allows us to dynamically select different files.
62 | model = __import__(config['model'])
63 | experiment_name = (config['experiment_name'] if config['experiment_name']
64 | else utils.name_from_config(config))
65 | print('Experiment name is %s' % experiment_name)
66 |
67 | # Next, build the model
68 | G = model.Generator(**config).to(device)
69 | D = model.Discriminator(**config).to(device)
70 |
71 | # If using EMA, prepare it
72 | if config['ema']:
73 | print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
74 | G_ema = model.Generator(**{**config, 'skip_init':True,
75 | 'no_optim': True}).to(device)
76 | ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
77 | else:
78 | G_ema, ema = None, None
79 |
80 | # FP16?
81 | if config['G_fp16']:
82 | print('Casting G to float16...')
83 | G = G.half()
84 | if config['ema']:
85 | G_ema = G_ema.half()
86 | if config['D_fp16']:
87 | print('Casting D to fp16...')
88 | D = D.half()
89 | # Consider automatically reducing SN_eps?
90 | GD = model.G_D(G, D)
91 | print(G)
92 | print(D)
93 | print('Number of params in G: {} D: {}'.format(
94 | *[sum([p.data.nelement() for p in net.parameters()]) for net in [G,D]]))
95 | # Prepare state dict, which holds things like epoch # and itr #
96 | state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
97 | 'best_IS': 0, 'best_FID': 999999, 'config': config}
98 |
99 | # If loading from a pre-trained model, load weights
100 | if config['resume']:
101 | print('Loading weights...')
102 | utils.load_weights(G, D, state_dict,
103 | config['weights_root'], experiment_name,
104 | config['load_weights'] if config['load_weights'] else None,
105 | G_ema if config['ema'] else None)
106 |
107 | # If parallel, parallelize the GD module
108 | if config['parallel']:
109 | GD = nn.DataParallel(GD)
110 | if config['cross_replica']:
111 | patch_replication_callback(GD)
112 |
113 | # Prepare loggers for stats; metrics holds test metrics,
114 | # lmetrics holds any desired training metrics.
115 | test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
116 | experiment_name)
117 | train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
118 | print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
119 | test_log = utils.MetricsLogger(test_metrics_fname,
120 | reinitialize=(not config['resume']))
121 | print('Training Metrics will be saved to {}'.format(train_metrics_fname))
122 | train_log = utils.MyLogger(train_metrics_fname,
123 | reinitialize=(not config['resume']),
124 | logstyle=config['logstyle'])
125 | # Write metadata
126 | utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
127 | # Prepare data; the Discriminator's batch size is all that needs to be passed
128 | # to the dataloader, as G doesn't require dataloading.
129 | # Note that at every loader iteration we pass in enough data to complete
130 | # a full D iteration (regardless of number of D steps and accumulations)
131 | D_batch_size = (config['batch_size'] * config['num_D_steps']
132 | * config['num_D_accumulations'])
133 | loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
134 | 'start_itr': state_dict['itr']})
135 |
136 | # Prepare inception metrics: FID and IS
137 | get_inception_metrics = inception_utils.prepare_inception_metrics(config['dataset'], config['parallel'], config['no_fid'])
138 |
139 | # Prepare noise and randomly sampled label arrays
140 | # Allow for different batch sizes in G
141 | G_batch_size = max(config['G_batch_size'], config['batch_size'])
142 | z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
143 | device=device, fp16=config['G_fp16'])
144 | # Prepare a fixed z & y to see individual sample evolution throghout training
145 | fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
146 | config['n_classes'], device=device,
147 | fp16=config['G_fp16'])
148 | fixed_z.sample_()
149 | fixed_y.sample_()
150 | # Loaders are loaded, prepare the training function
151 | if config['which_train_fn'] == 'GAN':
152 | train = train_fns.GAN_training_function(G, D, GD, z_, y_,
153 | ema, state_dict, config)
154 | # Else, assume debugging and use the dummy train fn
155 | else:
156 | train = train_fns.dummy_training_function()
157 | # Prepare Sample function for use with inception metrics
158 | sample = functools.partial(utils.sample,
159 | G=(G_ema if config['ema'] and config['use_ema']
160 | else G),
161 | z_=z_, y_=y_, config=config)
162 |
163 | print('Beginning training at epoch %d...' % state_dict['epoch'])
164 | # Train for specified number of epochs, although we mostly track G iterations.
165 | for epoch in range(state_dict['epoch'], config['num_epochs']):
166 | # Which progressbar to use? TQDM or my own?
167 | if config['pbar'] == 'mine':
168 | pbar = utils.progress(loaders[0],displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
169 | else:
170 | pbar = tqdm(loaders[0])
171 | for i, (x, y) in enumerate(pbar):
172 | # Increment the iteration counter
173 | state_dict['itr'] += 1
174 | # Make sure G and D are in training mode, just in case they got set to eval
175 | # For D, which typically doesn't have BN, this shouldn't matter much.
176 | G.train()
177 | D.train()
178 | if config['ema']:
179 | G_ema.train()
180 | if config['D_fp16']:
181 | x, y = x.to(device).half(), y.to(device)
182 | else:
183 | x, y = x.to(device), y.to(device)
184 | metrics = train(x, y)
185 | train_log.log(itr=int(state_dict['itr']), **metrics)
186 |
187 | # Every sv_log_interval, log singular values
188 | if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
189 | train_log.log(itr=int(state_dict['itr']),
190 | **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
191 |
192 | # If using my progbar, print metrics.
193 | if config['pbar'] == 'mine':
194 | print(', '.join(['itr: %d' % state_dict['itr']]
195 | + ['%s : %+4.3f' % (key, metrics[key])
196 | for key in metrics]), end=' ')
197 |
198 | # Save weights and copies as configured at specified interval
199 | if not (state_dict['itr'] % config['save_every']):
200 | if config['G_eval_mode']:
201 | print('Switchin G to eval mode...')
202 | G.eval()
203 | if config['ema']:
204 | G_ema.eval()
205 | train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
206 | state_dict, config, experiment_name)
207 |
208 | # Test every specified interval
209 | if not (state_dict['itr'] % config['test_every']):
210 | if config['G_eval_mode']:
211 | print('Switchin G to eval mode...')
212 | G.eval()
213 | train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
214 | get_inception_metrics, experiment_name, test_log)
215 | # Increment epoch counter at end of epoch
216 | state_dict['epoch'] += 1
217 |
218 |
219 | def main():
220 | # parse command line and run
221 | parser = utils.prepare_parser()
222 | config = vars(parser.parse_args())
223 | print(config)
224 | run(config)
225 |
226 | if __name__ == '__main__':
227 | main()
--------------------------------------------------------------------------------
/biggan_pytorch/train_fns.py:
--------------------------------------------------------------------------------
1 | ''' train_fns.py
2 | Functions for the main loop of training different conditional image models
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torchvision
7 | import os
8 |
9 | import utils
10 | import losses
11 |
12 |
13 | # Dummy training function for debugging
14 | def dummy_training_function():
15 | def train(x, y):
16 | return {}
17 | return train
18 |
19 |
20 | def GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config):
21 | def train(x, y):
22 | G.optim.zero_grad()
23 | D.optim.zero_grad()
24 | # How many chunks to split x and y into?
25 | x = torch.split(x, config['batch_size'])
26 | y = torch.split(y, config['batch_size'])
27 | counter = 0
28 |
29 | # Optionally toggle D and G's "require_grad"
30 | if config['toggle_grads']:
31 | utils.toggle_grad(D, True)
32 | utils.toggle_grad(G, False)
33 |
34 | for step_index in range(config['num_D_steps']):
35 | # If accumulating gradients, loop multiple times before an optimizer step
36 | D.optim.zero_grad()
37 | for accumulation_index in range(config['num_D_accumulations']):
38 | z_.sample_()
39 | y_.sample_()
40 | D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']],
41 | x[counter], y[counter], train_G=False,
42 | split_D=config['split_D'])
43 |
44 | # Compute components of D's loss, average them, and divide by
45 | # the number of gradient accumulations
46 | D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
47 | D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations'])
48 | D_loss.backward()
49 | counter += 1
50 |
51 | # Optionally apply ortho reg in D
52 | if config['D_ortho'] > 0.0:
53 | # Debug print to indicate we're using ortho reg in D.
54 | print('using modified ortho reg in D')
55 | utils.ortho(D, config['D_ortho'])
56 |
57 | D.optim.step()
58 |
59 | # Optionally toggle "requires_grad"
60 | if config['toggle_grads']:
61 | utils.toggle_grad(D, False)
62 | utils.toggle_grad(G, True)
63 |
64 | # Zero G's gradients by default before training G, for safety
65 | G.optim.zero_grad()
66 |
67 | # If accumulating gradients, loop multiple times
68 | for accumulation_index in range(config['num_G_accumulations']):
69 | z_.sample_()
70 | y_.sample_()
71 | D_fake = GD(z_, y_, train_G=True, split_D=config['split_D'])
72 | G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations'])
73 | G_loss.backward()
74 |
75 | # Optionally apply modified ortho reg in G
76 | if config['G_ortho'] > 0.0:
77 | print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G
78 | # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
79 | utils.ortho(G, config['G_ortho'],
80 | blacklist=[param for param in G.shared.parameters()])
81 | G.optim.step()
82 |
83 | # If we have an ema, update it, regardless of if we test with it or not
84 | if config['ema']:
85 | ema.update(state_dict['itr'])
86 |
87 | out = {'G_loss': float(G_loss.item()),
88 | 'D_loss_real': float(D_loss_real.item()),
89 | 'D_loss_fake': float(D_loss_fake.item())}
90 | # Return G's loss and the components of D's loss.
91 | return out
92 | return train
93 |
94 | ''' This function takes in the model, saves the weights (multiple copies if
95 | requested), and prepares sample sheets: one consisting of samples given
96 | a fixed noise seed (to show how the model evolves throughout training),
97 | a set of full conditional sample sheets, and a set of interp sheets. '''
98 | def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
99 | state_dict, config, experiment_name):
100 | utils.save_weights(G, D, state_dict, config['weights_root'],
101 | experiment_name, None, G_ema if config['ema'] else None)
102 | # Save an additional copy to mitigate accidental corruption if process
103 | # is killed during a save (it's happened to me before -.-)
104 | if config['num_save_copies'] > 0:
105 | utils.save_weights(G, D, state_dict, config['weights_root'],
106 | experiment_name,
107 | 'copy%d' % state_dict['save_num'],
108 | G_ema if config['ema'] else None)
109 | state_dict['save_num'] = (state_dict['save_num'] + 1 ) % config['num_save_copies']
110 |
111 | # Use EMA G for samples or non-EMA?
112 | which_G = G_ema if config['ema'] and config['use_ema'] else G
113 |
114 | # Accumulate standing statistics?
115 | if config['accumulate_stats']:
116 | utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
117 | z_, y_, config['n_classes'],
118 | config['num_standing_accumulations'])
119 |
120 | # Save a random sample sheet with fixed z and y
121 | with torch.no_grad():
122 | if config['parallel']:
123 | fixed_Gz = nn.parallel.data_parallel(which_G, (fixed_z, which_G.shared(fixed_y)))
124 | else:
125 | fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))
126 | if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
127 | os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
128 | image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'],
129 | experiment_name,
130 | state_dict['itr'])
131 | torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename,
132 | nrow=int(fixed_Gz.shape[0] **0.5), normalize=True)
133 | # For now, every time we save, also save sample sheets
134 | utils.sample_sheet(which_G,
135 | classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
136 | num_classes=config['n_classes'],
137 | samples_per_class=10, parallel=config['parallel'],
138 | samples_root=config['samples_root'],
139 | experiment_name=experiment_name,
140 | folder_number=state_dict['itr'],
141 | z_=z_)
142 | # Also save interp sheets
143 | for fix_z, fix_y in zip([False, False, True], [False, True, False]):
144 | utils.interp_sheet(which_G,
145 | num_per_sheet=16,
146 | num_midpoints=8,
147 | num_classes=config['n_classes'],
148 | parallel=config['parallel'],
149 | samples_root=config['samples_root'],
150 | experiment_name=experiment_name,
151 | folder_number=state_dict['itr'],
152 | sheet_number=0,
153 | fix_z=fix_z, fix_y=fix_y, device='cuda')
154 |
155 |
156 |
157 | ''' This function runs the inception metrics code, checks if the results
158 | are an improvement over the previous best (either in IS or FID,
159 | user-specified), logs the results, and saves a best_ copy if it's an
160 | improvement. '''
161 | def test(G, D, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics,
162 | experiment_name, test_log):
163 | print('Gathering inception metrics...')
164 | if config['accumulate_stats']:
165 | utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
166 | z_, y_, config['n_classes'],
167 | config['num_standing_accumulations'])
168 | IS_mean, IS_std, FID = get_inception_metrics(sample,
169 | config['num_inception_images'],
170 | num_splits=10)
171 | print('Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID))
172 | # If improved over previous best metric, save approrpiate copy
173 | if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS'])
174 | or (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
175 | print('%s improved over previous best, saving checkpoint...' % config['which_best'])
176 | utils.save_weights(G, D, state_dict, config['weights_root'],
177 | experiment_name, 'best%d' % state_dict['save_best_num'],
178 | G_ema if config['ema'] else None)
179 | state_dict['save_best_num'] = (state_dict['save_best_num'] + 1 ) % config['num_best_copies']
180 | state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
181 | state_dict['best_FID'] = min(state_dict['best_FID'], FID)
182 | # Log results to file
183 | test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean),
184 | IS_std=float(IS_std), FID=float(FID))
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import numpy as np
11 | import torch
12 | from torch.utils.data import Dataset
13 | from PIL import Image
14 |
15 | class ImagenetDataset(Dataset):
16 | # dataset for training bigdataset gan
17 | def __init__(self, data_root):
18 | self.label_dir = os.path.join(data_root, 'annotations/biggan512/')
19 | self.latent_dir = os.path.join(data_root, 'latents/biggan512/')
20 |
21 | self._prepare_data_list()
22 |
23 | def _prepare_data_list(self):
24 | class_list = sorted(os.listdir(self.label_dir))
25 | label_list = []
26 | latent_list = []
27 | for class_n in class_list:
28 | label_file_list = sorted(os.listdir(os.path.join(self.label_dir, class_n)))
29 | latent_file_list = sorted(os.listdir(os.path.join(self.latent_dir, class_n)))
30 |
31 | for label_file_n, latent_file_n in zip(label_file_list, latent_file_list):
32 | label_list.append(os.path.join(class_n, label_file_n))
33 | latent_list.append(os.path.join(class_n, latent_file_n))
34 |
35 | self.label_list = label_list
36 | self.latent_list = latent_list
37 | self.data_size = len(self.label_list)
38 |
39 | def __len__(self):
40 | return self.data_size
41 |
42 | def __getitem__(self, idx):
43 | latent_z = np.load(os.path.join(self.latent_dir, self.latent_list[idx]))[0]
44 | label_pil = Image.open(os.path.join(self.label_dir, self.label_list[idx])).convert('L')
45 | label_np = np.array(label_pil)
46 | # make label to 1
47 | label_np[label_np != 0] = 1
48 | class_y = int(self.label_list[idx].split('.')[0].split('_')[-2])
49 |
50 | latent_z = torch.tensor(latent_z, dtype=torch.float)
51 | label_tensor = torch.tensor(label_np, dtype=torch.long)
52 | class_y_tensor = torch.tensor(class_y, dtype=torch.long)
53 |
54 | return {
55 | 'latent': latent_z,
56 | 'label': label_tensor,
57 | 'y': class_y_tensor,
58 | }
59 |
60 |
--------------------------------------------------------------------------------
/datasets/imagenet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import json
10 |
11 | # with open("./imagenet2synset.json", 'r') as f:
12 | # imagenet_to_synset = json.load(f)
13 |
14 | # with open("./imagenetclass992.json", 'r') as f:
15 | # synset_to_imagenet = json.load(f)
16 |
17 | pascal_to_synset = {
18 | "bird": ["cock (n01514668)", "hen (n01514859)", "water ouzel (n01601694)", "house finch (n01532829)", "brambling (n01530575)", "junco (n01534433)", "goldfinch (n01531178)", "indigo bunting (n01537544)", "chickadee (n01592084)", "robin (n01558993)", "bulbul (n01560419)", "magpie (n01582220)", "jay (n01580077)", "black swan (n01860187)", "crane (n02012849)", "spoonbill (n02006656)", "flamingo (n02007558)", "bustard (n02018795)", "limpkin (n02013706)", "bittern (n02011460)", "little blue heron (n02009229)", "American egret (n02009912)", "oystercatcher (n02037110)", "dowitcher (n02033041)", "red-backed sandpiper (n02027492)", "redshank (n02028035)", "ruddy turnstone (n02025239)", "black stork (n02002724)", "white stork (n02002556)", "American coot (n02018207)", "king penguin (n02056570)", "albatross (n02058221)", "pelican (n02051845)", "European gallinule (n02017213)", "goose (n01855672)", "drake (n01847000)", "red-breasted merganser (n01855032)", "coucal (n01824575)", "hummingbird (n01833805)", "ostrich (n01518878)", "ruffed grouse (n01797886)", "black grouse (n01795545)", "prairie chicken (n01798484)", "ptarmigan (n01796340)", "quail (n01806567)", "partridge (n01807496)", "peacock (n01806143)", "bee eater (n01828970)", "hornbill (n01829413)", "vulture (n01616318)", "great grey owl (n01622779)", "kite (n01608432)", "bald eagle (n01614925)", "jacamar (n01843065)", "toucan (n01843383)", "macaw (n01818515)", "African grey (n01817953)", "lorikeet (n01820546)", "sulphur-crested cockatoo (n01819313)"],
19 | "boat": ["gondola (n03447447)", "fireboat (n03344393)", "yawl (n04612504)", "canoe (n02951358)", "lifeboat (n03662601)", "speedboat (n04273569)"],
20 | "bottle": ["beer bottle (n02823428)", "wine bottle (n04591713)", "water bottle (n04557648)", "pop bottle (n03983396)", "pill bottle (n03937543)", "whiskey jug (n04579145)", "water jug (n04560804)"],
21 | "bus": ["minibus (n03769881)", "school bus (n04146614)", "trolleybus (n04487081)"],
22 | "car": ["ambulance (n02701002)", "limousine (n03670208)", "jeep (n03594945)", "Model T (n03777568)", "cab (n02930766)", "minivan (n03770679)", "convertible (n03100240)", "racer (n04037443)", "beach wagon (n02814533)", "sports car (n04285008)"],
23 | "cat": ["Egyptian cat (n02124075)", "Persian cat (n02123394)", "tabby (n02123045)", "Siamese cat (n02123597)", "tiger cat (n02123159)", "cougar (n02125311)", "lynx (n02127052)"],
24 | "chair": ["barber chair (n02791124)", "rocking chair (n04099969)", "folding chair (n03376595)", "throne (n04429376)"],
25 | "diningtable": ["dining table (n03201208)"],
26 | "dog": ["dalmatian (n02110341)", "Mexican hairless (n02113978)", "pug (n02110958)", "Newfoundland (n02111277)", "Leonberg (n02111129)", "basenji (n02110806)", "Great Pyrenees (n02111500)", "Eskimo dog (n02109961)", "bull mastiff (n02108422)", "Saint Bernard (n02109525)", "Great Dane (n02109047)", "boxer (n02108089)", "Rottweiler (n02106550)", "Old English sheepdog (n02105641)", "Shetland sheepdog (n02105855)", "kelpie (n02105412)", "Border collie (n02106166)", "Bouvier des Flandres (n02106382)", "German shepherd (n02106662)", "komondor (n02105505)", "briard (n02105251)", "collie (n02106030)", "groenendael (n02105056)", "malinois (n02105162)", "French bulldog (n02108915)", "kuvasz (n02104029)", "schipperke (n02104365)", "Doberman (n02107142)", "affenpinscher (n02110627)", "miniature pinscher (n02107312)", "Tibetan mastiff (n02108551)", "Siberian husky (n02110185)", "malamute (n02110063)", "Bernese mountain dog (n02107683)", "Appenzeller (n02107908)", "EntleBucher (n02108000)", "Greater Swiss Mountain dog (n02107574)", "toy poodle (n02113624)", "miniature poodle (n02113712)", "standard poodle (n02113799)", "Pembroke (n02113023)", "Cardigan (n02113186)", "Rhodesian ridgeback (n02087394)", "Scottish deerhound (n02092002)", "bloodhound (n02088466)", "otterhound (n02091635)", "Afghan hound (n02088094)", "redbone (n02090379)", "bluetick (n02088632)", "basset (n02088238)", "Ibizan hound (n02091244)", "Saluki (n02091831)", "Norwegian elkhound (n02091467)", "beagle (n02088364)", "Weimaraner (n02092339)", "black-and-tan coonhound (n02089078)", "borzoi (n02090622)", "Irish wolfhound (n02090721)", "English foxhound (n02089973)", "Walker hound (n02089867)", "whippet (n02091134)", "Italian greyhound (n02091032)", "Dandie Dinmont (n02096437)", "Norwich terrier (n02094258)", "Border terrier (n02093754)", "West Highland white terrier (n02098286)", "Yorkshire terrier (n02094433)", "Airedale (n02096051)", "Irish terrier (n02093991)", "Bedlington terrier (n02093647)", "Norfolk terrier (n02094114)", "Lhasa (n02098413)", "silky terrier (n02097658)", "Kerry blue terrier (n02093859)", "Scotch terrier (n02097298)", "Tibetan terrier (n02097474)", "cairn (n02096177)", "soft-coated wheaten terrier (n02098105)", "Boston bull (n02096585)", "Australian terrier (n02096294)", "Staffordshire bullterrier (n02093256)", "American Staffordshire terrier (n02093428)", "Lakeland terrier (n02095570)", "Sealyham terrier (n02095889)", "giant schnauzer (n02097130)", "miniature schnauzer (n02097047)", "standard schnauzer (n02097209)", "wire-haired fox terrier (n02095314)", "curly-coated retriever (n02099429)", "flat-coated retriever (n02099267)", "golden retriever (n02099601)", "Chesapeake Bay retriever (n02099849)", "Labrador retriever (n02099712)", "Sussex spaniel (n02102480)", "Brittany spaniel (n02101388)", "clumber (n02101556)", "cocker spaniel (n02102318)", "English springer (n02102040)", "Welsh springer spaniel (n02102177)", "Irish water spaniel (n02102973)", "vizsla (n02100583)", "German short-haired pointer (n02100236)", "Irish setter (n02100877)", "Gordon setter (n02101006)", "English setter (n02100735)", "Maltese dog (n02085936)", "Chihuahua (n02085620)", "Pekinese (n02086079)", "Shih-Tzu (n02086240)", "toy terrier (n02087046)", "Japanese spaniel (n02085782)", "papillon (n02086910)", "Blenheim spaniel (n02086646)", "Brabancon griffon (n02112706)", "Samoyed (n02111889)", "Pomeranian (n02112018)", "keeshond (n02112350)", "chow (n02112137)"],
27 | "horse": ["sorrel (n02389026)"],
28 | #"motorbike": ["moped (n03785016)"],
29 | "person": ["ballplayer (n09835506)", "ballplayer (n09835506)", "groom (n10148035)", "scuba diver (n10565667)"],
30 | #"pottedplant": ["daisy (n11939491)", "yellow lady slipper (n12057211)"],
31 | "sheep": ["ram (n02412080)"],
32 | #"sofa": ["studio couch (n04344873)"],
33 | "train": ["bullet train (n02917067)"],
34 | #"tvmonitor": ["monitor (n03782006)"],
35 | "aeroplane": ["airliner (n02690373)"],
36 | "bicycle": ["bicycle-built-for-two (n02835271)", "mountain bike (n03792782)"],
37 | }
38 |
39 | def parse_synset_str(x):
40 | synset = ''
41 | i, j = 0, 0
42 | while True:
43 | if x[i] == '(':
44 | j = i + 1
45 | while True:
46 | if x[j] == ')':
47 | break
48 | synset += x[j]
49 | j += 1
50 |
51 | break
52 | i += 1
53 |
54 | return synset
55 |
56 |
57 |
58 | pascal_to_id = {
59 | 'background': 0,
60 | 'aeroplane': 1,
61 | 'bicycle': 2,
62 | 'bird': 3,
63 | 'boat': 4,
64 | 'bottle': 5,
65 | 'bus': 6,
66 | 'car': 7,
67 | 'cat': 8,
68 | 'chair': 9,
69 | 'diningtable': 10,
70 | 'dog': 11,
71 | 'horse': 12,
72 | #'motorbike': 13,
73 | 'person': 13,
74 | #'pottedplant': 15,
75 | 'sheep': 14,
76 | #'sofa': 17,
77 | 'train': 15,
78 | #'tvmonitor': 19,
79 | }
80 |
81 | pascal_to_coco = {
82 | 'background': 0,
83 | 'aeroplane': 5,
84 | 'bicycle': 2,
85 | 'bird': 16,
86 | 'boat': 9,
87 | 'bottle': 44,
88 | 'bus': 6,
89 | 'car': 3,
90 | 'cat': 17,
91 | 'chair': 62,
92 | #'cow': 21,
93 | 'diningtable': 67,
94 | 'dog': 18,
95 | 'horse': 19,
96 | #'motorbike': 4,
97 | 'person': 1,
98 | #'pottedplant': 64,
99 | 'sheep': 20,
100 | #'sofa': 63,
101 | 'train': 7,
102 | #'tvmonitor': 72,
103 | }
104 |
105 | random_synset_100 = ['n02966193', 'n02825657', 'n02708093', 'n07716906',
106 | 'n02325366', 'n02025239', 'n02097209', 'n02106662', 'n02277742', 'n02117135',
107 | 'n02087394', 'n01601694', 'n02447366', 'n01682714', 'n03884397', 'n01537544', 'n03063689',
108 | 'n04606251', 'n02493509', 'n02090622', 'n02071294', 'n13037406', 'n04146614', 'n02342885',
109 | 'n02110958', 'n03223299', 'n02963159', 'n02093859', 'n01494475', 'n01955084', 'n02490219',
110 | 'n02840245', 'n02108000', 'n01944390', 'n01860187', 'n02113799', 'n01910747', 'n02086910',
111 | 'n01978455', 'n02107312', 'n02965783', 'n02013706', 'n04033901', 'n01692333', 'n03207941',
112 | 'n02109961', 'n02687172', 'n02002724', 'n01775062', 'n02104365', 'n01749939', 'n01945685',
113 | 'n01704323', 'n04136333', 'n02105855', 'n02443484', 'n02056570', 'n02403003', 'n02134418',
114 | 'n03417042', 'n02096051', 'n02978881', 'n01531178', 'n03065424', 'n01806567', 'n02100877',
115 | 'n03126707', 'n01843065', 'n02814860', 'n02088238', 'n02999410', 'n01484850', 'n02259212',
116 | 'n02097474', 'n02877765', 'n02099712', 'n02123159', 'n01630670', 'n04252077', 'n03218198',
117 | 'n02489166', 'n02727426', 'n02097047', 'n02492035', 'n01728572', 'n03337140', 'n02268853',
118 | 'n01872401', 'n02094433', 'n02206856', 'n01753488', 'n02910353', 'n02114855', 'n03179701',
119 | 'n01498041', 'n04009552', 'n02177972', 'n03016953', 'n02894605', 'n01843383']
120 |
121 |
122 | def get_imagenet_id_list(class_name):
123 | id_list = []
124 | if class_name == 'imagenet-dog':
125 | for x in pascal_to_synset['dog']:
126 | synset_id = parse_synset_str(x)
127 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id'])
128 |
129 | elif class_name == 'imagenet-bird':
130 | for x in pascal_to_synset['bird']:
131 | synset_id = parse_synset_str(x)
132 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id'])
133 |
134 | elif class_name == 'imagenet-pascal':
135 | for key in pascal_to_synset:
136 | for x in pascal_to_synset[key]:
137 | synset_id = parse_synset_str(x)
138 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id'])
139 |
140 | elif class_name == 'imagenet-100':
141 | for synset_id in random_synset_100:
142 | id_list.append(synset_to_imagenet[synset_id]['imagenet_id'])
143 |
144 | else:
145 | for key in synset_to_imagenet:
146 | if key != 'background':
147 | id_list.append(synset_to_imagenet[key]['imagenet_id'])
148 |
149 | return id_list
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import sys
13 | sys.path.append('./biggan_pytorch')
14 | from biggan_pytorch import BigGAN
15 |
16 | class SegBlock(nn.Module):
17 | def __init__(self, in_channels, out_channels, con_channels,
18 | which_conv=nn.Conv2d, which_linear=None, activation=None,
19 | upsample=None):
20 | super(SegBlock, self).__init__()
21 |
22 | self.in_channels, self.out_channels = in_channels, out_channels
23 | self.which_conv, self.which_linear = which_conv, which_linear
24 | self.activation = activation
25 | self.upsample = upsample
26 | # Conv layers
27 | self.conv1 = self.which_conv(self.in_channels, self.out_channels)
28 | self.conv2 = self.which_conv(self.out_channels, self.out_channels)
29 | self.learnable_sc = in_channels != out_channels or upsample
30 | if self.learnable_sc:
31 | self.conv_sc = self.which_conv(in_channels, out_channels,
32 | kernel_size=1, padding=0)
33 | # Batchnorm layers
34 | self.bn1 = BigGAN.layers.ccbn(in_channels, con_channels, self.which_linear, eps=1e-4, norm_style='bn')
35 | self.bn2 = BigGAN.layers.ccbn(out_channels, con_channels, self.which_linear, eps=1e-4, norm_style='bn')
36 | # upsample layers
37 | self.upsample = upsample
38 |
39 | def forward(self, x, y):
40 | h = self.activation(self.bn1(x, y))
41 | if self.upsample:
42 | h = self.upsample(h)
43 | x = self.upsample(x)
44 | h = self.conv1(h)
45 | h = self.activation(self.bn2(h, y))
46 | h = self.conv2(h)
47 | if self.learnable_sc:
48 | x = self.conv_sc(x)
49 | return h + x
50 |
51 | def get_config(resolution):
52 | attn_dict = {128: '64', 256: '128', 512: '64'}
53 | dim_z_dict = {128: 120, 256: 140, 512: 128}
54 | config = {'G_param': 'SN', 'D_param': 'SN',
55 | 'G_ch': 96, 'D_ch': 96,
56 | 'D_wide': True, 'G_shared': True,
57 | 'shared_dim': 128, 'dim_z': dim_z_dict[resolution],
58 | 'hier': True, 'cross_replica': False,
59 | 'mybn': False, 'G_activation': nn.ReLU(inplace=True),
60 | 'G_attn': attn_dict[resolution],
61 | 'norm_style': 'bn',
62 | 'G_init': 'ortho', 'skip_init': True, 'no_optim': True,
63 | 'G_fp16': False, 'G_mixed_precision': False,
64 | 'accumulate_stats': False, 'num_standing_accumulations': 16,
65 | 'G_eval_mode': True,
66 | 'BN_eps': 1e-04, 'SN_eps': 1e-04,
67 | 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution,
68 | 'n_classes': 1000}
69 | return config
70 |
71 | class BigdatasetGANModel(nn.Module):
72 | def __init__(self, resolution, out_dim, biggan_ckpt=None):
73 | super(BigdatasetGANModel, self).__init__()
74 | self.biggan_ckpt = biggan_ckpt
75 | self.resolution = resolution
76 | # load biggan model
77 | self._prepare_biggan_model()
78 |
79 | self.low_feature_size = 32
80 | self.mid_feature_size = 128
81 | self.high_feature_size = 512
82 |
83 | low_feature_channel = 128
84 | mid_feature_channel = 64
85 | high_feature_channel = 32
86 |
87 | self.low_feature_conv = nn.Sequential(
88 | nn.Conv2d(3072, low_feature_channel, kernel_size=1, bias=False),
89 | #nn.ReLU(),
90 | )
91 | self.mid_feature_conv = nn.Sequential(
92 | nn.Conv2d(960, mid_feature_channel, kernel_size=1, bias=False),
93 | #nn.ReLU(),
94 | )
95 | self.mid_feature_mix_conv = SegBlock(
96 | in_channels=low_feature_channel+mid_feature_channel,
97 | out_channels=low_feature_channel+mid_feature_channel,
98 | con_channels=self.biggan_model.shared_dim,
99 | which_conv=self.biggan_model.which_conv,
100 | which_linear=self.biggan_model.which_linear,
101 | activation=self.biggan_model.activation,
102 | upsample=False,
103 | )
104 |
105 | self.high_feature_conv = nn.Sequential(
106 | nn.Conv2d(192, high_feature_channel, kernel_size=1, bias=False),
107 | #nn.ReLU(),
108 | )
109 |
110 | self.high_feature_mix_conv = SegBlock(
111 | in_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
112 | out_channels=low_feature_channel+mid_feature_channel+high_feature_channel,
113 | con_channels=self.biggan_model.shared_dim,
114 | which_conv=self.biggan_model.which_conv,
115 | which_linear=self.biggan_model.which_linear,
116 | activation=self.biggan_model.activation,
117 | upsample=False,
118 | )
119 |
120 | self.out_layer = nn.Conv2d(low_feature_channel+mid_feature_channel+high_feature_channel,
121 | out_dim, kernel_size=3, padding=1)
122 | self.out_layer = nn.Sequential(
123 | BigGAN.layers.bn(low_feature_channel+mid_feature_channel+high_feature_channel),
124 | self.biggan_model.activation,
125 | self.biggan_model.which_conv(low_feature_channel+mid_feature_channel+high_feature_channel, out_dim)
126 | )
127 |
128 | def _prepare_biggan_model(self):
129 | biggan_config = get_config(self.resolution)
130 | self.biggan_model = BigGAN.Generator(**biggan_config)
131 | if self.biggan_ckpt != None:
132 | state_dict = torch.load(self.biggan_ckpt)
133 | self.biggan_model.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries
134 | self.biggan_model.eval()
135 |
136 | def _prepare_features(self, features, upsample='bilinear'):
137 | # for low feature
138 | low_features = [
139 | F.interpolate(features[0], size=self.low_feature_size, mode=upsample, align_corners=False),
140 | F.interpolate(features[1], size=self.low_feature_size, mode=upsample, align_corners=False),
141 | features[2],
142 | ]
143 | low_features = torch.cat(low_features, dim=1)
144 | # for mid feature
145 | mid_features = [
146 | F.interpolate(features[3], size=self.mid_feature_size, mode=upsample, align_corners=False),
147 | F.interpolate(features[4], size=self.mid_feature_size, mode=upsample, align_corners=False),
148 | features[5],
149 | ]
150 | mid_features = torch.cat(mid_features, dim=1)
151 | # for high feature
152 | high_features = [
153 | F.interpolate(features[6], size=self.high_feature_size, mode=upsample, align_corners=False),
154 | #F.interpolate(features[7], size=self.high_feature_size, mode=upsample, align_corners=False),
155 | features[7],
156 | ]
157 | high_features = torch.cat(high_features, dim=1)
158 |
159 | features_dict = {
160 | 'low': low_features,
161 | 'mid': mid_features,
162 | 'high': high_features,
163 | }
164 |
165 | return features_dict
166 |
167 | @torch.no_grad()
168 | def _get_biggan_features(self, z, y):
169 | features = []
170 | y = self.biggan_model.shared(y)
171 | # forward thru biggan
172 | # If hierarchical, concatenate zs and ys
173 | if self.biggan_model.hier:
174 | zs = torch.split(z, self.biggan_model.z_chunk_size, 1)
175 | z = zs[0]
176 | ys = [torch.cat([y, item], 1) for item in zs[1:]]
177 | else:
178 | ys = [y] * len(self.biggan_model.blocks)
179 |
180 | # First linear layer
181 | h = self.biggan_model.linear(z)
182 | # Reshape
183 | h = h.view(h.size(0), -1, self.biggan_model.bottom_width, self.biggan_model.bottom_width)
184 |
185 | # Loop over blocks
186 | for index, blocklist in enumerate(self.biggan_model.blocks):
187 | # Second inner loop in case block has multiple layers
188 | for block in blocklist:
189 | h = block(h, ys[index])
190 | # save feature
191 | features.append(h)
192 | #print(index, h.shape)
193 |
194 | features_dict = self._prepare_features(features)
195 |
196 | return features_dict, y, h
197 |
198 | def forward(self, z, y):
199 | features_dict, y, _ = self._get_biggan_features(z, y)
200 |
201 | # for low features
202 | low_feat = self.low_feature_conv(features_dict['low'])
203 | low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False)
204 | # for mid features
205 | mid_feat = self.mid_feature_conv(features_dict['mid'])
206 | mid_feat = torch.cat([low_feat, mid_feat], dim=1)
207 | mid_feat = self.mid_feature_mix_conv(mid_feat, y)
208 | mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False)
209 | # for high features
210 | high_feat = self.high_feature_conv(features_dict['high'])
211 | high_feat = torch.cat([mid_feat, high_feat], dim=1)
212 | high_feat = self.high_feature_mix_conv(high_feat, y)
213 | out = self.out_layer(high_feat)
214 |
215 | return out
216 |
217 | @torch.no_grad()
218 | def sample(self, z, y):
219 | features_dict, y, h = self._get_biggan_features(z,y)
220 |
221 | image = torch.tanh(self.biggan_model.output_layer(h.detach()))
222 |
223 | # for low features
224 | low_feat = self.low_feature_conv(features_dict['low'])
225 | low_feat = F.interpolate(low_feat, size=self.mid_feature_size, mode='bilinear', align_corners=False)
226 | # for mid features
227 | mid_feat = self.mid_feature_conv(features_dict['mid'])
228 | mid_feat = torch.cat([low_feat, mid_feat], dim=1)
229 | mid_feat = self.mid_feature_mix_conv(mid_feat, y)
230 | mid_feat = F.interpolate(mid_feat, size=self.high_feature_size, mode='bilinear', align_corners=False)
231 | # for high features
232 | high_feat = self.high_feature_conv(features_dict['high'])
233 | high_feat = torch.cat([mid_feat, high_feat], dim=1)
234 | high_feat = self.high_feature_mix_conv(high_feat, y)
235 | out = self.out_layer(high_feat)
236 |
237 | return image, out
238 |
239 |
240 | if __name__ == '__main__':
241 |
242 | biggan_ckpt = './pretrain/biggan-512.pth'
243 | model = BigdatasetGANModel(512, 1, biggan_ckpt).cuda()
--------------------------------------------------------------------------------
/prepare_biggan_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | from models import BigdatasetGANModel
11 | import numpy as np
12 | import argparse
13 | import torch
14 | import torchvision
15 |
16 | def parse_args():
17 | usage = 'Parser for generate biggan images script.'
18 | parser = argparse.ArgumentParser(description=usage)
19 |
20 | parser.add_argument(
21 | '--biggan_ckpt', type=str, default='./pretrain/biggan-512.pth', help='path to the pretrained biggan ckpt')
22 | parser.add_argument(
23 | '--dataset_dir', type=str, default='./data/', help='path to the dataset dir')
24 |
25 | args = parser.parse_args()
26 |
27 | return args
28 |
29 | if __name__ == '__main__':
30 | args = parse_args()
31 |
32 | device = 'cuda'
33 |
34 | latents_dir = os.path.join(args.dataset_dir, 'latents/biggan512/')
35 | images_dir = os.path.join(args.dataset_dir, 'images/biggan512/')
36 |
37 | # loading model
38 | generator = BigdatasetGANModel(resolution=512, out_dim=1, biggan_ckpt=args.biggan_ckpt).to(device)
39 |
40 | generator.eval()
41 |
42 | class_list = os.listdir(latents_dir)
43 |
44 | for class_n in class_list:
45 | latent_class_dir = os.path.join(latents_dir, class_n)
46 | image_class_dir = os.path.join(images_dir, class_n)
47 |
48 | os.makedirs(image_class_dir, exist_ok=True)
49 |
50 | latent_list = os.listdir(latent_class_dir)
51 |
52 | for latent_n in latent_list:
53 | image_name = latent_n.split('.')[0]
54 | latent_np = np.load(os.path.join(latent_class_dir, latent_n))[0]
55 | class_y = int(image_name.split('_')[-2])
56 |
57 | latent_tensor = torch.tensor(latent_np, dtype=torch.float).unsqueeze(0).to(device)
58 | class_y_tensor = torch.tensor([class_y], dtype=torch.long).to(device)
59 |
60 | image_tensor, _ = generator.sample(latent_tensor, class_y_tensor)
61 |
62 | print("Saving biggan images from the latent to: ", os.path.join(image_class_dir, image_name+'.png'))
63 | # save image
64 | torchvision.utils.save_image(image_tensor, os.path.join(image_class_dir, image_name+'.png'), normalize=True)
65 |
--------------------------------------------------------------------------------
/prepare_imagenet_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import shutil
11 | import argparse
12 | import numpy as np
13 |
14 | def parse_args():
15 | usage = 'Parser for generate biggan images script.'
16 | parser = argparse.ArgumentParser(description=usage)
17 |
18 | parser.add_argument(
19 | '--imagenet_dir', type=str, required=True, help='path to the imagenet folder')
20 | parser.add_argument(
21 | '--dataset_dir', type=str, default='./data/', help='path to the dataset dir')
22 |
23 | args = parser.parse_args()
24 |
25 | return args
26 |
27 | if __name__ == '__main__':
28 | args = parse_args()
29 |
30 | real_list_path = os.path.join(args.dataset_dir, 'annotations/real-random-list.txt')
31 | real_list = np.loadtxt(real_list_path, dtype=str)
32 |
33 | for file_list in real_list:
34 | class_n = file_list.split('/')[0]
35 | real_anno_n = file_list.split('/')[1]
36 |
37 | img_id = real_anno_n.split('_')[-1]
38 | imagenet_file_name = class_n + '_' + img_id + '.JPEG'
39 | # copy imagenet image to dataset
40 | imagenet_file_path = os.path.join(args.imagenet_dir, class_n, imagenet_file_name)
41 | real_image_dir = os.path.join(args.dataset_dir, 'images/real-random/', class_n)
42 | os.makedirs(real_image_dir, exist_ok=True)
43 | save_path = os.path.join(real_image_dir, real_anno_n + '.jpg')
44 | print('Copy image from {0} to {1}'.format(imagenet_file_path, save_path))
45 | shutil.copy(imagenet_file_path, save_path)
46 |
47 |
--------------------------------------------------------------------------------
/sample_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import torch
11 | import argparse
12 | import torchvision
13 |
14 | from models import BigdatasetGANModel
15 | from utils import VOCColorize
16 |
17 | def parse_args():
18 | usage = 'Parser for training bigdataset script.'
19 | parser = argparse.ArgumentParser(description=usage)
20 | parser.add_argument(
21 | '--resolution', '-r', type=int, default=512,
22 | help='Resolution of the generated images, we use biggan-512 by default')
23 | parser.add_argument(
24 | '--ckpt', type=str, required=True,
25 | help='Path to the pretrained BigDatasetGAN weights')
26 | parser.add_argument(
27 | '--save_dir', type=str, default='./generated_datasets/',
28 | help='Path to save dataset')
29 | parser.add_argument(
30 | '--z_var', type=float, default=0.9,
31 | help='Truancation value of z')
32 | parser.add_argument(
33 | '--class_idx', type=int, default=[225, 200], nargs='+',
34 | help='Imagenet class index')
35 | parser.add_argument(
36 | '--samples_per_class', type=int, default=10,
37 | help='data samples per class')
38 |
39 | args = parser.parse_args()
40 |
41 | return args
42 |
43 | def main(args):
44 | device = 'cuda'
45 |
46 | # build seg model
47 | model = BigdatasetGANModel(resolution=args.resolution, out_dim=1, biggan_ckpt=None)
48 |
49 | # load pretrain model
50 | state_dict = torch.load(args.ckpt)
51 | model.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries
52 |
53 | model = model.to(device)
54 | model = model.eval()
55 |
56 | voc_col = VOCColorize(n=1000)
57 |
58 | overall_viz = []
59 |
60 | os.makedirs(args.save_dir, exist_ok=True)
61 |
62 | for class_y in args.class_idx:
63 | print("Start sampling dataset with class idx: {0}, total samples: {1}".format(class_y, args.samples_per_class))
64 |
65 | class_y_tensor = torch.tensor([class_y], dtype=torch.long).to(device)
66 |
67 | sample_imgs, sample_labels = [], []
68 | for i in range(args.samples_per_class):
69 | z = torch.empty(1, model.biggan_model.dim_z).normal_(mean=0, std=args.z_var).to(device)
70 | sample_img, sample_pred = model.sample(z, class_y_tensor)
71 | sample_img, sample_pred = sample_img.cpu(), sample_pred.cpu()
72 |
73 | label_pred_prob = torch.sigmoid(sample_pred)
74 | label_pred_mask = torch.zeros_like(label_pred_prob, dtype=torch.long)
75 | label_pred_mask[label_pred_prob>0.5] = 1
76 |
77 | label_pred_rgb = voc_col(label_pred_mask[0][0].cpu().numpy()*class_y)
78 | label_pred_rgb = torch.from_numpy(label_pred_rgb).float()
79 |
80 | sample_imgs.append(sample_img)
81 | sample_labels.append(label_pred_rgb)
82 |
83 | sample_imgs = torch.cat(sample_imgs, dim=0)
84 | sample_labels = torch.stack(sample_labels, dim=0)
85 |
86 | sample_imgs_grid = torchvision.utils.make_grid(sample_imgs, nrow=args.samples_per_class, normalize=True, scale_each=True)
87 | sample_labels_grid = torchvision.utils.make_grid(sample_labels, nrow=args.samples_per_class, normalize=True, scale_each=True)
88 | class_viz_tensor = torchvision.utils.make_grid(torch.stack([sample_imgs_grid, sample_labels_grid]), dim=0, nrow=1)
89 | overall_viz.append(class_viz_tensor)
90 |
91 | overall_viz = torch.stack(overall_viz, dim=0)
92 | torchvision.utils.save_image(overall_viz, os.path.join(args.save_dir, 'sample_overall.jpg'), nrow=1)
93 |
94 | if __name__ == '__main__':
95 |
96 | args = parse_args()
97 |
98 | main(args)
99 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import torch
11 | import torch.nn as nn
12 | import argparse
13 | import torchvision
14 |
15 | import numpy as np
16 | from torch.utils.data import DataLoader
17 | import torch.optim as optim
18 |
19 | from torch.utils.tensorboard import SummaryWriter
20 |
21 | from models import BigdatasetGANModel
22 | from datasets.datasets import ImagenetDataset
23 | from utils import VOCColorize
24 |
25 | def parse_args():
26 | usage = 'Parser for training bigdataset script.'
27 | parser = argparse.ArgumentParser(description=usage)
28 | parser.add_argument(
29 | '--resolution', '-r', type=int, default=512,
30 | help='Resolution of the generated images, we use biggan-512 by default')
31 |
32 | parser.add_argument(
33 | '--gan_ckpt', type=str, default='./pretrain/biggan-512.pth',
34 | help='Path to the pretrained gan ckpt')
35 | parser.add_argument(
36 | '--dataset_dir', type=str, default='./data/',
37 | help='Path to the dataset folder')
38 |
39 | parser.add_argument(
40 | '--save_dir', type=str, default='./logs/checkpoint_biggan512_label_conv/',
41 | help='Path to save logs')
42 | parser.add_argument(
43 | '--batch_size', type=int, default=4,
44 | help='training batch size')
45 | parser.add_argument(
46 | '--max_iter', type=int, default=5000,
47 | help='maximum iteration of training')
48 | parser.add_argument(
49 | '--lr', type=float, default=0.001,
50 | help='learning rate')
51 |
52 | args = parser.parse_args()
53 |
54 | return args
55 |
56 | def sample_data(loader):
57 | while True:
58 | for batch in loader:
59 | yield batch
60 |
61 | def main(args):
62 | device = 'cuda'
63 | # build checkpoint dir
64 | from datetime import datetime
65 | current_time = datetime.now().strftime('%b%d_%H-%M-%S')
66 | ckpt_dir = os.path.join(args.save_dir, 'run-'+current_time)
67 | os.makedirs(ckpt_dir, exist_ok=True)
68 | writer = SummaryWriter(log_dir=os.path.join(ckpt_dir, 'logs'))
69 | os.makedirs(os.path.join(ckpt_dir, 'training'), exist_ok=True)
70 | # os.makedirs(os.path.join(ckpt_dir, 'samples'), exist_ok=True)
71 |
72 | # build dataset
73 | dataset = ImagenetDataset(args.dataset_dir)
74 | #dataset = BigganDataset(args.dataset_dir, single_class=args.single_class)
75 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
76 |
77 | print("loading dataset size: ", len(dataset))
78 |
79 | # build seg model
80 | g_seg = BigdatasetGANModel(resolution=args.resolution, out_dim=1, biggan_ckpt=args.gan_ckpt)
81 |
82 | g_seg = g_seg.to(device)
83 |
84 | g_optim = optim.Adam(
85 | g_seg.parameters(),
86 | lr=args.lr
87 | )
88 |
89 | loss_fn = nn.BCEWithLogitsLoss()
90 |
91 | dataloader = sample_data(dataloader)
92 |
93 | voc_col = VOCColorize(n=1000)
94 |
95 | print("Start training with maximum {0} iterations.".format(args.max_iter))
96 |
97 | for i, batch_data in enumerate(dataloader):
98 |
99 | if i > args.max_iter:
100 | break
101 |
102 | z = batch_data['latent'].to(device)
103 | label_gt = batch_data['label'].to(device)
104 | y = batch_data['y'].to(device)
105 |
106 | # set g_Seg in train mode
107 | g_seg.train()
108 | g_seg.biggan_model.eval()
109 |
110 | label_pred = g_seg(z, y)
111 |
112 | loss = loss_fn(label_pred, label_gt.float().unsqueeze(1))
113 |
114 | g_optim.zero_grad()
115 | loss.backward()
116 | g_optim.step()
117 |
118 | writer.add_scalar('train/loss', loss.item(), global_step=i)
119 |
120 | if i % 10 == 0:
121 | print("Training step: {0:05d}/{1:05d}, loss: {2:0.4f}".format(i, args.max_iter, loss))
122 |
123 | if i % 100 == 0:
124 | # save train pred
125 | g_seg.eval()
126 | sample_imgs, sample_pred = g_seg.sample(z, y)
127 | sample_imgs, sample_pred = sample_imgs.cpu(), sample_pred.cpu()
128 |
129 | label_pred_prob = torch.sigmoid(label_pred)
130 | label_pred_mask = torch.zeros_like(label_pred_prob, dtype=torch.long)
131 | label_pred_mask[label_pred_prob>0.5] = 1
132 |
133 | label_pred_rgb = voc_col(label_pred_mask[0][0].cpu().numpy()*y[0].cpu().numpy())
134 | label_pred_rgb = torch.from_numpy(label_pred_rgb).float()
135 |
136 | label_gt_rgb = voc_col(label_gt[0].cpu().numpy()*y[0].cpu().numpy())
137 | label_gt_rgb = torch.from_numpy(label_gt_rgb).float()
138 |
139 | viz_tensor = torch.stack([sample_imgs[0], label_gt_rgb, label_pred_rgb], dim=0)
140 |
141 | torchvision.utils.save_image(viz_tensor, os.path.join(ckpt_dir,
142 | 'training/viz_sample_{0:05d}.jpg'.format(i)), normalize=True, scale_each=True)
143 |
144 | if i % 1000 == 0:
145 | # save checkpoint
146 | print("Saving latest checkpoint.")
147 | torch.save(g_seg.state_dict(), os.path.join(ckpt_dir, 'checkpoint_latest.pth'))
148 |
149 | if __name__ == '__main__':
150 |
151 | args = parse_args()
152 |
153 | main(args)
154 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class VOCColorize(object):
4 | def __init__(self, n):
5 | self.cmap = color_map(n)
6 |
7 | def __call__(self, gray_image):
8 | size = gray_image.shape
9 | color_image = np.zeros((3, size[0], size[1]), dtype=np.uint8)
10 |
11 | for label in range(0, len(self.cmap)):
12 | mask = (label == gray_image)
13 | color_image[0][mask] = self.cmap[label][0]
14 | color_image[1][mask] = self.cmap[label][1]
15 | color_image[2][mask] = self.cmap[label][2]
16 |
17 | # handle void
18 | # mask = (255 == gray_image)
19 | # color_image[0][mask] = color_image[1][mask] = color_image[2][mask] = 255
20 |
21 | return color_image
22 |
23 | def color_map(N, normalized=False):
24 | def bitget(byteval, idx):
25 | return ((byteval & (1 << idx)) != 0)
26 |
27 | dtype = 'float32' if normalized else 'uint8'
28 | cmap = np.zeros((N, 3), dtype=dtype)
29 | for i in range(N):
30 | r = g = b = 0
31 | c = i
32 | for j in range(8):
33 | r = r | (bitget(c, 0) << 7-j)
34 | g = g | (bitget(c, 1) << 7-j)
35 | b = b | (bitget(c, 2) << 7-j)
36 | c = c >> 3
37 |
38 | cmap[i] = np.array([r, g, b])
39 |
40 | cmap = cmap/255 if normalized else cmap
41 | return cmap
--------------------------------------------------------------------------------
/viz/sample_overall.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/bigdatasetgan_code/679f547cafaa004110f2cd5e5a08035fe0293223/viz/sample_overall.jpg
--------------------------------------------------------------------------------