├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── Dockerfile ├── LICENSE.txt ├── README.md ├── avg_spectra.py ├── calc_metrics.py ├── dataset_tool.py ├── dnnlib ├── __init__.py └── util.py ├── docs ├── avg_spectra_screen0.png ├── avg_spectra_screen0_half.png ├── configs.md ├── dataset-tool-help.txt ├── stylegan3-teaser-1920x1006.png ├── train-help.txt ├── troubleshooting.md ├── visualizer_screen0.png └── visualizer_screen0_half.png ├── environment.yml ├── example.png ├── ffhq_euler.txt ├── for_debug.py ├── gen_images.py ├── gen_video.py ├── gui_utils ├── __init__.py ├── gl_utils.py ├── glfw_window.py ├── imgui_utils.py ├── imgui_window.py └── text_utils.py ├── legacy.py ├── marching_cube.py ├── metrics ├── __init__.py ├── equivariance.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── test_eg3d.py ├── test_eg3d_gen_video.py ├── test_eg3d_new.py ├── test_gen_images.png ├── test_thre_images.png ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py ├── train_eg3d.py ├── train_eg3d_full.py ├── training ├── EG3d.py ├── EG3d_v10.py ├── EG3d_v11.py ├── EG3d_v12.py ├── EG3d_v14.py ├── EG3d_v16.py ├── EG3d_v17.py ├── EG3d_v18.py ├── EG3d_v2.py ├── EG3d_v3.py ├── EG3d_v4.py ├── EG3d_v5.py ├── EG3d_v6.py ├── EG3d_v7.py ├── EG3d_v8.py ├── __init__.py ├── augment.py ├── camera_utils.py ├── cips_camera_utils.py ├── cips_camera_utils_v2.py ├── dataset.py ├── dataset_eg3d.py ├── loss.py ├── loss_full_v16.py ├── loss_full_v17.py ├── loss_full_v17_2.py ├── networks_stylegan2.py ├── networks_stylegan3.py ├── pigan_utils.py ├── training_loop.py ├── training_loop_eg3d.py ├── training_loop_eg3d_full.py └── v16.py ├── visualizer.py └── viz ├── __init__.py ├── capture_widget.py ├── equivariance_widget.py ├── latent_widget.py ├── layer_widget.py ├── performance_widget.py ├── pickle_widget.py ├── renderer.py ├── stylemix_widget.py └── trunc_noise_widget.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. In '...' directory, run command '...' 16 | 2. See error (copy&paste full log, including exceptions and **stacktraces**). 17 | 18 | Please copy&paste text instead of screenshots for better searchability. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Linux Ubuntu 20.04, Windows 10] 28 | - PyTorch version (e.g., pytorch 1.9.0) 29 | - CUDA toolkit version (e.g., CUDA 11.4) 30 | - NVIDIA driver version 31 | - GPU [e.g., Titan V, RTX 3090] 32 | - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:21.08-py3) 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | .idea/ 3 | env.md 4 | datasets/ 5 | ~/ 6 | training-runs/ 7 | gen_video_examples/ 8 | *.jpg 9 | *.png 10 | *.mp4 11 | march_file.npy 12 | 13 | # system ignore 14 | .DS_Store 15 | ._.DS_Store 16 | Thumbs.db 17 | 18 | # temp file 19 | *.log 20 | *.cache 21 | *.diff 22 | *.patch 23 | *.tmp 24 | *.swap 25 | *.swp 26 | *.bk 27 | *.bak 28 | tmp* 29 | 30 | #Java from WangTianzhou 31 | *.class 32 | .idea/ 33 | *.iml 34 | 35 | # Mobile Tools for Java (J2ME) 36 | .mtj.tmp/ 37 | 38 | #package file 39 | *.war 40 | *.ear 41 | *.zip 42 | *.tar.gz 43 | *.rar 44 | *.jar 45 | 46 | #maven ignore 47 | target/ 48 | build/ 49 | 50 | #eclipse ignore 51 | .settings/ 52 | .project/ 53 | .classpatch 54 | 55 | #Intellij idea 56 | .idea/ 57 | /idea/ 58 | *.ipr 59 | *.iml 60 | *.iws 61 | 62 | 63 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 64 | hs_err_pid* 65 | target/ 66 | 67 | 68 | .README.md.html 69 | 70 | pom.xml.releaseBackup 71 | release.properties 72 | 73 | # C++ from WengCiJie 74 | *.diff 75 | .classpath 76 | *.ipr 77 | *.iws 78 | *.ids 79 | .tags* 80 | build_info.properties 81 | *.pyc 82 | test.* 83 | /**/thrift-java/ 84 | *.log 85 | release 86 | .DS_Store 87 | .checkstyle 88 | MANIFEST.MF 89 | deploy-lsd.sh 90 | bazel-* 91 | output 92 | cmake-build-debug 93 | tags 94 | bazel 95 | bin/bazel 96 | 97 | # Python from liyong 98 | *.egg 99 | *.egg-info 100 | *.pyc 101 | *.so 102 | *.tar.gz 103 | .cache 104 | .coverage 105 | .eggs 106 | .idea 107 | .mypy_cache 108 | .pytest_cache 109 | .python-version 110 | .tox 111 | .venv 112 | venv/ 113 | .vscode 114 | MANIFEST 115 | __pycache__ 116 | build/ 117 | dist/ 118 | htmlcov 119 | test/__pycache__ 120 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM nvcr.io/nvidia/pytorch:21.08-py3 10 | 11 | ENV PYTHONDONTWRITEBYTECODE 1 12 | ENV PYTHONUNBUFFERED 1 13 | 14 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 15 | 16 | WORKDIR /workspace 17 | 18 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 19 | ENTRYPOINT ["/entry.sh"] 20 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN3 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import click 13 | import json 14 | import tempfile 15 | import copy 16 | import torch 17 | 18 | import dnnlib 19 | import legacy 20 | from metrics import metric_main 21 | from metrics import metric_utils 22 | from torch_utils import training_stats 23 | from torch_utils import custom_ops 24 | from torch_utils import misc 25 | from torch_utils.ops import conv2d_gradfix 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def subprocess_fn(rank, args, temp_dir): 30 | dnnlib.util.Logger(should_flush=True) 31 | 32 | # Init torch.distributed. 33 | if args.num_gpus > 1: 34 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 35 | if os.name == 'nt': 36 | init_method = 'file:///' + init_file.replace('\\', '/') 37 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 38 | else: 39 | init_method = f'file://{init_file}' 40 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 41 | 42 | # Init torch_utils. 43 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 44 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 45 | if rank != 0 or not args.verbose: 46 | custom_ops.verbosity = 'none' 47 | 48 | # Configure torch. 49 | device = torch.device('cuda', rank) 50 | torch.backends.cuda.matmul.allow_tf32 = False 51 | torch.backends.cudnn.allow_tf32 = False 52 | conv2d_gradfix.enabled = True 53 | 54 | # Print network summary. 55 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 56 | if rank == 0 and args.verbose: 57 | z = torch.empty([1, G.z_dim], device=device) 58 | c = torch.empty([1, G.c_dim], device=device) 59 | misc.print_module_summary(G, [z, c]) 60 | 61 | # Calculate each metric. 62 | for metric in args.metrics: 63 | if rank == 0 and args.verbose: 64 | print(f'Calculating {metric}...') 65 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 66 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 67 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) 68 | if rank == 0: 69 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 70 | if rank == 0 and args.verbose: 71 | print() 72 | 73 | # Done. 74 | if rank == 0 and args.verbose: 75 | print('Exiting...') 76 | 77 | #---------------------------------------------------------------------------- 78 | 79 | def parse_comma_separated_list(s): 80 | if isinstance(s, list): 81 | return s 82 | if s is None or s.lower() == 'none' or s == '': 83 | return [] 84 | return s.split(',') 85 | 86 | #---------------------------------------------------------------------------- 87 | 88 | @click.command() 89 | @click.pass_context 90 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 91 | @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) 92 | @click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') 93 | @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') 94 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 95 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 96 | 97 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): 98 | """Calculate quality metrics for previous training run or pretrained network pickle. 99 | 100 | Examples: 101 | 102 | \b 103 | # Previous training run: look up options automatically, save result to JSONL file. 104 | python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ 105 | --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl 106 | 107 | \b 108 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 109 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ 110 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl 111 | 112 | \b 113 | Recommended metrics: 114 | fid50k_full Frechet inception distance against the full dataset. 115 | kid50k_full Kernel inception distance against the full dataset. 116 | pr50k3_full Precision and recall againt the full dataset. 117 | ppl2_wend Perceptual path length in W, endpoints, full image. 118 | eqt50k_int Equivariance w.r.t. integer translation (EQ-T). 119 | eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). 120 | eqr50k Equivariance w.r.t. rotation (EQ-R). 121 | 122 | \b 123 | Legacy metrics: 124 | fid50k Frechet inception distance against 50k real images. 125 | kid50k Kernel inception distance against 50k real images. 126 | pr50k3 Precision and recall against 50k real images. 127 | is50k Inception score for CIFAR-10. 128 | """ 129 | dnnlib.util.Logger(should_flush=True) 130 | 131 | # Validate arguments. 132 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) 133 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 134 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 135 | if not args.num_gpus >= 1: 136 | ctx.fail('--gpus must be at least 1') 137 | 138 | # Load network. 139 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 140 | ctx.fail('--network must point to a file or URL') 141 | if args.verbose: 142 | print(f'Loading network from "{network_pkl}"...') 143 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 144 | network_dict = legacy.load_network_pkl(f) 145 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 146 | 147 | # Initialize dataset options. 148 | if data is not None: 149 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) 150 | elif network_dict['training_set_kwargs'] is not None: 151 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 152 | else: 153 | ctx.fail('Could not look up dataset options; please specify --data') 154 | 155 | # Finalize dataset options. 156 | args.dataset_kwargs.resolution = args.G.img_resolution 157 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 158 | if mirror is not None: 159 | args.dataset_kwargs.xflip = mirror 160 | 161 | # Print dataset options. 162 | if args.verbose: 163 | print('Dataset options:') 164 | print(json.dumps(args.dataset_kwargs, indent=2)) 165 | 166 | # Locate run dir. 167 | args.run_dir = None 168 | if os.path.isfile(network_pkl): 169 | pkl_dir = os.path.dirname(network_pkl) 170 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 171 | args.run_dir = pkl_dir 172 | 173 | # Launch processes. 174 | if args.verbose: 175 | print('Launching processes...') 176 | torch.multiprocessing.set_start_method('spawn') 177 | with tempfile.TemporaryDirectory() as temp_dir: 178 | if args.num_gpus == 1: 179 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 180 | else: 181 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | if __name__ == "__main__": 186 | calc_metrics() # pylint: disable=no-value-for-parameter 187 | 188 | #---------------------------------------------------------------------------- 189 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /docs/avg_spectra_screen0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/docs/avg_spectra_screen0.png -------------------------------------------------------------------------------- /docs/avg_spectra_screen0_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/docs/avg_spectra_screen0_half.png -------------------------------------------------------------------------------- /docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ Load LSUN dataset 9 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 10 | --source train-images-idx3-ubyte.gz Load MNIST dataset 11 | --source path/ Recursively load all images from path/ 12 | --source dataset.zip Recursively load all images from dataset.zip 13 | 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 18 | 19 | The output dataset format can be either an image folder or an uncompressed 20 | zip archive. Zip archives makes it easier to move datasets around file 21 | servers and clusters, and may offer better training performance on network 22 | file systems. 23 | 24 | Images within the dataset archive will be stored as uncompressed PNG. 25 | Uncompresed PNGs can be efficiently decoded in the training loop. 26 | 27 | Class labels are stored in a file called 'dataset.json' that is stored at 28 | the dataset root folder. This file has the following structure: 29 | 30 | { 31 | "labels": [ 32 | ["00000/img00000000.png",6], 33 | ["00000/img00000001.png",9], 34 | ... repeated for every image in the datase 35 | ["00049/img00049999.png",1] 36 | ] 37 | } 38 | 39 | If the 'dataset.json' file cannot be found, the dataset is interpreted as 40 | not containing class labels. 41 | 42 | Image scale/crop and resolution requirements: 43 | 44 | Output images must be square-shaped and they must all have the same power- 45 | of-two dimensions. 46 | 47 | To scale arbitrary input image size to a specific width and height, use 48 | the --resolution option. Output resolution will be either the original 49 | input resolution (if resolution was not specified) or the one specified 50 | with --resolution option. 51 | 52 | Use the --transform=center-crop or --transform=center-crop-wide options to 53 | apply a center crop transform on the input image. These options should be 54 | used with the --resolution option. For example: 55 | 56 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 57 | --transform=center-crop-wide --resolution=512x384 58 | 59 | Options: 60 | --source PATH Directory or archive name for input dataset 61 | [required] 62 | 63 | --dest PATH Output directory or archive name for output 64 | dataset [required] 65 | 66 | --max-images INTEGER Output only up to `max-images` images 67 | --transform [center-crop|center-crop-wide] 68 | Input crop/resize mode 69 | --resolution WxH Output resolution (e.g., '512x512') 70 | --help Show this message and exit. 71 | -------------------------------------------------------------------------------- /docs/stylegan3-teaser-1920x1006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/docs/stylegan3-teaser-1920x1006.png -------------------------------------------------------------------------------- /docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train a GAN using the techniques described in the paper "Alias-Free 4 | Generative Adversarial Networks". 5 | 6 | Examples: 7 | 8 | # Train StyleGAN3-T for AFHQv2 using 8 GPUs. 9 | python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \ 10 | --gpus=8 --batch=32 --gamma=8.2 --mirror=1 11 | 12 | # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle. 13 | python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \ 14 | --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \ 15 | --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl 16 | 17 | # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs. 18 | python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \ 19 | --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug 20 | 21 | Options: 22 | --outdir DIR Where to save the results [required] 23 | --cfg [stylegan3-t|stylegan3-r|stylegan2] 24 | Base configuration [required] 25 | --data [ZIP|DIR] Training data [required] 26 | --gpus INT Number of GPUs to use [required] 27 | --batch INT Total batch size [required] 28 | --gamma FLOAT R1 regularization weight [required] 29 | --cond BOOL Train conditional model [default: False] 30 | --mirror BOOL Enable dataset x-flips [default: False] 31 | --aug [noaug|ada|fixed] Augmentation mode [default: ada] 32 | --resume [PATH|URL] Resume from given network pickle 33 | --freezed INT Freeze first layers of D [default: 0] 34 | --p FLOAT Probability for --aug=fixed [default: 0.2] 35 | --target FLOAT Target value for --aug=ada [default: 0.6] 36 | --batch-gpu INT Limit batch size per GPU 37 | --cbase INT Capacity multiplier [default: 32768] 38 | --cmax INT Max. feature maps [default: 512] 39 | --glr FLOAT G learning rate [default: varies] 40 | --dlr FLOAT D learning rate [default: 0.002] 41 | --map-depth INT Mapping network depth [default: varies] 42 | --mbstd-group INT Minibatch std group size [default: 4] 43 | --desc STR String to include in result dir name 44 | --metrics [NAME|A,B,C|none] Quality metrics [default: fid50k_full] 45 | --kimg KIMG Total training duration [default: 25000] 46 | --tick KIMG How often to print progress [default: 4] 47 | --snap TICKS How often to save snapshots [default: 50] 48 | --seed INT Random seed [default: 0] 49 | --fp32 BOOL Disable mixed-precision [default: False] 50 | --nobench BOOL Disable cuDNN benchmarking [default: False] 51 | --workers INT DataLoader worker processes [default: 3] 52 | -n, --dry-run Print training options and exit 53 | --help Show this message and exit. 54 | -------------------------------------------------------------------------------- /docs/troubleshooting.md: -------------------------------------------------------------------------------- 1 | # Troubleshooting 2 | 3 | Our PyTorch code uses custom [CUDA extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) to speed up some of the network layers. Getting these to run can sometimes be a hassle. 4 | 5 | This page aims to give guidance on how to diagnose and fix run-time problems related to these extensions. 6 | 7 | ## Before you start 8 | 9 | 1. Try Docker first! Ensure you can successfully run our models using the recommended Docker image. Follow the instructions in [README.md](/README.md) to get it running. 10 | 2. Can't use Docker? Read on.. 11 | 12 | ## Installing dependencies 13 | 14 | Make sure you've installed everything listed on the requirements section in the [README.md](/README.md). The key components w.r.t. custom extensions are: 15 | 16 | - **[CUDA toolkit 11.1](https://developer.nvidia.com/cuda-toolkit)** or later (this is not the same as `cudatoolkit` from Conda). 17 | - PyTorch invokes `nvcc` to compile our CUDA kernels. 18 | - **ninja** 19 | - PyTorch uses [Ninja](https://ninja-build.org/) as its build system. 20 | - **GCC** (Linux) or **Visual Studio** (Windows) 21 | - GCC 7.x or later is required. Earlier versions such as GCC 6.3 [are known not to work](https://github.com/NVlabs/stylegan3/issues/2). 22 | 23 | #### Why is CUDA toolkit installation necessary? 24 | 25 | The PyTorch package contains the required CUDA toolkit libraries needed to run PyTorch, so why is a separate CUDA toolkit installation required? Our models use custom CUDA kernels to implement operations such as efficient resampling of 2D images. PyTorch code invokes the CUDA compiler at run-time to compile these kernels on first-use. The tools and libraries required for this compilation are not bundled in PyTorch and thus a host CUDA toolkit installation is required. 26 | 27 | ## Things to try 28 | 29 | - Completely remove: `$HOME/.cache/torch_extensions` (Linux) or `C:\Users\\AppData\Local\torch_extensions\torch_extensions\Cache` (Windows) and re-run StyleGAN3 python code. 30 | - Run ninja in `$HOME/.cache/torch_extensions` to see that it builds. 31 | - Inspect the `build.ninja` in the build directories under `$HOME/.cache/torch_extensions` and check CUDA tools and versions are consistent with what you intended to use. 32 | -------------------------------------------------------------------------------- /docs/visualizer_screen0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/docs/visualizer_screen0.png -------------------------------------------------------------------------------- /docs/visualizer_screen0_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/docs/visualizer_screen0_half.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: stylegan3 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.9.1 13 | - cudatoolkit=11.1 14 | - requests=2.26.0 15 | - tqdm=4.62.2 16 | - ninja=1.10.2 17 | - matplotlib=3.4.2 18 | - imageio=2.9.0 19 | - pip: 20 | - imgui==1.3.0 21 | - glfw==2.2.0 22 | - pyopengl==3.1.5 23 | - imageio-ffmpeg==0.4.3 24 | - pyspng 25 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/example.png -------------------------------------------------------------------------------- /for_debug.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from training.training_loop_eg3d import save_image_grid 4 | # from training.EG3d_v2 import Generator 5 | from training.EG3d_v16 import Generator 6 | torch.set_grad_enabled(False) 7 | import dnnlib 8 | import legacy 9 | from torch_utils import misc 10 | import os 11 | from torch import nn 12 | 13 | 14 | 15 | root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v16/00000-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-001800.pkl' 16 | 17 | nerf_init_args = {} 18 | nerf_init_args['img_size'] = 64 19 | 20 | 21 | # hyper params 22 | G_kwargs = dnnlib.EasyDict( 23 | z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), 24 | use_noise=False, # 关闭noise 25 | nerf_decoder_kwargs=dnnlib.EasyDict( 26 | in_c=32, 27 | mid_c=64, 28 | out_c=32, 29 | ) 30 | ) 31 | 32 | G_kwargs.mapping_kwargs.num_layers = 8 33 | G_kwargs.fused_modconv_default = 'inference_only' 34 | G_kwargs.conv_clamp = None 35 | common_kwargs = dict(c_dim= 16, #12, 36 | img_resolution=512, 37 | img_channels= 96, 38 | backbone_resolution=256, 39 | rank=0, 40 | ) 41 | 42 | G = Generator(**G_kwargs, **common_kwargs) 43 | G.cuda() 44 | G.eval() 45 | G.requires_grad_(False) 46 | if root: 47 | print(f'Resuming from "{root}"') 48 | cwd = os.getcwd() 49 | os.chdir('./training') 50 | with dnnlib.util.open_url(root) as f: 51 | resume_data = legacy.load_network_pkl(f) 52 | for name, module in [('G', G)]: # G_ema 53 | misc.copy_params_and_buffers(resume_data[name], module, require_all=True) 54 | os.chdir(cwd) 55 | for n, p in G.named_parameters(): 56 | if torch.any(torch.isnan(p)): 57 | print(n) 58 | for n, b in G.named_buffers(): 59 | if torch.any(torch.isnan(b)): 60 | print(n) 61 | exit() 62 | 63 | # device = torch.device('cuda') 64 | # with dnnlib.util.open_url(root) as f: 65 | # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 66 | # G.eval() 67 | # G.requires_grad_(False) 68 | 69 | 70 | gh = 4 71 | grid_z = torch.randn(gh, 512).cuda() 72 | grid_c = torch.tensor([0,0,0]).reshape(1, -1).expand(gh, -1).cuda() # 4, 3 73 | 74 | meta_data = {'noise_mode': 'const'} 75 | with torch.no_grad(): 76 | gen_imgs = G(grid_z, grid_c, nerf_init_args=nerf_init_args, **meta_data)[:, :3] # gh, 3, h, w 77 | # print(gen_imgs.shape, gen_imgs.max(), bgen_imgs.min()) 78 | save_image_grid(gen_imgs.cpu().numpy(), 'example.png', (-1, 1), (1, gh)) 79 | 80 | -------------------------------------------------------------------------------- /gen_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def parse_range(s: Union[str, List]) -> List[int]: 26 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 27 | 28 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 29 | ''' 30 | if isinstance(s, list): return s 31 | ranges = [] 32 | range_re = re.compile(r'^(\d+)-(\d+)$') 33 | for p in s.split(','): 34 | m = range_re.match(p) 35 | if m: 36 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 37 | else: 38 | ranges.append(int(p)) 39 | return ranges 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: 44 | '''Parse a floating point 2-vector of syntax 'a,b'. 45 | 46 | Example: 47 | '0,1' returns (0,1) 48 | ''' 49 | if isinstance(s, tuple): return s 50 | parts = s.split(',') 51 | if len(parts) == 2: 52 | return (float(parts[0]), float(parts[1])) 53 | raise ValueError(f'cannot parse 2-vector {s}') 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def make_transform(translate: Tuple[float,float], angle: float): 58 | m = np.eye(3) 59 | s = np.sin(angle/360.0*np.pi*2) 60 | c = np.cos(angle/360.0*np.pi*2) 61 | m[0][0] = c 62 | m[0][1] = s 63 | m[0][2] = translate[0] 64 | m[1][0] = -s 65 | m[1][1] = c 66 | m[1][2] = translate[1] 67 | return m 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | @click.command() 72 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 73 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) 74 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 75 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 76 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 77 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') 78 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') 79 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 80 | def generate_images( 81 | network_pkl: str, 82 | seeds: List[int], 83 | truncation_psi: float, 84 | noise_mode: str, 85 | outdir: str, 86 | translate: Tuple[float,float], 87 | rotate: float, 88 | class_idx: Optional[int] 89 | ): 90 | """Generate images using pretrained network pickle. 91 | 92 | Examples: 93 | 94 | \b 95 | # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). 96 | python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ 97 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 98 | 99 | \b 100 | # Generate uncurated images with truncation using the MetFaces-U dataset 101 | python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 102 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl 103 | """ 104 | 105 | print('Loading networks from "%s"...' % network_pkl) 106 | device = torch.device('cuda') 107 | with dnnlib.util.open_url(network_pkl) as f: 108 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 109 | 110 | os.makedirs(outdir, exist_ok=True) 111 | 112 | # Labels. 113 | label = torch.zeros([1, G.c_dim], device=device) 114 | if G.c_dim != 0: 115 | if class_idx is None: 116 | raise click.ClickException('Must specify class label with --class when using a conditional network') 117 | label[:, class_idx] = 1 118 | else: 119 | if class_idx is not None: 120 | print ('warn: --class=lbl ignored when running on an unconditional network') 121 | 122 | # Generate images. 123 | for seed_idx, seed in enumerate(seeds): 124 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 125 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 126 | 127 | # Construct an inverse rotation/translation matrix and pass to the generator. The 128 | # generator expects this matrix as an inverse to avoid potentially failing numerical 129 | # operations in the network. 130 | if hasattr(G.synthesis, 'input'): 131 | m = make_transform(translate, rotate) 132 | m = np.linalg.inv(m) 133 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 134 | 135 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) 136 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 137 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 138 | 139 | 140 | #---------------------------------------------------------------------------- 141 | 142 | if __name__ == "__main__": 143 | generate_images() # pylint: disable=no-value-for-parameter 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /gen_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate lerp videos using pretrained network pickle.""" 10 | 11 | import copy 12 | import os 13 | import re 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import click 17 | import dnnlib 18 | import imageio 19 | import numpy as np 20 | import scipy.interpolate 21 | import torch 22 | from tqdm import tqdm 23 | 24 | import legacy 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): 29 | batch_size, channels, img_h, img_w = img.shape 30 | if grid_w is None: 31 | grid_w = batch_size // grid_h 32 | assert batch_size == grid_w * grid_h 33 | if float_to_uint8: 34 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 35 | img = img.reshape(grid_h, grid_w, channels, img_h, img_w) 36 | img = img.permute(2, 0, 3, 1, 4) 37 | img = img.reshape(channels, grid_h * img_h, grid_w * img_w) 38 | if chw_to_hwc: 39 | img = img.permute(1, 2, 0) 40 | if to_numpy: 41 | img = img.cpu().numpy() 42 | return img 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs): 47 | grid_w = grid_dims[0] 48 | grid_h = grid_dims[1] 49 | 50 | if num_keyframes is None: 51 | if len(seeds) % (grid_w*grid_h) != 0: 52 | raise ValueError('Number of input seeds must be divisible by grid W*H') 53 | num_keyframes = len(seeds) // (grid_w*grid_h) 54 | 55 | all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) 56 | for idx in range(num_keyframes*grid_h*grid_w): 57 | all_seeds[idx] = seeds[idx % len(seeds)] 58 | 59 | if shuffle_seed is not None: 60 | rng = np.random.RandomState(seed=shuffle_seed) 61 | rng.shuffle(all_seeds) 62 | 63 | zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) 64 | ws = G.mapping(z=zs, c=None, truncation_psi=psi) 65 | _ = G.synthesis(ws[:1]) # warm up 66 | ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) 67 | 68 | # Interpolation. 69 | grid = [] 70 | for yi in range(grid_h): 71 | row = [] 72 | for xi in range(grid_w): 73 | x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) 74 | y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) 75 | interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) 76 | row.append(interp) 77 | grid.append(row) 78 | 79 | # Render video. 80 | video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) 81 | for frame_idx in tqdm(range(num_keyframes * w_frames)): 82 | imgs = [] 83 | for yi in range(grid_h): 84 | for xi in range(grid_w): 85 | interp = grid[yi][xi] 86 | w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) 87 | img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] 88 | imgs.append(img) 89 | video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) 90 | video_out.close() 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def parse_range(s: Union[str, List[int]]) -> List[int]: 95 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 96 | 97 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 98 | ''' 99 | if isinstance(s, list): return s 100 | ranges = [] 101 | range_re = re.compile(r'^(\d+)-(\d+)$') 102 | for p in s.split(','): 103 | m = range_re.match(p) 104 | if m: 105 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 106 | else: 107 | ranges.append(int(p)) 108 | return ranges 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: 113 | '''Parse a 'M,N' or 'MxN' integer tuple. 114 | 115 | Example: 116 | '4x2' returns (4,2) 117 | '0,1' returns (0,1) 118 | ''' 119 | if isinstance(s, tuple): return s 120 | m = re.match(r'^(\d+)[x,](\d+)$', s) 121 | if m: 122 | return (int(m.group(1)), int(m.group(2))) 123 | raise ValueError(f'cannot parse tuple {s}') 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | @click.command() 128 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 129 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True) 130 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) 131 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) 132 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) 133 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) 134 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 135 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') 136 | def generate_images( 137 | network_pkl: str, 138 | seeds: List[int], 139 | shuffle_seed: Optional[int], 140 | truncation_psi: float, 141 | grid: Tuple[int,int], 142 | num_keyframes: Optional[int], 143 | w_frames: int, 144 | output: str 145 | ): 146 | """Render a latent vector interpolation video. 147 | 148 | Examples: 149 | 150 | \b 151 | # Render a 4x2 grid of interpolations for seeds 0 through 31. 152 | python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ 153 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 154 | 155 | Animation length and seed keyframes: 156 | 157 | The animation length is either determined based on the --seeds value or explicitly 158 | specified using the --num-keyframes option. 159 | 160 | When num keyframes is specified with --num-keyframes, the output video length 161 | will be 'num_keyframes*w_frames' frames. 162 | 163 | If --num-keyframes is not specified, the number of seeds given with 164 | --seeds must be divisible by grid size W*H (--grid). In this case the 165 | output video length will be '# seeds/(w*h)*w_frames' frames. 166 | """ 167 | 168 | print('Loading networks from "%s"...' % network_pkl) 169 | device = torch.device('cuda') 170 | with dnnlib.util.open_url(network_pkl) as f: 171 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 172 | 173 | gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi) 174 | 175 | #---------------------------------------------------------------------------- 176 | 177 | if __name__ == "__main__": 178 | generate_images() # pylint: disable=no-value-for-parameter 179 | 180 | #---------------------------------------------------------------------------- 181 | -------------------------------------------------------------------------------- /gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /gui_utils/glfw_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import time 10 | import glfw 11 | import OpenGL.GL as gl 12 | from . import gl_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class GlfwWindow: # pylint: disable=too-many-public-methods 17 | def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True): 18 | self._glfw_window = None 19 | self._drawing_frame = False 20 | self._frame_start_time = None 21 | self._frame_delta = 0 22 | self._fps_limit = None 23 | self._vsync = None 24 | self._skip_frames = 0 25 | self._deferred_show = deferred_show 26 | self._close_on_esc = close_on_esc 27 | self._esc_pressed = False 28 | self._drag_and_drop_paths = None 29 | self._capture_next_frame = False 30 | self._captured_frame = None 31 | 32 | # Create window. 33 | glfw.init() 34 | glfw.window_hint(glfw.VISIBLE, False) 35 | self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None) 36 | self._attach_glfw_callbacks() 37 | self.make_context_current() 38 | 39 | # Adjust window. 40 | self.set_vsync(False) 41 | self.set_window_size(window_width, window_height) 42 | if not self._deferred_show: 43 | glfw.show_window(self._glfw_window) 44 | 45 | def close(self): 46 | if self._drawing_frame: 47 | self.end_frame() 48 | if self._glfw_window is not None: 49 | glfw.destroy_window(self._glfw_window) 50 | self._glfw_window = None 51 | #glfw.terminate() # Commented out to play it nice with other glfw clients. 52 | 53 | def __del__(self): 54 | try: 55 | self.close() 56 | except: 57 | pass 58 | 59 | @property 60 | def window_width(self): 61 | return self.content_width 62 | 63 | @property 64 | def window_height(self): 65 | return self.content_height + self.title_bar_height 66 | 67 | @property 68 | def content_width(self): 69 | width, _height = glfw.get_window_size(self._glfw_window) 70 | return width 71 | 72 | @property 73 | def content_height(self): 74 | _width, height = glfw.get_window_size(self._glfw_window) 75 | return height 76 | 77 | @property 78 | def title_bar_height(self): 79 | _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window) 80 | return top 81 | 82 | @property 83 | def monitor_width(self): 84 | _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 85 | return width 86 | 87 | @property 88 | def monitor_height(self): 89 | _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor()) 90 | return height 91 | 92 | @property 93 | def frame_delta(self): 94 | return self._frame_delta 95 | 96 | def set_title(self, title): 97 | glfw.set_window_title(self._glfw_window, title) 98 | 99 | def set_window_size(self, width, height): 100 | width = min(width, self.monitor_width) 101 | height = min(height, self.monitor_height) 102 | glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0)) 103 | if width == self.monitor_width and height == self.monitor_height: 104 | self.maximize() 105 | 106 | def set_content_size(self, width, height): 107 | self.set_window_size(width, height + self.title_bar_height) 108 | 109 | def maximize(self): 110 | glfw.maximize_window(self._glfw_window) 111 | 112 | def set_position(self, x, y): 113 | glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height) 114 | 115 | def center(self): 116 | self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2) 117 | 118 | def set_vsync(self, vsync): 119 | vsync = bool(vsync) 120 | if vsync != self._vsync: 121 | glfw.swap_interval(1 if vsync else 0) 122 | self._vsync = vsync 123 | 124 | def set_fps_limit(self, fps_limit): 125 | self._fps_limit = int(fps_limit) 126 | 127 | def should_close(self): 128 | return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed) 129 | 130 | def skip_frame(self): 131 | self.skip_frames(1) 132 | 133 | def skip_frames(self, num): # Do not update window for the next N frames. 134 | self._skip_frames = max(self._skip_frames, int(num)) 135 | 136 | def is_skipping_frames(self): 137 | return self._skip_frames > 0 138 | 139 | def capture_next_frame(self): 140 | self._capture_next_frame = True 141 | 142 | def pop_captured_frame(self): 143 | frame = self._captured_frame 144 | self._captured_frame = None 145 | return frame 146 | 147 | def pop_drag_and_drop_paths(self): 148 | paths = self._drag_and_drop_paths 149 | self._drag_and_drop_paths = None 150 | return paths 151 | 152 | def draw_frame(self): # To be overridden by subclass. 153 | self.begin_frame() 154 | # Rendering code goes here. 155 | self.end_frame() 156 | 157 | def make_context_current(self): 158 | if self._glfw_window is not None: 159 | glfw.make_context_current(self._glfw_window) 160 | 161 | def begin_frame(self): 162 | # End previous frame. 163 | if self._drawing_frame: 164 | self.end_frame() 165 | 166 | # Apply FPS limit. 167 | if self._frame_start_time is not None and self._fps_limit is not None: 168 | delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit 169 | if delay > 0: 170 | time.sleep(delay) 171 | cur_time = time.perf_counter() 172 | if self._frame_start_time is not None: 173 | self._frame_delta = cur_time - self._frame_start_time 174 | self._frame_start_time = cur_time 175 | 176 | # Process events. 177 | glfw.poll_events() 178 | 179 | # Begin frame. 180 | self._drawing_frame = True 181 | self.make_context_current() 182 | 183 | # Initialize GL state. 184 | gl.glViewport(0, 0, self.content_width, self.content_height) 185 | gl.glMatrixMode(gl.GL_PROJECTION) 186 | gl.glLoadIdentity() 187 | gl.glTranslate(-1, 1, 0) 188 | gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1) 189 | gl.glMatrixMode(gl.GL_MODELVIEW) 190 | gl.glLoadIdentity() 191 | gl.glEnable(gl.GL_BLEND) 192 | gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha. 193 | 194 | # Clear. 195 | gl.glClearColor(0, 0, 0, 1) 196 | gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT) 197 | 198 | def end_frame(self): 199 | assert self._drawing_frame 200 | self._drawing_frame = False 201 | 202 | # Skip frames if requested. 203 | if self._skip_frames > 0: 204 | self._skip_frames -= 1 205 | return 206 | 207 | # Capture frame if requested. 208 | if self._capture_next_frame: 209 | self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height) 210 | self._capture_next_frame = False 211 | 212 | # Update window. 213 | if self._deferred_show: 214 | glfw.show_window(self._glfw_window) 215 | self._deferred_show = False 216 | glfw.swap_buffers(self._glfw_window) 217 | 218 | def _attach_glfw_callbacks(self): 219 | glfw.set_key_callback(self._glfw_window, self._glfw_key_callback) 220 | glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback) 221 | 222 | def _glfw_key_callback(self, _window, key, _scancode, action, _mods): 223 | if action == glfw.PRESS and key == glfw.KEY_ESCAPE: 224 | self._esc_pressed = True 225 | 226 | def _glfw_drop_callback(self, _window, paths): 227 | self._drag_and_drop_paths = paths 228 | 229 | #---------------------------------------------------------------------------- 230 | -------------------------------------------------------------------------------- /gui_utils/imgui_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import contextlib 10 | import imgui 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): 15 | s = imgui.get_style() 16 | s.window_padding = [spacing, spacing] 17 | s.item_spacing = [spacing, spacing] 18 | s.item_inner_spacing = [spacing, spacing] 19 | s.columns_min_spacing = spacing 20 | s.indent_spacing = indent 21 | s.scrollbar_size = scrollbar 22 | s.frame_padding = [4, 3] 23 | s.window_border_size = 1 24 | s.child_border_size = 1 25 | s.popup_border_size = 1 26 | s.frame_border_size = 1 27 | s.window_rounding = 0 28 | s.child_rounding = 0 29 | s.popup_rounding = 3 30 | s.frame_rounding = 3 31 | s.scrollbar_rounding = 3 32 | s.grab_rounding = 3 33 | 34 | getattr(imgui, f'style_colors_{color_scheme}')(s) 35 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 36 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] 37 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @contextlib.contextmanager 42 | def grayed_out(cond=True): 43 | if cond: 44 | s = imgui.get_style() 45 | text = s.colors[imgui.COLOR_TEXT_DISABLED] 46 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] 47 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 48 | imgui.push_style_color(imgui.COLOR_TEXT, *text) 49 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) 50 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) 51 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) 52 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) 53 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) 54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) 55 | imgui.push_style_color(imgui.COLOR_BUTTON, *back) 56 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) 57 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) 58 | imgui.push_style_color(imgui.COLOR_HEADER, *back) 59 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) 60 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) 61 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) 62 | yield 63 | imgui.pop_style_color(14) 64 | else: 65 | yield 66 | 67 | #---------------------------------------------------------------------------- 68 | 69 | @contextlib.contextmanager 70 | def item_width(width=None): 71 | if width is not None: 72 | imgui.push_item_width(width) 73 | yield 74 | imgui.pop_item_width() 75 | else: 76 | yield 77 | 78 | #---------------------------------------------------------------------------- 79 | 80 | def scoped_by_object_id(method): 81 | def decorator(self, *args, **kwargs): 82 | imgui.push_id(str(id(self))) 83 | res = method(self, *args, **kwargs) 84 | imgui.pop_id() 85 | return res 86 | return decorator 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | def button(label, width=0, enabled=True): 91 | with grayed_out(not enabled): 92 | clicked = imgui.button(label, width=width) 93 | clicked = clicked and enabled 94 | return clicked 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): 99 | expanded = False 100 | if show: 101 | if default: 102 | flags |= imgui.TREE_NODE_DEFAULT_OPEN 103 | if not enabled: 104 | flags |= imgui.TREE_NODE_LEAF 105 | with grayed_out(not enabled): 106 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) 107 | expanded = expanded and enabled 108 | return expanded, visible 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def popup_button(label, width=0, enabled=True): 113 | if button(label, width, enabled): 114 | imgui.open_popup(label) 115 | opened = imgui.begin_popup(label) 116 | return opened 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | def input_text(label, value, buffer_length, flags, width=None, help_text=''): 121 | old_value = value 122 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 123 | if value == '': 124 | color[-1] *= 0.5 125 | with item_width(width): 126 | imgui.push_style_color(imgui.COLOR_TEXT, *color) 127 | value = value if value != '' else help_text 128 | changed, value = imgui.input_text(label, value, buffer_length, flags) 129 | value = value if value != help_text else '' 130 | imgui.pop_style_color(1) 131 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: 132 | changed = (value != old_value) 133 | return changed, value 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | def drag_previous_control(enabled=True): 138 | dragging = False 139 | dx = 0 140 | dy = 0 141 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): 142 | if enabled: 143 | dragging = True 144 | dx, dy = imgui.get_mouse_drag_delta() 145 | imgui.reset_mouse_drag_delta() 146 | imgui.end_drag_drop_source() 147 | return dragging, dx, dy 148 | 149 | #---------------------------------------------------------------------------- 150 | 151 | def drag_button(label, width=0, enabled=True): 152 | clicked = button(label, width=width, enabled=enabled) 153 | dragging, dx, dy = drag_previous_control(enabled=enabled) 154 | return clicked, dragging, dx, dy 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | def drag_hidden_window(label, x, y, width, height, enabled=True): 159 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) 160 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) 161 | imgui.set_next_window_position(x, y) 162 | imgui.set_next_window_size(width, height) 163 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) 164 | dragging, dx, dy = drag_previous_control(enabled=enabled) 165 | imgui.end() 166 | imgui.pop_style_color(2) 167 | return dragging, dx, dy 168 | 169 | #---------------------------------------------------------------------------- 170 | -------------------------------------------------------------------------------- /gui_utils/imgui_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import imgui 11 | import imgui.integrations.glfw 12 | 13 | from . import glfw_window 14 | from . import imgui_utils 15 | from . import text_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class ImguiWindow(glfw_window.GlfwWindow): 20 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): 21 | if font is None: 22 | font = text_utils.get_default_font() 23 | font_sizes = {int(size) for size in font_sizes} 24 | super().__init__(title=title, **glfw_kwargs) 25 | 26 | # Init fields. 27 | self._imgui_context = None 28 | self._imgui_renderer = None 29 | self._imgui_fonts = None 30 | self._cur_font_size = max(font_sizes) 31 | 32 | # Delete leftover imgui.ini to avoid unexpected behavior. 33 | if os.path.isfile('imgui.ini'): 34 | os.remove('imgui.ini') 35 | 36 | # Init ImGui. 37 | self._imgui_context = imgui.create_context() 38 | self._imgui_renderer = _GlfwRenderer(self._glfw_window) 39 | self._attach_glfw_callbacks() 40 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. 41 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). 42 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} 43 | self._imgui_renderer.refresh_font_texture() 44 | 45 | def close(self): 46 | self.make_context_current() 47 | self._imgui_fonts = None 48 | if self._imgui_renderer is not None: 49 | self._imgui_renderer.shutdown() 50 | self._imgui_renderer = None 51 | if self._imgui_context is not None: 52 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. 53 | self._imgui_context = None 54 | super().close() 55 | 56 | def _glfw_key_callback(self, *args): 57 | super()._glfw_key_callback(*args) 58 | self._imgui_renderer.keyboard_callback(*args) 59 | 60 | @property 61 | def font_size(self): 62 | return self._cur_font_size 63 | 64 | @property 65 | def spacing(self): 66 | return round(self._cur_font_size * 0.4) 67 | 68 | def set_font_size(self, target): # Applied on next frame. 69 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] 70 | 71 | def begin_frame(self): 72 | # Begin glfw frame. 73 | super().begin_frame() 74 | 75 | # Process imgui events. 76 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 77 | if self.content_width > 0 and self.content_height > 0: 78 | self._imgui_renderer.process_inputs() 79 | 80 | # Begin imgui frame. 81 | imgui.new_frame() 82 | imgui.push_font(self._imgui_fonts[self._cur_font_size]) 83 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) 84 | 85 | def end_frame(self): 86 | imgui.pop_font() 87 | imgui.render() 88 | imgui.end_frame() 89 | self._imgui_renderer.render(imgui.get_draw_data()) 90 | super().end_frame() 91 | 92 | #---------------------------------------------------------------------------- 93 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. 94 | 95 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__(*args, **kwargs) 98 | self.mouse_wheel_multiplier = 1 99 | 100 | def scroll_callback(self, window, x_offset, y_offset): 101 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier 102 | 103 | #---------------------------------------------------------------------------- 104 | -------------------------------------------------------------------------------- /gui_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import functools 10 | from typing import Optional 11 | 12 | import dnnlib 13 | import numpy as np 14 | import PIL.Image 15 | import PIL.ImageFont 16 | import scipy.ndimage 17 | 18 | from . import gl_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def get_default_font(): 23 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular 24 | return dnnlib.util.open_url(url, return_filename=True) 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | @functools.lru_cache(maxsize=None) 29 | def get_pil_font(font=None, size=32): 30 | if font is None: 31 | font = get_default_font() 32 | return PIL.ImageFont.truetype(font=font, size=size) 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def get_array(string, *, dropshadow_radius: int=None, **kwargs): 37 | if dropshadow_radius is not None: 38 | offset_x = int(np.ceil(dropshadow_radius*2/3)) 39 | offset_y = int(np.ceil(dropshadow_radius*2/3)) 40 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 41 | else: 42 | return _get_array_priv(string, **kwargs) 43 | 44 | @functools.lru_cache(maxsize=10000) 45 | def _get_array_priv( 46 | string: str, *, 47 | size: int = 32, 48 | max_width: Optional[int]=None, 49 | max_height: Optional[int]=None, 50 | min_size=10, 51 | shrink_coef=0.8, 52 | dropshadow_radius: int=None, 53 | offset_x: int=None, 54 | offset_y: int=None, 55 | **kwargs 56 | ): 57 | cur_size = size 58 | array = None 59 | while True: 60 | if dropshadow_radius is not None: 61 | # separate implementation for dropshadow text rendering 62 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 63 | else: 64 | array = _get_array_impl(string, size=cur_size, **kwargs) 65 | height, width, _ = array.shape 66 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): 67 | break 68 | cur_size = max(int(cur_size * shrink_coef), min_size) 69 | return array 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | @functools.lru_cache(maxsize=10000) 74 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): 75 | pil_font = get_pil_font(font=font, size=size) 76 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 77 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 78 | width = max(line.shape[1] for line in lines) 79 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 80 | line_spacing = line_pad if line_pad is not None else size // 2 81 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 82 | mask = np.concatenate(lines, axis=0) 83 | alpha = mask 84 | if outline > 0: 85 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) 86 | alpha = mask.astype(np.float32) / 255 87 | alpha = scipy.ndimage.gaussian_filter(alpha, outline) 88 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp 89 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 90 | alpha = np.maximum(alpha, mask) 91 | return np.stack([mask, alpha], axis=-1) 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | @functools.lru_cache(maxsize=10000) 96 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): 97 | assert (offset_x > 0) and (offset_y > 0) 98 | pil_font = get_pil_font(font=font, size=size) 99 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 100 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 101 | width = max(line.shape[1] for line in lines) 102 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 103 | line_spacing = line_pad if line_pad is not None else size // 2 104 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 105 | mask = np.concatenate(lines, axis=0) 106 | alpha = mask 107 | 108 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) 109 | alpha = mask.astype(np.float32) / 255 110 | alpha = scipy.ndimage.gaussian_filter(alpha, radius) 111 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 112 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 113 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] 114 | alpha = np.maximum(alpha, mask) 115 | return np.stack([mask, alpha], axis=-1) 116 | 117 | #---------------------------------------------------------------------------- 118 | 119 | @functools.lru_cache(maxsize=10000) 120 | def get_texture(string, bilinear=True, mipmap=True, **kwargs): 121 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) 122 | 123 | #---------------------------------------------------------------------------- 124 | -------------------------------------------------------------------------------- /marching_cube.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from training.training_loop_eg3d import save_image_grid 3 | # from training.EG3d_v2 import Generator 4 | # from training.EG3d_v7 import Generator 5 | # from training.EG3d_v16 import Generator 6 | from training.EG3d_v12 import Generator 7 | # from training.EG3d_v3 import Generator 8 | import dnnlib 9 | import legacy 10 | from torch_utils import misc 11 | import os 12 | from torch import nn 13 | from training.training_loop_eg3d import setup_snapshot_image_grid, save_image_grid 14 | import numpy as np 15 | import imageio 16 | from tqdm import tqdm 17 | 18 | 19 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v8/00001-stylegan2-images1024x1024-gpus8-batch64-gamma1/network-snapshot-012096.pkl' 20 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v16/00004-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-002400.pkl' 21 | root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v12/00006-stylegan2-images1024x1024-gpus8-batch64-gamma1/network-snapshot-016329.pkl' 22 | 23 | G_kwargs = dnnlib.EasyDict( 24 | z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), 25 | use_noise=False, # 关闭noise 26 | nerf_decoder_kwargs=dnnlib.EasyDict( 27 | in_c=32, 28 | mid_c=64, 29 | out_c=32, 30 | ) 31 | ) 32 | 33 | G_kwargs.mapping_kwargs.num_layers = 8 34 | G_kwargs.fused_modconv_default = 'inference_only' 35 | G_kwargs.conv_clamp = None 36 | common_kwargs = dict(c_dim= 16, #12, 37 | img_resolution=512, 38 | img_channels= 96, 39 | backbone_resolution=256, 40 | rank=0, 41 | ) 42 | G = Generator(**G_kwargs, **common_kwargs) 43 | G.cuda() 44 | G.eval() 45 | G.requires_grad_(False) 46 | if root: 47 | print(f'Resuming from "{root}"') 48 | cwd = os.getcwd() 49 | os.chdir('./training') 50 | with dnnlib.util.open_url(root) as f: 51 | resume_data = legacy.load_network_pkl(f) 52 | for name, module in [('G_ema', G)]: 53 | misc.copy_params_and_buffers(resume_data[name], module, require_all=True) 54 | os.chdir(cwd) 55 | 56 | 57 | device = torch.device('cuda') 58 | # with dnnlib.util.open_url(root) as f: 59 | # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 60 | # G.eval() 61 | 62 | 63 | meta_data = {'noise_mode': 'const'} 64 | 65 | grid_z = torch.randn(16, G.z_dim, device=device) 66 | sigmas = G.get_sigma(grid_z, **meta_data).cpu().numpy() 67 | threshold = 50. 68 | 69 | 70 | 71 | 72 | print(sigmas.shape) 73 | print('fraction occupied', np.mean(sigmas > threshold)) 74 | np.save('march_file.npy', sigmas) 75 | 76 | 77 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | # print('**************************************') 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | # print('**************************************') 33 | if opts.rank != 0: 34 | return float('nan') 35 | # print('**************************************') 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Main API for computing and reporting quality metrics.""" 10 | 11 | import os 12 | import time 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | from . import metric_utils 18 | from . import frechet_inception_distance 19 | from . import kernel_inception_distance 20 | from . import precision_recall 21 | from . import perceptual_path_length 22 | from . import inception_score 23 | from . import equivariance 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _metric_dict = dict() # name => fn 28 | 29 | def register_metric(fn): 30 | assert callable(fn) 31 | _metric_dict[fn.__name__] = fn 32 | return fn 33 | 34 | def is_valid_metric(metric): 35 | return metric in _metric_dict 36 | 37 | def list_valid_metrics(): 38 | return list(_metric_dict.keys()) 39 | 40 | #---------------------------------------------------------------------------- 41 | 42 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 43 | assert is_valid_metric(metric) 44 | opts = metric_utils.MetricOptions(**kwargs) 45 | 46 | # Calculate. 47 | start_time = time.time() 48 | results = _metric_dict[metric](opts) 49 | total_time = time.time() - start_time 50 | 51 | # Broadcast results. 52 | for key, value in list(results.items()): 53 | if opts.num_gpus > 1: 54 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 55 | torch.distributed.broadcast(tensor=value, src=0) 56 | value = float(value.cpu()) 57 | results[key] = value 58 | 59 | # Decorate with metadata. 60 | return dnnlib.EasyDict( 61 | results = dnnlib.EasyDict(results), 62 | metric = metric, 63 | total_time = total_time, 64 | total_time_str = dnnlib.util.format_time(total_time), 65 | num_gpus = opts.num_gpus, 66 | ) 67 | 68 | #---------------------------------------------------------------------------- 69 | 70 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 71 | metric = result_dict['metric'] 72 | assert is_valid_metric(metric) 73 | if run_dir is not None and snapshot_pkl is not None: 74 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 75 | 76 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 77 | print(jsonl_line) 78 | if run_dir is not None and os.path.isdir(run_dir): 79 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 80 | f.write(jsonl_line + '\n') 81 | 82 | #---------------------------------------------------------------------------- 83 | # Recommended metrics. 84 | 85 | @register_metric 86 | def fid50k_full(opts): 87 | opts.dataset_kwargs.update(max_size=None, xflip=False) 88 | # fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 89 | fid = frechet_inception_distance.compute_fid(opts, max_real=20, num_gen=20) 90 | return dict(fid50k_full=fid) 91 | 92 | @register_metric 93 | def kid50k_full(opts): 94 | opts.dataset_kwargs.update(max_size=None, xflip=False) 95 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 96 | return dict(kid50k_full=kid) 97 | 98 | @register_metric 99 | def pr50k3_full(opts): 100 | opts.dataset_kwargs.update(max_size=None, xflip=False) 101 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 102 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 103 | 104 | @register_metric 105 | def ppl2_wend(opts): 106 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 107 | return dict(ppl2_wend=ppl) 108 | 109 | @register_metric 110 | def eqt50k_int(opts): 111 | opts.G_kwargs.update(force_fp32=True) 112 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 113 | return dict(eqt50k_int=psnr) 114 | 115 | @register_metric 116 | def eqt50k_frac(opts): 117 | opts.G_kwargs.update(force_fp32=True) 118 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 119 | return dict(eqt50k_frac=psnr) 120 | 121 | @register_metric 122 | def eqr50k(opts): 123 | opts.G_kwargs.update(force_fp32=True) 124 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 125 | return dict(eqr50k=psnr) 126 | 127 | #---------------------------------------------------------------------------- 128 | # Legacy metrics. 129 | 130 | @register_metric 131 | def fid50k(opts): 132 | opts.dataset_kwargs.update(max_size=None) 133 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 134 | return dict(fid50k=fid) 135 | 136 | @register_metric 137 | def kid50k(opts): 138 | opts.dataset_kwargs.update(max_size=None) 139 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 140 | return dict(kid50k=kid) 141 | 142 | @register_metric 143 | def pr50k3(opts): 144 | opts.dataset_kwargs.update(max_size=None) 145 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 146 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 147 | 148 | @register_metric 149 | def is50k(opts): 150 | opts.dataset_kwargs.update(max_size=None, xflip=False) 151 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 152 | return dict(is50k_mean=mean, is50k_std=std) 153 | 154 | #---------------------------------------------------------------------------- 155 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | # Spherical interpolation of a batch of vectors. 22 | def slerp(a, b, t): 23 | a = a / a.norm(dim=-1, keepdim=True) 24 | b = b / b.norm(dim=-1, keepdim=True) 25 | d = (a * b).sum(dim=-1, keepdim=True) 26 | p = t * torch.acos(d) 27 | c = b - d * a 28 | c = c / c.norm(dim=-1, keepdim=True) 29 | d = a * torch.cos(p) + c * torch.sin(p) 30 | d = d / d.norm(dim=-1, keepdim=True) 31 | return d 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPLSampler(torch.nn.Module): 36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__() 40 | self.G = copy.deepcopy(G) 41 | self.G_kwargs = G_kwargs 42 | self.epsilon = epsilon 43 | self.space = space 44 | self.sampling = sampling 45 | self.crop = crop 46 | self.vgg16 = copy.deepcopy(vgg16) 47 | 48 | def forward(self, c): 49 | # Generate random latents and interpolation t-values. 50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 52 | 53 | # Interpolate in W or Z. 54 | if self.space == 'w': 55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 58 | else: # space == 'z' 59 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 62 | 63 | # Randomize noise buffers. 64 | for name, buf in self.G.named_buffers(): 65 | if name.endswith('.noise_const'): 66 | buf.copy_(torch.randn_like(buf)) 67 | 68 | # Generate images. 69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 70 | 71 | # Center crop. 72 | if self.crop: 73 | assert img.shape[2] == img.shape[3] 74 | c = img.shape[2] // 8 75 | img = img[:, :, c*3 : c*7, c*2 : c*6] 76 | 77 | # Downsample to 256x256. 78 | factor = self.G.img_resolution // 256 79 | if factor > 1: 80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 81 | 82 | # Scale dynamic range from [-1,1] to [0,255]. 83 | img = (img + 1) * (255 / 2) 84 | if self.G.img_channels == 1: 85 | img = img.repeat([1, 3, 1, 1]) 86 | 87 | # Evaluate differential LPIPS. 88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 90 | return dist 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 97 | 98 | # Setup sampler and labels. 99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 100 | sampler.eval().requires_grad_(False).to(opts.device) 101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 102 | 103 | # Sampling loop. 104 | dist = [] 105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 107 | progress.update(batch_start) 108 | x = sampler(next(c_iter)) 109 | for src in range(opts.num_gpus): 110 | y = x.clone() 111 | if opts.num_gpus > 1: 112 | torch.distributed.broadcast(y, src=src) 113 | dist.append(y) 114 | progress.update(num_samples) 115 | 116 | # Compute PPL. 117 | if opts.rank != 0: 118 | return float('nan') 119 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 120 | lo = np.percentile(dist, 1, interpolation='lower') 121 | hi = np.percentile(dist, 99, interpolation='higher') 122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 123 | return float(ppl) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /test_eg3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from training.training_loop_eg3d import save_image_grid 3 | # from training.EG3d_v2 import Generator 4 | # from training.EG3d_v7 import Generator 5 | # from training.EG3d_v8 import Generator 6 | # from training.EG3d_v9 import Generator 7 | from training.EG3d_v12 import Generator 8 | import dnnlib 9 | import legacy 10 | from torch_utils import misc 11 | import os 12 | from torch import nn 13 | from training.training_loop_eg3d import setup_snapshot_image_grid, save_image_grid 14 | import numpy as np 15 | 16 | 17 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v12/00003-stylegan2-images1024x1024-gpus8-batch64-gamma1/network-snapshot-004032.pkl' 18 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v13/00003-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-000400.pkl' 19 | root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v12/00006-stylegan2-images1024x1024-gpus8-batch64-gamma1/network-snapshot-016329.pkl' 20 | gen_num = 16 21 | 22 | 23 | G_kwargs = dnnlib.EasyDict( 24 | z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), 25 | # init_point_kwargs=dnnlib.EasyDict( 26 | # nerf_resolution=128, 27 | # fov=12, 28 | # d_range=(0.88, 1.12), 29 | # ), 30 | use_noise=False, # 关闭noise 31 | nerf_decoder_kwargs=dnnlib.EasyDict( 32 | in_c=32, 33 | mid_c=64, 34 | out_c=32, 35 | ) 36 | ) 37 | nerf_init_args = {} 38 | nerf_init_args['num_steps'] = 36 39 | nerf_init_args['img_size'] = 64 40 | nerf_init_args['fov'] = 12 41 | nerf_init_args['nerf_noise'] = 0 42 | nerf_init_args['ray_start'] = 0.88 43 | nerf_init_args['ray_end'] = 1.12 44 | G_kwargs.mapping_kwargs.num_layers = 8 45 | G_kwargs.fused_modconv_default = 'inference_only' 46 | G_kwargs.conv_clamp = None 47 | common_kwargs = dict(c_dim= 16, #12, 48 | img_resolution=256, 49 | img_channels= 96, 50 | backbone_resolution=128, 51 | rank=0, 52 | ) 53 | # G = Generator(**G_kwargs, **common_kwargs) 54 | # G.cuda() 55 | # G.eval() 56 | # G.requires_grad_(False) 57 | # if root: 58 | # print(f'Resuming from "{root}"') 59 | # cwd = os.getcwd() 60 | # os.chdir('./training') 61 | # with dnnlib.util.open_url(root) as f: 62 | # resume_data = legacy.load_network_pkl(f) 63 | # for name, module in [('G_ema', G)]: 64 | # misc.copy_params_and_buffers(resume_data[name], module, require_all=True) 65 | # os.chdir(cwd) 66 | 67 | device = torch.device('cuda') 68 | with dnnlib.util.open_url(root) as f: 69 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 70 | G.eval() 71 | 72 | 73 | grid_size, grid_c = setup_snapshot_image_grid(gen_num, device=torch.device('cuda')) 74 | grid_z = torch.randn([gen_num, G.z_dim], device=device) 75 | 76 | meta_data = {'noise_mode': 'const',} # 'trans_x':torch.tensor([0.5, 0, 0]).to(device) } 77 | total_imgs = [] 78 | for c in grid_c: 79 | images = G(z=grid_z, angles=c, nerf_init_args=nerf_init_args, **meta_data)[:, :3] # b,3,h,w 80 | # images = G(z=grid_z, c=c, **meta_data)[:, :3] # b,3,h,w 81 | res = images.shape[-1] 82 | # save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) 83 | total_imgs.append(images) 84 | total_imgs = torch.stack(total_imgs, dim=1).reshape(grid_size[0]*grid_size[1], 3, res, res).cpu().numpy() # b*5, 3,h,w 85 | thre_imgs = np.where((total_imgs > 1) + (total_imgs < -1), 1, 0) # 溢出位置为1,否则为0 86 | 87 | 88 | save_image_grid(total_imgs, 'test_gen_images.png', drange=[-1,1], grid_size=grid_size) 89 | save_image_grid(thre_imgs, 'test_thre_images.png', drange=[-1,1], grid_size=grid_size) 90 | 91 | 92 | -------------------------------------------------------------------------------- /test_eg3d_gen_video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from training.training_loop_eg3d import save_image_grid 3 | # from training.EG3d_v2 import Generator 4 | # from training.EG3d_v7 import Generator 5 | # from training.EG3d_v12 import Generator 6 | # from training.EG3d_v3 import Generator 7 | import dnnlib 8 | import legacy 9 | from torch_utils import misc 10 | import os 11 | from torch import nn 12 | from training.training_loop_eg3d import setup_snapshot_image_grid, save_image_grid 13 | import numpy as np 14 | import imageio 15 | from tqdm import tqdm 16 | import math 17 | 18 | def make_camera_trajectory(fps): 19 | yaws = torch.linspace(-30, 30, fps) 20 | # pitchs = torch.linspace(-5, 5, fps // 2) 21 | pitchs = torch.zeros_like(yaws) 22 | rolls = torch.zeros_like(pitchs) 23 | angles = torch.stack([pitchs, yaws, rolls], dim=-1) # fps/2, 3 24 | # angles = torch.cat(angles, angles[::-1, ...]) 25 | return angles 26 | 27 | 28 | def make_camera_circle_trajectory(num_samples): 29 | fps = num_samples 30 | max_pitch = 15 31 | max_yaw = 30 32 | pitch1 = torch.linspace(max_pitch, 0, fps // 4) 33 | yaws1 = torch.linspace(0, max_yaw, fps // 4) 34 | 35 | pitch2 = torch.linspace(0, -max_pitch, fps // 4) 36 | yaws2 = torch.linspace(max_yaw, 0, fps//4) 37 | 38 | pitch3 = torch.linspace(-max_pitch, 0, fps//4) 39 | yaws3 = torch.linspace(0, -max_yaw, fps//4) 40 | 41 | pitch4 = torch.linspace(0, max_pitch, fps // 4) 42 | yaws4 = torch.linspace(-max_yaw, 0, fps//4) 43 | 44 | pitch = torch.cat([pitch1, pitch2, pitch3, pitch4]) 45 | yaws = torch.cat([yaws1, yaws2, yaws3, yaws4]) 46 | angles = torch.stack( 47 | [ 48 | pitch, yaws, torch.zeros_like(yaws) 49 | ], dim=-1 50 | ) 51 | return angles 52 | 53 | def trans_to_img(img, drange): 54 | lo, hi = drange 55 | img = np.asarray(img, dtype=np.float32) 56 | img = (img - lo) * (255 / (hi - lo)) 57 | img = np.rint(img).clip(0, 255).astype(np.uint8) 58 | img = img.transpose(1, 2, 0) 59 | return img 60 | 61 | 62 | root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v12/00006-stylegan2-images1024x1024-gpus8-batch64-gamma1/network-snapshot-016329.pkl' 63 | save_root = './gen_video_examples/cicle' 64 | os.makedirs(save_root, exist_ok=True) 65 | fps = 30 * 4 66 | img_size = 256 67 | 68 | device = torch.device('cuda') 69 | G_kwargs = dnnlib.EasyDict( 70 | z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), 71 | # init_point_kwargs=dnnlib.EasyDict( 72 | # nerf_resolution=128, 73 | # fov=12, 74 | # d_range=(0.88, 1.12), 75 | # ), 76 | use_noise=False, # 关闭noise 77 | nerf_decoder_kwargs=dnnlib.EasyDict( 78 | in_c=32, 79 | mid_c=64, 80 | out_c=32, 81 | ) 82 | ) 83 | nerf_init_args = {} 84 | nerf_init_args['num_steps'] = 36 85 | nerf_init_args['img_size'] = 64 86 | nerf_init_args['fov'] = 12 87 | nerf_init_args['nerf_noise'] = 0.5 88 | nerf_init_args['ray_start'] = 0.88 89 | nerf_init_args['ray_end'] = 1.12 90 | 91 | G_kwargs.mapping_kwargs.num_layers = 8 92 | G_kwargs.fused_modconv_default = 'inference_only' 93 | G_kwargs.conv_clamp = None 94 | common_kwargs = dict(c_dim= 16, #12, 95 | img_resolution=256, 96 | img_channels= 96, 97 | backbone_resolution=128, 98 | rank=0, 99 | ) 100 | 101 | # G = Generator(**G_kwargs, **common_kwargs) 102 | # G.cuda() 103 | # G.eval() 104 | # G.requires_grad_(False) 105 | # if root: 106 | # print(f'Resuming from "{root}"') 107 | # cwd = os.getcwd() 108 | # os.chdir('./training') 109 | # with dnnlib.util.open_url(root) as f: 110 | # resume_data = legacy.load_network_pkl(f) 111 | # for name, module in [('G_ema', G)]: 112 | # misc.copy_params_and_buffers(resume_data[name], module, require_all=True) 113 | # os.chdir(cwd) 114 | 115 | with dnnlib.util.open_url(root) as f: 116 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 117 | G.eval() 118 | 119 | meta_data = {'noise_mode': 'const'} 120 | 121 | 122 | video_num = 9 123 | grid_z = torch.randn([video_num, G.z_dim], device=device) 124 | # angles = make_camera_trajectory(fps) 125 | angles = make_camera_circle_trajectory(num_samples=fps) 126 | angles = angles.cuda() 127 | print(angles) 128 | for idx in range(video_num): 129 | z = grid_z[idx:idx+1] 130 | z = z.expand(fps, -1) 131 | images = [] 132 | for f_idx in tqdm(range(fps)): 133 | zz = z[f_idx:f_idx+1] 134 | image = G(z=zz, angles=angles[f_idx:f_idx+1], nerf_init_args=nerf_init_args, **meta_data)[:, :3].squeeze() # 3,h,w 135 | image = image.cpu().numpy() 136 | image = trans_to_img(image, (-1, 1)) 137 | images.append(image) 138 | 139 | save_dir = os.path.join(save_root, f'{idx}') 140 | os.makedirs(save_dir, exist_ok=True) 141 | for f_idx in tqdm(range(fps)): 142 | image = images[f_idx] 143 | imageio.imwrite(os.path.join(save_dir, f'{f_idx}.png'), image) 144 | 145 | 146 | 147 | # exit() 148 | # make_video 149 | import cv2 150 | # select_idx = [1,3,4,6,8,9,10,15,17,18,19,20,22,23,25,27] 151 | select_idx = list(range(video_num)) 152 | n_row = int(math.sqrt(video_num)) 153 | n_col = int(math.sqrt(video_num)) 154 | 155 | fourcc = cv2.VideoWriter_fourcc(*"XVID") 156 | writer = cv2.VideoWriter( 157 | './video_example.mp4', fourcc, 158 | 30, (n_col*img_size, n_row*img_size)) 159 | for f_idx in range(fps): 160 | imgs = np.zeros(shape=(n_row*img_size, n_col*img_size, 3), dtype=np.uint8) 161 | for v_idx in range(len(select_idx)): 162 | img_path = os.path.join(save_root, f'{select_idx[v_idx]}', f'{f_idx}.png') 163 | img = cv2.imread(img_path) 164 | row = v_idx // n_col 165 | col = v_idx % n_col 166 | imgs[row*img_size:(row+1)*img_size, col*img_size:(col+1)*img_size] = img 167 | writer.write(imgs) 168 | writer.release() 169 | -------------------------------------------------------------------------------- /test_eg3d_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from training.training_loop_eg3d import save_image_grid 3 | # from training.EG3d_v14 import Generator 4 | from training.EG3d_v16 import Generator 5 | # from training.EG3d_v12 import Generator 6 | import dnnlib 7 | import legacy 8 | from torch_utils import misc 9 | import os 10 | from torch import nn 11 | from training.training_loop_eg3d import setup_snapshot_image_grid, save_image_grid 12 | import numpy as np 13 | torch.set_grad_enabled(False) 14 | from tqdm import tqdm 15 | import math 16 | import imageio 17 | 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | def make_camera_circle_trajectory(num_samples): 22 | fps = num_samples 23 | max_pitch = 15 24 | max_yaw = 30 25 | pitch1 = torch.linspace(max_pitch, 0, fps // 4) 26 | yaws1 = torch.linspace(0, max_yaw, fps // 4) 27 | 28 | pitch2 = torch.linspace(0, -max_pitch, fps // 4) 29 | yaws2 = torch.linspace(max_yaw, 0, fps//4) 30 | 31 | pitch3 = torch.linspace(-max_pitch, 0, fps//4) 32 | yaws3 = torch.linspace(0, -max_yaw, fps//4) 33 | 34 | pitch4 = torch.linspace(0, max_pitch, fps // 4) 35 | yaws4 = torch.linspace(-max_yaw, 0, fps//4) 36 | 37 | pitch = torch.cat([pitch1, pitch2, pitch3, pitch4]) 38 | yaws = torch.cat([yaws1, yaws2, yaws3, yaws4]) 39 | angles = torch.stack( 40 | [ 41 | pitch, yaws, torch.zeros_like(yaws) 42 | ], dim=-1 43 | ) 44 | return angles 45 | 46 | def trans_to_img(img, drange): 47 | lo, hi = drange 48 | img = np.asarray(img, dtype=np.float32) 49 | img = (img - lo) * (255 / (hi - lo)) 50 | img = np.rint(img).clip(0, 255).astype(np.uint8) 51 | img = img.transpose(1, 2, 0) 52 | return img 53 | 54 | 55 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v12/00003-stylegan2-images1024x1024-gpus8-batch64-gamma1/network-snapshot-004032.pkl' 56 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v14/00006-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-004000.pkl' 57 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v14/00009-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-003600.pkl' 58 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v16/00004-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-002400.pkl' 59 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v12/00010-stylegan2-images1024x1024-gpus8-batch64-gamma1/network-snapshot-017539.pkl' 60 | # root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v14/00009-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-025000.pkl' 61 | root = '/home/yangjie08/stylegan3-main/training-runs/EG3d_v16/00004-stylegan2-images1024x1024-gpus8-batch32-gamma1/network-snapshot-012800.pkl' 62 | 63 | mode = 'video' # 'video' 'img' 64 | gen_num = 4 # img mode 65 | video_num = 4 # video mode 66 | fps = 30 * 4 67 | # hyper params 68 | img_size =512 #256 69 | nerf_init_args = {'img_size': 64} # 64 70 | 71 | # 不用设置 72 | G_kwargs = dnnlib.EasyDict( 73 | z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), 74 | use_noise=False, # 关闭noise 75 | nerf_decoder_kwargs=dnnlib.EasyDict( 76 | in_c=32, 77 | mid_c=64, 78 | out_c=32, 79 | ) 80 | ) 81 | device = torch.device('cuda') 82 | G_kwargs.mapping_kwargs.num_layers = 8 83 | G_kwargs.fused_modconv_default = 'inference_only' 84 | G_kwargs.conv_clamp = None 85 | common_kwargs = dict(c_dim=16, #12, 86 | img_resolution=img_size, 87 | img_channels= 96, 88 | backbone_resolution=None, 89 | rank=0, 90 | ) 91 | G = Generator(**G_kwargs, **common_kwargs) 92 | G.cuda() 93 | G.eval() 94 | if root: 95 | print(f'Resuming from "{root}"') 96 | cwd = os.getcwd() 97 | os.chdir('./training') 98 | with dnnlib.util.open_url(root) as f: 99 | resume_data = legacy.load_network_pkl(f) 100 | for name, module in [('G_ema', G)]: 101 | misc.copy_params_and_buffers(resume_data[name], module, require_all=True) 102 | os.chdir(cwd) 103 | 104 | # with dnnlib.util.open_url(root) as f: 105 | # G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 106 | # G.eval() 107 | 108 | G.requires_grad_(False) 109 | meta_data = {'noise_mode': 'const',} # 'trans_x':torch.tensor([0.5, 0, 0]).to(device) } 110 | if mode == 'img': 111 | grid_size, grid_c = setup_snapshot_image_grid(gen_num, device=torch.device('cuda')) 112 | grid_z = torch.randn([gen_num, G.z_dim], device=device) 113 | h = w = img_size 114 | total_imgs = [] 115 | for c in grid_c: 116 | cond = torch.zeros_like(c) 117 | images = G(z=grid_z, angles=c, cond=cond, nerf_init_args=nerf_init_args, **meta_data).reshape(-1, 3, h, w) # b,3,h,w 118 | # images = G(z=grid_z, angles=c, nerf_init_args=nerf_init_args, **meta_data)[:, :3] # b,3,h,w 119 | res = images.shape[-1] 120 | # save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size) 121 | total_imgs.append(images) 122 | if images.shape[0] == 2 * gen_num: 123 | grid_size[0] = 2 * gen_num 124 | total_imgs = torch.stack(total_imgs, dim=1).reshape(grid_size[0]*grid_size[1], 3, res, res).cpu().numpy() # b*5, 3,h,w 125 | 126 | thre_imgs = np.where((total_imgs > 1) + (total_imgs < -1), 1, 0) # 溢出位置为1,否则为0 127 | 128 | save_image_grid(total_imgs, 'test_gen_images.png', drange=[-1,1], grid_size=grid_size) 129 | save_image_grid(thre_imgs, 'test_thre_images.png', drange=[-1,1], grid_size=grid_size) 130 | elif mode == 'video': 131 | save_root = './gen_video_examples/cicle' 132 | grid_z = torch.randn([video_num, G.z_dim], device=device) 133 | # angles = make_camera_trajectory(fps) 134 | angles = make_camera_circle_trajectory(num_samples=fps) 135 | angles = angles.cuda() 136 | for idx in range(video_num): 137 | z = grid_z[idx:idx+1] 138 | z = z.expand(fps, -1) 139 | images = [] 140 | for f_idx in tqdm(range(fps)): 141 | zz = z[f_idx:f_idx+1] 142 | cond = torch.zeros_like(angles[f_idx:f_idx+1]) 143 | # image = G(z=zz, angles=angles[f_idx:f_idx+1], nerf_init_args=nerf_init_args, **meta_data)[:, :3].squeeze() # 3,h,w 144 | image = G(z=zz, angles=angles[f_idx:f_idx+1], cond=cond, nerf_init_args=nerf_init_args, **meta_data)[:, :3].squeeze() # b,3,h,w 145 | image = image.cpu().numpy() 146 | image = trans_to_img(image, (-1, 1)) 147 | images.append(image) 148 | 149 | save_dir = os.path.join(save_root, f'{idx}') 150 | os.makedirs(save_dir, exist_ok=True) 151 | for f_idx in tqdm(range(fps)): # TODO: 多进程执行 152 | image = images[f_idx] 153 | imageio.imwrite(os.path.join(save_dir, f'{f_idx}.png'), image) 154 | 155 | import cv2 156 | select_idx = list(range(video_num)) 157 | n_row = int(math.sqrt(video_num)) 158 | n_col = int(math.sqrt(video_num)) 159 | # print(n_row, n_col) 160 | # exit() 161 | fourcc = cv2.VideoWriter_fourcc(*"XVID") 162 | writer = cv2.VideoWriter( 163 | './video_example.mp4', fourcc, 164 | 30, (n_col*img_size, n_row*img_size)) 165 | for f_idx in range(fps): 166 | imgs = np.zeros(shape=(n_row*img_size, n_col*img_size, 3), dtype=np.uint8) 167 | for v_idx in range(len(select_idx)): 168 | img_path = os.path.join(save_root, f'{select_idx[v_idx]}', f'{f_idx}.png') 169 | img = cv2.imread(img_path) 170 | row = v_idx // n_col 171 | col = v_idx % n_col 172 | imgs[row*img_size:(row+1)*img_size, col*img_size:(col+1)*img_size] = img 173 | writer.write(imgs) 174 | writer.release() 175 | -------------------------------------------------------------------------------- /test_gen_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/test_gen_images.png -------------------------------------------------------------------------------- /test_thre_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bruinxiong/EG3D-pytorch/dd21da2aa73d8a8dcf248b33746779ed6182d314/test_thre_images.png -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/cips_camera_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import math 4 | import logging 5 | from PIL import Image 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from einops import rearrange, repeat 11 | 12 | 13 | def sample_pdf(bins, 14 | weights, 15 | N_importance, 16 | det=False, 17 | eps=1e-5): 18 | """ 19 | Sample @N_importance samples from @bins with distribution defined by @weights. 20 | Inputs: 21 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" 22 | weights: (N_rays, N_samples_) 23 | N_importance: the number of samples to draw from the distribution 24 | det: deterministic or not 25 | eps: a small number to prevent division by zero 26 | Outputs: 27 | samples: (N_rays, N_importance), the sampled samples 28 | Source: https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py 29 | """ 30 | N_rays, N_samples_ = weights.shape 31 | weights = weights + eps # prevent division by zero (don't do inplace op!) 32 | # (N_rays, N_samples_) 33 | pdf = weights / torch.sum(weights, -1, keepdim=True) 34 | # (N_rays, N_samples), cumulative distribution function 35 | cdf = torch.cumsum(pdf, -1) 36 | # (N_rays, N_samples_+1) 37 | cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) 38 | # padded to 0~1 inclusive 39 | 40 | if det: 41 | u = torch.linspace(0, 1, N_importance, device=bins.device) 42 | u = u.expand(N_rays, N_importance) 43 | else: 44 | u = torch.rand(N_rays, N_importance, device=bins.device) 45 | u = u.contiguous() 46 | 47 | inds = torch.searchsorted(cdf, u) 48 | below = torch.clamp_min(inds - 1, 0) 49 | above = torch.clamp_max(inds, N_samples_) 50 | 51 | inds_sampled = torch.stack( 52 | [below, above], -1).view(N_rays, 2 * N_importance) 53 | cdf_g = torch.gather(cdf, 1, inds_sampled) 54 | cdf_g = cdf_g.view(N_rays, N_importance, 2) 55 | bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) 56 | 57 | denom = cdf_g[..., 1] - cdf_g[..., 0] 58 | # denom equals 0 means a bin has weight 0, in which case it will not be sampled 59 | denom[denom < eps] = 1 60 | # anyway, therefore any value for it is fine (set to 1 here) 61 | 62 | samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / \ 63 | denom * (bins_g[..., 1] - bins_g[..., 0]) 64 | return samples 65 | 66 | def fancy_integration(rgb_sigma, 67 | z_vals, 68 | device, 69 | noise_std=0.5, 70 | last_back=False, 71 | white_back=False, 72 | fill_mode=None): 73 | """ 74 | Performs NeRF volumetric rendering. 75 | 76 | :param rgb_sigma: (b, h x w, num_samples, dim_rgb + dim_sigma) 77 | :param z_vals: (b, h x w, num_samples, 1) 78 | :param device: 79 | :param dim_rgb: rgb feature dim 80 | :param noise_std: 81 | :param last_back: 82 | :param white_back: 83 | :param clamp_mode: 84 | :param fill_mode: 85 | :return: 86 | - rgb_final: (b, h x w, dim_rgb) 87 | - depth_final: (b, h x w, 1) 88 | - weights: (b, h x w, num_samples, 1) 89 | """ 90 | 91 | rgbs = rgb_sigma[..., :-1] # (b, h x w, num_samples, 32) 92 | sigmas = rgb_sigma[..., -1:] # (b, h x w, num_samples, 1) 93 | 94 | # (b, h x w, num_samples - 1, 1) 95 | deltas = z_vals[:, :, 1:] - z_vals[:, :, :-1] 96 | delta_inf = 1e10 * torch.ones_like(deltas[:, :, :1]) # (b, h x w, 1, 1) 97 | deltas = torch.cat([deltas, delta_inf], -2) # (b, h x w, num_samples, 1) 98 | 99 | noise = torch.randn(sigmas.shape, device=device) * \ 100 | noise_std # (b, h x w, num_samples, 1) 101 | 102 | # if clamp_mode == 'softplus': 103 | # alphas = 1 - torch.exp(-deltas * (F.softplus(sigmas + noise))) 104 | # elif clamp_mode == 'relu': 105 | # # (b, h x w, num_samples, 1) 106 | # print(deltas.shape, sigmas.shape, noise.shape) 107 | alphas = 1 - torch.exp(-deltas * (F.relu(sigmas + noise))) 108 | # else: 109 | # assert 0, "Need to choose clamp mode" 110 | 111 | alphas_shifted = torch.cat([torch.ones_like( 112 | alphas[:, :, :1]), 1 - alphas + 1e-10], -2) # (b, h x w, num_samples + 1, 1) 113 | # (b, h x w, num_samples, 1) 114 | weights = alphas * torch.cumprod(alphas_shifted, -2)[:, :, :-1] 115 | weights_sum = weights.sum(2) 116 | 117 | if last_back: 118 | weights[:, :, -1] += (1 - weights_sum) 119 | 120 | rgb_final = torch.sum(weights * rgbs, -2) # (b, h x w, num_samples, 3) 121 | depth_final = torch.sum(weights * z_vals, -2) # (b, h x w, num_samples, 1) 122 | 123 | if white_back: 124 | rgb_final = rgb_final + 1 - weights_sum 125 | 126 | if fill_mode == 'debug': 127 | rgb_final[weights_sum.squeeze(-1) < 0.9] = torch.tensor( 128 | [1., 0, 0], device=rgb_final.device) 129 | elif fill_mode == 'weight': 130 | rgb_final = weights_sum.expand_as(rgb_final) 131 | 132 | return rgb_final, depth_final, weights 133 | 134 | 135 | @torch.no_grad() 136 | def get_fine_points_and_direction( 137 | coarse_output, 138 | z_vals, 139 | nerf_noise, 140 | num_steps, 141 | transformed_ray_origins, 142 | transformed_ray_directions, 143 | device, 144 | ): 145 | """ 146 | 147 | :param coarse_output: (b, h x w, num_samples, rgb_sigma) 148 | :param z_vals: (b, h x w, num_samples, 1) 149 | :param clamp_mode: 150 | :param nerf_noise: 151 | :param num_steps: 152 | :param transformed_ray_origins: (b, h x w, 3) 153 | :param transformed_ray_directions: (b, h x w, 3) 154 | :return: 155 | - fine_points: (b, h x w x num_steps, 3) 156 | - fine_z_vals: (b, h x w, num_steps, 1) 157 | """ 158 | 159 | batch_size = coarse_output.shape[0] 160 | 161 | _, _, weights = fancy_integration( 162 | rgb_sigma=coarse_output, 163 | z_vals=z_vals, 164 | device=device, 165 | # clamp_mode=clamp_mode, 166 | noise_std=nerf_noise) 167 | 168 | # weights = weights.reshape(batch_size * img_size * img_size, num_steps) + 1e-5 169 | weights = rearrange(weights, "b hw s 1 -> (b hw) s") + 1e-5 170 | 171 | # Start new importance sampling 172 | # z_vals = z_vals.reshape(batch_size * img_size * img_size, num_steps) 173 | z_vals = rearrange(z_vals, "b hw s 1 -> (b hw) s") 174 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) 175 | # z_vals = z_vals.reshape(batch_size, img_size * img_size, num_steps, 1) 176 | # z_vals = rearrange(z_vals, "(b hw) s -> b hw s 1", b=batch_size) 177 | fine_z_vals = sample_pdf(bins=z_vals_mid, 178 | weights=weights[:, 1:-1], 179 | N_importance=num_steps, 180 | det=False).detach() 181 | # fine_z_vals = fine_z_vals.reshape(batch_size, img_size * img_size, num_steps, 1) 182 | fine_z_vals = rearrange(fine_z_vals, "(b hw) s -> b hw s 1", b=batch_size) 183 | 184 | fine_points = transformed_ray_origins.unsqueeze(2).contiguous() + \ 185 | transformed_ray_directions.unsqueeze(2).contiguous() * \ 186 | fine_z_vals.expand(-1, -1, -1, 3).contiguous() 187 | # fine_points = fine_points.reshape(batch_size, img_size * img_size * num_steps, 3) 188 | # fine_points = rearrange(fine_points, "b hw s c -> b (hw s) c") 189 | 190 | # if lock_view_dependence: 191 | # transformed_ray_directions_expanded = torch.zeros_like(transformed_ray_directions_expanded) 192 | # transformed_ray_directions_expanded[..., -1] = -1 193 | # end new importance sampling 194 | return fine_points, fine_z_vals 195 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Streaming images and labels from datasets created with dataset_tool.py.""" 10 | 11 | import os 12 | import numpy as np 13 | import zipfile 14 | import PIL.Image 15 | import json 16 | import torch 17 | import dnnlib 18 | 19 | try: 20 | import pyspng 21 | except ImportError: 22 | pyspng = None 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | ): 35 | self._name = name 36 | self._raw_shape = list(raw_shape) 37 | self._use_labels = use_labels 38 | self._raw_labels = None 39 | self._label_shape = None 40 | 41 | # Apply max_size. 42 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 43 | if (max_size is not None) and (self._raw_idx.size > max_size): 44 | np.random.RandomState(random_seed).shuffle(self._raw_idx) 45 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 46 | 47 | # Apply xflip. 48 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 49 | if xflip: 50 | self._raw_idx = np.tile(self._raw_idx, 2) 51 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 52 | 53 | def _get_raw_labels(self): 54 | if self._raw_labels is None: 55 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 56 | if self._raw_labels is None: 57 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 58 | assert isinstance(self._raw_labels, np.ndarray) 59 | assert self._raw_labels.shape[0] == self._raw_shape[0] 60 | assert self._raw_labels.dtype in [np.float32, np.int64] 61 | if self._raw_labels.dtype == np.int64: 62 | assert self._raw_labels.ndim == 1 63 | assert np.all(self._raw_labels >= 0) 64 | return self._raw_labels 65 | 66 | def close(self): # to be overridden by subclass 67 | pass 68 | 69 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 70 | raise NotImplementedError 71 | 72 | def _load_raw_labels(self): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def __getstate__(self): 76 | return dict(self.__dict__, _raw_labels=None) 77 | 78 | def __del__(self): 79 | try: 80 | self.close() 81 | except: 82 | pass 83 | 84 | def __len__(self): 85 | return self._raw_idx.size 86 | 87 | def __getitem__(self, idx): 88 | image = self._load_raw_image(self._raw_idx[idx]) 89 | assert isinstance(image, np.ndarray) 90 | assert list(image.shape) == self.image_shape 91 | assert image.dtype == np.uint8 92 | if self._xflip[idx]: 93 | assert image.ndim == 3 # CHW 94 | image = image[:, :, ::-1] 95 | return image.copy(), self.get_label(idx) 96 | 97 | def get_label(self, idx): 98 | label = self._get_raw_labels()[self._raw_idx[idx]] 99 | if label.dtype == np.int64: 100 | onehot = np.zeros(self.label_shape, dtype=np.float32) 101 | onehot[label] = 1 102 | label = onehot 103 | return label.copy() 104 | 105 | def get_details(self, idx): 106 | d = dnnlib.EasyDict() 107 | d.raw_idx = int(self._raw_idx[idx]) 108 | d.xflip = (int(self._xflip[idx]) != 0) 109 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 110 | return d 111 | 112 | @property 113 | def name(self): 114 | return self._name 115 | 116 | @property 117 | def image_shape(self): 118 | return list(self._raw_shape[1:]) 119 | 120 | @property 121 | def num_channels(self): 122 | assert len(self.image_shape) == 3 # CHW 123 | return self.image_shape[0] 124 | 125 | @property 126 | def resolution(self): 127 | assert len(self.image_shape) == 3 # CHW 128 | assert self.image_shape[1] == self.image_shape[2] 129 | return self.image_shape[1] 130 | 131 | @property 132 | def label_shape(self): 133 | if self._label_shape is None: 134 | raw_labels = self._get_raw_labels() 135 | if raw_labels.dtype == np.int64: 136 | self._label_shape = [int(np.max(raw_labels)) + 1] 137 | else: 138 | self._label_shape = raw_labels.shape[1:] 139 | return list(self._label_shape) 140 | 141 | @property 142 | def label_dim(self): 143 | assert len(self.label_shape) == 1 144 | return self.label_shape[0] 145 | 146 | @property 147 | def has_labels(self): 148 | return any(x != 0 for x in self.label_shape) 149 | 150 | @property 151 | def has_onehot_labels(self): 152 | return self._get_raw_labels().dtype == np.int64 153 | 154 | #---------------------------------------------------------------------------- 155 | 156 | class ImageFolderDataset(Dataset): 157 | def __init__(self, 158 | path, # Path to directory or zip. 159 | resolution = None, # Ensure specific resolution, None = highest available. 160 | **super_kwargs, # Additional arguments for the Dataset base class. 161 | ): 162 | self._path = path 163 | self._zipfile = None 164 | 165 | if os.path.isdir(self._path): 166 | self._type = 'dir' 167 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 168 | elif self._file_ext(self._path) == '.zip': 169 | self._type = 'zip' 170 | self._all_fnames = set(self._get_zipfile().namelist()) 171 | else: 172 | raise IOError('Path must point to a directory or zip') 173 | 174 | PIL.Image.init() 175 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 176 | if len(self._image_fnames) == 0: 177 | raise IOError('No image files found in the specified path') 178 | 179 | name = os.path.splitext(os.path.basename(self._path))[0] 180 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 181 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 182 | raise IOError('Image files do not match the specified resolution') 183 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 184 | 185 | @staticmethod 186 | def _file_ext(fname): 187 | return os.path.splitext(fname)[1].lower() 188 | 189 | def _get_zipfile(self): 190 | assert self._type == 'zip' 191 | if self._zipfile is None: 192 | self._zipfile = zipfile.ZipFile(self._path) 193 | return self._zipfile 194 | 195 | def _open_file(self, fname): 196 | if self._type == 'dir': 197 | return open(os.path.join(self._path, fname), 'rb') 198 | if self._type == 'zip': 199 | return self._get_zipfile().open(fname, 'r') 200 | return None 201 | 202 | def close(self): 203 | try: 204 | if self._zipfile is not None: 205 | self._zipfile.close() 206 | finally: 207 | self._zipfile = None 208 | 209 | def __getstate__(self): 210 | return dict(super().__getstate__(), _zipfile=None) 211 | 212 | def _load_raw_image(self, raw_idx): 213 | fname = self._image_fnames[raw_idx] 214 | with self._open_file(fname) as f: 215 | if pyspng is not None and self._file_ext(fname) == '.png': 216 | image = pyspng.load(f.read()) 217 | else: 218 | image = np.array(PIL.Image.open(f)) 219 | if image.ndim == 2: 220 | image = image[:, :, np.newaxis] # HW => HWC 221 | image = image.transpose(2, 0, 1) # HWC => CHW 222 | return image 223 | 224 | def _load_raw_labels(self): 225 | fname = 'dataset.json' 226 | if fname not in self._all_fnames: 227 | return None 228 | with self._open_file(fname) as f: 229 | labels = json.load(f)['labels'] 230 | if labels is None: 231 | return None 232 | labels = dict(labels) 233 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 234 | labels = np.array(labels) 235 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 236 | return labels 237 | 238 | #---------------------------------------------------------------------------- 239 | -------------------------------------------------------------------------------- /viz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /viz/capture_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import re 11 | import numpy as np 12 | import imgui 13 | import PIL.Image 14 | from gui_utils import imgui_utils 15 | from . import renderer 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class CaptureWidget: 20 | def __init__(self, viz): 21 | self.viz = viz 22 | self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) 23 | self.dump_image = False 24 | self.dump_gui = False 25 | self.defer_frames = 0 26 | self.disabled_time = 0 27 | 28 | def dump_png(self, image): 29 | viz = self.viz 30 | try: 31 | _height, _width, channels = image.shape 32 | assert channels in [1, 3] 33 | assert image.dtype == np.uint8 34 | os.makedirs(self.path, exist_ok=True) 35 | file_id = 0 36 | for entry in os.scandir(self.path): 37 | if entry.is_file(): 38 | match = re.fullmatch(r'(\d+).*', entry.name) 39 | if match: 40 | file_id = max(file_id, int(match.group(1)) + 1) 41 | if channels == 1: 42 | pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') 43 | else: 44 | pil_image = PIL.Image.fromarray(image, 'RGB') 45 | pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) 46 | except: 47 | viz.result.error = renderer.CapturedException() 48 | 49 | @imgui_utils.scoped_by_object_id 50 | def __call__(self, show=True): 51 | viz = self.viz 52 | if show: 53 | with imgui_utils.grayed_out(self.disabled_time != 0): 54 | imgui.text('Capture') 55 | imgui.same_line(viz.label_w) 56 | _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, 57 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 58 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 59 | help_text='PATH') 60 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': 61 | imgui.set_tooltip(self.path) 62 | imgui.same_line() 63 | if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): 64 | self.dump_image = True 65 | self.defer_frames = 2 66 | self.disabled_time = 0.5 67 | imgui.same_line() 68 | if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)): 69 | self.dump_gui = True 70 | self.defer_frames = 2 71 | self.disabled_time = 0.5 72 | 73 | self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) 74 | if self.defer_frames > 0: 75 | self.defer_frames -= 1 76 | elif self.dump_image: 77 | if 'image' in viz.result: 78 | self.dump_png(viz.result.image) 79 | self.dump_image = False 80 | elif self.dump_gui: 81 | viz.capture_next_frame() 82 | self.dump_gui = False 83 | captured_frame = viz.pop_captured_frame() 84 | if captured_frame is not None: 85 | self.dump_png(captured_frame) 86 | 87 | #---------------------------------------------------------------------------- 88 | -------------------------------------------------------------------------------- /viz/equivariance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class EquivarianceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.xlate = dnnlib.EasyDict(x=0, y=0, anim=False, round=False, speed=1e-2) 20 | self.xlate_def = dnnlib.EasyDict(self.xlate) 21 | self.rotate = dnnlib.EasyDict(val=0, anim=False, speed=5e-3) 22 | self.rotate_def = dnnlib.EasyDict(self.rotate) 23 | self.opts = dnnlib.EasyDict(untransform=False) 24 | self.opts_def = dnnlib.EasyDict(self.opts) 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | if show: 30 | imgui.text('Translate') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8): 33 | _changed, (self.xlate.x, self.xlate.y) = imgui.input_float2('##xlate', self.xlate.x, self.xlate.y, format='%.4f') 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag fast##xlate', width=viz.button_w) 36 | if dragging: 37 | self.xlate.x += dx / viz.font_size * 2e-2 38 | self.xlate.y += dy / viz.font_size * 2e-2 39 | imgui.same_line() 40 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag slow##xlate', width=viz.button_w) 41 | if dragging: 42 | self.xlate.x += dx / viz.font_size * 4e-4 43 | self.xlate.y += dy / viz.font_size * 4e-4 44 | imgui.same_line() 45 | _clicked, self.xlate.anim = imgui.checkbox('Anim##xlate', self.xlate.anim) 46 | imgui.same_line() 47 | _clicked, self.xlate.round = imgui.checkbox('Round##xlate', self.xlate.round) 48 | imgui.same_line() 49 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.xlate.anim): 50 | changed, speed = imgui.slider_float('##xlate_speed', self.xlate.speed, 0, 0.5, format='Speed %.5f', power=5) 51 | if changed: 52 | self.xlate.speed = speed 53 | imgui.same_line() 54 | if imgui_utils.button('Reset##xlate', width=-1, enabled=(self.xlate != self.xlate_def)): 55 | self.xlate = dnnlib.EasyDict(self.xlate_def) 56 | 57 | if show: 58 | imgui.text('Rotate') 59 | imgui.same_line(viz.label_w) 60 | with imgui_utils.item_width(viz.font_size * 8): 61 | _changed, self.rotate.val = imgui.input_float('##rotate', self.rotate.val, format='%.4f') 62 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 63 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag fast##rotate', width=viz.button_w) 64 | if dragging: 65 | self.rotate.val += dx / viz.font_size * 2e-2 66 | imgui.same_line() 67 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag slow##rotate', width=viz.button_w) 68 | if dragging: 69 | self.rotate.val += dx / viz.font_size * 4e-4 70 | imgui.same_line() 71 | _clicked, self.rotate.anim = imgui.checkbox('Anim##rotate', self.rotate.anim) 72 | imgui.same_line() 73 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.rotate.anim): 74 | changed, speed = imgui.slider_float('##rotate_speed', self.rotate.speed, -1, 1, format='Speed %.4f', power=3) 75 | if changed: 76 | self.rotate.speed = speed 77 | imgui.same_line() 78 | if imgui_utils.button('Reset##rotate', width=-1, enabled=(self.rotate != self.rotate_def)): 79 | self.rotate = dnnlib.EasyDict(self.rotate_def) 80 | 81 | if show: 82 | imgui.set_cursor_pos_x(imgui.get_content_region_max()[0] - 1 - viz.button_w*1 - viz.font_size*16) 83 | _clicked, self.opts.untransform = imgui.checkbox('Untransform', self.opts.untransform) 84 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 85 | if imgui_utils.button('Reset##opts', width=-1, enabled=(self.opts != self.opts_def)): 86 | self.opts = dnnlib.EasyDict(self.opts_def) 87 | 88 | if self.xlate.anim: 89 | c = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 90 | t = c.copy() 91 | if np.max(np.abs(t)) < 1e-4: 92 | t += 1 93 | t *= 0.1 / np.hypot(*t) 94 | t += c[::-1] * [1, -1] 95 | d = t - c 96 | d *= (viz.frame_delta * self.xlate.speed) / np.hypot(*d) 97 | self.xlate.x += d[0] 98 | self.xlate.y += d[1] 99 | 100 | if self.rotate.anim: 101 | self.rotate.val += viz.frame_delta * self.rotate.speed 102 | 103 | pos = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 104 | if self.xlate.round and 'img_resolution' in viz.result: 105 | pos = np.rint(pos * viz.result.img_resolution) / viz.result.img_resolution 106 | angle = self.rotate.val * np.pi * 2 107 | 108 | viz.args.input_transform = [ 109 | [np.cos(angle), np.sin(angle), pos[0]], 110 | [-np.sin(angle), np.cos(angle), pos[1]], 111 | [0, 0, 1]] 112 | 113 | viz.args.update(untransform=self.opts.untransform) 114 | 115 | #---------------------------------------------------------------------------- 116 | -------------------------------------------------------------------------------- /viz/latent_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class LatentWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.latent = dnnlib.EasyDict(x=0, y=0, anim=False, speed=0.25) 20 | self.latent_def = dnnlib.EasyDict(self.latent) 21 | self.step_y = 100 22 | 23 | def drag(self, dx, dy): 24 | viz = self.viz 25 | self.latent.x += dx / viz.font_size * 4e-2 26 | self.latent.y += dy / viz.font_size * 4e-2 27 | 28 | @imgui_utils.scoped_by_object_id 29 | def __call__(self, show=True): 30 | viz = self.viz 31 | if show: 32 | imgui.text('Latent') 33 | imgui.same_line(viz.label_w) 34 | seed = round(self.latent.x) + round(self.latent.y) * self.step_y 35 | with imgui_utils.item_width(viz.font_size * 8): 36 | changed, seed = imgui.input_int('##seed', seed) 37 | if changed: 38 | self.latent.x = seed 39 | self.latent.y = 0 40 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 41 | frac_x = self.latent.x - round(self.latent.x) 42 | frac_y = self.latent.y - round(self.latent.y) 43 | with imgui_utils.item_width(viz.font_size * 5): 44 | changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 45 | if changed: 46 | self.latent.x += new_frac_x - frac_x 47 | self.latent.y += new_frac_y - frac_y 48 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 49 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 50 | if dragging: 51 | self.drag(dx, dy) 52 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) 53 | _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) 54 | imgui.same_line(round(viz.font_size * 27.7)) 55 | with imgui_utils.item_width(-1 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): 56 | changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) 57 | if changed: 58 | self.latent.speed = speed 59 | imgui.same_line() 60 | snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) 61 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): 62 | self.latent = snapped 63 | imgui.same_line() 64 | if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): 65 | self.latent = dnnlib.EasyDict(self.latent_def) 66 | 67 | if self.latent.anim: 68 | self.latent.x += viz.frame_delta * self.latent.speed 69 | viz.args.w0_seeds = [] # [[seed, weight], ...] 70 | for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: 71 | seed_x = np.floor(self.latent.x) + ofs_x 72 | seed_y = np.floor(self.latent.y) + ofs_y 73 | seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) 74 | weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) 75 | if weight > 0: 76 | viz.args.w0_seeds.append([seed, weight]) 77 | 78 | #---------------------------------------------------------------------------- 79 | -------------------------------------------------------------------------------- /viz/layer_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class LayerWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.prev_layers = None 18 | self.cur_layer = None 19 | self.sel_channels = 3 20 | self.base_channel = 0 21 | self.img_scale_db = 0 22 | self.img_normalize = False 23 | self.fft_show = False 24 | self.fft_all = True 25 | self.fft_range_db = 50 26 | self.fft_beta = 8 27 | self.refocus = False 28 | 29 | @imgui_utils.scoped_by_object_id 30 | def __call__(self, show=True): 31 | viz = self.viz 32 | layers = viz.result.get('layers', []) 33 | if self.prev_layers != layers: 34 | self.prev_layers = layers 35 | self.refocus = True 36 | layer = ([layer for layer in layers if layer.name == self.cur_layer] + [None])[0] 37 | if layer is None and len(layers) > 0: 38 | layer = layers[-1] 39 | self.cur_layer = layer.name 40 | num_channels = layer.shape[1] if layer is not None else 0 41 | base_channel_max = max(num_channels - self.sel_channels, 0) 42 | 43 | if show: 44 | bg_color = [0.16, 0.29, 0.48, 0.2] 45 | dim_color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 46 | dim_color[-1] *= 0.5 47 | 48 | # Begin list. 49 | width = viz.font_size * 28 50 | height = imgui.get_text_line_height_with_spacing() * 12 + viz.spacing 51 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) 52 | imgui.push_style_color(imgui.COLOR_CHILD_BACKGROUND, *bg_color) 53 | imgui.push_style_color(imgui.COLOR_HEADER, 0, 0, 0, 0) 54 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, 0.16, 0.29, 0.48, 0.5) 55 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, 0.16, 0.29, 0.48, 0.9) 56 | imgui.begin_child('##list', width=width, height=height, border=True, flags=imgui.WINDOW_ALWAYS_VERTICAL_SCROLLBAR) 57 | 58 | # List items. 59 | for layer in layers: 60 | selected = (self.cur_layer == layer.name) 61 | _opened, selected = imgui.selectable(f'##{layer.name}_selectable', selected) 62 | imgui.same_line(viz.spacing) 63 | _clicked, selected = imgui.checkbox(f'{layer.name}##radio', selected) 64 | if selected: 65 | self.cur_layer = layer.name 66 | if self.refocus: 67 | imgui.set_scroll_here() 68 | viz.skip_frame() # Focus will change on next frame. 69 | self.refocus = False 70 | imgui.same_line(width - viz.font_size * 13) 71 | imgui.text_colored('x'.join(str(x) for x in layer.shape[2:]), *dim_color) 72 | imgui.same_line(width - viz.font_size * 8) 73 | imgui.text_colored(str(layer.shape[1]), *dim_color) 74 | imgui.same_line(width - viz.font_size * 5) 75 | imgui.text_colored(layer.dtype, *dim_color) 76 | 77 | # End list. 78 | if len(layers) == 0: 79 | imgui.text_colored('No layers found', *dim_color) 80 | imgui.end_child() 81 | imgui.pop_style_color(4) 82 | imgui.pop_style_var(1) 83 | 84 | # Begin options. 85 | imgui.same_line() 86 | imgui.begin_child('##options', width=-1, height=height, border=False) 87 | 88 | # RGB & normalize. 89 | rgb = (self.sel_channels == 3) 90 | _clicked, rgb = imgui.checkbox('RGB', rgb) 91 | self.sel_channels = 3 if rgb else 1 92 | imgui.same_line(viz.font_size * 4) 93 | _clicked, self.img_normalize = imgui.checkbox('Normalize', self.img_normalize) 94 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 95 | if imgui_utils.button('Reset##img_flags', width=-1, enabled=(self.sel_channels != 3 or self.img_normalize)): 96 | self.sel_channels = 3 97 | self.img_normalize = False 98 | 99 | # Image scale. 100 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): 101 | _changed, self.img_scale_db = imgui.slider_float('##scale', self.img_scale_db, min_value=-40, max_value=40, format='Scale %+.1f dB') 102 | imgui.same_line() 103 | if imgui_utils.button('Reset##scale', width=-1, enabled=(self.img_scale_db != 0)): 104 | self.img_scale_db = 0 105 | 106 | # Base channel. 107 | self.base_channel = min(max(self.base_channel, 0), base_channel_max) 108 | narrow_w = imgui.get_text_line_height_with_spacing() 109 | with imgui_utils.grayed_out(base_channel_max == 0): 110 | with imgui_utils.item_width(-1 - viz.button_w - narrow_w * 2 - viz.spacing * 3): 111 | _changed, self.base_channel = imgui.drag_int('##channel', self.base_channel, change_speed=0.05, min_value=0, max_value=base_channel_max, format=f'Channel %d/{num_channels}') 112 | imgui.same_line() 113 | if imgui_utils.button('-##channel', width=narrow_w): 114 | self.base_channel -= 1 115 | imgui.same_line() 116 | if imgui_utils.button('+##channel', width=narrow_w): 117 | self.base_channel += 1 118 | imgui.same_line() 119 | self.base_channel = min(max(self.base_channel, 0), base_channel_max) 120 | if imgui_utils.button('Reset##channel', width=-1, enabled=(self.base_channel != 0 and base_channel_max > 0)): 121 | self.base_channel = 0 122 | 123 | # Stats. 124 | stats = viz.result.get('stats', None) 125 | stats = [f'{stats[idx]:g}' if stats is not None else 'N/A' for idx in range(6)] 126 | rows = [ 127 | ['Statistic', 'All channels', 'Selected'], 128 | ['Mean', stats[0], stats[1]], 129 | ['Std', stats[2], stats[3]], 130 | ['Max', stats[4], stats[5]], 131 | ] 132 | height = imgui.get_text_line_height_with_spacing() * len(rows) + viz.spacing 133 | imgui.push_style_color(imgui.COLOR_CHILD_BACKGROUND, *bg_color) 134 | imgui.begin_child('##stats', width=-1, height=height, border=True) 135 | for y, cols in enumerate(rows): 136 | for x, col in enumerate(cols): 137 | if x != 0: 138 | imgui.same_line(viz.font_size * (4 + (x - 1) * 6)) 139 | if x == 0 or y == 0: 140 | imgui.text_colored(col, *dim_color) 141 | else: 142 | imgui.text(col) 143 | imgui.end_child() 144 | imgui.pop_style_color(1) 145 | 146 | # FFT & all. 147 | _clicked, self.fft_show = imgui.checkbox('FFT', self.fft_show) 148 | imgui.same_line(viz.font_size * 4) 149 | with imgui_utils.grayed_out(not self.fft_show or base_channel_max == 0): 150 | _clicked, self.fft_all = imgui.checkbox('All channels', self.fft_all) 151 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 152 | with imgui_utils.grayed_out(not self.fft_show): 153 | if imgui_utils.button('Reset##fft_flags', width=-1, enabled=(self.fft_show or not self.fft_all)): 154 | self.fft_show = False 155 | self.fft_all = True 156 | 157 | # FFT range. 158 | with imgui_utils.grayed_out(not self.fft_show): 159 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): 160 | _changed, self.fft_range_db = imgui.slider_float('##fft_range_db', self.fft_range_db, min_value=0.1, max_value=100, format='Range +-%.1f dB') 161 | imgui.same_line() 162 | if imgui_utils.button('Reset##fft_range_db', width=-1, enabled=(self.fft_range_db != 50)): 163 | self.fft_range_db = 50 164 | 165 | # FFT beta. 166 | with imgui_utils.grayed_out(not self.fft_show): 167 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing): 168 | _changed, self.fft_beta = imgui.slider_float('##fft_beta', self.fft_beta, min_value=0, max_value=50, format='Kaiser beta %.2f', power=2.63) 169 | imgui.same_line() 170 | if imgui_utils.button('Reset##fft_beta', width=-1, enabled=(self.fft_beta != 8)): 171 | self.fft_beta = 8 172 | 173 | # End options. 174 | imgui.end_child() 175 | 176 | self.base_channel = min(max(self.base_channel, 0), base_channel_max) 177 | viz.args.layer_name = self.cur_layer if len(layers) > 0 and self.cur_layer != layers[-1].name else None 178 | viz.args.update(sel_channels=self.sel_channels, base_channel=self.base_channel, img_scale_db=self.img_scale_db, img_normalize=self.img_normalize) 179 | viz.args.fft_show = self.fft_show 180 | if self.fft_show: 181 | viz.args.update(fft_all=self.fft_all, fft_range_db=self.fft_range_db, fft_beta=self.fft_beta) 182 | 183 | #---------------------------------------------------------------------------- 184 | -------------------------------------------------------------------------------- /viz/performance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import array 10 | import numpy as np 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class PerformanceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.gui_times = [float('nan')] * 60 20 | self.render_times = [float('nan')] * 30 21 | self.fps_limit = 60 22 | self.use_vsync = False 23 | self.is_async = False 24 | self.force_fp32 = False 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | self.gui_times = self.gui_times[1:] + [viz.frame_delta] 30 | if 'render_time' in viz.result: 31 | self.render_times = self.render_times[1:] + [viz.result.render_time] 32 | del viz.result.render_time 33 | 34 | if show: 35 | imgui.text('GUI') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 8): 38 | imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) 39 | imgui.same_line(viz.label_w + viz.font_size * 9) 40 | t = [x for x in self.gui_times if x > 0] 41 | t = np.mean(t) if len(t) > 0 else 0 42 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 43 | imgui.same_line(viz.label_w + viz.font_size * 14) 44 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 45 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 46 | with imgui_utils.item_width(viz.font_size * 6): 47 | _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 48 | self.fps_limit = min(max(self.fps_limit, 5), 1000) 49 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 50 | _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) 51 | 52 | if show: 53 | imgui.text('Render') 54 | imgui.same_line(viz.label_w) 55 | with imgui_utils.item_width(viz.font_size * 8): 56 | imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) 57 | imgui.same_line(viz.label_w + viz.font_size * 9) 58 | t = [x for x in self.render_times if x > 0] 59 | t = np.mean(t) if len(t) > 0 else 0 60 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 61 | imgui.same_line(viz.label_w + viz.font_size * 14) 62 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 63 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 64 | _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) 65 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 66 | _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) 67 | 68 | viz.set_fps_limit(self.fps_limit) 69 | viz.set_vsync(self.use_vsync) 70 | viz.set_async(self.is_async) 71 | viz.args.force_fp32 = self.force_fp32 72 | 73 | #---------------------------------------------------------------------------- 74 | -------------------------------------------------------------------------------- /viz/pickle_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import os 11 | import re 12 | 13 | import dnnlib 14 | import imgui 15 | import numpy as np 16 | from gui_utils import imgui_utils 17 | 18 | from . import renderer 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def _locate_results(pattern): 23 | return pattern 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | class PickleWidget: 28 | def __init__(self, viz): 29 | self.viz = viz 30 | self.search_dirs = [] 31 | self.cur_pkl = None 32 | self.user_pkl = '' 33 | self.recent_pkls = [] 34 | self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} 35 | self.browse_refocus = False 36 | self.load('', ignore_errors=True) 37 | 38 | def add_recent(self, pkl, ignore_errors=False): 39 | try: 40 | resolved = self.resolve_pkl(pkl) 41 | if resolved not in self.recent_pkls: 42 | self.recent_pkls.append(resolved) 43 | except: 44 | if not ignore_errors: 45 | raise 46 | 47 | def load(self, pkl, ignore_errors=False): 48 | viz = self.viz 49 | viz.clear_result() 50 | viz.skip_frame() # The input field will change on next frame. 51 | try: 52 | resolved = self.resolve_pkl(pkl) 53 | name = resolved.replace('\\', '/').split('/')[-1] 54 | self.cur_pkl = resolved 55 | self.user_pkl = resolved 56 | viz.result.message = f'Loading {name}...' 57 | viz.defer_rendering() 58 | if resolved in self.recent_pkls: 59 | self.recent_pkls.remove(resolved) 60 | self.recent_pkls.insert(0, resolved) 61 | except: 62 | self.cur_pkl = None 63 | self.user_pkl = pkl 64 | if pkl == '': 65 | viz.result = dnnlib.EasyDict(message='No network pickle loaded') 66 | else: 67 | viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) 68 | if not ignore_errors: 69 | raise 70 | 71 | @imgui_utils.scoped_by_object_id 72 | def __call__(self, show=True): 73 | viz = self.viz 74 | recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] 75 | if show: 76 | imgui.text('Pickle') 77 | imgui.same_line(viz.label_w) 78 | changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, 79 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 80 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 81 | help_text=' | | | | /.pkl') 82 | if changed: 83 | self.load(self.user_pkl, ignore_errors=True) 84 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': 85 | imgui.set_tooltip(self.user_pkl) 86 | imgui.same_line() 87 | if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): 88 | imgui.open_popup('recent_pkls_popup') 89 | imgui.same_line() 90 | if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): 91 | imgui.open_popup('browse_pkls_popup') 92 | self.browse_cache.clear() 93 | self.browse_refocus = True 94 | 95 | if imgui.begin_popup('recent_pkls_popup'): 96 | for pkl in recent_pkls: 97 | clicked, _state = imgui.menu_item(pkl) 98 | if clicked: 99 | self.load(pkl, ignore_errors=True) 100 | imgui.end_popup() 101 | 102 | if imgui.begin_popup('browse_pkls_popup'): 103 | def recurse(parents): 104 | key = tuple(parents) 105 | items = self.browse_cache.get(key, None) 106 | if items is None: 107 | items = self.list_runs_and_pkls(parents) 108 | self.browse_cache[key] = items 109 | for item in items: 110 | if item.type == 'run' and imgui.begin_menu(item.name): 111 | recurse([item.path]) 112 | imgui.end_menu() 113 | if item.type == 'pkl': 114 | clicked, _state = imgui.menu_item(item.name) 115 | if clicked: 116 | self.load(item.path, ignore_errors=True) 117 | if len(items) == 0: 118 | with imgui_utils.grayed_out(): 119 | imgui.menu_item('No results found') 120 | recurse(self.search_dirs) 121 | if self.browse_refocus: 122 | imgui.set_scroll_here() 123 | viz.skip_frame() # Focus will change on next frame. 124 | self.browse_refocus = False 125 | imgui.end_popup() 126 | 127 | paths = viz.pop_drag_and_drop_paths() 128 | if paths is not None and len(paths) >= 1: 129 | self.load(paths[0], ignore_errors=True) 130 | 131 | viz.args.pkl = self.cur_pkl 132 | 133 | def list_runs_and_pkls(self, parents): 134 | items = [] 135 | run_regex = re.compile(r'\d+-.*') 136 | pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') 137 | for parent in set(parents): 138 | if os.path.isdir(parent): 139 | for entry in os.scandir(parent): 140 | if entry.is_dir() and run_regex.fullmatch(entry.name): 141 | items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) 142 | if entry.is_file() and pkl_regex.fullmatch(entry.name): 143 | items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) 144 | 145 | items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) 146 | return items 147 | 148 | def resolve_pkl(self, pattern): 149 | assert isinstance(pattern, str) 150 | assert pattern != '' 151 | 152 | # URL => return as is. 153 | if dnnlib.util.is_url(pattern): 154 | return pattern 155 | 156 | # Short-hand pattern => locate. 157 | path = _locate_results(pattern) 158 | 159 | # Run dir => pick the last saved snapshot. 160 | if os.path.isdir(path): 161 | pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) 162 | if len(pkl_files) == 0: 163 | raise IOError(f'No network pickle found in "{path}"') 164 | path = pkl_files[-1] 165 | 166 | # Normalize. 167 | path = os.path.abspath(path) 168 | return path 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /viz/stylemix_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class StyleMixingWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.seed_def = 1000 18 | self.seed = self.seed_def 19 | self.animate = False 20 | self.enables = [] 21 | 22 | @imgui_utils.scoped_by_object_id 23 | def __call__(self, show=True): 24 | viz = self.viz 25 | num_ws = viz.result.get('num_ws', 0) 26 | num_enables = viz.result.get('num_ws', 18) 27 | self.enables += [False] * max(num_enables - len(self.enables), 0) 28 | 29 | if show: 30 | imgui.text('Stylemix') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): 33 | _changed, self.seed = imgui.input_int('##seed', self.seed) 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | with imgui_utils.grayed_out(num_ws == 0): 36 | _clicked, self.animate = imgui.checkbox('Anim', self.animate) 37 | 38 | pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w 39 | pos1 = pos2 - imgui.get_text_line_height() - viz.spacing 40 | pos0 = viz.label_w + viz.font_size * 12 41 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) 42 | for idx in range(num_enables): 43 | imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) 44 | if idx == 0: 45 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) 46 | with imgui_utils.grayed_out(num_ws == 0): 47 | _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) 48 | if imgui.is_item_hovered(): 49 | imgui.set_tooltip(f'{idx}') 50 | imgui.pop_style_var(1) 51 | 52 | imgui.same_line(pos2) 53 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) 54 | with imgui_utils.grayed_out(num_ws == 0): 55 | if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): 56 | self.seed = self.seed_def 57 | self.animate = False 58 | self.enables = [False] * num_enables 59 | 60 | if any(self.enables[:num_ws]): 61 | viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] 62 | viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) 63 | if self.animate: 64 | self.seed += 1 65 | 66 | #---------------------------------------------------------------------------- 67 | -------------------------------------------------------------------------------- /viz/trunc_noise_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class TruncationNoiseWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.prev_num_ws = 0 18 | self.trunc_psi = 1 19 | self.trunc_cutoff = 0 20 | self.noise_enable = True 21 | self.noise_seed = 0 22 | self.noise_anim = False 23 | 24 | @imgui_utils.scoped_by_object_id 25 | def __call__(self, show=True): 26 | viz = self.viz 27 | num_ws = viz.result.get('num_ws', 0) 28 | has_noise = viz.result.get('has_noise', False) 29 | if num_ws > 0 and num_ws != self.prev_num_ws: 30 | if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: 31 | self.trunc_cutoff = num_ws 32 | self.prev_num_ws = num_ws 33 | 34 | if show: 35 | imgui.text('Truncate') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): 38 | _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') 39 | imgui.same_line() 40 | if num_ws == 0: 41 | imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) 42 | else: 43 | with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): 44 | changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') 45 | if changed: 46 | self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) 47 | 48 | with imgui_utils.grayed_out(not has_noise): 49 | imgui.same_line() 50 | _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) 51 | imgui.same_line(round(viz.font_size * 27.7)) 52 | with imgui_utils.grayed_out(not self.noise_enable): 53 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing - viz.font_size * 4): 54 | _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) 55 | imgui.same_line(spacing=0) 56 | _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) 57 | 58 | is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) 59 | is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) 60 | with imgui_utils.grayed_out(is_def_trunc and not has_noise): 61 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 62 | if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): 63 | self.prev_num_ws = num_ws 64 | self.trunc_psi = 1 65 | self.trunc_cutoff = num_ws 66 | self.noise_enable = True 67 | self.noise_seed = 0 68 | self.noise_anim = False 69 | 70 | if self.noise_anim: 71 | self.noise_seed += 1 72 | viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) 73 | viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') 74 | 75 | #---------------------------------------------------------------------------- 76 | --------------------------------------------------------------------------------