├── .github ├── FUNDING.yml └── workflows │ └── python-publish.yml ├── MANIFEST.in ├── dalle2.png ├── stylegan3 ├── docs │ ├── visualizer_screen0.png │ ├── avg_spectra_screen0.png │ ├── visualizer_screen0_half.png │ ├── avg_spectra_screen0_half.png │ ├── stylegan3-teaser-1920x1006.png │ ├── troubleshooting.md │ ├── dataset-tool-help.txt │ └── train-help.txt ├── metrics │ ├── __init__.py │ ├── inception_score.py │ ├── frechet_inception_distance.py │ ├── kernel_inception_distance.py │ ├── precision_recall.py │ ├── perceptual_path_length.py │ └── metric_main.py ├── viz │ ├── __init__.py │ ├── stylemix_widget.py │ ├── performance_widget.py │ ├── capture_widget.py │ ├── trunc_noise_widget.py │ ├── latent_widget.py │ ├── equivariance_widget.py │ └── pickle_widget.py ├── gui_utils │ ├── __init__.py │ ├── imgui_window.py │ ├── text_utils.py │ ├── imgui_utils.py │ └── glfw_window.py ├── torch_utils │ ├── __init__.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.h │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── filtered_lrelu_ns.cu │ │ ├── upfirdn2d.h │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── filtered_lrelu.h │ │ ├── bias_act.cpp │ │ ├── upfirdn2d.cpp │ │ ├── bias_act.cu │ │ └── conv2d_resample.py │ └── custom_ops.py ├── training │ ├── __init__.py │ └── loss.py ├── environment.yml ├── dnnlib │ └── __init__.py ├── Dockerfile ├── .github │ └── ISSUE_TEMPLATE │ │ └── bug_report.md ├── LICENSE.txt ├── gen_images.py ├── gen_video.py └── calc_metrics.py ├── dalle2_pytorch ├── dataloaders │ ├── __init__.py │ └── decoder_loader.py ├── __init__.py ├── optimizer.py ├── cli.py └── tokenizer.py ├── dalle2-pytorch_LICENSE ├── config.yaml ├── setup.py ├── .gitignore ├── clip2latent └── train_utils.py └── train.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [lucidrains] 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dalle2_pytorch *.txt 2 | -------------------------------------------------------------------------------- /dalle2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/DALLE2-pytorch/main/dalle2.png -------------------------------------------------------------------------------- /stylegan3/docs/visualizer_screen0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/DALLE2-pytorch/main/stylegan3/docs/visualizer_screen0.png -------------------------------------------------------------------------------- /stylegan3/docs/avg_spectra_screen0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/DALLE2-pytorch/main/stylegan3/docs/avg_spectra_screen0.png -------------------------------------------------------------------------------- /dalle2_pytorch/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader -------------------------------------------------------------------------------- /stylegan3/docs/visualizer_screen0_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/DALLE2-pytorch/main/stylegan3/docs/visualizer_screen0_half.png -------------------------------------------------------------------------------- /stylegan3/docs/avg_spectra_screen0_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/DALLE2-pytorch/main/stylegan3/docs/avg_spectra_screen0_half.png -------------------------------------------------------------------------------- /stylegan3/docs/stylegan3-teaser-1920x1006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/DALLE2-pytorch/main/stylegan3/docs/stylegan3-teaser-1920x1006.png -------------------------------------------------------------------------------- /dalle2_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder 2 | from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter 3 | from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer 4 | 5 | from dalle2_pytorch.vqgan_vae import VQGanVAE 6 | from x_clip import CLIP 7 | -------------------------------------------------------------------------------- /stylegan3/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan3/viz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan3/gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan3/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /stylegan3/environment.yml: -------------------------------------------------------------------------------- 1 | name: stylegan3 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.9.1 13 | - cudatoolkit=11.1 14 | - requests=2.26.0 15 | - tqdm=4.62.2 16 | - ninja=1.10.2 17 | - matplotlib=3.4.2 18 | - imageio=2.9.0 19 | - pip: 20 | - imgui==1.3.0 21 | - glfw==2.2.0 22 | - pyopengl==3.1.5 23 | - imageio-ffmpeg==0.4.3 24 | - pyspng 25 | -------------------------------------------------------------------------------- /stylegan3/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /stylegan3/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM nvcr.io/nvidia/pytorch:21.08-py3 10 | 11 | ENV PYTHONDONTWRITEBYTECODE 1 12 | ENV PYTHONUNBUFFERED 1 13 | 14 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 15 | 16 | WORKDIR /workspace 17 | 18 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 19 | ENTRYPOINT ["/entry.sh"] 20 | -------------------------------------------------------------------------------- /dalle2_pytorch/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW, Adam 2 | 3 | def separate_weight_decayable_params(params): 4 | no_wd_params = set([param for param in params if param.ndim < 2]) 5 | wd_params = set(params) - no_wd_params 6 | return wd_params, no_wd_params 7 | 8 | def get_optimizer( 9 | params, 10 | lr = 3e-4, 11 | wd = 1e-2, 12 | betas = (0.9, 0.999), 13 | filter_by_requires_grad = False 14 | ): 15 | if filter_by_requires_grad: 16 | params = list(filter(lambda t: t.requires_grad, params)) 17 | 18 | if wd == 0: 19 | return Adam(params, lr = lr, betas = betas) 20 | 21 | params = set(params) 22 | wd_params, no_wd_params = separate_weight_decayable_params(params) 23 | 24 | param_groups = [ 25 | {'params': list(wd_params)}, 26 | {'params': list(no_wd_params), 'weight_decay': 0}, 27 | ] 28 | 29 | return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) 30 | -------------------------------------------------------------------------------- /dalle2-pytorch_LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stylegan3/.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. In '...' directory, run command '...' 16 | 2. See error (copy&paste full log, including exceptions and **stacktraces**). 17 | 18 | Please copy&paste text instead of screenshots for better searchability. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Linux Ubuntu 20.04, Windows 10] 28 | - PyTorch version (e.g., pytorch 1.9.0) 29 | - CUDA toolkit version (e.g., CUDA 11.4) 30 | - NVIDIA driver version 31 | - GPU [e.g., Titan V, RTX 3090] 32 | - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:21.08-py3) 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | device: "cuda:0" 2 | model: 3 | network: 4 | dim: 512 5 | num_timesteps: 1000 6 | depth: 6 7 | dim_head: 64 8 | heads: 8 9 | diffusion: 10 | image_embed_dim: 512 11 | timesteps: 1000 12 | cond_drop_prob: 0.2 13 | image_embed_scale: 1.0 14 | beta_schedule: "cosine" 15 | predict_x_start: True 16 | data: 17 | bs: 32 18 | clip_feature_path: "data/sg2-ffhq-1024/clip_features.pt" 19 | latent_path: 'data/sg2-ffhq-1024/w_for_clip.pt' 20 | sg_pkl: 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl' 21 | clip_variant: "ViT-B/32" 22 | val: 23 | n_im_val_samples: 16 24 | text_val_samples: [ 25 | "A photograph of a young man's with a beard", 26 | "A photograph of a old woman's face with grey hair", 27 | "A photograph of a child at a birthday party", 28 | "A picture of a face outside in bright sun in front of green grass", 29 | "This man has bangs arched eyebrows curly hair and a small nose", 30 | "A photo of Barack Obama", 31 | "An arctic explorer", 32 | "A clown's face covered in make up", 33 | ] 34 | train: 35 | loop: 36 | print_it: 100 37 | max_it: 1000000 38 | val_it: 1000 39 | opt: 40 | lr: 1.0e-4 41 | weight_decay: 1.0e-2 42 | betas: [0.9, 0.999] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'dalle2-pytorch', 5 | packages = find_packages(exclude=[]), 6 | include_package_data = True, 7 | entry_points={ 8 | 'console_scripts': [ 9 | 'dalle2_pytorch = dalle2_pytorch.cli:main', 10 | 'dream = dalle2_pytorch.cli:dream' 11 | ], 12 | }, 13 | version = '0.2.4', 14 | license='MIT', 15 | description = 'DALL-E 2', 16 | author = 'Phil Wang', 17 | author_email = 'lucidrains@gmail.com', 18 | url = 'https://github.com/lucidrains/dalle2-pytorch', 19 | keywords = [ 20 | 'artificial intelligence', 21 | 'deep learning', 22 | 'text to image' 23 | ], 24 | install_requires=[ 25 | 'click', 26 | 'clip-anytorch', 27 | 'coca-pytorch>=0.0.5', 28 | 'einops>=0.4', 29 | 'einops-exts>=0.0.3', 30 | 'embedding-reader', 31 | 'kornia>=0.5.4', 32 | 'pillow', 33 | 'resize-right>=0.0.2', 34 | 'rotary-embedding-torch', 35 | 'torch>=1.10', 36 | 'torchvision', 37 | 'tqdm', 38 | 'vector-quantize-pytorch', 39 | 'x-clip>=0.4.4', 40 | 'youtokentome', 41 | 'webdataset>=0.2.5', 42 | 'fsspec>=2022.1.0' 43 | ], 44 | classifiers=[ 45 | 'Development Status :: 4 - Beta', 46 | 'Intended Audience :: Developers', 47 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 48 | 'License :: OSI Approved :: MIT License', 49 | 'Programming Language :: Python :: 3.6', 50 | ], 51 | ) 52 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /dalle2_pytorch/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | import torchvision.transforms as T 4 | from functools import reduce 5 | from pathlib import Path 6 | 7 | from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior 8 | 9 | def safeget(dictionary, keys, default = None): 10 | return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary) 11 | 12 | def simple_slugify(text, max_length = 255): 13 | return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length] 14 | 15 | def get_pkg_version(): 16 | from pkg_resources import get_distribution 17 | return get_distribution('dalle2_pytorch').version 18 | 19 | def main(): 20 | pass 21 | 22 | @click.command() 23 | @click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model') 24 | @click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder') 25 | @click.argument('text') 26 | def dream( 27 | model, 28 | cond_scale, 29 | text 30 | ): 31 | model_path = Path(model) 32 | full_model_path = str(model_path.resolve()) 33 | assert model_path.exists(), f'model not found at {full_model_path}' 34 | loaded = torch.load(str(model_path)) 35 | 36 | version = safeget(loaded, 'version') 37 | print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}') 38 | 39 | prior_init_params = safeget(loaded, 'init_params.prior') 40 | decoder_init_params = safeget(loaded, 'init_params.decoder') 41 | model_params = safeget(loaded, 'model_params') 42 | 43 | prior = DiffusionPrior(**prior_init_params) 44 | decoder = Decoder(**decoder_init_params) 45 | 46 | dalle2 = DALLE2(prior, decoder) 47 | dalle2.load_state_dict(model_params) 48 | 49 | image = dalle2(text, cond_scale = cond_scale) 50 | 51 | pil_image = T.ToPILImage()(image) 52 | return pil_image.save(f'./{simple_slugify(text)}.png') 53 | -------------------------------------------------------------------------------- /stylegan3/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /stylegan3/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /stylegan3/docs/troubleshooting.md: -------------------------------------------------------------------------------- 1 | # Troubleshooting 2 | 3 | Our PyTorch code uses custom [CUDA extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) to speed up some of the network layers. Getting these to run can sometimes be a hassle. 4 | 5 | This page aims to give guidance on how to diagnose and fix run-time problems related to these extensions. 6 | 7 | ## Before you start 8 | 9 | 1. Try Docker first! Ensure you can successfully run our models using the recommended Docker image. Follow the instructions in [README.md](/README.md) to get it running. 10 | 2. Can't use Docker? Read on.. 11 | 12 | ## Installing dependencies 13 | 14 | Make sure you've installed everything listed on the requirements section in the [README.md](/README.md). The key components w.r.t. custom extensions are: 15 | 16 | - **[CUDA toolkit 11.1](https://developer.nvidia.com/cuda-toolkit)** or later (this is not the same as `cudatoolkit` from Conda). 17 | - PyTorch invokes `nvcc` to compile our CUDA kernels. 18 | - **ninja** 19 | - PyTorch uses [Ninja](https://ninja-build.org/) as its build system. 20 | - **GCC** (Linux) or **Visual Studio** (Windows) 21 | - GCC 7.x or later is required. Earlier versions such as GCC 6.3 [are known not to work](https://github.com/NVlabs/stylegan3/issues/2). 22 | 23 | #### Why is CUDA toolkit installation necessary? 24 | 25 | The PyTorch package contains the required CUDA toolkit libraries needed to run PyTorch, so why is a separate CUDA toolkit installation required? Our models use custom CUDA kernels to implement operations such as efficient resampling of 2D images. PyTorch code invokes the CUDA compiler at run-time to compile these kernels on first-use. The tools and libraries required for this compilation are not bundled in PyTorch and thus a host CUDA toolkit installation is required. 26 | 27 | ## Things to try 28 | 29 | - Completely remove: `$HOME/.cache/torch_extensions` (Linux) or `C:\Users\\AppData\Local\torch_extensions\torch_extensions\Cache` (Windows) and re-run StyleGAN3 python code. 30 | - Run ninja in `$HOME/.cache/torch_extensions` to see that it builds. 31 | - Inspect the `build.ninja` in the build directories under `$HOME/.cache/torch_extensions` and check CUDA tools and versions are consistent with what you intended to use. 32 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /stylegan3/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | checkpoints/ 3 | data/ 4 | wandb/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /clip2latent/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm.auto import tqdm 4 | import clip 5 | from PIL import Image 6 | import torchvision 7 | 8 | def make_data_stats(w): 9 | w_mean = w.mean(dim=0) 10 | w_std = w.std(dim=0) 11 | return w_mean, w_std 12 | 13 | def normalise_data(w, w_mean, w_std): 14 | device = w.device 15 | w = w - w_mean.to(device) 16 | w = w / w_std.to(device) 17 | return w 18 | 19 | def denormalise_data(w, w_mean, w_std): 20 | device = w.device 21 | w = w * w_std.to(device) 22 | w = w + w_mean.to(device) 23 | return w 24 | 25 | def make_grid(ims, pil=True): 26 | ims = F.interpolate(ims, size=(256,256)) 27 | grid = torchvision.utils.make_grid( 28 | ims.clamp(-1,1), 29 | normalize=True, 30 | value_range=(-1,1), 31 | nrow=4, 32 | ) 33 | if pil: 34 | grid = Image.fromarray((255*grid).to(torch.uint8).permute(1,2,0).detach().cpu().numpy()) 35 | return grid 36 | 37 | @torch.no_grad() 38 | def make_image_val_data(G, clip_model, n_im_val_samples, device): 39 | clip_features = [] 40 | 41 | zs = torch.randn((n_im_val_samples, 512), device=device) 42 | ws = G.mapping(zs, c=None) 43 | for w in tqdm(ws): 44 | out = G.synthesis(w.unsqueeze(0)) 45 | clip_in = F.interpolate(out, (224,224)) 46 | image_features = clip_model.encode_image(clip_in) 47 | clip_features.append(image_features) 48 | 49 | clip_features = torch.cat(clip_features, dim=0) 50 | val_data = { 51 | "clip_features": clip_features, 52 | "z": zs, 53 | "w": ws, 54 | } 55 | return val_data 56 | 57 | 58 | @torch.no_grad() 59 | def make_text_val_data(G, clip_model, text_samples, device): 60 | 61 | text = clip.tokenize(text_samples).to(device) 62 | text_features = clip_model.encode_text(text) 63 | val_data = {"clip_features": text_features,} 64 | return val_data 65 | 66 | @torch.no_grad() 67 | def compute_val(diffusion, val_im, G, clip_model, device, stats, super_cond=0): 68 | diffusion.eval() 69 | images = [] 70 | inp = val_im["clip_features"].to(device) 71 | out = diffusion.p_sample_loop(inp.shape, {"text_embed": inp}, clip_denoised=False, super_cond=super_cond) 72 | 73 | pred_w_clip_features = [] 74 | pred_w = denormalise_data(out, *stats["w"]) 75 | for w in tqdm(pred_w): 76 | out = G.synthesis(w.tile(1,18,1)) 77 | images.append(out) 78 | clip_in = F.interpolate(out, (224,224)) 79 | image_features = clip_model.encode_image(clip_in) 80 | pred_w_clip_features.append(image_features) 81 | 82 | pred_w_clip_features = torch.cat(pred_w_clip_features, dim=0) 83 | images = torch.cat(images, dim=0) 84 | 85 | y = val_im["clip_features"]/val_im["clip_features"].norm(dim=1, keepdim=True) 86 | y_hat = pred_w_clip_features/pred_w_clip_features.norm(dim=1, keepdim=True) 87 | return torch.cosine_similarity(y, y_hat), images 88 | -------------------------------------------------------------------------------- /stylegan3/docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ Load LSUN dataset 9 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 10 | --source train-images-idx3-ubyte.gz Load MNIST dataset 11 | --source path/ Recursively load all images from path/ 12 | --source dataset.zip Recursively load all images from dataset.zip 13 | 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 18 | 19 | The output dataset format can be either an image folder or an uncompressed 20 | zip archive. Zip archives makes it easier to move datasets around file 21 | servers and clusters, and may offer better training performance on network 22 | file systems. 23 | 24 | Images within the dataset archive will be stored as uncompressed PNG. 25 | Uncompresed PNGs can be efficiently decoded in the training loop. 26 | 27 | Class labels are stored in a file called 'dataset.json' that is stored at 28 | the dataset root folder. This file has the following structure: 29 | 30 | { 31 | "labels": [ 32 | ["00000/img00000000.png",6], 33 | ["00000/img00000001.png",9], 34 | ... repeated for every image in the datase 35 | ["00049/img00049999.png",1] 36 | ] 37 | } 38 | 39 | If the 'dataset.json' file cannot be found, the dataset is interpreted as 40 | not containing class labels. 41 | 42 | Image scale/crop and resolution requirements: 43 | 44 | Output images must be square-shaped and they must all have the same power- 45 | of-two dimensions. 46 | 47 | To scale arbitrary input image size to a specific width and height, use 48 | the --resolution option. Output resolution will be either the original 49 | input resolution (if resolution was not specified) or the one specified 50 | with --resolution option. 51 | 52 | Use the --transform=center-crop or --transform=center-crop-wide options to 53 | apply a center crop transform on the input image. These options should be 54 | used with the --resolution option. For example: 55 | 56 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 57 | --transform=center-crop-wide --resolution=512x384 58 | 59 | Options: 60 | --source PATH Directory or archive name for input dataset 61 | [required] 62 | 63 | --dest PATH Output directory or archive name for output 64 | dataset [required] 65 | 66 | --max-images INTEGER Output only up to `max-images` images 67 | --transform [center-crop|center-crop-wide] 68 | Input crop/resize mode 69 | --resolution WxH Output resolution (e.g., '512x512') 70 | --help Show this message and exit. 71 | -------------------------------------------------------------------------------- /stylegan3/viz/stylemix_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class StyleMixingWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.seed_def = 1000 18 | self.seed = self.seed_def 19 | self.animate = False 20 | self.enables = [] 21 | 22 | @imgui_utils.scoped_by_object_id 23 | def __call__(self, show=True): 24 | viz = self.viz 25 | num_ws = viz.result.get('num_ws', 0) 26 | num_enables = viz.result.get('num_ws', 18) 27 | self.enables += [False] * max(num_enables - len(self.enables), 0) 28 | 29 | if show: 30 | imgui.text('Stylemix') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): 33 | _changed, self.seed = imgui.input_int('##seed', self.seed) 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | with imgui_utils.grayed_out(num_ws == 0): 36 | _clicked, self.animate = imgui.checkbox('Anim', self.animate) 37 | 38 | pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w 39 | pos1 = pos2 - imgui.get_text_line_height() - viz.spacing 40 | pos0 = viz.label_w + viz.font_size * 12 41 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) 42 | for idx in range(num_enables): 43 | imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) 44 | if idx == 0: 45 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) 46 | with imgui_utils.grayed_out(num_ws == 0): 47 | _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) 48 | if imgui.is_item_hovered(): 49 | imgui.set_tooltip(f'{idx}') 50 | imgui.pop_style_var(1) 51 | 52 | imgui.same_line(pos2) 53 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) 54 | with imgui_utils.grayed_out(num_ws == 0): 55 | if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): 56 | self.seed = self.seed_def 57 | self.animate = False 58 | self.enables = [False] * num_enables 59 | 60 | if any(self.enables[:num_ws]): 61 | viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] 62 | viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) 63 | if self.animate: 64 | self.seed += 1 65 | 66 | #---------------------------------------------------------------------------- 67 | -------------------------------------------------------------------------------- /stylegan3/docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train a GAN using the techniques described in the paper "Alias-Free 4 | Generative Adversarial Networks". 5 | 6 | Examples: 7 | 8 | # Train StyleGAN3-T for AFHQv2 using 8 GPUs. 9 | python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \ 10 | --gpus=8 --batch=32 --gamma=8.2 --mirror=1 11 | 12 | # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle. 13 | python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \ 14 | --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \ 15 | --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl 16 | 17 | # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs. 18 | python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \ 19 | --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug 20 | 21 | Options: 22 | --outdir DIR Where to save the results [required] 23 | --cfg [stylegan3-t|stylegan3-r|stylegan2] 24 | Base configuration [required] 25 | --data [ZIP|DIR] Training data [required] 26 | --gpus INT Number of GPUs to use [required] 27 | --batch INT Total batch size [required] 28 | --gamma FLOAT R1 regularization weight [required] 29 | --cond BOOL Train conditional model [default: False] 30 | --mirror BOOL Enable dataset x-flips [default: False] 31 | --aug [noaug|ada|fixed] Augmentation mode [default: ada] 32 | --resume [PATH|URL] Resume from given network pickle 33 | --freezed INT Freeze first layers of D [default: 0] 34 | --p FLOAT Probability for --aug=fixed [default: 0.2] 35 | --target FLOAT Target value for --aug=ada [default: 0.6] 36 | --batch-gpu INT Limit batch size per GPU 37 | --cbase INT Capacity multiplier [default: 32768] 38 | --cmax INT Max. feature maps [default: 512] 39 | --glr FLOAT G learning rate [default: varies] 40 | --dlr FLOAT D learning rate [default: 0.002] 41 | --map-depth INT Mapping network depth [default: varies] 42 | --mbstd-group INT Minibatch std group size [default: 4] 43 | --desc STR String to include in result dir name 44 | --metrics [NAME|A,B,C|none] Quality metrics [default: fid50k_full] 45 | --kimg KIMG Total training duration [default: 25000] 46 | --tick KIMG How often to print progress [default: 4] 47 | --snap TICKS How often to save snapshots [default: 50] 48 | --seed INT Random seed [default: 0] 49 | --fp32 BOOL Disable mixed-precision [default: False] 50 | --nobench BOOL Disable cuDNN benchmarking [default: False] 51 | --workers INT DataLoader worker processes [default: 3] 52 | -n, --dry-run Print training options and exit 53 | --help Show this message and exit. 54 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /stylegan3/viz/performance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import array 10 | import numpy as np 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class PerformanceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.gui_times = [float('nan')] * 60 20 | self.render_times = [float('nan')] * 30 21 | self.fps_limit = 60 22 | self.use_vsync = False 23 | self.is_async = False 24 | self.force_fp32 = False 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | self.gui_times = self.gui_times[1:] + [viz.frame_delta] 30 | if 'render_time' in viz.result: 31 | self.render_times = self.render_times[1:] + [viz.result.render_time] 32 | del viz.result.render_time 33 | 34 | if show: 35 | imgui.text('GUI') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 8): 38 | imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) 39 | imgui.same_line(viz.label_w + viz.font_size * 9) 40 | t = [x for x in self.gui_times if x > 0] 41 | t = np.mean(t) if len(t) > 0 else 0 42 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 43 | imgui.same_line(viz.label_w + viz.font_size * 14) 44 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 45 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 46 | with imgui_utils.item_width(viz.font_size * 6): 47 | _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 48 | self.fps_limit = min(max(self.fps_limit, 5), 1000) 49 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 50 | _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) 51 | 52 | if show: 53 | imgui.text('Render') 54 | imgui.same_line(viz.label_w) 55 | with imgui_utils.item_width(viz.font_size * 8): 56 | imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) 57 | imgui.same_line(viz.label_w + viz.font_size * 9) 58 | t = [x for x in self.render_times if x > 0] 59 | t = np.mean(t) if len(t) > 0 else 0 60 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 61 | imgui.same_line(viz.label_w + viz.font_size * 14) 62 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 63 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 64 | _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) 65 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 66 | _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) 67 | 68 | viz.set_fps_limit(self.fps_limit) 69 | viz.set_vsync(self.use_vsync) 70 | viz.set_async(self.is_async) 71 | viz.args.force_fp32 = self.force_fp32 72 | 73 | #---------------------------------------------------------------------------- 74 | -------------------------------------------------------------------------------- /stylegan3/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /stylegan3/viz/capture_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import re 11 | import numpy as np 12 | import imgui 13 | import PIL.Image 14 | from gui_utils import imgui_utils 15 | from . import renderer 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class CaptureWidget: 20 | def __init__(self, viz): 21 | self.viz = viz 22 | self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) 23 | self.dump_image = False 24 | self.dump_gui = False 25 | self.defer_frames = 0 26 | self.disabled_time = 0 27 | 28 | def dump_png(self, image): 29 | viz = self.viz 30 | try: 31 | _height, _width, channels = image.shape 32 | assert channels in [1, 3] 33 | assert image.dtype == np.uint8 34 | os.makedirs(self.path, exist_ok=True) 35 | file_id = 0 36 | for entry in os.scandir(self.path): 37 | if entry.is_file(): 38 | match = re.fullmatch(r'(\d+).*', entry.name) 39 | if match: 40 | file_id = max(file_id, int(match.group(1)) + 1) 41 | if channels == 1: 42 | pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') 43 | else: 44 | pil_image = PIL.Image.fromarray(image, 'RGB') 45 | pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) 46 | except: 47 | viz.result.error = renderer.CapturedException() 48 | 49 | @imgui_utils.scoped_by_object_id 50 | def __call__(self, show=True): 51 | viz = self.viz 52 | if show: 53 | with imgui_utils.grayed_out(self.disabled_time != 0): 54 | imgui.text('Capture') 55 | imgui.same_line(viz.label_w) 56 | _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, 57 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 58 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 59 | help_text='PATH') 60 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': 61 | imgui.set_tooltip(self.path) 62 | imgui.same_line() 63 | if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): 64 | self.dump_image = True 65 | self.defer_frames = 2 66 | self.disabled_time = 0.5 67 | imgui.same_line() 68 | if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)): 69 | self.dump_gui = True 70 | self.defer_frames = 2 71 | self.disabled_time = 0.5 72 | 73 | self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) 74 | if self.defer_frames > 0: 75 | self.defer_frames -= 1 76 | elif self.dump_image: 77 | if 'image' in viz.result: 78 | self.dump_png(viz.result.image) 79 | self.dump_image = False 80 | elif self.dump_gui: 81 | viz.capture_next_frame() 82 | self.dump_gui = False 83 | captured_frame = viz.pop_captured_frame() 84 | if captured_frame is not None: 85 | self.dump_png(captured_frame) 86 | 87 | #---------------------------------------------------------------------------- 88 | -------------------------------------------------------------------------------- /stylegan3/viz/trunc_noise_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class TruncationNoiseWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.prev_num_ws = 0 18 | self.trunc_psi = 1 19 | self.trunc_cutoff = 0 20 | self.noise_enable = True 21 | self.noise_seed = 0 22 | self.noise_anim = False 23 | 24 | @imgui_utils.scoped_by_object_id 25 | def __call__(self, show=True): 26 | viz = self.viz 27 | num_ws = viz.result.get('num_ws', 0) 28 | has_noise = viz.result.get('has_noise', False) 29 | if num_ws > 0 and num_ws != self.prev_num_ws: 30 | if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: 31 | self.trunc_cutoff = num_ws 32 | self.prev_num_ws = num_ws 33 | 34 | if show: 35 | imgui.text('Truncate') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): 38 | _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') 39 | imgui.same_line() 40 | if num_ws == 0: 41 | imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) 42 | else: 43 | with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): 44 | changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') 45 | if changed: 46 | self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) 47 | 48 | with imgui_utils.grayed_out(not has_noise): 49 | imgui.same_line() 50 | _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) 51 | imgui.same_line(round(viz.font_size * 27.7)) 52 | with imgui_utils.grayed_out(not self.noise_enable): 53 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing - viz.font_size * 4): 54 | _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) 55 | imgui.same_line(spacing=0) 56 | _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) 57 | 58 | is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) 59 | is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) 60 | with imgui_utils.grayed_out(is_def_trunc and not has_noise): 61 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 62 | if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): 63 | self.prev_num_ws = num_ws 64 | self.trunc_psi = 1 65 | self.trunc_cutoff = num_ws 66 | self.noise_enable = True 67 | self.noise_seed = 0 68 | self.noise_anim = False 69 | 70 | if self.noise_anim: 71 | self.noise_seed += 1 72 | viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) 73 | viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') 74 | 75 | #---------------------------------------------------------------------------- 76 | -------------------------------------------------------------------------------- /stylegan3/viz/latent_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class LatentWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.latent = dnnlib.EasyDict(x=0, y=0, anim=False, speed=0.25) 20 | self.latent_def = dnnlib.EasyDict(self.latent) 21 | self.step_y = 100 22 | 23 | def drag(self, dx, dy): 24 | viz = self.viz 25 | self.latent.x += dx / viz.font_size * 4e-2 26 | self.latent.y += dy / viz.font_size * 4e-2 27 | 28 | @imgui_utils.scoped_by_object_id 29 | def __call__(self, show=True): 30 | viz = self.viz 31 | if show: 32 | imgui.text('Latent') 33 | imgui.same_line(viz.label_w) 34 | seed = round(self.latent.x) + round(self.latent.y) * self.step_y 35 | with imgui_utils.item_width(viz.font_size * 8): 36 | changed, seed = imgui.input_int('##seed', seed) 37 | if changed: 38 | self.latent.x = seed 39 | self.latent.y = 0 40 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 41 | frac_x = self.latent.x - round(self.latent.x) 42 | frac_y = self.latent.y - round(self.latent.y) 43 | with imgui_utils.item_width(viz.font_size * 5): 44 | changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 45 | if changed: 46 | self.latent.x += new_frac_x - frac_x 47 | self.latent.y += new_frac_y - frac_y 48 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 49 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 50 | if dragging: 51 | self.drag(dx, dy) 52 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) 53 | _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) 54 | imgui.same_line(round(viz.font_size * 27.7)) 55 | with imgui_utils.item_width(-1 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): 56 | changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) 57 | if changed: 58 | self.latent.speed = speed 59 | imgui.same_line() 60 | snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) 61 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): 62 | self.latent = snapped 63 | imgui.same_line() 64 | if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): 65 | self.latent = dnnlib.EasyDict(self.latent_def) 66 | 67 | if self.latent.anim: 68 | self.latent.x += viz.frame_delta * self.latent.speed 69 | viz.args.w0_seeds = [] # [[seed, weight], ...] 70 | for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: 71 | seed_x = np.floor(self.latent.x) + ofs_x 72 | seed_y = np.floor(self.latent.y) + ofs_y 73 | seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) 74 | weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) 75 | if weight > 0: 76 | viz.args.w0_seeds.append([seed, weight]) 77 | 78 | #---------------------------------------------------------------------------- 79 | -------------------------------------------------------------------------------- /stylegan3/gui_utils/imgui_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import imgui 11 | import imgui.integrations.glfw 12 | 13 | from . import glfw_window 14 | from . import imgui_utils 15 | from . import text_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class ImguiWindow(glfw_window.GlfwWindow): 20 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): 21 | if font is None: 22 | font = text_utils.get_default_font() 23 | font_sizes = {int(size) for size in font_sizes} 24 | super().__init__(title=title, **glfw_kwargs) 25 | 26 | # Init fields. 27 | self._imgui_context = None 28 | self._imgui_renderer = None 29 | self._imgui_fonts = None 30 | self._cur_font_size = max(font_sizes) 31 | 32 | # Delete leftover imgui.ini to avoid unexpected behavior. 33 | if os.path.isfile('imgui.ini'): 34 | os.remove('imgui.ini') 35 | 36 | # Init ImGui. 37 | self._imgui_context = imgui.create_context() 38 | self._imgui_renderer = _GlfwRenderer(self._glfw_window) 39 | self._attach_glfw_callbacks() 40 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. 41 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). 42 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} 43 | self._imgui_renderer.refresh_font_texture() 44 | 45 | def close(self): 46 | self.make_context_current() 47 | self._imgui_fonts = None 48 | if self._imgui_renderer is not None: 49 | self._imgui_renderer.shutdown() 50 | self._imgui_renderer = None 51 | if self._imgui_context is not None: 52 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. 53 | self._imgui_context = None 54 | super().close() 55 | 56 | def _glfw_key_callback(self, *args): 57 | super()._glfw_key_callback(*args) 58 | self._imgui_renderer.keyboard_callback(*args) 59 | 60 | @property 61 | def font_size(self): 62 | return self._cur_font_size 63 | 64 | @property 65 | def spacing(self): 66 | return round(self._cur_font_size * 0.4) 67 | 68 | def set_font_size(self, target): # Applied on next frame. 69 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] 70 | 71 | def begin_frame(self): 72 | # Begin glfw frame. 73 | super().begin_frame() 74 | 75 | # Process imgui events. 76 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 77 | if self.content_width > 0 and self.content_height > 0: 78 | self._imgui_renderer.process_inputs() 79 | 80 | # Begin imgui frame. 81 | imgui.new_frame() 82 | imgui.push_font(self._imgui_fonts[self._cur_font_size]) 83 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) 84 | 85 | def end_frame(self): 86 | imgui.pop_font() 87 | imgui.render() 88 | imgui.end_frame() 89 | self._imgui_renderer.render(imgui.get_draw_data()) 90 | super().end_frame() 91 | 92 | #---------------------------------------------------------------------------- 93 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. 94 | 95 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__(*args, **kwargs) 98 | self.mouse_wheel_multiplier = 1 99 | 100 | def scroll_callback(self, window, x_offset, y_offset): 101 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier 102 | 103 | #---------------------------------------------------------------------------- 104 | -------------------------------------------------------------------------------- /stylegan3/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN3 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /stylegan3/metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | # Spherical interpolation of a batch of vectors. 22 | def slerp(a, b, t): 23 | a = a / a.norm(dim=-1, keepdim=True) 24 | b = b / b.norm(dim=-1, keepdim=True) 25 | d = (a * b).sum(dim=-1, keepdim=True) 26 | p = t * torch.acos(d) 27 | c = b - d * a 28 | c = c / c.norm(dim=-1, keepdim=True) 29 | d = a * torch.cos(p) + c * torch.sin(p) 30 | d = d / d.norm(dim=-1, keepdim=True) 31 | return d 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPLSampler(torch.nn.Module): 36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__() 40 | self.G = copy.deepcopy(G) 41 | self.G_kwargs = G_kwargs 42 | self.epsilon = epsilon 43 | self.space = space 44 | self.sampling = sampling 45 | self.crop = crop 46 | self.vgg16 = copy.deepcopy(vgg16) 47 | 48 | def forward(self, c): 49 | # Generate random latents and interpolation t-values. 50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 52 | 53 | # Interpolate in W or Z. 54 | if self.space == 'w': 55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 58 | else: # space == 'z' 59 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 62 | 63 | # Randomize noise buffers. 64 | for name, buf in self.G.named_buffers(): 65 | if name.endswith('.noise_const'): 66 | buf.copy_(torch.randn_like(buf)) 67 | 68 | # Generate images. 69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 70 | 71 | # Center crop. 72 | if self.crop: 73 | assert img.shape[2] == img.shape[3] 74 | c = img.shape[2] // 8 75 | img = img[:, :, c*3 : c*7, c*2 : c*6] 76 | 77 | # Downsample to 256x256. 78 | factor = self.G.img_resolution // 256 79 | if factor > 1: 80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 81 | 82 | # Scale dynamic range from [-1,1] to [0,255]. 83 | img = (img + 1) * (255 / 2) 84 | if self.G.img_channels == 1: 85 | img = img.repeat([1, 3, 1, 1]) 86 | 87 | # Evaluate differential LPIPS. 88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 90 | return dist 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 97 | 98 | # Setup sampler and labels. 99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 100 | sampler.eval().requires_grad_(False).to(opts.device) 101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 102 | 103 | # Sampling loop. 104 | dist = [] 105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 107 | progress.update(batch_start) 108 | x = sampler(next(c_iter)) 109 | for src in range(opts.num_gpus): 110 | y = x.clone() 111 | if opts.num_gpus > 1: 112 | torch.distributed.broadcast(y, src=src) 113 | dist.append(y) 114 | progress.update(num_samples) 115 | 116 | # Compute PPL. 117 | if opts.rank != 0: 118 | return float('nan') 119 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 120 | lo = np.percentile(dist, 1, interpolation='lower') 121 | hi = np.percentile(dist, 99, interpolation='higher') 122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 123 | return float(ppl) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import logging 3 | from pathlib import Path 4 | import sys 5 | import wandb 6 | import yaml 7 | 8 | import clip 9 | import torch 10 | from tqdm.auto import tqdm 11 | 12 | from dalle2_pytorch import DiffusionPriorNetwork 13 | 14 | sys.path.append("stylegan3") 15 | import dnnlib 16 | import legacy 17 | import torch 18 | 19 | from clip2latent import train_utils 20 | from clip2latent.train_utils import compute_val, make_grid, make_image_val_data, make_text_val_data 21 | from clip2latent.latent_prior import ZWPrior 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class Checkpointer(): 27 | def __init__(self, directory, checkpoint_its): 28 | directory = Path(directory) 29 | self.directory = directory 30 | self.checkpoint_its = checkpoint_its 31 | if not directory.exists(): 32 | directory.mkdir(parents=True) 33 | 34 | def save_checkpoint(self, model, iteration): 35 | if iteration % self.checkpoint_its: 36 | return 37 | 38 | k_it = iteration // 1000 39 | filename = self.directory/f"{k_it:06}.ckpt" 40 | checkpoint = {"state_dict": model.state_dict()} 41 | if hasattr(model, "cfg"): 42 | checkpoint["cfg"] = model.cfg 43 | 44 | print(f"Saving checkpoint to {filename}") 45 | torch.save(checkpoint, filename) 46 | 47 | 48 | 49 | def validation(current_it, device, diffusion_prior, stats, G, clip_model, val_data): 50 | image_result, ims = compute_val(diffusion_prior, val_data["val_im"], G, clip_model, device, stats) 51 | val = image_result.mean() 52 | wandb.log({'val/image-similiariy': val}, step=current_it) 53 | 54 | single_im = {"clip_features": val_data["val_im"]["clip_features"][0].tile(8,1)} 55 | image_result, ims = compute_val(diffusion_prior, single_im, G, clip_model, device, stats) 56 | val = image_result.mean() 57 | wandb.log({'val/image-vars': val}, step=current_it) 58 | wandb.log({'val/image/im-variations': wandb.Image(make_grid(ims))}, step=current_it) 59 | 60 | text_result, ims = compute_val(diffusion_prior, val_data["val_text"], G, clip_model, device, stats) 61 | val = text_result.mean() 62 | wandb.log({'val/text': val}, step=current_it) 63 | wandb.log({'val/image/text2im': wandb.Image(make_grid(ims))}, step=current_it) 64 | 65 | text_result, ims = compute_val(diffusion_prior, val_data["val_text"], G, clip_model, device, stats, super_cond=1) 66 | val = text_result.mean() 67 | wandb.log({'val/text-super': val}, step=current_it) 68 | wandb.log({'val/image/text2im-super': wandb.Image(make_grid(ims))}, step=current_it) 69 | 70 | 71 | def train_step(diffusion_prior, device, batch): 72 | diffusion_prior.train() 73 | batch_z, batch_w = batch 74 | batch_z = batch_z.to(device) 75 | batch_w = batch_w.to(device) 76 | 77 | loss = diffusion_prior(batch_z, batch_w) 78 | loss.backward() 79 | return loss 80 | 81 | 82 | def train(diffusion_prior, opt, loader, device, stats, G, clip_model, val_data, val_it, print_it, save_checkpoint, max_it): 83 | 84 | current_it = 0 85 | current_epoch = 0 86 | 87 | while current_it < max_it: 88 | 89 | wandb.log({'epoch': current_epoch}, step=current_it) 90 | pbar = tqdm(loader) 91 | for batch in pbar: 92 | 93 | if current_it % val_it == 0: 94 | validation(current_it, device, diffusion_prior, stats, G, clip_model, val_data) 95 | 96 | loss = train_step(diffusion_prior, device, batch) 97 | 98 | if (current_it % print_it == 0): 99 | wandb.log({'loss': loss.item()}, step=current_it) 100 | 101 | opt.step() 102 | opt.zero_grad() 103 | current_it += 1 104 | pbar.set_postfix({"epoch": current_epoch, "it": current_it}) 105 | 106 | save_checkpoint(diffusion_prior, current_it) 107 | 108 | current_epoch += 1 109 | 110 | 111 | if __name__ == "__main__": 112 | 113 | with open("config.yaml", "rt") as f: 114 | cfg = yaml.safe_load(f) 115 | 116 | wandb.init( 117 | project="clip2latent", 118 | config=cfg, 119 | entity="justinpinkney", 120 | ) 121 | # Load model 122 | device = cfg["device"] 123 | 124 | prior_network = DiffusionPriorNetwork(**cfg["model"]["network"]).to(device) 125 | diffusion_prior = ZWPrior(prior_network, **cfg["model"]["diffusion"]).to(device) 126 | diffusion_prior.cfg = cfg 127 | 128 | z = torch.load(cfg["data"]["clip_feature_path"]) 129 | w = torch.load(cfg["data"]["latent_path"]) 130 | 131 | stats = { 132 | "w": train_utils.make_data_stats(w), 133 | "clip_features": train_utils.make_data_stats(z), 134 | } 135 | 136 | w_norm = train_utils.normalise_data(w, *stats["w"]) 137 | # Doesn't seem to work well if we norm z 138 | # z_norm = train_utils.normalise_data(z, *stats["clip_features"]) 139 | 140 | # Load eval models 141 | with dnnlib.util.open_url(cfg["data"]["sg_pkl"]) as f: 142 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 143 | clip_model, _ = clip.load(cfg["data"]["clip_variant"], device=device) 144 | 145 | val_data = { 146 | "val_im": make_image_val_data(G, clip_model, cfg["val"]["n_im_val_samples"], device), 147 | "val_text": make_text_val_data(G, clip_model, cfg["val"]["text_val_samples"], device), 148 | } 149 | 150 | ds = torch.utils.data.TensorDataset(z, w_norm) 151 | loader = torch.utils.data.DataLoader(ds, batch_size=cfg["data"]["bs"], shuffle=True, drop_last=True) 152 | opt = torch.optim.AdamW(prior_network.parameters(), **cfg["train"]["opt"]) 153 | 154 | checkpoint_dir = f"checkpoints/{datetime.now():%Y%m%d-%H%M%S}" 155 | checkpointer = Checkpointer(checkpoint_dir, cfg["train"]["loop"]["val_it"]) 156 | 157 | train(diffusion_prior, opt, loader, device, stats, G, clip_model, val_data, **cfg["train"]["loop"], save_checkpoint=checkpointer.save_checkpoint) 158 | -------------------------------------------------------------------------------- /stylegan3/gui_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import functools 10 | from typing import Optional 11 | 12 | import dnnlib 13 | import numpy as np 14 | import PIL.Image 15 | import PIL.ImageFont 16 | import scipy.ndimage 17 | 18 | from . import gl_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def get_default_font(): 23 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular 24 | return dnnlib.util.open_url(url, return_filename=True) 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | @functools.lru_cache(maxsize=None) 29 | def get_pil_font(font=None, size=32): 30 | if font is None: 31 | font = get_default_font() 32 | return PIL.ImageFont.truetype(font=font, size=size) 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def get_array(string, *, dropshadow_radius: int=None, **kwargs): 37 | if dropshadow_radius is not None: 38 | offset_x = int(np.ceil(dropshadow_radius*2/3)) 39 | offset_y = int(np.ceil(dropshadow_radius*2/3)) 40 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 41 | else: 42 | return _get_array_priv(string, **kwargs) 43 | 44 | @functools.lru_cache(maxsize=10000) 45 | def _get_array_priv( 46 | string: str, *, 47 | size: int = 32, 48 | max_width: Optional[int]=None, 49 | max_height: Optional[int]=None, 50 | min_size=10, 51 | shrink_coef=0.8, 52 | dropshadow_radius: int=None, 53 | offset_x: int=None, 54 | offset_y: int=None, 55 | **kwargs 56 | ): 57 | cur_size = size 58 | array = None 59 | while True: 60 | if dropshadow_radius is not None: 61 | # separate implementation for dropshadow text rendering 62 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 63 | else: 64 | array = _get_array_impl(string, size=cur_size, **kwargs) 65 | height, width, _ = array.shape 66 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): 67 | break 68 | cur_size = max(int(cur_size * shrink_coef), min_size) 69 | return array 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | @functools.lru_cache(maxsize=10000) 74 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): 75 | pil_font = get_pil_font(font=font, size=size) 76 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 77 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 78 | width = max(line.shape[1] for line in lines) 79 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 80 | line_spacing = line_pad if line_pad is not None else size // 2 81 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 82 | mask = np.concatenate(lines, axis=0) 83 | alpha = mask 84 | if outline > 0: 85 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) 86 | alpha = mask.astype(np.float32) / 255 87 | alpha = scipy.ndimage.gaussian_filter(alpha, outline) 88 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp 89 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 90 | alpha = np.maximum(alpha, mask) 91 | return np.stack([mask, alpha], axis=-1) 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | @functools.lru_cache(maxsize=10000) 96 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): 97 | assert (offset_x > 0) and (offset_y > 0) 98 | pil_font = get_pil_font(font=font, size=size) 99 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 100 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 101 | width = max(line.shape[1] for line in lines) 102 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 103 | line_spacing = line_pad if line_pad is not None else size // 2 104 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 105 | mask = np.concatenate(lines, axis=0) 106 | alpha = mask 107 | 108 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) 109 | alpha = mask.astype(np.float32) / 255 110 | alpha = scipy.ndimage.gaussian_filter(alpha, radius) 111 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 112 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 113 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] 114 | alpha = np.maximum(alpha, mask) 115 | return np.stack([mask, alpha], axis=-1) 116 | 117 | #---------------------------------------------------------------------------- 118 | 119 | @functools.lru_cache(maxsize=10000) 120 | def get_texture(string, bilinear=True, mipmap=True, **kwargs): 121 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) 122 | 123 | #---------------------------------------------------------------------------- 124 | -------------------------------------------------------------------------------- /stylegan3/viz/equivariance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class EquivarianceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.xlate = dnnlib.EasyDict(x=0, y=0, anim=False, round=False, speed=1e-2) 20 | self.xlate_def = dnnlib.EasyDict(self.xlate) 21 | self.rotate = dnnlib.EasyDict(val=0, anim=False, speed=5e-3) 22 | self.rotate_def = dnnlib.EasyDict(self.rotate) 23 | self.opts = dnnlib.EasyDict(untransform=False) 24 | self.opts_def = dnnlib.EasyDict(self.opts) 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | if show: 30 | imgui.text('Translate') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8): 33 | _changed, (self.xlate.x, self.xlate.y) = imgui.input_float2('##xlate', self.xlate.x, self.xlate.y, format='%.4f') 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag fast##xlate', width=viz.button_w) 36 | if dragging: 37 | self.xlate.x += dx / viz.font_size * 2e-2 38 | self.xlate.y += dy / viz.font_size * 2e-2 39 | imgui.same_line() 40 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag slow##xlate', width=viz.button_w) 41 | if dragging: 42 | self.xlate.x += dx / viz.font_size * 4e-4 43 | self.xlate.y += dy / viz.font_size * 4e-4 44 | imgui.same_line() 45 | _clicked, self.xlate.anim = imgui.checkbox('Anim##xlate', self.xlate.anim) 46 | imgui.same_line() 47 | _clicked, self.xlate.round = imgui.checkbox('Round##xlate', self.xlate.round) 48 | imgui.same_line() 49 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.xlate.anim): 50 | changed, speed = imgui.slider_float('##xlate_speed', self.xlate.speed, 0, 0.5, format='Speed %.5f', power=5) 51 | if changed: 52 | self.xlate.speed = speed 53 | imgui.same_line() 54 | if imgui_utils.button('Reset##xlate', width=-1, enabled=(self.xlate != self.xlate_def)): 55 | self.xlate = dnnlib.EasyDict(self.xlate_def) 56 | 57 | if show: 58 | imgui.text('Rotate') 59 | imgui.same_line(viz.label_w) 60 | with imgui_utils.item_width(viz.font_size * 8): 61 | _changed, self.rotate.val = imgui.input_float('##rotate', self.rotate.val, format='%.4f') 62 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 63 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag fast##rotate', width=viz.button_w) 64 | if dragging: 65 | self.rotate.val += dx / viz.font_size * 2e-2 66 | imgui.same_line() 67 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag slow##rotate', width=viz.button_w) 68 | if dragging: 69 | self.rotate.val += dx / viz.font_size * 4e-4 70 | imgui.same_line() 71 | _clicked, self.rotate.anim = imgui.checkbox('Anim##rotate', self.rotate.anim) 72 | imgui.same_line() 73 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.rotate.anim): 74 | changed, speed = imgui.slider_float('##rotate_speed', self.rotate.speed, -1, 1, format='Speed %.4f', power=3) 75 | if changed: 76 | self.rotate.speed = speed 77 | imgui.same_line() 78 | if imgui_utils.button('Reset##rotate', width=-1, enabled=(self.rotate != self.rotate_def)): 79 | self.rotate = dnnlib.EasyDict(self.rotate_def) 80 | 81 | if show: 82 | imgui.set_cursor_pos_x(imgui.get_content_region_max()[0] - 1 - viz.button_w*1 - viz.font_size*16) 83 | _clicked, self.opts.untransform = imgui.checkbox('Untransform', self.opts.untransform) 84 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 85 | if imgui_utils.button('Reset##opts', width=-1, enabled=(self.opts != self.opts_def)): 86 | self.opts = dnnlib.EasyDict(self.opts_def) 87 | 88 | if self.xlate.anim: 89 | c = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 90 | t = c.copy() 91 | if np.max(np.abs(t)) < 1e-4: 92 | t += 1 93 | t *= 0.1 / np.hypot(*t) 94 | t += c[::-1] * [1, -1] 95 | d = t - c 96 | d *= (viz.frame_delta * self.xlate.speed) / np.hypot(*d) 97 | self.xlate.x += d[0] 98 | self.xlate.y += d[1] 99 | 100 | if self.rotate.anim: 101 | self.rotate.val += viz.frame_delta * self.rotate.speed 102 | 103 | pos = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 104 | if self.xlate.round and 'img_resolution' in viz.result: 105 | pos = np.rint(pos * viz.result.img_resolution) / viz.result.img_resolution 106 | angle = self.rotate.val * np.pi * 2 107 | 108 | viz.args.input_transform = [ 109 | [np.cos(angle), np.sin(angle), pos[0]], 110 | [-np.sin(angle), np.cos(angle), pos[1]], 111 | [0, 0, 1]] 112 | 113 | viz.args.update(untransform=self.opts.untransform) 114 | 115 | #---------------------------------------------------------------------------- 116 | -------------------------------------------------------------------------------- /stylegan3/metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Main API for computing and reporting quality metrics.""" 10 | 11 | import os 12 | import time 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | from . import metric_utils 18 | from . import frechet_inception_distance 19 | from . import kernel_inception_distance 20 | from . import precision_recall 21 | from . import perceptual_path_length 22 | from . import inception_score 23 | from . import equivariance 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _metric_dict = dict() # name => fn 28 | 29 | def register_metric(fn): 30 | assert callable(fn) 31 | _metric_dict[fn.__name__] = fn 32 | return fn 33 | 34 | def is_valid_metric(metric): 35 | return metric in _metric_dict 36 | 37 | def list_valid_metrics(): 38 | return list(_metric_dict.keys()) 39 | 40 | #---------------------------------------------------------------------------- 41 | 42 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 43 | assert is_valid_metric(metric) 44 | opts = metric_utils.MetricOptions(**kwargs) 45 | 46 | # Calculate. 47 | start_time = time.time() 48 | results = _metric_dict[metric](opts) 49 | total_time = time.time() - start_time 50 | 51 | # Broadcast results. 52 | for key, value in list(results.items()): 53 | if opts.num_gpus > 1: 54 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 55 | torch.distributed.broadcast(tensor=value, src=0) 56 | value = float(value.cpu()) 57 | results[key] = value 58 | 59 | # Decorate with metadata. 60 | return dnnlib.EasyDict( 61 | results = dnnlib.EasyDict(results), 62 | metric = metric, 63 | total_time = total_time, 64 | total_time_str = dnnlib.util.format_time(total_time), 65 | num_gpus = opts.num_gpus, 66 | ) 67 | 68 | #---------------------------------------------------------------------------- 69 | 70 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 71 | metric = result_dict['metric'] 72 | assert is_valid_metric(metric) 73 | if run_dir is not None and snapshot_pkl is not None: 74 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 75 | 76 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 77 | print(jsonl_line) 78 | if run_dir is not None and os.path.isdir(run_dir): 79 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 80 | f.write(jsonl_line + '\n') 81 | 82 | #---------------------------------------------------------------------------- 83 | # Recommended metrics. 84 | 85 | @register_metric 86 | def fid50k_full(opts): 87 | opts.dataset_kwargs.update(max_size=None, xflip=False) 88 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 89 | return dict(fid50k_full=fid) 90 | 91 | @register_metric 92 | def kid50k_full(opts): 93 | opts.dataset_kwargs.update(max_size=None, xflip=False) 94 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 95 | return dict(kid50k_full=kid) 96 | 97 | @register_metric 98 | def pr50k3_full(opts): 99 | opts.dataset_kwargs.update(max_size=None, xflip=False) 100 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 101 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 102 | 103 | @register_metric 104 | def ppl2_wend(opts): 105 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 106 | return dict(ppl2_wend=ppl) 107 | 108 | @register_metric 109 | def eqt50k_int(opts): 110 | opts.G_kwargs.update(force_fp32=True) 111 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 112 | return dict(eqt50k_int=psnr) 113 | 114 | @register_metric 115 | def eqt50k_frac(opts): 116 | opts.G_kwargs.update(force_fp32=True) 117 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 118 | return dict(eqt50k_frac=psnr) 119 | 120 | @register_metric 121 | def eqr50k(opts): 122 | opts.G_kwargs.update(force_fp32=True) 123 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 124 | return dict(eqr50k=psnr) 125 | 126 | #---------------------------------------------------------------------------- 127 | # Legacy metrics. 128 | 129 | @register_metric 130 | def fid50k(opts): 131 | opts.dataset_kwargs.update(max_size=None) 132 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 133 | return dict(fid50k=fid) 134 | 135 | @register_metric 136 | def kid50k(opts): 137 | opts.dataset_kwargs.update(max_size=None) 138 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 139 | return dict(kid50k=kid) 140 | 141 | @register_metric 142 | def pr50k3(opts): 143 | opts.dataset_kwargs.update(max_size=None) 144 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 145 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 146 | 147 | @register_metric 148 | def is50k(opts): 149 | opts.dataset_kwargs.update(max_size=None, xflip=False) 150 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 151 | return dict(is50k_mean=mean, is50k_std=std) 152 | 153 | #---------------------------------------------------------------------------- 154 | -------------------------------------------------------------------------------- /stylegan3/gen_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def parse_range(s: Union[str, List]) -> List[int]: 26 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 27 | 28 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 29 | ''' 30 | if isinstance(s, list): return s 31 | ranges = [] 32 | range_re = re.compile(r'^(\d+)-(\d+)$') 33 | for p in s.split(','): 34 | m = range_re.match(p) 35 | if m: 36 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 37 | else: 38 | ranges.append(int(p)) 39 | return ranges 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: 44 | '''Parse a floating point 2-vector of syntax 'a,b'. 45 | 46 | Example: 47 | '0,1' returns (0,1) 48 | ''' 49 | if isinstance(s, tuple): return s 50 | parts = s.split(',') 51 | if len(parts) == 2: 52 | return (float(parts[0]), float(parts[1])) 53 | raise ValueError(f'cannot parse 2-vector {s}') 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def make_transform(translate: Tuple[float,float], angle: float): 58 | m = np.eye(3) 59 | s = np.sin(angle/360.0*np.pi*2) 60 | c = np.cos(angle/360.0*np.pi*2) 61 | m[0][0] = c 62 | m[0][1] = s 63 | m[0][2] = translate[0] 64 | m[1][0] = -s 65 | m[1][1] = c 66 | m[1][2] = translate[1] 67 | return m 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | @click.command() 72 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 73 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) 74 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 75 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 76 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 77 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') 78 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') 79 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 80 | def generate_images( 81 | network_pkl: str, 82 | seeds: List[int], 83 | truncation_psi: float, 84 | noise_mode: str, 85 | outdir: str, 86 | translate: Tuple[float,float], 87 | rotate: float, 88 | class_idx: Optional[int] 89 | ): 90 | """Generate images using pretrained network pickle. 91 | 92 | Examples: 93 | 94 | \b 95 | # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). 96 | python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ 97 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 98 | 99 | \b 100 | # Generate uncurated images with truncation using the MetFaces-U dataset 101 | python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 102 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl 103 | """ 104 | 105 | print('Loading networks from "%s"...' % network_pkl) 106 | device = torch.device('cuda') 107 | with dnnlib.util.open_url(network_pkl) as f: 108 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 109 | 110 | os.makedirs(outdir, exist_ok=True) 111 | 112 | # Labels. 113 | label = torch.zeros([1, G.c_dim], device=device) 114 | if G.c_dim != 0: 115 | if class_idx is None: 116 | raise click.ClickException('Must specify class label with --class when using a conditional network') 117 | label[:, class_idx] = 1 118 | else: 119 | if class_idx is not None: 120 | print ('warn: --class=lbl ignored when running on an unconditional network') 121 | 122 | # Generate images. 123 | for seed_idx, seed in enumerate(seeds): 124 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 125 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 126 | 127 | # Construct an inverse rotation/translation matrix and pass to the generator. The 128 | # generator expects this matrix as an inverse to avoid potentially failing numerical 129 | # operations in the network. 130 | if hasattr(G.synthesis, 'input'): 131 | m = make_transform(translate, rotate) 132 | m = np.linalg.inv(m) 133 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 134 | 135 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) 136 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 137 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 138 | 139 | 140 | #---------------------------------------------------------------------------- 141 | 142 | if __name__ == "__main__": 143 | generate_images() # pylint: disable=no-value-for-parameter 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files*/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /stylegan3/gui_utils/imgui_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import contextlib 10 | import imgui 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): 15 | s = imgui.get_style() 16 | s.window_padding = [spacing, spacing] 17 | s.item_spacing = [spacing, spacing] 18 | s.item_inner_spacing = [spacing, spacing] 19 | s.columns_min_spacing = spacing 20 | s.indent_spacing = indent 21 | s.scrollbar_size = scrollbar 22 | s.frame_padding = [4, 3] 23 | s.window_border_size = 1 24 | s.child_border_size = 1 25 | s.popup_border_size = 1 26 | s.frame_border_size = 1 27 | s.window_rounding = 0 28 | s.child_rounding = 0 29 | s.popup_rounding = 3 30 | s.frame_rounding = 3 31 | s.scrollbar_rounding = 3 32 | s.grab_rounding = 3 33 | 34 | getattr(imgui, f'style_colors_{color_scheme}')(s) 35 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 36 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] 37 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @contextlib.contextmanager 42 | def grayed_out(cond=True): 43 | if cond: 44 | s = imgui.get_style() 45 | text = s.colors[imgui.COLOR_TEXT_DISABLED] 46 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] 47 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 48 | imgui.push_style_color(imgui.COLOR_TEXT, *text) 49 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) 50 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) 51 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) 52 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) 53 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) 54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) 55 | imgui.push_style_color(imgui.COLOR_BUTTON, *back) 56 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) 57 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) 58 | imgui.push_style_color(imgui.COLOR_HEADER, *back) 59 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) 60 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) 61 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) 62 | yield 63 | imgui.pop_style_color(14) 64 | else: 65 | yield 66 | 67 | #---------------------------------------------------------------------------- 68 | 69 | @contextlib.contextmanager 70 | def item_width(width=None): 71 | if width is not None: 72 | imgui.push_item_width(width) 73 | yield 74 | imgui.pop_item_width() 75 | else: 76 | yield 77 | 78 | #---------------------------------------------------------------------------- 79 | 80 | def scoped_by_object_id(method): 81 | def decorator(self, *args, **kwargs): 82 | imgui.push_id(str(id(self))) 83 | res = method(self, *args, **kwargs) 84 | imgui.pop_id() 85 | return res 86 | return decorator 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | def button(label, width=0, enabled=True): 91 | with grayed_out(not enabled): 92 | clicked = imgui.button(label, width=width) 93 | clicked = clicked and enabled 94 | return clicked 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): 99 | expanded = False 100 | if show: 101 | if default: 102 | flags |= imgui.TREE_NODE_DEFAULT_OPEN 103 | if not enabled: 104 | flags |= imgui.TREE_NODE_LEAF 105 | with grayed_out(not enabled): 106 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) 107 | expanded = expanded and enabled 108 | return expanded, visible 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def popup_button(label, width=0, enabled=True): 113 | if button(label, width, enabled): 114 | imgui.open_popup(label) 115 | opened = imgui.begin_popup(label) 116 | return opened 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | def input_text(label, value, buffer_length, flags, width=None, help_text=''): 121 | old_value = value 122 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 123 | if value == '': 124 | color[-1] *= 0.5 125 | with item_width(width): 126 | imgui.push_style_color(imgui.COLOR_TEXT, *color) 127 | value = value if value != '' else help_text 128 | changed, value = imgui.input_text(label, value, buffer_length, flags) 129 | value = value if value != help_text else '' 130 | imgui.pop_style_color(1) 131 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: 132 | changed = (value != old_value) 133 | return changed, value 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | def drag_previous_control(enabled=True): 138 | dragging = False 139 | dx = 0 140 | dy = 0 141 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): 142 | if enabled: 143 | dragging = True 144 | dx, dy = imgui.get_mouse_drag_delta() 145 | imgui.reset_mouse_drag_delta() 146 | imgui.end_drag_drop_source() 147 | return dragging, dx, dy 148 | 149 | #---------------------------------------------------------------------------- 150 | 151 | def drag_button(label, width=0, enabled=True): 152 | clicked = button(label, width=width, enabled=enabled) 153 | dragging, dx, dy = drag_previous_control(enabled=enabled) 154 | return clicked, dragging, dx, dy 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | def drag_hidden_window(label, x, y, width, height, enabled=True): 159 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) 160 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) 161 | imgui.set_next_window_position(x, y) 162 | imgui.set_next_window_size(width, height) 163 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) 164 | dragging, dx, dy = drag_previous_control(enabled=enabled) 165 | imgui.end() 166 | imgui.pop_style_color(2) 167 | return dragging, dx, dy 168 | 169 | #---------------------------------------------------------------------------- 170 | -------------------------------------------------------------------------------- /stylegan3/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /dalle2_pytorch/tokenizer.py: -------------------------------------------------------------------------------- 1 | # take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py 2 | # to give users a quick easy start to training DALL-E without doing BPE 3 | 4 | import torch 5 | import youtokentome as yttm 6 | 7 | import html 8 | import os 9 | import ftfy 10 | import regex as re 11 | from functools import lru_cache 12 | from pathlib import Path 13 | 14 | # OpenAI simple tokenizer 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt") 19 | 20 | @lru_cache() 21 | def bytes_to_unicode(): 22 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 23 | cs = bs[:] 24 | n = 0 25 | for b in range(2 ** 8): 26 | if b not in bs: 27 | bs.append(b) 28 | cs.append(2 ** 8 + n) 29 | n += 1 30 | cs = [chr(n) for n in cs] 31 | return dict(zip(bs, cs)) 32 | 33 | def get_pairs(word): 34 | pairs = set() 35 | prev_char = word[0] 36 | for char in word[1:]: 37 | pairs.add((prev_char, char)) 38 | prev_char = char 39 | return pairs 40 | 41 | def basic_clean(text): 42 | text = ftfy.fix_text(text) 43 | text = html.unescape(html.unescape(text)) 44 | return text.strip() 45 | 46 | def whitespace_clean(text): 47 | text = re.sub(r'\s+', ' ', text) 48 | text = text.strip() 49 | return text 50 | 51 | class SimpleTokenizer(object): 52 | def __init__(self, bpe_path = default_bpe()): 53 | self.byte_encoder = bytes_to_unicode() 54 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 55 | merges = Path(bpe_path).read_text(encoding='utf8').split('\n') 56 | merges = merges[1:49152 - 256 - 2 + 1] 57 | merges = [tuple(merge.split()) for merge in merges] 58 | vocab = list(bytes_to_unicode().values()) 59 | vocab = vocab + [v + '' for v in vocab] 60 | for merge in merges: 61 | vocab.append(''.join(merge)) 62 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 63 | 64 | self.vocab_size = 49408 65 | 66 | self.encoder = dict(zip(vocab, range(len(vocab)))) 67 | self.decoder = {v: k for k, v in self.encoder.items()} 68 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 69 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 70 | self.pat = re.compile( 71 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 72 | re.IGNORECASE) 73 | 74 | def bpe(self, token): 75 | if token in self.cache: 76 | return self.cache[token] 77 | word = tuple(token[:-1]) + (token[-1] + '',) 78 | pairs = get_pairs(word) 79 | 80 | if not pairs: 81 | return token + '' 82 | 83 | while True: 84 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 85 | if bigram not in self.bpe_ranks: 86 | break 87 | first, second = bigram 88 | new_word = [] 89 | i = 0 90 | while i < len(word): 91 | try: 92 | j = word.index(first, i) 93 | new_word.extend(word[i:j]) 94 | i = j 95 | except: 96 | new_word.extend(word[i:]) 97 | break 98 | 99 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 100 | new_word.append(first + second) 101 | i += 2 102 | else: 103 | new_word.append(word[i]) 104 | i += 1 105 | new_word = tuple(new_word) 106 | word = new_word 107 | if len(word) == 1: 108 | break 109 | else: 110 | pairs = get_pairs(word) 111 | word = ' '.join(word) 112 | self.cache[token] = word 113 | return word 114 | 115 | def encode(self, text): 116 | bpe_tokens = [] 117 | text = whitespace_clean(basic_clean(text)).lower() 118 | for token in re.findall(self.pat, text): 119 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 120 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 121 | return bpe_tokens 122 | 123 | def decode(self, tokens, remove_start_end = True, pad_tokens = set()): 124 | if torch.is_tensor(tokens): 125 | tokens = tokens.tolist() 126 | 127 | if remove_start_end: 128 | tokens = [token for token in tokens if token not in (49406, 40407, 0)] 129 | text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens]) 130 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 131 | return text 132 | 133 | def tokenize(self, texts, context_length = 256, truncate_text = False): 134 | if isinstance(texts, str): 135 | texts = [texts] 136 | 137 | all_tokens = [self.encode(text) for text in texts] 138 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 139 | 140 | for i, tokens in enumerate(all_tokens): 141 | if len(tokens) > context_length: 142 | if truncate_text: 143 | tokens = tokens[:context_length] 144 | else: 145 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 146 | result[i, :len(tokens)] = torch.tensor(tokens) 147 | 148 | return result 149 | 150 | tokenizer = SimpleTokenizer() 151 | 152 | # YTTM tokenizer 153 | 154 | class YttmTokenizer: 155 | def __init__(self, bpe_path = None): 156 | bpe_path = Path(bpe_path) 157 | assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' 158 | 159 | tokenizer = yttm.BPE(model = str(bpe_path)) 160 | self.tokenizer = tokenizer 161 | self.vocab_size = tokenizer.vocab_size() 162 | 163 | def decode(self, tokens, pad_tokens = set()): 164 | if torch.is_tensor(tokens): 165 | tokens = tokens.tolist() 166 | 167 | return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0})) 168 | 169 | def encode(self, texts): 170 | encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID) 171 | return list(map(torch.tensor, encoded)) 172 | 173 | def tokenize(self, texts, context_length = 256, truncate_text = False): 174 | if isinstance(texts, str): 175 | texts = [texts] 176 | 177 | all_tokens = self.encode(texts) 178 | 179 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 180 | for i, tokens in enumerate(all_tokens): 181 | if len(tokens) > context_length: 182 | if truncate_text: 183 | tokens = tokens[:context_length] 184 | else: 185 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 186 | result[i, :len(tokens)] = torch.tensor(tokens) 187 | 188 | return result 189 | -------------------------------------------------------------------------------- /stylegan3/viz/pickle_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import os 11 | import re 12 | 13 | import dnnlib 14 | import imgui 15 | import numpy as np 16 | from gui_utils import imgui_utils 17 | 18 | from . import renderer 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def _locate_results(pattern): 23 | return pattern 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | class PickleWidget: 28 | def __init__(self, viz): 29 | self.viz = viz 30 | self.search_dirs = [] 31 | self.cur_pkl = None 32 | self.user_pkl = '' 33 | self.recent_pkls = [] 34 | self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} 35 | self.browse_refocus = False 36 | self.load('', ignore_errors=True) 37 | 38 | def add_recent(self, pkl, ignore_errors=False): 39 | try: 40 | resolved = self.resolve_pkl(pkl) 41 | if resolved not in self.recent_pkls: 42 | self.recent_pkls.append(resolved) 43 | except: 44 | if not ignore_errors: 45 | raise 46 | 47 | def load(self, pkl, ignore_errors=False): 48 | viz = self.viz 49 | viz.clear_result() 50 | viz.skip_frame() # The input field will change on next frame. 51 | try: 52 | resolved = self.resolve_pkl(pkl) 53 | name = resolved.replace('\\', '/').split('/')[-1] 54 | self.cur_pkl = resolved 55 | self.user_pkl = resolved 56 | viz.result.message = f'Loading {name}...' 57 | viz.defer_rendering() 58 | if resolved in self.recent_pkls: 59 | self.recent_pkls.remove(resolved) 60 | self.recent_pkls.insert(0, resolved) 61 | except: 62 | self.cur_pkl = None 63 | self.user_pkl = pkl 64 | if pkl == '': 65 | viz.result = dnnlib.EasyDict(message='No network pickle loaded') 66 | else: 67 | viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) 68 | if not ignore_errors: 69 | raise 70 | 71 | @imgui_utils.scoped_by_object_id 72 | def __call__(self, show=True): 73 | viz = self.viz 74 | recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] 75 | if show: 76 | imgui.text('Pickle') 77 | imgui.same_line(viz.label_w) 78 | changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, 79 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 80 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 81 | help_text=' | | | | /.pkl') 82 | if changed: 83 | self.load(self.user_pkl, ignore_errors=True) 84 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': 85 | imgui.set_tooltip(self.user_pkl) 86 | imgui.same_line() 87 | if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): 88 | imgui.open_popup('recent_pkls_popup') 89 | imgui.same_line() 90 | if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): 91 | imgui.open_popup('browse_pkls_popup') 92 | self.browse_cache.clear() 93 | self.browse_refocus = True 94 | 95 | if imgui.begin_popup('recent_pkls_popup'): 96 | for pkl in recent_pkls: 97 | clicked, _state = imgui.menu_item(pkl) 98 | if clicked: 99 | self.load(pkl, ignore_errors=True) 100 | imgui.end_popup() 101 | 102 | if imgui.begin_popup('browse_pkls_popup'): 103 | def recurse(parents): 104 | key = tuple(parents) 105 | items = self.browse_cache.get(key, None) 106 | if items is None: 107 | items = self.list_runs_and_pkls(parents) 108 | self.browse_cache[key] = items 109 | for item in items: 110 | if item.type == 'run' and imgui.begin_menu(item.name): 111 | recurse([item.path]) 112 | imgui.end_menu() 113 | if item.type == 'pkl': 114 | clicked, _state = imgui.menu_item(item.name) 115 | if clicked: 116 | self.load(item.path, ignore_errors=True) 117 | if len(items) == 0: 118 | with imgui_utils.grayed_out(): 119 | imgui.menu_item('No results found') 120 | recurse(self.search_dirs) 121 | if self.browse_refocus: 122 | imgui.set_scroll_here() 123 | viz.skip_frame() # Focus will change on next frame. 124 | self.browse_refocus = False 125 | imgui.end_popup() 126 | 127 | paths = viz.pop_drag_and_drop_paths() 128 | if paths is not None and len(paths) >= 1: 129 | self.load(paths[0], ignore_errors=True) 130 | 131 | viz.args.pkl = self.cur_pkl 132 | 133 | def list_runs_and_pkls(self, parents): 134 | items = [] 135 | run_regex = re.compile(r'\d+-.*') 136 | pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') 137 | for parent in set(parents): 138 | if os.path.isdir(parent): 139 | for entry in os.scandir(parent): 140 | if entry.is_dir() and run_regex.fullmatch(entry.name): 141 | items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) 142 | if entry.is_file() and pkl_regex.fullmatch(entry.name): 143 | items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) 144 | 145 | items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) 146 | return items 147 | 148 | def resolve_pkl(self, pattern): 149 | assert isinstance(pattern, str) 150 | assert pattern != '' 151 | 152 | # URL => return as is. 153 | if dnnlib.util.is_url(pattern): 154 | return pattern 155 | 156 | # Short-hand pattern => locate. 157 | path = _locate_results(pattern) 158 | 159 | # Run dir => pick the last saved snapshot. 160 | if os.path.isdir(path): 161 | pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) 162 | if len(pkl_files) == 0: 163 | raise IOError(f'No network pickle found in "{path}"') 164 | path = pkl_files[-1] 165 | 166 | # Normalize. 167 | path = os.path.abspath(path) 168 | return path 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /stylegan3/gen_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate lerp videos using pretrained network pickle.""" 10 | 11 | import copy 12 | import os 13 | import re 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import click 17 | import dnnlib 18 | import imageio 19 | import numpy as np 20 | import scipy.interpolate 21 | import torch 22 | from tqdm import tqdm 23 | 24 | import legacy 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): 29 | batch_size, channels, img_h, img_w = img.shape 30 | if grid_w is None: 31 | grid_w = batch_size // grid_h 32 | assert batch_size == grid_w * grid_h 33 | if float_to_uint8: 34 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 35 | img = img.reshape(grid_h, grid_w, channels, img_h, img_w) 36 | img = img.permute(2, 0, 3, 1, 4) 37 | img = img.reshape(channels, grid_h * img_h, grid_w * img_w) 38 | if chw_to_hwc: 39 | img = img.permute(1, 2, 0) 40 | if to_numpy: 41 | img = img.cpu().numpy() 42 | return img 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs): 47 | grid_w = grid_dims[0] 48 | grid_h = grid_dims[1] 49 | 50 | if num_keyframes is None: 51 | if len(seeds) % (grid_w*grid_h) != 0: 52 | raise ValueError('Number of input seeds must be divisible by grid W*H') 53 | num_keyframes = len(seeds) // (grid_w*grid_h) 54 | 55 | all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) 56 | for idx in range(num_keyframes*grid_h*grid_w): 57 | all_seeds[idx] = seeds[idx % len(seeds)] 58 | 59 | if shuffle_seed is not None: 60 | rng = np.random.RandomState(seed=shuffle_seed) 61 | rng.shuffle(all_seeds) 62 | 63 | zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) 64 | ws = G.mapping(z=zs, c=None, truncation_psi=psi) 65 | _ = G.synthesis(ws[:1]) # warm up 66 | ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) 67 | 68 | # Interpolation. 69 | grid = [] 70 | for yi in range(grid_h): 71 | row = [] 72 | for xi in range(grid_w): 73 | x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) 74 | y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) 75 | interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) 76 | row.append(interp) 77 | grid.append(row) 78 | 79 | # Render video. 80 | video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) 81 | for frame_idx in tqdm(range(num_keyframes * w_frames)): 82 | imgs = [] 83 | for yi in range(grid_h): 84 | for xi in range(grid_w): 85 | interp = grid[yi][xi] 86 | w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) 87 | img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] 88 | imgs.append(img) 89 | video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) 90 | video_out.close() 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def parse_range(s: Union[str, List[int]]) -> List[int]: 95 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 96 | 97 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 98 | ''' 99 | if isinstance(s, list): return s 100 | ranges = [] 101 | range_re = re.compile(r'^(\d+)-(\d+)$') 102 | for p in s.split(','): 103 | m = range_re.match(p) 104 | if m: 105 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 106 | else: 107 | ranges.append(int(p)) 108 | return ranges 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: 113 | '''Parse a 'M,N' or 'MxN' integer tuple. 114 | 115 | Example: 116 | '4x2' returns (4,2) 117 | '0,1' returns (0,1) 118 | ''' 119 | if isinstance(s, tuple): return s 120 | m = re.match(r'^(\d+)[x,](\d+)$', s) 121 | if m: 122 | return (int(m.group(1)), int(m.group(2))) 123 | raise ValueError(f'cannot parse tuple {s}') 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | @click.command() 128 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 129 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True) 130 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) 131 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) 132 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) 133 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) 134 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 135 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') 136 | def generate_images( 137 | network_pkl: str, 138 | seeds: List[int], 139 | shuffle_seed: Optional[int], 140 | truncation_psi: float, 141 | grid: Tuple[int,int], 142 | num_keyframes: Optional[int], 143 | w_frames: int, 144 | output: str 145 | ): 146 | """Render a latent vector interpolation video. 147 | 148 | Examples: 149 | 150 | \b 151 | # Render a 4x2 grid of interpolations for seeds 0 through 31. 152 | python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ 153 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 154 | 155 | Animation length and seed keyframes: 156 | 157 | The animation length is either determined based on the --seeds value or explicitly 158 | specified using the --num-keyframes option. 159 | 160 | When num keyframes is specified with --num-keyframes, the output video length 161 | will be 'num_keyframes*w_frames' frames. 162 | 163 | If --num-keyframes is not specified, the number of seeds given with 164 | --seeds must be divisible by grid size W*H (--grid). In this case the 165 | output video length will be '# seeds/(w*h)*w_frames' frames. 166 | """ 167 | 168 | print('Loading networks from "%s"...' % network_pkl) 169 | device = torch.device('cuda') 170 | with dnnlib.util.open_url(network_pkl) as f: 171 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 172 | 173 | gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi) 174 | 175 | #---------------------------------------------------------------------------- 176 | 177 | if __name__ == "__main__": 178 | generate_images() # pylint: disable=no-value-for-parameter 179 | 180 | #---------------------------------------------------------------------------- 181 | -------------------------------------------------------------------------------- /stylegan3/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Loss functions.""" 10 | 11 | import numpy as np 12 | import torch 13 | from torch_utils import training_stats 14 | from torch_utils.ops import conv2d_gradfix 15 | from torch_utils.ops import upfirdn2d 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class Loss: 20 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass 21 | raise NotImplementedError() 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | class StyleGAN2Loss(Loss): 26 | def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0): 27 | super().__init__() 28 | self.device = device 29 | self.G = G 30 | self.D = D 31 | self.augment_pipe = augment_pipe 32 | self.r1_gamma = r1_gamma 33 | self.style_mixing_prob = style_mixing_prob 34 | self.pl_weight = pl_weight 35 | self.pl_batch_shrink = pl_batch_shrink 36 | self.pl_decay = pl_decay 37 | self.pl_no_weight_grad = pl_no_weight_grad 38 | self.pl_mean = torch.zeros([], device=device) 39 | self.blur_init_sigma = blur_init_sigma 40 | self.blur_fade_kimg = blur_fade_kimg 41 | 42 | def run_G(self, z, c, update_emas=False): 43 | ws = self.G.mapping(z, c, update_emas=update_emas) 44 | if self.style_mixing_prob > 0: 45 | with torch.autograd.profiler.record_function('style_mixing'): 46 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) 47 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) 48 | ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] 49 | img = self.G.synthesis(ws, update_emas=update_emas) 50 | return img, ws 51 | 52 | def run_D(self, img, c, blur_sigma=0, update_emas=False): 53 | blur_size = np.floor(blur_sigma * 3) 54 | if blur_size > 0: 55 | with torch.autograd.profiler.record_function('blur'): 56 | f = torch.arange(-blur_size, blur_size + 1, device=img.device).div(blur_sigma).square().neg().exp2() 57 | img = upfirdn2d.filter2d(img, f / f.sum()) 58 | if self.augment_pipe is not None: 59 | img = self.augment_pipe(img) 60 | logits = self.D(img, c, update_emas=update_emas) 61 | return logits 62 | 63 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): 64 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 65 | if self.pl_weight == 0: 66 | phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) 67 | if self.r1_gamma == 0: 68 | phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) 69 | blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 70 | 71 | # Gmain: Maximize logits for generated images. 72 | if phase in ['Gmain', 'Gboth']: 73 | with torch.autograd.profiler.record_function('Gmain_forward'): 74 | gen_img, _gen_ws = self.run_G(gen_z, gen_c) 75 | gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) 76 | training_stats.report('Loss/scores/fake', gen_logits) 77 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 78 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) 79 | training_stats.report('Loss/G/loss', loss_Gmain) 80 | with torch.autograd.profiler.record_function('Gmain_backward'): 81 | loss_Gmain.mean().mul(gain).backward() 82 | 83 | # Gpl: Apply path length regularization. 84 | if phase in ['Greg', 'Gboth']: 85 | with torch.autograd.profiler.record_function('Gpl_forward'): 86 | batch_size = gen_z.shape[0] // self.pl_batch_shrink 87 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size]) 88 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) 89 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad): 90 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] 91 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 92 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 93 | self.pl_mean.copy_(pl_mean.detach()) 94 | pl_penalty = (pl_lengths - pl_mean).square() 95 | training_stats.report('Loss/pl_penalty', pl_penalty) 96 | loss_Gpl = pl_penalty * self.pl_weight 97 | training_stats.report('Loss/G/reg', loss_Gpl) 98 | with torch.autograd.profiler.record_function('Gpl_backward'): 99 | loss_Gpl.mean().mul(gain).backward() 100 | 101 | # Dmain: Minimize logits for generated images. 102 | loss_Dgen = 0 103 | if phase in ['Dmain', 'Dboth']: 104 | with torch.autograd.profiler.record_function('Dgen_forward'): 105 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True) 106 | gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) 107 | training_stats.report('Loss/scores/fake', gen_logits) 108 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 109 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) 110 | with torch.autograd.profiler.record_function('Dgen_backward'): 111 | loss_Dgen.mean().mul(gain).backward() 112 | 113 | # Dmain: Maximize logits for real images. 114 | # Dr1: Apply R1 regularization. 115 | if phase in ['Dmain', 'Dreg', 'Dboth']: 116 | name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' 117 | with torch.autograd.profiler.record_function(name + '_forward'): 118 | real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']) 119 | real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) 120 | training_stats.report('Loss/scores/real', real_logits) 121 | training_stats.report('Loss/signs/real', real_logits.sign()) 122 | 123 | loss_Dreal = 0 124 | if phase in ['Dmain', 'Dboth']: 125 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) 126 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) 127 | 128 | loss_Dr1 = 0 129 | if phase in ['Dreg', 'Dboth']: 130 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): 131 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] 132 | r1_penalty = r1_grads.square().sum([1,2,3]) 133 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2) 134 | training_stats.report('Loss/r1_penalty', r1_penalty) 135 | training_stats.report('Loss/D/reg', loss_Dr1) 136 | 137 | with torch.autograd.profiler.record_function(name + '_backward'): 138 | (loss_Dreal + loss_Dr1).mean().mul(gain).backward() 139 | 140 | #---------------------------------------------------------------------------- 141 | -------------------------------------------------------------------------------- /stylegan3/gui_utils/glfw_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import time 10 | import glfw 11 | import OpenGL.GL as gl 12 | from . import gl_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class GlfwWindow: # pylint: disable=too-many-public-methods 17 | def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): 18 | self._glfw_window = None 19 | self._drawing_frame = False 20 | self._frame_start_time = None 21 | self._frame_delta = 0 22 | self._fps_limit = None 23 | self._vsync = None 24 | self._skip_frames = 0 25 | self._deferred_show = deferred_show 26 | self._close_on_esc = close_on_esc 27 | self._esc_pressed = False 28 | self._drag_and_drop_paths = None 29 | self._capture_next_frame = False 30 | self._captured_frame = None 31 | 32 | # Create window. 33 | glfw.init() 34 | glfw.window_hint(glfw.VISIBLE, False) 35 | self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) 36 | self._attach_glfw_callbacks() 37 | self.make_context_current() 38 | 39 | # Adjust window. 40 | self.set_vsync(False) 41 | self.set_window_size(window_width, window_height) 42 | if not self._deferred_show: 43 | glfw.show_window(self._glfw_window) 44 | 45 | def close(self): 46 | if self._drawing_frame: 47 | self.end_frame() 48 | if self._glfw_window is not None: 49 | glfw.destroy_window(self._glfw_window) 50 | self._glfw_window = None 51 | #glfw.terminate() # Commented out to play it nice with other glfw clients. 52 | 53 | def __del__(self): 54 | try: 55 | self.close() 56 | except: 57 | pass 58 | 59 | @property 60 | def window_width(self): 61 | return self.content_width 62 | 63 | @property 64 | def window_height(self): 65 | return self.content_height + self.title_bar_height 66 | 67 | @property 68 | def content_width(self): 69 | width, _height = glfw.get_window_size(self._glfw_window) 70 | return width 71 | 72 | @property 73 | def content_height(self): 74 | _width, height = glfw.get_window_size(self._glfw_window) 75 | return height 76 | 77 | @property 78 | def title_bar_height(self): 79 | _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) 80 | return top 81 | 82 | @property 83 | def monitor_width(self): 84 | _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 85 | return width 86 | 87 | @property 88 | def monitor_height(self): 89 | _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 90 | return height 91 | 92 | @property 93 | def frame_delta(self): 94 | return self._frame_delta 95 | 96 | def set_title(self, title): 97 | glfw.set_window_title(self._glfw_window, title) 98 | 99 | def set_window_size(self, width, height): 100 | width = min(width, self.monitor_width) 101 | height = min(height, self.monitor_height) 102 | glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) 103 | if width == self.monitor_width and height == self.monitor_height: 104 | self.maximize() 105 | 106 | def set_content_size(self, width, height): 107 | self.set_window_size(width, height + self.title_bar_height) 108 | 109 | def maximize(self): 110 | glfw.maximize_window(self._glfw_window) 111 | 112 | def set_position(self, x, y): 113 | glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) 114 | 115 | def center(self): 116 | self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) 117 | 118 | def set_vsync(self, vsync): 119 | vsync = bool(vsync) 120 | if vsync != self._vsync: 121 | glfw.swap_interval(1 if vsync else 0) 122 | self._vsync = vsync 123 | 124 | def set_fps_limit(self, fps_limit): 125 | self._fps_limit = int(fps_limit) 126 | 127 | def should_close(self): 128 | return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) 129 | 130 | def skip_frame(self): 131 | self.skip_frames(1) 132 | 133 | def skip_frames(self, num): # Do not update window for the next N frames. 134 | self._skip_frames = max(self._skip_frames, int(num)) 135 | 136 | def is_skipping_frames(self): 137 | return self._skip_frames > 0 138 | 139 | def capture_next_frame(self): 140 | self._capture_next_frame = True 141 | 142 | def pop_captured_frame(self): 143 | frame = self._captured_frame 144 | self._captured_frame = None 145 | return frame 146 | 147 | def pop_drag_and_drop_paths(self): 148 | paths = self._drag_and_drop_paths 149 | self._drag_and_drop_paths = None 150 | return paths 151 | 152 | def draw_frame(self): # To be overridden by subclass. 153 | self.begin_frame() 154 | # Rendering code goes here. 155 | self.end_frame() 156 | 157 | def make_context_current(self): 158 | if self._glfw_window is not None: 159 | glfw.make_context_current(self._glfw_window) 160 | 161 | def begin_frame(self): 162 | # End previous frame. 163 | if self._drawing_frame: 164 | self.end_frame() 165 | 166 | # Apply FPS limit. 167 | if self._frame_start_time is not None and self._fps_limit is not None: 168 | delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit 169 | if delay > 0: 170 | time.sleep(delay) 171 | cur_time = time.perf_counter() 172 | if self._frame_start_time is not None: 173 | self._frame_delta = cur_time - self._frame_start_time 174 | self._frame_start_time = cur_time 175 | 176 | # Process events. 177 | glfw.poll_events() 178 | 179 | # Begin frame. 180 | self._drawing_frame = True 181 | self.make_context_current() 182 | 183 | # Initialize GL state. 184 | gl.glViewport(0, 0, self.content_width, self.content_height) 185 | gl.glMatrixMode(gl.GL_PROJECTION) 186 | gl.glLoadIdentity() 187 | gl.glTranslate(-1, 1, 0) 188 | gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) 189 | gl.glMatrixMode(gl.GL_MODELVIEW) 190 | gl.glLoadIdentity() 191 | gl.glEnable(gl.GL_BLEND) 192 | gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. 193 | 194 | # Clear. 195 | gl.glClearColor(0, 0, 0, 1) 196 | gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) 197 | 198 | def end_frame(self): 199 | assert self._drawing_frame 200 | self._drawing_frame = False 201 | 202 | # Skip frames if requested. 203 | if self._skip_frames > 0: 204 | self._skip_frames -= 1 205 | return 206 | 207 | # Capture frame if requested. 208 | if self._capture_next_frame: 209 | self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) 210 | self._capture_next_frame = False 211 | 212 | # Update window. 213 | if self._deferred_show: 214 | glfw.show_window(self._glfw_window) 215 | self._deferred_show = False 216 | glfw.swap_buffers(self._glfw_window) 217 | 218 | def _attach_glfw_callbacks(self): 219 | glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) 220 | glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) 221 | 222 | def _glfw_key_callback(self, _window, key, _scancode, action, _mods): 223 | if action == glfw.PRESS and key == glfw.KEY_ESCAPE: 224 | self._esc_pressed = True 225 | 226 | def _glfw_drop_callback(self, _window, paths): 227 | self._drag_and_drop_paths = paths 228 | 229 | #---------------------------------------------------------------------------- 230 | -------------------------------------------------------------------------------- /stylegan3/calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import click 13 | import json 14 | import tempfile 15 | import copy 16 | import torch 17 | 18 | import dnnlib 19 | import legacy 20 | from metrics import metric_main 21 | from metrics import metric_utils 22 | from torch_utils import training_stats 23 | from torch_utils import custom_ops 24 | from torch_utils import misc 25 | from torch_utils.ops import conv2d_gradfix 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def subprocess_fn(rank, args, temp_dir): 30 | dnnlib.util.Logger(should_flush=True) 31 | 32 | # Init torch.distributed. 33 | if args.num_gpus > 1: 34 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 35 | if os.name == 'nt': 36 | init_method = 'file:///' + init_file.replace('\\', '/') 37 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 38 | else: 39 | init_method = f'file://{init_file}' 40 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 41 | 42 | # Init torch_utils. 43 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 44 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 45 | if rank != 0 or not args.verbose: 46 | custom_ops.verbosity = 'none' 47 | 48 | # Configure torch. 49 | device = torch.device('cuda', rank) 50 | torch.backends.cuda.matmul.allow_tf32 = False 51 | torch.backends.cudnn.allow_tf32 = False 52 | conv2d_gradfix.enabled = True 53 | 54 | # Print network summary. 55 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 56 | if rank == 0 and args.verbose: 57 | z = torch.empty([1, G.z_dim], device=device) 58 | c = torch.empty([1, G.c_dim], device=device) 59 | misc.print_module_summary(G, [z, c]) 60 | 61 | # Calculate each metric. 62 | for metric in args.metrics: 63 | if rank == 0 and args.verbose: 64 | print(f'Calculating {metric}...') 65 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 66 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 67 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) 68 | if rank == 0: 69 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 70 | if rank == 0 and args.verbose: 71 | print() 72 | 73 | # Done. 74 | if rank == 0 and args.verbose: 75 | print('Exiting...') 76 | 77 | #---------------------------------------------------------------------------- 78 | 79 | def parse_comma_separated_list(s): 80 | if isinstance(s, list): 81 | return s 82 | if s is None or s.lower() == 'none' or s == '': 83 | return [] 84 | return s.split(',') 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | @click.command() 89 | @click.pass_context 90 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 91 | @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) 92 | @click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') 93 | @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') 94 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 95 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 96 | 97 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): 98 | """Calculate quality metrics for previous training run or pretrained network pickle. 99 | 100 | Examples: 101 | 102 | \b 103 | # Previous training run: look up options automatically, save result to JSONL file. 104 | python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ 105 | --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl 106 | 107 | \b 108 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 109 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ 110 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl 111 | 112 | \b 113 | Recommended metrics: 114 | fid50k_full Frechet inception distance against the full dataset. 115 | kid50k_full Kernel inception distance against the full dataset. 116 | pr50k3_full Precision and recall againt the full dataset. 117 | ppl2_wend Perceptual path length in W, endpoints, full image. 118 | eqt50k_int Equivariance w.r.t. integer translation (EQ-T). 119 | eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). 120 | eqr50k Equivariance w.r.t. rotation (EQ-R). 121 | 122 | \b 123 | Legacy metrics: 124 | fid50k Frechet inception distance against 50k real images. 125 | kid50k Kernel inception distance against 50k real images. 126 | pr50k3 Precision and recall against 50k real images. 127 | is50k Inception score for CIFAR-10. 128 | """ 129 | dnnlib.util.Logger(should_flush=True) 130 | 131 | # Validate arguments. 132 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) 133 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 134 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 135 | if not args.num_gpus >= 1: 136 | ctx.fail('--gpus must be at least 1') 137 | 138 | # Load network. 139 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 140 | ctx.fail('--network must point to a file or URL') 141 | if args.verbose: 142 | print(f'Loading network from "{network_pkl}"...') 143 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 144 | network_dict = legacy.load_network_pkl(f) 145 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 146 | 147 | # Initialize dataset options. 148 | if data is not None: 149 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) 150 | elif network_dict['training_set_kwargs'] is not None: 151 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 152 | else: 153 | ctx.fail('Could not look up dataset options; please specify --data') 154 | 155 | # Finalize dataset options. 156 | args.dataset_kwargs.resolution = args.G.img_resolution 157 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 158 | if mirror is not None: 159 | args.dataset_kwargs.xflip = mirror 160 | 161 | # Print dataset options. 162 | if args.verbose: 163 | print('Dataset options:') 164 | print(json.dumps(args.dataset_kwargs, indent=2)) 165 | 166 | # Locate run dir. 167 | args.run_dir = None 168 | if os.path.isfile(network_pkl): 169 | pkl_dir = os.path.dirname(network_pkl) 170 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 171 | args.run_dir = pkl_dir 172 | 173 | # Launch processes. 174 | if args.verbose: 175 | print('Launching processes...') 176 | torch.multiprocessing.set_start_method('spawn') 177 | with tempfile.TemporaryDirectory() as temp_dir: 178 | if args.num_gpus == 1: 179 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 180 | else: 181 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | if __name__ == "__main__": 186 | calc_metrics() # pylint: disable=no-value-for-parameter 187 | 188 | #---------------------------------------------------------------------------- 189 | -------------------------------------------------------------------------------- /dalle2_pytorch/dataloaders/decoder_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import webdataset as wds 3 | import torch 4 | import numpy as np 5 | import fsspec 6 | 7 | def get_shard(filename): 8 | """ 9 | Filenames with shards in them have a consistent structure that we can take advantage of 10 | Standard structure: path/to/file/prefix_string_00001.ext 11 | """ 12 | try: 13 | return filename.split("_")[-1].split(".")[0] 14 | except ValueError: 15 | raise RuntimeError(f"Could not find shard for filename {filename}") 16 | 17 | def get_example_file(fs, path, file_format): 18 | """ 19 | Given a file system and a file extension, return the example file 20 | """ 21 | return fs.glob(os.path.join(path, f"*.{file_format}"))[0] 22 | 23 | def embedding_inserter(samples, embeddings_url, shard_width, handler=wds.handlers.reraise_exception): 24 | """Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields""" 25 | previous_tar_url = None 26 | current_embeddings = None 27 | # Get a reference to an abstract file system where the embeddings are stored 28 | embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url) 29 | example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy") 30 | example_embedding_shard = get_shard(example_embedding_file) 31 | emb_shard_width = len(example_embedding_shard) 32 | # Easier to get the basename without the shard once than search through for the correct file every time 33 | embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_" 34 | 35 | def load_corresponding_embeds(tar_url): 36 | """Finds and reads the npy files that contains embeddings for the given webdataset tar""" 37 | shard = int(tar_url.split("/")[-1].split(".")[0]) 38 | embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy' 39 | with embeddings_fs.open(embedding_url) as f: 40 | data = np.load(f) 41 | return torch.from_numpy(data) 42 | 43 | for sample in samples: 44 | try: 45 | tar_url = sample["__url__"] 46 | key = sample["__key__"] 47 | if tar_url != previous_tar_url: 48 | # If the tar changed, we need to download new embeddings 49 | # This means if we shuffle before inserting it will load many more files than we expect and be very inefficient. 50 | previous_tar_url = tar_url 51 | current_embeddings = load_corresponding_embeds(tar_url) 52 | 53 | embedding_index = int(key[shard_width:]) 54 | sample["npy"] = current_embeddings[embedding_index] 55 | yield sample 56 | except Exception as exn: # From wds implementation 57 | if handler(exn): 58 | continue 59 | else: 60 | break 61 | insert_embedding = wds.filters.pipelinefilter(embedding_inserter) 62 | 63 | def verify_keys(samples, handler=wds.handlers.reraise_exception): 64 | """ 65 | Requires that both the image and embedding are present in the sample 66 | This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter. 67 | """ 68 | for sample in samples: 69 | try: 70 | assert "jpg" in sample, f"Sample {sample['__key__']} missing image" 71 | assert "npy" in sample, f"Sample {sample['__key__']} missing embedding. Did you set embedding_folder_url?" 72 | yield sample 73 | except Exception as exn: # From wds implementation 74 | if handler(exn): 75 | continue 76 | else: 77 | break 78 | 79 | class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface): 80 | """ 81 | A fluid interface wrapper for DataPipline that returns image embedding pairs 82 | Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | urls, 88 | embedding_folder_url=None, 89 | shard_width=None, 90 | handler=wds.handlers.reraise_exception, 91 | resample=False, 92 | shuffle_shards=True 93 | ): 94 | """ 95 | Modeled directly off of the WebDataset constructor 96 | 97 | :param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar 98 | :param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset. 99 | Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros. 100 | :param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index. 101 | For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard with this 4 and the last three digits are the index. 102 | :param handler: A webdataset handler. 103 | :param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely. 104 | :param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true. 105 | """ 106 | super().__init__() 107 | # Add the shardList and randomize or resample if requested 108 | if resample: 109 | assert not shuffle_shards, "Cannot both resample and shuffle" 110 | self.append(wds.ResampledShards(urls)) 111 | else: 112 | self.append(wds.SimpleShardList(urls)) 113 | if shuffle_shards: 114 | self.append(wds.filters.shuffle(1000)) 115 | 116 | self.append(wds.split_by_node) 117 | self.append(wds.split_by_worker) 118 | 119 | self.append(wds.tarfile_to_samples(handler=handler)) 120 | self.append(wds.decode("torchrgb")) 121 | if embedding_folder_url is not None: 122 | assert shard_width is not None, "Reading embeddings separately requires shard length to be given" 123 | self.append(insert_embedding(embeddings_url=embedding_folder_url, shard_width=shard_width, handler=handler)) 124 | self.append(verify_keys) 125 | self.append(wds.to_tuple("jpg", "npy")) 126 | 127 | def create_image_embedding_dataloader( 128 | tar_url, 129 | num_workers, 130 | batch_size, 131 | embeddings_url=None, 132 | shard_width=None, 133 | shuffle_num = None, 134 | shuffle_shards = True, 135 | resample_shards = False, 136 | handler=wds.handlers.warn_and_continue 137 | ): 138 | """ 139 | Convenience function to create an image embedding dataseta and dataloader in one line 140 | 141 | :param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar 142 | :param num_workers: The number of workers to use for the dataloader 143 | :param batch_size: The batch size to use for the dataloader 144 | :param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset. 145 | Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros. 146 | :param shard_width: The number of digits in the shard number. This is used to align the embedding index with the image index. 147 | For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index. 148 | :param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling. 149 | :param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true. 150 | :param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely. 151 | :param handler: A webdataset handler. 152 | """ 153 | ds = ImageEmbeddingDataset( 154 | tar_url, 155 | embeddings_url, 156 | shard_width=shard_width, 157 | shuffle_shards=shuffle_shards, 158 | resample=resample_shards, 159 | handler=handler 160 | ) 161 | if shuffle_num is not None and shuffle_num > 0: 162 | ds.shuffle(1000) 163 | return wds.WebLoader( 164 | ds, 165 | num_workers=num_workers, 166 | batch_size=batch_size, 167 | prefetch_factor=2, # This might be good to have high so the next npy file is prefetched 168 | pin_memory=True, 169 | shuffle=False 170 | ) --------------------------------------------------------------------------------