├── .gitignore
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── README.rst
├── pro_gan_pytorch
├── __init__.py
├── custom_layers.py
├── data_tools.py
├── gan.py
├── losses.py
├── modules.py
├── networks.py
├── test
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_custom_layers.py
│ ├── test_gan.py
│ ├── test_networks.py
│ └── utils.py
└── utils.py
├── pro_gan_pytorch_scripts
├── __init__.py
├── compute_fid.py
├── latent_space_interpolation.py
└── train.py
├── requirements-dev.txt
├── requirements.txt
├── samples
├── .gitignore
└── celebA-HQ.gif
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # ignore pycharm data
107 | .idea/
108 |
109 | # ignore the virtual environment for the project as well
110 | pro_gan_pytorch_env/
111 |
112 | # ignore the test_train folder created by one of the tests:
113 | ./test_train
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Animesh Karnewar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # need to include the following files additionally for the setup.py to work
2 | include requirements.txt
3 | include scripts/*
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pro_gan_pytorch
2 | **Unofficial PyTorch** implementation of Paper titled "Progressive growing of GANs for improved
3 | Quality, Stability, and Variation".
4 | For the official TensorFlow code, please refer to
5 | [this repo](https://github.com/tkarras/progressive_growing_of_gans)
6 |
7 | 
8 | 
9 |
10 | # How to use:
11 | ### Using the package
12 | **Requirements (aka. we tested for):**
13 | 1. **Ubuntu** `20.04.3` or above
14 | 2. Python `3.8.3`
15 | 3. Nvidia GPU `GeForce 1080 Ti or above` min GPU-mem `8GB`
16 | 4. Nvidia drivers >= `470.86`
17 | 5. Nvidia cuda `11.3` | can be skipped since pytorch ships with cuda, cudnn etc.
18 |
19 | **Installing the package**
20 | 1. Easiest way is to create a new virtual-env
21 | so that your global python env doesn't get corrupted
22 | 2. Create and switch to your new virtual environment
23 | ```
24 | (your-machine):~$ python3 -m venv /pro_gan_pth_env
25 | (pro_gan_pth_env)(your-machine):~$ source /pro_gan_pth_env/bin/activate
26 | ```
27 | 3. Install the `pro-gan-pth` package from pypi, if you meet
28 | all the above dependencies
29 | ```
30 | (pro_gan_pth_env)(your-machine):~$ pip install pro-gan-pth
31 | ```
32 | 4. Once installed, you can either use the installed commandline tools
33 | `progan_train`, `progan_lsid` and `progan_fid`.
34 | Note that the `progan_train` can be used with multiple gpus
35 | (If you have many :smile:). Just ensure that the gpus visible in the
36 | `CUDA_VISIBLE_DEVICES=0,1,2` environment variable. The other two tools only use a
37 | single GPU.
38 |
39 |
40 | ```
41 | (your-machine):~$ progan_train --help
42 | usage: Train Progressively grown GAN
43 | [-h]
44 | [--retrain RETRAIN]
45 | [--generator_path GENERATOR_PATH]
46 | [--discriminator_path DISCRIMINATOR_PATH]
47 | [--rec_dir REC_DIR]
48 | [--flip_horizontal FLIP_HORIZONTAL]
49 | [--depth DEPTH]
50 | [--num_channels NUM_CHANNELS]
51 | [--latent_size LATENT_SIZE]
52 | [--use_eql USE_EQL]
53 | [--use_ema USE_EMA]
54 | [--ema_beta EMA_BETA]
55 | [--epochs EPOCHS [EPOCHS ...]]
56 | [--batch_sizes BATCH_SIZES [BATCH_SIZES ...]]
57 | [--batch_repeats BATCH_REPEATS]
58 | [--fade_in_percentages FADE_IN_PERCENTAGES [FADE_IN_PERCENTAGES ...]]
59 | [--loss_fn LOSS_FN]
60 | [--g_lrate G_LRATE]
61 | [--d_lrate D_LRATE]
62 | [--num_feedback_samples NUM_FEEDBACK_SAMPLES]
63 | [--start_depth START_DEPTH]
64 | [--num_workers NUM_WORKERS]
65 | [--feedback_factor FEEDBACK_FACTOR]
66 | [--checkpoint_factor CHECKPOINT_FACTOR]
67 | train_path
68 | output_dir
69 |
70 | positional arguments:
71 | train_path Path to the images folder for training the ProGAN
72 | output_dir Path to the directory for saving the logs and models
73 |
74 | optional arguments:
75 | -h, --help show this help message and exit
76 | --retrain RETRAIN whenever you want to resume training from saved models (default: False)
77 | --generator_path GENERATOR_PATH
78 | Path to the generator model for retraining the ProGAN (default: None)
79 | --discriminator_path DISCRIMINATOR_PATH
80 | Path to the discriminat or model for retraining the ProGAN (default: None)
81 | --rec_dir REC_DIR whether images stored under one folder or has a recursive dir structure (default: True)
82 | --flip_horizontal FLIP_HORIZONTAL
83 | whether to apply mirror augmentation (default: True)
84 | --depth DEPTH depth of the generator and the discriminator (default: 10)
85 | --num_channels NUM_CHANNELS
86 | number of channels of in the image data (default: 3)
87 | --latent_size LATENT_SIZE
88 | latent size of the generator and the discriminator (default: 512)
89 | --use_eql USE_EQL whether to use the equalized learning rate (default: True)
90 | --use_ema USE_EMA whether to use the exponential moving averages (default: True)
91 | --ema_beta EMA_BETA value of the ema beta (default: 0.999)
92 | --epochs EPOCHS [EPOCHS ...]
93 | number of epochs over the training dataset per stage (default: [42, 42, 42, 42, 42, 42, 42, 42, 42])
94 | --batch_sizes BATCH_SIZES [BATCH_SIZES ...]
95 | batch size used for training the model per stage (default: [32, 32, 32, 32, 16, 16, 8, 4, 2])
96 | --batch_repeats BATCH_REPEATS
97 | number of G and D steps executed per training iteration (default: 4)
98 | --fade_in_percentages FADE_IN_PERCENTAGES [FADE_IN_PERCENTAGES ...]
99 | number of iterations for which fading of new layer happens. Measured in percentage (default: [50, 50, 50, 50, 50, 50, 50, 50, 50])
100 | --loss_fn LOSS_FN loss function used for training the GAN. Current options: [wgan_gp, standard_gan] (default: wgan_gp)
101 | --g_lrate G_LRATE learning rate used by the generator (default: 0.003)
102 | --d_lrate D_LRATE learning rate used by the discriminator (default: 0.003)
103 | --num_feedback_samples NUM_FEEDBACK_SAMPLES
104 | number of samples used for fixed seed gan feedback (default: 4)
105 | --start_depth START_DEPTH
106 | resolution to start the training from. Example 2 --> (4x4) | 3 --> (8x8) ... | 10 --> (1024x1024)Note that this is not a way to restart a partial training. Resuming is not
107 | supported currently. But will be soon. (default: 2)
108 | --num_workers NUM_WORKERS
109 | number of dataloader subprocesses. It's a pytorch thing, you can ignore it ;). Leave it to the default value unless things are weirdly slow for you. (default: 4)
110 | --feedback_factor FEEDBACK_FACTOR
111 | number of feedback logs written per epoch (default: 10)
112 | --checkpoint_factor CHECKPOINT_FACTOR
113 | number of epochs after which a model snapshot is saved per training stage (default: 10)
114 |
115 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------
116 |
117 | (your-machine):~$ progan_lsid --help
118 | usage: ProGAN latent-space walk demo video creation tool [-h] [--output_path OUTPUT_PATH] [--generation_depth GENERATION_DEPTH] [--time TIME] [--fps FPS] [--smoothing SMOOTHING] model_path
119 |
120 | positional arguments:
121 | model_path path to the trained_model.bin file
122 |
123 | optional arguments:
124 | -h, --help show this help message and exit
125 | --output_path OUTPUT_PATH
126 | path to the output video file location. Please only use mp4 format with this tool (.mp4 extension). I have banged my head too much to get anything else to work :(. (default:
127 | ./latent_space_walk.mp4)
128 | --generation_depth GENERATION_DEPTH
129 | depth at which the images should be generated. Starts from 2 --> (4x4) | 3 --> (8x8) etc. (default: None)
130 | --time TIME number of seconds in the video (default: 30)
131 | --fps FPS fps of the generated video (default: 60)
132 | --smoothing SMOOTHING
133 | smoothness of walking in the latent-space. High values corresponds to more smoothing. (default: 0.75)
134 |
135 | ------------------------------------------------------------------------------------------------------------------------------------------------------------------
136 |
137 | (your-machine):~$ progan_fid --help
138 | usage: ProGAN fid_score computation tool [-h] [--generated_images_path GENERATED_IMAGES_PATH] [--batch_size BATCH_SIZE] [--num_generated_images NUM_GENERATED_IMAGES] model_path dataset_path
139 |
140 | positional arguments:
141 | model_path path to the trained_model.bin file
142 | dataset_path path to the directory containing the images from the dataset. Note that this needs to be a flat directory
143 |
144 | optional arguments:
145 | -h, --help show this help message and exit
146 | --generated_images_path GENERATED_IMAGES_PATH
147 | path to the directory where the generated images are to be written. Uses a temporary directory by default. Provide this path if you'd like to see the generated images yourself
148 | :). (default: None)
149 | --batch_size BATCH_SIZE
150 | batch size used for generating random images (default: 4)
151 | --num_generated_images NUM_GENERATED_IMAGES
152 | number of generated images used for computing the FID (default: 50000)
153 | ```
154 |
155 | 5. Or, you could import this as a python package in your code
156 | for more advanced use-cases:
157 | ```
158 | import pro_gan_pytorch as pg
159 | ```
160 | You can use all the modules in the package such as: `pg.networks.Generator`,
161 | `pg.networks.Discriminator`, `pg.gan.ProGAN` etc. Mostly, you'll only need
162 | the `pg.gan.ProGAN` module for training. For inference, you will probably only
163 | need the `pg.networks.Generator`. Please refer to the scripts for the tools as
164 | in 4. under `pro_gan_pytorch_scripts/` for examples on how to use the package.
165 | Besides, please feel free to just read the code. It's really easy to follow
166 | (or at least I hope so :sweat_smile: :grimacing:).
167 |
168 | ### Developing the package
169 | For more advanced use-cases in your project, or if you'd like to contribute new
170 | features to this project, the following steps would help you get this project setup
171 | for developing. There are no standard set of rules for contributing here
172 | (no `CONTRIBUTING.md`) but let's try to maintain the overall ethos of the
173 | codebase :smile:.
174 |
175 | 1. clone this repository
176 | ```
177 | (your-machine):~$ cd
178 | (your-machine):$ git clone https://github.com/akanimax/pro_gan_pytorch.git
179 | ```
180 | 2. Apologies in advance since the step 1. will take a while. I ended up
181 | pushing gifs and other large binary assets to git back then.
182 | I didn't know better :sad:. I'll see if this can be sorted out somehow.
183 | But once done setup a development virtual env,
184 | ```
185 | (your-machine):$ python3 -m venv pro-gan-pth-dev-env
186 | (your-machine):$ source pro-gan-pth-dev-env/source/activate
187 | ```
188 | 3. Install the package in development mode:
189 | ```
190 | (pro-gan-pth-dev-env)(your-machine):$ pip install -e .
191 | ```
192 | 4. Also install the dev requirements:
193 | ```
194 | (pro-gan-pth-dev-env)(your-machine):$ pip install -r requirements-dev.txt
195 | ```
196 | 5. Now open the project in the editor of your choice, and you are good to go.
197 | I use `pytest` for testing and `black` for code formatting. Check out
198 | [this_link](https://black.readthedocs.io/en/stable/integrations/editors.html) for
199 | how to setup `black` with various IDEs.
200 |
201 | 6. There is no fancy CI, or automated testing, or docs building since this is a
202 | fairly tiny project. But we are open to considering these tools if more features
203 | keep getting added to this project.
204 |
205 | # Trained Models
206 | We will be training models using this package on different datasets over the time.
207 | Also, please feel free to open PRs for the following table if you end up training
208 | models for your own datasets. If you are contributing, then please setup
209 | a file hosting solution for serving the trained models.
210 |
211 | | Courtesy | Dataset | Size |Resolution | GPUs used | #Epochs per stage | Training time | FID score | Link | Qualitative samples |
212 | | :--- | :--- | :--- |:--- | :--- | :--- | :--- | :--- | :--- | :--- |
213 | | @owang | Metfaces | ~1.3K |1024 x 1024 | 1 V100-32GB | 42 | 24 hrs | 101.624 | [model_link](http://geometry.cs.ucl.ac.uk/projects/2021/pro_gan_pytorch/model_metfaces.bin) | 
214 |
215 |
216 | **Note that we compute the FID using the clean_fid version from
217 | [Parmar et. al.](https://www.cs.cmu.edu/~clean-fid/)**
218 |
219 | # General cool stuff :smile:
220 | ### Training timelapse (fixed latent points):
221 | The training timelapse created from the images logged during the training
222 | looks really cool.
223 |
224 |
226 |
227 |
228 |
229 | Checkout this [YT video](https://www.youtube.com/watch?v=lzTm6Lq76Mo) for a
230 | 4K version :smile:.
231 |
232 | If interested please feel free to check out this
233 | [medium blog]( https://medium.com/@animeshsk3/the-unprecedented-effectiveness-of-progressive-growing-of-gans-37475c88afa3)
234 | I wrote explaining the progressive growing technique.
235 |
236 | # References
237 |
238 | 1. Tero Karras, Timo Aila, Samuli Laine, & Jaakko Lehtinen (2018).
239 | Progressive Growing of GANs for Improved Quality, Stability, and Variation.
240 | In International Conference on Learning Representations.
241 |
242 | 2. Parmar, Gaurav, Richard Zhang, and Jun-Yan Zhu.
243 | "On Buggy Resizing Libraries and Surprising Subtleties in FID Calculation."
244 | arXiv preprint arXiv:2104.11222 (2021).
245 |
246 | # Feature requests
247 | - [ ] Conditional GAN support
248 | - [ ] Tool for generating time-lapse video from the log images
249 | - [ ] Integrating fid-metric computation as a training-logging
250 |
251 | # Thanks
252 | As always,
253 | please feel free to open PRs/issues/suggestions here.
254 | Hope this work is useful in your project :smile:.
255 |
256 | cheers :beers:!
257 | @akanimax :sunglasses:
258 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/README.rst
--------------------------------------------------------------------------------
/pro_gan_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | """ Package has implementation of ProGAN (progressive growing of GANs)
2 | as an extension of PyTorch Module
3 | """
4 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/custom_layers.py:
--------------------------------------------------------------------------------
1 | """ Module contains custom layers """
2 | from typing import Any
3 |
4 | import numpy as np
5 |
6 | import torch
7 | from torch import Tensor
8 | from torch.nn import Conv2d, ConvTranspose2d, Linear
9 |
10 |
11 | def update_average(model_tgt, model_src, beta):
12 | """
13 | function to calculate the Exponential moving averages for the Generator weights
14 | This function updates the exponential average weights based on the current training
15 | Args:
16 | model_tgt: target model
17 | model_src: source model
18 | beta: value of decay beta
19 | Returns: None (updates the target model)
20 | """
21 |
22 | with torch.no_grad():
23 | param_dict_src = dict(model_src.named_parameters())
24 |
25 | for p_name, p_tgt in model_tgt.named_parameters():
26 | p_src = param_dict_src[p_name]
27 | assert p_src is not p_tgt
28 | p_tgt.copy_(beta * p_tgt + (1.0 - beta) * p_src)
29 |
30 |
31 | class EqualizedConv2d(Conv2d):
32 | def __init__(
33 | self,
34 | in_channels,
35 | out_channels,
36 | kernel_size,
37 | stride=1,
38 | padding=0,
39 | dilation=1,
40 | groups=1,
41 | bias=True,
42 | padding_mode="zeros",
43 | ) -> None:
44 | super().__init__(
45 | in_channels,
46 | out_channels,
47 | kernel_size,
48 | stride,
49 | padding,
50 | dilation,
51 | groups,
52 | bias,
53 | padding_mode,
54 | )
55 | # make sure that the self.weight and self.bias are initialized according to
56 | # random normal distribution
57 | torch.nn.init.normal_(self.weight)
58 | if bias:
59 | torch.nn.init.zeros_(self.bias)
60 |
61 | # define the scale for the weights:
62 | fan_in = np.prod(self.kernel_size) * self.in_channels
63 | self.scale = np.sqrt(2) / np.sqrt(fan_in)
64 |
65 | def forward(self, x: Tensor) -> Tensor:
66 | return torch.conv2d(
67 | input=x,
68 | weight=self.weight * self.scale, # scale the weight on runtime
69 | bias=self.bias,
70 | stride=self.stride,
71 | padding=self.padding,
72 | dilation=self.dilation,
73 | groups=self.groups,
74 | )
75 |
76 |
77 | class EqualizedConvTranspose2d(ConvTranspose2d):
78 | def __init__(
79 | self,
80 | in_channels,
81 | out_channels,
82 | kernel_size,
83 | stride=1,
84 | padding=0,
85 | output_padding=0,
86 | groups=1,
87 | bias=True,
88 | dilation=1,
89 | padding_mode="zeros",
90 | ) -> None:
91 | super().__init__(
92 | in_channels,
93 | out_channels,
94 | kernel_size,
95 | stride,
96 | padding,
97 | output_padding,
98 | groups,
99 | bias,
100 | dilation,
101 | padding_mode,
102 | )
103 | # make sure that the self.weight and self.bias are initialized according to
104 | # random normal distribution
105 | torch.nn.init.normal_(self.weight)
106 | if bias:
107 | torch.nn.init.zeros_(self.bias)
108 |
109 | # define the scale for the weights:
110 | fan_in = self.in_channels
111 | self.scale = np.sqrt(2) / np.sqrt(fan_in)
112 |
113 | def forward(self, x: Tensor, output_size: Any = None) -> Tensor:
114 | output_padding = self._output_padding(
115 | input, output_size, self.stride, self.padding, self.kernel_size
116 | )
117 | return torch.conv_transpose2d(
118 | input=x,
119 | weight=self.weight * self.scale, # scale the weight on runtime
120 | bias=self.bias,
121 | stride=self.stride,
122 | padding=self.padding,
123 | output_padding=output_padding,
124 | groups=self.groups,
125 | dilation=self.dilation,
126 | )
127 |
128 |
129 | class EqualizedLinear(Linear):
130 | def __init__(self, in_features, out_features, bias=True) -> None:
131 | super().__init__(in_features, out_features, bias)
132 |
133 | # make sure that the self.weight and self.bias are initialized according to
134 | # random normal distribution
135 | torch.nn.init.normal_(self.weight)
136 | if bias:
137 | torch.nn.init.zeros_(self.bias)
138 |
139 | # define the scale for the weights:
140 | fan_in = self.in_features
141 | self.scale = np.sqrt(2) / np.sqrt(fan_in)
142 |
143 | def forward(self, x: Tensor) -> Tensor:
144 | return torch.nn.functional.linear(x, self.weight * self.scale, self.bias)
145 |
146 |
147 | class PixelwiseNorm(torch.nn.Module):
148 | """
149 | ------------------------------------------------------------------------------------
150 | Pixelwise feature vector normalization.
151 | reference:
152 | https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py#L120
153 | ------------------------------------------------------------------------------------
154 | """
155 |
156 | def __init__(self):
157 | super(PixelwiseNorm, self).__init__()
158 |
159 | @staticmethod
160 | def forward(x: Tensor, alpha: float = 1e-8) -> Tensor:
161 | y = x.pow(2.0).mean(dim=1, keepdim=True).add(alpha).sqrt() # [N1HW]
162 | y = x / y # normalize the input x volume
163 | return y
164 |
165 |
166 | class MinibatchStdDev(torch.nn.Module):
167 | """
168 | Minibatch standard deviation layer for the discriminator
169 | Args:
170 | group_size: Size of each group into which the batch is split
171 | """
172 |
173 | def __init__(self, group_size: int = 4) -> None:
174 | """
175 |
176 | Args:
177 | group_size: Size of each group into which the batch is split
178 | """
179 | super(MinibatchStdDev, self).__init__()
180 | self.group_size = group_size
181 |
182 | def extra_repr(self) -> str:
183 | return f"group_size={self.group_size}"
184 |
185 | def forward(self, x: Tensor, alpha: float = 1e-8) -> Tensor:
186 | """
187 | forward pass of the layer
188 | Args:
189 | x: input activation volume
190 | alpha: small number for numerical stability
191 | Returns: y => x appended with standard deviation constant map
192 | """
193 | batch_size, channels, height, width = x.shape
194 | if batch_size > self.group_size:
195 | assert batch_size % self.group_size == 0, (
196 | f"batch_size {batch_size} should be "
197 | f"perfectly divisible by group_size {self.group_size}"
198 | )
199 | group_size = self.group_size
200 | else:
201 | group_size = batch_size
202 |
203 | # reshape x into a more amenable sized tensor
204 | y = torch.reshape(x, [group_size, -1, channels, height, width])
205 |
206 | # indicated shapes are after performing the operation
207 | # [G x M x C x H x W] Subtract mean over groups
208 | y = y - y.mean(dim=0, keepdim=True)
209 |
210 | # [M x C x H x W] Calc standard deviation over the groups
211 | y = torch.sqrt(y.square().mean(dim=0, keepdim=False) + alpha)
212 |
213 | # [M x 1 x 1 x 1] Take average over feature_maps and pixels.
214 | y = y.mean(dim=[1, 2, 3], keepdim=True)
215 |
216 | # [B x 1 x H x W] Replicate over group and pixels
217 | y = y.repeat(group_size, 1, height, width)
218 |
219 | # [B x (C + 1) x H x W] Append as new feature_map.
220 | y = torch.cat([x, y], 1)
221 |
222 | # return the computed values:
223 | return y
224 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/data_tools.py:
--------------------------------------------------------------------------------
1 | """ Module for the data loading pipeline for the model to train """
2 | from pathlib import Path
3 | from typing import Any, Callable, List, Optional, Tuple
4 |
5 | import numpy as np
6 | from PIL import Image
7 |
8 | from torch import Tensor
9 | from torch.utils.data import DataLoader, Dataset
10 | from torchvision.transforms import Compose, RandomHorizontalFlip, Resize, ToTensor
11 |
12 | from .utils import adjust_dynamic_range
13 |
14 |
15 | class NoOp(object):
16 | """A NoOp image transform utility. Does nothing, but makes the code cleaner"""
17 |
18 | def __call__(self, whatever: Any) -> Any:
19 | return whatever
20 |
21 | def __repr__(self) -> str:
22 | return self.__class__.__name__ + "()"
23 |
24 |
25 | def get_transform(
26 | new_size: Optional[Tuple[int, int]] = None, flip_horizontal: bool = False
27 | ) -> Callable[[Image.Image], Tensor]:
28 | """
29 | obtain the image transforms required for the input data
30 | Args:
31 | new_size: size of the resized images (if needed, could be None)
32 | flip_horizontal: whether to randomly mirror input images during training
33 | Returns: requested transform object from TorchVision
34 | """
35 | return Compose(
36 | [
37 | RandomHorizontalFlip(p=0.5) if flip_horizontal else NoOp(),
38 | Resize(new_size) if new_size is not None else NoOp(),
39 | ToTensor(),
40 | ]
41 | )
42 |
43 |
44 | class ImageDirectoryDataset(Dataset):
45 | """pyTorch Dataset wrapper for the simple case of flat directory images dataset
46 | Args:
47 | data_dir: directory containing all the images
48 | transform: whether to apply a certain transformation to the images
49 | rec_dir: whether to search all the sub-level directories for files
50 | recursively
51 | """
52 |
53 | def __init__(
54 | self,
55 | data_dir: Path,
56 | transform: Callable[[Image.Image], Tensor] = get_transform(),
57 | input_data_range: Tuple[float, float] = (0.0, 1.0),
58 | output_data_range: Tuple[float, float] = (-1.0, 1.0),
59 | rec_dir: bool = False,
60 | ) -> None:
61 | # define the state of the object
62 | self.rec_dir = rec_dir
63 | self.data_dir = data_dir
64 | self.transform = transform
65 | self.output_data_range = output_data_range
66 | self.input_data_range = input_data_range
67 |
68 | # setup the files for reading
69 | self.files = self._get_files(data_dir, rec_dir)
70 |
71 | def _get_files(self, path: Path, rec: bool = False) -> List[Path]:
72 | """
73 | helper function to search the given directory and obtain all the files in it's
74 | structure
75 | Args:
76 | path: path to the root directory
77 | rec: whether to search all the sub-level directories for files recursively
78 | Returns: list of all found paths
79 | """
80 | files = []
81 | for possible_file in path.iterdir():
82 | if possible_file.is_file():
83 | files.append(possible_file)
84 | elif rec and possible_file.is_dir():
85 | files.extend(self._get_files(possible_file))
86 | return files
87 |
88 | def __len__(self) -> int:
89 | """
90 | compute the length of the dataset
91 | Returns: len => length of dataset
92 | """
93 | return len(self.files)
94 |
95 | def __getitem__(self, item: int) -> Tensor:
96 | """
97 | obtain the image (read and transform)
98 | Args:
99 | item: index for the required image
100 | Returns: img => image array
101 | """
102 | # read the image:
103 | image = self.files[item]
104 | if image.name.endswith(".npy"):
105 | img = np.load(str(image))
106 | img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0))
107 | else:
108 | img = Image.open(image)
109 |
110 | # apply the transforms on the image
111 | if self.transform is not None:
112 | img = self.transform(img)
113 |
114 | # bring the image in the required range
115 | img = adjust_dynamic_range(
116 | img, drange_in=self.input_data_range, drange_out=self.output_data_range
117 | )
118 |
119 | return img
120 |
121 |
122 | def get_data_loader(
123 | dataset: Dataset, batch_size: int, num_workers: int = 3
124 | ) -> DataLoader:
125 | """
126 | generate the data_loader from the given dataset
127 | Args:
128 | dataset: Torch dataset object
129 | batch_size: batch size for training
130 | num_workers: num of parallel readers for reading the data
131 | Returns: dataloader for the dataset
132 | """
133 | return DataLoader(
134 | dataset,
135 | batch_size=batch_size,
136 | shuffle=True,
137 | num_workers=num_workers,
138 | drop_last=True,
139 | )
140 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/gan.py:
--------------------------------------------------------------------------------
1 | """ Module implementing ProGAN which is trained using the Progressive growing
2 | technique -> https://arxiv.org/abs/1710.10196
3 | """
4 | import copy
5 | import datetime
6 | import time
7 | import timeit
8 | from pathlib import Path
9 | from typing import Any, Dict, List, Optional
10 |
11 | import numpy as np
12 |
13 | import torch
14 | from torch import Tensor
15 | from torch.nn import DataParallel, Module
16 | from torch.nn.functional import avg_pool2d, interpolate
17 | from torch.optim.optimizer import Optimizer
18 | from torch.utils.data import Dataset
19 | from torch.utils.tensorboard import SummaryWriter
20 | from torchvision.utils import save_image
21 |
22 | from .custom_layers import update_average
23 | from .data_tools import get_data_loader
24 | from .losses import GANLoss, WganGP
25 | from .networks import Discriminator, Generator
26 | from .utils import adjust_dynamic_range
27 |
28 |
29 | class ProGAN:
30 | def __init__(
31 | self,
32 | gen: Generator,
33 | dis: Discriminator,
34 | device=torch.device("cpu"),
35 | use_ema: bool = True,
36 | ema_beta: float = 0.999,
37 | ):
38 | assert gen.depth == dis.depth, (
39 | f"Generator and Discriminator depths are not compatible. "
40 | f"GEN_Depth: {gen.depth} DIS_Depth: {dis.depth}"
41 | )
42 | self.gen = gen.to(device)
43 | self.dis = dis.to(device)
44 | self.use_ema = use_ema
45 | self.ema_beta = ema_beta
46 | self.depth = gen.depth
47 | self.latent_size = gen.latent_size
48 | self.device = device
49 |
50 | # if code is to be run on GPU, we can use DataParallel:
51 | if device == torch.device("cuda"):
52 | self.gen = DataParallel(self.gen)
53 | self.dis = DataParallel(self.dis)
54 |
55 | print(f"Generator Network: {self.gen}")
56 | print(f"Discriminator Network: {self.dis}")
57 |
58 | if self.use_ema:
59 | # create a shadow copy of the generator
60 | self.gen_shadow = copy.deepcopy(self.gen)
61 |
62 | # initialize the gen_shadow weights equal to the
63 | # weights of gen
64 | update_average(self.gen_shadow, self.gen, beta=0)
65 |
66 | # counters to maintain generator and discriminator gradient overflows
67 | self.gen_overflow_count = 0
68 | self.dis_overflow_count = 0
69 |
70 | def progressive_downsample_batch(self, real_batch, depth, alpha):
71 | """
72 | private helper for downsampling the original images in order to facilitate the
73 | progressive growing of the layers.
74 | Args:
75 | real_batch: batch of real samples
76 | depth: depth at which training is going on
77 | alpha: current value of the fader alpha
78 |
79 | Returns: modified real batch of samples
80 |
81 | """
82 | # downsample the real_batch for the given depth
83 | down_sample_factor = int(2 ** (self.depth - depth))
84 | prior_downsample_factor = int(2 ** (self.depth - depth + 1))
85 |
86 | ds_real_samples = avg_pool2d(
87 | real_batch, kernel_size=down_sample_factor, stride=down_sample_factor
88 | )
89 |
90 | if depth > 2:
91 | prior_ds_real_samples = interpolate(
92 | avg_pool2d(
93 | real_batch,
94 | kernel_size=prior_downsample_factor,
95 | stride=prior_downsample_factor,
96 | ),
97 | scale_factor=2,
98 | )
99 | else:
100 | prior_ds_real_samples = ds_real_samples
101 |
102 | # real samples are a linear combination of
103 | # ds_real_samples and prior_ds_real_samples
104 | real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)
105 |
106 | return real_samples
107 |
108 | def optimize_discriminator(
109 | self,
110 | loss: GANLoss,
111 | dis_optim: Optimizer,
112 | noise: Tensor,
113 | real_batch: Tensor,
114 | depth: int,
115 | alpha: float,
116 | labels: Optional[Tensor] = None,
117 | ) -> float:
118 | """
119 | performs a single weight update step on discriminator using the batch of data
120 | and the noise
121 | Args:
122 | loss: the loss function to be used for the optimization
123 | dis_optim: discriminator optimizer
124 | noise: input noise for sample generation
125 | real_batch: real samples batch
126 | depth: current depth of optimization
127 | alpha: current alpha for fade-in
128 | labels: labels for conditional discrimination
129 |
130 | Returns: discriminator loss value
131 | """
132 | real_samples = self.progressive_downsample_batch(real_batch, depth, alpha)
133 |
134 | # generate a batch of samples
135 | fake_samples = self.gen(noise, depth, alpha).detach()
136 | dis_loss = loss.dis_loss(
137 | self.dis, real_samples, fake_samples, depth, alpha, labels=labels
138 | )
139 |
140 | # optimize discriminator
141 | dis_optim.zero_grad()
142 | dis_loss.backward()
143 | if self._check_grad_ok(self.dis):
144 | dis_optim.step()
145 | else:
146 | self.dis_overflow_count += 1
147 |
148 | return dis_loss.item()
149 |
150 | def optimize_generator(
151 | self,
152 | loss: GANLoss,
153 | gen_optim: Optimizer,
154 | noise: Tensor,
155 | real_batch: Tensor,
156 | depth: int,
157 | alpha: float,
158 | labels: Optional[Tensor] = None,
159 | ) -> float:
160 | """
161 | performs a single weight update step on generator using the batch of data
162 | and the noise
163 | Args:
164 | loss: the loss function to be used for the optimization
165 | gen_optim: generator optimizer
166 | noise: input noise for sample generation
167 | real_batch: real samples batch
168 | depth: current depth of optimization
169 | alpha: current alpha for fade-in
170 | labels: labels for conditional discrimination
171 |
172 | Returns: generator loss value
173 | """
174 | real_samples = self.progressive_downsample_batch(real_batch, depth, alpha)
175 |
176 | # generate fake samples:
177 | fake_samples = self.gen(noise, depth, alpha)
178 |
179 | gen_loss = loss.gen_loss(
180 | self.dis, real_samples, fake_samples, depth, alpha, labels=labels
181 | )
182 |
183 | # optimize the generator
184 | gen_optim.zero_grad()
185 | gen_loss.backward()
186 | if self._check_grad_ok(self.gen):
187 | gen_optim.step()
188 | else:
189 | self.gen_overflow_count += 1
190 |
191 | return gen_loss.item()
192 |
193 | @staticmethod
194 | def create_grid(
195 | samples: Tensor,
196 | scale_factor: int,
197 | img_file: Path,
198 | ) -> None:
199 | """
200 | utility function to create a grid of GAN samples
201 | Args:
202 | samples: generated samples for feedback
203 | scale_factor: factor for upscaling the image
204 | img_file: name of file to write
205 | Returns: None (saves a file)
206 | """
207 | # upsample the image
208 | if scale_factor > 1:
209 | samples = interpolate(samples, scale_factor=scale_factor)
210 |
211 | samples = adjust_dynamic_range(
212 | samples, drange_in=(-1.0, 1.0), drange_out=(0.0, 1.0)
213 | )
214 |
215 | # save the images:
216 | save_image(samples, img_file, nrow=int(np.sqrt(len(samples))), padding=0)
217 |
218 | def _toggle_all_networks(self, mode="train"):
219 | for network in (self.gen, self.dis):
220 | if mode.lower() == "train":
221 | network.train()
222 | elif mode.lower() == "eval":
223 | network.eval()
224 | else:
225 | raise ValueError(f"Unknown mode requested: {mode}")
226 |
227 | @staticmethod
228 | def _check_grad_ok(network: Module) -> bool:
229 | grad_ok = True
230 | for _, param in network.named_parameters():
231 | if param.grad is not None:
232 | param_ok = (
233 | torch.sum(torch.isnan(param.grad)) == 0
234 | and torch.sum(torch.isinf(param.grad)) == 0
235 | )
236 | if not param_ok:
237 | grad_ok = False
238 | break
239 | return grad_ok
240 |
241 | def get_save_info(
242 | self, gen_optim: Optimizer, dis_optim: Optimizer
243 | ) -> Dict[str, Any]:
244 |
245 | if self.device == torch.device("cpu"):
246 | generator_save_info = self.gen.get_save_info()
247 | discriminator_save_info = self.dis.get_save_info()
248 | else:
249 | generator_save_info = self.gen.module.get_save_info()
250 | discriminator_save_info = self.dis.module.get_save_info()
251 | save_info = {
252 | "generator": generator_save_info,
253 | "discriminator": discriminator_save_info,
254 | "gen_optim": gen_optim.state_dict(),
255 | "dis_optim": dis_optim.state_dict(),
256 | }
257 | if self.use_ema:
258 | save_info["shadow_generator"] = (
259 | self.gen_shadow.get_save_info()
260 | if self.device == torch.device("cpu")
261 | else self.gen_shadow.module.get_save_info()
262 | )
263 | return save_info
264 |
265 | def train(
266 | self,
267 | dataset: Dataset,
268 | epochs: List[int],
269 | batch_sizes: List[int],
270 | fade_in_percentages: List[int],
271 | loss_fn: GANLoss = WganGP(),
272 | batch_repeats: int = 4,
273 | gen_learning_rate: float = 0.003,
274 | dis_learning_rate: float = 0.003,
275 | num_samples: int = 16,
276 | start_depth: int = 2,
277 | num_workers: int = 3,
278 | feedback_factor: int = 100,
279 | save_dir=Path("./train"),
280 | checkpoint_factor: int = 10,
281 | ):
282 | """
283 | # TODO implement support for conditional GAN here
284 | Utility method for training the ProGAN.
285 | Note that you don't have to necessarily use this method. You can use the
286 | optimize_generator and optimize_discriminator and define your own
287 | training routine
288 | Args:
289 | dataset: object of the dataset used for training.
290 | Note that this is not the dataloader (we create dataloader in this
291 | method since the batch_sizes for resolutions can be different)
292 | epochs: list of number of epochs to train the network for every resolution
293 | batch_sizes: list of batch_sizes for every resolution
294 | fade_in_percentages: list of percentages of epochs per resolution
295 | used for fading in the new layer not used for
296 | first resolution, but dummy value is still needed
297 | loss_fn: loss function used for training
298 | batch_repeats: number of iterations to perform on a single batch
299 | gen_learning_rate: generator learning rate
300 | dis_learning_rate: discriminator learning rate
301 | num_samples: number of samples generated in sample_sheet
302 | start_depth: start training from this depth
303 | num_workers: number of workers for reading the data
304 | feedback_factor: number of logs per epoch
305 | save_dir: directory for saving the models (.bin files)
306 | checkpoint_factor: save model after these many epochs.
307 | Returns: None (Writes multiple files to disk)
308 | """
309 |
310 | print(f"Loaded the dataset with: {len(dataset)} images ...")
311 | assert (self.depth - 1) == len(
312 | batch_sizes
313 | ), "batch_sizes are not compatible with depth"
314 | assert (self.depth - 1) == len(epochs), "epochs are not compatible with depth"
315 |
316 | self._toggle_all_networks("train")
317 |
318 | # create the generator and discriminator optimizers
319 | gen_optim = torch.optim.Adam(
320 | params=self.gen.parameters(),
321 | lr=gen_learning_rate,
322 | betas=(0, 0.99),
323 | eps=1e-8,
324 | )
325 | dis_optim = torch.optim.Adam(
326 | params=self.dis.parameters(),
327 | lr=dis_learning_rate,
328 | betas=(0, 0.99),
329 | eps=1e-8,
330 | )
331 |
332 | # verbose stuff
333 | print("setting up the image saving mechanism")
334 | model_dir, log_dir = save_dir / "models", save_dir / "logs"
335 | model_dir.mkdir(parents=True, exist_ok=True)
336 | log_dir.mkdir(parents=True, exist_ok=True)
337 |
338 | feedback_generator = self.gen_shadow if self.use_ema else self.gen
339 |
340 | # image saving mechanism
341 | with torch.no_grad():
342 | dummy_data_loader = get_data_loader(dataset, num_samples, num_workers)
343 | real_images_for_render = next(iter(dummy_data_loader))
344 | fixed_input = torch.randn(num_samples, self.latent_size).to(self.device)
345 | self.create_grid(
346 | real_images_for_render,
347 | scale_factor=1,
348 | img_file=log_dir / "real_images.png",
349 | )
350 | self.create_grid(
351 | feedback_generator(fixed_input, self.depth, 1).detach(),
352 | scale_factor=1,
353 | img_file=log_dir / "fake_images_0.png",
354 | )
355 |
356 | # tensorboard summarywriter:
357 | summary = SummaryWriter(str(log_dir / "tensorboard"))
358 |
359 | # create a global time counter
360 | global_time = time.time()
361 | global_step = 0
362 |
363 | print("Starting the training process ... ")
364 | for current_depth in range(start_depth, self.depth + 1):
365 | current_res = int(2 ** current_depth)
366 | print(f"\n\nCurrently working on Depth: {current_depth}")
367 | print("Current resolution: %d x %d" % (current_res, current_res))
368 | depth_list_index = current_depth - 2
369 | current_batch_size = batch_sizes[depth_list_index]
370 | data = get_data_loader(dataset, current_batch_size, num_workers)
371 | ticker = 1
372 | for epoch in range(1, epochs[depth_list_index] + 1):
373 | start = timeit.default_timer() # record time at the start of epoch
374 | print(f"\nEpoch: {epoch}")
375 | total_batches = len(data)
376 |
377 | # compute the fader point
378 | fader_point = int(
379 | (fade_in_percentages[depth_list_index] / 100)
380 | * epochs[depth_list_index]
381 | * total_batches
382 | )
383 |
384 | for (i, batch) in enumerate(data, start=1):
385 | # calculate the alpha for fading in the layers
386 | alpha = ticker / fader_point if ticker <= fader_point else 1
387 |
388 | # extract current batch of data for training
389 | images = batch.to(self.device)
390 |
391 | gan_input = torch.randn(current_batch_size, self.latent_size).to(
392 | self.device
393 | )
394 |
395 | gen_loss, dis_loss = None, None
396 | for _ in range(batch_repeats):
397 | # optimize the discriminator:
398 | dis_loss = self.optimize_discriminator(
399 | loss_fn, dis_optim, gan_input, images, current_depth, alpha
400 | )
401 |
402 | # no idea why this needs to be done after discriminator
403 | # iteration, but this is where it is done in the original
404 | # code
405 | if self.use_ema:
406 | update_average(
407 | self.gen_shadow, self.gen, beta=self.ema_beta
408 | )
409 |
410 | # optimize the generator:
411 | gen_loss = self.optimize_generator(
412 | loss_fn, gen_optim, gan_input, images, current_depth, alpha
413 | )
414 | global_step += 1
415 |
416 | # provide a loss feedback
417 | if (
418 | i % max(int(total_batches / max(feedback_factor, 1)), 1) == 0
419 | or i == 1
420 | or i == total_batches
421 | ):
422 | elapsed = time.time() - global_time
423 | elapsed = str(datetime.timedelta(seconds=elapsed))
424 | print(
425 | "Elapsed: [%s] batch: %d d_loss: %f g_loss: %f"
426 | % (elapsed, i, dis_loss, gen_loss)
427 | )
428 | summary.add_scalar(
429 | "dis_loss", dis_loss, global_step=global_step
430 | )
431 | summary.add_scalar(
432 | "gen_loss", gen_loss, global_step=global_step
433 | )
434 | # create a grid of samples and save it
435 | resolution_dir = log_dir / str(int(2 ** current_depth))
436 | resolution_dir.mkdir(exist_ok=True)
437 | gen_img_file = resolution_dir / f"{epoch}_{i}.png"
438 |
439 | # this is done to allow for more GPU space
440 | with torch.no_grad():
441 | self.create_grid(
442 | samples=feedback_generator(
443 | fixed_input, current_depth, alpha
444 | ).detach(),
445 | scale_factor=int(2 ** (self.depth - current_depth)),
446 | img_file=gen_img_file,
447 | )
448 |
449 | # increment the alpha ticker and the step
450 | ticker += 1
451 |
452 | stop = timeit.default_timer()
453 | print("Time taken for epoch: %.3f secs" % (stop - start))
454 |
455 | if (
456 | epoch % checkpoint_factor == 0
457 | or epoch == 1
458 | or epoch == epochs[depth_list_index]
459 | ):
460 | save_file = model_dir / f"depth_{current_depth}_epoch_{epoch}.bin"
461 | torch.save(self.get_save_info(gen_optim, dis_optim), save_file)
462 |
463 | self._toggle_all_networks("eval")
464 | print("Training completed ...")
465 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/losses.py:
--------------------------------------------------------------------------------
1 | """ Module implementing various loss functions """
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor
6 | from torch.nn import BCEWithLogitsLoss
7 |
8 | from .networks import Discriminator
9 |
10 |
11 | class GANLoss:
12 | def dis_loss(
13 | self,
14 | discriminator: Discriminator,
15 | real_samples: Tensor,
16 | fake_samples: Tensor,
17 | depth: int,
18 | alpha: float,
19 | labels: Optional[Tensor] = None,
20 | ) -> Tensor:
21 | """
22 | calculate the discriminator loss using the following data
23 | Args:
24 | discriminator: the Discriminator used by the GAN
25 | real_samples: real batch of samples
26 | fake_samples: fake batch of samples
27 | depth: resolution log 2 of the images under consideration
28 | alpha: alpha value of the fader
29 | labels: optional in case of the conditional discriminator
30 |
31 | Returns: computed discriminator loss
32 | """
33 | raise NotImplementedError("dis_loss method has not been implemented")
34 |
35 | def gen_loss(
36 | self,
37 | discriminator: Discriminator,
38 | real_samples: Tensor,
39 | fake_samples: Tensor,
40 | depth: int,
41 | alpha: float,
42 | labels: Optional[Tensor] = None,
43 | ) -> Tensor:
44 | """
45 | calculate the generator loss using the following data
46 | Args:
47 | discriminator: the Discriminator used by the GAN
48 | real_samples: real batch of samples
49 | fake_samples: fake batch of samples
50 | depth: resolution log 2 of the images under consideration
51 | alpha: alpha value of the fader
52 | labels: optional in case of the conditional discriminator
53 |
54 | Returns: computed discriminator loss
55 | """
56 | raise NotImplementedError("gen_loss method has not been implemented")
57 |
58 |
59 | class StandardGAN(GANLoss):
60 | def __init__(self):
61 | self.criterion = BCEWithLogitsLoss()
62 |
63 | def dis_loss(
64 | self,
65 | discriminator: Discriminator,
66 | real_samples: Tensor,
67 | fake_samples: Tensor,
68 | depth: int,
69 | alpha: float,
70 | labels: Optional[Tensor] = None,
71 | ) -> Tensor:
72 | if labels is not None:
73 | assert discriminator.conditional, "labels passed to an unconditional dis"
74 | real_scores = discriminator(real_samples, depth, alpha, labels)
75 | fake_scores = discriminator(fake_samples, depth, alpha, labels)
76 | else:
77 | real_scores = discriminator(real_samples, depth, alpha)
78 | fake_scores = discriminator(fake_samples, depth, alpha)
79 |
80 | real_loss = self.criterion(
81 | real_scores, torch.ones(real_scores.shape).to(real_scores.device)
82 | )
83 | fake_loss = self.criterion(
84 | fake_scores, torch.zeros(fake_scores.shape).to(fake_scores.device)
85 | )
86 | return (real_loss + fake_loss) / 2
87 |
88 | def gen_loss(
89 | self,
90 | discriminator: Discriminator,
91 | _: Tensor,
92 | fake_samples: Tensor,
93 | depth: int,
94 | alpha: float,
95 | labels: Optional[Tensor] = None,
96 | ) -> Tensor:
97 | if labels is not None:
98 | assert discriminator.conditional, "labels passed to an unconditional dis"
99 | fake_scores = discriminator(fake_samples, depth, alpha, labels)
100 | else:
101 | fake_scores = discriminator(fake_samples, depth, alpha)
102 | return self.criterion(
103 | fake_scores, torch.ones(fake_scores.shape).to(fake_scores.device)
104 | )
105 |
106 |
107 | class WganGP(GANLoss):
108 | """
109 | Wgan-GP loss function. The discriminator is required for computing the gradient
110 | penalty.
111 | Args:
112 | drift: weight for the drift penalty
113 | """
114 |
115 | def __init__(self, drift: float = 0.001) -> None:
116 | self.drift = drift
117 |
118 | @staticmethod
119 | def _gradient_penalty(
120 | dis: Discriminator,
121 | real_samples: Tensor,
122 | fake_samples: Tensor,
123 | depth: int,
124 | alpha: float,
125 | reg_lambda: float = 10,
126 | labels: Optional[Tensor] = None,
127 | ) -> Tensor:
128 | """
129 | private helper for calculating the gradient penalty
130 | Args:
131 | dis: the discriminator used for computing the penalty
132 | real_samples: real samples
133 | fake_samples: fake samples
134 | depth: current depth in the optimization
135 | alpha: current alpha for fade-in
136 | reg_lambda: regularisation lambda
137 |
138 | Returns: computed gradient penalty
139 | """
140 | batch_size = real_samples.shape[0]
141 |
142 | # generate random epsilon
143 | epsilon = torch.rand((batch_size, 1, 1, 1)).to(real_samples.device)
144 |
145 | # create the merge of both real and fake samples
146 | merged = epsilon * real_samples + ((1 - epsilon) * fake_samples)
147 | merged.requires_grad_(True)
148 |
149 | # forward pass
150 | if labels is not None:
151 | assert dis.conditional, "labels passed to an unconditional discriminator"
152 | op = dis(merged, depth, alpha, labels)
153 | else:
154 | op = dis(merged, depth, alpha)
155 |
156 | # perform backward pass from op to merged for obtaining the gradients
157 | gradient = torch.autograd.grad(
158 | outputs=op,
159 | inputs=merged,
160 | grad_outputs=torch.ones_like(op),
161 | create_graph=True,
162 | retain_graph=True,
163 | only_inputs=True,
164 | )[0]
165 |
166 | # calculate the penalty using these gradients
167 | gradient = gradient.view(gradient.shape[0], -1)
168 | penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean()
169 |
170 | # return the calculated penalty:
171 | return penalty
172 |
173 | def dis_loss(
174 | self,
175 | discriminator: Discriminator,
176 | real_samples: Tensor,
177 | fake_samples: Tensor,
178 | depth: int,
179 | alpha: float,
180 | labels: Optional[Tensor] = None,
181 | ) -> Tensor:
182 | if labels is not None:
183 | assert discriminator.conditional, "labels passed to an unconditional dis"
184 | real_scores = discriminator(real_samples, depth, alpha, labels)
185 | fake_scores = discriminator(fake_samples, depth, alpha, labels)
186 | else:
187 | real_scores = discriminator(real_samples, depth, alpha)
188 | fake_scores = discriminator(fake_samples, depth, alpha)
189 | loss = (
190 | torch.mean(fake_scores)
191 | - torch.mean(real_scores)
192 | + (self.drift * torch.mean(real_scores ** 2))
193 | )
194 |
195 | # calculate the WGAN-GP (gradient penalty)
196 | gp = self._gradient_penalty(
197 | discriminator, real_samples, fake_samples, depth, alpha, labels=labels
198 | )
199 | loss += gp
200 |
201 | return loss
202 |
203 | def gen_loss(
204 | self,
205 | discriminator: Discriminator,
206 | _: Tensor,
207 | fake_samples: Tensor,
208 | depth: int,
209 | alpha: float,
210 | labels: Optional[Tensor] = None,
211 | ) -> Tensor:
212 | if labels is not None:
213 | assert discriminator.conditional, "labels passed to an unconditional dis"
214 | fake_scores = discriminator(fake_samples, depth, alpha, labels)
215 | else:
216 | fake_scores = discriminator(fake_samples, depth, alpha)
217 | return -torch.mean(fake_scores)
218 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .custom_layers import (
3 | EqualizedConv2d,
4 | EqualizedConvTranspose2d,
5 | MinibatchStdDev,
6 | PixelwiseNorm,
7 | )
8 | from torch import Tensor
9 | from torch.nn import AvgPool2d, Conv2d, ConvTranspose2d, Embedding, LeakyReLU, Module
10 | from torch.nn.functional import interpolate
11 |
12 |
13 | class GenInitialBlock(Module):
14 | """
15 | Module implementing the initial block of the input
16 | Args:
17 | in_channels: number of input channels to the block
18 | out_channels: number of output channels of the block
19 | use_eql: whether to use equalized learning rate
20 | """
21 |
22 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
23 | super(GenInitialBlock, self).__init__()
24 | self.use_eql = use_eql
25 |
26 | ConvBlock = EqualizedConv2d if use_eql else Conv2d
27 | ConvTransposeBlock = EqualizedConvTranspose2d if use_eql else ConvTranspose2d
28 |
29 | self.conv_1 = ConvTransposeBlock(in_channels, out_channels, (4, 4), bias=True)
30 | self.conv_2 = ConvBlock(
31 | out_channels, out_channels, (3, 3), padding=1, bias=True
32 | )
33 | self.pixNorm = PixelwiseNorm()
34 | self.lrelu = LeakyReLU(0.2)
35 |
36 | def forward(self, x: Tensor) -> Tensor:
37 | y = torch.unsqueeze(torch.unsqueeze(x, -1), -1)
38 | y = self.pixNorm(y) # normalize the latents to hypersphere
39 | y = self.lrelu(self.conv_1(y))
40 | y = self.lrelu(self.conv_2(y))
41 | y = self.pixNorm(y)
42 | return y
43 |
44 |
45 | class GenGeneralConvBlock(torch.nn.Module):
46 | """
47 | Module implementing a general convolutional block
48 | Args:
49 | in_channels: number of input channels to the block
50 | out_channels: number of output channels required
51 | use_eql: whether to use equalized learning rate
52 | """
53 |
54 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
55 | super(GenGeneralConvBlock, self).__init__()
56 | self.in_channels = in_channels
57 | self.out_channels = in_channels
58 | self.use_eql = use_eql
59 |
60 | ConvBlock = EqualizedConv2d if use_eql else Conv2d
61 |
62 | self.conv_1 = ConvBlock(in_channels, out_channels, (3, 3), padding=1, bias=True)
63 | self.conv_2 = ConvBlock(
64 | out_channels, out_channels, (3, 3), padding=1, bias=True
65 | )
66 | self.pixNorm = PixelwiseNorm()
67 | self.lrelu = LeakyReLU(0.2)
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | y = interpolate(x, scale_factor=2)
71 | y = self.pixNorm(self.lrelu(self.conv_1(y)))
72 | y = self.pixNorm(self.lrelu(self.conv_2(y)))
73 |
74 | return y
75 |
76 |
77 | class DisFinalBlock(torch.nn.Module):
78 | """
79 | Final block for the Discriminator
80 | Args:
81 | in_channels: number of input channels
82 | use_eql: whether to use equalized learning rate
83 | """
84 |
85 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
86 | super(DisFinalBlock, self).__init__()
87 | self.in_channels = in_channels
88 | self.out_channels = out_channels
89 | self.use_eql = use_eql
90 |
91 | ConvBlock = EqualizedConv2d if use_eql else Conv2d
92 |
93 | self.conv_1 = ConvBlock(
94 | in_channels + 1, in_channels, (3, 3), padding=1, bias=True
95 | )
96 | self.conv_2 = ConvBlock(in_channels, out_channels, (4, 4), bias=True)
97 | self.conv_3 = ConvBlock(out_channels, 1, (1, 1), bias=True)
98 | self.batch_discriminator = MinibatchStdDev()
99 | self.lrelu = LeakyReLU(0.2)
100 |
101 | def forward(self, x: Tensor) -> Tensor:
102 | y = self.batch_discriminator(x)
103 | y = self.lrelu(self.conv_1(y))
104 | y = self.lrelu(self.conv_2(y))
105 | y = self.conv_3(y)
106 | return y.view(-1)
107 |
108 |
109 | class ConDisFinalBlock(torch.nn.Module):
110 | """ Final block for the Conditional Discriminator
111 | Uses the Projection mechanism
112 | from the paper -> https://arxiv.org/pdf/1802.05637.pdf
113 | Args:
114 | in_channels: number of input channels
115 | num_classes: number of classes for conditional discrimination
116 | use_eql: whether to use equalized learning rate
117 | """
118 |
119 | def __init__(
120 | self, in_channels: int, out_channels: int, num_classes: int, use_eql: bool
121 | ) -> None:
122 | super(ConDisFinalBlock, self).__init__()
123 | self.in_channels = in_channels
124 | self.out_channels = out_channels
125 | self.num_classes = num_classes
126 | self.use_eql = use_eql
127 |
128 | ConvBlock = EqualizedConv2d if use_eql else Conv2d
129 |
130 | self.conv_1 = ConvBlock(
131 | in_channels + 1, in_channels, (3, 3), padding=1, bias=True
132 | )
133 | self.conv_2 = ConvBlock(in_channels, out_channels, (4, 4), bias=True)
134 | self.conv_3 = ConvBlock(out_channels, 1, (1, 1), bias=True)
135 |
136 | # we also need an embedding matrix for the label vectors
137 | self.label_embedder = Embedding(num_classes, out_channels, max_norm=1)
138 | self.batch_discriminator = MinibatchStdDev()
139 | self.lrelu = LeakyReLU(0.2)
140 |
141 | def forward(self, x: Tensor, labels: Tensor) -> Tensor:
142 | y = self.batch_discriminator(x)
143 | y = self.lrelu(self.conv_1(y))
144 | y = self.lrelu(self.conv_2(y))
145 |
146 | # embed the labels
147 | labels = self.label_embedder(labels) # [B x C]
148 |
149 | # compute the inner product with the label embeddings
150 | y_ = torch.squeeze(torch.squeeze(y, dim=-1), dim=-1) # [B x C]
151 | projection_scores = (y_ * labels).sum(dim=-1) # [B]
152 |
153 | # normal discrimination score
154 | y = self.lrelu(self.conv_3(y)) # This layer has linear activation
155 |
156 | # calculate the total score
157 | final_score = y.view(-1) + projection_scores
158 |
159 | # return the output raw discriminator scores
160 | return final_score
161 |
162 |
163 | class DisGeneralConvBlock(torch.nn.Module):
164 | """
165 | General block in the discriminator
166 | Args:
167 | in_channels: number of input channels
168 | out_channels: number of output channels
169 | use_eql: whether to use equalized learning rate
170 | """
171 |
172 | def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
173 | super(DisGeneralConvBlock, self).__init__()
174 | self.in_channels = in_channels
175 | self.out_channels = out_channels
176 | self.use_eql = use_eql
177 |
178 | ConvBlock = EqualizedConv2d if use_eql else Conv2d
179 |
180 | self.conv_1 = ConvBlock(in_channels, in_channels, (3, 3), padding=1, bias=True)
181 | self.conv_2 = ConvBlock(in_channels, out_channels, (3, 3), padding=1, bias=True)
182 | self.downSampler = AvgPool2d(2)
183 | self.lrelu = LeakyReLU(0.2)
184 |
185 | def forward(self, x: Tensor) -> Tensor:
186 | y = self.lrelu(self.conv_1(x))
187 | y = self.lrelu(self.conv_2(y))
188 | y = self.downSampler(y)
189 | return y
190 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/networks.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Dict, Optional
3 |
4 | import numpy as np
5 | import torch
6 |
7 | import torch as th
8 | from .custom_layers import EqualizedConv2d
9 | from .modules import (
10 | ConDisFinalBlock,
11 | DisFinalBlock,
12 | DisGeneralConvBlock,
13 | GenGeneralConvBlock,
14 | GenInitialBlock,
15 | )
16 | from torch import Tensor
17 | from torch.nn import Conv2d, LeakyReLU, ModuleList, Sequential
18 | from torch.nn.functional import avg_pool2d, interpolate
19 |
20 |
21 | def nf(
22 | stage: int,
23 | fmap_base: int = 16 << 10,
24 | fmap_decay: float = 1.0,
25 | fmap_min: int = 1,
26 | fmap_max: int = 512,
27 | ) -> int:
28 | """
29 | computes the number of fmaps present in each stage
30 | Args:
31 | stage: stage level
32 | fmap_base: base number of fmaps
33 | fmap_decay: decay rate for the fmaps in the network
34 | fmap_min: minimum number of fmaps
35 | fmap_max: maximum number of fmaps
36 |
37 | Returns: number of fmaps that should be present there
38 | """
39 | return int(
40 | np.clip(
41 | int(fmap_base / (2.0 ** (stage * fmap_decay))),
42 | fmap_min,
43 | fmap_max,
44 | ).item()
45 | )
46 |
47 |
48 | class Generator(th.nn.Module):
49 | """
50 | Generator Module (block) of the GAN network
51 | Args:
52 | depth: required depth of the Network (**starts from 2)
53 | num_channels: number of output channels (default = 3 for RGB)
54 | latent_size: size of the latent manifold
55 | use_eql: whether to use equalized learning rate
56 | """
57 |
58 | def __init__(
59 | self,
60 | depth: int = 10,
61 | num_channels: int = 3,
62 | latent_size: int = 512,
63 | use_eql: bool = True,
64 | ) -> None:
65 | super().__init__()
66 |
67 | # object state:
68 | self.depth = depth
69 | self.latent_size = latent_size
70 | self.num_channels = num_channels
71 | self.use_eql = use_eql
72 |
73 | ConvBlock = EqualizedConv2d if use_eql else Conv2d
74 |
75 | self.layers = ModuleList(
76 | [GenInitialBlock(latent_size, nf(1), use_eql=self.use_eql)]
77 | )
78 | for stage in range(1, depth - 1):
79 | self.layers.append(GenGeneralConvBlock(nf(stage), nf(stage + 1), use_eql))
80 |
81 | self.rgb_converters = ModuleList(
82 | [
83 | ConvBlock(nf(stage), num_channels, kernel_size=(1, 1))
84 | for stage in range(1, depth)
85 | ]
86 | )
87 |
88 | def forward(
89 | self, x: Tensor, depth: Optional[int] = None, alpha: float = 1.0
90 | ) -> Tensor:
91 | """
92 | forward pass of the Generator
93 | Args:
94 | x: input latent noise
95 | depth: depth from where the network's output is required
96 | alpha: value of alpha for fade-in effect
97 |
98 | Returns: generated images at the give depth's resolution
99 | """
100 | depth = self.depth if depth is None else depth
101 | assert depth <= self.depth, f"Requested output depth {depth} cannot be produced"
102 |
103 | if depth == 2:
104 | y = self.rgb_converters[0](self.layers[0](x))
105 | else:
106 | y = x
107 | for layer_block in self.layers[: depth - 2]:
108 | y = layer_block(y)
109 | residual = interpolate(self.rgb_converters[depth - 3](y), scale_factor=2)
110 | straight = self.rgb_converters[depth - 2](self.layers[depth - 2](y))
111 | y = (alpha * straight) + ((1 - alpha) * residual)
112 | return y
113 |
114 | def get_save_info(self) -> Dict[str, Any]:
115 | return {
116 | "conf": {
117 | "depth": self.depth,
118 | "num_channels": self.num_channels,
119 | "latent_size": self.latent_size,
120 | "use_eql": self.use_eql,
121 | },
122 | "state_dict": self.state_dict(),
123 | }
124 |
125 |
126 | class Discriminator(th.nn.Module):
127 | """
128 | Discriminator of the GAN
129 | Args:
130 | depth: depth of the discriminator. log_2(resolution)
131 | num_channels: number of channels of the input images (Default = 3 for RGB)
132 | latent_size: latent size of the final layer
133 | use_eql: whether to use the equalized learning rate
134 | num_classes: number of classes for a conditional discriminator (Default = None)
135 | meaning unconditional discriminator
136 | """
137 |
138 | def __init__(
139 | self,
140 | depth: int = 7,
141 | num_channels: int = 3,
142 | latent_size: int = 512,
143 | use_eql: bool = True,
144 | num_classes: Optional[int] = None,
145 | ) -> None:
146 | super().__init__()
147 | self.depth = depth
148 | self.num_channels = num_channels
149 | self.latent_size = latent_size
150 | self.use_eql = use_eql
151 | self.num_classes = num_classes
152 | self.conditional = num_classes is not None
153 |
154 | ConvBlock = EqualizedConv2d if use_eql else Conv2d
155 |
156 | if self.conditional:
157 | self.layers = [ConDisFinalBlock(nf(1), latent_size, num_classes, use_eql)]
158 | else:
159 | self.layers = [DisFinalBlock(nf(1), latent_size, use_eql)]
160 |
161 | for stage in range(1, depth - 1):
162 | self.layers.insert(
163 | 0, DisGeneralConvBlock(nf(stage + 1), nf(stage), use_eql)
164 | )
165 | self.layers = ModuleList(self.layers)
166 | self.from_rgb = ModuleList(
167 | reversed(
168 | [
169 | Sequential(
170 | ConvBlock(num_channels, nf(stage), kernel_size=(1, 1)),
171 | LeakyReLU(0.2),
172 | )
173 | for stage in range(1, depth)
174 | ]
175 | )
176 | )
177 |
178 | def forward(
179 | self, x: Tensor, depth: int, alpha: float, labels: Optional[Tensor] = None
180 | ) -> Tensor:
181 | """
182 | forward pass of the discriminator
183 | Args:
184 | x: input to the network
185 | depth: current depth of operation (Progressive GAN)
186 | alpha: current value of alpha for fade-in
187 | labels: labels for conditional discriminator (Default = None)
188 | shape => (Batch_size,) shouldn't be a column vector
189 |
190 | Returns: raw discriminator scores
191 | """
192 | assert (
193 | depth <= self.depth
194 | ), f"Requested output depth {depth} cannot be evaluated"
195 |
196 | if self.conditional:
197 | assert labels is not None, "Conditional discriminator required labels"
198 |
199 | if depth > 2:
200 | residual = self.from_rgb[-(depth - 2)](
201 | avg_pool2d(x, kernel_size=2, stride=2)
202 | )
203 | straight = self.layers[-(depth - 1)](self.from_rgb[-(depth - 1)](x))
204 | y = (alpha * straight) + ((1 - alpha) * residual)
205 | for layer_block in self.layers[-(depth - 2) : -1]:
206 | y = layer_block(y)
207 | else:
208 | y = self.from_rgb[-1](x)
209 | if self.conditional:
210 | y = self.layers[-1](y, labels)
211 | else:
212 | y = self.layers[-1](y)
213 | return y
214 |
215 | def get_save_info(self) -> Dict[str, Any]:
216 | return {
217 | "conf": {
218 | "depth": self.depth,
219 | "num_channels": self.num_channels,
220 | "latent_size": self.latent_size,
221 | "use_eql": self.use_eql,
222 | "num_classes": self.num_classes,
223 | },
224 | "state_dict": self.state_dict(),
225 | }
226 |
227 |
228 | def create_generator_from_saved_model(saved_model_path: Path) -> Generator:
229 | # load the data from the saved_model
230 | loaded_data = torch.load(saved_model_path)
231 |
232 | # create a generator from the loaded data:
233 | generator_data = (
234 | loaded_data["shadow_generator"]
235 | if "shadow_generator" in loaded_data
236 | else loaded_data["generator"]
237 | )
238 | generator = Generator(**generator_data["conf"])
239 | generator.load_state_dict(generator_data["state_dict"])
240 |
241 | return generator
242 |
243 |
244 | def create_discriminator_from_saved_model(saved_model_path: Path) -> Discriminator:
245 | # load the data from the saved_model
246 | loaded_data = torch.load(saved_model_path)
247 |
248 | # create a discriminator from the loaded data:
249 | discriminator_data = (
250 | loaded_data.get("shadow_discriminator", loaded_data["discriminator"])
251 | )
252 | discriminator = Discriminator(**discriminator_data["conf"])
253 | discriminator.load_state_dict(discriminator_data["state_dict"])
254 |
255 | return discriminator
256 |
257 |
258 | def load_models(generator_path: Path, discriminator_path: Path) -> Tuple[Generator, Discriminator]:
259 | generator = create_generator_from_saved_model(generator_path)
260 | discriminator = create_discriminator_from_saved_model(discriminator_path)
261 | return generator, discriminator
262 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/pro_gan_pytorch/test/__init__.py
--------------------------------------------------------------------------------
/pro_gan_pytorch/test/conftest.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | # noinspection PyPackageRequirements
4 | import pytest
5 |
6 |
7 | @pytest.fixture
8 | def test_data_path() -> Path:
9 | return Path("/home/animesh/work/data/3d_scenes/forest_synthetic_struct/images")
10 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/test/test_custom_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..custom_layers import (
4 | EqualizedConv2d,
5 | EqualizedConvTranspose2d,
6 | EqualizedLinear,
7 | MinibatchStdDev,
8 | PixelwiseNorm,
9 | )
10 | from .utils import assert_almost_equal, device, assert_tensor_validity
11 |
12 |
13 | # noinspection PyPep8Naming
14 | def test_EqualizedConv2d() -> None:
15 | mock_in = torch.randn(32, 21, 16, 16).to(device)
16 | conv_block = EqualizedConv2d(21, 3, kernel_size=(3, 3), padding=1).to(device)
17 | print(f"Equalized conv block: {conv_block}")
18 |
19 | mock_out = conv_block(mock_in)
20 |
21 | # check output
22 | assert_tensor_validity(mock_out, (32, 3, 16, 16))
23 |
24 | # check the weight's scale
25 | assert_almost_equal(conv_block.weight.data.std().cpu(), 1, error_margin=1e-1)
26 |
27 |
28 | # noinspection PyPep8Naming
29 | def test_EqualizedConvTranspose2d() -> None:
30 | mock_in = torch.randn(32, 21, 16, 16).to(device)
31 |
32 | conv_transpose_block = EqualizedConvTranspose2d(
33 | 21, 3, kernel_size=(3, 3), padding=1
34 | ).to(device)
35 | print(f"Equalized conv transpose block: {conv_transpose_block}")
36 |
37 | mock_out = conv_transpose_block(mock_in)
38 |
39 | # check output
40 | assert_tensor_validity(mock_out, (32, 3, 16, 16))
41 |
42 | # check the weight's scale
43 | assert_almost_equal(
44 | conv_transpose_block.weight.data.std().cpu(), 1, error_margin=1e-1
45 | )
46 |
47 |
48 | # noinspection PyPep8Naming
49 | def test_EqualizedLinear() -> None:
50 | # test the forward for the first res block
51 | mock_in = torch.randn(32, 13).to(device)
52 |
53 | lin_block = EqualizedLinear(13, 52).to(device)
54 | print(f"Equalized linear block: {lin_block}")
55 |
56 | mock_out = lin_block(mock_in)
57 |
58 | # check output
59 | assert_tensor_validity(mock_out, (32, 52))
60 |
61 | # check the weight's scale
62 | assert_almost_equal(lin_block.weight.data.std().cpu(), 1, error_margin=1e-1)
63 |
64 |
65 | # noinspection PyPep8Naming
66 | def test_PixelwiseNorm() -> None:
67 | mock_in = torch.randn(1, 13, 1, 1).to(device)
68 | normalizer = PixelwiseNorm()
69 | print(f"\nNormalizerBlock: {normalizer}")
70 | mock_out = normalizer(mock_in)
71 |
72 | # check output
73 | assert_tensor_validity(mock_out, mock_in.shape)
74 |
75 | # we cannot comment that the norm of the output tensor
76 | # will always be less than the norm of the input tensor
77 | # so no more checking can be done
78 |
79 |
80 | # noinspection PyPep8Naming
81 | def test_MinibatchStdDev() -> None:
82 | mock_in = torch.randn(16, 13, 16, 16).to(device)
83 | minStdD = MinibatchStdDev()
84 | print(f"\nMiniBatchStdDevBlock: {minStdD}")
85 | mock_out = minStdD(mock_in)
86 |
87 | # check output
88 | assert mock_out.shape[1] == mock_in.shape[1] + 1
89 | assert_tensor_validity(mock_out, (16, 14, 16, 16))
90 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/test/test_gan.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | # noinspection PyPackageRequirements
4 | import matplotlib.pyplot as plt
5 |
6 | import torch
7 |
8 | from ..data_tools import ImageDirectoryDataset, get_transform
9 | from ..gan import ProGAN
10 | from ..networks import Discriminator, Generator
11 | from .utils import device
12 |
13 |
14 | def test_pro_gan_progressive_downsample_batch() -> None:
15 | batch = torch.randn((4, 3, 1024, 1024)).to(device)
16 | batch = torch.clamp(batch, min=0, max=1)
17 | progan = ProGAN(Generator(10), Discriminator(10), device=device)
18 |
19 | for res_log2 in range(2, 10):
20 | modified_batch = progan.progressive_downsample_batch(
21 | batch, depth=res_log2, alpha=0.001
22 | )
23 | print(f"Downsampled batch at res_log2 {res_log2}: {modified_batch.shape}")
24 | plt.figure()
25 | plt.title(f"Image at resolution: {int(2 ** res_log2)}x{int(2 ** res_log2)}")
26 | plt.imshow(modified_batch.permute((0, 2, 3, 1))[0].cpu().numpy())
27 | assert modified_batch.shape == (
28 | batch.shape[0],
29 | batch.shape[1],
30 | int(2 ** res_log2),
31 | int(2 ** res_log2),
32 | )
33 |
34 | plt.figure()
35 | plt.title(f"Image at resolution: {1024}x{1024}")
36 | plt.imshow(batch.permute((0, 2, 3, 1))[0].cpu().numpy())
37 | plt.show()
38 |
39 |
40 | def test_pro_gan_train(test_data_path: Path) -> None:
41 | depth = 4
42 | progan = ProGAN(Generator(depth), Discriminator(depth), device=device)
43 | progan.train(
44 | dataset=ImageDirectoryDataset(
45 | test_data_path,
46 | transform=get_transform(
47 | new_size=(int(2 ** depth), int(2 ** depth)), flip_horizontal=False
48 | ),
49 | rec_dir=False,
50 | ),
51 | epochs=[10 for _ in range(3)],
52 | batch_sizes=[256, 256, 256],
53 | fade_in_percentages=[50 for _ in range(3)],
54 | save_dir=Path("./test_train"),
55 | num_samples=64,
56 | feedback_factor=10,
57 | )
58 | print("test_finished")
59 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/test/test_networks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | from ..networks import Discriminator, Generator
6 | from .utils import device
7 |
8 |
9 | # noinspection PyPep8Naming
10 | def test_Generator() -> None:
11 | batch_size, latent_size = 2, 512
12 | num_channels = 3
13 | depth = 10 # resolution 1024 x 1024
14 | mock_generator = Generator(depth=depth, num_channels=num_channels).to(device)
15 | mock_latent = torch.randn((batch_size, latent_size)).to(device)
16 |
17 | print(f"Generator Network:\n{mock_generator}")
18 |
19 | with torch.no_grad():
20 | for res_log2 in range(2, depth + 1):
21 | rgb_images = mock_generator(mock_latent, depth=res_log2, alpha=1)
22 | print(f"RGB output shape at depth {res_log2}: {rgb_images.shape}")
23 | assert rgb_images.shape == (
24 | batch_size,
25 | num_channels,
26 | 2 ** res_log2,
27 | 2 ** res_log2,
28 | )
29 | assert torch.isnan(rgb_images).sum().item() == 0
30 | assert torch.isinf(rgb_images).sum().item() == 0
31 |
32 |
33 | # noinspection PyPep8Naming
34 | def test_DiscriminatorUnconditional() -> None:
35 | batch_size, latent_size = 2, 512
36 | num_channels = 3
37 | depth = 10 # resolution 1024 x 1024
38 | mock_discriminator = Discriminator(depth=depth, num_channels=num_channels).to(
39 | device
40 | )
41 | mock_inputs = [
42 | torch.randn((batch_size, num_channels, 2 ** stage, 2 ** stage)).to(device)
43 | for stage in range(2, depth + 1)
44 | ]
45 |
46 | print(f"Discriminator Network:\n{mock_discriminator}")
47 |
48 | with torch.no_grad():
49 | for res_log2 in range(2, depth + 1):
50 | mock_input = mock_inputs[res_log2 - 2]
51 | print(f"RGB input image shape at depth {res_log2}: {mock_input.shape}")
52 | score = mock_discriminator(mock_input, depth=res_log2, alpha=1)
53 | assert score.shape == (batch_size,)
54 | assert torch.isnan(score).sum().item() == 0
55 | assert torch.isinf(score).sum().item() == 0
56 |
57 |
58 | # noinspection PyPep8Naming
59 | def test_DiscriminatorConditional() -> None:
60 | batch_size, latent_size = 2, 512
61 | num_channels = 3
62 | depth = 10 # resolution 1024 x 1024
63 | mock_discriminator = Discriminator(
64 | depth=depth, num_channels=num_channels, num_classes=10
65 | ).to(device)
66 | mock_inputs = [
67 | torch.randn((batch_size, num_channels, 2 ** stage, 2 ** stage)).to(device)
68 | for stage in range(2, depth + 1)
69 | ]
70 | mock_labels = torch.from_numpy(np.array([3, 7])).to(device)
71 |
72 | print(f"Discriminator Network:\n{mock_discriminator}")
73 | with torch.no_grad():
74 | for res_log2 in range(2, depth + 1):
75 | mock_input = mock_inputs[res_log2 - 2]
76 | print(f"RGB input image shape at depth {res_log2}: {mock_input.shape}")
77 | score = mock_discriminator(
78 | mock_input, depth=res_log2, alpha=1, labels=mock_labels
79 | )
80 | assert score.shape == (batch_size,)
81 | assert torch.isnan(score).sum().item() == 0
82 | assert torch.isinf(score).sum().item() == 0
83 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/test/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple
2 |
3 | import numpy as np
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 |
11 | def assert_almost_equal(x: Any, y: Any, error_margin: float = 3.0) -> None:
12 | assert np.abs(x - y) <= error_margin
13 |
14 |
15 | def assert_tensor_validity(
16 | test_tensor: Tensor, expected_shape: Tuple[int, ...]
17 | ) -> None:
18 | assert test_tensor.shape == expected_shape
19 | assert torch.isnan(test_tensor).sum().item() == 0
20 | assert torch.isinf(test_tensor).sum().item() == 0
21 |
--------------------------------------------------------------------------------
/pro_gan_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from typing import Optional, Tuple
4 |
5 | import numpy as np
6 |
7 | import torch
8 | from torch import Tensor
9 |
10 | from pro_gan_pytorch import losses
11 | from pro_gan_pytorch.losses import WganGP, StandardGAN
12 |
13 |
14 | def adjust_dynamic_range(
15 | data: Tensor,
16 | drange_in: Optional[Tuple[float, float]] = (-1.0, 1.0),
17 | drange_out: Optional[Tuple[float, float]] = (0.0, 1.0),
18 | ):
19 | if drange_in != drange_out:
20 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
21 | np.float32(drange_in[1]) - np.float32(drange_in[0])
22 | )
23 | bias = np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale
24 | data = data * scale + bias
25 |
26 | return torch.clamp(data, min=drange_out[0], max=drange_out[1])
27 |
28 |
29 | def post_process_generated_images(imgs: Tensor) -> np.array:
30 | imgs = adjust_dynamic_range(
31 | imgs.permute(0, 2, 3, 1), drange_in=(-1.0, 1.0), drange_out=(0.0, 1.0)
32 | )
33 | return (imgs * 255.0).detach().cpu().numpy().astype(np.uint8)
34 |
35 |
36 | def str2bool(v):
37 | if isinstance(v, bool):
38 | return v
39 | if v.lower() in ("yes", "true", "t", "y", "1"):
40 | return True
41 | elif v.lower() in ("no", "false", "f", "n", "0"):
42 | return False
43 | else:
44 | raise argparse.ArgumentTypeError("Boolean value expected.")
45 |
46 |
47 | # noinspection PyPep8Naming
48 | def str2GANLoss(v):
49 | if v.lower() == "wgan_gp":
50 | return WganGP()
51 | elif v.lower() == "standard_gan":
52 | return StandardGAN()
53 | else:
54 | raise argparse.ArgumentTypeError(
55 | "Unknown gan-loss function requested."
56 | f"Please consider contributing a your GANLoss to: "
57 | f"{str(Path(losses.__file__).absolute())}"
58 | )
59 |
--------------------------------------------------------------------------------
/pro_gan_pytorch_scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/pro_gan_pytorch_scripts/__init__.py
--------------------------------------------------------------------------------
/pro_gan_pytorch_scripts/compute_fid.py:
--------------------------------------------------------------------------------
1 | """ script for computing the fid of a trained model when compared with the dataset images """
2 | import argparse
3 | import tempfile
4 | from pathlib import Path
5 |
6 | import imageio as imageio
7 | import torch
8 | from cleanfid import fid
9 | from torch.backends import cudnn
10 | from tqdm import tqdm
11 |
12 | from pro_gan_pytorch.networks import create_generator_from_saved_model
13 | from pro_gan_pytorch.utils import post_process_generated_images
14 |
15 | # turn fast mode on
16 | cudnn.benchmark = True
17 |
18 |
19 | def parse_arguments() -> argparse.Namespace:
20 | """
21 | Returns: parsed arguments object
22 | """
23 | parser = argparse.ArgumentParser("ProGAN fid_score computation tool",
24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,)
25 |
26 | # fmt: off
27 | # required arguments
28 | parser.add_argument("model_path", action="store", type=Path,
29 | help="path to the trained_model.bin file")
30 | parser.add_argument("dataset_path", action="store", type=Path,
31 | help="path to the directory containing the images from the dataset. "
32 | "Note that this needs to be a flat directory")
33 |
34 | # optional arguments
35 | parser.add_argument("--generated_images_path", action="store", type=Path, default=None, required=False,
36 | help="path to the directory where the generated images are to be written. "
37 | "Uses a temporary directory by default. Provide this path if you'd like "
38 | "to see the generated images yourself :).")
39 | parser.add_argument("--batch_size", action="store", type=int, default=4, required=False,
40 | help="batch size used for generating random images")
41 | parser.add_argument("--num_generated_images", action="store", type=int, default=50_000, required=False,
42 | help="number of generated images used for computing the FID")
43 | # fmt: on
44 |
45 | args = parser.parse_args()
46 |
47 | return args
48 |
49 |
50 | def compute_fid(args: argparse.Namespace) -> None:
51 | """
52 | compute the fid for a given trained pro-gan model
53 | Args:
54 | args: configuration used for the fid computation
55 | Returns: None
56 |
57 | """
58 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59 |
60 | # load the data from the trained-model
61 | print(f"loading data from the trained model at: {args.model_path}")
62 | generator = create_generator_from_saved_model(args.model_path).to(device)
63 |
64 | # create the generated images directory:
65 | if args.generated_images_path is not None:
66 | args.generated_images_path.mkdir(parents=True, exist_ok=True)
67 | generated_images_path = (
68 | args.generated_images_path
69 | if args.generated_images_path is not None
70 | else tempfile.TemporaryDirectory()
71 | )
72 | if args.generated_images_path is None:
73 | image_writing_path = Path(generated_images_path.name)
74 | else:
75 | image_writing_path = generated_images_path
76 |
77 | print("generating random images from the trained generator ...")
78 | with torch.no_grad():
79 | for img_num in tqdm(range(0, args.num_generated_images, args.batch_size)):
80 | num_imgs = min(args.batch_size, args.num_generated_images - img_num)
81 | random_latents = torch.randn(num_imgs, generator.latent_size, device=device)
82 | gen_imgs = post_process_generated_images(generator(random_latents))
83 |
84 | # write the batch of generated images:
85 | for batch_num, gen_img in enumerate(gen_imgs, start=1):
86 | imageio.imwrite(
87 | image_writing_path / f"{img_num + batch_num}.png",
88 | gen_img,
89 | )
90 |
91 | # compute the fid once all images are generated
92 | print("computing fid ...")
93 | score = fid.compute_fid(
94 | fdir1=args.dataset_path,
95 | fdir2=image_writing_path,
96 | mode="clean",
97 | num_workers=4,
98 | )
99 | print(f"fid score: {score: .3f}")
100 |
101 | # most importantly, don't forget to do the cleanup on the temporary directory:
102 | if hasattr(generated_images_path, "cleanup"):
103 | generated_images_path.cleanup()
104 |
105 |
106 | def main() -> None:
107 | """
108 | Main function of the script
109 | Returns: None
110 | """
111 | compute_fid(parse_arguments())
112 |
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/pro_gan_pytorch_scripts/latent_space_interpolation.py:
--------------------------------------------------------------------------------
1 | """ script for writing a video of the latent space interpolation from a trained model """
2 | import argparse
3 | from pathlib import Path
4 |
5 | import cv2
6 | import torch
7 | from scipy.ndimage import gaussian_filter
8 | from torch.backends import cudnn
9 | from tqdm import tqdm
10 |
11 | from pro_gan_pytorch.networks import create_generator_from_saved_model
12 | from pro_gan_pytorch.utils import post_process_generated_images
13 |
14 | # turn fast mode on
15 | cudnn.benchmark = True
16 |
17 |
18 | def parse_arguments():
19 | """
20 | command line arguments parser
21 | :return: args => parsed command line arguments
22 | """
23 | parser = argparse.ArgumentParser("ProGAN latent-space walk demo video creation tool",
24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,)
25 |
26 | # fmt: off
27 | # required arguments
28 | parser.add_argument("model_path", action="store", type=Path,
29 | help="path to the trained_model.bin file")
30 |
31 | # options related to the video
32 | parser.add_argument("--output_path", action="store", type=Path, required=False,
33 | default="./latent_space_walk.mp4",
34 | help="path to the output video file location. "
35 | "Please only use mp4 format with this tool (.mp4 extension). "
36 | "I have banged my head too much to get anything else to work :(.")
37 | parser.add_argument("--generation_depth", action="store", type=int, default=None, required=False,
38 | help="depth at which the images should be generated. "
39 | "Starts from 2 --> (4x4) | 3 --> (8x8) etc. Uses the highest resolution by default. ")
40 | parser.add_argument("--time", action="store", type=float, default=30, required=False,
41 | help="number of seconds in the video")
42 | parser.add_argument("--fps", action="store", type=int, default=60, required=False,
43 | help="fps of the generated video")
44 | parser.add_argument("--smoothing", action="store", type=float, default=0.75, required=False,
45 | help="smoothness of walking in the latent-space."
46 | " High values corresponds to more smoothing.")
47 | # fmt: on
48 |
49 | args = parser.parse_args()
50 |
51 | return args
52 |
53 |
54 | def latent_space_interpolation(args):
55 | """
56 | Generate a video of the latent space walk (interpolation)
57 | Args:
58 | args: configuration used for the lsid
59 | Returns: None (writes generated video to disk)
60 | """
61 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62 |
63 | # load the data from the trained-model
64 | print(f"loading data from the trained model at: {args.model_path}")
65 | generator = create_generator_from_saved_model(args.model_path).to(device)
66 |
67 | # total_frames in the video:
68 | total_frames = int(args.time * args.fps)
69 |
70 | # create the video from the latent space interpolation (walk)
71 | # all latent vectors for each and every frame:
72 | all_latents = torch.randn(total_frames, generator.latent_size).to(device)
73 | all_latents = gaussian_filter(all_latents.cpu(), [args.smoothing * args.fps, 0])
74 | all_latents = torch.from_numpy(all_latents).to(device)
75 |
76 | # create output directory
77 | args.output_path.parent.mkdir(parents=True, exist_ok=True)
78 |
79 | # make the cv2 video object
80 | print("Generating the video frames ...")
81 | generation_depth = (
82 | generator.depth if args.generation_depth is None else args.generation_depth
83 | )
84 | img_dim = 2 ** generation_depth
85 | video_out = cv2.VideoWriter(
86 | str(args.output_path),
87 | cv2.VideoWriter_fourcc(*"mp4v"),
88 | args.fps,
89 | (img_dim, img_dim),
90 | )
91 |
92 | # Run the main loop for the interpolation:
93 | with torch.no_grad(): # no need to compute gradients here :)
94 | for latent in tqdm(all_latents):
95 | latent = torch.unsqueeze(latent, dim=0)
96 |
97 | # generate the image for this latent vector:
98 | frame = post_process_generated_images(
99 | generator(latent, depth=generation_depth)
100 | )
101 | frame = frame[0, ..., ::-1] # need to reverse the channel order for cv2 :D
102 |
103 | # write the generated frame to the video
104 | video_out.write(frame)
105 |
106 | print(f"video has been generated and saved to {args.output_path}")
107 |
108 | # don't forget to close the video stream :)
109 | video_out.release()
110 |
111 |
112 | def main() -> None:
113 | """
114 | Main function of the script
115 | Returns: None
116 | """
117 | latent_space_interpolation(parse_arguments())
118 |
119 |
120 | if __name__ == "__main__":
121 | main()
122 |
--------------------------------------------------------------------------------
/pro_gan_pytorch_scripts/train.py:
--------------------------------------------------------------------------------
1 | """ script for training a ProGAN (Progressively grown gan model) """
2 |
3 | import argparse
4 | from pathlib import Path
5 |
6 | import torch
7 | from torch.backends import cudnn
8 |
9 | from pro_gan_pytorch.data_tools import ImageDirectoryDataset, get_transform
10 | from pro_gan_pytorch.gan import ProGAN
11 | from pro_gan_pytorch.networks import Discriminator, Generator
12 | from pro_gan_pytorch.utils import str2bool, str2GANLoss
13 |
14 | # turn fast mode on
15 | cudnn.benchmark = True
16 |
17 | # define the device for the training script
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 |
21 | def parse_arguments() -> argparse.Namespace:
22 | """
23 | command line arguments parser
24 | Returns: args => parsed command line arguments
25 | """
26 | parser = argparse.ArgumentParser(
27 | "Train Progressively grown GAN",
28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
29 | )
30 |
31 | # fmt: off
32 | # Required arguments (input path to the data and the output directory for saving training assets)
33 | parser.add_argument("train_path", action="store", type=Path,
34 | help="Path to the images folder for training the ProGAN")
35 | parser.add_argument("output_dir", action="store", type=Path,
36 | help="Path to the directory for saving the logs and models")
37 |
38 | # Optional arguments
39 | # for retraining a model options:
40 | parser.add_argument("--retrain", action="store", type=str2bool, default=False, required=False,
41 | help="whenever you want to resume training from saved models")
42 | parser.add_argument("--generator_path", action="store", type=Path, required="--retrain" in sys.argv,
43 | help="Path to the generator model for retraining the ProGAN")
44 | parser.add_argument("--discriminator_path", action="store", type=Path, required="--retrain" in sys.argv,
45 | help="Path to the discriminator model for retraining the ProGAN")
46 | # dataset related options:
47 | parser.add_argument("--rec_dir", action="store", type=str2bool, default=False, required=False,
48 | help="whether images are stored under one folder or under a recursive dir structure")
49 | parser.add_argument("--flip_horizontal", action="store", type=str2bool, default=True, required=False,
50 | help="whether to apply mirror (horizontal) augmentation")
51 |
52 | # model architecture related options:
53 | parser.add_argument("--depth", action="store", type=int, default=10, required=False,
54 | help="depth of the generator and the discriminator. Starts from 2. "
55 | "Example 2 --> (4x4) | 3 --> (8x8) ... | 10 --> (1024x1024)")
56 | parser.add_argument("--num_channels", action="store", type=int, default=3, required=False,
57 | help="number of channels in the image data")
58 | parser.add_argument("--latent_size", action="store", type=int, default=512, required=False,
59 | help="latent size of the generator and the discriminator")
60 |
61 | # training related options:
62 | parser.add_argument("--use_eql", action="store", type=str2bool, default=True, required=False,
63 | help="whether to use the equalized learning rate")
64 | parser.add_argument("--use_ema", action="store", type=str2bool, default=True, required=False,
65 | help="whether to use the exponential moving average of generator weights. "
66 | "Keeps two copies of the generator model; an instantaneous one and "
67 | "the averaged one.")
68 | parser.add_argument("--ema_beta", action="store", type=float, default=0.999, required=False,
69 | help="value of the ema beta")
70 | parser.add_argument("--epochs", action="store", type=int, required=False, nargs="+",
71 | default=[42 for _ in range(9)],
72 | help="number of epochs over the training dataset per stage")
73 | parser.add_argument("--batch_sizes", action="store", type=int, required=False, nargs="+",
74 | default=[32, 32, 32, 32, 16, 16, 8, 4, 2],
75 | help="batch size used for training the model per stage")
76 | parser.add_argument("--batch_repeats", action="store", type=int, required=False, default=4,
77 | help="number of G and D steps executed per training iteration")
78 | parser.add_argument("--fade_in_percentages", action="store", type=int, required=False, nargs="+",
79 | default=[50 for _ in range(9)],
80 | help="number of iterations for which fading of new layer happens. Measured in percentage")
81 | parser.add_argument("--loss_fn", action="store", type=str2GANLoss, required=False, default="wgan_gp",
82 | help="loss function used for training the GAN. "
83 | "Current options: [wgan_gp, standard_gan]")
84 | parser.add_argument("--g_lrate", action="store", type=float, required=False, default=0.003,
85 | help="learning rate used by the generator")
86 | parser.add_argument("--d_lrate", action="store", type=float, required=False, default=0.003,
87 | help="learning rate used by the discriminator")
88 | parser.add_argument("--num_feedback_samples", action="store", type=int, required=False, default=4,
89 | help="number of samples used for fixed seed gan feedback")
90 | parser.add_argument("--start_depth", action="store", type=int, required=False, default=2,
91 | help="resolution to start the training from. "
92 | "Example 2 --> (4x4) | 3 --> (8x8) ... | 10 --> (1024x1024). "
93 | "Note that this is not a way to restart a partial training. "
94 | "Resuming is not supported currently. But will soon be.")
95 | parser.add_argument("--num_workers", action="store", type=int, required=False, default=4,
96 | help="number of dataloader subprocesses. It's a pytorch thing, you can ignore it ;)."
97 | " Leave it to the default value unless things are weirdly slow for you.")
98 | parser.add_argument("--feedback_factor", action="store", type=int, required=False, default=10,
99 | help="number of feedback logs written per epoch")
100 | parser.add_argument("--checkpoint_factor", action="store", type=int, required=False, default=10,
101 | help="number of epochs after which a model snapshot is saved per training stage")
102 | # fmt: on
103 |
104 | parsed_args = parser.parse_args()
105 | return parsed_args
106 |
107 |
108 | def train_progan(args: argparse.Namespace) -> None:
109 | """
110 | method to train the progan (progressively grown gan) given the configuration parameters
111 | Args:
112 | args: configuration used for the training
113 | Returns: None
114 | """
115 | print(f"Selected arguments: {args}")
116 |
117 | if args.retrain:
118 | print(f"Retraining the ProGAN: `depth`, `num_channels`, `latent_size`, `use_eql` parameters will be ignored if "
119 | f"specified.")
120 | generator, discriminator = load_models(args.generator_path, args.discriminator_path)
121 | args.depth = generator.depth
122 | args.num_channels = generator.num_channels
123 | args.latent_size = generator.latent_size
124 | args.use_eql = generator.use_eql
125 | else:
126 | generator = Generator(
127 | depth=args.depth,
128 | num_channels=args.num_channels,
129 | latent_size=args.latent_size,
130 | use_eql=args.use_eql,
131 | )
132 | discriminator = Discriminator(
133 | depth=args.depth,
134 | num_channels=args.num_channels,
135 | latent_size=args.latent_size,
136 | use_eql=args.use_eql,
137 | )
138 |
139 | progan = ProGAN(
140 | generator,
141 | discriminator,
142 | device=device,
143 | use_ema=args.use_ema,
144 | ema_beta=args.ema_beta,
145 | )
146 |
147 | progan.train(
148 | dataset=ImageDirectoryDataset(
149 | args.train_path,
150 | transform=get_transform(
151 | new_size=(int(2 ** args.depth), int(2 ** args.depth)),
152 | flip_horizontal=args.flip_horizontal,
153 | ),
154 | rec_dir=args.rec_dir,
155 | ),
156 | epochs=args.epochs,
157 | batch_sizes=args.batch_sizes,
158 | fade_in_percentages=args.fade_in_percentages,
159 | loss_fn=args.loss_fn,
160 | batch_repeats=args.batch_repeats,
161 | gen_learning_rate=args.g_lrate,
162 | dis_learning_rate=args.d_lrate,
163 | num_samples=args.num_feedback_samples,
164 | start_depth=args.start_depth,
165 | num_workers=args.num_workers,
166 | feedback_factor=args.feedback_factor,
167 | checkpoint_factor=args.checkpoint_factor,
168 | save_dir=args.output_dir,
169 | )
170 |
171 |
172 | def main() -> None:
173 | """
174 | Main function of the script
175 | Returns: None
176 | """
177 | train_progan(parse_arguments())
178 |
179 |
180 | if __name__ == "__main__":
181 | main()
182 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pytest==6.2.2
2 | black==20.8b1
3 | matplotlib==3.5.0
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.21.4
2 | torch==1.10.0
3 | torchvision==0.11.1
4 | Pillow==9.0.0
5 | tensorboard==2.7.0
6 | imageio==2.12.0
7 | tqdm==4.62.3
8 | scipy==1.7.2
9 | opencv-python==4.5.4.60
10 | clean-fid==0.1.15
11 |
--------------------------------------------------------------------------------
/samples/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore the full version of the training video
2 | pro-gan_training_video_smaller.mp4
3 |
4 | # also ignore the trained model weights
5 | GAN_GEN_SHADOW_8.pth
6 |
7 | # ignore some huge videos:
8 | interpolation.mp4
9 | video_2.gif
10 | video_3.gif
11 |
12 | # ignore the new latent_space interpolation video
13 | new_interp.mp4
14 |
15 | frames_pro/
16 | frames_mine/
17 | M_GAN_GEN_SHADOW_8.pth
18 |
--------------------------------------------------------------------------------
/samples/celebA-HQ.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/62066139ec8b467ffe26ce18a76dad43a0c2058e/samples/celebA-HQ.gif
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | with open("requirements.txt", "r") as file_:
4 | project_requirements = file_.read().split("\n")
5 |
6 | setup(
7 | name="pro-gan-pth",
8 | version="3.4",
9 | packages=["pro_gan_pytorch", "pro_gan_pytorch_scripts"],
10 | url="https://github.com/akanimax/pro_gan_pytorch",
11 | license="MIT",
12 | author="akanimax",
13 | author_email="akanimax@gmail.com",
14 | setup_requires=['wheel'],
15 | description="ProGAN package implemented as an extension of PyTorch nn.Module",
16 | install_requires=project_requirements,
17 | entry_points={
18 | "console_scripts": [
19 | f"progan_train=pro_gan_pytorch_scripts.train:main",
20 | f"progan_lsid=pro_gan_pytorch_scripts.latent_space_interpolation:main",
21 | f"progan_fid=pro_gan_pytorch_scripts.compute_fid:main",
22 | ]
23 | },
24 | )
25 |
--------------------------------------------------------------------------------