├── stylegan2_pytorch ├── version.py ├── __init__.py ├── diff_augment.py ├── cli.py └── stylegan2_pytorch.py ├── setup.cfg ├── images └── aim.png ├── samples ├── hands.jpg ├── cities.jpg ├── flowers.jpg ├── flowers-2.jpg ├── celebrities.jpg └── celebrities-2.jpg ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── setup.py ├── .gitignore └── README.md /stylegan2_pytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.9.0' 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Inside of setup.cfg 2 | [metadata] 3 | description-file = README.md -------------------------------------------------------------------------------- /images/aim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/stylegan2-pytorch/HEAD/images/aim.png -------------------------------------------------------------------------------- /samples/hands.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/stylegan2-pytorch/HEAD/samples/hands.jpg -------------------------------------------------------------------------------- /samples/cities.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/stylegan2-pytorch/HEAD/samples/cities.jpg -------------------------------------------------------------------------------- /samples/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/stylegan2-pytorch/HEAD/samples/flowers.jpg -------------------------------------------------------------------------------- /samples/flowers-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/stylegan2-pytorch/HEAD/samples/flowers-2.jpg -------------------------------------------------------------------------------- /samples/celebrities.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/stylegan2-pytorch/HEAD/samples/celebrities.jpg -------------------------------------------------------------------------------- /samples/celebrities-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/stylegan2-pytorch/HEAD/samples/celebrities-2.jpg -------------------------------------------------------------------------------- /stylegan2_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from stylegan2_pytorch.stylegan2_pytorch import Trainer, StyleGAN2, NanException, ModelLoader 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup, find_packages 3 | 4 | sys.path[0:0] = ['stylegan2_pytorch'] 5 | from version import __version__ 6 | 7 | setup( 8 | name = 'stylegan2_pytorch', 9 | packages = find_packages(), 10 | entry_points={ 11 | 'console_scripts': [ 12 | 'stylegan2_pytorch = stylegan2_pytorch.cli:main', 13 | ], 14 | }, 15 | version = __version__, 16 | license = 'MIT', 17 | description = 'StyleGan2 in Pytorch', 18 | long_description_content_type = 'text/markdown', 19 | author = 'Phil Wang', 20 | author_email = 'lucidrains@gmail.com', 21 | url = 'https://github.com/lucidrains/stylegan2-pytorch', 22 | download_url = 'https://github.com/lucidrains/stylegan2-pytorch/archive/v_036.tar.gz', 23 | keywords = [ 24 | 'generative adversarial networks', 25 | 'artificial intelligence' 26 | ], 27 | install_requires=[ 28 | 'aim', 29 | 'einops>=0.8.0', 30 | 'contrastive_learner>=0.1.0', 31 | 'fire', 32 | 'kornia>=0.5.4', 33 | 'numpy', 34 | 'retry', 35 | 'tqdm', 36 | 'torch>=2.2', 37 | 'torchvision', 38 | 'pillow', 39 | 'vector-quantize-pytorch==0.1.0' 40 | ], 41 | classifiers=[ 42 | 'Development Status :: 4 - Beta', 43 | 'Intended Audience :: Developers', 44 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 45 | 'License :: OSI Approved :: MIT License', 46 | 'Programming Language :: Python :: 3.6', 47 | ], 48 | ) 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /stylegan2_pytorch/diff_augment.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import random 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def DiffAugment(x, types=[]): 8 | for p in types: 9 | for f in AUGMENT_FNS[p]: 10 | x = f(x) 11 | return x.contiguous() 12 | 13 | 14 | # """ 15 | # Augmentation functions got images as `x` 16 | # where `x` is tensor with this dimensions: 17 | # 0 - count of images 18 | # 1 - channels 19 | # 2 - width 20 | # 3 - height of image 21 | # """ 22 | 23 | def rand_brightness(x, scale): 24 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * scale 25 | return x 26 | 27 | def rand_saturation(x, scale): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean 30 | return x 31 | 32 | def rand_contrast(x, scale): 33 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 34 | x = (x - x_mean) * (((torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) * 2.0 * scale) + 1.0) + x_mean 35 | return x 36 | 37 | def rand_translation(x, ratio=0.125): 38 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 39 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 40 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 41 | grid_batch, grid_x, grid_y = torch.meshgrid( 42 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 43 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 44 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 45 | ) 46 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 47 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 48 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 49 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 50 | return x 51 | 52 | def rand_offset(x, ratio=1, ratio_h=1, ratio_v=1): 53 | w, h = x.size(2), x.size(3) 54 | 55 | imgs = [] 56 | for img in x.unbind(dim = 0): 57 | max_h = int(w * ratio * ratio_h) 58 | max_v = int(h * ratio * ratio_v) 59 | 60 | value_h = random.randint(0, max_h) * 2 - max_h 61 | value_v = random.randint(0, max_v) * 2 - max_v 62 | 63 | if abs(value_h) > 0: 64 | img = torch.roll(img, value_h, 2) 65 | 66 | if abs(value_v) > 0: 67 | img = torch.roll(img, value_v, 1) 68 | 69 | imgs.append(img) 70 | 71 | return torch.stack(imgs) 72 | 73 | def rand_offset_h(x, ratio=1): 74 | return rand_offset(x, ratio=1, ratio_h=ratio, ratio_v=0) 75 | 76 | def rand_offset_v(x, ratio=1): 77 | return rand_offset(x, ratio=1, ratio_h=0, ratio_v=ratio) 78 | 79 | def rand_cutout(x, ratio=0.5): 80 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 81 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 82 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 83 | grid_batch, grid_x, grid_y = torch.meshgrid( 84 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 85 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 86 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 87 | ) 88 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 89 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 90 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 91 | mask[grid_batch, grid_x, grid_y] = 0 92 | x = x * mask.unsqueeze(1) 93 | return x 94 | 95 | AUGMENT_FNS = { 96 | 'brightness': [partial(rand_brightness, scale=1.)], 97 | 'lightbrightness': [partial(rand_brightness, scale=.65)], 98 | 'contrast': [partial(rand_contrast, scale=.5)], 99 | 'lightcontrast': [partial(rand_contrast, scale=.25)], 100 | 'saturation': [partial(rand_saturation, scale=1.)], 101 | 'lightsaturation': [partial(rand_saturation, scale=.5)], 102 | 'color': [partial(rand_brightness, scale=1.), partial(rand_saturation, scale=1.), partial(rand_contrast, scale=0.5)], 103 | 'lightcolor': [partial(rand_brightness, scale=0.65), partial(rand_saturation, scale=.5), partial(rand_contrast, scale=0.5)], 104 | 'offset': [rand_offset], 105 | 'offset_h': [rand_offset_h], 106 | 'offset_v': [rand_offset_v], 107 | 'translation': [rand_translation], 108 | 'cutout': [rand_cutout], 109 | } 110 | -------------------------------------------------------------------------------- /stylegan2_pytorch/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fire 3 | import random 4 | from retry.api import retry_call 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | from functools import wraps 8 | from stylegan2_pytorch import Trainer, NanException 9 | 10 | import torch 11 | import torch.multiprocessing as mp 12 | import torch.distributed as dist 13 | 14 | import numpy as np 15 | 16 | def cast_list(el): 17 | return el if isinstance(el, list) else [el] 18 | 19 | def timestamped_filename(prefix = 'generated-'): 20 | now = datetime.now() 21 | timestamp = now.strftime("%m-%d-%Y_%H-%M-%S") 22 | return f'{prefix}{timestamp}' 23 | 24 | def set_seed(seed): 25 | torch.manual_seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | np.random.seed(seed) 29 | random.seed(seed) 30 | 31 | def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed): 32 | is_main = rank == 0 33 | is_ddp = world_size > 1 34 | 35 | if is_ddp: 36 | set_seed(seed) 37 | os.environ['MASTER_ADDR'] = 'localhost' 38 | os.environ['MASTER_PORT'] = '12355' 39 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 40 | 41 | print(f"{rank + 1}/{world_size} process initialized.") 42 | 43 | model_args.update( 44 | is_ddp = is_ddp, 45 | rank = rank, 46 | world_size = world_size 47 | ) 48 | 49 | model = Trainer(**model_args) 50 | 51 | if not new: 52 | model.load(load_from) 53 | else: 54 | model.clear() 55 | 56 | model.set_data_src(data) 57 | 58 | progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>') 59 | while model.steps < num_train_steps: 60 | retry_call(model.train, tries=3, exceptions=NanException) 61 | progress_bar.n = model.steps 62 | progress_bar.refresh() 63 | if is_main and model.steps % 50 == 0: 64 | model.print_log() 65 | 66 | model.save(model.checkpoint_num) 67 | 68 | if is_ddp: 69 | dist.destroy_process_group() 70 | 71 | def train_from_folder( 72 | data = './data', 73 | results_dir = './results', 74 | models_dir = './models', 75 | name = 'default', 76 | new = False, 77 | load_from = -1, 78 | image_size = 128, 79 | network_capacity = 16, 80 | fmap_max = 512, 81 | transparent = False, 82 | batch_size = 5, 83 | gradient_accumulate_every = 6, 84 | num_train_steps = 150000, 85 | learning_rate = 2e-4, 86 | lr_mlp = 0.1, 87 | ttur_mult = 1.5, 88 | rel_disc_loss = False, 89 | num_workers = None, 90 | save_every = 1000, 91 | evaluate_every = 1000, 92 | generate = False, 93 | num_generate = 1, 94 | generate_interpolation = False, 95 | interpolation_num_steps = 100, 96 | save_frames = False, 97 | num_image_tiles = 8, 98 | trunc_psi = 0.75, 99 | mixed_prob = 0.9, 100 | fp16 = False, 101 | no_pl_reg = False, 102 | cl_reg = False, 103 | fq_layers = [], 104 | fq_dict_size = 256, 105 | attn_layers = [], 106 | no_const = False, 107 | aug_prob = 0., 108 | aug_types = ['translation', 'cutout'], 109 | top_k_training = False, 110 | generator_top_k_gamma = 0.99, 111 | generator_top_k_frac = 0.5, 112 | dual_contrast_loss = False, 113 | dataset_aug_prob = 0., 114 | multi_gpus = False, 115 | calculate_fid_every = None, 116 | calculate_fid_num_images = 12800, 117 | clear_fid_cache = False, 118 | seed = 42, 119 | log = False 120 | ): 121 | model_args = dict( 122 | name = name, 123 | results_dir = results_dir, 124 | models_dir = models_dir, 125 | batch_size = batch_size, 126 | gradient_accumulate_every = gradient_accumulate_every, 127 | image_size = image_size, 128 | network_capacity = network_capacity, 129 | fmap_max = fmap_max, 130 | transparent = transparent, 131 | lr = learning_rate, 132 | lr_mlp = lr_mlp, 133 | ttur_mult = ttur_mult, 134 | rel_disc_loss = rel_disc_loss, 135 | num_workers = num_workers, 136 | save_every = save_every, 137 | evaluate_every = evaluate_every, 138 | num_image_tiles = num_image_tiles, 139 | trunc_psi = trunc_psi, 140 | fp16 = fp16, 141 | no_pl_reg = no_pl_reg, 142 | cl_reg = cl_reg, 143 | fq_layers = fq_layers, 144 | fq_dict_size = fq_dict_size, 145 | attn_layers = attn_layers, 146 | no_const = no_const, 147 | aug_prob = aug_prob, 148 | aug_types = cast_list(aug_types), 149 | top_k_training = top_k_training, 150 | generator_top_k_gamma = generator_top_k_gamma, 151 | generator_top_k_frac = generator_top_k_frac, 152 | dual_contrast_loss = dual_contrast_loss, 153 | dataset_aug_prob = dataset_aug_prob, 154 | calculate_fid_every = calculate_fid_every, 155 | calculate_fid_num_images = calculate_fid_num_images, 156 | clear_fid_cache = clear_fid_cache, 157 | mixed_prob = mixed_prob, 158 | log = log 159 | ) 160 | 161 | if generate: 162 | model = Trainer(**model_args) 163 | model.load(load_from) 164 | samples_name = timestamped_filename() 165 | for num in tqdm(range(num_generate)): 166 | model.evaluate(f'{samples_name}-{num}', num_image_tiles) 167 | print(f'sample images generated at {results_dir}/{name}/{samples_name}') 168 | return 169 | 170 | if generate_interpolation: 171 | model = Trainer(**model_args) 172 | model.load(load_from) 173 | samples_name = timestamped_filename() 174 | model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames) 175 | print(f'interpolation generated at {results_dir}/{name}/{samples_name}') 176 | return 177 | 178 | world_size = torch.cuda.device_count() 179 | 180 | if world_size == 1 or not multi_gpus: 181 | run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed) 182 | return 183 | 184 | mp.spawn(run_training, 185 | args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed), 186 | nprocs=world_size, 187 | join=True) 188 | 189 | def main(): 190 | fire.Fire(train_from_folder) 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Simple StyleGan2 for Pytorch 2 | [![PyPI version](https://badge.fury.io/py/stylegan2-pytorch.svg)](https://badge.fury.io/py/stylegan2-pytorch) 3 | 4 | Simple Pytorch implementation of Stylegan2 based on https://arxiv.org/abs/1912.04958 that can be completely trained from the command-line, no coding needed. 5 | 6 | Below are some flowers that do not exist. 7 | 8 | 9 | 10 | 11 | 12 | Neither do these hands 13 | 14 | 15 | 16 | Nor these cities 17 | 18 | 19 | 20 | Nor these celebrities (trained by @yoniker) 21 | 22 | 23 | 24 | 25 | 26 | 27 | ## Install 28 | 29 | You will need a machine with a GPU and CUDA installed. Then pip install the package like this 30 | 31 | ```bash 32 | $ pip install stylegan2_pytorch 33 | ``` 34 | 35 | If you are using a windows machine, the following commands reportedly works. 36 | 37 | ```bash 38 | $ conda install pytorch torchvision -c python 39 | $ pip install stylegan2_pytorch 40 | ``` 41 | 42 | ## Use 43 | 44 | ```bash 45 | $ stylegan2_pytorch --data /path/to/images 46 | ``` 47 | 48 | That's it. Sample images will be saved to `results/default` and models will be saved periodically to `models/default`. 49 | 50 | ## Advanced Use 51 | 52 | You can specify the name of your project with 53 | 54 | ```bash 55 | $ stylegan2_pytorch --data /path/to/images --name my-project-name 56 | ``` 57 | 58 | You can also specify the location where intermediate results and model checkpoints should be stored with 59 | 60 | ```bash 61 | $ stylegan2_pytorch --data /path/to/images --name my-project-name --results_dir /path/to/results/dir --models_dir /path/to/models/dir 62 | ``` 63 | 64 | You can increase the network capacity (which defaults to `16`) to improve generation results, at the cost of more memory. 65 | 66 | ```bash 67 | $ stylegan2_pytorch --data /path/to/images --network-capacity 256 68 | ``` 69 | 70 | By default, if the training gets cut off, it will automatically resume from the last checkpointed file. If you want to restart with new settings, just add a `new` flag 71 | 72 | ```bash 73 | $ stylegan2_pytorch --new --data /path/to/images --name my-project-name --image-size 512 --batch-size 1 --gradient-accumulate-every 16 --network-capacity 10 74 | ``` 75 | 76 | Once you have finished training, you can generate images from your latest checkpoint like so. 77 | 78 | ```bash 79 | $ stylegan2_pytorch --generate 80 | ``` 81 | 82 | To generate a video of a interpolation through two random points in latent space. 83 | 84 | ```bash 85 | $ stylegan2_pytorch --generate-interpolation --interpolation-num-steps 100 86 | ``` 87 | 88 | To save each individual frame of the interpolation 89 | 90 | ```bash 91 | $ stylegan2_pytorch --generate-interpolation --save-frames 92 | ``` 93 | 94 | If a previous checkpoint contained a better generator, (which often happens as generators start degrading towards the end of training), you can load from a previous checkpoint with another flag 95 | 96 | ```bash 97 | $ stylegan2_pytorch --generate --load-from {checkpoint number} 98 | ``` 99 | 100 | A technique used in both StyleGAN and BigGAN is truncating the latent values so that their values fall close to the mean. The small the truncation value, the better the samples will appear at the cost of sample variety. You can control this with the `--trunc-psi`, where values typically fall between `0.5` and `1`. It is set at `0.75` as default 101 | 102 | ```bash 103 | $ stylegan2_pytorch --generate --trunc-psi 0.5 104 | ``` 105 | 106 | ## Multi-GPU training 107 | 108 | If you have one machine with multiple GPUs, the repository offers a way to utilize all of them for training. With multiple GPUs, each batch will be divided evenly amongst the GPUs available. For example, for 2 GPUs, with a batch size of 32, each GPU will see 16 samples. 109 | 110 | You simply have to add a `--multi-gpus` flag, everyting else is taken care of. If you would like to restrict to specific GPUs, you can use the `CUDA_VISIBLE_DEVICES` environment variable to control what devices can be used. (ex. `CUDA_VISIBLE_DEVICES=0,2,3` only devices 0, 2, 3 are available) 111 | 112 | ```bash 113 | $ stylegan2_pytorch --data ./data --multi-gpus --batch-size 32 --gradient-accumulate-every 1 114 | ``` 115 | 116 | ## Low amounts of Training Data 117 | 118 | In the past, GANs needed a lot of data to learn how to generate well. The faces model took **70k** high quality images from Flickr, as an example. 119 | 120 | However, in the month of May 2020, researchers all across the world independently converged on a simple technique to reduce that number to as low as **1-2k**. That simple idea was to differentiably augment all images, generated or real, going into the discriminator during training. 121 | 122 | If one were to augment at a low enough probability, the augmentations will not 'leak' into the generations. 123 | 124 | In the setting of low data, you can use the feature with a simple flag. 125 | 126 | ```bash 127 | # find a suitable probability between 0. -> 0.7 at maximum 128 | $ stylegan2_pytorch --data ./data --aug-prob 0.25 129 | ``` 130 | 131 | By default, the augmentations used are `translation` and `cutout`. If you would like to add `color`, you can do so with the `--aug-types` argument. 132 | 133 | ```bash 134 | # make sure there are no spaces between items! 135 | $ stylegan2_pytorch --data ./data --aug-prob 0.25 --aug-types [translation,cutout,color] 136 | ``` 137 | 138 | You can customize it to any combination of the three you would like. The differentiable augmentation code was copied and slightly modified from here. 139 | 140 | ## When do I stop training? 141 | 142 | For as long as possible until the adversarial game between the two neural nets fall apart (we call this divergence). By default, the number of training steps is set to `150000` for 128x128 images, but you will certainly want this number to be higher if the GAN doesn't diverge by the end of training, or if you are training at a higher resolution. 143 | 144 | ```bash 145 | $ stylegan2_pytorch --data ./data --image-size 512 --num-train-steps 1000000 146 | ``` 147 | 148 | ## Attention 149 | 150 | This framework also allows for you to add an efficient form of self-attention to the designated layers of the discriminator (and the symmetric layer of the generator), which will greatly improve results. The more attention you can afford, the better! 151 | 152 | ```python 153 | # add self attention after the output of layer 1 154 | $ stylegan2_pytorch --data ./data --attn-layers 1 155 | ``` 156 | 157 | ```python 158 | # add self attention after the output of layers 1 and 2 159 | # do not put a space after the comma in the list! 160 | $ stylegan2_pytorch --data ./data --attn-layers [1,2] 161 | ``` 162 | 163 | ## Bonus 164 | 165 | Training on transparent images 166 | 167 | ```bash 168 | $ stylegan2_pytorch --data ./transparent/images/path --transparent 169 | ``` 170 | 171 | ## Memory considerations 172 | 173 | The more GPU memory you have, the bigger and better the image generation will be. Nvidia recommended having up to 16GB for training 1024x1024 images. If you have less than that, there are a couple settings you can play with so that the model fits. 174 | 175 | ```bash 176 | $ stylegan2_pytorch --data /path/to/data \ 177 | --batch-size 3 \ 178 | --gradient-accumulate-every 5 \ 179 | --network-capacity 16 180 | ``` 181 | 182 | 1. Batch size - You can decrease the `batch-size` down to 1, but you should increase the `gradient-accumulate-every` correspondingly so that the mini-batch the network sees is not too small. This may be confusing to a layperson, so I'll think about how I would automate the choice of `gradient-accumulate-every` going forward. 183 | 184 | 2. Network capacity - You can decrease the neural network capacity to lessen the memory requirements. Just be aware that this has been shown to degrade generation performance. 185 | 186 | If none of this works, you can settle for 'Lightweight' GAN, which will allow you to tradeoff quality to train at greater resolutions in reasonable amount of time. 187 | 188 | ## Deployment on AWS 189 | 190 | Below are some steps which may be helpful for deployment using Amazon Web Services. In order to use this, you will have 191 | to provision a GPU-backed EC2 instance. An appropriate instance type would be from a p2 or p3 series. I (iboates) tried 192 | a p2.xlarge (the cheapest option) and it was quite slow, slower in fact than using Google Colab. More powerful instance 193 | types may be better but they are more expensive. You can read more about them 194 | [here](https://aws.amazon.com/ec2/instance-types/#Accelerated_Computing). 195 | 196 | ### Setup steps 197 | 198 | 1. Archive your training data and upload it to an S3 bucket 199 | 2. Provision your EC2 instance (I used an Ubuntu AMI) 200 | 3. Log into your EC2 instance via SSH 201 | 4. Install the aws CLI client and configure it: 202 | 203 | ```bash 204 | sudo snap install aws-cli --classic 205 | aws configure 206 | ``` 207 | 208 | You will then have to enter your AWS access keys, which you can retrieve from the management console under AWS 209 | Management Console > Profile > My Security Credentials > Access Keys 210 | 211 | Then, run these commands, or maybe put them in a shell script and execute that: 212 | 213 | ```bash 214 | mkdir data 215 | curl -O https://bootstrap.pypa.io/get-pip.py 216 | sudo apt-get install python3-distutils 217 | python3 get-pip.py 218 | pip3 install stylegan2_pytorch 219 | export PATH=$PATH:/home/ubuntu/.local/bin 220 | aws s3 sync s3:// ~/data 221 | cd data 222 | tar -xf ../train.tar.gz 223 | ``` 224 | 225 | Now you should be able to train by simplying calling `stylegan2_pytorch [args]`. 226 | 227 | Notes: 228 | 229 | * If you have a lot of training data, you may need to provision extra block storage via EBS. 230 | * Also, you may need to spread your data across multiple archives. 231 | * You should run this on a `screen` window so it won't terminate once you log out of the SSH session. 232 | 233 | ## Research 234 | 235 | ### FID Scores 236 | 237 | Thanks to GetsEclectic, you can now calculate the FID score periodically! Again, made super simple with one extra argument, as shown below. 238 | 239 | Firstly, install the `pytorch_fid` package 240 | 241 | ```bash 242 | $ pip install pytorch-fid 243 | ``` 244 | 245 | Followed by 246 | 247 | ```bash 248 | $ stylegan2_pytorch --data ./data --calculate-fid-every 5000 249 | ``` 250 | 251 | FID results will be logged to `./results/{name}/fid_scores.txt` 252 | 253 | ### Coding 254 | 255 | If you would like to sample images programmatically, you can do so with the following simple `ModelLoader` class. 256 | 257 | ```python 258 | import torch 259 | from torchvision.utils import save_image 260 | from stylegan2_pytorch import ModelLoader 261 | 262 | loader = ModelLoader( 263 | base_dir = '/path/to/directory', # path to where you invoked the command line tool 264 | name = 'default' # the project name, defaults to 'default' 265 | ) 266 | 267 | noise = torch.randn(1, 512).cuda() # noise 268 | styles = loader.noise_to_styles(noise, trunc_psi = 0.7) # pass through mapping network 269 | images = loader.styles_to_images(styles) # call the generator on intermediate style vectors 270 | 271 | save_image(images, './sample.jpg') # save your images, or do whatever you desire 272 | ``` 273 | 274 | ### Logging to experiment tracker 275 | 276 | To log the losses to an open source experiment tracker (Aim), you simply need to pass an extra flag like so. 277 | 278 | ```bash 279 | $ stylegan2_pytorch --data ./data --log 280 | ``` 281 | 282 | Then, you need to make sure you have Docker installed. Following the instructions at Aim, you execute the following in your terminal. 283 | 284 | ```bash 285 | $ aim up 286 | ``` 287 | 288 | Then open up your browser to the address and you should see 289 | 290 | 291 | 292 | 293 | ## Experimental 294 | 295 | ### Top-k Training for Generator 296 | 297 | A new paper has produced evidence that by simply zero-ing out the gradient contributions from samples that are deemed fake by the discriminator, the generator learns significantly better, achieving new state of the art. 298 | 299 | ```python 300 | $ stylegan2_pytorch --data ./data --top-k-training 301 | ``` 302 | 303 | Gamma is a decay schedule that slowly decreases the topk from the full batch size to the target fraction of 50% (also modifiable hyperparameter). 304 | 305 | ```python 306 | $ stylegan2_pytorch --data ./data --top-k-training --generate-top-k-frac 0.5 --generate-top-k-gamma 0.99 307 | ``` 308 | 309 | ### Feature Quantization 310 | 311 | A recent paper reported improved results if intermediate representations of the discriminator are vector quantized. Although I have not noticed any dramatic changes, I have decided to add this as a feature, so other minds out there can investigate. To use, you have to specify which layer(s) you would like to vector quantize. Default dictionary size is `256` and is also tunable. 312 | 313 | ```python 314 | # feature quantize layers 1 and 2, with a dictionary size of 512 each 315 | # do not put a space after the comma in the list! 316 | $ stylegan2_pytorch --data ./data --fq-layers [1,2] --fq-dict-size 512 317 | ``` 318 | 319 | ### Contrastive Loss Regularization 320 | 321 | I have tried contrastive learning on the discriminator (in step with the usual GAN training) and possibly observed improved stability and quality of final results. You can turn on this experimental feature with a simple flag as shown below. 322 | 323 | ```python 324 | $ stylegan2_pytorch --data ./data --cl-reg 325 | ``` 326 | 327 | ### Relativistic Discriminator Loss 328 | 329 | This was proposed in the Relativistic GAN paper to stabilize training. I have had mixed results, but will include the feature for those who want to experiment with it. 330 | 331 | ```python 332 | $ stylegan2_pytorch --data ./data --rel-disc-loss 333 | ``` 334 | 335 | ### Non-constant 4x4 Block 336 | 337 | By default, the StyleGAN architecture styles a constant learned 4x4 block as it is progressively upsampled. This is an experimental feature that makes it so the 4x4 block is learned from the style vector `w` instead. 338 | 339 | ```python 340 | $ stylegan2_pytorch --data ./data --no-const 341 | ``` 342 | 343 | ### Dual Contrastive Loss 344 | 345 | A recent paper has proposed that a novel contrastive loss between the real and fake logits can improve quality over other types of losses. (The default in this repository is hinge loss, and the paper shows a slight improvement) 346 | 347 | ```python 348 | $ stylegan2_pytorch --data ./data --dual-contrast-loss 349 | ``` 350 | 351 | ## Alternatives 352 | 353 | Stylegan2 + Unet Discriminator 354 | 355 | I have gotten really good results with a unet discriminator, but the architecturally change was too big to fit as an option in this repository. If you are aiming for perfection, feel free to try it. 356 | 357 | If you would like me to give the royal treatment to some other GAN architecture (BigGAN), feel free to reach out at my email. Happy to hear your pitch. 358 | 359 | ## Appreciation 360 | 361 | Thank you to Matthew Mann for his inspiring [simple port](https://github.com/manicman1999/StyleGAN2-Tensorflow-2.0) for Tensorflow 2.0 362 | 363 | ## References 364 | 365 | ```bibtex 366 | @article{Karras2019stylegan2, 367 | title = {Analyzing and Improving the Image Quality of {StyleGAN}}, 368 | author = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila}, 369 | journal = {CoRR}, 370 | volume = {abs/1912.04958}, 371 | year = {2019}, 372 | } 373 | ``` 374 | 375 | ```bibtex 376 | @misc{zhao2020feature, 377 | title = {Feature Quantization Improves GAN Training}, 378 | author = {Yang Zhao and Chunyuan Li and Ping Yu and Jianfeng Gao and Changyou Chen}, 379 | year = {2020} 380 | } 381 | ``` 382 | 383 | ```bibtex 384 | @misc{chen2020simple, 385 | title = {A Simple Framework for Contrastive Learning of Visual Representations}, 386 | author = {Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton}, 387 | year = {2020} 388 | } 389 | ``` 390 | 391 | ```bibtex 392 | @article{, 393 | title = {Oxford 102 Flowers}, 394 | author = {Nilsback, M-E. and Zisserman, A., 2008}, 395 | abstract = {A 102 category dataset consisting of 102 flower categories, commonly occuring in the United Kingdom. Each class consists of 40 to 258 images. The images have large scale, pose and light variations.} 396 | } 397 | ``` 398 | 399 | ```bibtex 400 | @article{afifi201911k, 401 | title = {11K Hands: gender recognition and biometric identification using a large dataset of hand images}, 402 | author = {Afifi, Mahmoud}, 403 | journal = {Multimedia Tools and Applications} 404 | } 405 | ``` 406 | 407 | ```bibtex 408 | @misc{zhang2018selfattention, 409 | title = {Self-Attention Generative Adversarial Networks}, 410 | author = {Han Zhang and Ian Goodfellow and Dimitris Metaxas and Augustus Odena}, 411 | year = {2018}, 412 | eprint = {1805.08318}, 413 | archivePrefix = {arXiv} 414 | } 415 | ``` 416 | 417 | ```bibtex 418 | @article{shen2019efficient, 419 | author = {Zhuoran Shen and 420 | Mingyuan Zhang and 421 | Haiyu Zhao and 422 | Shuai Yi and 423 | Hongsheng Li}, 424 | title = {Efficient Attention: Attention with Linear Complexities}, 425 | journal = {CoRR}, 426 | year = {2018}, 427 | url = {http://arxiv.org/abs/1812.01243}, 428 | } 429 | ``` 430 | 431 | ```bibtex 432 | @article{zhao2020diffaugment, 433 | title = {Differentiable Augmentation for Data-Efficient GAN Training}, 434 | author = {Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song}, 435 | journal = {arXiv preprint arXiv:2006.10738}, 436 | year = {2020} 437 | } 438 | ``` 439 | 440 | ```bibtex 441 | @misc{zhao2020image, 442 | title = {Image Augmentations for GAN Training}, 443 | author = {Zhengli Zhao and Zizhao Zhang and Ting Chen and Sameer Singh and Han Zhang}, 444 | year = {2020}, 445 | eprint = {2006.02595}, 446 | archivePrefix = {arXiv} 447 | } 448 | ``` 449 | 450 | ```bibtex 451 | @misc{karras2020training, 452 | title = {Training Generative Adversarial Networks with Limited Data}, 453 | author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila}, 454 | year = {2020}, 455 | eprint = {2006.06676}, 456 | archivePrefix = {arXiv}, 457 | primaryClass = {cs.CV} 458 | } 459 | ``` 460 | 461 | ```bibtex 462 | @misc{jolicoeurmartineau2018relativistic, 463 | title = {The relativistic discriminator: a key element missing from standard GAN}, 464 | author = {Alexia Jolicoeur-Martineau}, 465 | year = {2018}, 466 | eprint = {1807.00734}, 467 | archivePrefix = {arXiv}, 468 | primaryClass = {cs.LG} 469 | } 470 | ``` 471 | 472 | ```bibtex 473 | @misc{sinha2020topk, 474 | title = {Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples}, 475 | author = {Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena}, 476 | year = {2020}, 477 | eprint = {2002.06224}, 478 | archivePrefix = {arXiv}, 479 | primaryClass = {stat.ML} 480 | } 481 | ``` 482 | 483 | ```bibtex 484 | @misc{yu2021dual, 485 | title = {Dual Contrastive Loss and Attention for GANs}, 486 | author = {Ning Yu and Guilin Liu and Aysegul Dundar and Andrew Tao and Bryan Catanzaro and Larry Davis and Mario Fritz}, 487 | year = {2021}, 488 | eprint = {2103.16748}, 489 | archivePrefix = {arXiv}, 490 | primaryClass = {cs.CV} 491 | } 492 | ``` 493 | 494 | ```bibtex 495 | @inproceedings{Huang2025TheGI, 496 | title = {The GAN is dead; long live the GAN! A Modern GAN Baseline}, 497 | author = {Yiwen Huang and Aaron Gokaslan and Volodymyr Kuleshov and James Tompkin}, 498 | year = {2025}, 499 | url = {https://api.semanticscholar.org/CorpusID:275405495} 500 | } 501 | ``` 502 | -------------------------------------------------------------------------------- /stylegan2_pytorch/stylegan2_pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import fire 5 | import json 6 | 7 | from tqdm import tqdm 8 | from math import floor, log2 9 | from random import random 10 | from shutil import rmtree 11 | from functools import partial 12 | import multiprocessing 13 | from contextlib import contextmanager, ExitStack 14 | 15 | import numpy as np 16 | 17 | import torch 18 | from torch import nn, einsum 19 | from torch.utils import data 20 | from torch.optim import Adam 21 | import torch.nn.functional as F 22 | from torch.autograd import grad as torch_grad 23 | from torch.utils.data.distributed import DistributedSampler 24 | from torch.nn.parallel import DistributedDataParallel as DDP 25 | 26 | from einops import rearrange, repeat 27 | from kornia.filters import filter2d 28 | 29 | import torchvision 30 | from torchvision import transforms 31 | from stylegan2_pytorch.version import __version__ 32 | from stylegan2_pytorch.diff_augment import DiffAugment 33 | 34 | from vector_quantize_pytorch import VectorQuantize 35 | 36 | from PIL import Image 37 | from pathlib import Path 38 | 39 | try: 40 | from apex import amp 41 | APEX_AVAILABLE = True 42 | except: 43 | APEX_AVAILABLE = False 44 | 45 | import aim 46 | 47 | assert torch.cuda.is_available(), 'You need to have an Nvidia GPU with CUDA installed.' 48 | 49 | 50 | # constants 51 | 52 | NUM_CORES = multiprocessing.cpu_count() 53 | EXTS = ['jpg', 'jpeg', 'png'] 54 | 55 | # helper classes 56 | 57 | class NanException(Exception): 58 | pass 59 | 60 | class EMA(): 61 | def __init__(self, beta): 62 | super().__init__() 63 | self.beta = beta 64 | def update_average(self, old, new): 65 | if not exists(old): 66 | return new 67 | return old * self.beta + (1 - self.beta) * new 68 | 69 | class Flatten(nn.Module): 70 | def forward(self, x): 71 | return x.reshape(x.shape[0], -1) 72 | 73 | class RandomApply(nn.Module): 74 | def __init__(self, prob, fn, fn_else = lambda x: x): 75 | super().__init__() 76 | self.fn = fn 77 | self.fn_else = fn_else 78 | self.prob = prob 79 | def forward(self, x): 80 | fn = self.fn if random() < self.prob else self.fn_else 81 | return fn(x) 82 | 83 | class Residual(nn.Module): 84 | def __init__(self, fn): 85 | super().__init__() 86 | self.fn = fn 87 | def forward(self, x): 88 | return self.fn(x) + x 89 | 90 | class ChanNorm(nn.Module): 91 | def __init__(self, dim, eps = 1e-5): 92 | super().__init__() 93 | self.eps = eps 94 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 95 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) 96 | 97 | def forward(self, x): 98 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 99 | mean = torch.mean(x, dim = 1, keepdim = True) 100 | return (x - mean) / (var + self.eps).sqrt() * self.g + self.b 101 | 102 | class PreNorm(nn.Module): 103 | def __init__(self, dim, fn): 104 | super().__init__() 105 | self.fn = fn 106 | self.norm = ChanNorm(dim) 107 | 108 | def forward(self, x): 109 | return self.fn(self.norm(x)) 110 | 111 | class PermuteToFrom(nn.Module): 112 | def __init__(self, fn): 113 | super().__init__() 114 | self.fn = fn 115 | def forward(self, x): 116 | x = x.permute(0, 2, 3, 1) 117 | out, *_, loss = self.fn(x) 118 | out = out.permute(0, 3, 1, 2) 119 | return out, loss 120 | 121 | class Blur(nn.Module): 122 | def __init__(self): 123 | super().__init__() 124 | f = torch.Tensor([1, 2, 1]) 125 | self.register_buffer('f', f) 126 | def forward(self, x): 127 | f = self.f 128 | f = f[None, None, :] * f [None, :, None] 129 | return filter2d(x, f, normalized=True) 130 | 131 | # attention 132 | 133 | class DepthWiseConv2d(nn.Module): 134 | def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True): 135 | super().__init__() 136 | self.net = nn.Sequential( 137 | nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), 138 | nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) 139 | ) 140 | def forward(self, x): 141 | return self.net(x) 142 | 143 | class LinearAttention(nn.Module): 144 | def __init__(self, dim, dim_head = 64, heads = 8): 145 | super().__init__() 146 | self.scale = dim_head ** -0.5 147 | self.heads = heads 148 | inner_dim = dim_head * heads 149 | 150 | self.nonlin = nn.GELU() 151 | self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False) 152 | self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False) 153 | self.to_out = nn.Conv2d(inner_dim, dim, 1) 154 | 155 | def forward(self, fmap): 156 | h, x, y = self.heads, *fmap.shape[-2:] 157 | q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1)) 158 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) 159 | 160 | q = q.softmax(dim = -1) 161 | k = k.softmax(dim = -2) 162 | 163 | q = q * self.scale 164 | 165 | context = einsum('b n d, b n e -> b d e', k, v) 166 | out = einsum('b n d, b d e -> b n e', q, context) 167 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) 168 | 169 | out = self.nonlin(out) 170 | return self.to_out(out) 171 | 172 | # one layer of self-attention and feedforward, for images 173 | 174 | attn_and_ff = lambda chan: nn.Sequential(*[ 175 | Residual(PreNorm(chan, LinearAttention(chan))), 176 | Residual(PreNorm(chan, nn.Sequential(nn.Conv2d(chan, chan * 2, 1), leaky_relu(), nn.Conv2d(chan * 2, chan, 1)))) 177 | ]) 178 | 179 | # helpers 180 | 181 | def exists(val): 182 | return val is not None 183 | 184 | @contextmanager 185 | def null_context(): 186 | yield 187 | 188 | def combine_contexts(contexts): 189 | @contextmanager 190 | def multi_contexts(): 191 | with ExitStack() as stack: 192 | yield [stack.enter_context(ctx()) for ctx in contexts] 193 | return multi_contexts 194 | 195 | def default(value, d): 196 | return value if exists(value) else d 197 | 198 | def cycle(iterable): 199 | while True: 200 | for i in iterable: 201 | yield i 202 | 203 | def cast_list(el): 204 | return el if isinstance(el, list) else [el] 205 | 206 | def is_empty(t): 207 | if isinstance(t, torch.Tensor): 208 | return t.nelement() == 0 209 | return not exists(t) 210 | 211 | def raise_if_nan(t): 212 | if torch.isnan(t): 213 | raise NanException 214 | 215 | def gradient_accumulate_contexts(gradient_accumulate_every, is_ddp, ddps): 216 | if is_ddp: 217 | num_no_syncs = gradient_accumulate_every - 1 218 | head = [combine_contexts(map(lambda ddp: ddp.no_sync, ddps))] * num_no_syncs 219 | tail = [null_context] 220 | contexts = head + tail 221 | else: 222 | contexts = [null_context] * gradient_accumulate_every 223 | 224 | for context in contexts: 225 | with context(): 226 | yield 227 | 228 | def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs): 229 | if fp16: 230 | with amp.scale_loss(loss, optimizer, loss_id) as scaled_loss: 231 | scaled_loss.backward(**kwargs) 232 | else: 233 | loss.backward(**kwargs) 234 | 235 | def gradient_penalty(images, output, weight = 10, center = 0.): 236 | batch_size = images.shape[0] 237 | gradients = torch_grad(outputs=output, inputs=images, 238 | grad_outputs=torch.ones(output.size(), device=images.device), 239 | create_graph=True, retain_graph=True, only_inputs=True)[0] 240 | 241 | gradients = gradients.reshape(batch_size, -1) 242 | return weight * ((gradients.norm(2, dim=1) - center) ** 2).mean() 243 | 244 | def calc_pl_lengths(styles, images): 245 | device = images.device 246 | num_pixels = images.shape[2] * images.shape[3] 247 | pl_noise = torch.randn(images.shape, device=device) / math.sqrt(num_pixels) 248 | outputs = (images * pl_noise).sum() 249 | 250 | pl_grads = torch_grad(outputs=outputs, inputs=styles, 251 | grad_outputs=torch.ones(outputs.shape, device=device), 252 | create_graph=True, retain_graph=True, only_inputs=True)[0] 253 | 254 | return (pl_grads ** 2).sum(dim=2).mean(dim=1).sqrt() 255 | 256 | def noise(n, latent_dim, device): 257 | return torch.randn(n, latent_dim).cuda(device) 258 | 259 | def noise_list(n, layers, latent_dim, device): 260 | return [(noise(n, latent_dim, device), layers)] 261 | 262 | def mixed_list(n, layers, latent_dim, device): 263 | tt = int(torch.rand(()).numpy() * layers) 264 | return noise_list(n, tt, latent_dim, device) + noise_list(n, layers - tt, latent_dim, device) 265 | 266 | def latent_to_w(style_vectorizer, latent_descr): 267 | return [(style_vectorizer(z), num_layers) for z, num_layers in latent_descr] 268 | 269 | def image_noise(n, im_size, device): 270 | return torch.FloatTensor(n, im_size, im_size, 1).uniform_(0., 1.).cuda(device) 271 | 272 | def leaky_relu(p=0.2): 273 | return nn.LeakyReLU(p, inplace=True) 274 | 275 | def evaluate_in_chunks(max_batch_size, model, *args): 276 | split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args)))) 277 | chunked_outputs = [model(*i) for i in split_args] 278 | if len(chunked_outputs) == 1: 279 | return chunked_outputs[0] 280 | return torch.cat(chunked_outputs, dim=0) 281 | 282 | def styles_def_to_tensor(styles_def): 283 | return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1) 284 | 285 | def set_requires_grad(model, bool): 286 | for p in model.parameters(): 287 | p.requires_grad = bool 288 | 289 | def slerp(val, low, high): 290 | low_norm = low / torch.norm(low, dim=1, keepdim=True) 291 | high_norm = high / torch.norm(high, dim=1, keepdim=True) 292 | omega = torch.acos((low_norm * high_norm).sum(1)) 293 | so = torch.sin(omega) 294 | res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high 295 | return res 296 | 297 | # losses 298 | 299 | def gen_hinge_loss(fake, real): 300 | return fake.mean() 301 | 302 | def hinge_loss(fake, real): 303 | return (F.relu(1 + real) + F.relu(1 - fake)).mean() 304 | 305 | def dual_contrastive_loss(fake_logits, real_logits): 306 | device = real_logits.device 307 | real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits)) 308 | 309 | def loss_half(t1, t2): 310 | t1 = rearrange(t1, 'i -> i 1') 311 | t2 = repeat(t2, 'j -> i j', i = t1.shape[0]) 312 | t = torch.cat((t1, t2), dim = -1) 313 | return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long)) 314 | 315 | return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits) 316 | 317 | # dataset 318 | 319 | def convert_rgb_to_transparent(image): 320 | if image.mode != 'RGBA': 321 | return image.convert('RGBA') 322 | return image 323 | 324 | def convert_transparent_to_rgb(image): 325 | if image.mode != 'RGB': 326 | return image.convert('RGB') 327 | return image 328 | 329 | class expand_greyscale(object): 330 | def __init__(self, transparent): 331 | self.transparent = transparent 332 | 333 | def __call__(self, tensor): 334 | channels = tensor.shape[0] 335 | num_target_channels = 4 if self.transparent else 3 336 | 337 | if channels == num_target_channels: 338 | return tensor 339 | 340 | alpha = None 341 | if channels == 1: 342 | color = tensor.expand(3, -1, -1) 343 | elif channels == 2: 344 | color = tensor[:1].expand(3, -1, -1) 345 | alpha = tensor[1:] 346 | else: 347 | raise Exception(f'image with invalid number of channels given {channels}') 348 | 349 | if not exists(alpha) and self.transparent: 350 | alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device) 351 | 352 | return color if not self.transparent else torch.cat((color, alpha)) 353 | 354 | def resize_to_minimum_size(min_size, image): 355 | if max(*image.size) < min_size: 356 | return torchvision.transforms.functional.resize(image, min_size) 357 | return image 358 | 359 | class Dataset(data.Dataset): 360 | def __init__(self, folder, image_size, transparent = False, aug_prob = 0.): 361 | super().__init__() 362 | self.folder = folder 363 | self.image_size = image_size 364 | self.paths = [p for ext in EXTS for p in Path(f'{folder}').glob(f'**/*.{ext}')] 365 | assert len(self.paths) > 0, f'No images were found in {folder} for training' 366 | 367 | convert_image_fn = convert_transparent_to_rgb if not transparent else convert_rgb_to_transparent 368 | num_channels = 3 if not transparent else 4 369 | 370 | self.transform = transforms.Compose([ 371 | transforms.Lambda(convert_image_fn), 372 | transforms.Lambda(partial(resize_to_minimum_size, image_size)), 373 | transforms.Resize(image_size), 374 | RandomApply(aug_prob, transforms.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.98, 1.02)), transforms.CenterCrop(image_size)), 375 | transforms.ToTensor(), 376 | transforms.Lambda(expand_greyscale(transparent)) 377 | ]) 378 | 379 | def __len__(self): 380 | return len(self.paths) 381 | 382 | def __getitem__(self, index): 383 | path = self.paths[index] 384 | img = Image.open(path) 385 | return self.transform(img) 386 | 387 | # augmentations 388 | 389 | def random_hflip(tensor, prob): 390 | if prob < random(): 391 | return tensor 392 | return torch.flip(tensor, dims=(3,)) 393 | 394 | class AugWrapper(nn.Module): 395 | def __init__(self, D, image_size): 396 | super().__init__() 397 | self.D = D 398 | 399 | def forward(self, images, prob = 0., types = [], detach = False, return_aug_images = False, input_requires_grad = False): 400 | if random() < prob: 401 | images = random_hflip(images, prob=0.5) 402 | images = DiffAugment(images, types=types) 403 | 404 | if detach: 405 | images = images.detach() 406 | 407 | if input_requires_grad: 408 | images.requires_grad_() 409 | 410 | logits = self.D(images) 411 | 412 | if not return_aug_images: 413 | return logits 414 | 415 | return images, logits 416 | 417 | # stylegan2 classes 418 | 419 | class EqualLinear(nn.Module): 420 | def __init__(self, in_dim, out_dim, lr_mul = 1, bias = True): 421 | super().__init__() 422 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) 423 | if bias: 424 | self.bias = nn.Parameter(torch.zeros(out_dim)) 425 | 426 | self.lr_mul = lr_mul 427 | 428 | def forward(self, input): 429 | return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) 430 | 431 | class StyleVectorizer(nn.Module): 432 | def __init__(self, emb, depth, lr_mul = 0.1): 433 | super().__init__() 434 | 435 | layers = [] 436 | for i in range(depth): 437 | layers.extend([EqualLinear(emb, emb, lr_mul), leaky_relu()]) 438 | 439 | self.net = nn.Sequential(*layers) 440 | 441 | def forward(self, x): 442 | x = F.normalize(x, dim=1) 443 | return self.net(x) 444 | 445 | class RGBBlock(nn.Module): 446 | def __init__(self, latent_dim, input_channel, upsample, rgba = False): 447 | super().__init__() 448 | self.input_channel = input_channel 449 | self.to_style = nn.Linear(latent_dim, input_channel) 450 | 451 | out_filters = 3 if not rgba else 4 452 | self.conv = Conv2DMod(input_channel, out_filters, 1, demod=False) 453 | 454 | self.upsample = nn.Sequential( 455 | nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False), 456 | Blur() 457 | ) if upsample else None 458 | 459 | def forward(self, x, prev_rgb, istyle): 460 | b, c, h, w = x.shape 461 | style = self.to_style(istyle) 462 | x = self.conv(x, style) 463 | 464 | if exists(prev_rgb): 465 | x = x + prev_rgb 466 | 467 | if exists(self.upsample): 468 | x = self.upsample(x) 469 | 470 | return x 471 | 472 | class Conv2DMod(nn.Module): 473 | def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps = 1e-8, **kwargs): 474 | super().__init__() 475 | self.filters = out_chan 476 | self.demod = demod 477 | self.kernel = kernel 478 | self.stride = stride 479 | self.dilation = dilation 480 | self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel))) 481 | self.eps = eps 482 | nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 483 | 484 | def _get_same_padding(self, size, kernel, dilation, stride): 485 | return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 486 | 487 | def forward(self, x, y): 488 | b, c, h, w = x.shape 489 | 490 | w1 = y[:, None, :, None, None] 491 | w2 = self.weight[None, :, :, :, :] 492 | weights = w2 * (w1 + 1) 493 | 494 | if self.demod: 495 | d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps) 496 | weights = weights * d 497 | 498 | x = x.reshape(1, -1, h, w) 499 | 500 | _, _, *ws = weights.shape 501 | weights = weights.reshape(b * self.filters, *ws) 502 | 503 | padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride) 504 | x = F.conv2d(x, weights, padding=padding, groups=b) 505 | 506 | x = x.reshape(-1, self.filters, h, w) 507 | return x 508 | 509 | class GeneratorBlock(nn.Module): 510 | def __init__(self, latent_dim, input_channels, filters, upsample = True, upsample_rgb = True, rgba = False): 511 | super().__init__() 512 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) if upsample else None 513 | 514 | self.to_style1 = nn.Linear(latent_dim, input_channels) 515 | self.to_noise1 = nn.Linear(1, filters) 516 | self.conv1 = Conv2DMod(input_channels, filters, 3) 517 | 518 | self.to_style2 = nn.Linear(latent_dim, filters) 519 | self.to_noise2 = nn.Linear(1, filters) 520 | self.conv2 = Conv2DMod(filters, filters, 3) 521 | 522 | self.activation = leaky_relu() 523 | self.to_rgb = RGBBlock(latent_dim, filters, upsample_rgb, rgba) 524 | 525 | def forward(self, x, prev_rgb, istyle, inoise): 526 | if exists(self.upsample): 527 | x = self.upsample(x) 528 | 529 | inoise = inoise[:, :x.shape[2], :x.shape[3], :] 530 | noise1 = self.to_noise1(inoise).permute((0, 3, 2, 1)) 531 | noise2 = self.to_noise2(inoise).permute((0, 3, 2, 1)) 532 | 533 | style1 = self.to_style1(istyle) 534 | x = self.conv1(x, style1) 535 | x = self.activation(x + noise1) 536 | 537 | style2 = self.to_style2(istyle) 538 | x = self.conv2(x, style2) 539 | x = self.activation(x + noise2) 540 | 541 | rgb = self.to_rgb(x, prev_rgb, istyle) 542 | return x, rgb 543 | 544 | class DiscriminatorBlock(nn.Module): 545 | def __init__(self, input_channels, filters, downsample=True): 546 | super().__init__() 547 | self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1)) 548 | 549 | self.net = nn.Sequential( 550 | nn.Conv2d(input_channels, filters, 3, padding=1), 551 | leaky_relu(), 552 | nn.Conv2d(filters, filters, 3, padding=1), 553 | leaky_relu() 554 | ) 555 | 556 | self.downsample = nn.Sequential( 557 | Blur(), 558 | nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) 559 | ) if downsample else None 560 | 561 | def forward(self, x): 562 | res = self.conv_res(x) 563 | x = self.net(x) 564 | if exists(self.downsample): 565 | x = self.downsample(x) 566 | x = (x + res) * (1 / math.sqrt(2)) 567 | return x 568 | 569 | class Generator(nn.Module): 570 | def __init__(self, image_size, latent_dim, network_capacity = 16, transparent = False, attn_layers = [], no_const = False, fmap_max = 512): 571 | super().__init__() 572 | self.image_size = image_size 573 | self.latent_dim = latent_dim 574 | self.num_layers = int(log2(image_size) - 1) 575 | 576 | filters = [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1] 577 | 578 | set_fmap_max = partial(min, fmap_max) 579 | filters = list(map(set_fmap_max, filters)) 580 | init_channels = filters[0] 581 | filters = [init_channels, *filters] 582 | 583 | in_out_pairs = zip(filters[:-1], filters[1:]) 584 | self.no_const = no_const 585 | 586 | if no_const: 587 | self.to_initial_block = nn.ConvTranspose2d(latent_dim, init_channels, 4, 1, 0, bias=False) 588 | else: 589 | self.initial_block = nn.Parameter(torch.randn((1, init_channels, 4, 4))) 590 | 591 | self.initial_conv = nn.Conv2d(filters[0], filters[0], 3, padding=1) 592 | self.blocks = nn.ModuleList([]) 593 | self.attns = nn.ModuleList([]) 594 | 595 | for ind, (in_chan, out_chan) in enumerate(in_out_pairs): 596 | not_first = ind != 0 597 | not_last = ind != (self.num_layers - 1) 598 | num_layer = self.num_layers - ind 599 | 600 | attn_fn = attn_and_ff(in_chan) if num_layer in attn_layers else None 601 | 602 | self.attns.append(attn_fn) 603 | 604 | block = GeneratorBlock( 605 | latent_dim, 606 | in_chan, 607 | out_chan, 608 | upsample = not_first, 609 | upsample_rgb = not_last, 610 | rgba = transparent 611 | ) 612 | self.blocks.append(block) 613 | 614 | def forward(self, styles, input_noise): 615 | batch_size = styles.shape[0] 616 | image_size = self.image_size 617 | 618 | if self.no_const: 619 | avg_style = styles.mean(dim=1)[:, :, None, None] 620 | x = self.to_initial_block(avg_style) 621 | else: 622 | x = self.initial_block.expand(batch_size, -1, -1, -1) 623 | 624 | rgb = None 625 | styles = styles.transpose(0, 1) 626 | x = self.initial_conv(x) 627 | 628 | for style, block, attn in zip(styles, self.blocks, self.attns): 629 | if exists(attn): 630 | x = attn(x) 631 | x, rgb = block(x, rgb, style, input_noise) 632 | 633 | return rgb 634 | 635 | class Discriminator(nn.Module): 636 | def __init__(self, image_size, network_capacity = 16, fq_layers = [], fq_dict_size = 256, attn_layers = [], transparent = False, fmap_max = 512): 637 | super().__init__() 638 | num_layers = int(log2(image_size) - 1) 639 | num_init_filters = 3 if not transparent else 4 640 | 641 | blocks = [] 642 | filters = [num_init_filters] + [(network_capacity * 4) * (2 ** i) for i in range(num_layers + 1)] 643 | 644 | set_fmap_max = partial(min, fmap_max) 645 | filters = list(map(set_fmap_max, filters)) 646 | chan_in_out = list(zip(filters[:-1], filters[1:])) 647 | 648 | blocks = [] 649 | attn_blocks = [] 650 | quantize_blocks = [] 651 | 652 | for ind, (in_chan, out_chan) in enumerate(chan_in_out): 653 | num_layer = ind + 1 654 | is_not_last = ind != (len(chan_in_out) - 1) 655 | 656 | block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last) 657 | blocks.append(block) 658 | 659 | attn_fn = attn_and_ff(out_chan) if num_layer in attn_layers else None 660 | 661 | attn_blocks.append(attn_fn) 662 | 663 | quantize_fn = PermuteToFrom(VectorQuantize(out_chan, fq_dict_size)) if num_layer in fq_layers else None 664 | quantize_blocks.append(quantize_fn) 665 | 666 | self.blocks = nn.ModuleList(blocks) 667 | self.attn_blocks = nn.ModuleList(attn_blocks) 668 | self.quantize_blocks = nn.ModuleList(quantize_blocks) 669 | 670 | chan_last = filters[-1] 671 | latent_dim = 2 * 2 * chan_last 672 | 673 | self.final_conv = nn.Conv2d(chan_last, chan_last, 3, padding=1) 674 | self.flatten = Flatten() 675 | self.to_logit = nn.Linear(latent_dim, 1) 676 | 677 | def forward(self, x): 678 | b, *_ = x.shape 679 | 680 | quantize_loss = torch.zeros(1).to(x) 681 | 682 | for (block, attn_block, q_block) in zip(self.blocks, self.attn_blocks, self.quantize_blocks): 683 | x = block(x) 684 | 685 | if exists(attn_block): 686 | x = attn_block(x) 687 | 688 | if exists(q_block): 689 | x, loss = q_block(x) 690 | quantize_loss += loss 691 | 692 | x = self.final_conv(x) 693 | x = self.flatten(x) 694 | x = self.to_logit(x) 695 | return x.squeeze(), quantize_loss 696 | 697 | class StyleGAN2(nn.Module): 698 | def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0): 699 | super().__init__() 700 | self.lr = lr 701 | self.steps = steps 702 | self.ema_updater = EMA(0.995) 703 | 704 | self.S = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp) 705 | self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const, fmap_max = fmap_max) 706 | self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent, fmap_max = fmap_max) 707 | 708 | self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp) 709 | self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const) 710 | 711 | self.D_cl = None 712 | 713 | if cl_reg: 714 | from contrastive_learner import ContrastiveLearner 715 | # experimental contrastive loss discriminator regularization 716 | assert not transparent, 'contrastive loss regularization does not work with transparent images yet' 717 | self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') 718 | 719 | # wrapper for augmenting all images going into the discriminator 720 | self.D_aug = AugWrapper(self.D, image_size) 721 | 722 | # turn off grad for exponential moving averages 723 | set_requires_grad(self.SE, False) 724 | set_requires_grad(self.GE, False) 725 | 726 | # init optimizers 727 | generator_params = list(self.G.parameters()) + list(self.S.parameters()) 728 | self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9)) 729 | self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9)) 730 | 731 | # init weights 732 | self._init_weights() 733 | self.reset_parameter_averaging() 734 | 735 | self.cuda(rank) 736 | 737 | # startup apex mixed precision 738 | self.fp16 = fp16 739 | if fp16: 740 | (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1', num_losses=3) 741 | 742 | def _init_weights(self): 743 | for m in self.modules(): 744 | if type(m) in {nn.Conv2d, nn.Linear}: 745 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') 746 | 747 | for block in self.G.blocks: 748 | nn.init.zeros_(block.to_noise1.weight) 749 | nn.init.zeros_(block.to_noise2.weight) 750 | nn.init.zeros_(block.to_noise1.bias) 751 | nn.init.zeros_(block.to_noise2.bias) 752 | 753 | def EMA(self): 754 | def update_moving_average(ma_model, current_model): 755 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 756 | old_weight, up_weight = ma_params.data, current_params.data 757 | ma_params.data = self.ema_updater.update_average(old_weight, up_weight) 758 | 759 | update_moving_average(self.SE, self.S) 760 | update_moving_average(self.GE, self.G) 761 | 762 | def reset_parameter_averaging(self): 763 | self.SE.load_state_dict(self.S.state_dict()) 764 | self.GE.load_state_dict(self.G.state_dict()) 765 | 766 | def forward(self, x): 767 | return x 768 | 769 | class Trainer(): 770 | def __init__( 771 | self, 772 | name = 'default', 773 | results_dir = 'results', 774 | models_dir = 'models', 775 | base_dir = './', 776 | image_size = 128, 777 | network_capacity = 16, 778 | fmap_max = 512, 779 | transparent = False, 780 | batch_size = 4, 781 | mixed_prob = 0.9, 782 | gradient_accumulate_every=1, 783 | lr = 2e-4, 784 | lr_mlp = 0.1, 785 | ttur_mult = 2, 786 | rel_disc_loss = False, 787 | num_workers = None, 788 | save_every = 1000, 789 | evaluate_every = 1000, 790 | num_image_tiles = 8, 791 | trunc_psi = 0.6, 792 | fp16 = False, 793 | cl_reg = False, 794 | no_pl_reg = False, 795 | fq_layers = [], 796 | fq_dict_size = 256, 797 | attn_layers = [], 798 | no_const = False, 799 | aug_prob = 0., 800 | aug_types = ['translation', 'cutout'], 801 | top_k_training = False, 802 | generator_top_k_gamma = 0.99, 803 | generator_top_k_frac = 0.5, 804 | dual_contrast_loss = False, 805 | dataset_aug_prob = 0., 806 | calculate_fid_every = None, 807 | calculate_fid_num_images = 12800, 808 | clear_fid_cache = False, 809 | is_ddp = False, 810 | rank = 0, 811 | world_size = 1, 812 | log = False, 813 | *args, 814 | **kwargs 815 | ): 816 | self.GAN_params = [args, kwargs] 817 | self.GAN = None 818 | 819 | self.name = name 820 | 821 | base_dir = Path(base_dir) 822 | self.base_dir = base_dir 823 | self.results_dir = base_dir / results_dir 824 | self.models_dir = base_dir / models_dir 825 | self.fid_dir = base_dir / 'fid' / name 826 | self.config_path = self.models_dir / name / '.config.json' 827 | 828 | assert log2(image_size).is_integer(), 'image size must be a power of 2 (64, 128, 256, 512, 1024)' 829 | self.image_size = image_size 830 | self.network_capacity = network_capacity 831 | self.fmap_max = fmap_max 832 | self.transparent = transparent 833 | 834 | self.fq_layers = cast_list(fq_layers) 835 | self.fq_dict_size = fq_dict_size 836 | self.has_fq = len(self.fq_layers) > 0 837 | 838 | self.attn_layers = cast_list(attn_layers) 839 | self.no_const = no_const 840 | 841 | self.aug_prob = aug_prob 842 | self.aug_types = aug_types 843 | 844 | self.lr = lr 845 | self.lr_mlp = lr_mlp 846 | self.ttur_mult = ttur_mult 847 | self.rel_disc_loss = rel_disc_loss 848 | self.batch_size = batch_size 849 | self.num_workers = num_workers 850 | self.mixed_prob = mixed_prob 851 | 852 | self.num_image_tiles = num_image_tiles 853 | self.evaluate_every = evaluate_every 854 | self.save_every = save_every 855 | self.steps = 0 856 | 857 | self.av = None 858 | self.trunc_psi = trunc_psi 859 | 860 | self.no_pl_reg = no_pl_reg 861 | self.pl_mean = None 862 | 863 | self.gradient_accumulate_every = gradient_accumulate_every 864 | 865 | assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex is not available for you to use mixed precision training' 866 | self.fp16 = fp16 867 | 868 | self.cl_reg = cl_reg 869 | 870 | self.d_loss = 0 871 | self.g_loss = 0 872 | self.q_loss = None 873 | self.last_gp_loss = None 874 | self.last_cr_loss = None 875 | self.last_fid = None 876 | 877 | self.pl_length_ma = EMA(0.99) 878 | self.init_folders() 879 | 880 | self.loader = None 881 | self.dataset_aug_prob = dataset_aug_prob 882 | 883 | self.calculate_fid_every = calculate_fid_every 884 | self.calculate_fid_num_images = calculate_fid_num_images 885 | self.clear_fid_cache = clear_fid_cache 886 | 887 | self.top_k_training = top_k_training 888 | self.generator_top_k_gamma = generator_top_k_gamma 889 | self.generator_top_k_frac = generator_top_k_frac 890 | 891 | self.dual_contrast_loss = dual_contrast_loss 892 | 893 | assert not (is_ddp and cl_reg), 'Contrastive loss regularization does not work well with multi GPUs yet' 894 | self.is_ddp = is_ddp 895 | self.is_main = rank == 0 896 | self.rank = rank 897 | self.world_size = world_size 898 | 899 | self.logger = aim.Session(experiment=name) if log else None 900 | 901 | @property 902 | def image_extension(self): 903 | return 'jpg' if not self.transparent else 'png' 904 | 905 | @property 906 | def checkpoint_num(self): 907 | return floor(self.steps // self.save_every) 908 | 909 | @property 910 | def hparams(self): 911 | return {'image_size': self.image_size, 'network_capacity': self.network_capacity} 912 | 913 | def init_GAN(self): 914 | args, kwargs = self.GAN_params 915 | self.GAN = StyleGAN2(lr = self.lr, lr_mlp = self.lr_mlp, ttur_mult = self.ttur_mult, image_size = self.image_size, network_capacity = self.network_capacity, fmap_max = self.fmap_max, transparent = self.transparent, fq_layers = self.fq_layers, fq_dict_size = self.fq_dict_size, attn_layers = self.attn_layers, fp16 = self.fp16, cl_reg = self.cl_reg, no_const = self.no_const, rank = self.rank, *args, **kwargs) 916 | 917 | if self.is_ddp: 918 | ddp_kwargs = {'device_ids': [self.rank]} 919 | self.S_ddp = DDP(self.GAN.S, **ddp_kwargs) 920 | self.G_ddp = DDP(self.GAN.G, **ddp_kwargs) 921 | self.D_ddp = DDP(self.GAN.D, **ddp_kwargs) 922 | self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs) 923 | 924 | if exists(self.logger): 925 | self.logger.set_params(self.hparams) 926 | 927 | def write_config(self): 928 | self.config_path.write_text(json.dumps(self.config())) 929 | 930 | def load_config(self): 931 | config = self.config() if not self.config_path.exists() else json.loads(self.config_path.read_text()) 932 | self.image_size = config['image_size'] 933 | self.network_capacity = config['network_capacity'] 934 | self.transparent = config['transparent'] 935 | self.fq_layers = config['fq_layers'] 936 | self.fq_dict_size = config['fq_dict_size'] 937 | self.fmap_max = config.pop('fmap_max', 512) 938 | self.attn_layers = config.pop('attn_layers', []) 939 | self.no_const = config.pop('no_const', False) 940 | self.lr_mlp = config.pop('lr_mlp', 0.1) 941 | del self.GAN 942 | self.init_GAN() 943 | 944 | def config(self): 945 | return {'image_size': self.image_size, 'network_capacity': self.network_capacity, 'lr_mlp': self.lr_mlp, 'transparent': self.transparent, 'fq_layers': self.fq_layers, 'fq_dict_size': self.fq_dict_size, 'attn_layers': self.attn_layers, 'no_const': self.no_const} 946 | 947 | def set_data_src(self, folder): 948 | self.dataset = Dataset(folder, self.image_size, transparent = self.transparent, aug_prob = self.dataset_aug_prob) 949 | num_workers = num_workers = default(self.num_workers, NUM_CORES if not self.is_ddp else 0) 950 | sampler = DistributedSampler(self.dataset, rank=self.rank, num_replicas=self.world_size, shuffle=True) if self.is_ddp else None 951 | dataloader = data.DataLoader(self.dataset, num_workers = num_workers, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True) 952 | self.loader = cycle(dataloader) 953 | 954 | # auto set augmentation prob for user if dataset is detected to be low 955 | num_samples = len(self.dataset) 956 | if not exists(self.aug_prob) and num_samples < 1e5: 957 | self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6) 958 | print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%') 959 | 960 | def train(self): 961 | assert exists(self.loader), 'You must first initialize the data source with `.set_data_src()`' 962 | 963 | if not exists(self.GAN): 964 | self.init_GAN() 965 | 966 | self.GAN.train() 967 | total_disc_loss = torch.tensor(0.).cuda(self.rank) 968 | total_gen_loss = torch.tensor(0.).cuda(self.rank) 969 | 970 | batch_size = math.ceil(self.batch_size / self.world_size) 971 | 972 | image_size = self.GAN.G.image_size 973 | latent_dim = self.GAN.G.latent_dim 974 | num_layers = self.GAN.G.num_layers 975 | 976 | aug_prob = self.aug_prob 977 | aug_types = self.aug_types 978 | aug_kwargs = {'prob': aug_prob, 'types': aug_types} 979 | 980 | apply_gradient_penalty = self.steps % 4 == 0 981 | apply_path_penalty = not self.no_pl_reg and self.steps > 5000 and self.steps % 32 == 0 982 | apply_cl_reg_to_generated = self.steps > 20000 983 | 984 | S = self.GAN.S if not self.is_ddp else self.S_ddp 985 | G = self.GAN.G if not self.is_ddp else self.G_ddp 986 | D = self.GAN.D if not self.is_ddp else self.D_ddp 987 | D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp 988 | 989 | backwards = partial(loss_backwards, self.fp16) 990 | 991 | if exists(self.GAN.D_cl): 992 | self.GAN.D_opt.zero_grad() 993 | 994 | if apply_cl_reg_to_generated: 995 | for i in range(self.gradient_accumulate_every): 996 | get_latents_fn = mixed_list if random() < self.mixed_prob else noise_list 997 | style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank) 998 | noise = image_noise(batch_size, image_size, device=self.rank) 999 | 1000 | w_space = latent_to_w(self.GAN.S, style) 1001 | w_styles = styles_def_to_tensor(w_space) 1002 | 1003 | generated_images = self.GAN.G(w_styles, noise) 1004 | self.GAN.D_cl(generated_images.clone().detach(), accumulate=True) 1005 | 1006 | for i in range(self.gradient_accumulate_every): 1007 | image_batch = next(self.loader).cuda(self.rank) 1008 | self.GAN.D_cl(image_batch, accumulate=True) 1009 | 1010 | loss = self.GAN.D_cl.calculate_loss() 1011 | self.last_cr_loss = loss.clone().detach().item() 1012 | backwards(loss, self.GAN.D_opt, loss_id = 0) 1013 | 1014 | self.GAN.D_opt.step() 1015 | 1016 | # setup losses 1017 | 1018 | if not self.dual_contrast_loss: 1019 | D_loss_fn = hinge_loss 1020 | G_loss_fn = gen_hinge_loss 1021 | G_requires_reals = False 1022 | else: 1023 | D_loss_fn = dual_contrastive_loss 1024 | G_loss_fn = dual_contrastive_loss 1025 | G_requires_reals = True 1026 | 1027 | # train discriminator 1028 | 1029 | avg_pl_length = self.pl_mean 1030 | self.GAN.D_opt.zero_grad() 1031 | 1032 | for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[D_aug, S, G]): 1033 | get_latents_fn = mixed_list if random() < self.mixed_prob else noise_list 1034 | style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank) 1035 | noise = image_noise(batch_size, image_size, device=self.rank) 1036 | 1037 | w_space = latent_to_w(S, style) 1038 | w_styles = styles_def_to_tensor(w_space) 1039 | 1040 | generated_images = G(w_styles, noise) 1041 | generated_images, (fake_output, fake_q_loss) = D_aug(generated_images.clone().detach(), return_aug_images = True, input_requires_grad = apply_gradient_penalty, detach = True, **aug_kwargs) 1042 | 1043 | image_batch = next(self.loader).cuda(self.rank) 1044 | 1045 | if apply_gradient_penalty: 1046 | image_batch.requires_grad_() 1047 | 1048 | real_output, real_q_loss = D_aug(image_batch, **aug_kwargs) 1049 | 1050 | real_output_loss = real_output 1051 | fake_output_loss = fake_output 1052 | 1053 | if self.rel_disc_loss: 1054 | real_output_loss = real_output_loss - fake_output.mean() 1055 | fake_output_loss = fake_output_loss - real_output.mean() 1056 | 1057 | divergence = D_loss_fn(fake_output_loss, real_output_loss) 1058 | disc_loss = divergence 1059 | 1060 | if self.has_fq: 1061 | quantize_loss = (fake_q_loss + real_q_loss).mean() 1062 | self.q_loss = float(quantize_loss.detach().item()) 1063 | 1064 | disc_loss = disc_loss + quantize_loss 1065 | 1066 | if apply_gradient_penalty: 1067 | gp = gradient_penalty(image_batch, real_output) + gradient_penalty(generated_images, fake_output) 1068 | self.last_gp_loss = gp.clone().detach().item() 1069 | self.track(self.last_gp_loss, 'GP') 1070 | disc_loss = disc_loss + gp 1071 | 1072 | disc_loss = disc_loss / self.gradient_accumulate_every 1073 | disc_loss.register_hook(raise_if_nan) 1074 | backwards(disc_loss, self.GAN.D_opt, loss_id = 1) 1075 | 1076 | total_disc_loss += divergence.detach().item() / self.gradient_accumulate_every 1077 | 1078 | self.d_loss = float(total_disc_loss) 1079 | self.track(self.d_loss, 'D') 1080 | 1081 | self.GAN.D_opt.step() 1082 | 1083 | # train generator 1084 | 1085 | self.GAN.G_opt.zero_grad() 1086 | 1087 | for i in gradient_accumulate_contexts(self.gradient_accumulate_every, self.is_ddp, ddps=[S, G, D_aug]): 1088 | style = get_latents_fn(batch_size, num_layers, latent_dim, device=self.rank) 1089 | noise = image_noise(batch_size, image_size, device=self.rank) 1090 | 1091 | w_space = latent_to_w(S, style) 1092 | w_styles = styles_def_to_tensor(w_space) 1093 | 1094 | generated_images = G(w_styles, noise) 1095 | fake_output, _ = D_aug(generated_images, **aug_kwargs) 1096 | fake_output_loss = fake_output 1097 | 1098 | real_output = None 1099 | if G_requires_reals: 1100 | image_batch = next(self.loader).cuda(self.rank) 1101 | real_output, _ = D_aug(image_batch, detach = True, **aug_kwargs) 1102 | real_output = real_output.detach() 1103 | 1104 | if self.top_k_training: 1105 | epochs = (self.steps * batch_size * self.gradient_accumulate_every) / len(self.dataset) 1106 | k_frac = max(self.generator_top_k_gamma ** epochs, self.generator_top_k_frac) 1107 | k = math.ceil(batch_size * k_frac) 1108 | 1109 | if k != batch_size: 1110 | fake_output_loss, _ = fake_output_loss.topk(k=k, largest=False) 1111 | 1112 | loss = G_loss_fn(fake_output_loss, real_output) 1113 | gen_loss = loss 1114 | 1115 | if apply_path_penalty: 1116 | pl_lengths = calc_pl_lengths(w_styles, generated_images) 1117 | avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy()) 1118 | 1119 | if not is_empty(self.pl_mean): 1120 | pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean() 1121 | if not torch.isnan(pl_loss): 1122 | gen_loss = gen_loss + pl_loss 1123 | 1124 | gen_loss = gen_loss / self.gradient_accumulate_every 1125 | gen_loss.register_hook(raise_if_nan) 1126 | backwards(gen_loss, self.GAN.G_opt, loss_id = 2) 1127 | 1128 | total_gen_loss += loss.detach().item() / self.gradient_accumulate_every 1129 | 1130 | self.g_loss = float(total_gen_loss) 1131 | self.track(self.g_loss, 'G') 1132 | 1133 | self.GAN.G_opt.step() 1134 | 1135 | # calculate moving averages 1136 | 1137 | if apply_path_penalty and not np.isnan(avg_pl_length): 1138 | self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length) 1139 | self.track(self.pl_mean, 'PL') 1140 | 1141 | if self.is_main and self.steps % 10 == 0 and self.steps > 20000: 1142 | self.GAN.EMA() 1143 | 1144 | if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2: 1145 | self.GAN.reset_parameter_averaging() 1146 | 1147 | # save from NaN errors 1148 | 1149 | if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)): 1150 | print(f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}') 1151 | self.load(self.checkpoint_num) 1152 | raise NanException 1153 | 1154 | # periodically save results 1155 | 1156 | if self.is_main: 1157 | if self.steps % self.save_every == 0: 1158 | self.save(self.checkpoint_num) 1159 | 1160 | if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 2500): 1161 | self.evaluate(floor(self.steps / self.evaluate_every)) 1162 | 1163 | if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0: 1164 | num_batches = math.ceil(self.calculate_fid_num_images / self.batch_size) 1165 | fid = self.calculate_fid(num_batches) 1166 | self.last_fid = fid 1167 | 1168 | with open(str(self.results_dir / self.name / f'fid_scores.txt'), 'a') as f: 1169 | f.write(f'{self.steps},{fid}\n') 1170 | 1171 | self.steps += 1 1172 | self.av = None 1173 | 1174 | @torch.no_grad() 1175 | def evaluate(self, num = 0, trunc = 1.0): 1176 | self.GAN.eval() 1177 | ext = self.image_extension 1178 | num_rows = self.num_image_tiles 1179 | 1180 | latent_dim = self.GAN.G.latent_dim 1181 | image_size = self.GAN.G.image_size 1182 | num_layers = self.GAN.G.num_layers 1183 | 1184 | # latents and noise 1185 | 1186 | latents = noise_list(num_rows ** 2, num_layers, latent_dim, device=self.rank) 1187 | n = image_noise(num_rows ** 2, image_size, device=self.rank) 1188 | 1189 | # regular 1190 | 1191 | generated_images = self.generate_truncated(self.GAN.S, self.GAN.G, latents, n, trunc_psi = self.trunc_psi) 1192 | torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}.{ext}'), nrow=num_rows) 1193 | 1194 | # moving averages 1195 | 1196 | generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi) 1197 | torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows) 1198 | 1199 | # mixing regularities 1200 | 1201 | def tile(a, dim, n_tile): 1202 | init_dim = a.size(dim) 1203 | repeat_idx = [1] * a.dim() 1204 | repeat_idx[dim] = n_tile 1205 | a = a.repeat(*(repeat_idx)) 1206 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda(self.rank) 1207 | return torch.index_select(a, dim, order_index) 1208 | 1209 | nn = noise(num_rows, latent_dim, device=self.rank) 1210 | tmp1 = tile(nn, 0, num_rows) 1211 | tmp2 = nn.repeat(num_rows, 1) 1212 | 1213 | tt = int(num_layers / 2) 1214 | mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)] 1215 | 1216 | generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, mixed_latents, n, trunc_psi = self.trunc_psi) 1217 | torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-mr.{ext}'), nrow=num_rows) 1218 | 1219 | @torch.no_grad() 1220 | def calculate_fid(self, num_batches): 1221 | from pytorch_fid import fid_score 1222 | torch.cuda.empty_cache() 1223 | 1224 | real_path = self.fid_dir / 'real' 1225 | fake_path = self.fid_dir / 'fake' 1226 | 1227 | # remove any existing files used for fid calculation and recreate directories 1228 | 1229 | if not real_path.exists() or self.clear_fid_cache: 1230 | rmtree(real_path, ignore_errors=True) 1231 | os.makedirs(real_path) 1232 | 1233 | for batch_num in tqdm(range(num_batches), desc='calculating FID - saving reals'): 1234 | real_batch = next(self.loader) 1235 | for k, image in enumerate(real_batch.unbind(0)): 1236 | filename = str(k + batch_num * self.batch_size) 1237 | torchvision.utils.save_image(image, str(real_path / f'{filename}.png')) 1238 | 1239 | # generate a bunch of fake images in results / name / fid_fake 1240 | 1241 | rmtree(fake_path, ignore_errors=True) 1242 | os.makedirs(fake_path) 1243 | 1244 | self.GAN.eval() 1245 | ext = self.image_extension 1246 | 1247 | latent_dim = self.GAN.G.latent_dim 1248 | image_size = self.GAN.G.image_size 1249 | num_layers = self.GAN.G.num_layers 1250 | 1251 | for batch_num in tqdm(range(num_batches), desc='calculating FID - saving generated'): 1252 | # latents and noise 1253 | latents = noise_list(self.batch_size, num_layers, latent_dim, device=self.rank) 1254 | noise = image_noise(self.batch_size, image_size, device=self.rank) 1255 | 1256 | # moving averages 1257 | generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, noise, trunc_psi = self.trunc_psi) 1258 | 1259 | for j, image in enumerate(generated_images.unbind(0)): 1260 | torchvision.utils.save_image(image, str(fake_path / f'{str(j + batch_num * self.batch_size)}-ema.{ext}')) 1261 | 1262 | return fid_score.calculate_fid_given_paths([str(real_path), str(fake_path)], 256, noise.device, 2048) 1263 | 1264 | @torch.no_grad() 1265 | def truncate_style(self, tensor, S = None, trunc_psi = 0.75): 1266 | S = default(S, self.GAN.S) 1267 | batch_size = self.batch_size 1268 | latent_dim = self.GAN.G.latent_dim 1269 | 1270 | if not exists(self.av): 1271 | z = noise(2000, latent_dim, device=self.rank) 1272 | samples = evaluate_in_chunks(batch_size, S, z).cpu().numpy() 1273 | self.av = np.mean(samples, axis = 0) 1274 | self.av = np.expand_dims(self.av, axis = 0) 1275 | 1276 | av_torch = torch.from_numpy(self.av).cuda(self.rank) 1277 | tensor = trunc_psi * (tensor - av_torch) + av_torch 1278 | return tensor 1279 | 1280 | @torch.no_grad() 1281 | def truncate_style_defs(self, w, S = None, trunc_psi = 0.75): 1282 | w_space = [] 1283 | for tensor, num_layers in w: 1284 | tensor = self.truncate_style(tensor, S = S, trunc_psi = trunc_psi) 1285 | w_space.append((tensor, num_layers)) 1286 | return w_space 1287 | 1288 | @torch.no_grad() 1289 | def generate_truncated(self, S, G, style, noi, trunc_psi = 0.75, num_image_tiles = 8): 1290 | w = map(lambda t: (S(t[0]), t[1]), style) 1291 | w_truncated = self.truncate_style_defs(w, S = S, trunc_psi = trunc_psi) 1292 | w_styles = styles_def_to_tensor(w_truncated) 1293 | generated_images = evaluate_in_chunks(self.batch_size, G, w_styles, noi) 1294 | return generated_images.clamp_(0., 1.) 1295 | 1296 | @torch.no_grad() 1297 | def generate_interpolation(self, num = 0, num_image_tiles = 8, trunc = 1.0, num_steps = 100, save_frames = False): 1298 | self.GAN.eval() 1299 | ext = self.image_extension 1300 | num_rows = num_image_tiles 1301 | 1302 | latent_dim = self.GAN.G.latent_dim 1303 | image_size = self.GAN.G.image_size 1304 | num_layers = self.GAN.G.num_layers 1305 | 1306 | # latents and noise 1307 | 1308 | latents_low = noise(num_rows ** 2, latent_dim, device=self.rank) 1309 | latents_high = noise(num_rows ** 2, latent_dim, device=self.rank) 1310 | n = image_noise(num_rows ** 2, image_size, device=self.rank) 1311 | 1312 | ratios = torch.linspace(0., 8., num_steps) 1313 | 1314 | frames = [] 1315 | for ratio in tqdm(ratios): 1316 | interp_latents = slerp(ratio, latents_low, latents_high) 1317 | latents = [(interp_latents, num_layers)] 1318 | generated_images = self.generate_truncated(self.GAN.SE, self.GAN.GE, latents, n, trunc_psi = self.trunc_psi) 1319 | images_grid = torchvision.utils.make_grid(generated_images, nrow = num_rows) 1320 | pil_image = transforms.ToPILImage()(images_grid.cpu()) 1321 | 1322 | if self.transparent: 1323 | background = Image.new("RGBA", pil_image.size, (255, 255, 255)) 1324 | pil_image = Image.alpha_composite(background, pil_image) 1325 | 1326 | frames.append(pil_image) 1327 | 1328 | frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'), save_all=True, append_images=frames[1:], duration=80, loop=0, optimize=True) 1329 | 1330 | if save_frames: 1331 | folder_path = (self.results_dir / self.name / f'{str(num)}') 1332 | folder_path.mkdir(parents=True, exist_ok=True) 1333 | for ind, frame in enumerate(frames): 1334 | frame.save(str(folder_path / f'{str(ind)}.{ext}')) 1335 | 1336 | def print_log(self): 1337 | data = [ 1338 | ('G', self.g_loss), 1339 | ('D', self.d_loss), 1340 | ('GP', self.last_gp_loss), 1341 | ('PL', self.pl_mean), 1342 | ('CR', self.last_cr_loss), 1343 | ('Q', self.q_loss), 1344 | ('FID', self.last_fid) 1345 | ] 1346 | 1347 | data = [d for d in data if exists(d[1])] 1348 | log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data)) 1349 | print(log) 1350 | 1351 | def track(self, value, name): 1352 | if not exists(self.logger): 1353 | return 1354 | self.logger.track(value, name = name) 1355 | 1356 | def model_name(self, num): 1357 | return str(self.models_dir / self.name / f'model_{num}.pt') 1358 | 1359 | def init_folders(self): 1360 | (self.results_dir / self.name).mkdir(parents=True, exist_ok=True) 1361 | (self.models_dir / self.name).mkdir(parents=True, exist_ok=True) 1362 | 1363 | def clear(self): 1364 | rmtree(str(self.models_dir / self.name), True) 1365 | rmtree(str(self.results_dir / self.name), True) 1366 | rmtree(str(self.fid_dir), True) 1367 | rmtree(str(self.config_path), True) 1368 | self.init_folders() 1369 | 1370 | def save(self, num): 1371 | save_data = { 1372 | 'GAN': self.GAN.state_dict(), 1373 | 'version': __version__ 1374 | } 1375 | 1376 | if self.GAN.fp16: 1377 | save_data['amp'] = amp.state_dict() 1378 | 1379 | torch.save(save_data, self.model_name(num)) 1380 | self.write_config() 1381 | 1382 | def load(self, num = -1): 1383 | self.load_config() 1384 | 1385 | name = num 1386 | if num == -1: 1387 | file_paths = [p for p in Path(self.models_dir / self.name).glob('model_*.pt')] 1388 | saved_nums = sorted(map(lambda x: int(x.stem.split('_')[1]), file_paths)) 1389 | if len(saved_nums) == 0: 1390 | return 1391 | name = saved_nums[-1] 1392 | print(f'continuing from previous epoch - {name}') 1393 | 1394 | self.steps = name * self.save_every 1395 | 1396 | load_data = torch.load(self.model_name(name), weights_only = True) 1397 | 1398 | if 'version' in load_data: 1399 | print(f"loading from version {load_data['version']}") 1400 | 1401 | try: 1402 | self.GAN.load_state_dict(load_data['GAN']) 1403 | except Exception as e: 1404 | print('unable to load save model. please try downgrading the package to the version specified by the saved model') 1405 | raise e 1406 | if self.GAN.fp16 and 'amp' in load_data: 1407 | amp.load_state_dict(load_data['amp']) 1408 | 1409 | class ModelLoader: 1410 | def __init__(self, *, base_dir, name = 'default', load_from = -1): 1411 | self.model = Trainer(name = name, base_dir = base_dir) 1412 | self.model.load(load_from) 1413 | 1414 | def noise_to_styles(self, noise, trunc_psi = None): 1415 | noise = noise.cuda() 1416 | w = self.model.GAN.SE(noise) 1417 | if exists(trunc_psi): 1418 | w = self.model.truncate_style(w) 1419 | return w 1420 | 1421 | def styles_to_images(self, w): 1422 | batch_size, *_ = w.shape 1423 | num_layers = self.model.GAN.GE.num_layers 1424 | image_size = self.model.image_size 1425 | w_def = [(w, num_layers)] 1426 | 1427 | w_tensors = styles_def_to_tensor(w_def) 1428 | noise = image_noise(batch_size, image_size, device = 0) 1429 | 1430 | images = self.model.GAN.GE(w_tensors, noise) 1431 | images.clamp_(0., 1.) 1432 | return images 1433 | --------------------------------------------------------------------------------