├── stylegan_xl ├── feature_networks │ ├── clip │ │ ├── __init__.py │ │ └── simple_tokenizer.py │ └── constants.py ├── media │ ├── banner.png │ ├── system.png │ ├── teaser.png │ ├── editing_banner.png │ ├── no_truncation.png │ ├── unimodal_truncation.png │ └── multimodal_truncation.png ├── .gitignore ├── viz │ ├── __init__.py │ ├── stylemix_widget.py │ ├── performance_widget.py │ ├── capture_widget.py │ ├── trunc_noise_widget.py │ ├── latent_widget.py │ ├── equivariance_widget.py │ └── pickle_widget.py ├── gui_utils │ ├── __init__.py │ ├── imgui_window.py │ ├── text_utils.py │ └── imgui_utils.py ├── metrics │ ├── __init__.py │ ├── inception_score.py │ ├── kernel_inception_distance.py │ ├── frechet_inception_distance.py │ ├── precision_recall.py │ └── perceptual_path_length.py ├── training │ ├── __init__.py │ ├── diffaug.py │ └── networks_fastgan.py ├── torch_utils │ ├── __init__.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.h │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── filtered_lrelu_ns.cu │ │ ├── upfirdn2d.h │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── filtered_lrelu.h │ │ ├── bias_act.cpp │ │ ├── upfirdn2d.cpp │ │ ├── bias_act.cu │ │ └── conv2d_resample.py │ ├── utils_spectrum.py │ └── custom_ops.py ├── dnnlib │ └── __init__.py ├── environment.yml ├── LICENSE.txt ├── gen_random_models.py ├── sg_forward.py ├── gen_class_samplesheet.py ├── gen_images.py └── pg_modules │ └── projector.py ├── networks ├── __init__.py ├── alexnet.py ├── alexnet_cifar.py ├── vgg.py ├── vgg_cifar.py ├── conv_gap.py ├── dnfr.py ├── conv.py ├── vit_cifar.py ├── vit.py └── resnet_cifar.py ├── environment.yml ├── readme.md ├── baseline_methods.py ├── reparam_module.py └── fed_baseline.py /stylegan_xl/feature_networks/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /stylegan_xl/media/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedDG23/FedDG-main/HEAD/stylegan_xl/media/banner.png -------------------------------------------------------------------------------- /stylegan_xl/media/system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedDG23/FedDG-main/HEAD/stylegan_xl/media/system.png -------------------------------------------------------------------------------- /stylegan_xl/media/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedDG23/FedDG-main/HEAD/stylegan_xl/media/teaser.png -------------------------------------------------------------------------------- /stylegan_xl/media/editing_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedDG23/FedDG-main/HEAD/stylegan_xl/media/editing_banner.png -------------------------------------------------------------------------------- /stylegan_xl/media/no_truncation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedDG23/FedDG-main/HEAD/stylegan_xl/media/no_truncation.png -------------------------------------------------------------------------------- /stylegan_xl/media/unimodal_truncation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedDG23/FedDG-main/HEAD/stylegan_xl/media/unimodal_truncation.png -------------------------------------------------------------------------------- /stylegan_xl/media/multimodal_truncation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedDG23/FedDG-main/HEAD/stylegan_xl/media/multimodal_truncation.png -------------------------------------------------------------------------------- /stylegan_xl/.gitignore: -------------------------------------------------------------------------------- 1 | g++ 2 | gcc 3 | 4 | data 5 | data/* 6 | !data/.placeholder 7 | 8 | training-runs 9 | training-runs/* 10 | out/* 11 | sample_sheets/* 12 | 13 | 14 | *.zip 15 | 16 | **/__pycache__ 17 | __pycache__ 18 | .ipynb_checkpoints/ 19 | tags 20 | *.swp 21 | *.pth 22 | *.pt 23 | *.npz 24 | *.tar 25 | *.gz 26 | *.pkl 27 | *.mp4 28 | *.pyc 29 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import AlexNet 2 | from .conv import ConvNet 3 | from .dnfr import DNFR 4 | from .resnet import ResNet, ResNet18, ResNet18ImageNet 5 | from .vgg import VGG, VGG11, VGG13, VGG16, VGG19, VGG11BN 6 | from .vit import ViT 7 | from .conv_gap import ConvNetGAP 8 | from .alexnet_cifar import AlexNetCIFAR 9 | from .resnet_cifar import ResNet18CIFAR 10 | from .vgg_cifar import VGG11CIFAR 11 | from .vit_cifar import ViTCIFAR -------------------------------------------------------------------------------- /stylegan_xl/viz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan_xl/gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan_xl/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan_xl/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan_xl/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: feddg 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.9.1 13 | - cudatoolkit=11.1 14 | - requests=2.26.0 15 | - tqdm=4.62.2 16 | - ninja=1.10.2 17 | - matplotlib=3.4.2 18 | - imageio=2.9.0 19 | - dill=0.3.4 20 | - psutil=5.8.0 21 | - regex=2022.3.15 22 | - pip: 23 | - imgui==1.3.0 24 | - glfw==2.2.0 25 | - pyopengl==3.1.5 26 | - imageio-ffmpeg==0.4.3 27 | - pyspng 28 | - ftfy==6.1.1 29 | - timm==0.4.12 30 | - wandb==0.15.11 31 | - kornia==0.7.0 32 | - ema-pytorch==0.2.3 33 | - einops==0.6.1 34 | -------------------------------------------------------------------------------- /stylegan_xl/environment.yml: -------------------------------------------------------------------------------- 1 | name: glad 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.9.1 13 | - cudatoolkit=11.1 14 | - requests=2.26.0 15 | - tqdm=4.62.2 16 | - ninja=1.10.2 17 | - matplotlib=3.4.2 18 | - imageio=2.9.0 19 | - dill=0.3.4 20 | - psutil=5.8.0 21 | - regex=2022.3.15 22 | - pip: 23 | - imgui==1.3.0 24 | - glfw==2.2.0 25 | - pyopengl==3.1.5 26 | - imageio-ffmpeg==0.4.3 27 | - pyspng 28 | - ftfy==6.1.1 29 | - timm==0.4.12 30 | - wandb==0.15.11 31 | - kornia==0.7.0 32 | - ema-pytorch==0.2.3 33 | - einops==0.6.1 34 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # FedDG-main 2 | 3 | To setup an environment, please run 4 | 5 | ```bash 6 | conda env create -f environment.yml 7 | ``` 8 | 9 | To train on CIFAR-10, please use the following command: 10 | ```bash 11 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 fed_main.py --dataset=CIFAR10 --space=wp --layer=5 --data_path=./data --eval_mode CIFAR --beta 0.5 --batch_real 256 --batch_train 256 --batch_test 128 --ipc=10 --nworkers 10 --round 20 --Iteration 100 --Iteration_g 20 12 | ``` 13 | 14 | To train on ImageFruit, please use the following command: 15 | ```bash 16 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 fed_main.py --dataset=imagenet-fruits --space=wp --layer=12 --data_path=./data --eval_mode imagenet --beta 0.5 --batch_real 32 --batch_train 32 --batch_test 32 --ipc=10 --nworkers 10 --round 20 --Iteration 100 --Iteration_g 20 17 | ``` -------------------------------------------------------------------------------- /stylegan_xl/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /networks/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | ''' AlexNet ''' 5 | class AlexNet(nn.Module): 6 | def __init__(self, channel, num_classes, im_size, **kwargs): 7 | super(AlexNet, self).__init__() 8 | self.features = nn.Sequential( 9 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 10 | nn.ReLU(inplace=True), 11 | nn.MaxPool2d(kernel_size=2, stride=2), 12 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=2, stride=2), 15 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | ) 23 | self.fc = nn.Linear(192 * im_size[0]//8 * im_size[1]//8, num_classes) 24 | 25 | def forward(self, x): 26 | x = self.features(x) 27 | feat_fc = x.view(x.size(0), -1) 28 | x = self.fc(feat_fc) 29 | 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /networks/alexnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class AlexNetCIFAR(nn.Module): 4 | def __init__(self, channel, num_classes): 5 | super(AlexNetCIFAR, self).__init__() 6 | self.features = nn.Sequential( 7 | nn.Conv2d(channel, 128, kernel_size=5, stride=1, padding=4 if channel==1 else 2), 8 | nn.ReLU(inplace=True), 9 | nn.MaxPool2d(kernel_size=2, stride=2), 10 | nn.Conv2d(128, 192, kernel_size=5, padding=2), 11 | nn.ReLU(inplace=True), 12 | nn.MaxPool2d(kernel_size=2, stride=2), 13 | nn.Conv2d(192, 256, kernel_size=3, padding=1), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(256, 192, kernel_size=3, padding=1), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(192, 192, kernel_size=3, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(kernel_size=2, stride=2), 20 | ) 21 | self.fc = nn.Linear(192 * 4 * 4, num_classes) 22 | 23 | def forward(self, x): 24 | x = self.features(x) 25 | x = x.view(x.size(0), -1) 26 | x = self.fc(x) 27 | return x 28 | 29 | def embed(self, x): 30 | x = self.features(x) 31 | x = x.view(x.size(0), -1) 32 | return x -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /stylegan_xl/gen_random_models.py: -------------------------------------------------------------------------------- 1 | import dnnlib 2 | import torch 3 | import copy 4 | import pickle 5 | 6 | # common_kwargs = dict(c_dim=100, img_resolution=32, img_channels=3, ) 7 | # G_kwargs = dict(class_name = "training.networks_stylegan3.Generator", z_dim= 64, w_dim= 512, num_layers=10) 8 | # G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False) 9 | # 10 | # snapshot_data = dict(G=G) 11 | # for key, value in snapshot_data.items(): 12 | # if isinstance(value, torch.nn.Module): 13 | # value = copy.deepcopy(value).eval().requires_grad_(False) 14 | # snapshot_data[key] = value.cpu() 15 | # del value # conserve memory 16 | # snapshot_pkl = "random_conditional_256.pkl" 17 | # with open(snapshot_pkl, 'wb') as f: 18 | # pickle.dump(snapshot_data, f) 19 | 20 | 21 | 22 | common_kwargs = dict(c_dim=0, img_resolution=256, img_channels=3, ) 23 | G_kwargs = dict(class_name = "training.networks_stylegan3.Generator", z_dim= 64, w_dim= 512, num_layers=17) 24 | G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False) 25 | 26 | snapshot_data = dict(G=G) 27 | for key, value in snapshot_data.items(): 28 | if isinstance(value, torch.nn.Module): 29 | value = copy.deepcopy(value).eval().requires_grad_(False) 30 | snapshot_data[key] = value.cpu() 31 | del value # conserve memory 32 | snapshot_pkl = "random_unconditional_256.pkl" 33 | with open(snapshot_pkl, 'wb') as f: 34 | pickle.dump(snapshot_data, f) -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylegan_xl/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /networks/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | ''' VGG ''' 7 | cfg_vgg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | class VGG(nn.Module): 14 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'): 15 | super(VGG, self).__init__() 16 | self.channel = channel 17 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 18 | self.classifier = nn.Linear(512*7*7 if vgg_name != 'VGGS' else 128, num_classes) 19 | 20 | def forward(self, x): 21 | x = self.features(x) 22 | feat_fc = x.view(x.size(0), -1) 23 | x = self.classifier(feat_fc) 24 | 25 | return x 26 | 27 | def _make_layers(self, cfg, norm): 28 | layers = [] 29 | in_channels = self.channel 30 | for ic, x in enumerate(cfg): 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 35 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x) if norm == 'batch' else nn.Identity(), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AdaptiveMaxPool2d((7, 7))] 39 | return nn.Sequential(*layers) 40 | 41 | 42 | def VGG11(channel, num_classes, **kwargs): 43 | return VGG('VGG11', channel, num_classes) 44 | def VGG11BN(channel, num_classes): 45 | return VGG('VGG11', channel, num_classes, norm='batchnorm') 46 | def VGG13(channel, num_classes): 47 | return VGG('VGG13', channel, num_classes) 48 | def VGG16(channel, num_classes): 49 | return VGG('VGG16', channel, num_classes) 50 | def VGG19(channel, num_classes): 51 | return VGG('VGG19', channel, num_classes) -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /networks/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | ''' VGG ''' 4 | cfg_vgg = { 5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 9 | } 10 | class VGGCIFAR(nn.Module): 11 | def __init__(self, vgg_name, channel, num_classes, norm='instancenorm'): 12 | super(VGGCIFAR, self).__init__() 13 | self.channel = channel 14 | self.features = self._make_layers(cfg_vgg[vgg_name], norm) 15 | self.classifier = nn.Linear(512 if vgg_name != 'VGGS' else 128, num_classes) 16 | 17 | def forward(self, x): 18 | x = self.features(x) 19 | x = x.view(x.size(0), -1) 20 | x = self.classifier(x) 21 | return x 22 | 23 | def embed(self, x): 24 | x = self.features(x) 25 | x = x.view(x.size(0), -1) 26 | return x 27 | 28 | def _make_layers(self, cfg, norm): 29 | layers = [] 30 | in_channels = self.channel 31 | for ic, x in enumerate(cfg): 32 | if x == 'M': 33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 34 | else: 35 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel==1 and ic==0 else 1), 36 | nn.GroupNorm(x, x, affine=True) if norm=='instancenorm' else nn.BatchNorm2d(x), 37 | nn.ReLU(inplace=True)] 38 | in_channels = x 39 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 40 | return nn.Sequential(*layers) 41 | 42 | 43 | def VGG11CIFAR(channel, num_classes): 44 | return VGGCIFAR('VGG11', channel, num_classes) 45 | def VGG11BNCIFAR(channel, num_classes): 46 | return VGGCIFAR('VGG11', channel, num_classes, norm='batchnorm') 47 | def VGG13CIFAR(channel, num_classes): 48 | return VGGCIFAR('VGG13', channel, num_classes) 49 | def VGG16CIFAR(channel, num_classes): 50 | return VGGCIFAR('VGG16', channel, num_classes) 51 | def VGG19CIFAR(channel, num_classes): 52 | return VGGCIFAR('VGG19', channel, num_classes) -------------------------------------------------------------------------------- /stylegan_xl/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /stylegan_xl/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen, sfid=False, rfid=False): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | if rfid: 25 | detector_url = 'https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/feature_networks/inception_rand_full.pkl' 26 | detector_kwargs = {} # random inception network returns features by default 27 | 28 | 29 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, sfid=sfid).get_mean_cov() 32 | 33 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 34 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 35 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, sfid=sfid).get_mean_cov() 36 | 37 | if opts.rank != 0: 38 | return float('nan') 39 | 40 | m = np.square(mu_gen - mu_real).sum() 41 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 42 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 43 | return float(fid) 44 | 45 | #---------------------------------------------------------------------------- 46 | -------------------------------------------------------------------------------- /stylegan_xl/sg_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_utils import misc 3 | 4 | class StyleGAN_Wrapper(torch.nn.Module): 5 | 6 | def __init__(self, G): 7 | super(StyleGAN_Wrapper, self).__init__() 8 | self.G = G 9 | self.syn = G.synthesis 10 | self.mapping = G.mapping 11 | 12 | def forward(self, ws=None, f_latents=None, f_layer=0, mode="wp"): 13 | 14 | if ws is None and f_latents is None: 15 | # print("I am here") 16 | return torch.zeros(0,0,0,0).to(self.G.device) 17 | 18 | if mode == "wp": 19 | return self.forward_wp(ws) 20 | elif mode == "from_f": 21 | return self.forward_from_f(ws, f_latents, f_layer) 22 | elif mode == "to_f": 23 | return self.forward_to_f(ws, f_layer) 24 | 25 | def forward_wp(self, ws, **layer_kwargs): 26 | misc.assert_shape(ws, [None, self.syn.num_ws, self.syn.w_dim]) 27 | ws = ws.to(torch.float32).unbind(dim=1) 28 | 29 | # Execute layers. 30 | x = self.syn.input(ws[0]) 31 | for name, w in zip(self.syn.layer_names, ws[1:]): 32 | x = getattr(self.syn, name)(x, w, **layer_kwargs) 33 | if self.syn.output_scale != 1: 34 | x = x * self.syn.output_scale 35 | 36 | # Ensure correct shape and dtype. 37 | misc.assert_shape(x, [None, self.syn.img_channels, self.syn.img_resolution, self.syn.img_resolution]) 38 | x = x.to(torch.float32) 39 | return x 40 | 41 | def forward_from_f(self, ws, f_latents, f_layer, **layer_kwargs): 42 | misc.assert_shape(ws, [None, self.syn.num_ws, self.syn.w_dim]) 43 | ws = ws.to(torch.float32).unbind(dim=1) 44 | 45 | x = f_latents 46 | for name, w in zip(self.syn.layer_names[f_layer:], ws[1+f_layer:]): 47 | x = getattr(self.syn, name)(x, w, **layer_kwargs) 48 | # print(i, name, x.shape) 49 | if self.syn.output_scale != 1: 50 | x = x * self.syn.output_scale 51 | 52 | # Ensure correct shape and dtype. 53 | misc.assert_shape(x, [None, self.syn.img_channels, self.syn.img_resolution, self.syn.img_resolution]) 54 | x = x.to(torch.float32) 55 | return x 56 | 57 | def forward_to_f(self, ws, f_layer, **layer_kwargs): 58 | misc.assert_shape(ws, [None, self.syn.num_ws, self.syn.w_dim]) 59 | ws = ws.to(torch.float32).unbind(dim=1) 60 | 61 | 62 | # Execute layers. 63 | x = self.syn.input(ws[0]) 64 | for name, w in zip(self.syn.layer_names[:f_layer], ws[1:1+f_layer]): 65 | x = getattr(self.syn, name)(x, w, **layer_kwargs) 66 | 67 | # print(i, name, x.shape) 68 | # print(x.shape) 69 | return x -------------------------------------------------------------------------------- /stylegan_xl/feature_networks/constants.py: -------------------------------------------------------------------------------- 1 | TORCHVISION = [ 2 | "vgg11_bn", 3 | "vgg13_bn", 4 | "vgg16", 5 | "vgg16_bn", 6 | "vgg19_bn", 7 | "densenet121", 8 | "densenet169", 9 | "densenet201", 10 | "inception_v3", 11 | "resnet18", 12 | "resnet34", 13 | "resnet50", 14 | "resnet101", 15 | "resnet152", 16 | "shufflenet_v2_x0_5", 17 | "mobilenet_v2", 18 | "wide_resnet50_2", 19 | "mnasnet0_5", 20 | "mnasnet1_0", 21 | "ghostnet_100", 22 | "cspresnet50", 23 | "fbnetc_100", 24 | "spnasnet_100", 25 | "resnet50d", 26 | "resnet26", 27 | "resnet26d", 28 | "seresnet50", 29 | "resnetblur50", 30 | "resnetrs50", 31 | "tf_mixnet_s", 32 | "tf_mixnet_m", 33 | "tf_mixnet_l", 34 | "ese_vovnet19b_dw", 35 | "ese_vovnet39b", 36 | "res2next50", 37 | "gernet_s", 38 | "gernet_m", 39 | "repvgg_a2", 40 | "repvgg_b0", 41 | "repvgg_b1", 42 | "repvgg_b1g4", 43 | "revnet", 44 | "dm_nfnet_f1", 45 | "nfnet_l0", 46 | ] 47 | 48 | REGNETS = [ 49 | "regnetx_002", 50 | "regnetx_004", 51 | "regnetx_006", 52 | "regnetx_008", 53 | "regnetx_016", 54 | "regnetx_032", 55 | "regnetx_040", 56 | "regnetx_064", 57 | "regnety_002", 58 | "regnety_004", 59 | "regnety_006", 60 | "regnety_008", 61 | "regnety_016", 62 | "regnety_032", 63 | "regnety_040", 64 | "regnety_064", 65 | ] 66 | 67 | EFFNETS_IMAGENET = [ 68 | 'tf_efficientnet_b0', 69 | 'tf_efficientnet_b1', 70 | 'tf_efficientnet_b2', 71 | 'tf_efficientnet_b3', 72 | 'tf_efficientnet_b4', 73 | 'tf_efficientnet_b0_ns', 74 | ] 75 | 76 | EFFNETS_INCEPTION = [ 77 | 'tf_efficientnet_lite0', 78 | 'tf_efficientnet_lite1', 79 | 'tf_efficientnet_lite2', 80 | 'tf_efficientnet_lite3', 81 | 'tf_efficientnet_lite4', 82 | 'tf_efficientnetv2_b0', 83 | 'tf_efficientnetv2_b1', 84 | 'tf_efficientnetv2_b2', 85 | 'tf_efficientnetv2_b3', 86 | 'efficientnet_b1', 87 | 'efficientnet_b1_pruned', 88 | 'efficientnet_b2_pruned', 89 | 'efficientnet_b3_pruned', 90 | ] 91 | 92 | EFFNETS = EFFNETS_IMAGENET + EFFNETS_INCEPTION 93 | 94 | VITS_IMAGENET = [ 95 | 'deit_tiny_distilled_patch16_224', 96 | 'deit_small_distilled_patch16_224', 97 | 'deit_base_distilled_patch16_224', 98 | ] 99 | 100 | VITS_INCEPTION = [ 101 | 'vit_base_patch16_224' 102 | ] 103 | 104 | VITS = VITS_IMAGENET + VITS_INCEPTION 105 | 106 | CLIP = [ 107 | 'resnet50_clip' 108 | ] 109 | 110 | ALL_MODELS = TORCHVISION + REGNETS + EFFNETS + VITS + CLIP 111 | 112 | # Group according to input normalization 113 | 114 | NORMALIZED_IMAGENET = TORCHVISION + REGNETS + EFFNETS_IMAGENET + VITS_IMAGENET 115 | 116 | NORMALIZED_INCEPTION = EFFNETS_INCEPTION + VITS_INCEPTION 117 | 118 | NORMALIZED_CLIP = CLIP 119 | -------------------------------------------------------------------------------- /stylegan_xl/gen_class_samplesheet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import PIL.Image 4 | from typing import List 5 | import click 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import legacy 11 | import dnnlib 12 | from training.training_loop import save_image_grid 13 | from torch_utils import gen_utils 14 | from gen_images import parse_range 15 | 16 | @click.command() 17 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 18 | @click.option('--trunc', 'truncation_psi', help='Truncation psi', type=float, default=1, show_default=True) 19 | @click.option('--seed', help='Random seed', type=int, default=42) 20 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation') 21 | @click.option('--classes', type=parse_range, help='List of classes (e.g., \'0,1,4-6\')', required=True) 22 | @click.option('--samples-per-class', help='Samples per class.', type=int, default=4) 23 | @click.option('--grid-width', help='Total width of image grid', type=int, default=32) 24 | @click.option('--batch-gpu', help='Samples per pass, adapt to fit on GPU', type=int, default=32) 25 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 26 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) 27 | def generate_samplesheet( 28 | network_pkl: str, 29 | truncation_psi: float, 30 | seed: int, 31 | centroids_path: str, 32 | classes: List[int], 33 | samples_per_class: int, 34 | batch_gpu: int, 35 | grid_width: int, 36 | outdir: str, 37 | desc: str, 38 | ): 39 | print('Loading networks from "%s"...' % network_pkl) 40 | device = torch.device('cuda') 41 | with dnnlib.util.open_url(network_pkl) as f: 42 | G = legacy.load_network_pkl(f)['G_ema'].to(device).requires_grad_(False) 43 | 44 | # setup 45 | os.makedirs(outdir, exist_ok=True) 46 | desc_full = f'{Path(network_pkl).stem}_trunc_{truncation_psi}' 47 | if desc is not None: desc_full += f'-{desc}' 48 | run_dir = Path(gen_utils.make_run_dir(outdir, desc_full)) 49 | 50 | print('Generating latents.') 51 | ws = [] 52 | for class_idx in tqdm(classes): 53 | w = gen_utils.get_w_from_seed(G, samples_per_class, device, truncation_psi, seed=seed, 54 | centroids_path=centroids_path, class_idx=class_idx) 55 | ws.append(w) 56 | ws = torch.cat(ws) 57 | 58 | print('Generating samples.') 59 | images = [] 60 | for w in tqdm(ws.split(batch_gpu)): 61 | img = gen_utils.w_to_img(G, w, to_np=True) 62 | images.append(img) 63 | 64 | # adjust grid widht to prohibit folding between same class then save to disk 65 | grid_width = grid_width - grid_width % samples_per_class 66 | images = gen_utils.create_image_grid(np.concatenate(images), grid_size=(grid_width, None)) 67 | PIL.Image.fromarray(images, 'RGB').save(run_dir / 'sheet.png') 68 | 69 | if __name__ == "__main__": 70 | generate_samplesheet() 71 | -------------------------------------------------------------------------------- /stylegan_xl/training/diffaug.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def DiffAugment(x, policy='', channels_first=True): 10 | if policy: 11 | if not channels_first: 12 | x = x.permute(0, 3, 1, 2) 13 | for p in policy.split(','): 14 | for f in AUGMENT_FNS[p]: 15 | x = f(x) 16 | if not channels_first: 17 | x = x.permute(0, 2, 3, 1) 18 | x = x.contiguous() 19 | return x 20 | 21 | 22 | def rand_brightness(x): 23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 24 | return x 25 | 26 | 27 | def rand_saturation(x): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 30 | return x 31 | 32 | 33 | def rand_contrast(x): 34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 36 | return x 37 | 38 | 39 | def rand_translation(x, ratio=0.125): 40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 43 | grid_batch, grid_x, grid_y = torch.meshgrid( 44 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 45 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 47 | ) 48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 52 | return x 53 | 54 | 55 | def rand_cutout(x, ratio=0.2): 56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 59 | grid_batch, grid_x, grid_y = torch.meshgrid( 60 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 63 | ) 64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 67 | mask[grid_batch, grid_x, grid_y] = 0 68 | x = x * mask.unsqueeze(1) 69 | return x 70 | 71 | 72 | AUGMENT_FNS = { 73 | 'color': [rand_brightness, rand_saturation, rand_contrast], 74 | 'translation': [rand_translation], 75 | 'cutout': [rand_cutout], 76 | } 77 | -------------------------------------------------------------------------------- /stylegan_xl/viz/stylemix_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class StyleMixingWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.seed_def = 1000 18 | self.seed = self.seed_def 19 | self.animate = False 20 | self.enables = [] 21 | 22 | @imgui_utils.scoped_by_object_id 23 | def __call__(self, show=True): 24 | viz = self.viz 25 | num_ws = viz.result.get('num_ws', 0) 26 | num_enables = viz.result.get('num_ws', 18) 27 | self.enables += [False] * max(num_enables - len(self.enables), 0) 28 | 29 | if show: 30 | imgui.text('Stylemix') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): 33 | _changed, self.seed = imgui.input_int('##seed', self.seed) 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | with imgui_utils.grayed_out(num_ws == 0): 36 | _clicked, self.animate = imgui.checkbox('Anim', self.animate) 37 | 38 | pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w 39 | pos1 = pos2 - imgui.get_text_line_height() - viz.spacing 40 | pos0 = viz.label_w + viz.font_size * 12 41 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) 42 | for idx in range(num_enables): 43 | imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) 44 | if idx == 0: 45 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) 46 | with imgui_utils.grayed_out(num_ws == 0): 47 | _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) 48 | if imgui.is_item_hovered(): 49 | imgui.set_tooltip(f'{idx}') 50 | imgui.pop_style_var(1) 51 | 52 | imgui.same_line(pos2) 53 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) 54 | with imgui_utils.grayed_out(num_ws == 0): 55 | if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): 56 | self.seed = self.seed_def 57 | self.animate = False 58 | self.enables = [False] * num_enables 59 | 60 | if any(self.enables[:num_ws]): 61 | viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] 62 | viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) 63 | if self.animate: 64 | self.seed += 1 65 | 66 | #---------------------------------------------------------------------------- 67 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /networks/conv_gap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ''' ConvNetGAP ''' 5 | class ConvNetGAP(nn.Module): 6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size = (32,32)): 7 | super(ConvNetGAP, self).__init__() 8 | 9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 10 | num_feat = shape_feat[0] 11 | self.classifier = nn.Linear(num_feat, num_classes) 12 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 13 | 14 | def forward(self, x): 15 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 16 | out = self.features(x) 17 | out = self.pool(out) 18 | out = out.view(out.size(0), -1) 19 | out = self.classifier(out) 20 | return out 21 | 22 | def _get_activation(self, net_act): 23 | if net_act == 'sigmoid': 24 | return nn.Sigmoid() 25 | elif net_act == 'relu': 26 | return nn.ReLU(inplace=True) 27 | elif net_act == 'leakyrelu': 28 | return nn.LeakyReLU(negative_slope=0.01) 29 | else: 30 | exit('unknown activation function: %s'%net_act) 31 | 32 | def _get_pooling(self, net_pooling): 33 | if net_pooling == 'maxpooling': 34 | return nn.MaxPool2d(kernel_size=2, stride=2) 35 | elif net_pooling == 'avgpooling': 36 | return nn.AvgPool2d(kernel_size=2, stride=2) 37 | elif net_pooling == 'none': 38 | return None 39 | else: 40 | exit('unknown net_pooling: %s'%net_pooling) 41 | 42 | def _get_normlayer(self, net_norm, shape_feat): 43 | # shape_feat = (c*h*w) 44 | if net_norm == 'batchnorm': 45 | return nn.BatchNorm2d(shape_feat[0], affine=True) 46 | elif net_norm == 'layernorm': 47 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 48 | elif net_norm == 'instancenorm': 49 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 50 | elif net_norm == 'groupnorm': 51 | return nn.GroupNorm(4, shape_feat[0], affine=True) 52 | elif net_norm == 'none': 53 | return None 54 | else: 55 | exit('unknown net_norm: %s'%net_norm) 56 | 57 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 58 | layers = [] 59 | in_channels = channel 60 | if im_size[0] == 28: 61 | im_size = (32, 32) 62 | shape_feat = [in_channels, im_size[0], im_size[1]] 63 | for d in range(net_depth): 64 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 65 | shape_feat[0] = net_width 66 | if net_norm != 'none': 67 | layers += [self._get_normlayer(net_norm, shape_feat)] 68 | layers += [self._get_activation(net_act)] 69 | in_channels = net_width 70 | if net_pooling != 'none': 71 | layers += [self._get_pooling(net_pooling)] 72 | shape_feat[1] //= 2 73 | shape_feat[2] //= 2 74 | net_width *= 2 75 | 76 | 77 | return nn.Sequential(*layers), shape_feat -------------------------------------------------------------------------------- /networks/dnfr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ''' DNFR_Net ''' 5 | class DNFR(nn.Module): 6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size=(32,32)): 7 | super(DNFR, self).__init__() 8 | 9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 10 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 11 | self.classifier = nn.Linear(num_feat, num_classes) 12 | 13 | def forward(self, x): 14 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 15 | out = self.features(x) 16 | out = out.view(out.size(0), -1) 17 | out = self.classifier(out) 18 | return out 19 | 20 | def _get_activation(self, net_act): 21 | if net_act == 'sigmoid': 22 | return nn.Sigmoid() 23 | elif net_act == 'relu': 24 | return nn.ReLU(inplace=True) 25 | elif net_act == 'leakyrelu': 26 | return nn.LeakyReLU(negative_slope=0.01) 27 | else: 28 | exit('unknown activation function: %s'%net_act) 29 | 30 | def _get_pooling(self, net_pooling): 31 | if net_pooling == 'maxpooling': 32 | return nn.MaxPool2d(kernel_size=2, stride=2) 33 | elif net_pooling == 'avgpooling': 34 | return nn.AvgPool2d(kernel_size=2, stride=2) 35 | elif net_pooling == 'none': 36 | return None 37 | else: 38 | exit('unknown net_pooling: %s'%net_pooling) 39 | 40 | def _get_normlayer(self, net_norm, shape_feat): 41 | # shape_feat = (c*h*w) 42 | if net_norm == 'batchnorm': 43 | return nn.BatchNorm2d(shape_feat[0], affine=True) 44 | elif net_norm == 'layernorm': 45 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 46 | elif net_norm == 'instancenorm': 47 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 48 | elif net_norm == 'groupnorm': 49 | return nn.GroupNorm(4, shape_feat[0], affine=True) 50 | elif net_norm == 'none': 51 | return None 52 | else: 53 | exit('unknown net_norm: %s'%net_norm) 54 | 55 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 56 | layers = [] 57 | in_channels = channel 58 | if im_size[0] == 28: 59 | im_size = (32, 32) 60 | shape_feat = [in_channels, im_size[0], im_size[1]] 61 | for d in range(net_depth): 62 | if net_norm != 'none': 63 | if d == 0 and net_norm == 'groupnorm': 64 | layers += [self._get_normlayer('instancenorm', shape_feat)] 65 | else: 66 | layers += [self._get_normlayer(net_norm, shape_feat)] 67 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 68 | shape_feat[0] = net_width 69 | layers += [self._get_activation(net_act)] 70 | in_channels = net_width 71 | if net_pooling != 'none': 72 | layers += [self._get_pooling(net_pooling)] 73 | shape_feat[1] //= 2 74 | shape_feat[2] //= 2 75 | 76 | 77 | return nn.Sequential(*layers), shape_feat -------------------------------------------------------------------------------- /stylegan_xl/viz/performance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import array 10 | import numpy as np 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class PerformanceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.gui_times = [float('nan')] * 60 20 | self.render_times = [float('nan')] * 30 21 | self.fps_limit = 60 22 | self.use_vsync = False 23 | self.is_async = False 24 | self.force_fp32 = False 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | self.gui_times = self.gui_times[1:] + [viz.frame_delta] 30 | if 'render_time' in viz.result: 31 | self.render_times = self.render_times[1:] + [viz.result.render_time] 32 | del viz.result.render_time 33 | 34 | if show: 35 | imgui.text('GUI') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 8): 38 | imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) 39 | imgui.same_line(viz.label_w + viz.font_size * 9) 40 | t = [x for x in self.gui_times if x > 0] 41 | t = np.mean(t) if len(t) > 0 else 0 42 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 43 | imgui.same_line(viz.label_w + viz.font_size * 14) 44 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 45 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 46 | with imgui_utils.item_width(viz.font_size * 6): 47 | _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 48 | self.fps_limit = min(max(self.fps_limit, 5), 1000) 49 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 50 | _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) 51 | 52 | if show: 53 | imgui.text('Render') 54 | imgui.same_line(viz.label_w) 55 | with imgui_utils.item_width(viz.font_size * 8): 56 | imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) 57 | imgui.same_line(viz.label_w + viz.font_size * 9) 58 | t = [x for x in self.render_times if x > 0] 59 | t = np.mean(t) if len(t) > 0 else 0 60 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 61 | imgui.same_line(viz.label_w + viz.font_size * 14) 62 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 63 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 64 | _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) 65 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 66 | _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) 67 | 68 | viz.set_fps_limit(self.fps_limit) 69 | viz.set_vsync(self.use_vsync) 70 | viz.set_async(self.is_async) 71 | viz.args.force_fp32 = self.force_fp32 72 | 73 | #---------------------------------------------------------------------------- 74 | -------------------------------------------------------------------------------- /stylegan_xl/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | from tqdm import tqdm 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 21 | assert 0 <= rank < num_gpus 22 | num_cols = col_features.shape[0] 23 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 24 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 25 | dist_batches = [] 26 | for col_batch in col_batches[rank :: num_gpus]: 27 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 28 | for src in range(num_gpus): 29 | dist_broadcast = dist_batch.clone() 30 | if num_gpus > 1: 31 | torch.distributed.broadcast(dist_broadcast, src=src) 32 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 33 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 38 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 39 | detector_kwargs = dict(return_features=True) 40 | max_real = max_real // opts.num_gpus 41 | 42 | real_features = metric_utils.compute_feature_stats_for_dataset( 43 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 44 | rel_lo=0, rel_hi=0, capture_all=True, max_items=None, shuffle_size=max_real).get_all_torch().to(torch.float16).to(opts.device) 45 | 46 | gen_features = metric_utils.compute_feature_stats_for_generator( 47 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 48 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 49 | 50 | results = dict() 51 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 52 | kth = [] 53 | for manifold_batch in tqdm(manifold.split(row_batch_size)): 54 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 55 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 56 | kth = torch.cat(kth) if opts.rank == 0 else None 57 | pred = [] 58 | for probes_batch in tqdm(probes.split(row_batch_size)): 59 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 60 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 61 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 62 | return results['precision'], results['recall'] 63 | 64 | #---------------------------------------------------------------------------- 65 | -------------------------------------------------------------------------------- /stylegan_xl/viz/capture_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import re 11 | import numpy as np 12 | import imgui 13 | import PIL.Image 14 | from gui_utils import imgui_utils 15 | from . import renderer 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class CaptureWidget: 20 | def __init__(self, viz): 21 | self.viz = viz 22 | self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) 23 | self.dump_image = False 24 | self.dump_gui = False 25 | self.defer_frames = 0 26 | self.disabled_time = 0 27 | 28 | def dump_png(self, image): 29 | viz = self.viz 30 | try: 31 | _height, _width, channels = image.shape 32 | assert channels in [1, 3] 33 | assert image.dtype == np.uint8 34 | os.makedirs(self.path, exist_ok=True) 35 | file_id = 0 36 | for entry in os.scandir(self.path): 37 | if entry.is_file(): 38 | match = re.fullmatch(r'(\d+).*', entry.name) 39 | if match: 40 | file_id = max(file_id, int(match.group(1)) + 1) 41 | if channels == 1: 42 | pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') 43 | else: 44 | pil_image = PIL.Image.fromarray(image, 'RGB') 45 | pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) 46 | except: 47 | viz.result.error = renderer.CapturedException() 48 | 49 | @imgui_utils.scoped_by_object_id 50 | def __call__(self, show=True): 51 | viz = self.viz 52 | if show: 53 | with imgui_utils.grayed_out(self.disabled_time != 0): 54 | imgui.text('Capture') 55 | imgui.same_line(viz.label_w) 56 | _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, 57 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 58 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 59 | help_text='PATH') 60 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': 61 | imgui.set_tooltip(self.path) 62 | imgui.same_line() 63 | if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): 64 | self.dump_image = True 65 | self.defer_frames = 2 66 | self.disabled_time = 0.5 67 | imgui.same_line() 68 | if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)): 69 | self.dump_gui = True 70 | self.defer_frames = 2 71 | self.disabled_time = 0.5 72 | 73 | self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) 74 | if self.defer_frames > 0: 75 | self.defer_frames -= 1 76 | elif self.dump_image: 77 | if 'image' in viz.result: 78 | self.dump_png(viz.result.image) 79 | self.dump_image = False 80 | elif self.dump_gui: 81 | viz.capture_next_frame() 82 | self.dump_gui = False 83 | captured_frame = viz.pop_captured_frame() 84 | if captured_frame is not None: 85 | self.dump_png(captured_frame) 86 | 87 | #---------------------------------------------------------------------------- 88 | -------------------------------------------------------------------------------- /stylegan_xl/viz/trunc_noise_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class TruncationNoiseWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.prev_num_ws = 0 18 | self.trunc_psi = 1 19 | self.trunc_cutoff = 0 20 | self.noise_enable = True 21 | self.noise_seed = 0 22 | self.noise_anim = False 23 | 24 | @imgui_utils.scoped_by_object_id 25 | def __call__(self, show=True): 26 | viz = self.viz 27 | num_ws = viz.result.get('num_ws', 0) 28 | has_noise = viz.result.get('has_noise', False) 29 | if num_ws > 0 and num_ws != self.prev_num_ws: 30 | if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: 31 | self.trunc_cutoff = num_ws 32 | self.prev_num_ws = num_ws 33 | 34 | if show: 35 | imgui.text('Truncate') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): 38 | _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') 39 | imgui.same_line() 40 | if num_ws == 0: 41 | imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) 42 | else: 43 | with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): 44 | changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') 45 | if changed: 46 | self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) 47 | 48 | with imgui_utils.grayed_out(not has_noise): 49 | imgui.same_line() 50 | _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) 51 | imgui.same_line(round(viz.font_size * 27.7)) 52 | with imgui_utils.grayed_out(not self.noise_enable): 53 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing - viz.font_size * 4): 54 | _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) 55 | imgui.same_line(spacing=0) 56 | _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) 57 | 58 | is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) 59 | is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) 60 | with imgui_utils.grayed_out(is_def_trunc and not has_noise): 61 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 62 | if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): 63 | self.prev_num_ws = num_ws 64 | self.trunc_psi = 1 65 | self.trunc_cutoff = num_ws 66 | self.noise_enable = True 67 | self.noise_seed = 0 68 | self.noise_anim = False 69 | 70 | if self.noise_anim: 71 | self.noise_seed += 1 72 | viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) 73 | viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') 74 | 75 | #---------------------------------------------------------------------------- 76 | -------------------------------------------------------------------------------- /networks/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ''' ConvNet ''' 5 | class ConvNet(nn.Module): 6 | def __init__(self, channel, num_classes, net_width, net_depth, net_act, net_norm, net_pooling='avgpooling', im_size = (32,32)): 7 | super(ConvNet, self).__init__() 8 | 9 | self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size) 10 | num_feat = shape_feat[0]*shape_feat[1]*shape_feat[2] 11 | self.classifier = nn.Linear(num_feat, num_classes) 12 | 13 | def forward(self, x): 14 | # print("MODEL DATA ON: ", x.get_device(), "MODEL PARAMS ON: ", self.classifier.weight.data.get_device()) 15 | out = self.features(x) 16 | out = out.view(out.size(0), -1) 17 | out = self.classifier(out) 18 | return out 19 | 20 | 21 | def forward_cafe(self, x): 22 | f_maps = [] 23 | x = self.features[0:3](x) 24 | f_maps.append(x) 25 | 26 | for i in range(start=3, stop=-1, step=4): 27 | x = self.features[i:i+4](x) 28 | f_maps.append(x) 29 | 30 | out = self.features[-1:](x) 31 | f_maps.append(out) 32 | out = out.view(out.size(0), -1) 33 | out_final = self.classifier(out) 34 | return out_final, out, f_maps 35 | 36 | def embed(self, x): 37 | out = self.features(x) 38 | out = out.view(out.size(0), -1) 39 | return out 40 | 41 | def logit_layer(self, x): 42 | out = self.features(x) 43 | out = out.view(out.size(0), -1) 44 | out = self.classifier(out) 45 | return out 46 | 47 | def _get_activation(self, net_act): 48 | if net_act == 'sigmoid': 49 | return nn.Sigmoid() 50 | elif net_act == 'relu': 51 | return nn.ReLU(inplace=True) 52 | elif net_act == 'leakyrelu': 53 | return nn.LeakyReLU(negative_slope=0.01) 54 | else: 55 | exit('unknown activation function: %s'%net_act) 56 | 57 | def _get_pooling(self, net_pooling): 58 | if net_pooling == 'maxpooling': 59 | return nn.MaxPool2d(kernel_size=2, stride=2) 60 | elif net_pooling == 'avgpooling': 61 | return nn.AvgPool2d(kernel_size=2, stride=2) 62 | elif net_pooling == 'none': 63 | return None 64 | else: 65 | exit('unknown net_pooling: %s'%net_pooling) 66 | 67 | def _get_normlayer(self, net_norm, shape_feat): 68 | # shape_feat = (c*h*w) 69 | if net_norm == 'batchnorm': 70 | return nn.BatchNorm2d(shape_feat[0], affine=True) 71 | elif net_norm == 'layernorm': 72 | return nn.LayerNorm(shape_feat, elementwise_affine=True) 73 | elif net_norm == 'instancenorm': 74 | return nn.GroupNorm(shape_feat[0], shape_feat[0], affine=True) 75 | elif net_norm == 'groupnorm': 76 | return nn.GroupNorm(4, shape_feat[0], affine=True) 77 | elif net_norm == 'none': 78 | return None 79 | else: 80 | exit('unknown net_norm: %s'%net_norm) 81 | 82 | def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size): 83 | layers = [] 84 | in_channels = channel 85 | if im_size[0] == 28: 86 | im_size = (32, 32) 87 | shape_feat = [in_channels, im_size[0], im_size[1]] 88 | for d in range(net_depth): 89 | layers += [nn.Conv2d(in_channels, net_width, kernel_size=3, padding=3 if channel == 1 and d == 0 else 1)] 90 | shape_feat[0] = net_width 91 | if net_norm != 'none': 92 | layers += [self._get_normlayer(net_norm, shape_feat)] 93 | layers += [self._get_activation(net_act)] 94 | in_channels = net_width 95 | if net_pooling != 'none': 96 | layers += [self._get_pooling(net_pooling)] 97 | shape_feat[1] //= 2 98 | shape_feat[2] //= 2 99 | 100 | 101 | return nn.Sequential(*layers), shape_feat -------------------------------------------------------------------------------- /stylegan_xl/viz/latent_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class LatentWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.latent = dnnlib.EasyDict(x=0, y=0, anim=False, speed=0.25) 20 | self.latent_def = dnnlib.EasyDict(self.latent) 21 | self.step_y = 100 22 | 23 | def drag(self, dx, dy): 24 | viz = self.viz 25 | self.latent.x += dx / viz.font_size * 4e-2 26 | self.latent.y += dy / viz.font_size * 4e-2 27 | 28 | @imgui_utils.scoped_by_object_id 29 | def __call__(self, show=True): 30 | viz = self.viz 31 | if show: 32 | imgui.text('Latent') 33 | imgui.same_line(viz.label_w) 34 | seed = round(self.latent.x) + round(self.latent.y) * self.step_y 35 | with imgui_utils.item_width(viz.font_size * 8): 36 | changed, seed = imgui.input_int('##seed', seed) 37 | if changed: 38 | self.latent.x = seed 39 | self.latent.y = 0 40 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 41 | frac_x = self.latent.x - round(self.latent.x) 42 | frac_y = self.latent.y - round(self.latent.y) 43 | with imgui_utils.item_width(viz.font_size * 5): 44 | changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 45 | if changed: 46 | self.latent.x += new_frac_x - frac_x 47 | self.latent.y += new_frac_y - frac_y 48 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 49 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 50 | if dragging: 51 | self.drag(dx, dy) 52 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) 53 | _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) 54 | imgui.same_line(round(viz.font_size * 27.7)) 55 | with imgui_utils.item_width(-1 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): 56 | changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) 57 | if changed: 58 | self.latent.speed = speed 59 | imgui.same_line() 60 | snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) 61 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): 62 | self.latent = snapped 63 | imgui.same_line() 64 | if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): 65 | self.latent = dnnlib.EasyDict(self.latent_def) 66 | 67 | if self.latent.anim: 68 | self.latent.x += viz.frame_delta * self.latent.speed 69 | viz.args.w0_seeds = [] # [[seed, weight], ...] 70 | for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: 71 | seed_x = np.floor(self.latent.x) + ofs_x 72 | seed_y = np.floor(self.latent.y) + ofs_y 73 | seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) 74 | weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) 75 | if weight > 0: 76 | viz.args.w0_seeds.append([seed, weight]) 77 | 78 | #---------------------------------------------------------------------------- 79 | -------------------------------------------------------------------------------- /stylegan_xl/gui_utils/imgui_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import imgui 11 | import imgui.integrations.glfw 12 | 13 | from . import glfw_window 14 | from . import imgui_utils 15 | from . import text_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class ImguiWindow(glfw_window.GlfwWindow): 20 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): 21 | if font is None: 22 | font = text_utils.get_default_font() 23 | font_sizes = {int(size) for size in font_sizes} 24 | super().__init__(title=title, **glfw_kwargs) 25 | 26 | # Init fields. 27 | self._imgui_context = None 28 | self._imgui_renderer = None 29 | self._imgui_fonts = None 30 | self._cur_font_size = max(font_sizes) 31 | 32 | # Delete leftover imgui.ini to avoid unexpected behavior. 33 | if os.path.isfile('imgui.ini'): 34 | os.remove('imgui.ini') 35 | 36 | # Init ImGui. 37 | self._imgui_context = imgui.create_context() 38 | self._imgui_renderer = _GlfwRenderer(self._glfw_window) 39 | self._attach_glfw_callbacks() 40 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. 41 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). 42 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} 43 | self._imgui_renderer.refresh_font_texture() 44 | 45 | def close(self): 46 | self.make_context_current() 47 | self._imgui_fonts = None 48 | if self._imgui_renderer is not None: 49 | self._imgui_renderer.shutdown() 50 | self._imgui_renderer = None 51 | if self._imgui_context is not None: 52 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. 53 | self._imgui_context = None 54 | super().close() 55 | 56 | def _glfw_key_callback(self, *args): 57 | super()._glfw_key_callback(*args) 58 | self._imgui_renderer.keyboard_callback(*args) 59 | 60 | @property 61 | def font_size(self): 62 | return self._cur_font_size 63 | 64 | @property 65 | def spacing(self): 66 | return round(self._cur_font_size * 0.4) 67 | 68 | def set_font_size(self, target): # Applied on next frame. 69 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] 70 | 71 | def begin_frame(self): 72 | # Begin glfw frame. 73 | super().begin_frame() 74 | 75 | # Process imgui events. 76 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 77 | if self.content_width > 0 and self.content_height > 0: 78 | self._imgui_renderer.process_inputs() 79 | 80 | # Begin imgui frame. 81 | imgui.new_frame() 82 | imgui.push_font(self._imgui_fonts[self._cur_font_size]) 83 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) 84 | 85 | def end_frame(self): 86 | imgui.pop_font() 87 | imgui.render() 88 | imgui.end_frame() 89 | self._imgui_renderer.render(imgui.get_draw_data()) 90 | super().end_frame() 91 | 92 | #---------------------------------------------------------------------------- 93 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. 94 | 95 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__(*args, **kwargs) 98 | self.mouse_wheel_multiplier = 1 99 | 100 | def scroll_callback(self, window, x_offset, y_offset): 101 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier 102 | 103 | #---------------------------------------------------------------------------- 104 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /networks/vit_cifar.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | # classes 15 | 16 | class PreNorm(nn.Module): 17 | def __init__(self, dim, fn): 18 | super().__init__() 19 | self.norm = nn.LayerNorm(dim) 20 | self.fn = fn 21 | def forward(self, x, **kwargs): 22 | return self.fn(self.norm(x), **kwargs) 23 | 24 | class FeedForward(nn.Module): 25 | def __init__(self, dim, hidden_dim, dropout = 0.): 26 | super().__init__() 27 | self.net = nn.Sequential( 28 | nn.Linear(dim, hidden_dim), 29 | nn.GELU(), 30 | nn.Dropout(dropout), 31 | nn.Linear(hidden_dim, dim), 32 | nn.Dropout(dropout) 33 | ) 34 | def forward(self, x): 35 | return self.net(x) 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 39 | super().__init__() 40 | inner_dim = dim_head * heads 41 | project_out = not (heads == 1 and dim_head == dim) 42 | 43 | self.heads = heads 44 | self.scale = dim_head ** -0.5 45 | 46 | self.attend = nn.Softmax(dim = -1) 47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 48 | 49 | self.to_out = nn.Sequential( 50 | nn.Linear(inner_dim, dim), 51 | nn.Dropout(dropout) 52 | ) if project_out else nn.Identity() 53 | 54 | def forward(self, x): 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 57 | 58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 59 | 60 | attn = self.attend(dots) 61 | 62 | out = torch.matmul(attn, v) 63 | out = rearrange(out, 'b h n d -> b n (h d)') 64 | return self.to_out(out) 65 | 66 | class Transformer(nn.Module): 67 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 68 | super().__init__() 69 | self.layers = nn.ModuleList([]) 70 | for _ in range(depth): 71 | self.layers.append(nn.ModuleList([ 72 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 73 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 74 | ])) 75 | def forward(self, x): 76 | for attn, ff in self.layers: 77 | x = attn(x) + x 78 | x = ff(x) + x 79 | return x 80 | 81 | class ViTCIFAR(nn.Module): 82 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 83 | super().__init__() 84 | image_height, image_width = pair(image_size) 85 | patch_height, patch_width = pair(patch_size) 86 | 87 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 88 | 89 | num_patches = (image_height // patch_height) * (image_width // patch_width) 90 | patch_dim = channels * patch_height * patch_width 91 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 92 | 93 | self.to_patch_embedding = nn.Sequential( 94 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 95 | nn.Linear(patch_dim, dim), 96 | ) 97 | 98 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 99 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 100 | self.dropout = nn.Dropout(emb_dropout) 101 | 102 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 103 | 104 | self.pool = pool 105 | self.to_latent = nn.Identity() 106 | 107 | self.mlp_head = nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, num_classes) 110 | ) 111 | 112 | def forward(self, img): 113 | x = self.to_patch_embedding(img) 114 | b, n, _ = x.shape 115 | 116 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 117 | x = torch.cat((cls_tokens, x), dim=1) 118 | x += self.pos_embedding[:, :(n + 1)] 119 | x = self.dropout(x) 120 | 121 | x = self.transformer(x) 122 | 123 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 124 | 125 | x = self.to_latent(x) 126 | return self.mlp_head(x) -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /networks/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | # classes 15 | 16 | class PreNorm(nn.Module): 17 | def __init__(self, dim, fn): 18 | super().__init__() 19 | # self.norm = nn.LayerNorm(dim) 20 | self.fn = fn 21 | def forward(self, x, **kwargs): 22 | # return self.fn(self.norm(x), **kwargs) 23 | return self.fn(x, **kwargs) 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, dim, hidden_dim, dropout = 0.): 27 | super().__init__() 28 | self.net = nn.Sequential( 29 | nn.Linear(dim, hidden_dim), 30 | nn.GELU(), 31 | nn.Dropout(dropout), 32 | nn.Linear(hidden_dim, dim), 33 | nn.Dropout(dropout) 34 | ) 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 40 | super().__init__() 41 | inner_dim = dim_head * heads 42 | project_out = not (heads == 1 and dim_head == dim) 43 | 44 | self.heads = heads 45 | self.scale = dim_head ** -0.5 46 | 47 | self.attend = nn.Softmax(dim = -1) 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 49 | 50 | self.to_out = nn.Sequential( 51 | nn.Linear(inner_dim, dim), 52 | nn.Dropout(dropout) 53 | ) if project_out else nn.Identity() 54 | 55 | def forward(self, x): 56 | qkv = self.to_qkv(x).chunk(3, dim = -1) 57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 58 | 59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 60 | 61 | attn = self.attend(dots) 62 | 63 | out = torch.matmul(attn, v) 64 | out = rearrange(out, 'b h n d -> b n (h d)') 65 | return self.to_out(out) 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([ 73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 75 | ])) 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) + x 79 | x = ff(x) + x 80 | return x 81 | 82 | class ViT(nn.Module): 83 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 84 | super().__init__() 85 | image_height, image_width = pair(image_size) 86 | patch_height, patch_width = pair(patch_size) 87 | 88 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 89 | 90 | num_patches = (image_height // patch_height) * (image_width // patch_width) 91 | patch_dim = channels * patch_height * patch_width 92 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 93 | 94 | self.to_patch_embedding = nn.Sequential( 95 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 96 | nn.Linear(patch_dim, dim), 97 | ) 98 | 99 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 100 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 101 | self.dropout = nn.Dropout(emb_dropout) 102 | 103 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 104 | 105 | self.pool = pool 106 | self.to_latent = nn.Identity() 107 | 108 | self.mlp_head = nn.Sequential( 109 | # nn.LayerNorm(dim), 110 | nn.Linear(dim, num_classes) 111 | ) 112 | 113 | def forward(self, img): 114 | x = self.to_patch_embedding(img) 115 | b, n, _ = x.shape 116 | 117 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 118 | x = torch.cat((cls_tokens, x), dim=1) 119 | x += self.pos_embedding[:, :(n + 1)] 120 | x = self.dropout(x) 121 | 122 | x = self.transformer(x) 123 | 124 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 125 | 126 | x = self.to_latent(x) 127 | 128 | feat_fc = x 129 | 130 | x = self.mlp_head(x) 131 | 132 | return x -------------------------------------------------------------------------------- /baseline_methods.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import logging 5 | from utils import epoch_global, get_time 6 | 7 | 8 | def fedavg(global_net, clients, round, testloader, args, total_data_points=None): 9 | logging.info("="*50 + " Round: " + str(round) + " " + "="*50) 10 | for each_worker in range(args.nworkers): 11 | clients[each_worker].train_net(local_round=args.Iteration, batch_size=args.batch_train, device=args.device, args=args) 12 | loss_avg, acc_avg = epoch_global('test', testloader, clients[each_worker].model, clients[each_worker].optimizer, 13 | clients[each_worker].criterion, args, aug=False) 14 | clients[each_worker].model = clients[each_worker].model.cpu() 15 | logging.info('%s Evaluate_%02d_%02d: val loss = %.6f val acc = %.4f' % (get_time(), round, each_worker, loss_avg, acc_avg)) 16 | 17 | fed_avg_freqs = [len(clients[each_worker].each_worker_data) / total_data_points for each_worker in range(args.nworkers)] 18 | global_para = global_net.cpu().state_dict() 19 | 20 | for each_worker in range(args.nworkers): 21 | net_para = clients[each_worker].model.cpu().state_dict() 22 | if each_worker == 0: 23 | for key in net_para: 24 | global_para[key] = net_para[key] * fed_avg_freqs[each_worker] 25 | else: 26 | for key in net_para: 27 | global_para[key] += net_para[key] * fed_avg_freqs[each_worker] 28 | global_net.load_state_dict(global_para) 29 | global_net.to(args.device) 30 | return global_net 31 | 32 | 33 | def fedprox(global_net, clients, round, testloader, args, total_data_points=None): 34 | logging.info("="*50 + " Round: " + str(round) + " " + "="*50) 35 | for each_worker in range(args.nworkers): 36 | clients[each_worker].train_net_fedprox(local_round=args.Iteration, batch_size=args.batch_train, device=args.device, args=args, global_net=global_net) 37 | loss_avg, acc_avg = epoch_global('test', testloader, clients[each_worker].model, clients[each_worker].optimizer, 38 | clients[each_worker].criterion, args, aug=False) 39 | clients[each_worker].model = clients[each_worker].model.cpu() 40 | logging.info('%s Evaluate_%02d_%02d: val loss = %.6f val acc = %.4f' % (get_time(), round, each_worker, loss_avg, acc_avg)) 41 | 42 | fed_avg_freqs = [len(clients[each_worker].each_worker_data) / total_data_points for each_worker in range(args.nworkers)] 43 | global_para = global_net.cpu().state_dict() 44 | 45 | for each_worker in range(args.nworkers): 46 | net_para = clients[each_worker].model.cpu().state_dict() 47 | if each_worker == 0: 48 | for key in net_para: 49 | global_para[key] = net_para[key] * fed_avg_freqs[each_worker] 50 | else: 51 | for key in net_para: 52 | global_para[key] += net_para[key] * fed_avg_freqs[each_worker] 53 | global_net.load_state_dict(global_para) 54 | global_net.to(args.device) 55 | return global_net 56 | 57 | 58 | def fednova(global_net, clients, round, testloader, args, total_data_points=None): 59 | logging.info("=" * 50 + " Round: " + str(round) + " " + "=" * 50) 60 | a_list = [] 61 | d_list = [] 62 | n_list = [] 63 | for each_worker in range(args.nworkers): 64 | train_loss, a_i, d_i, n_i = clients[each_worker].train_net_fednova(local_round=args.Iteration, 65 | batch_size=args.batch_train, device=args.device, args=args, global_net=global_net) 66 | a_list.append(a_i) 67 | d_list.append(d_i) 68 | # n_list.append(n_i) 69 | loss_avg, acc_avg = epoch_global('test', testloader, clients[each_worker].model, clients[each_worker].optimizer, 70 | clients[each_worker].criterion, args, aug=False) 71 | clients[each_worker].model = clients[each_worker].model.cpu() 72 | logging.info('%s Evaluate_%02d_%02d: val loss = %.6f val acc = %.4f' % (get_time(), round, each_worker, loss_avg, acc_avg)) 73 | 74 | fed_avg_freqs = [len(clients[each_worker].each_worker_data) / total_data_points for each_worker in range(args.nworkers)] 75 | global_para = global_net.state_dict() 76 | 77 | d_total_round = copy.deepcopy(global_net.state_dict()) 78 | for key in d_total_round: 79 | d_total_round[key] = 0.0 80 | for each_worker in range(args.nworkers): 81 | for key in d_list[each_worker]: 82 | d_total_round[key] += d_list[each_worker][key] * fed_avg_freqs[each_worker] 83 | 84 | coeff = 0.0 85 | for each_worker in range(args.nworkers): 86 | coeff += a_list[each_worker] * fed_avg_freqs[each_worker] 87 | 88 | for key in global_para: 89 | if global_para[key].type() == 'torch.LongTensor': 90 | global_para[key] -= (coeff * d_total_round[key]).type(torch.LongTensor) 91 | elif global_para[key].type() == 'torch.cuda.LongTensor': 92 | global_para[key] -= (coeff * d_total_round[key]).type(torch.cuda.LongTensor) 93 | else: 94 | global_para[key] -= coeff * d_total_round[key] 95 | global_net.load_state_dict(global_para) 96 | global_net.to(args.device) 97 | return global_net 98 | -------------------------------------------------------------------------------- /stylegan_xl/feature_networks/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /stylegan_xl/gen_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | from torch_utils import gen_utils 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def parse_range(s: Union[str, List]) -> List[int]: 27 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 28 | 29 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 30 | ''' 31 | if isinstance(s, list): return s 32 | ranges = [] 33 | range_re = re.compile(r'^(\d+)-(\d+)$') 34 | for p in s.split(','): 35 | m = range_re.match(p) 36 | if m: 37 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 38 | else: 39 | ranges.append(int(p)) 40 | return ranges 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: 45 | '''Parse a floating point 2-vector of syntax 'a,b'. 46 | 47 | Example: 48 | '0,1' returns (0,1) 49 | ''' 50 | if isinstance(s, tuple): return s 51 | parts = s.split(',') 52 | if len(parts) == 2: 53 | return (float(parts[0]), float(parts[1])) 54 | raise ValueError(f'cannot parse 2-vector {s}') 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | def make_transform(translate: Tuple[float,float], angle: float): 59 | m = np.eye(3) 60 | s = np.sin(angle/360.0*np.pi*2) 61 | c = np.cos(angle/360.0*np.pi*2) 62 | m[0][0] = c 63 | m[0][1] = s 64 | m[0][2] = translate[0] 65 | m[1][0] = -s 66 | m[1][1] = c 67 | m[1][2] = translate[1] 68 | return m 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | @click.command() 73 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 74 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) 75 | @click.option('--batch-sz', type=int, help='Batch size per sample', default=1) 76 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 77 | @click.option('--centroids-path', type=str, help='Pass path to precomputed centroids to enable multimodal truncation') 78 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 79 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 80 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') 81 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') 82 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 83 | def generate_images( 84 | network_pkl: str, 85 | seeds: List[int], 86 | batch_sz: int, 87 | truncation_psi: float, 88 | centroids_path: str, 89 | noise_mode: str, 90 | outdir: str, 91 | translate: Tuple[float,float], 92 | rotate: float, 93 | class_idx: Optional[int] 94 | ): 95 | print('Loading networks from "%s"...' % network_pkl) 96 | device = torch.device('cuda') 97 | with dnnlib.util.open_url(network_pkl) as f: 98 | G = legacy.load_network_pkl(f)['G_ema'] 99 | G = G.eval().requires_grad_(False).to(device) 100 | 101 | os.makedirs(outdir, exist_ok=True) 102 | 103 | # Generate images. 104 | for seed_idx, seed in enumerate(seeds): 105 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 106 | 107 | # Construct an inverse rotation/translation matrix and pass to the generator. The 108 | # generator expects this matrix as an inverse to avoid potentially failing numerical 109 | # operations in the network. 110 | if hasattr(G.synthesis, 'input'): 111 | m = make_transform(translate, rotate) 112 | m = np.linalg.inv(m) 113 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 114 | 115 | w = gen_utils.get_w_from_seed(G, batch_sz, device, truncation_psi, seed=seed, 116 | centroids_path=centroids_path, class_idx=class_idx) 117 | img = gen_utils.w_to_img(G, w, to_np=True) 118 | PIL.Image.fromarray(gen_utils.create_image_grid(img), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 119 | 120 | 121 | #---------------------------------------------------------------------------- 122 | 123 | if __name__ == "__main__": 124 | generate_images() # pylint: disable=no-value-for-parameter 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /stylegan_xl/metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | # Spherical interpolation of a batch of vectors. 22 | def slerp(a, b, t): 23 | a = a / a.norm(dim=-1, keepdim=True) 24 | b = b / b.norm(dim=-1, keepdim=True) 25 | d = (a * b).sum(dim=-1, keepdim=True) 26 | p = t * torch.acos(d) 27 | c = b - d * a 28 | c = c / c.norm(dim=-1, keepdim=True) 29 | d = a * torch.cos(p) + c * torch.sin(p) 30 | d = d / d.norm(dim=-1, keepdim=True) 31 | return d 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPLSampler(torch.nn.Module): 36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__() 40 | self.G = copy.deepcopy(G) 41 | self.G_kwargs = G_kwargs 42 | self.epsilon = epsilon 43 | self.space = space 44 | self.sampling = sampling 45 | self.crop = crop 46 | self.vgg16 = copy.deepcopy(vgg16) 47 | 48 | def forward(self, c): 49 | # Generate random latents and interpolation t-values. 50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 52 | 53 | # Interpolate in W or Z. 54 | if self.space == 'w': 55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 58 | else: # space == 'z' 59 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 62 | 63 | # Randomize noise buffers. 64 | for name, buf in self.G.named_buffers(): 65 | if name.endswith('.noise_const'): 66 | buf.copy_(torch.randn_like(buf)) 67 | 68 | # Generate images. 69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 70 | 71 | # Center crop. 72 | if self.crop: 73 | assert img.shape[2] == img.shape[3] 74 | c = img.shape[2] // 8 75 | img = img[:, :, c*3 : c*7, c*2 : c*6] 76 | 77 | # Downsample to 256x256. 78 | factor = self.G.img_resolution // 256 79 | if factor > 1: 80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 81 | 82 | # Scale dynamic range from [-1,1] to [0,255]. 83 | img = (img + 1) * (255 / 2) 84 | if self.G.img_channels == 1: 85 | img = img.repeat([1, 3, 1, 1]) 86 | 87 | # Evaluate differential LPIPS. 88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 90 | return dist 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 97 | 98 | # Setup sampler and labels. 99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 100 | sampler.eval().requires_grad_(False).to(opts.device) 101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 102 | 103 | # Sampling loop. 104 | dist = [] 105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 107 | progress.update(batch_start) 108 | x = sampler(next(c_iter)) 109 | for src in range(opts.num_gpus): 110 | y = x.clone() 111 | if opts.num_gpus > 1: 112 | torch.distributed.broadcast(y, src=src) 113 | dist.append(y) 114 | progress.update(num_samples) 115 | 116 | # Compute PPL. 117 | if opts.rank != 0: 118 | return float('nan') 119 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 120 | lo = np.percentile(dist, 1, interpolation='lower') 121 | hi = np.percentile(dist, 99, interpolation='higher') 122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 123 | return float(ppl) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /networks/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | ''' ResNet ''' 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 10 | super(BasicBlock, self).__init__() 11 | self.norm = norm 12 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 16 | 17 | self.shortcut = nn.Sequential() 18 | if stride != 1 or in_planes != self.expansion*planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 21 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 22 | ) 23 | 24 | def forward(self, x): 25 | out = F.relu(self.bn1(self.conv1(x))) 26 | out = self.bn2(self.conv2(out)) 27 | out += self.shortcut(x) 28 | out = F.relu(out) 29 | return out 30 | 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, in_planes, planes, stride=1, norm='instancenorm'): 36 | super(Bottleneck, self).__init__() 37 | self.norm = norm 38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 39 | self.bn1 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn2 = nn.GroupNorm(planes, planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 43 | self.bn3 = nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 44 | 45 | self.shortcut = nn.Sequential() 46 | if stride != 1 or in_planes != self.expansion*planes: 47 | self.shortcut = nn.Sequential( 48 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 49 | nn.GroupNorm(self.expansion*planes, self.expansion*planes, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(self.expansion*planes) 50 | ) 51 | 52 | def forward(self, x): 53 | out = F.relu(self.bn1(self.conv1(x))) 54 | out = F.relu(self.bn2(self.conv2(out))) 55 | out = self.bn3(self.conv3(out)) 56 | out += self.shortcut(x) 57 | out = F.relu(out) 58 | return out 59 | 60 | 61 | class ResNetCIFAR(nn.Module): 62 | def __init__(self, block, num_blocks, channel=3, num_classes=10, norm='instancenorm'): 63 | super(ResNetCIFAR, self).__init__() 64 | self.in_planes = 64 65 | self.norm = norm 66 | 67 | self.conv1 = nn.Conv2d(channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 68 | self.bn1 = nn.GroupNorm(64, 64, affine=True) if self.norm == 'instancenorm' else nn.BatchNorm2d(64) 69 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 70 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 71 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 72 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 73 | self.classifier = nn.Linear(512*block.expansion, num_classes) 74 | 75 | def _make_layer(self, block, planes, num_blocks, stride): 76 | strides = [stride] + [1]*(num_blocks-1) 77 | layers = [] 78 | for stride in strides: 79 | layers.append(block(self.in_planes, planes, stride, self.norm)) 80 | self.in_planes = planes * block.expansion 81 | return nn.Sequential(*layers) 82 | 83 | def forward(self, x): 84 | out = F.relu(self.bn1(self.conv1(x))) 85 | out = self.layer1(out) 86 | out = self.layer2(out) 87 | out = self.layer3(out) 88 | out = self.layer4(out) 89 | out = F.avg_pool2d(out, 4) 90 | out = out.view(out.size(0), -1) 91 | out = self.classifier(out) 92 | return out 93 | 94 | def embed(self, x): 95 | out = F.relu(self.bn1(self.conv1(x))) 96 | out = self.layer1(out) 97 | out = self.layer2(out) 98 | out = self.layer3(out) 99 | out = self.layer4(out) 100 | out = F.avg_pool2d(out, 4) 101 | out = out.view(out.size(0), -1) 102 | return out 103 | 104 | 105 | def ResNet18BNCIFAR(channel, num_classes): 106 | return ResNetCIFAR(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes, norm='batchnorm') 107 | 108 | def ResNet18CIFAR(channel, num_classes): 109 | return ResNetCIFAR(BasicBlock, [2,2,2,2], channel=channel, num_classes=num_classes) 110 | 111 | def ResNet34CIFAR(channel, num_classes): 112 | return ResNetCIFAR(BasicBlock, [3,4,6,3], channel=channel, num_classes=num_classes) 113 | 114 | def ResNet50CIFAR(channel, num_classes): 115 | return ResNetCIFAR(Bottleneck, [3,4,6,3], channel=channel, num_classes=num_classes) 116 | 117 | def ResNet101CIFAR(channel, num_classes): 118 | return ResNetCIFAR(Bottleneck, [3,4,23,3], channel=channel, num_classes=num_classes) 119 | 120 | def ResNet152CIFAR(channel, num_classes): 121 | return ResNetCIFAR(Bottleneck, [3,8,36,3], channel=channel, num_classes=num_classes) 122 | -------------------------------------------------------------------------------- /stylegan_xl/gui_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import functools 10 | from typing import Optional 11 | 12 | import dnnlib 13 | import numpy as np 14 | import PIL.Image 15 | import PIL.ImageFont 16 | import scipy.ndimage 17 | 18 | from . import gl_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def get_default_font(): 23 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular 24 | return dnnlib.util.open_url(url, return_filename=True) 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | @functools.lru_cache(maxsize=None) 29 | def get_pil_font(font=None, size=32): 30 | if font is None: 31 | font = get_default_font() 32 | return PIL.ImageFont.truetype(font=font, size=size) 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def get_array(string, *, dropshadow_radius: int=None, **kwargs): 37 | if dropshadow_radius is not None: 38 | offset_x = int(np.ceil(dropshadow_radius*2/3)) 39 | offset_y = int(np.ceil(dropshadow_radius*2/3)) 40 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 41 | else: 42 | return _get_array_priv(string, **kwargs) 43 | 44 | @functools.lru_cache(maxsize=10000) 45 | def _get_array_priv( 46 | string: str, *, 47 | size: int = 32, 48 | max_width: Optional[int]=None, 49 | max_height: Optional[int]=None, 50 | min_size=10, 51 | shrink_coef=0.8, 52 | dropshadow_radius: int=None, 53 | offset_x: int=None, 54 | offset_y: int=None, 55 | **kwargs 56 | ): 57 | cur_size = size 58 | array = None 59 | while True: 60 | if dropshadow_radius is not None: 61 | # separate implementation for dropshadow text rendering 62 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 63 | else: 64 | array = _get_array_impl(string, size=cur_size, **kwargs) 65 | height, width, _ = array.shape 66 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): 67 | break 68 | cur_size = max(int(cur_size * shrink_coef), min_size) 69 | return array 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | @functools.lru_cache(maxsize=10000) 74 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): 75 | pil_font = get_pil_font(font=font, size=size) 76 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 77 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 78 | width = max(line.shape[1] for line in lines) 79 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 80 | line_spacing = line_pad if line_pad is not None else size // 2 81 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 82 | mask = np.concatenate(lines, axis=0) 83 | alpha = mask 84 | if outline > 0: 85 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) 86 | alpha = mask.astype(np.float32) / 255 87 | alpha = scipy.ndimage.gaussian_filter(alpha, outline) 88 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp 89 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 90 | alpha = np.maximum(alpha, mask) 91 | return np.stack([mask, alpha], axis=-1) 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | @functools.lru_cache(maxsize=10000) 96 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): 97 | assert (offset_x > 0) and (offset_y > 0) 98 | pil_font = get_pil_font(font=font, size=size) 99 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 100 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 101 | width = max(line.shape[1] for line in lines) 102 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 103 | line_spacing = line_pad if line_pad is not None else size // 2 104 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 105 | mask = np.concatenate(lines, axis=0) 106 | alpha = mask 107 | 108 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) 109 | alpha = mask.astype(np.float32) / 255 110 | alpha = scipy.ndimage.gaussian_filter(alpha, radius) 111 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 112 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 113 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] 114 | alpha = np.maximum(alpha, mask) 115 | return np.stack([mask, alpha], axis=-1) 116 | 117 | #---------------------------------------------------------------------------- 118 | 119 | @functools.lru_cache(maxsize=10000) 120 | def get_texture(string, bilinear=True, mipmap=True, **kwargs): 121 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) 122 | 123 | #---------------------------------------------------------------------------- 124 | -------------------------------------------------------------------------------- /stylegan_xl/viz/equivariance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class EquivarianceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.xlate = dnnlib.EasyDict(x=0, y=0, anim=False, round=False, speed=1e-2) 20 | self.xlate_def = dnnlib.EasyDict(self.xlate) 21 | self.rotate = dnnlib.EasyDict(val=0, anim=False, speed=5e-3) 22 | self.rotate_def = dnnlib.EasyDict(self.rotate) 23 | self.opts = dnnlib.EasyDict(untransform=False) 24 | self.opts_def = dnnlib.EasyDict(self.opts) 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | if show: 30 | imgui.text('Translate') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8): 33 | _changed, (self.xlate.x, self.xlate.y) = imgui.input_float2('##xlate', self.xlate.x, self.xlate.y, format='%.4f') 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag fast##xlate', width=viz.button_w) 36 | if dragging: 37 | self.xlate.x += dx / viz.font_size * 2e-2 38 | self.xlate.y += dy / viz.font_size * 2e-2 39 | imgui.same_line() 40 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag slow##xlate', width=viz.button_w) 41 | if dragging: 42 | self.xlate.x += dx / viz.font_size * 4e-4 43 | self.xlate.y += dy / viz.font_size * 4e-4 44 | imgui.same_line() 45 | _clicked, self.xlate.anim = imgui.checkbox('Anim##xlate', self.xlate.anim) 46 | imgui.same_line() 47 | _clicked, self.xlate.round = imgui.checkbox('Round##xlate', self.xlate.round) 48 | imgui.same_line() 49 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.xlate.anim): 50 | changed, speed = imgui.slider_float('##xlate_speed', self.xlate.speed, 0, 0.5, format='Speed %.5f', power=5) 51 | if changed: 52 | self.xlate.speed = speed 53 | imgui.same_line() 54 | if imgui_utils.button('Reset##xlate', width=-1, enabled=(self.xlate != self.xlate_def)): 55 | self.xlate = dnnlib.EasyDict(self.xlate_def) 56 | 57 | if show: 58 | imgui.text('Rotate') 59 | imgui.same_line(viz.label_w) 60 | with imgui_utils.item_width(viz.font_size * 8): 61 | _changed, self.rotate.val = imgui.input_float('##rotate', self.rotate.val, format='%.4f') 62 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 63 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag fast##rotate', width=viz.button_w) 64 | if dragging: 65 | self.rotate.val += dx / viz.font_size * 2e-2 66 | imgui.same_line() 67 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag slow##rotate', width=viz.button_w) 68 | if dragging: 69 | self.rotate.val += dx / viz.font_size * 4e-4 70 | imgui.same_line() 71 | _clicked, self.rotate.anim = imgui.checkbox('Anim##rotate', self.rotate.anim) 72 | imgui.same_line() 73 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.rotate.anim): 74 | changed, speed = imgui.slider_float('##rotate_speed', self.rotate.speed, -1, 1, format='Speed %.4f', power=3) 75 | if changed: 76 | self.rotate.speed = speed 77 | imgui.same_line() 78 | if imgui_utils.button('Reset##rotate', width=-1, enabled=(self.rotate != self.rotate_def)): 79 | self.rotate = dnnlib.EasyDict(self.rotate_def) 80 | 81 | if show: 82 | imgui.set_cursor_pos_x(imgui.get_content_region_max()[0] - 1 - viz.button_w*1 - viz.font_size*16) 83 | _clicked, self.opts.untransform = imgui.checkbox('Untransform', self.opts.untransform) 84 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 85 | if imgui_utils.button('Reset##opts', width=-1, enabled=(self.opts != self.opts_def)): 86 | self.opts = dnnlib.EasyDict(self.opts_def) 87 | 88 | if self.xlate.anim: 89 | c = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 90 | t = c.copy() 91 | if np.max(np.abs(t)) < 1e-4: 92 | t += 1 93 | t *= 0.1 / np.hypot(*t) 94 | t += c[::-1] * [1, -1] 95 | d = t - c 96 | d *= (viz.frame_delta * self.xlate.speed) / np.hypot(*d) 97 | self.xlate.x += d[0] 98 | self.xlate.y += d[1] 99 | 100 | if self.rotate.anim: 101 | self.rotate.val += viz.frame_delta * self.rotate.speed 102 | 103 | pos = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 104 | if self.xlate.round and 'img_resolution' in viz.result: 105 | pos = np.rint(pos * viz.result.img_resolution) / viz.result.img_resolution 106 | angle = self.rotate.val * np.pi * 2 107 | 108 | viz.args.input_transform = [ 109 | [np.cos(angle), np.sin(angle), pos[0]], 110 | [-np.sin(angle), np.cos(angle), pos[1]], 111 | [0, 0, 1]] 112 | 113 | viz.args.update(untransform=self.opts.untransform) 114 | 115 | #---------------------------------------------------------------------------- 116 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/utils_spectrum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fftn 3 | 4 | 5 | def roll_quadrants(data, backwards=False): 6 | """ 7 | Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1] 8 | Args: 9 | data: fourier transform, (NxHxW) 10 | backwards: bool, if True shift high frequencies back to center 11 | 12 | Returns: 13 | Shifted fourier transform. 14 | """ 15 | dim = data.ndim - 1 16 | 17 | if dim != 2: 18 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 19 | if any(s % 2 == 0 for s in data.shape[1:]): 20 | raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.') 21 | 22 | # for each dimension swap left and right half 23 | dims = tuple(range(1, dim+1)) # add one for batch dimension 24 | shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd 25 | if backwards: 26 | shifts *= -1 27 | return data.roll(shifts.tolist(), dims=dims) 28 | 29 | 30 | def batch_fft(data, normalize=False): 31 | """ 32 | Compute fourier transform of batch. 33 | Args: 34 | data: input tensor, (NxHxW) 35 | 36 | Returns: 37 | Batch fourier transform of input data. 38 | """ 39 | 40 | dim = data.ndim - 1 # subtract one for batch dimension 41 | if dim != 2: 42 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 43 | 44 | dims = tuple(range(1, dim + 1)) # add one for batch dimension 45 | if normalize: 46 | norm = 'ortho' 47 | else: 48 | norm = 'backward' 49 | 50 | if not torch.is_complex(data): 51 | data = torch.complex(data, torch.zeros_like(data)) 52 | freq = fftn(data, dim=dims, norm=norm) 53 | 54 | return freq 55 | 56 | 57 | def azimuthal_average(image, center=None): 58 | # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/ 59 | """ 60 | Calculate the azimuthally averaged radial profile. 61 | Requires low frequencies to be at the center of the image. 62 | Args: 63 | image: Batch of 2D images, NxHxW 64 | center: The [x,y] pixel coordinates used as the center. The default is 65 | None, which then uses the center of the image (including 66 | fracitonal pixels). 67 | 68 | Returns: 69 | Azimuthal average over the image around the center 70 | """ 71 | # Check input shapes 72 | assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \ 73 | f'(but it is len(center)={len(center)}.' 74 | # Calculate the indices from the image 75 | H, W = image.shape[-2:] 76 | h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 77 | 78 | if center is None: 79 | center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0]) 80 | 81 | # Compute radius for each pixel wrt center 82 | r = torch.stack([w-center[0], h-center[1]]).norm(2, 0) 83 | 84 | # Get sorted radii 85 | r_sorted, ind = r.flatten().sort() 86 | i_sorted = image.flatten(-2, -1)[..., ind] 87 | 88 | # Get the integer part of the radii (bin size = 1) 89 | r_int = r_sorted.long() # attribute to the smaller integer 90 | 91 | # Find all pixels that fall within each radial bin. 92 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii 93 | rind = torch.where(deltar)[0] # location of changed radius 94 | 95 | # compute number of elements in each bin 96 | nind = rind + 1 # number of elements = idx + 1 97 | nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders 98 | nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius 99 | 100 | # Cumulative sum to figure out sums for each radius bin 101 | if H % 2 == 0: 102 | raise NotImplementedError('Not sure if implementation correct, please check') 103 | rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders 104 | else: 105 | rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders 106 | csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius 107 | tbin = csim[..., rind[1:]] - csim[..., rind[:-1]] 108 | # add mean 109 | tbin = torch.cat([csim[:, 0:1], tbin], 1) 110 | 111 | radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins 112 | 113 | return radial_prof 114 | 115 | 116 | def get_spectrum(data, normalize=False): 117 | dim = data.ndim - 1 # subtract one for batch dimension 118 | if dim != 2: 119 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 120 | 121 | freq = batch_fft(data, normalize=normalize) 122 | power_spec = freq.real ** 2 + freq.imag ** 2 123 | N = data.shape[1] 124 | if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum 125 | # and is not averaged with the mean value 126 | N_2 = N//2 127 | power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1) 128 | power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2) 129 | 130 | power_spec = roll_quadrants(power_spec) 131 | power_spec = azimuthal_average(power_spec) 132 | return power_spec 133 | 134 | 135 | def plot_std(mean, std, x=None, ax=None, **kwargs): 136 | import matplotlib.pyplot as plt 137 | if ax is None: 138 | fig, ax = plt.subplots(1) 139 | 140 | # plot error margins in same color as line 141 | err_kwargs = { 142 | 'alpha': 0.3 143 | } 144 | 145 | if 'c' in kwargs.keys(): 146 | err_kwargs['color'] = kwargs['c'] 147 | elif 'color' in kwargs.keys(): 148 | err_kwargs['color'] = kwargs['color'] 149 | 150 | if x is None: 151 | x = torch.linspace(0, 1, len(mean)) # use normalized x axis 152 | ax.plot(x, mean, **kwargs) 153 | ax.fill_between(x, mean-std, mean+std, **err_kwargs) 154 | 155 | return ax 156 | -------------------------------------------------------------------------------- /stylegan_xl/pg_modules/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from feature_networks.vit import forward_vit 5 | from feature_networks.pretrained_builder import _make_pretrained 6 | from feature_networks.constants import NORMALIZED_INCEPTION, NORMALIZED_IMAGENET, NORMALIZED_CLIP, VITS 7 | from pg_modules.blocks import FeatureFusionBlock 8 | 9 | def get_backbone_normstats(backbone): 10 | if backbone in NORMALIZED_INCEPTION: 11 | return { 12 | 'mean': [0.5, 0.5, 0.5], 13 | 'std': [0.5, 0.5, 0.5], 14 | } 15 | 16 | elif backbone in NORMALIZED_IMAGENET: 17 | return { 18 | 'mean': [0.485, 0.456, 0.406], 19 | 'std': [0.229, 0.224, 0.225], 20 | } 21 | 22 | elif backbone in NORMALIZED_CLIP: 23 | return { 24 | 'mean': [0.48145466, 0.4578275, 0.40821073], 25 | 'std': [0.26862954, 0.26130258, 0.27577711], 26 | } 27 | 28 | else: 29 | raise NotImplementedError 30 | 31 | def _make_scratch_ccm(scratch, in_channels, cout, expand=False): 32 | # shapes 33 | out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4 34 | 35 | scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True) 36 | scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True) 37 | scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True) 38 | scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True) 39 | 40 | scratch.CHANNELS = out_channels 41 | 42 | return scratch 43 | 44 | def _make_scratch_csm(scratch, in_channels, cout, expand): 45 | scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True) 46 | scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand) 47 | scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand) 48 | scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False)) 49 | 50 | # last refinenet does not expand to save channels in higher dimensions 51 | scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4 52 | 53 | return scratch 54 | 55 | def _make_projector(im_res, backbone, cout, proj_type, expand=False): 56 | assert proj_type in [0, 1, 2], "Invalid projection type" 57 | 58 | ### Build pretrained feature network 59 | pretrained = _make_pretrained(backbone) 60 | 61 | # Following Projected GAN 62 | im_res = 256 63 | pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32] 64 | 65 | if proj_type == 0: return pretrained, None 66 | 67 | ### Build CCM 68 | scratch = nn.Module() 69 | scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand) 70 | 71 | pretrained.CHANNELS = scratch.CHANNELS 72 | 73 | if proj_type == 1: return pretrained, scratch 74 | 75 | ### build CSM 76 | scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand) 77 | 78 | # CSM upsamples x2 so the feature map resolution doubles 79 | pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS] 80 | pretrained.CHANNELS = scratch.CHANNELS 81 | 82 | return pretrained, scratch 83 | 84 | class F_Identity(nn.Module): 85 | def forward(self, x): 86 | return x 87 | 88 | class F_RandomProj(nn.Module): 89 | def __init__( 90 | self, 91 | backbone="tf_efficientnet_lite3", 92 | im_res=256, 93 | cout=64, 94 | expand=True, 95 | proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing 96 | **kwargs, 97 | ): 98 | super().__init__() 99 | self.proj_type = proj_type 100 | self.backbone = backbone 101 | self.cout = cout 102 | self.expand = expand 103 | self.normstats = get_backbone_normstats(backbone) 104 | 105 | # build pretrained feature network and random decoder (scratch) 106 | self.pretrained, self.scratch = _make_projector(im_res=im_res, backbone=self.backbone, cout=self.cout, 107 | proj_type=self.proj_type, expand=self.expand) 108 | self.CHANNELS = self.pretrained.CHANNELS 109 | self.RESOLUTIONS = self.pretrained.RESOLUTIONS 110 | 111 | def forward(self, x): 112 | # predict feature maps 113 | if self.backbone in VITS: 114 | out0, out1, out2, out3 = forward_vit(self.pretrained, x) 115 | else: 116 | out0 = self.pretrained.layer0(x) 117 | out1 = self.pretrained.layer1(out0) 118 | out2 = self.pretrained.layer2(out1) 119 | out3 = self.pretrained.layer3(out2) 120 | 121 | # start enumerating at the lowest layer (this is where we put the first discriminator) 122 | out = { 123 | '0': out0, 124 | '1': out1, 125 | '2': out2, 126 | '3': out3, 127 | } 128 | 129 | if self.proj_type == 0: return out 130 | 131 | out0_channel_mixed = self.scratch.layer0_ccm(out['0']) 132 | out1_channel_mixed = self.scratch.layer1_ccm(out['1']) 133 | out2_channel_mixed = self.scratch.layer2_ccm(out['2']) 134 | out3_channel_mixed = self.scratch.layer3_ccm(out['3']) 135 | 136 | out = { 137 | '0': out0_channel_mixed, 138 | '1': out1_channel_mixed, 139 | '2': out2_channel_mixed, 140 | '3': out3_channel_mixed, 141 | } 142 | 143 | if self.proj_type == 1: return out 144 | 145 | # from bottom to top 146 | out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed) 147 | out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed) 148 | out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed) 149 | out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed) 150 | 151 | out = { 152 | '0': out0_scale_mixed, 153 | '1': out1_scale_mixed, 154 | '2': out2_scale_mixed, 155 | '3': out3_scale_mixed, 156 | } 157 | 158 | return out 159 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /stylegan_xl/training/networks_fastgan.py: -------------------------------------------------------------------------------- 1 | # original implementation: https://github.com/odegeasslbc/FastGAN-pytorch/blob/main/models.py 2 | # 3 | # modified by Axel Sauer for "Projected GANs Converge Faster" 4 | # 5 | import torch.nn as nn 6 | from pg_modules.blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d) 7 | 8 | 9 | def normalize_second_moment(x, dim=1, eps=1e-8): 10 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 11 | 12 | 13 | class DummyMapping(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, z, c=None, **kwargs): 18 | return z.unsqueeze(1) # to fit the StyleGAN API 19 | 20 | 21 | class FastganSynthesis(nn.Module): 22 | def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False): 23 | super().__init__() 24 | self.img_resolution = img_resolution 25 | self.z_dim = z_dim 26 | 27 | # channel multiplier 28 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 29 | 512:0.25, 1024:0.125} 30 | nfc = {} 31 | for k, v in nfc_multi.items(): 32 | nfc[k] = int(v*ngf) 33 | 34 | # layers 35 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4) 36 | 37 | UpBlock = UpBlockSmall if lite else UpBlockBig 38 | 39 | self.feat_8 = UpBlock(nfc[4], nfc[8]) 40 | self.feat_16 = UpBlock(nfc[8], nfc[16]) 41 | self.feat_32 = UpBlock(nfc[16], nfc[32]) 42 | self.feat_64 = UpBlock(nfc[32], nfc[64]) 43 | self.feat_128 = UpBlock(nfc[64], nfc[128]) 44 | self.feat_256 = UpBlock(nfc[128], nfc[256]) 45 | 46 | self.se_64 = SEBlock(nfc[4], nfc[64]) 47 | self.se_128 = SEBlock(nfc[8], nfc[128]) 48 | self.se_256 = SEBlock(nfc[16], nfc[256]) 49 | 50 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) 51 | 52 | if img_resolution > 256: 53 | self.feat_512 = UpBlock(nfc[256], nfc[512]) 54 | self.se_512 = SEBlock(nfc[32], nfc[512]) 55 | if img_resolution > 512: 56 | self.feat_1024 = UpBlock(nfc[512], nfc[1024]) 57 | 58 | def forward(self, input, c=None, **kwargs): 59 | # map noise to hypersphere as in "Progressive Growing of GANS" 60 | input = normalize_second_moment(input[:, 0]) 61 | 62 | feat_4 = self.init(input) 63 | feat_8 = self.feat_8(feat_4) 64 | feat_16 = self.feat_16(feat_8) 65 | feat_32 = self.feat_32(feat_16) 66 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32)) 67 | 68 | if self.img_resolution >= 64: 69 | feat_last = feat_64 70 | 71 | if self.img_resolution >= 128: 72 | feat_last = self.se_128(feat_8, self.feat_128(feat_last)) 73 | 74 | if self.img_resolution >= 256: 75 | feat_last = self.se_256(feat_16, self.feat_256(feat_last)) 76 | 77 | if self.img_resolution >= 512: 78 | feat_last = self.se_512(feat_32, self.feat_512(feat_last)) 79 | 80 | if self.img_resolution >= 1024: 81 | feat_last = self.feat_1024(feat_last) 82 | 83 | return self.to_big(feat_last) 84 | 85 | 86 | class FastganSynthesisCond(nn.Module): 87 | def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False): 88 | super().__init__() 89 | 90 | self.z_dim = z_dim 91 | nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 92 | 512:0.25, 1024:0.125, 2048:0.125} 93 | nfc = {} 94 | for k, v in nfc_multi.items(): 95 | nfc[k] = int(v*ngf) 96 | 97 | self.img_resolution = img_resolution 98 | 99 | self.init = InitLayer(z_dim, channel=nfc[2], sz=4) 100 | 101 | UpBlock = UpBlockSmallCond if lite else UpBlockBigCond 102 | 103 | self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim) 104 | self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim) 105 | self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim) 106 | self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim) 107 | self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim) 108 | self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim) 109 | 110 | self.se_64 = SEBlock(nfc[4], nfc[64]) 111 | self.se_128 = SEBlock(nfc[8], nfc[128]) 112 | self.se_256 = SEBlock(nfc[16], nfc[256]) 113 | 114 | self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) 115 | 116 | if img_resolution > 256: 117 | self.feat_512 = UpBlock(nfc[256], nfc[512]) 118 | self.se_512 = SEBlock(nfc[32], nfc[512]) 119 | if img_resolution > 512: 120 | self.feat_1024 = UpBlock(nfc[512], nfc[1024]) 121 | 122 | self.embed = nn.Embedding(num_classes, z_dim) 123 | 124 | def forward(self, input, c, update_emas=False): 125 | c = self.embed(c.argmax(1)) 126 | 127 | # map noise to hypersphere as in "Progressive Growing of GANS" 128 | input = normalize_second_moment(input[:, 0]) 129 | 130 | feat_4 = self.init(input) 131 | feat_8 = self.feat_8(feat_4, c) 132 | feat_16 = self.feat_16(feat_8, c) 133 | feat_32 = self.feat_32(feat_16, c) 134 | feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c)) 135 | feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c)) 136 | 137 | if self.img_resolution >= 128: 138 | feat_last = feat_128 139 | 140 | if self.img_resolution >= 256: 141 | feat_last = self.se_256(feat_16, self.feat_256(feat_last, c)) 142 | 143 | if self.img_resolution >= 512: 144 | feat_last = self.se_512(feat_32, self.feat_512(feat_last, c)) 145 | 146 | if self.img_resolution >= 1024: 147 | feat_last = self.feat_1024(feat_last, c) 148 | 149 | return self.to_big(feat_last) 150 | 151 | 152 | class Generator(nn.Module): 153 | def __init__( 154 | self, 155 | z_dim=256, 156 | c_dim=0, 157 | w_dim=0, 158 | img_resolution=256, 159 | img_channels=3, 160 | ngf=128, 161 | cond=0, 162 | mapping_kwargs={}, 163 | synthesis_kwargs={}, 164 | **kwargs, 165 | ): 166 | super().__init__() 167 | self.z_dim = z_dim 168 | self.c_dim = c_dim 169 | self.w_dim = w_dim 170 | self.img_resolution = img_resolution 171 | self.img_channels = img_channels 172 | 173 | # Mapping and Synthesis Networks 174 | self.mapping = DummyMapping() # to fit the StyleGAN API 175 | Synthesis = FastganSynthesisCond if cond else FastganSynthesis 176 | self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs) 177 | 178 | def forward(self, z, c, **kwargs): 179 | w = self.mapping(z, c) 180 | img = self.synthesis(w, c) 181 | return img 182 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /stylegan_xl/gui_utils/imgui_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import contextlib 10 | import imgui 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): 15 | s = imgui.get_style() 16 | s.window_padding = [spacing, spacing] 17 | s.item_spacing = [spacing, spacing] 18 | s.item_inner_spacing = [spacing, spacing] 19 | s.columns_min_spacing = spacing 20 | s.indent_spacing = indent 21 | s.scrollbar_size = scrollbar 22 | s.frame_padding = [4, 3] 23 | s.window_border_size = 1 24 | s.child_border_size = 1 25 | s.popup_border_size = 1 26 | s.frame_border_size = 1 27 | s.window_rounding = 0 28 | s.child_rounding = 0 29 | s.popup_rounding = 3 30 | s.frame_rounding = 3 31 | s.scrollbar_rounding = 3 32 | s.grab_rounding = 3 33 | 34 | getattr(imgui, f'style_colors_{color_scheme}')(s) 35 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 36 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] 37 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @contextlib.contextmanager 42 | def grayed_out(cond=True): 43 | if cond: 44 | s = imgui.get_style() 45 | text = s.colors[imgui.COLOR_TEXT_DISABLED] 46 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] 47 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 48 | imgui.push_style_color(imgui.COLOR_TEXT, *text) 49 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) 50 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) 51 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) 52 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) 53 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) 54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) 55 | imgui.push_style_color(imgui.COLOR_BUTTON, *back) 56 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) 57 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) 58 | imgui.push_style_color(imgui.COLOR_HEADER, *back) 59 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) 60 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) 61 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) 62 | yield 63 | imgui.pop_style_color(14) 64 | else: 65 | yield 66 | 67 | #---------------------------------------------------------------------------- 68 | 69 | @contextlib.contextmanager 70 | def item_width(width=None): 71 | if width is not None: 72 | imgui.push_item_width(width) 73 | yield 74 | imgui.pop_item_width() 75 | else: 76 | yield 77 | 78 | #---------------------------------------------------------------------------- 79 | 80 | def scoped_by_object_id(method): 81 | def decorator(self, *args, **kwargs): 82 | imgui.push_id(str(id(self))) 83 | res = method(self, *args, **kwargs) 84 | imgui.pop_id() 85 | return res 86 | return decorator 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | def button(label, width=0, enabled=True): 91 | with grayed_out(not enabled): 92 | clicked = imgui.button(label, width=width) 93 | clicked = clicked and enabled 94 | return clicked 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): 99 | expanded = False 100 | if show: 101 | if default: 102 | flags |= imgui.TREE_NODE_DEFAULT_OPEN 103 | if not enabled: 104 | flags |= imgui.TREE_NODE_LEAF 105 | with grayed_out(not enabled): 106 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) 107 | expanded = expanded and enabled 108 | return expanded, visible 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def popup_button(label, width=0, enabled=True): 113 | if button(label, width, enabled): 114 | imgui.open_popup(label) 115 | opened = imgui.begin_popup(label) 116 | return opened 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | def input_text(label, value, buffer_length, flags, width=None, help_text=''): 121 | old_value = value 122 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 123 | if value == '': 124 | color[-1] *= 0.5 125 | with item_width(width): 126 | imgui.push_style_color(imgui.COLOR_TEXT, *color) 127 | value = value if value != '' else help_text 128 | changed, value = imgui.input_text(label, value, buffer_length, flags) 129 | value = value if value != help_text else '' 130 | imgui.pop_style_color(1) 131 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: 132 | changed = (value != old_value) 133 | return changed, value 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | def drag_previous_control(enabled=True): 138 | dragging = False 139 | dx = 0 140 | dy = 0 141 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): 142 | if enabled: 143 | dragging = True 144 | dx, dy = imgui.get_mouse_drag_delta() 145 | imgui.reset_mouse_drag_delta() 146 | imgui.end_drag_drop_source() 147 | return dragging, dx, dy 148 | 149 | #---------------------------------------------------------------------------- 150 | 151 | def drag_button(label, width=0, enabled=True): 152 | clicked = button(label, width=width, enabled=enabled) 153 | dragging, dx, dy = drag_previous_control(enabled=enabled) 154 | return clicked, dragging, dx, dy 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | def drag_hidden_window(label, x, y, width, height, enabled=True): 159 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) 160 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) 161 | imgui.set_next_window_position(x, y) 162 | imgui.set_next_window_size(width, height) 163 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) 164 | dragging, dx, dy = drag_previous_control(enabled=enabled) 165 | imgui.end() 166 | imgui.pop_style_color(2) 167 | return dragging, dx, dy 168 | 169 | #---------------------------------------------------------------------------- 170 | -------------------------------------------------------------------------------- /stylegan_xl/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /reparam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | import types 5 | from collections import namedtuple 6 | from contextlib import contextmanager 7 | 8 | 9 | class ReparamModule(nn.Module): 10 | def _get_module_from_name(self, mn): 11 | if mn == '': 12 | return self 13 | m = self 14 | for p in mn.split('.'): 15 | m = getattr(m, p) 16 | return m 17 | 18 | def __init__(self, module): 19 | super(ReparamModule, self).__init__() 20 | self.module = module 21 | 22 | param_infos = [] # (module name/path, param name) 23 | shared_param_memo = {} 24 | shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name) 25 | params = [] 26 | param_numels = [] 27 | param_shapes = [] 28 | for mn, m in self.named_modules(): 29 | for n, p in m.named_parameters(recurse=False): 30 | if p is not None: 31 | if p in shared_param_memo: 32 | shared_mn, shared_n = shared_param_memo[p] 33 | shared_param_infos.append((mn, n, shared_mn, shared_n)) 34 | else: 35 | shared_param_memo[p] = (mn, n) 36 | param_infos.append((mn, n)) 37 | params.append(p.detach()) 38 | param_numels.append(p.numel()) 39 | param_shapes.append(p.size()) 40 | 41 | assert len(set(p.dtype for p in params)) <= 1, \ 42 | "expects all parameters in module to have same dtype" 43 | 44 | # store the info for unflatten 45 | self._param_infos = tuple(param_infos) 46 | self._shared_param_infos = tuple(shared_param_infos) 47 | self._param_numels = tuple(param_numels) 48 | self._param_shapes = tuple(param_shapes) 49 | 50 | # flatten 51 | flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) 52 | self.register_parameter('flat_param', flat_param) 53 | self.param_numel = flat_param.numel() 54 | del params 55 | del shared_param_memo 56 | 57 | # deregister the names as parameters 58 | for mn, n in self._param_infos: 59 | delattr(self._get_module_from_name(mn), n) 60 | for mn, n, _, _ in self._shared_param_infos: 61 | delattr(self._get_module_from_name(mn), n) 62 | 63 | # register the views as plain attributes 64 | self._unflatten_param(self.flat_param) 65 | 66 | # now buffers 67 | # they are not reparametrized. just store info as (module, name, buffer) 68 | buffer_infos = [] 69 | for mn, m in self.named_modules(): 70 | for n, b in m.named_buffers(recurse=False): 71 | if b is not None: 72 | buffer_infos.append((mn, n, b)) 73 | 74 | self._buffer_infos = tuple(buffer_infos) 75 | self._traced_self = None 76 | 77 | def trace(self, example_input, **trace_kwargs): 78 | assert self._traced_self is None, 'This ReparamModule is already traced' 79 | 80 | if isinstance(example_input, torch.Tensor): 81 | example_input = (example_input,) 82 | example_input = tuple(example_input) 83 | example_param = (self.flat_param.detach().clone(),) 84 | example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),) 85 | 86 | self._traced_self = torch.jit.trace_module( 87 | self, 88 | inputs=dict( 89 | _forward_with_param=example_param + example_input, 90 | _forward_with_param_and_buffers=example_param + example_buffers + example_input, 91 | ), 92 | **trace_kwargs, 93 | ) 94 | 95 | # replace forwards with traced versions 96 | self._forward_with_param = self._traced_self._forward_with_param 97 | self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers 98 | return self 99 | 100 | def clear_views(self): 101 | for mn, n in self._param_infos: 102 | setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr 103 | 104 | def _apply(self, *args, **kwargs): 105 | if self._traced_self is not None: 106 | self._traced_self._apply(*args, **kwargs) 107 | return self 108 | return super(ReparamModule, self)._apply(*args, **kwargs) 109 | 110 | def _unflatten_param(self, flat_param): 111 | ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) 112 | for (mn, n), p in zip(self._param_infos, ps): 113 | setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr 114 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 115 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 116 | 117 | @contextmanager 118 | def unflattened_param(self, flat_param): 119 | saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos] 120 | self._unflatten_param(flat_param) 121 | yield 122 | # Why not just `self._unflatten_param(self.flat_param)`? 123 | # 1. because of https://github.com/pytorch/pytorch/issues/17583 124 | # 2. slightly faster since it does not require reconstruct the split+view 125 | # graph 126 | for (mn, n), p in zip(self._param_infos, saved_views): 127 | setattr(self._get_module_from_name(mn), n, p) 128 | for (mn, n, shared_mn, shared_n) in self._shared_param_infos: 129 | setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n)) 130 | 131 | @contextmanager 132 | def replaced_buffers(self, buffers): 133 | for (mn, n, _), new_b in zip(self._buffer_infos, buffers): 134 | setattr(self._get_module_from_name(mn), n, new_b) 135 | yield 136 | for mn, n, old_b in self._buffer_infos: 137 | setattr(self._get_module_from_name(mn), n, old_b) 138 | 139 | def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs): 140 | with self.unflattened_param(flat_param): 141 | with self.replaced_buffers(buffers): 142 | return self.module(*inputs, **kwinputs) 143 | 144 | def _forward_with_param(self, flat_param, *inputs, **kwinputs): 145 | with self.unflattened_param(flat_param): 146 | return self.module(*inputs, **kwinputs) 147 | 148 | def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs): 149 | flat_param = torch.squeeze(flat_param) 150 | # print("PARAMS ON DEVICE: ", flat_param.get_device(), flat_param.shape) 151 | # print("DATA ON DEVICE: ", inputs[0].get_device(), inputs[0].shape) 152 | # flat_param.to("cuda:{}".format(inputs[0].get_device())) 153 | # self.module.to("cuda:{}".format(inputs[0].get_device())) 154 | if flat_param is None: 155 | flat_param = self.flat_param 156 | if buffers is None: 157 | return self._forward_with_param(flat_param, *inputs, **kwinputs) 158 | else: 159 | return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs) -------------------------------------------------------------------------------- /fed_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.distributed as dist 6 | import numpy as np 7 | import sys 8 | import random 9 | from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, match_loss, get_time, \ 10 | TensorDataset, epoch, DiffAugment, ParamDiffAug, epoch_global, get_logger 11 | from clients import Client 12 | from distill_client import client_update 13 | from glad_utils import * 14 | from baseline_methods import * 15 | 16 | 17 | def get_fname(args): 18 | parserString = "Baseline_Method"+str(args.baseline_method)+"_Dataset"+str(args.dataset)+"_Iteration"+\ 19 | str(args.Iteration)+"_nworkers"+str(args.nworkers)+"_beta"+str(args.beta)+"_lrNet"+\ 20 | str(args.lr_net)+"_round"+str(args.round) 21 | return parserString 22 | 23 | 24 | def main(args): 25 | 26 | input_str = ' '.join(sys.argv) 27 | print(input_str) 28 | print(args.local_rank) 29 | logger = get_logger('log/' + get_fname(args) + time.strftime("%Y%m%d-%H%M%S") + ".log", distributed_rank=args.local_rank) 30 | 31 | for k, v in sorted(vars(args).items()): 32 | logger.info(str(k) + '=' + str(v)) 33 | logger.info(input_str) 34 | dist.init_process_group(backend='nccl') 35 | torch.cuda.set_device(args.local_rank) 36 | args.device = torch.device('cuda', args.local_rank) if torch.cuda.is_available() else 'cpu' 37 | args.dsa_param = ParamDiffAug() 38 | args.dsa = False if args.dsa_strategy in ['none', 'None'] else True 39 | 40 | seed = 0 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | torch.backends.cudnn.deterministic = True 46 | 47 | if not os.path.exists(args.data_path): 48 | os.mkdir(args.data_path) 49 | 50 | run_dir = "{}-{}".format(time.strftime("%Y%m%d-%H%M%S"), get_fname(args)) 51 | 52 | args.save_path = os.path.join(args.save_path, "dm", run_dir) 53 | 54 | if not os.path.exists(args.save_path): 55 | os.makedirs(args.save_path, exist_ok=True) 56 | 57 | channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset( 58 | args.dataset, args.data_path, args.batch_real, args.res, args=args) 59 | 60 | model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model) 61 | 62 | args.distributed = torch.cuda.device_count() > 1 63 | 64 | each_worker_data, each_worker_label, indices_cl_classes = build_client_dataset_dirichlet(dst_train, class_map, num_classes, args) 65 | total_data_points = len(dst_train) 66 | fed_avg_freqs = [len(each_worker_data[each_worker]) / total_data_points for each_worker in 67 | range(args.nworkers)] 68 | 69 | for each_client in range(args.nworkers): 70 | print(each_client, " ", len(each_worker_data[each_client])) 71 | print([len(indices_cl_classes[each_client][c]) for c in range(10)]) 72 | print("="*100) 73 | print(fed_avg_freqs, sum(fed_avg_freqs)) 74 | 75 | def get_images(c, n, each_worker): # get random n images of class c from client each_worker 76 | idx_shuffle = np.random.permutation(indices_cl_classes[each_worker][c])[:n] 77 | return each_worker_data[each_worker][idx_shuffle].cuda(non_blocking=True) 78 | 79 | global_model_acc = [] 80 | global_model_loss = [] 81 | 82 | """initialize clients""" 83 | Clients = [] 84 | for each_worker in range(args.nworkers): 85 | net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width, args=args).to( 86 | args.device) # get a random model 87 | if not args.single_gpu: 88 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.local_rank]) 89 | net.train() 90 | if args.dataset == "CIFAR10": 91 | optimizer_net = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr_net, 92 | momentum=0.9, 93 | weight_decay=args.reg) 94 | else: 95 | optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic data 96 | criterion = nn.CrossEntropyLoss().cuda() 97 | client = Client(each_worker, optimizer=optimizer_net, criterion=criterion, 98 | each_worker_data=each_worker_data[each_worker], 99 | each_worker_label=each_worker_label[each_worker], 100 | indices_cl_class=indices_cl_classes[each_worker], 101 | model=net, args=args) 102 | Clients.append(client) 103 | 104 | """initialize global model""" 105 | net = get_network(args.model, channel, num_classes, im_size, depth=args.depth, width=args.width, args=args).to(args.device) # get a random model 106 | if not args.single_gpu: 107 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.local_rank]) 108 | net.train() 109 | if args.dataset == "CIFAR10": 110 | optimizer_net = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr_net, momentum=0.9, 111 | weight_decay=args.reg) 112 | else: 113 | optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) # optimizer_img for synthetic data 114 | optimizer_net.zero_grad() 115 | criterion = nn.CrossEntropyLoss().cuda() 116 | 117 | """federated training""" 118 | for i in range(args.round): 119 | global_para = net.cpu().state_dict() 120 | for each_worker in range(args.nworkers): 121 | Clients[each_worker].model.load_state_dict(global_para) 122 | if args.baseline_method == "fedavg": 123 | net = fedavg(net, Clients, i, testloader, args, total_data_points=total_data_points) 124 | elif args.baseline_method == 'fedprox': 125 | net = fedprox(net, Clients, i, testloader, args, total_data_points=total_data_points) 126 | elif args.baseline_method == 'fednova': 127 | net = fednova(net, Clients, i, testloader, args, total_data_points=total_data_points) 128 | else: 129 | raise NotImplementedError 130 | with torch.no_grad(): 131 | loss_avg, acc_avg = epoch_global('test', testloader, net, optimizer_net, criterion, args, 132 | aug=True if args.dsa else False) 133 | global_model_acc.append(acc_avg) 134 | global_model_loss.append(loss_avg) 135 | logger.info('%s Evaluate: val loss = %.6f val acc = %.4f' % (get_time(), loss_avg, acc_avg)) 136 | 137 | logger.info("Test Loss: " + str(global_model_loss)) 138 | logger.info("Test Acc: " + str(global_model_acc)) 139 | logger.info(input_str) 140 | 141 | 142 | if __name__ == '__main__': 143 | import shared_args 144 | 145 | parser = shared_args.baseline_args() 146 | 147 | parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters') 148 | parser.add_argument('--inner_loop', type=int, default=500, help='inner loop') 149 | parser.add_argument('--outer_loop', type=int, default=1, help='outer loop') 150 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric') 151 | 152 | args = parser.parse_args() 153 | main(args) 154 | -------------------------------------------------------------------------------- /stylegan_xl/viz/pickle_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import os 11 | import re 12 | 13 | import dnnlib 14 | import imgui 15 | import numpy as np 16 | from gui_utils import imgui_utils 17 | 18 | from . import renderer 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def _locate_results(pattern): 23 | return pattern 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | class PickleWidget: 28 | def __init__(self, viz): 29 | self.viz = viz 30 | self.search_dirs = [] 31 | self.cur_pkl = None 32 | self.user_pkl = '' 33 | self.recent_pkls = [] 34 | self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} 35 | self.browse_refocus = False 36 | self.load('', ignore_errors=True) 37 | 38 | def add_recent(self, pkl, ignore_errors=False): 39 | try: 40 | resolved = self.resolve_pkl(pkl) 41 | if resolved not in self.recent_pkls: 42 | self.recent_pkls.append(resolved) 43 | except: 44 | if not ignore_errors: 45 | raise 46 | 47 | def load(self, pkl, ignore_errors=False): 48 | viz = self.viz 49 | viz.clear_result() 50 | viz.skip_frame() # The input field will change on next frame. 51 | try: 52 | resolved = self.resolve_pkl(pkl) 53 | name = resolved.replace('\\', '/').split('/')[-1] 54 | self.cur_pkl = resolved 55 | self.user_pkl = resolved 56 | viz.result.message = f'Loading {name}...' 57 | viz.defer_rendering() 58 | if resolved in self.recent_pkls: 59 | self.recent_pkls.remove(resolved) 60 | self.recent_pkls.insert(0, resolved) 61 | except: 62 | self.cur_pkl = None 63 | self.user_pkl = pkl 64 | if pkl == '': 65 | viz.result = dnnlib.EasyDict(message='No network pickle loaded') 66 | else: 67 | viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) 68 | if not ignore_errors: 69 | raise 70 | 71 | @imgui_utils.scoped_by_object_id 72 | def __call__(self, show=True): 73 | viz = self.viz 74 | recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] 75 | if show: 76 | imgui.text('Pickle') 77 | imgui.same_line(viz.label_w) 78 | changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, 79 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 80 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 81 | help_text=' | | | | /.pkl') 82 | if changed: 83 | self.load(self.user_pkl, ignore_errors=True) 84 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': 85 | imgui.set_tooltip(self.user_pkl) 86 | imgui.same_line() 87 | if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): 88 | imgui.open_popup('recent_pkls_popup') 89 | imgui.same_line() 90 | if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): 91 | imgui.open_popup('browse_pkls_popup') 92 | self.browse_cache.clear() 93 | self.browse_refocus = True 94 | 95 | if imgui.begin_popup('recent_pkls_popup'): 96 | for pkl in recent_pkls: 97 | clicked, _state = imgui.menu_item(pkl) 98 | if clicked: 99 | self.load(pkl, ignore_errors=True) 100 | imgui.end_popup() 101 | 102 | if imgui.begin_popup('browse_pkls_popup'): 103 | def recurse(parents): 104 | key = tuple(parents) 105 | items = self.browse_cache.get(key, None) 106 | if items is None: 107 | items = self.list_runs_and_pkls(parents) 108 | self.browse_cache[key] = items 109 | for item in items: 110 | if item.type == 'run' and imgui.begin_menu(item.name): 111 | recurse([item.path]) 112 | imgui.end_menu() 113 | if item.type == 'pkl': 114 | clicked, _state = imgui.menu_item(item.name) 115 | if clicked: 116 | self.load(item.path, ignore_errors=True) 117 | if len(items) == 0: 118 | with imgui_utils.grayed_out(): 119 | imgui.menu_item('No results found') 120 | recurse(self.search_dirs) 121 | if self.browse_refocus: 122 | imgui.set_scroll_here() 123 | viz.skip_frame() # Focus will change on next frame. 124 | self.browse_refocus = False 125 | imgui.end_popup() 126 | 127 | paths = viz.pop_drag_and_drop_paths() 128 | if paths is not None and len(paths) >= 1: 129 | self.load(paths[0], ignore_errors=True) 130 | 131 | viz.args.pkl = self.cur_pkl 132 | 133 | def list_runs_and_pkls(self, parents): 134 | items = [] 135 | run_regex = re.compile(r'\d+-.*') 136 | pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') 137 | for parent in set(parents): 138 | if os.path.isdir(parent): 139 | for entry in os.scandir(parent): 140 | if entry.is_dir() and run_regex.fullmatch(entry.name): 141 | items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) 142 | if entry.is_file() and pkl_regex.fullmatch(entry.name): 143 | items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) 144 | 145 | items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) 146 | return items 147 | 148 | def resolve_pkl(self, pattern): 149 | assert isinstance(pattern, str) 150 | assert pattern != '' 151 | 152 | # URL => return as is. 153 | if dnnlib.util.is_url(pattern): 154 | return pattern 155 | 156 | # Short-hand pattern => locate. 157 | path = _locate_results(pattern) 158 | 159 | # Run dir => pick the last saved snapshot. 160 | if os.path.isdir(path): 161 | pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) 162 | if len(pkl_files) == 0: 163 | raise IOError(f'No network pickle found in "{path}"') 164 | path = pkl_files[-1] 165 | 166 | # Normalize. 167 | path = os.path.abspath(path) 168 | return path 169 | 170 | #---------------------------------------------------------------------------- 171 | --------------------------------------------------------------------------------