├── .gitignore ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── README.rst ├── pro_gan_pytorch ├── __init__.py ├── custom_layers.py ├── data_tools.py ├── gan.py ├── losses.py ├── modules.py ├── networks.py ├── test │ ├── __init__.py │ ├── conftest.py │ ├── test_custom_layers.py │ ├── test_gan.py │ ├── test_networks.py │ └── utils.py └── utils.py ├── pro_gan_pytorch_scripts ├── __init__.py ├── compute_fid.py ├── latent_space_interpolation.py └── train.py ├── requirements-dev.txt ├── requirements.txt ├── samples ├── .gitignore └── celebA-HQ.gif └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # ignore pycharm data 107 | .idea/ 108 | 109 | # ignore the virtual environment for the project as well 110 | pro_gan_pytorch_env/ 111 | 112 | # ignore the test_train folder created by one of the tests: 113 | ./test_train -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Animesh Karnewar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # need to include the following files additionally for the setup.py to work 2 | include requirements.txt 3 | include scripts/* 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pro_gan_pytorch 2 | **Unofficial PyTorch** implementation of Paper titled "Progressive growing of GANs for improved 3 | Quality, Stability, and Variation".
4 | For the official TensorFlow code, please refer to 5 | [this repo](https://github.com/tkarras/progressive_growing_of_gans)
6 | 7 | ![GitHub](https://img.shields.io/github/license/akanimax/pro_gan_pytorch) 8 | ![PyPi](https://img.shields.io/badge/pip--pro--gan--pth-3.4-brightgreen) 9 | 10 | # How to use: 11 | ### Using the package 12 | **Requirements (aka. we tested for):** 13 | 1. **Ubuntu** `20.04.3` or above 14 | 2. Python `3.8.3` 15 | 3. Nvidia GPU `GeForce 1080 Ti or above` min GPU-mem `8GB` 16 | 4. Nvidia drivers >= `470.86` 17 | 5. Nvidia cuda `11.3` | can be skipped since pytorch ships with cuda, cudnn etc. 18 | 19 | **Installing the package** 20 | 1. Easiest way is to create a new virtual-env 21 | so that your global python env doesn't get corrupted 22 | 2. Create and switch to your new virtual environment 23 | ``` 24 | (your-machine):~$ python3 -m venv /pro_gan_pth_env 25 | (pro_gan_pth_env)(your-machine):~$ source /pro_gan_pth_env/bin/activate 26 | ``` 27 | 3. Install the `pro-gan-pth` package from pypi, if you meet 28 | all the above dependencies 29 | ``` 30 | (pro_gan_pth_env)(your-machine):~$ pip install pro-gan-pth 31 | ``` 32 | 4. Once installed, you can either use the installed commandline tools 33 | `progan_train`, `progan_lsid` and `progan_fid`. 34 | Note that the `progan_train` can be used with multiple gpus 35 | (If you have many :smile:). Just ensure that the gpus visible in the 36 | `CUDA_VISIBLE_DEVICES=0,1,2` environment variable. The other two tools only use a 37 | single GPU. 38 | 39 | 40 | ``` 41 | (your-machine):~$ progan_train --help 42 | usage: Train Progressively grown GAN 43 | [-h] 44 | [--retrain RETRAIN] 45 | [--generator_path GENERATOR_PATH] 46 | [--discriminator_path DISCRIMINATOR_PATH] 47 | [--rec_dir REC_DIR] 48 | [--flip_horizontal FLIP_HORIZONTAL] 49 | [--depth DEPTH] 50 | [--num_channels NUM_CHANNELS] 51 | [--latent_size LATENT_SIZE] 52 | [--use_eql USE_EQL] 53 | [--use_ema USE_EMA] 54 | [--ema_beta EMA_BETA] 55 | [--epochs EPOCHS [EPOCHS ...]] 56 | [--batch_sizes BATCH_SIZES [BATCH_SIZES ...]] 57 | [--batch_repeats BATCH_REPEATS] 58 | [--fade_in_percentages FADE_IN_PERCENTAGES [FADE_IN_PERCENTAGES ...]] 59 | [--loss_fn LOSS_FN] 60 | [--g_lrate G_LRATE] 61 | [--d_lrate D_LRATE] 62 | [--num_feedback_samples NUM_FEEDBACK_SAMPLES] 63 | [--start_depth START_DEPTH] 64 | [--num_workers NUM_WORKERS] 65 | [--feedback_factor FEEDBACK_FACTOR] 66 | [--checkpoint_factor CHECKPOINT_FACTOR] 67 | train_path 68 | output_dir 69 | 70 | positional arguments: 71 | train_path Path to the images folder for training the ProGAN 72 | output_dir Path to the directory for saving the logs and models 73 | 74 | optional arguments: 75 | -h, --help show this help message and exit 76 | --retrain RETRAIN whenever you want to resume training from saved models (default: False) 77 | --generator_path GENERATOR_PATH 78 | Path to the generator model for retraining the ProGAN (default: None) 79 | --discriminator_path DISCRIMINATOR_PATH 80 | Path to the discriminat or model for retraining the ProGAN (default: None) 81 | --rec_dir REC_DIR whether images stored under one folder or has a recursive dir structure (default: True) 82 | --flip_horizontal FLIP_HORIZONTAL 83 | whether to apply mirror augmentation (default: True) 84 | --depth DEPTH depth of the generator and the discriminator (default: 10) 85 | --num_channels NUM_CHANNELS 86 | number of channels of in the image data (default: 3) 87 | --latent_size LATENT_SIZE 88 | latent size of the generator and the discriminator (default: 512) 89 | --use_eql USE_EQL whether to use the equalized learning rate (default: True) 90 | --use_ema USE_EMA whether to use the exponential moving averages (default: True) 91 | --ema_beta EMA_BETA value of the ema beta (default: 0.999) 92 | --epochs EPOCHS [EPOCHS ...] 93 | number of epochs over the training dataset per stage (default: [42, 42, 42, 42, 42, 42, 42, 42, 42]) 94 | --batch_sizes BATCH_SIZES [BATCH_SIZES ...] 95 | batch size used for training the model per stage (default: [32, 32, 32, 32, 16, 16, 8, 4, 2]) 96 | --batch_repeats BATCH_REPEATS 97 | number of G and D steps executed per training iteration (default: 4) 98 | --fade_in_percentages FADE_IN_PERCENTAGES [FADE_IN_PERCENTAGES ...] 99 | number of iterations for which fading of new layer happens. Measured in percentage (default: [50, 50, 50, 50, 50, 50, 50, 50, 50]) 100 | --loss_fn LOSS_FN loss function used for training the GAN. Current options: [wgan_gp, standard_gan] (default: wgan_gp) 101 | --g_lrate G_LRATE learning rate used by the generator (default: 0.003) 102 | --d_lrate D_LRATE learning rate used by the discriminator (default: 0.003) 103 | --num_feedback_samples NUM_FEEDBACK_SAMPLES 104 | number of samples used for fixed seed gan feedback (default: 4) 105 | --start_depth START_DEPTH 106 | resolution to start the training from. Example 2 --> (4x4) | 3 --> (8x8) ... | 10 --> (1024x1024)Note that this is not a way to restart a partial training. Resuming is not 107 | supported currently. But will be soon. (default: 2) 108 | --num_workers NUM_WORKERS 109 | number of dataloader subprocesses. It's a pytorch thing, you can ignore it ;). Leave it to the default value unless things are weirdly slow for you. (default: 4) 110 | --feedback_factor FEEDBACK_FACTOR 111 | number of feedback logs written per epoch (default: 10) 112 | --checkpoint_factor CHECKPOINT_FACTOR 113 | number of epochs after which a model snapshot is saved per training stage (default: 10) 114 | 115 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ 116 | 117 | (your-machine):~$ progan_lsid --help 118 | usage: ProGAN latent-space walk demo video creation tool [-h] [--output_path OUTPUT_PATH] [--generation_depth GENERATION_DEPTH] [--time TIME] [--fps FPS] [--smoothing SMOOTHING] model_path 119 | 120 | positional arguments: 121 | model_path path to the trained_model.bin file 122 | 123 | optional arguments: 124 | -h, --help show this help message and exit 125 | --output_path OUTPUT_PATH 126 | path to the output video file location. Please only use mp4 format with this tool (.mp4 extension). I have banged my head too much to get anything else to work :(. (default: 127 | ./latent_space_walk.mp4) 128 | --generation_depth GENERATION_DEPTH 129 | depth at which the images should be generated. Starts from 2 --> (4x4) | 3 --> (8x8) etc. (default: None) 130 | --time TIME number of seconds in the video (default: 30) 131 | --fps FPS fps of the generated video (default: 60) 132 | --smoothing SMOOTHING 133 | smoothness of walking in the latent-space. High values corresponds to more smoothing. (default: 0.75) 134 | 135 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ 136 | 137 | (your-machine):~$ progan_fid --help 138 | usage: ProGAN fid_score computation tool [-h] [--generated_images_path GENERATED_IMAGES_PATH] [--batch_size BATCH_SIZE] [--num_generated_images NUM_GENERATED_IMAGES] model_path dataset_path 139 | 140 | positional arguments: 141 | model_path path to the trained_model.bin file 142 | dataset_path path to the directory containing the images from the dataset. Note that this needs to be a flat directory 143 | 144 | optional arguments: 145 | -h, --help show this help message and exit 146 | --generated_images_path GENERATED_IMAGES_PATH 147 | path to the directory where the generated images are to be written. Uses a temporary directory by default. Provide this path if you'd like to see the generated images yourself 148 | :). (default: None) 149 | --batch_size BATCH_SIZE 150 | batch size used for generating random images (default: 4) 151 | --num_generated_images NUM_GENERATED_IMAGES 152 | number of generated images used for computing the FID (default: 50000) 153 | ``` 154 | 155 | 5. Or, you could import this as a python package in your code 156 | for more advanced use-cases: 157 | ``` 158 | import pro_gan_pytorch as pg 159 | ``` 160 | You can use all the modules in the package such as: `pg.networks.Generator`, 161 | `pg.networks.Discriminator`, `pg.gan.ProGAN` etc. Mostly, you'll only need 162 | the `pg.gan.ProGAN` module for training. For inference, you will probably only 163 | need the `pg.networks.Generator`. Please refer to the scripts for the tools as 164 | in 4. under `pro_gan_pytorch_scripts/` for examples on how to use the package. 165 | Besides, please feel free to just read the code. It's really easy to follow 166 | (or at least I hope so :sweat_smile: :grimacing:). 167 | 168 | ### Developing the package 169 | For more advanced use-cases in your project, or if you'd like to contribute new 170 | features to this project, the following steps would help you get this project setup 171 | for developing. There are no standard set of rules for contributing here 172 | (no `CONTRIBUTING.md`) but let's try to maintain the overall ethos of the 173 | codebase :smile:. 174 | 175 | 1. clone this repository 176 | ``` 177 | (your-machine):~$ cd 178 | (your-machine):$ git clone https://github.com/akanimax/pro_gan_pytorch.git 179 | ``` 180 | 2. Apologies in advance since the step 1. will take a while. I ended up 181 | pushing gifs and other large binary assets to git back then. 182 | I didn't know better :sad:. I'll see if this can be sorted out somehow. 183 | But once done setup a development virtual env, 184 | ``` 185 | (your-machine):$ python3 -m venv pro-gan-pth-dev-env 186 | (your-machine):$ source pro-gan-pth-dev-env/source/activate 187 | ``` 188 | 3. Install the package in development mode: 189 | ``` 190 | (pro-gan-pth-dev-env)(your-machine):$ pip install -e . 191 | ``` 192 | 4. Also install the dev requirements: 193 | ``` 194 | (pro-gan-pth-dev-env)(your-machine):$ pip install -r requirements-dev.txt 195 | ``` 196 | 5. Now open the project in the editor of your choice, and you are good to go. 197 | I use `pytest` for testing and `black` for code formatting. Check out 198 | [this_link](https://black.readthedocs.io/en/stable/integrations/editors.html) for 199 | how to setup `black` with various IDEs. 200 | 201 | 6. There is no fancy CI, or automated testing, or docs building since this is a 202 | fairly tiny project. But we are open to considering these tools if more features 203 | keep getting added to this project. 204 | 205 | # Trained Models 206 | We will be training models using this package on different datasets over the time. 207 | Also, please feel free to open PRs for the following table if you end up training 208 | models for your own datasets. If you are contributing, then please setup 209 | a file hosting solution for serving the trained models. 210 | 211 | | Courtesy | Dataset | Size |Resolution | GPUs used | #Epochs per stage | Training time | FID score | Link | Qualitative samples | 212 | | :--- | :--- | :--- |:--- | :--- | :--- | :--- | :--- | :--- | :--- | 213 | | @owang | Metfaces | ~1.3K |1024 x 1024 | 1 V100-32GB | 42 | 24 hrs | 101.624 | [model_link](http://geometry.cs.ucl.ac.uk/projects/2021/pro_gan_pytorch/model_metfaces.bin) | ![image](https://drive.google.com/uc?export=view&id=1loYYvM_d1uG7CKtGkJRpKTwY5CQIldxm) 214 | 215 | 216 | **Note that we compute the FID using the clean_fid version from 217 | [Parmar et. al.](https://www.cs.cmu.edu/~clean-fid/)** 218 | 219 | # General cool stuff :smile: 220 | ### Training timelapse (fixed latent points): 221 | The training timelapse created from the images logged during the training 222 | looks really cool. 223 |

224 | 226 |

227 |
228 | 229 | Checkout this [YT video](https://www.youtube.com/watch?v=lzTm6Lq76Mo) for a 230 | 4K version :smile:. 231 | 232 | If interested please feel free to check out this 233 | [medium blog]( https://medium.com/@animeshsk3/the-unprecedented-effectiveness-of-progressive-growing-of-gans-37475c88afa3) 234 | I wrote explaining the progressive growing technique. 235 | 236 | # References 237 | 238 | 1. Tero Karras, Timo Aila, Samuli Laine, & Jaakko Lehtinen (2018). 239 | Progressive Growing of GANs for Improved Quality, Stability, and Variation. 240 | In International Conference on Learning Representations. 241 | 242 | 2. Parmar, Gaurav, Richard Zhang, and Jun-Yan Zhu. 243 | "On Buggy Resizing Libraries and Surprising Subtleties in FID Calculation." 244 | arXiv preprint arXiv:2104.11222 (2021). 245 | 246 | # Feature requests 247 | - [ ] Conditional GAN support 248 | - [ ] Tool for generating time-lapse video from the log images 249 | - [ ] Integrating fid-metric computation as a training-logging 250 | 251 | # Thanks 252 | As always,
253 | please feel free to open PRs/issues/suggestions here. 254 | Hope this work is useful in your project :smile:. 255 | 256 | cheers :beers:!
257 | @akanimax :sunglasses: 258 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/README.rst -------------------------------------------------------------------------------- /pro_gan_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | """ Package has implementation of ProGAN (progressive growing of GANs) 2 | as an extension of PyTorch Module 3 | """ 4 | -------------------------------------------------------------------------------- /pro_gan_pytorch/custom_layers.py: -------------------------------------------------------------------------------- 1 | """ Module contains custom layers """ 2 | from typing import Any 3 | 4 | import numpy as np 5 | 6 | import torch 7 | from torch import Tensor 8 | from torch.nn import Conv2d, ConvTranspose2d, Linear 9 | 10 | 11 | def update_average(model_tgt, model_src, beta): 12 | """ 13 | function to calculate the Exponential moving averages for the Generator weights 14 | This function updates the exponential average weights based on the current training 15 | Args: 16 | model_tgt: target model 17 | model_src: source model 18 | beta: value of decay beta 19 | Returns: None (updates the target model) 20 | """ 21 | 22 | with torch.no_grad(): 23 | param_dict_src = dict(model_src.named_parameters()) 24 | 25 | for p_name, p_tgt in model_tgt.named_parameters(): 26 | p_src = param_dict_src[p_name] 27 | assert p_src is not p_tgt 28 | p_tgt.copy_(beta * p_tgt + (1.0 - beta) * p_src) 29 | 30 | 31 | class EqualizedConv2d(Conv2d): 32 | def __init__( 33 | self, 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride=1, 38 | padding=0, 39 | dilation=1, 40 | groups=1, 41 | bias=True, 42 | padding_mode="zeros", 43 | ) -> None: 44 | super().__init__( 45 | in_channels, 46 | out_channels, 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | groups, 52 | bias, 53 | padding_mode, 54 | ) 55 | # make sure that the self.weight and self.bias are initialized according to 56 | # random normal distribution 57 | torch.nn.init.normal_(self.weight) 58 | if bias: 59 | torch.nn.init.zeros_(self.bias) 60 | 61 | # define the scale for the weights: 62 | fan_in = np.prod(self.kernel_size) * self.in_channels 63 | self.scale = np.sqrt(2) / np.sqrt(fan_in) 64 | 65 | def forward(self, x: Tensor) -> Tensor: 66 | return torch.conv2d( 67 | input=x, 68 | weight=self.weight * self.scale, # scale the weight on runtime 69 | bias=self.bias, 70 | stride=self.stride, 71 | padding=self.padding, 72 | dilation=self.dilation, 73 | groups=self.groups, 74 | ) 75 | 76 | 77 | class EqualizedConvTranspose2d(ConvTranspose2d): 78 | def __init__( 79 | self, 80 | in_channels, 81 | out_channels, 82 | kernel_size, 83 | stride=1, 84 | padding=0, 85 | output_padding=0, 86 | groups=1, 87 | bias=True, 88 | dilation=1, 89 | padding_mode="zeros", 90 | ) -> None: 91 | super().__init__( 92 | in_channels, 93 | out_channels, 94 | kernel_size, 95 | stride, 96 | padding, 97 | output_padding, 98 | groups, 99 | bias, 100 | dilation, 101 | padding_mode, 102 | ) 103 | # make sure that the self.weight and self.bias are initialized according to 104 | # random normal distribution 105 | torch.nn.init.normal_(self.weight) 106 | if bias: 107 | torch.nn.init.zeros_(self.bias) 108 | 109 | # define the scale for the weights: 110 | fan_in = self.in_channels 111 | self.scale = np.sqrt(2) / np.sqrt(fan_in) 112 | 113 | def forward(self, x: Tensor, output_size: Any = None) -> Tensor: 114 | output_padding = self._output_padding( 115 | input, output_size, self.stride, self.padding, self.kernel_size 116 | ) 117 | return torch.conv_transpose2d( 118 | input=x, 119 | weight=self.weight * self.scale, # scale the weight on runtime 120 | bias=self.bias, 121 | stride=self.stride, 122 | padding=self.padding, 123 | output_padding=output_padding, 124 | groups=self.groups, 125 | dilation=self.dilation, 126 | ) 127 | 128 | 129 | class EqualizedLinear(Linear): 130 | def __init__(self, in_features, out_features, bias=True) -> None: 131 | super().__init__(in_features, out_features, bias) 132 | 133 | # make sure that the self.weight and self.bias are initialized according to 134 | # random normal distribution 135 | torch.nn.init.normal_(self.weight) 136 | if bias: 137 | torch.nn.init.zeros_(self.bias) 138 | 139 | # define the scale for the weights: 140 | fan_in = self.in_features 141 | self.scale = np.sqrt(2) / np.sqrt(fan_in) 142 | 143 | def forward(self, x: Tensor) -> Tensor: 144 | return torch.nn.functional.linear(x, self.weight * self.scale, self.bias) 145 | 146 | 147 | class PixelwiseNorm(torch.nn.Module): 148 | """ 149 | ------------------------------------------------------------------------------------ 150 | Pixelwise feature vector normalization. 151 | reference: 152 | https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120 153 | ------------------------------------------------------------------------------------ 154 | """ 155 | 156 | def __init__(self): 157 | super(PixelwiseNorm, self).__init__() 158 | 159 | @staticmethod 160 | def forward(x: Tensor, alpha: float = 1e-8) -> Tensor: 161 | y = x.pow(2.0).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW] 162 | y = x / y # normalize the input x volume 163 | return y 164 | 165 | 166 | class MinibatchStdDev(torch.nn.Module): 167 | """ 168 | Minibatch standard deviation layer for the discriminator 169 | Args: 170 | group_size: Size of each group into which the batch is split 171 | """ 172 | 173 | def __init__(self, group_size: int = 4) -> None: 174 | """ 175 | 176 | Args: 177 | group_size: Size of each group into which the batch is split 178 | """ 179 | super(MinibatchStdDev, self).__init__() 180 | self.group_size = group_size 181 | 182 | def extra_repr(self) -> str: 183 | return f"group_size={self.group_size}" 184 | 185 | def forward(self, x: Tensor, alpha: float = 1e-8) -> Tensor: 186 | """ 187 | forward pass of the layer 188 | Args: 189 | x: input activation volume 190 | alpha: small number for numerical stability 191 | Returns: y => x appended with standard deviation constant map 192 | """ 193 | batch_size, channels, height, width = x.shape 194 | if batch_size > self.group_size: 195 | assert batch_size % self.group_size == 0, ( 196 | f"batch_size {batch_size} should be " 197 | f"perfectly divisible by group_size {self.group_size}" 198 | ) 199 | group_size = self.group_size 200 | else: 201 | group_size = batch_size 202 | 203 | # reshape x into a more amenable sized tensor 204 | y = torch.reshape(x, [group_size, -1, channels, height, width]) 205 | 206 | # indicated shapes are after performing the operation 207 | # [G x M x C x H x W] Subtract mean over groups 208 | y = y - y.mean(dim=0, keepdim=True) 209 | 210 | # [M x C x H x W] Calc standard deviation over the groups 211 | y = torch.sqrt(y.square().mean(dim=0, keepdim=False) + alpha) 212 | 213 | # [M x 1 x 1 x 1] Take average over feature_maps and pixels. 214 | y = y.mean(dim=[1, 2, 3], keepdim=True) 215 | 216 | # [B x 1 x H x W] Replicate over group and pixels 217 | y = y.repeat(group_size, 1, height, width) 218 | 219 | # [B x (C + 1) x H x W] Append as new feature_map. 220 | y = torch.cat([x, y], 1) 221 | 222 | # return the computed values: 223 | return y 224 | -------------------------------------------------------------------------------- /pro_gan_pytorch/data_tools.py: -------------------------------------------------------------------------------- 1 | """ Module for the data loading pipeline for the model to train """ 2 | from pathlib import Path 3 | from typing import Any, Callable, List, Optional, Tuple 4 | 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from torch import Tensor 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision.transforms import Compose, RandomHorizontalFlip, Resize, ToTensor 11 | 12 | from .utils import adjust_dynamic_range 13 | 14 | 15 | class NoOp(object): 16 | """A NoOp image transform utility. Does nothing, but makes the code cleaner""" 17 | 18 | def __call__(self, whatever: Any) -> Any: 19 | return whatever 20 | 21 | def __repr__(self) -> str: 22 | return self.__class__.__name__ + "()" 23 | 24 | 25 | def get_transform( 26 | new_size: Optional[Tuple[int, int]] = None, flip_horizontal: bool = False 27 | ) -> Callable[[Image.Image], Tensor]: 28 | """ 29 | obtain the image transforms required for the input data 30 | Args: 31 | new_size: size of the resized images (if needed, could be None) 32 | flip_horizontal: whether to randomly mirror input images during training 33 | Returns: requested transform object from TorchVision 34 | """ 35 | return Compose( 36 | [ 37 | RandomHorizontalFlip(p=0.5) if flip_horizontal else NoOp(), 38 | Resize(new_size) if new_size is not None else NoOp(), 39 | ToTensor(), 40 | ] 41 | ) 42 | 43 | 44 | class ImageDirectoryDataset(Dataset): 45 | """pyTorch Dataset wrapper for the simple case of flat directory images dataset 46 | Args: 47 | data_dir: directory containing all the images 48 | transform: whether to apply a certain transformation to the images 49 | rec_dir: whether to search all the sub-level directories for files 50 | recursively 51 | """ 52 | 53 | def __init__( 54 | self, 55 | data_dir: Path, 56 | transform: Callable[[Image.Image], Tensor] = get_transform(), 57 | input_data_range: Tuple[float, float] = (0.0, 1.0), 58 | output_data_range: Tuple[float, float] = (-1.0, 1.0), 59 | rec_dir: bool = False, 60 | ) -> None: 61 | # define the state of the object 62 | self.rec_dir = rec_dir 63 | self.data_dir = data_dir 64 | self.transform = transform 65 | self.output_data_range = output_data_range 66 | self.input_data_range = input_data_range 67 | 68 | # setup the files for reading 69 | self.files = self._get_files(data_dir, rec_dir) 70 | 71 | def _get_files(self, path: Path, rec: bool = False) -> List[Path]: 72 | """ 73 | helper function to search the given directory and obtain all the files in it's 74 | structure 75 | Args: 76 | path: path to the root directory 77 | rec: whether to search all the sub-level directories for files recursively 78 | Returns: list of all found paths 79 | """ 80 | files = [] 81 | for possible_file in path.iterdir(): 82 | if possible_file.is_file(): 83 | files.append(possible_file) 84 | elif rec and possible_file.is_dir(): 85 | files.extend(self._get_files(possible_file)) 86 | return files 87 | 88 | def __len__(self) -> int: 89 | """ 90 | compute the length of the dataset 91 | Returns: len => length of dataset 92 | """ 93 | return len(self.files) 94 | 95 | def __getitem__(self, item: int) -> Tensor: 96 | """ 97 | obtain the image (read and transform) 98 | Args: 99 | item: index for the required image 100 | Returns: img => image array 101 | """ 102 | # read the image: 103 | image = self.files[item] 104 | if image.name.endswith(".npy"): 105 | img = np.load(str(image)) 106 | img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0)) 107 | else: 108 | img = Image.open(image) 109 | 110 | # apply the transforms on the image 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | 114 | # bring the image in the required range 115 | img = adjust_dynamic_range( 116 | img, drange_in=self.input_data_range, drange_out=self.output_data_range 117 | ) 118 | 119 | return img 120 | 121 | 122 | def get_data_loader( 123 | dataset: Dataset, batch_size: int, num_workers: int = 3 124 | ) -> DataLoader: 125 | """ 126 | generate the data_loader from the given dataset 127 | Args: 128 | dataset: Torch dataset object 129 | batch_size: batch size for training 130 | num_workers: num of parallel readers for reading the data 131 | Returns: dataloader for the dataset 132 | """ 133 | return DataLoader( 134 | dataset, 135 | batch_size=batch_size, 136 | shuffle=True, 137 | num_workers=num_workers, 138 | drop_last=True, 139 | ) 140 | -------------------------------------------------------------------------------- /pro_gan_pytorch/gan.py: -------------------------------------------------------------------------------- 1 | """ Module implementing ProGAN which is trained using the Progressive growing 2 | technique -> https://arxiv.org/abs/1710.10196 3 | """ 4 | import copy 5 | import datetime 6 | import time 7 | import timeit 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Optional 10 | 11 | import numpy as np 12 | 13 | import torch 14 | from torch import Tensor 15 | from torch.nn import DataParallel, Module 16 | from torch.nn.functional import avg_pool2d, interpolate 17 | from torch.optim.optimizer import Optimizer 18 | from torch.utils.data import Dataset 19 | from torch.utils.tensorboard import SummaryWriter 20 | from torchvision.utils import save_image 21 | 22 | from .custom_layers import update_average 23 | from .data_tools import get_data_loader 24 | from .losses import GANLoss, WganGP 25 | from .networks import Discriminator, Generator 26 | from .utils import adjust_dynamic_range 27 | 28 | 29 | class ProGAN: 30 | def __init__( 31 | self, 32 | gen: Generator, 33 | dis: Discriminator, 34 | device=torch.device("cpu"), 35 | use_ema: bool = True, 36 | ema_beta: float = 0.999, 37 | ): 38 | assert gen.depth == dis.depth, ( 39 | f"Generator and Discriminator depths are not compatible. " 40 | f"GEN_Depth: {gen.depth} DIS_Depth: {dis.depth}" 41 | ) 42 | self.gen = gen.to(device) 43 | self.dis = dis.to(device) 44 | self.use_ema = use_ema 45 | self.ema_beta = ema_beta 46 | self.depth = gen.depth 47 | self.latent_size = gen.latent_size 48 | self.device = device 49 | 50 | # if code is to be run on GPU, we can use DataParallel: 51 | if device == torch.device("cuda"): 52 | self.gen = DataParallel(self.gen) 53 | self.dis = DataParallel(self.dis) 54 | 55 | print(f"Generator Network: {self.gen}") 56 | print(f"Discriminator Network: {self.dis}") 57 | 58 | if self.use_ema: 59 | # create a shadow copy of the generator 60 | self.gen_shadow = copy.deepcopy(self.gen) 61 | 62 | # initialize the gen_shadow weights equal to the 63 | # weights of gen 64 | update_average(self.gen_shadow, self.gen, beta=0) 65 | 66 | # counters to maintain generator and discriminator gradient overflows 67 | self.gen_overflow_count = 0 68 | self.dis_overflow_count = 0 69 | 70 | def progressive_downsample_batch(self, real_batch, depth, alpha): 71 | """ 72 | private helper for downsampling the original images in order to facilitate the 73 | progressive growing of the layers. 74 | Args: 75 | real_batch: batch of real samples 76 | depth: depth at which training is going on 77 | alpha: current value of the fader alpha 78 | 79 | Returns: modified real batch of samples 80 | 81 | """ 82 | # downsample the real_batch for the given depth 83 | down_sample_factor = int(2 ** (self.depth - depth)) 84 | prior_downsample_factor = int(2 ** (self.depth - depth + 1)) 85 | 86 | ds_real_samples = avg_pool2d( 87 | real_batch, kernel_size=down_sample_factor, stride=down_sample_factor 88 | ) 89 | 90 | if depth > 2: 91 | prior_ds_real_samples = interpolate( 92 | avg_pool2d( 93 | real_batch, 94 | kernel_size=prior_downsample_factor, 95 | stride=prior_downsample_factor, 96 | ), 97 | scale_factor=2, 98 | ) 99 | else: 100 | prior_ds_real_samples = ds_real_samples 101 | 102 | # real samples are a linear combination of 103 | # ds_real_samples and prior_ds_real_samples 104 | real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples) 105 | 106 | return real_samples 107 | 108 | def optimize_discriminator( 109 | self, 110 | loss: GANLoss, 111 | dis_optim: Optimizer, 112 | noise: Tensor, 113 | real_batch: Tensor, 114 | depth: int, 115 | alpha: float, 116 | labels: Optional[Tensor] = None, 117 | ) -> float: 118 | """ 119 | performs a single weight update step on discriminator using the batch of data 120 | and the noise 121 | Args: 122 | loss: the loss function to be used for the optimization 123 | dis_optim: discriminator optimizer 124 | noise: input noise for sample generation 125 | real_batch: real samples batch 126 | depth: current depth of optimization 127 | alpha: current alpha for fade-in 128 | labels: labels for conditional discrimination 129 | 130 | Returns: discriminator loss value 131 | """ 132 | real_samples = self.progressive_downsample_batch(real_batch, depth, alpha) 133 | 134 | # generate a batch of samples 135 | fake_samples = self.gen(noise, depth, alpha).detach() 136 | dis_loss = loss.dis_loss( 137 | self.dis, real_samples, fake_samples, depth, alpha, labels=labels 138 | ) 139 | 140 | # optimize discriminator 141 | dis_optim.zero_grad() 142 | dis_loss.backward() 143 | if self._check_grad_ok(self.dis): 144 | dis_optim.step() 145 | else: 146 | self.dis_overflow_count += 1 147 | 148 | return dis_loss.item() 149 | 150 | def optimize_generator( 151 | self, 152 | loss: GANLoss, 153 | gen_optim: Optimizer, 154 | noise: Tensor, 155 | real_batch: Tensor, 156 | depth: int, 157 | alpha: float, 158 | labels: Optional[Tensor] = None, 159 | ) -> float: 160 | """ 161 | performs a single weight update step on generator using the batch of data 162 | and the noise 163 | Args: 164 | loss: the loss function to be used for the optimization 165 | gen_optim: generator optimizer 166 | noise: input noise for sample generation 167 | real_batch: real samples batch 168 | depth: current depth of optimization 169 | alpha: current alpha for fade-in 170 | labels: labels for conditional discrimination 171 | 172 | Returns: generator loss value 173 | """ 174 | real_samples = self.progressive_downsample_batch(real_batch, depth, alpha) 175 | 176 | # generate fake samples: 177 | fake_samples = self.gen(noise, depth, alpha) 178 | 179 | gen_loss = loss.gen_loss( 180 | self.dis, real_samples, fake_samples, depth, alpha, labels=labels 181 | ) 182 | 183 | # optimize the generator 184 | gen_optim.zero_grad() 185 | gen_loss.backward() 186 | if self._check_grad_ok(self.gen): 187 | gen_optim.step() 188 | else: 189 | self.gen_overflow_count += 1 190 | 191 | return gen_loss.item() 192 | 193 | @staticmethod 194 | def create_grid( 195 | samples: Tensor, 196 | scale_factor: int, 197 | img_file: Path, 198 | ) -> None: 199 | """ 200 | utility function to create a grid of GAN samples 201 | Args: 202 | samples: generated samples for feedback 203 | scale_factor: factor for upscaling the image 204 | img_file: name of file to write 205 | Returns: None (saves a file) 206 | """ 207 | # upsample the image 208 | if scale_factor > 1: 209 | samples = interpolate(samples, scale_factor=scale_factor) 210 | 211 | samples = adjust_dynamic_range( 212 | samples, drange_in=(-1.0, 1.0), drange_out=(0.0, 1.0) 213 | ) 214 | 215 | # save the images: 216 | save_image(samples, img_file, nrow=int(np.sqrt(len(samples))), padding=0) 217 | 218 | def _toggle_all_networks(self, mode="train"): 219 | for network in (self.gen, self.dis): 220 | if mode.lower() == "train": 221 | network.train() 222 | elif mode.lower() == "eval": 223 | network.eval() 224 | else: 225 | raise ValueError(f"Unknown mode requested: {mode}") 226 | 227 | @staticmethod 228 | def _check_grad_ok(network: Module) -> bool: 229 | grad_ok = True 230 | for _, param in network.named_parameters(): 231 | if param.grad is not None: 232 | param_ok = ( 233 | torch.sum(torch.isnan(param.grad)) == 0 234 | and torch.sum(torch.isinf(param.grad)) == 0 235 | ) 236 | if not param_ok: 237 | grad_ok = False 238 | break 239 | return grad_ok 240 | 241 | def get_save_info( 242 | self, gen_optim: Optimizer, dis_optim: Optimizer 243 | ) -> Dict[str, Any]: 244 | 245 | if self.device == torch.device("cpu"): 246 | generator_save_info = self.gen.get_save_info() 247 | discriminator_save_info = self.dis.get_save_info() 248 | else: 249 | generator_save_info = self.gen.module.get_save_info() 250 | discriminator_save_info = self.dis.module.get_save_info() 251 | save_info = { 252 | "generator": generator_save_info, 253 | "discriminator": discriminator_save_info, 254 | "gen_optim": gen_optim.state_dict(), 255 | "dis_optim": dis_optim.state_dict(), 256 | } 257 | if self.use_ema: 258 | save_info["shadow_generator"] = ( 259 | self.gen_shadow.get_save_info() 260 | if self.device == torch.device("cpu") 261 | else self.gen_shadow.module.get_save_info() 262 | ) 263 | return save_info 264 | 265 | def train( 266 | self, 267 | dataset: Dataset, 268 | epochs: List[int], 269 | batch_sizes: List[int], 270 | fade_in_percentages: List[int], 271 | loss_fn: GANLoss = WganGP(), 272 | batch_repeats: int = 4, 273 | gen_learning_rate: float = 0.003, 274 | dis_learning_rate: float = 0.003, 275 | num_samples: int = 16, 276 | start_depth: int = 2, 277 | num_workers: int = 3, 278 | feedback_factor: int = 100, 279 | save_dir=Path("./train"), 280 | checkpoint_factor: int = 10, 281 | ): 282 | """ 283 | # TODO implement support for conditional GAN here 284 | Utility method for training the ProGAN. 285 | Note that you don't have to necessarily use this method. You can use the 286 | optimize_generator and optimize_discriminator and define your own 287 | training routine 288 | Args: 289 | dataset: object of the dataset used for training. 290 | Note that this is not the dataloader (we create dataloader in this 291 | method since the batch_sizes for resolutions can be different) 292 | epochs: list of number of epochs to train the network for every resolution 293 | batch_sizes: list of batch_sizes for every resolution 294 | fade_in_percentages: list of percentages of epochs per resolution 295 | used for fading in the new layer not used for 296 | first resolution, but dummy value is still needed 297 | loss_fn: loss function used for training 298 | batch_repeats: number of iterations to perform on a single batch 299 | gen_learning_rate: generator learning rate 300 | dis_learning_rate: discriminator learning rate 301 | num_samples: number of samples generated in sample_sheet 302 | start_depth: start training from this depth 303 | num_workers: number of workers for reading the data 304 | feedback_factor: number of logs per epoch 305 | save_dir: directory for saving the models (.bin files) 306 | checkpoint_factor: save model after these many epochs. 307 | Returns: None (Writes multiple files to disk) 308 | """ 309 | 310 | print(f"Loaded the dataset with: {len(dataset)} images ...") 311 | assert (self.depth - 1) == len( 312 | batch_sizes 313 | ), "batch_sizes are not compatible with depth" 314 | assert (self.depth - 1) == len(epochs), "epochs are not compatible with depth" 315 | 316 | self._toggle_all_networks("train") 317 | 318 | # create the generator and discriminator optimizers 319 | gen_optim = torch.optim.Adam( 320 | params=self.gen.parameters(), 321 | lr=gen_learning_rate, 322 | betas=(0, 0.99), 323 | eps=1e-8, 324 | ) 325 | dis_optim = torch.optim.Adam( 326 | params=self.dis.parameters(), 327 | lr=dis_learning_rate, 328 | betas=(0, 0.99), 329 | eps=1e-8, 330 | ) 331 | 332 | # verbose stuff 333 | print("setting up the image saving mechanism") 334 | model_dir, log_dir = save_dir / "models", save_dir / "logs" 335 | model_dir.mkdir(parents=True, exist_ok=True) 336 | log_dir.mkdir(parents=True, exist_ok=True) 337 | 338 | feedback_generator = self.gen_shadow if self.use_ema else self.gen 339 | 340 | # image saving mechanism 341 | with torch.no_grad(): 342 | dummy_data_loader = get_data_loader(dataset, num_samples, num_workers) 343 | real_images_for_render = next(iter(dummy_data_loader)) 344 | fixed_input = torch.randn(num_samples, self.latent_size).to(self.device) 345 | self.create_grid( 346 | real_images_for_render, 347 | scale_factor=1, 348 | img_file=log_dir / "real_images.png", 349 | ) 350 | self.create_grid( 351 | feedback_generator(fixed_input, self.depth, 1).detach(), 352 | scale_factor=1, 353 | img_file=log_dir / "fake_images_0.png", 354 | ) 355 | 356 | # tensorboard summarywriter: 357 | summary = SummaryWriter(str(log_dir / "tensorboard")) 358 | 359 | # create a global time counter 360 | global_time = time.time() 361 | global_step = 0 362 | 363 | print("Starting the training process ... ") 364 | for current_depth in range(start_depth, self.depth + 1): 365 | current_res = int(2 ** current_depth) 366 | print(f"\n\nCurrently working on Depth: {current_depth}") 367 | print("Current resolution: %d x %d" % (current_res, current_res)) 368 | depth_list_index = current_depth - 2 369 | current_batch_size = batch_sizes[depth_list_index] 370 | data = get_data_loader(dataset, current_batch_size, num_workers) 371 | ticker = 1 372 | for epoch in range(1, epochs[depth_list_index] + 1): 373 | start = timeit.default_timer() # record time at the start of epoch 374 | print(f"\nEpoch: {epoch}") 375 | total_batches = len(data) 376 | 377 | # compute the fader point 378 | fader_point = int( 379 | (fade_in_percentages[depth_list_index] / 100) 380 | * epochs[depth_list_index] 381 | * total_batches 382 | ) 383 | 384 | for (i, batch) in enumerate(data, start=1): 385 | # calculate the alpha for fading in the layers 386 | alpha = ticker / fader_point if ticker <= fader_point else 1 387 | 388 | # extract current batch of data for training 389 | images = batch.to(self.device) 390 | 391 | gan_input = torch.randn(current_batch_size, self.latent_size).to( 392 | self.device 393 | ) 394 | 395 | gen_loss, dis_loss = None, None 396 | for _ in range(batch_repeats): 397 | # optimize the discriminator: 398 | dis_loss = self.optimize_discriminator( 399 | loss_fn, dis_optim, gan_input, images, current_depth, alpha 400 | ) 401 | 402 | # no idea why this needs to be done after discriminator 403 | # iteration, but this is where it is done in the original 404 | # code 405 | if self.use_ema: 406 | update_average( 407 | self.gen_shadow, self.gen, beta=self.ema_beta 408 | ) 409 | 410 | # optimize the generator: 411 | gen_loss = self.optimize_generator( 412 | loss_fn, gen_optim, gan_input, images, current_depth, alpha 413 | ) 414 | global_step += 1 415 | 416 | # provide a loss feedback 417 | if ( 418 | i % max(int(total_batches / max(feedback_factor, 1)), 1) == 0 419 | or i == 1 420 | or i == total_batches 421 | ): 422 | elapsed = time.time() - global_time 423 | elapsed = str(datetime.timedelta(seconds=elapsed)) 424 | print( 425 | "Elapsed: [%s] batch: %d d_loss: %f g_loss: %f" 426 | % (elapsed, i, dis_loss, gen_loss) 427 | ) 428 | summary.add_scalar( 429 | "dis_loss", dis_loss, global_step=global_step 430 | ) 431 | summary.add_scalar( 432 | "gen_loss", gen_loss, global_step=global_step 433 | ) 434 | # create a grid of samples and save it 435 | resolution_dir = log_dir / str(int(2 ** current_depth)) 436 | resolution_dir.mkdir(exist_ok=True) 437 | gen_img_file = resolution_dir / f"{epoch}_{i}.png" 438 | 439 | # this is done to allow for more GPU space 440 | with torch.no_grad(): 441 | self.create_grid( 442 | samples=feedback_generator( 443 | fixed_input, current_depth, alpha 444 | ).detach(), 445 | scale_factor=int(2 ** (self.depth - current_depth)), 446 | img_file=gen_img_file, 447 | ) 448 | 449 | # increment the alpha ticker and the step 450 | ticker += 1 451 | 452 | stop = timeit.default_timer() 453 | print("Time taken for epoch: %.3f secs" % (stop - start)) 454 | 455 | if ( 456 | epoch % checkpoint_factor == 0 457 | or epoch == 1 458 | or epoch == epochs[depth_list_index] 459 | ): 460 | save_file = model_dir / f"depth_{current_depth}_epoch_{epoch}.bin" 461 | torch.save(self.get_save_info(gen_optim, dis_optim), save_file) 462 | 463 | self._toggle_all_networks("eval") 464 | print("Training completed ...") 465 | -------------------------------------------------------------------------------- /pro_gan_pytorch/losses.py: -------------------------------------------------------------------------------- 1 | """ Module implementing various loss functions """ 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import BCEWithLogitsLoss 7 | 8 | from .networks import Discriminator 9 | 10 | 11 | class GANLoss: 12 | def dis_loss( 13 | self, 14 | discriminator: Discriminator, 15 | real_samples: Tensor, 16 | fake_samples: Tensor, 17 | depth: int, 18 | alpha: float, 19 | labels: Optional[Tensor] = None, 20 | ) -> Tensor: 21 | """ 22 | calculate the discriminator loss using the following data 23 | Args: 24 | discriminator: the Discriminator used by the GAN 25 | real_samples: real batch of samples 26 | fake_samples: fake batch of samples 27 | depth: resolution log 2 of the images under consideration 28 | alpha: alpha value of the fader 29 | labels: optional in case of the conditional discriminator 30 | 31 | Returns: computed discriminator loss 32 | """ 33 | raise NotImplementedError("dis_loss method has not been implemented") 34 | 35 | def gen_loss( 36 | self, 37 | discriminator: Discriminator, 38 | real_samples: Tensor, 39 | fake_samples: Tensor, 40 | depth: int, 41 | alpha: float, 42 | labels: Optional[Tensor] = None, 43 | ) -> Tensor: 44 | """ 45 | calculate the generator loss using the following data 46 | Args: 47 | discriminator: the Discriminator used by the GAN 48 | real_samples: real batch of samples 49 | fake_samples: fake batch of samples 50 | depth: resolution log 2 of the images under consideration 51 | alpha: alpha value of the fader 52 | labels: optional in case of the conditional discriminator 53 | 54 | Returns: computed discriminator loss 55 | """ 56 | raise NotImplementedError("gen_loss method has not been implemented") 57 | 58 | 59 | class StandardGAN(GANLoss): 60 | def __init__(self): 61 | self.criterion = BCEWithLogitsLoss() 62 | 63 | def dis_loss( 64 | self, 65 | discriminator: Discriminator, 66 | real_samples: Tensor, 67 | fake_samples: Tensor, 68 | depth: int, 69 | alpha: float, 70 | labels: Optional[Tensor] = None, 71 | ) -> Tensor: 72 | if labels is not None: 73 | assert discriminator.conditional, "labels passed to an unconditional dis" 74 | real_scores = discriminator(real_samples, depth, alpha, labels) 75 | fake_scores = discriminator(fake_samples, depth, alpha, labels) 76 | else: 77 | real_scores = discriminator(real_samples, depth, alpha) 78 | fake_scores = discriminator(fake_samples, depth, alpha) 79 | 80 | real_loss = self.criterion( 81 | real_scores, torch.ones(real_scores.shape).to(real_scores.device) 82 | ) 83 | fake_loss = self.criterion( 84 | fake_scores, torch.zeros(fake_scores.shape).to(fake_scores.device) 85 | ) 86 | return (real_loss + fake_loss) / 2 87 | 88 | def gen_loss( 89 | self, 90 | discriminator: Discriminator, 91 | _: Tensor, 92 | fake_samples: Tensor, 93 | depth: int, 94 | alpha: float, 95 | labels: Optional[Tensor] = None, 96 | ) -> Tensor: 97 | if labels is not None: 98 | assert discriminator.conditional, "labels passed to an unconditional dis" 99 | fake_scores = discriminator(fake_samples, depth, alpha, labels) 100 | else: 101 | fake_scores = discriminator(fake_samples, depth, alpha) 102 | return self.criterion( 103 | fake_scores, torch.ones(fake_scores.shape).to(fake_scores.device) 104 | ) 105 | 106 | 107 | class WganGP(GANLoss): 108 | """ 109 | Wgan-GP loss function. The discriminator is required for computing the gradient 110 | penalty. 111 | Args: 112 | drift: weight for the drift penalty 113 | """ 114 | 115 | def __init__(self, drift: float = 0.001) -> None: 116 | self.drift = drift 117 | 118 | @staticmethod 119 | def _gradient_penalty( 120 | dis: Discriminator, 121 | real_samples: Tensor, 122 | fake_samples: Tensor, 123 | depth: int, 124 | alpha: float, 125 | reg_lambda: float = 10, 126 | labels: Optional[Tensor] = None, 127 | ) -> Tensor: 128 | """ 129 | private helper for calculating the gradient penalty 130 | Args: 131 | dis: the discriminator used for computing the penalty 132 | real_samples: real samples 133 | fake_samples: fake samples 134 | depth: current depth in the optimization 135 | alpha: current alpha for fade-in 136 | reg_lambda: regularisation lambda 137 | 138 | Returns: computed gradient penalty 139 | """ 140 | batch_size = real_samples.shape[0] 141 | 142 | # generate random epsilon 143 | epsilon = torch.rand((batch_size, 1, 1, 1)).to(real_samples.device) 144 | 145 | # create the merge of both real and fake samples 146 | merged = epsilon * real_samples + ((1 - epsilon) * fake_samples) 147 | merged.requires_grad_(True) 148 | 149 | # forward pass 150 | if labels is not None: 151 | assert dis.conditional, "labels passed to an unconditional discriminator" 152 | op = dis(merged, depth, alpha, labels) 153 | else: 154 | op = dis(merged, depth, alpha) 155 | 156 | # perform backward pass from op to merged for obtaining the gradients 157 | gradient = torch.autograd.grad( 158 | outputs=op, 159 | inputs=merged, 160 | grad_outputs=torch.ones_like(op), 161 | create_graph=True, 162 | retain_graph=True, 163 | only_inputs=True, 164 | )[0] 165 | 166 | # calculate the penalty using these gradients 167 | gradient = gradient.view(gradient.shape[0], -1) 168 | penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean() 169 | 170 | # return the calculated penalty: 171 | return penalty 172 | 173 | def dis_loss( 174 | self, 175 | discriminator: Discriminator, 176 | real_samples: Tensor, 177 | fake_samples: Tensor, 178 | depth: int, 179 | alpha: float, 180 | labels: Optional[Tensor] = None, 181 | ) -> Tensor: 182 | if labels is not None: 183 | assert discriminator.conditional, "labels passed to an unconditional dis" 184 | real_scores = discriminator(real_samples, depth, alpha, labels) 185 | fake_scores = discriminator(fake_samples, depth, alpha, labels) 186 | else: 187 | real_scores = discriminator(real_samples, depth, alpha) 188 | fake_scores = discriminator(fake_samples, depth, alpha) 189 | loss = ( 190 | torch.mean(fake_scores) 191 | - torch.mean(real_scores) 192 | + (self.drift * torch.mean(real_scores ** 2)) 193 | ) 194 | 195 | # calculate the WGAN-GP (gradient penalty) 196 | gp = self._gradient_penalty( 197 | discriminator, real_samples, fake_samples, depth, alpha, labels=labels 198 | ) 199 | loss += gp 200 | 201 | return loss 202 | 203 | def gen_loss( 204 | self, 205 | discriminator: Discriminator, 206 | _: Tensor, 207 | fake_samples: Tensor, 208 | depth: int, 209 | alpha: float, 210 | labels: Optional[Tensor] = None, 211 | ) -> Tensor: 212 | if labels is not None: 213 | assert discriminator.conditional, "labels passed to an unconditional dis" 214 | fake_scores = discriminator(fake_samples, depth, alpha, labels) 215 | else: 216 | fake_scores = discriminator(fake_samples, depth, alpha) 217 | return -torch.mean(fake_scores) 218 | -------------------------------------------------------------------------------- /pro_gan_pytorch/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .custom_layers import ( 3 | EqualizedConv2d, 4 | EqualizedConvTranspose2d, 5 | MinibatchStdDev, 6 | PixelwiseNorm, 7 | ) 8 | from torch import Tensor 9 | from torch.nn import AvgPool2d, Conv2d, ConvTranspose2d, Embedding, LeakyReLU, Module 10 | from torch.nn.functional import interpolate 11 | 12 | 13 | class GenInitialBlock(Module): 14 | """ 15 | Module implementing the initial block of the input 16 | Args: 17 | in_channels: number of input channels to the block 18 | out_channels: number of output channels of the block 19 | use_eql: whether to use equalized learning rate 20 | """ 21 | 22 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None: 23 | super(GenInitialBlock, self).__init__() 24 | self.use_eql = use_eql 25 | 26 | ConvBlock = EqualizedConv2d if use_eql else Conv2d 27 | ConvTransposeBlock = EqualizedConvTranspose2d if use_eql else ConvTranspose2d 28 | 29 | self.conv_1 = ConvTransposeBlock(in_channels, out_channels, (4, 4), bias=True) 30 | self.conv_2 = ConvBlock( 31 | out_channels, out_channels, (3, 3), padding=1, bias=True 32 | ) 33 | self.pixNorm = PixelwiseNorm() 34 | self.lrelu = LeakyReLU(0.2) 35 | 36 | def forward(self, x: Tensor) -> Tensor: 37 | y = torch.unsqueeze(torch.unsqueeze(x, -1), -1) 38 | y = self.pixNorm(y) # normalize the latents to hypersphere 39 | y = self.lrelu(self.conv_1(y)) 40 | y = self.lrelu(self.conv_2(y)) 41 | y = self.pixNorm(y) 42 | return y 43 | 44 | 45 | class GenGeneralConvBlock(torch.nn.Module): 46 | """ 47 | Module implementing a general convolutional block 48 | Args: 49 | in_channels: number of input channels to the block 50 | out_channels: number of output channels required 51 | use_eql: whether to use equalized learning rate 52 | """ 53 | 54 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None: 55 | super(GenGeneralConvBlock, self).__init__() 56 | self.in_channels = in_channels 57 | self.out_channels = in_channels 58 | self.use_eql = use_eql 59 | 60 | ConvBlock = EqualizedConv2d if use_eql else Conv2d 61 | 62 | self.conv_1 = ConvBlock(in_channels, out_channels, (3, 3), padding=1, bias=True) 63 | self.conv_2 = ConvBlock( 64 | out_channels, out_channels, (3, 3), padding=1, bias=True 65 | ) 66 | self.pixNorm = PixelwiseNorm() 67 | self.lrelu = LeakyReLU(0.2) 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | y = interpolate(x, scale_factor=2) 71 | y = self.pixNorm(self.lrelu(self.conv_1(y))) 72 | y = self.pixNorm(self.lrelu(self.conv_2(y))) 73 | 74 | return y 75 | 76 | 77 | class DisFinalBlock(torch.nn.Module): 78 | """ 79 | Final block for the Discriminator 80 | Args: 81 | in_channels: number of input channels 82 | use_eql: whether to use equalized learning rate 83 | """ 84 | 85 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None: 86 | super(DisFinalBlock, self).__init__() 87 | self.in_channels = in_channels 88 | self.out_channels = out_channels 89 | self.use_eql = use_eql 90 | 91 | ConvBlock = EqualizedConv2d if use_eql else Conv2d 92 | 93 | self.conv_1 = ConvBlock( 94 | in_channels + 1, in_channels, (3, 3), padding=1, bias=True 95 | ) 96 | self.conv_2 = ConvBlock(in_channels, out_channels, (4, 4), bias=True) 97 | self.conv_3 = ConvBlock(out_channels, 1, (1, 1), bias=True) 98 | self.batch_discriminator = MinibatchStdDev() 99 | self.lrelu = LeakyReLU(0.2) 100 | 101 | def forward(self, x: Tensor) -> Tensor: 102 | y = self.batch_discriminator(x) 103 | y = self.lrelu(self.conv_1(y)) 104 | y = self.lrelu(self.conv_2(y)) 105 | y = self.conv_3(y) 106 | return y.view(-1) 107 | 108 | 109 | class ConDisFinalBlock(torch.nn.Module): 110 | """ Final block for the Conditional Discriminator 111 | Uses the Projection mechanism 112 | from the paper -> https://arxiv.org/pdf/1802.05637.pdf 113 | Args: 114 | in_channels: number of input channels 115 | num_classes: number of classes for conditional discrimination 116 | use_eql: whether to use equalized learning rate 117 | """ 118 | 119 | def __init__( 120 | self, in_channels: int, out_channels: int, num_classes: int, use_eql: bool 121 | ) -> None: 122 | super(ConDisFinalBlock, self).__init__() 123 | self.in_channels = in_channels 124 | self.out_channels = out_channels 125 | self.num_classes = num_classes 126 | self.use_eql = use_eql 127 | 128 | ConvBlock = EqualizedConv2d if use_eql else Conv2d 129 | 130 | self.conv_1 = ConvBlock( 131 | in_channels + 1, in_channels, (3, 3), padding=1, bias=True 132 | ) 133 | self.conv_2 = ConvBlock(in_channels, out_channels, (4, 4), bias=True) 134 | self.conv_3 = ConvBlock(out_channels, 1, (1, 1), bias=True) 135 | 136 | # we also need an embedding matrix for the label vectors 137 | self.label_embedder = Embedding(num_classes, out_channels, max_norm=1) 138 | self.batch_discriminator = MinibatchStdDev() 139 | self.lrelu = LeakyReLU(0.2) 140 | 141 | def forward(self, x: Tensor, labels: Tensor) -> Tensor: 142 | y = self.batch_discriminator(x) 143 | y = self.lrelu(self.conv_1(y)) 144 | y = self.lrelu(self.conv_2(y)) 145 | 146 | # embed the labels 147 | labels = self.label_embedder(labels) # [B x C] 148 | 149 | # compute the inner product with the label embeddings 150 | y_ = torch.squeeze(torch.squeeze(y, dim=-1), dim=-1) # [B x C] 151 | projection_scores = (y_ * labels).sum(dim=-1) # [B] 152 | 153 | # normal discrimination score 154 | y = self.lrelu(self.conv_3(y)) # This layer has linear activation 155 | 156 | # calculate the total score 157 | final_score = y.view(-1) + projection_scores 158 | 159 | # return the output raw discriminator scores 160 | return final_score 161 | 162 | 163 | class DisGeneralConvBlock(torch.nn.Module): 164 | """ 165 | General block in the discriminator 166 | Args: 167 | in_channels: number of input channels 168 | out_channels: number of output channels 169 | use_eql: whether to use equalized learning rate 170 | """ 171 | 172 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None: 173 | super(DisGeneralConvBlock, self).__init__() 174 | self.in_channels = in_channels 175 | self.out_channels = out_channels 176 | self.use_eql = use_eql 177 | 178 | ConvBlock = EqualizedConv2d if use_eql else Conv2d 179 | 180 | self.conv_1 = ConvBlock(in_channels, in_channels, (3, 3), padding=1, bias=True) 181 | self.conv_2 = ConvBlock(in_channels, out_channels, (3, 3), padding=1, bias=True) 182 | self.downSampler = AvgPool2d(2) 183 | self.lrelu = LeakyReLU(0.2) 184 | 185 | def forward(self, x: Tensor) -> Tensor: 186 | y = self.lrelu(self.conv_1(x)) 187 | y = self.lrelu(self.conv_2(y)) 188 | y = self.downSampler(y) 189 | return y 190 | -------------------------------------------------------------------------------- /pro_gan_pytorch/networks.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import torch as th 8 | from .custom_layers import EqualizedConv2d 9 | from .modules import ( 10 | ConDisFinalBlock, 11 | DisFinalBlock, 12 | DisGeneralConvBlock, 13 | GenGeneralConvBlock, 14 | GenInitialBlock, 15 | ) 16 | from torch import Tensor 17 | from torch.nn import Conv2d, LeakyReLU, ModuleList, Sequential 18 | from torch.nn.functional import avg_pool2d, interpolate 19 | 20 | 21 | def nf( 22 | stage: int, 23 | fmap_base: int = 16 << 10, 24 | fmap_decay: float = 1.0, 25 | fmap_min: int = 1, 26 | fmap_max: int = 512, 27 | ) -> int: 28 | """ 29 | computes the number of fmaps present in each stage 30 | Args: 31 | stage: stage level 32 | fmap_base: base number of fmaps 33 | fmap_decay: decay rate for the fmaps in the network 34 | fmap_min: minimum number of fmaps 35 | fmap_max: maximum number of fmaps 36 | 37 | Returns: number of fmaps that should be present there 38 | """ 39 | return int( 40 | np.clip( 41 | int(fmap_base / (2.0 ** (stage * fmap_decay))), 42 | fmap_min, 43 | fmap_max, 44 | ).item() 45 | ) 46 | 47 | 48 | class Generator(th.nn.Module): 49 | """ 50 | Generator Module (block) of the GAN network 51 | Args: 52 | depth: required depth of the Network (**starts from 2) 53 | num_channels: number of output channels (default = 3 for RGB) 54 | latent_size: size of the latent manifold 55 | use_eql: whether to use equalized learning rate 56 | """ 57 | 58 | def __init__( 59 | self, 60 | depth: int = 10, 61 | num_channels: int = 3, 62 | latent_size: int = 512, 63 | use_eql: bool = True, 64 | ) -> None: 65 | super().__init__() 66 | 67 | # object state: 68 | self.depth = depth 69 | self.latent_size = latent_size 70 | self.num_channels = num_channels 71 | self.use_eql = use_eql 72 | 73 | ConvBlock = EqualizedConv2d if use_eql else Conv2d 74 | 75 | self.layers = ModuleList( 76 | [GenInitialBlock(latent_size, nf(1), use_eql=self.use_eql)] 77 | ) 78 | for stage in range(1, depth - 1): 79 | self.layers.append(GenGeneralConvBlock(nf(stage), nf(stage + 1), use_eql)) 80 | 81 | self.rgb_converters = ModuleList( 82 | [ 83 | ConvBlock(nf(stage), num_channels, kernel_size=(1, 1)) 84 | for stage in range(1, depth) 85 | ] 86 | ) 87 | 88 | def forward( 89 | self, x: Tensor, depth: Optional[int] = None, alpha: float = 1.0 90 | ) -> Tensor: 91 | """ 92 | forward pass of the Generator 93 | Args: 94 | x: input latent noise 95 | depth: depth from where the network's output is required 96 | alpha: value of alpha for fade-in effect 97 | 98 | Returns: generated images at the give depth's resolution 99 | """ 100 | depth = self.depth if depth is None else depth 101 | assert depth <= self.depth, f"Requested output depth {depth} cannot be produced" 102 | 103 | if depth == 2: 104 | y = self.rgb_converters[0](self.layers[0](x)) 105 | else: 106 | y = x 107 | for layer_block in self.layers[: depth - 2]: 108 | y = layer_block(y) 109 | residual = interpolate(self.rgb_converters[depth - 3](y), scale_factor=2) 110 | straight = self.rgb_converters[depth - 2](self.layers[depth - 2](y)) 111 | y = (alpha * straight) + ((1 - alpha) * residual) 112 | return y 113 | 114 | def get_save_info(self) -> Dict[str, Any]: 115 | return { 116 | "conf": { 117 | "depth": self.depth, 118 | "num_channels": self.num_channels, 119 | "latent_size": self.latent_size, 120 | "use_eql": self.use_eql, 121 | }, 122 | "state_dict": self.state_dict(), 123 | } 124 | 125 | 126 | class Discriminator(th.nn.Module): 127 | """ 128 | Discriminator of the GAN 129 | Args: 130 | depth: depth of the discriminator. log_2(resolution) 131 | num_channels: number of channels of the input images (Default = 3 for RGB) 132 | latent_size: latent size of the final layer 133 | use_eql: whether to use the equalized learning rate 134 | num_classes: number of classes for a conditional discriminator (Default = None) 135 | meaning unconditional discriminator 136 | """ 137 | 138 | def __init__( 139 | self, 140 | depth: int = 7, 141 | num_channels: int = 3, 142 | latent_size: int = 512, 143 | use_eql: bool = True, 144 | num_classes: Optional[int] = None, 145 | ) -> None: 146 | super().__init__() 147 | self.depth = depth 148 | self.num_channels = num_channels 149 | self.latent_size = latent_size 150 | self.use_eql = use_eql 151 | self.num_classes = num_classes 152 | self.conditional = num_classes is not None 153 | 154 | ConvBlock = EqualizedConv2d if use_eql else Conv2d 155 | 156 | if self.conditional: 157 | self.layers = [ConDisFinalBlock(nf(1), latent_size, num_classes, use_eql)] 158 | else: 159 | self.layers = [DisFinalBlock(nf(1), latent_size, use_eql)] 160 | 161 | for stage in range(1, depth - 1): 162 | self.layers.insert( 163 | 0, DisGeneralConvBlock(nf(stage + 1), nf(stage), use_eql) 164 | ) 165 | self.layers = ModuleList(self.layers) 166 | self.from_rgb = ModuleList( 167 | reversed( 168 | [ 169 | Sequential( 170 | ConvBlock(num_channels, nf(stage), kernel_size=(1, 1)), 171 | LeakyReLU(0.2), 172 | ) 173 | for stage in range(1, depth) 174 | ] 175 | ) 176 | ) 177 | 178 | def forward( 179 | self, x: Tensor, depth: int, alpha: float, labels: Optional[Tensor] = None 180 | ) -> Tensor: 181 | """ 182 | forward pass of the discriminator 183 | Args: 184 | x: input to the network 185 | depth: current depth of operation (Progressive GAN) 186 | alpha: current value of alpha for fade-in 187 | labels: labels for conditional discriminator (Default = None) 188 | shape => (Batch_size,) shouldn't be a column vector 189 | 190 | Returns: raw discriminator scores 191 | """ 192 | assert ( 193 | depth <= self.depth 194 | ), f"Requested output depth {depth} cannot be evaluated" 195 | 196 | if self.conditional: 197 | assert labels is not None, "Conditional discriminator required labels" 198 | 199 | if depth > 2: 200 | residual = self.from_rgb[-(depth - 2)]( 201 | avg_pool2d(x, kernel_size=2, stride=2) 202 | ) 203 | straight = self.layers[-(depth - 1)](self.from_rgb[-(depth - 1)](x)) 204 | y = (alpha * straight) + ((1 - alpha) * residual) 205 | for layer_block in self.layers[-(depth - 2) : -1]: 206 | y = layer_block(y) 207 | else: 208 | y = self.from_rgb[-1](x) 209 | if self.conditional: 210 | y = self.layers[-1](y, labels) 211 | else: 212 | y = self.layers[-1](y) 213 | return y 214 | 215 | def get_save_info(self) -> Dict[str, Any]: 216 | return { 217 | "conf": { 218 | "depth": self.depth, 219 | "num_channels": self.num_channels, 220 | "latent_size": self.latent_size, 221 | "use_eql": self.use_eql, 222 | "num_classes": self.num_classes, 223 | }, 224 | "state_dict": self.state_dict(), 225 | } 226 | 227 | 228 | def create_generator_from_saved_model(saved_model_path: Path) -> Generator: 229 | # load the data from the saved_model 230 | loaded_data = torch.load(saved_model_path) 231 | 232 | # create a generator from the loaded data: 233 | generator_data = ( 234 | loaded_data["shadow_generator"] 235 | if "shadow_generator" in loaded_data 236 | else loaded_data["generator"] 237 | ) 238 | generator = Generator(**generator_data["conf"]) 239 | generator.load_state_dict(generator_data["state_dict"]) 240 | 241 | return generator 242 | 243 | 244 | def create_discriminator_from_saved_model(saved_model_path: Path) -> Discriminator: 245 | # load the data from the saved_model 246 | loaded_data = torch.load(saved_model_path) 247 | 248 | # create a discriminator from the loaded data: 249 | discriminator_data = ( 250 | loaded_data.get("shadow_discriminator", loaded_data["discriminator"]) 251 | ) 252 | discriminator = Discriminator(**discriminator_data["conf"]) 253 | discriminator.load_state_dict(discriminator_data["state_dict"]) 254 | 255 | return discriminator 256 | 257 | 258 | def load_models(generator_path: Path, discriminator_path: Path) -> Tuple[Generator, Discriminator]: 259 | generator = create_generator_from_saved_model(generator_path) 260 | discriminator = create_discriminator_from_saved_model(discriminator_path) 261 | return generator, discriminator 262 | -------------------------------------------------------------------------------- /pro_gan_pytorch/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/pro_gan_pytorch/test/__init__.py -------------------------------------------------------------------------------- /pro_gan_pytorch/test/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | # noinspection PyPackageRequirements 4 | import pytest 5 | 6 | 7 | @pytest.fixture 8 | def test_data_path() -> Path: 9 | return Path("/home/animesh/work/data/3d_scenes/forest_synthetic_struct/images") 10 | -------------------------------------------------------------------------------- /pro_gan_pytorch/test/test_custom_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..custom_layers import ( 4 | EqualizedConv2d, 5 | EqualizedConvTranspose2d, 6 | EqualizedLinear, 7 | MinibatchStdDev, 8 | PixelwiseNorm, 9 | ) 10 | from .utils import assert_almost_equal, device, assert_tensor_validity 11 | 12 | 13 | # noinspection PyPep8Naming 14 | def test_EqualizedConv2d() -> None: 15 | mock_in = torch.randn(32, 21, 16, 16).to(device) 16 | conv_block = EqualizedConv2d(21, 3, kernel_size=(3, 3), padding=1).to(device) 17 | print(f"Equalized conv block: {conv_block}") 18 | 19 | mock_out = conv_block(mock_in) 20 | 21 | # check output 22 | assert_tensor_validity(mock_out, (32, 3, 16, 16)) 23 | 24 | # check the weight's scale 25 | assert_almost_equal(conv_block.weight.data.std().cpu(), 1, error_margin=1e-1) 26 | 27 | 28 | # noinspection PyPep8Naming 29 | def test_EqualizedConvTranspose2d() -> None: 30 | mock_in = torch.randn(32, 21, 16, 16).to(device) 31 | 32 | conv_transpose_block = EqualizedConvTranspose2d( 33 | 21, 3, kernel_size=(3, 3), padding=1 34 | ).to(device) 35 | print(f"Equalized conv transpose block: {conv_transpose_block}") 36 | 37 | mock_out = conv_transpose_block(mock_in) 38 | 39 | # check output 40 | assert_tensor_validity(mock_out, (32, 3, 16, 16)) 41 | 42 | # check the weight's scale 43 | assert_almost_equal( 44 | conv_transpose_block.weight.data.std().cpu(), 1, error_margin=1e-1 45 | ) 46 | 47 | 48 | # noinspection PyPep8Naming 49 | def test_EqualizedLinear() -> None: 50 | # test the forward for the first res block 51 | mock_in = torch.randn(32, 13).to(device) 52 | 53 | lin_block = EqualizedLinear(13, 52).to(device) 54 | print(f"Equalized linear block: {lin_block}") 55 | 56 | mock_out = lin_block(mock_in) 57 | 58 | # check output 59 | assert_tensor_validity(mock_out, (32, 52)) 60 | 61 | # check the weight's scale 62 | assert_almost_equal(lin_block.weight.data.std().cpu(), 1, error_margin=1e-1) 63 | 64 | 65 | # noinspection PyPep8Naming 66 | def test_PixelwiseNorm() -> None: 67 | mock_in = torch.randn(1, 13, 1, 1).to(device) 68 | normalizer = PixelwiseNorm() 69 | print(f"\nNormalizerBlock: {normalizer}") 70 | mock_out = normalizer(mock_in) 71 | 72 | # check output 73 | assert_tensor_validity(mock_out, mock_in.shape) 74 | 75 | # we cannot comment that the norm of the output tensor 76 | # will always be less than the norm of the input tensor 77 | # so no more checking can be done 78 | 79 | 80 | # noinspection PyPep8Naming 81 | def test_MinibatchStdDev() -> None: 82 | mock_in = torch.randn(16, 13, 16, 16).to(device) 83 | minStdD = MinibatchStdDev() 84 | print(f"\nMiniBatchStdDevBlock: {minStdD}") 85 | mock_out = minStdD(mock_in) 86 | 87 | # check output 88 | assert mock_out.shape[1] == mock_in.shape[1] + 1 89 | assert_tensor_validity(mock_out, (16, 14, 16, 16)) 90 | -------------------------------------------------------------------------------- /pro_gan_pytorch/test/test_gan.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | # noinspection PyPackageRequirements 4 | import matplotlib.pyplot as plt 5 | 6 | import torch 7 | 8 | from ..data_tools import ImageDirectoryDataset, get_transform 9 | from ..gan import ProGAN 10 | from ..networks import Discriminator, Generator 11 | from .utils import device 12 | 13 | 14 | def test_pro_gan_progressive_downsample_batch() -> None: 15 | batch = torch.randn((4, 3, 1024, 1024)).to(device) 16 | batch = torch.clamp(batch, min=0, max=1) 17 | progan = ProGAN(Generator(10), Discriminator(10), device=device) 18 | 19 | for res_log2 in range(2, 10): 20 | modified_batch = progan.progressive_downsample_batch( 21 | batch, depth=res_log2, alpha=0.001 22 | ) 23 | print(f"Downsampled batch at res_log2 {res_log2}: {modified_batch.shape}") 24 | plt.figure() 25 | plt.title(f"Image at resolution: {int(2 ** res_log2)}x{int(2 ** res_log2)}") 26 | plt.imshow(modified_batch.permute((0, 2, 3, 1))[0].cpu().numpy()) 27 | assert modified_batch.shape == ( 28 | batch.shape[0], 29 | batch.shape[1], 30 | int(2 ** res_log2), 31 | int(2 ** res_log2), 32 | ) 33 | 34 | plt.figure() 35 | plt.title(f"Image at resolution: {1024}x{1024}") 36 | plt.imshow(batch.permute((0, 2, 3, 1))[0].cpu().numpy()) 37 | plt.show() 38 | 39 | 40 | def test_pro_gan_train(test_data_path: Path) -> None: 41 | depth = 4 42 | progan = ProGAN(Generator(depth), Discriminator(depth), device=device) 43 | progan.train( 44 | dataset=ImageDirectoryDataset( 45 | test_data_path, 46 | transform=get_transform( 47 | new_size=(int(2 ** depth), int(2 ** depth)), flip_horizontal=False 48 | ), 49 | rec_dir=False, 50 | ), 51 | epochs=[10 for _ in range(3)], 52 | batch_sizes=[256, 256, 256], 53 | fade_in_percentages=[50 for _ in range(3)], 54 | save_dir=Path("./test_train"), 55 | num_samples=64, 56 | feedback_factor=10, 57 | ) 58 | print("test_finished") 59 | -------------------------------------------------------------------------------- /pro_gan_pytorch/test/test_networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from ..networks import Discriminator, Generator 6 | from .utils import device 7 | 8 | 9 | # noinspection PyPep8Naming 10 | def test_Generator() -> None: 11 | batch_size, latent_size = 2, 512 12 | num_channels = 3 13 | depth = 10 # resolution 1024 x 1024 14 | mock_generator = Generator(depth=depth, num_channels=num_channels).to(device) 15 | mock_latent = torch.randn((batch_size, latent_size)).to(device) 16 | 17 | print(f"Generator Network:\n{mock_generator}") 18 | 19 | with torch.no_grad(): 20 | for res_log2 in range(2, depth + 1): 21 | rgb_images = mock_generator(mock_latent, depth=res_log2, alpha=1) 22 | print(f"RGB output shape at depth {res_log2}: {rgb_images.shape}") 23 | assert rgb_images.shape == ( 24 | batch_size, 25 | num_channels, 26 | 2 ** res_log2, 27 | 2 ** res_log2, 28 | ) 29 | assert torch.isnan(rgb_images).sum().item() == 0 30 | assert torch.isinf(rgb_images).sum().item() == 0 31 | 32 | 33 | # noinspection PyPep8Naming 34 | def test_DiscriminatorUnconditional() -> None: 35 | batch_size, latent_size = 2, 512 36 | num_channels = 3 37 | depth = 10 # resolution 1024 x 1024 38 | mock_discriminator = Discriminator(depth=depth, num_channels=num_channels).to( 39 | device 40 | ) 41 | mock_inputs = [ 42 | torch.randn((batch_size, num_channels, 2 ** stage, 2 ** stage)).to(device) 43 | for stage in range(2, depth + 1) 44 | ] 45 | 46 | print(f"Discriminator Network:\n{mock_discriminator}") 47 | 48 | with torch.no_grad(): 49 | for res_log2 in range(2, depth + 1): 50 | mock_input = mock_inputs[res_log2 - 2] 51 | print(f"RGB input image shape at depth {res_log2}: {mock_input.shape}") 52 | score = mock_discriminator(mock_input, depth=res_log2, alpha=1) 53 | assert score.shape == (batch_size,) 54 | assert torch.isnan(score).sum().item() == 0 55 | assert torch.isinf(score).sum().item() == 0 56 | 57 | 58 | # noinspection PyPep8Naming 59 | def test_DiscriminatorConditional() -> None: 60 | batch_size, latent_size = 2, 512 61 | num_channels = 3 62 | depth = 10 # resolution 1024 x 1024 63 | mock_discriminator = Discriminator( 64 | depth=depth, num_channels=num_channels, num_classes=10 65 | ).to(device) 66 | mock_inputs = [ 67 | torch.randn((batch_size, num_channels, 2 ** stage, 2 ** stage)).to(device) 68 | for stage in range(2, depth + 1) 69 | ] 70 | mock_labels = torch.from_numpy(np.array([3, 7])).to(device) 71 | 72 | print(f"Discriminator Network:\n{mock_discriminator}") 73 | with torch.no_grad(): 74 | for res_log2 in range(2, depth + 1): 75 | mock_input = mock_inputs[res_log2 - 2] 76 | print(f"RGB input image shape at depth {res_log2}: {mock_input.shape}") 77 | score = mock_discriminator( 78 | mock_input, depth=res_log2, alpha=1, labels=mock_labels 79 | ) 80 | assert score.shape == (batch_size,) 81 | assert torch.isnan(score).sum().item() == 0 82 | assert torch.isinf(score).sum().item() == 0 83 | -------------------------------------------------------------------------------- /pro_gan_pytorch/test/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import numpy as np 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def assert_almost_equal(x: Any, y: Any, error_margin: float = 3.0) -> None: 12 | assert np.abs(x - y) <= error_margin 13 | 14 | 15 | def assert_tensor_validity( 16 | test_tensor: Tensor, expected_shape: Tuple[int, ...] 17 | ) -> None: 18 | assert test_tensor.shape == expected_shape 19 | assert torch.isnan(test_tensor).sum().item() == 0 20 | assert torch.isinf(test_tensor).sum().item() == 0 21 | -------------------------------------------------------------------------------- /pro_gan_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import Optional, Tuple 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | from pro_gan_pytorch import losses 11 | from pro_gan_pytorch.losses import WganGP, StandardGAN 12 | 13 | 14 | def adjust_dynamic_range( 15 | data: Tensor, 16 | drange_in: Optional[Tuple[float, float]] = (-1.0, 1.0), 17 | drange_out: Optional[Tuple[float, float]] = (0.0, 1.0), 18 | ): 19 | if drange_in != drange_out: 20 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / ( 21 | np.float32(drange_in[1]) - np.float32(drange_in[0]) 22 | ) 23 | bias = np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale 24 | data = data * scale + bias 25 | 26 | return torch.clamp(data, min=drange_out[0], max=drange_out[1]) 27 | 28 | 29 | def post_process_generated_images(imgs: Tensor) -> np.array: 30 | imgs = adjust_dynamic_range( 31 | imgs.permute(0, 2, 3, 1), drange_in=(-1.0, 1.0), drange_out=(0.0, 1.0) 32 | ) 33 | return (imgs * 255.0).detach().cpu().numpy().astype(np.uint8) 34 | 35 | 36 | def str2bool(v): 37 | if isinstance(v, bool): 38 | return v 39 | if v.lower() in ("yes", "true", "t", "y", "1"): 40 | return True 41 | elif v.lower() in ("no", "false", "f", "n", "0"): 42 | return False 43 | else: 44 | raise argparse.ArgumentTypeError("Boolean value expected.") 45 | 46 | 47 | # noinspection PyPep8Naming 48 | def str2GANLoss(v): 49 | if v.lower() == "wgan_gp": 50 | return WganGP() 51 | elif v.lower() == "standard_gan": 52 | return StandardGAN() 53 | else: 54 | raise argparse.ArgumentTypeError( 55 | "Unknown gan-loss function requested." 56 | f"Please consider contributing a your GANLoss to: " 57 | f"{str(Path(losses.__file__).absolute())}" 58 | ) 59 | -------------------------------------------------------------------------------- /pro_gan_pytorch_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/pro_gan_pytorch_scripts/__init__.py -------------------------------------------------------------------------------- /pro_gan_pytorch_scripts/compute_fid.py: -------------------------------------------------------------------------------- 1 | """ script for computing the fid of a trained model when compared with the dataset images """ 2 | import argparse 3 | import tempfile 4 | from pathlib import Path 5 | 6 | import imageio as imageio 7 | import torch 8 | from cleanfid import fid 9 | from torch.backends import cudnn 10 | from tqdm import tqdm 11 | 12 | from pro_gan_pytorch.networks import create_generator_from_saved_model 13 | from pro_gan_pytorch.utils import post_process_generated_images 14 | 15 | # turn fast mode on 16 | cudnn.benchmark = True 17 | 18 | 19 | def parse_arguments() -> argparse.Namespace: 20 | """ 21 | Returns: parsed arguments object 22 | """ 23 | parser = argparse.ArgumentParser("ProGAN fid_score computation tool", 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,) 25 | 26 | # fmt: off 27 | # required arguments 28 | parser.add_argument("model_path", action="store", type=Path, 29 | help="path to the trained_model.bin file") 30 | parser.add_argument("dataset_path", action="store", type=Path, 31 | help="path to the directory containing the images from the dataset. " 32 | "Note that this needs to be a flat directory") 33 | 34 | # optional arguments 35 | parser.add_argument("--generated_images_path", action="store", type=Path, default=None, required=False, 36 | help="path to the directory where the generated images are to be written. " 37 | "Uses a temporary directory by default. Provide this path if you'd like " 38 | "to see the generated images yourself :).") 39 | parser.add_argument("--batch_size", action="store", type=int, default=4, required=False, 40 | help="batch size used for generating random images") 41 | parser.add_argument("--num_generated_images", action="store", type=int, default=50_000, required=False, 42 | help="number of generated images used for computing the FID") 43 | # fmt: on 44 | 45 | args = parser.parse_args() 46 | 47 | return args 48 | 49 | 50 | def compute_fid(args: argparse.Namespace) -> None: 51 | """ 52 | compute the fid for a given trained pro-gan model 53 | Args: 54 | args: configuration used for the fid computation 55 | Returns: None 56 | 57 | """ 58 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 59 | 60 | # load the data from the trained-model 61 | print(f"loading data from the trained model at: {args.model_path}") 62 | generator = create_generator_from_saved_model(args.model_path).to(device) 63 | 64 | # create the generated images directory: 65 | if args.generated_images_path is not None: 66 | args.generated_images_path.mkdir(parents=True, exist_ok=True) 67 | generated_images_path = ( 68 | args.generated_images_path 69 | if args.generated_images_path is not None 70 | else tempfile.TemporaryDirectory() 71 | ) 72 | if args.generated_images_path is None: 73 | image_writing_path = Path(generated_images_path.name) 74 | else: 75 | image_writing_path = generated_images_path 76 | 77 | print("generating random images from the trained generator ...") 78 | with torch.no_grad(): 79 | for img_num in tqdm(range(0, args.num_generated_images, args.batch_size)): 80 | num_imgs = min(args.batch_size, args.num_generated_images - img_num) 81 | random_latents = torch.randn(num_imgs, generator.latent_size, device=device) 82 | gen_imgs = post_process_generated_images(generator(random_latents)) 83 | 84 | # write the batch of generated images: 85 | for batch_num, gen_img in enumerate(gen_imgs, start=1): 86 | imageio.imwrite( 87 | image_writing_path / f"{img_num + batch_num}.png", 88 | gen_img, 89 | ) 90 | 91 | # compute the fid once all images are generated 92 | print("computing fid ...") 93 | score = fid.compute_fid( 94 | fdir1=args.dataset_path, 95 | fdir2=image_writing_path, 96 | mode="clean", 97 | num_workers=4, 98 | ) 99 | print(f"fid score: {score: .3f}") 100 | 101 | # most importantly, don't forget to do the cleanup on the temporary directory: 102 | if hasattr(generated_images_path, "cleanup"): 103 | generated_images_path.cleanup() 104 | 105 | 106 | def main() -> None: 107 | """ 108 | Main function of the script 109 | Returns: None 110 | """ 111 | compute_fid(parse_arguments()) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /pro_gan_pytorch_scripts/latent_space_interpolation.py: -------------------------------------------------------------------------------- 1 | """ script for writing a video of the latent space interpolation from a trained model """ 2 | import argparse 3 | from pathlib import Path 4 | 5 | import cv2 6 | import torch 7 | from scipy.ndimage import gaussian_filter 8 | from torch.backends import cudnn 9 | from tqdm import tqdm 10 | 11 | from pro_gan_pytorch.networks import create_generator_from_saved_model 12 | from pro_gan_pytorch.utils import post_process_generated_images 13 | 14 | # turn fast mode on 15 | cudnn.benchmark = True 16 | 17 | 18 | def parse_arguments(): 19 | """ 20 | command line arguments parser 21 | :return: args => parsed command line arguments 22 | """ 23 | parser = argparse.ArgumentParser("ProGAN latent-space walk demo video creation tool", 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,) 25 | 26 | # fmt: off 27 | # required arguments 28 | parser.add_argument("model_path", action="store", type=Path, 29 | help="path to the trained_model.bin file") 30 | 31 | # options related to the video 32 | parser.add_argument("--output_path", action="store", type=Path, required=False, 33 | default="./latent_space_walk.mp4", 34 | help="path to the output video file location. " 35 | "Please only use mp4 format with this tool (.mp4 extension). " 36 | "I have banged my head too much to get anything else to work :(.") 37 | parser.add_argument("--generation_depth", action="store", type=int, default=None, required=False, 38 | help="depth at which the images should be generated. " 39 | "Starts from 2 --> (4x4) | 3 --> (8x8) etc. Uses the highest resolution by default. ") 40 | parser.add_argument("--time", action="store", type=float, default=30, required=False, 41 | help="number of seconds in the video") 42 | parser.add_argument("--fps", action="store", type=int, default=60, required=False, 43 | help="fps of the generated video") 44 | parser.add_argument("--smoothing", action="store", type=float, default=0.75, required=False, 45 | help="smoothness of walking in the latent-space." 46 | " High values corresponds to more smoothing.") 47 | # fmt: on 48 | 49 | args = parser.parse_args() 50 | 51 | return args 52 | 53 | 54 | def latent_space_interpolation(args): 55 | """ 56 | Generate a video of the latent space walk (interpolation) 57 | Args: 58 | args: configuration used for the lsid 59 | Returns: None (writes generated video to disk) 60 | """ 61 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 | 63 | # load the data from the trained-model 64 | print(f"loading data from the trained model at: {args.model_path}") 65 | generator = create_generator_from_saved_model(args.model_path).to(device) 66 | 67 | # total_frames in the video: 68 | total_frames = int(args.time * args.fps) 69 | 70 | # create the video from the latent space interpolation (walk) 71 | # all latent vectors for each and every frame: 72 | all_latents = torch.randn(total_frames, generator.latent_size).to(device) 73 | all_latents = gaussian_filter(all_latents.cpu(), [args.smoothing * args.fps, 0]) 74 | all_latents = torch.from_numpy(all_latents).to(device) 75 | 76 | # create output directory 77 | args.output_path.parent.mkdir(parents=True, exist_ok=True) 78 | 79 | # make the cv2 video object 80 | print("Generating the video frames ...") 81 | generation_depth = ( 82 | generator.depth if args.generation_depth is None else args.generation_depth 83 | ) 84 | img_dim = 2 ** generation_depth 85 | video_out = cv2.VideoWriter( 86 | str(args.output_path), 87 | cv2.VideoWriter_fourcc(*"mp4v"), 88 | args.fps, 89 | (img_dim, img_dim), 90 | ) 91 | 92 | # Run the main loop for the interpolation: 93 | with torch.no_grad(): # no need to compute gradients here :) 94 | for latent in tqdm(all_latents): 95 | latent = torch.unsqueeze(latent, dim=0) 96 | 97 | # generate the image for this latent vector: 98 | frame = post_process_generated_images( 99 | generator(latent, depth=generation_depth) 100 | ) 101 | frame = frame[0, ..., ::-1] # need to reverse the channel order for cv2 :D 102 | 103 | # write the generated frame to the video 104 | video_out.write(frame) 105 | 106 | print(f"video has been generated and saved to {args.output_path}") 107 | 108 | # don't forget to close the video stream :) 109 | video_out.release() 110 | 111 | 112 | def main() -> None: 113 | """ 114 | Main function of the script 115 | Returns: None 116 | """ 117 | latent_space_interpolation(parse_arguments()) 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /pro_gan_pytorch_scripts/train.py: -------------------------------------------------------------------------------- 1 | """ script for training a ProGAN (Progressively grown gan model) """ 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch.backends import cudnn 8 | 9 | from pro_gan_pytorch.data_tools import ImageDirectoryDataset, get_transform 10 | from pro_gan_pytorch.gan import ProGAN 11 | from pro_gan_pytorch.networks import Discriminator, Generator 12 | from pro_gan_pytorch.utils import str2bool, str2GANLoss 13 | 14 | # turn fast mode on 15 | cudnn.benchmark = True 16 | 17 | # define the device for the training script 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | def parse_arguments() -> argparse.Namespace: 22 | """ 23 | command line arguments parser 24 | Returns: args => parsed command line arguments 25 | """ 26 | parser = argparse.ArgumentParser( 27 | "Train Progressively grown GAN", 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 29 | ) 30 | 31 | # fmt: off 32 | # Required arguments (input path to the data and the output directory for saving training assets) 33 | parser.add_argument("train_path", action="store", type=Path, 34 | help="Path to the images folder for training the ProGAN") 35 | parser.add_argument("output_dir", action="store", type=Path, 36 | help="Path to the directory for saving the logs and models") 37 | 38 | # Optional arguments 39 | # for retraining a model options: 40 | parser.add_argument("--retrain", action="store", type=str2bool, default=False, required=False, 41 | help="whenever you want to resume training from saved models") 42 | parser.add_argument("--generator_path", action="store", type=Path, required="--retrain" in sys.argv, 43 | help="Path to the generator model for retraining the ProGAN") 44 | parser.add_argument("--discriminator_path", action="store", type=Path, required="--retrain" in sys.argv, 45 | help="Path to the discriminator model for retraining the ProGAN") 46 | # dataset related options: 47 | parser.add_argument("--rec_dir", action="store", type=str2bool, default=False, required=False, 48 | help="whether images are stored under one folder or under a recursive dir structure") 49 | parser.add_argument("--flip_horizontal", action="store", type=str2bool, default=True, required=False, 50 | help="whether to apply mirror (horizontal) augmentation") 51 | 52 | # model architecture related options: 53 | parser.add_argument("--depth", action="store", type=int, default=10, required=False, 54 | help="depth of the generator and the discriminator. Starts from 2. " 55 | "Example 2 --> (4x4) | 3 --> (8x8) ... | 10 --> (1024x1024)") 56 | parser.add_argument("--num_channels", action="store", type=int, default=3, required=False, 57 | help="number of channels in the image data") 58 | parser.add_argument("--latent_size", action="store", type=int, default=512, required=False, 59 | help="latent size of the generator and the discriminator") 60 | 61 | # training related options: 62 | parser.add_argument("--use_eql", action="store", type=str2bool, default=True, required=False, 63 | help="whether to use the equalized learning rate") 64 | parser.add_argument("--use_ema", action="store", type=str2bool, default=True, required=False, 65 | help="whether to use the exponential moving average of generator weights. " 66 | "Keeps two copies of the generator model; an instantaneous one and " 67 | "the averaged one.") 68 | parser.add_argument("--ema_beta", action="store", type=float, default=0.999, required=False, 69 | help="value of the ema beta") 70 | parser.add_argument("--epochs", action="store", type=int, required=False, nargs="+", 71 | default=[42 for _ in range(9)], 72 | help="number of epochs over the training dataset per stage") 73 | parser.add_argument("--batch_sizes", action="store", type=int, required=False, nargs="+", 74 | default=[32, 32, 32, 32, 16, 16, 8, 4, 2], 75 | help="batch size used for training the model per stage") 76 | parser.add_argument("--batch_repeats", action="store", type=int, required=False, default=4, 77 | help="number of G and D steps executed per training iteration") 78 | parser.add_argument("--fade_in_percentages", action="store", type=int, required=False, nargs="+", 79 | default=[50 for _ in range(9)], 80 | help="number of iterations for which fading of new layer happens. Measured in percentage") 81 | parser.add_argument("--loss_fn", action="store", type=str2GANLoss, required=False, default="wgan_gp", 82 | help="loss function used for training the GAN. " 83 | "Current options: [wgan_gp, standard_gan]") 84 | parser.add_argument("--g_lrate", action="store", type=float, required=False, default=0.003, 85 | help="learning rate used by the generator") 86 | parser.add_argument("--d_lrate", action="store", type=float, required=False, default=0.003, 87 | help="learning rate used by the discriminator") 88 | parser.add_argument("--num_feedback_samples", action="store", type=int, required=False, default=4, 89 | help="number of samples used for fixed seed gan feedback") 90 | parser.add_argument("--start_depth", action="store", type=int, required=False, default=2, 91 | help="resolution to start the training from. " 92 | "Example 2 --> (4x4) | 3 --> (8x8) ... | 10 --> (1024x1024). " 93 | "Note that this is not a way to restart a partial training. " 94 | "Resuming is not supported currently. But will soon be.") 95 | parser.add_argument("--num_workers", action="store", type=int, required=False, default=4, 96 | help="number of dataloader subprocesses. It's a pytorch thing, you can ignore it ;)." 97 | " Leave it to the default value unless things are weirdly slow for you.") 98 | parser.add_argument("--feedback_factor", action="store", type=int, required=False, default=10, 99 | help="number of feedback logs written per epoch") 100 | parser.add_argument("--checkpoint_factor", action="store", type=int, required=False, default=10, 101 | help="number of epochs after which a model snapshot is saved per training stage") 102 | # fmt: on 103 | 104 | parsed_args = parser.parse_args() 105 | return parsed_args 106 | 107 | 108 | def train_progan(args: argparse.Namespace) -> None: 109 | """ 110 | method to train the progan (progressively grown gan) given the configuration parameters 111 | Args: 112 | args: configuration used for the training 113 | Returns: None 114 | """ 115 | print(f"Selected arguments: {args}") 116 | 117 | if args.retrain: 118 | print(f"Retraining the ProGAN: `depth`, `num_channels`, `latent_size`, `use_eql` parameters will be ignored if " 119 | f"specified.") 120 | generator, discriminator = load_models(args.generator_path, args.discriminator_path) 121 | args.depth = generator.depth 122 | args.num_channels = generator.num_channels 123 | args.latent_size = generator.latent_size 124 | args.use_eql = generator.use_eql 125 | else: 126 | generator = Generator( 127 | depth=args.depth, 128 | num_channels=args.num_channels, 129 | latent_size=args.latent_size, 130 | use_eql=args.use_eql, 131 | ) 132 | discriminator = Discriminator( 133 | depth=args.depth, 134 | num_channels=args.num_channels, 135 | latent_size=args.latent_size, 136 | use_eql=args.use_eql, 137 | ) 138 | 139 | progan = ProGAN( 140 | generator, 141 | discriminator, 142 | device=device, 143 | use_ema=args.use_ema, 144 | ema_beta=args.ema_beta, 145 | ) 146 | 147 | progan.train( 148 | dataset=ImageDirectoryDataset( 149 | args.train_path, 150 | transform=get_transform( 151 | new_size=(int(2 ** args.depth), int(2 ** args.depth)), 152 | flip_horizontal=args.flip_horizontal, 153 | ), 154 | rec_dir=args.rec_dir, 155 | ), 156 | epochs=args.epochs, 157 | batch_sizes=args.batch_sizes, 158 | fade_in_percentages=args.fade_in_percentages, 159 | loss_fn=args.loss_fn, 160 | batch_repeats=args.batch_repeats, 161 | gen_learning_rate=args.g_lrate, 162 | dis_learning_rate=args.d_lrate, 163 | num_samples=args.num_feedback_samples, 164 | start_depth=args.start_depth, 165 | num_workers=args.num_workers, 166 | feedback_factor=args.feedback_factor, 167 | checkpoint_factor=args.checkpoint_factor, 168 | save_dir=args.output_dir, 169 | ) 170 | 171 | 172 | def main() -> None: 173 | """ 174 | Main function of the script 175 | Returns: None 176 | """ 177 | train_progan(parse_arguments()) 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest==6.2.2 2 | black==20.8b1 3 | matplotlib==3.5.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.4 2 | torch==1.10.0 3 | torchvision==0.11.1 4 | Pillow==9.0.0 5 | tensorboard==2.7.0 6 | imageio==2.12.0 7 | tqdm==4.62.3 8 | scipy==1.7.2 9 | opencv-python==4.5.4.60 10 | clean-fid==0.1.15 11 | -------------------------------------------------------------------------------- /samples/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore the full version of the training video 2 | pro-gan_training_video_smaller.mp4 3 | 4 | # also ignore the trained model weights 5 | GAN_GEN_SHADOW_8.pth 6 | 7 | # ignore some huge videos: 8 | interpolation.mp4 9 | video_2.gif 10 | video_3.gif 11 | 12 | # ignore the new latent_space interpolation video 13 | new_interp.mp4 14 | 15 | frames_pro/ 16 | frames_mine/ 17 | M_GAN_GEN_SHADOW_8.pth 18 | -------------------------------------------------------------------------------- /samples/celebA-HQ.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/samples/celebA-HQ.gif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("requirements.txt", "r") as file_: 4 | project_requirements = file_.read().split("\n") 5 | 6 | setup( 7 | name="pro-gan-pth", 8 | version="3.4", 9 | packages=["pro_gan_pytorch", "pro_gan_pytorch_scripts"], 10 | url="https://github.com/akanimax/pro_gan_pytorch", 11 | license="MIT", 12 | author="akanimax", 13 | author_email="akanimax@gmail.com", 14 | setup_requires=['wheel'], 15 | description="ProGAN package implemented as an extension of PyTorch nn.Module", 16 | install_requires=project_requirements, 17 | entry_points={ 18 | "console_scripts": [ 19 | f"progan_train=pro_gan_pytorch_scripts.train:main", 20 | f"progan_lsid=pro_gan_pytorch_scripts.latent_space_interpolation:main", 21 | f"progan_fid=pro_gan_pytorch_scripts.compute_fid:main", 22 | ] 23 | }, 24 | ) 25 | --------------------------------------------------------------------------------