├── .gitignore ├── LICENSE ├── README.md ├── calc_metrics.py ├── data ├── afhq.json ├── carla.json └── ffhq.json ├── dataset_tool.py ├── dist ├── MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl ├── stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl └── svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl ├── dnnlib ├── __init__.py └── util.py ├── environment.yml ├── gen_images.py ├── gen_video.py ├── gfx └── ffhq.gif ├── legacy.py ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── metric_main.py └── metric_utils.py ├── scripts ├── build_wheels.sh ├── download_pretrained_models.sh └── make_datasets.sh ├── torch_utils ├── __init__.py ├── build_fat_binary.sh ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── bias_act_kernel.cu │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── pybind.cpp │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── persistence.py ├── setup.py └── training_stats.py ├── train.py ├── training ├── __init__.py ├── augment.py ├── dataset.py ├── loss.py ├── networks_stylegan2.py ├── networks_stylegan2_3d.py ├── networks_stylegan3.py ├── networks_voxgraf.py ├── training_loop.py └── virtual_camera_utils.py └── voxgraf-plenoxels ├── LICENSE ├── README.md ├── environment.yml ├── manual_install.sh ├── setup.py ├── svox2 ├── __init__.py ├── csrc │ ├── .ccls │ ├── CMakeLists.txt │ ├── include │ │ ├── cubemap_util.cuh │ │ ├── cuda_util.cuh │ │ ├── data_spec.hpp │ │ ├── data_spec_packed.cuh │ │ ├── random_util.cuh │ │ ├── render_util.cuh │ │ └── util.hpp │ ├── loss_kernel.cu │ ├── misc_kernel.cu │ ├── optim_kernel.cu │ ├── render_lerp_kernel_cuvol.cu │ ├── render_lerp_kernel_nvol.cu │ ├── render_svox1_kernel.cu │ ├── svox2.cpp │ └── svox2_kernel.cu ├── defs.py ├── svox2.py ├── utils.py └── version.py └── test ├── test_render_gradcheck_alpha.py ├── test_render_gradcheck_depth.py └── test_render_gradcheck_vardepth.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Added 132 | training-runs/* 133 | training-runs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VoxGRAF 2 | 3 |
4 |
5 |
6 | 7 | This repository contains official code for the paper 8 | [VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids](https://www.cvlibs.net/publications/Schwarz2022NEURIPS.pdf). 9 | 10 | You can find detailed usage instructions for training your own models and using pre-trained models below. 11 | 12 | If you find our code or paper useful, please consider citing 13 | 14 | @inproceedings{Schwarz2022NEURIPS, 15 | title = {VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids}, 16 | author = {Schwarz, Katja and Sauer, Axel and Niemeyer, Michael and Liao, Yiyi and Geiger, Andreas}, 17 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 18 | year = {2022} 19 | } 20 | 21 | ## Installation 22 | First you have to make sure that you have all dependencies in place. 23 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/). 24 | 25 | You can create and activate an anaconda environment called `voxgraf` using 26 | 27 | ```commandline 28 | conda env create -f environment.yml 29 | conda activate voxgraf 30 | ``` 31 | 32 | ### CUDA extension installation 33 | 34 | Install pre-compiled CUDA extensions by running 35 | ```commandline 36 | ./scripts/build_wheels.sh 37 | ``` 38 | **Or** install them individually by running 39 | ```commandline 40 | pip install dist/stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl 41 | pip install dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl 42 | pip install dist/MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl # optional, only required when training with minkowski sparse convolutions 43 | ``` 44 | In case the wheels do not work for you, you can also install the extensions from source. For this please check the original repos: [Stylegan-3](https://github.com/NVlabs/stylegan3), [Minkowski Engine](https://github.com/NVIDIA/MinkowskiEngine) and for our version of Plenoxels follow the instructions [here](voxgraf-plenoxels/README.md). 45 | 46 | ## Pretrained models 47 | To download the pretrained models run 48 | ```commandline 49 | ./scripts/download_pretrained_models.sh 50 | ``` 51 | 52 | ### Evaluate pretrained models 53 | ```commandline 54 | # generate a video with 1x2 samples and interpolations between 2 keyframes each 55 | python gen_video.py --network pretrained_models/ffhq256.pkl --seeds 0-3 --grid 1x2 --num-keyframes 2 --output ffhq_256_samples/video.mp4 --trunc=0.5 56 | 57 | # generate grids of 3x4 samples and their depths 58 | python gen_images.py --network pretrained_models/ffhq256.pkl --seeds 0-23 --grid 3x4 --outdir ffhq_256_samples --save_depth true --trunc=0.5 59 | ``` 60 | 61 | ## Train custom models 62 | 63 | ### Download the data 64 | Download [FFHQ](https://github.com/NVlabs/stylegan2), [AFHQ](https://github.com/clovaai/stargan-v2) and [Carla](https://github.com/autonomousvision/graf). 65 | 66 | ### Preparing the data 67 | To prepare the data at the required resolutions you can run 68 | ```commandline 69 | ./scripts/make_dataset.sh /PATH/TO/IMAGES data/{DATASET_NAME}.json data/{DATASET_NAME} 32,64,128,256 70 | ``` 71 | This will create the datasets in `data/{DATASET_NAME}_{RES}.zip`. 72 | 73 | ### Train models progressively 74 | 75 | ```commandline 76 | # Train a model on FFHQ progressively starting at image resolution 32x32 with voxel grid resolution 32x32x32 77 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_32.zip --batch 64 --grid-res 32 78 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 32 --resume /PATH/TO/32-IMG-32-GRID-MODEL # Next stage 79 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 64 --resume /PATH/TO/64-IMG-32-GRID-MODEL # Next stage 80 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 64 --grid-res 64 --lambda_vardepth 1e-3 --resume /PATH/TO/64-IMG-64-GRID-MODEL # Next stage 81 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 32 --grid-res 128 --lambda_vardepth 1e-3 --resume /PATH/TO/128-IMG-64-GRID-MODEL # Next stage 82 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_256.zip --batch 32 --grid-res 128 --lambda_vardepth 1e-3 --resume /PATH/TO/128-IMG-128-GRID-MODEL # Next stage 83 | 84 | # Train a model on Carla at image resolution 32x32 with voxel grid resolution 32x32x32 85 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_32.zip --batch 64 --grid-res 32 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 86 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 32 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --resume /PATH/TO/32-IMG-32-GRID-MODEL # Next stage 87 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 64 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --resume /PATH/TO/64-IMG-32-GRID-MODEL # Next stage 88 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 64 --grid-res 64 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --lambda_vardepth 1e-3 --resume /PATH/TO/64-IMG-64-GRID-MODEL # Next stage 89 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 32 --grid-res 128 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --lambda_vardepth 1e-3 --resume /PATH/TO/128-IMG-64-GRID-MODEL # Next stage 90 | ``` -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import click 13 | import json 14 | import tempfile 15 | import copy 16 | import torch 17 | 18 | import dnnlib 19 | import legacy 20 | from metrics import metric_main 21 | from metrics import metric_utils 22 | from torch_utils import training_stats 23 | from torch_utils import custom_ops 24 | from torch_utils import misc 25 | from torch_utils.ops import conv2d_gradfix 26 | 27 | # environment variables 28 | os.environ['OMP_NUM_THREADS'] = "16" 29 | 30 | #---------------------------------------------------------------------------- 31 | 32 | def subprocess_fn(rank, args, temp_dir): 33 | dnnlib.util.Logger(should_flush=True) 34 | 35 | # Init torch.distributed. 36 | if args.num_gpus > 1: 37 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 38 | if os.name == 'nt': 39 | init_method = 'file:///' + init_file.replace('\\', '/') 40 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 41 | else: 42 | init_method = f'file://{init_file}' 43 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 44 | 45 | # Init torch_utils. 46 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 47 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 48 | if rank != 0 or not args.verbose: 49 | custom_ops.verbosity = 'none' 50 | 51 | # Configure torch. 52 | device = torch.device('cuda', rank) 53 | torch.cuda.set_device(device) 54 | torch.backends.cuda.matmul.allow_tf32 = False 55 | torch.backends.cudnn.allow_tf32 = False 56 | conv2d_gradfix.enabled = True 57 | 58 | # Print network summary. 59 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 60 | if rank == 0 and args.verbose: 61 | z = torch.empty([1, G.z_dim], device=device) 62 | c = torch.empty([1, G.c_dim], device=device) 63 | misc.print_module_summary(G, [z, c]) 64 | 65 | # Calculate each metric. 66 | for metric in args.metrics: 67 | if rank == 0 and args.verbose: 68 | print(f'Calculating {metric}...') 69 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 70 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 71 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress, G_kwargs=args.G_kwargs) 72 | if rank == 0: 73 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 74 | if rank == 0 and args.verbose: 75 | print() 76 | 77 | # Done. 78 | if rank == 0 and args.verbose: 79 | print('Exiting...') 80 | 81 | #---------------------------------------------------------------------------- 82 | 83 | def parse_comma_separated_list(s): 84 | if isinstance(s, list): 85 | return s 86 | if s is None or s.lower() == 'none' or s == '': 87 | return [] 88 | return s.split(',') 89 | 90 | #---------------------------------------------------------------------------- 91 | 92 | @click.command() 93 | @click.pass_context 94 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 95 | @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True) 96 | @click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]') 97 | @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL') 98 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 99 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 100 | 101 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): 102 | """Calculate quality metrics for previous training run or pretrained network pickle. 103 | 104 | Examples: 105 | 106 | \b 107 | # Previous training run: look up options automatically, save result to JSONL file. 108 | python calc_metrics.py --metrics=eqt50k_int,eqr50k \\ 109 | --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl 110 | 111 | \b 112 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 113 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\ 114 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl 115 | 116 | \b 117 | Recommended metrics: 118 | fid50k_full Frechet inception distance against the full dataset. 119 | kid50k_full Kernel inception distance against the full dataset. 120 | pr50k3_full Precision and recall againt the full dataset. 121 | ppl2_wend Perceptual path length in W, endpoints, full image. 122 | eqt50k_int Equivariance w.r.t. integer translation (EQ-T). 123 | eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac). 124 | eqr50k Equivariance w.r.t. rotation (EQ-R). 125 | 126 | \b 127 | Legacy metrics: 128 | fid50k Frechet inception distance against 50k real images. 129 | kid50k Kernel inception distance against 50k real images. 130 | pr50k3 Precision and recall against 50k real images. 131 | is50k Inception score for CIFAR-10. 132 | """ 133 | dnnlib.util.Logger(should_flush=True) 134 | 135 | # Validate arguments. 136 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) 137 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 138 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 139 | if not args.num_gpus >= 1: 140 | ctx.fail('--gpus must be at least 1') 141 | 142 | # Load network. 143 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 144 | ctx.fail('--network must point to a file or URL') 145 | if args.verbose: 146 | print(f'Loading network from "{network_pkl}"...') 147 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 148 | network_dict = legacy.load_network_pkl(f) 149 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 150 | n_img = network_dict.get('progress', {}).get('cur_nimg', None) 151 | if n_img is not None: 152 | args.G_kwargs = {'cur_nimg': n_img.item()} 153 | 154 | # Initialize dataset options. 155 | if data is not None: 156 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) 157 | elif network_dict['training_set_kwargs'] is not None: 158 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 159 | else: 160 | ctx.fail('Could not look up dataset options; please specify --data') 161 | 162 | # Finalize dataset options. 163 | args.dataset_kwargs.resolution = args.G.img_resolution 164 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 165 | if mirror is not None: 166 | args.dataset_kwargs.xflip = mirror 167 | 168 | # Print dataset options. 169 | if args.verbose: 170 | print('Dataset options:') 171 | print(json.dumps(args.dataset_kwargs, indent=2)) 172 | 173 | # Locate run dir. 174 | args.run_dir = None 175 | if os.path.isfile(network_pkl): 176 | pkl_dir = os.path.dirname(network_pkl) 177 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 178 | args.run_dir = pkl_dir 179 | 180 | # Launch processes. 181 | if args.verbose: 182 | print('Launching processes...') 183 | torch.multiprocessing.set_start_method('spawn') 184 | with tempfile.TemporaryDirectory() as temp_dir: 185 | if args.num_gpus == 1: 186 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 187 | else: 188 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 189 | 190 | #---------------------------------------------------------------------------- 191 | 192 | if __name__ == "__main__": 193 | calc_metrics() # pylint: disable=no-value-for-parameter 194 | 195 | #---------------------------------------------------------------------------- 196 | -------------------------------------------------------------------------------- /dist/MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/dist/MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl -------------------------------------------------------------------------------- /dist/stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/dist/stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl -------------------------------------------------------------------------------- /dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: voxgraf 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - anaconda 6 | dependencies: 7 | - python >= 3.8 8 | - pip 9 | - numpy=1.22 10 | - click>=8.0 11 | - pillow=8.3.1 12 | - scipy=1.7.1 13 | - pytorch==1.11.0 14 | - cudatoolkit=11.1 15 | - requests=2.26.0 16 | - tqdm=4.62.2 17 | - ninja=1.10.2 18 | - matplotlib=3.4.2 19 | - imageio=2.9.0 20 | - openblas-devel 21 | - pip: 22 | - imgui==1.3.0 23 | - glfw==2.2.0 24 | - pyopengl==3.1.5 25 | - imageio-ffmpeg==0.4.3 26 | - pyspng 27 | - dill 28 | - psutil 29 | - opencv-python 30 | - tensorboard 31 | - plyfile 32 | - torchvision==0.12.0 -------------------------------------------------------------------------------- /gen_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | # 9 | # Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids 10 | # 11 | 12 | """Generate lerp videos using pretrained network pickle.""" 13 | 14 | from typing import List, Optional, Tuple, Union 15 | import os 16 | import click 17 | import imageio 18 | import numpy as np 19 | import torch 20 | from training.virtual_camera_utils import uv2RT 21 | from gen_images import parse_range, parse_tuple, load_model, make_inputs, generate_grids 22 | 23 | # environment variables 24 | os.environ['OMP_NUM_THREADS'] = "16" 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def sample_oval(n_points, a, b): 29 | "a is axes length in x-direction, b in y-direction" 30 | angle_range = np.linspace(0, 2*np.pi, n_points) 31 | x = a*np.sin(angle_range) 32 | y = b*np.cos(angle_range) 33 | return x, y 34 | 35 | def sample_line(n_points, a): 36 | "a is axes length in x-direction, b in y-direction" 37 | angle_range = np.linspace(-1, 1, n_points) 38 | x = a*angle_range 39 | y = np.zeros(n_points) 40 | return x, y 41 | 42 | def make_poses(G, label, pose_sampling, w_frames, num_keyframes=1, range_azim=(180, 90), range_polar=(180, 90)): 43 | assert len(label) % num_keyframes == 0 44 | nsamples = len(label) // num_keyframes 45 | # set up the camera trajectory 46 | if pose_sampling == 'oval': 47 | azim, polar = sample_oval(w_frames, range_azim[1], range_polar[1]) 48 | elif pose_sampling == 'line': 49 | azim, polar = sample_line(w_frames, range_azim[1]) 50 | else: 51 | raise AttributeError 52 | azim += range_azim[0] 53 | polar += range_polar[0] 54 | 55 | azim2u = lambda x: torch.deg2rad(torch.tensor(x)) / (2 * np.pi) 56 | polar2v = lambda x: 0.5 * (1 - torch.cos(torch.deg2rad(torch.tensor(x)))) 57 | us, vs = azim2u(azim), polar2v(polar) 58 | 59 | RTs = [] 60 | for u, v in zip(us, vs): 61 | RT = uv2RT(u, v, 1) # radius is set later 62 | RTs.append(RT) 63 | poses = torch.stack(RTs).to(torch.float32) 64 | poses = poses.view(1, 1, *poses.shape).repeat(num_keyframes, nsamples, 1, 1, 1) 65 | 66 | if G.c_dim != 0: # adjust the radius of the samples to match their pose 67 | c = label.view(num_keyframes, nsamples, 1, 3, 4).expand(-1, -1, w_frames, -1, -1) 68 | cradius = c[:, :, :, :3, 3].norm(dim=3, keepdim=True).min(dim=0, keepdim=True).values # chooses smalles radius of all keyframes 69 | poses[:, :, :, :3, 3] *= cradius 70 | 71 | return poses.flatten(0, 1) # (num_keyframes x nsamples) x wframes x 4 x 4 72 | 73 | @click.command() 74 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 75 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True) 76 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) 77 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) 78 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) 79 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) # 60*x 80 | @click.option('--range-azim', type=parse_tuple, help='Mean and std of pose distribution azimuth angle', default=(180, 20)) 81 | @click.option('--range-polar', type=parse_tuple, help='Mean and std of pose distribution polar angle', default=(90, 10)) 82 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 83 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') 84 | @click.option('--pose-sampling', help='Camera trajectory', type=click.Choice(['oval', 'line']), default='oval', show_default=True) 85 | @click.option('--bg-decay', help='Reduce background visibility', type=bool, default=False, show_default=True) 86 | @click.option('--no-bg', help='Remove background', type=bool, default=False, show_default=True) 87 | 88 | def generate_videos( 89 | network_pkl: str, 90 | seeds: List[int], 91 | shuffle_seed: Optional[int], 92 | truncation_psi: float, 93 | grid: Tuple[int,int], 94 | num_keyframes: Optional[int], 95 | w_frames: int, 96 | range_azim: Optional[int], 97 | range_polar: Optional[int], 98 | output: str, 99 | pose_sampling: str, 100 | bg_decay: bool, 101 | no_bg: bool, 102 | ): 103 | """Render a latent vector interpolation video. 104 | 105 | The output video length 106 | will be 'num_keyframes*w_frames' frames. 107 | """ 108 | os.makedirs(os.path.dirname(output), exist_ok=True) 109 | assert len(seeds) == grid[0]*grid[1]*num_keyframes, f'need gw({grid[0]})*gh({grid[1]})*num_keyframes({num_keyframes})={grid[0]*grid[1]*num_keyframes} seeds but have {len(seeds)}' 110 | G = load_model(network_pkl) 111 | z, c = make_inputs(seeds, G) 112 | p = make_poses(G, c, pose_sampling, w_frames, num_keyframes=num_keyframes, range_azim=range_azim, range_polar=range_polar) 113 | 114 | if bg_decay: 115 | bg_color = 1 116 | n_start = w_frames // 4 117 | n_end = num_keyframes * w_frames - (w_frames // 4) 118 | get_bg_weight = lambda n: 1 - min(1, max(0, (n - n_start) / (n_end - n_start))) 119 | 120 | video_out = imageio.get_writer(output, mode='I', fps=60, codec='libx264', bitrate='12M') 121 | keyframe_idx = 0 122 | for ret_dict in generate_grids(G=G, truncation_psi=truncation_psi, grid=grid, latents=z, poses=p, conditions=c, ret_alpha=bg_decay, ret_depth=False, ret_bg_fg=(bg_decay or no_bg), interpolate=True): 123 | g = ret_dict['rgb'] 124 | 125 | if no_bg: 126 | g = ret_dict['fg'] 127 | if bg_decay: 128 | g_decay = [] 129 | for i, (g_i, alpha_i) in enumerate(zip(g, ret_dict['alpha'])): 130 | weight = get_bg_weight(keyframe_idx*w_frames + i) 131 | print(weight) 132 | g_decay.append(alpha_i * g_i + (1 - alpha_i) * weight * g_i + (1 - alpha_i) * (1 - weight) * torch.full_like(alpha_i, fill_value=bg_color)) 133 | g = g_decay 134 | 135 | for pose_idx, g_p in enumerate(g): 136 | video_out.append_data((g_p.permute(1,2,0)*255).to(torch.uint8).cpu().numpy()) 137 | keyframe_idx += 1 138 | video_out.close() 139 | 140 | #---------------------------------------------------------------------------- 141 | 142 | if __name__ == "__main__": 143 | generate_videos() # pylint: disable=no-value-for-parameter 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /gfx/ffhq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/gfx/ffhq.gif -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from torch_utils import misc 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def compute_fid(opts, max_real, num_gen): 22 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 23 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 24 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 25 | 26 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 27 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 28 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 29 | 30 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 31 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 32 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 33 | 34 | if opts.rank != 0: 35 | return float('nan') 36 | 37 | m = np.square(mu_gen - mu_real).sum() 38 | try: 39 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 40 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 41 | except ValueError: # nan 42 | fid = 9999 43 | return float(fid) 44 | 45 | #---------------------------------------------------------------------------- 46 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Main API for computing and reporting quality metrics.""" 10 | 11 | import os 12 | import time 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | from . import metric_utils 18 | from . import frechet_inception_distance 19 | 20 | os.environ['MKL_NUM_THREADS'] = "1" # needed for FID computation on cluster 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | _metric_dict = dict() # name => fn 25 | 26 | def register_metric(fn): 27 | assert callable(fn) 28 | _metric_dict[fn.__name__] = fn 29 | return fn 30 | 31 | def is_valid_metric(metric): 32 | return metric in _metric_dict 33 | 34 | def list_valid_metrics(): 35 | return list(_metric_dict.keys()) 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 40 | assert is_valid_metric(metric) 41 | opts = metric_utils.MetricOptions(**kwargs) 42 | 43 | # Calculate. 44 | start_time = time.time() 45 | results = _metric_dict[metric](opts) 46 | total_time = time.time() - start_time 47 | 48 | # Broadcast results. 49 | for key, value in list(results.items()): 50 | if opts.num_gpus > 1: 51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 52 | torch.distributed.broadcast(tensor=value, src=0) 53 | value = float(value.cpu()) 54 | results[key] = value 55 | 56 | # Decorate with metadata. 57 | return dnnlib.EasyDict( 58 | results = dnnlib.EasyDict(results), 59 | metric = metric, 60 | total_time = total_time, 61 | total_time_str = dnnlib.util.format_time(total_time), 62 | num_gpus = opts.num_gpus, 63 | ) 64 | 65 | #---------------------------------------------------------------------------- 66 | 67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 68 | metric = result_dict['metric'] 69 | assert is_valid_metric(metric) 70 | if run_dir is not None and snapshot_pkl is not None: 71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 72 | 73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 74 | print(jsonl_line) 75 | if run_dir is not None and os.path.isdir(run_dir): 76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 77 | f.write(jsonl_line + '\n') 78 | 79 | #---------------------------------------------------------------------------- 80 | # Recommended metrics. 81 | 82 | @register_metric 83 | def fid50k_full(opts): 84 | opts.dataset_kwargs.update(max_size=None, xflip=False) 85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 86 | return dict(fid50k_full=fid) 87 | 88 | #---------------------------------------------------------------------------- 89 | # Legacy metrics. 90 | 91 | @register_metric 92 | def fid50k(opts): 93 | opts.dataset_kwargs.update(max_size=None) 94 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 95 | return dict(fid50k=fid) 96 | 97 | #---------------------------------------------------------------------------- 98 | # Added 99 | 100 | @register_metric 101 | def fid10k_full(opts): 102 | opts.dataset_kwargs.update(max_size=None, xflip=False) 103 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000) 104 | return dict(fid10k_full=fid) 105 | 106 | @register_metric 107 | def fid20k_full(opts): 108 | opts.dataset_kwargs.update(max_size=None, xflip=False) 109 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=20000) 110 | return dict(fid10k_full=fid) -------------------------------------------------------------------------------- /scripts/build_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export TMPDIR=~/tmp 3 | conda run -n voxgraf pip install dist/* 4 | -------------------------------------------------------------------------------- /scripts/download_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | mkdir pretrained_models 2 | # FFHQ 256 3 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/voxgraf/models/ffhq256.pkl -P pretrained_models 4 | # Carla 128 5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/voxgraf/models/carla128.pkl -P pretrained_models -------------------------------------------------------------------------------- /scripts/make_datasets.sh: -------------------------------------------------------------------------------- 1 | SRC=$1 2 | META=$2 3 | DST=${3%.zip} 4 | IFS=',' read -ra RES <<< $4 5 | 6 | for res in ${RES[@]} 7 | do 8 | OUT=${DST}_${res}.zip 9 | if test -f "$OUT"; then 10 | continue 11 | fi 12 | echo python dataset_tool.py --source $SRC --meta $META --dest $OUT --resolution ${res}x${res} --mirror-aug True 13 | python dataset_tool.py --source $SRC --meta $META --dest $OUT --resolution ${res}x${res} --mirror-aug True 14 | done -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/build_fat_binary.sh: -------------------------------------------------------------------------------- 1 | export TORCH_CUDA_ARCH_LIST="3.5+PTX;3.7+PTX;5.0+PTX;5.2+PTX;5.3+PTX;6.0+PTX;6.1+PTX;6.2+PTX;7.0+PTX;7.2+PTX;7.5+PTX;8.0+PTX;8.6+PTX" 2 | 3 | #python setup.py install 4 | python setup.py bdist_wheel 5 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | // 9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids 10 | // 11 | 12 | #include 13 | #include 14 | #include 15 | #include "bias_act.h" 16 | 17 | //------------------------------------------------------------------------ 18 | 19 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 20 | { 21 | if (x.dim() != y.dim()) 22 | return false; 23 | for (int64_t i = 0; i < x.dim(); i++) 24 | { 25 | if (x.size(i) != y.size(i)) 26 | return false; 27 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 28 | return false; 29 | } 30 | return true; 31 | } 32 | 33 | //------------------------------------------------------------------------ 34 | 35 | torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 36 | { 37 | // Validate arguments. 38 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 39 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 40 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 41 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 42 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 43 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 44 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 45 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 46 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 47 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 48 | 49 | // Validate layout. 50 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 51 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 52 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 53 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 54 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 55 | 56 | // Create output tensor. 57 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 58 | torch::Tensor y = torch::empty_like(x); 59 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 60 | 61 | // Initialize CUDA kernel parameters. 62 | bias_act_kernel_params p; 63 | p.x = x.data_ptr(); 64 | p.b = (b.numel()) ? b.data_ptr() : NULL; 65 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 66 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 67 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 68 | p.y = y.data_ptr(); 69 | p.grad = grad; 70 | p.act = act; 71 | p.alpha = alpha; 72 | p.gain = gain; 73 | p.clamp = clamp; 74 | p.sizeX = (int)x.numel(); 75 | p.sizeB = (int)b.numel(); 76 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 77 | 78 | // Choose CUDA kernel. 79 | void* kernel; 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 81 | { 82 | kernel = choose_bias_act_kernel(p); 83 | }); 84 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 85 | 86 | // Launch CUDA kernel. 87 | p.loopX = 4; 88 | int blockSize = 4 * 32; 89 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 90 | void* args[] = {&p}; 91 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 92 | return y; 93 | } 94 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | # 9 | # Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids 10 | # 11 | 12 | 13 | """Custom PyTorch ops for efficient bias and activation.""" 14 | 15 | import os 16 | import numpy as np 17 | import torch 18 | import dnnlib 19 | 20 | from .. import custom_ops 21 | from .. import misc 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | activation_funcs = { 26 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 27 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 28 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 29 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 30 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 31 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 32 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 33 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 34 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 35 | } 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | _plugin = None 40 | _null_tensor = torch.empty([0]) 41 | 42 | def _init(): 43 | global _plugin 44 | if _plugin is None: 45 | if False: 46 | _plugin = custom_ops.get_plugin( 47 | module_name='bias_act_plugin', 48 | sources=['bias_act.cpp', 'bias_act.cu'], 49 | headers=['bias_act.h'], 50 | source_dir=os.path.dirname(__file__), 51 | extra_cuda_cflags=['--use_fast_math'], 52 | ) 53 | import stylegan3_cuda as _plugin 54 | return True 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 59 | r"""Fused bias and activation function. 60 | 61 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 62 | and scales the result by `gain`. Each of the steps is optional. In most cases, 63 | the fused op is considerably more efficient than performing the same calculation 64 | using standard PyTorch ops. It supports first and second order gradients, 65 | but not third order gradients. 66 | 67 | Args: 68 | x: Input activation tensor. Can be of any shape. 69 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 70 | as `x`. The shape must be known, and it must match the dimension of `x` 71 | corresponding to `dim`. 72 | dim: The dimension in `x` corresponding to the elements of `b`. 73 | The value of `dim` is ignored if `b` is not specified. 74 | act: Name of the activation function to evaluate, or `"linear"` to disable. 75 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 76 | See `activation_funcs` for a full list. `None` is not allowed. 77 | alpha: Shape parameter for the activation function, or `None` to use the default. 78 | gain: Scaling factor for the output tensor, or `None` to use default. 79 | See `activation_funcs` for the default scaling of each activation function. 80 | If unsure, consider specifying 1. 81 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 82 | the clamping (default). 83 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 84 | 85 | Returns: 86 | Tensor of the same shape and datatype as `x`. 87 | """ 88 | assert isinstance(x, torch.Tensor) 89 | assert impl in ['ref', 'cuda'] 90 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 91 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 92 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 93 | 94 | #---------------------------------------------------------------------------- 95 | 96 | @misc.profiled_function 97 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 98 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 99 | """ 100 | assert isinstance(x, torch.Tensor) 101 | assert clamp is None or clamp >= 0 102 | spec = activation_funcs[act] 103 | alpha = float(alpha if alpha is not None else spec.def_alpha) 104 | gain = float(gain if gain is not None else spec.def_gain) 105 | clamp = float(clamp if clamp is not None else -1) 106 | 107 | # Add bias. 108 | if b is not None: 109 | assert isinstance(b, torch.Tensor) and b.ndim == 1 110 | assert 0 <= dim < x.ndim 111 | assert b.shape[0] == x.shape[dim] 112 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 113 | 114 | # Evaluate activation function. 115 | alpha = float(alpha) 116 | x = spec.func(x, alpha=alpha) 117 | 118 | # Scale by gain. 119 | gain = float(gain) 120 | if gain != 1: 121 | x = x * gain 122 | 123 | # Clamp. 124 | if clamp >= 0: 125 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 126 | return x 127 | 128 | #---------------------------------------------------------------------------- 129 | 130 | _bias_act_cuda_cache = dict() 131 | 132 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 133 | """Fast CUDA implementation of `bias_act()` using custom ops. 134 | """ 135 | # Parse arguments. 136 | assert clamp is None or clamp >= 0 137 | spec = activation_funcs[act] 138 | alpha = float(alpha if alpha is not None else spec.def_alpha) 139 | gain = float(gain if gain is not None else spec.def_gain) 140 | clamp = float(clamp if clamp is not None else -1) 141 | 142 | # Lookup from cache. 143 | key = (dim, act, alpha, gain, clamp) 144 | if key in _bias_act_cuda_cache: 145 | return _bias_act_cuda_cache[key] 146 | 147 | # Forward op. 148 | class BiasActCuda(torch.autograd.Function): 149 | @staticmethod 150 | def forward(ctx, x, b): # pylint: disable=arguments-differ 151 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 152 | x = x.contiguous(memory_format=ctx.memory_format) 153 | b = b.contiguous() if b is not None else _null_tensor 154 | y = x 155 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 156 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 157 | ctx.save_for_backward( 158 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 159 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 160 | y if 'y' in spec.ref else _null_tensor) 161 | return y 162 | 163 | @staticmethod 164 | def backward(ctx, dy): # pylint: disable=arguments-differ 165 | dy = dy.contiguous(memory_format=ctx.memory_format) 166 | x, b, y = ctx.saved_tensors 167 | dx = None 168 | db = None 169 | 170 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 171 | dx = dy 172 | if act != 'linear' or gain != 1 or clamp >= 0: 173 | dx = BiasActCudaGrad.apply(dy, x, b, y) 174 | 175 | if ctx.needs_input_grad[1]: 176 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 177 | 178 | return dx, db 179 | 180 | # Backward op. 181 | class BiasActCudaGrad(torch.autograd.Function): 182 | @staticmethod 183 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 184 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 185 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 186 | ctx.save_for_backward( 187 | dy if spec.has_2nd_grad else _null_tensor, 188 | x, b, y) 189 | return dx 190 | 191 | @staticmethod 192 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 193 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 194 | dy, x, b, y = ctx.saved_tensors 195 | d_dy = None 196 | d_x = None 197 | d_b = None 198 | d_y = None 199 | 200 | if ctx.needs_input_grad[0]: 201 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 202 | 203 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 204 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 205 | 206 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 207 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 208 | 209 | return d_dy, d_x, d_b, d_y 210 | 211 | # Add to cache. 212 | _bias_act_cuda_cache[key] = BiasActCuda 213 | return BiasActCuda 214 | 215 | #---------------------------------------------------------------------------- 216 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | // 9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids 10 | // 11 | 12 | #include 13 | #include "bias_act.h" 14 | 15 | //------------------------------------------------------------------------ 16 | // Helpers. 17 | 18 | template struct InternalType; 19 | template <> struct InternalType { typedef double scalar_t; }; 20 | template <> struct InternalType { typedef float scalar_t; }; 21 | template <> struct InternalType { typedef float scalar_t; }; 22 | 23 | //------------------------------------------------------------------------ 24 | // CUDA kernel. 25 | 26 | template 27 | __global__ void bias_act_kernel(bias_act_kernel_params p) 28 | { 29 | typedef typename InternalType::scalar_t scalar_t; 30 | int G = p.grad; 31 | scalar_t alpha = (scalar_t)p.alpha; 32 | scalar_t gain = (scalar_t)p.gain; 33 | scalar_t clamp = (scalar_t)p.clamp; 34 | scalar_t one = (scalar_t)1; 35 | scalar_t two = (scalar_t)2; 36 | scalar_t expRange = (scalar_t)80; 37 | scalar_t halfExpRange = (scalar_t)40; 38 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 39 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 40 | 41 | // Loop over elements. 42 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 43 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 44 | { 45 | // Load. 46 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 47 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 48 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 49 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 50 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 51 | scalar_t yy = (gain != 0) ? yref / gain : 0; 52 | scalar_t y = 0; 53 | 54 | // Apply bias. 55 | ((G == 0) ? x : xref) += b; 56 | 57 | // linear 58 | if (A == 1) 59 | { 60 | if (G == 0) y = x; 61 | if (G == 1) y = x; 62 | } 63 | 64 | // relu 65 | if (A == 2) 66 | { 67 | if (G == 0) y = (x > 0) ? x : 0; 68 | if (G == 1) y = (yy > 0) ? x : 0; 69 | } 70 | 71 | // lrelu 72 | if (A == 3) 73 | { 74 | if (G == 0) y = (x > 0) ? x : x * alpha; 75 | if (G == 1) y = (yy > 0) ? x : x * alpha; 76 | } 77 | 78 | // tanh 79 | if (A == 4) 80 | { 81 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 82 | if (G == 1) y = x * (one - yy * yy); 83 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 84 | } 85 | 86 | // sigmoid 87 | if (A == 5) 88 | { 89 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 90 | if (G == 1) y = x * yy * (one - yy); 91 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 92 | } 93 | 94 | // elu 95 | if (A == 6) 96 | { 97 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 98 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 99 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 100 | } 101 | 102 | // selu 103 | if (A == 7) 104 | { 105 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 106 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 107 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 108 | } 109 | 110 | // softplus 111 | if (A == 8) 112 | { 113 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 114 | if (G == 1) y = x * (one - exp(-yy)); 115 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 116 | } 117 | 118 | // swish 119 | if (A == 9) 120 | { 121 | if (G == 0) 122 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 123 | else 124 | { 125 | scalar_t c = exp(xref); 126 | scalar_t d = c + one; 127 | if (G == 1) 128 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 129 | else 130 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 131 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 132 | } 133 | } 134 | 135 | // Apply gain. 136 | y *= gain * dy; 137 | 138 | // Clamp. 139 | if (clamp >= 0) 140 | { 141 | if (G == 0) 142 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 143 | else 144 | y = (yref > -clamp & yref < clamp) ? y : 0; 145 | } 146 | 147 | // Store. 148 | ((T*)p.y)[xi] = (T)y; 149 | } 150 | } 151 | 152 | //------------------------------------------------------------------------ 153 | // CUDA kernel selection. 154 | 155 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 156 | { 157 | if (p.act == 1) return (void*)bias_act_kernel; 158 | if (p.act == 2) return (void*)bias_act_kernel; 159 | if (p.act == 3) return (void*)bias_act_kernel; 160 | if (p.act == 4) return (void*)bias_act_kernel; 161 | if (p.act == 5) return (void*)bias_act_kernel; 162 | if (p.act == 6) return (void*)bias_act_kernel; 163 | if (p.act == 7) return (void*)bias_act_kernel; 164 | if (p.act == 8) return (void*)bias_act_kernel; 165 | if (p.act == 9) return (void*)bias_act_kernel; 166 | return NULL; 167 | } 168 | 169 | //------------------------------------------------------------------------ 170 | // Template specializations. 171 | 172 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 173 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 174 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 175 | 176 | //------------------------------------------------------------------------ 177 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | from pkg_resources import parse_version 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | @contextlib.contextmanager 27 | def no_weight_gradients(disable=True): 28 | global weight_gradients_disabled 29 | old = weight_gradients_disabled 30 | if disable: 31 | weight_gradients_disabled = True 32 | yield 33 | weight_gradients_disabled = old 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 38 | if _should_use_custom_op(input): 39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 41 | 42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 43 | if _should_use_custom_op(input): 44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _should_use_custom_op(input): 50 | assert isinstance(input, torch.Tensor) 51 | if (not enabled) or (not torch.backends.cudnn.enabled): 52 | return False 53 | if _use_pytorch_1_11_api: 54 | # The work-around code doesn't work on PyTorch 1.11.0 onwards 55 | return False 56 | if input.device.type != 'cuda': 57 | return False 58 | return True 59 | 60 | def _tuple_of_ints(xs, ndim): 61 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 62 | assert len(xs) == ndim 63 | assert all(isinstance(x, int) for x in xs) 64 | return xs 65 | 66 | #---------------------------------------------------------------------------- 67 | 68 | _conv2d_gradfix_cache = dict() 69 | _null_tensor = torch.empty([0]) 70 | 71 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 72 | # Parse arguments. 73 | ndim = 2 74 | weight_shape = tuple(weight_shape) 75 | stride = _tuple_of_ints(stride, ndim) 76 | padding = _tuple_of_ints(padding, ndim) 77 | output_padding = _tuple_of_ints(output_padding, ndim) 78 | dilation = _tuple_of_ints(dilation, ndim) 79 | 80 | # Lookup from cache. 81 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 82 | if key in _conv2d_gradfix_cache: 83 | return _conv2d_gradfix_cache[key] 84 | 85 | # Validate arguments. 86 | assert groups >= 1 87 | assert len(weight_shape) == ndim + 2 88 | assert all(stride[i] >= 1 for i in range(ndim)) 89 | assert all(padding[i] >= 0 for i in range(ndim)) 90 | assert all(dilation[i] >= 0 for i in range(ndim)) 91 | if not transpose: 92 | assert all(output_padding[i] == 0 for i in range(ndim)) 93 | else: # transpose 94 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 95 | 96 | # Helpers. 97 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 98 | def calc_output_padding(input_shape, output_shape): 99 | if transpose: 100 | return [0, 0] 101 | return [ 102 | input_shape[i + 2] 103 | - (output_shape[i + 2] - 1) * stride[i] 104 | - (1 - 2 * padding[i]) 105 | - dilation[i] * (weight_shape[i + 2] - 1) 106 | for i in range(ndim) 107 | ] 108 | 109 | # Forward & backward. 110 | class Conv2d(torch.autograd.Function): 111 | @staticmethod 112 | def forward(ctx, input, weight, bias): 113 | assert weight.shape == weight_shape 114 | ctx.save_for_backward( 115 | input if weight.requires_grad else _null_tensor, 116 | weight if input.requires_grad else _null_tensor, 117 | ) 118 | ctx.input_shape = input.shape 119 | 120 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 121 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 122 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 123 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 124 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 125 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 126 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 127 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 128 | 129 | # General case => cuDNN. 130 | if transpose: 131 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 132 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | input, weight = ctx.saved_tensors 137 | input_shape = ctx.input_shape 138 | grad_input = None 139 | grad_weight = None 140 | grad_bias = None 141 | 142 | if ctx.needs_input_grad[0]: 143 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 144 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 145 | grad_input = op.apply(grad_output, weight, None) 146 | assert grad_input.shape == input_shape 147 | 148 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 149 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 150 | assert grad_weight.shape == weight_shape 151 | 152 | if ctx.needs_input_grad[2]: 153 | grad_bias = grad_output.sum([0, 2, 3]) 154 | 155 | return grad_input, grad_weight, grad_bias 156 | 157 | # Gradient with respect to the weights. 158 | class Conv2dGradWeight(torch.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, grad_output, input): 161 | ctx.save_for_backward( 162 | grad_output if input.requires_grad else _null_tensor, 163 | input if grad_output.requires_grad else _null_tensor, 164 | ) 165 | ctx.grad_output_shape = grad_output.shape 166 | ctx.input_shape = input.shape 167 | 168 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 169 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 170 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 171 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 172 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 173 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 174 | 175 | # General case => cuDNN. 176 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 177 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 178 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 179 | 180 | @staticmethod 181 | def backward(ctx, grad2_grad_weight): 182 | grad_output, input = ctx.saved_tensors 183 | grad_output_shape = ctx.grad_output_shape 184 | input_shape = ctx.input_shape 185 | grad2_grad_output = None 186 | grad2_input = None 187 | 188 | if ctx.needs_input_grad[0]: 189 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 190 | assert grad2_grad_output.shape == grad_output_shape 191 | 192 | if ctx.needs_input_grad[1]: 193 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 194 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 195 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 196 | assert grad2_input.shape == input_shape 197 | 198 | return grad2_grad_output, grad2_input 199 | 200 | _conv2d_gradfix_cache[key] = Conv2d 201 | return Conv2d 202 | 203 | #---------------------------------------------------------------------------- 204 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/pybind.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | // 9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids 10 | // 11 | 12 | #include 13 | #include "bias_act.h" 14 | #include "filtered_lrelu.h" 15 | #include "upfirdn2d.h" 16 | 17 | torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp); 18 | torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns); 19 | std::tuple filtered_lrelu( 20 | torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, 21 | int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns); 22 | torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain); 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 25 | { 26 | m.def("bias_act", &bias_act); 27 | m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. 28 | m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. 29 | m.def("upfirdn2d", &upfirdn2d); 30 | } 31 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | // 9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids 10 | // 11 | 12 | #include 13 | #include 14 | #include 15 | #include "upfirdn2d.h" 16 | 17 | //------------------------------------------------------------------------ 18 | 19 | torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 20 | { 21 | // Validate arguments. 22 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 23 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 24 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 25 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 26 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 27 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 28 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 29 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 30 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 31 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 32 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 33 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 34 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 35 | 36 | // Create output tensor. 37 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 38 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 39 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 40 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 41 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 42 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 43 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 44 | 45 | // Initialize CUDA kernel parameters. 46 | upfirdn2d_kernel_params p; 47 | p.x = x.data_ptr(); 48 | p.f = f.data_ptr(); 49 | p.y = y.data_ptr(); 50 | p.up = make_int2(upx, upy); 51 | p.down = make_int2(downx, downy); 52 | p.pad0 = make_int2(padx0, pady0); 53 | p.flip = (flip) ? 1 : 0; 54 | p.gain = gain; 55 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 56 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 57 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 58 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 59 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 60 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 61 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 62 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 63 | 64 | // Choose CUDA kernel. 65 | upfirdn2d_kernel_spec spec; 66 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 67 | { 68 | spec = choose_upfirdn2d_kernel(p); 69 | }); 70 | 71 | // Set looping options. 72 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 73 | p.loopMinor = spec.loopMinor; 74 | p.loopX = spec.loopX; 75 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 76 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 77 | 78 | // Compute grid size. 79 | dim3 blockSize, gridSize; 80 | if (spec.tileOutW < 0) // large 81 | { 82 | blockSize = dim3(4, 32, 1); 83 | gridSize = dim3( 84 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 85 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 86 | p.launchMajor); 87 | } 88 | else // small 89 | { 90 | blockSize = dim3(256, 1, 1); 91 | gridSize = dim3( 92 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 93 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 94 | p.launchMajor); 95 | } 96 | 97 | // Launch CUDA kernel. 98 | void* args[] = {&p}; 99 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 100 | return y; 101 | } 102 | 103 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | import os 4 | 5 | extension = CUDAExtension('stylegan3_cuda', 6 | sources=[os.path.join('./ops', x) for x in ['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 7 | 'filtered_lrelu_ns.cu', 'bias_act.cpp', 'bias_act_kernel.cu', 'pybind.cpp', 'upfirdn2d.cpp', 'upfirdn2d_kernel.cu']], 8 | headers=[os.path.join('./ops', x) for x in ['filtered_lrelu.h', 'filtered_lrelu.cu', 'bias_act.h', 'upfirdn2d.h']], 9 | extra_compile_args = {'cxx': [], 'nvcc': ['--use_fast_math']} 10 | ) 11 | 12 | setup( 13 | name='stylegan3_cuda', 14 | ext_modules=[extension], 15 | cmdclass={ 16 | 'build_ext': BuildExtension 17 | }) 18 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/networks_voxgraf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | # 9 | # Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids 10 | # 11 | 12 | import torch 13 | from torch_utils import persistence 14 | from training.networks_stylegan2 import SynthesisNetwork as SG2_syn, MappingNetwork as SingleMappingNetwork 15 | from training.networks_stylegan3 import SynthesisLayer as SG3_SynthesisLayer 16 | from training.networks_stylegan2_3d import SynthesisNetwork as Synthesis3D 17 | from training.virtual_camera_utils import PoseSampler, Renderer 18 | 19 | @persistence.persistent_class 20 | class MappingNetwork(torch.nn.Module): 21 | """Wrapper class for fg and bg mapping network.""" 22 | def __init__(self, 23 | z_dim, # Input latent (Z) dimensionality. 24 | c_dim, # Conditioning label (C) dimensionality, 0 = no labels. 25 | w_dim, # Intermediate latent (W) dimensionality. 26 | num_ws_fg, # Number of intermediate latents to output for foreground. 27 | num_ws_bg, # Number of intermediate latents to output for background. 28 | use_bg=True, # Model background with 2D GAN 29 | **mapping_kwargs 30 | ): 31 | super().__init__() 32 | self.fg = SingleMappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=num_ws_fg, **mapping_kwargs) 33 | self.use_bg = use_bg 34 | if use_bg: 35 | self.bg = SingleMappingNetwork(z_dim=z_dim, c_dim=0, w_dim=w_dim, num_ws=num_ws_bg, **mapping_kwargs) # bg is not conditioned on pose 36 | 37 | def forward(self, z_fg, z_bg, c_fg, c_bg, **kwargs): 38 | w_fg = self.fg(z_fg, c_fg, **kwargs) 39 | w_bg = None 40 | if self.use_bg: 41 | w_bg = self.bg(z_bg, c_bg, **kwargs) 42 | return w_fg, w_bg 43 | 44 | 45 | @persistence.persistent_class 46 | class RefinementNetwork(torch.nn.Module): 47 | def __init__(self, w_dim, input_resolution, input_channels, output_channels, dhidden, num_layers): 48 | super().__init__() 49 | assert num_layers > 0 50 | last_cutoff = input_resolution / 2 51 | refinement = [] 52 | for i in range(num_layers): 53 | is_last = (i == num_layers-1) 54 | 55 | refinement.append(SG3_SynthesisLayer( 56 | w_dim=w_dim, is_torgb=is_last, is_critically_sampled=True, use_fp16=False, 57 | in_channels=input_channels if i==0 else dhidden, 58 | out_channels=dhidden if not is_last else output_channels, 59 | in_size=input_resolution, out_size=input_resolution, 60 | in_sampling_rate=input_resolution, out_sampling_rate=input_resolution, 61 | in_cutoff=last_cutoff, out_cutoff=last_cutoff, 62 | in_half_width=0, out_half_width=0 63 | )) 64 | 65 | self.layers = torch.nn.ModuleList(refinement) 66 | 67 | def forward(self, x, w, update_emas=False): 68 | for layer in self.layers: 69 | x = layer(x, w, update_emas=update_emas) 70 | return x 71 | 72 | 73 | @persistence.persistent_class 74 | class SynthesisNetwork(torch.nn.Module): 75 | def __init__(self, 76 | w_dim, # Intermediate latent (W) dimensionality. 77 | img_resolution, # Output image resolution. 78 | img_channels, # Number of color channels. 79 | use_bg=True, # Model background with 2D GAN 80 | sigma_mpl = 1, # Rescaling factor for density 81 | pose_kwargs={}, 82 | render_kwargs={}, 83 | fg_kwargs={}, 84 | bg_kwargs={}, 85 | refinement_kwargs={}, 86 | ): 87 | super().__init__() 88 | if img_channels != 3: 89 | raise NotImplementedError 90 | 91 | self.img_resolution = img_resolution 92 | self.use_bg = use_bg 93 | self.use_refinement = refinement_kwargs.get('num_layers', 0) > 0 94 | self.sigma_mpl = sigma_mpl 95 | 96 | self.pose_sampler = PoseSampler(**pose_kwargs) 97 | self.renderer = Renderer(**render_kwargs) 98 | 99 | self.generator_fg = Synthesis3D(w_dim=w_dim, img_channels=img_channels*self.renderer.basis_dim+1, architecture='skip', renderer=self.renderer, **fg_kwargs) 100 | if self.use_bg: 101 | self.generator_bg = SG2_syn(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **bg_kwargs) 102 | 103 | if self.use_refinement: 104 | self.__setattr__(f'r{self.img_resolution}', RefinementNetwork(w_dim, self.img_resolution, input_channels=3, output_channels=img_channels, **refinement_kwargs)) 105 | 106 | def __repr__(self): 107 | return ( 108 | f"SynthesisNetwork(img_resolution={self.img_resolution}, use_bg={self.use_bg}, use_refinement={self.use_refinement})" 109 | ) 110 | 111 | def _get_sparse_grad_indexer(self, device): 112 | indexer = None#self.sparse_grad_indexer # TODO: can we reuse it? 113 | if indexer is None: 114 | indexer = torch.empty((0,), dtype=torch.bool, device=device) 115 | return indexer 116 | 117 | def _get_sparse_sh_grad_indexer(self, device): 118 | indexer = None# self.sparse_sh_grad_indexer # TODO: can we reuse it? 119 | if indexer is None: 120 | indexer = torch.empty((0,), dtype=torch.bool, device=device) 121 | return indexer 122 | 123 | def forward(self, ws, ws_bg, pose=None, noise_mode='none', update_emas=None, return_3d=False, n_views=1, render_alpha=False, render_depth=False, render_vardepth=False, raw_noise_std_sigma=0): 124 | B = ws.shape[0] 125 | 126 | # Get camera poses 127 | if pose is None or torch.isnan(pose).all(): 128 | pose = torch.stack([self.pose_sampler.sample_from_poses() for _ in range(B*n_views)]) 129 | 130 | # Generate sparse voxel grid 131 | density_data_list, sh_data_list, coords = self.generator_fg(ws, pose=pose, noise_mode=noise_mode, update_emas=update_emas) 132 | 133 | if raw_noise_std_sigma > 0: 134 | for i in range(B): 135 | density_data_list[i] = density_data_list[i] + torch.randn(density_data_list[i].shape, device=ws.device) * raw_noise_std_sigma 136 | if self.sigma_mpl != 1: 137 | for i in range(B): 138 | density_data_list[i] = density_data_list[i] * self.sigma_mpl 139 | 140 | # Render foreground 141 | pred = self.renderer(pose, density_data_list, sh_data_list, coords, img_resolution=self.img_resolution, grid_resolution=self.generator_fg.grid_resolution, render_alpha=(render_alpha or self.use_bg), render_depth=render_depth, render_vardepth=render_vardepth) 142 | pred = pred.view(B*n_views, self.img_resolution, self.img_resolution, 6).permute(0, 3, 1, 2) # (BxHxW)xC -> # BxCxHxW 143 | rgb, alpha, depth, vardepth = pred[:, :3], pred[:, 3:4], pred[:, 4:5], pred[:, 5:6] 144 | 145 | if self.use_bg: 146 | rgb_bg = self.generator_bg(ws_bg, update_emas=update_emas, noise_mode=noise_mode) 147 | 148 | # alpha compositing 149 | rgb_bg = rgb_bg / 2 + 0.5 # [-1, 1] -> [0, 1] 150 | assert alpha.min() >= 0 and alpha.max() <= 1+1e-3 # add some offset due to precision in alpha computation 151 | rgb = alpha * rgb + (1 - alpha) * rgb_bg 152 | 153 | # [0,1] -> [-1, 1] 154 | rgb = rgb * 2 - 1 155 | 156 | # Refine composited image 157 | if self.use_refinement: 158 | w_last = ws[:, -1] # simply reuse last w for refinement layers 159 | rgb = self.__getattr__(f'r{self.img_resolution}')(rgb, w_last, update_emas) 160 | 161 | out = {'rgb': rgb} 162 | if render_alpha: 163 | assert alpha.min() >= 0 and alpha.max() <= 1+1e-3 # add some offset due to precision in alpha computation 164 | out['alpha'] = alpha 165 | if render_depth: 166 | out['depth'] = depth 167 | if render_vardepth: 168 | out['vardepth'] = vardepth 169 | if return_3d: 170 | out['density'] = density_data_list 171 | out['sh'] = sh_data_list 172 | out['coords'] = coords 173 | return out 174 | 175 | 176 | @persistence.persistent_class 177 | class Generator(torch.nn.Module): 178 | def __init__(self, 179 | z_dim, # Input latent (Z) dimensionality. 180 | c_dim, # Conditioning label (C) dimensionality. 181 | w_dim, # Intermediate latent (W) dimensionality. 182 | img_resolution, # Output resolution. 183 | img_channels, # Number of output color channels. 184 | use_bg = True, # Model background with 2D GAN 185 | pose_conditioning = True, # Condition generator on pose 186 | mapping_kwargs = {}, # Arguments for MappingNetwork. 187 | **synthesis_kwargs, # Arguments for SynthesisNetwork. 188 | ): 189 | super().__init__() 190 | self.z_dim = z_dim 191 | self.c_dim = c_dim 192 | self.w_dim = w_dim 193 | self.img_resolution = img_resolution 194 | self.img_channels = img_channels 195 | self.pose_conditioning = pose_conditioning 196 | self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, use_bg=use_bg, **synthesis_kwargs) 197 | self.num_ws = self.synthesis.generator_fg.num_ws 198 | self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws_fg=self.synthesis.generator_fg.num_ws, num_ws_bg=0 if not use_bg else self.synthesis.generator_bg.num_ws, use_bg=use_bg, **mapping_kwargs) 199 | 200 | def forward(self, z, c, pose, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): 201 | ws, ws_bg = self.mapping(z, z, c, torch.empty_like(c[:, :0]), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) 202 | img = self.synthesis(ws, update_emas=update_emas, ws_bg=ws_bg, pose=pose, **synthesis_kwargs) 203 | return img 204 | 205 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, the Plenoxels authors 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | 28 | This variant of plenoxels was modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids. 29 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/README.md: -------------------------------------------------------------------------------- 1 | We adapted the CUDA kernels from [Plenoxels: Radiance Fields without Neural Networks](https://arxiv.org/abs/2112.05131). 2 | Try installing pre-compiled CUDA extension: 3 | ```commandline 4 | pip install ../dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl 5 | ``` 6 | Or install from this directory by running 7 | ```commandline 8 | pip install . 9 | ``` 10 | For more details please refer to [the Plenoxels github repository](https://github.com/sxyu/svox2). 11 | 12 | Modifications contain:
13 | * rendering alpha 14 | * rendering depth 15 | * rendering variance of depth 16 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/environment.yml: -------------------------------------------------------------------------------- 1 | # run: conda env create -f environment.yml 2 | name: plenoxel 3 | channels: 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.8.8 8 | - numpy>=1.16.4,<1.19.0 9 | - pip 10 | - pip: 11 | - imageio 12 | - imageio-ffmpeg 13 | - ipdb 14 | - lpips 15 | - opencv-python>=4.4.0 16 | - Pillow>=7.2.0 17 | - pyyaml>=5.3.1 18 | - tensorboard>=2.4.0 19 | - imageio 20 | - imageio-ffmpeg 21 | - pymcubes 22 | - moviepy 23 | - matplotlib 24 | - scipy>=1.6.0 25 | - pytorch 26 | - torchvision 27 | - cudatoolkit 28 | - tqdm 29 | 30 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/manual_install.sh: -------------------------------------------------------------------------------- 1 | cp svox2/svox2.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/svox2.py 2 | cp svox2/utils.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/utils.py 3 | cp svox2/version.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/version.py 4 | cp svox2/defs.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/defs.py 5 | cp svox2/__init__.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/__init__.py 6 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import os 3 | import os.path as osp 4 | import warnings 5 | 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 9 | 10 | __version__ = None 11 | exec(open('svox2/version.py', 'r').read()) 12 | 13 | CUDA_FLAGS = [] 14 | INSTALL_REQUIREMENTS = [] 15 | include_dirs = [osp.join(ROOT_DIR, "svox2", "csrc", "include")] 16 | 17 | # From PyTorch3D 18 | cub_home = os.environ.get("CUB_HOME", None) 19 | if cub_home is None: 20 | prefix = os.environ.get("CONDA_PREFIX", None) 21 | if prefix is not None and os.path.isdir(prefix + "/include/cub"): 22 | cub_home = prefix + "/include" 23 | 24 | if cub_home is None: 25 | warnings.warn( 26 | "The environment variable `CUB_HOME` was not found." 27 | "Installation will fail if your system CUDA toolkit version is less than 11." 28 | "NVIDIA CUB can be downloaded " 29 | "from `https://github.com/NVIDIA/cub/releases`. You can unpack " 30 | "it to a location of your choice and set the environment variable " 31 | "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." 32 | ) 33 | else: 34 | include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " ")) 35 | 36 | try: 37 | ext_modules = [ 38 | CUDAExtension('svox2.csrc', [ 39 | 'svox2/csrc/svox2.cpp', 40 | 'svox2/csrc/svox2_kernel.cu', 41 | 'svox2/csrc/render_lerp_kernel_cuvol.cu', 42 | 'svox2/csrc/render_lerp_kernel_nvol.cu', 43 | 'svox2/csrc/render_svox1_kernel.cu', 44 | 'svox2/csrc/misc_kernel.cu', 45 | 'svox2/csrc/loss_kernel.cu', 46 | 'svox2/csrc/optim_kernel.cu', 47 | ], include_dirs=include_dirs, 48 | optional=False), 49 | ] 50 | except: 51 | import warnings 52 | warnings.warn("Failed to build CUDA extension") 53 | ext_modules = [] 54 | 55 | setup( 56 | name='svox2', 57 | version=__version__, 58 | author='Alex Yu / Modified by Katja Schwarz', 59 | author_email='alexyu99126@gmail.com / katja.schwarz@uni-tuebingen.de', 60 | description='PyTorch sparse voxel volume extension, including custom CUDA kernels, adapted some operations to VoxGRAF', 61 | long_description='PyTorch sparse voxel volume extension, including custom CUDA kernels, adapted some operations to VoxGRAF', 62 | ext_modules=ext_modules, 63 | setup_requires=['pybind11>=2.5.0'], 64 | packages=['svox2', 'svox2.csrc'], 65 | cmdclass={'build_ext': BuildExtension}, 66 | zip_safe=False, 67 | ) 68 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/__init__.py: -------------------------------------------------------------------------------- 1 | from .defs import * 2 | from .svox2 import SparseGrid, Camera, Rays, RenderOptions 3 | from .version import __version__ 4 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/.ccls: -------------------------------------------------------------------------------- 1 | %compile_commands.json 2 | %cu -x cuda 3 | %cu --cuda-gpu-arch=sm_61 4 | %cu --cuda-path=/usr/local/cuda-11.2 5 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2021 PlenOctree Authors. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions are met: 5 | # 6 | # 1. Redistributions of source code must retain the above copyright notice, 7 | # this list of conditions and the following disclaimer. 8 | # 9 | # 2. Redistributions in binary form must reproduce the above copyright notice, 10 | # this list of conditions and the following disclaimer in the documentation 11 | # and/or other materials provided with the distribution. 12 | # 13 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 14 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 15 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 16 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 17 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 18 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 19 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 20 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 21 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 22 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 23 | # POSSIBILITY OF SUCH DAMAGE. 24 | 25 | # NOTE: This CMakeLists is for development purposes only 26 | # (To check CUDA compile errors) 27 | # It is NOT necessary to use this for installation. Just use pip install . 28 | cmake_minimum_required( VERSION 3.3 ) 29 | 30 | if(NOT CMAKE_BUILD_TYPE) 31 | set(CMAKE_BUILD_TYPE Release) 32 | endif() 33 | if (POLICY CMP0048) 34 | cmake_policy(SET CMP0048 NEW) 35 | endif (POLICY CMP0048) 36 | if (POLICY CMP0069) 37 | cmake_policy(SET CMP0069 NEW) 38 | endif (POLICY CMP0069) 39 | if (POLICY CMP0072) 40 | cmake_policy(SET CMP0072 NEW) 41 | endif (POLICY CMP0072) 42 | 43 | project( svox2 ) 44 | 45 | set(CMAKE_CXX_STANDARD 14) 46 | enable_language(CUDA) 47 | message(STATUS "CUDA enabled") 48 | set( CMAKE_CUDA_STANDARD 14 ) 49 | set( CMAKE_CUDA_STANDARD_REQUIRED ON) 50 | set( CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -Xcudafe \"--display_error_number --diag_suppress=3057 --diag_suppress=3058 --diag_suppress=3059 --diag_suppress=3060\" -lineinfo -arch=sm_75 ") 51 | # -Xptxas=\"-v\" 52 | 53 | set( INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include" ) 54 | 55 | if( MSVC ) 56 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd") 57 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT /GLT /Ox") 58 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler=\"/MT\"" ) 59 | endif() 60 | 61 | file(GLOB SOURCES 62 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp 63 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cu) 64 | 65 | find_package(pybind11 REQUIRED) 66 | find_package(Torch REQUIRED) 67 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 68 | 69 | include_directories (${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) 70 | 71 | pybind11_add_module(svox2-test SHARED ${SOURCES}) 72 | target_link_libraries(svox2-test PRIVATE "${TORCH_LIBRARIES}") 73 | target_include_directories(svox2-test PRIVATE "${INCLUDE_DIR}") 74 | 75 | if (MSVC) 76 | file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") 77 | add_custom_command(TARGET svox2-test 78 | POST_BUILD 79 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 80 | ${TORCH_DLLS} 81 | $) 82 | endif (MSVC) 83 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/include/cubemap_util.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cuda_util.cuh" 3 | #include 4 | #include 5 | 6 | #define _AXIS(x) (x>>1) 7 | #define _ORI(x) (x&1) 8 | #define _FACE(axis, ori) uint8_t((axis << 1) | ori) 9 | 10 | namespace { 11 | namespace device { 12 | 13 | struct CubemapCoord { 14 | uint8_t face; 15 | float uv[2]; 16 | }; 17 | 18 | struct CubemapLocation { 19 | uint8_t face; 20 | int16_t uv[2]; 21 | }; 22 | 23 | struct CubemapBilerpQuery { 24 | CubemapLocation ptr[2][2]; 25 | float duv[2]; 26 | }; 27 | 28 | __device__ __inline__ void 29 | invert_cubemap(int u, int v, float r, 30 | int reso, 31 | float* __restrict__ out) { 32 | const float u_norm = (u + 0.5f) / reso * 2 - 1; 33 | const float v_norm = (v + 0.5f) / reso * 2 - 1; 34 | // EAC 35 | const float tx = tanf((M_PI / 4) * u_norm); 36 | const float ty = tanf((M_PI / 4) * v_norm); 37 | const float common = r * rnorm3df(1.f, tx, ty); 38 | out[0] = tx * common; 39 | out[1] = ty * common; 40 | out[2] = common; 41 | } 42 | 43 | __device__ __inline__ void 44 | invert_cubemap_traditional(int u, int v, float r, 45 | int reso, 46 | float* __restrict__ out) { 47 | const float u_norm = (u + 0.5f) / reso * 2 - 1; 48 | const float v_norm = (v + 0.5f) / reso * 2 - 1; 49 | const float common = r * rnorm3df(1.f, u_norm, v_norm); 50 | out[0] = u_norm * common; 51 | out[1] = v_norm * common; 52 | out[2] = common; 53 | } 54 | 55 | __device__ __host__ __inline__ CubemapCoord 56 | dir_to_cubemap_coord(const float* __restrict__ xyz_o, 57 | int face_reso, 58 | bool eac = true) { 59 | float maxv; 60 | int ax; 61 | float xyz[3] = {xyz_o[0], xyz_o[1], xyz_o[2]}; 62 | if (fabsf(xyz[0]) >= fabsf(xyz[1]) && fabsf(xyz[0]) >= fabsf(xyz[2])) { 63 | ax = 0; maxv = xyz[0]; 64 | } else if (fabsf(xyz[1]) >= fabsf(xyz[2])) { 65 | ax = 1; maxv = xyz[1]; 66 | } else { 67 | ax = 2; maxv = xyz[2]; 68 | } 69 | const float recip = 1.f / fabsf(maxv); 70 | xyz[0] *= recip; 71 | xyz[1] *= recip; 72 | xyz[2] *= recip; 73 | 74 | if (eac) { 75 | #pragma unroll 3 76 | for (int i = 0; i < 3; ++i) { 77 | xyz[i] = atanf(xyz[i]) * (4 * M_1_PI); 78 | } 79 | } 80 | 81 | CubemapCoord idx; 82 | idx.uv[0] = ((xyz[(ax ^ 1) & 1] + 1) * face_reso - 1) * 0.5; 83 | idx.uv[1] = ((xyz[(ax ^ 2) & 2] + 1) * face_reso - 1) * 0.5; 84 | const int ori = xyz[ax] >= 0; 85 | idx.face = _FACE(ax, ori); 86 | 87 | return idx; 88 | } 89 | 90 | __device__ __host__ __inline__ CubemapBilerpQuery 91 | cubemap_build_query( 92 | const CubemapCoord& idx, 93 | int face_reso) { 94 | const int uv_idx[2] ={ (int)floorf(idx.uv[0]), (int)floorf(idx.uv[1]) }; 95 | 96 | bool m[2][2]; 97 | m[0][0] = uv_idx[0] < 0; 98 | m[0][1] = uv_idx[0] > face_reso - 2; 99 | m[1][0] = uv_idx[1] < 0; 100 | m[1][1] = uv_idx[1] > face_reso - 2; 101 | 102 | const int face = idx.face; 103 | const int ax = _AXIS(face); 104 | const int ori = _ORI(face); 105 | // if ax is one of {0, 1, 2}, this trick gets the 2 106 | // of {0, 1, 2} other than ax 107 | const int uvd[2] = {((ax ^ 1) & 1), ((ax ^ 2) & 2)}; 108 | int uv_ori[2]; 109 | 110 | CubemapBilerpQuery result; 111 | result.duv[0] = idx.uv[0] - uv_idx[0]; 112 | result.duv[1] = idx.uv[1] - uv_idx[1]; 113 | 114 | #pragma unroll 2 115 | for (uv_ori[0] = 0; uv_ori[0] < 2; ++uv_ori[0]) { 116 | #pragma unroll 2 117 | for (uv_ori[1] = 0; uv_ori[1] < 2; ++uv_ori[1]) { 118 | CubemapLocation& nidx = result.ptr[uv_ori[0]][uv_ori[1]]; 119 | nidx.face = face; 120 | nidx.uv[0] = uv_idx[0] + uv_ori[0]; 121 | nidx.uv[1] = uv_idx[1] + uv_ori[1]; 122 | 123 | const bool mu = m[0][uv_ori[0]]; 124 | const bool mv = m[1][uv_ori[1]]; 125 | 126 | int edge_idx = -1; 127 | if (mu) { 128 | // Crosses edge in u-axis 129 | if (mv) { 130 | // FIXME: deal with corners properly, right now 131 | // just clamps, resulting in a little artifact 132 | // at each cube corner 133 | nidx.uv[0] = min(max(nidx.uv[0], 0), face_reso - 1); 134 | nidx.uv[1] = min(max(nidx.uv[1], 0), face_reso - 1); 135 | } else { 136 | edge_idx = 0; 137 | } 138 | } else if (mv) { 139 | // Crosses edge in v-axis 140 | edge_idx = 1; 141 | } 142 | if (~edge_idx) { 143 | const int nax = uvd[edge_idx]; 144 | const int16_t other_coord = nidx.uv[1 - edge_idx]; 145 | 146 | // Determine directions in the new face 147 | const int nud = (nax ^ 1) & 1; 148 | // const int nvd = (nax ^ 2) & 2; 149 | 150 | if (nud == ax) { 151 | nidx.uv[0] = ori ? (face_reso - 1) : 0; 152 | nidx.uv[1] = other_coord; 153 | } else { 154 | nidx.uv[0] = other_coord; 155 | nidx.uv[1] = ori ? (face_reso - 1) : 0; 156 | } 157 | 158 | nidx.face = _FACE(nax, uv_ori[edge_idx]); 159 | } 160 | // Interior point: nothing needs to be done 161 | 162 | } 163 | } 164 | 165 | return result; 166 | } 167 | 168 | __device__ __host__ __inline__ float 169 | cubemap_sample( 170 | const float* __restrict__ cubemap, // (6, face_reso, face_reso, n_channels) 171 | const CubemapBilerpQuery& query, 172 | int face_reso, 173 | int n_channels, 174 | int chnl_id) { 175 | 176 | // NOTE: assuming address will fit in int32 177 | const int stride1 = face_reso * n_channels; 178 | const int stride0 = face_reso * stride1; 179 | const CubemapLocation& p00 = query.ptr[0][0]; 180 | const float v00 = cubemap[p00.face * stride0 + p00.uv[0] * stride1 + p00.uv[1] * n_channels + chnl_id]; 181 | const CubemapLocation& p01 = query.ptr[0][1]; 182 | const float v01 = cubemap[p01.face * stride0 + p01.uv[0] * stride1 + p01.uv[1] * n_channels + chnl_id]; 183 | const CubemapLocation& p10 = query.ptr[1][0]; 184 | const float v10 = cubemap[p10.face * stride0 + p10.uv[0] * stride1 + p10.uv[1] * n_channels + chnl_id]; 185 | const CubemapLocation& p11 = query.ptr[1][1]; 186 | const float v11 = cubemap[p11.face * stride0 + p11.uv[0] * stride1 + p11.uv[1] * n_channels + chnl_id]; 187 | 188 | const float val0 = lerp(v00, v01, query.duv[1]); 189 | const float val1 = lerp(v10, v11, query.duv[1]); 190 | 191 | return lerp(val0, val1, query.duv[0]); 192 | } 193 | 194 | __device__ __inline__ void 195 | cubemap_sample_backward( 196 | float* __restrict__ cubemap_grad, // (6, face_reso, face_reso, n_channels) 197 | const CubemapBilerpQuery& query, 198 | int face_reso, 199 | int n_channels, 200 | float grad_out, 201 | int chnl_id, 202 | bool* __restrict__ mask_out = nullptr) { 203 | 204 | // NOTE: assuming address will fit in int32 205 | const float bu = query.duv[0], bv = query.duv[1]; 206 | const float au = 1.f - bu, av = 1.f - bv; 207 | 208 | #define _ADD_CUBEVERT(i, j, val) { \ 209 | const CubemapLocation& p00 = query.ptr[i][j]; \ 210 | const int idx = (p00.face * face_reso + p00.uv[0]) * face_reso + p00.uv[1]; \ 211 | float* __restrict__ v00 = &cubemap_grad[idx * n_channels + chnl_id]; \ 212 | atomicAdd(v00, val); \ 213 | if (mask_out != nullptr) { \ 214 | mask_out[idx] = true; \ 215 | } \ 216 | } 217 | 218 | _ADD_CUBEVERT(0, 0, au * av * grad_out); 219 | _ADD_CUBEVERT(0, 1, au * bv * grad_out); 220 | _ADD_CUBEVERT(1, 0, bu * av * grad_out); 221 | _ADD_CUBEVERT(1, 1, bu * bv * grad_out); 222 | #undef _ADD_CUBEVERT 223 | 224 | } 225 | 226 | __device__ __host__ __inline__ float 227 | multi_cubemap_sample( 228 | const float* __restrict__ cubemap1, // (6, face_reso, face_reso, n_channels) 229 | const float* __restrict__ cubemap2, // (6, face_reso, face_reso, n_channels) 230 | const CubemapBilerpQuery& query, 231 | float interp_wt, 232 | int face_reso, 233 | int n_channels, 234 | int chnl_id) { 235 | const float val1 = cubemap_sample(cubemap1, 236 | query, 237 | face_reso, 238 | n_channels, 239 | chnl_id); 240 | const float val2 = cubemap_sample(cubemap2, 241 | query, 242 | face_reso, 243 | n_channels, 244 | chnl_id); 245 | return lerp(val1, val2, interp_wt); 246 | } 247 | 248 | __device__ __inline__ void 249 | multi_cubemap_sample_backward( 250 | float* __restrict__ cubemap_grad1, // (6, face_reso, face_reso, n_channels) 251 | float* __restrict__ cubemap_grad2, // (6, face_reso, face_reso, n_channels) 252 | const CubemapBilerpQuery& query, 253 | float interp_wt, 254 | int face_reso, 255 | int n_channels, 256 | float grad_out, 257 | int chnl_id, 258 | bool* __restrict__ mask_out1 = nullptr, 259 | bool* __restrict__ mask_out2 = nullptr) { 260 | if (cubemap_grad1 == nullptr) return; 261 | cubemap_sample_backward(cubemap_grad1, 262 | query, 263 | face_reso, 264 | n_channels, 265 | grad_out * (1.f - interp_wt), 266 | chnl_id, 267 | mask_out1); 268 | cubemap_sample_backward(cubemap_grad2, 269 | query, 270 | face_reso, 271 | n_channels, 272 | grad_out * interp_wt, 273 | chnl_id, 274 | mask_out1 == nullptr ? nullptr : mask_out2); 275 | } 276 | 277 | 278 | } // namespace device 279 | } // namespace 280 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/include/cuda_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "util.hpp" 8 | 9 | 10 | #define DEVICE_GUARD(_ten) \ 11 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); 12 | 13 | #define CUDA_GET_THREAD_ID(tid, Q) const int tid = blockIdx.x * blockDim.x + threadIdx.x; \ 14 | if (tid >= Q) return 15 | #define CUDA_GET_THREAD_ID_U64(tid, Q) const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \ 16 | if (tid >= Q) return 17 | #define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1) 18 | #define CUDA_CHECK_ERRORS \ 19 | cudaError_t err = cudaGetLastError(); \ 20 | if (err != cudaSuccess) \ 21 | printf("Error in svox2.%s : %s\n", __FUNCTION__, cudaGetErrorString(err)) 22 | 23 | #define CUDA_MAX_THREADS at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock 24 | 25 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 26 | #else 27 | __device__ inline double atomicAdd(double* address, double val){ 28 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 29 | unsigned long long int old = *address_as_ull, assumed; 30 | do { 31 | assumed = old; 32 | old = atomicCAS(address_as_ull, assumed, 33 | __double_as_longlong(val + __longlong_as_double(assumed))); 34 | } while (assumed != old); 35 | return __longlong_as_double(old); 36 | } 37 | #endif 38 | 39 | __device__ inline void atomicMax(float* result, float value){ 40 | unsigned* result_as_u = (unsigned*)result; 41 | unsigned old = *result_as_u, assumed; 42 | do { 43 | assumed = old; 44 | old = atomicCAS(result_as_u, assumed, 45 | __float_as_int(fmaxf(value, __int_as_float(assumed)))); 46 | } while (old != assumed); 47 | return; 48 | } 49 | 50 | __device__ inline void atomicMax(double* result, double value){ 51 | unsigned long long int* result_as_ull = (unsigned long long int*)result; 52 | unsigned long long int old = *result_as_ull, assumed; 53 | do { 54 | assumed = old; 55 | old = atomicCAS(result_as_ull, assumed, 56 | __double_as_longlong(fmaxf(value, __longlong_as_double(assumed)))); 57 | } while (old != assumed); 58 | return; 59 | } 60 | 61 | __device__ __inline__ void transform_coord(float* __restrict__ point, 62 | const float* __restrict__ scaling, 63 | const float* __restrict__ offset) { 64 | point[0] = fmaf(point[0], scaling[0], offset[0]); // a*b + c 65 | point[1] = fmaf(point[1], scaling[1], offset[1]); // a*b + c 66 | point[2] = fmaf(point[2], scaling[2], offset[2]); // a*b + c 67 | } 68 | 69 | // Linear interp 70 | // Subtract and fused multiply-add 71 | // (1-w) a + w b 72 | template 73 | __host__ __device__ __inline__ T lerp(T a, T b, T w) { 74 | return fmaf(w, b - a, a); 75 | } 76 | 77 | __device__ __inline__ static float _norm( 78 | const float* __restrict__ dir) { 79 | // return sqrtf(dir[0] * dir[0] + dir[1] * dir[1] + dir[2] * dir[2]); 80 | return norm3df(dir[0], dir[1], dir[2]); 81 | } 82 | 83 | __device__ __inline__ static float _rnorm( 84 | const float* __restrict__ dir) { 85 | // return 1.f / _norm(dir); 86 | return rnorm3df(dir[0], dir[1], dir[2]); 87 | } 88 | 89 | __host__ __device__ __inline__ static void xsuby3d( 90 | float* __restrict__ x, 91 | const float* __restrict__ y) { 92 | x[0] -= y[0]; 93 | x[1] -= y[1]; 94 | x[2] -= y[2]; 95 | } 96 | 97 | __host__ __device__ __inline__ static float _dot( 98 | const float* __restrict__ x, 99 | const float* __restrict__ y) { 100 | return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]; 101 | } 102 | 103 | __host__ __device__ __inline__ static void _cross( 104 | const float* __restrict__ a, 105 | const float* __restrict__ b, 106 | float* __restrict__ out) { 107 | out[0] = a[1] * b[2] - a[2] * b[1]; 108 | out[1] = a[2] * b[0] - a[0] * b[2]; 109 | out[2] = a[0] * b[1] - a[1] * b[0]; 110 | } 111 | 112 | __device__ __inline__ static float _dist_ray_to_origin( 113 | const float* __restrict__ origin, 114 | const float* __restrict__ dir) { 115 | // dir must be unit vector 116 | float tmp[3]; 117 | _cross(origin, dir, tmp); 118 | return _norm(tmp); 119 | } 120 | 121 | #define int_div2_ceil(x) ((((x) - 1) >> 1) + 1) 122 | 123 | __host__ __inline__ cudaError_t cuda_assert( 124 | const cudaError_t code, const char* const file, 125 | const int line, const bool abort) { 126 | if (code != cudaSuccess) { 127 | fprintf(stderr, "cuda_assert: %s %s %s %d\n", cudaGetErrorName(code) ,cudaGetErrorString(code), 128 | file, line); 129 | 130 | if (abort) { 131 | cudaDeviceReset(); 132 | exit(code); 133 | } 134 | } 135 | 136 | return code; 137 | } 138 | 139 | #define cuda(...) cuda_assert((cuda##__VA_ARGS__), __FILE__, __LINE__, true); 140 | 141 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/include/data_spec.hpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include "util.hpp" 4 | #include 5 | 6 | using torch::Tensor; 7 | 8 | enum BasisType { 9 | // For svox 1 compatibility 10 | // BASIS_TYPE_RGBA = 0 11 | BASIS_TYPE_SH = 1, 12 | // BASIS_TYPE_SG = 2 13 | // BASIS_TYPE_ASG = 3 14 | BASIS_TYPE_3D_TEXTURE = 4, 15 | BASIS_TYPE_MLP = 255, 16 | }; 17 | 18 | struct SparseGridSpec { 19 | Tensor density_data; 20 | Tensor sh_data; 21 | Tensor links; 22 | Tensor _offset; 23 | Tensor _scaling; 24 | 25 | Tensor background_links; 26 | Tensor background_data; 27 | 28 | int basis_dim; 29 | uint8_t basis_type; 30 | Tensor basis_data; 31 | 32 | inline void check() { 33 | CHECK_INPUT(density_data); 34 | CHECK_INPUT(sh_data); 35 | CHECK_INPUT(links); 36 | if (background_links.defined()) { 37 | CHECK_INPUT(background_links); 38 | CHECK_INPUT(background_data); 39 | TORCH_CHECK(background_links.ndimension() == 40 | 2); // (H, W) -> [N] \cup {-1} 41 | TORCH_CHECK(background_data.ndimension() == 3); // (N, D, C) -> R 42 | } 43 | if (basis_data.defined()) { 44 | CHECK_INPUT(basis_data); 45 | } 46 | CHECK_CPU_INPUT(_offset); 47 | CHECK_CPU_INPUT(_scaling); 48 | TORCH_CHECK(density_data.ndimension() == 2); 49 | TORCH_CHECK(sh_data.ndimension() == 2); 50 | TORCH_CHECK(links.ndimension() == 3); 51 | } 52 | }; 53 | 54 | struct GridOutputGrads { 55 | torch::Tensor grad_density_out; 56 | torch::Tensor grad_sh_out; 57 | torch::Tensor grad_basis_out; 58 | torch::Tensor grad_background_out; 59 | 60 | torch::Tensor mask_out; 61 | torch::Tensor mask_background_out; 62 | inline void check() { 63 | if (grad_density_out.defined()) { 64 | CHECK_INPUT(grad_density_out); 65 | } 66 | if (grad_sh_out.defined()) { 67 | CHECK_INPUT(grad_sh_out); 68 | } 69 | if (grad_basis_out.defined()) { 70 | CHECK_INPUT(grad_basis_out); 71 | } 72 | if (grad_background_out.defined()) { 73 | CHECK_INPUT(grad_background_out); 74 | } 75 | if (mask_out.defined() && mask_out.size(0) > 0) { 76 | CHECK_INPUT(mask_out); 77 | } 78 | if (mask_background_out.defined() && mask_background_out.size(0) > 0) { 79 | CHECK_INPUT(mask_background_out); 80 | } 81 | } 82 | }; 83 | 84 | struct CameraSpec { 85 | torch::Tensor c2w; 86 | float fx; 87 | float fy; 88 | float cx; 89 | float cy; 90 | int width; 91 | int height; 92 | 93 | float ndc_coeffx; 94 | float ndc_coeffy; 95 | 96 | inline void check() { 97 | CHECK_INPUT(c2w); 98 | TORCH_CHECK(c2w.is_floating_point()); 99 | TORCH_CHECK(c2w.ndimension() == 2); 100 | TORCH_CHECK(c2w.size(1) == 4); 101 | } 102 | }; 103 | 104 | struct RaysSpec { 105 | Tensor origins; 106 | Tensor dirs; 107 | inline void check() { 108 | CHECK_INPUT(origins); 109 | CHECK_INPUT(dirs); 110 | TORCH_CHECK(origins.is_floating_point()); 111 | TORCH_CHECK(dirs.is_floating_point()); 112 | } 113 | }; 114 | 115 | struct RenderOptions { 116 | float background_brightness; 117 | // float step_epsilon; 118 | float step_size; 119 | float sigma_thresh; 120 | float stop_thresh; 121 | 122 | float near_clip; 123 | bool use_spheric_clip; 124 | 125 | bool last_sample_opaque; 126 | 127 | // bool randomize; 128 | // float random_sigma_std; 129 | // float random_sigma_std_background; 130 | // 32-bit RNG state masks 131 | // uint32_t _m1, _m2, _m3; 132 | 133 | // int msi_start_layer = 0; 134 | // int msi_end_layer = 66; 135 | }; 136 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/include/data_spec_packed.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include "data_spec.hpp" 5 | #include "cuda_util.cuh" 6 | #include "random_util.cuh" 7 | 8 | namespace { 9 | namespace device { 10 | 11 | struct PackedSparseGridSpec { 12 | PackedSparseGridSpec(SparseGridSpec& spec) 13 | : 14 | density_data(spec.density_data.data_ptr()), 15 | sh_data(spec.sh_data.data_ptr()), 16 | links(spec.links.data_ptr()), 17 | basis_type(spec.basis_type), 18 | basis_data(spec.basis_data.defined() ? spec.basis_data.data_ptr() : nullptr), 19 | background_links(spec.background_links.defined() ? 20 | spec.background_links.data_ptr() : 21 | nullptr), 22 | background_data(spec.background_data.defined() ? 23 | spec.background_data.data_ptr() : 24 | nullptr), 25 | size{(int)spec.links.size(0), 26 | (int)spec.links.size(1), 27 | (int)spec.links.size(2)}, 28 | stride_x{(int)spec.links.stride(0)}, 29 | background_reso{ 30 | spec.background_links.defined() ? (int)spec.background_links.size(1) : 0, 31 | }, 32 | background_nlayers{ 33 | spec.background_data.defined() ? (int)spec.background_data.size(1) : 0 34 | }, 35 | basis_dim(spec.basis_dim), 36 | sh_data_dim((int)spec.sh_data.size(1)), 37 | basis_reso(spec.basis_data.defined() ? spec.basis_data.size(0) : 0), 38 | _offset{spec._offset.data_ptr()[0], 39 | spec._offset.data_ptr()[1], 40 | spec._offset.data_ptr()[2]}, 41 | _scaling{spec._scaling.data_ptr()[0], 42 | spec._scaling.data_ptr()[1], 43 | spec._scaling.data_ptr()[2]} { 44 | } 45 | 46 | float* __restrict__ density_data; 47 | float* __restrict__ sh_data; 48 | const int32_t* __restrict__ links; 49 | 50 | const uint8_t basis_type; 51 | float* __restrict__ basis_data; 52 | 53 | const int32_t* __restrict__ background_links; 54 | float* __restrict__ background_data; 55 | 56 | const int size[3], stride_x; 57 | const int background_reso, background_nlayers; 58 | 59 | const int basis_dim, sh_data_dim, basis_reso; 60 | const float _offset[3]; 61 | const float _scaling[3]; 62 | }; 63 | 64 | struct PackedGridOutputGrads { 65 | PackedGridOutputGrads(GridOutputGrads& grads) : 66 | grad_density_out(grads.grad_density_out.defined() ? grads.grad_density_out.data_ptr() : nullptr), 67 | grad_sh_out(grads.grad_sh_out.defined() ? grads.grad_sh_out.data_ptr() : nullptr), 68 | grad_basis_out(grads.grad_basis_out.defined() ? grads.grad_basis_out.data_ptr() : nullptr), 69 | grad_background_out(grads.grad_background_out.defined() ? grads.grad_background_out.data_ptr() : nullptr), 70 | mask_out((grads.mask_out.defined() && grads.mask_out.size(0) > 0) ? grads.mask_out.data_ptr() : nullptr), 71 | mask_background_out((grads.mask_background_out.defined() && grads.mask_background_out.size(0) > 0) ? grads.mask_background_out.data_ptr() : nullptr) 72 | {} 73 | float* __restrict__ grad_density_out; 74 | float* __restrict__ grad_sh_out; 75 | float* __restrict__ grad_basis_out; 76 | float* __restrict__ grad_background_out; 77 | 78 | bool* __restrict__ mask_out; 79 | bool* __restrict__ mask_background_out; 80 | }; 81 | 82 | struct PackedCameraSpec { 83 | PackedCameraSpec(CameraSpec& cam) : 84 | c2w(cam.c2w.packed_accessor32()), 85 | fx(cam.fx), fy(cam.fy), 86 | cx(cam.cx), cy(cam.cy), 87 | width(cam.width), height(cam.height), 88 | ndc_coeffx(cam.ndc_coeffx), ndc_coeffy(cam.ndc_coeffy) {} 89 | const torch::PackedTensorAccessor32 90 | c2w; 91 | float fx; 92 | float fy; 93 | float cx; 94 | float cy; 95 | int width; 96 | int height; 97 | 98 | float ndc_coeffx; 99 | float ndc_coeffy; 100 | }; 101 | 102 | struct PackedRaysSpec { 103 | const torch::PackedTensorAccessor32 origins; 104 | const torch::PackedTensorAccessor32 dirs; 105 | PackedRaysSpec(RaysSpec& spec) : 106 | origins(spec.origins.packed_accessor32()), 107 | dirs(spec.dirs.packed_accessor32()) 108 | { } 109 | }; 110 | 111 | struct SingleRaySpec { 112 | SingleRaySpec() = default; 113 | __device__ SingleRaySpec(const float* __restrict__ origin, const float* __restrict__ dir) 114 | : origin{origin[0], origin[1], origin[2]}, 115 | dir{dir[0], dir[1], dir[2]} {} 116 | __device__ void set(const float* __restrict__ origin, const float* __restrict__ dir) { 117 | #pragma unroll 3 118 | for (int i = 0; i < 3; ++i) { 119 | this->origin[i] = origin[i]; 120 | this->dir[i] = dir[i]; 121 | } 122 | } 123 | 124 | float origin[3]; 125 | float dir[3]; 126 | float tmin, tmax, world_step; 127 | 128 | float pos[3]; 129 | int32_t l[3]; 130 | RandomEngine32 rng; 131 | }; 132 | 133 | } // namespace device 134 | } // namespace 135 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/include/random_util.cuh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | #pragma once 3 | #include 4 | #include 5 | 6 | // A custom xorshift random generator 7 | // Maybe replace with some CUDA internal stuff? 8 | struct RandomEngine32 { 9 | uint32_t x, y, z; 10 | 11 | // Inclusive both 12 | __host__ __device__ 13 | uint32_t randint(uint32_t lo, uint32_t hi) { 14 | if (hi <= lo) return lo; 15 | uint32_t z = (*this)(); 16 | return z % (hi - lo + 1) + lo; 17 | } 18 | 19 | __host__ __device__ 20 | void rand2(float* out1, float* out2) { 21 | const uint32_t z = (*this)(); 22 | const uint32_t fmax = (1 << 16); 23 | const uint32_t z1 = z >> 16; 24 | const uint32_t z2 = z & (fmax - 1); 25 | const float ifmax = 1.f / fmax; 26 | 27 | *out1 = z1 * ifmax; 28 | *out2 = z2 * ifmax; 29 | } 30 | 31 | __host__ __device__ 32 | float rand() { 33 | uint32_t z = (*this)(); 34 | return float(z) / (1LL << 32); 35 | } 36 | 37 | 38 | __host__ __device__ 39 | void randn2(float* out1, float* out2) { 40 | rand2(out1, out2); 41 | // Box-Muller transform 42 | const float srlog = sqrtf(-2 * logf(*out1 + 1e-32f)); 43 | *out2 *= 2 * M_PI; 44 | *out1 = srlog * cosf(*out2); 45 | *out2 = srlog * sinf(*out2); 46 | } 47 | 48 | __host__ __device__ 49 | float randn() { 50 | float x, y; 51 | rand2(&x, &y); 52 | // Box-Muller transform 53 | return sqrtf(-2 * logf(x + 1e-32f))* cosf(2 * M_PI * y); 54 | } 55 | 56 | __host__ __device__ 57 | uint32_t operator()() { 58 | uint32_t t; 59 | x ^= x << 16; 60 | x ^= x >> 5; 61 | x ^= x << 1; 62 | t = x; 63 | x = y; 64 | y = z; 65 | z = t ^ x ^ y; 66 | return z; 67 | } 68 | }; 69 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/include/util.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // Changed from x.type().is_cuda() due to deprecation 3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 4 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor") 5 | #define CHECK_CONTIGUOUS(x) \ 6 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) \ 8 | CHECK_CUDA(x); \ 9 | CHECK_CONTIGUOUS(x) 10 | #define CHECK_CPU_INPUT(x) \ 11 | CHECK_CPU(x); \ 12 | CHECK_CONTIGUOUS(x) 13 | 14 | #if defined(__CUDACC__) 15 | // #define _EXP(x) expf(x) // SLOW EXP 16 | #define _EXP(x) __expf(x) // FAST EXP 17 | #define _SIGMOID(x) (1 / (1 + _EXP(-(x)))) 18 | 19 | #else 20 | 21 | #define _EXP(x) expf(x) 22 | #define _SIGMOID(x) (1 / (1 + expf(-(x)))) 23 | #endif 24 | #define _SQR(x) ((x) * (x)) 25 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/optim_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | // Optimizer-related kernels 3 | 4 | #include 5 | #include "cuda_util.cuh" 6 | 7 | namespace { 8 | 9 | const int RMSPROP_STEP_CUDA_THREADS = 256; 10 | const int MIN_BLOCKS_PER_SM = 4; 11 | 12 | namespace device { 13 | 14 | // RMSPROP 15 | __inline__ __device__ void rmsprop_once( 16 | float* __restrict__ ptr_data, 17 | float* __restrict__ ptr_rms, 18 | float* __restrict__ ptr_grad, 19 | const float beta, const float lr, const float epsilon, float minval) { 20 | float rms = *ptr_rms; 21 | rms = rms == 0.f ? _SQR(*ptr_grad) : lerp(_SQR(*ptr_grad), rms, beta); 22 | *ptr_rms = rms; 23 | *ptr_data = fmaxf(*ptr_data - lr * (*ptr_grad) / (sqrtf(rms) + epsilon), minval); 24 | *ptr_grad = 0.f; 25 | } 26 | 27 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 28 | __global__ void rmsprop_step_kernel( 29 | torch::PackedTensorAccessor64 all_data, 30 | torch::PackedTensorAccessor64 all_rms, 31 | torch::PackedTensorAccessor64 all_grad, 32 | float beta, 33 | float lr, 34 | float epsilon, 35 | float minval, 36 | float lr_last) { 37 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 38 | int32_t chnl = tid % all_data.size(1); 39 | rmsprop_once(all_data.data() + tid, 40 | all_rms.data() + tid, 41 | all_grad.data() + tid, 42 | beta, 43 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 44 | epsilon, 45 | minval); 46 | } 47 | 48 | 49 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 50 | __global__ void rmsprop_mask_step_kernel( 51 | torch::PackedTensorAccessor64 all_data, 52 | torch::PackedTensorAccessor64 all_rms, 53 | torch::PackedTensorAccessor64 all_grad, 54 | const bool* __restrict__ mask, 55 | float beta, 56 | float lr, 57 | float epsilon, 58 | float minval, 59 | float lr_last) { 60 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 61 | if (mask[tid / all_data.size(1)] == false) return; 62 | int32_t chnl = tid % all_data.size(1); 63 | rmsprop_once(all_data.data() + tid, 64 | all_rms.data() + tid, 65 | all_grad.data() + tid, 66 | beta, 67 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 68 | epsilon, 69 | minval); 70 | } 71 | 72 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 73 | __global__ void rmsprop_index_step_kernel( 74 | torch::PackedTensorAccessor64 all_data, 75 | torch::PackedTensorAccessor64 all_rms, 76 | torch::PackedTensorAccessor64 all_grad, 77 | torch::PackedTensorAccessor32 indices, 78 | float beta, 79 | float lr, 80 | float epsilon, 81 | float minval, 82 | float lr_last) { 83 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1)); 84 | int32_t i = indices[tid / all_data.size(1)]; 85 | int32_t chnl = tid % all_data.size(1); 86 | size_t off = i * all_data.size(1) + chnl; 87 | rmsprop_once(all_data.data() + off, all_rms.data() + off, 88 | all_grad.data() + off, 89 | beta, 90 | (chnl == all_data.size(1) - 1) ? lr_last : lr, 91 | epsilon, 92 | minval); 93 | } 94 | 95 | 96 | // SGD 97 | __inline__ __device__ void sgd_once( 98 | float* __restrict__ ptr_data, 99 | float* __restrict__ ptr_grad, 100 | const float lr) { 101 | *ptr_data -= lr * (*ptr_grad); 102 | *ptr_grad = 0.f; 103 | } 104 | 105 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 106 | __global__ void sgd_step_kernel( 107 | torch::PackedTensorAccessor64 all_data, 108 | torch::PackedTensorAccessor64 all_grad, 109 | float lr, 110 | float lr_last) { 111 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 112 | int32_t chnl = tid % all_data.size(1); 113 | sgd_once(all_data.data() + tid, 114 | all_grad.data() + tid, 115 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 116 | } 117 | 118 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 119 | __global__ void sgd_mask_step_kernel( 120 | torch::PackedTensorAccessor64 all_data, 121 | torch::PackedTensorAccessor64 all_grad, 122 | const bool* __restrict__ mask, 123 | float lr, 124 | float lr_last) { 125 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1)); 126 | if (mask[tid / all_data.size(1)] == false) return; 127 | int32_t chnl = tid % all_data.size(1); 128 | sgd_once(all_data.data() + tid, 129 | all_grad.data() + tid, 130 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 131 | } 132 | 133 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM) 134 | __global__ void sgd_index_step_kernel( 135 | torch::PackedTensorAccessor64 all_data, 136 | torch::PackedTensorAccessor64 all_grad, 137 | torch::PackedTensorAccessor32 indices, 138 | float lr, 139 | float lr_last) { 140 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1)); 141 | int32_t i = indices[tid / all_data.size(1)]; 142 | int32_t chnl = tid % all_data.size(1); 143 | size_t off = i * all_data.size(1) + chnl; 144 | sgd_once(all_data.data() + off, 145 | all_grad.data() + off, 146 | (chnl == all_data.size(1) - 1) ? lr_last : lr); 147 | } 148 | 149 | 150 | 151 | } // namespace device 152 | } // namespace 153 | 154 | void rmsprop_step( 155 | torch::Tensor data, 156 | torch::Tensor rms, 157 | torch::Tensor grad, 158 | torch::Tensor indexer, 159 | float beta, 160 | float lr, 161 | float epsilon, 162 | float minval, 163 | float lr_last) { 164 | 165 | DEVICE_GUARD(data); 166 | CHECK_INPUT(data); 167 | CHECK_INPUT(rms); 168 | CHECK_INPUT(grad); 169 | CHECK_INPUT(indexer); 170 | 171 | if (lr_last < 0.f) lr_last = lr; 172 | 173 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS; 174 | 175 | if (indexer.dim() == 0) { 176 | const size_t Q = data.size(0) * data.size(1); 177 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 178 | device::rmsprop_step_kernel<<>>( 179 | data.packed_accessor64(), 180 | rms.packed_accessor64(), 181 | grad.packed_accessor64(), 182 | beta, 183 | lr, 184 | epsilon, 185 | minval, 186 | lr_last); 187 | } else if (indexer.size(0) == 0) { 188 | // Skip 189 | } else if (indexer.scalar_type() == at::ScalarType::Bool) { 190 | const size_t Q = data.size(0) * data.size(1); 191 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 192 | device::rmsprop_mask_step_kernel<<>>( 193 | data.packed_accessor64(), 194 | rms.packed_accessor64(), 195 | grad.packed_accessor64(), 196 | indexer.data_ptr(), 197 | beta, 198 | lr, 199 | epsilon, 200 | minval, 201 | lr_last); 202 | } else { 203 | const size_t Q = indexer.size(0) * data.size(1); 204 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 205 | device::rmsprop_index_step_kernel<<>>( 206 | data.packed_accessor64(), 207 | rms.packed_accessor64(), 208 | grad.packed_accessor64(), 209 | indexer.packed_accessor32(), 210 | beta, 211 | lr, 212 | epsilon, 213 | minval, 214 | lr_last); 215 | } 216 | 217 | CUDA_CHECK_ERRORS; 218 | } 219 | 220 | void sgd_step( 221 | torch::Tensor data, 222 | torch::Tensor grad, 223 | torch::Tensor indexer, 224 | float lr, 225 | float lr_last) { 226 | 227 | DEVICE_GUARD(data); 228 | CHECK_INPUT(data); 229 | CHECK_INPUT(grad); 230 | CHECK_INPUT(indexer); 231 | 232 | if (lr_last < 0.f) lr_last = lr; 233 | 234 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS; 235 | 236 | if (indexer.dim() == 0) { 237 | const size_t Q = data.size(0) * data.size(1); 238 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 239 | device::sgd_step_kernel<<>>( 240 | data.packed_accessor64(), 241 | grad.packed_accessor64(), 242 | lr, 243 | lr_last); 244 | } else if (indexer.size(0) == 0) { 245 | // Skip 246 | } else if (indexer.scalar_type() == at::ScalarType::Bool) { 247 | const size_t Q = data.size(0) * data.size(1); 248 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 249 | device::sgd_mask_step_kernel<<>>( 250 | data.packed_accessor64(), 251 | grad.packed_accessor64(), 252 | indexer.data_ptr(), 253 | lr, 254 | lr_last); 255 | } else { 256 | const size_t Q = indexer.size(0) * data.size(1); 257 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads); 258 | device::sgd_index_step_kernel<<>>( 259 | data.packed_accessor64(), 260 | grad.packed_accessor64(), 261 | indexer.packed_accessor32(), 262 | lr, 263 | lr_last); 264 | } 265 | 266 | CUDA_CHECK_ERRORS; 267 | } 268 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/csrc/svox2.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Alex Yu 2 | 3 | // This file contains only Python bindings 4 | #include "data_spec.hpp" 5 | #include 6 | #include 7 | #include 8 | 9 | using torch::Tensor; 10 | 11 | std::tuple sample_grid(SparseGridSpec &, Tensor, 12 | bool); 13 | void sample_grid_backward(SparseGridSpec &, Tensor, Tensor, Tensor, Tensor, 14 | Tensor, bool); 15 | 16 | // ** NeRF rendering formula (trilerp) 17 | Tensor volume_render_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &); 18 | Tensor volume_render_w_alpha_depth_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &, Tensor, Tensor, Tensor, Tensor, bool); 19 | Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &, 20 | RenderOptions &); 21 | void volume_render_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 22 | Tensor, Tensor, GridOutputGrads &); 23 | void volume_render_w_alpha_depth_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 24 | Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, GridOutputGrads &); 25 | Tensor volume_render_reg_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &, float, Tensor); 26 | void volume_render_reg_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, float, float, Tensor, Tensor, Tensor, GridOutputGrads &); 27 | void volume_render_cuvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 28 | Tensor, float, float, Tensor, GridOutputGrads &); 29 | // Expected termination (depth) rendering 30 | torch::Tensor volume_render_expected_term(SparseGridSpec &, RaysSpec &, 31 | RenderOptions &); 32 | // Depth rendering based on sigma-threshold as in Dex-NeRF 33 | torch::Tensor volume_render_sigma_thresh(SparseGridSpec &, RaysSpec &, 34 | RenderOptions &, float); 35 | 36 | // ** NV rendering formula (trilerp) 37 | Tensor volume_render_nvol(SparseGridSpec &, RaysSpec &, RenderOptions &); 38 | void volume_render_nvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 39 | Tensor, Tensor, GridOutputGrads &); 40 | void volume_render_nvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 41 | Tensor, float, float, Tensor, GridOutputGrads &); 42 | 43 | // ** NeRF rendering formula (nearest-neighbor, infinitely many steps) 44 | Tensor volume_render_svox1(SparseGridSpec &, RaysSpec &, RenderOptions &); 45 | void volume_render_svox1_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, 46 | Tensor, Tensor, GridOutputGrads &); 47 | void volume_render_svox1_fused(SparseGridSpec &, RaysSpec &, RenderOptions &, 48 | Tensor, float, float, Tensor, GridOutputGrads &); 49 | 50 | // Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &, 51 | // RenderOptions &); 52 | // 53 | // void volume_render_cuvol_image_backward(SparseGridSpec &, CameraSpec &, 54 | // RenderOptions &, Tensor, Tensor, 55 | // GridOutputGrads &); 56 | 57 | // Misc 58 | Tensor dilate(Tensor); 59 | void accel_dist_prop(Tensor); 60 | void grid_weight_render(Tensor, CameraSpec &, float, float, bool, Tensor, 61 | Tensor, Tensor); 62 | // void sample_cubemap(Tensor, Tensor, bool, Tensor); 63 | 64 | // Loss 65 | Tensor tv(Tensor, Tensor, int, int, bool, float, bool, float, float); 66 | void tv_grad(Tensor, Tensor, int, int, float, bool, float, bool, float, float, 67 | Tensor); 68 | void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool, 69 | float, bool, bool, float, float, Tensor); 70 | void msi_tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, float, float, Tensor); 71 | void lumisphere_tv_grad_sparse(SparseGridSpec &, Tensor, Tensor, Tensor, float, 72 | float, float, float, GridOutputGrads &); 73 | 74 | // Optim 75 | void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float, 76 | float); 77 | void sgd_step(Tensor, Tensor, Tensor, float, float); 78 | 79 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 80 | #define _REG_FUNC(funname) m.def(#funname, &funname) 81 | _REG_FUNC(sample_grid); 82 | _REG_FUNC(sample_grid_backward); 83 | _REG_FUNC(volume_render_cuvol); 84 | _REG_FUNC(volume_render_cuvol_image); 85 | _REG_FUNC(volume_render_cuvol_backward); 86 | _REG_FUNC(volume_render_reg_cuvol); 87 | _REG_FUNC(volume_render_reg_cuvol_backward); 88 | _REG_FUNC(volume_render_cuvol_fused); 89 | _REG_FUNC(volume_render_expected_term); 90 | _REG_FUNC(volume_render_sigma_thresh); 91 | 92 | _REG_FUNC(volume_render_nvol); 93 | _REG_FUNC(volume_render_nvol_backward); 94 | _REG_FUNC(volume_render_nvol_fused); 95 | 96 | _REG_FUNC(volume_render_svox1); 97 | _REG_FUNC(volume_render_svox1_backward); 98 | _REG_FUNC(volume_render_svox1_fused); 99 | 100 | _REG_FUNC(volume_render_w_alpha_depth_cuvol); 101 | _REG_FUNC(volume_render_w_alpha_depth_cuvol_backward); 102 | // _REG_FUNC(volume_render_cuvol_image); 103 | // _REG_FUNC(volume_render_cuvol_image_backward); 104 | 105 | // Loss 106 | _REG_FUNC(tv); 107 | _REG_FUNC(tv_grad); 108 | _REG_FUNC(tv_grad_sparse); 109 | _REG_FUNC(msi_tv_grad_sparse); 110 | _REG_FUNC(lumisphere_tv_grad_sparse); 111 | 112 | // Misc 113 | _REG_FUNC(dilate); 114 | _REG_FUNC(accel_dist_prop); 115 | _REG_FUNC(grid_weight_render); 116 | // _REG_FUNC(sample_cubemap); 117 | 118 | // Optimizer 119 | _REG_FUNC(rmsprop_step); 120 | _REG_FUNC(sgd_step); 121 | #undef _REG_FUNC 122 | 123 | py::class_(m, "SparseGridSpec") 124 | .def(py::init<>()) 125 | .def_readwrite("density_data", &SparseGridSpec::density_data) 126 | .def_readwrite("sh_data", &SparseGridSpec::sh_data) 127 | .def_readwrite("links", &SparseGridSpec::links) 128 | .def_readwrite("_offset", &SparseGridSpec::_offset) 129 | .def_readwrite("_scaling", &SparseGridSpec::_scaling) 130 | .def_readwrite("basis_dim", &SparseGridSpec::basis_dim) 131 | .def_readwrite("basis_type", &SparseGridSpec::basis_type) 132 | .def_readwrite("basis_data", &SparseGridSpec::basis_data) 133 | .def_readwrite("background_links", &SparseGridSpec::background_links) 134 | .def_readwrite("background_data", &SparseGridSpec::background_data); 135 | 136 | py::class_(m, "CameraSpec") 137 | .def(py::init<>()) 138 | .def_readwrite("c2w", &CameraSpec::c2w) 139 | .def_readwrite("fx", &CameraSpec::fx) 140 | .def_readwrite("fy", &CameraSpec::fy) 141 | .def_readwrite("cx", &CameraSpec::cx) 142 | .def_readwrite("cy", &CameraSpec::cy) 143 | .def_readwrite("width", &CameraSpec::width) 144 | .def_readwrite("height", &CameraSpec::height) 145 | .def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx) 146 | .def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy); 147 | 148 | py::class_(m, "RaysSpec") 149 | .def(py::init<>()) 150 | .def_readwrite("origins", &RaysSpec::origins) 151 | .def_readwrite("dirs", &RaysSpec::dirs); 152 | 153 | py::class_(m, "RenderOptions") 154 | .def(py::init<>()) 155 | .def_readwrite("background_brightness", 156 | &RenderOptions::background_brightness) 157 | .def_readwrite("step_size", &RenderOptions::step_size) 158 | .def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh) 159 | .def_readwrite("stop_thresh", &RenderOptions::stop_thresh) 160 | .def_readwrite("near_clip", &RenderOptions::near_clip) 161 | .def_readwrite("use_spheric_clip", &RenderOptions::use_spheric_clip) 162 | .def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque); 163 | // .def_readwrite("randomize", &RenderOptions::randomize) 164 | // .def_readwrite("random_sigma_std", &RenderOptions::random_sigma_std) 165 | // .def_readwrite("random_sigma_std_background", 166 | // &RenderOptions::random_sigma_std_background) 167 | // .def_readwrite("_m1", &RenderOptions::_m1) 168 | // .def_readwrite("_m2", &RenderOptions::_m2) 169 | // .def_readwrite("_m3", &RenderOptions::_m3); 170 | 171 | py::class_(m, "GridOutputGrads") 172 | .def(py::init<>()) 173 | .def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out) 174 | .def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out) 175 | .def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out) 176 | .def_readwrite("grad_background_out", 177 | &GridOutputGrads::grad_background_out) 178 | .def_readwrite("mask_out", &GridOutputGrads::mask_out) 179 | .def_readwrite("mask_background_out", 180 | &GridOutputGrads::mask_background_out); 181 | } 182 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/defs.py: -------------------------------------------------------------------------------- 1 | # Basis types (copied from C++ data_spec.hpp) 2 | BASIS_TYPE_SH = 1 3 | BASIS_TYPE_3D_TEXTURE = 4 4 | BASIS_TYPE_MLP = 255 5 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/svox2/version.py: -------------------------------------------------------------------------------- 1 | __version__ = 'voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast' 2 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/test/test_render_gradcheck_alpha.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import torch.nn.functional as F 4 | from svox2.utils import Timing 5 | 6 | torch.random.manual_seed(2) 7 | # torch.random.manual_seed(8289) 8 | 9 | device = 'cuda:0' 10 | torch.cuda.set_device(device) 11 | dtype = torch.float32 12 | grid = svox2.SparseGrid( 13 | reso=128, 14 | center=[0.0, 0.0, 0.0], 15 | radius=[1.0, 1.0, 1.0], 16 | basis_dim=9, 17 | use_z_order=True, 18 | device=device, 19 | background_nlayers=0, 20 | basis_type=svox2.BASIS_TYPE_SH) 21 | grid.opt.backend = 'cuvol' 22 | grid.opt.sigma_thresh = 0.0 23 | grid.opt.stop_thresh = 0.0 24 | grid.opt.background_brightness = 0.0 25 | 26 | print(grid.sh_data.shape) 27 | 28 | # Render using a white color to obtain equivalent to alpha rendering 29 | from svox2.utils import SH_C0 30 | grid.sh_data.data = torch.zeros_like(grid.sh_data.data.view(-1, 3, grid.basis_dim)) 31 | grid.sh_data.data[..., 0] = 0.5 / SH_C0 32 | grid.sh_data.data = grid.sh_data.data.flatten(1, 2) 33 | grid.density_data.data[:] = 1.0 34 | 35 | if grid.use_background: 36 | grid.background_data.data[..., -1] = 0.5 37 | grid.background_data.data[..., :-1] = torch.randn_like( 38 | grid.background_data.data[..., :-1]) * 0.01 39 | 40 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 41 | grid.basis_data.data.normal_() 42 | grid.basis_data.data += 1.0 43 | 44 | ENABLE_TORCH_CHECK = True 45 | # N_RAYS = 5000 #200 * 200 46 | N_RAYS = 200 * 200 47 | origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3 48 | dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 49 | 50 | # from training.networks_plenoxel import RaySampler, PoseSampler 51 | # img_res = 32 52 | # ps = PoseSampler(range_azim=(180, 45), range_polar=(90, 15), radius=8, dist='normal') 53 | # rs = RaySampler(img_resolution=img_res, fov=30) 54 | # pose = ps.sample() 55 | # origins, dirs = rs.sample_at(pose[None, :3, :4], rs.img_resolution) 56 | # origins = origins.cuda() 57 | # dirs = dirs.cuda() 58 | 59 | # origins = torch.clip(origins, -0.8, 0.8) 60 | 61 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype) 62 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype) 63 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 64 | 65 | # start = 71 66 | # end = 72 67 | # origins = origins[start:end] 68 | # dirs = dirs[start:end] 69 | # print(origins.tolist(), dirs.tolist()) 70 | 71 | # breakpoint() 72 | rays = svox2.Rays(origins, dirs) 73 | 74 | rgb_gt = torch.zeros((origins.size(0), 4), device=device, dtype=dtype) 75 | 76 | # grid.requires_grad_(True) 77 | 78 | # samps = grid.volume_render(rays, use_kernel=True) 79 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 80 | 81 | with Timing("ours"): 82 | samps = grid.volume_render(rays, use_kernel=True, render_alpha=True) 83 | assert all([(a == samps[:, 0]).all() for a in samps.T[1:]]), 'Alpha map and rendered color images do not match.' 84 | 85 | # Check internal gradients, i.e. white color vs alpha 86 | density_grads_s = [] 87 | grid.sh_data.grad = None 88 | grid.density_data.grad = None 89 | for i in range(4): 90 | a = grid.volume_render(rays, use_kernel=True, render_alpha=True)[:, i] 91 | s = F.mse_loss(a, rgb_gt[:, i]) 92 | with Timing("ours_backward"): 93 | s.backward() 94 | density_grads_s.append(grid.density_data.grad.clone().cpu()) 95 | grid.sh_data.grad = None 96 | grid.density_data.grad = None 97 | density_grads_s = torch.stack(density_grads_s) 98 | print('Error gradients color/alpha', [(g - density_grads_s[0]).abs().max().item() for g in density_grads_s[1:]]) 99 | 100 | # Check gradients wrt torch implementation 101 | s = F.mse_loss(samps, rgb_gt) 102 | 103 | print(s) 104 | print('bkwd..') 105 | with Timing("ours_backward"): 106 | s.backward() 107 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu() 108 | grid_density_grad_s = grid.density_data.grad.clone().cpu() 109 | grid.sh_data.grad = None 110 | grid.density_data.grad = None 111 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 112 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu() 113 | grid.basis_data.grad = None 114 | if grid.use_background: 115 | grid_bg_grad_s = grid.background_data.grad.clone().cpu() 116 | grid.background_data.grad = None 117 | 118 | if ENABLE_TORCH_CHECK: 119 | with Timing("torch"): 120 | sampt = grid.volume_render(rays, use_kernel=False, render_alpha=True) 121 | assert all([(a == sampt[:, 0]).all() for a in sampt.T[1:]]), 'Alpha map and rendered color images do not match.' 122 | 123 | print('Do ours and torch output match?', torch.isclose(samps, sampt).all().item()) 124 | 125 | s = F.mse_loss(sampt, rgb_gt) 126 | with Timing("torch_backward"): 127 | s.backward() 128 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s) 129 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s) 130 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 131 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu() 132 | if grid.use_background: 133 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s) 134 | 135 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t) 136 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t) 137 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 138 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t) 139 | if grid.use_background: 140 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t) 141 | print('err', torch.abs(samps - sampt).max()) 142 | print('err_sh_grad\n', E.max()) 143 | print(' mean\n', E.mean()) 144 | print('err_density_grad\n', Ed.max()) 145 | print(' mean\n', Ed.mean()) 146 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 147 | print('err_basis_grad\n', Eb.max()) 148 | print(' mean\n', Eb.mean()) 149 | if grid.use_background: 150 | print('err_background_grad\n', Ebg.max()) 151 | print(' mean\n', Ebg.mean()) 152 | print() 153 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max()) 154 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max()) 155 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max()) 156 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max()) 157 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 158 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max()) 159 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max()) 160 | if grid.use_background: 161 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max()) 162 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max()) 163 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/test/test_render_gradcheck_depth.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import torch.nn.functional as F 4 | from svox2.utils import Timing 5 | 6 | torch.random.manual_seed(2) 7 | # torch.random.manual_seed(8289) 8 | 9 | device = 'cuda:0' 10 | torch.cuda.set_device(device) 11 | dtype = torch.float32 12 | grid_res = 128 13 | grid = svox2.SparseGrid( 14 | reso=grid_res, 15 | center=[0.0, 0.0, 0.0], 16 | radius=[1.0, 1.0, 1.0], 17 | basis_dim=9, 18 | use_z_order=True, 19 | device=device, 20 | background_nlayers=0, 21 | basis_type=svox2.BASIS_TYPE_SH) 22 | grid.opt.backend = 'cuvol' 23 | grid.opt.sigma_thresh = 0.0 24 | grid.opt.step_size = 0.5 25 | grid.opt.stop_thresh = 0.0 26 | grid.opt.background_brightness = 1.0 27 | 28 | print(grid.sh_data.shape) 29 | # grid.sh_data.data.normal_() 30 | grid.sh_data.data[..., 0] = 0.5 31 | grid.sh_data.data[..., 1:].normal_(std=0.1) 32 | grid.density_data.data[:] = 1.0 33 | 34 | if grid.use_background: 35 | grid.background_data.data[..., -1] = 0.5 36 | grid.background_data.data[..., :-1] = torch.randn_like( 37 | grid.background_data.data[..., :-1]) * 0.01 38 | 39 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 40 | grid.basis_data.data.normal_() 41 | grid.basis_data.data += 1.0 42 | 43 | ENABLE_TORCH_CHECK = True 44 | # N_RAYS = 5000 #200 * 200 45 | N_RAYS = 200 * 200 46 | origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3 47 | dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 48 | 49 | # from training.networks_plenoxel import RaySampler, PoseSampler 50 | # img_res = 32 51 | # ps = PoseSampler(range_azim=(180, 45), range_polar=(90, 15), radius=8, dist='normal') 52 | # rs = RaySampler(img_resolution=img_res, fov=30) 53 | # pose = ps.sample() 54 | # origins, dirs = rs.sample_at(pose[None, :3, :4], rs.img_resolution) 55 | # origins = origins.cuda() 56 | # dirs = dirs.cuda() 57 | 58 | # origins = torch.clip(origins, -0.8, 0.8) 59 | 60 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype) 61 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype) 62 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 63 | 64 | # start = 71 65 | # end = 72 66 | # origins = origins[start:end] 67 | # dirs = dirs[start:end] 68 | # print(origins.tolist(), dirs.tolist()) 69 | 70 | # breakpoint() 71 | rays = svox2.Rays(origins, dirs) 72 | 73 | dims = 5 74 | rgb_gt = torch.zeros((origins.size(0), dims), device=device, dtype=dtype) 75 | 76 | # grid.requires_grad_(True) 77 | 78 | # samps = grid.volume_render(rays, use_kernel=True) 79 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 80 | 81 | with Timing("ours"): 82 | samps = grid.volume_render(rays, use_kernel=True, render_alpha=dims==5, render_depth=True)[:, -dims:] 83 | # samps_original = grid.volume_render_depth(rays)[:, None]#, use_kernel=True, render_alpha=False, render_depth=True) 84 | # assert (samps == samps_original).all(), 'Depth implementation does not match original plenoxel implementation' # original implementation does not feature background depth 85 | 86 | # Check gradients wrt torch implementation 87 | s = F.mse_loss(samps, rgb_gt) 88 | 89 | print(s) 90 | print('bkwd..') 91 | with Timing("ours_backward"): 92 | s.backward() 93 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu() 94 | grid_density_grad_s = grid.density_data.grad.clone().cpu() 95 | grid.sh_data.grad = None 96 | grid.density_data.grad = None 97 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 98 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu() 99 | grid.basis_data.grad = None 100 | if grid.use_background: 101 | grid_bg_grad_s = grid.background_data.grad.clone().cpu() 102 | grid.background_data.grad = None 103 | 104 | if ENABLE_TORCH_CHECK: 105 | with Timing("torch"): 106 | sampt = grid.volume_render(rays, use_kernel=False, render_alpha=dims==5, render_depth=True)[:, -dims:] 107 | print('Do ours and torch output match?', torch.isclose(samps, sampt).all().item()) 108 | 109 | s = F.mse_loss(sampt, rgb_gt) 110 | with Timing("torch_backward"): 111 | s.backward() 112 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s) 113 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s) 114 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 115 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu() 116 | if grid.use_background: 117 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s) 118 | 119 | grid_density_grad_t = grid_density_grad_t.view(grid_res, grid_res, grid_res) 120 | grid_density_grad_s = grid_density_grad_s.view(grid_res, grid_res, grid_res) 121 | 122 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t) 123 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t) 124 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 125 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t) 126 | if grid.use_background: 127 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t) 128 | print('err', torch.abs(samps - sampt).max()) 129 | print('err_sh_grad\n', E.max()) 130 | print(' mean\n', E.mean()) 131 | print('err_density_grad\n', Ed.max()) 132 | print(' mean\n', Ed.mean()) 133 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 134 | print('err_basis_grad\n', Eb.max()) 135 | print(' mean\n', Eb.mean()) 136 | if grid.use_background: 137 | print('err_background_grad\n', Ebg.max()) 138 | print(' mean\n', Ebg.mean()) 139 | print() 140 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max()) 141 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max()) 142 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max()) 143 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max()) 144 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 145 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max()) 146 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max()) 147 | if grid.use_background: 148 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max()) 149 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max()) 150 | -------------------------------------------------------------------------------- /voxgraf-plenoxels/test/test_render_gradcheck_vardepth.py: -------------------------------------------------------------------------------- 1 | import svox2 2 | import torch 3 | import torch.nn.functional as F 4 | from svox2.utils import Timing 5 | 6 | torch.random.manual_seed(2) 7 | # torch.random.manual_seed(8289) 8 | 9 | device = 'cuda:0' 10 | torch.cuda.set_device(device) 11 | dtype = torch.float32 12 | grid_res = 64 13 | grid = svox2.SparseGrid( 14 | reso=grid_res, 15 | center=[0.0, 0.0, 0.0], 16 | radius=[1.0, 1.0, 1.0], 17 | basis_dim=1, 18 | use_z_order=True, 19 | device=device, 20 | background_nlayers=0, 21 | basis_type=svox2.BASIS_TYPE_SH) 22 | grid.opt.backend = 'cuvol' 23 | grid.opt.sigma_thresh = 0.0 24 | grid.opt.step_size = 0.5 25 | grid.opt.stop_thresh = 0.0 26 | grid.opt.background_brightness = 1.0 27 | 28 | feat_dim = 3*grid.basis_dim 29 | grid.sh_data.data = torch.empty([grid.sh_data.shape[0], feat_dim], device=grid.sh_data.device, dtype=grid.sh_data.dtype) 30 | print(grid.sh_data.shape) 31 | # grid.sh_data.data.normal_() 32 | grid.sh_data.data[..., 0] = 0.5 33 | grid.sh_data.data[..., 1:].normal_(std=0.1) 34 | grid.density_data.data[:] = 1.0 35 | 36 | if grid.use_background: 37 | grid.background_data.data[..., -1] = 0.5 38 | grid.background_data.data[..., :-1] = torch.randn_like( 39 | grid.background_data.data[..., :-1]) * 0.01 40 | 41 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 42 | grid.basis_data.data.normal_() 43 | grid.basis_data.data += 1.0 44 | 45 | ENABLE_TORCH_CHECK = True 46 | # # N_RAYS = 5000 #200 * 200 47 | # N_RAYS = 200 * 200 48 | # origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3 49 | # dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype) 50 | 51 | from training.virtual_camera_utils import RaySampler, PoseSampler 52 | img_res = 32 53 | ps = PoseSampler(range_azim=(180, 45), range_polar=(90, 15), radius=8, dist='normal') 54 | rs = RaySampler(img_resolution=img_res, fov=30) 55 | pose = ps.sample_from_dist() 56 | origins, dirs = rs.sample_at(pose[None, :3, :4], rs.img_resolution, fov=30) 57 | origins = origins.cuda() 58 | dirs = dirs.cuda() 59 | 60 | # origins = torch.clip(origins, -0.8, 0.8) 61 | 62 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype) 63 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype) 64 | dirs /= torch.norm(dirs, dim=-1, keepdim=True) 65 | 66 | # start = 71 67 | # end = 72 68 | # origins = origins[start:end] 69 | # dirs = dirs[start:end] 70 | # print(origins.tolist(), dirs.tolist()) 71 | 72 | # breakpoint() 73 | rays = svox2.Rays(origins, dirs) 74 | 75 | dims = feat_dim + 3 76 | rgb_gt = torch.zeros((origins.size(0), dims), device=device, dtype=dtype) 77 | 78 | # grid.requires_grad_(True) 79 | 80 | # samps = grid.volume_render(rays, use_kernel=True) 81 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False) 82 | 83 | with Timing("ours"): 84 | samps = grid.volume_render(rays, use_kernel=True, render_alpha=dims==(feat_dim + 3), render_depth=True, render_vardepth=True)[:, -dims:] 85 | 86 | # Check gradients wrt torch implementation 87 | s = F.mse_loss(samps, rgb_gt) 88 | 89 | print(s) 90 | print('bkwd..') 91 | with Timing("ours_backward"): 92 | s.backward() 93 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu() 94 | grid_density_grad_s = grid.density_data.grad.clone().cpu() 95 | grid.sh_data.grad = None 96 | grid.density_data.grad = None 97 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 98 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu() 99 | grid.basis_data.grad = None 100 | if grid.use_background: 101 | grid_bg_grad_s = grid.background_data.grad.clone().cpu() 102 | grid.background_data.grad = None 103 | 104 | if ENABLE_TORCH_CHECK: 105 | with Timing("torch"): 106 | sampt = grid.volume_render(rays, use_kernel=False, render_alpha=dims==(feat_dim + 3), render_depth=True, render_vardepth=True)[:, -dims:] 107 | print('Do ours and torch output match?', torch.isclose(samps, sampt).all().item()) 108 | 109 | s = F.mse_loss(sampt, rgb_gt) 110 | with Timing("torch_backward"): 111 | s.backward() 112 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s) 113 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s) 114 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 115 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu() 116 | if grid.use_background: 117 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s) 118 | 119 | grid_density_grad_t = grid_density_grad_t.view(grid_res, grid_res, grid_res) 120 | grid_density_grad_s = grid_density_grad_s.view(grid_res, grid_res, grid_res) 121 | 122 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t) 123 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t) 124 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 125 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t) 126 | if grid.use_background: 127 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t) 128 | print('err', torch.abs(samps - sampt).max()) 129 | print('err_sh_grad\n', E.max()) 130 | print(' mean\n', E.mean()) 131 | print('err_density_grad\n', Ed.max()) 132 | print(' mean\n', Ed.mean()) 133 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 134 | print('err_basis_grad\n', Eb.max()) 135 | print(' mean\n', Eb.mean()) 136 | if grid.use_background: 137 | print('err_background_grad\n', Ebg.max()) 138 | print(' mean\n', Ebg.mean()) 139 | print() 140 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max()) 141 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max()) 142 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max()) 143 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max()) 144 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE: 145 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max()) 146 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max()) 147 | if grid.use_background: 148 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max()) 149 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max()) 150 | --------------------------------------------------------------------------------