├── LICENSE
├── README.md
├── scripts
├── load_model.py
├── setup_fid.py
└── style-gan-pytorch
│ ├── MLproject
│ ├── README.md
│ ├── conda.yaml
│ ├── dnnlib
│ ├── __init__.py
│ ├── submission
│ │ ├── __init__.py
│ │ ├── _internal
│ │ │ └── run.py
│ │ ├── run_context.py
│ │ └── submit.py
│ ├── tflib
│ │ ├── .ipynb_checkpoints
│ │ │ ├── __init__-checkpoint.py
│ │ │ └── tfutil-checkpoint.py
│ │ ├── __init__.py
│ │ ├── autosummary.py
│ │ ├── network.py
│ │ ├── optimizer.py
│ │ └── tfutil.py
│ └── util.py
│ ├── generate.py
│ ├── loss_criterions
│ ├── __init__.py
│ ├── base_loss_criterions.py
│ └── gradient_losses.py
│ ├── networks
│ ├── __init__.py
│ ├── building_blocks.py
│ ├── custom_layers.py
│ └── style_gan_net.py
│ ├── train.py
│ └── utils.py
└── src
├── __init__.py
├── configs
├── checkpoint
│ ├── after_each_epoch.yaml
│ ├── after_each_epoch_fid.yaml
│ ├── every_n_train_steps.yaml
│ └── every_n_train_steps_fid.yaml
├── dataset
│ ├── imagefolder.yaml
│ ├── lsun.yaml
│ ├── multiimagefolder.yaml
│ ├── multilsun.yaml
│ ├── nodata.yaml
│ └── other_image_dataset.yaml
├── experiment
│ ├── blobgan.yaml
│ ├── debug.yaml
│ ├── gan.yaml
│ ├── invertblobgan.yaml
│ ├── jitter.yaml
│ └── local.yaml
└── fit.yaml
├── data
├── __init__.py
├── imagefolder.py
├── multiimagefolder.py
├── nodata.py
└── utils.py
├── models
├── __init__.py
├── base.py
├── blobgan.py
├── gan.py
├── invertblobgan.py
└── networks
│ ├── __init__.py
│ ├── layoutnet.py
│ ├── layoutstylegan.py
│ ├── op
│ ├── __init__.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_gradfix_111andon.py
│ ├── conv2d_gradfix_pre111.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
│ └── stylegan.py
├── run.py
└── utils
├── __init__.py
├── colab.py
├── distributed.py
├── io.py
├── logging.py
├── misc.py
├── training.py
└── wandb_logger.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2022, Dave Epstein
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## BlobGAN: Spatially Disentangled Scene Representations
Official PyTorch Implementation
2 |
3 | ### [Paper](https://arxiv.org/abs/2205.02837) | [Project Page](https://dave.ml/blobgan) | [Video](https://www.youtube.com/watch?v=KpUv82VsU5k) | [Interactive Demo ](https://dave.ml/blobgan/demo)
4 |
5 | https://user-images.githubusercontent.com/5674727/168323496-990b46a2-a11d-4192-898a-f5b683d20265.mp4
6 |
7 | This repository contains:
8 |
9 | * 🚂 Pre-trained BlobGAN models on three datasets: bedrooms, conference rooms, and a combination of kitchens, living rooms, and dining rooms
10 | * 💻 Code based on PyTorch Lightning ⚡ and Hydra 🐍 which fully supports CPU, single GPU, or multi GPU/node training and inference
11 |
12 | We also provide an [📓 interactive demo notebook](https://dave.ml/blobgan/demo) to help get started using our model. Download this notebook and run it on your own Python environment, or test it out on Colab. You can:
13 |
14 | * 🖌️️ Generate and edit realistic images with an interactive UI
15 | * 📹 Create animated videos showing off your edited scenes
16 | * 📸 **(new!)** Upload your own image and convert it into blobs!
17 |
18 | And, coming soon:
19 | * 🧬 More edits, as shown in the paper! Code for cloning, restyling, rotating, and reshaping blobs.
20 |
21 | ## Setup
22 |
23 | Run the commands below one at a time to download the latest version of the BlobGAN code, create a Conda environment, and install necessary packages and utilities.
24 |
25 | ```bash
26 | git clone https://github.com/dave-epstein/blobgan.git
27 | mkdir -p blobgan/logs/wandb
28 | conda create -y -n blobgan python=3.9
29 | conda activate blobgan
30 | conda install -y pytorch=1.11.0 torchvision=0.12.0 torchaudio cudatoolkit=11.3 -c pytorch
31 | conda install -y cudatoolkit-dev=11.3 -c conda-forge
32 | pip install tqdm==4.64.0 hydra-core==1.1.2 omegaconf==2.1.2 clean-fid==0.1.23 wandb==0.12.11 ipdb==0.13.9 lpips==0.1.4 einops==0.4.1 inputimeout==1.0.4 pytorch-lightning==1.5.10 matplotlib==3.5.2 "mpl_interactions[jupyter]==0.21.0" protobuf~=3.19.0 moviepy==1.0.3
33 | cd blobgan
34 | python scripts/setup_fid.py
35 | ```
36 | And if you haven't installed `ninja` yet on your machine (to compile custom C++ code), do that. On Linux, this looks like:
37 | ```
38 | wget -q --show-progress https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip
39 | sudo unzip -q ninja-linux.zip -d /usr/local/bin/
40 | sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
41 | ```
42 |
43 |
44 | ## Running pretrained models
45 |
46 | See `scripts/load_model.py` for an example of how to load a pre-trained model (using the provided `load_{blobgan/stylegan}_model` functions, which can be called from elsewhere) and generate images with it. You can also run the file from the command line to generate images and save them to disk. For example, from the `blobgan` directory, you can run:
47 |
48 | ```bash
49 | python scripts/load_model.py --model_data bed --dl_dir models --save_dir out --n_imgs 32 --save_blobs --label_blobs
50 | ```
51 |
52 | Or
53 |
54 | ```bash
55 | python scripts/load_model.py --model_name stylegan --model_data conference --truncate 0.4
56 | ```
57 | Note that the first run may take a minute or two longer as custom C++ code is compiled. See the command's help for more details and options: `scripts/load_model.py --help`
58 |
59 | Using these functions, you can access pretrained models on bedrooms (trained with or without jitter); conference rooms; and kitchens, living rooms, and dining rooms.
60 |
61 | ## Training your own model
62 |
63 | **Before training your model**, you'll need to modify `src/configs/experiments/local.yaml` to include your WandB information and machine-specific configuration (such as path to data -- `dataset.path` or `dataset.basepath` -- and number of GPUs `trainer.gpus`). Alternatively, you can exclude `local` from the `experiments` option in the commands below and specify these parameters directly on the command line.
64 |
65 | To turn off logging entirely, pass `logger=false`, or to only log to disk but not write to server, pass `wandb.offline=true`. Our code currently only supports WandB logging.
66 |
67 | Here's an example command which will train a model on LSUN bedrooms. We list the configuration modules to load for this experiment (`blobgan`, `local`, `jitter`) and then specify any other options as we desire. For example, if we wanted to train a model without jitter, we could just remove that module from the `experiments` array.
68 |
69 | ```bash
70 | python src/run.py +experiment=[blobgan,local,jitter] wandb.name='10-blob BlobGAN on bedrooms'
71 | ```
72 |
73 | In some shells, you may need to add extra quotes around some of these options to prevent them from being parsed immediately on the command line.
74 |
75 | Train on the LSUN category of your choice by passing in `dataset.category`, e.g. `dataset.category=church`. Tackle multiple categories at once with `dataset=multilsun` and `dataset.categories=[kitchen,bedroom]`.
76 |
77 | You can also train on any collection of images by selecting `dataset=imagefolder` and passing in the path. The code expects at least a subfolder named `train` and optional subfolders named `validate` and `test`. The below command also illustrates how to set arbitrary options using Hydra syntax, such as turning off FID logging or changing dataloader batch size:
78 |
79 | ```bash
80 | python src/run.py +experiment=[blobgan,local,jitter] wandb.name='20-blob BlobGAN on Places' dataset.dataloader.batch_size=24 +model.log_fid_every_epoch=false dataset=imagefolder +dataset.path=/path/to/places/ model.n_features=20
81 | ```
82 |
83 | Other parameters of interest are likely `trainer.log_every_n_steps`, `model.log_images_every_n_steps`, and `model.log_fid_every_n_steps`, which control frequency of logging scalars, images, and FID (set any of the latter two to -1 to disable). Also check out `checkpoint.every_n_train_steps` and `checkpoint.save_top_k` which dictate checkpoint saving frequency and decide how many most recent checkpoints to keep (`-1` means keep everything).
84 |
85 | ### Changing model feature dimensions
86 |
87 | To change `d_in`, set both `model.layout_net.feature_dim` and `model.generator.override_c_in` to the same value. To change `d_style`, change `model.dim`.
88 |
89 | ### Logging FID during training and at test
90 |
91 | In the initial codebase setup, you should have run `scripts/setup_fid.py` which will download and install FID statistics for three different datasets:
92 |
93 | * Bedrooms: `lsun_bedroom`
94 | * Conference rooms: `lsun_conference`
95 | * Kitchens, living rooms, dining rooms: `lsun_kld`
96 |
97 | **If either `model.log_fid_every_n_steps > -1` or `model.log_fid_every_epoch == true`, make sure that `model.fid_stats_name` is passed in.** If you are training on one of the three datasets from the paper, just pass in the string from the list above.
98 |
99 | If you are training on your own data, you'll need to first run `setup_fid.py` to precompute statistics on that. The command might look something like:
100 |
101 | ```bash
102 | python scripts/setup_fid.py --action compute_new --path /path/to/new/data --name newdata -j 32 -bs 256
103 | ```
104 |
105 | Then, pass `model.fid_stats_name=newdata` on the command line.
106 |
107 | Note that the precomputed FID statistics are on 256px images. You will need to recompute if training at higher resolution.
108 |
109 | To run FID logging at test time, a simple snippet such as the following will return the score:
110 | ```python
111 | model = load_blobgan_model('bed', 'models', 'cuda', fixed_noise=False)
112 | model.fid_stats_name = 'lsun_bedroom'
113 | model.fid_n_imgs = 50000
114 | print(model.log_fid('train'))
115 | ```
116 |
117 | ### Resuming training
118 |
119 | To continue a training run that was terminated, simply add `resume.id=PREVIOUS RUN ID`. To resume from a previous run but start a new WandB run (e.g. to avoid overwriting previous checkpoints), also pass in `wandb.id=null`.
120 |
121 | ### Training StyleGAN2
122 |
123 | Many of the above command line options apply (for controlling data and logging). For example, to train a StyleGAN2 model on LSUN conference rooms, run:
124 |
125 | ```bash
126 | python src/run.py +experiment=[gan,local] wandb.name='Conference room StyleGAN2' dataset.category=conference
127 | ```
128 |
129 | This uses default StyleGAN2 hyperparameters: R1 regularization on D every 16 steps, path length regularization on G every 4, R1 weight 50 or gamma=100 (the weight is gamma/2).
130 |
131 | ### Training inversion encoders
132 |
133 | The same is true for training an inversion encoder. See this example command:
134 |
135 | ```bash
136 | python src/run.py +experiment=[invertblobgan,local] wandb.name='Inversion model' +model.G_pretrained.id="BLOBGAN MODEL ID HERE" +model.trunc_min=0.2 +model.trunc_max=0.4 model.lambda.fake_latents_MSE=10
137 | ```
138 |
139 | Be sure to specify `model.G_pretrained.id` to match the ID of the BlobGAN model you are trying to invert. Also, you can set `model.G_pretrained.log_dir` to tell the program where to look for the model logs (this defaults to `./logs` if unspecified). The options `trunc_min` and `trunc_max` specify what truncation level to use (randomly sampled within the specified interval) when sampling fake images. If both are set to the same value (including zero, the default), this value will always be used.
140 |
141 | ## Citation
142 |
143 | If our code or models aided your research, please cite our [paper](https://arxiv.org/abs/2205.02837):
144 | ```
145 | @misc{epstein2022blobgan,
146 | title={BlobGAN: Spatially Disentangled Scene Representations},
147 | author={Dave Epstein and Taesung Park and Richard Zhang and Eli Shechtman and Alexei A. Efros},
148 | year={2022},
149 | eprint={2205.02837},
150 | archivePrefix={arXiv},
151 | primaryClass={cs.CV}
152 | }
153 | ```
154 |
155 | ## Code acknowledgments
156 |
157 | This repository is built on top of rosinality's excellent [PyTorch re-implementation of StyleGAN2](https://github.com/rosinality/stylegan2-pytorch) and Bill Peebles' [GANgealing codebase](https://github.com/wpeebles/gangealing).
158 |
--------------------------------------------------------------------------------
/scripts/load_model.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import os, sys
3 | import torch
4 | from PIL import Image
5 | from tqdm import tqdm, trange
6 |
7 | here_dir = os.path.dirname(__file__)
8 |
9 | sys.path.append(os.path.join(here_dir, '..', 'src'))
10 | os.environ['PYTHONPATH'] = os.path.join(here_dir, '..', 'src')
11 |
12 | from models import BlobGAN, GAN, BlobGANInverter
13 | from utils import download_model, download_mean_latent, download_cherrypicked, KLD_COLORS, BED_CONF_COLORS, \
14 | viz_score_fn, for_canvas, draw_labels, download
15 |
16 |
17 | def load_SGAN1_bedrooms(path, device='cuda'):
18 | ckpt = download(path=path, file='SGAN1_bedrooms.ckpt', load=True)
19 | sys.path.append(os.path.join(here_dir, 'style-gan-pytorch'))
20 | from networks.style_gan_net import Generator
21 |
22 | model = Generator(resolution=256)
23 | model.load_state_dict(ckpt)
24 | model.eval()
25 | return model.to(device)
26 |
27 |
28 | def load_stylegan1_model(model_data, path, device='cuda'):
29 | if model_data.startswith('bed'):
30 | model = load_SGAN1_bedrooms(path, device)
31 | Z = torch.randn((10000, 512)).to(device)
32 | latents = [model.g_mapping(Z[_:_ + 1])[0] for _ in trange(10000, desc='Computing mean latent')]
33 | model.mean_latent = torch.stack(latents, 0).mean(0)
34 |
35 | def SGAN1_gen(z, truncate):
36 | a = 1 - truncate
37 | dlatents = model.g_mapping(z).clone()
38 | if a < 1:
39 | dlatents = a * dlatents + (1 - a) * model.mean_latent
40 | x = model.g_synthesis(dlatents, 8, 1).clone()
41 | xx = ((x.clamp(min=-1, max=1) + 1) / 2.0) * 255
42 | return xx
43 |
44 | model.gen = SGAN1_gen
45 | else:
46 | raise ValueError('Only bedrooms supported for SGAN1.')
47 |
48 |
49 | def load_stylegan_model(model_data, path, device='cuda'):
50 | if model_data.startswith('bed'):
51 | datastr = 'bed'
52 | else:
53 | datastr = 'conference' if model_data.startswith('conference') else 'kitchenlivingdining'
54 | ckpt = download(path=path, file=f'SGAN2_{datastr}.ckpt')
55 | model = GAN.load_from_checkpoint(ckpt, strict=False).to(device)
56 | model.get_mean_latent()
57 | return model
58 |
59 |
60 | def load_blobgan_model(model_data, path, device='cuda', fixed_noise=False):
61 | ckpt = download_model(model_data, path)
62 | model = BlobGAN.load_from_checkpoint(ckpt, strict=False).to(device)
63 | try:
64 | model.mean_latent = download_mean_latent(model_data, path).to(device)
65 | except:
66 | model.get_mean_latent()
67 | try:
68 | model.cherry_picked = download_cherrypicked(model_data, path).to(device)
69 | except:
70 | pass
71 | COLORS = KLD_COLORS if 'kitchen' in model_data else BED_CONF_COLORS
72 | model.colors = COLORS
73 | noise = [torch.randn((1, 1, 16 * 2 ** ((i + 1) // 2), 16 * 2 ** ((i + 1) // 2))).to(device) for i in
74 | range(model.generator_ema.num_layers)] if fixed_noise else None
75 | model.noise = noise
76 | render_kwargs = {
77 | 'no_jitter': True,
78 | 'ret_layout': True,
79 | 'viz': True,
80 | 'ema': True,
81 | 'viz_colors': COLORS,
82 | 'norm_img': True,
83 | 'viz_score_fn': viz_score_fn,
84 | 'noise': noise
85 | }
86 | model.render_kwargs = render_kwargs
87 | return model
88 |
89 |
90 | def load_inversion_model(model_data, path, device):
91 | ckpt = download(model_data, suffix='_invert.ckpt', path=path, load=False)
92 | d_out = torch.load(ckpt, map_location='cpu')['state_dict']['inverter.final_linear.1.weight'].shape[0]
93 | model = BlobGANInverter.load_from_checkpoint(ckpt, strict=False, load_only_inverter=True, inverter_d_out=d_out).to(
94 | device)
95 | return model
96 |
97 |
98 | if __name__ == "__main__":
99 | import argparse
100 |
101 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
102 | parser.add_argument("-m", "--model_name", default='blobgan',
103 | choices=['blobgan', 'stylegan', 'stylegan1'])
104 | parser.add_argument("-d", "--model_data", default='bed',
105 | help="Choose a pretrained model. This must be a string that begins either with `bed_no_jitter` (bedrooms, trained without jitter), "
106 | "`bed` (bedrooms),"
107 | " `kitchen` (kitchens, living rooms, and dining rooms),"
108 | " or `conference` (conference rooms).")
109 | parser.add_argument("-dl", "--dl_dir", default='models',
110 | help='Path to a directory where model files will be downloaded.')
111 | parser.add_argument("-s", "--save_dir", default='out',
112 | help='Path to the directory where output images will be saved.')
113 | parser.add_argument("-n", "--n_imgs", default=100, type=int, help='Number of random images to generate.')
114 | parser.add_argument('-bs', '--batch_size', default=32,
115 | help='Number of images to generate in one forward pass. Adjust based on available GPU memory.',
116 | type=int)
117 | parser.add_argument('-t', '--truncate', default=0.4,
118 | help='Amount of truncation to use when generating images. 0 means no truncation, 1 means full truncation.',
119 | type=float)
120 | parser.add_argument("--save_blobs", action='store_true',
121 | help='If passed, save images of blob maps (when `--model_name` is BlobGAN).')
122 | parser.add_argument("--label_blobs", action='store_true',
123 | help='If passed, add numeric blob labels to blob map images, when `--save_blobs` is true.')
124 | parser.add_argument('--size_threshold', default=-3,
125 | help='Threshold for blob size parameter above which to render blob labels, when `--label_blobs` is true.',
126 | type=float)
127 | parser.add_argument('--device', default='cuda',
128 | help='Specify the device on which to run the code, in PyTorch syntax, e.g. `cuda`, `cpu`, `cuda:3`.')
129 | parser.add_argument('--fixed_spatial_noise', action='store_true',
130 | help='Whether to use random spatial noise to generate images. '
131 | 'This is false by default for general use cases, but set it to true for things like animation.')
132 | args = parser.parse_args()
133 |
134 | blobgan = args.model_name == 'blobgan'
135 | stylegan = args.model_name == 'stylegan'
136 | sgan1 = args.model_name == 'stylegan1'
137 |
138 | save_dir = Path(args.save_dir)
139 | (save_dir / 'imgs').mkdir(exist_ok=True, parents=True)
140 |
141 | if blobgan:
142 | model = load_blobgan_model(args.model_data, args.dl_dir, args.device, fixed_noise=args.fixed_spatial_noise)
143 |
144 | if args.save_blobs:
145 | (save_dir / 'blobs').mkdir(exist_ok=True, parents=True)
146 | if args.label_blobs:
147 | (save_dir / 'blobs_labeled').mkdir(exist_ok=True, parents=True)
148 | elif stylegan:
149 | model = load_stylegan_model(args.model_data, args.dl_dir, args.device)
150 | elif sgan1:
151 | model = load_stylegan1_model(args.model_data, args.dl_dir, args.device)
152 | else:
153 | raise NotImplementedError('Inversion of images from command line not yet supported. ')
154 |
155 | n_to_gen = args.n_imgs
156 | n_gen = 0
157 |
158 | torch.set_grad_enabled(False)
159 |
160 | with tqdm(total=args.n_imgs, desc='Generating images') as pbar:
161 | while n_to_gen > 0:
162 | bs = min(args.batch_size, n_to_gen)
163 | z = torch.randn((bs, 512)).to(args.device)
164 |
165 | if blobgan:
166 | layout, orig_img = model.gen(z=z, truncate=args.truncate, **model.render_kwargs)
167 | else:
168 | orig_img = model.gen(z=z, truncate=args.truncate)
169 |
170 | for i in range(len(orig_img)):
171 | img_i = for_canvas(orig_img[i:i + 1])
172 | Image.fromarray(img_i).save(str(save_dir / 'imgs' / f'{i + n_gen:04d}.png'))
173 | if blobgan and args.save_blobs:
174 | blobs_i = for_canvas(layout['feature_img'][i:i + 1].mul(255))
175 | Image.fromarray(blobs_i).save(str(save_dir / 'blobs' / f'{i + n_gen:04d}.png'))
176 | if args.label_blobs:
177 | labeled_blobs, labeled_blobs_img = draw_labels(blobs_i, layout, args.size_threshold,
178 | model.colors, layout_i=i)
179 | labeled_blobs_img.save(str(save_dir / 'blobs_labeled' / f'{i + n_gen:04d}.png'))
180 |
181 | n_to_gen -= bs
182 | n_gen += bs
183 | pbar.update(bs)
184 |
--------------------------------------------------------------------------------
/scripts/setup_fid.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import sys
4 | import cleanfid
5 | import torch
6 | import numpy as np
7 | from cleanfid.features import build_feature_extractor
8 | from cleanfid.fid import get_folder_features
9 | from torchvision.transforms import functional as F
10 |
11 | here_dir = os.path.dirname(__file__)
12 |
13 | sys.path.append(os.path.join(here_dir, '..', 'src'))
14 | os.environ['PYTHONPATH'] = os.path.join(here_dir, '..', 'src')
15 |
16 | from utils import download
17 |
18 |
19 | if __name__ == "__main__":
20 | import argparse
21 |
22 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
23 | parser.add_argument('--action', default='download', choices=['download', 'compute_new'],
24 | help='All other options only apply if action is set to `compute_new`.'
25 | ' Download mode (default) simply configures precomputed stats used in the BlobGAN paper.')
26 | parser.add_argument('--path', default='', type=str,
27 | help='Path to custom folder from which to sample `--n_imgs` images and compute FID statistics.')
28 | parser.add_argument('--n_imgs', type=int, default=-1,
29 | help='Number of images to randomly sample for FID stats. Set to -1 to use all images.')
30 | parser.add_argument('--shuffle', action='store_true',
31 | help='Shuffle the files in a directory before selecting `--n_imgs` for computation.')
32 | parser.add_argument('--name', default=None, help='Name to give custom stats.')
33 | parser.add_argument('-bs', '--batch_size', default=32,
34 | help='Number of images to analyze in one forward pass. Adjust based on available GPU memory.',
35 | type=int)
36 | parser.add_argument('-j', '--num_workers', default=8,
37 | help='Number of workers to use for FID stats generation.',
38 | type=int)
39 | parser.add_argument('-r', '--resolution', default=256,
40 | help='Image resolution to use before feeding images into FID pipeline (where they are resized to 299).',
41 | type=int)
42 | parser.add_argument('--device', default='cuda',
43 | help='Specify the device on which to run the code, in PyTorch syntax, '
44 | 'e.g. `cuda`, `cpu`, `cuda:3`.')
45 | args = parser.parse_args()
46 |
47 |
48 | def load_fn(x):
49 | x = F.resize(torch.from_numpy(x).permute(2, 0, 1), args.resolution)
50 | x = F.center_crop(x, args.resolution).permute(1, 2, 0)
51 | return np.array(x)
52 |
53 | if args.action == 'download':
54 | path = os.path.join(os.path.dirname(cleanfid.__file__), "stats")
55 | stats = download(path=path, file='fid_stats.tar.gz')
56 | subprocess.run(["tar", "xvzf", stats, '-C', path, '--strip-components', '1'])
57 | else:
58 | print('Calculating...')
59 | name, mode, device, fdir, num = args.name, "clean", torch.device(args.device), args.path, args.n_imgs
60 | assert name
61 | stats_folder = os.path.join(os.path.dirname(cleanfid.__file__), "stats")
62 | os.makedirs(stats_folder, exist_ok=True)
63 | split, res = "custom", "na"
64 | outname = f"{name}_{mode}_{split}_{res}.npz"
65 | outf = os.path.join(stats_folder, outname)
66 | # if the custom stat file already exists
67 | if os.path.exists(outf):
68 | msg = f"The statistics file {name} already exists. "
69 | msg += "Use remove_custom_stats function to delete it first."
70 | raise Exception(msg)
71 |
72 | feat_model = build_feature_extractor(mode, device)
73 | fbname = os.path.basename(fdir)
74 | # get all inception features for folder images
75 | if num < 0: num = None
76 | np_feats = get_folder_features(fdir, feat_model, num_workers=args.num_workers, num=num, shuffle=args.shuffle,
77 | batch_size=args.batch_size, device=device, custom_image_tranform=load_fn,
78 | mode=mode, description=f"custom stats: {fbname} : ")
79 | mu = np.mean(np_feats, axis=0)
80 | sigma = np.cov(np_feats, rowvar=False)
81 | print(f"Saving custom FID stats to {outf}")
82 | np.savez_compressed(outf, mu=mu, sigma=sigma)
83 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/MLproject:
--------------------------------------------------------------------------------
1 | name: style-gan
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | convert: {type: string, default: True}
9 | official_checkpoints: {type: string, default: True}
10 | random_seed: {type: int, default: 77}
11 | dataset: {type: string, default: ffhq}
12 | nrow: {type: int, default: 2}
13 | ncol: {type: int, default: 2}
14 | g_checkpoint: {type: string, default: ./checkpoints/generator.64x64.0.759840.3460000.158.pt}
15 | target_resolution: {type: int, default: 128}
16 | command: |
17 | python generate.py \
18 | --dataset {dataset} \
19 | --convert {convert} \
20 | --use-official-checkpoints {official_checkpoints} \
21 | --random-seed {random_seed} \
22 | --nrow {nrow} \
23 | --ncol {ncol} \
24 | --g-checkpoint {g_checkpoint} \
25 | --target-resolution {target_resolution}
26 |
27 | train:
28 | parameters:
29 | data_root: {type: string, default: ./data/celeba}
30 | resume: {type: string, default: True}
31 | g_checkpoint: {type: string, default: ./checkpoints/generator.64x64.0.759840.3460000.158.pt}
32 | d_checkpoint: {type: string, default: ./checkpoints/discriminator.64x64.0.759840.3460000.158.pt}
33 | target_resolution: {type: int, default: 128}
34 | n_gpu: {type: int, default: 1}
35 | command: |
36 | python train.py \
37 | --data-root {data_root} \
38 | --resume {resume} \
39 | --g-checkpoint {g_checkpoint} \
40 | --d-checkpoint {d_checkpoint} \
41 | --target-resolution {target_resolution} \
42 | --n-gpu {n_gpu}
43 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/README.md:
--------------------------------------------------------------------------------
1 | # StyleGAN Pytorch Implementation
2 | This is a Pytorch implementation of StyleGAN (https://arxiv.org/abs/1812.04948), with the capability of generating 1024x1024 pictures. Training to grow to 1024x1024 is also supported. A 1080 Ti is recommended for faster training speed.
3 |
4 | ## Prerequisites
5 | ### Dependencies
6 | See conda.yaml. Please note that I have cuda 10.0 installed. Change your conda.yaml accordingly if you use different cuda version.
7 | ### Image generation using official implementation's TensorFlow checkpoints
8 |
9 | **This step is not needed if running the generation command succeeds for downloading.** If downloading fails for Google Drive, manual download is required:
10 |
11 | * [ffhq-1024x1024](https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ)
12 | * [bedrooms-256x256](https://drive.google.com/open?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF)
13 | * [cats-256x256](https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ)
14 |
15 | And place them in ./pretrained directory.
16 |
17 | ### Training prerequisites on CelebA dataset
18 | Download the [celeba](https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg) dataset. Unzip the .zip file into ./data/celeba directory.
19 |
20 | ## Image Generation
21 | ### Image generation using official checkpoints
22 | Run the command:
23 | ```bash
24 | mlflow -e generate . -P dataset=cats
25 | ```
26 | The default random seed is 77. To generate different images with different image grid (note that the number of images you can generate is limited by your GPU).
27 | ```bash
28 | mlflow -e generate -P dataset=cats -P random-seed=777 -P nrow=2 -P ncol=5
29 | ```
30 | This will generate 10 images at once.
31 | ### Image generation using checkpoints generated by this code
32 | ```bash
33 | mlflow -e genearte . \
34 | -P use_official_checkpoints=False \
35 | -P g_checkpoint=[path_to_generator_checkpoint] \
36 | -P target_resolution=128 \
37 | -P nrow=2 \
38 | -P ncol=2
39 | ```
40 | This will generate images using checkpoints trained by this code.
41 |
42 | ## Training on CelebA dataset
43 | Run the command to start from scratch:
44 | ```bash
45 | mlflow -e train . -P resume=False
46 | ```
47 | This will kick off the training for 128x128 resolution on CelebA dataset. During training, the model checkpoints are stored under ./checkpoints, and the fake images are generated for checking under ./checks/fake\_imgs. Note that this is a progressive process starting from 8x8, so you will see 8x8 images in the begining and 128x128 images in the end of the training process.
48 |
49 | To resume training:
50 | ```bash
51 | mlflow -e train .\
52 | -P resume=True \
53 | -P g_checkpoint=[path_to_generator_checkpoint] \
54 | -P d_checkpoint=[path_to_discriminator_checkpoint]
55 | ```
56 | For other training options, please check the MLproject file. For hyperparameters, please check train.py and NVidia's official implementation.
57 |
58 | ## TODO
59 | 1. Add truncation trick
60 | 2. Add and experiment with other loss functions (some are in the repo but not tried)
61 | 3. Add tensorboard support
62 | 4. Add moving average of generator's weight
63 |
64 | Multi-GPU support is added but not experimented due to hardware limitation.
65 |
66 | ## License
67 | This project is under BSD-3 license.
68 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/conda.yaml:
--------------------------------------------------------------------------------
1 | name: pytorch_stylegan
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - python=3.6
7 | - cudatoolkit=10.0 # my machine has cuda 10.0
8 | - pytorch=1.1.0
9 | - torchvision=0.3.0
10 | - tensorflow-gpu=1.13.1
11 | - jupyter=1.0.0
12 | - matplotlib=3.1.0
13 | - pip:
14 | - mlflow>=1.0
15 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import submission
9 |
10 | from .submission.run_context import RunContext
11 |
12 | from .submission.submit import SubmitTarget
13 | from .submission.submit import PathType
14 | from .submission.submit import SubmitConfig
15 | from .submission.submit import get_path_from_template
16 | from .submission.submit import submit_run
17 |
18 | from .util import EasyDict
19 |
20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
21 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/submission/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import run_context
9 | from . import submit
10 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/submission/_internal/run.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for launching run functions in computing clusters.
9 |
10 | During the submit process, this file is copied to the appropriate run dir.
11 | When the job is launched in the cluster, this module is the first thing that
12 | is run inside the docker container.
13 | """
14 |
15 | import os
16 | import pickle
17 | import sys
18 |
19 | # PYTHONPATH should have been set so that the run_dir/src is in it
20 | import dnnlib
21 |
22 | def main():
23 | if not len(sys.argv) >= 4:
24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
25 |
26 | run_dir = str(sys.argv[1])
27 | task_name = str(sys.argv[2])
28 | host_name = str(sys.argv[3])
29 |
30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl")
31 |
32 | # SubmitConfig should have been pickled to the run dir
33 | if not os.path.exists(submit_config_path):
34 | raise RuntimeError("SubmitConfig pickle file does not exist!")
35 |
36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
38 |
39 | submit_config.task_name = task_name
40 | submit_config.host_name = host_name
41 |
42 | dnnlib.submission.submit.run_wrapper(submit_config)
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/submission/run_context.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helpers for managing the run/training loop."""
9 |
10 | import datetime
11 | import json
12 | import os
13 | import pprint
14 | import time
15 | import types
16 |
17 | from typing import Any
18 |
19 | from . import submit
20 |
21 |
22 | class RunContext(object):
23 | """Helper class for managing the run/training loop.
24 |
25 | The context will hide the implementation details of a basic run/training loop.
26 | It will set things up properly, tell if run should be stopped, and then cleans up.
27 | User should call update periodically and use should_stop to determine if run should be stopped.
28 |
29 | Args:
30 | submit_config: The SubmitConfig that is used for the current run.
31 | config_module: The whole config module that is used for the current run.
32 | max_epoch: Optional cached value for the max_epoch variable used in update.
33 | """
34 |
35 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
36 | self.submit_config = submit_config
37 | self.should_stop_flag = False
38 | self.has_closed = False
39 | self.start_time = time.time()
40 | self.last_update_time = time.time()
41 | self.last_update_interval = 0.0
42 | self.max_epoch = max_epoch
43 |
44 | # pretty print the all the relevant content of the config module to a text file
45 | if config_module is not None:
46 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
47 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
48 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
49 |
50 | # write out details about the run to a text file
51 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
52 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
53 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
54 |
55 | def __enter__(self) -> "RunContext":
56 | return self
57 |
58 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
59 | self.close()
60 |
61 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
62 | """Do general housekeeping and keep the state of the context up-to-date.
63 | Should be called often enough but not in a tight loop."""
64 | assert not self.has_closed
65 |
66 | self.last_update_interval = time.time() - self.last_update_time
67 | self.last_update_time = time.time()
68 |
69 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
70 | self.should_stop_flag = True
71 |
72 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
73 |
74 | def should_stop(self) -> bool:
75 | """Tell whether a stopping condition has been triggered one way or another."""
76 | return self.should_stop_flag
77 |
78 | def get_time_since_start(self) -> float:
79 | """How much time has passed since the creation of the context."""
80 | return time.time() - self.start_time
81 |
82 | def get_time_since_last_update(self) -> float:
83 | """How much time has passed since the last call to update."""
84 | return time.time() - self.last_update_time
85 |
86 | def get_last_update_interval(self) -> float:
87 | """How much time passed between the previous two calls to update."""
88 | return self.last_update_interval
89 |
90 | def close(self) -> None:
91 | """Close the context and clean up.
92 | Should only be called once."""
93 | if not self.has_closed:
94 | # update the run.txt with stopping time
95 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
96 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
97 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
98 |
99 | self.has_closed = True
100 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/tflib/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import autosummary
9 | from . import network
10 | from . import optimizer
11 | from . import tfutil
12 |
13 | from .tfutil import *
14 | from .network import Network
15 |
16 | from .optimizer import Optimizer
17 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/tflib/.ipynb_checkpoints/tfutil-checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous helper utils for Tensorflow."""
9 |
10 | import os
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from typing import Any, Iterable, List, Union
15 |
16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
17 | """A type that represents a valid Tensorflow expression."""
18 |
19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
20 | """A type that can be converted to a valid Tensorflow expression."""
21 |
22 |
23 | def run(*args, **kwargs) -> Any:
24 | """Run the specified ops in the default session."""
25 | assert_tf_initialized()
26 | return tf.get_default_session().run(*args, **kwargs)
27 |
28 |
29 | def is_tf_expression(x: Any) -> bool:
30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
32 |
33 |
34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
35 | """Convert a Tensorflow shape to a list of ints."""
36 | return [dim.value for dim in shape]
37 |
38 |
39 | def flatten(x: TfExpressionEx) -> TfExpression:
40 | """Shortcut function for flattening a tensor."""
41 | with tf.name_scope("Flatten"):
42 | return tf.reshape(x, [-1])
43 |
44 |
45 | def log2(x: TfExpressionEx) -> TfExpression:
46 | """Logarithm in base 2."""
47 | with tf.name_scope("Log2"):
48 | return tf.log(x) * np.float32(1.0 / np.log(2.0))
49 |
50 |
51 | def exp2(x: TfExpressionEx) -> TfExpression:
52 | """Exponent in base 2."""
53 | with tf.name_scope("Exp2"):
54 | return tf.exp(x * np.float32(np.log(2.0)))
55 |
56 |
57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
58 | """Linear interpolation."""
59 | with tf.name_scope("Lerp"):
60 | return a + (b - a) * t
61 |
62 |
63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
64 | """Linear interpolation with clip."""
65 | with tf.name_scope("LerpClip"):
66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
67 |
68 |
69 | def absolute_name_scope(scope: str) -> tf.name_scope:
70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
71 | return tf.name_scope(scope + "/")
72 |
73 |
74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
77 |
78 |
79 | def _sanitize_tf_config(config_dict: dict = None) -> dict:
80 | # Defaults.
81 | cfg = dict()
82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
87 |
88 | # User overrides.
89 | if config_dict is not None:
90 | cfg.update(config_dict)
91 | return cfg
92 |
93 |
94 | def init_tf(config_dict: dict = None) -> None:
95 | """Initialize TensorFlow session using good default settings."""
96 | # Skip if already initialized.
97 | if tf.get_default_session() is not None:
98 | return
99 |
100 | # Setup config dict and random seeds.
101 | cfg = _sanitize_tf_config(config_dict)
102 | np_random_seed = cfg["rnd.np_random_seed"]
103 | if np_random_seed is not None:
104 | np.random.seed(np_random_seed)
105 | tf_random_seed = cfg["rnd.tf_random_seed"]
106 | if tf_random_seed == "auto":
107 | tf_random_seed = np.random.randint(1 << 31)
108 | if tf_random_seed is not None:
109 | tf.set_random_seed(tf_random_seed)
110 |
111 | # Setup environment variables.
112 | for key, value in list(cfg.items()):
113 | fields = key.split(".")
114 | if fields[0] == "env":
115 | assert len(fields) == 2
116 | os.environ[fields[1]] = str(value)
117 |
118 | # Create default TensorFlow session.
119 | create_session(cfg, force_as_default=True)
120 |
121 |
122 | def assert_tf_initialized():
123 | """Check that TensorFlow session has been initialized."""
124 | if tf.get_default_session() is None:
125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
126 |
127 |
128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
129 | """Create tf.Session based on config dict."""
130 | # Setup TensorFlow config proto.
131 | cfg = _sanitize_tf_config(config_dict)
132 | config_proto = tf.ConfigProto()
133 | for key, value in cfg.items():
134 | fields = key.split(".")
135 | if fields[0] not in ["rnd", "env"]:
136 | obj = config_proto
137 | for field in fields[:-1]:
138 | obj = getattr(obj, field)
139 | setattr(obj, fields[-1], value)
140 |
141 | # Create session.
142 | session = tf.Session(config=config_proto)
143 | if force_as_default:
144 | # pylint: disable=protected-access
145 | session._default_session = session.as_default()
146 | session._default_session.enforce_nesting = False
147 | session._default_session.__enter__() # pylint: disable=no-member
148 |
149 | return session
150 |
151 |
152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
153 | """Initialize all tf.Variables that have not already been initialized.
154 |
155 | Equivalent to the following, but more efficient and does not bloat the tf graph:
156 | tf.variables_initializer(tf.report_uninitialized_variables()).run()
157 | """
158 | assert_tf_initialized()
159 | if target_vars is None:
160 | target_vars = tf.global_variables()
161 |
162 | test_vars = []
163 | test_ops = []
164 |
165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
166 | for var in target_vars:
167 | assert is_tf_expression(var)
168 |
169 | try:
170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
171 | except KeyError:
172 | # Op does not exist => variable may be uninitialized.
173 | test_vars.append(var)
174 |
175 | with absolute_name_scope(var.name.split(":")[0]):
176 | test_ops.append(tf.is_variable_initialized(var))
177 |
178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
179 | run([var.initializer for var in init_vars])
180 |
181 |
182 | def set_vars(var_to_value_dict: dict) -> None:
183 | """Set the values of given tf.Variables.
184 |
185 | Equivalent to the following, but more efficient and does not bloat the tf graph:
186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
187 | """
188 | assert_tf_initialized()
189 | ops = []
190 | feed_dict = {}
191 |
192 | for var, value in var_to_value_dict.items():
193 | assert is_tf_expression(var)
194 |
195 | try:
196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
197 | except KeyError:
198 | with absolute_name_scope(var.name.split(":")[0]):
199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
201 |
202 | ops.append(setter)
203 | feed_dict[setter.op.inputs[1]] = value
204 |
205 | run(ops, feed_dict)
206 |
207 |
208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
209 | """Create tf.Variable with large initial value without bloating the tf graph."""
210 | assert_tf_initialized()
211 | assert isinstance(initial_value, np.ndarray)
212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype)
213 | var = tf.Variable(zeros, *args, **kwargs)
214 | set_vars({var: initial_value})
215 | return var
216 |
217 |
218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
220 | Can be used as an input transformation for Network.run().
221 | """
222 | images = tf.cast(images, tf.float32)
223 | if nhwc_to_nchw:
224 | images = tf.transpose(images, [0, 3, 1, 2])
225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
226 |
227 |
228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
230 | Can be used as an output transformation for Network.run().
231 | """
232 | images = tf.cast(images, tf.float32)
233 | if shrink > 1:
234 | ksize = [1, 1, shrink, shrink]
235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
236 | if nchw_to_nhwc:
237 | images = tf.transpose(images, [0, 2, 3, 1])
238 | scale = 255 / (drange[1] - drange[0])
239 | images = images * scale + (0.5 - drange[0] * scale)
240 | return tf.saturate_cast(images, tf.uint8)
241 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import autosummary
9 | from . import network
10 | from . import optimizer
11 | from . import tfutil
12 |
13 | from .tfutil import *
14 | from .network import Network
15 |
16 | from .optimizer import Optimizer
17 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/tflib/autosummary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for adding automatically tracked values to Tensorboard.
9 |
10 | Autosummary creates an identity op that internally keeps track of the input
11 | values and automatically shows up in TensorBoard. The reported value
12 | represents an average over input components. The average is accumulated
13 | constantly over time and flushed when save_summaries() is called.
14 |
15 | Notes:
16 | - The output tensor must be used as an input for something else in the
17 | graph. Otherwise, the autosummary op will not get executed, and the average
18 | value will not get accumulated.
19 | - It is perfectly fine to include autosummaries with the same name in
20 | several places throughout the graph, even if they are executed concurrently.
21 | - It is ok to also pass in a python scalar or numpy array. In this case, it
22 | is added to the average immediately.
23 | """
24 |
25 | from collections import OrderedDict
26 | import numpy as np
27 | import tensorflow as tf
28 | from tensorboard import summary as summary_lib
29 | from tensorboard.plugins.custom_scalar import layout_pb2
30 |
31 | from . import tfutil
32 | from .tfutil import TfExpression
33 | from .tfutil import TfExpressionEx
34 |
35 | _dtype = tf.float64
36 | _vars = OrderedDict() # name => [var, ...]
37 | _immediate = OrderedDict() # name => update_op, update_value
38 | _finalized = False
39 | _merge_op = None
40 |
41 |
42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
43 | """Internal helper for creating autosummary accumulators."""
44 | assert not _finalized
45 | name_id = name.replace("/", "_")
46 | v = tf.cast(value_expr, _dtype)
47 |
48 | if v.shape.is_fully_defined():
49 | size = np.prod(tfutil.shape_to_list(v.shape))
50 | size_expr = tf.constant(size, dtype=_dtype)
51 | else:
52 | size = None
53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
54 |
55 | if size == 1:
56 | if v.shape.ndims != 0:
57 | v = tf.reshape(v, [])
58 | v = [size_expr, v, tf.square(v)]
59 | else:
60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
62 |
63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
66 |
67 | if name in _vars:
68 | _vars[name].append(var)
69 | else:
70 | _vars[name] = [var]
71 | return update_op
72 |
73 |
74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
75 | """Create a new autosummary.
76 |
77 | Args:
78 | name: Name to use in TensorBoard
79 | value: TensorFlow expression or python value to track
80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
81 |
82 | Example use of the passthru mechanism:
83 |
84 | n = autosummary('l2loss', loss, passthru=n)
85 |
86 | This is a shorthand for the following code:
87 |
88 | with tf.control_dependencies([autosummary('l2loss', loss)]):
89 | n = tf.identity(n)
90 | """
91 | tfutil.assert_tf_initialized()
92 | name_id = name.replace("/", "_")
93 |
94 | if tfutil.is_tf_expression(value):
95 | with tf.name_scope("summary_" + name_id), tf.device(value.device):
96 | update_op = _create_var(name, value)
97 | with tf.control_dependencies([update_op]):
98 | return tf.identity(value if passthru is None else passthru)
99 |
100 | else: # python scalar or numpy array
101 | if name not in _immediate:
102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
103 | update_value = tf.placeholder(_dtype)
104 | update_op = _create_var(name, update_value)
105 | _immediate[name] = update_op, update_value
106 |
107 | update_op, update_value = _immediate[name]
108 | tfutil.run(update_op, {update_value: value})
109 | return value if passthru is None else passthru
110 |
111 |
112 | def finalize_autosummaries() -> None:
113 | """Create the necessary ops to include autosummaries in TensorBoard report.
114 | Note: This should be done only once per graph.
115 | """
116 | global _finalized
117 | tfutil.assert_tf_initialized()
118 |
119 | if _finalized:
120 | return None
121 |
122 | _finalized = True
123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
124 |
125 | # Create summary ops.
126 | with tf.device(None), tf.control_dependencies(None):
127 | for name, vars_list in _vars.items():
128 | name_id = name.replace("/", "_")
129 | with tfutil.absolute_name_scope("Autosummary/" + name_id):
130 | moments = tf.add_n(vars_list)
131 | moments /= moments[0]
132 | with tf.control_dependencies([moments]): # read before resetting
133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
135 | mean = moments[1]
136 | std = tf.sqrt(moments[2] - tf.square(moments[1]))
137 | tf.summary.scalar(name, mean)
138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
140 |
141 | # Group by category and chart name.
142 | cat_dict = OrderedDict()
143 | for series_name in sorted(_vars.keys()):
144 | p = series_name.split("/")
145 | cat = p[0] if len(p) >= 2 else ""
146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
147 | if cat not in cat_dict:
148 | cat_dict[cat] = OrderedDict()
149 | if chart not in cat_dict[cat]:
150 | cat_dict[cat][chart] = []
151 | cat_dict[cat][chart].append(series_name)
152 |
153 | # Setup custom_scalar layout.
154 | categories = []
155 | for cat_name, chart_dict in cat_dict.items():
156 | charts = []
157 | for chart_name, series_names in chart_dict.items():
158 | series = []
159 | for series_name in series_names:
160 | series.append(layout_pb2.MarginChartContent.Series(
161 | value=series_name,
162 | lower="xCustomScalars/" + series_name + "/margin_lo",
163 | upper="xCustomScalars/" + series_name + "/margin_hi"))
164 | margin = layout_pb2.MarginChartContent(series=series)
165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts))
167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
168 | return layout
169 |
170 | def save_summaries(file_writer, global_step=None):
171 | """Call FileWriter.add_summary() with all summaries in the default graph,
172 | automatically finalizing and merging them on the first call.
173 | """
174 | global _merge_op
175 | tfutil.assert_tf_initialized()
176 |
177 | if _merge_op is None:
178 | layout = finalize_autosummaries()
179 | if layout is not None:
180 | file_writer.add_summary(layout)
181 | with tf.device(None), tf.control_dependencies(None):
182 | _merge_op = tf.summary.merge_all()
183 |
184 | file_writer.add_summary(_merge_op.eval(), global_step)
185 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/tflib/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper wrapper for a Tensorflow optimizer."""
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | from collections import OrderedDict
14 | from typing import List, Union
15 |
16 | from . import autosummary
17 | from . import tfutil
18 | from .. import util
19 |
20 | from .tfutil import TfExpression, TfExpressionEx
21 |
22 | try:
23 | # TensorFlow 1.13
24 | from tensorflow.python.ops import nccl_ops
25 | except:
26 | # Older TensorFlow versions
27 | import tensorflow.contrib.nccl as nccl_ops
28 |
29 | class Optimizer:
30 | """A Wrapper for tf.train.Optimizer.
31 |
32 | Automatically takes care of:
33 | - Gradient averaging for multi-GPU training.
34 | - Dynamic loss scaling and typecasts for FP16 training.
35 | - Ignoring corrupted gradients that contain NaNs/Infs.
36 | - Reporting statistics.
37 | - Well-chosen default settings.
38 | """
39 |
40 | def __init__(self,
41 | name: str = "Train",
42 | tf_optimizer: str = "tf.train.AdamOptimizer",
43 | learning_rate: TfExpressionEx = 0.001,
44 | use_loss_scaling: bool = False,
45 | loss_scaling_init: float = 64.0,
46 | loss_scaling_inc: float = 0.0005,
47 | loss_scaling_dec: float = 1.0,
48 | **kwargs):
49 |
50 | # Init fields.
51 | self.name = name
52 | self.learning_rate = tf.convert_to_tensor(learning_rate)
53 | self.id = self.name.replace("/", ".")
54 | self.scope = tf.get_default_graph().unique_name(self.id)
55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer)
56 | self.optimizer_kwargs = dict(kwargs)
57 | self.use_loss_scaling = use_loss_scaling
58 | self.loss_scaling_init = loss_scaling_init
59 | self.loss_scaling_inc = loss_scaling_inc
60 | self.loss_scaling_dec = loss_scaling_dec
61 | self._grad_shapes = None # [shape, ...]
62 | self._dev_opt = OrderedDict() # device => optimizer
63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
65 | self._updates_applied = False
66 |
67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
68 | """Register the gradients of the given loss function with respect to the given variables.
69 | Intended to be called once per GPU."""
70 | assert not self._updates_applied
71 |
72 | # Validate arguments.
73 | if isinstance(trainable_vars, dict):
74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
75 |
76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
78 |
79 | if self._grad_shapes is None:
80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
81 |
82 | assert len(trainable_vars) == len(self._grad_shapes)
83 | assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
84 |
85 | dev = loss.device
86 |
87 | assert all(var.device == dev for var in trainable_vars)
88 |
89 | # Register device and compute gradients.
90 | with tf.name_scope(self.id + "_grad"), tf.device(dev):
91 | if dev not in self._dev_opt:
92 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
93 | assert callable(self.optimizer_class)
94 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
95 | self._dev_grads[dev] = []
96 |
97 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
98 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
99 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
100 | self._dev_grads[dev].append(grads)
101 |
102 | def apply_updates(self) -> tf.Operation:
103 | """Construct training op to update the registered variables based on their gradients."""
104 | tfutil.assert_tf_initialized()
105 | assert not self._updates_applied
106 | self._updates_applied = True
107 | devices = list(self._dev_grads.keys())
108 | total_grads = sum(len(grads) for grads in self._dev_grads.values())
109 | assert len(devices) >= 1 and total_grads >= 1
110 | ops = []
111 |
112 | with tfutil.absolute_name_scope(self.scope):
113 | # Cast gradients to FP32 and calculate partial sum within each device.
114 | dev_grads = OrderedDict() # device => [(grad, var), ...]
115 |
116 | for dev_idx, dev in enumerate(devices):
117 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
118 | sums = []
119 |
120 | for gv in zip(*self._dev_grads[dev]):
121 | assert all(v is gv[0][1] for g, v in gv)
122 | g = [tf.cast(g, tf.float32) for g, v in gv]
123 | g = g[0] if len(g) == 1 else tf.add_n(g)
124 | sums.append((g, gv[0][1]))
125 |
126 | dev_grads[dev] = sums
127 |
128 | # Sum gradients across devices.
129 | if len(devices) > 1:
130 | with tf.name_scope("SumAcrossGPUs"), tf.device(None):
131 | for var_idx, grad_shape in enumerate(self._grad_shapes):
132 | g = [dev_grads[dev][var_idx][0] for dev in devices]
133 |
134 | if np.prod(grad_shape): # nccl does not support zero-sized tensors
135 | g = nccl_ops.all_sum(g)
136 |
137 | for dev, gg in zip(devices, g):
138 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
139 |
140 | # Apply updates separately on each device.
141 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
142 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
143 | # Scale gradients as needed.
144 | if self.use_loss_scaling or total_grads > 1:
145 | with tf.name_scope("Scale"):
146 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
147 | coef = self.undo_loss_scaling(coef)
148 | grads = [(g * coef, v) for g, v in grads]
149 |
150 | # Check for overflows.
151 | with tf.name_scope("CheckOverflow"):
152 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
153 |
154 | # Update weights and adjust loss scaling.
155 | with tf.name_scope("UpdateWeights"):
156 | # pylint: disable=cell-var-from-loop
157 | opt = self._dev_opt[dev]
158 | ls_var = self.get_loss_scaling_var(dev)
159 |
160 | if not self.use_loss_scaling:
161 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
162 | else:
163 | ops.append(tf.cond(grad_ok,
164 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
165 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
166 |
167 | # Report statistics on the last device.
168 | if dev == devices[-1]:
169 | with tf.name_scope("Statistics"):
170 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
171 | ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
172 |
173 | if self.use_loss_scaling:
174 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
175 |
176 | # Initialize variables and group everything into a single op.
177 | self.reset_optimizer_state()
178 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
179 |
180 | return tf.group(*ops, name="TrainingOp")
181 |
182 | def reset_optimizer_state(self) -> None:
183 | """Reset internal state of the underlying optimizer."""
184 | tfutil.assert_tf_initialized()
185 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
186 |
187 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
188 | """Get or create variable representing log2 of the current dynamic loss scaling factor."""
189 | if not self.use_loss_scaling:
190 | return None
191 |
192 | if device not in self._dev_ls_var:
193 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
194 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
195 |
196 | return self._dev_ls_var[device]
197 |
198 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
199 | """Apply dynamic loss scaling for the given expression."""
200 | assert tfutil.is_tf_expression(value)
201 |
202 | if not self.use_loss_scaling:
203 | return value
204 |
205 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
206 |
207 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
208 | """Undo the effect of dynamic loss scaling for the given expression."""
209 | assert tfutil.is_tf_expression(value)
210 |
211 | if not self.use_loss_scaling:
212 | return value
213 |
214 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
215 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/dnnlib/tflib/tfutil.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous helper utils for Tensorflow."""
9 |
10 | import os
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from typing import Any, Iterable, List, Union
15 |
16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
17 | """A type that represents a valid Tensorflow expression."""
18 |
19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
20 | """A type that can be converted to a valid Tensorflow expression."""
21 |
22 |
23 | def run(*args, **kwargs) -> Any:
24 | """Run the specified ops in the default session."""
25 | assert_tf_initialized()
26 | return tf.get_default_session().run(*args, **kwargs)
27 |
28 |
29 | def is_tf_expression(x: Any) -> bool:
30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
32 |
33 |
34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
35 | """Convert a Tensorflow shape to a list of ints."""
36 | return [dim.value for dim in shape]
37 |
38 |
39 | def flatten(x: TfExpressionEx) -> TfExpression:
40 | """Shortcut function for flattening a tensor."""
41 | with tf.name_scope("Flatten"):
42 | return tf.reshape(x, [-1])
43 |
44 |
45 | def log2(x: TfExpressionEx) -> TfExpression:
46 | """Logarithm in base 2."""
47 | with tf.name_scope("Log2"):
48 | return tf.log(x) * np.float32(1.0 / np.log(2.0))
49 |
50 |
51 | def exp2(x: TfExpressionEx) -> TfExpression:
52 | """Exponent in base 2."""
53 | with tf.name_scope("Exp2"):
54 | return tf.exp(x * np.float32(np.log(2.0)))
55 |
56 |
57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
58 | """Linear interpolation."""
59 | with tf.name_scope("Lerp"):
60 | return a + (b - a) * t
61 |
62 |
63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
64 | """Linear interpolation with clip."""
65 | with tf.name_scope("LerpClip"):
66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
67 |
68 |
69 | def absolute_name_scope(scope: str) -> tf.name_scope:
70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
71 | return tf.name_scope(scope + "/")
72 |
73 |
74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
77 |
78 |
79 | def _sanitize_tf_config(config_dict: dict = None) -> dict:
80 | # Defaults.
81 | cfg = dict()
82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
87 |
88 | # User overrides.
89 | if config_dict is not None:
90 | cfg.update(config_dict)
91 | return cfg
92 |
93 |
94 | def init_tf(config_dict: dict = None) -> None:
95 | """Initialize TensorFlow session using good default settings."""
96 | # Skip if already initialized.
97 | if tf.get_default_session() is not None:
98 | return
99 |
100 | # Setup config dict and random seeds.
101 | cfg = _sanitize_tf_config(config_dict)
102 | np_random_seed = cfg["rnd.np_random_seed"]
103 | if np_random_seed is not None:
104 | np.random.seed(np_random_seed)
105 | tf_random_seed = cfg["rnd.tf_random_seed"]
106 | if tf_random_seed == "auto":
107 | tf_random_seed = np.random.randint(1 << 31)
108 | if tf_random_seed is not None:
109 | tf.set_random_seed(tf_random_seed)
110 |
111 | # Setup environment variables.
112 | for key, value in list(cfg.items()):
113 | fields = key.split(".")
114 | if fields[0] == "env":
115 | assert len(fields) == 2
116 | os.environ[fields[1]] = str(value)
117 |
118 | # Create default TensorFlow session.
119 | create_session(cfg, force_as_default=True)
120 |
121 |
122 | def assert_tf_initialized():
123 | """Check that TensorFlow session has been initialized."""
124 | if tf.get_default_session() is None:
125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
126 |
127 |
128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
129 | """Create tf.Session based on config dict."""
130 | # Setup TensorFlow config proto.
131 | cfg = _sanitize_tf_config(config_dict)
132 | config_proto = tf.ConfigProto()
133 | for key, value in cfg.items():
134 | fields = key.split(".")
135 | if fields[0] not in ["rnd", "env"]:
136 | obj = config_proto
137 | for field in fields[:-1]:
138 | obj = getattr(obj, field)
139 | setattr(obj, fields[-1], value)
140 |
141 | # Create session.
142 | session = tf.Session(config=config_proto)
143 | if force_as_default:
144 | # pylint: disable=protected-access
145 | session._default_session = session.as_default()
146 | session._default_session.enforce_nesting = False
147 | session._default_session.__enter__() # pylint: disable=no-member
148 |
149 | return session
150 |
151 |
152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
153 | """Initialize all tf.Variables that have not already been initialized.
154 |
155 | Equivalent to the following, but more efficient and does not bloat the tf graph:
156 | tf.variables_initializer(tf.report_uninitialized_variables()).run()
157 | """
158 | assert_tf_initialized()
159 | if target_vars is None:
160 | target_vars = tf.global_variables()
161 |
162 | test_vars = []
163 | test_ops = []
164 |
165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
166 | for var in target_vars:
167 | assert is_tf_expression(var)
168 |
169 | try:
170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
171 | except KeyError:
172 | # Op does not exist => variable may be uninitialized.
173 | test_vars.append(var)
174 |
175 | with absolute_name_scope(var.name.split(":")[0]):
176 | test_ops.append(tf.is_variable_initialized(var))
177 |
178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
179 | run([var.initializer for var in init_vars])
180 |
181 |
182 | def set_vars(var_to_value_dict: dict) -> None:
183 | """Set the values of given tf.Variables.
184 |
185 | Equivalent to the following, but more efficient and does not bloat the tf graph:
186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
187 | """
188 | assert_tf_initialized()
189 | ops = []
190 | feed_dict = {}
191 |
192 | for var, value in var_to_value_dict.items():
193 | assert is_tf_expression(var)
194 |
195 | try:
196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
197 | except KeyError:
198 | with absolute_name_scope(var.name.split(":")[0]):
199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
201 |
202 | ops.append(setter)
203 | feed_dict[setter.op.inputs[1]] = value
204 |
205 | run(ops, feed_dict)
206 |
207 |
208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
209 | """Create tf.Variable with large initial value without bloating the tf graph."""
210 | assert_tf_initialized()
211 | assert isinstance(initial_value, np.ndarray)
212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype)
213 | var = tf.Variable(zeros, *args, **kwargs)
214 | set_vars({var: initial_value})
215 | return var
216 |
217 |
218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
220 | Can be used as an input transformation for Network.run().
221 | """
222 | images = tf.cast(images, tf.float32)
223 | if nhwc_to_nchw:
224 | images = tf.transpose(images, [0, 3, 1, 2])
225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
226 |
227 |
228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
230 | Can be used as an output transformation for Network.run().
231 | """
232 | images = tf.cast(images, tf.float32)
233 | if shrink > 1:
234 | ksize = [1, 1, shrink, shrink]
235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
236 | if nchw_to_nhwc:
237 | images = tf.transpose(images, [0, 2, 3, 1])
238 | scale = 255 / (drange[1] - drange[0])
239 | images = images * scale + (0.5 - drange[0] * scale)
240 | return tf.saturate_cast(images, tf.uint8)
241 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/loss_criterions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dave-epstein/blobgan/c527f1c27447dffe3cf4cf3901571a83ce59f1fe/scripts/style-gan-pytorch/loss_criterions/__init__.py
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/loss_criterions/base_loss_criterions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # source: https://github.com/facebookresearch/pytorch_GAN_zoo/blob/master/models/loss_criterions/base_loss_criterions.py
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class BaseLossWrapper:
8 | r"""
9 | Loss criterion class. Must define 4 members:
10 | sizeDecisionLayer : size of the decision layer of the discrimator
11 |
12 | getCriterion : how the loss is actually computed
13 |
14 | !! The activation function of the discriminator is computed within the
15 | loss !!
16 | """
17 |
18 | def __init__(self, device):
19 | self.device = device
20 |
21 | def getCriterion(self, input, status):
22 | r"""
23 | Given an input tensor and its targeted status (detected as real or
24 | detected as fake) build the associated loss
25 |
26 | Args:
27 |
28 | - input (Tensor): decision tensor build by the model's discrimator
29 | - status (bool): if True -> this tensor should have been detected
30 | as a real input
31 | else -> it shouldn't have
32 | """
33 | pass
34 |
35 |
36 | class MSE(BaseLossWrapper):
37 | r"""
38 | Mean Square error loss.
39 | """
40 |
41 | def __init__(self, device):
42 | self.generationActivation = F.tanh
43 | self.sizeDecisionLayer = 1
44 |
45 | BaseLossWrapper.__init__(self, device)
46 |
47 | def getCriterion(self, input, status):
48 | size = input.size()[0]
49 | value = float(status)
50 | reference = torch.tensor([value]).expand(size, 1).to(self.device)
51 | return F.mse_loss(F.sigmoid(input[:, :self.sizeDecisionLayer]),
52 | reference)
53 |
54 |
55 | class WGANGP(BaseLossWrapper):
56 | r"""
57 | Paper WGANGP loss : linear activation for the generator.
58 | https://arxiv.org/pdf/1704.00028.pdf
59 | """
60 |
61 | def __init__(self, device):
62 |
63 | self.generationActivation = None
64 | self.sizeDecisionLayer = 1
65 |
66 | BaseLossWrapper.__init__(self, device)
67 |
68 | def getCriterion(self, input, status):
69 | if status:
70 | return -input[:, 0].sum()
71 | return input[:, 0].sum()
72 |
73 |
74 | class Logistic(BaseLossWrapper):
75 | r"""
76 | "Which training method of GANs actually converge"
77 | https://arxiv.org/pdf/1801.04406.pdf
78 | """
79 |
80 | def __init__(self, device):
81 |
82 | self.generationActivation = None
83 | self.sizeDecisionLayer = 1
84 | BaseLossWrapper.__init__(self, device)
85 |
86 | def getCriterion(self, input, status):
87 | if status:
88 | return F.softplus(-input[:, 0]).mean()
89 | return F.softplus(input[:, 0]).mean()
90 |
91 |
92 | class DCGAN(BaseLossWrapper):
93 | r"""
94 | Cross entropy loss.
95 | """
96 |
97 | def __init__(self, device):
98 |
99 | self.generationActivation = F.tanh
100 | self.sizeDecisionLayer = 1
101 |
102 | BaseLossWrapper.__init__(self, device)
103 |
104 | def getCriterion(self, input, status):
105 | size = input.size()[0]
106 | value = int(status)
107 | reference = torch.tensor(
108 | [value], dtype=torch.float).expand(size).to(self.device)
109 | return F.binary_cross_entropy(torch.sigmoid(input[:, :self.sizeDecisionLayer]), reference)
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/loss_criterions/gradient_losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import torch
3 |
4 |
5 | def WGANGPGradientPenalty(input, fake, discriminator, weight, backward=True):
6 | r"""
7 | Gradient penalty as described in
8 | "Improved Training of Wasserstein GANs"
9 | https://arxiv.org/pdf/1704.00028.pdf
10 |
11 | Args:
12 |
13 | - input (Tensor): batch of real data
14 | - fake (Tensor): batch of generated data. Must have the same size
15 | as the input
16 | - discrimator (nn.Module): discriminator network
17 | - weight (float): weight to apply to the penalty term
18 | - backward (bool): loss backpropagation
19 | """
20 |
21 | batchSize = input.size(0)
22 | alpha = torch.rand(batchSize, 1)
23 | alpha = alpha.expand(batchSize, int(input.nelement() /
24 | batchSize)).contiguous().view(
25 | input.size())
26 | alpha = alpha.to(input.device)
27 | interpolates = alpha * input + ((1 - alpha) * fake)
28 |
29 | interpolates = torch.autograd.Variable(
30 | interpolates, requires_grad=True)
31 |
32 | decisionInterpolate = discriminator(interpolates, False)
33 | decisionInterpolate = decisionInterpolate[:, 0].sum()
34 |
35 | gradients = torch.autograd.grad(outputs=decisionInterpolate,
36 | inputs=interpolates,
37 | create_graph=True, retain_graph=True)
38 |
39 | gradients = gradients[0].view(batchSize, -1)
40 | gradients = (gradients * gradients).sum(dim=1).sqrt()
41 | gradient_penalty = (((gradients - 1.0)**2)).sum() * weight
42 |
43 | if backward:
44 | gradient_penalty.backward(retain_graph=True)
45 |
46 | return gradient_penalty.item()
47 |
48 |
49 | def logisticGradientPenalty(input, discrimator, res, alpha, weight, backward=True):
50 | r"""
51 | Gradient penalty described in "Which training method of GANs actually
52 | converge
53 | https://arxiv.org/pdf/1801.04406.pdf
54 |
55 | Args:
56 |
57 | - input (Tensor): batch of real data
58 | - discrimator (nn.Module): discriminator network
59 | - weight (float): weight to apply to the penalty term
60 | - backward (bool): loss backpropagation
61 | """
62 |
63 | locInput = torch.autograd.Variable(
64 | input, requires_grad=True)
65 | gradients = torch.autograd.grad(outputs=discrimator(locInput, res, alpha)[:, 0].sum(),
66 | inputs=locInput,
67 | create_graph=True, retain_graph=True)[0]
68 |
69 | gradients = gradients.view(gradients.size(0), -1)
70 | gradients = (gradients * gradients).sum(dim=1).mean()
71 |
72 | gradient_penalty = gradients * weight
73 | if backward:
74 | gradient_penalty.backward(retain_graph=True)
75 |
76 | return gradient_penalty.item()
77 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dave-epstein/blobgan/c527f1c27447dffe3cf4cf3901571a83ce59f1fe/scripts/style-gan-pytorch/networks/__init__.py
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/networks/building_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from collections import OrderedDict
4 | from networks.custom_layers import *
5 |
6 |
7 | class LayerEpilogue(nn.Module):
8 | """
9 | Things to do at the end of each layer
10 | 1. mixin scaled noise
11 | 2. mixin style with AdaIN
12 | """
13 | def __init__(self,
14 | num_channels,
15 | dlatent_size, # Disentangled latent (W) dimensionality,
16 | use_wscale, # Enable equalized learning rate?
17 | use_pixel_norm, # Enable pixel-wise feature vector normalization?
18 | use_instance_norm,
19 | use_noise,
20 | use_styles,
21 | nonlinearity,
22 | ):
23 | super(LayerEpilogue, self).__init__()
24 |
25 | act = {
26 | 'relu': torch.relu,
27 | 'lrelu': nn.LeakyReLU(negative_slope=0.2)
28 | }[nonlinearity]
29 |
30 | layers = []
31 | if use_noise:
32 | layers.append(('noise', NoiseMixin(num_channels)))
33 | layers.append(('act', act))
34 |
35 | # to follow the tf implementation
36 | if use_pixel_norm:
37 | layers.append(('pixel_norm', NormalizationLayer()))
38 | if use_instance_norm:
39 | layers.append(('instance_norm', nn.InstanceNorm2d(num_channels)))
40 | # now we need to mixin styles
41 | self.pre_style_op = nn.Sequential(OrderedDict(layers))
42 |
43 | if use_styles:
44 | self.style_mod = StyleMixin(dlatent_size,
45 | num_channels,
46 | use_wscale=use_wscale)
47 | def forward(self, x, dlatent):
48 | # dlatent is w
49 | x = self.pre_style_op(x)
50 | if self.style_mod:
51 | x = self.style_mod(x, dlatent)
52 | return x
53 |
54 |
55 | class EarlySynthesisBlock(nn.Module):
56 | """
57 | The first block for 4x4 resolution
58 | """
59 | def __init__(self,
60 | in_channels,
61 | dlatent_size,
62 | const_input_layer,
63 | use_wscale,
64 | use_noise,
65 | use_pixel_norm,
66 | use_instance_norm,
67 | use_styles,
68 | nonlinearity
69 | ):
70 | super(EarlySynthesisBlock, self).__init__()
71 | self.const_input_layer = const_input_layer
72 | self.in_channels = in_channels
73 |
74 | if const_input_layer:
75 | self.const = nn.Parameter(torch.ones(1, in_channels, 4, 4))
76 | self.bias = nn.Parameter(torch.ones(in_channels))
77 | else:
78 | self.dense = EqualizedLinear(dlatent_size, in_channels * 16, use_wscale=use_wscale)
79 |
80 | self.epi0 = LayerEpilogue(num_channels=in_channels,
81 | dlatent_size=dlatent_size,
82 | use_wscale=use_wscale,
83 | use_noise=use_noise,
84 | use_pixel_norm=use_pixel_norm,
85 | use_instance_norm=use_instance_norm,
86 | use_styles=use_styles,
87 | nonlinearity=nonlinearity
88 | )
89 | # kernel size must be 3 or other odd numbers
90 | # so that we have 'same' padding
91 | self.conv = EqualizedConv2d(in_channels=in_channels,
92 | out_channels=in_channels,
93 | kernel_size=3,
94 | padding=3//2)
95 |
96 | self.epi1 = LayerEpilogue(num_channels=in_channels,
97 | dlatent_size=dlatent_size,
98 | use_wscale=use_wscale,
99 | use_noise=use_noise,
100 | use_pixel_norm=use_pixel_norm,
101 | use_instance_norm=use_instance_norm,
102 | use_styles=use_styles,
103 | nonlinearity=nonlinearity
104 | )
105 |
106 | def forward(self, dlatents):
107 | # note dlatents is broadcast one
108 | dlatents_0 = dlatents[:, 0]
109 | dlatents_1 = dlatents[:, 1]
110 | batch_size = dlatents.size(0)
111 | if self.const_input_layer:
112 | x = self.const.expand(batch_size, -1, -1, -1)
113 | x = x + self.bias.view(1, -1, 1, 1)
114 | else:
115 | x = self.dense(dlatents_0).view(batch_size, self.in_channels, 4, 4)
116 |
117 | x = self.epi0(x, dlatents_0)
118 | x = self.conv(x)
119 | x = self.epi1(x, dlatents_1)
120 | return x
121 |
122 |
123 | class LaterSynthesisBlock(nn.Module):
124 | """
125 | The following blocks for res 8x8...etc.
126 | """
127 |
128 | def __init__(self,
129 | in_channels,
130 | out_channels,
131 | dlatent_size,
132 | use_wscale,
133 | use_noise,
134 | use_pixel_norm,
135 | use_instance_norm,
136 | use_styles,
137 | nonlinearity,
138 | blur_filter,
139 | res,
140 | ):
141 | super(LaterSynthesisBlock, self).__init__()
142 |
143 | # res = log2(H), H is 4, 8, 16, 32 ... 1024
144 |
145 | assert isinstance(res, int) and (2 <= res <= 10)
146 |
147 | self.res = res
148 |
149 | if blur_filter:
150 | self.blur = Blur2d(blur_filter)
151 | #blur = Blur2d(blur_filter)
152 | else:
153 | self.blur = None
154 |
155 | # name 'conv0_up' is used in tf implementation
156 | self.conv0_up = Upscale2dConv2d(res=res,
157 | in_channels=in_channels,
158 | out_channels=out_channels,
159 | kernel_size=3,
160 | use_wscale=use_wscale)
161 | # self.conv0_up = Upscale2dConv2d2(
162 | # input_channels=in_channels,
163 | # output_channels=out_channels,
164 | # kernel_size=3,
165 | # gain=np.sqrt(2),
166 | # use_wscale=use_wscale,
167 | # intermediate=blur,
168 | # upscale=True
169 | # )
170 |
171 | self.epi0 = LayerEpilogue(num_channels=out_channels,
172 | dlatent_size=dlatent_size,
173 | use_wscale=use_wscale,
174 | use_pixel_norm=use_pixel_norm,
175 | use_noise=use_noise,
176 | use_instance_norm=use_instance_norm,
177 | use_styles=use_styles,
178 | nonlinearity=nonlinearity)
179 |
180 | # name 'conv1' is used in tf implementation
181 | # kernel size must be 3 or other odd numbers
182 | # so that we have 'same' padding
183 | # no upsclaing
184 | self.conv1 = EqualizedConv2d(in_channels=out_channels,
185 | out_channels=out_channels,
186 | kernel_size=3,
187 | padding=3//2)
188 |
189 | self.epi1 = LayerEpilogue(num_channels=out_channels,
190 | dlatent_size=dlatent_size,
191 | use_wscale=use_wscale,
192 | use_pixel_norm=use_pixel_norm,
193 | use_noise=use_noise,
194 | use_instance_norm=use_instance_norm,
195 | use_styles=use_styles,
196 | nonlinearity=nonlinearity)
197 |
198 |
199 | def forward(self, x, dlatents):
200 |
201 | x = self.conv0_up(x)
202 | if self.blur is not None:
203 | x = self.blur(x)
204 | x = self.epi0(x, dlatents[:, self.res * 2 - 4])
205 | x = self.conv1(x)
206 | x = self.epi1(x, dlatents[:, self.res * 2 - 3])
207 | return x
208 |
209 |
210 | class EarlyDiscriminatorBlock(nn.Sequential):
211 | def __init__(self,
212 | res,
213 | in_channels,
214 | out_channels,
215 | use_wscale,
216 | blur_filter,
217 | fused_scale,
218 | nonlinearity):
219 | act = {
220 | 'relu': torch.relu,
221 | 'lrelu': nn.LeakyReLU(negative_slope=0.2)
222 | }[nonlinearity]
223 |
224 | layers = []
225 |
226 | layers.append(('conv0', EqualizedConv2d(in_channels=in_channels,
227 | out_channels=in_channels,
228 | kernel_size=3,
229 | padding=3//2,
230 | use_wscale=use_wscale)))
231 | # note that we don't have layer epilogue in discriminator, so we need to add activation layer mannually
232 | layers.append(('act0', act))
233 |
234 | layers.append(('blur', Blur2d(blur_filter)))
235 |
236 | layers.append(('conv1_down', Downscale2dConv2d(res=res,
237 | in_channels=in_channels,
238 | out_channels=out_channels,
239 | kernel_size=3,
240 | fused_scale=fused_scale,
241 | use_wscale=use_wscale)))
242 | layers.append(('act1', act))
243 |
244 | super().__init__(OrderedDict(layers))
245 |
246 |
247 | class LaterDiscriminatorBlock(nn.Sequential):
248 |
249 | def __init__(self,
250 | in_channels,
251 | out_channels,
252 | use_wscale,
253 | nonlinearity,
254 | mbstd_group_size,
255 | mbstd_num_features,
256 | res,
257 | ):
258 | act = {
259 | 'relu': torch.relu,
260 | 'lrelu': nn.LeakyReLU(negative_slope=0.2)
261 | }[nonlinearity]
262 |
263 | resolution = 2 ** res
264 | layers = []
265 | layers.append(('minibatchstddev', MiniBatchStdDev(mbstd_group_size, mbstd_num_features)))
266 | layers.append(('conv', EqualizedConv2d(in_channels=in_channels + mbstd_num_features,
267 | out_channels=in_channels,
268 | kernel_size=3,
269 | padding=3//2,
270 | use_wscale=use_wscale)))
271 | layers.append(('act0', act))
272 | layers.append(('flatten', Flatten()))
273 | layers.append(('dense0', EqualizedLinear(in_channels=in_channels * (resolution**2),
274 | out_channels=in_channels,
275 | use_wscale=use_wscale)))
276 | layers.append(('act1', act))
277 | # no activation for the last fc
278 | layers.append(('dense1', EqualizedLinear(in_channels=in_channels,
279 | out_channels=out_channels)))
280 |
281 | super().__init__(OrderedDict(layers))
282 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/networks/style_gan_net.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from collections import OrderedDict
6 | from networks.custom_layers import EqualizedLinear, EqualizedConv2d, \
7 | NormalizationLayer, _upscale2d
8 | from networks.building_blocks import EarlySynthesisBlock, LaterSynthesisBlock, \
9 | EarlyDiscriminatorBlock, LaterDiscriminatorBlock
10 |
11 | class MappingNet(nn.Sequential):
12 | """
13 | A mapping network f implemented using an 8-layer MLP
14 | """
15 | def __init__(self,
16 | resolution = 1024,
17 | num_layers = 8,
18 | dlatent_size = 512,
19 | normalize_latents = True,
20 | nonlinearity = 'lrelu',
21 | maping_lrmul = 0.01, # We thus reduce the learning rate by two orders of magnitude for the mapping network
22 | **kwargs): # other parameters are ignored
23 |
24 | resolution_log2: int = int(np.log2(resolution))
25 |
26 | assert resolution == 2**resolution_log2 and 4 <= resolution <= 1024
27 |
28 | act = {
29 | 'relu': torch.relu,
30 | 'lrelu': nn.LeakyReLU(negative_slope=0.2)
31 | }[nonlinearity]
32 |
33 | self.dlatent_broadcast = resolution_log2 * 2 - 2
34 | layers = []
35 | if normalize_latents:
36 | layers.append(('pixel_norm', NormalizationLayer()))
37 | for i in range(num_layers):
38 | layers.append(('dense{}'.format(i), EqualizedLinear(dlatent_size,
39 | dlatent_size,
40 | use_wscale=True,
41 | lrmul=maping_lrmul)))
42 | layers.append(('dense{}_act'.format(i), act))
43 |
44 | super().__init__(OrderedDict(layers))
45 |
46 | def forward(self, x):
47 | # N x 512
48 | w = super().forward(x)
49 | if self.dlatent_broadcast is not None:
50 | # broadcast
51 | # tf.tile in the official tf implementation:
52 | # w = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])
53 | w = w.unsqueeze(1).expand(-1, self.dlatent_broadcast, -1)
54 | return w
55 |
56 |
57 | class SynthesisNet(nn.Module):
58 | """
59 | Synthesis network
60 | """
61 | def __init__(self,
62 | dlatent_size = 512,
63 | num_channels = 3,
64 | resolution = 1024,
65 | fmap_base = 8192,
66 | fmap_decay = 1.0,
67 | fmap_max = 512,
68 | use_styles = True,
69 | const_input_layer = True,
70 | use_noise = True,
71 | nonlinearity = 'lrelu',
72 | use_wscale = True,
73 | use_pixel_norm = False,
74 | use_instance_norm = True,
75 | blur_filter = [1, 2, 1], # low-pass filer to apply when resampling activations. None = no filtering
76 | **kwargs # other parameters are ignored
77 | ):
78 | super(SynthesisNet, self).__init__()
79 |
80 | # copied from tf implementation
81 |
82 | resolution_log2: int = int(np.log2(resolution))
83 |
84 | assert resolution == 2**resolution_log2 and 4 <= resolution <= 1024
85 |
86 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
87 |
88 | act = {
89 | 'relu': torch.relu,
90 | 'lrelu': nn.LeakyReLU(negative_slope=0.2)
91 | }[nonlinearity]
92 |
93 | num_layers = resolution_log2 * 2 - 2
94 |
95 | num_styles = num_layers if use_styles else 1
96 |
97 | blocks = []
98 | torgbs = []
99 |
100 | # 2....10 (inclusive) for 1024 resolution
101 | for res in range(2, resolution_log2 + 1):
102 | channels = nf(res - 1)
103 | block_name = '{s}x{s}'.format(s=2**res)
104 | torgb_name = 'torgb_lod{}'.format(resolution_log2 - res)
105 | if res == 2:
106 | # early block
107 | block = (block_name, EarlySynthesisBlock(channels,
108 | dlatent_size,
109 | const_input_layer,
110 | use_wscale,
111 | use_noise,
112 | use_pixel_norm,
113 | use_instance_norm,
114 | use_styles,
115 | nonlinearity))
116 | else:
117 | block = (block_name, LaterSynthesisBlock(last_channels,
118 | out_channels=channels,
119 | dlatent_size=dlatent_size,
120 | use_wscale=use_wscale,
121 | use_noise=use_noise,
122 | use_pixel_norm=use_pixel_norm,
123 | use_instance_norm=use_instance_norm,
124 | use_styles=use_styles,
125 | nonlinearity=nonlinearity,
126 | blur_filter=blur_filter,
127 | res=res,
128 | ))
129 |
130 | # torgb block
131 | torgb = (torgb_name, EqualizedConv2d(channels, num_channels, 1, use_wscale=use_wscale))
132 |
133 | blocks.append(block)
134 | torgbs.append(torgb)
135 | last_channels = channels
136 |
137 | # the last one has bias
138 | self.torgbs = nn.ModuleDict(OrderedDict(torgbs))
139 |
140 | #self.torgb = Upscale2dConv2d2(channels, num_channels, 1, gain=1, use_wscale=use_wscale, bias=True)
141 | self.blocks = nn.ModuleDict(OrderedDict(blocks))
142 |
143 |
144 | def forward(self, dlatents, res, alpha):
145 | assert 2 <= res <= 10
146 | # step 1...9
147 | step = res - 1
148 | block_list = list(self.blocks.values())[:step]
149 | torgb_list = list(self.torgbs.values())[:step]
150 |
151 | # starting from 8x8 we have skip connections
152 | if step > 1:
153 | skip_torgb = torgb_list[-2]
154 | this_rgb = torgb_list[-1]
155 |
156 | for i, block in enumerate(block_list):
157 |
158 | if i == 0:
159 | x = block(dlatents)
160 | else:
161 | x = block(x, dlatents)
162 |
163 | # step - 1 is the last index
164 | # so step - 2 is the second last
165 | if i == step - 2:
166 | # get the skip result
167 | skip_x = _upscale2d(skip_torgb(x), 2)
168 |
169 | # finally for current resolution, to rgb:
170 | x = this_rgb(x)
171 |
172 | x = (1 - alpha) * skip_x + alpha * x
173 |
174 | return x
175 |
176 |
177 | # a convenient wrapping class
178 | class Generator(nn.Sequential):
179 | def __init__(self, **kwargs):
180 | super().__init__(OrderedDict([
181 | ('g_mapping', MappingNet(**kwargs)),
182 | ('g_synthesis', SynthesisNet(**kwargs))
183 | ]))
184 |
185 | def forward(self, latents, res, alpha):
186 | dlatents = self.g_mapping(latents)
187 | x = self.g_synthesis(dlatents, res, alpha)
188 | return x
189 |
190 |
191 | class BasicDiscriminator(nn.Module):
192 |
193 | def __init__(self,
194 | num_channels = 3,
195 | resolution = 1024,
196 | fmap_base = 8192,
197 | fmap_decay = 1.0,
198 | fmap_max = 512,
199 | nonlinearity = 'lrelu',
200 | mbstd_group_size = 4,
201 | mbstd_num_features = 1,
202 | use_wscale = True,
203 | fused_scale = 'auto',
204 | blur_filter = [1, 2, 1],
205 | ):
206 | super(BasicDiscriminator, self).__init__()
207 |
208 | resolution_log2: int = int(np.log2(resolution))
209 |
210 | assert resolution == 2**resolution_log2 and 4 <= resolution <= 1024
211 |
212 | def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
213 |
214 | act = {
215 | 'relu': torch.relu,
216 | 'lrelu': nn.LeakyReLU(negative_slope=0.2)
217 | }[nonlinearity]
218 | # this is fixed. We need to grow it...
219 | blocks = []
220 | fromrgbs = []
221 | for res in range(resolution_log2, 1, -1):
222 | block_name = '{s}x{s}'.format(s=2 ** res)
223 | fromrgb_name = 'fromrgb_lod{}'.format(resolution_log2 - res)
224 | if res != 2:
225 | blocks.append((block_name, EarlyDiscriminatorBlock(res=res,
226 | in_channels=nf(res-1),
227 | out_channels=nf(res-2),
228 | use_wscale=use_wscale,
229 | blur_filter=blur_filter,
230 | fused_scale=fused_scale,
231 | nonlinearity=nonlinearity)))
232 | else:
233 | blocks.append((block_name, LaterDiscriminatorBlock(in_channels=nf(res),
234 | out_channels=1,
235 | mbstd_group_size=mbstd_group_size,
236 | mbstd_num_features=mbstd_num_features,
237 | use_wscale=use_wscale,
238 | nonlinearity=nonlinearity,
239 | res=2,
240 | )))
241 |
242 | fromrgbs.append((fromrgb_name, EqualizedConv2d(num_channels, nf(res - 1), 1, use_wscale=use_wscale)))
243 |
244 | self.blocks = nn.ModuleDict(OrderedDict(blocks))
245 | self.fromrgbs = nn.ModuleDict(OrderedDict(fromrgbs))
246 |
247 |
248 | def forward(self, x, res, alpha):
249 | assert 2 <= res <= 10
250 | # step 1...9
251 | step = res - 1
252 | block_list = list(self.blocks.values())[-step:]
253 | fromrgb_list = list(self.fromrgbs.values())[-step:]
254 |
255 | if step > 1:
256 | skip_fromrgb = fromrgb_list[1]
257 | this_fromrgb = fromrgb_list[0]
258 |
259 | for i, block in enumerate(block_list):
260 | if i == 0:
261 | skip_x = skip_fromrgb(F.avg_pool2d(x, 2))
262 | x = block(this_fromrgb(x))
263 | x = (1 - alpha) * skip_x + alpha * x
264 | else:
265 | x = block(x)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/scripts/style-gan-pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | if isinstance(v, bool):
5 | return v
6 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
7 | return True
8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
9 | return False
10 | else:
11 | raise argparse.ArgumentTypeError('Boolean value expected.')
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dave-epstein/blobgan/c527f1c27447dffe3cf4cf3901571a83ce59f1fe/src/__init__.py
--------------------------------------------------------------------------------
/src/configs/checkpoint/after_each_epoch.yaml:
--------------------------------------------------------------------------------
1 | save_top_k: 3
2 | every_n_epochs: 1
3 | save_last: true
4 | monitor: validate_total_loss
--------------------------------------------------------------------------------
/src/configs/checkpoint/after_each_epoch_fid.yaml:
--------------------------------------------------------------------------------
1 | save_top_k: -1
2 | every_n_epochs: 1
3 | save_last: true
4 | monitor: train/fid
5 | save_on_train_epoch_end: true
6 | auto_insert_metric_name: true
--------------------------------------------------------------------------------
/src/configs/checkpoint/every_n_train_steps.yaml:
--------------------------------------------------------------------------------
1 | every_n_train_steps: 3000
2 | save_top_k: -1
3 | mode: max
4 | monitor: step
--------------------------------------------------------------------------------
/src/configs/checkpoint/every_n_train_steps_fid.yaml:
--------------------------------------------------------------------------------
1 | every_n_train_steps: 3000
2 | save_top_k: -1
3 | monitor: train/fid
--------------------------------------------------------------------------------
/src/configs/dataset/imagefolder.yaml:
--------------------------------------------------------------------------------
1 | name: ImageFolderDataModule
2 | resolution: 128
3 | dataloader:
4 | num_workers: 12
5 | batch_size: 24
6 | drop_last: true
--------------------------------------------------------------------------------
/src/configs/dataset/lsun.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - imagefolder
3 | category: bedroom
4 | basepath: /path/to/lsun # Must have train, validate, test subfolders (validate/test can be empty)
5 | path: ${.basepath}/${.category}
6 | dataloader:
7 | batch_size: 24
--------------------------------------------------------------------------------
/src/configs/dataset/multiimagefolder.yaml:
--------------------------------------------------------------------------------
1 | name: MultiImageFolderDataModule
2 | resolution: 128
3 | dataloader:
4 | num_workers: 12
5 | batch_size: 24
6 | drop_last: true
--------------------------------------------------------------------------------
/src/configs/dataset/multilsun.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - multiimagefolder
3 | categories: [kitchen,living,dining]
4 | category: null
5 | dataloader:
6 | batch_size: 24
--------------------------------------------------------------------------------
/src/configs/dataset/nodata.yaml:
--------------------------------------------------------------------------------
1 | name: NullDataModule
2 | dataloader:
3 | num_workers: 4
4 | batch_size: 128
--------------------------------------------------------------------------------
/src/configs/dataset/other_image_dataset.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - imagefolder
3 | path: /path/to/images # Must have train, validate, test subfolders (validate/test can be empty)
4 | dataloader:
5 | batch_size: 24
--------------------------------------------------------------------------------
/src/configs/experiment/blobgan.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /checkpoint: every_n_train_steps
4 | - /dataset: lsun
5 | checkpoint:
6 | every_n_train_steps: 1500
7 | wandb:
8 | name: BlobGAN
9 | dataset:
10 | category: bedroom
11 | resolution: ${model.resolution}
12 | dataloader:
13 | batch_size: 24
14 | drop_last: true
15 | model:
16 | name: BlobGAN
17 | lr: 0.002
18 | dim: 512
19 | noise_dim: 512
20 | resolution: 256
21 | lambda: # Needed for convenience since can't input λ on command line
22 | D_real: 1
23 | D_fake: 1
24 | D_R1: 50
25 | G: 1
26 | G_path: 2
27 | G_feature_mean: 10
28 | G_feature_variance: 10
29 | discriminator:
30 | name: StyleGANDiscriminator
31 | size: ${model.resolution}
32 | generator:
33 | name: models.networks.layoutstylegan.LayoutStyleGANGenerator
34 | style_dim: ${model.dim}
35 | n_mlp: 8
36 | size_in: 16
37 | c_model: 96
38 | spatial_style: ${model.spatial_style}
39 | size: ${model.resolution}
40 | layout_net:
41 | name: models.networks.layoutnet.LayoutGenerator
42 | n_features_max: ${model.n_features_max}
43 | feature_dim: 768
44 | style_dim: ${model.dim}
45 | noise_dim: ${model.noise_dim}
46 | norm_features: true
47 | mlp_lr_mul: 0.01
48 | mlp_hidden_dim: 1024
49 | spatial_style: ${model.spatial_style}
50 | D_reg_every: 16
51 | G_reg_every: -1
52 | λ: ${.lambda}
53 | log_images_every_n_steps: 1000
54 | n_features_min: ${model.n_features}
55 | n_features_max: ${model.n_features}
56 | n_features: 10
57 | spatial_style: true
58 | trainer:
59 | limit_val_batches: 0
60 | precision: 32
61 | plugins: null
62 | deterministic: false
--------------------------------------------------------------------------------
/src/configs/experiment/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | dataset:
3 | dataloader:
4 | num_workers: 0
5 | persistent_workers: false
6 | batch_size: 4
7 | trainer:
8 | gpus: 1
9 | accelerator: null
10 | plugins: null
11 | overfit_batches: 20
12 | wandb:
13 | group: debug
14 | detect_anomalies: true
15 | logger: false
--------------------------------------------------------------------------------
/src/configs/experiment/gan.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /checkpoint: every_n_train_steps
4 | - /dataset: lsun
5 | wandb:
6 | name: GAN
7 | dataset:
8 | resolution: ${model.resolution}
9 | dataloader:
10 | drop_last: true
11 | model:
12 | name: GAN
13 | lr: 0.002
14 | dim: 512
15 | resolution: 256
16 | lambda: # Needed for convenience since can't input λ on command line
17 | D_real: 1
18 | D_fake: 1
19 | D_R1: 50
20 | G: 1
21 | G_path: 2
22 | discriminator:
23 | name: StyleGANDiscriminator
24 | size: ${model.resolution}
25 | generator:
26 | name: models.networks.stylegan.StyleGANGenerator
27 | style_dim: 512
28 | dim: 512
29 | n_mlp: 8
30 | size: ${model.resolution}
31 | D_reg_every: 16
32 | λ: ${.lambda}
33 | log_images_every_n_steps: 1000
34 | trainer:
35 | limit_val_batches: 0
36 | precision: 32
37 | plugins: null
--------------------------------------------------------------------------------
/src/configs/experiment/invertblobgan.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /dataset: lsun
4 | checkpoint:
5 | save_top_k: 3
6 | save_last: true
7 | monitor: validate/total_loss
8 | every_n_train_steps: 5000
9 | wandb:
10 | name: InvertBlobGAN
11 | dataset:
12 | category: bedroom
13 | resolution: ${model.G.resolution}
14 | dataloader:
15 | batch_size: 16
16 | drop_last: true
17 | model:
18 | name: BlobGANInverter
19 | lr: 0.002
20 | log_images_every_n_steps: 1000
21 | lambda: # Needed for convenience since can't input λ on command line
22 | real_LPIPS: 1
23 | real_MSE: 1
24 | fake_LPIPS: 1
25 | fake_MSE: 1
26 | fake_latents_MSE: 1
27 | λ: ${.lambda}
28 | G_pretrained:
29 | key: state_dict
30 | log_dir: null # Defaults to $PWD/logs
31 | project: ${wandb.project}
32 | generator: ${model.G}
33 | generator_pretrained: ${model.G_pretrained}
34 | inverter:
35 | name: StyleGANDiscriminator
36 | size: ${model.G.resolution}
37 | discriminate_stddev: false
38 | G:
39 | lr: 0.002
40 | dim: 512
41 | noise_dim: 512
42 | resolution: 256
43 | lambda: # Needed for convenience since can't input λ on command line
44 | D_real: 1
45 | D_fake: 1
46 | D_R1: 50
47 | G: 1
48 | G_path: 2
49 | G_feature_mean: 10
50 | G_feature_variance: 10
51 | discriminator:
52 | name: StyleGANDiscriminator
53 | size: ${model.G.resolution}
54 | generator:
55 | name: models.networks.layoutstylegan.LayoutStyleGANGenerator
56 | style_dim: ${model.G.dim}
57 | n_mlp: 8
58 | size_in: 16
59 | c_model: 96
60 | spatial_style: ${model.G.spatial_style}
61 | size: ${model.G.resolution}
62 | layout_net:
63 | name: models.networks.layoutnet.LayoutGenerator
64 | n_features_max: ${model.G.n_features_max}
65 | feature_dim: 768
66 | style_dim: ${model.G.dim}
67 | noise_dim: ${model.G.noise_dim}
68 | norm_features: true
69 | mlp_lr_mul: 0.01
70 | mlp_hidden_dim: 1024
71 | spatial_style: ${model.G.spatial_style}
72 | D_reg_every: 16
73 | G_reg_every: -1
74 | λ: ${.lambda}
75 | log_images_every_n_steps: 1000
76 | n_features_min: ${model.G.n_features}
77 | n_features_max: ${model.G.n_features}
78 | n_features: 10
79 | spatial_style: true
80 | trainer:
81 | precision: 32
82 | plugins: null
83 | deterministic: false
--------------------------------------------------------------------------------
/src/configs/experiment/jitter.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | model:
3 | feature_jitter_xy: 0.04
4 | feature_jitter_shift: 0.5
5 | feature_jitter_angle: 0.1
--------------------------------------------------------------------------------
/src/configs/experiment/local.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | dataset:
3 | basepath: /path/to/your/lsun # Change to your path
4 | trainer:
5 | gpus: YOUR_NGPUS # Change to your number of GPUs
6 | wandb: # Fill in your settings
7 | group: YOUR_GROUP
8 | project: YOUR_PROJECT
9 | entity: YOUR_ENTITY
--------------------------------------------------------------------------------
/src/configs/fit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | hydra:
3 | run:
4 | dir: .
5 | output_subdir: null
6 | resume:
7 | id: null
8 | step: null
9 | epoch: null
10 | last: true
11 | best: false
12 | clobber_hparams: false
13 | project: ${wandb.project}
14 | log_dir: ${wandb.log_dir}
15 | model_only: false
16 | logger: wandb
17 | wandb:
18 | save_code: true
19 | offline: false
20 | log_dir: ./logs
21 | id: ${resume.id}
22 | trainer:
23 | accelerator: ddp
24 | benchmark: false
25 | deterministic: true
26 | gpus: 8
27 | precision: 16
28 | plugins: null
29 | max_steps: 10000000
30 | profiler: simple
31 | num_sanity_val_steps: 0
32 | log_every_n_steps: 200
33 | dataset:
34 | dataloader:
35 | prefetch_factor: 2
36 | pin_memory: true
37 | drop_last: true
38 | persistent_workers: true
39 | mode: fit
40 | seed: 0
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from utils import to_dataclass_cfg
4 | from .nodata import *
5 | from .imagefolder import *
6 | from .multiimagefolder import *
7 |
8 | def get_datamodule(name: str, **kwargs) -> LightningDataModule:
9 | cls = getattr(sys.modules[__name__], name)
10 | return cls(**to_dataclass_cfg(kwargs, cls))
11 |
--------------------------------------------------------------------------------
/src/data/imagefolder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 | from typing import Any, Optional, Union, Dict
6 |
7 | from pytorch_lightning import LightningDataModule
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | from torchvision.transforms import InterpolationMode
11 |
12 | from data.nodata import NullIterableDataset
13 | from data.utils import ImageFolderWithFilenames
14 | from utils import print_once
15 |
16 | _all__ = ['ImageFolderDataModule']
17 |
18 |
19 | @dataclass
20 | class ImageFolderDataModule(LightningDataModule):
21 | path: Union[str, Path] # Root
22 | dataloader: Dict[str, Any]
23 | resolution: int = 256 # Image dimension
24 |
25 | def __post_init__(self):
26 | super().__init__()
27 | self.path = Path(self.path)
28 | self.stats = {'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5)}
29 | self.transform = transforms.Compose([
30 | t for t in [
31 | transforms.Resize(self.resolution, InterpolationMode.LANCZOS),
32 | transforms.CenterCrop(self.resolution),
33 | transforms.RandomHorizontalFlip(),
34 | transforms.ToTensor(),
35 | transforms.Normalize(self.stats['mean'], self.stats['std'], inplace=True),
36 | ]
37 | ])
38 | self.data = {}
39 |
40 | def setup(self, stage: Optional[str] = None):
41 | for split in ('train', 'validate', 'test'):
42 | path = self.path / split
43 | empty = True
44 | if path.exists():
45 | try:
46 | self.data[split] = ImageFolderWithFilenames(path, transform=self.transform)
47 | empty = False
48 | except FileNotFoundError:
49 | pass
50 | if empty:
51 | print_once(
52 | f'Warning: no images found in {path}. Using empty dataset for split {split}. '
53 | f'Perhaps you set `dataset.path` incorrectly?')
54 | self.data[split] = NullIterableDataset(1)
55 |
56 | def train_dataloader(self) -> DataLoader:
57 | return self._get_dataloader('train')
58 |
59 | def val_dataloader(self) -> DataLoader:
60 | return self._get_dataloader('validate')
61 |
62 | def test_dataloader(self) -> DataLoader:
63 | return self._get_dataloader('test')
64 |
65 | def _get_dataloader(self, split: str):
66 | return DataLoader(self.data[split], **self.dataloader)
67 |
--------------------------------------------------------------------------------
/src/data/multiimagefolder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from dataclasses import dataclass
5 | from pathlib import Path
6 | from typing import Any, Optional, Union, Dict, List, Callable
7 |
8 | from pytorch_lightning import LightningDataModule
9 | from torch.utils.data import DataLoader, Dataset
10 | from torchvision import transforms
11 | from torchvision.transforms import InterpolationMode
12 |
13 | from data.utils import ImageFolderWithFilenames
14 | from utils import print_once
15 |
16 | _all__ = ['MultiImageFolderDataModule']
17 |
18 |
19 | @dataclass
20 | class MultiImageFolderDataModule(LightningDataModule):
21 | basepath: Union[str, Path] # Root
22 | categories: List[str]
23 | dataloader: Dict[str, Any]
24 | resolution: int = 256 # Image dimension
25 |
26 | def __post_init__(self):
27 | super().__init__()
28 | self.path = Path(self.basepath)
29 | self.stats = {'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5)}
30 | self.transform = transforms.Compose([
31 | t for t in [
32 | transforms.Resize(self.resolution, InterpolationMode.LANCZOS),
33 | transforms.CenterCrop(self.resolution),
34 | transforms.RandomHorizontalFlip(),
35 | transforms.ToTensor(),
36 | transforms.Normalize(self.stats['mean'], self.stats['std'], inplace=True),
37 | ]
38 | ])
39 | self.data = {}
40 |
41 | def setup(self, stage: Optional[str] = None):
42 | for split in ('train', 'validate', 'test'):
43 | try:
44 | self.data[split] = MultiImageFolderWithFilenames(self.basepath, self.categories, split,
45 | transform=self.transform)
46 | except FileNotFoundError:
47 | print_once(f'Could not create dataset for split {split}')
48 |
49 | def train_dataloader(self) -> DataLoader:
50 | return self._get_dataloader('train')
51 |
52 | def val_dataloader(self) -> DataLoader:
53 | return self._get_dataloader('validate')
54 |
55 | def test_dataloader(self) -> DataLoader:
56 | return self._get_dataloader('test')
57 |
58 | def _get_dataloader(self, split: str):
59 | return DataLoader(self.data[split], **self.dataloader)
60 |
61 |
62 | @dataclass
63 | class MultiImageFolderWithFilenames(Dataset):
64 | basepath: Union[str, Path] # Root
65 | categories: List[str]
66 | split: str
67 | transform: Callable
68 |
69 | def __post_init__(self):
70 | super().__init__()
71 | self.datasets = [ImageFolderWithFilenames(os.path.join(self.basepath, c, self.split), self.transform) for c in
72 | self.categories]
73 | self._n_datasets = len(self.datasets)
74 | self._dataset_lens = [len(d) for d in self.datasets]
75 | self._len = self._n_datasets * max(self._dataset_lens)
76 | print_once(f'Created dataset with {self.categories}. '
77 | f'Lengths are {self._dataset_lens}. Effective dataset length is {self._len}.')
78 |
79 | def __getitem__(self, index):
80 | dataset_idx = index % self._n_datasets
81 | item_idx = (index // self._n_datasets) % self._dataset_lens[dataset_idx]
82 | return self.datasets[dataset_idx][item_idx]
83 |
84 | def __len__(self):
85 | return self._len
86 |
87 |
--------------------------------------------------------------------------------
/src/data/nodata.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from dataclasses import dataclass
3 | from typing import Any, Dict
4 |
5 | from pytorch_lightning import LightningDataModule
6 | from torch.utils.data import DataLoader, IterableDataset
7 |
8 | __all__ = ["NullDataModule"]
9 |
10 |
11 | @dataclass
12 | class NullIterableDataset(IterableDataset):
13 | size: int
14 |
15 | def __post_init__(self):
16 | super().__init__()
17 |
18 | def __iter__(self):
19 | if self.size >= 0:
20 | return iter(range(self.size))
21 | else:
22 | return itertools.count(0, 0)
23 |
24 |
25 | @dataclass
26 | class NullDataModule(LightningDataModule):
27 | dataloader: Dict[str, Any]
28 | train_size: int = -1
29 | validate_size: int = -1
30 | test_size: int = -1
31 |
32 | def __post_init__(self):
33 | super().__init__()
34 |
35 | def train_dataloader(self) -> DataLoader:
36 | return self._get_dataloader(self.train_size)
37 |
38 | def val_dataloader(self) -> DataLoader:
39 | return self._get_dataloader(self.validate_size)
40 |
41 | def test_dataloader(self) -> DataLoader:
42 | return self._get_dataloader(self.test_size)
43 |
44 | def _get_dataloader(self, size: int) -> DataLoader:
45 | return DataLoader(NullIterableDataset(size), **self.dataloader)
46 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Tuple, List
3 | from typing import Optional, Callable, Any
4 |
5 | import torch
6 | from torchvision.datasets.folder import default_loader, ImageFolder, make_dataset
7 |
8 | from utils import is_rank_zero, print_once
9 |
10 |
11 | class ImageFolderWithFilenames(ImageFolder):
12 | def __init__(self, root: str, transform: Optional[Callable] = None,
13 | target_transform: Optional[Callable] = None,
14 | loader: Callable[[str], Any] = default_loader,
15 | is_valid_file: Optional[Callable[[str], bool]] = None):
16 | super().__init__(root=root, transform=transform, target_transform=target_transform,
17 | loader=loader, is_valid_file=is_valid_file)
18 |
19 | @staticmethod
20 | def make_dataset(
21 | directory: str,
22 | class_to_idx: Dict[str, int],
23 | extensions: Optional[Tuple[str, ...]] = None,
24 | is_valid_file: Optional[Callable[[str], bool]] = None,
25 | ) -> List[Tuple[str, int]]:
26 | if class_to_idx is None:
27 | # prevent potential bug since make_dataset() would use the class_to_idx logic of the
28 | # find_classes() function, instead of using that of the find_classes() method, which
29 | # is potentially overridden and thus could have a different logic.
30 | raise ValueError(
31 | "The class_to_idx parameter cannot be None."
32 | )
33 | cache_path = os.path.join(directory, 'cache.pt')
34 | try:
35 | dataset = torch.load(cache_path, map_location='cpu')
36 | print_once(f'Loading dataset from cache in {directory}')
37 | except FileNotFoundError:
38 | print_once(f'Creating dataset and saving to cache in {directory}')
39 | dataset = make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
40 | if is_rank_zero():
41 | torch.save(dataset, cache_path)
42 | except EOFError:
43 | print_once(f'Error loading cache from {directory},'
44 | f' likely because dataset is small and read/write were attempted concurrently. '
45 | f'Proceeding by remaking dataset in-memory.')
46 | dataset = make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
47 | print_once(f'{len(dataset)} images in dataset')
48 | return dataset
49 |
50 | def __getitem__(self, i):
51 | x, y = super().__getitem__(i)
52 | return x, {'labels': y, 'filenames': self.imgs[i][0]}
53 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Any, Tuple, Dict
3 |
4 | from pytorch_lightning import LightningModule
5 |
6 | from models import networks
7 | from utils import to_dataclass_cfg
8 | # from .segmenter import *
9 | from .blobgan import *
10 | from .gan import *
11 | from .invertblobgan import *
12 |
13 |
14 | def get_model(name: str, return_cfg: bool = False, **kwargs) -> Tuple[LightningModule, Dict[str, Any]]:
15 | cls = getattr(sys.modules[__name__], name)
16 | cfg = to_dataclass_cfg(kwargs, cls)
17 | if return_cfg:
18 | return cls(**cfg), cfg
19 | else:
20 | return cls(**cfg)
21 |
--------------------------------------------------------------------------------
/src/models/base.py:
--------------------------------------------------------------------------------
1 | from itertools import groupby
2 | from numbers import Number
3 | from typing import Union, Any, Optional, Dict, Tuple, List
4 |
5 | import numpy as np
6 | import torch
7 | from einops import rearrange
8 | from pytorch_lightning import LightningModule
9 | from torch import Tensor
10 |
11 | from utils import scalars_to_log_dict, run_at_step, epoch_outputs_to_log_dict, is_rank_zero, get_rank, print_once
12 |
13 |
14 | class BaseModule(LightningModule):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | # Control flow
19 | def training_step(self, batch: Tuple[Tensor, dict], batch_idx: int, optimizer_idx: Optional[int] = None) -> Tensor:
20 | return self.shared_step(batch, batch_idx, optimizer_idx, 'train')
21 |
22 | def validation_step(self, batch: Tuple[Tensor, dict], batch_idx: int):
23 | return self.shared_step(batch, batch_idx, mode='validate')
24 |
25 | def test_step(self, batch: Tuple[Tensor, dict], batch_idx: int):
26 | return self.shared_step(batch, batch_idx, mode='test')
27 |
28 | def valtest_epoch_end(self, outputs: List[Dict[str, Tensor]], mode: str):
29 | if self.logger is None:
30 | return
31 | # Either log each step's output separately (results have been all_gathered in this case)
32 | if self.valtest_log_all:
33 | for image_dict in outputs:
34 | self._log_image_dict(image_dict, mode, commit=True)
35 | # Or just log a random batch worth of images from master process
36 | else:
37 | self._log_image_dict(epoch_outputs_to_log_dict(outputs, n_max="batch", shuffle=True), mode)
38 |
39 | def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]):
40 | self.valtest_epoch_end(outputs, 'validate')
41 |
42 | def test_epoch_end(self, outputs: List[Dict[str, Tensor]]):
43 | self.valtest_epoch_end(outputs, 'test')
44 |
45 | # Utility methods for logging
46 | def gather_tensor(self, t: Tensor) -> Tensor:
47 | return rearrange(self.all_gather(t), "m n c h w -> (m n) c h w")
48 |
49 | def gather_tensor_dict(self, d: Dict[Any, Tensor]) -> Dict[Any, Tensor]:
50 | return {k: rearrange(v.cpu(), "m n c h w -> (m n) c h w") for k, v in self.all_gather(d).items()}
51 |
52 | def log_scalars(self, scalars: Dict[Any, Union[Number, Tensor]], mode: str, **kwargs):
53 | if 'sync_dist' not in kwargs:
54 | kwargs['sync_dist'] = mode != 'train'
55 | self.log_dict(scalars_to_log_dict(scalars, mode), **kwargs)
56 |
57 | def _log_image_dict(self, img_dict: Dict[str, Tensor], mode: str, commit: bool = False, **kwargs):
58 | if self.logger is not None:
59 | for k, v in img_dict.items():
60 | self.logger.log_image_batch(f'{mode}/{k}', v, commit=commit, **kwargs)
61 |
62 | def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
63 | optimizer.zero_grad(set_to_none=True) # Improves performance
64 |
65 | def alert_nan_loss(self, loss: Tensor, batch_idx: int):
66 | if loss != loss:
67 | print(
68 | f'NaN loss in epoch {self.current_epoch}, batch index {batch_idx}, global step {self.global_step}, '
69 | f'local rank {get_rank()}. Skipping.')
70 | return loss != loss
71 |
72 | def _log_profiler(self):
73 | if run_at_step(self.trainer.global_step, self.log_timing_every_n_steps):
74 | report, total_duration = self.trainer.profiler._make_report()
75 | report_log = dict([kv for action, durations, duration_per in report for kv in
76 | [(f'profiler/mean_t/{action}', np.mean(durations)),
77 | (f'profiler/n_calls/{action}', len(durations)),
78 | (f'profiler/total_t/{action}', np.sum(durations)),
79 | (f'profiler/pct/{action}', duration_per)]])
80 | self.log_dict(report_log)
81 | self.logger.save_to_file('profiler_summary.txt', self.trainer.profiler.summary(), unique_filename=False)
82 |
83 | def on_train_start(self):
84 | if self.logger:
85 | self.logger.log_model_summary(self)
86 |
87 | def log_grad_norm(self, grad_norm_dict: Dict[str, torch.Tensor]) -> None:
88 | self.log_dict({'grads/' + k: v for k, v in grad_norm_dict.items()})
89 |
90 | def on_after_backward(self) -> None:
91 | if not getattr(self, 'validate_gradients', False):
92 | return
93 |
94 | valid_gradients = True
95 | invalid_params = []
96 | for name, param in self.named_parameters():
97 | if param.grad is not None:
98 | this_param_valid = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
99 | valid_gradients &= this_param_valid
100 | if not this_param_valid:
101 | invalid_params.append(name)
102 | # if not valid_gradients:
103 | # break
104 |
105 | if not valid_gradients:
106 | depth_two_params = [k for k, _ in groupby(
107 | ['.'.join(n.split('.')[:2]).replace('.weight', '').replace('.bias', '') for n in invalid_params])]
108 | print_once(f'Detected inf/NaN gradients for parameters {", ".join(depth_two_params)}. '
109 | f'Skipping epoch {self.current_epoch}, batch index {self.batch_idx}, global step {self.global_step}.')
110 | self.zero_grad(set_to_none=True)
111 |
--------------------------------------------------------------------------------
/src/models/gan.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | __all__ = ["GAN"]
4 |
5 | from dataclasses import dataclass
6 | from typing import Optional, Union, List, Callable, Tuple
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from cleanfid import fid
12 | from torch import nn, Tensor
13 | from torch.cuda.amp import autocast
14 | from torch.optim import Optimizer
15 |
16 | from models import networks
17 | from models.base import BaseModule
18 | from utils import FromConfig, run_at_step, get_D_stats, G_path_loss, D_R1_loss, freeze, is_rank_zero, accumulate, \
19 | mixing_noise, print_once
20 |
21 |
22 | @dataclass
23 | class Lossλs:
24 | D_real: float = 1
25 | D_fake: float = 1
26 | D_R1: float = 5
27 | G: float = 1
28 | G_path: float = 2
29 |
30 | def __getitem__(self, key):
31 | return super().__getattribute__(key)
32 |
33 |
34 | @dataclass(eq=False)
35 | class GAN(BaseModule):
36 | # Modules
37 | generator: FromConfig[nn.Module]
38 | discriminator: FromConfig[nn.Module]
39 | # Module parameters
40 | dim: int = 256
41 | resolution: int = 128
42 | p_mixing_noise: float = 0.9
43 | n_ema_sample: int = 16
44 | freeze_G: bool = False
45 | # Optimization
46 | lr: float = 1e-3
47 | eps: float = 1e-5
48 | # Regularization
49 | D_reg_every: int = 16
50 | G_reg_every: int = 4
51 | path_len: float = 0
52 | # Loss parameters
53 | λ: FromConfig[Lossλs] = None
54 | # Logging
55 | log_images_every_n_steps: Optional[int] = 500
56 | log_timing_every_n_steps: Optional[int] = -1
57 | log_fid_every_n_steps: Optional[int] = -1
58 | log_fid_every_epoch: bool = True
59 | fid_n_imgs: Optional[int] = 50000
60 | fid_stats_name: Optional[str] = None
61 | fid_num_workers: Optional[int] = 24
62 | valtest_log_all: bool = False
63 | accumulate: bool = True
64 |
65 | def __post_init__(self):
66 | super().__init__()
67 | self.save_hyperparameters()
68 | self.discriminator = networks.get_network(**self.discriminator)
69 | self.generator_ema = networks.get_network(**self.generator)
70 | self.generator = networks.get_network(**self.generator)
71 | if self.freeze_G:
72 | self.generator.eval()
73 | freeze(self.generator)
74 | if self.accumulate:
75 | self.generator_ema.eval()
76 | freeze(self.generator_ema)
77 | accumulate(self.generator_ema, self.generator, 0)
78 | else:
79 | del self.generator_ema
80 | self.λ = Lossλs(**self.λ)
81 | self.register_buffer('sample_z', torch.randn(self.n_ema_sample, self.dim))
82 | # self.sample_z = torch.randn(self.n_ema_sample, self.dim)
83 |
84 | # Initialization and state management
85 | def on_train_start(self):
86 | super().on_train_start()
87 | # Validate parameters w.r.t. trainer (must be done here since trainer is not attached as property yet in init)
88 | assert self.log_images_every_n_steps % self.trainer.log_every_n_steps == 0, \
89 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder'
90 | if self.log_timing_every_n_steps > -1:
91 | assert self.log_timing_every_n_steps % self.trainer.log_every_n_steps == 0, \
92 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder'
93 | assert self.log_fid_every_n_steps < 0 or self.log_fid_every_n_steps % self.trainer.log_every_n_steps == 0, \
94 | '`model.log_fid_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder'
95 | assert not ((self.log_fid_every_n_steps > -1 or self.log_fid_every_epoch) and (not self.fid_stats_name)), \
96 | 'Cannot compute FID without name of statistics file to use.'
97 |
98 | def configure_optimizers(self) -> Union[optim, List[optim]]:
99 | G_reg_ratio = self.G_reg_every / ((self.G_reg_every + 1) or -1)
100 | D_reg_ratio = self.D_reg_every / ((self.D_reg_every + 1) or -1)
101 | _requires_grad = lambda p: p.requires_grad
102 | G_optim = torch.optim.Adam(filter(_requires_grad, self.generator.parameters()), lr=self.lr * G_reg_ratio,
103 | betas=(0 ** G_reg_ratio, 0.99 ** G_reg_ratio), eps=self.eps)
104 | D_optim = torch.optim.Adam(filter(_requires_grad, self.discriminator.parameters()), lr=self.lr * D_reg_ratio,
105 | betas=(0 ** D_reg_ratio, 0.99 ** D_reg_ratio), eps=self.eps)
106 | if self.freeze_G:
107 | return D_optim
108 | else:
109 | return G_optim, D_optim
110 |
111 | def optimizer_step(
112 | self,
113 | epoch: int = None,
114 | batch_idx: int = None,
115 | optimizer: Optimizer = None,
116 | optimizer_idx: int = None,
117 | optimizer_closure: Optional[Callable] = None,
118 | on_tpu: bool = None,
119 | using_native_amp: bool = None,
120 | using_lbfgs: bool = None,
121 | ):
122 | optimizer.step(closure=optimizer_closure)
123 |
124 | def training_epoch_end(self, *args, **kwargs):
125 | if self.log_fid_every_epoch:
126 | try:
127 | self.log_fid("train")
128 | except:
129 | pass
130 |
131 | def gen(self, z, truncate, ema=True, norm_img=True):
132 | G = self.generator_ema if ema else self.generator
133 | try:
134 | imgs = G([z], return_image_only=True, truncation=1 - truncate,
135 | truncation_latent=self.mean_latent)
136 | except AttributeError:
137 | print_once('Computing mean latent for generation.')
138 | self.get_mean_latent()
139 | imgs = G([z], return_image_only=True, truncation=1 - truncate,
140 | truncation_latent=self.mean_latent)
141 | if norm_img:
142 | imgs = imgs.add_(1).div_(2).mul_(255)
143 | return imgs
144 |
145 | @torch.no_grad()
146 | def log_fid(self, mode, **kwargs):
147 | def gen_fn(z):
148 | if self.accumulate:
149 | out = self.generator_ema([z], return_image_only=True).add_(1).div_(2).mul_(255)
150 | else:
151 | out = self.generator([z], return_image_only=True).add_(1).div_(2).mul_(255)
152 | return out.clamp(min=0, max=255)
153 |
154 | if is_rank_zero():
155 | fid_score = fid.compute_fid(gen=gen_fn, dataset_name=self.fid_stats_name,
156 | dataset_res=256, num_gen=self.fid_n_imgs,
157 | dataset_split="custom", device=self.device,
158 | num_workers=self.fid_num_workers)
159 | else:
160 | fid_score = 0.0
161 | try:
162 | fid_score = self.all_gather(fid_score).max().item()
163 | self.log_scalars({'fid': fid_score}, mode, **kwargs)
164 | except AttributeError:
165 | pass
166 | return fid_score
167 |
168 | def get_mean_latent(self, n_trunc: int = 10000, ema=True):
169 | G = self.generator_ema if ema else self.generator
170 | mean_latent = self.mean_latent = G.mean_latent(n_trunc)
171 | return mean_latent
172 |
173 | # Training and evaluation
174 | def shared_step(self, batch: Tuple[Tensor, dict], batch_idx: int,
175 | optimizer_idx: Optional[int] = None, mode: str = 'train') -> Optional[Union[Tensor, dict]]:
176 | """
177 | Args:
178 | batch: tuple of tensor of shape N x C x H x W of images and a dictionary of batch metadata/labels
179 | batch_idx: pytorch lightning training loop batch index
180 | optimizer_idx: pytorch lightning optimizer index (0 = G, 1 = D)
181 | mode:
182 | `train` returns the total loss and logs losses and images/profiling info.
183 | `validate`/`test` log total loss and return images
184 | Returns: see description for `mode` above
185 | """
186 | # Set up modules and data
187 | train = mode == 'train'
188 | train_G = train and optimizer_idx == 0 and not self.freeze_G
189 | train_D = train and (optimizer_idx == 1 or self.freeze_G)
190 | batch_real, batch_labels = batch
191 | # z = torch.randn(len(batch_real), self.dim).type_as(batch_real)
192 | info = dict()
193 | losses = dict()
194 | z = mixing_noise(batch_real, self.dim, self.p_mixing_noise)
195 |
196 | gen_imgs, latents = self.generator(z, return_latents=True)
197 |
198 | if latents is not None:
199 | if latents.ndim == 3:
200 | latents = latents[:, 0]
201 | info['latent_norm'] = latents.norm(2, 1).mean()
202 | info['latent_stdev'] = latents.std(0).mean()
203 |
204 | # Compute various losses
205 | logits_fake = self.discriminator(gen_imgs)
206 | if train_G or not train:
207 | # Log
208 | losses['G'] = F.softplus(-logits_fake).mean()
209 | if train_D or not train:
210 | # Discriminate real images
211 | logits_real = self.discriminator(batch_real)
212 | # Log
213 | losses['D_real'] = F.softplus(-logits_real).mean()
214 | losses['D_fake'] = F.softplus(logits_fake).mean()
215 | info.update(get_D_stats('fake', logits_fake, gt=False))
216 | info.update(get_D_stats('real', logits_real, gt=True))
217 |
218 | # Save images
219 | imgs = {
220 | 'real_imgs': batch_real,
221 | 'gen_imgs': gen_imgs,
222 | }
223 | imgs = {k: v.clone().detach().float().cpu() for k, v in imgs.items()}
224 |
225 | # Compute train regularization loss
226 | if train_G and run_at_step(batch_idx, self.G_reg_every):
227 | if self.λ.G_path:
228 | z = mixing_noise(batch_real, self.dim, self.p_mixing_noise)
229 | gen_imgs, latents = self.generator(z, return_latents=True)
230 | losses['G_path'], self.path_len, info['G_path_len'] = G_path_loss(gen_imgs, latents, self.path_len)
231 | losses['G_path'] = losses['G_path'] * self.G_reg_every
232 | elif train_D and run_at_step(batch_idx, self.D_reg_every):
233 | if self.λ.D_R1:
234 | with autocast(enabled=False):
235 | batch_real.requires_grad = True
236 | logits_real = self.discriminator(batch_real)
237 | R1 = D_R1_loss(logits_real, batch_real)
238 | info['D_R1_unscaled'] = R1
239 | losses['D_R1'] = R1 * self.D_reg_every
240 |
241 | # Compute final loss and log
242 | losses['total_loss'] = sum(map(lambda k: losses[k] * self.λ[k], losses))
243 | # if losses['total_loss'] > 20 and is_rank_zero():
244 | # import ipdb
245 | # ipdb.set_trace()
246 | if self.alert_nan_loss(losses['total_loss'], batch_idx):
247 | if is_rank_zero():
248 | import ipdb
249 | ipdb.set_trace()
250 | return
251 | self.log_scalars(losses, mode)
252 | self.log_scalars(info, mode)
253 | # Further logging and terminate
254 | if mode == "train":
255 | if train_G and self.accumulate:
256 | accumulate(self.generator_ema, self.generator, 0.5 ** (32 / (10 * 1000)))
257 | if run_at_step(self.trainer.global_step, self.log_images_every_n_steps):
258 | if self.accumulate:
259 | with torch.no_grad():
260 | imgs['gen_imgs_ema'], _ = self.generator_ema([self.sample_z])
261 | self._log_image_dict(imgs, mode, square_grid=False, ncol=len(batch_real))
262 | if run_at_step(self.trainer.global_step, self.log_fid_every_n_steps) and is_rank_zero() and train_G:
263 | self.log_fid(mode)
264 | self._log_profiler()
265 | return losses['total_loss']
266 | else:
267 | if self.valtest_log_all:
268 | imgs = self.gather_tensor_dict(imgs)
269 | return imgs
270 |
--------------------------------------------------------------------------------
/src/models/invertblobgan.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | __all__ = ["BlobGANInverter"]
4 |
5 | import random
6 | from dataclasses import dataclass
7 | from typing import Optional, Union, List, Callable, Tuple
8 |
9 | import torch
10 | import torch.optim as optim
11 | from PIL import Image
12 | from lpips import LPIPS
13 | from omegaconf import DictConfig
14 | from torch import nn, Tensor
15 | from torch.optim import Optimizer
16 | from torchvision.utils import make_grid
17 |
18 | from models import networks, BlobGAN
19 | from models.base import BaseModule
20 | from utils import FromConfig, run_at_step, freeze, is_rank_zero, load_pretrained_weights, to_dataclass_cfg, print_once
21 |
22 | # SPLAT_KEYS = ['spatial_style', 'xs', 'ys', 'covs', 'sizes']
23 | SPLAT_KEYS = ['spatial_style', 'scores_pyramid']
24 | _ = Image
25 | _ = make_grid
26 |
27 |
28 | @dataclass
29 | class Lossλs:
30 | real_LPIPS: float = 1.
31 | real_MSE: float = 1.
32 | fake_LPIPS: float = 1.
33 | fake_MSE: float = 1.
34 | fake_latents_MSE: float = 1.
35 |
36 | def __getitem__(self, key):
37 | return super().__getattribute__(key)
38 |
39 |
40 | @dataclass(eq=False)
41 | class BlobGANInverter(BaseModule):
42 | # Modules
43 | inverter: FromConfig[nn.Module]
44 | generator: FromConfig[BlobGAN]
45 | # Loss parameters
46 | λ: FromConfig[Lossλs] = None
47 | # Logging
48 | log_images_every_n_steps: Optional[int] = 500
49 | log_timing_every_n_steps: Optional[int] = -1
50 | log_grads_every_n_steps: Optional[int] = -1
51 | valtest_log_all: bool = False
52 | # Resuming
53 | generator_pretrained: Optional[Union[str, DictConfig]] = None
54 | load_only_inverter: bool = False
55 | inverter_d_out: Optional[int] = None
56 | # Optim
57 | lr: float = 0.002
58 | eps: float = 1e-5
59 | # Training
60 | trunc_min: float = 0.0
61 | trunc_max: float = 0.0
62 |
63 | def __post_init__(self):
64 | super().__init__()
65 | self.save_hyperparameters()
66 | cfg = to_dataclass_cfg(self.generator, BlobGAN)
67 | if self.generator_pretrained.log_dir is None:
68 | self.generator_pretrained.log_dir = 'logs/'
69 | if not self.load_only_inverter:
70 | self.generator = load_pretrained_weights('BlobGAN', self.generator_pretrained, BlobGAN(**cfg), strict=False)
71 | del self.generator.discriminator
72 | del self.generator.generator
73 | del self.generator.layout_net
74 | freeze(self.generator)
75 | self.inverter = networks.get_network(**self.inverter,
76 | d_out=self.inverter_d_out or
77 | self.generator.layout_net_ema.mlp[-1].weight.shape[0])
78 | self.L_LPIPS = LPIPS(net='vgg', verbose=False)
79 | freeze(self.L_LPIPS)
80 | self.λ = Lossλs(**self.λ)
81 |
82 | # Initialization and state management
83 | def on_train_start(self):
84 | super().on_train_start()
85 | # Validate parameters w.r.t. trainer (must be done here since trainer is not attached as property yet in init)
86 | assert self.log_images_every_n_steps % self.trainer.log_every_n_steps == 0, \
87 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder. ' \
88 | f'Got {self.log_images_every_n_steps} and {self.trainer.log_every_n_steps}.'
89 | assert self.log_timing_every_n_steps < 0 or self.log_timing_every_n_steps % self.trainer.log_every_n_steps == 0, \
90 | '`model.log_images_every_n_steps` must be divisible by `trainer.log_every_n_steps` without remainder'
91 |
92 | def configure_optimizers(self) -> Union[optim, List[optim]]:
93 | params = list(self.inverter.parameters())
94 | print_once(f'Optimizing {sum([p.numel() for p in params]) / 1e6:.2f}M params')
95 | return torch.optim.AdamW(params, lr=self.lr, eps=self.eps, weight_decay=0)
96 |
97 | def optimizer_step(
98 | self,
99 | epoch: int = None,
100 | batch_idx: int = None,
101 | optimizer: Optimizer = None,
102 | optimizer_idx: int = None,
103 | optimizer_closure: Optional[Callable] = None,
104 | on_tpu: bool = None,
105 | using_native_amp: bool = None,
106 | using_lbfgs: bool = None,
107 | ):
108 | self.batch_idx = batch_idx
109 | optimizer.step(closure=optimizer_closure)
110 |
111 | def shared_step(self, batch: Tuple[Tensor, dict], batch_idx: int,
112 | optimizer_idx: Optional[int] = None, mode: str = 'train') -> Optional[Union[Tensor, dict]]:
113 | """
114 | Args:
115 | batch: tuple of tensor of shape N x C x H x W of images and a dictionary of batch metadata/labels
116 | batch_idx: pytorch lightning training loop batch index
117 | optimizer_idx: pytorch lightning optimizer index (0 = G, 1 = D)
118 | mode:
119 | `train` returns the total loss and logs losses and images/profiling info.
120 | `validate`/`test` log total loss and return images
121 | Returns: see description for `mode` above
122 | """
123 | # Set up modules and data
124 | batch_real, batch_labels = batch
125 | log_images = run_at_step(self.trainer.global_step, self.log_images_every_n_steps)
126 |
127 | z = torch.randn(len(batch_real), self.generator.noise_dim).type_as(batch_real)
128 |
129 | with torch.no_grad():
130 | truncate = self.trunc_min if self.trunc_min == self.trunc_max \
131 | else random.uniform(self.trunc_min, self.trunc_max)
132 | layout_gt_fake, gen_imgs = self.generator.gen(z, truncate=truncate, ema=True, viz=log_images,
133 | ret_layout=True)
134 |
135 | losses = dict()
136 |
137 | z_pred_fake = self.inverter(gen_imgs.detach())
138 |
139 | layout_pred_fake, reconstr_fake = self.generator.gen(z_pred_fake, ema=True, viz=log_images, ret_layout=True,
140 | mlp_idx=len(self.generator.layout_net_ema.mlp))
141 |
142 | losses['fake_MSE'] = (gen_imgs - reconstr_fake).pow(2).mean()
143 | losses['fake_LPIPS'] = self.L_LPIPS(reconstr_fake, gen_imgs).mean()
144 | latent_l2_loss = []
145 | for k in ('xs', 'ys', 'covs', 'sizes', 'features', 'spatial_style'):
146 | latent_l2_loss.append((layout_pred_fake[k] - layout_gt_fake[k].detach()).pow(2).mean())
147 | losses['fake_latents_MSE'] = sum(latent_l2_loss) / len(latent_l2_loss)
148 |
149 | z_pred_real = self.inverter(batch_real.detach())
150 | layout_pred_real, reconstr_real = self.generator.gen(z_pred_real, ema=True, viz=log_images, ret_layout=True,
151 | mlp_idx=len(self.generator.layout_net_ema.mlp))
152 |
153 | losses['real_MSE'] = (batch_real - reconstr_real).pow(2).mean()
154 | losses['real_LPIPS'] = self.L_LPIPS(reconstr_real, batch_real).mean()
155 |
156 | total_loss = f'total_loss'
157 | losses[total_loss] = sum(map(lambda k: losses[k] * self.λ[k], losses))
158 | isnan = self.alert_nan_loss(losses[total_loss], batch_idx)
159 | if self.all_gather(isnan).any():
160 | if self.ipdb_on_nan and is_rank_zero():
161 | import ipdb
162 | ipdb.set_trace()
163 | return
164 | self.log_scalars(losses, mode)
165 |
166 | imgs = {
167 | 'real': batch_real,
168 | 'real_reconstr': reconstr_real,
169 | 'fake': gen_imgs,
170 | 'fake_reconstr': reconstr_fake,
171 | 'real_reconstr_feats': layout_pred_real['feature_img'],
172 | 'fake_reconstr_feats': layout_pred_fake['feature_img'],
173 | 'fake_feats': layout_gt_fake['feature_img']
174 | }
175 | if mode == "train":
176 | if log_images and is_rank_zero():
177 | imgs = {k: v.clone().detach().float().cpu() for k, v in imgs.items()}
178 | self._log_image_dict(imgs, mode, square_grid=False, ncol=len(batch_real))
179 | return losses[total_loss]
180 | else:
181 | if self.valtest_log_all:
182 | imgs = self.gather_tensor_dict(imgs)
183 | return imgs
184 |
--------------------------------------------------------------------------------
/src/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import Optional, Union
3 |
4 | import torch
5 | from omegaconf import DictConfig
6 | from torch import nn
7 |
8 | from utils import import_external, is_rank_zero, get_checkpoint_path, load_pretrained_weights
9 | from .stylegan import *
10 | from .layoutnet import *
11 |
12 |
13 | def get_network(name: str, pretrained: Optional[Union[str, DictConfig]] = None, **kwargs) -> nn.Module:
14 | if '.' in name:
15 | ret = import_external(name, pretrained, **kwargs)
16 | return ret
17 | else:
18 | ret = getattr(sys.modules[__name__], name)(**kwargs)
19 | return load_pretrained_weights(name, pretrained, ret)
20 |
--------------------------------------------------------------------------------
/src/models/networks/layoutnet.py:
--------------------------------------------------------------------------------
1 | # https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
2 | import random
3 | from dataclasses import dataclass
4 | from typing import Optional, Dict
5 |
6 | import torch
7 | from einops import rearrange
8 | from torch import nn, Tensor
9 |
10 | __all__ = ["LayoutGenerator"]
11 |
12 | from models.networks.stylegan import StyleMLP, pixel_norm
13 | from utils import derange_tensor
14 |
15 |
16 | @dataclass(eq=False)
17 | class LayoutGenerator(nn.Module):
18 | noise_dim: int = 512
19 | feature_dim: int = 512
20 | style_dim: int = 512
21 | # MLP options
22 | mlp_n_layers: int = 8
23 | mlp_trunk_n_layers: int = 4
24 | mlp_hidden_dim: int = 1024
25 | n_features_max: int = 5
26 | norm_features: bool = False
27 | # Transformer options
28 | spatial_style: bool = False
29 | # Training options
30 | mlp_lr_mul: float = 0.01
31 | shuffle_features: bool = False
32 | p_swap_style: float = 0.0
33 | feature_jitter_xy: float = 0.0 # Legacy, unused
34 | feature_dropout: float = 0.0
35 |
36 | def __post_init__(self):
37 | super().__init__()
38 | if self.feature_jitter_xy:
39 | print('Warning! This parameter is here only to support loading of old checkpoints, and does not function. '
40 | 'Unless you are loading a model that has this value set, it should not be used. To control jitter, '
41 | 'set model.feature_jitter_xy directly.')
42 | # {x_i, y_i, feature_i, covariance_i}, bg feature, and cluster sizes
43 | maybe_style_dim = int(self.spatial_style) * self.style_dim
44 | ndim = (self.feature_dim + maybe_style_dim + 2 + 4 + 1) * self.n_features_max + \
45 | (maybe_style_dim + self.feature_dim + 1)
46 | self.mlp = StyleMLP(self.mlp_n_layers, self.mlp_hidden_dim, self.mlp_lr_mul, first_dim=self.noise_dim,
47 | last_dim=ndim, last_relu=False)
48 |
49 | def forward(self, noise: Tensor, n_features: int,
50 | mlp_idx: Optional[int] = None) -> Optional[Dict[str, Tensor]]:
51 | """
52 | Args:
53 | noise: [N x noise_dim] or [N x M x noise_dim]
54 | mlp_idx: which IDX to start running MLP from, useful for truncation
55 | n_features: int num features to output
56 | Returns: three tensors x coordinates [N x M], y coordinates [N x M], features [N x M x feature_dim]
57 | """
58 | if mlp_idx is None:
59 | out = self.mlp(noise)
60 | else:
61 | out = self.mlp[mlp_idx:](noise)
62 | sizes, out = out.tensor_split((self.n_features_max + 1,), dim=1)
63 | bg_feat, out = out.tensor_split((self.feature_dim,), dim=1)
64 | if self.spatial_style:
65 | bg_style_feat, out = out.tensor_split((self.style_dim,), dim=1)
66 | out = rearrange(out, 'n (m d) -> n m d', m=self.n_features_max)
67 | if self.shuffle_features:
68 | idxs = torch.randperm(self.n_features_max)[:n_features]
69 | else:
70 | idxs = torch.arange(n_features)
71 | out = out[:, idxs]
72 | sizes = sizes[:, [0] + idxs.add(1).tolist()]
73 | if self.feature_dropout:
74 | keep = torch.rand((out.size(1),)) > self.feature_dropout
75 | if not keep.any():
76 | keep[0] = True
77 | out = out[:, keep]
78 | sizes = sizes[:, [True] + keep.tolist()]
79 | xy = out[..., :2].sigmoid() # .mul(self.max_coord)
80 | ret = {'xs': xy[..., 0], 'ys': xy[..., 1], 'sizes': sizes[:, :n_features + 1], 'covs': out[..., 2:6]}
81 | end_dim = self.feature_dim + 6
82 | features = out[..., 6:end_dim]
83 | features = torch.cat((bg_feat[:, None], features), 1)
84 | ret['features'] = features
85 | # return [xy[..., 0], xy[..., 1], features, covs, sizes[:, :n_features + 1].softmax(-1)]
86 | if self.spatial_style:
87 | style_features = out[..., end_dim:]
88 | style_features = torch.cat((bg_style_feat[:, None], style_features), 1)
89 | ret['spatial_style'] = style_features
90 | # ret['covs'] = ret['covs'].detach()
91 | if self.norm_features:
92 | for k in ('features', 'spatial_style', 'shape_features'):
93 | if k in ret:
94 | ret[k] = pixel_norm(ret[k])
95 | if self.p_swap_style:
96 | if random.random() <= self.p_swap_style:
97 | n = random.randint(0, ret['spatial_style'].size(1) - 1)
98 | shuffle = torch.randperm(ret['spatial_style'].size(1) - 1).add(1)[:n]
99 | ret['spatial_style'][:, shuffle] = derange_tensor(ret['spatial_style'][:, shuffle])
100 | return ret
101 |
--------------------------------------------------------------------------------
/src/models/networks/op/__init__.py:
--------------------------------------------------------------------------------
1 | # From https://github.com/rosinality/stylegan2-pytorch/tree/3dee637b8937bf3830991c066ed8d9cc58afd661/op
2 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
3 | from .upfirdn2d import upfirdn2d
4 |
--------------------------------------------------------------------------------
/src/models/networks/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from distutils.version import LooseVersion
3 |
4 | if LooseVersion(torch.__version__) >= LooseVersion('1.11.0'):
5 | # New conv refactoring started at version 1.11, it seems.
6 | from .conv2d_gradfix_111andon import conv2d, conv_transpose2d
7 | else:
8 | from .conv2d_gradfix_pre111 import conv2d, conv_transpose2d
--------------------------------------------------------------------------------
/src/models/networks/op/conv2d_gradfix_111andon.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 | # THANKS https://github.com/pytorch/pytorch/issues/74437 !!!!!
12 | import warnings
13 | import contextlib
14 | import torch
15 | from distutils.version import LooseVersion
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | # ----------------------------------------------------------------------------
22 |
23 | enabled = True # Enable the custom op by setting this to true.
24 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
25 |
26 |
27 | @contextlib.contextmanager
28 | def no_weight_gradients():
29 | global weight_gradients_disabled
30 | old = weight_gradients_disabled
31 | weight_gradients_disabled = True
32 | yield
33 | weight_gradients_disabled = old
34 |
35 |
36 | # ----------------------------------------------------------------------------
37 |
38 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
39 | if _should_use_custom_op(input):
40 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding,
41 | output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
42 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding,
43 | dilation=dilation, groups=groups)
44 |
45 |
46 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
47 | if _should_use_custom_op(input):
48 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding,
49 | output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight,
50 | bias)
51 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding,
52 | output_padding=output_padding, groups=groups, dilation=dilation)
53 |
54 |
55 | # ----------------------------------------------------------------------------
56 |
57 | def _should_use_custom_op(input):
58 | assert isinstance(input, torch.Tensor)
59 | if (not enabled) or (not torch.backends.cudnn.enabled):
60 | return False
61 | if input.device.type != 'cuda':
62 | return False
63 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
64 | return True
65 | warnings.warn(
66 | f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
67 | return False
68 |
69 |
70 | def _tuple_of_ints(xs, ndim):
71 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
72 | assert len(xs) == ndim
73 | assert all(isinstance(x, int) for x in xs)
74 | return xs
75 |
76 |
77 | # ----------------------------------------------------------------------------
78 |
79 | _conv2d_gradfix_cache = dict()
80 |
81 |
82 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
83 | # Parse arguments.
84 | ndim = 2
85 | weight_shape = tuple(weight_shape)
86 | stride = _tuple_of_ints(stride, ndim)
87 | padding = _tuple_of_ints(padding, ndim)
88 | output_padding = _tuple_of_ints(output_padding, ndim)
89 | dilation = _tuple_of_ints(dilation, ndim)
90 |
91 | # Lookup from cache.
92 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
93 | if key in _conv2d_gradfix_cache:
94 | return _conv2d_gradfix_cache[key]
95 |
96 | # Validate arguments.
97 | assert groups >= 1
98 | assert len(weight_shape) == ndim + 2
99 | assert all(stride[i] >= 1 for i in range(ndim))
100 | assert all(padding[i] >= 0 for i in range(ndim))
101 | assert all(dilation[i] >= 0 for i in range(ndim))
102 | if not transpose:
103 | assert all(output_padding[i] == 0 for i in range(ndim))
104 | else: # transpose
105 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
106 |
107 | # Helpers.
108 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
109 |
110 | def calc_output_padding(input_shape, output_shape):
111 | if transpose:
112 | return [0, 0]
113 | return [
114 | input_shape[i + 2]
115 | - (output_shape[i + 2] - 1) * stride[i]
116 | - (1 - 2 * padding[i])
117 | - dilation[i] * (weight_shape[i + 2] - 1)
118 | for i in range(ndim)
119 | ]
120 |
121 | # Forward & backward.
122 | class Conv2d(torch.autograd.Function):
123 | @staticmethod
124 | def forward(ctx, input, weight, bias):
125 | assert weight.shape == weight_shape
126 | if not transpose:
127 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
128 | else: # transpose
129 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias,
130 | output_padding=output_padding, **common_kwargs)
131 | ctx.save_for_backward(input, weight, bias)
132 | return output
133 |
134 | @staticmethod
135 | def backward(ctx, grad_output):
136 | input, weight, bias = ctx.saved_tensors
137 | grad_input = None
138 | grad_weight = None
139 | grad_bias = None
140 |
141 | if ctx.needs_input_grad[0]:
142 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
143 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p,
144 | **common_kwargs).apply(grad_output, weight, None)
145 | assert grad_input.shape == input.shape
146 |
147 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
148 | grad_weight = Conv2dGradWeight.apply(grad_output, input, bias)
149 | assert grad_weight.shape == weight_shape
150 |
151 | if ctx.needs_input_grad[2]:
152 | grad_bias = grad_output.sum([0, 2, 3])
153 |
154 | return grad_input, grad_weight, grad_bias
155 |
156 | # Gradient with respect to the weights.
157 | class Conv2dGradWeight(torch.autograd.Function):
158 | @staticmethod
159 | def forward(ctx, grad_output, input, bias):
160 | bias_shape = bias.shape if (bias is not None) else None
161 | # empty_weight = torch.empty(weight_shape, dtype=input.dtype, layout=input.layout, device=input.device)
162 | empty_weight = torch.tensor(0.0, dtype=input.dtype, device=input.device).expand(weight_shape)
163 | grad_weight = \
164 | torch.ops.aten.convolution_backward(grad_output, input, empty_weight, bias_sizes=bias_shape,
165 | stride=stride,
166 | padding=padding, dilation=dilation, transposed=transpose,
167 | output_padding=output_padding, groups=groups,
168 | output_mask=[0, 1, 0])[1]
169 | assert grad_weight.shape == weight_shape
170 | ctx.save_for_backward(grad_output, input)
171 | return grad_weight
172 |
173 | @staticmethod
174 | def backward(ctx, grad2_grad_weight):
175 | grad_output, input = ctx.saved_tensors
176 | grad2_grad_output = None
177 | grad2_input = None
178 |
179 | if ctx.needs_input_grad[0]:
180 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
181 | assert grad2_grad_output.shape == grad_output.shape
182 |
183 | if ctx.needs_input_grad[1]:
184 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
185 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p,
186 | **common_kwargs).apply(grad_output, grad2_grad_weight, None)
187 | assert grad2_input.shape == input.shape
188 |
189 | return grad2_grad_output, grad2_input, None
190 |
191 | _conv2d_gradfix_cache[key] = Conv2d
192 | return Conv2d
193 |
194 | # ----------------------------------------------------------------------------
195 |
--------------------------------------------------------------------------------
/src/models/networks/op/conv2d_gradfix_pre111.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import warnings
3 |
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 |
8 | from utils import is_rank_zero
9 |
10 | enabled = True
11 | weight_gradients_disabled = False
12 |
13 |
14 | @contextlib.contextmanager
15 | def no_weight_gradients():
16 | global weight_gradients_disabled
17 |
18 | old = weight_gradients_disabled
19 | weight_gradients_disabled = True
20 | yield
21 | weight_gradients_disabled = old
22 |
23 |
24 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
25 | if could_use_op(input):
26 | return conv2d_gradfix(
27 | transpose=False,
28 | weight_shape=weight.shape,
29 | stride=stride,
30 | padding=padding,
31 | output_padding=0,
32 | dilation=dilation,
33 | groups=groups,
34 | ).apply(input, weight, bias)
35 |
36 | return F.conv2d(
37 | input=input,
38 | weight=weight,
39 | bias=bias,
40 | stride=stride,
41 | padding=padding,
42 | dilation=dilation,
43 | groups=groups,
44 | )
45 |
46 |
47 | def conv_transpose2d(
48 | input,
49 | weight,
50 | bias=None,
51 | stride=1,
52 | padding=0,
53 | output_padding=0,
54 | groups=1,
55 | dilation=1,
56 | ):
57 | if could_use_op(input):
58 | return conv2d_gradfix(
59 | transpose=True,
60 | weight_shape=weight.shape,
61 | stride=stride,
62 | padding=padding,
63 | output_padding=output_padding,
64 | groups=groups,
65 | dilation=dilation,
66 | ).apply(input, weight, bias)
67 |
68 | return F.conv_transpose2d(
69 | input=input,
70 | weight=weight,
71 | bias=bias,
72 | stride=stride,
73 | padding=padding,
74 | output_padding=output_padding,
75 | dilation=dilation,
76 | groups=groups,
77 | )
78 |
79 |
80 | def could_use_op(input):
81 | if (not enabled) or (not torch.backends.cudnn.enabled) or input.device.type != "cuda":
82 | if is_rank_zero():
83 | warnings.warn("CUDNN disabled, no GPUs, or custom ops otherwise not enabled, so not being used.")
84 | return False
85 |
86 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9.", "1.10.", "1.11"]):
87 | if is_rank_zero():
88 | warnings.warn("Using custom ops")
89 | return True
90 |
91 | if is_rank_zero():
92 | warnings.warn(
93 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
94 | )
95 |
96 | return False
97 |
98 |
99 | def ensure_tuple(xs, ndim):
100 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
101 |
102 | return xs
103 |
104 |
105 | conv2d_gradfix_cache = dict()
106 |
107 |
108 | def conv2d_gradfix(
109 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
110 | ):
111 | ndim = 2
112 | weight_shape = tuple(weight_shape)
113 | stride = ensure_tuple(stride, ndim)
114 | padding = ensure_tuple(padding, ndim)
115 | output_padding = ensure_tuple(output_padding, ndim)
116 | dilation = ensure_tuple(dilation, ndim)
117 |
118 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
119 | if key in conv2d_gradfix_cache:
120 | return conv2d_gradfix_cache[key]
121 |
122 | common_kwargs = dict(
123 | stride=stride, padding=padding, dilation=dilation, groups=groups
124 | )
125 |
126 | def calc_output_padding(input_shape, output_shape):
127 | if transpose:
128 | return [0, 0]
129 |
130 | return [
131 | input_shape[i + 2]
132 | - (output_shape[i + 2] - 1) * stride[i]
133 | - (1 - 2 * padding[i])
134 | - dilation[i] * (weight_shape[i + 2] - 1)
135 | for i in range(ndim)
136 | ]
137 |
138 | class Conv2d(autograd.Function):
139 | @staticmethod
140 | def forward(ctx, input, weight, bias):
141 | if not transpose:
142 | out = F.conv2d(input=input, weight=weight.to(input.dtype),
143 | bias=bias.to(input.dtype) if bias is not None else bias,
144 | **common_kwargs)
145 |
146 | else:
147 | out = F.conv_transpose2d(
148 | input=input,
149 | weight=weight.to(input.dtype),
150 | bias=bias.to(input.dtype) if bias else bias,
151 | output_padding=output_padding,
152 | **common_kwargs,
153 | )
154 |
155 | ctx.save_for_backward(input, weight)
156 |
157 | return out
158 |
159 | @staticmethod
160 | def backward(ctx, grad_output):
161 | input, weight = ctx.saved_tensors
162 | grad_input, grad_weight, grad_bias = None, None, None
163 |
164 | if ctx.needs_input_grad[0]:
165 | p = calc_output_padding(
166 | input_shape=input.shape, output_shape=grad_output.shape
167 | )
168 | grad_input = conv2d_gradfix(
169 | transpose=(not transpose),
170 | weight_shape=weight_shape,
171 | output_padding=p,
172 | **common_kwargs,
173 | ).apply(grad_output, weight, None)
174 |
175 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
176 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
177 |
178 | if ctx.needs_input_grad[2]:
179 | grad_bias = grad_output.sum((0, 2, 3))
180 |
181 | return grad_input, grad_weight, grad_bias
182 |
183 | class Conv2dGradWeight(autograd.Function):
184 | @staticmethod
185 | def forward(ctx, grad_output, input):
186 | op = torch._C._jit_get_operation(
187 | "aten::cudnn_convolution_backward_weight"
188 | if not transpose
189 | else "aten::cudnn_convolution_transpose_backward_weight"
190 | )
191 | flags = [
192 | torch.backends.cudnn.benchmark,
193 | torch.backends.cudnn.deterministic,
194 | torch.backends.cudnn.allow_tf32,
195 | ]
196 | grad_weight = op(
197 | weight_shape,
198 | grad_output,
199 | input.to(grad_output.dtype),
200 | padding,
201 | stride,
202 | dilation,
203 | groups,
204 | *flags,
205 | )
206 | ctx.save_for_backward(grad_output, input)
207 |
208 | return grad_weight
209 |
210 | @staticmethod
211 | def backward(ctx, grad_grad_weight):
212 | grad_output, input = ctx.saved_tensors
213 | grad_grad_output, grad_grad_input = None, None
214 |
215 | if ctx.needs_input_grad[0]:
216 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
217 |
218 | if ctx.needs_input_grad[1]:
219 | p = calc_output_padding(
220 | input_shape=input.shape, output_shape=grad_output.shape
221 | )
222 | grad_grad_input = conv2d_gradfix(
223 | transpose=(not transpose),
224 | weight_shape=weight_shape,
225 | output_padding=p,
226 | **common_kwargs,
227 | ).apply(grad_output, grad_grad_weight, None)
228 |
229 | return grad_grad_output, grad_grad_input
230 |
231 | conv2d_gradfix_cache[key] = Conv2d
232 |
233 | return Conv2d
234 |
--------------------------------------------------------------------------------
/src/models/networks/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output, empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51 | )
52 |
53 | return gradgrad_out, None, None, None, None
54 |
55 |
56 | class FusedLeakyReLUFunction(Function):
57 | @staticmethod
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 |
61 | ctx.bias = bias is not None
62 |
63 | if bias is None:
64 | bias = empty
65 |
66 | out = fused.fused_bias_act(input.float(), bias, empty.float(), 3, 0, negative_slope, scale).to(input.dtype)
67 | ctx.save_for_backward(out)
68 | ctx.negative_slope = negative_slope
69 | ctx.scale = scale
70 |
71 | return out
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | out, = ctx.saved_tensors
76 |
77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79 | )
80 |
81 | if not ctx.bias:
82 | grad_bias = None
83 |
84 | return grad_input, grad_bias, None, None
85 |
86 |
87 | class FusedLeakyReLU(nn.Module):
88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89 | super().__init__()
90 |
91 | if bias:
92 | self.bias = nn.Parameter(torch.zeros(channel))
93 |
94 | else:
95 | self.bias = None
96 |
97 | self.negative_slope = negative_slope
98 | self.scale = scale
99 |
100 | def forward(self, input):
101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102 |
103 |
104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105 | if input.device.type == "cpu":
106 | if bias is not None:
107 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
108 | return (
109 | F.leaky_relu(
110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111 | )
112 | * scale
113 | )
114 |
115 | else:
116 | return F.leaky_relu(input, negative_slope=0.2) * scale
117 |
118 | else:
119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
120 |
--------------------------------------------------------------------------------
/src/models/networks/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/src/models/networks/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/src/models/networks/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/src/models/networks/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import os
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | upfirdn2d_op = load(
12 | "upfirdn2d",
13 | sources=[
14 | os.path.join(module_path, "upfirdn2d.cpp"),
15 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class UpFirDn2dBackward(Function):
21 | @staticmethod
22 | def forward(
23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24 | ):
25 |
26 | up_x, up_y = up
27 | down_x, down_y = down
28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29 |
30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31 |
32 | grad_input = upfirdn2d_op.upfirdn2d(
33 | grad_output.float(),
34 | grad_kernel,
35 | down_x,
36 | down_y,
37 | up_x,
38 | up_y,
39 | g_pad_x0,
40 | g_pad_x1,
41 | g_pad_y0,
42 | g_pad_y1,
43 | ).to(grad_output.dtype)
44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45 |
46 | ctx.save_for_backward(kernel)
47 |
48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
49 |
50 | ctx.up_x = up_x
51 | ctx.up_y = up_y
52 | ctx.down_x = down_x
53 | ctx.down_y = down_y
54 | ctx.pad_x0 = pad_x0
55 | ctx.pad_x1 = pad_x1
56 | ctx.pad_y0 = pad_y0
57 | ctx.pad_y1 = pad_y1
58 | ctx.in_size = in_size
59 | ctx.out_size = out_size
60 |
61 | return grad_input
62 |
63 | @staticmethod
64 | def backward(ctx, gradgrad_input):
65 | kernel, = ctx.saved_tensors
66 |
67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68 |
69 | gradgrad_out = upfirdn2d_op.upfirdn2d(
70 | gradgrad_input.float(),
71 | kernel,
72 | ctx.up_x,
73 | ctx.up_y,
74 | ctx.down_x,
75 | ctx.down_y,
76 | ctx.pad_x0,
77 | ctx.pad_x1,
78 | ctx.pad_y0,
79 | ctx.pad_y1,
80 | )
81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82 | gradgrad_out = gradgrad_out.view(
83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84 | )
85 |
86 | return gradgrad_out, None, None, None, None, None, None, None, None
87 |
88 |
89 | class UpFirDn2d(Function):
90 | @staticmethod
91 | def forward(ctx, input, kernel, up, down, pad):
92 | up_x, up_y = up
93 | down_x, down_y = down
94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
95 |
96 | kernel_h, kernel_w = kernel.shape
97 | batch, channel, in_h, in_w = input.shape
98 | ctx.in_size = input.shape
99 |
100 | input = input.reshape(-1, in_h, in_w, 1)
101 |
102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103 |
104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106 | ctx.out_size = (out_h, out_w)
107 |
108 | ctx.up = (up_x, up_y)
109 | ctx.down = (down_x, down_y)
110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111 |
112 | g_pad_x0 = kernel_w - pad_x0 - 1
113 | g_pad_y0 = kernel_h - pad_y0 - 1
114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116 |
117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118 |
119 | out = upfirdn2d_op.upfirdn2d(
120 | input.float(), kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121 | ).to(input.dtype)
122 | # out = out.view(major, out_h, out_w, minor)
123 | out = out.view(-1, channel, out_h, out_w)
124 |
125 | return out
126 |
127 | @staticmethod
128 | def backward(ctx, grad_output):
129 | kernel, grad_kernel = ctx.saved_tensors
130 |
131 | grad_input = None
132 |
133 | if ctx.needs_input_grad[0]:
134 | grad_input = UpFirDn2dBackward.apply(
135 | grad_output,
136 | kernel,
137 | grad_kernel,
138 | ctx.up,
139 | ctx.down,
140 | ctx.pad,
141 | ctx.g_pad,
142 | ctx.in_size,
143 | ctx.out_size,
144 | )
145 |
146 | return grad_input, None, None, None, None
147 |
148 |
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150 | if not isinstance(up, abc.Iterable):
151 | up = (up, up)
152 |
153 | if not isinstance(down, abc.Iterable):
154 | down = (down, down)
155 |
156 | if len(pad) == 2:
157 | pad = (pad[0], pad[1], pad[0], pad[1])
158 |
159 | if input.device.type == "cpu":
160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161 |
162 | else:
163 | out = UpFirDn2d.apply(input, kernel, up, down, pad)
164 |
165 | return out
166 |
167 |
168 | def upfirdn2d_native(
169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170 | ):
171 | _, channel, in_h, in_w = input.shape
172 | input = input.reshape(-1, in_h, in_w, 1)
173 |
174 | _, in_h, in_w, minor = input.shape
175 | kernel_h, kernel_w = kernel.shape
176 |
177 | out = input.view(-1, in_h, 1, in_w, 1, minor)
178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180 |
181 | out = F.pad(
182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183 | )
184 | out = out[
185 | :,
186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188 | :,
189 | ]
190 |
191 | out = out.permute(0, 3, 1, 2)
192 | out = out.reshape(
193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194 | )
195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196 | out = F.conv2d(out, w)
197 | out = out.reshape(
198 | -1,
199 | minor,
200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202 | )
203 | out = out.permute(0, 2, 3, 1)
204 | out = out[:, ::down_y, ::down_x, :]
205 |
206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208 |
209 | return out.view(-1, channel, out_h, out_w)
210 |
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import hydra
5 | import pytorch_lightning as pl
6 | import torch
7 | from omegaconf import OmegaConf, DictConfig
8 | from pytorch_lightning import seed_everything
9 |
10 | import data
11 | import models
12 | import utils
13 | from utils import scale_logging_rates, print_once, Checkpoint
14 |
15 |
16 | @hydra.main(config_path="configs", config_name="fit")
17 | def run(config: DictConfig):
18 | torch.backends.cudnn.deterministic = config.trainer.deterministic
19 | torch.backends.cudnn.benchmark = config.trainer.benchmark
20 | torch.use_deterministic_algorithms(config.trainer.deterministic)
21 | if config.trainer.deterministic:
22 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
23 |
24 | print_once(OmegaConf.to_yaml(config, resolve=True))
25 |
26 | seed_everything(config.seed, workers=True)
27 |
28 | scale_logging_rates(config, 1 / config.trainer.get('accumulate_grad_batches', 1))
29 |
30 | if config.get('detect_anomalies', False):
31 | print_once('Anomaly detection mode ACTIVATED')
32 | torch.autograd.set_detect_anomaly(True)
33 |
34 | config.resume.id = utils.resolve_resume_id(**config.resume)
35 |
36 | if config.logger:
37 | logger = utils.Logger(**config[config.logger])
38 | logger.log_config(config)
39 | logger.log_code()
40 | else:
41 | logger = False
42 |
43 | datamodule = data.get_datamodule(**config.dataset)
44 |
45 | model, model_cfg = models.get_model(**config.model, return_cfg=True)
46 |
47 | if config.resume.id is not None:
48 | checkpoint = utils.get_checkpoint_path(**config.resume)
49 | if config.mode != 'fit' or config.resume.model_only:
50 | # Automatically load model weights in validate/test mode as opposed to using built-in PL argument to
51 | # validate or test methods since need custom logic e.g. to remove non-dataclass args
52 | model = model.load_from_checkpoint(checkpoint, **(model_cfg if config.resume.clobber_hparams else {}))
53 | else:
54 | checkpoint = None
55 |
56 | if logger:
57 | if os.environ.get("EXP_LOG_DIR", None) is None:
58 | # Needed because in distributed training, the logger is not properly initializated on clone processes
59 | # If dirname for the checkpointer is not the same on all processes, training hangs
60 | # See https://github.com/PyTorchLightning/pytorch-lightning/issues/5319
61 | os.environ["EXP_LOG_DIR"] = logger.experiment.dir
62 |
63 | callbacks = []
64 | checkpoint_callback = 'checkpoint' in config and config.checkpoint is not None
65 |
66 | if logger and checkpoint_callback:
67 | checkpoint_cb = Checkpoint(**config.checkpoint,
68 | dirpath=Path(os.environ["EXP_LOG_DIR"]) / 'checkpoints')
69 | checkpoint_cb.CHECKPOINT_NAME_LAST = checkpoint_cb.CHECKPOINT_JOIN_CHAR.join(["{epoch}", "{step}", "last"])
70 | callbacks.append(checkpoint_cb)
71 |
72 | trainer = pl.Trainer(
73 | resume_from_checkpoint=None if config.resume.model_only else checkpoint,
74 | logger=logger,
75 | callbacks=callbacks,
76 | checkpoint_callback=checkpoint_callback,
77 | **config.trainer
78 | )
79 |
80 | if config.mode == 'fit':
81 | trainer.fit(model, datamodule=datamodule)
82 | elif config.mode == 'validate':
83 | trainer.validate(model, datamodule=datamodule)
84 | elif config.mode == 'test':
85 | trainer.test(model, datamodule=datamodule)
86 |
87 |
88 | if __name__ == "__main__":
89 | run()
90 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import *
2 | from .training import *
3 | from .io import *
4 | from .colab import *
5 | from .distributed import *
6 | from .wandb_logger import *
7 | from .logging import *
8 |
--------------------------------------------------------------------------------
/src/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # From Bill Peebles, thanks Bill!
2 | # https://raw.githubusercontent.com/wpeebles/gangealing/739da2a25de62702d54d83fad6b644646512039c/utils/distributed.py
3 | import torch
4 | from torch import distributed as dist
5 | import os
6 |
7 |
8 | def is_rank_zero():
9 | return get_rank() == 0
10 |
11 |
12 | def print_once(s):
13 | if is_rank_zero():
14 | print(s)
15 |
16 |
17 | def get_rank():
18 | return int(os.environ.get('LOCAL_RANK', 0))
19 |
20 |
21 | def get_rank_colab():
22 | if not dist.is_available():
23 | return 0
24 |
25 | if not dist.is_initialized():
26 | return 0
27 |
28 | return dist.get_rank()
29 |
30 |
31 | def primary():
32 | if not dist.is_available():
33 | return True
34 |
35 | if not dist.is_initialized():
36 | return True
37 |
38 | return get_rank_colab() == 0
39 |
40 |
41 | def synchronize():
42 | if not dist.is_available():
43 | return
44 |
45 | if not dist.is_initialized():
46 | return
47 |
48 | world_size = dist.get_world_size()
49 |
50 | if world_size == 1:
51 | return
52 |
53 | dist.barrier()
54 |
55 |
56 | def get_world_size():
57 | if not dist.is_available():
58 | return 1
59 |
60 | if not dist.is_initialized():
61 | return 1
62 |
63 | return dist.get_world_size()
64 |
65 |
66 | def reduce_sum(tensor):
67 | if not dist.is_available():
68 | return tensor
69 |
70 | if not dist.is_initialized():
71 | return tensor
72 |
73 | tensor = tensor.clone()
74 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
75 |
76 | return tensor
77 |
78 |
79 | def gather_grad(params):
80 | world_size = get_world_size()
81 |
82 | if world_size == 1:
83 | return
84 |
85 | for param in params:
86 | if param.grad is not None:
87 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
88 | param.grad.data.div_(world_size)
89 |
90 |
91 | def all_gather(input, cat=True):
92 | if get_world_size() == 1:
93 | if cat:
94 | return input
95 | else:
96 | return input.unsqueeze(0)
97 | input_list = [torch.zeros_like(input) for _ in range(get_world_size())]
98 | synchronize()
99 | torch.distributed.all_gather(input_list, input, async_op=False)
100 | if cat:
101 | inputs = torch.cat(input_list, dim=0)
102 | else:
103 | inputs = torch.stack(input_list, dim=0)
104 | return inputs
105 |
106 |
107 | def all_gatherv(input, return_boundaries=False):
108 | """Variable-sized all_gather"""
109 |
110 | # Broadcast the number of elements in every process:
111 | num_elements = torch.tensor(input.size(0), device=input.device)
112 | num_elements_per_process = all_gather(num_elements, cat=False)
113 | max_elements = num_elements_per_process.max()
114 | # Add padding so every input is the same size:
115 | difference = max_elements - input.size(0)
116 | if difference > 0:
117 | input = torch.cat([input, torch.zeros(difference, *input.size()[1:], device=input.device, dtype=input.dtype)],
118 | 0)
119 | inputs = all_gather(input, cat=False)
120 | # Remove padding:
121 | inputs = torch.cat([row[:num_ele] for row, num_ele in zip(inputs, num_elements_per_process)], 0)
122 | if return_boundaries:
123 | boundaries = torch.cumsum(num_elements_per_process, dim=0)
124 | boundaries = torch.cat([torch.zeros(1, device=input.device, dtype=torch.int), boundaries], 0)
125 | return inputs, boundaries.long()
126 | else:
127 | return inputs
128 |
129 |
130 | def all_reduce(input, device):
131 | num_local = torch.tensor([input.size(0)], dtype=torch.float, device=device)
132 | input = input.sum(dim=0, keepdim=True).to(device)
133 | num_global = all_gather(num_local).sum()
134 | input = all_gather(input)
135 | input = input.sum(dim=0).div(num_global)
136 | return input
137 |
138 |
139 | def rank0_to_all(input):
140 | input = all_gather(input)
141 | rank0_input = input[0]
142 | return rank0_input
143 |
144 |
145 | def reduce_loss_dict(loss_dict):
146 | world_size = get_world_size()
147 |
148 | if world_size < 2:
149 | return loss_dict
150 |
151 | with torch.no_grad():
152 | keys = []
153 | losses = []
154 |
155 | for k in sorted(loss_dict.keys()):
156 | keys.append(k)
157 | losses.append(loss_dict[k])
158 |
159 | losses = torch.stack(losses, 0)
160 | dist.reduce(losses, dst=0)
161 |
162 | if dist.get_rank() == 0:
163 | losses /= world_size
164 |
165 | reduced_losses = {k: v for k, v in zip(keys, losses)}
166 |
167 | return reduced_losses
168 |
--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Union, Optional, Tuple, Dict, List
2 | import torch
3 | from numbers import Number
4 |
5 | from omegaconf import DictConfig
6 | from torch import Tensor
7 | from .distributed import print_once
8 |
9 |
10 | def scalars_to_log_dict(scalars: Dict[Any, Union[Number, Tensor]], mode: str) -> Dict[str, Number]:
11 | return {f'{mode}/{k}': (v.item() if isinstance(v, Tensor) else v) for k, v in scalars.items()}
12 |
13 |
14 | def epoch_outputs_to_log_dict(outputs: List[Dict[str, Tensor]],
15 | n_max: Optional[Union[int, str]] = None,
16 | shuffle: bool = False,
17 | reduce: Optional[str] = None) -> Dict[str, Tensor]:
18 | # Converts list of dicts (per-batch return values) into dict of concatenated list element dict values
19 | # Optionally return a tensor of length at most n_max for each key, and shuffle
20 | # If n_max is "batch", return one batch worth of tensors
21 | # Either cat or stack, depending on whether batch output is 0-d tensor (scalar) or not
22 | def merge_fn(v):
23 | return (torch.cat if len(v.shape) else torch.stack) if torch.is_tensor(v) else Tensor
24 |
25 | reduce_fn = lambda x: x
26 | if reduce is not None:
27 | if reduce == 'mean':
28 | reduce_fn = torch.mean
29 | elif reduce == 'sum':
30 | reduce_fn = torch.sum
31 | else:
32 | raise ValueError('reduce must be either `mean` or `sum`')
33 | out_dict = {k: reduce_fn(merge_fn(v)([o[k] for o in outputs])) for k, v in outputs[0].items() if v is not None}
34 | if n_max is not None:
35 | for k, v in out_dict.items():
36 | if shuffle:
37 | v = v[torch.randperm(len(v))]
38 | n_max_ = len(outputs[0][k]) if n_max == "batch" else n_max
39 | out_dict[k] = v[:n_max_]
40 | return out_dict
41 |
42 |
43 | def scale_logging_rates(d: DictConfig, c: Number, strs: Tuple[str] = ('log', 'every_n_steps'), prefix: str = 'config'):
44 | if c == 1:
45 | return
46 | for k, v in d.items():
47 | if all([s in k for s in strs]):
48 | d[k] = type(v)(v * c)
49 | print_once(f'Scaling {prefix}.{k} from {v} to {type(v)(v * c)} due to gradient accumulation')
50 | elif isinstance(v, DictConfig):
51 | scale_logging_rates(v, c, strs, prefix=prefix + '.' + k)
52 |
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import random
3 | from dataclasses import is_dataclass, fields
4 | from math import pi
5 | from typing import Any, Union, TypeVar, Tuple, Optional, Dict, OrderedDict
6 |
7 | import einops
8 | import numpy as np
9 | import torch
10 | from PIL import Image, ImageDraw
11 | from omegaconf import DictConfig
12 | from torch import Tensor
13 | from torch.nn import functional as F
14 | from torchvision.transforms.functional import to_tensor
15 |
16 | from utils.io import load_pretrained_weights
17 | from .distributed import is_rank_zero
18 |
19 | T = TypeVar('T')
20 | FromConfig = Union[T, Dict[str, Any]]
21 | NTuple = Tuple[T, ...]
22 | StateDict = OrderedDict[str, torch.Tensor]
23 |
24 | TORCH_EINSUM = True
25 | einsum = torch.einsum if TORCH_EINSUM else oe.contract
26 |
27 |
28 | def recursive_compare(d1: dict, d2: dict, level: str = 'root') -> str:
29 | ret = []
30 | if isinstance(d1, dict) and isinstance(d2, dict):
31 | if d1.keys() != d2.keys():
32 | s1 = set(d1.keys())
33 | s2 = set(d2.keys())
34 | ret.append('{:<20} - {} + {}'.format(level, ','.join(s1 - s2), ','.join(s2 - s1)))
35 | common_keys = s1 & s2
36 | else:
37 | common_keys = set(d1.keys())
38 |
39 | for k in common_keys:
40 | ret.append(recursive_compare(d1[k], d2[k], level='{}.{}'.format(level, k)))
41 | elif isinstance(d1, list) and isinstance(d2, list):
42 | if len(d1) != len(d2):
43 | ret.append('{:<20} len1={}; len2={}'.format(level, len(d1), len(d2)))
44 | common_len = min(len(d1), len(d2))
45 |
46 | for i in range(common_len):
47 | ret.append(recursive_compare(d1[i], d2[i], level='{}[{}]'.format(level, i)))
48 | else:
49 | if d1 != d2:
50 | ret.append('{:<20} {} -> {}'.format(level, d1, d2))
51 | return '\n'.join(filter(None, ret))
52 |
53 |
54 | def import_external(name: str, pretrained: Optional[Union[str, DictConfig]] = None, **kwargs):
55 | module, name = name.rsplit('.', 1)
56 | ret = getattr(importlib.import_module(module), name)
57 | ret = ret(**to_dataclass_cfg(kwargs, ret))
58 | return load_pretrained_weights(name, pretrained, ret)
59 |
60 |
61 | def run_at_step(step: int, freq: int):
62 | return (freq > 0) and ((step + 1) % freq == 0)
63 |
64 |
65 | def rotation_matrix(theta):
66 | cos = torch.cos(theta)
67 | sin = torch.sin(theta)
68 | return torch.stack([cos, sin, -sin, cos], dim=-1).view(*theta.shape, 2, 2)
69 |
70 |
71 | def gaussian(window_size, sigma):
72 | def gauss_fcn(x):
73 | return -(x - window_size // 2) ** 2 / float(2 * sigma ** 2)
74 |
75 | gauss = torch.stack(
76 | [torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)])
77 | return gauss / gauss.sum()
78 |
79 |
80 | def splat_features_from_scores(scores: Tensor, features: Tensor, size: Optional[int],
81 | channels_last: bool = True) -> Tensor:
82 | """
83 |
84 | Args:
85 | channels_last: expect input with M at end or not, see below
86 | scores: [N, H, W, M] (or [N, M, H, W] if not channels last)
87 | features: [N, M, C]
88 | size: dimension of map to return
89 | Returns: [N, C, H, W]
90 |
91 | """
92 | if size and not (scores.shape[2] == size):
93 | if channels_last:
94 | scores = einops.rearrange(scores, 'n h w m -> n m h w')
95 | scores = F.interpolate(scores, size, mode='bilinear', align_corners=False)
96 | einstr = 'nmhw,nmc->nchw'
97 | else:
98 | einstr = 'nhwm,nmc->nchw' if channels_last else 'nmhw,nmc->nchw'
99 | return einsum(einstr, scores, features).contiguous()
100 |
101 |
102 | def cholesky_to_matrix(covs: Tensor) -> Tensor:
103 | covs[..., ::3] = covs[:, :, ::3].exp()
104 | covs[..., 2] = 0
105 | covs = einops.rearrange(covs, 'n m (x y) -> n m x y', x=2, y=2)
106 | covs = einsum('nmji,nmjk->nmik', covs, covs) # [n, m, 2, 2]
107 | return covs
108 |
109 |
110 | def jitter_image_batch(images: Tensor, dy: int, dx: int) -> Tensor:
111 | # images: N x C x H x W
112 | images = torch.roll(images, (dy, dx), (2, 3))
113 | if dy > 0:
114 | images[:, :, :dy, :] = 0
115 | else:
116 | images[:, :, dy:, :] = 0
117 | if dx > 0:
118 | images[:, :, :, :dx] = 0
119 | else:
120 | images[:, :, :, dx:] = 0
121 | return images
122 |
123 |
124 | DERANGEMENT_WARNED = False
125 |
126 |
127 | def derangement(n: int) -> Tensor:
128 | global DERANGEMENT_WARNED
129 | orig = torch.arange(n)
130 | shuffle = torch.randperm(n)
131 | if n == 1 and not DERANGEMENT_WARNED:
132 | if is_rank_zero():
133 | print('Warning: called derangement with n=1!')
134 | DERANGEMENT_WARNED = True
135 | while (n > 1) and (shuffle == orig).any():
136 | shuffle = torch.randperm(n)
137 | return shuffle
138 |
139 |
140 | def pyramid_resize(img, cutoff):
141 | """
142 |
143 | Args:
144 | img: [N x C x H x W]
145 | cutoff: threshold at which to stop pyramid
146 |
147 | Returns: gaussian pyramid
148 |
149 | """
150 | out = [img]
151 | while img.shape[-1] > cutoff:
152 | img = F.interpolate(img, img.shape[-1] // 2, mode='bilinear', align_corners=False)
153 | out.append(img)
154 | return {i.size(-1): i for i in out}
155 |
156 |
157 | def derange_tensor(x: Tensor, dim: int = 0) -> Tensor:
158 | if dim == 0:
159 | return x[derangement(len(x))]
160 | elif dim == 1:
161 | return x[:, derangement(len(x[0]))]
162 |
163 |
164 | def derange_tensor_n_times(x: Tensor, n: int, dim: int = 0, stack_dim: int = 0) -> Tensor:
165 | return torch.stack([derange_tensor(x, dim) for _ in range(n)], stack_dim)
166 |
167 |
168 | def to_dataclass_cfg(cfg, cls):
169 | """
170 | Can't add **kwargs catch-all to dataclass, so need to strip dict of keys that are not fields
171 | """
172 | if is_dataclass(cls):
173 | return {k: v for k, v in cfg.items() if k in [f.name for f in fields(cls)]}
174 | return cfg
175 |
176 |
177 | def random_polygons(size: int, shape: Union[int, Tuple[int, ...]]):
178 | if type(shape) is int:
179 | shape = (shape,)
180 | n = np.prod(shape)
181 | return torch.stack([random_polygon(size) for _ in range(n)]).view(*shape, 1, size, size)
182 |
183 |
184 | def random_polygon(size: int):
185 | # Logic from Copy Paste GAN
186 | img = Image.new("RGB", (size, size), "black")
187 | f = lambda s: round(size * s)
188 | to_xy = lambda r, θ, p: (p + r * np.array([np.cos(θ), np.sin(θ)]))
189 | c = np.array([random.randint(f(0.1), f(0.9)), random.randint(f(0.1), f(0.9))])
190 | n_vert = random.randint(4, 6)
191 | coords = []
192 | while len(coords) < n_vert:
193 | coord = to_xy(random.uniform(f(0.1), f(0.5)), random.uniform(0, 2 * pi), c)
194 | if coord.min() >= 0:
195 | coords.append(tuple(coord))
196 | ImageDraw.Draw(img).polygon(coords, fill="white")
197 | return to_tensor(img)[:1]
198 |
--------------------------------------------------------------------------------
/src/utils/training.py:
--------------------------------------------------------------------------------
1 | # Training utilities
2 | import math
3 | import random
4 | from itertools import groupby
5 | from numbers import Number
6 | from typing import Dict, List
7 |
8 | import torch
9 | from torch import nn, Tensor, autograd
10 |
11 | from .distributed import is_rank_zero
12 |
13 |
14 | def make_noise(batch, latent_dim, n_noise):
15 | if n_noise == 1:
16 | return torch.randn(len(batch), latent_dim).type_as(batch)
17 | return torch.randn(n_noise, len(batch), latent_dim).type_as(batch).unbind(0)
18 |
19 |
20 | def mixing_noise(batch, latent_dim, prob):
21 | if prob > 0 and random.random() < prob:
22 | return make_noise(batch, latent_dim, 2)
23 | else:
24 | return [make_noise(batch, latent_dim, 1)]
25 |
26 |
27 | ACCUM_WARN = False
28 |
29 |
30 | def accumulate(model1, model2, decay=0.999):
31 | global ACCUM_WARN
32 | par1 = dict(model1.named_parameters())
33 | par2 = dict(model2.named_parameters())
34 | if len(par1.keys() & par2.keys()) == 0:
35 | if is_rank_zero() and not ACCUM_WARN:
36 | print('Cannot accumulate, likely due to FSDP parameter flattening. Skipping.')
37 | ACCUM_WARN = True
38 | return
39 | device = next(model1.parameters()).device
40 | for k in par1.keys():
41 | par1[k].data.mul_(decay).add_(par2[k].data.to(device), alpha=1 - decay)
42 |
43 |
44 |
45 |
46 | def freeze(model: nn.Module, layers: List[str] = None):
47 | frozen = []
48 | for name, param in model.named_parameters():
49 | if layers is None or any(name.startswith(l) for l in layers):
50 | param.requires_grad = False
51 | frozen.append(name)
52 | if is_rank_zero():
53 | depth_two_params = [k for k, _ in groupby(
54 | ['.'.join(n.split('.')[:2]).replace('.weight', '').replace('.bias', '') for n in frozen])]
55 | print(f'Froze {len(frozen)} parameters - {depth_two_params} - for model of type {model.__class__.__name__}')
56 |
57 |
58 | def requires_grad(model: nn.Module, requires: bool):
59 | for param in model.parameters():
60 | param.requires_grad = requires
61 |
62 |
63 | def unfreeze(model: nn.Module):
64 | for param in model.parameters():
65 | param.requires_grad = True
66 |
67 |
68 | def fill_module_uniform(module, range, blacklist=None):
69 | if blacklist is None: blacklist = []
70 | for n, p in module.named_parameters():
71 | if not any([b in n for b in blacklist]):
72 | nn.init.uniform_(p, -range, range)
73 |
74 |
75 | def zero_module(module):
76 | for p in module.parameters():
77 | p.detach().zero_()
78 | return module
79 |
80 |
81 | def get_D_stats(key: str, scores: Tensor, gt: bool) -> Dict[str, Number]:
82 | acc = 100 * (scores > 0).sum() / len(scores)
83 | if not gt:
84 | acc = 100 - acc
85 | return {
86 | f'score_{key}': scores.mean(),
87 | f'acc_{key}': acc
88 | }
89 |
90 |
91 | # Losses adapted from https://github.com/rosinality/stylegan2-pytorch/blob/master/train.py
92 | def D_R1_loss(real_pred, real_img):
93 | grad_real, = autograd.grad(
94 | outputs=real_pred.sum(), inputs=real_img, create_graph=True
95 | )
96 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
97 | return grad_penalty
98 |
99 |
100 | def G_path_loss(fake_img, latents, mean_path_length, decay=0.01):
101 | noise = torch.randn_like(fake_img) / math.sqrt(
102 | fake_img.shape[2] * fake_img.shape[3]
103 | )
104 | grad, = autograd.grad(
105 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
106 | )
107 |
108 | if grad.ndim == 3: # [N_batch x N_latent x D]
109 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
110 | elif grad.ndim == 2: # [N_batch x D]
111 | path_lengths = torch.sqrt(grad.pow(2).sum(1))
112 |
113 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
114 |
115 | path_penalty = (path_lengths - path_mean).pow(2).mean()
116 |
117 | return path_penalty, path_mean.detach(), path_lengths.mean()
118 |
--------------------------------------------------------------------------------
/src/utils/wandb_logger.py:
--------------------------------------------------------------------------------
1 | # From Tim
2 | from __future__ import annotations
3 |
4 | __all__ = ["Logger"]
5 |
6 | import os
7 | import re
8 | import subprocess
9 | import sys
10 | from ast import literal_eval
11 | from math import sqrt, ceil
12 | from pathlib import Path
13 | from typing import Optional, Union
14 |
15 | import omegaconf
16 | import torch
17 | import torchvision.utils as utils
18 | import wandb
19 | from pytorch_lightning import LightningModule
20 | from pytorch_lightning.core.memory import ModelSummary
21 | from pytorch_lightning.loggers import WandbLogger
22 | from pytorch_lightning.utilities import rank_zero_only
23 | from torch import Tensor
24 | from wandb.sdk.lib.config_util import ConfigError
25 |
26 | from .distributed import is_rank_zero
27 | from .io import yes_or_no
28 | from .misc import recursive_compare
29 |
30 |
31 | class Logger(WandbLogger):
32 | def __init__(
33 | self,
34 | *,
35 | name: str,
36 | project: str,
37 | entity: str,
38 | group: Optional[str] = None,
39 | offline: bool = False,
40 | log_dir: Optional[str] = './logs',
41 | **kwargs
42 | ):
43 | log_dir = str(Path(log_dir).absolute())
44 |
45 | super().__init__(
46 | name=name,
47 | save_dir=log_dir,
48 | offline=offline,
49 | project=project,
50 | log_model=False,
51 | entity=entity,
52 | group=group,
53 | **kwargs
54 | )
55 |
56 | def log_hyperparams(self, *args, **kwargs):
57 | pass
58 |
59 | def _file_exists(self, path: str) -> bool:
60 | try:
61 | self.experiment.restore(path)
62 | return True
63 | except ValueError:
64 | return False
65 |
66 | def _get_unique_fn(self, filename: str, sep: str = '_') -> str:
67 | orig_filename, ext = os.path.splitext(filename)
68 | cfg_ctr = 0
69 | while self._file_exists(filename):
70 | cfg_ctr += 1
71 | filename = f"{orig_filename}{sep}{cfg_ctr}{ext}"
72 | return filename
73 |
74 | @rank_zero_only
75 | def save_to_file(self, filename: str, contents: Union[str, bytes], unique_filename: bool = True) -> str:
76 | if not is_rank_zero():
77 | return
78 | if unique_filename:
79 | filename = self._get_unique_fn(filename)
80 | self.experiment.save(filename)
81 | t = type(contents)
82 | if t is str:
83 | mode = 'w'
84 | elif t is bytes:
85 | mode = 'wb'
86 | else:
87 | raise TypeError('Can only save str or bytes')
88 | (Path(self.experiment.dir) / filename).open(mode).write(contents)
89 | return filename
90 |
91 | @rank_zero_only
92 | def log_config(self, config: omegaconf.DictConfig):
93 | if not is_rank_zero():
94 | return
95 | filename = self.save_to_file("hydra_config.yaml", omegaconf.OmegaConf.to_yaml(config))
96 | params = omegaconf.OmegaConf.to_container(config)
97 | assert isinstance(params, dict)
98 | params.pop("wandb", None)
99 |
100 | try:
101 | self.experiment.config.update(params)
102 | except ConfigError as e:
103 | # Config has changed, so confirm with user that this is okay before proceeding
104 | msg = e.message.split("\n")[0]
105 |
106 | def try_literal_eval(x):
107 | try:
108 | return literal_eval(x)
109 | except ValueError:
110 | return x
111 |
112 | key, old, new = map(try_literal_eval, re.search("key (.*) from (.*) to (.*)", msg).groups())
113 | print(f'Caution! Parameters have changed!')
114 | if not (type(old) == type(new) == dict):
115 | old = {key: old}
116 | new = {key: new}
117 | print(recursive_compare(old, new, level=key))
118 | if yes_or_no('Was this intended?', default=True, timeout=10):
119 | print(f'Saving new parameters to {filename} and updating W and B config.')
120 | self.experiment.config.update(params, allow_val_change=True)
121 | else:
122 | sys.exit(1)
123 |
124 | @rank_zero_only
125 | def log_model_summary(self, model: LightningModule):
126 | if not is_rank_zero():
127 | return
128 | self.save_to_file("model_summary.txt", str(ModelSummary(model, max_depth=-1)))
129 |
130 | @torch.no_grad()
131 | @rank_zero_only
132 | def log_image_batch(self, name: str, images: Tensor, square_grid: bool = True, commit: bool = False,
133 | ncol: Optional[int] = None, **kwargs):
134 | """
135 | Args:
136 | name: Name of key to use for logging
137 | images: N x C x H x W tensor of images
138 | square_grid: whether to render images into a square grid
139 | commit: whether to commit log to wandb or not
140 | ncol: analogous to nrow in make_grid, control how many images are in each column
141 | **kwargs: passed onto make_grid
142 | """
143 | if not is_rank_zero():
144 | return
145 | assert not (square_grid and ncol is not None), "Set either square_grid or ncol"
146 | if square_grid:
147 | kwargs['nrow'] = ceil(sqrt(len(images)))
148 | elif ncol is not None:
149 | kwargs['nrow'] = ceil(len(images) / ncol)
150 | image_grid = utils.make_grid(
151 | images.float(), normalize=True, value_range=(-1, 1), **kwargs
152 | )
153 | wandb_image = wandb.Image(image_grid.float().cpu())
154 | self.experiment.log({name: wandb_image}, commit=commit)
155 |
156 | @rank_zero_only
157 | def log_code(self):
158 | if not is_rank_zero():
159 | return
160 | codetar = subprocess.run(
161 | ['tar', '--exclude=*.pyc', '--exclude=__pycache__', '--exclude=*.pt','--exclude=*.pkl', '-cvJf', '-', 'src'],
162 | stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout
163 | self.save_to_file('code.tar.xz', codetar)
164 |
--------------------------------------------------------------------------------