├── README.md ├── simple-san ├── .gitignore ├── LICENSE ├── README.md ├── assets │ ├── gan_class.png │ ├── gan_nocond.png │ ├── san_class.png │ └── san_nocond.png ├── hparams │ └── params.json ├── models │ ├── discriminator.py │ └── generator.py ├── requirements.txt └── train.py ├── stylesan-xl ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── calc_metrics.py ├── dataset_tool.py ├── dataset_tool_for_imagenet.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── feature_networks │ ├── clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── clip.py │ │ ├── model.py │ │ └── simple_tokenizer.py │ ├── constants.py │ ├── pretrained_builder.py │ └── vit.py ├── gen_class_samplesheet.py ├── gen_images.py ├── gen_video.py ├── gui_utils │ ├── __init__.py │ ├── gl_utils.py │ ├── glfw_window.py │ ├── imgui_utils.py │ ├── imgui_window.py │ └── text_utils.py ├── in_embeddings │ └── tf_efficientnet_lite0.pkl ├── incl_licenses │ └── LICENSE_1 ├── legacy.py ├── media │ └── imagenet_idx2labels.txt ├── metrics │ ├── __init__.py │ ├── equivariance.py │ ├── frechet_inception_distance.py │ ├── inception_score.py │ ├── kernel_inception_distance.py │ ├── metric_main.py │ ├── metric_utils.py │ ├── perceptual_path_length.py │ └── precision_recall.py ├── pg_modules │ ├── blocks.py │ ├── discriminator.py │ ├── projector.py │ └── san_modules.py ├── requirements.txt ├── run_inversion.py ├── run_stylemc.py ├── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── gen_utils.py │ ├── misc.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── filtered_lrelu.cpp │ │ ├── filtered_lrelu.cu │ │ ├── filtered_lrelu.h │ │ ├── filtered_lrelu.py │ │ ├── filtered_lrelu_ns.cu │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ ├── training_stats.py │ └── utils_spectrum.py ├── train.py ├── training │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── diffaug.py │ ├── loss.py │ ├── networks_fastgan.py │ ├── networks_stylegan2.py │ ├── networks_stylegan3.py │ ├── networks_stylegan3_resetting.py │ └── training_loop.py ├── visualizer.py └── viz │ ├── __init__.py │ ├── capture_widget.py │ ├── equivariance_widget.py │ ├── latent_widget.py │ ├── layer_widget.py │ ├── performance_widget.py │ ├── pickle_widget.py │ ├── renderer.py │ ├── stylemix_widget.py │ └── trunc_noise_widget.py └── tutorial ├── README.md └── assets ├── .gitkeep └── imagenet256.png /README.md: -------------------------------------------------------------------------------- 1 | # Slicing Adversarial Network (SAN) [ICLR 2024] 2 | 3 | This repository contains the official PyTorch implementation of **"SAN: Inducing Metrizability of GAN with Discriminative Normalized Linear Layer"** (*[arXiv 2301.12811](https://arxiv.org/abs/2301.12811)*). 4 | Please cite [[1](#citation)] in your work when using this code in your experiments. 5 | 6 | ### [[Project Page]](https://ytakida.github.io/san/) 7 | 8 | 9 | > **Abstract:** Generative adversarial networks (GANs) learn a target probability distribution by optimizing a generator and a discriminator with minimax objectives. This paper addresses the question of whether such optimization actually provides the generator with gradients that make its distribution close to the target distribution. We derive metrizable conditions, sufficient conditions for the discriminator to serve as the distance between the distributions, by connecting the GAN formulation with the concept of sliced optimal transport. Furthermore, by leveraging these theoretical results, we propose a novel GAN training scheme called the Slicing Adversarial Network (SAN). With only simple modifications, a broad class of existing GANs can be converted to SANs. Experiments on synthetic and image datasets support our theoretical results and the effectiveness of SAN as compared to the usual GANs. We also apply SAN to StyleGAN-XL, which leads to a state-of-the-art FID score amongst GANs for class conditional generation on CIFAR10 and ImageNet 256$times$256. 10 | 11 | 12 | # Citation 13 | [1] Takida, Y., Imaizumi, M., Shibuya, T., Lai, C., Uesaka, T., Murata, N. and Mitsufuji, Y., 14 | "SAN: Inducing Metrizability of GAN with Discriminative Normalized Linear Layer," 15 | ICLR 2024. 16 | ``` 17 | @inproceedings{takida2024san, 18 | title={{SAN}: Inducing Metrizability of {GAN} with Discriminative Normalized Linear Layer}, 19 | author={Takida, Yuhta and Imaizumi, Masaaki and Shibuya, Takashi and Lai, Chieh-Hsin and Uesaka, Toshimitsu and Murata, Naoki and Mitsufuji, Yuki}, 20 | booktitle={The Twelfth International Conference on Learning Representations}, 21 | year={2024}, 22 | url={https://openreview.net/forum?id=eiF7TU1E8E} 23 | } 24 | ``` -------------------------------------------------------------------------------- /simple-san/.gitignore: -------------------------------------------------------------------------------- 1 | # vscode 2 | .vscode/ 3 | 4 | # pycache 5 | __pycache__ 6 | 7 | # tmp files 8 | *.tmp 9 | 10 | # output 11 | job/ 12 | logs/ 13 | out/ 14 | dataset/ -------------------------------------------------------------------------------- /simple-san/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jens Rahnfeld 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /simple-san/README.md: -------------------------------------------------------------------------------- 1 | # SAN (simple code) 2 | 3 | This repository provides a simple implementation of Slicing Adversarial Network (SAN) on MNIST for a tutorial purpose. (*[arXiv 2301.12811](https://arxiv.org/abs/2301.12811)*). 4 | Please cite [[1](#citation)] in your work when using this code in your experiments. 5 | 6 | ### [[Project Page]](https://ytakida.github.io/san/) 7 | 8 | 9 | ## Requirements 10 | This repository builds on the codebase of 11 | 1. https://github.com/yukara-ikemiya/minimal-san 12 | 2. https://github.com/JensRahnfeld/Simple-GAN-MNIST 13 | 14 | Install the following dependencies: 15 | - Python 3.8.5 16 | - pytorch 1.6.0 17 | - torchvision 0.7.0 18 | - tqdm 4.50.2 19 | 20 | ## Training 21 | 22 | ### 1) Hyperparameters 23 | Specify hyperparameters inside a .json file, e.g.: 24 | 25 | ```json 26 | { 27 | "dim_latent": 100, 28 | "batch_size": 128, 29 | "learning_rate": 0.001, 30 | "beta_1": 0.0, 31 | "beta_2": 0.99, 32 | "num_epochs": 200 33 | } 34 | ``` 35 | 36 | ### 2) Options 37 | 38 | ``` 39 | -h, --help 40 | show this help message and exit
41 | --datadir DATADIR 42 | path to MNIST dataset folder 43 | --params PARAMS 44 | path to hyperparameters 45 | --model MODEL 46 | model's name / 'gan' or 'san' 47 | --enable_class 48 | enable class conditioning 49 | --device DEVICE 50 | gpu device to use 51 | ``` 52 | 53 | 54 | ### 3) Train the model 55 | 56 | - Class conditional (hinge) SAN 57 | ```bash 58 | python train.py --datadir --model 'san' --enable_class 59 | ``` 60 | 61 | - Unconditional (hinge) GAN 62 | ```bash 63 | python train.py --datadir --model 'gan' 64 | ``` 65 | 66 | 67 | ## Generated images (after 200 epochs) 68 | 69 | ### Unconditional 70 | 71 | GAN | SAN 72 | :-------------------------:|:-------------------------: 73 | ![](assets/gan_nocond.png) | ![](assets/san_nocond.png) 74 | 75 | ### Class conditional 76 | 77 | GAN | SAN 78 | :-------------------------:|:-------------------------: 79 | ![](assets/gan_class.png) | ![](assets/san_class.png) 80 | 81 | 82 | 83 | # Citation 84 | [1] Takida, Y., Imaizumi, M., Shibuya, T., Lai, C., Uesaka, T., Murata, N. and Mitsufuji, Y., 85 | "SAN: Inducing Metrizability of GAN with Discriminative Normalized Linear Layer," 86 | ICLR 2024. 87 | ``` 88 | @inproceedings{takida2024san, 89 | title={{SAN}: Inducing Metrizability of {GAN} with Discriminative Normalized Linear Layer}, 90 | author={Takida, Yuhta and Imaizumi, Masaaki and Shibuya, Takashi and Lai, Chieh-Hsin and Uesaka, Toshimitsu and Murata, Naoki and Mitsufuji, Yuki}, 91 | booktitle={The Twelfth International Conference on Learning Representations}, 92 | year={2024}, 93 | url={https://openreview.net/forum?id=eiF7TU1E8E} 94 | } 95 | ``` -------------------------------------------------------------------------------- /simple-san/assets/gan_class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/gan_class.png -------------------------------------------------------------------------------- /simple-san/assets/gan_nocond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/gan_nocond.png -------------------------------------------------------------------------------- /simple-san/assets/san_class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/san_class.png -------------------------------------------------------------------------------- /simple-san/assets/san_nocond.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/san_nocond.png -------------------------------------------------------------------------------- /simple-san/hparams/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim_latent": 100, 3 | "batch_size": 128, 4 | "learning_rate": 0.001, 5 | "beta_1": 0.0, 6 | "beta_2": 0.99, 7 | "num_epochs": 200 8 | } -------------------------------------------------------------------------------- /simple-san/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BaseDiscriminator(nn.Module): 7 | def __init__(self, num_class=10): 8 | super(BaseDiscriminator, self).__init__() 9 | 10 | # Feature extractor 11 | self.h_function = nn.Sequential( 12 | nn.Conv2d(1, 128, kernel_size=6, stride=2), 13 | nn.LeakyReLU(0.2), 14 | nn.utils.spectral_norm( 15 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)), 16 | nn.LeakyReLU(0.2), 17 | nn.utils.spectral_norm( 18 | nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)), 19 | nn.LeakyReLU(0.2) 20 | ) 21 | 22 | # Last linear layer 23 | self.use_class = num_class > 0 24 | self.fc_w = nn.Parameter( 25 | torch.randn(num_class if self.use_class else 1, 512 * 3 * 3)) 26 | 27 | def forward(self, x, class_ids, flg_train: bool): 28 | h_feature = self.h_function(x) 29 | h_feature = torch.flatten(h_feature, start_dim=1) 30 | weights = self.fc_w[class_ids] if self.use_class else self.fc_w 31 | out = (h_feature * weights).sum(dim=1) 32 | 33 | return out 34 | 35 | 36 | # Modified Discriminator Architecture 37 | 38 | class SanDiscriminator(BaseDiscriminator): 39 | def __init__(self, num_class=10): 40 | super(SanDiscriminator, self).__init__(num_class) 41 | 42 | def forward(self, x, class_ids, flg_train: bool): 43 | h_feature = self.h_function(x) 44 | h_feature = torch.flatten(h_feature, start_dim=1) 45 | weights = self.fc_w[class_ids] if self.use_class else self.fc_w 46 | direction = F.normalize(weights, dim=1) # Normalize the last layer 47 | scale = torch.norm(weights, dim=1).unsqueeze(1) 48 | h_feature = h_feature * scale # For keep the scale 49 | if flg_train: # for discriminator training 50 | out_fun = (h_feature * direction.detach()).sum(dim=1) 51 | out_dir = (h_feature.detach() * direction).sum(dim=1) 52 | out = dict(fun=out_fun, dir=out_dir) 53 | else: # for generator training or inference 54 | out = (h_feature * direction).sum(dim=1) 55 | 56 | return out 57 | -------------------------------------------------------------------------------- /simple-san/models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Generator(nn.Module): 7 | def __init__(self, dim_latent=100, num_class=10): 8 | super(Generator, self).__init__() 9 | 10 | self.dim_latent = dim_latent 11 | self.use_class = num_class > 0 12 | 13 | if self.use_class: 14 | self.emb_class = nn.Embedding(num_class, dim_latent) 15 | self.fc = nn.Linear(dim_latent * 2, 512 * 3 * 3) 16 | else: 17 | self.fc = nn.Linear(dim_latent, 512 * 3 * 3) 18 | 19 | self.g_function = nn.Sequential( 20 | nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), 21 | nn.BatchNorm2d(256), 22 | nn.SiLU(), 23 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), 24 | nn.BatchNorm2d(128), 25 | nn.SiLU(), 26 | nn.ConvTranspose2d(128, 1, kernel_size=6, stride=2), 27 | nn.Sigmoid() 28 | ) 29 | 30 | def forward(self, x, class_ids): 31 | batch_size = x.size(0) 32 | 33 | if self.use_class: 34 | x_class = self.emb_class(class_ids) 35 | x = self.fc(torch.cat((x, x_class), dim=1)) 36 | else: 37 | x = self.fc(x) 38 | 39 | x = F.leaky_relu(x) 40 | x = x.view(batch_size, 512, 3, 3) 41 | img = self.g_function(x) 42 | 43 | return img 44 | -------------------------------------------------------------------------------- /simple-san/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | tqdm==4.50.2 4 | -------------------------------------------------------------------------------- /simple-san/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import matplotlib.pyplot as plt 6 | import matplotlib.gridspec as gridspec 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import torchvision.datasets as datasets 13 | import torchvision.transforms as transforms 14 | import tqdm 15 | 16 | from models.discriminator import BaseDiscriminator, SanDiscriminator 17 | from models.generator import Generator 18 | from torch.utils.data import DataLoader 19 | 20 | 21 | def update_discriminator(x, class_ids, discriminator, generator, optimizer, params): 22 | bs = x.size(0) 23 | device = x.device 24 | 25 | optimizer.zero_grad() 26 | 27 | # for data (ground-truth) distribution 28 | disc_real = discriminator(x, class_ids, flg_train=True) 29 | loss_real = eval('compute_loss_'+args.model)(disc_real, loss_type='real') 30 | 31 | # for generator distribution 32 | latent = torch.randn(bs, params["dim_latent"], device=device) 33 | img_fake = generator(latent, class_ids) 34 | disc_fake = discriminator(img_fake.detach(), class_ids, flg_train=True) 35 | loss_fake = eval('compute_loss_'+args.model)(disc_fake, loss_type='fake') 36 | 37 | 38 | loss_d = loss_real + loss_fake 39 | loss_d.backward() 40 | optimizer.step() 41 | 42 | 43 | def update_generator(num_class, discriminator, generator, optimizer, params, device): 44 | optimizer.zero_grad() 45 | 46 | bs = params['batch_size'] 47 | latent = torch.randn(bs, params["dim_latent"], device=device) 48 | 49 | class_ids = torch.randint(num_class, size=(bs,), device=device) 50 | batch_fake = generator(latent, class_ids) 51 | 52 | disc_gen = discriminator(batch_fake, class_ids, flg_train=False) 53 | loss_g = - disc_gen.mean() 54 | loss_g.backward() 55 | optimizer.step() 56 | 57 | 58 | def compute_loss_gan(disc, loss_type): 59 | assert (loss_type in ['real', 'fake']) 60 | if 'real' == loss_type: 61 | loss = (1. - disc).relu().mean() # Hinge loss 62 | else: # 'fake' == loss_type 63 | loss = (1. + disc).relu().mean() # Hinge loss 64 | 65 | return loss 66 | 67 | 68 | def compute_loss_san(disc, loss_type): 69 | assert (loss_type in ['real', 'fake']) 70 | if 'real' == loss_type: 71 | loss_fun = (1. - disc['fun']).relu().mean() # Hinge loss for function h 72 | loss_dir = - disc['dir'].mean() # Wasserstein loss for omega 73 | else: # 'fake' == loss_type 74 | loss_fun = (1. + disc['fun']).relu().mean() # Hinge loss for function h 75 | loss_dir = disc['dir'].mean() # Wasserstein loss for omega 76 | loss = loss_fun + loss_dir 77 | 78 | return loss 79 | 80 | 81 | def save_images(imgs, idx, dirname='test'): 82 | import numpy as np 83 | if imgs.shape[1] == 1: 84 | imgs = np.repeat(imgs, 3, axis=1) 85 | fig = plt.figure(figsize=(10, 10)) 86 | gs = gridspec.GridSpec(10, 10) 87 | gs.update(wspace=0.05, hspace=0.05) 88 | for i, sample in enumerate(imgs): 89 | ax = plt.subplot(gs[i]) 90 | plt.axis('off') 91 | ax.set_xticklabels([]) 92 | ax.set_yticklabels([]) 93 | ax.set_aspect('equal') 94 | plt.imshow(sample.transpose((1,2,0))) 95 | 96 | if not os.path.exists('out/{}/'.format(dirname)): 97 | os.makedirs('out/{}/'.format(dirname)) 98 | plt.savefig('out/{0}/{1}.png'.format(dirname, str(idx).zfill(3)), bbox_inches="tight") 99 | plt.close(fig) 100 | 101 | 102 | def get_args(): 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("--datadir", type=str, required=True, help="path to MNIST dataset folder") 105 | parser.add_argument("--params", type=str, default="./hparams/params.json", help="path to hyperparameters") 106 | parser.add_argument("--model", type=str, default="gan", help="model's name / 'gan' or 'san'") 107 | parser.add_argument('--enable_class', action='store_true', help='enable class conditioning') 108 | parser.add_argument("--logdir", type=str, default="./logs", help="directory storing log files") 109 | parser.add_argument("--device", type=int, default=0, help="gpu device to use") 110 | 111 | return parser.parse_args() 112 | 113 | 114 | def main(args): 115 | with open(args.params, "r") as f: 116 | params = json.load(f) 117 | 118 | device = f'cuda:{args.device}' if args.device is not None else 'cpu' 119 | model_name = args.model 120 | if not model_name in ['gan', 'san']: 121 | raise RuntimeError("A model name have to be 'gan' or 'san'.") 122 | experiment_name = model_name + "_cond" if args.enable_class else model_name 123 | 124 | # dataloading 125 | num_class = 10 126 | train_dataset = datasets.MNIST(root=args.datadir, transform=transforms.ToTensor(), train=True, download=True) 127 | train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], num_workers=4, 128 | pin_memory=True, persistent_workers=True, shuffle=True) 129 | test_dataset = datasets.MNIST(root=args.datadir, transform=transforms.ToTensor(), train=False, download=True) 130 | test_loader = DataLoader(test_dataset, batch_size=params["batch_size"], num_workers=4, 131 | pin_memory=True, persistent_workers=True, shuffle=False) 132 | 133 | # model 134 | use_class = args.enable_class 135 | generator = Generator(params["dim_latent"], num_class=num_class if use_class else 0) 136 | if 'gan' == args.model: 137 | discriminator = BaseDiscriminator(num_class=num_class if use_class else 0) 138 | else: # 'san' == args.model 139 | discriminator = SanDiscriminator(num_class=num_class if use_class else 0) 140 | generator = generator.to(device) 141 | discriminator = discriminator.to(device) 142 | 143 | # optimizer 144 | betas = (params["beta_1"], params["beta_2"]) 145 | optimizer_G = optim.Adam(generator.parameters(), lr=params["learning_rate"], betas=betas) 146 | optimizer_D = optim.Adam(discriminator.parameters(), lr=params["learning_rate"], betas=betas) 147 | 148 | ckpt_dir = f'{args.logdir}/{experiment_name}/' 149 | if not os.path.exists(args.logdir): 150 | os.mkdir(args.logdir) 151 | if not os.path.exists(ckpt_dir): 152 | os.mkdir(ckpt_dir) 153 | 154 | steps_per_epoch = len(train_loader) 155 | 156 | msg = ["\t{0}: {1}".format(key, val) for key, val in params.items()] 157 | print("hyperparameters: \n" + "\n".join(msg)) 158 | 159 | # eval initial states 160 | num_samples_per_class = 10 161 | with torch.no_grad(): 162 | latent = torch.randn(num_samples_per_class * num_class, params["dim_latent"]).cuda() 163 | class_ids = torch.arange(num_class, dtype=torch.long, 164 | device=device).repeat_interleave(num_samples_per_class) 165 | imgs_fake = generator(latent, class_ids) 166 | 167 | # main training loop 168 | for n in range(params["num_epochs"]): 169 | loader = iter(train_loader) 170 | 171 | print("epoch: {0}/{1}".format(n + 1, params["num_epochs"])) 172 | for i in tqdm.trange(steps_per_epoch): 173 | x, class_ids = next(loader) 174 | x = x.to(device) 175 | class_ids = class_ids.to(device) 176 | 177 | update_discriminator(x, class_ids, discriminator, generator, optimizer_D, params) 178 | update_generator(num_class, discriminator, generator, optimizer_G, params, device) 179 | 180 | torch.save(generator.state_dict(), ckpt_dir + "g." + str(n) + ".tmp") 181 | torch.save(discriminator.state_dict(), ckpt_dir + "d." + str(n) + ".tmp") 182 | 183 | # eval 184 | with torch.no_grad(): 185 | latent = torch.randn(num_samples_per_class * num_class, params["dim_latent"]).cuda() 186 | class_ids = torch.arange(num_class, dtype=torch.long, 187 | device=device).repeat_interleave(num_samples_per_class) 188 | imgs_fake = generator(latent, class_ids).cpu().data.numpy() 189 | save_images(imgs_fake, n, dirname=experiment_name) 190 | 191 | torch.save(generator.state_dict(), ckpt_dir + "generator.pt") 192 | torch.save(discriminator.state_dict(), ckpt_dir + "discriminator.pt") 193 | 194 | 195 | if __name__ == '__main__': 196 | args = get_args() 197 | main(args) 198 | -------------------------------------------------------------------------------- /stylesan-xl/.gitignore: -------------------------------------------------------------------------------- 1 | g++ 2 | gcc 3 | 4 | data 5 | data/* 6 | !data/.placeholder 7 | 8 | training-runs 9 | training-runs/* 10 | out/* 11 | sample_sheets/* 12 | 13 | 14 | *.zip 15 | 16 | **/__pycache__ 17 | __pycache__ 18 | .ipynb_checkpoints/ 19 | tags 20 | *.swp 21 | *.pth 22 | *.pt 23 | *.npz 24 | *.tar 25 | *.gz 26 | *.pkl 27 | *.mp4 28 | *.pyc 29 | -------------------------------------------------------------------------------- /stylesan-xl/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.06-py3 2 | COPY ./requirements.txt . 3 | RUN pip install --upgrade pip 4 | RUN pip install -r requirements.txt 5 | RUN pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121 6 | RUN apt-get update -y && apt-get install -y --no-install-recommends build-essential gcc libsndfile1 7 | -------------------------------------------------------------------------------- /stylesan-xl/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sony Research Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stylesan-xl/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /stylesan-xl/feature_networks/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /stylesan-xl/feature_networks/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/stylesan-xl/feature_networks/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /stylesan-xl/feature_networks/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /stylesan-xl/feature_networks/constants.py: -------------------------------------------------------------------------------- 1 | TORCHVISION = [ 2 | "vgg11_bn", 3 | "vgg13_bn", 4 | "vgg16", 5 | "vgg16_bn", 6 | "vgg19_bn", 7 | "densenet121", 8 | "densenet169", 9 | "densenet201", 10 | "inception_v3", 11 | "resnet18", 12 | "resnet34", 13 | "resnet50", 14 | "resnet101", 15 | "resnet152", 16 | "shufflenet_v2_x0_5", 17 | "mobilenet_v2", 18 | "wide_resnet50_2", 19 | "mnasnet0_5", 20 | "mnasnet1_0", 21 | "ghostnet_100", 22 | "cspresnet50", 23 | "fbnetc_100", 24 | "spnasnet_100", 25 | "resnet50d", 26 | "resnet26", 27 | "resnet26d", 28 | "seresnet50", 29 | "resnetblur50", 30 | "resnetrs50", 31 | "tf_mixnet_s", 32 | "tf_mixnet_m", 33 | "tf_mixnet_l", 34 | "ese_vovnet19b_dw", 35 | "ese_vovnet39b", 36 | "res2next50", 37 | "gernet_s", 38 | "gernet_m", 39 | "repvgg_a2", 40 | "repvgg_b0", 41 | "repvgg_b1", 42 | "repvgg_b1g4", 43 | "revnet", 44 | "dm_nfnet_f1", 45 | "nfnet_l0", 46 | ] 47 | 48 | REGNETS = [ 49 | "regnetx_002", 50 | "regnetx_004", 51 | "regnetx_006", 52 | "regnetx_008", 53 | "regnetx_016", 54 | "regnetx_032", 55 | "regnetx_040", 56 | "regnetx_064", 57 | "regnety_002", 58 | "regnety_004", 59 | "regnety_006", 60 | "regnety_008", 61 | "regnety_016", 62 | "regnety_032", 63 | "regnety_040", 64 | "regnety_064", 65 | ] 66 | 67 | EFFNETS_IMAGENET = [ 68 | 'tf_efficientnet_b0', 69 | 'tf_efficientnet_b1', 70 | 'tf_efficientnet_b2', 71 | 'tf_efficientnet_b3', 72 | 'tf_efficientnet_b4', 73 | 'tf_efficientnet_b0_ns', 74 | ] 75 | 76 | EFFNETS_INCEPTION = [ 77 | 'tf_efficientnet_lite0', 78 | 'tf_efficientnet_lite1', 79 | 'tf_efficientnet_lite2', 80 | 'tf_efficientnet_lite3', 81 | 'tf_efficientnet_lite4', 82 | 'tf_efficientnetv2_b0', 83 | 'tf_efficientnetv2_b1', 84 | 'tf_efficientnetv2_b2', 85 | 'tf_efficientnetv2_b3', 86 | 'efficientnet_b1', 87 | 'efficientnet_b1_pruned', 88 | 'efficientnet_b2_pruned', 89 | 'efficientnet_b3_pruned', 90 | ] 91 | 92 | EFFNETS = EFFNETS_IMAGENET + EFFNETS_INCEPTION 93 | 94 | VITS_IMAGENET = [ 95 | 'deit_tiny_distilled_patch16_224', 96 | 'deit_small_distilled_patch16_224', 97 | 'deit_base_distilled_patch16_224', 98 | ] 99 | 100 | VITS_INCEPTION = [ 101 | 'vit_base_patch16_224' 102 | ] 103 | 104 | VITS = VITS_IMAGENET + VITS_INCEPTION 105 | 106 | CLIP = [ 107 | 'resnet50_clip' 108 | ] 109 | 110 | ALL_MODELS = TORCHVISION + REGNETS + EFFNETS + VITS + CLIP 111 | 112 | # Group according to input normalization 113 | 114 | NORMALIZED_IMAGENET = TORCHVISION + REGNETS + EFFNETS_IMAGENET + VITS_IMAGENET 115 | 116 | NORMALIZED_INCEPTION = EFFNETS_INCEPTION + VITS_INCEPTION 117 | 118 | NORMALIZED_CLIP = CLIP 119 | -------------------------------------------------------------------------------- /stylesan-xl/gen_class_samplesheet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import PIL.Image 4 | from typing import List 5 | import click 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import legacy 11 | import dnnlib 12 | from training.training_loop import save_image_grid 13 | from torch_utils import gen_utils 14 | from gen_images import parse_range 15 | 16 | @click.command() 17 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 18 | @click.option('--trunc', 'truncation_psi', help='Truncation psi', type=float, default=1, show_default=True) 19 | @click.option('--seed', help='Random seed', type=int, default=42) 20 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation') 21 | @click.option('--classes', type=parse_range, help='List of classes (e.g., \'0,1,4-6\')', required=True) 22 | @click.option('--samples-per-class', help='Samples per class.', type=int, default=4) 23 | @click.option('--grid-width', help='Total width of image grid', type=int, default=32) 24 | @click.option('--batch-gpu', help='Samples per pass, adapt to fit on GPU', type=int, default=32) 25 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 26 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) 27 | def generate_samplesheet( 28 | network_pkl: str, 29 | truncation_psi: float, 30 | seed: int, 31 | centroids_path: str, 32 | classes: List[int], 33 | samples_per_class: int, 34 | batch_gpu: int, 35 | grid_width: int, 36 | outdir: str, 37 | desc: str, 38 | ): 39 | print('Loading networks from "%s"...' % network_pkl) 40 | device = torch.device('cuda') 41 | with dnnlib.util.open_url(network_pkl) as f: 42 | G = legacy.load_network_pkl(f)['G_ema'].to(device).requires_grad_(False) 43 | 44 | # setup 45 | os.makedirs(outdir, exist_ok=True) 46 | desc_full = f'{Path(network_pkl).stem}_trunc_{truncation_psi}' 47 | if desc is not None: desc_full += f'-{desc}' 48 | run_dir = Path(gen_utils.make_run_dir(outdir, desc_full)) 49 | 50 | print('Generating latents.') 51 | ws = [] 52 | for class_idx in tqdm(classes): 53 | w = gen_utils.get_w_from_seed(G, samples_per_class, device, truncation_psi, seed=seed, 54 | centroids_path=centroids_path, class_idx=class_idx) 55 | ws.append(w) 56 | ws = torch.cat(ws) 57 | 58 | print('Generating samples.') 59 | images = [] 60 | for w in tqdm(ws.split(batch_gpu)): 61 | img = gen_utils.w_to_img(G, w, to_np=True) 62 | images.append(img) 63 | 64 | # adjust grid widht to prohibit folding between same class then save to disk 65 | grid_width = grid_width - grid_width % samples_per_class 66 | images = gen_utils.create_image_grid(np.concatenate(images), grid_size=(grid_width, None)) 67 | PIL.Image.fromarray(images, 'RGB').save(run_dir / 'sheet.png') 68 | 69 | if __name__ == "__main__": 70 | generate_samplesheet() 71 | -------------------------------------------------------------------------------- /stylesan-xl/gen_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | from torch_utils import gen_utils 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def parse_range(s: Union[str, List]) -> List[int]: 27 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 28 | 29 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 30 | ''' 31 | if isinstance(s, list): return s 32 | ranges = [] 33 | range_re = re.compile(r'^(\d+)-(\d+)$') 34 | for p in s.split(','): 35 | m = range_re.match(p) 36 | if m: 37 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 38 | else: 39 | ranges.append(int(p)) 40 | return ranges 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: 45 | '''Parse a floating point 2-vector of syntax 'a,b'. 46 | 47 | Example: 48 | '0,1' returns (0,1) 49 | ''' 50 | if isinstance(s, tuple): return s 51 | parts = s.split(',') 52 | if len(parts) == 2: 53 | return (float(parts[0]), float(parts[1])) 54 | raise ValueError(f'cannot parse 2-vector {s}') 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | def make_transform(translate: Tuple[float,float], angle: float): 59 | m = np.eye(3) 60 | s = np.sin(angle/360.0*np.pi*2) 61 | c = np.cos(angle/360.0*np.pi*2) 62 | m[0][0] = c 63 | m[0][1] = s 64 | m[0][2] = translate[0] 65 | m[1][0] = -s 66 | m[1][1] = c 67 | m[1][2] = translate[1] 68 | return m 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | @click.command() 73 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 74 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) 75 | @click.option('--batch-sz', type=int, help='Batch size per sample', default=1) 76 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 77 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation') 78 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 79 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 80 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') 81 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') 82 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 83 | def generate_images( 84 | network_pkl: str, 85 | seeds: List[int], 86 | batch_sz: int, 87 | truncation_psi: float, 88 | centroids_path: str, 89 | noise_mode: str, 90 | outdir: str, 91 | translate: Tuple[float,float], 92 | rotate: float, 93 | class_idx: Optional[int] 94 | ): 95 | print('Loading networks from "%s"...' % network_pkl) 96 | device = torch.device('cuda') 97 | with dnnlib.util.open_url(network_pkl) as f: 98 | G = legacy.load_network_pkl(f)['G_ema'] 99 | G = G.eval().requires_grad_(False).to(device) 100 | 101 | os.makedirs(outdir, exist_ok=True) 102 | 103 | # Generate images. 104 | for seed_idx, seed in enumerate(seeds): 105 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 106 | 107 | # Construct an inverse rotation/translation matrix and pass to the generator. The 108 | # generator expects this matrix as an inverse to avoid potentially failing numerical 109 | # operations in the network. 110 | if hasattr(G.synthesis, 'input'): 111 | m = make_transform(translate, rotate) 112 | m = np.linalg.inv(m) 113 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 114 | 115 | w = gen_utils.get_w_from_seed(G, batch_sz, device, truncation_psi, seed=seed, 116 | centroids_path=centroids_path, class_idx=class_idx) 117 | img = gen_utils.w_to_img(G, w, to_np=True) 118 | PIL.Image.fromarray(gen_utils.create_image_grid(img), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 119 | 120 | 121 | #---------------------------------------------------------------------------- 122 | 123 | if __name__ == "__main__": 124 | generate_images() # pylint: disable=no-value-for-parameter 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /stylesan-xl/gen_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate lerp videos using pretrained network pickle.""" 10 | 11 | import copy 12 | import os 13 | import re 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import click 17 | import dnnlib 18 | import imageio 19 | import numpy as np 20 | import scipy.interpolate 21 | import torch 22 | from tqdm import tqdm 23 | 24 | import legacy 25 | from torch_utils import gen_utils 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): 30 | batch_size, channels, img_h, img_w = img.shape 31 | if grid_w is None: 32 | grid_w = batch_size // grid_h 33 | assert batch_size == grid_w * grid_h 34 | if float_to_uint8: 35 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 36 | img = img.reshape(grid_h, grid_w, channels, img_h, img_w) 37 | img = img.permute(2, 0, 3, 1, 4) 38 | img = img.reshape(channels, grid_h * img_h, grid_w * img_w) 39 | if chw_to_hwc: 40 | img = img.permute(1, 2, 0) 41 | if to_numpy: 42 | img = img.cpu().numpy() 43 | return img 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, truncation_psi=1, device=torch.device('cuda'), centroids_path=None, class_idx=None, **video_kwargs): 48 | grid_w = grid_dims[0] 49 | grid_h = grid_dims[1] 50 | 51 | if num_keyframes is None: 52 | if len(seeds) % (grid_w*grid_h) != 0: 53 | raise ValueError('Number of input seeds must be divisible by grid W*H') 54 | num_keyframes = len(seeds) // (grid_w*grid_h) 55 | 56 | all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) 57 | for idx in range(num_keyframes*grid_h*grid_w): 58 | all_seeds[idx] = seeds[idx % len(seeds)] 59 | 60 | if shuffle_seed is not None: 61 | rng = np.random.RandomState(seed=shuffle_seed) 62 | rng.shuffle(all_seeds) 63 | 64 | if class_idx is None: 65 | class_idx = [None] * len(seeds) 66 | elif len(class_idx) == 1: 67 | class_idx = [class_idx] * len(seeds) 68 | assert len(all_seeds) == len(class_idx), "Seeds and class-idx should have the same length" 69 | 70 | ws = [] 71 | for seed, cls in zip(all_seeds, class_idx): 72 | ws.append( 73 | gen_utils.get_w_from_seed(G, 1, device, truncation_psi, seed=seed, 74 | centroids_path=centroids_path, class_idx=cls) 75 | ) 76 | ws = torch.cat(ws) 77 | 78 | _ = G.synthesis(ws[:1]) # warm up 79 | ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) 80 | 81 | # Interpolation. 82 | grid = [] 83 | for yi in range(grid_h): 84 | row = [] 85 | for xi in range(grid_w): 86 | x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) 87 | y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) 88 | interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) 89 | row.append(interp) 90 | grid.append(row) 91 | 92 | # Render video. 93 | video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) 94 | for frame_idx in tqdm(range(num_keyframes * w_frames)): 95 | imgs = [] 96 | for yi in range(grid_h): 97 | for xi in range(grid_w): 98 | interp = grid[yi][xi] 99 | w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) 100 | img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] 101 | imgs.append(img) 102 | video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) 103 | video_out.close() 104 | 105 | #---------------------------------------------------------------------------- 106 | 107 | def parse_range(s: Union[str, List[int]]) -> List[int]: 108 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 109 | 110 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 111 | ''' 112 | if isinstance(s, list): return s 113 | ranges = [] 114 | range_re = re.compile(r'^(\d+)-(\d+)$') 115 | for p in s.split(','): 116 | m = range_re.match(p) 117 | if m: 118 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 119 | else: 120 | ranges.append(int(p)) 121 | return ranges 122 | 123 | #---------------------------------------------------------------------------- 124 | 125 | def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: 126 | '''Parse a 'M,N' or 'MxN' integer tuple. 127 | 128 | Example: 129 | '4x2' returns (4,2) 130 | '0,1' returns (0,1) 131 | ''' 132 | if isinstance(s, tuple): return s 133 | m = re.match(r'^(\d+)[x,](\d+)$', s) 134 | if m: 135 | return (int(m.group(1)), int(m.group(2))) 136 | raise ValueError(f'cannot parse tuple {s}') 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | @click.command() 141 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 142 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True) 143 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) 144 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) 145 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) 146 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) 147 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 148 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation') 149 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') 150 | @click.option('--class', 'class_idx', type=parse_range, help='Class label (unconditional if not specified)') 151 | def generate_images( 152 | network_pkl: str, 153 | seeds: List[int], 154 | shuffle_seed: Optional[int], 155 | truncation_psi: float, 156 | centroids_path: str, 157 | grid: Tuple[int,int], 158 | num_keyframes: Optional[int], 159 | w_frames: int, 160 | output: str, 161 | class_idx: Optional[List[int]], 162 | ): 163 | """Render a latent vector interpolation video. 164 | 165 | Examples: 166 | 167 | \b 168 | # Render a 4x2 grid of interpolations for seeds 0 through 31. 169 | python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ 170 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 171 | 172 | Animation length and seed keyframes: 173 | 174 | The animation length is either determined based on the --seeds value or explicitly 175 | specified using the --num-keyframes option. 176 | 177 | When num keyframes is specified with --num-keyframes, the output video length 178 | will be 'num_keyframes*w_frames' frames. 179 | 180 | If --num-keyframes is not specified, the number of seeds given with 181 | --seeds must be divisible by grid size W*H (--grid). In this case the 182 | output video length will be '# seeds/(w*h)*w_frames' frames. 183 | """ 184 | 185 | print('Loading networks from "%s"...' % network_pkl) 186 | device = torch.device('cuda') 187 | with dnnlib.util.open_url(network_pkl) as f: 188 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 189 | 190 | gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, truncation_psi=truncation_psi, centroids_path=centroids_path, class_idx=class_idx) 191 | 192 | #---------------------------------------------------------------------------- 193 | 194 | if __name__ == "__main__": 195 | generate_images() # pylint: disable=no-value-for-parameter 196 | 197 | #---------------------------------------------------------------------------- 198 | -------------------------------------------------------------------------------- /stylesan-xl/gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylesan-xl/gui_utils/glfw_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import time 10 | import glfw 11 | import OpenGL.GL as gl 12 | from . import gl_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class GlfwWindow: # pylint: disable=too-many-public-methods 17 | def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): 18 | self._glfw_window = None 19 | self._drawing_frame = False 20 | self._frame_start_time = None 21 | self._frame_delta = 0 22 | self._fps_limit = None 23 | self._vsync = None 24 | self._skip_frames = 0 25 | self._deferred_show = deferred_show 26 | self._close_on_esc = close_on_esc 27 | self._esc_pressed = False 28 | self._drag_and_drop_paths = None 29 | self._capture_next_frame = False 30 | self._captured_frame = None 31 | 32 | # Create window. 33 | glfw.init() 34 | glfw.window_hint(glfw.VISIBLE, False) 35 | self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) 36 | self._attach_glfw_callbacks() 37 | self.make_context_current() 38 | 39 | # Adjust window. 40 | self.set_vsync(False) 41 | self.set_window_size(window_width, window_height) 42 | if not self._deferred_show: 43 | glfw.show_window(self._glfw_window) 44 | 45 | def close(self): 46 | if self._drawing_frame: 47 | self.end_frame() 48 | if self._glfw_window is not None: 49 | glfw.destroy_window(self._glfw_window) 50 | self._glfw_window = None 51 | #glfw.terminate() # Commented out to play it nice with other glfw clients. 52 | 53 | def __del__(self): 54 | try: 55 | self.close() 56 | except: 57 | pass 58 | 59 | @property 60 | def window_width(self): 61 | return self.content_width 62 | 63 | @property 64 | def window_height(self): 65 | return self.content_height + self.title_bar_height 66 | 67 | @property 68 | def content_width(self): 69 | width, _height = glfw.get_window_size(self._glfw_window) 70 | return width 71 | 72 | @property 73 | def content_height(self): 74 | _width, height = glfw.get_window_size(self._glfw_window) 75 | return height 76 | 77 | @property 78 | def title_bar_height(self): 79 | _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) 80 | return top 81 | 82 | @property 83 | def monitor_width(self): 84 | _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 85 | return width 86 | 87 | @property 88 | def monitor_height(self): 89 | _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 90 | return height 91 | 92 | @property 93 | def frame_delta(self): 94 | return self._frame_delta 95 | 96 | def set_title(self, title): 97 | glfw.set_window_title(self._glfw_window, title) 98 | 99 | def set_window_size(self, width, height): 100 | width = min(width, self.monitor_width) 101 | height = min(height, self.monitor_height) 102 | glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) 103 | if width == self.monitor_width and height == self.monitor_height: 104 | self.maximize() 105 | 106 | def set_content_size(self, width, height): 107 | self.set_window_size(width, height + self.title_bar_height) 108 | 109 | def maximize(self): 110 | glfw.maximize_window(self._glfw_window) 111 | 112 | def set_position(self, x, y): 113 | glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) 114 | 115 | def center(self): 116 | self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) 117 | 118 | def set_vsync(self, vsync): 119 | vsync = bool(vsync) 120 | if vsync != self._vsync: 121 | glfw.swap_interval(1 if vsync else 0) 122 | self._vsync = vsync 123 | 124 | def set_fps_limit(self, fps_limit): 125 | self._fps_limit = int(fps_limit) 126 | 127 | def should_close(self): 128 | return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) 129 | 130 | def skip_frame(self): 131 | self.skip_frames(1) 132 | 133 | def skip_frames(self, num): # Do not update window for the next N frames. 134 | self._skip_frames = max(self._skip_frames, int(num)) 135 | 136 | def is_skipping_frames(self): 137 | return self._skip_frames > 0 138 | 139 | def capture_next_frame(self): 140 | self._capture_next_frame = True 141 | 142 | def pop_captured_frame(self): 143 | frame = self._captured_frame 144 | self._captured_frame = None 145 | return frame 146 | 147 | def pop_drag_and_drop_paths(self): 148 | paths = self._drag_and_drop_paths 149 | self._drag_and_drop_paths = None 150 | return paths 151 | 152 | def draw_frame(self): # To be overridden by subclass. 153 | self.begin_frame() 154 | # Rendering code goes here. 155 | self.end_frame() 156 | 157 | def make_context_current(self): 158 | if self._glfw_window is not None: 159 | glfw.make_context_current(self._glfw_window) 160 | 161 | def begin_frame(self): 162 | # End previous frame. 163 | if self._drawing_frame: 164 | self.end_frame() 165 | 166 | # Apply FPS limit. 167 | if self._frame_start_time is not None and self._fps_limit is not None: 168 | delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit 169 | if delay > 0: 170 | time.sleep(delay) 171 | cur_time = time.perf_counter() 172 | if self._frame_start_time is not None: 173 | self._frame_delta = cur_time - self._frame_start_time 174 | self._frame_start_time = cur_time 175 | 176 | # Process events. 177 | glfw.poll_events() 178 | 179 | # Begin frame. 180 | self._drawing_frame = True 181 | self.make_context_current() 182 | 183 | # Initialize GL state. 184 | gl.glViewport(0, 0, self.content_width, self.content_height) 185 | gl.glMatrixMode(gl.GL_PROJECTION) 186 | gl.glLoadIdentity() 187 | gl.glTranslate(-1, 1, 0) 188 | gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) 189 | gl.glMatrixMode(gl.GL_MODELVIEW) 190 | gl.glLoadIdentity() 191 | gl.glEnable(gl.GL_BLEND) 192 | gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. 193 | 194 | # Clear. 195 | gl.glClearColor(0, 0, 0, 1) 196 | gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) 197 | 198 | def end_frame(self): 199 | assert self._drawing_frame 200 | self._drawing_frame = False 201 | 202 | # Skip frames if requested. 203 | if self._skip_frames > 0: 204 | self._skip_frames -= 1 205 | return 206 | 207 | # Capture frame if requested. 208 | if self._capture_next_frame: 209 | self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) 210 | self._capture_next_frame = False 211 | 212 | # Update window. 213 | if self._deferred_show: 214 | glfw.show_window(self._glfw_window) 215 | self._deferred_show = False 216 | glfw.swap_buffers(self._glfw_window) 217 | 218 | def _attach_glfw_callbacks(self): 219 | glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) 220 | glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) 221 | 222 | def _glfw_key_callback(self, _window, key, _scancode, action, _mods): 223 | if action == glfw.PRESS and key == glfw.KEY_ESCAPE: 224 | self._esc_pressed = True 225 | 226 | def _glfw_drop_callback(self, _window, paths): 227 | self._drag_and_drop_paths = paths 228 | 229 | #---------------------------------------------------------------------------- 230 | -------------------------------------------------------------------------------- /stylesan-xl/gui_utils/imgui_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import contextlib 10 | import imgui 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): 15 | s = imgui.get_style() 16 | s.window_padding = [spacing, spacing] 17 | s.item_spacing = [spacing, spacing] 18 | s.item_inner_spacing = [spacing, spacing] 19 | s.columns_min_spacing = spacing 20 | s.indent_spacing = indent 21 | s.scrollbar_size = scrollbar 22 | s.frame_padding = [4, 3] 23 | s.window_border_size = 1 24 | s.child_border_size = 1 25 | s.popup_border_size = 1 26 | s.frame_border_size = 1 27 | s.window_rounding = 0 28 | s.child_rounding = 0 29 | s.popup_rounding = 3 30 | s.frame_rounding = 3 31 | s.scrollbar_rounding = 3 32 | s.grab_rounding = 3 33 | 34 | getattr(imgui, f'style_colors_{color_scheme}')(s) 35 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 36 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] 37 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @contextlib.contextmanager 42 | def grayed_out(cond=True): 43 | if cond: 44 | s = imgui.get_style() 45 | text = s.colors[imgui.COLOR_TEXT_DISABLED] 46 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] 47 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 48 | imgui.push_style_color(imgui.COLOR_TEXT, *text) 49 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) 50 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) 51 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) 52 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) 53 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) 54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) 55 | imgui.push_style_color(imgui.COLOR_BUTTON, *back) 56 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) 57 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) 58 | imgui.push_style_color(imgui.COLOR_HEADER, *back) 59 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) 60 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) 61 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) 62 | yield 63 | imgui.pop_style_color(14) 64 | else: 65 | yield 66 | 67 | #---------------------------------------------------------------------------- 68 | 69 | @contextlib.contextmanager 70 | def item_width(width=None): 71 | if width is not None: 72 | imgui.push_item_width(width) 73 | yield 74 | imgui.pop_item_width() 75 | else: 76 | yield 77 | 78 | #---------------------------------------------------------------------------- 79 | 80 | def scoped_by_object_id(method): 81 | def decorator(self, *args, **kwargs): 82 | imgui.push_id(str(id(self))) 83 | res = method(self, *args, **kwargs) 84 | imgui.pop_id() 85 | return res 86 | return decorator 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | def button(label, width=0, enabled=True): 91 | with grayed_out(not enabled): 92 | clicked = imgui.button(label, width=width) 93 | clicked = clicked and enabled 94 | return clicked 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): 99 | expanded = False 100 | if show: 101 | if default: 102 | flags |= imgui.TREE_NODE_DEFAULT_OPEN 103 | if not enabled: 104 | flags |= imgui.TREE_NODE_LEAF 105 | with grayed_out(not enabled): 106 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) 107 | expanded = expanded and enabled 108 | return expanded, visible 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def popup_button(label, width=0, enabled=True): 113 | if button(label, width, enabled): 114 | imgui.open_popup(label) 115 | opened = imgui.begin_popup(label) 116 | return opened 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | def input_text(label, value, buffer_length, flags, width=None, help_text=''): 121 | old_value = value 122 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 123 | if value == '': 124 | color[-1] *= 0.5 125 | with item_width(width): 126 | imgui.push_style_color(imgui.COLOR_TEXT, *color) 127 | value = value if value != '' else help_text 128 | changed, value = imgui.input_text(label, value, buffer_length, flags) 129 | value = value if value != help_text else '' 130 | imgui.pop_style_color(1) 131 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: 132 | changed = (value != old_value) 133 | return changed, value 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | def drag_previous_control(enabled=True): 138 | dragging = False 139 | dx = 0 140 | dy = 0 141 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): 142 | if enabled: 143 | dragging = True 144 | dx, dy = imgui.get_mouse_drag_delta() 145 | imgui.reset_mouse_drag_delta() 146 | imgui.end_drag_drop_source() 147 | return dragging, dx, dy 148 | 149 | #---------------------------------------------------------------------------- 150 | 151 | def drag_button(label, width=0, enabled=True): 152 | clicked = button(label, width=width, enabled=enabled) 153 | dragging, dx, dy = drag_previous_control(enabled=enabled) 154 | return clicked, dragging, dx, dy 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | def drag_hidden_window(label, x, y, width, height, enabled=True): 159 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) 160 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) 161 | imgui.set_next_window_position(x, y) 162 | imgui.set_next_window_size(width, height) 163 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) 164 | dragging, dx, dy = drag_previous_control(enabled=enabled) 165 | imgui.end() 166 | imgui.pop_style_color(2) 167 | return dragging, dx, dy 168 | 169 | #---------------------------------------------------------------------------- 170 | -------------------------------------------------------------------------------- /stylesan-xl/gui_utils/imgui_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import imgui 11 | import imgui.integrations.glfw 12 | 13 | from . import glfw_window 14 | from . import imgui_utils 15 | from . import text_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class ImguiWindow(glfw_window.GlfwWindow): 20 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): 21 | if font is None: 22 | font = text_utils.get_default_font() 23 | font_sizes = {int(size) for size in font_sizes} 24 | super().__init__(title=title, **glfw_kwargs) 25 | 26 | # Init fields. 27 | self._imgui_context = None 28 | self._imgui_renderer = None 29 | self._imgui_fonts = None 30 | self._cur_font_size = max(font_sizes) 31 | 32 | # Delete leftover imgui.ini to avoid unexpected behavior. 33 | if os.path.isfile('imgui.ini'): 34 | os.remove('imgui.ini') 35 | 36 | # Init ImGui. 37 | self._imgui_context = imgui.create_context() 38 | self._imgui_renderer = _GlfwRenderer(self._glfw_window) 39 | self._attach_glfw_callbacks() 40 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. 41 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). 42 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} 43 | self._imgui_renderer.refresh_font_texture() 44 | 45 | def close(self): 46 | self.make_context_current() 47 | self._imgui_fonts = None 48 | if self._imgui_renderer is not None: 49 | self._imgui_renderer.shutdown() 50 | self._imgui_renderer = None 51 | if self._imgui_context is not None: 52 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. 53 | self._imgui_context = None 54 | super().close() 55 | 56 | def _glfw_key_callback(self, *args): 57 | super()._glfw_key_callback(*args) 58 | self._imgui_renderer.keyboard_callback(*args) 59 | 60 | @property 61 | def font_size(self): 62 | return self._cur_font_size 63 | 64 | @property 65 | def spacing(self): 66 | return round(self._cur_font_size * 0.4) 67 | 68 | def set_font_size(self, target): # Applied on next frame. 69 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] 70 | 71 | def begin_frame(self): 72 | # Begin glfw frame. 73 | super().begin_frame() 74 | 75 | # Process imgui events. 76 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 77 | if self.content_width > 0 and self.content_height > 0: 78 | self._imgui_renderer.process_inputs() 79 | 80 | # Begin imgui frame. 81 | imgui.new_frame() 82 | imgui.push_font(self._imgui_fonts[self._cur_font_size]) 83 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) 84 | 85 | def end_frame(self): 86 | imgui.pop_font() 87 | imgui.render() 88 | imgui.end_frame() 89 | self._imgui_renderer.render(imgui.get_draw_data()) 90 | super().end_frame() 91 | 92 | #---------------------------------------------------------------------------- 93 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. 94 | 95 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__(*args, **kwargs) 98 | self.mouse_wheel_multiplier = 1 99 | 100 | def scroll_callback(self, window, x_offset, y_offset): 101 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier 102 | 103 | #---------------------------------------------------------------------------- 104 | -------------------------------------------------------------------------------- /stylesan-xl/gui_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import functools 10 | from typing import Optional 11 | 12 | import dnnlib 13 | import numpy as np 14 | import PIL.Image 15 | import PIL.ImageFont 16 | import scipy.ndimage 17 | 18 | from . import gl_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def get_default_font(): 23 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular 24 | return dnnlib.util.open_url(url, return_filename=True) 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | @functools.lru_cache(maxsize=None) 29 | def get_pil_font(font=None, size=32): 30 | if font is None: 31 | font = get_default_font() 32 | return PIL.ImageFont.truetype(font=font, size=size) 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def get_array(string, *, dropshadow_radius: int=None, **kwargs): 37 | if dropshadow_radius is not None: 38 | offset_x = int(np.ceil(dropshadow_radius*2/3)) 39 | offset_y = int(np.ceil(dropshadow_radius*2/3)) 40 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 41 | else: 42 | return _get_array_priv(string, **kwargs) 43 | 44 | @functools.lru_cache(maxsize=10000) 45 | def _get_array_priv( 46 | string: str, *, 47 | size: int = 32, 48 | max_width: Optional[int]=None, 49 | max_height: Optional[int]=None, 50 | min_size=10, 51 | shrink_coef=0.8, 52 | dropshadow_radius: int=None, 53 | offset_x: int=None, 54 | offset_y: int=None, 55 | **kwargs 56 | ): 57 | cur_size = size 58 | array = None 59 | while True: 60 | if dropshadow_radius is not None: 61 | # separate implementation for dropshadow text rendering 62 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 63 | else: 64 | array = _get_array_impl(string, size=cur_size, **kwargs) 65 | height, width, _ = array.shape 66 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): 67 | break 68 | cur_size = max(int(cur_size * shrink_coef), min_size) 69 | return array 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | @functools.lru_cache(maxsize=10000) 74 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): 75 | pil_font = get_pil_font(font=font, size=size) 76 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 77 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 78 | width = max(line.shape[1] for line in lines) 79 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 80 | line_spacing = line_pad if line_pad is not None else size // 2 81 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 82 | mask = np.concatenate(lines, axis=0) 83 | alpha = mask 84 | if outline > 0: 85 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) 86 | alpha = mask.astype(np.float32) / 255 87 | alpha = scipy.ndimage.gaussian_filter(alpha, outline) 88 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp 89 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 90 | alpha = np.maximum(alpha, mask) 91 | return np.stack([mask, alpha], axis=-1) 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | @functools.lru_cache(maxsize=10000) 96 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): 97 | assert (offset_x > 0) and (offset_y > 0) 98 | pil_font = get_pil_font(font=font, size=size) 99 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 100 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 101 | width = max(line.shape[1] for line in lines) 102 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 103 | line_spacing = line_pad if line_pad is not None else size // 2 104 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 105 | mask = np.concatenate(lines, axis=0) 106 | alpha = mask 107 | 108 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) 109 | alpha = mask.astype(np.float32) / 255 110 | alpha = scipy.ndimage.gaussian_filter(alpha, radius) 111 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 112 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 113 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] 114 | alpha = np.maximum(alpha, mask) 115 | return np.stack([mask, alpha], axis=-1) 116 | 117 | #---------------------------------------------------------------------------- 118 | 119 | @functools.lru_cache(maxsize=10000) 120 | def get_texture(string, bilinear=True, mipmap=True, **kwargs): 121 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) 122 | 123 | #---------------------------------------------------------------------------- 124 | -------------------------------------------------------------------------------- /stylesan-xl/in_embeddings/tf_efficientnet_lite0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/stylesan-xl/in_embeddings/tf_efficientnet_lite0.pkl -------------------------------------------------------------------------------- /stylesan-xl/incl_licenses/LICENSE_1: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stylesan-xl/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylesan-xl/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen, sfid=False, rfid=False): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | if rfid: 25 | detector_url = 'https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/feature_networks/inception_rand_full.pkl' 26 | detector_kwargs = {} # random inception network returns features by default 27 | 28 | 29 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, sfid=sfid).get_mean_cov() 32 | 33 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 34 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 35 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, sfid=sfid).get_mean_cov() 36 | 37 | if opts.rank != 0: 38 | return float('nan') 39 | 40 | m = np.square(mu_gen - mu_real).sum() 41 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 42 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 43 | return float(fid) 44 | 45 | #---------------------------------------------------------------------------- 46 | -------------------------------------------------------------------------------- /stylesan-xl/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /stylesan-xl/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /stylesan-xl/metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # distribution of this software and related documentation without an express 2 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 3 | 4 | """Main API for computing and reporting quality metrics.""" 5 | 6 | import os 7 | import time 8 | import json 9 | import torch 10 | import dnnlib 11 | 12 | from . import metric_utils 13 | from . import frechet_inception_distance 14 | from . import kernel_inception_distance 15 | from . import precision_recall 16 | from . import perceptual_path_length 17 | from . import inception_score 18 | from . import equivariance 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _metric_dict = dict() # name => fn 23 | 24 | def register_metric(fn): 25 | assert callable(fn) 26 | _metric_dict[fn.__name__] = fn 27 | return fn 28 | 29 | def is_valid_metric(metric): 30 | return metric in _metric_dict 31 | 32 | def list_valid_metrics(): 33 | return list(_metric_dict.keys()) 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 38 | assert is_valid_metric(metric) 39 | opts = metric_utils.MetricOptions(**kwargs) 40 | 41 | # Calculate. 42 | start_time = time.time() 43 | results = _metric_dict[metric](opts) 44 | total_time = time.time() - start_time 45 | 46 | # Broadcast results. 47 | for key, value in list(results.items()): 48 | if opts.num_gpus > 1: 49 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 50 | torch.distributed.broadcast(tensor=value, src=0) 51 | value = float(value.cpu()) 52 | results[key] = value 53 | 54 | # Decorate with metadata. 55 | return dnnlib.EasyDict( 56 | results = dnnlib.EasyDict(results), 57 | metric = metric, 58 | total_time = total_time, 59 | total_time_str = dnnlib.util.format_time(total_time), 60 | num_gpus = opts.num_gpus, 61 | ) 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 66 | metric = result_dict['metric'] 67 | assert is_valid_metric(metric) 68 | if run_dir is not None and snapshot_pkl is not None: 69 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 70 | 71 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 72 | print(jsonl_line) 73 | if run_dir is not None and os.path.isdir(run_dir): 74 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 75 | f.write(jsonl_line + '\n') 76 | 77 | #---------------------------------------------------------------------------- 78 | # Recommended metrics. 79 | 80 | @register_metric 81 | def fid50k_full(opts): 82 | opts.dataset_kwargs.update(max_size=None, xflip=False) 83 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 84 | return dict(fid50k_full=fid) 85 | 86 | @register_metric 87 | def fid10k_full(opts): 88 | opts.dataset_kwargs.update(max_size=None, xflip=False) 89 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000) 90 | return dict(fid10k_full=fid) 91 | 92 | @register_metric 93 | def kid50k_full(opts): 94 | opts.dataset_kwargs.update(max_size=None, xflip=False) 95 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 96 | return dict(kid50k_full=kid) 97 | 98 | @register_metric 99 | def pr50k3_full(opts): 100 | opts.dataset_kwargs.update(max_size=None, xflip=False) 101 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 102 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 103 | 104 | @register_metric 105 | def ppl2_wend(opts): 106 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 107 | return dict(ppl2_wend=ppl) 108 | 109 | @register_metric 110 | def eqt50k_int(opts): 111 | opts.G_kwargs.update(force_fp32=True) 112 | del opts.G_kwargs.truncation_psi 113 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 114 | return dict(eqt50k_int=psnr) 115 | 116 | @register_metric 117 | def eqt50k_frac(opts): 118 | opts.G_kwargs.update(force_fp32=True) 119 | del opts.G_kwargs.truncation_psi 120 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 121 | return dict(eqt50k_frac=psnr) 122 | 123 | @register_metric 124 | def eqr50k(opts): 125 | opts.G_kwargs.update(force_fp32=True) 126 | del opts.G_kwargs.truncation_psi 127 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 128 | return dict(eqr50k=psnr) 129 | 130 | #---------------------------------------------------------------------------- 131 | # New Metrics 132 | 133 | def clipfid50k_full(opts): 134 | opts.dataset_kwargs.update(max_size=None, xflip=False) 135 | opts.feature_network = 'resnet50_clip' 136 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 137 | return dict(clipfid50k_full=fid) 138 | 139 | @register_metric 140 | def sfid50k_full(opts): 141 | opts.dataset_kwargs.update(max_size=None, xflip=False) 142 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000, sfid=True) 143 | return dict(sfid50k_full=fid) 144 | 145 | @register_metric 146 | def rfid50k_full(opts): 147 | opts.dataset_kwargs.update(max_size=None, xflip=False) 148 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000, rfid=True) 149 | return dict(rfid50k_full=fid) 150 | 151 | #---------------------------------------------------------------------------- 152 | # Legacy metrics. 153 | 154 | @register_metric 155 | def fid50k(opts): 156 | opts.dataset_kwargs.update(max_size=None) 157 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 158 | return dict(fid50k=fid) 159 | 160 | @register_metric 161 | def kid50k(opts): 162 | opts.dataset_kwargs.update(max_size=None) 163 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 164 | return dict(kid50k=kid) 165 | 166 | @register_metric 167 | def pr50k3(opts): 168 | opts.dataset_kwargs.update(max_size=None) 169 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 170 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 171 | 172 | @register_metric 173 | def pr10k3(opts): 174 | opts.dataset_kwargs.update(max_size=None) 175 | precision, recall = precision_recall.compute_pr(opts, max_real=10000, num_gen=10000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 176 | return dict(pr10k3_precision=precision, pr10k3_recall=recall) 177 | 178 | @register_metric 179 | def is50k(opts): 180 | opts.dataset_kwargs.update(max_size=None, xflip=False) 181 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 182 | return dict(is50k_mean=mean, is50k_std=std) 183 | -------------------------------------------------------------------------------- /stylesan-xl/metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | # Spherical interpolation of a batch of vectors. 22 | def slerp(a, b, t): 23 | a = a / a.norm(dim=-1, keepdim=True) 24 | b = b / b.norm(dim=-1, keepdim=True) 25 | d = (a * b).sum(dim=-1, keepdim=True) 26 | p = t * torch.acos(d) 27 | c = b - d * a 28 | c = c / c.norm(dim=-1, keepdim=True) 29 | d = a * torch.cos(p) + c * torch.sin(p) 30 | d = d / d.norm(dim=-1, keepdim=True) 31 | return d 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPLSampler(torch.nn.Module): 36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__() 40 | self.G = copy.deepcopy(G) 41 | self.G_kwargs = G_kwargs 42 | self.epsilon = epsilon 43 | self.space = space 44 | self.sampling = sampling 45 | self.crop = crop 46 | self.vgg16 = copy.deepcopy(vgg16) 47 | 48 | def forward(self, c): 49 | # Generate random latents and interpolation t-values. 50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 52 | 53 | # Interpolate in W or Z. 54 | if self.space == 'w': 55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 58 | else: # space == 'z' 59 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 62 | 63 | # Randomize noise buffers. 64 | for name, buf in self.G.named_buffers(): 65 | if name.endswith('.noise_const'): 66 | buf.copy_(torch.randn_like(buf)) 67 | 68 | # Generate images. 69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 70 | 71 | # Center crop. 72 | if self.crop: 73 | assert img.shape[2] == img.shape[3] 74 | c = img.shape[2] // 8 75 | img = img[:, :, c*3 : c*7, c*2 : c*6] 76 | 77 | # Downsample to 256x256. 78 | factor = self.G.img_resolution // 256 79 | if factor > 1: 80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 81 | 82 | # Scale dynamic range from [-1,1] to [0,255]. 83 | img = (img + 1) * (255 / 2) 84 | if self.G.img_channels == 1: 85 | img = img.repeat([1, 3, 1, 1]) 86 | 87 | # Evaluate differential LPIPS. 88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 90 | return dist 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 97 | 98 | # Setup sampler and labels. 99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 100 | sampler.eval().requires_grad_(False).to(opts.device) 101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 102 | 103 | # Sampling loop. 104 | dist = [] 105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 107 | progress.update(batch_start) 108 | x = sampler(next(c_iter)) 109 | for src in range(opts.num_gpus): 110 | y = x.clone() 111 | if opts.num_gpus > 1: 112 | torch.distributed.broadcast(y, src=src) 113 | dist.append(y) 114 | progress.update(num_samples) 115 | 116 | # Compute PPL. 117 | if opts.rank != 0: 118 | return float('nan') 119 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 120 | lo = np.percentile(dist, 1, interpolation='lower') 121 | hi = np.percentile(dist, 99, interpolation='higher') 122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 123 | return float(ppl) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /stylesan-xl/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | from tqdm import tqdm 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 21 | assert 0 <= rank < num_gpus 22 | num_cols = col_features.shape[0] 23 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 24 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 25 | dist_batches = [] 26 | for col_batch in col_batches[rank :: num_gpus]: 27 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 28 | for src in range(num_gpus): 29 | dist_broadcast = dist_batch.clone() 30 | if num_gpus > 1: 31 | torch.distributed.broadcast(dist_broadcast, src=src) 32 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 33 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 38 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 39 | detector_kwargs = dict(return_features=True) 40 | max_real = max_real // opts.num_gpus 41 | 42 | real_features = metric_utils.compute_feature_stats_for_dataset( 43 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 44 | rel_lo=0, rel_hi=0, capture_all=True, max_items=None, shuffle_size=max_real).get_all_torch().to(torch.float16).to(opts.device) 45 | 46 | gen_features = metric_utils.compute_feature_stats_for_generator( 47 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 48 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 49 | 50 | results = dict() 51 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 52 | kth = [] 53 | for manifold_batch in tqdm(manifold.split(row_batch_size)): 54 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 55 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 56 | kth = torch.cat(kth) if opts.rank == 0 else None 57 | pred = [] 58 | for probes_batch in tqdm(probes.split(row_batch_size)): 59 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 60 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 61 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 62 | return results['precision'], results['recall'] 63 | 64 | #---------------------------------------------------------------------------- 65 | -------------------------------------------------------------------------------- /stylesan-xl/pg_modules/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from feature_networks.vit import forward_vit 5 | from feature_networks.pretrained_builder import _make_pretrained 6 | from feature_networks.constants import NORMALIZED_INCEPTION, NORMALIZED_IMAGENET, NORMALIZED_CLIP, VITS 7 | from pg_modules.blocks import FeatureFusionBlock 8 | 9 | def get_backbone_normstats(backbone): 10 | if backbone in NORMALIZED_INCEPTION: 11 | return { 12 | 'mean': [0.5, 0.5, 0.5], 13 | 'std': [0.5, 0.5, 0.5], 14 | } 15 | 16 | elif backbone in NORMALIZED_IMAGENET: 17 | return { 18 | 'mean': [0.485, 0.456, 0.406], 19 | 'std': [0.229, 0.224, 0.225], 20 | } 21 | 22 | elif backbone in NORMALIZED_CLIP: 23 | return { 24 | 'mean': [0.48145466, 0.4578275, 0.40821073], 25 | 'std': [0.26862954, 0.26130258, 0.27577711], 26 | } 27 | 28 | else: 29 | raise NotImplementedError 30 | 31 | def _make_scratch_ccm(scratch, in_channels, cout, expand=False): 32 | # shapes 33 | out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4 34 | 35 | scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True) 36 | scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True) 37 | scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True) 38 | scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True) 39 | 40 | scratch.CHANNELS = out_channels 41 | 42 | return scratch 43 | 44 | def _make_scratch_csm(scratch, in_channels, cout, expand): 45 | scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True) 46 | scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand) 47 | scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand) 48 | scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False)) 49 | 50 | # last refinenet does not expand to save channels in higher dimensions 51 | scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4 52 | 53 | return scratch 54 | 55 | def _make_projector(im_res, backbone, cout, proj_type, expand=False): 56 | assert proj_type in [0, 1, 2], "Invalid projection type" 57 | 58 | ### Build pretrained feature network 59 | pretrained = _make_pretrained(backbone) 60 | 61 | # Following Projected GAN 62 | im_res = 256 63 | pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32] 64 | 65 | if proj_type == 0: return pretrained, None 66 | 67 | ### Build CCM 68 | scratch = nn.Module() 69 | scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand) 70 | 71 | pretrained.CHANNELS = scratch.CHANNELS 72 | 73 | if proj_type == 1: return pretrained, scratch 74 | 75 | ### build CSM 76 | scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand) 77 | 78 | # CSM upsamples x2 so the feature map resolution doubles 79 | pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS] 80 | pretrained.CHANNELS = scratch.CHANNELS 81 | 82 | return pretrained, scratch 83 | 84 | class F_Identity(nn.Module): 85 | def forward(self, x): 86 | return x 87 | 88 | class F_RandomProj(nn.Module): 89 | def __init__( 90 | self, 91 | backbone="tf_efficientnet_lite3", 92 | im_res=256, 93 | cout=64, 94 | expand=True, 95 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing 96 | **kwargs, 97 | ): 98 | super().__init__() 99 | self.proj_type = proj_type 100 | self.backbone = backbone 101 | self.cout = cout 102 | self.expand = expand 103 | self.normstats = get_backbone_normstats(backbone) 104 | 105 | # build pretrained feature network and random decoder (scratch) 106 | self.pretrained, self.scratch = _make_projector(im_res=im_res, backbone=self.backbone, cout=self.cout, 107 | proj_type=self.proj_type, expand=self.expand) 108 | self.CHANNELS = self.pretrained.CHANNELS 109 | self.RESOLUTIONS = self.pretrained.RESOLUTIONS 110 | 111 | def forward(self, x): 112 | # predict feature maps 113 | if self.backbone in VITS: 114 | out0, out1, out2, out3 = forward_vit(self.pretrained, x) 115 | else: 116 | out0 = self.pretrained.layer0(x) 117 | out1 = self.pretrained.layer1(out0) 118 | out2 = self.pretrained.layer2(out1) 119 | out3 = self.pretrained.layer3(out2) 120 | 121 | # start enumerating at the lowest layer (this is where we put the first discriminator) 122 | out = { 123 | '0': out0, 124 | '1': out1, 125 | '2': out2, 126 | '3': out3, 127 | } 128 | 129 | if self.proj_type == 0: return out 130 | 131 | out0_channel_mixed = self.scratch.layer0_ccm(out['0']) 132 | out1_channel_mixed = self.scratch.layer1_ccm(out['1']) 133 | out2_channel_mixed = self.scratch.layer2_ccm(out['2']) 134 | out3_channel_mixed = self.scratch.layer3_ccm(out['3']) 135 | 136 | out = { 137 | '0': out0_channel_mixed, 138 | '1': out1_channel_mixed, 139 | '2': out2_channel_mixed, 140 | '3': out3_channel_mixed, 141 | } 142 | 143 | if self.proj_type == 1: return out 144 | 145 | # from bottom to top 146 | out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed) 147 | out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed) 148 | out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed) 149 | out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed) 150 | 151 | out = { 152 | '0': out0_scale_mixed, 153 | '1': out1_scale_mixed, 154 | '2': out2_scale_mixed, 155 | '3': out3_scale_mixed, 156 | } 157 | 158 | return out 159 | -------------------------------------------------------------------------------- /stylesan-xl/pg_modules/san_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def _normalize(tensor, dim): 7 | denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-12) 8 | return tensor / denom 9 | 10 | 11 | class SANLinear(nn.Linear): 12 | 13 | def __init__(self, 14 | in_features, 15 | out_features, 16 | bias=True, 17 | device=None, 18 | dtype=None 19 | ): 20 | super(SANLinear, self).__init__( 21 | in_features, out_features, bias=bias, device=device, dtype=dtype) 22 | scale = self.weight.norm(p=2.0, dim=1, keepdim=True).clamp_min(1e-12) 23 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) 24 | self.scale = nn.parameter.Parameter(scale.view(out_features)) 25 | if bias: 26 | self.bias = nn.parameter.Parameter(torch.zeros(in_features, device=device, dtype=dtype)) 27 | else: 28 | self.register_parameter('bias', None) 29 | 30 | def forward(self, input, flg_train=False): 31 | if self.bias is not None: 32 | input = input + self.bias 33 | normalized_weight = self._get_normalized_weight() 34 | scale = self.scale 35 | if flg_train: 36 | out_fun = F.linear(input, normalized_weight.detach(), None) 37 | out_dir = F.linear(input.detach(), normalized_weight, None) 38 | out = [out_fun * scale, out_dir * scale.detach()] 39 | else: 40 | out = F.linear(input, normalized_weight, None) 41 | out = out * scale 42 | return out 43 | 44 | @torch.no_grad() 45 | def normalize_weight(self): 46 | self.weight.data = self._get_normalized_weight() 47 | 48 | def _get_normalized_weight(self): 49 | return _normalize(self.weight, dim=1) 50 | 51 | 52 | class SANConv1d(nn.Conv1d): 53 | 54 | def __init__(self, 55 | in_channels, 56 | out_channels, 57 | kernel_size, 58 | stride=1, 59 | padding=0, 60 | dilation=1, 61 | bias=True, 62 | padding_mode='zeros', 63 | device=None, 64 | dtype=None 65 | ): 66 | super(SANConv1d, self).__init__( 67 | in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation, 68 | groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype) 69 | scale = self.weight.norm(p=2.0, dim=[1, 2], keepdim=True).clamp_min(1e-12) 70 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) 71 | self.scale = nn.parameter.Parameter(scale.view(out_channels)) 72 | if bias: 73 | self.bias = nn.parameter.Parameter(torch.zeros(in_channels, device=device, dtype=dtype)) 74 | else: 75 | self.register_parameter('bias', None) 76 | 77 | def forward(self, input, flg_train=False): 78 | if self.bias is not None: 79 | input = input + self.bias.view(self.in_channels, 1) 80 | normalized_weight = self._get_normalized_weight() 81 | scale = self.scale.view(self.out_channels, 1) 82 | if flg_train: 83 | out_fun = F.conv1d(input, normalized_weight.detach(), None, self.stride, 84 | self.padding, self.dilation, self.groups) 85 | out_dir = F.conv1d(input.detach(), normalized_weight, None, self.stride, 86 | self.padding, self.dilation, self.groups) 87 | out = [out_fun * scale, out_dir * scale.detach()] 88 | else: 89 | out = F.conv1d(input, normalized_weight, None, self.stride, 90 | self.padding, self.dilation, self.groups) 91 | out = out * scale 92 | return out 93 | 94 | @torch.no_grad() 95 | def normalize_weight(self): 96 | self.weight.data = self._get_normalized_weight() 97 | 98 | def _get_normalized_weight(self): 99 | return _normalize(self.weight, dim=[1, 2]) 100 | 101 | 102 | class SANConv2d(nn.Conv2d): 103 | 104 | def __init__(self, 105 | in_channels, 106 | out_channels, 107 | kernel_size, 108 | stride=1, 109 | padding=0, 110 | dilation=1, 111 | bias=True, 112 | padding_mode='zeros', 113 | device=None, 114 | dtype=None 115 | ): 116 | super(SANConv2d, self).__init__( 117 | in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation, 118 | groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype) 119 | scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-12) 120 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) 121 | self.scale = nn.parameter.Parameter(scale.view(out_channels)) 122 | if bias: 123 | self.bias = nn.parameter.Parameter(torch.zeros(in_channels, device=device, dtype=dtype)) 124 | else: 125 | self.register_parameter('bias', None) 126 | 127 | def forward(self, input, flg_train=False): 128 | if self.bias is not None: 129 | input = input + self.bias.view(self.in_channels, 1, 1) 130 | normalized_weight = self._get_normalized_weight() 131 | scale = self.scale.view(self.out_channels, 1, 1) 132 | if flg_train: 133 | out_fun = F.conv2d(input, normalized_weight.detach(), None, self.stride, 134 | self.padding, self.dilation, self.groups) 135 | out_dir = F.conv2d(input.detach(), normalized_weight, None, self.stride, 136 | self.padding, self.dilation, self.groups) 137 | out = [out_fun * scale, out_dir * scale.detach()] 138 | else: 139 | out = F.conv2d(input, normalized_weight, None, self.stride, 140 | self.padding, self.dilation, self.groups) 141 | out = out * scale 142 | return out 143 | 144 | @torch.no_grad() 145 | def normalize_weight(self): 146 | self.weight.data = self._get_normalized_weight() 147 | 148 | def _get_normalized_weight(self): 149 | return _normalize(self.weight, dim=[1, 2, 3]) 150 | 151 | 152 | class SANEmbedding(nn.Embedding): 153 | 154 | def __init__(self, num_embeddings, embedding_dim, 155 | scale_grad_by_freq=False, 156 | sparse=False, _weight=None, 157 | device=None, dtype=None): 158 | super(SANEmbedding, self).__init__( 159 | num_embeddings, embedding_dim, padding_idx=None, 160 | max_norm=None, norm_type=2., scale_grad_by_freq=scale_grad_by_freq, 161 | sparse=sparse, _weight=_weight, 162 | device=device, dtype=dtype) 163 | scale = self.weight.norm(p=2.0, dim=1, keepdim=True).clamp_min(1e-12) 164 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) 165 | self.scale = nn.parameter.Parameter(scale) 166 | 167 | def forward(self, input, flg_train=False): 168 | out = F.embedding( 169 | input, self.weight, self.padding_idx, self.max_norm, 170 | self.norm_type, self.scale_grad_by_freq, self.sparse) 171 | out = _normalize(out, dim=-1) 172 | scale = F.embedding( 173 | input, self.scale, self.padding_idx, self.max_norm, 174 | self.norm_type, self.scale_grad_by_freq, self.sparse) 175 | if flg_train: 176 | out_fun = out.detach() 177 | out_dir = out 178 | out = [out_fun * scale, out_dir * scale.detach()] 179 | else: 180 | out = out * scale 181 | return out 182 | 183 | @torch.no_grad() 184 | def normalize_weight(self): 185 | self.weight.data = _normalize(self.weight, dim=1) 186 | -------------------------------------------------------------------------------- /stylesan-xl/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20 2 | click>=8.0 3 | pillow>=8.3.1 4 | scipy>=1.7.1 5 | requests>=2.26.0 6 | tqdm>=4.62.2 7 | ninja>=1.10.2 8 | matplotlib>=3.4.2 9 | imageio>=2.9.0 10 | dill>=0.3.4 11 | psutil>=5.8.0 12 | regex>=2022.3.15 13 | pillow>=8.3.1 14 | imgui>=1.3.0 15 | glfw>=2.2.0 16 | pyopengl>=3.1.5 17 | imageio-ffmpeg>=0.4.3 18 | pyspng 19 | ftfy>=6.1.1 20 | timm==0.4.12 21 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /stylesan-xl/torch_utils/utils_spectrum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fftn 3 | 4 | 5 | def roll_quadrants(data, backwards=False): 6 | """ 7 | Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1] 8 | Args: 9 | data: fourier transform, (NxHxW) 10 | backwards: bool, if True shift high frequencies back to center 11 | 12 | Returns: 13 | Shifted fourier transform. 14 | """ 15 | dim = data.ndim - 1 16 | 17 | if dim != 2: 18 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 19 | if any(s % 2 == 0 for s in data.shape[1:]): 20 | raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.') 21 | 22 | # for each dimension swap left and right half 23 | dims = tuple(range(1, dim+1)) # add one for batch dimension 24 | shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd 25 | if backwards: 26 | shifts *= -1 27 | return data.roll(shifts.tolist(), dims=dims) 28 | 29 | 30 | def batch_fft(data, normalize=False): 31 | """ 32 | Compute fourier transform of batch. 33 | Args: 34 | data: input tensor, (NxHxW) 35 | 36 | Returns: 37 | Batch fourier transform of input data. 38 | """ 39 | 40 | dim = data.ndim - 1 # subtract one for batch dimension 41 | if dim != 2: 42 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 43 | 44 | dims = tuple(range(1, dim + 1)) # add one for batch dimension 45 | if normalize: 46 | norm = 'ortho' 47 | else: 48 | norm = 'backward' 49 | 50 | if not torch.is_complex(data): 51 | data = torch.complex(data, torch.zeros_like(data)) 52 | freq = fftn(data, dim=dims, norm=norm) 53 | 54 | return freq 55 | 56 | 57 | def azimuthal_average(image, center=None): 58 | # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/ 59 | """ 60 | Calculate the azimuthally averaged radial profile. 61 | Requires low frequencies to be at the center of the image. 62 | Args: 63 | image: Batch of 2D images, NxHxW 64 | center: The [x,y] pixel coordinates used as the center. The default is 65 | None, which then uses the center of the image (including 66 | fracitonal pixels). 67 | 68 | Returns: 69 | Azimuthal average over the image around the center 70 | """ 71 | # Check input shapes 72 | assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \ 73 | f'(but it is len(center)={len(center)}.' 74 | # Calculate the indices from the image 75 | H, W = image.shape[-2:] 76 | h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 77 | 78 | if center is None: 79 | center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0]) 80 | 81 | # Compute radius for each pixel wrt center 82 | r = torch.stack([w-center[0], h-center[1]]).norm(2, 0) 83 | 84 | # Get sorted radii 85 | r_sorted, ind = r.flatten().sort() 86 | i_sorted = image.flatten(-2, -1)[..., ind] 87 | 88 | # Get the integer part of the radii (bin size = 1) 89 | r_int = r_sorted.long() # attribute to the smaller integer 90 | 91 | # Find all pixels that fall within each radial bin. 92 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii 93 | rind = torch.where(deltar)[0] # location of changed radius 94 | 95 | # compute number of elements in each bin 96 | nind = rind + 1 # number of elements = idx + 1 97 | nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders 98 | nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius 99 | 100 | # Cumulative sum to figure out sums for each radius bin 101 | if H % 2 == 0: 102 | raise NotImplementedError('Not sure if implementation correct, please check') 103 | rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders 104 | else: 105 | rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders 106 | csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius 107 | tbin = csim[..., rind[1:]] - csim[..., rind[:-1]] 108 | # add mean 109 | tbin = torch.cat([csim[:, 0:1], tbin], 1) 110 | 111 | radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins 112 | 113 | return radial_prof 114 | 115 | 116 | def get_spectrum(data, normalize=False): 117 | dim = data.ndim - 1 # subtract one for batch dimension 118 | if dim != 2: 119 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 120 | 121 | freq = batch_fft(data, normalize=normalize) 122 | power_spec = freq.real ** 2 + freq.imag ** 2 123 | N = data.shape[1] 124 | if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum 125 | # and is not averaged with the mean value 126 | N_2 = N//2 127 | power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1) 128 | power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2) 129 | 130 | power_spec = roll_quadrants(power_spec) 131 | power_spec = azimuthal_average(power_spec) 132 | return power_spec 133 | 134 | 135 | def plot_std(mean, std, x=None, ax=None, **kwargs): 136 | import matplotlib.pyplot as plt 137 | if ax is None: 138 | fig, ax = plt.subplots(1) 139 | 140 | # plot error margins in same color as line 141 | err_kwargs = { 142 | 'alpha': 0.3 143 | } 144 | 145 | if 'c' in kwargs.keys(): 146 | err_kwargs['color'] = kwargs['c'] 147 | elif 'color' in kwargs.keys(): 148 | err_kwargs['color'] = kwargs['color'] 149 | 150 | if x is None: 151 | x = torch.linspace(0, 1, len(mean)) # use normalized x axis 152 | ax.plot(x, mean, **kwargs) 153 | ax.fill_between(x, mean-std, mean+std, **err_kwargs) 154 | 155 | return ax 156 | -------------------------------------------------------------------------------- /stylesan-xl/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylesan-xl/training/diffaug.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def DiffAugment(x, policy='', channels_first=True): 10 | if policy: 11 | if not channels_first: 12 | x = x.permute(0, 3, 1, 2) 13 | for p in policy.split(','): 14 | for f in AUGMENT_FNS[p]: 15 | x = f(x) 16 | if not channels_first: 17 | x = x.permute(0, 2, 3, 1) 18 | x = x.contiguous() 19 | return x 20 | 21 | 22 | def rand_brightness(x): 23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 24 | return x 25 | 26 | 27 | def rand_saturation(x): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 30 | return x 31 | 32 | 33 | def rand_contrast(x): 34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 36 | return x 37 | 38 | 39 | def rand_translation(x, ratio=0.125): 40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 43 | grid_batch, grid_x, grid_y = torch.meshgrid( 44 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 45 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 47 | ) 48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 52 | return x 53 | 54 | 55 | def rand_cutout(x, ratio=0.2): 56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 59 | grid_batch, grid_x, grid_y = torch.meshgrid( 60 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 63 | ) 64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 67 | mask[grid_batch, grid_x, grid_y] = 0 68 | x = x * mask.unsqueeze(1) 69 | return x 70 | 71 | 72 | AUGMENT_FNS = { 73 | 'color': [rand_brightness, rand_saturation, rand_contrast], 74 | 'translation': [rand_translation], 75 | 'cutout': [rand_cutout], 76 | } 77 | -------------------------------------------------------------------------------- /stylesan-xl/training/networks_fastgan.py: -------------------------------------------------------------------------------- 1 | # original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py 2 | # 3 | # modified by Axel Sauer for "Projected GANs Converge Faster" 4 | # 5 | import torch.nn as nn 6 | from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d) 7 | 8 | 9 | def normalize_second_moment(x, dim=1, eps=1e-8): 10 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 11 | 12 | 13 | class DummyMapping(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, z, c=None, **kwargs): 18 | return z.unsqueeze(1) # to fit the StyleGAN API 19 | 20 | 21 | class FastganSynthesis(nn.Module): 22 | def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False): 23 | super().__init__() 24 | self.img_resolution = img_resolution 25 | self.z_dim = z_dim 26 | 27 | # channel multiplier 28 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 29 | 512:0.25, 1024:0.125} 30 | nfc = {} 31 | for k, v in nfc_multi.items(): 32 | nfc[k] = int(v*ngf) 33 | 34 | # layers 35 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4) 36 | 37 | UpBlock = UpBlockSmall if lite else UpBlockBig 38 | 39 | self.feat_8 = UpBlock(nfc[4], nfc[8]) 40 | self.feat_16 = UpBlock(nfc[8], nfc[16]) 41 | self.feat_32 = UpBlock(nfc[16], nfc[32]) 42 | self.feat_64 = UpBlock(nfc[32], nfc[64]) 43 | self.feat_128 = UpBlock(nfc[64], nfc[128]) 44 | self.feat_256 = UpBlock(nfc[128], nfc[256]) 45 | 46 | self.se_64 = SEBlock(nfc[4], nfc[64]) 47 | self.se_128 = SEBlock(nfc[8], nfc[128]) 48 | self.se_256 = SEBlock(nfc[16], nfc[256]) 49 | 50 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) 51 | 52 | if img_resolution > 256: 53 | self.feat_512 = UpBlock(nfc[256], nfc[512]) 54 | self.se_512 = SEBlock(nfc[32], nfc[512]) 55 | if img_resolution > 512: 56 | self.feat_1024 = UpBlock(nfc[512], nfc[1024]) 57 | 58 | def forward(self, input, c=None, **kwargs): 59 | # map noise to hypersphere as in "Progressive Growing of GANS" 60 | input = normalize_second_moment(input[:, 0]) 61 | 62 | feat_4 = self.init(input) 63 | feat_8 = self.feat_8(feat_4) 64 | feat_16 = self.feat_16(feat_8) 65 | feat_32 = self.feat_32(feat_16) 66 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32)) 67 | 68 | if self.img_resolution >= 64: 69 | feat_last = feat_64 70 | 71 | if self.img_resolution >= 128: 72 | feat_last = self.se_128(feat_8, self.feat_128(feat_last)) 73 | 74 | if self.img_resolution >= 256: 75 | feat_last = self.se_256(feat_16, self.feat_256(feat_last)) 76 | 77 | if self.img_resolution >= 512: 78 | feat_last = self.se_512(feat_32, self.feat_512(feat_last)) 79 | 80 | if self.img_resolution >= 1024: 81 | feat_last = self.feat_1024(feat_last) 82 | 83 | return self.to_big(feat_last) 84 | 85 | 86 | class FastganSynthesisCond(nn.Module): 87 | def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False): 88 | super().__init__() 89 | 90 | self.z_dim = z_dim 91 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 92 | 512:0.25, 1024:0.125, 2048:0.125} 93 | nfc = {} 94 | for k, v in nfc_multi.items(): 95 | nfc[k] = int(v*ngf) 96 | 97 | self.img_resolution = img_resolution 98 | 99 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4) 100 | 101 | UpBlock = UpBlockSmallCond if lite else UpBlockBigCond 102 | 103 | self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim) 104 | self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim) 105 | self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim) 106 | self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim) 107 | self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim) 108 | self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim) 109 | 110 | self.se_64 = SEBlock(nfc[4], nfc[64]) 111 | self.se_128 = SEBlock(nfc[8], nfc[128]) 112 | self.se_256 = SEBlock(nfc[16], nfc[256]) 113 | 114 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) 115 | 116 | if img_resolution > 256: 117 | self.feat_512 = UpBlock(nfc[256], nfc[512]) 118 | self.se_512 = SEBlock(nfc[32], nfc[512]) 119 | if img_resolution > 512: 120 | self.feat_1024 = UpBlock(nfc[512], nfc[1024]) 121 | 122 | self.embed = nn.Embedding(num_classes, z_dim) 123 | 124 | def forward(self, input, c, update_emas=False): 125 | c = self.embed(c.argmax(1)) 126 | 127 | # map noise to hypersphere as in "Progressive Growing of GANS" 128 | input = normalize_second_moment(input[:, 0]) 129 | 130 | feat_4 = self.init(input) 131 | feat_8 = self.feat_8(feat_4, c) 132 | feat_16 = self.feat_16(feat_8, c) 133 | feat_32 = self.feat_32(feat_16, c) 134 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c)) 135 | feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c)) 136 | 137 | if self.img_resolution >= 128: 138 | feat_last = feat_128 139 | 140 | if self.img_resolution >= 256: 141 | feat_last = self.se_256(feat_16, self.feat_256(feat_last, c)) 142 | 143 | if self.img_resolution >= 512: 144 | feat_last = self.se_512(feat_32, self.feat_512(feat_last, c)) 145 | 146 | if self.img_resolution >= 1024: 147 | feat_last = self.feat_1024(feat_last, c) 148 | 149 | return self.to_big(feat_last) 150 | 151 | 152 | class Generator(nn.Module): 153 | def __init__( 154 | self, 155 | z_dim=256, 156 | c_dim=0, 157 | w_dim=0, 158 | img_resolution=256, 159 | img_channels=3, 160 | ngf=128, 161 | cond=0, 162 | mapping_kwargs={}, 163 | synthesis_kwargs={}, 164 | **kwargs, 165 | ): 166 | super().__init__() 167 | self.z_dim = z_dim 168 | self.c_dim = c_dim 169 | self.w_dim = w_dim 170 | self.img_resolution = img_resolution 171 | self.img_channels = img_channels 172 | 173 | # Mapping and Synthesis Networks 174 | self.mapping = DummyMapping() # to fit the StyleGAN API 175 | Synthesis = FastganSynthesisCond if cond else FastganSynthesis 176 | self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs) 177 | 178 | def forward(self, z, c, **kwargs): 179 | w = self.mapping(z, c) 180 | img = self.synthesis(w, c) 181 | return img 182 | -------------------------------------------------------------------------------- /stylesan-xl/viz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylesan-xl/viz/capture_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import re 11 | import numpy as np 12 | import imgui 13 | import PIL.Image 14 | from gui_utils import imgui_utils 15 | from . import renderer 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class CaptureWidget: 20 | def __init__(self, viz): 21 | self.viz = viz 22 | self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) 23 | self.dump_image = False 24 | self.dump_gui = False 25 | self.defer_frames = 0 26 | self.disabled_time = 0 27 | 28 | def dump_png(self, image): 29 | viz = self.viz 30 | try: 31 | _height, _width, channels = image.shape 32 | assert channels in [1, 3] 33 | assert image.dtype == np.uint8 34 | os.makedirs(self.path, exist_ok=True) 35 | file_id = 0 36 | for entry in os.scandir(self.path): 37 | if entry.is_file(): 38 | match = re.fullmatch(r'(\d+).*', entry.name) 39 | if match: 40 | file_id = max(file_id, int(match.group(1)) + 1) 41 | if channels == 1: 42 | pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') 43 | else: 44 | pil_image = PIL.Image.fromarray(image, 'RGB') 45 | pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) 46 | except: 47 | viz.result.error = renderer.CapturedException() 48 | 49 | @imgui_utils.scoped_by_object_id 50 | def __call__(self, show=True): 51 | viz = self.viz 52 | if show: 53 | with imgui_utils.grayed_out(self.disabled_time != 0): 54 | imgui.text('Capture') 55 | imgui.same_line(viz.label_w) 56 | _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, 57 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 58 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 59 | help_text='PATH') 60 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': 61 | imgui.set_tooltip(self.path) 62 | imgui.same_line() 63 | if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): 64 | self.dump_image = True 65 | self.defer_frames = 2 66 | self.disabled_time = 0.5 67 | imgui.same_line() 68 | if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)): 69 | self.dump_gui = True 70 | self.defer_frames = 2 71 | self.disabled_time = 0.5 72 | 73 | self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) 74 | if self.defer_frames > 0: 75 | self.defer_frames -= 1 76 | elif self.dump_image: 77 | if 'image' in viz.result: 78 | self.dump_png(viz.result.image) 79 | self.dump_image = False 80 | elif self.dump_gui: 81 | viz.capture_next_frame() 82 | self.dump_gui = False 83 | captured_frame = viz.pop_captured_frame() 84 | if captured_frame is not None: 85 | self.dump_png(captured_frame) 86 | 87 | #---------------------------------------------------------------------------- 88 | -------------------------------------------------------------------------------- /stylesan-xl/viz/equivariance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class EquivarianceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.xlate = dnnlib.EasyDict(x=0, y=0, anim=False, round=False, speed=1e-2) 20 | self.xlate_def = dnnlib.EasyDict(self.xlate) 21 | self.rotate = dnnlib.EasyDict(val=0, anim=False, speed=5e-3) 22 | self.rotate_def = dnnlib.EasyDict(self.rotate) 23 | self.opts = dnnlib.EasyDict(untransform=False) 24 | self.opts_def = dnnlib.EasyDict(self.opts) 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | if show: 30 | imgui.text('Translate') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8): 33 | _changed, (self.xlate.x, self.xlate.y) = imgui.input_float2('##xlate', self.xlate.x, self.xlate.y, format='%.4f') 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag fast##xlate', width=viz.button_w) 36 | if dragging: 37 | self.xlate.x += dx / viz.font_size * 2e-2 38 | self.xlate.y += dy / viz.font_size * 2e-2 39 | imgui.same_line() 40 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag slow##xlate', width=viz.button_w) 41 | if dragging: 42 | self.xlate.x += dx / viz.font_size * 4e-4 43 | self.xlate.y += dy / viz.font_size * 4e-4 44 | imgui.same_line() 45 | _clicked, self.xlate.anim = imgui.checkbox('Anim##xlate', self.xlate.anim) 46 | imgui.same_line() 47 | _clicked, self.xlate.round = imgui.checkbox('Round##xlate', self.xlate.round) 48 | imgui.same_line() 49 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.xlate.anim): 50 | changed, speed = imgui.slider_float('##xlate_speed', self.xlate.speed, 0, 0.5, format='Speed %.5f', power=5) 51 | if changed: 52 | self.xlate.speed = speed 53 | imgui.same_line() 54 | if imgui_utils.button('Reset##xlate', width=-1, enabled=(self.xlate != self.xlate_def)): 55 | self.xlate = dnnlib.EasyDict(self.xlate_def) 56 | 57 | if show: 58 | imgui.text('Rotate') 59 | imgui.same_line(viz.label_w) 60 | with imgui_utils.item_width(viz.font_size * 8): 61 | _changed, self.rotate.val = imgui.input_float('##rotate', self.rotate.val, format='%.4f') 62 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 63 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag fast##rotate', width=viz.button_w) 64 | if dragging: 65 | self.rotate.val += dx / viz.font_size * 2e-2 66 | imgui.same_line() 67 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag slow##rotate', width=viz.button_w) 68 | if dragging: 69 | self.rotate.val += dx / viz.font_size * 4e-4 70 | imgui.same_line() 71 | _clicked, self.rotate.anim = imgui.checkbox('Anim##rotate', self.rotate.anim) 72 | imgui.same_line() 73 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.rotate.anim): 74 | changed, speed = imgui.slider_float('##rotate_speed', self.rotate.speed, -1, 1, format='Speed %.4f', power=3) 75 | if changed: 76 | self.rotate.speed = speed 77 | imgui.same_line() 78 | if imgui_utils.button('Reset##rotate', width=-1, enabled=(self.rotate != self.rotate_def)): 79 | self.rotate = dnnlib.EasyDict(self.rotate_def) 80 | 81 | if show: 82 | imgui.set_cursor_pos_x(imgui.get_content_region_max()[0] - 1 - viz.button_w*1 - viz.font_size*16) 83 | _clicked, self.opts.untransform = imgui.checkbox('Untransform', self.opts.untransform) 84 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 85 | if imgui_utils.button('Reset##opts', width=-1, enabled=(self.opts != self.opts_def)): 86 | self.opts = dnnlib.EasyDict(self.opts_def) 87 | 88 | if self.xlate.anim: 89 | c = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 90 | t = c.copy() 91 | if np.max(np.abs(t)) < 1e-4: 92 | t += 1 93 | t *= 0.1 / np.hypot(*t) 94 | t += c[::-1] * [1, -1] 95 | d = t - c 96 | d *= (viz.frame_delta * self.xlate.speed) / np.hypot(*d) 97 | self.xlate.x += d[0] 98 | self.xlate.y += d[1] 99 | 100 | if self.rotate.anim: 101 | self.rotate.val += viz.frame_delta * self.rotate.speed 102 | 103 | pos = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 104 | if self.xlate.round and 'img_resolution' in viz.result: 105 | pos = np.rint(pos * viz.result.img_resolution) / viz.result.img_resolution 106 | angle = self.rotate.val * np.pi * 2 107 | 108 | viz.args.input_transform = [ 109 | [np.cos(angle), np.sin(angle), pos[0]], 110 | [-np.sin(angle), np.cos(angle), pos[1]], 111 | [0, 0, 1]] 112 | 113 | viz.args.update(untransform=self.opts.untransform) 114 | 115 | #---------------------------------------------------------------------------- 116 | -------------------------------------------------------------------------------- /stylesan-xl/viz/latent_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class LatentWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.latent = dnnlib.EasyDict(x=0, y=0, anim=False, speed=0.25) 20 | self.latent_def = dnnlib.EasyDict(self.latent) 21 | self.step_y = 100 22 | 23 | def drag(self, dx, dy): 24 | viz = self.viz 25 | self.latent.x += dx / viz.font_size * 4e-2 26 | self.latent.y += dy / viz.font_size * 4e-2 27 | 28 | @imgui_utils.scoped_by_object_id 29 | def __call__(self, show=True): 30 | viz = self.viz 31 | if show: 32 | imgui.text('Latent') 33 | imgui.same_line(viz.label_w) 34 | seed = round(self.latent.x) + round(self.latent.y) * self.step_y 35 | with imgui_utils.item_width(viz.font_size * 8): 36 | changed, seed = imgui.input_int('##seed', seed) 37 | if changed: 38 | self.latent.x = seed 39 | self.latent.y = 0 40 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 41 | frac_x = self.latent.x - round(self.latent.x) 42 | frac_y = self.latent.y - round(self.latent.y) 43 | with imgui_utils.item_width(viz.font_size * 5): 44 | changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 45 | if changed: 46 | self.latent.x += new_frac_x - frac_x 47 | self.latent.y += new_frac_y - frac_y 48 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 49 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 50 | if dragging: 51 | self.drag(dx, dy) 52 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) 53 | _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) 54 | imgui.same_line(round(viz.font_size * 27.7)) 55 | with imgui_utils.item_width(-1 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): 56 | changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) 57 | if changed: 58 | self.latent.speed = speed 59 | imgui.same_line() 60 | snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) 61 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): 62 | self.latent = snapped 63 | imgui.same_line() 64 | if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): 65 | self.latent = dnnlib.EasyDict(self.latent_def) 66 | 67 | if self.latent.anim: 68 | self.latent.x += viz.frame_delta * self.latent.speed 69 | viz.args.w0_seeds = [] # [[seed, weight], ...] 70 | for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: 71 | seed_x = np.floor(self.latent.x) + ofs_x 72 | seed_y = np.floor(self.latent.y) + ofs_y 73 | seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) 74 | weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) 75 | if weight > 0: 76 | viz.args.w0_seeds.append([seed, weight]) 77 | 78 | #---------------------------------------------------------------------------- 79 | -------------------------------------------------------------------------------- /stylesan-xl/viz/performance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import array 10 | import numpy as np 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class PerformanceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.gui_times = [float('nan')] * 60 20 | self.render_times = [float('nan')] * 30 21 | self.fps_limit = 60 22 | self.use_vsync = False 23 | self.is_async = False 24 | self.force_fp32 = False 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | self.gui_times = self.gui_times[1:] + [viz.frame_delta] 30 | if 'render_time' in viz.result: 31 | self.render_times = self.render_times[1:] + [viz.result.render_time] 32 | del viz.result.render_time 33 | 34 | if show: 35 | imgui.text('GUI') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 8): 38 | imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) 39 | imgui.same_line(viz.label_w + viz.font_size * 9) 40 | t = [x for x in self.gui_times if x > 0] 41 | t = np.mean(t) if len(t) > 0 else 0 42 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 43 | imgui.same_line(viz.label_w + viz.font_size * 14) 44 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 45 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 46 | with imgui_utils.item_width(viz.font_size * 6): 47 | _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 48 | self.fps_limit = min(max(self.fps_limit, 5), 1000) 49 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 50 | _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) 51 | 52 | if show: 53 | imgui.text('Render') 54 | imgui.same_line(viz.label_w) 55 | with imgui_utils.item_width(viz.font_size * 8): 56 | imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) 57 | imgui.same_line(viz.label_w + viz.font_size * 9) 58 | t = [x for x in self.render_times if x > 0] 59 | t = np.mean(t) if len(t) > 0 else 0 60 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 61 | imgui.same_line(viz.label_w + viz.font_size * 14) 62 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 63 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 64 | _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) 65 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 66 | _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) 67 | 68 | viz.set_fps_limit(self.fps_limit) 69 | viz.set_vsync(self.use_vsync) 70 | viz.set_async(self.is_async) 71 | viz.args.force_fp32 = self.force_fp32 72 | 73 | #---------------------------------------------------------------------------- 74 | -------------------------------------------------------------------------------- /stylesan-xl/viz/pickle_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import os 11 | import re 12 | 13 | import dnnlib 14 | import imgui 15 | import numpy as np 16 | from gui_utils import imgui_utils 17 | 18 | from . import renderer 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def _locate_results(pattern): 23 | return pattern 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | class PickleWidget: 28 | def __init__(self, viz): 29 | self.viz = viz 30 | self.search_dirs = [] 31 | self.cur_pkl = None 32 | self.user_pkl = '' 33 | self.recent_pkls = [] 34 | self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} 35 | self.browse_refocus = False 36 | self.load('', ignore_errors=True) 37 | 38 | def add_recent(self, pkl, ignore_errors=False): 39 | try: 40 | resolved = self.resolve_pkl(pkl) 41 | if resolved not in self.recent_pkls: 42 | self.recent_pkls.append(resolved) 43 | except: 44 | if not ignore_errors: 45 | raise 46 | 47 | def load(self, pkl, ignore_errors=False): 48 | viz = self.viz 49 | viz.clear_result() 50 | viz.skip_frame() # The input field will change on next frame. 51 | try: 52 | resolved = self.resolve_pkl(pkl) 53 | name = resolved.replace('\\', '/').split('/')[-1] 54 | self.cur_pkl = resolved 55 | self.user_pkl = resolved 56 | viz.result.message = f'Loading {name}...' 57 | viz.defer_rendering() 58 | if resolved in self.recent_pkls: 59 | self.recent_pkls.remove(resolved) 60 | self.recent_pkls.insert(0, resolved) 61 | except: 62 | self.cur_pkl = None 63 | self.user_pkl = pkl 64 | if pkl == '': 65 | viz.result = dnnlib.EasyDict(message='No network pickle loaded') 66 | else: 67 | viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) 68 | if not ignore_errors: 69 | raise 70 | 71 | @imgui_utils.scoped_by_object_id 72 | def __call__(self, show=True): 73 | viz = self.viz 74 | recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] 75 | if show: 76 | imgui.text('Pickle') 77 | imgui.same_line(viz.label_w) 78 | changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, 79 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 80 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 81 | help_text=' | | | | /.pkl') 82 | if changed: 83 | self.load(self.user_pkl, ignore_errors=True) 84 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': 85 | imgui.set_tooltip(self.user_pkl) 86 | imgui.same_line() 87 | if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): 88 | imgui.open_popup('recent_pkls_popup') 89 | imgui.same_line() 90 | if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): 91 | imgui.open_popup('browse_pkls_popup') 92 | self.browse_cache.clear() 93 | self.browse_refocus = True 94 | 95 | if imgui.begin_popup('recent_pkls_popup'): 96 | for pkl in recent_pkls: 97 | clicked, _state = imgui.menu_item(pkl) 98 | if clicked: 99 | self.load(pkl, ignore_errors=True) 100 | imgui.end_popup() 101 | 102 | if imgui.begin_popup('browse_pkls_popup'): 103 | def recurse(parents): 104 | key = tuple(parents) 105 | items = self.browse_cache.get(key, None) 106 | if items is None: 107 | items = self.list_runs_and_pkls(parents) 108 | self.browse_cache[key] = items 109 | for item in items: 110 | if item.type == 'run' and imgui.begin_menu(item.name): 111 | recurse([item.path]) 112 | imgui.end_menu() 113 | if item.type == 'pkl': 114 | clicked, _state = imgui.menu_item(item.name) 115 | if clicked: 116 | self.load(item.path, ignore_errors=True) 117 | if len(items) == 0: 118 | with imgui_utils.grayed_out(): 119 | imgui.menu_item('No results found') 120 | recurse(self.search_dirs) 121 | if self.browse_refocus: 122 | imgui.set_scroll_here() 123 | viz.skip_frame() # Focus will change on next frame. 124 | self.browse_refocus = False 125 | imgui.end_popup() 126 | 127 | paths = viz.pop_drag_and_drop_paths() 128 | if paths is not None and len(paths) >= 1: 129 | self.load(paths[0], ignore_errors=True) 130 | 131 | viz.args.pkl = self.cur_pkl 132 | 133 | def list_runs_and_pkls(self, parents): 134 | items = [] 135 | run_regex = re.compile(r'\d+-.*') 136 | pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') 137 | for parent in set(parents): 138 | if os.path.isdir(parent): 139 | for entry in os.scandir(parent): 140 | if entry.is_dir() and run_regex.fullmatch(entry.name): 141 | items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) 142 | if entry.is_file() and pkl_regex.fullmatch(entry.name): 143 | items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) 144 | 145 | items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) 146 | return items 147 | 148 | def resolve_pkl(self, pattern): 149 | assert isinstance(pattern, str) 150 | assert pattern != '' 151 | 152 | # URL => return as is. 153 | if dnnlib.util.is_url(pattern): 154 | return pattern 155 | 156 | # Short-hand pattern => locate. 157 | path = _locate_results(pattern) 158 | 159 | # Run dir => pick the last saved snapshot. 160 | if os.path.isdir(path): 161 | pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) 162 | if len(pkl_files) == 0: 163 | raise IOError(f'No network pickle found in "{path}"') 164 | path = pkl_files[-1] 165 | 166 | # Normalize. 167 | path = os.path.abspath(path) 168 | return path 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /stylesan-xl/viz/stylemix_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class StyleMixingWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.seed_def = 1000 18 | self.seed = self.seed_def 19 | self.animate = False 20 | self.enables = [] 21 | 22 | @imgui_utils.scoped_by_object_id 23 | def __call__(self, show=True): 24 | viz = self.viz 25 | num_ws = viz.result.get('num_ws', 0) 26 | num_enables = viz.result.get('num_ws', 18) 27 | self.enables += [False] * max(num_enables - len(self.enables), 0) 28 | 29 | if show: 30 | imgui.text('Stylemix') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): 33 | _changed, self.seed = imgui.input_int('##seed', self.seed) 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | with imgui_utils.grayed_out(num_ws == 0): 36 | _clicked, self.animate = imgui.checkbox('Anim', self.animate) 37 | 38 | pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w 39 | pos1 = pos2 - imgui.get_text_line_height() - viz.spacing 40 | pos0 = viz.label_w + viz.font_size * 12 41 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) 42 | for idx in range(num_enables): 43 | imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) 44 | if idx == 0: 45 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) 46 | with imgui_utils.grayed_out(num_ws == 0): 47 | _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) 48 | if imgui.is_item_hovered(): 49 | imgui.set_tooltip(f'{idx}') 50 | imgui.pop_style_var(1) 51 | 52 | imgui.same_line(pos2) 53 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) 54 | with imgui_utils.grayed_out(num_ws == 0): 55 | if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): 56 | self.seed = self.seed_def 57 | self.animate = False 58 | self.enables = [False] * num_enables 59 | 60 | if any(self.enables[:num_ws]): 61 | viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] 62 | viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) 63 | if self.animate: 64 | self.seed += 1 65 | 66 | #---------------------------------------------------------------------------- 67 | -------------------------------------------------------------------------------- /stylesan-xl/viz/trunc_noise_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class TruncationNoiseWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.prev_num_ws = 0 18 | self.trunc_psi = 1 19 | self.trunc_cutoff = 0 20 | self.noise_enable = True 21 | self.noise_seed = 0 22 | self.noise_anim = False 23 | 24 | @imgui_utils.scoped_by_object_id 25 | def __call__(self, show=True): 26 | viz = self.viz 27 | num_ws = viz.result.get('num_ws', 0) 28 | has_noise = viz.result.get('has_noise', False) 29 | if num_ws > 0 and num_ws != self.prev_num_ws: 30 | if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: 31 | self.trunc_cutoff = num_ws 32 | self.prev_num_ws = num_ws 33 | 34 | if show: 35 | imgui.text('Truncate') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): 38 | _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') 39 | imgui.same_line() 40 | if num_ws == 0: 41 | imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) 42 | else: 43 | with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): 44 | changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') 45 | if changed: 46 | self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) 47 | 48 | with imgui_utils.grayed_out(not has_noise): 49 | imgui.same_line() 50 | _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) 51 | imgui.same_line(round(viz.font_size * 27.7)) 52 | with imgui_utils.grayed_out(not self.noise_enable): 53 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing - viz.font_size * 4): 54 | _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) 55 | imgui.same_line(spacing=0) 56 | _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) 57 | 58 | is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) 59 | is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) 60 | with imgui_utils.grayed_out(is_def_trunc and not has_noise): 61 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 62 | if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): 63 | self.prev_num_ws = num_ws 64 | self.trunc_psi = 1 65 | self.trunc_cutoff = num_ws 66 | self.noise_enable = True 67 | self.noise_seed = 0 68 | self.noise_anim = False 69 | 70 | if self.noise_anim: 71 | self.noise_seed += 1 72 | viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) 73 | viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') 74 | 75 | #---------------------------------------------------------------------------- 76 | -------------------------------------------------------------------------------- /tutorial/assets/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tutorial/assets/imagenet256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/tutorial/assets/imagenet256.png --------------------------------------------------------------------------------