├── README.md
├── simple-san
├── .gitignore
├── LICENSE
├── README.md
├── assets
│ ├── gan_class.png
│ ├── gan_nocond.png
│ ├── san_class.png
│ └── san_nocond.png
├── hparams
│ └── params.json
├── models
│ ├── discriminator.py
│ └── generator.py
├── requirements.txt
└── train.py
├── stylesan-xl
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── calc_metrics.py
├── dataset_tool.py
├── dataset_tool_for_imagenet.py
├── dnnlib
│ ├── __init__.py
│ └── util.py
├── feature_networks
│ ├── clip
│ │ ├── __init__.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── clip.py
│ │ ├── model.py
│ │ └── simple_tokenizer.py
│ ├── constants.py
│ ├── pretrained_builder.py
│ └── vit.py
├── gen_class_samplesheet.py
├── gen_images.py
├── gen_video.py
├── gui_utils
│ ├── __init__.py
│ ├── gl_utils.py
│ ├── glfw_window.py
│ ├── imgui_utils.py
│ ├── imgui_window.py
│ └── text_utils.py
├── in_embeddings
│ └── tf_efficientnet_lite0.pkl
├── incl_licenses
│ └── LICENSE_1
├── legacy.py
├── media
│ └── imagenet_idx2labels.txt
├── metrics
│ ├── __init__.py
│ ├── equivariance.py
│ ├── frechet_inception_distance.py
│ ├── inception_score.py
│ ├── kernel_inception_distance.py
│ ├── metric_main.py
│ ├── metric_utils.py
│ ├── perceptual_path_length.py
│ └── precision_recall.py
├── pg_modules
│ ├── blocks.py
│ ├── discriminator.py
│ ├── projector.py
│ └── san_modules.py
├── requirements.txt
├── run_inversion.py
├── run_stylemc.py
├── torch_utils
│ ├── __init__.py
│ ├── custom_ops.py
│ ├── gen_utils.py
│ ├── misc.py
│ ├── ops
│ │ ├── __init__.py
│ │ ├── bias_act.cpp
│ │ ├── bias_act.cu
│ │ ├── bias_act.h
│ │ ├── bias_act.py
│ │ ├── conv2d_gradfix.py
│ │ ├── conv2d_resample.py
│ │ ├── filtered_lrelu.cpp
│ │ ├── filtered_lrelu.cu
│ │ ├── filtered_lrelu.h
│ │ ├── filtered_lrelu.py
│ │ ├── filtered_lrelu_ns.cu
│ │ ├── filtered_lrelu_rd.cu
│ │ ├── filtered_lrelu_wr.cu
│ │ ├── fma.py
│ │ ├── grid_sample_gradfix.py
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.cu
│ │ ├── upfirdn2d.h
│ │ └── upfirdn2d.py
│ ├── persistence.py
│ ├── training_stats.py
│ └── utils_spectrum.py
├── train.py
├── training
│ ├── __init__.py
│ ├── augment.py
│ ├── dataset.py
│ ├── diffaug.py
│ ├── loss.py
│ ├── networks_fastgan.py
│ ├── networks_stylegan2.py
│ ├── networks_stylegan3.py
│ ├── networks_stylegan3_resetting.py
│ └── training_loop.py
├── visualizer.py
└── viz
│ ├── __init__.py
│ ├── capture_widget.py
│ ├── equivariance_widget.py
│ ├── latent_widget.py
│ ├── layer_widget.py
│ ├── performance_widget.py
│ ├── pickle_widget.py
│ ├── renderer.py
│ ├── stylemix_widget.py
│ └── trunc_noise_widget.py
└── tutorial
├── README.md
└── assets
├── .gitkeep
└── imagenet256.png
/README.md:
--------------------------------------------------------------------------------
1 | # Slicing Adversarial Network (SAN) [ICLR 2024]
2 |
3 | This repository contains the official PyTorch implementation of **"SAN: Inducing Metrizability of GAN with Discriminative Normalized Linear Layer"** (*[arXiv 2301.12811](https://arxiv.org/abs/2301.12811)*).
4 | Please cite [[1](#citation)] in your work when using this code in your experiments.
5 |
6 | ### [[Project Page]](https://ytakida.github.io/san/)
7 |
8 |
9 | > **Abstract:** Generative adversarial networks (GANs) learn a target probability distribution by optimizing a generator and a discriminator with minimax objectives. This paper addresses the question of whether such optimization actually provides the generator with gradients that make its distribution close to the target distribution. We derive metrizable conditions, sufficient conditions for the discriminator to serve as the distance between the distributions, by connecting the GAN formulation with the concept of sliced optimal transport. Furthermore, by leveraging these theoretical results, we propose a novel GAN training scheme called the Slicing Adversarial Network (SAN). With only simple modifications, a broad class of existing GANs can be converted to SANs. Experiments on synthetic and image datasets support our theoretical results and the effectiveness of SAN as compared to the usual GANs. We also apply SAN to StyleGAN-XL, which leads to a state-of-the-art FID score amongst GANs for class conditional generation on CIFAR10 and ImageNet 256$times$256.
10 |
11 |
12 | # Citation
13 | [1] Takida, Y., Imaizumi, M., Shibuya, T., Lai, C., Uesaka, T., Murata, N. and Mitsufuji, Y.,
14 | "SAN: Inducing Metrizability of GAN with Discriminative Normalized Linear Layer,"
15 | ICLR 2024.
16 | ```
17 | @inproceedings{takida2024san,
18 | title={{SAN}: Inducing Metrizability of {GAN} with Discriminative Normalized Linear Layer},
19 | author={Takida, Yuhta and Imaizumi, Masaaki and Shibuya, Takashi and Lai, Chieh-Hsin and Uesaka, Toshimitsu and Murata, Naoki and Mitsufuji, Yuki},
20 | booktitle={The Twelfth International Conference on Learning Representations},
21 | year={2024},
22 | url={https://openreview.net/forum?id=eiF7TU1E8E}
23 | }
24 | ```
--------------------------------------------------------------------------------
/simple-san/.gitignore:
--------------------------------------------------------------------------------
1 | # vscode
2 | .vscode/
3 |
4 | # pycache
5 | __pycache__
6 |
7 | # tmp files
8 | *.tmp
9 |
10 | # output
11 | job/
12 | logs/
13 | out/
14 | dataset/
--------------------------------------------------------------------------------
/simple-san/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jens Rahnfeld
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/simple-san/README.md:
--------------------------------------------------------------------------------
1 | # SAN (simple code)
2 |
3 | This repository provides a simple implementation of Slicing Adversarial Network (SAN) on MNIST for a tutorial purpose. (*[arXiv 2301.12811](https://arxiv.org/abs/2301.12811)*).
4 | Please cite [[1](#citation)] in your work when using this code in your experiments.
5 |
6 | ### [[Project Page]](https://ytakida.github.io/san/)
7 |
8 |
9 | ## Requirements
10 | This repository builds on the codebase of
11 | 1. https://github.com/yukara-ikemiya/minimal-san
12 | 2. https://github.com/JensRahnfeld/Simple-GAN-MNIST
13 |
14 | Install the following dependencies:
15 | - Python 3.8.5
16 | - pytorch 1.6.0
17 | - torchvision 0.7.0
18 | - tqdm 4.50.2
19 |
20 | ## Training
21 |
22 | ### 1) Hyperparameters
23 | Specify hyperparameters inside a .json file, e.g.:
24 |
25 | ```json
26 | {
27 | "dim_latent": 100,
28 | "batch_size": 128,
29 | "learning_rate": 0.001,
30 | "beta_1": 0.0,
31 | "beta_2": 0.99,
32 | "num_epochs": 200
33 | }
34 | ```
35 |
36 | ### 2) Options
37 |
38 | ```
39 | -h, --help
40 | show this help message and exit
41 | --datadir DATADIR
42 | path to MNIST dataset folder
43 | --params PARAMS
44 | path to hyperparameters
45 | --model MODEL
46 | model's name / 'gan' or 'san'
47 | --enable_class
48 | enable class conditioning
49 | --device DEVICE
50 | gpu device to use
51 | ```
52 |
53 |
54 | ### 3) Train the model
55 |
56 | - Class conditional (hinge) SAN
57 | ```bash
58 | python train.py --datadir --model 'san' --enable_class
59 | ```
60 |
61 | - Unconditional (hinge) GAN
62 | ```bash
63 | python train.py --datadir --model 'gan'
64 | ```
65 |
66 |
67 | ## Generated images (after 200 epochs)
68 |
69 | ### Unconditional
70 |
71 | GAN | SAN
72 | :-------------------------:|:-------------------------:
73 |  | 
74 |
75 | ### Class conditional
76 |
77 | GAN | SAN
78 | :-------------------------:|:-------------------------:
79 |  | 
80 |
81 |
82 |
83 | # Citation
84 | [1] Takida, Y., Imaizumi, M., Shibuya, T., Lai, C., Uesaka, T., Murata, N. and Mitsufuji, Y.,
85 | "SAN: Inducing Metrizability of GAN with Discriminative Normalized Linear Layer,"
86 | ICLR 2024.
87 | ```
88 | @inproceedings{takida2024san,
89 | title={{SAN}: Inducing Metrizability of {GAN} with Discriminative Normalized Linear Layer},
90 | author={Takida, Yuhta and Imaizumi, Masaaki and Shibuya, Takashi and Lai, Chieh-Hsin and Uesaka, Toshimitsu and Murata, Naoki and Mitsufuji, Yuki},
91 | booktitle={The Twelfth International Conference on Learning Representations},
92 | year={2024},
93 | url={https://openreview.net/forum?id=eiF7TU1E8E}
94 | }
95 | ```
--------------------------------------------------------------------------------
/simple-san/assets/gan_class.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/gan_class.png
--------------------------------------------------------------------------------
/simple-san/assets/gan_nocond.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/gan_nocond.png
--------------------------------------------------------------------------------
/simple-san/assets/san_class.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/san_class.png
--------------------------------------------------------------------------------
/simple-san/assets/san_nocond.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/simple-san/assets/san_nocond.png
--------------------------------------------------------------------------------
/simple-san/hparams/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dim_latent": 100,
3 | "batch_size": 128,
4 | "learning_rate": 0.001,
5 | "beta_1": 0.0,
6 | "beta_2": 0.99,
7 | "num_epochs": 200
8 | }
--------------------------------------------------------------------------------
/simple-san/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class BaseDiscriminator(nn.Module):
7 | def __init__(self, num_class=10):
8 | super(BaseDiscriminator, self).__init__()
9 |
10 | # Feature extractor
11 | self.h_function = nn.Sequential(
12 | nn.Conv2d(1, 128, kernel_size=6, stride=2),
13 | nn.LeakyReLU(0.2),
14 | nn.utils.spectral_norm(
15 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
16 | nn.LeakyReLU(0.2),
17 | nn.utils.spectral_norm(
18 | nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)),
19 | nn.LeakyReLU(0.2)
20 | )
21 |
22 | # Last linear layer
23 | self.use_class = num_class > 0
24 | self.fc_w = nn.Parameter(
25 | torch.randn(num_class if self.use_class else 1, 512 * 3 * 3))
26 |
27 | def forward(self, x, class_ids, flg_train: bool):
28 | h_feature = self.h_function(x)
29 | h_feature = torch.flatten(h_feature, start_dim=1)
30 | weights = self.fc_w[class_ids] if self.use_class else self.fc_w
31 | out = (h_feature * weights).sum(dim=1)
32 |
33 | return out
34 |
35 |
36 | # Modified Discriminator Architecture
37 |
38 | class SanDiscriminator(BaseDiscriminator):
39 | def __init__(self, num_class=10):
40 | super(SanDiscriminator, self).__init__(num_class)
41 |
42 | def forward(self, x, class_ids, flg_train: bool):
43 | h_feature = self.h_function(x)
44 | h_feature = torch.flatten(h_feature, start_dim=1)
45 | weights = self.fc_w[class_ids] if self.use_class else self.fc_w
46 | direction = F.normalize(weights, dim=1) # Normalize the last layer
47 | scale = torch.norm(weights, dim=1).unsqueeze(1)
48 | h_feature = h_feature * scale # For keep the scale
49 | if flg_train: # for discriminator training
50 | out_fun = (h_feature * direction.detach()).sum(dim=1)
51 | out_dir = (h_feature.detach() * direction).sum(dim=1)
52 | out = dict(fun=out_fun, dir=out_dir)
53 | else: # for generator training or inference
54 | out = (h_feature * direction).sum(dim=1)
55 |
56 | return out
57 |
--------------------------------------------------------------------------------
/simple-san/models/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Generator(nn.Module):
7 | def __init__(self, dim_latent=100, num_class=10):
8 | super(Generator, self).__init__()
9 |
10 | self.dim_latent = dim_latent
11 | self.use_class = num_class > 0
12 |
13 | if self.use_class:
14 | self.emb_class = nn.Embedding(num_class, dim_latent)
15 | self.fc = nn.Linear(dim_latent * 2, 512 * 3 * 3)
16 | else:
17 | self.fc = nn.Linear(dim_latent, 512 * 3 * 3)
18 |
19 | self.g_function = nn.Sequential(
20 | nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
21 | nn.BatchNorm2d(256),
22 | nn.SiLU(),
23 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
24 | nn.BatchNorm2d(128),
25 | nn.SiLU(),
26 | nn.ConvTranspose2d(128, 1, kernel_size=6, stride=2),
27 | nn.Sigmoid()
28 | )
29 |
30 | def forward(self, x, class_ids):
31 | batch_size = x.size(0)
32 |
33 | if self.use_class:
34 | x_class = self.emb_class(class_ids)
35 | x = self.fc(torch.cat((x, x_class), dim=1))
36 | else:
37 | x = self.fc(x)
38 |
39 | x = F.leaky_relu(x)
40 | x = x.view(batch_size, 512, 3, 3)
41 | img = self.g_function(x)
42 |
43 | return img
44 |
--------------------------------------------------------------------------------
/simple-san/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.6.0
2 | torchvision==0.7.0
3 | tqdm==4.50.2
4 |
--------------------------------------------------------------------------------
/simple-san/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import sys
5 | import matplotlib.pyplot as plt
6 | import matplotlib.gridspec as gridspec
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | import torchvision.datasets as datasets
13 | import torchvision.transforms as transforms
14 | import tqdm
15 |
16 | from models.discriminator import BaseDiscriminator, SanDiscriminator
17 | from models.generator import Generator
18 | from torch.utils.data import DataLoader
19 |
20 |
21 | def update_discriminator(x, class_ids, discriminator, generator, optimizer, params):
22 | bs = x.size(0)
23 | device = x.device
24 |
25 | optimizer.zero_grad()
26 |
27 | # for data (ground-truth) distribution
28 | disc_real = discriminator(x, class_ids, flg_train=True)
29 | loss_real = eval('compute_loss_'+args.model)(disc_real, loss_type='real')
30 |
31 | # for generator distribution
32 | latent = torch.randn(bs, params["dim_latent"], device=device)
33 | img_fake = generator(latent, class_ids)
34 | disc_fake = discriminator(img_fake.detach(), class_ids, flg_train=True)
35 | loss_fake = eval('compute_loss_'+args.model)(disc_fake, loss_type='fake')
36 |
37 |
38 | loss_d = loss_real + loss_fake
39 | loss_d.backward()
40 | optimizer.step()
41 |
42 |
43 | def update_generator(num_class, discriminator, generator, optimizer, params, device):
44 | optimizer.zero_grad()
45 |
46 | bs = params['batch_size']
47 | latent = torch.randn(bs, params["dim_latent"], device=device)
48 |
49 | class_ids = torch.randint(num_class, size=(bs,), device=device)
50 | batch_fake = generator(latent, class_ids)
51 |
52 | disc_gen = discriminator(batch_fake, class_ids, flg_train=False)
53 | loss_g = - disc_gen.mean()
54 | loss_g.backward()
55 | optimizer.step()
56 |
57 |
58 | def compute_loss_gan(disc, loss_type):
59 | assert (loss_type in ['real', 'fake'])
60 | if 'real' == loss_type:
61 | loss = (1. - disc).relu().mean() # Hinge loss
62 | else: # 'fake' == loss_type
63 | loss = (1. + disc).relu().mean() # Hinge loss
64 |
65 | return loss
66 |
67 |
68 | def compute_loss_san(disc, loss_type):
69 | assert (loss_type in ['real', 'fake'])
70 | if 'real' == loss_type:
71 | loss_fun = (1. - disc['fun']).relu().mean() # Hinge loss for function h
72 | loss_dir = - disc['dir'].mean() # Wasserstein loss for omega
73 | else: # 'fake' == loss_type
74 | loss_fun = (1. + disc['fun']).relu().mean() # Hinge loss for function h
75 | loss_dir = disc['dir'].mean() # Wasserstein loss for omega
76 | loss = loss_fun + loss_dir
77 |
78 | return loss
79 |
80 |
81 | def save_images(imgs, idx, dirname='test'):
82 | import numpy as np
83 | if imgs.shape[1] == 1:
84 | imgs = np.repeat(imgs, 3, axis=1)
85 | fig = plt.figure(figsize=(10, 10))
86 | gs = gridspec.GridSpec(10, 10)
87 | gs.update(wspace=0.05, hspace=0.05)
88 | for i, sample in enumerate(imgs):
89 | ax = plt.subplot(gs[i])
90 | plt.axis('off')
91 | ax.set_xticklabels([])
92 | ax.set_yticklabels([])
93 | ax.set_aspect('equal')
94 | plt.imshow(sample.transpose((1,2,0)))
95 |
96 | if not os.path.exists('out/{}/'.format(dirname)):
97 | os.makedirs('out/{}/'.format(dirname))
98 | plt.savefig('out/{0}/{1}.png'.format(dirname, str(idx).zfill(3)), bbox_inches="tight")
99 | plt.close(fig)
100 |
101 |
102 | def get_args():
103 | parser = argparse.ArgumentParser()
104 | parser.add_argument("--datadir", type=str, required=True, help="path to MNIST dataset folder")
105 | parser.add_argument("--params", type=str, default="./hparams/params.json", help="path to hyperparameters")
106 | parser.add_argument("--model", type=str, default="gan", help="model's name / 'gan' or 'san'")
107 | parser.add_argument('--enable_class', action='store_true', help='enable class conditioning')
108 | parser.add_argument("--logdir", type=str, default="./logs", help="directory storing log files")
109 | parser.add_argument("--device", type=int, default=0, help="gpu device to use")
110 |
111 | return parser.parse_args()
112 |
113 |
114 | def main(args):
115 | with open(args.params, "r") as f:
116 | params = json.load(f)
117 |
118 | device = f'cuda:{args.device}' if args.device is not None else 'cpu'
119 | model_name = args.model
120 | if not model_name in ['gan', 'san']:
121 | raise RuntimeError("A model name have to be 'gan' or 'san'.")
122 | experiment_name = model_name + "_cond" if args.enable_class else model_name
123 |
124 | # dataloading
125 | num_class = 10
126 | train_dataset = datasets.MNIST(root=args.datadir, transform=transforms.ToTensor(), train=True, download=True)
127 | train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], num_workers=4,
128 | pin_memory=True, persistent_workers=True, shuffle=True)
129 | test_dataset = datasets.MNIST(root=args.datadir, transform=transforms.ToTensor(), train=False, download=True)
130 | test_loader = DataLoader(test_dataset, batch_size=params["batch_size"], num_workers=4,
131 | pin_memory=True, persistent_workers=True, shuffle=False)
132 |
133 | # model
134 | use_class = args.enable_class
135 | generator = Generator(params["dim_latent"], num_class=num_class if use_class else 0)
136 | if 'gan' == args.model:
137 | discriminator = BaseDiscriminator(num_class=num_class if use_class else 0)
138 | else: # 'san' == args.model
139 | discriminator = SanDiscriminator(num_class=num_class if use_class else 0)
140 | generator = generator.to(device)
141 | discriminator = discriminator.to(device)
142 |
143 | # optimizer
144 | betas = (params["beta_1"], params["beta_2"])
145 | optimizer_G = optim.Adam(generator.parameters(), lr=params["learning_rate"], betas=betas)
146 | optimizer_D = optim.Adam(discriminator.parameters(), lr=params["learning_rate"], betas=betas)
147 |
148 | ckpt_dir = f'{args.logdir}/{experiment_name}/'
149 | if not os.path.exists(args.logdir):
150 | os.mkdir(args.logdir)
151 | if not os.path.exists(ckpt_dir):
152 | os.mkdir(ckpt_dir)
153 |
154 | steps_per_epoch = len(train_loader)
155 |
156 | msg = ["\t{0}: {1}".format(key, val) for key, val in params.items()]
157 | print("hyperparameters: \n" + "\n".join(msg))
158 |
159 | # eval initial states
160 | num_samples_per_class = 10
161 | with torch.no_grad():
162 | latent = torch.randn(num_samples_per_class * num_class, params["dim_latent"]).cuda()
163 | class_ids = torch.arange(num_class, dtype=torch.long,
164 | device=device).repeat_interleave(num_samples_per_class)
165 | imgs_fake = generator(latent, class_ids)
166 |
167 | # main training loop
168 | for n in range(params["num_epochs"]):
169 | loader = iter(train_loader)
170 |
171 | print("epoch: {0}/{1}".format(n + 1, params["num_epochs"]))
172 | for i in tqdm.trange(steps_per_epoch):
173 | x, class_ids = next(loader)
174 | x = x.to(device)
175 | class_ids = class_ids.to(device)
176 |
177 | update_discriminator(x, class_ids, discriminator, generator, optimizer_D, params)
178 | update_generator(num_class, discriminator, generator, optimizer_G, params, device)
179 |
180 | torch.save(generator.state_dict(), ckpt_dir + "g." + str(n) + ".tmp")
181 | torch.save(discriminator.state_dict(), ckpt_dir + "d." + str(n) + ".tmp")
182 |
183 | # eval
184 | with torch.no_grad():
185 | latent = torch.randn(num_samples_per_class * num_class, params["dim_latent"]).cuda()
186 | class_ids = torch.arange(num_class, dtype=torch.long,
187 | device=device).repeat_interleave(num_samples_per_class)
188 | imgs_fake = generator(latent, class_ids).cpu().data.numpy()
189 | save_images(imgs_fake, n, dirname=experiment_name)
190 |
191 | torch.save(generator.state_dict(), ckpt_dir + "generator.pt")
192 | torch.save(discriminator.state_dict(), ckpt_dir + "discriminator.pt")
193 |
194 |
195 | if __name__ == '__main__':
196 | args = get_args()
197 | main(args)
198 |
--------------------------------------------------------------------------------
/stylesan-xl/.gitignore:
--------------------------------------------------------------------------------
1 | g++
2 | gcc
3 |
4 | data
5 | data/*
6 | !data/.placeholder
7 |
8 | training-runs
9 | training-runs/*
10 | out/*
11 | sample_sheets/*
12 |
13 |
14 | *.zip
15 |
16 | **/__pycache__
17 | __pycache__
18 | .ipynb_checkpoints/
19 | tags
20 | *.swp
21 | *.pth
22 | *.pt
23 | *.npz
24 | *.tar
25 | *.gz
26 | *.pkl
27 | *.mp4
28 | *.pyc
29 |
--------------------------------------------------------------------------------
/stylesan-xl/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:23.06-py3
2 | COPY ./requirements.txt .
3 | RUN pip install --upgrade pip
4 | RUN pip install -r requirements.txt
5 | RUN pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121
6 | RUN apt-get update -y && apt-get install -y --no-install-recommends build-essential gcc libsndfile1
7 |
--------------------------------------------------------------------------------
/stylesan-xl/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Sony Research Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/stylesan-xl/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from .util import EasyDict, make_cache_dir_path
10 |
--------------------------------------------------------------------------------
/stylesan-xl/feature_networks/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/stylesan-xl/feature_networks/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/stylesan-xl/feature_networks/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/stylesan-xl/feature_networks/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/stylesan-xl/feature_networks/constants.py:
--------------------------------------------------------------------------------
1 | TORCHVISION = [
2 | "vgg11_bn",
3 | "vgg13_bn",
4 | "vgg16",
5 | "vgg16_bn",
6 | "vgg19_bn",
7 | "densenet121",
8 | "densenet169",
9 | "densenet201",
10 | "inception_v3",
11 | "resnet18",
12 | "resnet34",
13 | "resnet50",
14 | "resnet101",
15 | "resnet152",
16 | "shufflenet_v2_x0_5",
17 | "mobilenet_v2",
18 | "wide_resnet50_2",
19 | "mnasnet0_5",
20 | "mnasnet1_0",
21 | "ghostnet_100",
22 | "cspresnet50",
23 | "fbnetc_100",
24 | "spnasnet_100",
25 | "resnet50d",
26 | "resnet26",
27 | "resnet26d",
28 | "seresnet50",
29 | "resnetblur50",
30 | "resnetrs50",
31 | "tf_mixnet_s",
32 | "tf_mixnet_m",
33 | "tf_mixnet_l",
34 | "ese_vovnet19b_dw",
35 | "ese_vovnet39b",
36 | "res2next50",
37 | "gernet_s",
38 | "gernet_m",
39 | "repvgg_a2",
40 | "repvgg_b0",
41 | "repvgg_b1",
42 | "repvgg_b1g4",
43 | "revnet",
44 | "dm_nfnet_f1",
45 | "nfnet_l0",
46 | ]
47 |
48 | REGNETS = [
49 | "regnetx_002",
50 | "regnetx_004",
51 | "regnetx_006",
52 | "regnetx_008",
53 | "regnetx_016",
54 | "regnetx_032",
55 | "regnetx_040",
56 | "regnetx_064",
57 | "regnety_002",
58 | "regnety_004",
59 | "regnety_006",
60 | "regnety_008",
61 | "regnety_016",
62 | "regnety_032",
63 | "regnety_040",
64 | "regnety_064",
65 | ]
66 |
67 | EFFNETS_IMAGENET = [
68 | 'tf_efficientnet_b0',
69 | 'tf_efficientnet_b1',
70 | 'tf_efficientnet_b2',
71 | 'tf_efficientnet_b3',
72 | 'tf_efficientnet_b4',
73 | 'tf_efficientnet_b0_ns',
74 | ]
75 |
76 | EFFNETS_INCEPTION = [
77 | 'tf_efficientnet_lite0',
78 | 'tf_efficientnet_lite1',
79 | 'tf_efficientnet_lite2',
80 | 'tf_efficientnet_lite3',
81 | 'tf_efficientnet_lite4',
82 | 'tf_efficientnetv2_b0',
83 | 'tf_efficientnetv2_b1',
84 | 'tf_efficientnetv2_b2',
85 | 'tf_efficientnetv2_b3',
86 | 'efficientnet_b1',
87 | 'efficientnet_b1_pruned',
88 | 'efficientnet_b2_pruned',
89 | 'efficientnet_b3_pruned',
90 | ]
91 |
92 | EFFNETS = EFFNETS_IMAGENET + EFFNETS_INCEPTION
93 |
94 | VITS_IMAGENET = [
95 | 'deit_tiny_distilled_patch16_224',
96 | 'deit_small_distilled_patch16_224',
97 | 'deit_base_distilled_patch16_224',
98 | ]
99 |
100 | VITS_INCEPTION = [
101 | 'vit_base_patch16_224'
102 | ]
103 |
104 | VITS = VITS_IMAGENET + VITS_INCEPTION
105 |
106 | CLIP = [
107 | 'resnet50_clip'
108 | ]
109 |
110 | ALL_MODELS = TORCHVISION + REGNETS + EFFNETS + VITS + CLIP
111 |
112 | # Group according to input normalization
113 |
114 | NORMALIZED_IMAGENET = TORCHVISION + REGNETS + EFFNETS_IMAGENET + VITS_IMAGENET
115 |
116 | NORMALIZED_INCEPTION = EFFNETS_INCEPTION + VITS_INCEPTION
117 |
118 | NORMALIZED_CLIP = CLIP
119 |
--------------------------------------------------------------------------------
/stylesan-xl/gen_class_samplesheet.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import PIL.Image
4 | from typing import List
5 | import click
6 | import numpy as np
7 | import torch
8 | from tqdm import tqdm
9 |
10 | import legacy
11 | import dnnlib
12 | from training.training_loop import save_image_grid
13 | from torch_utils import gen_utils
14 | from gen_images import parse_range
15 |
16 | @click.command()
17 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
18 | @click.option('--trunc', 'truncation_psi', help='Truncation psi', type=float, default=1, show_default=True)
19 | @click.option('--seed', help='Random seed', type=int, default=42)
20 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation')
21 | @click.option('--classes', type=parse_range, help='List of classes (e.g., \'0,1,4-6\')', required=True)
22 | @click.option('--samples-per-class', help='Samples per class.', type=int, default=4)
23 | @click.option('--grid-width', help='Total width of image grid', type=int, default=32)
24 | @click.option('--batch-gpu', help='Samples per pass, adapt to fit on GPU', type=int, default=32)
25 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
26 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
27 | def generate_samplesheet(
28 | network_pkl: str,
29 | truncation_psi: float,
30 | seed: int,
31 | centroids_path: str,
32 | classes: List[int],
33 | samples_per_class: int,
34 | batch_gpu: int,
35 | grid_width: int,
36 | outdir: str,
37 | desc: str,
38 | ):
39 | print('Loading networks from "%s"...' % network_pkl)
40 | device = torch.device('cuda')
41 | with dnnlib.util.open_url(network_pkl) as f:
42 | G = legacy.load_network_pkl(f)['G_ema'].to(device).requires_grad_(False)
43 |
44 | # setup
45 | os.makedirs(outdir, exist_ok=True)
46 | desc_full = f'{Path(network_pkl).stem}_trunc_{truncation_psi}'
47 | if desc is not None: desc_full += f'-{desc}'
48 | run_dir = Path(gen_utils.make_run_dir(outdir, desc_full))
49 |
50 | print('Generating latents.')
51 | ws = []
52 | for class_idx in tqdm(classes):
53 | w = gen_utils.get_w_from_seed(G, samples_per_class, device, truncation_psi, seed=seed,
54 | centroids_path=centroids_path, class_idx=class_idx)
55 | ws.append(w)
56 | ws = torch.cat(ws)
57 |
58 | print('Generating samples.')
59 | images = []
60 | for w in tqdm(ws.split(batch_gpu)):
61 | img = gen_utils.w_to_img(G, w, to_np=True)
62 | images.append(img)
63 |
64 | # adjust grid widht to prohibit folding between same class then save to disk
65 | grid_width = grid_width - grid_width % samples_per_class
66 | images = gen_utils.create_image_grid(np.concatenate(images), grid_size=(grid_width, None))
67 | PIL.Image.fromarray(images, 'RGB').save(run_dir / 'sheet.png')
68 |
69 | if __name__ == "__main__":
70 | generate_samplesheet()
71 |
--------------------------------------------------------------------------------
/stylesan-xl/gen_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Generate images using pretrained network pickle."""
10 |
11 | import os
12 | import re
13 | from typing import List, Optional, Tuple, Union
14 |
15 | import click
16 | import dnnlib
17 | import numpy as np
18 | import PIL.Image
19 | import torch
20 |
21 | import legacy
22 | from torch_utils import gen_utils
23 |
24 | #----------------------------------------------------------------------------
25 |
26 | def parse_range(s: Union[str, List]) -> List[int]:
27 | '''Parse a comma separated list of numbers or ranges and return a list of ints.
28 |
29 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
30 | '''
31 | if isinstance(s, list): return s
32 | ranges = []
33 | range_re = re.compile(r'^(\d+)-(\d+)$')
34 | for p in s.split(','):
35 | m = range_re.match(p)
36 | if m:
37 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
38 | else:
39 | ranges.append(int(p))
40 | return ranges
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
45 | '''Parse a floating point 2-vector of syntax 'a,b'.
46 |
47 | Example:
48 | '0,1' returns (0,1)
49 | '''
50 | if isinstance(s, tuple): return s
51 | parts = s.split(',')
52 | if len(parts) == 2:
53 | return (float(parts[0]), float(parts[1]))
54 | raise ValueError(f'cannot parse 2-vector {s}')
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | def make_transform(translate: Tuple[float,float], angle: float):
59 | m = np.eye(3)
60 | s = np.sin(angle/360.0*np.pi*2)
61 | c = np.cos(angle/360.0*np.pi*2)
62 | m[0][0] = c
63 | m[0][1] = s
64 | m[0][2] = translate[0]
65 | m[1][0] = -s
66 | m[1][1] = c
67 | m[1][2] = translate[1]
68 | return m
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | @click.command()
73 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
74 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
75 | @click.option('--batch-sz', type=int, help='Batch size per sample', default=1)
76 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
77 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation')
78 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
79 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
80 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
81 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
82 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
83 | def generate_images(
84 | network_pkl: str,
85 | seeds: List[int],
86 | batch_sz: int,
87 | truncation_psi: float,
88 | centroids_path: str,
89 | noise_mode: str,
90 | outdir: str,
91 | translate: Tuple[float,float],
92 | rotate: float,
93 | class_idx: Optional[int]
94 | ):
95 | print('Loading networks from "%s"...' % network_pkl)
96 | device = torch.device('cuda')
97 | with dnnlib.util.open_url(network_pkl) as f:
98 | G = legacy.load_network_pkl(f)['G_ema']
99 | G = G.eval().requires_grad_(False).to(device)
100 |
101 | os.makedirs(outdir, exist_ok=True)
102 |
103 | # Generate images.
104 | for seed_idx, seed in enumerate(seeds):
105 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
106 |
107 | # Construct an inverse rotation/translation matrix and pass to the generator. The
108 | # generator expects this matrix as an inverse to avoid potentially failing numerical
109 | # operations in the network.
110 | if hasattr(G.synthesis, 'input'):
111 | m = make_transform(translate, rotate)
112 | m = np.linalg.inv(m)
113 | G.synthesis.input.transform.copy_(torch.from_numpy(m))
114 |
115 | w = gen_utils.get_w_from_seed(G, batch_sz, device, truncation_psi, seed=seed,
116 | centroids_path=centroids_path, class_idx=class_idx)
117 | img = gen_utils.w_to_img(G, w, to_np=True)
118 | PIL.Image.fromarray(gen_utils.create_image_grid(img), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
119 |
120 |
121 | #----------------------------------------------------------------------------
122 |
123 | if __name__ == "__main__":
124 | generate_images() # pylint: disable=no-value-for-parameter
125 |
126 | #----------------------------------------------------------------------------
127 |
--------------------------------------------------------------------------------
/stylesan-xl/gen_video.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Generate lerp videos using pretrained network pickle."""
10 |
11 | import copy
12 | import os
13 | import re
14 | from typing import List, Optional, Tuple, Union
15 |
16 | import click
17 | import dnnlib
18 | import imageio
19 | import numpy as np
20 | import scipy.interpolate
21 | import torch
22 | from tqdm import tqdm
23 |
24 | import legacy
25 | from torch_utils import gen_utils
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
30 | batch_size, channels, img_h, img_w = img.shape
31 | if grid_w is None:
32 | grid_w = batch_size // grid_h
33 | assert batch_size == grid_w * grid_h
34 | if float_to_uint8:
35 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
36 | img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
37 | img = img.permute(2, 0, 3, 1, 4)
38 | img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
39 | if chw_to_hwc:
40 | img = img.permute(1, 2, 0)
41 | if to_numpy:
42 | img = img.cpu().numpy()
43 | return img
44 |
45 | #----------------------------------------------------------------------------
46 |
47 | def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, truncation_psi=1, device=torch.device('cuda'), centroids_path=None, class_idx=None, **video_kwargs):
48 | grid_w = grid_dims[0]
49 | grid_h = grid_dims[1]
50 |
51 | if num_keyframes is None:
52 | if len(seeds) % (grid_w*grid_h) != 0:
53 | raise ValueError('Number of input seeds must be divisible by grid W*H')
54 | num_keyframes = len(seeds) // (grid_w*grid_h)
55 |
56 | all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
57 | for idx in range(num_keyframes*grid_h*grid_w):
58 | all_seeds[idx] = seeds[idx % len(seeds)]
59 |
60 | if shuffle_seed is not None:
61 | rng = np.random.RandomState(seed=shuffle_seed)
62 | rng.shuffle(all_seeds)
63 |
64 | if class_idx is None:
65 | class_idx = [None] * len(seeds)
66 | elif len(class_idx) == 1:
67 | class_idx = [class_idx] * len(seeds)
68 | assert len(all_seeds) == len(class_idx), "Seeds and class-idx should have the same length"
69 |
70 | ws = []
71 | for seed, cls in zip(all_seeds, class_idx):
72 | ws.append(
73 | gen_utils.get_w_from_seed(G, 1, device, truncation_psi, seed=seed,
74 | centroids_path=centroids_path, class_idx=cls)
75 | )
76 | ws = torch.cat(ws)
77 |
78 | _ = G.synthesis(ws[:1]) # warm up
79 | ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
80 |
81 | # Interpolation.
82 | grid = []
83 | for yi in range(grid_h):
84 | row = []
85 | for xi in range(grid_w):
86 | x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
87 | y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
88 | interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
89 | row.append(interp)
90 | grid.append(row)
91 |
92 | # Render video.
93 | video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
94 | for frame_idx in tqdm(range(num_keyframes * w_frames)):
95 | imgs = []
96 | for yi in range(grid_h):
97 | for xi in range(grid_w):
98 | interp = grid[yi][xi]
99 | w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
100 | img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
101 | imgs.append(img)
102 | video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
103 | video_out.close()
104 |
105 | #----------------------------------------------------------------------------
106 |
107 | def parse_range(s: Union[str, List[int]]) -> List[int]:
108 | '''Parse a comma separated list of numbers or ranges and return a list of ints.
109 |
110 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
111 | '''
112 | if isinstance(s, list): return s
113 | ranges = []
114 | range_re = re.compile(r'^(\d+)-(\d+)$')
115 | for p in s.split(','):
116 | m = range_re.match(p)
117 | if m:
118 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
119 | else:
120 | ranges.append(int(p))
121 | return ranges
122 |
123 | #----------------------------------------------------------------------------
124 |
125 | def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
126 | '''Parse a 'M,N' or 'MxN' integer tuple.
127 |
128 | Example:
129 | '4x2' returns (4,2)
130 | '0,1' returns (0,1)
131 | '''
132 | if isinstance(s, tuple): return s
133 | m = re.match(r'^(\d+)[x,](\d+)$', s)
134 | if m:
135 | return (int(m.group(1)), int(m.group(2)))
136 | raise ValueError(f'cannot parse tuple {s}')
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | @click.command()
141 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
142 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
143 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
144 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
145 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
146 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
147 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
148 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation')
149 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
150 | @click.option('--class', 'class_idx', type=parse_range, help='Class label (unconditional if not specified)')
151 | def generate_images(
152 | network_pkl: str,
153 | seeds: List[int],
154 | shuffle_seed: Optional[int],
155 | truncation_psi: float,
156 | centroids_path: str,
157 | grid: Tuple[int,int],
158 | num_keyframes: Optional[int],
159 | w_frames: int,
160 | output: str,
161 | class_idx: Optional[List[int]],
162 | ):
163 | """Render a latent vector interpolation video.
164 |
165 | Examples:
166 |
167 | \b
168 | # Render a 4x2 grid of interpolations for seeds 0 through 31.
169 | python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
170 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
171 |
172 | Animation length and seed keyframes:
173 |
174 | The animation length is either determined based on the --seeds value or explicitly
175 | specified using the --num-keyframes option.
176 |
177 | When num keyframes is specified with --num-keyframes, the output video length
178 | will be 'num_keyframes*w_frames' frames.
179 |
180 | If --num-keyframes is not specified, the number of seeds given with
181 | --seeds must be divisible by grid size W*H (--grid). In this case the
182 | output video length will be '# seeds/(w*h)*w_frames' frames.
183 | """
184 |
185 | print('Loading networks from "%s"...' % network_pkl)
186 | device = torch.device('cuda')
187 | with dnnlib.util.open_url(network_pkl) as f:
188 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
189 |
190 | gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, truncation_psi=truncation_psi, centroids_path=centroids_path, class_idx=class_idx)
191 |
192 | #----------------------------------------------------------------------------
193 |
194 | if __name__ == "__main__":
195 | generate_images() # pylint: disable=no-value-for-parameter
196 |
197 | #----------------------------------------------------------------------------
198 |
--------------------------------------------------------------------------------
/stylesan-xl/gui_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/stylesan-xl/gui_utils/glfw_window.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import time
10 | import glfw
11 | import OpenGL.GL as gl
12 | from . import gl_utils
13 |
14 | #----------------------------------------------------------------------------
15 |
16 | class GlfwWindow: # pylint: disable=too-many-public-methods
17 | def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
18 | self._glfw_window = None
19 | self._drawing_frame = False
20 | self._frame_start_time = None
21 | self._frame_delta = 0
22 | self._fps_limit = None
23 | self._vsync = None
24 | self._skip_frames = 0
25 | self._deferred_show = deferred_show
26 | self._close_on_esc = close_on_esc
27 | self._esc_pressed = False
28 | self._drag_and_drop_paths = None
29 | self._capture_next_frame = False
30 | self._captured_frame = None
31 |
32 | # Create window.
33 | glfw.init()
34 | glfw.window_hint(glfw.VISIBLE, False)
35 | self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
36 | self._attach_glfw_callbacks()
37 | self.make_context_current()
38 |
39 | # Adjust window.
40 | self.set_vsync(False)
41 | self.set_window_size(window_width, window_height)
42 | if not self._deferred_show:
43 | glfw.show_window(self._glfw_window)
44 |
45 | def close(self):
46 | if self._drawing_frame:
47 | self.end_frame()
48 | if self._glfw_window is not None:
49 | glfw.destroy_window(self._glfw_window)
50 | self._glfw_window = None
51 | #glfw.terminate() # Commented out to play it nice with other glfw clients.
52 |
53 | def __del__(self):
54 | try:
55 | self.close()
56 | except:
57 | pass
58 |
59 | @property
60 | def window_width(self):
61 | return self.content_width
62 |
63 | @property
64 | def window_height(self):
65 | return self.content_height + self.title_bar_height
66 |
67 | @property
68 | def content_width(self):
69 | width, _height = glfw.get_window_size(self._glfw_window)
70 | return width
71 |
72 | @property
73 | def content_height(self):
74 | _width, height = glfw.get_window_size(self._glfw_window)
75 | return height
76 |
77 | @property
78 | def title_bar_height(self):
79 | _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
80 | return top
81 |
82 | @property
83 | def monitor_width(self):
84 | _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
85 | return width
86 |
87 | @property
88 | def monitor_height(self):
89 | _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
90 | return height
91 |
92 | @property
93 | def frame_delta(self):
94 | return self._frame_delta
95 |
96 | def set_title(self, title):
97 | glfw.set_window_title(self._glfw_window, title)
98 |
99 | def set_window_size(self, width, height):
100 | width = min(width, self.monitor_width)
101 | height = min(height, self.monitor_height)
102 | glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
103 | if width == self.monitor_width and height == self.monitor_height:
104 | self.maximize()
105 |
106 | def set_content_size(self, width, height):
107 | self.set_window_size(width, height + self.title_bar_height)
108 |
109 | def maximize(self):
110 | glfw.maximize_window(self._glfw_window)
111 |
112 | def set_position(self, x, y):
113 | glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
114 |
115 | def center(self):
116 | self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
117 |
118 | def set_vsync(self, vsync):
119 | vsync = bool(vsync)
120 | if vsync != self._vsync:
121 | glfw.swap_interval(1 if vsync else 0)
122 | self._vsync = vsync
123 |
124 | def set_fps_limit(self, fps_limit):
125 | self._fps_limit = int(fps_limit)
126 |
127 | def should_close(self):
128 | return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
129 |
130 | def skip_frame(self):
131 | self.skip_frames(1)
132 |
133 | def skip_frames(self, num): # Do not update window for the next N frames.
134 | self._skip_frames = max(self._skip_frames, int(num))
135 |
136 | def is_skipping_frames(self):
137 | return self._skip_frames > 0
138 |
139 | def capture_next_frame(self):
140 | self._capture_next_frame = True
141 |
142 | def pop_captured_frame(self):
143 | frame = self._captured_frame
144 | self._captured_frame = None
145 | return frame
146 |
147 | def pop_drag_and_drop_paths(self):
148 | paths = self._drag_and_drop_paths
149 | self._drag_and_drop_paths = None
150 | return paths
151 |
152 | def draw_frame(self): # To be overridden by subclass.
153 | self.begin_frame()
154 | # Rendering code goes here.
155 | self.end_frame()
156 |
157 | def make_context_current(self):
158 | if self._glfw_window is not None:
159 | glfw.make_context_current(self._glfw_window)
160 |
161 | def begin_frame(self):
162 | # End previous frame.
163 | if self._drawing_frame:
164 | self.end_frame()
165 |
166 | # Apply FPS limit.
167 | if self._frame_start_time is not None and self._fps_limit is not None:
168 | delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
169 | if delay > 0:
170 | time.sleep(delay)
171 | cur_time = time.perf_counter()
172 | if self._frame_start_time is not None:
173 | self._frame_delta = cur_time - self._frame_start_time
174 | self._frame_start_time = cur_time
175 |
176 | # Process events.
177 | glfw.poll_events()
178 |
179 | # Begin frame.
180 | self._drawing_frame = True
181 | self.make_context_current()
182 |
183 | # Initialize GL state.
184 | gl.glViewport(0, 0, self.content_width, self.content_height)
185 | gl.glMatrixMode(gl.GL_PROJECTION)
186 | gl.glLoadIdentity()
187 | gl.glTranslate(-1, 1, 0)
188 | gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
189 | gl.glMatrixMode(gl.GL_MODELVIEW)
190 | gl.glLoadIdentity()
191 | gl.glEnable(gl.GL_BLEND)
192 | gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
193 |
194 | # Clear.
195 | gl.glClearColor(0, 0, 0, 1)
196 | gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
197 |
198 | def end_frame(self):
199 | assert self._drawing_frame
200 | self._drawing_frame = False
201 |
202 | # Skip frames if requested.
203 | if self._skip_frames > 0:
204 | self._skip_frames -= 1
205 | return
206 |
207 | # Capture frame if requested.
208 | if self._capture_next_frame:
209 | self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
210 | self._capture_next_frame = False
211 |
212 | # Update window.
213 | if self._deferred_show:
214 | glfw.show_window(self._glfw_window)
215 | self._deferred_show = False
216 | glfw.swap_buffers(self._glfw_window)
217 |
218 | def _attach_glfw_callbacks(self):
219 | glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
220 | glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
221 |
222 | def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
223 | if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
224 | self._esc_pressed = True
225 |
226 | def _glfw_drop_callback(self, _window, paths):
227 | self._drag_and_drop_paths = paths
228 |
229 | #----------------------------------------------------------------------------
230 |
--------------------------------------------------------------------------------
/stylesan-xl/gui_utils/imgui_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import contextlib
10 | import imgui
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
15 | s = imgui.get_style()
16 | s.window_padding = [spacing, spacing]
17 | s.item_spacing = [spacing, spacing]
18 | s.item_inner_spacing = [spacing, spacing]
19 | s.columns_min_spacing = spacing
20 | s.indent_spacing = indent
21 | s.scrollbar_size = scrollbar
22 | s.frame_padding = [4, 3]
23 | s.window_border_size = 1
24 | s.child_border_size = 1
25 | s.popup_border_size = 1
26 | s.frame_border_size = 1
27 | s.window_rounding = 0
28 | s.child_rounding = 0
29 | s.popup_rounding = 3
30 | s.frame_rounding = 3
31 | s.scrollbar_rounding = 3
32 | s.grab_rounding = 3
33 |
34 | getattr(imgui, f'style_colors_{color_scheme}')(s)
35 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
36 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
37 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
38 |
39 | #----------------------------------------------------------------------------
40 |
41 | @contextlib.contextmanager
42 | def grayed_out(cond=True):
43 | if cond:
44 | s = imgui.get_style()
45 | text = s.colors[imgui.COLOR_TEXT_DISABLED]
46 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
47 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
48 | imgui.push_style_color(imgui.COLOR_TEXT, *text)
49 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
50 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
51 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
52 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
53 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
55 | imgui.push_style_color(imgui.COLOR_BUTTON, *back)
56 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
57 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
58 | imgui.push_style_color(imgui.COLOR_HEADER, *back)
59 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
60 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
61 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
62 | yield
63 | imgui.pop_style_color(14)
64 | else:
65 | yield
66 |
67 | #----------------------------------------------------------------------------
68 |
69 | @contextlib.contextmanager
70 | def item_width(width=None):
71 | if width is not None:
72 | imgui.push_item_width(width)
73 | yield
74 | imgui.pop_item_width()
75 | else:
76 | yield
77 |
78 | #----------------------------------------------------------------------------
79 |
80 | def scoped_by_object_id(method):
81 | def decorator(self, *args, **kwargs):
82 | imgui.push_id(str(id(self)))
83 | res = method(self, *args, **kwargs)
84 | imgui.pop_id()
85 | return res
86 | return decorator
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | def button(label, width=0, enabled=True):
91 | with grayed_out(not enabled):
92 | clicked = imgui.button(label, width=width)
93 | clicked = clicked and enabled
94 | return clicked
95 |
96 | #----------------------------------------------------------------------------
97 |
98 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
99 | expanded = False
100 | if show:
101 | if default:
102 | flags |= imgui.TREE_NODE_DEFAULT_OPEN
103 | if not enabled:
104 | flags |= imgui.TREE_NODE_LEAF
105 | with grayed_out(not enabled):
106 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
107 | expanded = expanded and enabled
108 | return expanded, visible
109 |
110 | #----------------------------------------------------------------------------
111 |
112 | def popup_button(label, width=0, enabled=True):
113 | if button(label, width, enabled):
114 | imgui.open_popup(label)
115 | opened = imgui.begin_popup(label)
116 | return opened
117 |
118 | #----------------------------------------------------------------------------
119 |
120 | def input_text(label, value, buffer_length, flags, width=None, help_text=''):
121 | old_value = value
122 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
123 | if value == '':
124 | color[-1] *= 0.5
125 | with item_width(width):
126 | imgui.push_style_color(imgui.COLOR_TEXT, *color)
127 | value = value if value != '' else help_text
128 | changed, value = imgui.input_text(label, value, buffer_length, flags)
129 | value = value if value != help_text else ''
130 | imgui.pop_style_color(1)
131 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
132 | changed = (value != old_value)
133 | return changed, value
134 |
135 | #----------------------------------------------------------------------------
136 |
137 | def drag_previous_control(enabled=True):
138 | dragging = False
139 | dx = 0
140 | dy = 0
141 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
142 | if enabled:
143 | dragging = True
144 | dx, dy = imgui.get_mouse_drag_delta()
145 | imgui.reset_mouse_drag_delta()
146 | imgui.end_drag_drop_source()
147 | return dragging, dx, dy
148 |
149 | #----------------------------------------------------------------------------
150 |
151 | def drag_button(label, width=0, enabled=True):
152 | clicked = button(label, width=width, enabled=enabled)
153 | dragging, dx, dy = drag_previous_control(enabled=enabled)
154 | return clicked, dragging, dx, dy
155 |
156 | #----------------------------------------------------------------------------
157 |
158 | def drag_hidden_window(label, x, y, width, height, enabled=True):
159 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
160 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
161 | imgui.set_next_window_position(x, y)
162 | imgui.set_next_window_size(width, height)
163 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
164 | dragging, dx, dy = drag_previous_control(enabled=enabled)
165 | imgui.end()
166 | imgui.pop_style_color(2)
167 | return dragging, dx, dy
168 |
169 | #----------------------------------------------------------------------------
170 |
--------------------------------------------------------------------------------
/stylesan-xl/gui_utils/imgui_window.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import imgui
11 | import imgui.integrations.glfw
12 |
13 | from . import glfw_window
14 | from . import imgui_utils
15 | from . import text_utils
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | class ImguiWindow(glfw_window.GlfwWindow):
20 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
21 | if font is None:
22 | font = text_utils.get_default_font()
23 | font_sizes = {int(size) for size in font_sizes}
24 | super().__init__(title=title, **glfw_kwargs)
25 |
26 | # Init fields.
27 | self._imgui_context = None
28 | self._imgui_renderer = None
29 | self._imgui_fonts = None
30 | self._cur_font_size = max(font_sizes)
31 |
32 | # Delete leftover imgui.ini to avoid unexpected behavior.
33 | if os.path.isfile('imgui.ini'):
34 | os.remove('imgui.ini')
35 |
36 | # Init ImGui.
37 | self._imgui_context = imgui.create_context()
38 | self._imgui_renderer = _GlfwRenderer(self._glfw_window)
39 | self._attach_glfw_callbacks()
40 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
41 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
42 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
43 | self._imgui_renderer.refresh_font_texture()
44 |
45 | def close(self):
46 | self.make_context_current()
47 | self._imgui_fonts = None
48 | if self._imgui_renderer is not None:
49 | self._imgui_renderer.shutdown()
50 | self._imgui_renderer = None
51 | if self._imgui_context is not None:
52 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
53 | self._imgui_context = None
54 | super().close()
55 |
56 | def _glfw_key_callback(self, *args):
57 | super()._glfw_key_callback(*args)
58 | self._imgui_renderer.keyboard_callback(*args)
59 |
60 | @property
61 | def font_size(self):
62 | return self._cur_font_size
63 |
64 | @property
65 | def spacing(self):
66 | return round(self._cur_font_size * 0.4)
67 |
68 | def set_font_size(self, target): # Applied on next frame.
69 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
70 |
71 | def begin_frame(self):
72 | # Begin glfw frame.
73 | super().begin_frame()
74 |
75 | # Process imgui events.
76 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
77 | if self.content_width > 0 and self.content_height > 0:
78 | self._imgui_renderer.process_inputs()
79 |
80 | # Begin imgui frame.
81 | imgui.new_frame()
82 | imgui.push_font(self._imgui_fonts[self._cur_font_size])
83 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
84 |
85 | def end_frame(self):
86 | imgui.pop_font()
87 | imgui.render()
88 | imgui.end_frame()
89 | self._imgui_renderer.render(imgui.get_draw_data())
90 | super().end_frame()
91 |
92 | #----------------------------------------------------------------------------
93 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
94 |
95 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
96 | def __init__(self, *args, **kwargs):
97 | super().__init__(*args, **kwargs)
98 | self.mouse_wheel_multiplier = 1
99 |
100 | def scroll_callback(self, window, x_offset, y_offset):
101 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
102 |
103 | #----------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/stylesan-xl/gui_utils/text_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import functools
10 | from typing import Optional
11 |
12 | import dnnlib
13 | import numpy as np
14 | import PIL.Image
15 | import PIL.ImageFont
16 | import scipy.ndimage
17 |
18 | from . import gl_utils
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | def get_default_font():
23 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
24 | return dnnlib.util.open_url(url, return_filename=True)
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | @functools.lru_cache(maxsize=None)
29 | def get_pil_font(font=None, size=32):
30 | if font is None:
31 | font = get_default_font()
32 | return PIL.ImageFont.truetype(font=font, size=size)
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def get_array(string, *, dropshadow_radius: int=None, **kwargs):
37 | if dropshadow_radius is not None:
38 | offset_x = int(np.ceil(dropshadow_radius*2/3))
39 | offset_y = int(np.ceil(dropshadow_radius*2/3))
40 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
41 | else:
42 | return _get_array_priv(string, **kwargs)
43 |
44 | @functools.lru_cache(maxsize=10000)
45 | def _get_array_priv(
46 | string: str, *,
47 | size: int = 32,
48 | max_width: Optional[int]=None,
49 | max_height: Optional[int]=None,
50 | min_size=10,
51 | shrink_coef=0.8,
52 | dropshadow_radius: int=None,
53 | offset_x: int=None,
54 | offset_y: int=None,
55 | **kwargs
56 | ):
57 | cur_size = size
58 | array = None
59 | while True:
60 | if dropshadow_radius is not None:
61 | # separate implementation for dropshadow text rendering
62 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
63 | else:
64 | array = _get_array_impl(string, size=cur_size, **kwargs)
65 | height, width, _ = array.shape
66 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
67 | break
68 | cur_size = max(int(cur_size * shrink_coef), min_size)
69 | return array
70 |
71 | #----------------------------------------------------------------------------
72 |
73 | @functools.lru_cache(maxsize=10000)
74 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
75 | pil_font = get_pil_font(font=font, size=size)
76 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
77 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
78 | width = max(line.shape[1] for line in lines)
79 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
80 | line_spacing = line_pad if line_pad is not None else size // 2
81 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
82 | mask = np.concatenate(lines, axis=0)
83 | alpha = mask
84 | if outline > 0:
85 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
86 | alpha = mask.astype(np.float32) / 255
87 | alpha = scipy.ndimage.gaussian_filter(alpha, outline)
88 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
89 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
90 | alpha = np.maximum(alpha, mask)
91 | return np.stack([mask, alpha], axis=-1)
92 |
93 | #----------------------------------------------------------------------------
94 |
95 | @functools.lru_cache(maxsize=10000)
96 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
97 | assert (offset_x > 0) and (offset_y > 0)
98 | pil_font = get_pil_font(font=font, size=size)
99 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
100 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
101 | width = max(line.shape[1] for line in lines)
102 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
103 | line_spacing = line_pad if line_pad is not None else size // 2
104 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
105 | mask = np.concatenate(lines, axis=0)
106 | alpha = mask
107 |
108 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
109 | alpha = mask.astype(np.float32) / 255
110 | alpha = scipy.ndimage.gaussian_filter(alpha, radius)
111 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
112 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
113 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
114 | alpha = np.maximum(alpha, mask)
115 | return np.stack([mask, alpha], axis=-1)
116 |
117 | #----------------------------------------------------------------------------
118 |
119 | @functools.lru_cache(maxsize=10000)
120 | def get_texture(string, bilinear=True, mipmap=True, **kwargs):
121 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
122 |
123 | #----------------------------------------------------------------------------
124 |
--------------------------------------------------------------------------------
/stylesan-xl/in_embeddings/tf_efficientnet_lite0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/stylesan-xl/in_embeddings/tf_efficientnet_lite0.pkl
--------------------------------------------------------------------------------
/stylesan-xl/incl_licenses/LICENSE_1:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 autonomousvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/stylesan-xl/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/stylesan-xl/metrics/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Frechet Inception Distance (FID) from the paper
10 | "GANs trained by a two time-scale update rule converge to a local Nash
11 | equilibrium". Matches the original implementation by Heusel et al. at
12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13 |
14 | import numpy as np
15 | import scipy.linalg
16 | from . import metric_utils
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def compute_fid(opts, max_real, num_gen, sfid=False, rfid=False):
21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24 | if rfid:
25 | detector_url = 'https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/feature_networks/inception_rand_full.pkl'
26 | detector_kwargs = {} # random inception network returns features by default
27 |
28 |
29 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, sfid=sfid).get_mean_cov()
32 |
33 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
34 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
35 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, sfid=sfid).get_mean_cov()
36 |
37 | if opts.rank != 0:
38 | return float('nan')
39 |
40 | m = np.square(mu_gen - mu_real).sum()
41 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
42 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
43 | return float(fid)
44 |
45 | #----------------------------------------------------------------------------
46 |
--------------------------------------------------------------------------------
/stylesan-xl/metrics/inception_score.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Inception Score (IS) from the paper "Improved techniques for training
10 | GANs". Matches the original implementation by Salimans et al. at
11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12 |
13 | import numpy as np
14 | from . import metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_is(opts, num_gen, num_splits):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22 |
23 | gen_probs = metric_utils.compute_feature_stats_for_generator(
24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25 | capture_all=True, max_items=num_gen).get_all()
26 |
27 | if opts.rank != 0:
28 | return float('nan'), float('nan')
29 |
30 | scores = []
31 | for i in range(num_splits):
32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34 | kl = np.mean(np.sum(kl, axis=1))
35 | scores.append(np.exp(kl))
36 | return float(np.mean(scores)), float(np.std(scores))
37 |
38 | #----------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/stylesan-xl/metrics/kernel_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10 | GANs". Matches the original implementation by Binkowski et al. at
11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12 |
13 | import numpy as np
14 | from . import metric_utils
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22 |
23 | real_features = metric_utils.compute_feature_stats_for_dataset(
24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26 |
27 | gen_features = metric_utils.compute_feature_stats_for_generator(
28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30 |
31 | if opts.rank != 0:
32 | return float('nan')
33 |
34 | n = real_features.shape[1]
35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36 | t = 0
37 | for _subset_idx in range(num_subsets):
38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41 | b = (x @ y.T / n + 1) ** 3
42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43 | kid = t / num_subsets / m
44 | return float(kid)
45 |
46 | #----------------------------------------------------------------------------
47 |
--------------------------------------------------------------------------------
/stylesan-xl/metrics/metric_main.py:
--------------------------------------------------------------------------------
1 | # distribution of this software and related documentation without an express
2 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
3 |
4 | """Main API for computing and reporting quality metrics."""
5 |
6 | import os
7 | import time
8 | import json
9 | import torch
10 | import dnnlib
11 |
12 | from . import metric_utils
13 | from . import frechet_inception_distance
14 | from . import kernel_inception_distance
15 | from . import precision_recall
16 | from . import perceptual_path_length
17 | from . import inception_score
18 | from . import equivariance
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | _metric_dict = dict() # name => fn
23 |
24 | def register_metric(fn):
25 | assert callable(fn)
26 | _metric_dict[fn.__name__] = fn
27 | return fn
28 |
29 | def is_valid_metric(metric):
30 | return metric in _metric_dict
31 |
32 | def list_valid_metrics():
33 | return list(_metric_dict.keys())
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
38 | assert is_valid_metric(metric)
39 | opts = metric_utils.MetricOptions(**kwargs)
40 |
41 | # Calculate.
42 | start_time = time.time()
43 | results = _metric_dict[metric](opts)
44 | total_time = time.time() - start_time
45 |
46 | # Broadcast results.
47 | for key, value in list(results.items()):
48 | if opts.num_gpus > 1:
49 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
50 | torch.distributed.broadcast(tensor=value, src=0)
51 | value = float(value.cpu())
52 | results[key] = value
53 |
54 | # Decorate with metadata.
55 | return dnnlib.EasyDict(
56 | results = dnnlib.EasyDict(results),
57 | metric = metric,
58 | total_time = total_time,
59 | total_time_str = dnnlib.util.format_time(total_time),
60 | num_gpus = opts.num_gpus,
61 | )
62 |
63 | #----------------------------------------------------------------------------
64 |
65 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
66 | metric = result_dict['metric']
67 | assert is_valid_metric(metric)
68 | if run_dir is not None and snapshot_pkl is not None:
69 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
70 |
71 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
72 | print(jsonl_line)
73 | if run_dir is not None and os.path.isdir(run_dir):
74 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
75 | f.write(jsonl_line + '\n')
76 |
77 | #----------------------------------------------------------------------------
78 | # Recommended metrics.
79 |
80 | @register_metric
81 | def fid50k_full(opts):
82 | opts.dataset_kwargs.update(max_size=None, xflip=False)
83 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
84 | return dict(fid50k_full=fid)
85 |
86 | @register_metric
87 | def fid10k_full(opts):
88 | opts.dataset_kwargs.update(max_size=None, xflip=False)
89 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000)
90 | return dict(fid10k_full=fid)
91 |
92 | @register_metric
93 | def kid50k_full(opts):
94 | opts.dataset_kwargs.update(max_size=None, xflip=False)
95 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
96 | return dict(kid50k_full=kid)
97 |
98 | @register_metric
99 | def pr50k3_full(opts):
100 | opts.dataset_kwargs.update(max_size=None, xflip=False)
101 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
102 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
103 |
104 | @register_metric
105 | def ppl2_wend(opts):
106 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
107 | return dict(ppl2_wend=ppl)
108 |
109 | @register_metric
110 | def eqt50k_int(opts):
111 | opts.G_kwargs.update(force_fp32=True)
112 | del opts.G_kwargs.truncation_psi
113 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
114 | return dict(eqt50k_int=psnr)
115 |
116 | @register_metric
117 | def eqt50k_frac(opts):
118 | opts.G_kwargs.update(force_fp32=True)
119 | del opts.G_kwargs.truncation_psi
120 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
121 | return dict(eqt50k_frac=psnr)
122 |
123 | @register_metric
124 | def eqr50k(opts):
125 | opts.G_kwargs.update(force_fp32=True)
126 | del opts.G_kwargs.truncation_psi
127 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
128 | return dict(eqr50k=psnr)
129 |
130 | #----------------------------------------------------------------------------
131 | # New Metrics
132 |
133 | def clipfid50k_full(opts):
134 | opts.dataset_kwargs.update(max_size=None, xflip=False)
135 | opts.feature_network = 'resnet50_clip'
136 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
137 | return dict(clipfid50k_full=fid)
138 |
139 | @register_metric
140 | def sfid50k_full(opts):
141 | opts.dataset_kwargs.update(max_size=None, xflip=False)
142 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000, sfid=True)
143 | return dict(sfid50k_full=fid)
144 |
145 | @register_metric
146 | def rfid50k_full(opts):
147 | opts.dataset_kwargs.update(max_size=None, xflip=False)
148 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000, rfid=True)
149 | return dict(rfid50k_full=fid)
150 |
151 | #----------------------------------------------------------------------------
152 | # Legacy metrics.
153 |
154 | @register_metric
155 | def fid50k(opts):
156 | opts.dataset_kwargs.update(max_size=None)
157 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
158 | return dict(fid50k=fid)
159 |
160 | @register_metric
161 | def kid50k(opts):
162 | opts.dataset_kwargs.update(max_size=None)
163 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
164 | return dict(kid50k=kid)
165 |
166 | @register_metric
167 | def pr50k3(opts):
168 | opts.dataset_kwargs.update(max_size=None)
169 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
170 | return dict(pr50k3_precision=precision, pr50k3_recall=recall)
171 |
172 | @register_metric
173 | def pr10k3(opts):
174 | opts.dataset_kwargs.update(max_size=None)
175 | precision, recall = precision_recall.compute_pr(opts, max_real=10000, num_gen=10000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
176 | return dict(pr10k3_precision=precision, pr10k3_recall=recall)
177 |
178 | @register_metric
179 | def is50k(opts):
180 | opts.dataset_kwargs.update(max_size=None, xflip=False)
181 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
182 | return dict(is50k_mean=mean, is50k_std=std)
183 |
--------------------------------------------------------------------------------
/stylesan-xl/metrics/perceptual_path_length.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10 | Architecture for Generative Adversarial Networks". Matches the original
11 | implementation by Karras et al. at
12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13 |
14 | import copy
15 | import numpy as np
16 | import torch
17 | from . import metric_utils
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | # Spherical interpolation of a batch of vectors.
22 | def slerp(a, b, t):
23 | a = a / a.norm(dim=-1, keepdim=True)
24 | b = b / b.norm(dim=-1, keepdim=True)
25 | d = (a * b).sum(dim=-1, keepdim=True)
26 | p = t * torch.acos(d)
27 | c = b - d * a
28 | c = c / c.norm(dim=-1, keepdim=True)
29 | d = a * torch.cos(p) + c * torch.sin(p)
30 | d = d / d.norm(dim=-1, keepdim=True)
31 | return d
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | class PPLSampler(torch.nn.Module):
36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
37 | assert space in ['z', 'w']
38 | assert sampling in ['full', 'end']
39 | super().__init__()
40 | self.G = copy.deepcopy(G)
41 | self.G_kwargs = G_kwargs
42 | self.epsilon = epsilon
43 | self.space = space
44 | self.sampling = sampling
45 | self.crop = crop
46 | self.vgg16 = copy.deepcopy(vgg16)
47 |
48 | def forward(self, c):
49 | # Generate random latents and interpolation t-values.
50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
52 |
53 | # Interpolate in W or Z.
54 | if self.space == 'w':
55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
58 | else: # space == 'z'
59 | zt0 = slerp(z0, z1, t.unsqueeze(1))
60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
62 |
63 | # Randomize noise buffers.
64 | for name, buf in self.G.named_buffers():
65 | if name.endswith('.noise_const'):
66 | buf.copy_(torch.randn_like(buf))
67 |
68 | # Generate images.
69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
70 |
71 | # Center crop.
72 | if self.crop:
73 | assert img.shape[2] == img.shape[3]
74 | c = img.shape[2] // 8
75 | img = img[:, :, c*3 : c*7, c*2 : c*6]
76 |
77 | # Downsample to 256x256.
78 | factor = self.G.img_resolution // 256
79 | if factor > 1:
80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
81 |
82 | # Scale dynamic range from [-1,1] to [0,255].
83 | img = (img + 1) * (255 / 2)
84 | if self.G.img_channels == 1:
85 | img = img.repeat([1, 3, 1, 1])
86 |
87 | # Evaluate differential LPIPS.
88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
90 | return dist
91 |
92 | #----------------------------------------------------------------------------
93 |
94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
97 |
98 | # Setup sampler and labels.
99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
100 | sampler.eval().requires_grad_(False).to(opts.device)
101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
102 |
103 | # Sampling loop.
104 | dist = []
105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
107 | progress.update(batch_start)
108 | x = sampler(next(c_iter))
109 | for src in range(opts.num_gpus):
110 | y = x.clone()
111 | if opts.num_gpus > 1:
112 | torch.distributed.broadcast(y, src=src)
113 | dist.append(y)
114 | progress.update(num_samples)
115 |
116 | # Compute PPL.
117 | if opts.rank != 0:
118 | return float('nan')
119 | dist = torch.cat(dist)[:num_samples].cpu().numpy()
120 | lo = np.percentile(dist, 1, interpolation='lower')
121 | hi = np.percentile(dist, 99, interpolation='higher')
122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
123 | return float(ppl)
124 |
125 | #----------------------------------------------------------------------------
126 |
--------------------------------------------------------------------------------
/stylesan-xl/metrics/precision_recall.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall
10 | Metric for Assessing Generative Models". Matches the original implementation
11 | by Kynkaanniemi et al. at
12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13 |
14 | import torch
15 | from . import metric_utils
16 | from tqdm import tqdm
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
21 | assert 0 <= rank < num_gpus
22 | num_cols = col_features.shape[0]
23 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
24 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
25 | dist_batches = []
26 | for col_batch in col_batches[rank :: num_gpus]:
27 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
28 | for src in range(num_gpus):
29 | dist_broadcast = dist_batch.clone()
30 | if num_gpus > 1:
31 | torch.distributed.broadcast(dist_broadcast, src=src)
32 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
33 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
38 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
39 | detector_kwargs = dict(return_features=True)
40 | max_real = max_real // opts.num_gpus
41 |
42 | real_features = metric_utils.compute_feature_stats_for_dataset(
43 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
44 | rel_lo=0, rel_hi=0, capture_all=True, max_items=None, shuffle_size=max_real).get_all_torch().to(torch.float16).to(opts.device)
45 |
46 | gen_features = metric_utils.compute_feature_stats_for_generator(
47 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
48 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
49 |
50 | results = dict()
51 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
52 | kth = []
53 | for manifold_batch in tqdm(manifold.split(row_batch_size)):
54 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
55 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
56 | kth = torch.cat(kth) if opts.rank == 0 else None
57 | pred = []
58 | for probes_batch in tqdm(probes.split(row_batch_size)):
59 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
60 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
61 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
62 | return results['precision'], results['recall']
63 |
64 | #----------------------------------------------------------------------------
65 |
--------------------------------------------------------------------------------
/stylesan-xl/pg_modules/projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from feature_networks.vit import forward_vit
5 | from feature_networks.pretrained_builder import _make_pretrained
6 | from feature_networks.constants import NORMALIZED_INCEPTION, NORMALIZED_IMAGENET, NORMALIZED_CLIP, VITS
7 | from pg_modules.blocks import FeatureFusionBlock
8 |
9 | def get_backbone_normstats(backbone):
10 | if backbone in NORMALIZED_INCEPTION:
11 | return {
12 | 'mean': [0.5, 0.5, 0.5],
13 | 'std': [0.5, 0.5, 0.5],
14 | }
15 |
16 | elif backbone in NORMALIZED_IMAGENET:
17 | return {
18 | 'mean': [0.485, 0.456, 0.406],
19 | 'std': [0.229, 0.224, 0.225],
20 | }
21 |
22 | elif backbone in NORMALIZED_CLIP:
23 | return {
24 | 'mean': [0.48145466, 0.4578275, 0.40821073],
25 | 'std': [0.26862954, 0.26130258, 0.27577711],
26 | }
27 |
28 | else:
29 | raise NotImplementedError
30 |
31 | def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
32 | # shapes
33 | out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
34 |
35 | scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
36 | scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
37 | scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
38 | scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
39 |
40 | scratch.CHANNELS = out_channels
41 |
42 | return scratch
43 |
44 | def _make_scratch_csm(scratch, in_channels, cout, expand):
45 | scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
46 | scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
47 | scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
48 | scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
49 |
50 | # last refinenet does not expand to save channels in higher dimensions
51 | scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
52 |
53 | return scratch
54 |
55 | def _make_projector(im_res, backbone, cout, proj_type, expand=False):
56 | assert proj_type in [0, 1, 2], "Invalid projection type"
57 |
58 | ### Build pretrained feature network
59 | pretrained = _make_pretrained(backbone)
60 |
61 | # Following Projected GAN
62 | im_res = 256
63 | pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
64 |
65 | if proj_type == 0: return pretrained, None
66 |
67 | ### Build CCM
68 | scratch = nn.Module()
69 | scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
70 |
71 | pretrained.CHANNELS = scratch.CHANNELS
72 |
73 | if proj_type == 1: return pretrained, scratch
74 |
75 | ### build CSM
76 | scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
77 |
78 | # CSM upsamples x2 so the feature map resolution doubles
79 | pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
80 | pretrained.CHANNELS = scratch.CHANNELS
81 |
82 | return pretrained, scratch
83 |
84 | class F_Identity(nn.Module):
85 | def forward(self, x):
86 | return x
87 |
88 | class F_RandomProj(nn.Module):
89 | def __init__(
90 | self,
91 | backbone="tf_efficientnet_lite3",
92 | im_res=256,
93 | cout=64,
94 | expand=True,
95 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
96 | **kwargs,
97 | ):
98 | super().__init__()
99 | self.proj_type = proj_type
100 | self.backbone = backbone
101 | self.cout = cout
102 | self.expand = expand
103 | self.normstats = get_backbone_normstats(backbone)
104 |
105 | # build pretrained feature network and random decoder (scratch)
106 | self.pretrained, self.scratch = _make_projector(im_res=im_res, backbone=self.backbone, cout=self.cout,
107 | proj_type=self.proj_type, expand=self.expand)
108 | self.CHANNELS = self.pretrained.CHANNELS
109 | self.RESOLUTIONS = self.pretrained.RESOLUTIONS
110 |
111 | def forward(self, x):
112 | # predict feature maps
113 | if self.backbone in VITS:
114 | out0, out1, out2, out3 = forward_vit(self.pretrained, x)
115 | else:
116 | out0 = self.pretrained.layer0(x)
117 | out1 = self.pretrained.layer1(out0)
118 | out2 = self.pretrained.layer2(out1)
119 | out3 = self.pretrained.layer3(out2)
120 |
121 | # start enumerating at the lowest layer (this is where we put the first discriminator)
122 | out = {
123 | '0': out0,
124 | '1': out1,
125 | '2': out2,
126 | '3': out3,
127 | }
128 |
129 | if self.proj_type == 0: return out
130 |
131 | out0_channel_mixed = self.scratch.layer0_ccm(out['0'])
132 | out1_channel_mixed = self.scratch.layer1_ccm(out['1'])
133 | out2_channel_mixed = self.scratch.layer2_ccm(out['2'])
134 | out3_channel_mixed = self.scratch.layer3_ccm(out['3'])
135 |
136 | out = {
137 | '0': out0_channel_mixed,
138 | '1': out1_channel_mixed,
139 | '2': out2_channel_mixed,
140 | '3': out3_channel_mixed,
141 | }
142 |
143 | if self.proj_type == 1: return out
144 |
145 | # from bottom to top
146 | out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
147 | out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
148 | out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
149 | out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
150 |
151 | out = {
152 | '0': out0_scale_mixed,
153 | '1': out1_scale_mixed,
154 | '2': out2_scale_mixed,
155 | '3': out3_scale_mixed,
156 | }
157 |
158 | return out
159 |
--------------------------------------------------------------------------------
/stylesan-xl/pg_modules/san_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def _normalize(tensor, dim):
7 | denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-12)
8 | return tensor / denom
9 |
10 |
11 | class SANLinear(nn.Linear):
12 |
13 | def __init__(self,
14 | in_features,
15 | out_features,
16 | bias=True,
17 | device=None,
18 | dtype=None
19 | ):
20 | super(SANLinear, self).__init__(
21 | in_features, out_features, bias=bias, device=device, dtype=dtype)
22 | scale = self.weight.norm(p=2.0, dim=1, keepdim=True).clamp_min(1e-12)
23 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
24 | self.scale = nn.parameter.Parameter(scale.view(out_features))
25 | if bias:
26 | self.bias = nn.parameter.Parameter(torch.zeros(in_features, device=device, dtype=dtype))
27 | else:
28 | self.register_parameter('bias', None)
29 |
30 | def forward(self, input, flg_train=False):
31 | if self.bias is not None:
32 | input = input + self.bias
33 | normalized_weight = self._get_normalized_weight()
34 | scale = self.scale
35 | if flg_train:
36 | out_fun = F.linear(input, normalized_weight.detach(), None)
37 | out_dir = F.linear(input.detach(), normalized_weight, None)
38 | out = [out_fun * scale, out_dir * scale.detach()]
39 | else:
40 | out = F.linear(input, normalized_weight, None)
41 | out = out * scale
42 | return out
43 |
44 | @torch.no_grad()
45 | def normalize_weight(self):
46 | self.weight.data = self._get_normalized_weight()
47 |
48 | def _get_normalized_weight(self):
49 | return _normalize(self.weight, dim=1)
50 |
51 |
52 | class SANConv1d(nn.Conv1d):
53 |
54 | def __init__(self,
55 | in_channels,
56 | out_channels,
57 | kernel_size,
58 | stride=1,
59 | padding=0,
60 | dilation=1,
61 | bias=True,
62 | padding_mode='zeros',
63 | device=None,
64 | dtype=None
65 | ):
66 | super(SANConv1d, self).__init__(
67 | in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation,
68 | groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype)
69 | scale = self.weight.norm(p=2.0, dim=[1, 2], keepdim=True).clamp_min(1e-12)
70 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
71 | self.scale = nn.parameter.Parameter(scale.view(out_channels))
72 | if bias:
73 | self.bias = nn.parameter.Parameter(torch.zeros(in_channels, device=device, dtype=dtype))
74 | else:
75 | self.register_parameter('bias', None)
76 |
77 | def forward(self, input, flg_train=False):
78 | if self.bias is not None:
79 | input = input + self.bias.view(self.in_channels, 1)
80 | normalized_weight = self._get_normalized_weight()
81 | scale = self.scale.view(self.out_channels, 1)
82 | if flg_train:
83 | out_fun = F.conv1d(input, normalized_weight.detach(), None, self.stride,
84 | self.padding, self.dilation, self.groups)
85 | out_dir = F.conv1d(input.detach(), normalized_weight, None, self.stride,
86 | self.padding, self.dilation, self.groups)
87 | out = [out_fun * scale, out_dir * scale.detach()]
88 | else:
89 | out = F.conv1d(input, normalized_weight, None, self.stride,
90 | self.padding, self.dilation, self.groups)
91 | out = out * scale
92 | return out
93 |
94 | @torch.no_grad()
95 | def normalize_weight(self):
96 | self.weight.data = self._get_normalized_weight()
97 |
98 | def _get_normalized_weight(self):
99 | return _normalize(self.weight, dim=[1, 2])
100 |
101 |
102 | class SANConv2d(nn.Conv2d):
103 |
104 | def __init__(self,
105 | in_channels,
106 | out_channels,
107 | kernel_size,
108 | stride=1,
109 | padding=0,
110 | dilation=1,
111 | bias=True,
112 | padding_mode='zeros',
113 | device=None,
114 | dtype=None
115 | ):
116 | super(SANConv2d, self).__init__(
117 | in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation,
118 | groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype)
119 | scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-12)
120 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
121 | self.scale = nn.parameter.Parameter(scale.view(out_channels))
122 | if bias:
123 | self.bias = nn.parameter.Parameter(torch.zeros(in_channels, device=device, dtype=dtype))
124 | else:
125 | self.register_parameter('bias', None)
126 |
127 | def forward(self, input, flg_train=False):
128 | if self.bias is not None:
129 | input = input + self.bias.view(self.in_channels, 1, 1)
130 | normalized_weight = self._get_normalized_weight()
131 | scale = self.scale.view(self.out_channels, 1, 1)
132 | if flg_train:
133 | out_fun = F.conv2d(input, normalized_weight.detach(), None, self.stride,
134 | self.padding, self.dilation, self.groups)
135 | out_dir = F.conv2d(input.detach(), normalized_weight, None, self.stride,
136 | self.padding, self.dilation, self.groups)
137 | out = [out_fun * scale, out_dir * scale.detach()]
138 | else:
139 | out = F.conv2d(input, normalized_weight, None, self.stride,
140 | self.padding, self.dilation, self.groups)
141 | out = out * scale
142 | return out
143 |
144 | @torch.no_grad()
145 | def normalize_weight(self):
146 | self.weight.data = self._get_normalized_weight()
147 |
148 | def _get_normalized_weight(self):
149 | return _normalize(self.weight, dim=[1, 2, 3])
150 |
151 |
152 | class SANEmbedding(nn.Embedding):
153 |
154 | def __init__(self, num_embeddings, embedding_dim,
155 | scale_grad_by_freq=False,
156 | sparse=False, _weight=None,
157 | device=None, dtype=None):
158 | super(SANEmbedding, self).__init__(
159 | num_embeddings, embedding_dim, padding_idx=None,
160 | max_norm=None, norm_type=2., scale_grad_by_freq=scale_grad_by_freq,
161 | sparse=sparse, _weight=_weight,
162 | device=device, dtype=dtype)
163 | scale = self.weight.norm(p=2.0, dim=1, keepdim=True).clamp_min(1e-12)
164 | self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
165 | self.scale = nn.parameter.Parameter(scale)
166 |
167 | def forward(self, input, flg_train=False):
168 | out = F.embedding(
169 | input, self.weight, self.padding_idx, self.max_norm,
170 | self.norm_type, self.scale_grad_by_freq, self.sparse)
171 | out = _normalize(out, dim=-1)
172 | scale = F.embedding(
173 | input, self.scale, self.padding_idx, self.max_norm,
174 | self.norm_type, self.scale_grad_by_freq, self.sparse)
175 | if flg_train:
176 | out_fun = out.detach()
177 | out_dir = out
178 | out = [out_fun * scale, out_dir * scale.detach()]
179 | else:
180 | out = out * scale
181 | return out
182 |
183 | @torch.no_grad()
184 | def normalize_weight(self):
185 | self.weight.data = _normalize(self.weight, dim=1)
186 |
--------------------------------------------------------------------------------
/stylesan-xl/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.20
2 | click>=8.0
3 | pillow>=8.3.1
4 | scipy>=1.7.1
5 | requests>=2.26.0
6 | tqdm>=4.62.2
7 | ninja>=1.10.2
8 | matplotlib>=3.4.2
9 | imageio>=2.9.0
10 | dill>=0.3.4
11 | psutil>=5.8.0
12 | regex>=2022.3.15
13 | pillow>=8.3.1
14 | imgui>=1.3.0
15 | glfw>=2.2.0
16 | pyopengl>=3.1.5
17 | imageio-ffmpeg>=0.4.3
18 | pyspng
19 | ftfy>=6.1.1
20 | timm==0.4.12
21 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import glob
10 | import hashlib
11 | import importlib
12 | import os
13 | import re
14 | import shutil
15 | import uuid
16 |
17 | import torch
18 | import torch.utils.cpp_extension
19 | from torch.utils.file_baton import FileBaton
20 |
21 | #----------------------------------------------------------------------------
22 | # Global options.
23 |
24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25 |
26 | #----------------------------------------------------------------------------
27 | # Internal helper funcs.
28 |
29 | def _find_compiler_bindir():
30 | patterns = [
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
35 | ]
36 | for pattern in patterns:
37 | matches = sorted(glob.glob(pattern))
38 | if len(matches):
39 | return matches[-1]
40 | return None
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | def _get_mangled_gpu_name():
45 | name = torch.cuda.get_device_name().lower()
46 | out = []
47 | for c in name:
48 | if re.match('[a-z0-9_-]+', c):
49 | out.append(c)
50 | else:
51 | out.append('-')
52 | return ''.join(out)
53 |
54 | #----------------------------------------------------------------------------
55 | # Main entry point for compiling and loading C++/CUDA plugins.
56 |
57 | _cached_plugins = dict()
58 |
59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60 | assert verbosity in ['none', 'brief', 'full']
61 | if headers is None:
62 | headers = []
63 | if source_dir is not None:
64 | sources = [os.path.join(source_dir, fname) for fname in sources]
65 | headers = [os.path.join(source_dir, fname) for fname in headers]
66 |
67 | # Already cached?
68 | if module_name in _cached_plugins:
69 | return _cached_plugins[module_name]
70 |
71 | # Print status.
72 | if verbosity == 'full':
73 | print(f'Setting up PyTorch plugin "{module_name}"...')
74 | elif verbosity == 'brief':
75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76 | verbose_build = (verbosity == 'full')
77 |
78 | # Compile and load.
79 | try: # pylint: disable=too-many-nested-blocks
80 | # Make sure we can find the necessary compiler binaries.
81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82 | compiler_bindir = _find_compiler_bindir()
83 | if compiler_bindir is None:
84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85 | os.environ['PATH'] += ';' + compiler_bindir
86 |
87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88 | # break the build or unnecessarily restrict what's available to nvcc.
89 | # Unset it to let nvcc decide based on what's available on the
90 | # machine.
91 | os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92 |
93 | # Incremental build md5sum trickery. Copies all the input source files
94 | # into a cached build directory under a combined md5 digest of the input
95 | # source files. Copying is done only if the combined digest has changed.
96 | # This keeps input file timestamps and filenames the same as in previous
97 | # extension builds, allowing for fast incremental rebuilds.
98 | #
99 | # This optimization is done only in case all the source files reside in
100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101 | # environment variable is set (we take this as a signal that the user
102 | # actually cares about this.)
103 | #
104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105 | # around the *.cu dependency bug in ninja config.
106 | #
107 | all_source_files = sorted(sources + headers)
108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110 |
111 | # Compute combined hash digest for all source files.
112 | hash_md5 = hashlib.md5()
113 | for src in all_source_files:
114 | with open(src, 'rb') as f:
115 | hash_md5.update(f.read())
116 |
117 | # Select cached build directory name.
118 | source_digest = hash_md5.hexdigest()
119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121 |
122 | if not os.path.isdir(cached_build_dir):
123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124 | os.makedirs(tmpdir)
125 | for src in all_source_files:
126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127 | try:
128 | os.replace(tmpdir, cached_build_dir) # atomic
129 | except OSError:
130 | # source directory already exists, delete tmpdir and its contents.
131 | shutil.rmtree(tmpdir)
132 | if not os.path.isdir(cached_build_dir): raise
133 |
134 | # Compile.
135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137 | verbose=verbose_build, sources=cached_sources, **build_kwargs)
138 | else:
139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140 |
141 | # Load.
142 | module = importlib.import_module(module_name)
143 |
144 | except:
145 | if verbosity == 'brief':
146 | print('Failed!')
147 | raise
148 |
149 | # Print status and add to cache dict.
150 | if verbosity == 'full':
151 | print(f'Done setting up PyTorch plugin "{module_name}".')
152 | elif verbosity == 'brief':
153 | print('Done.')
154 | _cached_plugins[module_name] = module
155 | return module
156 |
157 | #----------------------------------------------------------------------------
158 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | if not flip_weight and (kw > 1 or kh > 1):
37 | w = w.flip([2, 3])
38 |
39 | # Execute using conv2d_gradfix.
40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41 | return op(x, w, stride=stride, padding=padding, groups=groups)
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | @misc.profiled_function
46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47 | r"""2D convolution with optional up/downsampling.
48 |
49 | Padding is performed only once at the beginning, not between the operations.
50 |
51 | Args:
52 | x: Input tensor of shape
53 | `[batch_size, in_channels, in_height, in_width]`.
54 | w: Weight tensor of shape
55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57 | calling upfirdn2d.setup_filter(). None = identity (default).
58 | up: Integer upsampling factor (default: 1).
59 | down: Integer downsampling factor (default: 1).
60 | padding: Padding with respect to the upsampled image. Can be a single number
61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62 | (default: 0).
63 | groups: Split input channels into N groups (default: 1).
64 | flip_weight: False = convolution, True = correlation (default: True).
65 | flip_filter: False = convolution, True = correlation (default: False).
66 |
67 | Returns:
68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69 | """
70 | # Validate arguments.
71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74 | assert isinstance(up, int) and (up >= 1)
75 | assert isinstance(down, int) and (down >= 1)
76 | assert isinstance(groups, int) and (groups >= 1)
77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78 | fw, fh = _get_filter_size(f)
79 | px0, px1, py0, py1 = _parse_padding(padding)
80 |
81 | # Adjust padding to account for up/downsampling.
82 | if up > 1:
83 | px0 += (fw + up - 1) // 2
84 | px1 += (fw - up) // 2
85 | py0 += (fh + up - 1) // 2
86 | py1 += (fh - up) // 2
87 | if down > 1:
88 | px0 += (fw - down + 1) // 2
89 | px1 += (fw - down) // 2
90 | py0 += (fh - down + 1) // 2
91 | py1 += (fh - down) // 2
92 |
93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97 | return x
98 |
99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103 | return x
104 |
105 | # Fast path: downsampling only => use strided convolution.
106 | if down > 1 and up == 1:
107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109 | return x
110 |
111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112 | if up > 1:
113 | if groups == 1:
114 | w = w.transpose(0, 1)
115 | else:
116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117 | w = w.transpose(1, 2)
118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119 | px0 -= kw - 1
120 | px1 -= kw - up
121 | py0 -= kh - 1
122 | py1 -= kh - up
123 | pxt = max(min(-px0, -px1), 0)
124 | pyt = max(min(-py0, -py1), 0)
125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127 | if down > 1:
128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129 | return x
130 |
131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132 | if up == 1 and down == 1:
133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135 |
136 | # Fallback: Generic reference implementation.
137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139 | if down > 1:
140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141 | return x
142 |
143 | #----------------------------------------------------------------------------
144 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct filtered_lrelu_kernel_params
15 | {
16 | // These parameters decide which kernel to use.
17 | int up; // upsampling ratio (1, 2, 4)
18 | int down; // downsampling ratio (1, 2, 4)
19 | int2 fuShape; // [size, 1] | [size, size]
20 | int2 fdShape; // [size, 1] | [size, size]
21 |
22 | int _dummy; // Alignment.
23 |
24 | // Rest of the parameters.
25 | const void* x; // Input tensor.
26 | void* y; // Output tensor.
27 | const void* b; // Bias tensor.
28 | unsigned char* s; // Sign tensor in/out. NULL if unused.
29 | const float* fu; // Upsampling filter.
30 | const float* fd; // Downsampling filter.
31 |
32 | int2 pad0; // Left/top padding.
33 | float gain; // Additional gain factor.
34 | float slope; // Leaky ReLU slope on negative side.
35 | float clamp; // Clamp after nonlinearity.
36 | int flip; // Filter kernel flip for gradient computation.
37 |
38 | int tilesXdim; // Original number of horizontal output tiles.
39 | int tilesXrep; // Number of horizontal tiles per CTA.
40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41 |
42 | int4 xShape; // [width, height, channel, batch]
43 | int4 yShape; // [width, height, channel, batch]
44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46 | int swLimit; // Active width of sign tensor in bytes.
47 |
48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49 | longlong4 yStride; //
50 | int64_t bStride; //
51 | longlong3 fuStride; //
52 | longlong3 fdStride; //
53 | };
54 |
55 | struct filtered_lrelu_act_kernel_params
56 | {
57 | void* x; // Input/output, modified in-place.
58 | unsigned char* s; // Sign tensor in/out. NULL if unused.
59 |
60 | float gain; // Additional gain factor.
61 | float slope; // Leaky ReLU slope on negative side.
62 | float clamp; // Clamp after nonlinearity.
63 |
64 | int4 xShape; // [width, height, channel, batch]
65 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68 | };
69 |
70 | //------------------------------------------------------------------------
71 | // CUDA kernel specialization.
72 |
73 | struct filtered_lrelu_kernel_spec
74 | {
75 | void* setup; // Function for filter kernel setup.
76 | void* exec; // Function for main operation.
77 | int2 tileOut; // Width/height of launch tile.
78 | int numWarps; // Number of warps per thread block, determines launch block size.
79 | int xrep; // For processing multiple horizontal tiles per thread block.
80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81 | };
82 |
83 | //------------------------------------------------------------------------
84 | // CUDA kernel selection.
85 |
86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87 | template void* choose_filtered_lrelu_act_kernel(void);
88 | template cudaError_t copy_filters(cudaStream_t stream);
89 |
90 | //------------------------------------------------------------------------
91 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for no signs mode (no gradients required).
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign read mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign write mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import torch
15 | from pkg_resources import parse_version
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def grid_sample(input, grid):
29 | if _should_use_custom_op():
30 | return _GridSample2dForward.apply(input, grid)
31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def _should_use_custom_op():
36 | return enabled
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | class _GridSample2dForward(torch.autograd.Function):
41 | @staticmethod
42 | def forward(ctx, input, grid):
43 | assert input.ndim == 4
44 | assert grid.ndim == 4
45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
46 | ctx.save_for_backward(input, grid)
47 | return output
48 |
49 | @staticmethod
50 | def backward(ctx, grad_output):
51 | input, grid = ctx.saved_tensors
52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
53 | return grad_input, grad_grid
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | class _GridSample2dBackward(torch.autograd.Function):
58 | @staticmethod
59 | def forward(ctx, grad_output, input, grid):
60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
61 | if _use_pytorch_1_11_api:
62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
64 | else:
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.numel() > 0, "x has zero size");
25 | TORCH_CHECK(f.numel() > 0, "f has zero size");
26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32 |
33 | // Create output tensor.
34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41 |
42 | // Initialize CUDA kernel parameters.
43 | upfirdn2d_kernel_params p;
44 | p.x = x.data_ptr();
45 | p.f = f.data_ptr();
46 | p.y = y.data_ptr();
47 | p.up = make_int2(upx, upy);
48 | p.down = make_int2(downx, downy);
49 | p.pad0 = make_int2(padx0, pady0);
50 | p.flip = (flip) ? 1 : 0;
51 | p.gain = gain;
52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60 |
61 | // Choose CUDA kernel.
62 | upfirdn2d_kernel_spec spec;
63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64 | {
65 | spec = choose_upfirdn2d_kernel(p);
66 | });
67 |
68 | // Set looping options.
69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70 | p.loopMinor = spec.loopMinor;
71 | p.loopX = spec.loopX;
72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74 |
75 | // Compute grid size.
76 | dim3 blockSize, gridSize;
77 | if (spec.tileOutW < 0) // large
78 | {
79 | blockSize = dim3(4, 32, 1);
80 | gridSize = dim3(
81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83 | p.launchMajor);
84 | }
85 | else // small
86 | {
87 | blockSize = dim3(256, 1, 1);
88 | gridSize = dim3(
89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91 | p.launchMajor);
92 | }
93 |
94 | // Launch CUDA kernel.
95 | void* args[] = {&p};
96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97 | return y;
98 | }
99 |
100 | //------------------------------------------------------------------------
101 |
102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103 | {
104 | m.def("upfirdn2d", &upfirdn2d);
105 | }
106 |
107 | //------------------------------------------------------------------------
108 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/stylesan-xl/torch_utils/utils_spectrum.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.fft import fftn
3 |
4 |
5 | def roll_quadrants(data, backwards=False):
6 | """
7 | Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
8 | Args:
9 | data: fourier transform, (NxHxW)
10 | backwards: bool, if True shift high frequencies back to center
11 |
12 | Returns:
13 | Shifted fourier transform.
14 | """
15 | dim = data.ndim - 1
16 |
17 | if dim != 2:
18 | raise AttributeError(f'Data must be 2d but it is {dim}d.')
19 | if any(s % 2 == 0 for s in data.shape[1:]):
20 | raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')
21 |
22 | # for each dimension swap left and right half
23 | dims = tuple(range(1, dim+1)) # add one for batch dimension
24 | shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd
25 | if backwards:
26 | shifts *= -1
27 | return data.roll(shifts.tolist(), dims=dims)
28 |
29 |
30 | def batch_fft(data, normalize=False):
31 | """
32 | Compute fourier transform of batch.
33 | Args:
34 | data: input tensor, (NxHxW)
35 |
36 | Returns:
37 | Batch fourier transform of input data.
38 | """
39 |
40 | dim = data.ndim - 1 # subtract one for batch dimension
41 | if dim != 2:
42 | raise AttributeError(f'Data must be 2d but it is {dim}d.')
43 |
44 | dims = tuple(range(1, dim + 1)) # add one for batch dimension
45 | if normalize:
46 | norm = 'ortho'
47 | else:
48 | norm = 'backward'
49 |
50 | if not torch.is_complex(data):
51 | data = torch.complex(data, torch.zeros_like(data))
52 | freq = fftn(data, dim=dims, norm=norm)
53 |
54 | return freq
55 |
56 |
57 | def azimuthal_average(image, center=None):
58 | # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
59 | """
60 | Calculate the azimuthally averaged radial profile.
61 | Requires low frequencies to be at the center of the image.
62 | Args:
63 | image: Batch of 2D images, NxHxW
64 | center: The [x,y] pixel coordinates used as the center. The default is
65 | None, which then uses the center of the image (including
66 | fracitonal pixels).
67 |
68 | Returns:
69 | Azimuthal average over the image around the center
70 | """
71 | # Check input shapes
72 | assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
73 | f'(but it is len(center)={len(center)}.'
74 | # Calculate the indices from the image
75 | H, W = image.shape[-2:]
76 | h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
77 |
78 | if center is None:
79 | center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])
80 |
81 | # Compute radius for each pixel wrt center
82 | r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)
83 |
84 | # Get sorted radii
85 | r_sorted, ind = r.flatten().sort()
86 | i_sorted = image.flatten(-2, -1)[..., ind]
87 |
88 | # Get the integer part of the radii (bin size = 1)
89 | r_int = r_sorted.long() # attribute to the smaller integer
90 |
91 | # Find all pixels that fall within each radial bin.
92 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii
93 | rind = torch.where(deltar)[0] # location of changed radius
94 |
95 | # compute number of elements in each bin
96 | nind = rind + 1 # number of elements = idx + 1
97 | nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders
98 | nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius
99 |
100 | # Cumulative sum to figure out sums for each radius bin
101 | if H % 2 == 0:
102 | raise NotImplementedError('Not sure if implementation correct, please check')
103 | rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders
104 | else:
105 | rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders
106 | csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius
107 | tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
108 | # add mean
109 | tbin = torch.cat([csim[:, 0:1], tbin], 1)
110 |
111 | radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins
112 |
113 | return radial_prof
114 |
115 |
116 | def get_spectrum(data, normalize=False):
117 | dim = data.ndim - 1 # subtract one for batch dimension
118 | if dim != 2:
119 | raise AttributeError(f'Data must be 2d but it is {dim}d.')
120 |
121 | freq = batch_fft(data, normalize=normalize)
122 | power_spec = freq.real ** 2 + freq.imag ** 2
123 | N = data.shape[1]
124 | if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum
125 | # and is not averaged with the mean value
126 | N_2 = N//2
127 | power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
128 | power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)
129 |
130 | power_spec = roll_quadrants(power_spec)
131 | power_spec = azimuthal_average(power_spec)
132 | return power_spec
133 |
134 |
135 | def plot_std(mean, std, x=None, ax=None, **kwargs):
136 | import matplotlib.pyplot as plt
137 | if ax is None:
138 | fig, ax = plt.subplots(1)
139 |
140 | # plot error margins in same color as line
141 | err_kwargs = {
142 | 'alpha': 0.3
143 | }
144 |
145 | if 'c' in kwargs.keys():
146 | err_kwargs['color'] = kwargs['c']
147 | elif 'color' in kwargs.keys():
148 | err_kwargs['color'] = kwargs['color']
149 |
150 | if x is None:
151 | x = torch.linspace(0, 1, len(mean)) # use normalized x axis
152 | ax.plot(x, mean, **kwargs)
153 | ax.fill_between(x, mean-std, mean+std, **err_kwargs)
154 |
155 | return ax
156 |
--------------------------------------------------------------------------------
/stylesan-xl/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/stylesan-xl/training/diffaug.py:
--------------------------------------------------------------------------------
1 | # Differentiable Augmentation for Data-Efficient GAN Training
2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3 | # https://arxiv.org/pdf/2006.10738
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def DiffAugment(x, policy='', channels_first=True):
10 | if policy:
11 | if not channels_first:
12 | x = x.permute(0, 3, 1, 2)
13 | for p in policy.split(','):
14 | for f in AUGMENT_FNS[p]:
15 | x = f(x)
16 | if not channels_first:
17 | x = x.permute(0, 2, 3, 1)
18 | x = x.contiguous()
19 | return x
20 |
21 |
22 | def rand_brightness(x):
23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24 | return x
25 |
26 |
27 | def rand_saturation(x):
28 | x_mean = x.mean(dim=1, keepdim=True)
29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30 | return x
31 |
32 |
33 | def rand_contrast(x):
34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36 | return x
37 |
38 |
39 | def rand_translation(x, ratio=0.125):
40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43 | grid_batch, grid_x, grid_y = torch.meshgrid(
44 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
45 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
46 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
47 | )
48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52 | return x
53 |
54 |
55 | def rand_cutout(x, ratio=0.2):
56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59 | grid_batch, grid_x, grid_y = torch.meshgrid(
60 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63 | )
64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67 | mask[grid_batch, grid_x, grid_y] = 0
68 | x = x * mask.unsqueeze(1)
69 | return x
70 |
71 |
72 | AUGMENT_FNS = {
73 | 'color': [rand_brightness, rand_saturation, rand_contrast],
74 | 'translation': [rand_translation],
75 | 'cutout': [rand_cutout],
76 | }
77 |
--------------------------------------------------------------------------------
/stylesan-xl/training/networks_fastgan.py:
--------------------------------------------------------------------------------
1 | # original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py
2 | #
3 | # modified by Axel Sauer for "Projected GANs Converge Faster"
4 | #
5 | import torch.nn as nn
6 | from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d)
7 |
8 |
9 | def normalize_second_moment(x, dim=1, eps=1e-8):
10 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
11 |
12 |
13 | class DummyMapping(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def forward(self, z, c=None, **kwargs):
18 | return z.unsqueeze(1) # to fit the StyleGAN API
19 |
20 |
21 | class FastganSynthesis(nn.Module):
22 | def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False):
23 | super().__init__()
24 | self.img_resolution = img_resolution
25 | self.z_dim = z_dim
26 |
27 | # channel multiplier
28 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
29 | 512:0.25, 1024:0.125}
30 | nfc = {}
31 | for k, v in nfc_multi.items():
32 | nfc[k] = int(v*ngf)
33 |
34 | # layers
35 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
36 |
37 | UpBlock = UpBlockSmall if lite else UpBlockBig
38 |
39 | self.feat_8 = UpBlock(nfc[4], nfc[8])
40 | self.feat_16 = UpBlock(nfc[8], nfc[16])
41 | self.feat_32 = UpBlock(nfc[16], nfc[32])
42 | self.feat_64 = UpBlock(nfc[32], nfc[64])
43 | self.feat_128 = UpBlock(nfc[64], nfc[128])
44 | self.feat_256 = UpBlock(nfc[128], nfc[256])
45 |
46 | self.se_64 = SEBlock(nfc[4], nfc[64])
47 | self.se_128 = SEBlock(nfc[8], nfc[128])
48 | self.se_256 = SEBlock(nfc[16], nfc[256])
49 |
50 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
51 |
52 | if img_resolution > 256:
53 | self.feat_512 = UpBlock(nfc[256], nfc[512])
54 | self.se_512 = SEBlock(nfc[32], nfc[512])
55 | if img_resolution > 512:
56 | self.feat_1024 = UpBlock(nfc[512], nfc[1024])
57 |
58 | def forward(self, input, c=None, **kwargs):
59 | # map noise to hypersphere as in "Progressive Growing of GANS"
60 | input = normalize_second_moment(input[:, 0])
61 |
62 | feat_4 = self.init(input)
63 | feat_8 = self.feat_8(feat_4)
64 | feat_16 = self.feat_16(feat_8)
65 | feat_32 = self.feat_32(feat_16)
66 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32))
67 |
68 | if self.img_resolution >= 64:
69 | feat_last = feat_64
70 |
71 | if self.img_resolution >= 128:
72 | feat_last = self.se_128(feat_8, self.feat_128(feat_last))
73 |
74 | if self.img_resolution >= 256:
75 | feat_last = self.se_256(feat_16, self.feat_256(feat_last))
76 |
77 | if self.img_resolution >= 512:
78 | feat_last = self.se_512(feat_32, self.feat_512(feat_last))
79 |
80 | if self.img_resolution >= 1024:
81 | feat_last = self.feat_1024(feat_last)
82 |
83 | return self.to_big(feat_last)
84 |
85 |
86 | class FastganSynthesisCond(nn.Module):
87 | def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False):
88 | super().__init__()
89 |
90 | self.z_dim = z_dim
91 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5,
92 | 512:0.25, 1024:0.125, 2048:0.125}
93 | nfc = {}
94 | for k, v in nfc_multi.items():
95 | nfc[k] = int(v*ngf)
96 |
97 | self.img_resolution = img_resolution
98 |
99 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4)
100 |
101 | UpBlock = UpBlockSmallCond if lite else UpBlockBigCond
102 |
103 | self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim)
104 | self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim)
105 | self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim)
106 | self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim)
107 | self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim)
108 | self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim)
109 |
110 | self.se_64 = SEBlock(nfc[4], nfc[64])
111 | self.se_128 = SEBlock(nfc[8], nfc[128])
112 | self.se_256 = SEBlock(nfc[16], nfc[256])
113 |
114 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True)
115 |
116 | if img_resolution > 256:
117 | self.feat_512 = UpBlock(nfc[256], nfc[512])
118 | self.se_512 = SEBlock(nfc[32], nfc[512])
119 | if img_resolution > 512:
120 | self.feat_1024 = UpBlock(nfc[512], nfc[1024])
121 |
122 | self.embed = nn.Embedding(num_classes, z_dim)
123 |
124 | def forward(self, input, c, update_emas=False):
125 | c = self.embed(c.argmax(1))
126 |
127 | # map noise to hypersphere as in "Progressive Growing of GANS"
128 | input = normalize_second_moment(input[:, 0])
129 |
130 | feat_4 = self.init(input)
131 | feat_8 = self.feat_8(feat_4, c)
132 | feat_16 = self.feat_16(feat_8, c)
133 | feat_32 = self.feat_32(feat_16, c)
134 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c))
135 | feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c))
136 |
137 | if self.img_resolution >= 128:
138 | feat_last = feat_128
139 |
140 | if self.img_resolution >= 256:
141 | feat_last = self.se_256(feat_16, self.feat_256(feat_last, c))
142 |
143 | if self.img_resolution >= 512:
144 | feat_last = self.se_512(feat_32, self.feat_512(feat_last, c))
145 |
146 | if self.img_resolution >= 1024:
147 | feat_last = self.feat_1024(feat_last, c)
148 |
149 | return self.to_big(feat_last)
150 |
151 |
152 | class Generator(nn.Module):
153 | def __init__(
154 | self,
155 | z_dim=256,
156 | c_dim=0,
157 | w_dim=0,
158 | img_resolution=256,
159 | img_channels=3,
160 | ngf=128,
161 | cond=0,
162 | mapping_kwargs={},
163 | synthesis_kwargs={},
164 | **kwargs,
165 | ):
166 | super().__init__()
167 | self.z_dim = z_dim
168 | self.c_dim = c_dim
169 | self.w_dim = w_dim
170 | self.img_resolution = img_resolution
171 | self.img_channels = img_channels
172 |
173 | # Mapping and Synthesis Networks
174 | self.mapping = DummyMapping() # to fit the StyleGAN API
175 | Synthesis = FastganSynthesisCond if cond else FastganSynthesis
176 | self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs)
177 |
178 | def forward(self, z, c, **kwargs):
179 | w = self.mapping(z, c)
180 | img = self.synthesis(w, c)
181 | return img
182 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/capture_widget.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import re
11 | import numpy as np
12 | import imgui
13 | import PIL.Image
14 | from gui_utils import imgui_utils
15 | from . import renderer
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | class CaptureWidget:
20 | def __init__(self, viz):
21 | self.viz = viz
22 | self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots'))
23 | self.dump_image = False
24 | self.dump_gui = False
25 | self.defer_frames = 0
26 | self.disabled_time = 0
27 |
28 | def dump_png(self, image):
29 | viz = self.viz
30 | try:
31 | _height, _width, channels = image.shape
32 | assert channels in [1, 3]
33 | assert image.dtype == np.uint8
34 | os.makedirs(self.path, exist_ok=True)
35 | file_id = 0
36 | for entry in os.scandir(self.path):
37 | if entry.is_file():
38 | match = re.fullmatch(r'(\d+).*', entry.name)
39 | if match:
40 | file_id = max(file_id, int(match.group(1)) + 1)
41 | if channels == 1:
42 | pil_image = PIL.Image.fromarray(image[:, :, 0], 'L')
43 | else:
44 | pil_image = PIL.Image.fromarray(image, 'RGB')
45 | pil_image.save(os.path.join(self.path, f'{file_id:05d}.png'))
46 | except:
47 | viz.result.error = renderer.CapturedException()
48 |
49 | @imgui_utils.scoped_by_object_id
50 | def __call__(self, show=True):
51 | viz = self.viz
52 | if show:
53 | with imgui_utils.grayed_out(self.disabled_time != 0):
54 | imgui.text('Capture')
55 | imgui.same_line(viz.label_w)
56 | _changed, self.path = imgui_utils.input_text('##path', self.path, 1024,
57 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE),
58 | width=(-1 - viz.button_w * 2 - viz.spacing * 2),
59 | help_text='PATH')
60 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '':
61 | imgui.set_tooltip(self.path)
62 | imgui.same_line()
63 | if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)):
64 | self.dump_image = True
65 | self.defer_frames = 2
66 | self.disabled_time = 0.5
67 | imgui.same_line()
68 | if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)):
69 | self.dump_gui = True
70 | self.defer_frames = 2
71 | self.disabled_time = 0.5
72 |
73 | self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
74 | if self.defer_frames > 0:
75 | self.defer_frames -= 1
76 | elif self.dump_image:
77 | if 'image' in viz.result:
78 | self.dump_png(viz.result.image)
79 | self.dump_image = False
80 | elif self.dump_gui:
81 | viz.capture_next_frame()
82 | self.dump_gui = False
83 | captured_frame = viz.pop_captured_frame()
84 | if captured_frame is not None:
85 | self.dump_png(captured_frame)
86 |
87 | #----------------------------------------------------------------------------
88 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/equivariance_widget.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import imgui
11 | import dnnlib
12 | from gui_utils import imgui_utils
13 |
14 | #----------------------------------------------------------------------------
15 |
16 | class EquivarianceWidget:
17 | def __init__(self, viz):
18 | self.viz = viz
19 | self.xlate = dnnlib.EasyDict(x=0, y=0, anim=False, round=False, speed=1e-2)
20 | self.xlate_def = dnnlib.EasyDict(self.xlate)
21 | self.rotate = dnnlib.EasyDict(val=0, anim=False, speed=5e-3)
22 | self.rotate_def = dnnlib.EasyDict(self.rotate)
23 | self.opts = dnnlib.EasyDict(untransform=False)
24 | self.opts_def = dnnlib.EasyDict(self.opts)
25 |
26 | @imgui_utils.scoped_by_object_id
27 | def __call__(self, show=True):
28 | viz = self.viz
29 | if show:
30 | imgui.text('Translate')
31 | imgui.same_line(viz.label_w)
32 | with imgui_utils.item_width(viz.font_size * 8):
33 | _changed, (self.xlate.x, self.xlate.y) = imgui.input_float2('##xlate', self.xlate.x, self.xlate.y, format='%.4f')
34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing)
35 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag fast##xlate', width=viz.button_w)
36 | if dragging:
37 | self.xlate.x += dx / viz.font_size * 2e-2
38 | self.xlate.y += dy / viz.font_size * 2e-2
39 | imgui.same_line()
40 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag slow##xlate', width=viz.button_w)
41 | if dragging:
42 | self.xlate.x += dx / viz.font_size * 4e-4
43 | self.xlate.y += dy / viz.font_size * 4e-4
44 | imgui.same_line()
45 | _clicked, self.xlate.anim = imgui.checkbox('Anim##xlate', self.xlate.anim)
46 | imgui.same_line()
47 | _clicked, self.xlate.round = imgui.checkbox('Round##xlate', self.xlate.round)
48 | imgui.same_line()
49 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.xlate.anim):
50 | changed, speed = imgui.slider_float('##xlate_speed', self.xlate.speed, 0, 0.5, format='Speed %.5f', power=5)
51 | if changed:
52 | self.xlate.speed = speed
53 | imgui.same_line()
54 | if imgui_utils.button('Reset##xlate', width=-1, enabled=(self.xlate != self.xlate_def)):
55 | self.xlate = dnnlib.EasyDict(self.xlate_def)
56 |
57 | if show:
58 | imgui.text('Rotate')
59 | imgui.same_line(viz.label_w)
60 | with imgui_utils.item_width(viz.font_size * 8):
61 | _changed, self.rotate.val = imgui.input_float('##rotate', self.rotate.val, format='%.4f')
62 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing)
63 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag fast##rotate', width=viz.button_w)
64 | if dragging:
65 | self.rotate.val += dx / viz.font_size * 2e-2
66 | imgui.same_line()
67 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag slow##rotate', width=viz.button_w)
68 | if dragging:
69 | self.rotate.val += dx / viz.font_size * 4e-4
70 | imgui.same_line()
71 | _clicked, self.rotate.anim = imgui.checkbox('Anim##rotate', self.rotate.anim)
72 | imgui.same_line()
73 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.rotate.anim):
74 | changed, speed = imgui.slider_float('##rotate_speed', self.rotate.speed, -1, 1, format='Speed %.4f', power=3)
75 | if changed:
76 | self.rotate.speed = speed
77 | imgui.same_line()
78 | if imgui_utils.button('Reset##rotate', width=-1, enabled=(self.rotate != self.rotate_def)):
79 | self.rotate = dnnlib.EasyDict(self.rotate_def)
80 |
81 | if show:
82 | imgui.set_cursor_pos_x(imgui.get_content_region_max()[0] - 1 - viz.button_w*1 - viz.font_size*16)
83 | _clicked, self.opts.untransform = imgui.checkbox('Untransform', self.opts.untransform)
84 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w)
85 | if imgui_utils.button('Reset##opts', width=-1, enabled=(self.opts != self.opts_def)):
86 | self.opts = dnnlib.EasyDict(self.opts_def)
87 |
88 | if self.xlate.anim:
89 | c = np.array([self.xlate.x, self.xlate.y], dtype=np.float64)
90 | t = c.copy()
91 | if np.max(np.abs(t)) < 1e-4:
92 | t += 1
93 | t *= 0.1 / np.hypot(*t)
94 | t += c[::-1] * [1, -1]
95 | d = t - c
96 | d *= (viz.frame_delta * self.xlate.speed) / np.hypot(*d)
97 | self.xlate.x += d[0]
98 | self.xlate.y += d[1]
99 |
100 | if self.rotate.anim:
101 | self.rotate.val += viz.frame_delta * self.rotate.speed
102 |
103 | pos = np.array([self.xlate.x, self.xlate.y], dtype=np.float64)
104 | if self.xlate.round and 'img_resolution' in viz.result:
105 | pos = np.rint(pos * viz.result.img_resolution) / viz.result.img_resolution
106 | angle = self.rotate.val * np.pi * 2
107 |
108 | viz.args.input_transform = [
109 | [np.cos(angle), np.sin(angle), pos[0]],
110 | [-np.sin(angle), np.cos(angle), pos[1]],
111 | [0, 0, 1]]
112 |
113 | viz.args.update(untransform=self.opts.untransform)
114 |
115 | #----------------------------------------------------------------------------
116 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/latent_widget.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import imgui
11 | import dnnlib
12 | from gui_utils import imgui_utils
13 |
14 | #----------------------------------------------------------------------------
15 |
16 | class LatentWidget:
17 | def __init__(self, viz):
18 | self.viz = viz
19 | self.latent = dnnlib.EasyDict(x=0, y=0, anim=False, speed=0.25)
20 | self.latent_def = dnnlib.EasyDict(self.latent)
21 | self.step_y = 100
22 |
23 | def drag(self, dx, dy):
24 | viz = self.viz
25 | self.latent.x += dx / viz.font_size * 4e-2
26 | self.latent.y += dy / viz.font_size * 4e-2
27 |
28 | @imgui_utils.scoped_by_object_id
29 | def __call__(self, show=True):
30 | viz = self.viz
31 | if show:
32 | imgui.text('Latent')
33 | imgui.same_line(viz.label_w)
34 | seed = round(self.latent.x) + round(self.latent.y) * self.step_y
35 | with imgui_utils.item_width(viz.font_size * 8):
36 | changed, seed = imgui.input_int('##seed', seed)
37 | if changed:
38 | self.latent.x = seed
39 | self.latent.y = 0
40 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing)
41 | frac_x = self.latent.x - round(self.latent.x)
42 | frac_y = self.latent.y - round(self.latent.y)
43 | with imgui_utils.item_width(viz.font_size * 5):
44 | changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE)
45 | if changed:
46 | self.latent.x += new_frac_x - frac_x
47 | self.latent.y += new_frac_y - frac_y
48 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2)
49 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w)
50 | if dragging:
51 | self.drag(dx, dy)
52 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3)
53 | _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim)
54 | imgui.same_line(round(viz.font_size * 27.7))
55 | with imgui_utils.item_width(-1 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim):
56 | changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3)
57 | if changed:
58 | self.latent.speed = speed
59 | imgui.same_line()
60 | snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y))
61 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)):
62 | self.latent = snapped
63 | imgui.same_line()
64 | if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)):
65 | self.latent = dnnlib.EasyDict(self.latent_def)
66 |
67 | if self.latent.anim:
68 | self.latent.x += viz.frame_delta * self.latent.speed
69 | viz.args.w0_seeds = [] # [[seed, weight], ...]
70 | for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]:
71 | seed_x = np.floor(self.latent.x) + ofs_x
72 | seed_y = np.floor(self.latent.y) + ofs_y
73 | seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1)
74 | weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y))
75 | if weight > 0:
76 | viz.args.w0_seeds.append([seed, weight])
77 |
78 | #----------------------------------------------------------------------------
79 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/performance_widget.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import array
10 | import numpy as np
11 | import imgui
12 | from gui_utils import imgui_utils
13 |
14 | #----------------------------------------------------------------------------
15 |
16 | class PerformanceWidget:
17 | def __init__(self, viz):
18 | self.viz = viz
19 | self.gui_times = [float('nan')] * 60
20 | self.render_times = [float('nan')] * 30
21 | self.fps_limit = 60
22 | self.use_vsync = False
23 | self.is_async = False
24 | self.force_fp32 = False
25 |
26 | @imgui_utils.scoped_by_object_id
27 | def __call__(self, show=True):
28 | viz = self.viz
29 | self.gui_times = self.gui_times[1:] + [viz.frame_delta]
30 | if 'render_time' in viz.result:
31 | self.render_times = self.render_times[1:] + [viz.result.render_time]
32 | del viz.result.render_time
33 |
34 | if show:
35 | imgui.text('GUI')
36 | imgui.same_line(viz.label_w)
37 | with imgui_utils.item_width(viz.font_size * 8):
38 | imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0)
39 | imgui.same_line(viz.label_w + viz.font_size * 9)
40 | t = [x for x in self.gui_times if x > 0]
41 | t = np.mean(t) if len(t) > 0 else 0
42 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A')
43 | imgui.same_line(viz.label_w + viz.font_size * 14)
44 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A')
45 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3)
46 | with imgui_utils.item_width(viz.font_size * 6):
47 | _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE)
48 | self.fps_limit = min(max(self.fps_limit, 5), 1000)
49 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing)
50 | _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync)
51 |
52 | if show:
53 | imgui.text('Render')
54 | imgui.same_line(viz.label_w)
55 | with imgui_utils.item_width(viz.font_size * 8):
56 | imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0)
57 | imgui.same_line(viz.label_w + viz.font_size * 9)
58 | t = [x for x in self.render_times if x > 0]
59 | t = np.mean(t) if len(t) > 0 else 0
60 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A')
61 | imgui.same_line(viz.label_w + viz.font_size * 14)
62 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A')
63 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3)
64 | _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async)
65 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing)
66 | _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32)
67 |
68 | viz.set_fps_limit(self.fps_limit)
69 | viz.set_vsync(self.use_vsync)
70 | viz.set_async(self.is_async)
71 | viz.args.force_fp32 = self.force_fp32
72 |
73 | #----------------------------------------------------------------------------
74 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/pickle_widget.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import glob
10 | import os
11 | import re
12 |
13 | import dnnlib
14 | import imgui
15 | import numpy as np
16 | from gui_utils import imgui_utils
17 |
18 | from . import renderer
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | def _locate_results(pattern):
23 | return pattern
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | class PickleWidget:
28 | def __init__(self, viz):
29 | self.viz = viz
30 | self.search_dirs = []
31 | self.cur_pkl = None
32 | self.user_pkl = ''
33 | self.recent_pkls = []
34 | self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...}
35 | self.browse_refocus = False
36 | self.load('', ignore_errors=True)
37 |
38 | def add_recent(self, pkl, ignore_errors=False):
39 | try:
40 | resolved = self.resolve_pkl(pkl)
41 | if resolved not in self.recent_pkls:
42 | self.recent_pkls.append(resolved)
43 | except:
44 | if not ignore_errors:
45 | raise
46 |
47 | def load(self, pkl, ignore_errors=False):
48 | viz = self.viz
49 | viz.clear_result()
50 | viz.skip_frame() # The input field will change on next frame.
51 | try:
52 | resolved = self.resolve_pkl(pkl)
53 | name = resolved.replace('\\', '/').split('/')[-1]
54 | self.cur_pkl = resolved
55 | self.user_pkl = resolved
56 | viz.result.message = f'Loading {name}...'
57 | viz.defer_rendering()
58 | if resolved in self.recent_pkls:
59 | self.recent_pkls.remove(resolved)
60 | self.recent_pkls.insert(0, resolved)
61 | except:
62 | self.cur_pkl = None
63 | self.user_pkl = pkl
64 | if pkl == '':
65 | viz.result = dnnlib.EasyDict(message='No network pickle loaded')
66 | else:
67 | viz.result = dnnlib.EasyDict(error=renderer.CapturedException())
68 | if not ignore_errors:
69 | raise
70 |
71 | @imgui_utils.scoped_by_object_id
72 | def __call__(self, show=True):
73 | viz = self.viz
74 | recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl]
75 | if show:
76 | imgui.text('Pickle')
77 | imgui.same_line(viz.label_w)
78 | changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024,
79 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE),
80 | width=(-1 - viz.button_w * 2 - viz.spacing * 2),
81 | help_text=' | | | | /.pkl')
82 | if changed:
83 | self.load(self.user_pkl, ignore_errors=True)
84 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '':
85 | imgui.set_tooltip(self.user_pkl)
86 | imgui.same_line()
87 | if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)):
88 | imgui.open_popup('recent_pkls_popup')
89 | imgui.same_line()
90 | if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1):
91 | imgui.open_popup('browse_pkls_popup')
92 | self.browse_cache.clear()
93 | self.browse_refocus = True
94 |
95 | if imgui.begin_popup('recent_pkls_popup'):
96 | for pkl in recent_pkls:
97 | clicked, _state = imgui.menu_item(pkl)
98 | if clicked:
99 | self.load(pkl, ignore_errors=True)
100 | imgui.end_popup()
101 |
102 | if imgui.begin_popup('browse_pkls_popup'):
103 | def recurse(parents):
104 | key = tuple(parents)
105 | items = self.browse_cache.get(key, None)
106 | if items is None:
107 | items = self.list_runs_and_pkls(parents)
108 | self.browse_cache[key] = items
109 | for item in items:
110 | if item.type == 'run' and imgui.begin_menu(item.name):
111 | recurse([item.path])
112 | imgui.end_menu()
113 | if item.type == 'pkl':
114 | clicked, _state = imgui.menu_item(item.name)
115 | if clicked:
116 | self.load(item.path, ignore_errors=True)
117 | if len(items) == 0:
118 | with imgui_utils.grayed_out():
119 | imgui.menu_item('No results found')
120 | recurse(self.search_dirs)
121 | if self.browse_refocus:
122 | imgui.set_scroll_here()
123 | viz.skip_frame() # Focus will change on next frame.
124 | self.browse_refocus = False
125 | imgui.end_popup()
126 |
127 | paths = viz.pop_drag_and_drop_paths()
128 | if paths is not None and len(paths) >= 1:
129 | self.load(paths[0], ignore_errors=True)
130 |
131 | viz.args.pkl = self.cur_pkl
132 |
133 | def list_runs_and_pkls(self, parents):
134 | items = []
135 | run_regex = re.compile(r'\d+-.*')
136 | pkl_regex = re.compile(r'network-snapshot-\d+\.pkl')
137 | for parent in set(parents):
138 | if os.path.isdir(parent):
139 | for entry in os.scandir(parent):
140 | if entry.is_dir() and run_regex.fullmatch(entry.name):
141 | items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name)))
142 | if entry.is_file() and pkl_regex.fullmatch(entry.name):
143 | items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name)))
144 |
145 | items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path))
146 | return items
147 |
148 | def resolve_pkl(self, pattern):
149 | assert isinstance(pattern, str)
150 | assert pattern != ''
151 |
152 | # URL => return as is.
153 | if dnnlib.util.is_url(pattern):
154 | return pattern
155 |
156 | # Short-hand pattern => locate.
157 | path = _locate_results(pattern)
158 |
159 | # Run dir => pick the last saved snapshot.
160 | if os.path.isdir(path):
161 | pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl')))
162 | if len(pkl_files) == 0:
163 | raise IOError(f'No network pickle found in "{path}"')
164 | path = pkl_files[-1]
165 |
166 | # Normalize.
167 | path = os.path.abspath(path)
168 | return path
169 |
170 | #----------------------------------------------------------------------------
171 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/stylemix_widget.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import imgui
10 | from gui_utils import imgui_utils
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | class StyleMixingWidget:
15 | def __init__(self, viz):
16 | self.viz = viz
17 | self.seed_def = 1000
18 | self.seed = self.seed_def
19 | self.animate = False
20 | self.enables = []
21 |
22 | @imgui_utils.scoped_by_object_id
23 | def __call__(self, show=True):
24 | viz = self.viz
25 | num_ws = viz.result.get('num_ws', 0)
26 | num_enables = viz.result.get('num_ws', 18)
27 | self.enables += [False] * max(num_enables - len(self.enables), 0)
28 |
29 | if show:
30 | imgui.text('Stylemix')
31 | imgui.same_line(viz.label_w)
32 | with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0):
33 | _changed, self.seed = imgui.input_int('##seed', self.seed)
34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing)
35 | with imgui_utils.grayed_out(num_ws == 0):
36 | _clicked, self.animate = imgui.checkbox('Anim', self.animate)
37 |
38 | pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w
39 | pos1 = pos2 - imgui.get_text_line_height() - viz.spacing
40 | pos0 = viz.label_w + viz.font_size * 12
41 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0])
42 | for idx in range(num_enables):
43 | imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1))))
44 | if idx == 0:
45 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3)
46 | with imgui_utils.grayed_out(num_ws == 0):
47 | _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx])
48 | if imgui.is_item_hovered():
49 | imgui.set_tooltip(f'{idx}')
50 | imgui.pop_style_var(1)
51 |
52 | imgui.same_line(pos2)
53 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3)
54 | with imgui_utils.grayed_out(num_ws == 0):
55 | if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))):
56 | self.seed = self.seed_def
57 | self.animate = False
58 | self.enables = [False] * num_enables
59 |
60 | if any(self.enables[:num_ws]):
61 | viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable]
62 | viz.args.stylemix_seed = self.seed & ((1 << 32) - 1)
63 | if self.animate:
64 | self.seed += 1
65 |
66 | #----------------------------------------------------------------------------
67 |
--------------------------------------------------------------------------------
/stylesan-xl/viz/trunc_noise_widget.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import imgui
10 | from gui_utils import imgui_utils
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | class TruncationNoiseWidget:
15 | def __init__(self, viz):
16 | self.viz = viz
17 | self.prev_num_ws = 0
18 | self.trunc_psi = 1
19 | self.trunc_cutoff = 0
20 | self.noise_enable = True
21 | self.noise_seed = 0
22 | self.noise_anim = False
23 |
24 | @imgui_utils.scoped_by_object_id
25 | def __call__(self, show=True):
26 | viz = self.viz
27 | num_ws = viz.result.get('num_ws', 0)
28 | has_noise = viz.result.get('has_noise', False)
29 | if num_ws > 0 and num_ws != self.prev_num_ws:
30 | if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws:
31 | self.trunc_cutoff = num_ws
32 | self.prev_num_ws = num_ws
33 |
34 | if show:
35 | imgui.text('Truncate')
36 | imgui.same_line(viz.label_w)
37 | with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0):
38 | _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f')
39 | imgui.same_line()
40 | if num_ws == 0:
41 | imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False)
42 | else:
43 | with imgui_utils.item_width(viz.font_size * 8 + viz.spacing):
44 | changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d')
45 | if changed:
46 | self.trunc_cutoff = min(max(new_cutoff, 0), num_ws)
47 |
48 | with imgui_utils.grayed_out(not has_noise):
49 | imgui.same_line()
50 | _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable)
51 | imgui.same_line(round(viz.font_size * 27.7))
52 | with imgui_utils.grayed_out(not self.noise_enable):
53 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing - viz.font_size * 4):
54 | _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed)
55 | imgui.same_line(spacing=0)
56 | _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim)
57 |
58 | is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws)
59 | is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim)
60 | with imgui_utils.grayed_out(is_def_trunc and not has_noise):
61 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w)
62 | if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)):
63 | self.prev_num_ws = num_ws
64 | self.trunc_psi = 1
65 | self.trunc_cutoff = num_ws
66 | self.noise_enable = True
67 | self.noise_seed = 0
68 | self.noise_anim = False
69 |
70 | if self.noise_anim:
71 | self.noise_seed += 1
72 | viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed)
73 | viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random')
74 |
75 | #----------------------------------------------------------------------------
76 |
--------------------------------------------------------------------------------
/tutorial/assets/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tutorial/assets/imagenet256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sony/san/0e52b1428b2e66ae1c5f6a3586a76497d88c9ea8/tutorial/assets/imagenet256.png
--------------------------------------------------------------------------------