├── LICENSE ├── README.md ├── scripts ├── load_model.py ├── setup_fid.py └── style-gan-pytorch │ ├── MLproject │ ├── README.md │ ├── conda.yaml │ ├── dnnlib │ ├── __init__.py │ ├── submission │ │ ├── __init__.py │ │ ├── _internal │ │ │ └── run.py │ │ ├── run_context.py │ │ └── submit.py │ ├── tflib │ │ ├── .ipynb_checkpoints │ │ │ ├── __init__-checkpoint.py │ │ │ └── tfutil-checkpoint.py │ │ ├── __init__.py │ │ ├── autosummary.py │ │ ├── network.py │ │ ├── optimizer.py │ │ └── tfutil.py │ └── util.py │ ├── generate.py │ ├── loss_criterions │ ├── __init__.py │ ├── base_loss_criterions.py │ └── gradient_losses.py │ ├── networks │ ├── __init__.py │ ├── building_blocks.py │ ├── custom_layers.py │ └── style_gan_net.py │ ├── train.py │ └── utils.py └── src ├── __init__.py ├── configs ├── checkpoint │ ├── after_each_epoch.yaml │ ├── after_each_epoch_fid.yaml │ ├── every_n_train_steps.yaml │ └── every_n_train_steps_fid.yaml ├── dataset │ ├── imagefolder.yaml │ ├── lsun.yaml │ ├── multiimagefolder.yaml │ ├── multilsun.yaml │ ├── nodata.yaml │ └── other_image_dataset.yaml ├── experiment │ ├── blobgan.yaml │ ├── debug.yaml │ ├── gan.yaml │ ├── invertblobgan.yaml │ ├── jitter.yaml │ └── local.yaml └── fit.yaml ├── data ├── __init__.py ├── imagefolder.py ├── multiimagefolder.py ├── nodata.py └── utils.py ├── models ├── __init__.py ├── base.py ├── blobgan.py ├── gan.py ├── invertblobgan.py └── networks │ ├── __init__.py │ ├── layoutnet.py │ ├── layoutstylegan.py │ ├── op │ ├── __init__.py │ ├── conv2d_gradfix.py │ ├── conv2d_gradfix_111andon.py │ ├── conv2d_gradfix_pre111.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu │ └── stylegan.py ├── run.py └── utils ├── __init__.py ├── colab.py ├── distributed.py ├── io.py ├── logging.py ├── misc.py ├── training.py └── wandb_logger.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, Dave Epstein 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## BlobGAN: Spatially Disentangled Scene Representations
Official PyTorch Implementation
2 | 3 | ### [Paper](https://arxiv.org/abs/2205.02837) | [Project Page](https://dave.ml/blobgan) | [Video](https://www.youtube.com/watch?v=KpUv82VsU5k) | [Interactive Demo ![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://dave.ml/blobgan/demo) 4 | 5 | https://user-images.githubusercontent.com/5674727/168323496-990b46a2-a11d-4192-898a-f5b683d20265.mp4 6 | 7 | This repository contains: 8 | 9 | * 🚂 Pre-trained BlobGAN models on three datasets: bedrooms, conference rooms, and a combination of kitchens, living rooms, and dining rooms 10 | * 💻 Code based on PyTorch Lightning ⚡ and Hydra 🐍 which fully supports CPU, single GPU, or multi GPU/node training and inference 11 | 12 | We also provide an [📓 interactive demo notebook](https://dave.ml/blobgan/demo) to help get started using our model. Download this notebook and run it on your own Python environment, or test it out on Colab. You can: 13 | 14 | * 🖌️️ Generate and edit realistic images with an interactive UI 15 | * 📹 Create animated videos showing off your edited scenes 16 | * 📸 **(new!)** Upload your own image and convert it into blobs! 17 | 18 | And, coming soon: 19 | * 🧬 More edits, as shown in the paper! Code for cloning, restyling, rotating, and reshaping blobs. 20 | 21 | ## Setup 22 | 23 | Run the commands below one at a time to download the latest version of the BlobGAN code, create a Conda environment, and install necessary packages and utilities. 24 | 25 | ```bash 26 | git clone https://github.com/dave-epstein/blobgan.git 27 | mkdir -p blobgan/logs/wandb 28 | conda create -y -n blobgan python=3.9 29 | conda activate blobgan 30 | conda install -y pytorch=1.11.0 torchvision=0.12.0 torchaudio cudatoolkit=11.3 -c pytorch 31 | conda install -y cudatoolkit-dev=11.3 -c conda-forge 32 | pip install tqdm==4.64.0 hydra-core==1.1.2 omegaconf==2.1.2 clean-fid==0.1.23 wandb==0.12.11 ipdb==0.13.9 lpips==0.1.4 einops==0.4.1 inputimeout==1.0.4 pytorch-lightning==1.5.10 matplotlib==3.5.2 "mpl_interactions[jupyter]==0.21.0" protobuf~=3.19.0 moviepy==1.0.3 33 | cd blobgan 34 | python scripts/setup_fid.py 35 | ``` 36 | And if you haven't installed `ninja` yet on your machine (to compile custom C++ code), do that. On Linux, this looks like: 37 | ``` 38 | wget -q --show-progress https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip 39 | sudo unzip -q ninja-linux.zip -d /usr/local/bin/ 40 | sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 41 | ``` 42 | 43 | 44 | ## Running pretrained models 45 | 46 | See `scripts/load_model.py` for an example of how to load a pre-trained model (using the provided `load_{blobgan/stylegan}_model` functions, which can be called from elsewhere) and generate images with it. You can also run the file from the command line to generate images and save them to disk. For example, from the `blobgan` directory, you can run: 47 | 48 | ```bash 49 | python scripts/load_model.py --model_data bed --dl_dir models --save_dir out --n_imgs 32 --save_blobs --label_blobs 50 | ``` 51 | 52 | Or 53 | 54 | ```bash 55 | python scripts/load_model.py --model_name stylegan --model_data conference --truncate 0.4 56 | ``` 57 | Note that the first run may take a minute or two longer as custom C++ code is compiled. See the command's help for more details and options: `scripts/load_model.py --help` 58 | 59 | Using these functions, you can access pretrained models on bedrooms (trained with or without jitter); conference rooms; and kitchens, living rooms, and dining rooms. 60 | 61 | ## Training your own model 62 | 63 | **Before training your model**, you'll need to modify `src/configs/experiments/local.yaml` to include your WandB information and machine-specific configuration (such as path to data -- `dataset.path` or `dataset.basepath` -- and number of GPUs `trainer.gpus`). Alternatively, you can exclude `local` from the `experiments` option in the commands below and specify these parameters directly on the command line. 64 | 65 | To turn off logging entirely, pass `logger=false`, or to only log to disk but not write to server, pass `wandb.offline=true`. Our code currently only supports WandB logging. 66 | 67 | Here's an example command which will train a model on LSUN bedrooms. We list the configuration modules to load for this experiment (`blobgan`, `local`, `jitter`) and then specify any other options as we desire. For example, if we wanted to train a model without jitter, we could just remove that module from the `experiments` array. 68 | 69 | ```bash 70 | python src/run.py +experiment=[blobgan,local,jitter] wandb.name='10-blob BlobGAN on bedrooms' 71 | ``` 72 | 73 | In some shells, you may need to add extra quotes around some of these options to prevent them from being parsed immediately on the command line. 74 | 75 | Train on the LSUN category of your choice by passing in `dataset.category`, e.g. `dataset.category=church`. Tackle multiple categories at once with `dataset=multilsun` and `dataset.categories=[kitchen,bedroom]`. 76 | 77 | You can also train on any collection of images by selecting `dataset=imagefolder` and passing in the path. The code expects at least a subfolder named `train` and optional subfolders named `validate` and `test`. The below command also illustrates how to set arbitrary options using Hydra syntax, such as turning off FID logging or changing dataloader batch size: 78 | 79 | ```bash 80 | python src/run.py +experiment=[blobgan,local,jitter] wandb.name='20-blob BlobGAN on Places' dataset.dataloader.batch_size=24 +model.log_fid_every_epoch=false dataset=imagefolder +dataset.path=/path/to/places/ model.n_features=20 81 | ``` 82 | 83 | Other parameters of interest are likely `trainer.log_every_n_steps`, `model.log_images_every_n_steps`, and `model.log_fid_every_n_steps`, which control frequency of logging scalars, images, and FID (set any of the latter two to -1 to disable). Also check out `checkpoint.every_n_train_steps` and `checkpoint.save_top_k` which dictate checkpoint saving frequency and decide how many most recent checkpoints to keep (`-1` means keep everything). 84 | 85 | ### Changing model feature dimensions 86 | 87 | To change `d_in`, set both `model.layout_net.feature_dim` and `model.generator.override_c_in` to the same value. To change `d_style`, change `model.dim`. 88 | 89 | ### Logging FID during training and at test 90 | 91 | In the initial codebase setup, you should have run `scripts/setup_fid.py` which will download and install FID statistics for three different datasets: 92 | 93 | * Bedrooms: `lsun_bedroom` 94 | * Conference rooms: `lsun_conference` 95 | * Kitchens, living rooms, dining rooms: `lsun_kld` 96 | 97 | **If either `model.log_fid_every_n_steps > -1` or `model.log_fid_every_epoch == true`, make sure that `model.fid_stats_name` is passed in.** If you are training on one of the three datasets from the paper, just pass in the string from the list above. 98 | 99 | If you are training on your own data, you'll need to first run `setup_fid.py` to precompute statistics on that. The command might look something like: 100 | 101 | ```bash 102 | python scripts/setup_fid.py --action compute_new --path /path/to/new/data --name newdata -j 32 -bs 256 103 | ``` 104 | 105 | Then, pass `model.fid_stats_name=newdata` on the command line. 106 | 107 | Note that the precomputed FID statistics are on 256px images. You will need to recompute if training at higher resolution. 108 | 109 | To run FID logging at test time, a simple snippet such as the following will return the score: 110 | ```python 111 | model = load_blobgan_model('bed', 'models', 'cuda', fixed_noise=False) 112 | model.fid_stats_name = 'lsun_bedroom' 113 | model.fid_n_imgs = 50000 114 | print(model.log_fid('train')) 115 | ``` 116 | 117 | ### Resuming training 118 | 119 | To continue a training run that was terminated, simply add `resume.id=PREVIOUS RUN ID`. To resume from a previous run but start a new WandB run (e.g. to avoid overwriting previous checkpoints), also pass in `wandb.id=null`. 120 | 121 | ### Training StyleGAN2 122 | 123 | Many of the above command line options apply (for controlling data and logging). For example, to train a StyleGAN2 model on LSUN conference rooms, run: 124 | 125 | ```bash 126 | python src/run.py +experiment=[gan,local] wandb.name='Conference room StyleGAN2' dataset.category=conference 127 | ``` 128 | 129 | This uses default StyleGAN2 hyperparameters: R1 regularization on D every 16 steps, path length regularization on G every 4, R1 weight 50 or gamma=100 (the weight is gamma/2). 130 | 131 | ### Training inversion encoders 132 | 133 | The same is true for training an inversion encoder. See this example command: 134 | 135 | ```bash 136 | python src/run.py +experiment=[invertblobgan,local] wandb.name='Inversion model' +model.G_pretrained.id="BLOBGAN MODEL ID HERE" +model.trunc_min=0.2 +model.trunc_max=0.4 model.lambda.fake_latents_MSE=10 137 | ``` 138 | 139 | Be sure to specify `model.G_pretrained.id` to match the ID of the BlobGAN model you are trying to invert. Also, you can set `model.G_pretrained.log_dir` to tell the program where to look for the model logs (this defaults to `./logs` if unspecified). The options `trunc_min` and `trunc_max` specify what truncation level to use (randomly sampled within the specified interval) when sampling fake images. If both are set to the same value (including zero, the default), this value will always be used. 140 | 141 | ## Citation 142 | 143 | If our code or models aided your research, please cite our [paper](https://arxiv.org/abs/2205.02837): 144 | ``` 145 | @misc{epstein2022blobgan, 146 | title={BlobGAN: Spatially Disentangled Scene Representations}, 147 | author={Dave Epstein and Taesung Park and Richard Zhang and Eli Shechtman and Alexei A. Efros}, 148 | year={2022}, 149 | eprint={2205.02837}, 150 | archivePrefix={arXiv}, 151 | primaryClass={cs.CV} 152 | } 153 | ``` 154 | 155 | ## Code acknowledgments 156 | 157 | This repository is built on top of rosinality's excellent [PyTorch re-implementation of StyleGAN2](https://github.com/rosinality/stylegan2-pytorch) and Bill Peebles' [GANgealing codebase](https://github.com/wpeebles/gangealing). 158 | -------------------------------------------------------------------------------- /scripts/load_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os, sys 3 | import torch 4 | from PIL import Image 5 | from tqdm import tqdm, trange 6 | 7 | here_dir = os.path.dirname(__file__) 8 | 9 | sys.path.append(os.path.join(here_dir, '..', 'src')) 10 | os.environ['PYTHONPATH'] = os.path.join(here_dir, '..', 'src') 11 | 12 | from models import BlobGAN, GAN, BlobGANInverter 13 | from utils import download_model, download_mean_latent, download_cherrypicked, KLD_COLORS, BED_CONF_COLORS, \ 14 | viz_score_fn, for_canvas, draw_labels, download 15 | 16 | 17 | def load_SGAN1_bedrooms(path, device='cuda'): 18 | ckpt = download(path=path, file='SGAN1_bedrooms.ckpt', load=True) 19 | sys.path.append(os.path.join(here_dir, 'style-gan-pytorch')) 20 | from networks.style_gan_net import Generator 21 | 22 | model = Generator(resolution=256) 23 | model.load_state_dict(ckpt) 24 | model.eval() 25 | return model.to(device) 26 | 27 | 28 | def load_stylegan1_model(model_data, path, device='cuda'): 29 | if model_data.startswith('bed'): 30 | model = load_SGAN1_bedrooms(path, device) 31 | Z = torch.randn((10000, 512)).to(device) 32 | latents = [model.g_mapping(Z[_:_ + 1])[0] for _ in trange(10000, desc='Computing mean latent')] 33 | model.mean_latent = torch.stack(latents, 0).mean(0) 34 | 35 | def SGAN1_gen(z, truncate): 36 | a = 1 - truncate 37 | dlatents = model.g_mapping(z).clone() 38 | if a < 1: 39 | dlatents = a * dlatents + (1 - a) * model.mean_latent 40 | x = model.g_synthesis(dlatents, 8, 1).clone() 41 | xx = ((x.clamp(min=-1, max=1) + 1) / 2.0) * 255 42 | return xx 43 | 44 | model.gen = SGAN1_gen 45 | else: 46 | raise ValueError('Only bedrooms supported for SGAN1.') 47 | 48 | 49 | def load_stylegan_model(model_data, path, device='cuda'): 50 | if model_data.startswith('bed'): 51 | datastr = 'bed' 52 | else: 53 | datastr = 'conference' if model_data.startswith('conference') else 'kitchenlivingdining' 54 | ckpt = download(path=path, file=f'SGAN2_{datastr}.ckpt') 55 | model = GAN.load_from_checkpoint(ckpt, strict=False).to(device) 56 | model.get_mean_latent() 57 | return model 58 | 59 | 60 | def load_blobgan_model(model_data, path, device='cuda', fixed_noise=False): 61 | ckpt = download_model(model_data, path) 62 | model = BlobGAN.load_from_checkpoint(ckpt, strict=False).to(device) 63 | try: 64 | model.mean_latent = download_mean_latent(model_data, path).to(device) 65 | except: 66 | model.get_mean_latent() 67 | try: 68 | model.cherry_picked = download_cherrypicked(model_data, path).to(device) 69 | except: 70 | pass 71 | COLORS = KLD_COLORS if 'kitchen' in model_data else BED_CONF_COLORS 72 | model.colors = COLORS 73 | noise = [torch.randn((1, 1, 16 * 2 ** ((i + 1) // 2), 16 * 2 ** ((i + 1) // 2))).to(device) for i in 74 | range(model.generator_ema.num_layers)] if fixed_noise else None 75 | model.noise = noise 76 | render_kwargs = { 77 | 'no_jitter': True, 78 | 'ret_layout': True, 79 | 'viz': True, 80 | 'ema': True, 81 | 'viz_colors': COLORS, 82 | 'norm_img': True, 83 | 'viz_score_fn': viz_score_fn, 84 | 'noise': noise 85 | } 86 | model.render_kwargs = render_kwargs 87 | return model 88 | 89 | 90 | def load_inversion_model(model_data, path, device): 91 | ckpt = download(model_data, suffix='_invert.ckpt', path=path, load=False) 92 | d_out = torch.load(ckpt, map_location='cpu')['state_dict']['inverter.final_linear.1.weight'].shape[0] 93 | model = BlobGANInverter.load_from_checkpoint(ckpt, strict=False, load_only_inverter=True, inverter_d_out=d_out).to( 94 | device) 95 | return model 96 | 97 | 98 | if __name__ == "__main__": 99 | import argparse 100 | 101 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 102 | parser.add_argument("-m", "--model_name", default='blobgan', 103 | choices=['blobgan', 'stylegan', 'stylegan1']) 104 | parser.add_argument("-d", "--model_data", default='bed', 105 | help="Choose a pretrained model. This must be a string that begins either with `bed_no_jitter` (bedrooms, trained without jitter), " 106 | "`bed` (bedrooms)," 107 | " `kitchen` (kitchens, living rooms, and dining rooms)," 108 | " or `conference` (conference rooms).") 109 | parser.add_argument("-dl", "--dl_dir", default='models', 110 | help='Path to a directory where model files will be downloaded.') 111 | parser.add_argument("-s", "--save_dir", default='out', 112 | help='Path to the directory where output images will be saved.') 113 | parser.add_argument("-n", "--n_imgs", default=100, type=int, help='Number of random images to generate.') 114 | parser.add_argument('-bs', '--batch_size', default=32, 115 | help='Number of images to generate in one forward pass. Adjust based on available GPU memory.', 116 | type=int) 117 | parser.add_argument('-t', '--truncate', default=0.4, 118 | help='Amount of truncation to use when generating images. 0 means no truncation, 1 means full truncation.', 119 | type=float) 120 | parser.add_argument("--save_blobs", action='store_true', 121 | help='If passed, save images of blob maps (when `--model_name` is BlobGAN).') 122 | parser.add_argument("--label_blobs", action='store_true', 123 | help='If passed, add numeric blob labels to blob map images, when `--save_blobs` is true.') 124 | parser.add_argument('--size_threshold', default=-3, 125 | help='Threshold for blob size parameter above which to render blob labels, when `--label_blobs` is true.', 126 | type=float) 127 | parser.add_argument('--device', default='cuda', 128 | help='Specify the device on which to run the code, in PyTorch syntax, e.g. `cuda`, `cpu`, `cuda:3`.') 129 | parser.add_argument('--fixed_spatial_noise', action='store_true', 130 | help='Whether to use random spatial noise to generate images. ' 131 | 'This is false by default for general use cases, but set it to true for things like animation.') 132 | args = parser.parse_args() 133 | 134 | blobgan = args.model_name == 'blobgan' 135 | stylegan = args.model_name == 'stylegan' 136 | sgan1 = args.model_name == 'stylegan1' 137 | 138 | save_dir = Path(args.save_dir) 139 | (save_dir / 'imgs').mkdir(exist_ok=True, parents=True) 140 | 141 | if blobgan: 142 | model = load_blobgan_model(args.model_data, args.dl_dir, args.device, fixed_noise=args.fixed_spatial_noise) 143 | 144 | if args.save_blobs: 145 | (save_dir / 'blobs').mkdir(exist_ok=True, parents=True) 146 | if args.label_blobs: 147 | (save_dir / 'blobs_labeled').mkdir(exist_ok=True, parents=True) 148 | elif stylegan: 149 | model = load_stylegan_model(args.model_data, args.dl_dir, args.device) 150 | elif sgan1: 151 | model = load_stylegan1_model(args.model_data, args.dl_dir, args.device) 152 | else: 153 | raise NotImplementedError('Inversion of images from command line not yet supported. ') 154 | 155 | n_to_gen = args.n_imgs 156 | n_gen = 0 157 | 158 | torch.set_grad_enabled(False) 159 | 160 | with tqdm(total=args.n_imgs, desc='Generating images') as pbar: 161 | while n_to_gen > 0: 162 | bs = min(args.batch_size, n_to_gen) 163 | z = torch.randn((bs, 512)).to(args.device) 164 | 165 | if blobgan: 166 | layout, orig_img = model.gen(z=z, truncate=args.truncate, **model.render_kwargs) 167 | else: 168 | orig_img = model.gen(z=z, truncate=args.truncate) 169 | 170 | for i in range(len(orig_img)): 171 | img_i = for_canvas(orig_img[i:i + 1]) 172 | Image.fromarray(img_i).save(str(save_dir / 'imgs' / f'{i + n_gen:04d}.png')) 173 | if blobgan and args.save_blobs: 174 | blobs_i = for_canvas(layout['feature_img'][i:i + 1].mul(255)) 175 | Image.fromarray(blobs_i).save(str(save_dir / 'blobs' / f'{i + n_gen:04d}.png')) 176 | if args.label_blobs: 177 | labeled_blobs, labeled_blobs_img = draw_labels(blobs_i, layout, args.size_threshold, 178 | model.colors, layout_i=i) 179 | labeled_blobs_img.save(str(save_dir / 'blobs_labeled' / f'{i + n_gen:04d}.png')) 180 | 181 | n_to_gen -= bs 182 | n_gen += bs 183 | pbar.update(bs) 184 | -------------------------------------------------------------------------------- /scripts/setup_fid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | import cleanfid 5 | import torch 6 | import numpy as np 7 | from cleanfid.features import build_feature_extractor 8 | from cleanfid.fid import get_folder_features 9 | from torchvision.transforms import functional as F 10 | 11 | here_dir = os.path.dirname(__file__) 12 | 13 | sys.path.append(os.path.join(here_dir, '..', 'src')) 14 | os.environ['PYTHONPATH'] = os.path.join(here_dir, '..', 'src') 15 | 16 | from utils import download 17 | 18 | 19 | if __name__ == "__main__": 20 | import argparse 21 | 22 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | parser.add_argument('--action', default='download', choices=['download', 'compute_new'], 24 | help='All other options only apply if action is set to `compute_new`.' 25 | ' Download mode (default) simply configures precomputed stats used in the BlobGAN paper.') 26 | parser.add_argument('--path', default='', type=str, 27 | help='Path to custom folder from which to sample `--n_imgs` images and compute FID statistics.') 28 | parser.add_argument('--n_imgs', type=int, default=-1, 29 | help='Number of images to randomly sample for FID stats. Set to -1 to use all images.') 30 | parser.add_argument('--shuffle', action='store_true', 31 | help='Shuffle the files in a directory before selecting `--n_imgs` for computation.') 32 | parser.add_argument('--name', default=None, help='Name to give custom stats.') 33 | parser.add_argument('-bs', '--batch_size', default=32, 34 | help='Number of images to analyze in one forward pass. Adjust based on available GPU memory.', 35 | type=int) 36 | parser.add_argument('-j', '--num_workers', default=8, 37 | help='Number of workers to use for FID stats generation.', 38 | type=int) 39 | parser.add_argument('-r', '--resolution', default=256, 40 | help='Image resolution to use before feeding images into FID pipeline (where they are resized to 299).', 41 | type=int) 42 | parser.add_argument('--device', default='cuda', 43 | help='Specify the device on which to run the code, in PyTorch syntax, ' 44 | 'e.g. `cuda`, `cpu`, `cuda:3`.') 45 | args = parser.parse_args() 46 | 47 | 48 | def load_fn(x): 49 | x = F.resize(torch.from_numpy(x).permute(2, 0, 1), args.resolution) 50 | x = F.center_crop(x, args.resolution).permute(1, 2, 0) 51 | return np.array(x) 52 | 53 | if args.action == 'download': 54 | path = os.path.join(os.path.dirname(cleanfid.__file__), "stats") 55 | stats = download(path=path, file='fid_stats.tar.gz') 56 | subprocess.run(["tar", "xvzf", stats, '-C', path, '--strip-components', '1']) 57 | else: 58 | print('Calculating...') 59 | name, mode, device, fdir, num = args.name, "clean", torch.device(args.device), args.path, args.n_imgs 60 | assert name 61 | stats_folder = os.path.join(os.path.dirname(cleanfid.__file__), "stats") 62 | os.makedirs(stats_folder, exist_ok=True) 63 | split, res = "custom", "na" 64 | outname = f"{name}_{mode}_{split}_{res}.npz" 65 | outf = os.path.join(stats_folder, outname) 66 | # if the custom stat file already exists 67 | if os.path.exists(outf): 68 | msg = f"The statistics file {name} already exists. " 69 | msg += "Use remove_custom_stats function to delete it first." 70 | raise Exception(msg) 71 | 72 | feat_model = build_feature_extractor(mode, device) 73 | fbname = os.path.basename(fdir) 74 | # get all inception features for folder images 75 | if num < 0: num = None 76 | np_feats = get_folder_features(fdir, feat_model, num_workers=args.num_workers, num=num, shuffle=args.shuffle, 77 | batch_size=args.batch_size, device=device, custom_image_tranform=load_fn, 78 | mode=mode, description=f"custom stats: {fbname} : ") 79 | mu = np.mean(np_feats, axis=0) 80 | sigma = np.cov(np_feats, rowvar=False) 81 | print(f"Saving custom FID stats to {outf}") 82 | np.savez_compressed(outf, mu=mu, sigma=sigma) 83 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/MLproject: -------------------------------------------------------------------------------- 1 | name: style-gan 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | convert: {type: string, default: True} 9 | official_checkpoints: {type: string, default: True} 10 | random_seed: {type: int, default: 77} 11 | dataset: {type: string, default: ffhq} 12 | nrow: {type: int, default: 2} 13 | ncol: {type: int, default: 2} 14 | g_checkpoint: {type: string, default: ./checkpoints/generator.64x64.0.759840.3460000.158.pt} 15 | target_resolution: {type: int, default: 128} 16 | command: | 17 | python generate.py \ 18 | --dataset {dataset} \ 19 | --convert {convert} \ 20 | --use-official-checkpoints {official_checkpoints} \ 21 | --random-seed {random_seed} \ 22 | --nrow {nrow} \ 23 | --ncol {ncol} \ 24 | --g-checkpoint {g_checkpoint} \ 25 | --target-resolution {target_resolution} 26 | 27 | train: 28 | parameters: 29 | data_root: {type: string, default: ./data/celeba} 30 | resume: {type: string, default: True} 31 | g_checkpoint: {type: string, default: ./checkpoints/generator.64x64.0.759840.3460000.158.pt} 32 | d_checkpoint: {type: string, default: ./checkpoints/discriminator.64x64.0.759840.3460000.158.pt} 33 | target_resolution: {type: int, default: 128} 34 | n_gpu: {type: int, default: 1} 35 | command: | 36 | python train.py \ 37 | --data-root {data_root} \ 38 | --resume {resume} \ 39 | --g-checkpoint {g_checkpoint} \ 40 | --d-checkpoint {d_checkpoint} \ 41 | --target-resolution {target_resolution} \ 42 | --n-gpu {n_gpu} 43 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN Pytorch Implementation 2 | This is a Pytorch implementation of StyleGAN (https://arxiv.org/abs/1812.04948), with the capability of generating 1024x1024 pictures. Training to grow to 1024x1024 is also supported. A 1080 Ti is recommended for faster training speed. 3 | 4 | ## Prerequisites 5 | ### Dependencies 6 | See conda.yaml. Please note that I have cuda 10.0 installed. Change your conda.yaml accordingly if you use different cuda version. 7 | ### Image generation using official implementation's TensorFlow checkpoints 8 | 9 | **This step is not needed if running the generation command succeeds for downloading.** If downloading fails for Google Drive, manual download is required: 10 | 11 | * [ffhq-1024x1024](https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ) 12 | * [bedrooms-256x256](https://drive.google.com/open?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF) 13 | * [cats-256x256](https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ) 14 | 15 | And place them in ./pretrained directory. 16 | 17 | ### Training prerequisites on CelebA dataset 18 | Download the [celeba](https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg) dataset. Unzip the .zip file into ./data/celeba directory. 19 | 20 | ## Image Generation 21 | ### Image generation using official checkpoints 22 | Run the command: 23 | ```bash 24 | mlflow -e generate . -P dataset=cats 25 | ``` 26 | The default random seed is 77. To generate different images with different image grid (note that the number of images you can generate is limited by your GPU). 27 | ```bash 28 | mlflow -e generate -P dataset=cats -P random-seed=777 -P nrow=2 -P ncol=5 29 | ``` 30 | This will generate 10 images at once. 31 | ### Image generation using checkpoints generated by this code 32 | ```bash 33 | mlflow -e genearte . \ 34 | -P use_official_checkpoints=False \ 35 | -P g_checkpoint=[path_to_generator_checkpoint] \ 36 | -P target_resolution=128 \ 37 | -P nrow=2 \ 38 | -P ncol=2 39 | ``` 40 | This will generate images using checkpoints trained by this code. 41 | 42 | ## Training on CelebA dataset 43 | Run the command to start from scratch: 44 | ```bash 45 | mlflow -e train . -P resume=False 46 | ``` 47 | This will kick off the training for 128x128 resolution on CelebA dataset. During training, the model checkpoints are stored under ./checkpoints, and the fake images are generated for checking under ./checks/fake\_imgs. Note that this is a progressive process starting from 8x8, so you will see 8x8 images in the begining and 128x128 images in the end of the training process. 48 | 49 | To resume training: 50 | ```bash 51 | mlflow -e train .\ 52 | -P resume=True \ 53 | -P g_checkpoint=[path_to_generator_checkpoint] \ 54 | -P d_checkpoint=[path_to_discriminator_checkpoint] 55 | ``` 56 | For other training options, please check the MLproject file. For hyperparameters, please check train.py and NVidia's official implementation. 57 | 58 | ## TODO 59 | 1. Add truncation trick 60 | 2. Add and experiment with other loss functions (some are in the repo but not tried) 61 | 3. Add tensorboard support 62 | 4. Add moving average of generator's weight 63 | 64 | Multi-GPU support is added but not experimented due to hardware limitation. 65 | 66 | ## License 67 | This project is under BSD-3 license. 68 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/conda.yaml: -------------------------------------------------------------------------------- 1 | name: pytorch_stylegan 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - python=3.6 7 | - cudatoolkit=10.0 # my machine has cuda 10.0 8 | - pytorch=1.1.0 9 | - torchvision=0.3.0 10 | - tensorflow-gpu=1.13.1 11 | - jupyter=1.0.0 12 | - matplotlib=3.1.0 13 | - pip: 14 | - mlflow>=1.0 15 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import submission 9 | 10 | from .submission.run_context import RunContext 11 | 12 | from .submission.submit import SubmitTarget 13 | from .submission.submit import PathType 14 | from .submission.submit import SubmitConfig 15 | from .submission.submit import get_path_from_template 16 | from .submission.submit import submit_run 17 | 18 | from .util import EasyDict 19 | 20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function. 21 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/submission/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import run_context 9 | from . import submit 10 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/submission/_internal/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for launching run functions in computing clusters. 9 | 10 | During the submit process, this file is copied to the appropriate run dir. 11 | When the job is launched in the cluster, this module is the first thing that 12 | is run inside the docker container. 13 | """ 14 | 15 | import os 16 | import pickle 17 | import sys 18 | 19 | # PYTHONPATH should have been set so that the run_dir/src is in it 20 | import dnnlib 21 | 22 | def main(): 23 | if not len(sys.argv) >= 4: 24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!") 25 | 26 | run_dir = str(sys.argv[1]) 27 | task_name = str(sys.argv[2]) 28 | host_name = str(sys.argv[3]) 29 | 30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl") 31 | 32 | # SubmitConfig should have been pickled to the run dir 33 | if not os.path.exists(submit_config_path): 34 | raise RuntimeError("SubmitConfig pickle file does not exist!") 35 | 36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb")) 37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name) 38 | 39 | submit_config.task_name = task_name 40 | submit_config.host_name = host_name 41 | 42 | dnnlib.submission.submit.run_wrapper(submit_config) 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/submission/run_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helpers for managing the run/training loop.""" 9 | 10 | import datetime 11 | import json 12 | import os 13 | import pprint 14 | import time 15 | import types 16 | 17 | from typing import Any 18 | 19 | from . import submit 20 | 21 | 22 | class RunContext(object): 23 | """Helper class for managing the run/training loop. 24 | 25 | The context will hide the implementation details of a basic run/training loop. 26 | It will set things up properly, tell if run should be stopped, and then cleans up. 27 | User should call update periodically and use should_stop to determine if run should be stopped. 28 | 29 | Args: 30 | submit_config: The SubmitConfig that is used for the current run. 31 | config_module: The whole config module that is used for the current run. 32 | max_epoch: Optional cached value for the max_epoch variable used in update. 33 | """ 34 | 35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None): 36 | self.submit_config = submit_config 37 | self.should_stop_flag = False 38 | self.has_closed = False 39 | self.start_time = time.time() 40 | self.last_update_time = time.time() 41 | self.last_update_interval = 0.0 42 | self.max_epoch = max_epoch 43 | 44 | # pretty print the all the relevant content of the config module to a text file 45 | if config_module is not None: 46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f: 47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))} 48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False) 49 | 50 | # write out details about the run to a text file 51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")} 52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f: 53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 54 | 55 | def __enter__(self) -> "RunContext": 56 | return self 57 | 58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 59 | self.close() 60 | 61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None: 62 | """Do general housekeeping and keep the state of the context up-to-date. 63 | Should be called often enough but not in a tight loop.""" 64 | assert not self.has_closed 65 | 66 | self.last_update_interval = time.time() - self.last_update_time 67 | self.last_update_time = time.time() 68 | 69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")): 70 | self.should_stop_flag = True 71 | 72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch 73 | 74 | def should_stop(self) -> bool: 75 | """Tell whether a stopping condition has been triggered one way or another.""" 76 | return self.should_stop_flag 77 | 78 | def get_time_since_start(self) -> float: 79 | """How much time has passed since the creation of the context.""" 80 | return time.time() - self.start_time 81 | 82 | def get_time_since_last_update(self) -> float: 83 | """How much time has passed since the last call to update.""" 84 | return time.time() - self.last_update_time 85 | 86 | def get_last_update_interval(self) -> float: 87 | """How much time passed between the previous two calls to update.""" 88 | return self.last_update_interval 89 | 90 | def close(self) -> None: 91 | """Close the context and clean up. 92 | Should only be called once.""" 93 | if not self.has_closed: 94 | # update the run.txt with stopping time 95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ") 96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f: 97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False) 98 | 99 | self.has_closed = True 100 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/tflib/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/tflib/.ipynb_checkpoints/tfutil-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from typing import Any, Iterable, List, Union 15 | 16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 17 | """A type that represents a valid Tensorflow expression.""" 18 | 19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 20 | """A type that can be converted to a valid Tensorflow expression.""" 21 | 22 | 23 | def run(*args, **kwargs) -> Any: 24 | """Run the specified ops in the default session.""" 25 | assert_tf_initialized() 26 | return tf.get_default_session().run(*args, **kwargs) 27 | 28 | 29 | def is_tf_expression(x: Any) -> bool: 30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 32 | 33 | 34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 35 | """Convert a Tensorflow shape to a list of ints.""" 36 | return [dim.value for dim in shape] 37 | 38 | 39 | def flatten(x: TfExpressionEx) -> TfExpression: 40 | """Shortcut function for flattening a tensor.""" 41 | with tf.name_scope("Flatten"): 42 | return tf.reshape(x, [-1]) 43 | 44 | 45 | def log2(x: TfExpressionEx) -> TfExpression: 46 | """Logarithm in base 2.""" 47 | with tf.name_scope("Log2"): 48 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 49 | 50 | 51 | def exp2(x: TfExpressionEx) -> TfExpression: 52 | """Exponent in base 2.""" 53 | with tf.name_scope("Exp2"): 54 | return tf.exp(x * np.float32(np.log(2.0))) 55 | 56 | 57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 58 | """Linear interpolation.""" 59 | with tf.name_scope("Lerp"): 60 | return a + (b - a) * t 61 | 62 | 63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 64 | """Linear interpolation with clip.""" 65 | with tf.name_scope("LerpClip"): 66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 67 | 68 | 69 | def absolute_name_scope(scope: str) -> tf.name_scope: 70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 71 | return tf.name_scope(scope + "/") 72 | 73 | 74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 77 | 78 | 79 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 80 | # Defaults. 81 | cfg = dict() 82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 87 | 88 | # User overrides. 89 | if config_dict is not None: 90 | cfg.update(config_dict) 91 | return cfg 92 | 93 | 94 | def init_tf(config_dict: dict = None) -> None: 95 | """Initialize TensorFlow session using good default settings.""" 96 | # Skip if already initialized. 97 | if tf.get_default_session() is not None: 98 | return 99 | 100 | # Setup config dict and random seeds. 101 | cfg = _sanitize_tf_config(config_dict) 102 | np_random_seed = cfg["rnd.np_random_seed"] 103 | if np_random_seed is not None: 104 | np.random.seed(np_random_seed) 105 | tf_random_seed = cfg["rnd.tf_random_seed"] 106 | if tf_random_seed == "auto": 107 | tf_random_seed = np.random.randint(1 << 31) 108 | if tf_random_seed is not None: 109 | tf.set_random_seed(tf_random_seed) 110 | 111 | # Setup environment variables. 112 | for key, value in list(cfg.items()): 113 | fields = key.split(".") 114 | if fields[0] == "env": 115 | assert len(fields) == 2 116 | os.environ[fields[1]] = str(value) 117 | 118 | # Create default TensorFlow session. 119 | create_session(cfg, force_as_default=True) 120 | 121 | 122 | def assert_tf_initialized(): 123 | """Check that TensorFlow session has been initialized.""" 124 | if tf.get_default_session() is None: 125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 126 | 127 | 128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 129 | """Create tf.Session based on config dict.""" 130 | # Setup TensorFlow config proto. 131 | cfg = _sanitize_tf_config(config_dict) 132 | config_proto = tf.ConfigProto() 133 | for key, value in cfg.items(): 134 | fields = key.split(".") 135 | if fields[0] not in ["rnd", "env"]: 136 | obj = config_proto 137 | for field in fields[:-1]: 138 | obj = getattr(obj, field) 139 | setattr(obj, fields[-1], value) 140 | 141 | # Create session. 142 | session = tf.Session(config=config_proto) 143 | if force_as_default: 144 | # pylint: disable=protected-access 145 | session._default_session = session.as_default() 146 | session._default_session.enforce_nesting = False 147 | session._default_session.__enter__() # pylint: disable=no-member 148 | 149 | return session 150 | 151 | 152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 153 | """Initialize all tf.Variables that have not already been initialized. 154 | 155 | Equivalent to the following, but more efficient and does not bloat the tf graph: 156 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 157 | """ 158 | assert_tf_initialized() 159 | if target_vars is None: 160 | target_vars = tf.global_variables() 161 | 162 | test_vars = [] 163 | test_ops = [] 164 | 165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 166 | for var in target_vars: 167 | assert is_tf_expression(var) 168 | 169 | try: 170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 171 | except KeyError: 172 | # Op does not exist => variable may be uninitialized. 173 | test_vars.append(var) 174 | 175 | with absolute_name_scope(var.name.split(":")[0]): 176 | test_ops.append(tf.is_variable_initialized(var)) 177 | 178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 179 | run([var.initializer for var in init_vars]) 180 | 181 | 182 | def set_vars(var_to_value_dict: dict) -> None: 183 | """Set the values of given tf.Variables. 184 | 185 | Equivalent to the following, but more efficient and does not bloat the tf graph: 186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 187 | """ 188 | assert_tf_initialized() 189 | ops = [] 190 | feed_dict = {} 191 | 192 | for var, value in var_to_value_dict.items(): 193 | assert is_tf_expression(var) 194 | 195 | try: 196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 197 | except KeyError: 198 | with absolute_name_scope(var.name.split(":")[0]): 199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 201 | 202 | ops.append(setter) 203 | feed_dict[setter.op.inputs[1]] = value 204 | 205 | run(ops, feed_dict) 206 | 207 | 208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 209 | """Create tf.Variable with large initial value without bloating the tf graph.""" 210 | assert_tf_initialized() 211 | assert isinstance(initial_value, np.ndarray) 212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 213 | var = tf.Variable(zeros, *args, **kwargs) 214 | set_vars({var: initial_value}) 215 | return var 216 | 217 | 218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 220 | Can be used as an input transformation for Network.run(). 221 | """ 222 | images = tf.cast(images, tf.float32) 223 | if nhwc_to_nchw: 224 | images = tf.transpose(images, [0, 3, 1, 2]) 225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 226 | 227 | 228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 230 | Can be used as an output transformation for Network.run(). 231 | """ 232 | images = tf.cast(images, tf.float32) 233 | if shrink > 1: 234 | ksize = [1, 1, shrink, shrink] 235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 236 | if nchw_to_nhwc: 237 | images = tf.transpose(images, [0, 2, 3, 1]) 238 | scale = 255 / (drange[1] - drange[0]) 239 | images = images * scale + (0.5 - drange[0] * scale) 240 | return tf.saturate_cast(images, tf.uint8) 241 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | from . import autosummary 9 | from . import network 10 | from . import optimizer 11 | from . import tfutil 12 | 13 | from .tfutil import * 14 | from .network import Network 15 | 16 | from .optimizer import Optimizer 17 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/tflib/autosummary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper for adding automatically tracked values to Tensorboard. 9 | 10 | Autosummary creates an identity op that internally keeps track of the input 11 | values and automatically shows up in TensorBoard. The reported value 12 | represents an average over input components. The average is accumulated 13 | constantly over time and flushed when save_summaries() is called. 14 | 15 | Notes: 16 | - The output tensor must be used as an input for something else in the 17 | graph. Otherwise, the autosummary op will not get executed, and the average 18 | value will not get accumulated. 19 | - It is perfectly fine to include autosummaries with the same name in 20 | several places throughout the graph, even if they are executed concurrently. 21 | - It is ok to also pass in a python scalar or numpy array. In this case, it 22 | is added to the average immediately. 23 | """ 24 | 25 | from collections import OrderedDict 26 | import numpy as np 27 | import tensorflow as tf 28 | from tensorboard import summary as summary_lib 29 | from tensorboard.plugins.custom_scalar import layout_pb2 30 | 31 | from . import tfutil 32 | from .tfutil import TfExpression 33 | from .tfutil import TfExpressionEx 34 | 35 | _dtype = tf.float64 36 | _vars = OrderedDict() # name => [var, ...] 37 | _immediate = OrderedDict() # name => update_op, update_value 38 | _finalized = False 39 | _merge_op = None 40 | 41 | 42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression: 43 | """Internal helper for creating autosummary accumulators.""" 44 | assert not _finalized 45 | name_id = name.replace("/", "_") 46 | v = tf.cast(value_expr, _dtype) 47 | 48 | if v.shape.is_fully_defined(): 49 | size = np.prod(tfutil.shape_to_list(v.shape)) 50 | size_expr = tf.constant(size, dtype=_dtype) 51 | else: 52 | size = None 53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype)) 54 | 55 | if size == 1: 56 | if v.shape.ndims != 0: 57 | v = tf.reshape(v, []) 58 | v = [size_expr, v, tf.square(v)] 59 | else: 60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))] 61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype)) 62 | 63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None): 64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)] 65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v)) 66 | 67 | if name in _vars: 68 | _vars[name].append(var) 69 | else: 70 | _vars[name] = [var] 71 | return update_op 72 | 73 | 74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx: 75 | """Create a new autosummary. 76 | 77 | Args: 78 | name: Name to use in TensorBoard 79 | value: TensorFlow expression or python value to track 80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node. 81 | 82 | Example use of the passthru mechanism: 83 | 84 | n = autosummary('l2loss', loss, passthru=n) 85 | 86 | This is a shorthand for the following code: 87 | 88 | with tf.control_dependencies([autosummary('l2loss', loss)]): 89 | n = tf.identity(n) 90 | """ 91 | tfutil.assert_tf_initialized() 92 | name_id = name.replace("/", "_") 93 | 94 | if tfutil.is_tf_expression(value): 95 | with tf.name_scope("summary_" + name_id), tf.device(value.device): 96 | update_op = _create_var(name, value) 97 | with tf.control_dependencies([update_op]): 98 | return tf.identity(value if passthru is None else passthru) 99 | 100 | else: # python scalar or numpy array 101 | if name not in _immediate: 102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None): 103 | update_value = tf.placeholder(_dtype) 104 | update_op = _create_var(name, update_value) 105 | _immediate[name] = update_op, update_value 106 | 107 | update_op, update_value = _immediate[name] 108 | tfutil.run(update_op, {update_value: value}) 109 | return value if passthru is None else passthru 110 | 111 | 112 | def finalize_autosummaries() -> None: 113 | """Create the necessary ops to include autosummaries in TensorBoard report. 114 | Note: This should be done only once per graph. 115 | """ 116 | global _finalized 117 | tfutil.assert_tf_initialized() 118 | 119 | if _finalized: 120 | return None 121 | 122 | _finalized = True 123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list]) 124 | 125 | # Create summary ops. 126 | with tf.device(None), tf.control_dependencies(None): 127 | for name, vars_list in _vars.items(): 128 | name_id = name.replace("/", "_") 129 | with tfutil.absolute_name_scope("Autosummary/" + name_id): 130 | moments = tf.add_n(vars_list) 131 | moments /= moments[0] 132 | with tf.control_dependencies([moments]): # read before resetting 133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list] 134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting 135 | mean = moments[1] 136 | std = tf.sqrt(moments[2] - tf.square(moments[1])) 137 | tf.summary.scalar(name, mean) 138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std) 139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std) 140 | 141 | # Group by category and chart name. 142 | cat_dict = OrderedDict() 143 | for series_name in sorted(_vars.keys()): 144 | p = series_name.split("/") 145 | cat = p[0] if len(p) >= 2 else "" 146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1] 147 | if cat not in cat_dict: 148 | cat_dict[cat] = OrderedDict() 149 | if chart not in cat_dict[cat]: 150 | cat_dict[cat][chart] = [] 151 | cat_dict[cat][chart].append(series_name) 152 | 153 | # Setup custom_scalar layout. 154 | categories = [] 155 | for cat_name, chart_dict in cat_dict.items(): 156 | charts = [] 157 | for chart_name, series_names in chart_dict.items(): 158 | series = [] 159 | for series_name in series_names: 160 | series.append(layout_pb2.MarginChartContent.Series( 161 | value=series_name, 162 | lower="xCustomScalars/" + series_name + "/margin_lo", 163 | upper="xCustomScalars/" + series_name + "/margin_hi")) 164 | margin = layout_pb2.MarginChartContent(series=series) 165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin)) 166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts)) 167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories)) 168 | return layout 169 | 170 | def save_summaries(file_writer, global_step=None): 171 | """Call FileWriter.add_summary() with all summaries in the default graph, 172 | automatically finalizing and merging them on the first call. 173 | """ 174 | global _merge_op 175 | tfutil.assert_tf_initialized() 176 | 177 | if _merge_op is None: 178 | layout = finalize_autosummaries() 179 | if layout is not None: 180 | file_writer.add_summary(layout) 181 | with tf.device(None), tf.control_dependencies(None): 182 | _merge_op = tf.summary.merge_all() 183 | 184 | file_writer.add_summary(_merge_op.eval(), global_step) 185 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/tflib/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Helper wrapper for a Tensorflow optimizer.""" 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from collections import OrderedDict 14 | from typing import List, Union 15 | 16 | from . import autosummary 17 | from . import tfutil 18 | from .. import util 19 | 20 | from .tfutil import TfExpression, TfExpressionEx 21 | 22 | try: 23 | # TensorFlow 1.13 24 | from tensorflow.python.ops import nccl_ops 25 | except: 26 | # Older TensorFlow versions 27 | import tensorflow.contrib.nccl as nccl_ops 28 | 29 | class Optimizer: 30 | """A Wrapper for tf.train.Optimizer. 31 | 32 | Automatically takes care of: 33 | - Gradient averaging for multi-GPU training. 34 | - Dynamic loss scaling and typecasts for FP16 training. 35 | - Ignoring corrupted gradients that contain NaNs/Infs. 36 | - Reporting statistics. 37 | - Well-chosen default settings. 38 | """ 39 | 40 | def __init__(self, 41 | name: str = "Train", 42 | tf_optimizer: str = "tf.train.AdamOptimizer", 43 | learning_rate: TfExpressionEx = 0.001, 44 | use_loss_scaling: bool = False, 45 | loss_scaling_init: float = 64.0, 46 | loss_scaling_inc: float = 0.0005, 47 | loss_scaling_dec: float = 1.0, 48 | **kwargs): 49 | 50 | # Init fields. 51 | self.name = name 52 | self.learning_rate = tf.convert_to_tensor(learning_rate) 53 | self.id = self.name.replace("/", ".") 54 | self.scope = tf.get_default_graph().unique_name(self.id) 55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer) 56 | self.optimizer_kwargs = dict(kwargs) 57 | self.use_loss_scaling = use_loss_scaling 58 | self.loss_scaling_init = loss_scaling_init 59 | self.loss_scaling_inc = loss_scaling_inc 60 | self.loss_scaling_dec = loss_scaling_dec 61 | self._grad_shapes = None # [shape, ...] 62 | self._dev_opt = OrderedDict() # device => optimizer 63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...] 64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor) 65 | self._updates_applied = False 66 | 67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None: 68 | """Register the gradients of the given loss function with respect to the given variables. 69 | Intended to be called once per GPU.""" 70 | assert not self._updates_applied 71 | 72 | # Validate arguments. 73 | if isinstance(trainable_vars, dict): 74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars 75 | 76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1 77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss]) 78 | 79 | if self._grad_shapes is None: 80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars] 81 | 82 | assert len(trainable_vars) == len(self._grad_shapes) 83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes)) 84 | 85 | dev = loss.device 86 | 87 | assert all(var.device == dev for var in trainable_vars) 88 | 89 | # Register device and compute gradients. 90 | with tf.name_scope(self.id + "_grad"), tf.device(dev): 91 | if dev not in self._dev_opt: 92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt) 93 | assert callable(self.optimizer_class) 94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs) 95 | self._dev_grads[dev] = [] 96 | 97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32)) 98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage 99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros 100 | self._dev_grads[dev].append(grads) 101 | 102 | def apply_updates(self) -> tf.Operation: 103 | """Construct training op to update the registered variables based on their gradients.""" 104 | tfutil.assert_tf_initialized() 105 | assert not self._updates_applied 106 | self._updates_applied = True 107 | devices = list(self._dev_grads.keys()) 108 | total_grads = sum(len(grads) for grads in self._dev_grads.values()) 109 | assert len(devices) >= 1 and total_grads >= 1 110 | ops = [] 111 | 112 | with tfutil.absolute_name_scope(self.scope): 113 | # Cast gradients to FP32 and calculate partial sum within each device. 114 | dev_grads = OrderedDict() # device => [(grad, var), ...] 115 | 116 | for dev_idx, dev in enumerate(devices): 117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev): 118 | sums = [] 119 | 120 | for gv in zip(*self._dev_grads[dev]): 121 | assert all(v is gv[0][1] for g, v in gv) 122 | g = [tf.cast(g, tf.float32) for g, v in gv] 123 | g = g[0] if len(g) == 1 else tf.add_n(g) 124 | sums.append((g, gv[0][1])) 125 | 126 | dev_grads[dev] = sums 127 | 128 | # Sum gradients across devices. 129 | if len(devices) > 1: 130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None): 131 | for var_idx, grad_shape in enumerate(self._grad_shapes): 132 | g = [dev_grads[dev][var_idx][0] for dev in devices] 133 | 134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors 135 | g = nccl_ops.all_sum(g) 136 | 137 | for dev, gg in zip(devices, g): 138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1]) 139 | 140 | # Apply updates separately on each device. 141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()): 142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev): 143 | # Scale gradients as needed. 144 | if self.use_loss_scaling or total_grads > 1: 145 | with tf.name_scope("Scale"): 146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef") 147 | coef = self.undo_loss_scaling(coef) 148 | grads = [(g * coef, v) for g, v in grads] 149 | 150 | # Check for overflows. 151 | with tf.name_scope("CheckOverflow"): 152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads])) 153 | 154 | # Update weights and adjust loss scaling. 155 | with tf.name_scope("UpdateWeights"): 156 | # pylint: disable=cell-var-from-loop 157 | opt = self._dev_opt[dev] 158 | ls_var = self.get_loss_scaling_var(dev) 159 | 160 | if not self.use_loss_scaling: 161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op)) 162 | else: 163 | ops.append(tf.cond(grad_ok, 164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)), 165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec)))) 166 | 167 | # Report statistics on the last device. 168 | if dev == devices[-1]: 169 | with tf.name_scope("Statistics"): 170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate)) 171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1))) 172 | 173 | if self.use_loss_scaling: 174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var)) 175 | 176 | # Initialize variables and group everything into a single op. 177 | self.reset_optimizer_state() 178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values())) 179 | 180 | return tf.group(*ops, name="TrainingOp") 181 | 182 | def reset_optimizer_state(self) -> None: 183 | """Reset internal state of the underlying optimizer.""" 184 | tfutil.assert_tf_initialized() 185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()]) 186 | 187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]: 188 | """Get or create variable representing log2 of the current dynamic loss scaling factor.""" 189 | if not self.use_loss_scaling: 190 | return None 191 | 192 | if device not in self._dev_ls_var: 193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None): 194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var") 195 | 196 | return self._dev_ls_var[device] 197 | 198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression: 199 | """Apply dynamic loss scaling for the given expression.""" 200 | assert tfutil.is_tf_expression(value) 201 | 202 | if not self.use_loss_scaling: 203 | return value 204 | 205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device)) 206 | 207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression: 208 | """Undo the effect of dynamic loss scaling for the given expression.""" 209 | assert tfutil.is_tf_expression(value) 210 | 211 | if not self.use_loss_scaling: 212 | return value 213 | 214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type 215 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/dnnlib/tflib/tfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Miscellaneous helper utils for Tensorflow.""" 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from typing import Any, Iterable, List, Union 15 | 16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation] 17 | """A type that represents a valid Tensorflow expression.""" 18 | 19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray] 20 | """A type that can be converted to a valid Tensorflow expression.""" 21 | 22 | 23 | def run(*args, **kwargs) -> Any: 24 | """Run the specified ops in the default session.""" 25 | assert_tf_initialized() 26 | return tf.get_default_session().run(*args, **kwargs) 27 | 28 | 29 | def is_tf_expression(x: Any) -> bool: 30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.""" 31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation)) 32 | 33 | 34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]: 35 | """Convert a Tensorflow shape to a list of ints.""" 36 | return [dim.value for dim in shape] 37 | 38 | 39 | def flatten(x: TfExpressionEx) -> TfExpression: 40 | """Shortcut function for flattening a tensor.""" 41 | with tf.name_scope("Flatten"): 42 | return tf.reshape(x, [-1]) 43 | 44 | 45 | def log2(x: TfExpressionEx) -> TfExpression: 46 | """Logarithm in base 2.""" 47 | with tf.name_scope("Log2"): 48 | return tf.log(x) * np.float32(1.0 / np.log(2.0)) 49 | 50 | 51 | def exp2(x: TfExpressionEx) -> TfExpression: 52 | """Exponent in base 2.""" 53 | with tf.name_scope("Exp2"): 54 | return tf.exp(x * np.float32(np.log(2.0))) 55 | 56 | 57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx: 58 | """Linear interpolation.""" 59 | with tf.name_scope("Lerp"): 60 | return a + (b - a) * t 61 | 62 | 63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression: 64 | """Linear interpolation with clip.""" 65 | with tf.name_scope("LerpClip"): 66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0) 67 | 68 | 69 | def absolute_name_scope(scope: str) -> tf.name_scope: 70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes.""" 71 | return tf.name_scope(scope + "/") 72 | 73 | 74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope: 75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes.""" 76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False) 77 | 78 | 79 | def _sanitize_tf_config(config_dict: dict = None) -> dict: 80 | # Defaults. 81 | cfg = dict() 82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is. 83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is. 84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info. 85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used. 86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed. 87 | 88 | # User overrides. 89 | if config_dict is not None: 90 | cfg.update(config_dict) 91 | return cfg 92 | 93 | 94 | def init_tf(config_dict: dict = None) -> None: 95 | """Initialize TensorFlow session using good default settings.""" 96 | # Skip if already initialized. 97 | if tf.get_default_session() is not None: 98 | return 99 | 100 | # Setup config dict and random seeds. 101 | cfg = _sanitize_tf_config(config_dict) 102 | np_random_seed = cfg["rnd.np_random_seed"] 103 | if np_random_seed is not None: 104 | np.random.seed(np_random_seed) 105 | tf_random_seed = cfg["rnd.tf_random_seed"] 106 | if tf_random_seed == "auto": 107 | tf_random_seed = np.random.randint(1 << 31) 108 | if tf_random_seed is not None: 109 | tf.set_random_seed(tf_random_seed) 110 | 111 | # Setup environment variables. 112 | for key, value in list(cfg.items()): 113 | fields = key.split(".") 114 | if fields[0] == "env": 115 | assert len(fields) == 2 116 | os.environ[fields[1]] = str(value) 117 | 118 | # Create default TensorFlow session. 119 | create_session(cfg, force_as_default=True) 120 | 121 | 122 | def assert_tf_initialized(): 123 | """Check that TensorFlow session has been initialized.""" 124 | if tf.get_default_session() is None: 125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().") 126 | 127 | 128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session: 129 | """Create tf.Session based on config dict.""" 130 | # Setup TensorFlow config proto. 131 | cfg = _sanitize_tf_config(config_dict) 132 | config_proto = tf.ConfigProto() 133 | for key, value in cfg.items(): 134 | fields = key.split(".") 135 | if fields[0] not in ["rnd", "env"]: 136 | obj = config_proto 137 | for field in fields[:-1]: 138 | obj = getattr(obj, field) 139 | setattr(obj, fields[-1], value) 140 | 141 | # Create session. 142 | session = tf.Session(config=config_proto) 143 | if force_as_default: 144 | # pylint: disable=protected-access 145 | session._default_session = session.as_default() 146 | session._default_session.enforce_nesting = False 147 | session._default_session.__enter__() # pylint: disable=no-member 148 | 149 | return session 150 | 151 | 152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None: 153 | """Initialize all tf.Variables that have not already been initialized. 154 | 155 | Equivalent to the following, but more efficient and does not bloat the tf graph: 156 | tf.variables_initializer(tf.report_uninitialized_variables()).run() 157 | """ 158 | assert_tf_initialized() 159 | if target_vars is None: 160 | target_vars = tf.global_variables() 161 | 162 | test_vars = [] 163 | test_ops = [] 164 | 165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 166 | for var in target_vars: 167 | assert is_tf_expression(var) 168 | 169 | try: 170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0")) 171 | except KeyError: 172 | # Op does not exist => variable may be uninitialized. 173 | test_vars.append(var) 174 | 175 | with absolute_name_scope(var.name.split(":")[0]): 176 | test_ops.append(tf.is_variable_initialized(var)) 177 | 178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited] 179 | run([var.initializer for var in init_vars]) 180 | 181 | 182 | def set_vars(var_to_value_dict: dict) -> None: 183 | """Set the values of given tf.Variables. 184 | 185 | Equivalent to the following, but more efficient and does not bloat the tf graph: 186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()] 187 | """ 188 | assert_tf_initialized() 189 | ops = [] 190 | feed_dict = {} 191 | 192 | for var, value in var_to_value_dict.items(): 193 | assert is_tf_expression(var) 194 | 195 | try: 196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op 197 | except KeyError: 198 | with absolute_name_scope(var.name.split(":")[0]): 199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies 200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter 201 | 202 | ops.append(setter) 203 | feed_dict[setter.op.inputs[1]] = value 204 | 205 | run(ops, feed_dict) 206 | 207 | 208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs): 209 | """Create tf.Variable with large initial value without bloating the tf graph.""" 210 | assert_tf_initialized() 211 | assert isinstance(initial_value, np.ndarray) 212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype) 213 | var = tf.Variable(zeros, *args, **kwargs) 214 | set_vars({var: initial_value}) 215 | return var 216 | 217 | 218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False): 219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range. 220 | Can be used as an input transformation for Network.run(). 221 | """ 222 | images = tf.cast(images, tf.float32) 223 | if nhwc_to_nchw: 224 | images = tf.transpose(images, [0, 3, 1, 2]) 225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255) 226 | 227 | 228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1): 229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range. 230 | Can be used as an output transformation for Network.run(). 231 | """ 232 | images = tf.cast(images, tf.float32) 233 | if shrink > 1: 234 | ksize = [1, 1, shrink, shrink] 235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") 236 | if nchw_to_nhwc: 237 | images = tf.transpose(images, [0, 2, 3, 1]) 238 | scale = 255 / (drange[1] - drange[0]) 239 | images = images * scale + (0.5 - drange[0] * scale) 240 | return tf.saturate_cast(images, tf.uint8) 241 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/loss_criterions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dave-epstein/blobgan/c527f1c27447dffe3cf4cf3901571a83ce59f1fe/scripts/style-gan-pytorch/loss_criterions/__init__.py -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/loss_criterions/base_loss_criterions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # source: https://github.com/facebookresearch/pytorch_GAN_zoo/blob/master/models/loss_criterions/base_loss_criterions.py 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class BaseLossWrapper: 8 | r""" 9 | Loss criterion class. Must define 4 members: 10 | sizeDecisionLayer : size of the decision layer of the discrimator 11 | 12 | getCriterion : how the loss is actually computed 13 | 14 | !! The activation function of the discriminator is computed within the 15 | loss !! 16 | """ 17 | 18 | def __init__(self, device): 19 | self.device = device 20 | 21 | def getCriterion(self, input, status): 22 | r""" 23 | Given an input tensor and its targeted status (detected as real or 24 | detected as fake) build the associated loss 25 | 26 | Args: 27 | 28 | - input (Tensor): decision tensor build by the model's discrimator 29 | - status (bool): if True -> this tensor should have been detected 30 | as a real input 31 | else -> it shouldn't have 32 | """ 33 | pass 34 | 35 | 36 | class MSE(BaseLossWrapper): 37 | r""" 38 | Mean Square error loss. 39 | """ 40 | 41 | def __init__(self, device): 42 | self.generationActivation = F.tanh 43 | self.sizeDecisionLayer = 1 44 | 45 | BaseLossWrapper.__init__(self, device) 46 | 47 | def getCriterion(self, input, status): 48 | size = input.size()[0] 49 | value = float(status) 50 | reference = torch.tensor([value]).expand(size, 1).to(self.device) 51 | return F.mse_loss(F.sigmoid(input[:, :self.sizeDecisionLayer]), 52 | reference) 53 | 54 | 55 | class WGANGP(BaseLossWrapper): 56 | r""" 57 | Paper WGANGP loss : linear activation for the generator. 58 | https://arxiv.org/pdf/1704.00028.pdf 59 | """ 60 | 61 | def __init__(self, device): 62 | 63 | self.generationActivation = None 64 | self.sizeDecisionLayer = 1 65 | 66 | BaseLossWrapper.__init__(self, device) 67 | 68 | def getCriterion(self, input, status): 69 | if status: 70 | return -input[:, 0].sum() 71 | return input[:, 0].sum() 72 | 73 | 74 | class Logistic(BaseLossWrapper): 75 | r""" 76 | "Which training method of GANs actually converge" 77 | https://arxiv.org/pdf/1801.04406.pdf 78 | """ 79 | 80 | def __init__(self, device): 81 | 82 | self.generationActivation = None 83 | self.sizeDecisionLayer = 1 84 | BaseLossWrapper.__init__(self, device) 85 | 86 | def getCriterion(self, input, status): 87 | if status: 88 | return F.softplus(-input[:, 0]).mean() 89 | return F.softplus(input[:, 0]).mean() 90 | 91 | 92 | class DCGAN(BaseLossWrapper): 93 | r""" 94 | Cross entropy loss. 95 | """ 96 | 97 | def __init__(self, device): 98 | 99 | self.generationActivation = F.tanh 100 | self.sizeDecisionLayer = 1 101 | 102 | BaseLossWrapper.__init__(self, device) 103 | 104 | def getCriterion(self, input, status): 105 | size = input.size()[0] 106 | value = int(status) 107 | reference = torch.tensor( 108 | [value], dtype=torch.float).expand(size).to(self.device) 109 | return F.binary_cross_entropy(torch.sigmoid(input[:, :self.sizeDecisionLayer]), reference) -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/loss_criterions/gradient_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | 5 | def WGANGPGradientPenalty(input, fake, discriminator, weight, backward=True): 6 | r""" 7 | Gradient penalty as described in 8 | "Improved Training of Wasserstein GANs" 9 | https://arxiv.org/pdf/1704.00028.pdf 10 | 11 | Args: 12 | 13 | - input (Tensor): batch of real data 14 | - fake (Tensor): batch of generated data. Must have the same size 15 | as the input 16 | - discrimator (nn.Module): discriminator network 17 | - weight (float): weight to apply to the penalty term 18 | - backward (bool): loss backpropagation 19 | """ 20 | 21 | batchSize = input.size(0) 22 | alpha = torch.rand(batchSize, 1) 23 | alpha = alpha.expand(batchSize, int(input.nelement() / 24 | batchSize)).contiguous().view( 25 | input.size()) 26 | alpha = alpha.to(input.device) 27 | interpolates = alpha * input + ((1 - alpha) * fake) 28 | 29 | interpolates = torch.autograd.Variable( 30 | interpolates, requires_grad=True) 31 | 32 | decisionInterpolate = discriminator(interpolates, False) 33 | decisionInterpolate = decisionInterpolate[:, 0].sum() 34 | 35 | gradients = torch.autograd.grad(outputs=decisionInterpolate, 36 | inputs=interpolates, 37 | create_graph=True, retain_graph=True) 38 | 39 | gradients = gradients[0].view(batchSize, -1) 40 | gradients = (gradients * gradients).sum(dim=1).sqrt() 41 | gradient_penalty = (((gradients - 1.0)**2)).sum() * weight 42 | 43 | if backward: 44 | gradient_penalty.backward(retain_graph=True) 45 | 46 | return gradient_penalty.item() 47 | 48 | 49 | def logisticGradientPenalty(input, discrimator, res, alpha, weight, backward=True): 50 | r""" 51 | Gradient penalty described in "Which training method of GANs actually 52 | converge 53 | https://arxiv.org/pdf/1801.04406.pdf 54 | 55 | Args: 56 | 57 | - input (Tensor): batch of real data 58 | - discrimator (nn.Module): discriminator network 59 | - weight (float): weight to apply to the penalty term 60 | - backward (bool): loss backpropagation 61 | """ 62 | 63 | locInput = torch.autograd.Variable( 64 | input, requires_grad=True) 65 | gradients = torch.autograd.grad(outputs=discrimator(locInput, res, alpha)[:, 0].sum(), 66 | inputs=locInput, 67 | create_graph=True, retain_graph=True)[0] 68 | 69 | gradients = gradients.view(gradients.size(0), -1) 70 | gradients = (gradients * gradients).sum(dim=1).mean() 71 | 72 | gradient_penalty = gradients * weight 73 | if backward: 74 | gradient_penalty.backward(retain_graph=True) 75 | 76 | return gradient_penalty.item() 77 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dave-epstein/blobgan/c527f1c27447dffe3cf4cf3901571a83ce59f1fe/scripts/style-gan-pytorch/networks/__init__.py -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/networks/building_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from networks.custom_layers import * 5 | 6 | 7 | class LayerEpilogue(nn.Module): 8 | """ 9 | Things to do at the end of each layer 10 | 1. mixin scaled noise 11 | 2. mixin style with AdaIN 12 | """ 13 | def __init__(self, 14 | num_channels, 15 | dlatent_size, # Disentangled latent (W) dimensionality, 16 | use_wscale, # Enable equalized learning rate? 17 | use_pixel_norm, # Enable pixel-wise feature vector normalization? 18 | use_instance_norm, 19 | use_noise, 20 | use_styles, 21 | nonlinearity, 22 | ): 23 | super(LayerEpilogue, self).__init__() 24 | 25 | act = { 26 | 'relu': torch.relu, 27 | 'lrelu': nn.LeakyReLU(negative_slope=0.2) 28 | }[nonlinearity] 29 | 30 | layers = [] 31 | if use_noise: 32 | layers.append(('noise', NoiseMixin(num_channels))) 33 | layers.append(('act', act)) 34 | 35 | # to follow the tf implementation 36 | if use_pixel_norm: 37 | layers.append(('pixel_norm', NormalizationLayer())) 38 | if use_instance_norm: 39 | layers.append(('instance_norm', nn.InstanceNorm2d(num_channels))) 40 | # now we need to mixin styles 41 | self.pre_style_op = nn.Sequential(OrderedDict(layers)) 42 | 43 | if use_styles: 44 | self.style_mod = StyleMixin(dlatent_size, 45 | num_channels, 46 | use_wscale=use_wscale) 47 | def forward(self, x, dlatent): 48 | # dlatent is w 49 | x = self.pre_style_op(x) 50 | if self.style_mod: 51 | x = self.style_mod(x, dlatent) 52 | return x 53 | 54 | 55 | class EarlySynthesisBlock(nn.Module): 56 | """ 57 | The first block for 4x4 resolution 58 | """ 59 | def __init__(self, 60 | in_channels, 61 | dlatent_size, 62 | const_input_layer, 63 | use_wscale, 64 | use_noise, 65 | use_pixel_norm, 66 | use_instance_norm, 67 | use_styles, 68 | nonlinearity 69 | ): 70 | super(EarlySynthesisBlock, self).__init__() 71 | self.const_input_layer = const_input_layer 72 | self.in_channels = in_channels 73 | 74 | if const_input_layer: 75 | self.const = nn.Parameter(torch.ones(1, in_channels, 4, 4)) 76 | self.bias = nn.Parameter(torch.ones(in_channels)) 77 | else: 78 | self.dense = EqualizedLinear(dlatent_size, in_channels * 16, use_wscale=use_wscale) 79 | 80 | self.epi0 = LayerEpilogue(num_channels=in_channels, 81 | dlatent_size=dlatent_size, 82 | use_wscale=use_wscale, 83 | use_noise=use_noise, 84 | use_pixel_norm=use_pixel_norm, 85 | use_instance_norm=use_instance_norm, 86 | use_styles=use_styles, 87 | nonlinearity=nonlinearity 88 | ) 89 | # kernel size must be 3 or other odd numbers 90 | # so that we have 'same' padding 91 | self.conv = EqualizedConv2d(in_channels=in_channels, 92 | out_channels=in_channels, 93 | kernel_size=3, 94 | padding=3//2) 95 | 96 | self.epi1 = LayerEpilogue(num_channels=in_channels, 97 | dlatent_size=dlatent_size, 98 | use_wscale=use_wscale, 99 | use_noise=use_noise, 100 | use_pixel_norm=use_pixel_norm, 101 | use_instance_norm=use_instance_norm, 102 | use_styles=use_styles, 103 | nonlinearity=nonlinearity 104 | ) 105 | 106 | def forward(self, dlatents): 107 | # note dlatents is broadcast one 108 | dlatents_0 = dlatents[:, 0] 109 | dlatents_1 = dlatents[:, 1] 110 | batch_size = dlatents.size(0) 111 | if self.const_input_layer: 112 | x = self.const.expand(batch_size, -1, -1, -1) 113 | x = x + self.bias.view(1, -1, 1, 1) 114 | else: 115 | x = self.dense(dlatents_0).view(batch_size, self.in_channels, 4, 4) 116 | 117 | x = self.epi0(x, dlatents_0) 118 | x = self.conv(x) 119 | x = self.epi1(x, dlatents_1) 120 | return x 121 | 122 | 123 | class LaterSynthesisBlock(nn.Module): 124 | """ 125 | The following blocks for res 8x8...etc. 126 | """ 127 | 128 | def __init__(self, 129 | in_channels, 130 | out_channels, 131 | dlatent_size, 132 | use_wscale, 133 | use_noise, 134 | use_pixel_norm, 135 | use_instance_norm, 136 | use_styles, 137 | nonlinearity, 138 | blur_filter, 139 | res, 140 | ): 141 | super(LaterSynthesisBlock, self).__init__() 142 | 143 | # res = log2(H), H is 4, 8, 16, 32 ... 1024 144 | 145 | assert isinstance(res, int) and (2 <= res <= 10) 146 | 147 | self.res = res 148 | 149 | if blur_filter: 150 | self.blur = Blur2d(blur_filter) 151 | #blur = Blur2d(blur_filter) 152 | else: 153 | self.blur = None 154 | 155 | # name 'conv0_up' is used in tf implementation 156 | self.conv0_up = Upscale2dConv2d(res=res, 157 | in_channels=in_channels, 158 | out_channels=out_channels, 159 | kernel_size=3, 160 | use_wscale=use_wscale) 161 | # self.conv0_up = Upscale2dConv2d2( 162 | # input_channels=in_channels, 163 | # output_channels=out_channels, 164 | # kernel_size=3, 165 | # gain=np.sqrt(2), 166 | # use_wscale=use_wscale, 167 | # intermediate=blur, 168 | # upscale=True 169 | # ) 170 | 171 | self.epi0 = LayerEpilogue(num_channels=out_channels, 172 | dlatent_size=dlatent_size, 173 | use_wscale=use_wscale, 174 | use_pixel_norm=use_pixel_norm, 175 | use_noise=use_noise, 176 | use_instance_norm=use_instance_norm, 177 | use_styles=use_styles, 178 | nonlinearity=nonlinearity) 179 | 180 | # name 'conv1' is used in tf implementation 181 | # kernel size must be 3 or other odd numbers 182 | # so that we have 'same' padding 183 | # no upsclaing 184 | self.conv1 = EqualizedConv2d(in_channels=out_channels, 185 | out_channels=out_channels, 186 | kernel_size=3, 187 | padding=3//2) 188 | 189 | self.epi1 = LayerEpilogue(num_channels=out_channels, 190 | dlatent_size=dlatent_size, 191 | use_wscale=use_wscale, 192 | use_pixel_norm=use_pixel_norm, 193 | use_noise=use_noise, 194 | use_instance_norm=use_instance_norm, 195 | use_styles=use_styles, 196 | nonlinearity=nonlinearity) 197 | 198 | 199 | def forward(self, x, dlatents): 200 | 201 | x = self.conv0_up(x) 202 | if self.blur is not None: 203 | x = self.blur(x) 204 | x = self.epi0(x, dlatents[:, self.res * 2 - 4]) 205 | x = self.conv1(x) 206 | x = self.epi1(x, dlatents[:, self.res * 2 - 3]) 207 | return x 208 | 209 | 210 | class EarlyDiscriminatorBlock(nn.Sequential): 211 | def __init__(self, 212 | res, 213 | in_channels, 214 | out_channels, 215 | use_wscale, 216 | blur_filter, 217 | fused_scale, 218 | nonlinearity): 219 | act = { 220 | 'relu': torch.relu, 221 | 'lrelu': nn.LeakyReLU(negative_slope=0.2) 222 | }[nonlinearity] 223 | 224 | layers = [] 225 | 226 | layers.append(('conv0', EqualizedConv2d(in_channels=in_channels, 227 | out_channels=in_channels, 228 | kernel_size=3, 229 | padding=3//2, 230 | use_wscale=use_wscale))) 231 | # note that we don't have layer epilogue in discriminator, so we need to add activation layer mannually 232 | layers.append(('act0', act)) 233 | 234 | layers.append(('blur', Blur2d(blur_filter))) 235 | 236 | layers.append(('conv1_down', Downscale2dConv2d(res=res, 237 | in_channels=in_channels, 238 | out_channels=out_channels, 239 | kernel_size=3, 240 | fused_scale=fused_scale, 241 | use_wscale=use_wscale))) 242 | layers.append(('act1', act)) 243 | 244 | super().__init__(OrderedDict(layers)) 245 | 246 | 247 | class LaterDiscriminatorBlock(nn.Sequential): 248 | 249 | def __init__(self, 250 | in_channels, 251 | out_channels, 252 | use_wscale, 253 | nonlinearity, 254 | mbstd_group_size, 255 | mbstd_num_features, 256 | res, 257 | ): 258 | act = { 259 | 'relu': torch.relu, 260 | 'lrelu': nn.LeakyReLU(negative_slope=0.2) 261 | }[nonlinearity] 262 | 263 | resolution = 2 ** res 264 | layers = [] 265 | layers.append(('minibatchstddev', MiniBatchStdDev(mbstd_group_size, mbstd_num_features))) 266 | layers.append(('conv', EqualizedConv2d(in_channels=in_channels + mbstd_num_features, 267 | out_channels=in_channels, 268 | kernel_size=3, 269 | padding=3//2, 270 | use_wscale=use_wscale))) 271 | layers.append(('act0', act)) 272 | layers.append(('flatten', Flatten())) 273 | layers.append(('dense0', EqualizedLinear(in_channels=in_channels * (resolution**2), 274 | out_channels=in_channels, 275 | use_wscale=use_wscale))) 276 | layers.append(('act1', act)) 277 | # no activation for the last fc 278 | layers.append(('dense1', EqualizedLinear(in_channels=in_channels, 279 | out_channels=out_channels))) 280 | 281 | super().__init__(OrderedDict(layers)) 282 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/networks/style_gan_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from collections import OrderedDict 6 | from networks.custom_layers import EqualizedLinear, EqualizedConv2d, \ 7 | NormalizationLayer, _upscale2d 8 | from networks.building_blocks import EarlySynthesisBlock, LaterSynthesisBlock, \ 9 | EarlyDiscriminatorBlock, LaterDiscriminatorBlock 10 | 11 | class MappingNet(nn.Sequential): 12 | """ 13 | A mapping network f implemented using an 8-layer MLP 14 | """ 15 | def __init__(self, 16 | resolution = 1024, 17 | num_layers = 8, 18 | dlatent_size = 512, 19 | normalize_latents = True, 20 | nonlinearity = 'lrelu', 21 | maping_lrmul = 0.01, # We thus reduce the learning rate by two orders of magnitude for the mapping network 22 | **kwargs): # other parameters are ignored 23 | 24 | resolution_log2: int = int(np.log2(resolution)) 25 | 26 | assert resolution == 2**resolution_log2 and 4 <= resolution <= 1024 27 | 28 | act = { 29 | 'relu': torch.relu, 30 | 'lrelu': nn.LeakyReLU(negative_slope=0.2) 31 | }[nonlinearity] 32 | 33 | self.dlatent_broadcast = resolution_log2 * 2 - 2 34 | layers = [] 35 | if normalize_latents: 36 | layers.append(('pixel_norm', NormalizationLayer())) 37 | for i in range(num_layers): 38 | layers.append(('dense{}'.format(i), EqualizedLinear(dlatent_size, 39 | dlatent_size, 40 | use_wscale=True, 41 | lrmul=maping_lrmul))) 42 | layers.append(('dense{}_act'.format(i), act)) 43 | 44 | super().__init__(OrderedDict(layers)) 45 | 46 | def forward(self, x): 47 | # N x 512 48 | w = super().forward(x) 49 | if self.dlatent_broadcast is not None: 50 | # broadcast 51 | # tf.tile in the official tf implementation: 52 | # w = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1]) 53 | w = w.unsqueeze(1).expand(-1, self.dlatent_broadcast, -1) 54 | return w 55 | 56 | 57 | class SynthesisNet(nn.Module): 58 | """ 59 | Synthesis network 60 | """ 61 | def __init__(self, 62 | dlatent_size = 512, 63 | num_channels = 3, 64 | resolution = 1024, 65 | fmap_base = 8192, 66 | fmap_decay = 1.0, 67 | fmap_max = 512, 68 | use_styles = True, 69 | const_input_layer = True, 70 | use_noise = True, 71 | nonlinearity = 'lrelu', 72 | use_wscale = True, 73 | use_pixel_norm = False, 74 | use_instance_norm = True, 75 | blur_filter = [1, 2, 1], # low-pass filer to apply when resampling activations. None = no filtering 76 | **kwargs # other parameters are ignored 77 | ): 78 | super(SynthesisNet, self).__init__() 79 | 80 | # copied from tf implementation 81 | 82 | resolution_log2: int = int(np.log2(resolution)) 83 | 84 | assert resolution == 2**resolution_log2 and 4 <= resolution <= 1024 85 | 86 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 87 | 88 | act = { 89 | 'relu': torch.relu, 90 | 'lrelu': nn.LeakyReLU(negative_slope=0.2) 91 | }[nonlinearity] 92 | 93 | num_layers = resolution_log2 * 2 - 2 94 | 95 | num_styles = num_layers if use_styles else 1 96 | 97 | blocks = [] 98 | torgbs = [] 99 | 100 | # 2....10 (inclusive) for 1024 resolution 101 | for res in range(2, resolution_log2 + 1): 102 | channels = nf(res - 1) 103 | block_name = '{s}x{s}'.format(s=2**res) 104 | torgb_name = 'torgb_lod{}'.format(resolution_log2 - res) 105 | if res == 2: 106 | # early block 107 | block = (block_name, EarlySynthesisBlock(channels, 108 | dlatent_size, 109 | const_input_layer, 110 | use_wscale, 111 | use_noise, 112 | use_pixel_norm, 113 | use_instance_norm, 114 | use_styles, 115 | nonlinearity)) 116 | else: 117 | block = (block_name, LaterSynthesisBlock(last_channels, 118 | out_channels=channels, 119 | dlatent_size=dlatent_size, 120 | use_wscale=use_wscale, 121 | use_noise=use_noise, 122 | use_pixel_norm=use_pixel_norm, 123 | use_instance_norm=use_instance_norm, 124 | use_styles=use_styles, 125 | nonlinearity=nonlinearity, 126 | blur_filter=blur_filter, 127 | res=res, 128 | )) 129 | 130 | # torgb block 131 | torgb = (torgb_name, EqualizedConv2d(channels, num_channels, 1, use_wscale=use_wscale)) 132 | 133 | blocks.append(block) 134 | torgbs.append(torgb) 135 | last_channels = channels 136 | 137 | # the last one has bias 138 | self.torgbs = nn.ModuleDict(OrderedDict(torgbs)) 139 | 140 | #self.torgb = Upscale2dConv2d2(channels, num_channels, 1, gain=1, use_wscale=use_wscale, bias=True) 141 | self.blocks = nn.ModuleDict(OrderedDict(blocks)) 142 | 143 | 144 | def forward(self, dlatents, res, alpha): 145 | assert 2 <= res <= 10 146 | # step 1...9 147 | step = res - 1 148 | block_list = list(self.blocks.values())[:step] 149 | torgb_list = list(self.torgbs.values())[:step] 150 | 151 | # starting from 8x8 we have skip connections 152 | if step > 1: 153 | skip_torgb = torgb_list[-2] 154 | this_rgb = torgb_list[-1] 155 | 156 | for i, block in enumerate(block_list): 157 | 158 | if i == 0: 159 | x = block(dlatents) 160 | else: 161 | x = block(x, dlatents) 162 | 163 | # step - 1 is the last index 164 | # so step - 2 is the second last 165 | if i == step - 2: 166 | # get the skip result 167 | skip_x = _upscale2d(skip_torgb(x), 2) 168 | 169 | # finally for current resolution, to rgb: 170 | x = this_rgb(x) 171 | 172 | x = (1 - alpha) * skip_x + alpha * x 173 | 174 | return x 175 | 176 | 177 | # a convenient wrapping class 178 | class Generator(nn.Sequential): 179 | def __init__(self, **kwargs): 180 | super().__init__(OrderedDict([ 181 | ('g_mapping', MappingNet(**kwargs)), 182 | ('g_synthesis', SynthesisNet(**kwargs)) 183 | ])) 184 | 185 | def forward(self, latents, res, alpha): 186 | dlatents = self.g_mapping(latents) 187 | x = self.g_synthesis(dlatents, res, alpha) 188 | return x 189 | 190 | 191 | class BasicDiscriminator(nn.Module): 192 | 193 | def __init__(self, 194 | num_channels = 3, 195 | resolution = 1024, 196 | fmap_base = 8192, 197 | fmap_decay = 1.0, 198 | fmap_max = 512, 199 | nonlinearity = 'lrelu', 200 | mbstd_group_size = 4, 201 | mbstd_num_features = 1, 202 | use_wscale = True, 203 | fused_scale = 'auto', 204 | blur_filter = [1, 2, 1], 205 | ): 206 | super(BasicDiscriminator, self).__init__() 207 | 208 | resolution_log2: int = int(np.log2(resolution)) 209 | 210 | assert resolution == 2**resolution_log2 and 4 <= resolution <= 1024 211 | 212 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) 213 | 214 | act = { 215 | 'relu': torch.relu, 216 | 'lrelu': nn.LeakyReLU(negative_slope=0.2) 217 | }[nonlinearity] 218 | # this is fixed. We need to grow it... 219 | blocks = [] 220 | fromrgbs = [] 221 | for res in range(resolution_log2, 1, -1): 222 | block_name = '{s}x{s}'.format(s=2 ** res) 223 | fromrgb_name = 'fromrgb_lod{}'.format(resolution_log2 - res) 224 | if res != 2: 225 | blocks.append((block_name, EarlyDiscriminatorBlock(res=res, 226 | in_channels=nf(res-1), 227 | out_channels=nf(res-2), 228 | use_wscale=use_wscale, 229 | blur_filter=blur_filter, 230 | fused_scale=fused_scale, 231 | nonlinearity=nonlinearity))) 232 | else: 233 | blocks.append((block_name, LaterDiscriminatorBlock(in_channels=nf(res), 234 | out_channels=1, 235 | mbstd_group_size=mbstd_group_size, 236 | mbstd_num_features=mbstd_num_features, 237 | use_wscale=use_wscale, 238 | nonlinearity=nonlinearity, 239 | res=2, 240 | ))) 241 | 242 | fromrgbs.append((fromrgb_name, EqualizedConv2d(num_channels, nf(res - 1), 1, use_wscale=use_wscale))) 243 | 244 | self.blocks = nn.ModuleDict(OrderedDict(blocks)) 245 | self.fromrgbs = nn.ModuleDict(OrderedDict(fromrgbs)) 246 | 247 | 248 | def forward(self, x, res, alpha): 249 | assert 2 <= res <= 10 250 | # step 1...9 251 | step = res - 1 252 | block_list = list(self.blocks.values())[-step:] 253 | fromrgb_list = list(self.fromrgbs.values())[-step:] 254 | 255 | if step > 1: 256 | skip_fromrgb = fromrgb_list[1] 257 | this_fromrgb = fromrgb_list[0] 258 | 259 | for i, block in enumerate(block_list): 260 | if i == 0: 261 | skip_x = skip_fromrgb(F.avg_pool2d(x, 2)) 262 | x = block(this_fromrgb(x)) 263 | x = (1 - alpha) * skip_x + alpha * x 264 | else: 265 | x = block(x) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /scripts/style-gan-pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | if isinstance(v, bool): 5 | return v 6 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 7 | return True 8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 9 | return False 10 | else: 11 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dave-epstein/blobgan/c527f1c27447dffe3cf4cf3901571a83ce59f1fe/src/__init__.py -------------------------------------------------------------------------------- /src/configs/checkpoint/after_each_epoch.yaml: -------------------------------------------------------------------------------- 1 | save_top_k: 3 2 | every_n_epochs: 1 3 | save_last: true 4 | monitor: validate_total_loss -------------------------------------------------------------------------------- /src/configs/checkpoint/after_each_epoch_fid.yaml: -------------------------------------------------------------------------------- 1 | save_top_k: -1 2 | every_n_epochs: 1 3 | save_last: true 4 | monitor: train/fid 5 | save_on_train_epoch_end: true 6 | auto_insert_metric_name: true -------------------------------------------------------------------------------- /src/configs/checkpoint/every_n_train_steps.yaml: -------------------------------------------------------------------------------- 1 | every_n_train_steps: 3000 2 | save_top_k: -1 3 | mode: max 4 | monitor: step -------------------------------------------------------------------------------- /src/configs/checkpoint/every_n_train_steps_fid.yaml: -------------------------------------------------------------------------------- 1 | every_n_train_steps: 3000 2 | save_top_k: -1 3 | monitor: train/fid -------------------------------------------------------------------------------- /src/configs/dataset/imagefolder.yaml: -------------------------------------------------------------------------------- 1 | name: ImageFolderDataModule 2 | resolution: 128 3 | dataloader: 4 | num_workers: 12 5 | batch_size: 24 6 | drop_last: true -------------------------------------------------------------------------------- /src/configs/dataset/lsun.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - imagefolder 3 | category: bedroom 4 | basepath: /path/to/lsun # Must have train, validate, test subfolders (validate/test can be empty) 5 | path: ${.basepath}/${.category} 6 | dataloader: 7 | batch_size: 24 -------------------------------------------------------------------------------- /src/configs/dataset/multiimagefolder.yaml: -------------------------------------------------------------------------------- 1 | name: MultiImageFolderDataModule 2 | resolution: 128 3 | dataloader: 4 | num_workers: 12 5 | batch_size: 24 6 | drop_last: true -------------------------------------------------------------------------------- /src/configs/dataset/multilsun.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - multiimagefolder 3 | categories: [kitchen,living,dining] 4 | category: null 5 | dataloader: 6 | batch_size: 24 -------------------------------------------------------------------------------- /src/configs/dataset/nodata.yaml: -------------------------------------------------------------------------------- 1 | name: NullDataModule 2 | dataloader: 3 | num_workers: 4 4 | batch_size: 128 -------------------------------------------------------------------------------- /src/configs/dataset/other_image_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - imagefolder 3 | path: /path/to/images # Must have train, validate, test subfolders (validate/test can be empty) 4 | dataloader: 5 | batch_size: 24 -------------------------------------------------------------------------------- /src/configs/experiment/blobgan.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /checkpoint: every_n_train_steps 4 | - /dataset: lsun 5 | checkpoint: 6 | every_n_train_steps: 1500 7 | wandb: 8 | name: BlobGAN 9 | dataset: 10 | category: bedroom 11 | resolution: ${model.resolution} 12 | dataloader: 13 | batch_size: 24 14 | drop_last: true 15 | model: 16 | name: BlobGAN 17 | lr: 0.002 18 | dim: 512 19 | noise_dim: 512 20 | resolution: 256 21 | lambda: # Needed for convenience since can't input λ on command line 22 | D_real: 1 23 | D_fake: 1 24 | D_R1: 50 25 | G: 1 26 | G_path: 2 27 | G_feature_mean: 10 28 | G_feature_variance: 10 29 | discriminator: 30 | name: StyleGANDiscriminator 31 | size: ${model.resolution} 32 | generator: 33 | name: models.networks.layoutstylegan.LayoutStyleGANGenerator 34 | style_dim: ${model.dim} 35 | n_mlp: 8 36 | size_in: 16 37 | c_model: 96 38 | spatial_style: ${model.spatial_style} 39 | size: ${model.resolution} 40 | layout_net: 41 | name: models.networks.layoutnet.LayoutGenerator 42 | n_features_max: ${model.n_features_max} 43 | feature_dim: 768 44 | style_dim: ${model.dim} 45 | noise_dim: ${model.noise_dim} 46 | norm_features: true 47 | mlp_lr_mul: 0.01 48 | mlp_hidden_dim: 1024 49 | spatial_style: ${model.spatial_style} 50 | D_reg_every: 16 51 | G_reg_every: -1 52 | λ: ${.lambda} 53 | log_images_every_n_steps: 1000 54 | n_features_min: ${model.n_features} 55 | n_features_max: ${model.n_features} 56 | n_features: 10 57 | spatial_style: true 58 | trainer: 59 | limit_val_batches: 0 60 | precision: 32 61 | plugins: null 62 | deterministic: false -------------------------------------------------------------------------------- /src/configs/experiment/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | dataset: 3 | dataloader: 4 | num_workers: 0 5 | persistent_workers: false 6 | batch_size: 4 7 | trainer: 8 | gpus: 1 9 | accelerator: null 10 | plugins: null 11 | overfit_batches: 20 12 | wandb: 13 | group: debug 14 | detect_anomalies: true 15 | logger: false -------------------------------------------------------------------------------- /src/configs/experiment/gan.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /checkpoint: every_n_train_steps 4 | - /dataset: lsun 5 | wandb: 6 | name: GAN 7 | dataset: 8 | resolution: ${model.resolution} 9 | dataloader: 10 | drop_last: true 11 | model: 12 | name: GAN 13 | lr: 0.002 14 | dim: 512 15 | resolution: 256 16 | lambda: # Needed for convenience since can't input λ on command line 17 | D_real: 1 18 | D_fake: 1 19 | D_R1: 50 20 | G: 1 21 | G_path: 2 22 | discriminator: 23 | name: StyleGANDiscriminator 24 | size: ${model.resolution} 25 | generator: 26 | name: models.networks.stylegan.StyleGANGenerator 27 | style_dim: 512 28 | dim: 512 29 | n_mlp: 8 30 | size: ${model.resolution} 31 | D_reg_every: 16 32 | λ: ${.lambda} 33 | log_images_every_n_steps: 1000 34 | trainer: 35 | limit_val_batches: 0 36 | precision: 32 37 | plugins: null -------------------------------------------------------------------------------- /src/configs/experiment/invertblobgan.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /dataset: lsun 4 | checkpoint: 5 | save_top_k: 3 6 | save_last: true 7 | monitor: validate/total_loss 8 | every_n_train_steps: 5000 9 | wandb: 10 | name: InvertBlobGAN 11 | dataset: 12 | category: bedroom 13 | resolution: ${model.G.resolution} 14 | dataloader: 15 | batch_size: 16 16 | drop_last: true 17 | model: 18 | name: BlobGANInverter 19 | lr: 0.002 20 | log_images_every_n_steps: 1000 21 | lambda: # Needed for convenience since can't input λ on command line 22 | real_LPIPS: 1 23 | real_MSE: 1 24 | fake_LPIPS: 1 25 | fake_MSE: 1 26 | fake_latents_MSE: 1 27 | λ: ${.lambda} 28 | G_pretrained: 29 | key: state_dict 30 | log_dir: null # Defaults to $PWD/logs 31 | project: ${wandb.project} 32 | generator: ${model.G} 33 | generator_pretrained: ${model.G_pretrained} 34 | inverter: 35 | name: StyleGANDiscriminator 36 | size: ${model.G.resolution} 37 | discriminate_stddev: false 38 | G: 39 | lr: 0.002 40 | dim: 512 41 | noise_dim: 512 42 | resolution: 256 43 | lambda: # Needed for convenience since can't input λ on command line 44 | D_real: 1 45 | D_fake: 1 46 | D_R1: 50 47 | G: 1 48 | G_path: 2 49 | G_feature_mean: 10 50 | G_feature_variance: 10 51 | discriminator: 52 | name: StyleGANDiscriminator 53 | size: ${model.G.resolution} 54 | generator: 55 | name: models.networks.layoutstylegan.LayoutStyleGANGenerator 56 | style_dim: ${model.G.dim} 57 | n_mlp: 8 58 | size_in: 16 59 | c_model: 96 60 | spatial_style: ${model.G.spatial_style} 61 | size: ${model.G.resolution} 62 | layout_net: 63 | name: models.networks.layoutnet.LayoutGenerator 64 | n_features_max: ${model.G.n_features_max} 65 | feature_dim: 768 66 | style_dim: ${model.G.dim} 67 | noise_dim: ${model.G.noise_dim} 68 | norm_features: true 69 | mlp_lr_mul: 0.01 70 | mlp_hidden_dim: 1024 71 | spatial_style: ${model.G.spatial_style} 72 | D_reg_every: 16 73 | G_reg_every: -1 74 | λ: ${.lambda} 75 | log_images_every_n_steps: 1000 76 | n_features_min: ${model.G.n_features} 77 | n_features_max: ${model.G.n_features} 78 | n_features: 10 79 | spatial_style: true 80 | trainer: 81 | precision: 32 82 | plugins: null 83 | deterministic: false -------------------------------------------------------------------------------- /src/configs/experiment/jitter.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | feature_jitter_xy: 0.04 4 | feature_jitter_shift: 0.5 5 | feature_jitter_angle: 0.1 -------------------------------------------------------------------------------- /src/configs/experiment/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | dataset: 3 | basepath: /path/to/your/lsun # Change to your path 4 | trainer: 5 | gpus: YOUR_NGPUS # Change to your number of GPUs 6 | wandb: # Fill in your settings 7 | group: YOUR_GROUP 8 | project: YOUR_PROJECT 9 | entity: YOUR_ENTITY -------------------------------------------------------------------------------- /src/configs/fit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | run: 4 | dir: . 5 | output_subdir: null 6 | resume: 7 | id: null 8 | step: null 9 | epoch: null 10 | last: true 11 | best: false 12 | clobber_hparams: false 13 | project: ${wandb.project} 14 | log_dir: ${wandb.log_dir} 15 | model_only: false 16 | logger: wandb 17 | wandb: 18 | save_code: true 19 | offline: false 20 | log_dir: ./logs 21 | id: ${resume.id} 22 | trainer: 23 | accelerator: ddp 24 | benchmark: false 25 | deterministic: true 26 | gpus: 8 27 | precision: 16 28 | plugins: null 29 | max_steps: 10000000 30 | profiler: simple 31 | num_sanity_val_steps: 0 32 | log_every_n_steps: 200 33 | dataset: 34 | dataloader: 35 | prefetch_factor: 2 36 | pin_memory: true 37 | drop_last: true 38 | persistent_workers: true 39 | mode: fit 40 | seed: 0 -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from utils import to_dataclass_cfg 4 | from .nodata import * 5 | from .imagefolder import * 6 | from .multiimagefolder import * 7 | 8 | def get_datamodule(name: str, **kwargs) -> LightningDataModule: 9 | cls = getattr(sys.modules[__name__], name) 10 | return cls(**to_dataclass_cfg(kwargs, cls)) 11 | -------------------------------------------------------------------------------- /src/data/imagefolder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Any, Optional, Union, Dict 6 | 7 | from pytorch_lightning import LightningDataModule 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.transforms import InterpolationMode 11 | 12 | from data.nodata import NullIterableDataset 13 | from data.utils import ImageFolderWithFilenames 14 | from utils import print_once 15 | 16 | _all__ = ['ImageFolderDataModule'] 17 | 18 | 19 | @dataclass 20 | class ImageFolderDataModule(LightningDataModule): 21 | path: Union[str, Path] # Root 22 | dataloader: Dict[str, Any] 23 | resolution: int = 256 # Image dimension 24 | 25 | def __post_init__(self): 26 | super().__init__() 27 | self.path = Path(self.path) 28 | self.stats = {'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5)} 29 | self.transform = transforms.Compose([ 30 | t for t in [ 31 | transforms.Resize(self.resolution, InterpolationMode.LANCZOS), 32 | transforms.CenterCrop(self.resolution), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | transforms.Normalize(self.stats['mean'], self.stats['std'], inplace=True), 36 | ] 37 | ]) 38 | self.data = {} 39 | 40 | def setup(self, stage: Optional[str] = None): 41 | for split in ('train', 'validate', 'test'): 42 | path = self.path / split 43 | empty = True 44 | if path.exists(): 45 | try: 46 | self.data[split] = ImageFolderWithFilenames(path, transform=self.transform) 47 | empty = False 48 | except FileNotFoundError: 49 | pass 50 | if empty: 51 | print_once( 52 | f'Warning: no images found in {path}. Using empty dataset for split {split}. ' 53 | f'Perhaps you set `dataset.path` incorrectly?') 54 | self.data[split] = NullIterableDataset(1) 55 | 56 | def train_dataloader(self) -> DataLoader: 57 | return self._get_dataloader('train') 58 | 59 | def val_dataloader(self) -> DataLoader: 60 | return self._get_dataloader('validate') 61 | 62 | def test_dataloader(self) -> DataLoader: 63 | return self._get_dataloader('test') 64 | 65 | def _get_dataloader(self, split: str): 66 | return DataLoader(self.data[split], **self.dataloader) 67 | -------------------------------------------------------------------------------- /src/data/multiimagefolder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Any, Optional, Union, Dict, List, Callable 7 | 8 | from pytorch_lightning import LightningDataModule 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision import transforms 11 | from torchvision.transforms import InterpolationMode 12 | 13 | from data.utils import ImageFolderWithFilenames 14 | from utils import print_once 15 | 16 | _all__ = ['MultiImageFolderDataModule'] 17 | 18 | 19 | @dataclass 20 | class MultiImageFolderDataModule(LightningDataModule): 21 | basepath: Union[str, Path] # Root 22 | categories: List[str] 23 | dataloader: Dict[str, Any] 24 | resolution: int = 256 # Image dimension 25 | 26 | def __post_init__(self): 27 | super().__init__() 28 | self.path = Path(self.basepath) 29 | self.stats = {'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5)} 30 | self.transform = transforms.Compose([ 31 | t for t in [ 32 | transforms.Resize(self.resolution, InterpolationMode.LANCZOS), 33 | transforms.CenterCrop(self.resolution), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize(self.stats['mean'], self.stats['std'], inplace=True), 37 | ] 38 | ]) 39 | self.data = {} 40 | 41 | def setup(self, stage: Optional[str] = None): 42 | for split in ('train', 'validate', 'test'): 43 | try: 44 | self.data[split] = MultiImageFolderWithFilenames(self.basepath, self.categories, split, 45 | transform=self.transform) 46 | except FileNotFoundError: 47 | print_once(f'Could not create dataset for split {split}') 48 | 49 | def train_dataloader(self) -> DataLoader: 50 | return self._get_dataloader('train') 51 | 52 | def val_dataloader(self) -> DataLoader: 53 | return self._get_dataloader('validate') 54 | 55 | def test_dataloader(self) -> DataLoader: 56 | return self._get_dataloader('test') 57 | 58 | def _get_dataloader(self, split: str): 59 | return DataLoader(self.data[split], **self.dataloader) 60 | 61 | 62 | @dataclass 63 | class MultiImageFolderWithFilenames(Dataset): 64 | basepath: Union[str, Path] # Root 65 | categories: List[str] 66 | split: str 67 | transform: Callable 68 | 69 | def __post_init__(self): 70 | super().__init__() 71 | self.datasets = [ImageFolderWithFilenames(os.path.join(self.basepath, c, self.split), self.transform) for c in 72 | self.categories] 73 | self._n_datasets = len(self.datasets) 74 | self._dataset_lens = [len(d) for d in self.datasets] 75 | self._len = self._n_datasets * max(self._dataset_lens) 76 | print_once(f'Created dataset with {self.categories}. ' 77 | f'Lengths are {self._dataset_lens}. Effective dataset length is {self._len}.') 78 | 79 | def __getitem__(self, index): 80 | dataset_idx = index % self._n_datasets 81 | item_idx = (index // self._n_datasets) % self._dataset_lens[dataset_idx] 82 | return self.datasets[dataset_idx][item_idx] 83 | 84 | def __len__(self): 85 | return self._len 86 | 87 | -------------------------------------------------------------------------------- /src/data/nodata.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from dataclasses import dataclass 3 | from typing import Any, Dict 4 | 5 | from pytorch_lightning import LightningDataModule 6 | from torch.utils.data import DataLoader, IterableDataset 7 | 8 | __all__ = ["NullDataModule"] 9 | 10 | 11 | @dataclass 12 | class NullIterableDataset(IterableDataset): 13 | size: int 14 | 15 | def __post_init__(self): 16 | super().__init__() 17 | 18 | def __iter__(self): 19 | if self.size >= 0: 20 | return iter(range(self.size)) 21 | else: 22 | return itertools.count(0, 0) 23 | 24 | 25 | @dataclass 26 | class NullDataModule(LightningDataModule): 27 | dataloader: Dict[str, Any] 28 | train_size: int = -1 29 | validate_size: int = -1 30 | test_size: int = -1 31 | 32 | def __post_init__(self): 33 | super().__init__() 34 | 35 | def train_dataloader(self) -> DataLoader: 36 | return self._get_dataloader(self.train_size) 37 | 38 | def val_dataloader(self) -> DataLoader: 39 | return self._get_dataloader(self.validate_size) 40 | 41 | def test_dataloader(self) -> DataLoader: 42 | return self._get_dataloader(self.test_size) 43 | 44 | def _get_dataloader(self, size: int) -> DataLoader: 45 | return DataLoader(NullIterableDataset(size), **self.dataloader) 46 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Tuple, List 3 | from typing import Optional, Callable, Any 4 | 5 | import torch 6 | from torchvision.datasets.folder import default_loader, ImageFolder, make_dataset 7 | 8 | from utils import is_rank_zero, print_once 9 | 10 | 11 | class ImageFolderWithFilenames(ImageFolder): 12 | def __init__(self, root: str, transform: Optional[Callable] = None, 13 | target_transform: Optional[Callable] = None, 14 | loader: Callable[[str], Any] = default_loader, 15 | is_valid_file: Optional[Callable[[str], bool]] = None): 16 | super().__init__(root=root, transform=transform, target_transform=target_transform, 17 | loader=loader, is_valid_file=is_valid_file) 18 | 19 | @staticmethod 20 | def make_dataset( 21 | directory: str, 22 | class_to_idx: Dict[str, int], 23 | extensions: Optional[Tuple[str, ...]] = None, 24 | is_valid_file: Optional[Callable[[str], bool]] = None, 25 | ) -> List[Tuple[str, int]]: 26 | if class_to_idx is None: 27 | # prevent potential bug since make_dataset() would use the class_to_idx logic of the 28 | # find_classes() function, instead of using that of the find_classes() method, which 29 | # is potentially overridden and thus could have a different logic. 30 | raise ValueError( 31 | "The class_to_idx parameter cannot be None." 32 | ) 33 | cache_path = os.path.join(directory, 'cache.pt') 34 | try: 35 | dataset = torch.load(cache_path, map_location='cpu') 36 | print_once(f'Loading dataset from cache in {directory}') 37 | except FileNotFoundError: 38 | print_once(f'Creating dataset and saving to cache in {directory}') 39 | dataset = make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) 40 | if is_rank_zero(): 41 | torch.save(dataset, cache_path) 42 | except EOFError: 43 | print_once(f'Error loading cache from {directory},' 44 | f' likely because dataset is small and read/write were attempted concurrently. ' 45 | f'Proceeding by remaking dataset in-memory.') 46 | dataset = make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) 47 | print_once(f'{len(dataset)} images in dataset') 48 | return dataset 49 | 50 | def __getitem__(self, i): 51 | x, y = super().__getitem__(i) 52 | return x, {'labels': y, 'filenames': self.imgs[i][0]} 53 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Any, Tuple, Dict 3 | 4 | from pytorch_lightning import LightningModule 5 | 6 | from models import networks 7 | from utils import to_dataclass_cfg 8 | # from .segmenter import * 9 | from .blobgan import * 10 | from .gan import * 11 | from .invertblobgan import * 12 | 13 | 14 | def get_model(name: str, return_cfg: bool = False, **kwargs) -> Tuple[LightningModule, Dict[str, Any]]: 15 | cls = getattr(sys.modules[__name__], name) 16 | cfg = to_dataclass_cfg(kwargs, cls) 17 | if return_cfg: 18 | return cls(**cfg), cfg 19 | else: 20 | return cls(**cfg) 21 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | from numbers import Number 3 | from typing import Union, Any, Optional, Dict, Tuple, List 4 | 5 | import numpy as np 6 | import torch 7 | from einops import rearrange 8 | from pytorch_lightning import LightningModule 9 | from torch import Tensor 10 | 11 | from utils import scalars_to_log_dict, run_at_step, epoch_outputs_to_log_dict, is_rank_zero, get_rank, print_once 12 | 13 | 14 | class BaseModule(LightningModule): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | # Control flow 19 | def training_step(self, batch: Tuple[Tensor, dict], batch_idx: int, optimizer_idx: Optional[int] = None) -> Tensor: 20 | return self.shared_step(batch, batch_idx, optimizer_idx, 'train') 21 | 22 | def validation_step(self, batch: Tuple[Tensor, dict], batch_idx: int): 23 | return self.shared_step(batch, batch_idx, mode='validate') 24 | 25 | def test_step(self, batch: Tuple[Tensor, dict], batch_idx: int): 26 | return self.shared_step(batch, batch_idx, mode='test') 27 | 28 | def valtest_epoch_end(self, outputs: List[Dict[str, Tensor]], mode: str): 29 | if self.logger is None: 30 | return 31 | # Either log each step's output separately (results have been all_gathered in this case) 32 | if self.valtest_log_all: 33 | for image_dict in outputs: 34 | self._log_image_dict(image_dict, mode, commit=True) 35 | # Or just log a random batch worth of images from master process 36 | else: 37 | self._log_image_dict(epoch_outputs_to_log_dict(outputs, n_max="batch", shuffle=True), mode) 38 | 39 | def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]): 40 | self.valtest_epoch_end(outputs, 'validate') 41 | 42 | def test_epoch_end(self, outputs: List[Dict[str, Tensor]]): 43 | self.valtest_epoch_end(outputs, 'test') 44 | 45 | # Utility methods for logging 46 | def gather_tensor(self, t: Tensor) -> Tensor: 47 | return rearrange(self.all_gather(t), "m n c h w -> (m n) c h w") 48 | 49 | def gather_tensor_dict(self, d: Dict[Any, Tensor]) -> Dict[Any, Tensor]: 50 | return {k: rearrange(v.cpu(), "m n c h w -> (m n) c h w") for k, v in self.all_gather(d).items()} 51 | 52 | def log_scalars(self, scalars: Dict[Any, Union[Number, Tensor]], mode: str, **kwargs): 53 | if 'sync_dist' not in kwargs: 54 | kwargs['sync_dist'] = mode != 'train' 55 | self.log_dict(scalars_to_log_dict(scalars, mode), **kwargs) 56 | 57 | def _log_image_dict(self, img_dict: Dict[str, Tensor], mode: str, commit: bool = False, **kwargs): 58 | if self.logger is not None: 59 | for k, v in img_dict.items(): 60 | self.logger.log_image_batch(f'{mode}/{k}', v, commit=commit, **kwargs) 61 | 62 | def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): 63 | optimizer.zero_grad(set_to_none=True) # Improves performance 64 | 65 | def alert_nan_loss(self, loss: Tensor, batch_idx: int): 66 | if loss != loss: 67 | print( 68 | f'NaN loss in epoch {self.current_epoch}, batch index {batch_idx}, global step {self.global_step}, ' 69 | f'local rank {get_rank()}. Skipping.') 70 | return loss != loss 71 | 72 | def _log_profiler(self): 73 | if run_at_step(self.trainer.global_step, self.log_timing_every_n_steps): 74 | report, total_duration = self.trainer.profiler._make_report() 75 | report_log = dict([kv for action, durations, duration_per in report for kv in 76 | [(f'profiler/mean_t/{action}', np.mean(durations)), 77 | (f'profiler/n_calls/{action}', len(durations)), 78 | (f'profiler/total_t/{action}', np.sum(durations)), 79 | (f'profiler/pct/{action}', duration_per)]]) 80 | self.log_dict(report_log) 81 | self.logger.save_to_file('profiler_summary.txt', self.trainer.profiler.summary(), unique_filename=False) 82 | 83 | def on_train_start(self): 84 | if self.logger: 85 | self.logger.log_model_summary(self) 86 | 87 | def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None: 88 | self.log_dict({'grads/' + k: v for k, v in grad_norm_dict.items()}) 89 | 90 | def on_after_backward(self) -> None: 91 | if not getattr(self, 'validate_gradients', False): 92 | return 93 | 94 | valid_gradients = True 95 | invalid_params = [] 96 | for name, param in self.named_parameters(): 97 | if param.grad is not None: 98 | this_param_valid = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()) 99 | valid_gradients &= this_param_valid 100 | if not this_param_valid: 101 | invalid_params.append(name) 102 | # if not valid_gradients: 103 | # break 104 | 105 | if not valid_gradients: 106 | depth_two_params = [k for k, _ in groupby( 107 | ['.'.join(n.split('.')[:2]).replace('.weight', '').replace('.bias', '') for n in invalid_params])] 108 | print_once(f'Detected inf/NaN gradients for parameters {", ".join(depth_two_params)}. ' 109 | f'Skipping epoch {self.current_epoch}, batch index {self.batch_idx}, global step {self.global_step}.') 110 | self.zero_grad(set_to_none=True) 111 | -------------------------------------------------------------------------------- /src/models/gan.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["GAN"] 4 | 5 | from dataclasses import dataclass 6 | from typing import Optional, Union, List, Callable, Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from cleanfid import fid 12 | from torch import nn, Tensor 13 | from torch.cuda.amp import autocast 14 | from torch.optim import Optimizer 15 | 16 | from models import networks 17 | from models.base import BaseModule 18 | from utils import FromConfig, run_at_step, get_D_stats, G_path_loss, D_R1_loss, freeze, is_rank_zero, accumulate, \ 19 | mixing_noise, print_once 20 | 21 | 22 | @dataclass 23 | class Lossλs: 24 | D_real: float = 1 25 | D_fake: float = 1 26 | D_R1: float = 5 27 | G: float = 1 28 | G_path: float = 2 29 | 30 | def __getitem__(self, key): 31 | return super().__getattribute__(key) 32 | 33 | 34 | @dataclass(eq=False) 35 | class GAN(BaseModule): 36 | # Modules 37 | generator: FromConfig[nn.Module] 38 | discriminator: FromConfig[nn.Module] 39 | # Module parameters 40 | dim: int = 256 41 | resolution: int = 128 42 | p_mixing_noise: float = 0.9 43 | n_ema_sample: int = 16 44 | freeze_G: bool = False 45 | # Optimization 46 | lr: float = 1e-3 47 | eps: float = 1e-5 48 | # Regularization 49 | D_reg_every: int = 16 50 | G_reg_every: int = 4 51 | path_len: float = 0 52 | # Loss parameters 53 | λ: FromConfig[Lossλs] = None 54 | # Logging 55 | log_images_every_n_steps: Optional[int] = 500 56 | log_timing_every_n_steps: Optional[int] = -1 57 | log_fid_every_n_steps: Optional[int] = -1 58 | log_fid_every_epoch: bool = True 59 | fid_n_imgs: Optional[int] = 50000 60 | fid_stats_name: Optional[str] = None 61 | fid_num_workers: Optional[int] = 24 62 | valtest_log_all: bool = False 63 | accumulate: bool = True 64 | 65 | def __post_init__(self): 66 | super().__init__() 67 | self.save_hyperparameters() 68 | self.discriminator = networks.get_network(**self.discriminator) 69 | self.generator_ema = networks.get_network(**self.generator) 70 | self.generator = networks.get_network(**self.generator) 71 | if self.freeze_G: 72 | self.generator.eval() 73 | freeze(self.generator) 74 | if self.accumulate: 75 | self.generator_ema.eval() 76 | freeze(self.generator_ema) 77 | accumulate(self.generator_ema, self.generator, 0) 78 | else: 79 | del self.generator_ema 80 | self.λ = Lossλs(**self.λ) 81 | self.register_buffer('sample_z', torch.randn(self.n_ema_sample, self.dim)) 82 | # self.sample_z = torch.randn(self.n_ema_sample, self.dim) 83 | 84 | # Initialization and state management 85 | def on_train_start(self): 86 | super().on_train_start() 87 | # Validate parameters w.r.t. trainer (must be done here since trainer is not attached as property yet in init) 88 | assert self.log_images_every_n_steps % self.trainer.log_every_n_steps == 0, \ 89 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder' 90 | if self.log_timing_every_n_steps > -1: 91 | assert self.log_timing_every_n_steps % self.trainer.log_every_n_steps == 0, \ 92 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder' 93 | assert self.log_fid_every_n_steps < 0 or self.log_fid_every_n_steps % self.trainer.log_every_n_steps == 0, \ 94 | '`model.log_fid_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder' 95 | assert not ((self.log_fid_every_n_steps > -1 or self.log_fid_every_epoch) and (not self.fid_stats_name)), \ 96 | 'Cannot compute FID without name of statistics file to use.' 97 | 98 | def configure_optimizers(self) -> Union[optim, List[optim]]: 99 | G_reg_ratio = self.G_reg_every / ((self.G_reg_every + 1) or -1) 100 | D_reg_ratio = self.D_reg_every / ((self.D_reg_every + 1) or -1) 101 | _requires_grad = lambda p: p.requires_grad 102 | G_optim = torch.optim.Adam(filter(_requires_grad, self.generator.parameters()), lr=self.lr * G_reg_ratio, 103 | betas=(0 ** G_reg_ratio, 0.99 ** G_reg_ratio), eps=self.eps) 104 | D_optim = torch.optim.Adam(filter(_requires_grad, self.discriminator.parameters()), lr=self.lr * D_reg_ratio, 105 | betas=(0 ** D_reg_ratio, 0.99 ** D_reg_ratio), eps=self.eps) 106 | if self.freeze_G: 107 | return D_optim 108 | else: 109 | return G_optim, D_optim 110 | 111 | def optimizer_step( 112 | self, 113 | epoch: int = None, 114 | batch_idx: int = None, 115 | optimizer: Optimizer = None, 116 | optimizer_idx: int = None, 117 | optimizer_closure: Optional[Callable] = None, 118 | on_tpu: bool = None, 119 | using_native_amp: bool = None, 120 | using_lbfgs: bool = None, 121 | ): 122 | optimizer.step(closure=optimizer_closure) 123 | 124 | def training_epoch_end(self, *args, **kwargs): 125 | if self.log_fid_every_epoch: 126 | try: 127 | self.log_fid("train") 128 | except: 129 | pass 130 | 131 | def gen(self, z, truncate, ema=True, norm_img=True): 132 | G = self.generator_ema if ema else self.generator 133 | try: 134 | imgs = G([z], return_image_only=True, truncation=1 - truncate, 135 | truncation_latent=self.mean_latent) 136 | except AttributeError: 137 | print_once('Computing mean latent for generation.') 138 | self.get_mean_latent() 139 | imgs = G([z], return_image_only=True, truncation=1 - truncate, 140 | truncation_latent=self.mean_latent) 141 | if norm_img: 142 | imgs = imgs.add_(1).div_(2).mul_(255) 143 | return imgs 144 | 145 | @torch.no_grad() 146 | def log_fid(self, mode, **kwargs): 147 | def gen_fn(z): 148 | if self.accumulate: 149 | out = self.generator_ema([z], return_image_only=True).add_(1).div_(2).mul_(255) 150 | else: 151 | out = self.generator([z], return_image_only=True).add_(1).div_(2).mul_(255) 152 | return out.clamp(min=0, max=255) 153 | 154 | if is_rank_zero(): 155 | fid_score = fid.compute_fid(gen=gen_fn, dataset_name=self.fid_stats_name, 156 | dataset_res=256, num_gen=self.fid_n_imgs, 157 | dataset_split="custom", device=self.device, 158 | num_workers=self.fid_num_workers) 159 | else: 160 | fid_score = 0.0 161 | try: 162 | fid_score = self.all_gather(fid_score).max().item() 163 | self.log_scalars({'fid': fid_score}, mode, **kwargs) 164 | except AttributeError: 165 | pass 166 | return fid_score 167 | 168 | def get_mean_latent(self, n_trunc: int = 10000, ema=True): 169 | G = self.generator_ema if ema else self.generator 170 | mean_latent = self.mean_latent = G.mean_latent(n_trunc) 171 | return mean_latent 172 | 173 | # Training and evaluation 174 | def shared_step(self, batch: Tuple[Tensor, dict], batch_idx: int, 175 | optimizer_idx: Optional[int] = None, mode: str = 'train') -> Optional[Union[Tensor, dict]]: 176 | """ 177 | Args: 178 | batch: tuple of tensor of shape N x C x H x W of images and a dictionary of batch metadata/labels 179 | batch_idx: pytorch lightning training loop batch index 180 | optimizer_idx: pytorch lightning optimizer index (0 = G, 1 = D) 181 | mode: 182 | `train` returns the total loss and logs losses and images/profiling info. 183 | `validate`/`test` log total loss and return images 184 | Returns: see description for `mode` above 185 | """ 186 | # Set up modules and data 187 | train = mode == 'train' 188 | train_G = train and optimizer_idx == 0 and not self.freeze_G 189 | train_D = train and (optimizer_idx == 1 or self.freeze_G) 190 | batch_real, batch_labels = batch 191 | # z = torch.randn(len(batch_real), self.dim).type_as(batch_real) 192 | info = dict() 193 | losses = dict() 194 | z = mixing_noise(batch_real, self.dim, self.p_mixing_noise) 195 | 196 | gen_imgs, latents = self.generator(z, return_latents=True) 197 | 198 | if latents is not None: 199 | if latents.ndim == 3: 200 | latents = latents[:, 0] 201 | info['latent_norm'] = latents.norm(2, 1).mean() 202 | info['latent_stdev'] = latents.std(0).mean() 203 | 204 | # Compute various losses 205 | logits_fake = self.discriminator(gen_imgs) 206 | if train_G or not train: 207 | # Log 208 | losses['G'] = F.softplus(-logits_fake).mean() 209 | if train_D or not train: 210 | # Discriminate real images 211 | logits_real = self.discriminator(batch_real) 212 | # Log 213 | losses['D_real'] = F.softplus(-logits_real).mean() 214 | losses['D_fake'] = F.softplus(logits_fake).mean() 215 | info.update(get_D_stats('fake', logits_fake, gt=False)) 216 | info.update(get_D_stats('real', logits_real, gt=True)) 217 | 218 | # Save images 219 | imgs = { 220 | 'real_imgs': batch_real, 221 | 'gen_imgs': gen_imgs, 222 | } 223 | imgs = {k: v.clone().detach().float().cpu() for k, v in imgs.items()} 224 | 225 | # Compute train regularization loss 226 | if train_G and run_at_step(batch_idx, self.G_reg_every): 227 | if self.λ.G_path: 228 | z = mixing_noise(batch_real, self.dim, self.p_mixing_noise) 229 | gen_imgs, latents = self.generator(z, return_latents=True) 230 | losses['G_path'], self.path_len, info['G_path_len'] = G_path_loss(gen_imgs, latents, self.path_len) 231 | losses['G_path'] = losses['G_path'] * self.G_reg_every 232 | elif train_D and run_at_step(batch_idx, self.D_reg_every): 233 | if self.λ.D_R1: 234 | with autocast(enabled=False): 235 | batch_real.requires_grad = True 236 | logits_real = self.discriminator(batch_real) 237 | R1 = D_R1_loss(logits_real, batch_real) 238 | info['D_R1_unscaled'] = R1 239 | losses['D_R1'] = R1 * self.D_reg_every 240 | 241 | # Compute final loss and log 242 | losses['total_loss'] = sum(map(lambda k: losses[k] * self.λ[k], losses)) 243 | # if losses['total_loss'] > 20 and is_rank_zero(): 244 | # import ipdb 245 | # ipdb.set_trace() 246 | if self.alert_nan_loss(losses['total_loss'], batch_idx): 247 | if is_rank_zero(): 248 | import ipdb 249 | ipdb.set_trace() 250 | return 251 | self.log_scalars(losses, mode) 252 | self.log_scalars(info, mode) 253 | # Further logging and terminate 254 | if mode == "train": 255 | if train_G and self.accumulate: 256 | accumulate(self.generator_ema, self.generator, 0.5 ** (32 / (10 * 1000))) 257 | if run_at_step(self.trainer.global_step, self.log_images_every_n_steps): 258 | if self.accumulate: 259 | with torch.no_grad(): 260 | imgs['gen_imgs_ema'], _ = self.generator_ema([self.sample_z]) 261 | self._log_image_dict(imgs, mode, square_grid=False, ncol=len(batch_real)) 262 | if run_at_step(self.trainer.global_step, self.log_fid_every_n_steps) and is_rank_zero() and train_G: 263 | self.log_fid(mode) 264 | self._log_profiler() 265 | return losses['total_loss'] 266 | else: 267 | if self.valtest_log_all: 268 | imgs = self.gather_tensor_dict(imgs) 269 | return imgs 270 | -------------------------------------------------------------------------------- /src/models/invertblobgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | __all__ = ["BlobGANInverter"] 4 | 5 | import random 6 | from dataclasses import dataclass 7 | from typing import Optional, Union, List, Callable, Tuple 8 | 9 | import torch 10 | import torch.optim as optim 11 | from PIL import Image 12 | from lpips import LPIPS 13 | from omegaconf import DictConfig 14 | from torch import nn, Tensor 15 | from torch.optim import Optimizer 16 | from torchvision.utils import make_grid 17 | 18 | from models import networks, BlobGAN 19 | from models.base import BaseModule 20 | from utils import FromConfig, run_at_step, freeze, is_rank_zero, load_pretrained_weights, to_dataclass_cfg, print_once 21 | 22 | # SPLAT_KEYS = ['spatial_style', 'xs', 'ys', 'covs', 'sizes'] 23 | SPLAT_KEYS = ['spatial_style', 'scores_pyramid'] 24 | _ = Image 25 | _ = make_grid 26 | 27 | 28 | @dataclass 29 | class Lossλs: 30 | real_LPIPS: float = 1. 31 | real_MSE: float = 1. 32 | fake_LPIPS: float = 1. 33 | fake_MSE: float = 1. 34 | fake_latents_MSE: float = 1. 35 | 36 | def __getitem__(self, key): 37 | return super().__getattribute__(key) 38 | 39 | 40 | @dataclass(eq=False) 41 | class BlobGANInverter(BaseModule): 42 | # Modules 43 | inverter: FromConfig[nn.Module] 44 | generator: FromConfig[BlobGAN] 45 | # Loss parameters 46 | λ: FromConfig[Lossλs] = None 47 | # Logging 48 | log_images_every_n_steps: Optional[int] = 500 49 | log_timing_every_n_steps: Optional[int] = -1 50 | log_grads_every_n_steps: Optional[int] = -1 51 | valtest_log_all: bool = False 52 | # Resuming 53 | generator_pretrained: Optional[Union[str, DictConfig]] = None 54 | load_only_inverter: bool = False 55 | inverter_d_out: Optional[int] = None 56 | # Optim 57 | lr: float = 0.002 58 | eps: float = 1e-5 59 | # Training 60 | trunc_min: float = 0.0 61 | trunc_max: float = 0.0 62 | 63 | def __post_init__(self): 64 | super().__init__() 65 | self.save_hyperparameters() 66 | cfg = to_dataclass_cfg(self.generator, BlobGAN) 67 | if self.generator_pretrained.log_dir is None: 68 | self.generator_pretrained.log_dir = 'logs/' 69 | if not self.load_only_inverter: 70 | self.generator = load_pretrained_weights('BlobGAN', self.generator_pretrained, BlobGAN(**cfg), strict=False) 71 | del self.generator.discriminator 72 | del self.generator.generator 73 | del self.generator.layout_net 74 | freeze(self.generator) 75 | self.inverter = networks.get_network(**self.inverter, 76 | d_out=self.inverter_d_out or 77 | self.generator.layout_net_ema.mlp[-1].weight.shape[0]) 78 | self.L_LPIPS = LPIPS(net='vgg', verbose=False) 79 | freeze(self.L_LPIPS) 80 | self.λ = Lossλs(**self.λ) 81 | 82 | # Initialization and state management 83 | def on_train_start(self): 84 | super().on_train_start() 85 | # Validate parameters w.r.t. trainer (must be done here since trainer is not attached as property yet in init) 86 | assert self.log_images_every_n_steps % self.trainer.log_every_n_steps == 0, \ 87 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder. ' \ 88 | f'Got {self.log_images_every_n_steps} and {self.trainer.log_every_n_steps}.' 89 | assert self.log_timing_every_n_steps < 0 or self.log_timing_every_n_steps % self.trainer.log_every_n_steps == 0, \ 90 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder' 91 | 92 | def configure_optimizers(self) -> Union[optim, List[optim]]: 93 | params = list(self.inverter.parameters()) 94 | print_once(f'Optimizing {sum([p.numel() for p in params]) / 1e6:.2f}M params') 95 | return torch.optim.AdamW(params, lr=self.lr, eps=self.eps, weight_decay=0) 96 | 97 | def optimizer_step( 98 | self, 99 | epoch: int = None, 100 | batch_idx: int = None, 101 | optimizer: Optimizer = None, 102 | optimizer_idx: int = None, 103 | optimizer_closure: Optional[Callable] = None, 104 | on_tpu: bool = None, 105 | using_native_amp: bool = None, 106 | using_lbfgs: bool = None, 107 | ): 108 | self.batch_idx = batch_idx 109 | optimizer.step(closure=optimizer_closure) 110 | 111 | def shared_step(self, batch: Tuple[Tensor, dict], batch_idx: int, 112 | optimizer_idx: Optional[int] = None, mode: str = 'train') -> Optional[Union[Tensor, dict]]: 113 | """ 114 | Args: 115 | batch: tuple of tensor of shape N x C x H x W of images and a dictionary of batch metadata/labels 116 | batch_idx: pytorch lightning training loop batch index 117 | optimizer_idx: pytorch lightning optimizer index (0 = G, 1 = D) 118 | mode: 119 | `train` returns the total loss and logs losses and images/profiling info. 120 | `validate`/`test` log total loss and return images 121 | Returns: see description for `mode` above 122 | """ 123 | # Set up modules and data 124 | batch_real, batch_labels = batch 125 | log_images = run_at_step(self.trainer.global_step, self.log_images_every_n_steps) 126 | 127 | z = torch.randn(len(batch_real), self.generator.noise_dim).type_as(batch_real) 128 | 129 | with torch.no_grad(): 130 | truncate = self.trunc_min if self.trunc_min == self.trunc_max \ 131 | else random.uniform(self.trunc_min, self.trunc_max) 132 | layout_gt_fake, gen_imgs = self.generator.gen(z, truncate=truncate, ema=True, viz=log_images, 133 | ret_layout=True) 134 | 135 | losses = dict() 136 | 137 | z_pred_fake = self.inverter(gen_imgs.detach()) 138 | 139 | layout_pred_fake, reconstr_fake = self.generator.gen(z_pred_fake, ema=True, viz=log_images, ret_layout=True, 140 | mlp_idx=len(self.generator.layout_net_ema.mlp)) 141 | 142 | losses['fake_MSE'] = (gen_imgs - reconstr_fake).pow(2).mean() 143 | losses['fake_LPIPS'] = self.L_LPIPS(reconstr_fake, gen_imgs).mean() 144 | latent_l2_loss = [] 145 | for k in ('xs', 'ys', 'covs', 'sizes', 'features', 'spatial_style'): 146 | latent_l2_loss.append((layout_pred_fake[k] - layout_gt_fake[k].detach()).pow(2).mean()) 147 | losses['fake_latents_MSE'] = sum(latent_l2_loss) / len(latent_l2_loss) 148 | 149 | z_pred_real = self.inverter(batch_real.detach()) 150 | layout_pred_real, reconstr_real = self.generator.gen(z_pred_real, ema=True, viz=log_images, ret_layout=True, 151 | mlp_idx=len(self.generator.layout_net_ema.mlp)) 152 | 153 | losses['real_MSE'] = (batch_real - reconstr_real).pow(2).mean() 154 | losses['real_LPIPS'] = self.L_LPIPS(reconstr_real, batch_real).mean() 155 | 156 | total_loss = f'total_loss' 157 | losses[total_loss] = sum(map(lambda k: losses[k] * self.λ[k], losses)) 158 | isnan = self.alert_nan_loss(losses[total_loss], batch_idx) 159 | if self.all_gather(isnan).any(): 160 | if self.ipdb_on_nan and is_rank_zero(): 161 | import ipdb 162 | ipdb.set_trace() 163 | return 164 | self.log_scalars(losses, mode) 165 | 166 | imgs = { 167 | 'real': batch_real, 168 | 'real_reconstr': reconstr_real, 169 | 'fake': gen_imgs, 170 | 'fake_reconstr': reconstr_fake, 171 | 'real_reconstr_feats': layout_pred_real['feature_img'], 172 | 'fake_reconstr_feats': layout_pred_fake['feature_img'], 173 | 'fake_feats': layout_gt_fake['feature_img'] 174 | } 175 | if mode == "train": 176 | if log_images and is_rank_zero(): 177 | imgs = {k: v.clone().detach().float().cpu() for k, v in imgs.items()} 178 | self._log_image_dict(imgs, mode, square_grid=False, ncol=len(batch_real)) 179 | return losses[total_loss] 180 | else: 181 | if self.valtest_log_all: 182 | imgs = self.gather_tensor_dict(imgs) 183 | return imgs 184 | -------------------------------------------------------------------------------- /src/models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional, Union 3 | 4 | import torch 5 | from omegaconf import DictConfig 6 | from torch import nn 7 | 8 | from utils import import_external, is_rank_zero, get_checkpoint_path, load_pretrained_weights 9 | from .stylegan import * 10 | from .layoutnet import * 11 | 12 | 13 | def get_network(name: str, pretrained: Optional[Union[str, DictConfig]] = None, **kwargs) -> nn.Module: 14 | if '.' in name: 15 | ret = import_external(name, pretrained, **kwargs) 16 | return ret 17 | else: 18 | ret = getattr(sys.modules[__name__], name)(**kwargs) 19 | return load_pretrained_weights(name, pretrained, ret) 20 | -------------------------------------------------------------------------------- /src/models/networks/layoutnet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py 2 | import random 3 | from dataclasses import dataclass 4 | from typing import Optional, Dict 5 | 6 | import torch 7 | from einops import rearrange 8 | from torch import nn, Tensor 9 | 10 | __all__ = ["LayoutGenerator"] 11 | 12 | from models.networks.stylegan import StyleMLP, pixel_norm 13 | from utils import derange_tensor 14 | 15 | 16 | @dataclass(eq=False) 17 | class LayoutGenerator(nn.Module): 18 | noise_dim: int = 512 19 | feature_dim: int = 512 20 | style_dim: int = 512 21 | # MLP options 22 | mlp_n_layers: int = 8 23 | mlp_trunk_n_layers: int = 4 24 | mlp_hidden_dim: int = 1024 25 | n_features_max: int = 5 26 | norm_features: bool = False 27 | # Transformer options 28 | spatial_style: bool = False 29 | # Training options 30 | mlp_lr_mul: float = 0.01 31 | shuffle_features: bool = False 32 | p_swap_style: float = 0.0 33 | feature_jitter_xy: float = 0.0 # Legacy, unused 34 | feature_dropout: float = 0.0 35 | 36 | def __post_init__(self): 37 | super().__init__() 38 | if self.feature_jitter_xy: 39 | print('Warning! This parameter is here only to support loading of old checkpoints, and does not function. ' 40 | 'Unless you are loading a model that has this value set, it should not be used. To control jitter, ' 41 | 'set model.feature_jitter_xy directly.') 42 | # {x_i, y_i, feature_i, covariance_i}, bg feature, and cluster sizes 43 | maybe_style_dim = int(self.spatial_style) * self.style_dim 44 | ndim = (self.feature_dim + maybe_style_dim + 2 + 4 + 1) * self.n_features_max + \ 45 | (maybe_style_dim + self.feature_dim + 1) 46 | self.mlp = StyleMLP(self.mlp_n_layers, self.mlp_hidden_dim, self.mlp_lr_mul, first_dim=self.noise_dim, 47 | last_dim=ndim, last_relu=False) 48 | 49 | def forward(self, noise: Tensor, n_features: int, 50 | mlp_idx: Optional[int] = None) -> Optional[Dict[str, Tensor]]: 51 | """ 52 | Args: 53 | noise: [N x noise_dim] or [N x M x noise_dim] 54 | mlp_idx: which IDX to start running MLP from, useful for truncation 55 | n_features: int num features to output 56 | Returns: three tensors x coordinates [N x M], y coordinates [N x M], features [N x M x feature_dim] 57 | """ 58 | if mlp_idx is None: 59 | out = self.mlp(noise) 60 | else: 61 | out = self.mlp[mlp_idx:](noise) 62 | sizes, out = out.tensor_split((self.n_features_max + 1,), dim=1) 63 | bg_feat, out = out.tensor_split((self.feature_dim,), dim=1) 64 | if self.spatial_style: 65 | bg_style_feat, out = out.tensor_split((self.style_dim,), dim=1) 66 | out = rearrange(out, 'n (m d) -> n m d', m=self.n_features_max) 67 | if self.shuffle_features: 68 | idxs = torch.randperm(self.n_features_max)[:n_features] 69 | else: 70 | idxs = torch.arange(n_features) 71 | out = out[:, idxs] 72 | sizes = sizes[:, [0] + idxs.add(1).tolist()] 73 | if self.feature_dropout: 74 | keep = torch.rand((out.size(1),)) > self.feature_dropout 75 | if not keep.any(): 76 | keep[0] = True 77 | out = out[:, keep] 78 | sizes = sizes[:, [True] + keep.tolist()] 79 | xy = out[..., :2].sigmoid() # .mul(self.max_coord) 80 | ret = {'xs': xy[..., 0], 'ys': xy[..., 1], 'sizes': sizes[:, :n_features + 1], 'covs': out[..., 2:6]} 81 | end_dim = self.feature_dim + 6 82 | features = out[..., 6:end_dim] 83 | features = torch.cat((bg_feat[:, None], features), 1) 84 | ret['features'] = features 85 | # return [xy[..., 0], xy[..., 1], features, covs, sizes[:, :n_features + 1].softmax(-1)] 86 | if self.spatial_style: 87 | style_features = out[..., end_dim:] 88 | style_features = torch.cat((bg_style_feat[:, None], style_features), 1) 89 | ret['spatial_style'] = style_features 90 | # ret['covs'] = ret['covs'].detach() 91 | if self.norm_features: 92 | for k in ('features', 'spatial_style', 'shape_features'): 93 | if k in ret: 94 | ret[k] = pixel_norm(ret[k]) 95 | if self.p_swap_style: 96 | if random.random() <= self.p_swap_style: 97 | n = random.randint(0, ret['spatial_style'].size(1) - 1) 98 | shuffle = torch.randperm(ret['spatial_style'].size(1) - 1).add(1)[:n] 99 | ret['spatial_style'][:, shuffle] = derange_tensor(ret['spatial_style'][:, shuffle]) 100 | return ret 101 | -------------------------------------------------------------------------------- /src/models/networks/op/__init__.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/rosinality/stylegan2-pytorch/tree/3dee637b8937bf3830991c066ed8d9cc58afd661/op 2 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 3 | from .upfirdn2d import upfirdn2d 4 | -------------------------------------------------------------------------------- /src/models/networks/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from distutils.version import LooseVersion 3 | 4 | if LooseVersion(torch.__version__) >= LooseVersion('1.11.0'): 5 | # New conv refactoring started at version 1.11, it seems. 6 | from .conv2d_gradfix_111andon import conv2d, conv_transpose2d 7 | else: 8 | from .conv2d_gradfix_pre111 import conv2d, conv_transpose2d -------------------------------------------------------------------------------- /src/models/networks/op/conv2d_gradfix_111andon.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | # THANKS https://github.com/pytorch/pytorch/issues/74437 !!!!! 12 | import warnings 13 | import contextlib 14 | import torch 15 | from distutils.version import LooseVersion 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | # ---------------------------------------------------------------------------- 22 | 23 | enabled = True # Enable the custom op by setting this to true. 24 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 25 | 26 | 27 | @contextlib.contextmanager 28 | def no_weight_gradients(): 29 | global weight_gradients_disabled 30 | old = weight_gradients_disabled 31 | weight_gradients_disabled = True 32 | yield 33 | weight_gradients_disabled = old 34 | 35 | 36 | # ---------------------------------------------------------------------------- 37 | 38 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 39 | if _should_use_custom_op(input): 40 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, 41 | output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 42 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, 43 | dilation=dilation, groups=groups) 44 | 45 | 46 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 47 | if _should_use_custom_op(input): 48 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, 49 | output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, 50 | bias) 51 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, 52 | output_padding=output_padding, groups=groups, dilation=dilation) 53 | 54 | 55 | # ---------------------------------------------------------------------------- 56 | 57 | def _should_use_custom_op(input): 58 | assert isinstance(input, torch.Tensor) 59 | if (not enabled) or (not torch.backends.cudnn.enabled): 60 | return False 61 | if input.device.type != 'cuda': 62 | return False 63 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): 64 | return True 65 | warnings.warn( 66 | f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 67 | return False 68 | 69 | 70 | def _tuple_of_ints(xs, ndim): 71 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 72 | assert len(xs) == ndim 73 | assert all(isinstance(x, int) for x in xs) 74 | return xs 75 | 76 | 77 | # ---------------------------------------------------------------------------- 78 | 79 | _conv2d_gradfix_cache = dict() 80 | 81 | 82 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 83 | # Parse arguments. 84 | ndim = 2 85 | weight_shape = tuple(weight_shape) 86 | stride = _tuple_of_ints(stride, ndim) 87 | padding = _tuple_of_ints(padding, ndim) 88 | output_padding = _tuple_of_ints(output_padding, ndim) 89 | dilation = _tuple_of_ints(dilation, ndim) 90 | 91 | # Lookup from cache. 92 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 93 | if key in _conv2d_gradfix_cache: 94 | return _conv2d_gradfix_cache[key] 95 | 96 | # Validate arguments. 97 | assert groups >= 1 98 | assert len(weight_shape) == ndim + 2 99 | assert all(stride[i] >= 1 for i in range(ndim)) 100 | assert all(padding[i] >= 0 for i in range(ndim)) 101 | assert all(dilation[i] >= 0 for i in range(ndim)) 102 | if not transpose: 103 | assert all(output_padding[i] == 0 for i in range(ndim)) 104 | else: # transpose 105 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 106 | 107 | # Helpers. 108 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 109 | 110 | def calc_output_padding(input_shape, output_shape): 111 | if transpose: 112 | return [0, 0] 113 | return [ 114 | input_shape[i + 2] 115 | - (output_shape[i + 2] - 1) * stride[i] 116 | - (1 - 2 * padding[i]) 117 | - dilation[i] * (weight_shape[i + 2] - 1) 118 | for i in range(ndim) 119 | ] 120 | 121 | # Forward & backward. 122 | class Conv2d(torch.autograd.Function): 123 | @staticmethod 124 | def forward(ctx, input, weight, bias): 125 | assert weight.shape == weight_shape 126 | if not transpose: 127 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 128 | else: # transpose 129 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, 130 | output_padding=output_padding, **common_kwargs) 131 | ctx.save_for_backward(input, weight, bias) 132 | return output 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | input, weight, bias = ctx.saved_tensors 137 | grad_input = None 138 | grad_weight = None 139 | grad_bias = None 140 | 141 | if ctx.needs_input_grad[0]: 142 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 143 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, 144 | **common_kwargs).apply(grad_output, weight, None) 145 | assert grad_input.shape == input.shape 146 | 147 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 148 | grad_weight = Conv2dGradWeight.apply(grad_output, input, bias) 149 | assert grad_weight.shape == weight_shape 150 | 151 | if ctx.needs_input_grad[2]: 152 | grad_bias = grad_output.sum([0, 2, 3]) 153 | 154 | return grad_input, grad_weight, grad_bias 155 | 156 | # Gradient with respect to the weights. 157 | class Conv2dGradWeight(torch.autograd.Function): 158 | @staticmethod 159 | def forward(ctx, grad_output, input, bias): 160 | bias_shape = bias.shape if (bias is not None) else None 161 | # empty_weight = torch.empty(weight_shape, dtype=input.dtype, layout=input.layout, device=input.device) 162 | empty_weight = torch.tensor(0.0, dtype=input.dtype, device=input.device).expand(weight_shape) 163 | grad_weight = \ 164 | torch.ops.aten.convolution_backward(grad_output, input, empty_weight, bias_sizes=bias_shape, 165 | stride=stride, 166 | padding=padding, dilation=dilation, transposed=transpose, 167 | output_padding=output_padding, groups=groups, 168 | output_mask=[0, 1, 0])[1] 169 | assert grad_weight.shape == weight_shape 170 | ctx.save_for_backward(grad_output, input) 171 | return grad_weight 172 | 173 | @staticmethod 174 | def backward(ctx, grad2_grad_weight): 175 | grad_output, input = ctx.saved_tensors 176 | grad2_grad_output = None 177 | grad2_input = None 178 | 179 | if ctx.needs_input_grad[0]: 180 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 181 | assert grad2_grad_output.shape == grad_output.shape 182 | 183 | if ctx.needs_input_grad[1]: 184 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 185 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, 186 | **common_kwargs).apply(grad_output, grad2_grad_weight, None) 187 | assert grad2_input.shape == input.shape 188 | 189 | return grad2_grad_output, grad2_input, None 190 | 191 | _conv2d_gradfix_cache[key] = Conv2d 192 | return Conv2d 193 | 194 | # ---------------------------------------------------------------------------- 195 | -------------------------------------------------------------------------------- /src/models/networks/op/conv2d_gradfix_pre111.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | from utils import is_rank_zero 9 | 10 | enabled = True 11 | weight_gradients_disabled = False 12 | 13 | 14 | @contextlib.contextmanager 15 | def no_weight_gradients(): 16 | global weight_gradients_disabled 17 | 18 | old = weight_gradients_disabled 19 | weight_gradients_disabled = True 20 | yield 21 | weight_gradients_disabled = old 22 | 23 | 24 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 25 | if could_use_op(input): 26 | return conv2d_gradfix( 27 | transpose=False, 28 | weight_shape=weight.shape, 29 | stride=stride, 30 | padding=padding, 31 | output_padding=0, 32 | dilation=dilation, 33 | groups=groups, 34 | ).apply(input, weight, bias) 35 | 36 | return F.conv2d( 37 | input=input, 38 | weight=weight, 39 | bias=bias, 40 | stride=stride, 41 | padding=padding, 42 | dilation=dilation, 43 | groups=groups, 44 | ) 45 | 46 | 47 | def conv_transpose2d( 48 | input, 49 | weight, 50 | bias=None, 51 | stride=1, 52 | padding=0, 53 | output_padding=0, 54 | groups=1, 55 | dilation=1, 56 | ): 57 | if could_use_op(input): 58 | return conv2d_gradfix( 59 | transpose=True, 60 | weight_shape=weight.shape, 61 | stride=stride, 62 | padding=padding, 63 | output_padding=output_padding, 64 | groups=groups, 65 | dilation=dilation, 66 | ).apply(input, weight, bias) 67 | 68 | return F.conv_transpose2d( 69 | input=input, 70 | weight=weight, 71 | bias=bias, 72 | stride=stride, 73 | padding=padding, 74 | output_padding=output_padding, 75 | dilation=dilation, 76 | groups=groups, 77 | ) 78 | 79 | 80 | def could_use_op(input): 81 | if (not enabled) or (not torch.backends.cudnn.enabled) or input.device.type != "cuda": 82 | if is_rank_zero(): 83 | warnings.warn("CUDNN disabled, no GPUs, or custom ops otherwise not enabled, so not being used.") 84 | return False 85 | 86 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9.", "1.10.", "1.11"]): 87 | if is_rank_zero(): 88 | warnings.warn("Using custom ops") 89 | return True 90 | 91 | if is_rank_zero(): 92 | warnings.warn( 93 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 94 | ) 95 | 96 | return False 97 | 98 | 99 | def ensure_tuple(xs, ndim): 100 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 101 | 102 | return xs 103 | 104 | 105 | conv2d_gradfix_cache = dict() 106 | 107 | 108 | def conv2d_gradfix( 109 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 110 | ): 111 | ndim = 2 112 | weight_shape = tuple(weight_shape) 113 | stride = ensure_tuple(stride, ndim) 114 | padding = ensure_tuple(padding, ndim) 115 | output_padding = ensure_tuple(output_padding, ndim) 116 | dilation = ensure_tuple(dilation, ndim) 117 | 118 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 119 | if key in conv2d_gradfix_cache: 120 | return conv2d_gradfix_cache[key] 121 | 122 | common_kwargs = dict( 123 | stride=stride, padding=padding, dilation=dilation, groups=groups 124 | ) 125 | 126 | def calc_output_padding(input_shape, output_shape): 127 | if transpose: 128 | return [0, 0] 129 | 130 | return [ 131 | input_shape[i + 2] 132 | - (output_shape[i + 2] - 1) * stride[i] 133 | - (1 - 2 * padding[i]) 134 | - dilation[i] * (weight_shape[i + 2] - 1) 135 | for i in range(ndim) 136 | ] 137 | 138 | class Conv2d(autograd.Function): 139 | @staticmethod 140 | def forward(ctx, input, weight, bias): 141 | if not transpose: 142 | out = F.conv2d(input=input, weight=weight.to(input.dtype), 143 | bias=bias.to(input.dtype) if bias is not None else bias, 144 | **common_kwargs) 145 | 146 | else: 147 | out = F.conv_transpose2d( 148 | input=input, 149 | weight=weight.to(input.dtype), 150 | bias=bias.to(input.dtype) if bias else bias, 151 | output_padding=output_padding, 152 | **common_kwargs, 153 | ) 154 | 155 | ctx.save_for_backward(input, weight) 156 | 157 | return out 158 | 159 | @staticmethod 160 | def backward(ctx, grad_output): 161 | input, weight = ctx.saved_tensors 162 | grad_input, grad_weight, grad_bias = None, None, None 163 | 164 | if ctx.needs_input_grad[0]: 165 | p = calc_output_padding( 166 | input_shape=input.shape, output_shape=grad_output.shape 167 | ) 168 | grad_input = conv2d_gradfix( 169 | transpose=(not transpose), 170 | weight_shape=weight_shape, 171 | output_padding=p, 172 | **common_kwargs, 173 | ).apply(grad_output, weight, None) 174 | 175 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 176 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 177 | 178 | if ctx.needs_input_grad[2]: 179 | grad_bias = grad_output.sum((0, 2, 3)) 180 | 181 | return grad_input, grad_weight, grad_bias 182 | 183 | class Conv2dGradWeight(autograd.Function): 184 | @staticmethod 185 | def forward(ctx, grad_output, input): 186 | op = torch._C._jit_get_operation( 187 | "aten::cudnn_convolution_backward_weight" 188 | if not transpose 189 | else "aten::cudnn_convolution_transpose_backward_weight" 190 | ) 191 | flags = [ 192 | torch.backends.cudnn.benchmark, 193 | torch.backends.cudnn.deterministic, 194 | torch.backends.cudnn.allow_tf32, 195 | ] 196 | grad_weight = op( 197 | weight_shape, 198 | grad_output, 199 | input.to(grad_output.dtype), 200 | padding, 201 | stride, 202 | dilation, 203 | groups, 204 | *flags, 205 | ) 206 | ctx.save_for_backward(grad_output, input) 207 | 208 | return grad_weight 209 | 210 | @staticmethod 211 | def backward(ctx, grad_grad_weight): 212 | grad_output, input = ctx.saved_tensors 213 | grad_grad_output, grad_grad_input = None, None 214 | 215 | if ctx.needs_input_grad[0]: 216 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 217 | 218 | if ctx.needs_input_grad[1]: 219 | p = calc_output_padding( 220 | input_shape=input.shape, output_shape=grad_output.shape 221 | ) 222 | grad_grad_input = conv2d_gradfix( 223 | transpose=(not transpose), 224 | weight_shape=weight_shape, 225 | output_padding=p, 226 | **common_kwargs, 227 | ).apply(grad_output, grad_grad_weight, None) 228 | 229 | return grad_grad_output, grad_grad_input 230 | 231 | conv2d_gradfix_cache[key] = Conv2d 232 | 233 | return Conv2d 234 | -------------------------------------------------------------------------------- /src/models/networks/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input.float(), bias, empty.float(), 3, 0, negative_slope, scale).to(input.dtype) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /src/models/networks/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /src/models/networks/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /src/models/networks/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /src/models/networks/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output.float(), 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ).to(grad_output.dtype) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input.float(), 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input.float(), kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ).to(input.dtype) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | import torch 7 | from omegaconf import OmegaConf, DictConfig 8 | from pytorch_lightning import seed_everything 9 | 10 | import data 11 | import models 12 | import utils 13 | from utils import scale_logging_rates, print_once, Checkpoint 14 | 15 | 16 | @hydra.main(config_path="configs", config_name="fit") 17 | def run(config: DictConfig): 18 | torch.backends.cudnn.deterministic = config.trainer.deterministic 19 | torch.backends.cudnn.benchmark = config.trainer.benchmark 20 | torch.use_deterministic_algorithms(config.trainer.deterministic) 21 | if config.trainer.deterministic: 22 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 23 | 24 | print_once(OmegaConf.to_yaml(config, resolve=True)) 25 | 26 | seed_everything(config.seed, workers=True) 27 | 28 | scale_logging_rates(config, 1 / config.trainer.get('accumulate_grad_batches', 1)) 29 | 30 | if config.get('detect_anomalies', False): 31 | print_once('Anomaly detection mode ACTIVATED') 32 | torch.autograd.set_detect_anomaly(True) 33 | 34 | config.resume.id = utils.resolve_resume_id(**config.resume) 35 | 36 | if config.logger: 37 | logger = utils.Logger(**config[config.logger]) 38 | logger.log_config(config) 39 | logger.log_code() 40 | else: 41 | logger = False 42 | 43 | datamodule = data.get_datamodule(**config.dataset) 44 | 45 | model, model_cfg = models.get_model(**config.model, return_cfg=True) 46 | 47 | if config.resume.id is not None: 48 | checkpoint = utils.get_checkpoint_path(**config.resume) 49 | if config.mode != 'fit' or config.resume.model_only: 50 | # Automatically load model weights in validate/test mode as opposed to using built-in PL argument to 51 | # validate or test methods since need custom logic e.g. to remove non-dataclass args 52 | model = model.load_from_checkpoint(checkpoint, **(model_cfg if config.resume.clobber_hparams else {})) 53 | else: 54 | checkpoint = None 55 | 56 | if logger: 57 | if os.environ.get("EXP_LOG_DIR", None) is None: 58 | # Needed because in distributed training, the logger is not properly initializated on clone processes 59 | # If dirname for the checkpointer is not the same on all processes, training hangs 60 | # See https://github.com/PyTorchLightning/pytorch-lightning/issues/5319 61 | os.environ["EXP_LOG_DIR"] = logger.experiment.dir 62 | 63 | callbacks = [] 64 | checkpoint_callback = 'checkpoint' in config and config.checkpoint is not None 65 | 66 | if logger and checkpoint_callback: 67 | checkpoint_cb = Checkpoint(**config.checkpoint, 68 | dirpath=Path(os.environ["EXP_LOG_DIR"]) / 'checkpoints') 69 | checkpoint_cb.CHECKPOINT_NAME_LAST = checkpoint_cb.CHECKPOINT_JOIN_CHAR.join(["{epoch}", "{step}", "last"]) 70 | callbacks.append(checkpoint_cb) 71 | 72 | trainer = pl.Trainer( 73 | resume_from_checkpoint=None if config.resume.model_only else checkpoint, 74 | logger=logger, 75 | callbacks=callbacks, 76 | checkpoint_callback=checkpoint_callback, 77 | **config.trainer 78 | ) 79 | 80 | if config.mode == 'fit': 81 | trainer.fit(model, datamodule=datamodule) 82 | elif config.mode == 'validate': 83 | trainer.validate(model, datamodule=datamodule) 84 | elif config.mode == 'test': 85 | trainer.test(model, datamodule=datamodule) 86 | 87 | 88 | if __name__ == "__main__": 89 | run() 90 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | from .training import * 3 | from .io import * 4 | from .colab import * 5 | from .distributed import * 6 | from .wandb_logger import * 7 | from .logging import * 8 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # From Bill Peebles, thanks Bill! 2 | # https://raw.githubusercontent.com/wpeebles/gangealing/739da2a25de62702d54d83fad6b644646512039c/utils/distributed.py 3 | import torch 4 | from torch import distributed as dist 5 | import os 6 | 7 | 8 | def is_rank_zero(): 9 | return get_rank() == 0 10 | 11 | 12 | def print_once(s): 13 | if is_rank_zero(): 14 | print(s) 15 | 16 | 17 | def get_rank(): 18 | return int(os.environ.get('LOCAL_RANK', 0)) 19 | 20 | 21 | def get_rank_colab(): 22 | if not dist.is_available(): 23 | return 0 24 | 25 | if not dist.is_initialized(): 26 | return 0 27 | 28 | return dist.get_rank() 29 | 30 | 31 | def primary(): 32 | if not dist.is_available(): 33 | return True 34 | 35 | if not dist.is_initialized(): 36 | return True 37 | 38 | return get_rank_colab() == 0 39 | 40 | 41 | def synchronize(): 42 | if not dist.is_available(): 43 | return 44 | 45 | if not dist.is_initialized(): 46 | return 47 | 48 | world_size = dist.get_world_size() 49 | 50 | if world_size == 1: 51 | return 52 | 53 | dist.barrier() 54 | 55 | 56 | def get_world_size(): 57 | if not dist.is_available(): 58 | return 1 59 | 60 | if not dist.is_initialized(): 61 | return 1 62 | 63 | return dist.get_world_size() 64 | 65 | 66 | def reduce_sum(tensor): 67 | if not dist.is_available(): 68 | return tensor 69 | 70 | if not dist.is_initialized(): 71 | return tensor 72 | 73 | tensor = tensor.clone() 74 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 75 | 76 | return tensor 77 | 78 | 79 | def gather_grad(params): 80 | world_size = get_world_size() 81 | 82 | if world_size == 1: 83 | return 84 | 85 | for param in params: 86 | if param.grad is not None: 87 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 88 | param.grad.data.div_(world_size) 89 | 90 | 91 | def all_gather(input, cat=True): 92 | if get_world_size() == 1: 93 | if cat: 94 | return input 95 | else: 96 | return input.unsqueeze(0) 97 | input_list = [torch.zeros_like(input) for _ in range(get_world_size())] 98 | synchronize() 99 | torch.distributed.all_gather(input_list, input, async_op=False) 100 | if cat: 101 | inputs = torch.cat(input_list, dim=0) 102 | else: 103 | inputs = torch.stack(input_list, dim=0) 104 | return inputs 105 | 106 | 107 | def all_gatherv(input, return_boundaries=False): 108 | """Variable-sized all_gather""" 109 | 110 | # Broadcast the number of elements in every process: 111 | num_elements = torch.tensor(input.size(0), device=input.device) 112 | num_elements_per_process = all_gather(num_elements, cat=False) 113 | max_elements = num_elements_per_process.max() 114 | # Add padding so every input is the same size: 115 | difference = max_elements - input.size(0) 116 | if difference > 0: 117 | input = torch.cat([input, torch.zeros(difference, *input.size()[1:], device=input.device, dtype=input.dtype)], 118 | 0) 119 | inputs = all_gather(input, cat=False) 120 | # Remove padding: 121 | inputs = torch.cat([row[:num_ele] for row, num_ele in zip(inputs, num_elements_per_process)], 0) 122 | if return_boundaries: 123 | boundaries = torch.cumsum(num_elements_per_process, dim=0) 124 | boundaries = torch.cat([torch.zeros(1, device=input.device, dtype=torch.int), boundaries], 0) 125 | return inputs, boundaries.long() 126 | else: 127 | return inputs 128 | 129 | 130 | def all_reduce(input, device): 131 | num_local = torch.tensor([input.size(0)], dtype=torch.float, device=device) 132 | input = input.sum(dim=0, keepdim=True).to(device) 133 | num_global = all_gather(num_local).sum() 134 | input = all_gather(input) 135 | input = input.sum(dim=0).div(num_global) 136 | return input 137 | 138 | 139 | def rank0_to_all(input): 140 | input = all_gather(input) 141 | rank0_input = input[0] 142 | return rank0_input 143 | 144 | 145 | def reduce_loss_dict(loss_dict): 146 | world_size = get_world_size() 147 | 148 | if world_size < 2: 149 | return loss_dict 150 | 151 | with torch.no_grad(): 152 | keys = [] 153 | losses = [] 154 | 155 | for k in sorted(loss_dict.keys()): 156 | keys.append(k) 157 | losses.append(loss_dict[k]) 158 | 159 | losses = torch.stack(losses, 0) 160 | dist.reduce(losses, dst=0) 161 | 162 | if dist.get_rank() == 0: 163 | losses /= world_size 164 | 165 | reduced_losses = {k: v for k, v in zip(keys, losses)} 166 | 167 | return reduced_losses 168 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, Optional, Tuple, Dict, List 2 | import torch 3 | from numbers import Number 4 | 5 | from omegaconf import DictConfig 6 | from torch import Tensor 7 | from .distributed import print_once 8 | 9 | 10 | def scalars_to_log_dict(scalars: Dict[Any, Union[Number, Tensor]], mode: str) -> Dict[str, Number]: 11 | return {f'{mode}/{k}': (v.item() if isinstance(v, Tensor) else v) for k, v in scalars.items()} 12 | 13 | 14 | def epoch_outputs_to_log_dict(outputs: List[Dict[str, Tensor]], 15 | n_max: Optional[Union[int, str]] = None, 16 | shuffle: bool = False, 17 | reduce: Optional[str] = None) -> Dict[str, Tensor]: 18 | # Converts list of dicts (per-batch return values) into dict of concatenated list element dict values 19 | # Optionally return a tensor of length at most n_max for each key, and shuffle 20 | # If n_max is "batch", return one batch worth of tensors 21 | # Either cat or stack, depending on whether batch output is 0-d tensor (scalar) or not 22 | def merge_fn(v): 23 | return (torch.cat if len(v.shape) else torch.stack) if torch.is_tensor(v) else Tensor 24 | 25 | reduce_fn = lambda x: x 26 | if reduce is not None: 27 | if reduce == 'mean': 28 | reduce_fn = torch.mean 29 | elif reduce == 'sum': 30 | reduce_fn = torch.sum 31 | else: 32 | raise ValueError('reduce must be either `mean` or `sum`') 33 | out_dict = {k: reduce_fn(merge_fn(v)([o[k] for o in outputs])) for k, v in outputs[0].items() if v is not None} 34 | if n_max is not None: 35 | for k, v in out_dict.items(): 36 | if shuffle: 37 | v = v[torch.randperm(len(v))] 38 | n_max_ = len(outputs[0][k]) if n_max == "batch" else n_max 39 | out_dict[k] = v[:n_max_] 40 | return out_dict 41 | 42 | 43 | def scale_logging_rates(d: DictConfig, c: Number, strs: Tuple[str] = ('log', 'every_n_steps'), prefix: str = 'config'): 44 | if c == 1: 45 | return 46 | for k, v in d.items(): 47 | if all([s in k for s in strs]): 48 | d[k] = type(v)(v * c) 49 | print_once(f'Scaling {prefix}.{k} from {v} to {type(v)(v * c)} due to gradient accumulation') 50 | elif isinstance(v, DictConfig): 51 | scale_logging_rates(v, c, strs, prefix=prefix + '.' + k) 52 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | from dataclasses import is_dataclass, fields 4 | from math import pi 5 | from typing import Any, Union, TypeVar, Tuple, Optional, Dict, OrderedDict 6 | 7 | import einops 8 | import numpy as np 9 | import torch 10 | from PIL import Image, ImageDraw 11 | from omegaconf import DictConfig 12 | from torch import Tensor 13 | from torch.nn import functional as F 14 | from torchvision.transforms.functional import to_tensor 15 | 16 | from utils.io import load_pretrained_weights 17 | from .distributed import is_rank_zero 18 | 19 | T = TypeVar('T') 20 | FromConfig = Union[T, Dict[str, Any]] 21 | NTuple = Tuple[T, ...] 22 | StateDict = OrderedDict[str, torch.Tensor] 23 | 24 | TORCH_EINSUM = True 25 | einsum = torch.einsum if TORCH_EINSUM else oe.contract 26 | 27 | 28 | def recursive_compare(d1: dict, d2: dict, level: str = 'root') -> str: 29 | ret = [] 30 | if isinstance(d1, dict) and isinstance(d2, dict): 31 | if d1.keys() != d2.keys(): 32 | s1 = set(d1.keys()) 33 | s2 = set(d2.keys()) 34 | ret.append('{:<20} - {} + {}'.format(level, ','.join(s1 - s2), ','.join(s2 - s1))) 35 | common_keys = s1 & s2 36 | else: 37 | common_keys = set(d1.keys()) 38 | 39 | for k in common_keys: 40 | ret.append(recursive_compare(d1[k], d2[k], level='{}.{}'.format(level, k))) 41 | elif isinstance(d1, list) and isinstance(d2, list): 42 | if len(d1) != len(d2): 43 | ret.append('{:<20} len1={}; len2={}'.format(level, len(d1), len(d2))) 44 | common_len = min(len(d1), len(d2)) 45 | 46 | for i in range(common_len): 47 | ret.append(recursive_compare(d1[i], d2[i], level='{}[{}]'.format(level, i))) 48 | else: 49 | if d1 != d2: 50 | ret.append('{:<20} {} -> {}'.format(level, d1, d2)) 51 | return '\n'.join(filter(None, ret)) 52 | 53 | 54 | def import_external(name: str, pretrained: Optional[Union[str, DictConfig]] = None, **kwargs): 55 | module, name = name.rsplit('.', 1) 56 | ret = getattr(importlib.import_module(module), name) 57 | ret = ret(**to_dataclass_cfg(kwargs, ret)) 58 | return load_pretrained_weights(name, pretrained, ret) 59 | 60 | 61 | def run_at_step(step: int, freq: int): 62 | return (freq > 0) and ((step + 1) % freq == 0) 63 | 64 | 65 | def rotation_matrix(theta): 66 | cos = torch.cos(theta) 67 | sin = torch.sin(theta) 68 | return torch.stack([cos, sin, -sin, cos], dim=-1).view(*theta.shape, 2, 2) 69 | 70 | 71 | def gaussian(window_size, sigma): 72 | def gauss_fcn(x): 73 | return -(x - window_size // 2) ** 2 / float(2 * sigma ** 2) 74 | 75 | gauss = torch.stack( 76 | [torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)]) 77 | return gauss / gauss.sum() 78 | 79 | 80 | def splat_features_from_scores(scores: Tensor, features: Tensor, size: Optional[int], 81 | channels_last: bool = True) -> Tensor: 82 | """ 83 | 84 | Args: 85 | channels_last: expect input with M at end or not, see below 86 | scores: [N, H, W, M] (or [N, M, H, W] if not channels last) 87 | features: [N, M, C] 88 | size: dimension of map to return 89 | Returns: [N, C, H, W] 90 | 91 | """ 92 | if size and not (scores.shape[2] == size): 93 | if channels_last: 94 | scores = einops.rearrange(scores, 'n h w m -> n m h w') 95 | scores = F.interpolate(scores, size, mode='bilinear', align_corners=False) 96 | einstr = 'nmhw,nmc->nchw' 97 | else: 98 | einstr = 'nhwm,nmc->nchw' if channels_last else 'nmhw,nmc->nchw' 99 | return einsum(einstr, scores, features).contiguous() 100 | 101 | 102 | def cholesky_to_matrix(covs: Tensor) -> Tensor: 103 | covs[..., ::3] = covs[:, :, ::3].exp() 104 | covs[..., 2] = 0 105 | covs = einops.rearrange(covs, 'n m (x y) -> n m x y', x=2, y=2) 106 | covs = einsum('nmji,nmjk->nmik', covs, covs) # [n, m, 2, 2] 107 | return covs 108 | 109 | 110 | def jitter_image_batch(images: Tensor, dy: int, dx: int) -> Tensor: 111 | # images: N x C x H x W 112 | images = torch.roll(images, (dy, dx), (2, 3)) 113 | if dy > 0: 114 | images[:, :, :dy, :] = 0 115 | else: 116 | images[:, :, dy:, :] = 0 117 | if dx > 0: 118 | images[:, :, :, :dx] = 0 119 | else: 120 | images[:, :, :, dx:] = 0 121 | return images 122 | 123 | 124 | DERANGEMENT_WARNED = False 125 | 126 | 127 | def derangement(n: int) -> Tensor: 128 | global DERANGEMENT_WARNED 129 | orig = torch.arange(n) 130 | shuffle = torch.randperm(n) 131 | if n == 1 and not DERANGEMENT_WARNED: 132 | if is_rank_zero(): 133 | print('Warning: called derangement with n=1!') 134 | DERANGEMENT_WARNED = True 135 | while (n > 1) and (shuffle == orig).any(): 136 | shuffle = torch.randperm(n) 137 | return shuffle 138 | 139 | 140 | def pyramid_resize(img, cutoff): 141 | """ 142 | 143 | Args: 144 | img: [N x C x H x W] 145 | cutoff: threshold at which to stop pyramid 146 | 147 | Returns: gaussian pyramid 148 | 149 | """ 150 | out = [img] 151 | while img.shape[-1] > cutoff: 152 | img = F.interpolate(img, img.shape[-1] // 2, mode='bilinear', align_corners=False) 153 | out.append(img) 154 | return {i.size(-1): i for i in out} 155 | 156 | 157 | def derange_tensor(x: Tensor, dim: int = 0) -> Tensor: 158 | if dim == 0: 159 | return x[derangement(len(x))] 160 | elif dim == 1: 161 | return x[:, derangement(len(x[0]))] 162 | 163 | 164 | def derange_tensor_n_times(x: Tensor, n: int, dim: int = 0, stack_dim: int = 0) -> Tensor: 165 | return torch.stack([derange_tensor(x, dim) for _ in range(n)], stack_dim) 166 | 167 | 168 | def to_dataclass_cfg(cfg, cls): 169 | """ 170 | Can't add **kwargs catch-all to dataclass, so need to strip dict of keys that are not fields 171 | """ 172 | if is_dataclass(cls): 173 | return {k: v for k, v in cfg.items() if k in [f.name for f in fields(cls)]} 174 | return cfg 175 | 176 | 177 | def random_polygons(size: int, shape: Union[int, Tuple[int, ...]]): 178 | if type(shape) is int: 179 | shape = (shape,) 180 | n = np.prod(shape) 181 | return torch.stack([random_polygon(size) for _ in range(n)]).view(*shape, 1, size, size) 182 | 183 | 184 | def random_polygon(size: int): 185 | # Logic from Copy Paste GAN 186 | img = Image.new("RGB", (size, size), "black") 187 | f = lambda s: round(size * s) 188 | to_xy = lambda r, θ, p: (p + r * np.array([np.cos(θ), np.sin(θ)])) 189 | c = np.array([random.randint(f(0.1), f(0.9)), random.randint(f(0.1), f(0.9))]) 190 | n_vert = random.randint(4, 6) 191 | coords = [] 192 | while len(coords) < n_vert: 193 | coord = to_xy(random.uniform(f(0.1), f(0.5)), random.uniform(0, 2 * pi), c) 194 | if coord.min() >= 0: 195 | coords.append(tuple(coord)) 196 | ImageDraw.Draw(img).polygon(coords, fill="white") 197 | return to_tensor(img)[:1] 198 | -------------------------------------------------------------------------------- /src/utils/training.py: -------------------------------------------------------------------------------- 1 | # Training utilities 2 | import math 3 | import random 4 | from itertools import groupby 5 | from numbers import Number 6 | from typing import Dict, List 7 | 8 | import torch 9 | from torch import nn, Tensor, autograd 10 | 11 | from .distributed import is_rank_zero 12 | 13 | 14 | def make_noise(batch, latent_dim, n_noise): 15 | if n_noise == 1: 16 | return torch.randn(len(batch), latent_dim).type_as(batch) 17 | return torch.randn(n_noise, len(batch), latent_dim).type_as(batch).unbind(0) 18 | 19 | 20 | def mixing_noise(batch, latent_dim, prob): 21 | if prob > 0 and random.random() < prob: 22 | return make_noise(batch, latent_dim, 2) 23 | else: 24 | return [make_noise(batch, latent_dim, 1)] 25 | 26 | 27 | ACCUM_WARN = False 28 | 29 | 30 | def accumulate(model1, model2, decay=0.999): 31 | global ACCUM_WARN 32 | par1 = dict(model1.named_parameters()) 33 | par2 = dict(model2.named_parameters()) 34 | if len(par1.keys() & par2.keys()) == 0: 35 | if is_rank_zero() and not ACCUM_WARN: 36 | print('Cannot accumulate, likely due to FSDP parameter flattening. Skipping.') 37 | ACCUM_WARN = True 38 | return 39 | device = next(model1.parameters()).device 40 | for k in par1.keys(): 41 | par1[k].data.mul_(decay).add_(par2[k].data.to(device), alpha=1 - decay) 42 | 43 | 44 | 45 | 46 | def freeze(model: nn.Module, layers: List[str] = None): 47 | frozen = [] 48 | for name, param in model.named_parameters(): 49 | if layers is None or any(name.startswith(l) for l in layers): 50 | param.requires_grad = False 51 | frozen.append(name) 52 | if is_rank_zero(): 53 | depth_two_params = [k for k, _ in groupby( 54 | ['.'.join(n.split('.')[:2]).replace('.weight', '').replace('.bias', '') for n in frozen])] 55 | print(f'Froze {len(frozen)} parameters - {depth_two_params} - for model of type {model.__class__.__name__}') 56 | 57 | 58 | def requires_grad(model: nn.Module, requires: bool): 59 | for param in model.parameters(): 60 | param.requires_grad = requires 61 | 62 | 63 | def unfreeze(model: nn.Module): 64 | for param in model.parameters(): 65 | param.requires_grad = True 66 | 67 | 68 | def fill_module_uniform(module, range, blacklist=None): 69 | if blacklist is None: blacklist = [] 70 | for n, p in module.named_parameters(): 71 | if not any([b in n for b in blacklist]): 72 | nn.init.uniform_(p, -range, range) 73 | 74 | 75 | def zero_module(module): 76 | for p in module.parameters(): 77 | p.detach().zero_() 78 | return module 79 | 80 | 81 | def get_D_stats(key: str, scores: Tensor, gt: bool) -> Dict[str, Number]: 82 | acc = 100 * (scores > 0).sum() / len(scores) 83 | if not gt: 84 | acc = 100 - acc 85 | return { 86 | f'score_{key}': scores.mean(), 87 | f'acc_{key}': acc 88 | } 89 | 90 | 91 | # Losses adapted from https://github.com/rosinality/stylegan2-pytorch/blob/master/train.py 92 | def D_R1_loss(real_pred, real_img): 93 | grad_real, = autograd.grad( 94 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 95 | ) 96 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 97 | return grad_penalty 98 | 99 | 100 | def G_path_loss(fake_img, latents, mean_path_length, decay=0.01): 101 | noise = torch.randn_like(fake_img) / math.sqrt( 102 | fake_img.shape[2] * fake_img.shape[3] 103 | ) 104 | grad, = autograd.grad( 105 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 106 | ) 107 | 108 | if grad.ndim == 3: # [N_batch x N_latent x D] 109 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 110 | elif grad.ndim == 2: # [N_batch x D] 111 | path_lengths = torch.sqrt(grad.pow(2).sum(1)) 112 | 113 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 114 | 115 | path_penalty = (path_lengths - path_mean).pow(2).mean() 116 | 117 | return path_penalty, path_mean.detach(), path_lengths.mean() 118 | -------------------------------------------------------------------------------- /src/utils/wandb_logger.py: -------------------------------------------------------------------------------- 1 | # From Tim 2 | from __future__ import annotations 3 | 4 | __all__ = ["Logger"] 5 | 6 | import os 7 | import re 8 | import subprocess 9 | import sys 10 | from ast import literal_eval 11 | from math import sqrt, ceil 12 | from pathlib import Path 13 | from typing import Optional, Union 14 | 15 | import omegaconf 16 | import torch 17 | import torchvision.utils as utils 18 | import wandb 19 | from pytorch_lightning import LightningModule 20 | from pytorch_lightning.core.memory import ModelSummary 21 | from pytorch_lightning.loggers import WandbLogger 22 | from pytorch_lightning.utilities import rank_zero_only 23 | from torch import Tensor 24 | from wandb.sdk.lib.config_util import ConfigError 25 | 26 | from .distributed import is_rank_zero 27 | from .io import yes_or_no 28 | from .misc import recursive_compare 29 | 30 | 31 | class Logger(WandbLogger): 32 | def __init__( 33 | self, 34 | *, 35 | name: str, 36 | project: str, 37 | entity: str, 38 | group: Optional[str] = None, 39 | offline: bool = False, 40 | log_dir: Optional[str] = './logs', 41 | **kwargs 42 | ): 43 | log_dir = str(Path(log_dir).absolute()) 44 | 45 | super().__init__( 46 | name=name, 47 | save_dir=log_dir, 48 | offline=offline, 49 | project=project, 50 | log_model=False, 51 | entity=entity, 52 | group=group, 53 | **kwargs 54 | ) 55 | 56 | def log_hyperparams(self, *args, **kwargs): 57 | pass 58 | 59 | def _file_exists(self, path: str) -> bool: 60 | try: 61 | self.experiment.restore(path) 62 | return True 63 | except ValueError: 64 | return False 65 | 66 | def _get_unique_fn(self, filename: str, sep: str = '_') -> str: 67 | orig_filename, ext = os.path.splitext(filename) 68 | cfg_ctr = 0 69 | while self._file_exists(filename): 70 | cfg_ctr += 1 71 | filename = f"{orig_filename}{sep}{cfg_ctr}{ext}" 72 | return filename 73 | 74 | @rank_zero_only 75 | def save_to_file(self, filename: str, contents: Union[str, bytes], unique_filename: bool = True) -> str: 76 | if not is_rank_zero(): 77 | return 78 | if unique_filename: 79 | filename = self._get_unique_fn(filename) 80 | self.experiment.save(filename) 81 | t = type(contents) 82 | if t is str: 83 | mode = 'w' 84 | elif t is bytes: 85 | mode = 'wb' 86 | else: 87 | raise TypeError('Can only save str or bytes') 88 | (Path(self.experiment.dir) / filename).open(mode).write(contents) 89 | return filename 90 | 91 | @rank_zero_only 92 | def log_config(self, config: omegaconf.DictConfig): 93 | if not is_rank_zero(): 94 | return 95 | filename = self.save_to_file("hydra_config.yaml", omegaconf.OmegaConf.to_yaml(config)) 96 | params = omegaconf.OmegaConf.to_container(config) 97 | assert isinstance(params, dict) 98 | params.pop("wandb", None) 99 | 100 | try: 101 | self.experiment.config.update(params) 102 | except ConfigError as e: 103 | # Config has changed, so confirm with user that this is okay before proceeding 104 | msg = e.message.split("\n")[0] 105 | 106 | def try_literal_eval(x): 107 | try: 108 | return literal_eval(x) 109 | except ValueError: 110 | return x 111 | 112 | key, old, new = map(try_literal_eval, re.search("key (.*) from (.*) to (.*)", msg).groups()) 113 | print(f'Caution! Parameters have changed!') 114 | if not (type(old) == type(new) == dict): 115 | old = {key: old} 116 | new = {key: new} 117 | print(recursive_compare(old, new, level=key)) 118 | if yes_or_no('Was this intended?', default=True, timeout=10): 119 | print(f'Saving new parameters to {filename} and updating W and B config.') 120 | self.experiment.config.update(params, allow_val_change=True) 121 | else: 122 | sys.exit(1) 123 | 124 | @rank_zero_only 125 | def log_model_summary(self, model: LightningModule): 126 | if not is_rank_zero(): 127 | return 128 | self.save_to_file("model_summary.txt", str(ModelSummary(model, max_depth=-1))) 129 | 130 | @torch.no_grad() 131 | @rank_zero_only 132 | def log_image_batch(self, name: str, images: Tensor, square_grid: bool = True, commit: bool = False, 133 | ncol: Optional[int] = None, **kwargs): 134 | """ 135 | Args: 136 | name: Name of key to use for logging 137 | images: N x C x H x W tensor of images 138 | square_grid: whether to render images into a square grid 139 | commit: whether to commit log to wandb or not 140 | ncol: analogous to nrow in make_grid, control how many images are in each column 141 | **kwargs: passed onto make_grid 142 | """ 143 | if not is_rank_zero(): 144 | return 145 | assert not (square_grid and ncol is not None), "Set either square_grid or ncol" 146 | if square_grid: 147 | kwargs['nrow'] = ceil(sqrt(len(images))) 148 | elif ncol is not None: 149 | kwargs['nrow'] = ceil(len(images) / ncol) 150 | image_grid = utils.make_grid( 151 | images.float(), normalize=True, value_range=(-1, 1), **kwargs 152 | ) 153 | wandb_image = wandb.Image(image_grid.float().cpu()) 154 | self.experiment.log({name: wandb_image}, commit=commit) 155 | 156 | @rank_zero_only 157 | def log_code(self): 158 | if not is_rank_zero(): 159 | return 160 | codetar = subprocess.run( 161 | ['tar', '--exclude=*.pyc', '--exclude=__pycache__', '--exclude=*.pt','--exclude=*.pkl', '-cvJf', '-', 'src'], 162 | stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout 163 | self.save_to_file('code.tar.xz', codetar) 164 | --------------------------------------------------------------------------------