├── .gitignore
├── LICENSE
├── README.md
├── calc_metrics.py
├── data
├── afhq.json
├── carla.json
└── ffhq.json
├── dataset_tool.py
├── dist
├── MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl
├── stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl
└── svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl
├── dnnlib
├── __init__.py
└── util.py
├── environment.yml
├── gen_images.py
├── gen_video.py
├── gfx
└── ffhq.gif
├── legacy.py
├── metrics
├── __init__.py
├── frechet_inception_distance.py
├── metric_main.py
└── metric_utils.py
├── scripts
├── build_wheels.sh
├── download_pretrained_models.sh
└── make_datasets.sh
├── torch_utils
├── __init__.py
├── build_fat_binary.sh
├── custom_ops.py
├── misc.py
├── ops
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── bias_act_kernel.cu
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── filtered_lrelu.cpp
│ ├── filtered_lrelu.cu
│ ├── filtered_lrelu.h
│ ├── filtered_lrelu.py
│ ├── filtered_lrelu_ns.cu
│ ├── filtered_lrelu_rd.cu
│ ├── filtered_lrelu_wr.cu
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── pybind.cpp
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
├── persistence.py
├── setup.py
└── training_stats.py
├── train.py
├── training
├── __init__.py
├── augment.py
├── dataset.py
├── loss.py
├── networks_stylegan2.py
├── networks_stylegan2_3d.py
├── networks_stylegan3.py
├── networks_voxgraf.py
├── training_loop.py
└── virtual_camera_utils.py
└── voxgraf-plenoxels
├── LICENSE
├── README.md
├── environment.yml
├── manual_install.sh
├── setup.py
├── svox2
├── __init__.py
├── csrc
│ ├── .ccls
│ ├── CMakeLists.txt
│ ├── include
│ │ ├── cubemap_util.cuh
│ │ ├── cuda_util.cuh
│ │ ├── data_spec.hpp
│ │ ├── data_spec_packed.cuh
│ │ ├── random_util.cuh
│ │ ├── render_util.cuh
│ │ └── util.hpp
│ ├── loss_kernel.cu
│ ├── misc_kernel.cu
│ ├── optim_kernel.cu
│ ├── render_lerp_kernel_cuvol.cu
│ ├── render_lerp_kernel_nvol.cu
│ ├── render_svox1_kernel.cu
│ ├── svox2.cpp
│ └── svox2_kernel.cu
├── defs.py
├── svox2.py
├── utils.py
└── version.py
└── test
├── test_render_gradcheck_alpha.py
├── test_render_gradcheck_depth.py
└── test_render_gradcheck_vardepth.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Added
132 | training-runs/*
133 | training-runs
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 autonomousvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VoxGRAF
2 |
3 |
4 |

5 |
6 |
7 | This repository contains official code for the paper
8 | [VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids](https://www.cvlibs.net/publications/Schwarz2022NEURIPS.pdf).
9 |
10 | You can find detailed usage instructions for training your own models and using pre-trained models below.
11 |
12 | If you find our code or paper useful, please consider citing
13 |
14 | @inproceedings{Schwarz2022NEURIPS,
15 | title = {VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids},
16 | author = {Schwarz, Katja and Sauer, Axel and Niemeyer, Michael and Liao, Yiyi and Geiger, Andreas},
17 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
18 | year = {2022}
19 | }
20 |
21 | ## Installation
22 | First you have to make sure that you have all dependencies in place.
23 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).
24 |
25 | You can create and activate an anaconda environment called `voxgraf` using
26 |
27 | ```commandline
28 | conda env create -f environment.yml
29 | conda activate voxgraf
30 | ```
31 |
32 | ### CUDA extension installation
33 |
34 | Install pre-compiled CUDA extensions by running
35 | ```commandline
36 | ./scripts/build_wheels.sh
37 | ```
38 | **Or** install them individually by running
39 | ```commandline
40 | pip install dist/stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl
41 | pip install dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl
42 | pip install dist/MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl # optional, only required when training with minkowski sparse convolutions
43 | ```
44 | In case the wheels do not work for you, you can also install the extensions from source. For this please check the original repos: [Stylegan-3](https://github.com/NVlabs/stylegan3), [Minkowski Engine](https://github.com/NVIDIA/MinkowskiEngine) and for our version of Plenoxels follow the instructions [here](voxgraf-plenoxels/README.md).
45 |
46 | ## Pretrained models
47 | To download the pretrained models run
48 | ```commandline
49 | ./scripts/download_pretrained_models.sh
50 | ```
51 |
52 | ### Evaluate pretrained models
53 | ```commandline
54 | # generate a video with 1x2 samples and interpolations between 2 keyframes each
55 | python gen_video.py --network pretrained_models/ffhq256.pkl --seeds 0-3 --grid 1x2 --num-keyframes 2 --output ffhq_256_samples/video.mp4 --trunc=0.5
56 |
57 | # generate grids of 3x4 samples and their depths
58 | python gen_images.py --network pretrained_models/ffhq256.pkl --seeds 0-23 --grid 3x4 --outdir ffhq_256_samples --save_depth true --trunc=0.5
59 | ```
60 |
61 | ## Train custom models
62 |
63 | ### Download the data
64 | Download [FFHQ](https://github.com/NVlabs/stylegan2), [AFHQ](https://github.com/clovaai/stargan-v2) and [Carla](https://github.com/autonomousvision/graf).
65 |
66 | ### Preparing the data
67 | To prepare the data at the required resolutions you can run
68 | ```commandline
69 | ./scripts/make_dataset.sh /PATH/TO/IMAGES data/{DATASET_NAME}.json data/{DATASET_NAME} 32,64,128,256
70 | ```
71 | This will create the datasets in `data/{DATASET_NAME}_{RES}.zip`.
72 |
73 | ### Train models progressively
74 |
75 | ```commandline
76 | # Train a model on FFHQ progressively starting at image resolution 32x32 with voxel grid resolution 32x32x32
77 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_32.zip --batch 64 --grid-res 32
78 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 32 --resume /PATH/TO/32-IMG-32-GRID-MODEL # Next stage
79 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 64 --resume /PATH/TO/64-IMG-32-GRID-MODEL # Next stage
80 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 64 --grid-res 64 --lambda_vardepth 1e-3 --resume /PATH/TO/64-IMG-64-GRID-MODEL # Next stage
81 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 32 --grid-res 128 --lambda_vardepth 1e-3 --resume /PATH/TO/128-IMG-64-GRID-MODEL # Next stage
82 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_256.zip --batch 32 --grid-res 128 --lambda_vardepth 1e-3 --resume /PATH/TO/128-IMG-128-GRID-MODEL # Next stage
83 |
84 | # Train a model on Carla at image resolution 32x32 with voxel grid resolution 32x32x32
85 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_32.zip --batch 64 --grid-res 32 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8
86 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 32 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --resume /PATH/TO/32-IMG-32-GRID-MODEL # Next stage
87 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_64.zip --batch 64 --grid-res 64 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --resume /PATH/TO/64-IMG-32-GRID-MODEL # Next stage
88 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 64 --grid-res 64 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --lambda_vardepth 1e-3 --resume /PATH/TO/64-IMG-64-GRID-MODEL # Next stage
89 | python train.py --outdir training-runs --gpus 8 --data data/ffhq_128.zip --batch 32 --grid-res 128 --n-refinement 0 --use_bg False --lambda_sparsity 1e-8 --lambda_vardepth 1e-3 --resume /PATH/TO/128-IMG-64-GRID-MODEL # Next stage
90 | ```
--------------------------------------------------------------------------------
/calc_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Calculate quality metrics for previous training run or pretrained network pickle."""
10 |
11 | import os
12 | import click
13 | import json
14 | import tempfile
15 | import copy
16 | import torch
17 |
18 | import dnnlib
19 | import legacy
20 | from metrics import metric_main
21 | from metrics import metric_utils
22 | from torch_utils import training_stats
23 | from torch_utils import custom_ops
24 | from torch_utils import misc
25 | from torch_utils.ops import conv2d_gradfix
26 |
27 | # environment variables
28 | os.environ['OMP_NUM_THREADS'] = "16"
29 |
30 | #----------------------------------------------------------------------------
31 |
32 | def subprocess_fn(rank, args, temp_dir):
33 | dnnlib.util.Logger(should_flush=True)
34 |
35 | # Init torch.distributed.
36 | if args.num_gpus > 1:
37 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
38 | if os.name == 'nt':
39 | init_method = 'file:///' + init_file.replace('\\', '/')
40 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
41 | else:
42 | init_method = f'file://{init_file}'
43 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
44 |
45 | # Init torch_utils.
46 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
47 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
48 | if rank != 0 or not args.verbose:
49 | custom_ops.verbosity = 'none'
50 |
51 | # Configure torch.
52 | device = torch.device('cuda', rank)
53 | torch.cuda.set_device(device)
54 | torch.backends.cuda.matmul.allow_tf32 = False
55 | torch.backends.cudnn.allow_tf32 = False
56 | conv2d_gradfix.enabled = True
57 |
58 | # Print network summary.
59 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
60 | if rank == 0 and args.verbose:
61 | z = torch.empty([1, G.z_dim], device=device)
62 | c = torch.empty([1, G.c_dim], device=device)
63 | misc.print_module_summary(G, [z, c])
64 |
65 | # Calculate each metric.
66 | for metric in args.metrics:
67 | if rank == 0 and args.verbose:
68 | print(f'Calculating {metric}...')
69 | progress = metric_utils.ProgressMonitor(verbose=args.verbose)
70 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
71 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress, G_kwargs=args.G_kwargs)
72 | if rank == 0:
73 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
74 | if rank == 0 and args.verbose:
75 | print()
76 |
77 | # Done.
78 | if rank == 0 and args.verbose:
79 | print('Exiting...')
80 |
81 | #----------------------------------------------------------------------------
82 |
83 | def parse_comma_separated_list(s):
84 | if isinstance(s, list):
85 | return s
86 | if s is None or s.lower() == 'none' or s == '':
87 | return []
88 | return s.split(',')
89 |
90 | #----------------------------------------------------------------------------
91 |
92 | @click.command()
93 | @click.pass_context
94 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
95 | @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
96 | @click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
97 | @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
98 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
99 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
100 |
101 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
102 | """Calculate quality metrics for previous training run or pretrained network pickle.
103 |
104 | Examples:
105 |
106 | \b
107 | # Previous training run: look up options automatically, save result to JSONL file.
108 | python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
109 | --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
110 |
111 | \b
112 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
113 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
114 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
115 |
116 | \b
117 | Recommended metrics:
118 | fid50k_full Frechet inception distance against the full dataset.
119 | kid50k_full Kernel inception distance against the full dataset.
120 | pr50k3_full Precision and recall againt the full dataset.
121 | ppl2_wend Perceptual path length in W, endpoints, full image.
122 | eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
123 | eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
124 | eqr50k Equivariance w.r.t. rotation (EQ-R).
125 |
126 | \b
127 | Legacy metrics:
128 | fid50k Frechet inception distance against 50k real images.
129 | kid50k Kernel inception distance against 50k real images.
130 | pr50k3 Precision and recall against 50k real images.
131 | is50k Inception score for CIFAR-10.
132 | """
133 | dnnlib.util.Logger(should_flush=True)
134 |
135 | # Validate arguments.
136 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
137 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
138 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
139 | if not args.num_gpus >= 1:
140 | ctx.fail('--gpus must be at least 1')
141 |
142 | # Load network.
143 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
144 | ctx.fail('--network must point to a file or URL')
145 | if args.verbose:
146 | print(f'Loading network from "{network_pkl}"...')
147 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
148 | network_dict = legacy.load_network_pkl(f)
149 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module
150 | n_img = network_dict.get('progress', {}).get('cur_nimg', None)
151 | if n_img is not None:
152 | args.G_kwargs = {'cur_nimg': n_img.item()}
153 |
154 | # Initialize dataset options.
155 | if data is not None:
156 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
157 | elif network_dict['training_set_kwargs'] is not None:
158 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
159 | else:
160 | ctx.fail('Could not look up dataset options; please specify --data')
161 |
162 | # Finalize dataset options.
163 | args.dataset_kwargs.resolution = args.G.img_resolution
164 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
165 | if mirror is not None:
166 | args.dataset_kwargs.xflip = mirror
167 |
168 | # Print dataset options.
169 | if args.verbose:
170 | print('Dataset options:')
171 | print(json.dumps(args.dataset_kwargs, indent=2))
172 |
173 | # Locate run dir.
174 | args.run_dir = None
175 | if os.path.isfile(network_pkl):
176 | pkl_dir = os.path.dirname(network_pkl)
177 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
178 | args.run_dir = pkl_dir
179 |
180 | # Launch processes.
181 | if args.verbose:
182 | print('Launching processes...')
183 | torch.multiprocessing.set_start_method('spawn')
184 | with tempfile.TemporaryDirectory() as temp_dir:
185 | if args.num_gpus == 1:
186 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
187 | else:
188 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
189 |
190 | #----------------------------------------------------------------------------
191 |
192 | if __name__ == "__main__":
193 | calc_metrics() # pylint: disable=no-value-for-parameter
194 |
195 | #----------------------------------------------------------------------------
196 |
--------------------------------------------------------------------------------
/dist/MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/dist/MinkowskiEngine-0.5.4-cp39-cp39-linux_x86_64.whl
--------------------------------------------------------------------------------
/dist/stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/dist/stylegan3_cuda-0.0.0-cp39-cp39-linux_x86_64.whl
--------------------------------------------------------------------------------
/dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from .util import EasyDict, make_cache_dir_path
10 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: voxgraf
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - anaconda
6 | dependencies:
7 | - python >= 3.8
8 | - pip
9 | - numpy=1.22
10 | - click>=8.0
11 | - pillow=8.3.1
12 | - scipy=1.7.1
13 | - pytorch==1.11.0
14 | - cudatoolkit=11.1
15 | - requests=2.26.0
16 | - tqdm=4.62.2
17 | - ninja=1.10.2
18 | - matplotlib=3.4.2
19 | - imageio=2.9.0
20 | - openblas-devel
21 | - pip:
22 | - imgui==1.3.0
23 | - glfw==2.2.0
24 | - pyopengl==3.1.5
25 | - imageio-ffmpeg==0.4.3
26 | - pyspng
27 | - dill
28 | - psutil
29 | - opencv-python
30 | - tensorboard
31 | - plyfile
32 | - torchvision==0.12.0
--------------------------------------------------------------------------------
/gen_video.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 | #
9 | # Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
10 | #
11 |
12 | """Generate lerp videos using pretrained network pickle."""
13 |
14 | from typing import List, Optional, Tuple, Union
15 | import os
16 | import click
17 | import imageio
18 | import numpy as np
19 | import torch
20 | from training.virtual_camera_utils import uv2RT
21 | from gen_images import parse_range, parse_tuple, load_model, make_inputs, generate_grids
22 |
23 | # environment variables
24 | os.environ['OMP_NUM_THREADS'] = "16"
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def sample_oval(n_points, a, b):
29 | "a is axes length in x-direction, b in y-direction"
30 | angle_range = np.linspace(0, 2*np.pi, n_points)
31 | x = a*np.sin(angle_range)
32 | y = b*np.cos(angle_range)
33 | return x, y
34 |
35 | def sample_line(n_points, a):
36 | "a is axes length in x-direction, b in y-direction"
37 | angle_range = np.linspace(-1, 1, n_points)
38 | x = a*angle_range
39 | y = np.zeros(n_points)
40 | return x, y
41 |
42 | def make_poses(G, label, pose_sampling, w_frames, num_keyframes=1, range_azim=(180, 90), range_polar=(180, 90)):
43 | assert len(label) % num_keyframes == 0
44 | nsamples = len(label) // num_keyframes
45 | # set up the camera trajectory
46 | if pose_sampling == 'oval':
47 | azim, polar = sample_oval(w_frames, range_azim[1], range_polar[1])
48 | elif pose_sampling == 'line':
49 | azim, polar = sample_line(w_frames, range_azim[1])
50 | else:
51 | raise AttributeError
52 | azim += range_azim[0]
53 | polar += range_polar[0]
54 |
55 | azim2u = lambda x: torch.deg2rad(torch.tensor(x)) / (2 * np.pi)
56 | polar2v = lambda x: 0.5 * (1 - torch.cos(torch.deg2rad(torch.tensor(x))))
57 | us, vs = azim2u(azim), polar2v(polar)
58 |
59 | RTs = []
60 | for u, v in zip(us, vs):
61 | RT = uv2RT(u, v, 1) # radius is set later
62 | RTs.append(RT)
63 | poses = torch.stack(RTs).to(torch.float32)
64 | poses = poses.view(1, 1, *poses.shape).repeat(num_keyframes, nsamples, 1, 1, 1)
65 |
66 | if G.c_dim != 0: # adjust the radius of the samples to match their pose
67 | c = label.view(num_keyframes, nsamples, 1, 3, 4).expand(-1, -1, w_frames, -1, -1)
68 | cradius = c[:, :, :, :3, 3].norm(dim=3, keepdim=True).min(dim=0, keepdim=True).values # chooses smalles radius of all keyframes
69 | poses[:, :, :, :3, 3] *= cradius
70 |
71 | return poses.flatten(0, 1) # (num_keyframes x nsamples) x wframes x 4 x 4
72 |
73 | @click.command()
74 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
75 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
76 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
77 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
78 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
79 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) # 60*x
80 | @click.option('--range-azim', type=parse_tuple, help='Mean and std of pose distribution azimuth angle', default=(180, 20))
81 | @click.option('--range-polar', type=parse_tuple, help='Mean and std of pose distribution polar angle', default=(90, 10))
82 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
83 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
84 | @click.option('--pose-sampling', help='Camera trajectory', type=click.Choice(['oval', 'line']), default='oval', show_default=True)
85 | @click.option('--bg-decay', help='Reduce background visibility', type=bool, default=False, show_default=True)
86 | @click.option('--no-bg', help='Remove background', type=bool, default=False, show_default=True)
87 |
88 | def generate_videos(
89 | network_pkl: str,
90 | seeds: List[int],
91 | shuffle_seed: Optional[int],
92 | truncation_psi: float,
93 | grid: Tuple[int,int],
94 | num_keyframes: Optional[int],
95 | w_frames: int,
96 | range_azim: Optional[int],
97 | range_polar: Optional[int],
98 | output: str,
99 | pose_sampling: str,
100 | bg_decay: bool,
101 | no_bg: bool,
102 | ):
103 | """Render a latent vector interpolation video.
104 |
105 | The output video length
106 | will be 'num_keyframes*w_frames' frames.
107 | """
108 | os.makedirs(os.path.dirname(output), exist_ok=True)
109 | assert len(seeds) == grid[0]*grid[1]*num_keyframes, f'need gw({grid[0]})*gh({grid[1]})*num_keyframes({num_keyframes})={grid[0]*grid[1]*num_keyframes} seeds but have {len(seeds)}'
110 | G = load_model(network_pkl)
111 | z, c = make_inputs(seeds, G)
112 | p = make_poses(G, c, pose_sampling, w_frames, num_keyframes=num_keyframes, range_azim=range_azim, range_polar=range_polar)
113 |
114 | if bg_decay:
115 | bg_color = 1
116 | n_start = w_frames // 4
117 | n_end = num_keyframes * w_frames - (w_frames // 4)
118 | get_bg_weight = lambda n: 1 - min(1, max(0, (n - n_start) / (n_end - n_start)))
119 |
120 | video_out = imageio.get_writer(output, mode='I', fps=60, codec='libx264', bitrate='12M')
121 | keyframe_idx = 0
122 | for ret_dict in generate_grids(G=G, truncation_psi=truncation_psi, grid=grid, latents=z, poses=p, conditions=c, ret_alpha=bg_decay, ret_depth=False, ret_bg_fg=(bg_decay or no_bg), interpolate=True):
123 | g = ret_dict['rgb']
124 |
125 | if no_bg:
126 | g = ret_dict['fg']
127 | if bg_decay:
128 | g_decay = []
129 | for i, (g_i, alpha_i) in enumerate(zip(g, ret_dict['alpha'])):
130 | weight = get_bg_weight(keyframe_idx*w_frames + i)
131 | print(weight)
132 | g_decay.append(alpha_i * g_i + (1 - alpha_i) * weight * g_i + (1 - alpha_i) * (1 - weight) * torch.full_like(alpha_i, fill_value=bg_color))
133 | g = g_decay
134 |
135 | for pose_idx, g_p in enumerate(g):
136 | video_out.append_data((g_p.permute(1,2,0)*255).to(torch.uint8).cpu().numpy())
137 | keyframe_idx += 1
138 | video_out.close()
139 |
140 | #----------------------------------------------------------------------------
141 |
142 | if __name__ == "__main__":
143 | generate_videos() # pylint: disable=no-value-for-parameter
144 |
145 | #----------------------------------------------------------------------------
146 |
--------------------------------------------------------------------------------
/gfx/ffhq.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/voxgraf/4468343786c6c3c2fea402509a71b4fffc378dc1/gfx/ffhq.gif
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/metrics/frechet_inception_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Frechet Inception Distance (FID) from the paper
10 | "GANs trained by a two time-scale update rule converge to a local Nash
11 | equilibrium". Matches the original implementation by Heusel et al. at
12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13 |
14 | import numpy as np
15 | import scipy.linalg
16 | from torch_utils import misc
17 | from . import metric_utils
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def compute_fid(opts, max_real, num_gen):
22 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
23 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
24 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
25 |
26 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
27 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
28 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
29 |
30 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
31 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
32 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
33 |
34 | if opts.rank != 0:
35 | return float('nan')
36 |
37 | m = np.square(mu_gen - mu_real).sum()
38 | try:
39 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
40 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
41 | except ValueError: # nan
42 | fid = 9999
43 | return float(fid)
44 |
45 | #----------------------------------------------------------------------------
46 |
--------------------------------------------------------------------------------
/metrics/metric_main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Main API for computing and reporting quality metrics."""
10 |
11 | import os
12 | import time
13 | import json
14 | import torch
15 | import dnnlib
16 |
17 | from . import metric_utils
18 | from . import frechet_inception_distance
19 |
20 | os.environ['MKL_NUM_THREADS'] = "1" # needed for FID computation on cluster
21 |
22 | #----------------------------------------------------------------------------
23 |
24 | _metric_dict = dict() # name => fn
25 |
26 | def register_metric(fn):
27 | assert callable(fn)
28 | _metric_dict[fn.__name__] = fn
29 | return fn
30 |
31 | def is_valid_metric(metric):
32 | return metric in _metric_dict
33 |
34 | def list_valid_metrics():
35 | return list(_metric_dict.keys())
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40 | assert is_valid_metric(metric)
41 | opts = metric_utils.MetricOptions(**kwargs)
42 |
43 | # Calculate.
44 | start_time = time.time()
45 | results = _metric_dict[metric](opts)
46 | total_time = time.time() - start_time
47 |
48 | # Broadcast results.
49 | for key, value in list(results.items()):
50 | if opts.num_gpus > 1:
51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52 | torch.distributed.broadcast(tensor=value, src=0)
53 | value = float(value.cpu())
54 | results[key] = value
55 |
56 | # Decorate with metadata.
57 | return dnnlib.EasyDict(
58 | results = dnnlib.EasyDict(results),
59 | metric = metric,
60 | total_time = total_time,
61 | total_time_str = dnnlib.util.format_time(total_time),
62 | num_gpus = opts.num_gpus,
63 | )
64 |
65 | #----------------------------------------------------------------------------
66 |
67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68 | metric = result_dict['metric']
69 | assert is_valid_metric(metric)
70 | if run_dir is not None and snapshot_pkl is not None:
71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72 |
73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74 | print(jsonl_line)
75 | if run_dir is not None and os.path.isdir(run_dir):
76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77 | f.write(jsonl_line + '\n')
78 |
79 | #----------------------------------------------------------------------------
80 | # Recommended metrics.
81 |
82 | @register_metric
83 | def fid50k_full(opts):
84 | opts.dataset_kwargs.update(max_size=None, xflip=False)
85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
86 | return dict(fid50k_full=fid)
87 |
88 | #----------------------------------------------------------------------------
89 | # Legacy metrics.
90 |
91 | @register_metric
92 | def fid50k(opts):
93 | opts.dataset_kwargs.update(max_size=None)
94 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
95 | return dict(fid50k=fid)
96 |
97 | #----------------------------------------------------------------------------
98 | # Added
99 |
100 | @register_metric
101 | def fid10k_full(opts):
102 | opts.dataset_kwargs.update(max_size=None, xflip=False)
103 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=10000)
104 | return dict(fid10k_full=fid)
105 |
106 | @register_metric
107 | def fid20k_full(opts):
108 | opts.dataset_kwargs.update(max_size=None, xflip=False)
109 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=20000)
110 | return dict(fid10k_full=fid)
--------------------------------------------------------------------------------
/scripts/build_wheels.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export TMPDIR=~/tmp
3 | conda run -n voxgraf pip install dist/*
4 |
--------------------------------------------------------------------------------
/scripts/download_pretrained_models.sh:
--------------------------------------------------------------------------------
1 | mkdir pretrained_models
2 | # FFHQ 256
3 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/voxgraf/models/ffhq256.pkl -P pretrained_models
4 | # Carla 128
5 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/voxgraf/models/carla128.pkl -P pretrained_models
--------------------------------------------------------------------------------
/scripts/make_datasets.sh:
--------------------------------------------------------------------------------
1 | SRC=$1
2 | META=$2
3 | DST=${3%.zip}
4 | IFS=',' read -ra RES <<< $4
5 |
6 | for res in ${RES[@]}
7 | do
8 | OUT=${DST}_${res}.zip
9 | if test -f "$OUT"; then
10 | continue
11 | fi
12 | echo python dataset_tool.py --source $SRC --meta $META --dest $OUT --resolution ${res}x${res} --mirror-aug True
13 | python dataset_tool.py --source $SRC --meta $META --dest $OUT --resolution ${res}x${res} --mirror-aug True
14 | done
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/build_fat_binary.sh:
--------------------------------------------------------------------------------
1 | export TORCH_CUDA_ARCH_LIST="3.5+PTX;3.7+PTX;5.0+PTX;5.2+PTX;5.3+PTX;6.0+PTX;6.1+PTX;6.2+PTX;7.0+PTX;7.2+PTX;7.5+PTX;8.0+PTX;8.6+PTX"
2 |
3 | #python setup.py install
4 | python setup.py bdist_wheel
5 |
--------------------------------------------------------------------------------
/torch_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import glob
10 | import hashlib
11 | import importlib
12 | import os
13 | import re
14 | import shutil
15 | import uuid
16 |
17 | import torch
18 | import torch.utils.cpp_extension
19 | from torch.utils.file_baton import FileBaton
20 |
21 | #----------------------------------------------------------------------------
22 | # Global options.
23 |
24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25 |
26 | #----------------------------------------------------------------------------
27 | # Internal helper funcs.
28 |
29 | def _find_compiler_bindir():
30 | patterns = [
31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
35 | ]
36 | for pattern in patterns:
37 | matches = sorted(glob.glob(pattern))
38 | if len(matches):
39 | return matches[-1]
40 | return None
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | def _get_mangled_gpu_name():
45 | name = torch.cuda.get_device_name().lower()
46 | out = []
47 | for c in name:
48 | if re.match('[a-z0-9_-]+', c):
49 | out.append(c)
50 | else:
51 | out.append('-')
52 | return ''.join(out)
53 |
54 | #----------------------------------------------------------------------------
55 | # Main entry point for compiling and loading C++/CUDA plugins.
56 |
57 | _cached_plugins = dict()
58 |
59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60 | assert verbosity in ['none', 'brief', 'full']
61 | if headers is None:
62 | headers = []
63 | if source_dir is not None:
64 | sources = [os.path.join(source_dir, fname) for fname in sources]
65 | headers = [os.path.join(source_dir, fname) for fname in headers]
66 |
67 | # Already cached?
68 | if module_name in _cached_plugins:
69 | return _cached_plugins[module_name]
70 |
71 | # Print status.
72 | if verbosity == 'full':
73 | print(f'Setting up PyTorch plugin "{module_name}"...')
74 | elif verbosity == 'brief':
75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76 | verbose_build = (verbosity == 'full')
77 |
78 | # Compile and load.
79 | try: # pylint: disable=too-many-nested-blocks
80 | # Make sure we can find the necessary compiler binaries.
81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82 | compiler_bindir = _find_compiler_bindir()
83 | if compiler_bindir is None:
84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85 | os.environ['PATH'] += ';' + compiler_bindir
86 |
87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88 | # break the build or unnecessarily restrict what's available to nvcc.
89 | # Unset it to let nvcc decide based on what's available on the
90 | # machine.
91 | os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92 |
93 | # Incremental build md5sum trickery. Copies all the input source files
94 | # into a cached build directory under a combined md5 digest of the input
95 | # source files. Copying is done only if the combined digest has changed.
96 | # This keeps input file timestamps and filenames the same as in previous
97 | # extension builds, allowing for fast incremental rebuilds.
98 | #
99 | # This optimization is done only in case all the source files reside in
100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101 | # environment variable is set (we take this as a signal that the user
102 | # actually cares about this.)
103 | #
104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105 | # around the *.cu dependency bug in ninja config.
106 | #
107 | all_source_files = sorted(sources + headers)
108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110 |
111 | # Compute combined hash digest for all source files.
112 | hash_md5 = hashlib.md5()
113 | for src in all_source_files:
114 | with open(src, 'rb') as f:
115 | hash_md5.update(f.read())
116 |
117 | # Select cached build directory name.
118 | source_digest = hash_md5.hexdigest()
119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121 |
122 | if not os.path.isdir(cached_build_dir):
123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124 | os.makedirs(tmpdir)
125 | for src in all_source_files:
126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127 | try:
128 | os.replace(tmpdir, cached_build_dir) # atomic
129 | except OSError:
130 | # source directory already exists, delete tmpdir and its contents.
131 | shutil.rmtree(tmpdir)
132 | if not os.path.isdir(cached_build_dir): raise
133 |
134 | # Compile.
135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137 | verbose=verbose_build, sources=cached_sources, **build_kwargs)
138 | else:
139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140 |
141 | # Load.
142 | module = importlib.import_module(module_name)
143 |
144 | except:
145 | if verbosity == 'brief':
146 | print('Failed!')
147 | raise
148 |
149 | # Print status and add to cache dict.
150 | if verbosity == 'full':
151 | print(f'Done setting up PyTorch plugin "{module_name}".')
152 | elif verbosity == 'brief':
153 | print('Done.')
154 | _cached_plugins[module_name] = module
155 | return module
156 |
157 | #----------------------------------------------------------------------------
158 |
--------------------------------------------------------------------------------
/torch_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 | //
9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
10 | //
11 |
12 | #include
13 | #include
14 | #include
15 | #include "bias_act.h"
16 |
17 | //------------------------------------------------------------------------
18 |
19 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
20 | {
21 | if (x.dim() != y.dim())
22 | return false;
23 | for (int64_t i = 0; i < x.dim(); i++)
24 | {
25 | if (x.size(i) != y.size(i))
26 | return false;
27 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
28 | return false;
29 | }
30 | return true;
31 | }
32 |
33 | //------------------------------------------------------------------------
34 |
35 | torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
36 | {
37 | // Validate arguments.
38 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
39 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
40 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
41 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
42 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
43 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
44 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
45 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
46 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
47 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
48 |
49 | // Validate layout.
50 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
51 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
52 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
53 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
54 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
55 |
56 | // Create output tensor.
57 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
58 | torch::Tensor y = torch::empty_like(x);
59 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
60 |
61 | // Initialize CUDA kernel parameters.
62 | bias_act_kernel_params p;
63 | p.x = x.data_ptr();
64 | p.b = (b.numel()) ? b.data_ptr() : NULL;
65 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
66 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
67 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
68 | p.y = y.data_ptr();
69 | p.grad = grad;
70 | p.act = act;
71 | p.alpha = alpha;
72 | p.gain = gain;
73 | p.clamp = clamp;
74 | p.sizeX = (int)x.numel();
75 | p.sizeB = (int)b.numel();
76 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
77 |
78 | // Choose CUDA kernel.
79 | void* kernel;
80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
81 | {
82 | kernel = choose_bias_act_kernel(p);
83 | });
84 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
85 |
86 | // Launch CUDA kernel.
87 | p.loopX = 4;
88 | int blockSize = 4 * 32;
89 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
90 | void* args[] = {&p};
91 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
92 | return y;
93 | }
94 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 | #
9 | # Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
10 | #
11 |
12 |
13 | """Custom PyTorch ops for efficient bias and activation."""
14 |
15 | import os
16 | import numpy as np
17 | import torch
18 | import dnnlib
19 |
20 | from .. import custom_ops
21 | from .. import misc
22 |
23 | #----------------------------------------------------------------------------
24 |
25 | activation_funcs = {
26 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
27 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
28 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
29 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
30 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
31 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
32 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
33 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
34 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
35 | }
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | _plugin = None
40 | _null_tensor = torch.empty([0])
41 |
42 | def _init():
43 | global _plugin
44 | if _plugin is None:
45 | if False:
46 | _plugin = custom_ops.get_plugin(
47 | module_name='bias_act_plugin',
48 | sources=['bias_act.cpp', 'bias_act.cu'],
49 | headers=['bias_act.h'],
50 | source_dir=os.path.dirname(__file__),
51 | extra_cuda_cflags=['--use_fast_math'],
52 | )
53 | import stylegan3_cuda as _plugin
54 | return True
55 |
56 | #----------------------------------------------------------------------------
57 |
58 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
59 | r"""Fused bias and activation function.
60 |
61 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
62 | and scales the result by `gain`. Each of the steps is optional. In most cases,
63 | the fused op is considerably more efficient than performing the same calculation
64 | using standard PyTorch ops. It supports first and second order gradients,
65 | but not third order gradients.
66 |
67 | Args:
68 | x: Input activation tensor. Can be of any shape.
69 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
70 | as `x`. The shape must be known, and it must match the dimension of `x`
71 | corresponding to `dim`.
72 | dim: The dimension in `x` corresponding to the elements of `b`.
73 | The value of `dim` is ignored if `b` is not specified.
74 | act: Name of the activation function to evaluate, or `"linear"` to disable.
75 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
76 | See `activation_funcs` for a full list. `None` is not allowed.
77 | alpha: Shape parameter for the activation function, or `None` to use the default.
78 | gain: Scaling factor for the output tensor, or `None` to use default.
79 | See `activation_funcs` for the default scaling of each activation function.
80 | If unsure, consider specifying 1.
81 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
82 | the clamping (default).
83 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
84 |
85 | Returns:
86 | Tensor of the same shape and datatype as `x`.
87 | """
88 | assert isinstance(x, torch.Tensor)
89 | assert impl in ['ref', 'cuda']
90 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
91 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
92 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
93 |
94 | #----------------------------------------------------------------------------
95 |
96 | @misc.profiled_function
97 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
98 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
99 | """
100 | assert isinstance(x, torch.Tensor)
101 | assert clamp is None or clamp >= 0
102 | spec = activation_funcs[act]
103 | alpha = float(alpha if alpha is not None else spec.def_alpha)
104 | gain = float(gain if gain is not None else spec.def_gain)
105 | clamp = float(clamp if clamp is not None else -1)
106 |
107 | # Add bias.
108 | if b is not None:
109 | assert isinstance(b, torch.Tensor) and b.ndim == 1
110 | assert 0 <= dim < x.ndim
111 | assert b.shape[0] == x.shape[dim]
112 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
113 |
114 | # Evaluate activation function.
115 | alpha = float(alpha)
116 | x = spec.func(x, alpha=alpha)
117 |
118 | # Scale by gain.
119 | gain = float(gain)
120 | if gain != 1:
121 | x = x * gain
122 |
123 | # Clamp.
124 | if clamp >= 0:
125 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
126 | return x
127 |
128 | #----------------------------------------------------------------------------
129 |
130 | _bias_act_cuda_cache = dict()
131 |
132 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
133 | """Fast CUDA implementation of `bias_act()` using custom ops.
134 | """
135 | # Parse arguments.
136 | assert clamp is None or clamp >= 0
137 | spec = activation_funcs[act]
138 | alpha = float(alpha if alpha is not None else spec.def_alpha)
139 | gain = float(gain if gain is not None else spec.def_gain)
140 | clamp = float(clamp if clamp is not None else -1)
141 |
142 | # Lookup from cache.
143 | key = (dim, act, alpha, gain, clamp)
144 | if key in _bias_act_cuda_cache:
145 | return _bias_act_cuda_cache[key]
146 |
147 | # Forward op.
148 | class BiasActCuda(torch.autograd.Function):
149 | @staticmethod
150 | def forward(ctx, x, b): # pylint: disable=arguments-differ
151 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
152 | x = x.contiguous(memory_format=ctx.memory_format)
153 | b = b.contiguous() if b is not None else _null_tensor
154 | y = x
155 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
156 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
157 | ctx.save_for_backward(
158 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
159 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
160 | y if 'y' in spec.ref else _null_tensor)
161 | return y
162 |
163 | @staticmethod
164 | def backward(ctx, dy): # pylint: disable=arguments-differ
165 | dy = dy.contiguous(memory_format=ctx.memory_format)
166 | x, b, y = ctx.saved_tensors
167 | dx = None
168 | db = None
169 |
170 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
171 | dx = dy
172 | if act != 'linear' or gain != 1 or clamp >= 0:
173 | dx = BiasActCudaGrad.apply(dy, x, b, y)
174 |
175 | if ctx.needs_input_grad[1]:
176 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
177 |
178 | return dx, db
179 |
180 | # Backward op.
181 | class BiasActCudaGrad(torch.autograd.Function):
182 | @staticmethod
183 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
184 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
185 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
186 | ctx.save_for_backward(
187 | dy if spec.has_2nd_grad else _null_tensor,
188 | x, b, y)
189 | return dx
190 |
191 | @staticmethod
192 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
193 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
194 | dy, x, b, y = ctx.saved_tensors
195 | d_dy = None
196 | d_x = None
197 | d_b = None
198 | d_y = None
199 |
200 | if ctx.needs_input_grad[0]:
201 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
202 |
203 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
204 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
205 |
206 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
207 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
208 |
209 | return d_dy, d_x, d_b, d_y
210 |
211 | # Add to cache.
212 | _bias_act_cuda_cache[key] = BiasActCuda
213 | return BiasActCuda
214 |
215 | #----------------------------------------------------------------------------
216 |
--------------------------------------------------------------------------------
/torch_utils/ops/bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 | //
9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
10 | //
11 |
12 | #include
13 | #include "bias_act.h"
14 |
15 | //------------------------------------------------------------------------
16 | // Helpers.
17 |
18 | template struct InternalType;
19 | template <> struct InternalType { typedef double scalar_t; };
20 | template <> struct InternalType { typedef float scalar_t; };
21 | template <> struct InternalType { typedef float scalar_t; };
22 |
23 | //------------------------------------------------------------------------
24 | // CUDA kernel.
25 |
26 | template
27 | __global__ void bias_act_kernel(bias_act_kernel_params p)
28 | {
29 | typedef typename InternalType::scalar_t scalar_t;
30 | int G = p.grad;
31 | scalar_t alpha = (scalar_t)p.alpha;
32 | scalar_t gain = (scalar_t)p.gain;
33 | scalar_t clamp = (scalar_t)p.clamp;
34 | scalar_t one = (scalar_t)1;
35 | scalar_t two = (scalar_t)2;
36 | scalar_t expRange = (scalar_t)80;
37 | scalar_t halfExpRange = (scalar_t)40;
38 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
39 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
40 |
41 | // Loop over elements.
42 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
43 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
44 | {
45 | // Load.
46 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
47 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
48 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
49 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
50 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
51 | scalar_t yy = (gain != 0) ? yref / gain : 0;
52 | scalar_t y = 0;
53 |
54 | // Apply bias.
55 | ((G == 0) ? x : xref) += b;
56 |
57 | // linear
58 | if (A == 1)
59 | {
60 | if (G == 0) y = x;
61 | if (G == 1) y = x;
62 | }
63 |
64 | // relu
65 | if (A == 2)
66 | {
67 | if (G == 0) y = (x > 0) ? x : 0;
68 | if (G == 1) y = (yy > 0) ? x : 0;
69 | }
70 |
71 | // lrelu
72 | if (A == 3)
73 | {
74 | if (G == 0) y = (x > 0) ? x : x * alpha;
75 | if (G == 1) y = (yy > 0) ? x : x * alpha;
76 | }
77 |
78 | // tanh
79 | if (A == 4)
80 | {
81 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
82 | if (G == 1) y = x * (one - yy * yy);
83 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
84 | }
85 |
86 | // sigmoid
87 | if (A == 5)
88 | {
89 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
90 | if (G == 1) y = x * yy * (one - yy);
91 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
92 | }
93 |
94 | // elu
95 | if (A == 6)
96 | {
97 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
98 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
99 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
100 | }
101 |
102 | // selu
103 | if (A == 7)
104 | {
105 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
106 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
107 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
108 | }
109 |
110 | // softplus
111 | if (A == 8)
112 | {
113 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
114 | if (G == 1) y = x * (one - exp(-yy));
115 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
116 | }
117 |
118 | // swish
119 | if (A == 9)
120 | {
121 | if (G == 0)
122 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
123 | else
124 | {
125 | scalar_t c = exp(xref);
126 | scalar_t d = c + one;
127 | if (G == 1)
128 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
129 | else
130 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
131 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
132 | }
133 | }
134 |
135 | // Apply gain.
136 | y *= gain * dy;
137 |
138 | // Clamp.
139 | if (clamp >= 0)
140 | {
141 | if (G == 0)
142 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
143 | else
144 | y = (yref > -clamp & yref < clamp) ? y : 0;
145 | }
146 |
147 | // Store.
148 | ((T*)p.y)[xi] = (T)y;
149 | }
150 | }
151 |
152 | //------------------------------------------------------------------------
153 | // CUDA kernel selection.
154 |
155 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
156 | {
157 | if (p.act == 1) return (void*)bias_act_kernel;
158 | if (p.act == 2) return (void*)bias_act_kernel;
159 | if (p.act == 3) return (void*)bias_act_kernel;
160 | if (p.act == 4) return (void*)bias_act_kernel;
161 | if (p.act == 5) return (void*)bias_act_kernel;
162 | if (p.act == 6) return (void*)bias_act_kernel;
163 | if (p.act == 7) return (void*)bias_act_kernel;
164 | if (p.act == 8) return (void*)bias_act_kernel;
165 | if (p.act == 9) return (void*)bias_act_kernel;
166 | return NULL;
167 | }
168 |
169 | //------------------------------------------------------------------------
170 | // Template specializations.
171 |
172 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
173 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
174 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
175 |
176 | //------------------------------------------------------------------------
177 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import contextlib
13 | import torch
14 | from pkg_resources import parse_version
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25 |
26 | @contextlib.contextmanager
27 | def no_weight_gradients(disable=True):
28 | global weight_gradients_disabled
29 | old = weight_gradients_disabled
30 | if disable:
31 | weight_gradients_disabled = True
32 | yield
33 | weight_gradients_disabled = old
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38 | if _should_use_custom_op(input):
39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41 |
42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43 | if _should_use_custom_op(input):
44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _should_use_custom_op(input):
50 | assert isinstance(input, torch.Tensor)
51 | if (not enabled) or (not torch.backends.cudnn.enabled):
52 | return False
53 | if _use_pytorch_1_11_api:
54 | # The work-around code doesn't work on PyTorch 1.11.0 onwards
55 | return False
56 | if input.device.type != 'cuda':
57 | return False
58 | return True
59 |
60 | def _tuple_of_ints(xs, ndim):
61 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
62 | assert len(xs) == ndim
63 | assert all(isinstance(x, int) for x in xs)
64 | return xs
65 |
66 | #----------------------------------------------------------------------------
67 |
68 | _conv2d_gradfix_cache = dict()
69 | _null_tensor = torch.empty([0])
70 |
71 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
72 | # Parse arguments.
73 | ndim = 2
74 | weight_shape = tuple(weight_shape)
75 | stride = _tuple_of_ints(stride, ndim)
76 | padding = _tuple_of_ints(padding, ndim)
77 | output_padding = _tuple_of_ints(output_padding, ndim)
78 | dilation = _tuple_of_ints(dilation, ndim)
79 |
80 | # Lookup from cache.
81 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
82 | if key in _conv2d_gradfix_cache:
83 | return _conv2d_gradfix_cache[key]
84 |
85 | # Validate arguments.
86 | assert groups >= 1
87 | assert len(weight_shape) == ndim + 2
88 | assert all(stride[i] >= 1 for i in range(ndim))
89 | assert all(padding[i] >= 0 for i in range(ndim))
90 | assert all(dilation[i] >= 0 for i in range(ndim))
91 | if not transpose:
92 | assert all(output_padding[i] == 0 for i in range(ndim))
93 | else: # transpose
94 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
95 |
96 | # Helpers.
97 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
98 | def calc_output_padding(input_shape, output_shape):
99 | if transpose:
100 | return [0, 0]
101 | return [
102 | input_shape[i + 2]
103 | - (output_shape[i + 2] - 1) * stride[i]
104 | - (1 - 2 * padding[i])
105 | - dilation[i] * (weight_shape[i + 2] - 1)
106 | for i in range(ndim)
107 | ]
108 |
109 | # Forward & backward.
110 | class Conv2d(torch.autograd.Function):
111 | @staticmethod
112 | def forward(ctx, input, weight, bias):
113 | assert weight.shape == weight_shape
114 | ctx.save_for_backward(
115 | input if weight.requires_grad else _null_tensor,
116 | weight if input.requires_grad else _null_tensor,
117 | )
118 | ctx.input_shape = input.shape
119 |
120 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
121 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
122 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
123 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
124 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
125 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
126 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
127 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
128 |
129 | # General case => cuDNN.
130 | if transpose:
131 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
132 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
133 |
134 | @staticmethod
135 | def backward(ctx, grad_output):
136 | input, weight = ctx.saved_tensors
137 | input_shape = ctx.input_shape
138 | grad_input = None
139 | grad_weight = None
140 | grad_bias = None
141 |
142 | if ctx.needs_input_grad[0]:
143 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
144 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
145 | grad_input = op.apply(grad_output, weight, None)
146 | assert grad_input.shape == input_shape
147 |
148 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
149 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
150 | assert grad_weight.shape == weight_shape
151 |
152 | if ctx.needs_input_grad[2]:
153 | grad_bias = grad_output.sum([0, 2, 3])
154 |
155 | return grad_input, grad_weight, grad_bias
156 |
157 | # Gradient with respect to the weights.
158 | class Conv2dGradWeight(torch.autograd.Function):
159 | @staticmethod
160 | def forward(ctx, grad_output, input):
161 | ctx.save_for_backward(
162 | grad_output if input.requires_grad else _null_tensor,
163 | input if grad_output.requires_grad else _null_tensor,
164 | )
165 | ctx.grad_output_shape = grad_output.shape
166 | ctx.input_shape = input.shape
167 |
168 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
169 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
170 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
171 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
172 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
173 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
174 |
175 | # General case => cuDNN.
176 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
177 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
178 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
179 |
180 | @staticmethod
181 | def backward(ctx, grad2_grad_weight):
182 | grad_output, input = ctx.saved_tensors
183 | grad_output_shape = ctx.grad_output_shape
184 | input_shape = ctx.input_shape
185 | grad2_grad_output = None
186 | grad2_input = None
187 |
188 | if ctx.needs_input_grad[0]:
189 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
190 | assert grad2_grad_output.shape == grad_output_shape
191 |
192 | if ctx.needs_input_grad[1]:
193 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
194 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
195 | grad2_input = op.apply(grad_output, grad2_grad_weight, None)
196 | assert grad2_input.shape == input_shape
197 |
198 | return grad2_grad_output, grad2_input
199 |
200 | _conv2d_gradfix_cache[key] = Conv2d
201 | return Conv2d
202 |
203 | #----------------------------------------------------------------------------
204 |
--------------------------------------------------------------------------------
/torch_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | if not flip_weight and (kw > 1 or kh > 1):
37 | w = w.flip([2, 3])
38 |
39 | # Execute using conv2d_gradfix.
40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41 | return op(x, w, stride=stride, padding=padding, groups=groups)
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | @misc.profiled_function
46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47 | r"""2D convolution with optional up/downsampling.
48 |
49 | Padding is performed only once at the beginning, not between the operations.
50 |
51 | Args:
52 | x: Input tensor of shape
53 | `[batch_size, in_channels, in_height, in_width]`.
54 | w: Weight tensor of shape
55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57 | calling upfirdn2d.setup_filter(). None = identity (default).
58 | up: Integer upsampling factor (default: 1).
59 | down: Integer downsampling factor (default: 1).
60 | padding: Padding with respect to the upsampled image. Can be a single number
61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62 | (default: 0).
63 | groups: Split input channels into N groups (default: 1).
64 | flip_weight: False = convolution, True = correlation (default: True).
65 | flip_filter: False = convolution, True = correlation (default: False).
66 |
67 | Returns:
68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69 | """
70 | # Validate arguments.
71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74 | assert isinstance(up, int) and (up >= 1)
75 | assert isinstance(down, int) and (down >= 1)
76 | assert isinstance(groups, int) and (groups >= 1)
77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78 | fw, fh = _get_filter_size(f)
79 | px0, px1, py0, py1 = _parse_padding(padding)
80 |
81 | # Adjust padding to account for up/downsampling.
82 | if up > 1:
83 | px0 += (fw + up - 1) // 2
84 | px1 += (fw - up) // 2
85 | py0 += (fh + up - 1) // 2
86 | py1 += (fh - up) // 2
87 | if down > 1:
88 | px0 += (fw - down + 1) // 2
89 | px1 += (fw - down) // 2
90 | py0 += (fh - down + 1) // 2
91 | py1 += (fh - down) // 2
92 |
93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97 | return x
98 |
99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103 | return x
104 |
105 | # Fast path: downsampling only => use strided convolution.
106 | if down > 1 and up == 1:
107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109 | return x
110 |
111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112 | if up > 1:
113 | if groups == 1:
114 | w = w.transpose(0, 1)
115 | else:
116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117 | w = w.transpose(1, 2)
118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119 | px0 -= kw - 1
120 | px1 -= kw - up
121 | py0 -= kh - 1
122 | py1 -= kh - up
123 | pxt = max(min(-px0, -px1), 0)
124 | pyt = max(min(-py0, -py1), 0)
125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127 | if down > 1:
128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129 | return x
130 |
131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132 | if up == 1 and down == 1:
133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135 |
136 | # Fallback: Generic reference implementation.
137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139 | if down > 1:
140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141 | return x
142 |
143 | #----------------------------------------------------------------------------
144 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct filtered_lrelu_kernel_params
15 | {
16 | // These parameters decide which kernel to use.
17 | int up; // upsampling ratio (1, 2, 4)
18 | int down; // downsampling ratio (1, 2, 4)
19 | int2 fuShape; // [size, 1] | [size, size]
20 | int2 fdShape; // [size, 1] | [size, size]
21 |
22 | int _dummy; // Alignment.
23 |
24 | // Rest of the parameters.
25 | const void* x; // Input tensor.
26 | void* y; // Output tensor.
27 | const void* b; // Bias tensor.
28 | unsigned char* s; // Sign tensor in/out. NULL if unused.
29 | const float* fu; // Upsampling filter.
30 | const float* fd; // Downsampling filter.
31 |
32 | int2 pad0; // Left/top padding.
33 | float gain; // Additional gain factor.
34 | float slope; // Leaky ReLU slope on negative side.
35 | float clamp; // Clamp after nonlinearity.
36 | int flip; // Filter kernel flip for gradient computation.
37 |
38 | int tilesXdim; // Original number of horizontal output tiles.
39 | int tilesXrep; // Number of horizontal tiles per CTA.
40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41 |
42 | int4 xShape; // [width, height, channel, batch]
43 | int4 yShape; // [width, height, channel, batch]
44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46 | int swLimit; // Active width of sign tensor in bytes.
47 |
48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49 | longlong4 yStride; //
50 | int64_t bStride; //
51 | longlong3 fuStride; //
52 | longlong3 fdStride; //
53 | };
54 |
55 | struct filtered_lrelu_act_kernel_params
56 | {
57 | void* x; // Input/output, modified in-place.
58 | unsigned char* s; // Sign tensor in/out. NULL if unused.
59 |
60 | float gain; // Additional gain factor.
61 | float slope; // Leaky ReLU slope on negative side.
62 | float clamp; // Clamp after nonlinearity.
63 |
64 | int4 xShape; // [width, height, channel, batch]
65 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68 | };
69 |
70 | //------------------------------------------------------------------------
71 | // CUDA kernel specialization.
72 |
73 | struct filtered_lrelu_kernel_spec
74 | {
75 | void* setup; // Function for filter kernel setup.
76 | void* exec; // Function for main operation.
77 | int2 tileOut; // Width/height of launch tile.
78 | int numWarps; // Number of warps per thread block, determines launch block size.
79 | int xrep; // For processing multiple horizontal tiles per thread block.
80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81 | };
82 |
83 | //------------------------------------------------------------------------
84 | // CUDA kernel selection.
85 |
86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87 | template void* choose_filtered_lrelu_act_kernel(void);
88 | template cudaError_t copy_filters(cudaStream_t stream);
89 |
90 | //------------------------------------------------------------------------
91 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for no signs mode (no gradients required).
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign read mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/torch_utils/ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign write mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/torch_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/torch_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import torch
15 | from pkg_resources import parse_version
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def grid_sample(input, grid):
29 | if _should_use_custom_op():
30 | return _GridSample2dForward.apply(input, grid)
31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def _should_use_custom_op():
36 | return enabled
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | class _GridSample2dForward(torch.autograd.Function):
41 | @staticmethod
42 | def forward(ctx, input, grid):
43 | assert input.ndim == 4
44 | assert grid.ndim == 4
45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
46 | ctx.save_for_backward(input, grid)
47 | return output
48 |
49 | @staticmethod
50 | def backward(ctx, grad_output):
51 | input, grid = ctx.saved_tensors
52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
53 | return grad_input, grad_grid
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | class _GridSample2dBackward(torch.autograd.Function):
58 | @staticmethod
59 | def forward(ctx, grad_output, input, grid):
60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
61 | if _use_pytorch_1_11_api:
62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
64 | else:
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/torch_utils/ops/pybind.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 | //
9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
10 | //
11 |
12 | #include
13 | #include "bias_act.h"
14 | #include "filtered_lrelu.h"
15 | #include "upfirdn2d.h"
16 |
17 | torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp);
18 | torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns);
19 | std::tuple filtered_lrelu(
20 | torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
21 | int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns);
22 | torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain);
23 |
24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
25 | {
26 | m.def("bias_act", &bias_act);
27 | m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
28 | m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
29 | m.def("upfirdn2d", &upfirdn2d);
30 | }
31 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 | //
9 | // Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
10 | //
11 |
12 | #include
13 | #include
14 | #include
15 | #include "upfirdn2d.h"
16 |
17 | //------------------------------------------------------------------------
18 |
19 | torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
20 | {
21 | // Validate arguments.
22 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
23 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
24 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
25 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
26 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
27 | TORCH_CHECK(x.numel() > 0, "x has zero size");
28 | TORCH_CHECK(f.numel() > 0, "f has zero size");
29 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
31 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
32 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
33 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
34 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
35 |
36 | // Create output tensor.
37 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
38 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
39 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
40 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
41 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
42 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
43 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
44 |
45 | // Initialize CUDA kernel parameters.
46 | upfirdn2d_kernel_params p;
47 | p.x = x.data_ptr();
48 | p.f = f.data_ptr();
49 | p.y = y.data_ptr();
50 | p.up = make_int2(upx, upy);
51 | p.down = make_int2(downx, downy);
52 | p.pad0 = make_int2(padx0, pady0);
53 | p.flip = (flip) ? 1 : 0;
54 | p.gain = gain;
55 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
56 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
57 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
58 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
59 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
60 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
61 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
62 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
63 |
64 | // Choose CUDA kernel.
65 | upfirdn2d_kernel_spec spec;
66 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
67 | {
68 | spec = choose_upfirdn2d_kernel(p);
69 | });
70 |
71 | // Set looping options.
72 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
73 | p.loopMinor = spec.loopMinor;
74 | p.loopX = spec.loopX;
75 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
76 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
77 |
78 | // Compute grid size.
79 | dim3 blockSize, gridSize;
80 | if (spec.tileOutW < 0) // large
81 | {
82 | blockSize = dim3(4, 32, 1);
83 | gridSize = dim3(
84 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
85 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
86 | p.launchMajor);
87 | }
88 | else // small
89 | {
90 | blockSize = dim3(256, 1, 1);
91 | gridSize = dim3(
92 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
93 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
94 | p.launchMajor);
95 | }
96 |
97 | // Launch CUDA kernel.
98 | void* args[] = {&p};
99 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
100 | return y;
101 | }
102 |
103 |
--------------------------------------------------------------------------------
/torch_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Facilities for pickling Python code alongside other data.
10 |
11 | The pickled code is automatically imported into a separate Python module
12 | during unpickling. This way, any previously exported pickles will remain
13 | usable even if the original code is no longer available, or if the current
14 | version of the code is not consistent with what was originally pickled."""
15 |
16 | import sys
17 | import pickle
18 | import io
19 | import inspect
20 | import copy
21 | import uuid
22 | import types
23 | import dnnlib
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | _version = 6 # internal version number
28 | _decorators = set() # {decorator_class, ...}
29 | _import_hooks = [] # [hook_function, ...]
30 | _module_to_src_dict = dict() # {module: src, ...}
31 | _src_to_module_dict = dict() # {src: module, ...}
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def persistent_class(orig_class):
36 | r"""Class decorator that extends a given class to save its source code
37 | when pickled.
38 |
39 | Example:
40 |
41 | from torch_utils import persistence
42 |
43 | @persistence.persistent_class
44 | class MyNetwork(torch.nn.Module):
45 | def __init__(self, num_inputs, num_outputs):
46 | super().__init__()
47 | self.fc = MyLayer(num_inputs, num_outputs)
48 | ...
49 |
50 | @persistence.persistent_class
51 | class MyLayer(torch.nn.Module):
52 | ...
53 |
54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55 | source code alongside other internal state (e.g., parameters, buffers,
56 | and submodules). This way, any previously exported pickle will remain
57 | usable even if the class definitions have been modified or are no
58 | longer available.
59 |
60 | The decorator saves the source code of the entire Python module
61 | containing the decorated class. It does *not* save the source code of
62 | any imported modules. Thus, the imported modules must be available
63 | during unpickling, also including `torch_utils.persistence` itself.
64 |
65 | It is ok to call functions defined in the same module from the
66 | decorated class. However, if the decorated class depends on other
67 | classes defined in the same module, they must be decorated as well.
68 | This is illustrated in the above example in the case of `MyLayer`.
69 |
70 | It is also possible to employ the decorator just-in-time before
71 | calling the constructor. For example:
72 |
73 | cls = MyLayer
74 | if want_to_make_it_persistent:
75 | cls = persistence.persistent_class(cls)
76 | layer = cls(num_inputs, num_outputs)
77 |
78 | As an additional feature, the decorator also keeps track of the
79 | arguments that were used to construct each instance of the decorated
80 | class. The arguments can be queried via `obj.init_args` and
81 | `obj.init_kwargs`, and they are automatically pickled alongside other
82 | object state. A typical use case is to first unpickle a previous
83 | instance of a persistent class, and then upgrade it to use the latest
84 | version of the source code:
85 |
86 | with open('old_pickle.pkl', 'rb') as f:
87 | old_net = pickle.load(f)
88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90 | """
91 | assert isinstance(orig_class, type)
92 | if is_persistent(orig_class):
93 | return orig_class
94 |
95 | assert orig_class.__module__ in sys.modules
96 | orig_module = sys.modules[orig_class.__module__]
97 | orig_module_src = _module_to_src(orig_module)
98 |
99 | class Decorator(orig_class):
100 | _orig_module_src = orig_module_src
101 | _orig_class_name = orig_class.__name__
102 |
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 | self._init_args = copy.deepcopy(args)
106 | self._init_kwargs = copy.deepcopy(kwargs)
107 | assert orig_class.__name__ in orig_module.__dict__
108 | _check_pickleable(self.__reduce__())
109 |
110 | @property
111 | def init_args(self):
112 | return copy.deepcopy(self._init_args)
113 |
114 | @property
115 | def init_kwargs(self):
116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117 |
118 | def __reduce__(self):
119 | fields = list(super().__reduce__())
120 | fields += [None] * max(3 - len(fields), 0)
121 | if fields[0] is not _reconstruct_persistent_obj:
122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123 | fields[0] = _reconstruct_persistent_obj # reconstruct func
124 | fields[1] = (meta,) # reconstruct args
125 | fields[2] = None # state dict
126 | return tuple(fields)
127 |
128 | Decorator.__name__ = orig_class.__name__
129 | _decorators.add(Decorator)
130 | return Decorator
131 |
132 | #----------------------------------------------------------------------------
133 |
134 | def is_persistent(obj):
135 | r"""Test whether the given object or class is persistent, i.e.,
136 | whether it will save its source code when pickled.
137 | """
138 | try:
139 | if obj in _decorators:
140 | return True
141 | except TypeError:
142 | pass
143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144 |
145 | #----------------------------------------------------------------------------
146 |
147 | def import_hook(hook):
148 | r"""Register an import hook that is called whenever a persistent object
149 | is being unpickled. A typical use case is to patch the pickled source
150 | code to avoid errors and inconsistencies when the API of some imported
151 | module has changed.
152 |
153 | The hook should have the following signature:
154 |
155 | hook(meta) -> modified meta
156 |
157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158 |
159 | type: Type of the persistent object, e.g. `'class'`.
160 | version: Internal version number of `torch_utils.persistence`.
161 | module_src Original source code of the Python module.
162 | class_name: Class name in the original Python module.
163 | state: Internal state of the object.
164 |
165 | Example:
166 |
167 | @persistence.import_hook
168 | def wreck_my_network(meta):
169 | if meta.class_name == 'MyNetwork':
170 | print('MyNetwork is being imported. I will wreck it!')
171 | meta.module_src = meta.module_src.replace("True", "False")
172 | return meta
173 | """
174 | assert callable(hook)
175 | _import_hooks.append(hook)
176 |
177 | #----------------------------------------------------------------------------
178 |
179 | def _reconstruct_persistent_obj(meta):
180 | r"""Hook that is called internally by the `pickle` module to unpickle
181 | a persistent object.
182 | """
183 | meta = dnnlib.EasyDict(meta)
184 | meta.state = dnnlib.EasyDict(meta.state)
185 | for hook in _import_hooks:
186 | meta = hook(meta)
187 | assert meta is not None
188 |
189 | assert meta.version == _version
190 | module = _src_to_module(meta.module_src)
191 |
192 | assert meta.type == 'class'
193 | orig_class = module.__dict__[meta.class_name]
194 | decorator_class = persistent_class(orig_class)
195 | obj = decorator_class.__new__(decorator_class)
196 |
197 | setstate = getattr(obj, '__setstate__', None)
198 | if callable(setstate):
199 | setstate(meta.state) # pylint: disable=not-callable
200 | else:
201 | obj.__dict__.update(meta.state)
202 | return obj
203 |
204 | #----------------------------------------------------------------------------
205 |
206 | def _module_to_src(module):
207 | r"""Query the source code of a given Python module.
208 | """
209 | src = _module_to_src_dict.get(module, None)
210 | if src is None:
211 | src = inspect.getsource(module)
212 | _module_to_src_dict[module] = src
213 | _src_to_module_dict[src] = module
214 | return src
215 |
216 | def _src_to_module(src):
217 | r"""Get or create a Python module for the given source code.
218 | """
219 | module = _src_to_module_dict.get(src, None)
220 | if module is None:
221 | module_name = "_imported_module_" + uuid.uuid4().hex
222 | module = types.ModuleType(module_name)
223 | sys.modules[module_name] = module
224 | _module_to_src_dict[module] = src
225 | _src_to_module_dict[src] = module
226 | exec(src, module.__dict__) # pylint: disable=exec-used
227 | return module
228 |
229 | #----------------------------------------------------------------------------
230 |
231 | def _check_pickleable(obj):
232 | r"""Check that the given object is pickleable, raising an exception if
233 | it is not. This function is expected to be considerably more efficient
234 | than actually pickling the object.
235 | """
236 | def recurse(obj):
237 | if isinstance(obj, (list, tuple, set)):
238 | return [recurse(x) for x in obj]
239 | if isinstance(obj, dict):
240 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242 | return None # Python primitive types are pickleable.
243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
244 | return None # NumPy arrays and PyTorch tensors are pickleable.
245 | if is_persistent(obj):
246 | return None # Persistent objects are pickleable, by virtue of the constructor check.
247 | return obj
248 | with io.BytesIO() as f:
249 | pickle.dump(recurse(obj), f)
250 |
251 | #----------------------------------------------------------------------------
252 |
--------------------------------------------------------------------------------
/torch_utils/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 | import os
4 |
5 | extension = CUDAExtension('stylegan3_cuda',
6 | sources=[os.path.join('./ops', x) for x in ['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu',
7 | 'filtered_lrelu_ns.cu', 'bias_act.cpp', 'bias_act_kernel.cu', 'pybind.cpp', 'upfirdn2d.cpp', 'upfirdn2d_kernel.cu']],
8 | headers=[os.path.join('./ops', x) for x in ['filtered_lrelu.h', 'filtered_lrelu.cu', 'bias_act.h', 'upfirdn2d.h']],
9 | extra_compile_args = {'cxx': [], 'nvcc': ['--use_fast_math']}
10 | )
11 |
12 | setup(
13 | name='stylegan3_cuda',
14 | ext_modules=[extension],
15 | cmdclass={
16 | 'build_ext': BuildExtension
17 | })
18 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/training/networks_voxgraf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 | #
9 | # Modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids
10 | #
11 |
12 | import torch
13 | from torch_utils import persistence
14 | from training.networks_stylegan2 import SynthesisNetwork as SG2_syn, MappingNetwork as SingleMappingNetwork
15 | from training.networks_stylegan3 import SynthesisLayer as SG3_SynthesisLayer
16 | from training.networks_stylegan2_3d import SynthesisNetwork as Synthesis3D
17 | from training.virtual_camera_utils import PoseSampler, Renderer
18 |
19 | @persistence.persistent_class
20 | class MappingNetwork(torch.nn.Module):
21 | """Wrapper class for fg and bg mapping network."""
22 | def __init__(self,
23 | z_dim, # Input latent (Z) dimensionality.
24 | c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
25 | w_dim, # Intermediate latent (W) dimensionality.
26 | num_ws_fg, # Number of intermediate latents to output for foreground.
27 | num_ws_bg, # Number of intermediate latents to output for background.
28 | use_bg=True, # Model background with 2D GAN
29 | **mapping_kwargs
30 | ):
31 | super().__init__()
32 | self.fg = SingleMappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=num_ws_fg, **mapping_kwargs)
33 | self.use_bg = use_bg
34 | if use_bg:
35 | self.bg = SingleMappingNetwork(z_dim=z_dim, c_dim=0, w_dim=w_dim, num_ws=num_ws_bg, **mapping_kwargs) # bg is not conditioned on pose
36 |
37 | def forward(self, z_fg, z_bg, c_fg, c_bg, **kwargs):
38 | w_fg = self.fg(z_fg, c_fg, **kwargs)
39 | w_bg = None
40 | if self.use_bg:
41 | w_bg = self.bg(z_bg, c_bg, **kwargs)
42 | return w_fg, w_bg
43 |
44 |
45 | @persistence.persistent_class
46 | class RefinementNetwork(torch.nn.Module):
47 | def __init__(self, w_dim, input_resolution, input_channels, output_channels, dhidden, num_layers):
48 | super().__init__()
49 | assert num_layers > 0
50 | last_cutoff = input_resolution / 2
51 | refinement = []
52 | for i in range(num_layers):
53 | is_last = (i == num_layers-1)
54 |
55 | refinement.append(SG3_SynthesisLayer(
56 | w_dim=w_dim, is_torgb=is_last, is_critically_sampled=True, use_fp16=False,
57 | in_channels=input_channels if i==0 else dhidden,
58 | out_channels=dhidden if not is_last else output_channels,
59 | in_size=input_resolution, out_size=input_resolution,
60 | in_sampling_rate=input_resolution, out_sampling_rate=input_resolution,
61 | in_cutoff=last_cutoff, out_cutoff=last_cutoff,
62 | in_half_width=0, out_half_width=0
63 | ))
64 |
65 | self.layers = torch.nn.ModuleList(refinement)
66 |
67 | def forward(self, x, w, update_emas=False):
68 | for layer in self.layers:
69 | x = layer(x, w, update_emas=update_emas)
70 | return x
71 |
72 |
73 | @persistence.persistent_class
74 | class SynthesisNetwork(torch.nn.Module):
75 | def __init__(self,
76 | w_dim, # Intermediate latent (W) dimensionality.
77 | img_resolution, # Output image resolution.
78 | img_channels, # Number of color channels.
79 | use_bg=True, # Model background with 2D GAN
80 | sigma_mpl = 1, # Rescaling factor for density
81 | pose_kwargs={},
82 | render_kwargs={},
83 | fg_kwargs={},
84 | bg_kwargs={},
85 | refinement_kwargs={},
86 | ):
87 | super().__init__()
88 | if img_channels != 3:
89 | raise NotImplementedError
90 |
91 | self.img_resolution = img_resolution
92 | self.use_bg = use_bg
93 | self.use_refinement = refinement_kwargs.get('num_layers', 0) > 0
94 | self.sigma_mpl = sigma_mpl
95 |
96 | self.pose_sampler = PoseSampler(**pose_kwargs)
97 | self.renderer = Renderer(**render_kwargs)
98 |
99 | self.generator_fg = Synthesis3D(w_dim=w_dim, img_channels=img_channels*self.renderer.basis_dim+1, architecture='skip', renderer=self.renderer, **fg_kwargs)
100 | if self.use_bg:
101 | self.generator_bg = SG2_syn(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **bg_kwargs)
102 |
103 | if self.use_refinement:
104 | self.__setattr__(f'r{self.img_resolution}', RefinementNetwork(w_dim, self.img_resolution, input_channels=3, output_channels=img_channels, **refinement_kwargs))
105 |
106 | def __repr__(self):
107 | return (
108 | f"SynthesisNetwork(img_resolution={self.img_resolution}, use_bg={self.use_bg}, use_refinement={self.use_refinement})"
109 | )
110 |
111 | def _get_sparse_grad_indexer(self, device):
112 | indexer = None#self.sparse_grad_indexer # TODO: can we reuse it?
113 | if indexer is None:
114 | indexer = torch.empty((0,), dtype=torch.bool, device=device)
115 | return indexer
116 |
117 | def _get_sparse_sh_grad_indexer(self, device):
118 | indexer = None# self.sparse_sh_grad_indexer # TODO: can we reuse it?
119 | if indexer is None:
120 | indexer = torch.empty((0,), dtype=torch.bool, device=device)
121 | return indexer
122 |
123 | def forward(self, ws, ws_bg, pose=None, noise_mode='none', update_emas=None, return_3d=False, n_views=1, render_alpha=False, render_depth=False, render_vardepth=False, raw_noise_std_sigma=0):
124 | B = ws.shape[0]
125 |
126 | # Get camera poses
127 | if pose is None or torch.isnan(pose).all():
128 | pose = torch.stack([self.pose_sampler.sample_from_poses() for _ in range(B*n_views)])
129 |
130 | # Generate sparse voxel grid
131 | density_data_list, sh_data_list, coords = self.generator_fg(ws, pose=pose, noise_mode=noise_mode, update_emas=update_emas)
132 |
133 | if raw_noise_std_sigma > 0:
134 | for i in range(B):
135 | density_data_list[i] = density_data_list[i] + torch.randn(density_data_list[i].shape, device=ws.device) * raw_noise_std_sigma
136 | if self.sigma_mpl != 1:
137 | for i in range(B):
138 | density_data_list[i] = density_data_list[i] * self.sigma_mpl
139 |
140 | # Render foreground
141 | pred = self.renderer(pose, density_data_list, sh_data_list, coords, img_resolution=self.img_resolution, grid_resolution=self.generator_fg.grid_resolution, render_alpha=(render_alpha or self.use_bg), render_depth=render_depth, render_vardepth=render_vardepth)
142 | pred = pred.view(B*n_views, self.img_resolution, self.img_resolution, 6).permute(0, 3, 1, 2) # (BxHxW)xC -> # BxCxHxW
143 | rgb, alpha, depth, vardepth = pred[:, :3], pred[:, 3:4], pred[:, 4:5], pred[:, 5:6]
144 |
145 | if self.use_bg:
146 | rgb_bg = self.generator_bg(ws_bg, update_emas=update_emas, noise_mode=noise_mode)
147 |
148 | # alpha compositing
149 | rgb_bg = rgb_bg / 2 + 0.5 # [-1, 1] -> [0, 1]
150 | assert alpha.min() >= 0 and alpha.max() <= 1+1e-3 # add some offset due to precision in alpha computation
151 | rgb = alpha * rgb + (1 - alpha) * rgb_bg
152 |
153 | # [0,1] -> [-1, 1]
154 | rgb = rgb * 2 - 1
155 |
156 | # Refine composited image
157 | if self.use_refinement:
158 | w_last = ws[:, -1] # simply reuse last w for refinement layers
159 | rgb = self.__getattr__(f'r{self.img_resolution}')(rgb, w_last, update_emas)
160 |
161 | out = {'rgb': rgb}
162 | if render_alpha:
163 | assert alpha.min() >= 0 and alpha.max() <= 1+1e-3 # add some offset due to precision in alpha computation
164 | out['alpha'] = alpha
165 | if render_depth:
166 | out['depth'] = depth
167 | if render_vardepth:
168 | out['vardepth'] = vardepth
169 | if return_3d:
170 | out['density'] = density_data_list
171 | out['sh'] = sh_data_list
172 | out['coords'] = coords
173 | return out
174 |
175 |
176 | @persistence.persistent_class
177 | class Generator(torch.nn.Module):
178 | def __init__(self,
179 | z_dim, # Input latent (Z) dimensionality.
180 | c_dim, # Conditioning label (C) dimensionality.
181 | w_dim, # Intermediate latent (W) dimensionality.
182 | img_resolution, # Output resolution.
183 | img_channels, # Number of output color channels.
184 | use_bg = True, # Model background with 2D GAN
185 | pose_conditioning = True, # Condition generator on pose
186 | mapping_kwargs = {}, # Arguments for MappingNetwork.
187 | **synthesis_kwargs, # Arguments for SynthesisNetwork.
188 | ):
189 | super().__init__()
190 | self.z_dim = z_dim
191 | self.c_dim = c_dim
192 | self.w_dim = w_dim
193 | self.img_resolution = img_resolution
194 | self.img_channels = img_channels
195 | self.pose_conditioning = pose_conditioning
196 | self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, use_bg=use_bg, **synthesis_kwargs)
197 | self.num_ws = self.synthesis.generator_fg.num_ws
198 | self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws_fg=self.synthesis.generator_fg.num_ws, num_ws_bg=0 if not use_bg else self.synthesis.generator_bg.num_ws, use_bg=use_bg, **mapping_kwargs)
199 |
200 | def forward(self, z, c, pose, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
201 | ws, ws_bg = self.mapping(z, z, c, torch.empty_like(c[:, :0]), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
202 | img = self.synthesis(ws, update_emas=update_emas, ws_bg=ws_bg, pose=pose, **synthesis_kwargs)
203 | return img
204 |
205 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2021, the Plenoxels authors
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 |
28 | This variant of plenoxels was modified by Katja Schwarz for VoxGRAF: Fast 3D-Aware Image Synthesis with Sparse Voxel Grids.
29 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/README.md:
--------------------------------------------------------------------------------
1 | We adapted the CUDA kernels from [Plenoxels: Radiance Fields without Neural Networks](https://arxiv.org/abs/2112.05131).
2 | Try installing pre-compiled CUDA extension:
3 | ```commandline
4 | pip install ../dist/svox2-voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast-cp39-cp39-linux_x86_64.whl
5 | ```
6 | Or install from this directory by running
7 | ```commandline
8 | pip install .
9 | ```
10 | For more details please refer to [the Plenoxels github repository](https://github.com/sxyu/svox2).
11 |
12 | Modifications contain:
13 | * rendering alpha
14 | * rendering depth
15 | * rendering variance of depth
16 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/environment.yml:
--------------------------------------------------------------------------------
1 | # run: conda env create -f environment.yml
2 | name: plenoxel
3 | channels:
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - python=3.8.8
8 | - numpy>=1.16.4,<1.19.0
9 | - pip
10 | - pip:
11 | - imageio
12 | - imageio-ffmpeg
13 | - ipdb
14 | - lpips
15 | - opencv-python>=4.4.0
16 | - Pillow>=7.2.0
17 | - pyyaml>=5.3.1
18 | - tensorboard>=2.4.0
19 | - imageio
20 | - imageio-ffmpeg
21 | - pymcubes
22 | - moviepy
23 | - matplotlib
24 | - scipy>=1.6.0
25 | - pytorch
26 | - torchvision
27 | - cudatoolkit
28 | - tqdm
29 |
30 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/manual_install.sh:
--------------------------------------------------------------------------------
1 | cp svox2/svox2.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/svox2.py
2 | cp svox2/utils.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/utils.py
3 | cp svox2/version.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/version.py
4 | cp svox2/defs.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/defs.py
5 | cp svox2/__init__.py ~/miniconda3/envs/plenoctree/lib/python3.8/site-packages/svox2/__init__.py
6 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | import os
3 | import os.path as osp
4 | import warnings
5 |
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | ROOT_DIR = osp.dirname(osp.abspath(__file__))
9 |
10 | __version__ = None
11 | exec(open('svox2/version.py', 'r').read())
12 |
13 | CUDA_FLAGS = []
14 | INSTALL_REQUIREMENTS = []
15 | include_dirs = [osp.join(ROOT_DIR, "svox2", "csrc", "include")]
16 |
17 | # From PyTorch3D
18 | cub_home = os.environ.get("CUB_HOME", None)
19 | if cub_home is None:
20 | prefix = os.environ.get("CONDA_PREFIX", None)
21 | if prefix is not None and os.path.isdir(prefix + "/include/cub"):
22 | cub_home = prefix + "/include"
23 |
24 | if cub_home is None:
25 | warnings.warn(
26 | "The environment variable `CUB_HOME` was not found."
27 | "Installation will fail if your system CUDA toolkit version is less than 11."
28 | "NVIDIA CUB can be downloaded "
29 | "from `https://github.com/NVIDIA/cub/releases`. You can unpack "
30 | "it to a location of your choice and set the environment variable "
31 | "`CUB_HOME` to the folder containing the `CMakeListst.txt` file."
32 | )
33 | else:
34 | include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " "))
35 |
36 | try:
37 | ext_modules = [
38 | CUDAExtension('svox2.csrc', [
39 | 'svox2/csrc/svox2.cpp',
40 | 'svox2/csrc/svox2_kernel.cu',
41 | 'svox2/csrc/render_lerp_kernel_cuvol.cu',
42 | 'svox2/csrc/render_lerp_kernel_nvol.cu',
43 | 'svox2/csrc/render_svox1_kernel.cu',
44 | 'svox2/csrc/misc_kernel.cu',
45 | 'svox2/csrc/loss_kernel.cu',
46 | 'svox2/csrc/optim_kernel.cu',
47 | ], include_dirs=include_dirs,
48 | optional=False),
49 | ]
50 | except:
51 | import warnings
52 | warnings.warn("Failed to build CUDA extension")
53 | ext_modules = []
54 |
55 | setup(
56 | name='svox2',
57 | version=__version__,
58 | author='Alex Yu / Modified by Katja Schwarz',
59 | author_email='alexyu99126@gmail.com / katja.schwarz@uni-tuebingen.de',
60 | description='PyTorch sparse voxel volume extension, including custom CUDA kernels, adapted some operations to VoxGRAF',
61 | long_description='PyTorch sparse voxel volume extension, including custom CUDA kernels, adapted some operations to VoxGRAF',
62 | ext_modules=ext_modules,
63 | setup_requires=['pybind11>=2.5.0'],
64 | packages=['svox2', 'svox2.csrc'],
65 | cmdclass={'build_ext': BuildExtension},
66 | zip_safe=False,
67 | )
68 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/__init__.py:
--------------------------------------------------------------------------------
1 | from .defs import *
2 | from .svox2 import SparseGrid, Camera, Rays, RenderOptions
3 | from .version import __version__
4 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/.ccls:
--------------------------------------------------------------------------------
1 | %compile_commands.json
2 | %cu -x cuda
3 | %cu --cuda-gpu-arch=sm_61
4 | %cu --cuda-path=/usr/local/cuda-11.2
5 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright 2021 PlenOctree Authors.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions are met:
5 | #
6 | # 1. Redistributions of source code must retain the above copyright notice,
7 | # this list of conditions and the following disclaimer.
8 | #
9 | # 2. Redistributions in binary form must reproduce the above copyright notice,
10 | # this list of conditions and the following disclaimer in the documentation
11 | # and/or other materials provided with the distribution.
12 | #
13 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
23 | # POSSIBILITY OF SUCH DAMAGE.
24 |
25 | # NOTE: This CMakeLists is for development purposes only
26 | # (To check CUDA compile errors)
27 | # It is NOT necessary to use this for installation. Just use pip install .
28 | cmake_minimum_required( VERSION 3.3 )
29 |
30 | if(NOT CMAKE_BUILD_TYPE)
31 | set(CMAKE_BUILD_TYPE Release)
32 | endif()
33 | if (POLICY CMP0048)
34 | cmake_policy(SET CMP0048 NEW)
35 | endif (POLICY CMP0048)
36 | if (POLICY CMP0069)
37 | cmake_policy(SET CMP0069 NEW)
38 | endif (POLICY CMP0069)
39 | if (POLICY CMP0072)
40 | cmake_policy(SET CMP0072 NEW)
41 | endif (POLICY CMP0072)
42 |
43 | project( svox2 )
44 |
45 | set(CMAKE_CXX_STANDARD 14)
46 | enable_language(CUDA)
47 | message(STATUS "CUDA enabled")
48 | set( CMAKE_CUDA_STANDARD 14 )
49 | set( CMAKE_CUDA_STANDARD_REQUIRED ON)
50 | set( CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g -Xcudafe \"--display_error_number --diag_suppress=3057 --diag_suppress=3058 --diag_suppress=3059 --diag_suppress=3060\" -lineinfo -arch=sm_75 ")
51 | # -Xptxas=\"-v\"
52 |
53 | set( INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include" )
54 |
55 | if( MSVC )
56 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd")
57 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT /GLT /Ox")
58 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler=\"/MT\"" )
59 | endif()
60 |
61 | file(GLOB SOURCES
62 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
63 | ${CMAKE_CURRENT_SOURCE_DIR}/*.cu)
64 |
65 | find_package(pybind11 REQUIRED)
66 | find_package(Torch REQUIRED)
67 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
68 |
69 | include_directories (${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
70 |
71 | pybind11_add_module(svox2-test SHARED ${SOURCES})
72 | target_link_libraries(svox2-test PRIVATE "${TORCH_LIBRARIES}")
73 | target_include_directories(svox2-test PRIVATE "${INCLUDE_DIR}")
74 |
75 | if (MSVC)
76 | file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
77 | add_custom_command(TARGET svox2-test
78 | POST_BUILD
79 | COMMAND ${CMAKE_COMMAND} -E copy_if_different
80 | ${TORCH_DLLS}
81 | $)
82 | endif (MSVC)
83 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/include/cubemap_util.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "cuda_util.cuh"
3 | #include
4 | #include
5 |
6 | #define _AXIS(x) (x>>1)
7 | #define _ORI(x) (x&1)
8 | #define _FACE(axis, ori) uint8_t((axis << 1) | ori)
9 |
10 | namespace {
11 | namespace device {
12 |
13 | struct CubemapCoord {
14 | uint8_t face;
15 | float uv[2];
16 | };
17 |
18 | struct CubemapLocation {
19 | uint8_t face;
20 | int16_t uv[2];
21 | };
22 |
23 | struct CubemapBilerpQuery {
24 | CubemapLocation ptr[2][2];
25 | float duv[2];
26 | };
27 |
28 | __device__ __inline__ void
29 | invert_cubemap(int u, int v, float r,
30 | int reso,
31 | float* __restrict__ out) {
32 | const float u_norm = (u + 0.5f) / reso * 2 - 1;
33 | const float v_norm = (v + 0.5f) / reso * 2 - 1;
34 | // EAC
35 | const float tx = tanf((M_PI / 4) * u_norm);
36 | const float ty = tanf((M_PI / 4) * v_norm);
37 | const float common = r * rnorm3df(1.f, tx, ty);
38 | out[0] = tx * common;
39 | out[1] = ty * common;
40 | out[2] = common;
41 | }
42 |
43 | __device__ __inline__ void
44 | invert_cubemap_traditional(int u, int v, float r,
45 | int reso,
46 | float* __restrict__ out) {
47 | const float u_norm = (u + 0.5f) / reso * 2 - 1;
48 | const float v_norm = (v + 0.5f) / reso * 2 - 1;
49 | const float common = r * rnorm3df(1.f, u_norm, v_norm);
50 | out[0] = u_norm * common;
51 | out[1] = v_norm * common;
52 | out[2] = common;
53 | }
54 |
55 | __device__ __host__ __inline__ CubemapCoord
56 | dir_to_cubemap_coord(const float* __restrict__ xyz_o,
57 | int face_reso,
58 | bool eac = true) {
59 | float maxv;
60 | int ax;
61 | float xyz[3] = {xyz_o[0], xyz_o[1], xyz_o[2]};
62 | if (fabsf(xyz[0]) >= fabsf(xyz[1]) && fabsf(xyz[0]) >= fabsf(xyz[2])) {
63 | ax = 0; maxv = xyz[0];
64 | } else if (fabsf(xyz[1]) >= fabsf(xyz[2])) {
65 | ax = 1; maxv = xyz[1];
66 | } else {
67 | ax = 2; maxv = xyz[2];
68 | }
69 | const float recip = 1.f / fabsf(maxv);
70 | xyz[0] *= recip;
71 | xyz[1] *= recip;
72 | xyz[2] *= recip;
73 |
74 | if (eac) {
75 | #pragma unroll 3
76 | for (int i = 0; i < 3; ++i) {
77 | xyz[i] = atanf(xyz[i]) * (4 * M_1_PI);
78 | }
79 | }
80 |
81 | CubemapCoord idx;
82 | idx.uv[0] = ((xyz[(ax ^ 1) & 1] + 1) * face_reso - 1) * 0.5;
83 | idx.uv[1] = ((xyz[(ax ^ 2) & 2] + 1) * face_reso - 1) * 0.5;
84 | const int ori = xyz[ax] >= 0;
85 | idx.face = _FACE(ax, ori);
86 |
87 | return idx;
88 | }
89 |
90 | __device__ __host__ __inline__ CubemapBilerpQuery
91 | cubemap_build_query(
92 | const CubemapCoord& idx,
93 | int face_reso) {
94 | const int uv_idx[2] ={ (int)floorf(idx.uv[0]), (int)floorf(idx.uv[1]) };
95 |
96 | bool m[2][2];
97 | m[0][0] = uv_idx[0] < 0;
98 | m[0][1] = uv_idx[0] > face_reso - 2;
99 | m[1][0] = uv_idx[1] < 0;
100 | m[1][1] = uv_idx[1] > face_reso - 2;
101 |
102 | const int face = idx.face;
103 | const int ax = _AXIS(face);
104 | const int ori = _ORI(face);
105 | // if ax is one of {0, 1, 2}, this trick gets the 2
106 | // of {0, 1, 2} other than ax
107 | const int uvd[2] = {((ax ^ 1) & 1), ((ax ^ 2) & 2)};
108 | int uv_ori[2];
109 |
110 | CubemapBilerpQuery result;
111 | result.duv[0] = idx.uv[0] - uv_idx[0];
112 | result.duv[1] = idx.uv[1] - uv_idx[1];
113 |
114 | #pragma unroll 2
115 | for (uv_ori[0] = 0; uv_ori[0] < 2; ++uv_ori[0]) {
116 | #pragma unroll 2
117 | for (uv_ori[1] = 0; uv_ori[1] < 2; ++uv_ori[1]) {
118 | CubemapLocation& nidx = result.ptr[uv_ori[0]][uv_ori[1]];
119 | nidx.face = face;
120 | nidx.uv[0] = uv_idx[0] + uv_ori[0];
121 | nidx.uv[1] = uv_idx[1] + uv_ori[1];
122 |
123 | const bool mu = m[0][uv_ori[0]];
124 | const bool mv = m[1][uv_ori[1]];
125 |
126 | int edge_idx = -1;
127 | if (mu) {
128 | // Crosses edge in u-axis
129 | if (mv) {
130 | // FIXME: deal with corners properly, right now
131 | // just clamps, resulting in a little artifact
132 | // at each cube corner
133 | nidx.uv[0] = min(max(nidx.uv[0], 0), face_reso - 1);
134 | nidx.uv[1] = min(max(nidx.uv[1], 0), face_reso - 1);
135 | } else {
136 | edge_idx = 0;
137 | }
138 | } else if (mv) {
139 | // Crosses edge in v-axis
140 | edge_idx = 1;
141 | }
142 | if (~edge_idx) {
143 | const int nax = uvd[edge_idx];
144 | const int16_t other_coord = nidx.uv[1 - edge_idx];
145 |
146 | // Determine directions in the new face
147 | const int nud = (nax ^ 1) & 1;
148 | // const int nvd = (nax ^ 2) & 2;
149 |
150 | if (nud == ax) {
151 | nidx.uv[0] = ori ? (face_reso - 1) : 0;
152 | nidx.uv[1] = other_coord;
153 | } else {
154 | nidx.uv[0] = other_coord;
155 | nidx.uv[1] = ori ? (face_reso - 1) : 0;
156 | }
157 |
158 | nidx.face = _FACE(nax, uv_ori[edge_idx]);
159 | }
160 | // Interior point: nothing needs to be done
161 |
162 | }
163 | }
164 |
165 | return result;
166 | }
167 |
168 | __device__ __host__ __inline__ float
169 | cubemap_sample(
170 | const float* __restrict__ cubemap, // (6, face_reso, face_reso, n_channels)
171 | const CubemapBilerpQuery& query,
172 | int face_reso,
173 | int n_channels,
174 | int chnl_id) {
175 |
176 | // NOTE: assuming address will fit in int32
177 | const int stride1 = face_reso * n_channels;
178 | const int stride0 = face_reso * stride1;
179 | const CubemapLocation& p00 = query.ptr[0][0];
180 | const float v00 = cubemap[p00.face * stride0 + p00.uv[0] * stride1 + p00.uv[1] * n_channels + chnl_id];
181 | const CubemapLocation& p01 = query.ptr[0][1];
182 | const float v01 = cubemap[p01.face * stride0 + p01.uv[0] * stride1 + p01.uv[1] * n_channels + chnl_id];
183 | const CubemapLocation& p10 = query.ptr[1][0];
184 | const float v10 = cubemap[p10.face * stride0 + p10.uv[0] * stride1 + p10.uv[1] * n_channels + chnl_id];
185 | const CubemapLocation& p11 = query.ptr[1][1];
186 | const float v11 = cubemap[p11.face * stride0 + p11.uv[0] * stride1 + p11.uv[1] * n_channels + chnl_id];
187 |
188 | const float val0 = lerp(v00, v01, query.duv[1]);
189 | const float val1 = lerp(v10, v11, query.duv[1]);
190 |
191 | return lerp(val0, val1, query.duv[0]);
192 | }
193 |
194 | __device__ __inline__ void
195 | cubemap_sample_backward(
196 | float* __restrict__ cubemap_grad, // (6, face_reso, face_reso, n_channels)
197 | const CubemapBilerpQuery& query,
198 | int face_reso,
199 | int n_channels,
200 | float grad_out,
201 | int chnl_id,
202 | bool* __restrict__ mask_out = nullptr) {
203 |
204 | // NOTE: assuming address will fit in int32
205 | const float bu = query.duv[0], bv = query.duv[1];
206 | const float au = 1.f - bu, av = 1.f - bv;
207 |
208 | #define _ADD_CUBEVERT(i, j, val) { \
209 | const CubemapLocation& p00 = query.ptr[i][j]; \
210 | const int idx = (p00.face * face_reso + p00.uv[0]) * face_reso + p00.uv[1]; \
211 | float* __restrict__ v00 = &cubemap_grad[idx * n_channels + chnl_id]; \
212 | atomicAdd(v00, val); \
213 | if (mask_out != nullptr) { \
214 | mask_out[idx] = true; \
215 | } \
216 | }
217 |
218 | _ADD_CUBEVERT(0, 0, au * av * grad_out);
219 | _ADD_CUBEVERT(0, 1, au * bv * grad_out);
220 | _ADD_CUBEVERT(1, 0, bu * av * grad_out);
221 | _ADD_CUBEVERT(1, 1, bu * bv * grad_out);
222 | #undef _ADD_CUBEVERT
223 |
224 | }
225 |
226 | __device__ __host__ __inline__ float
227 | multi_cubemap_sample(
228 | const float* __restrict__ cubemap1, // (6, face_reso, face_reso, n_channels)
229 | const float* __restrict__ cubemap2, // (6, face_reso, face_reso, n_channels)
230 | const CubemapBilerpQuery& query,
231 | float interp_wt,
232 | int face_reso,
233 | int n_channels,
234 | int chnl_id) {
235 | const float val1 = cubemap_sample(cubemap1,
236 | query,
237 | face_reso,
238 | n_channels,
239 | chnl_id);
240 | const float val2 = cubemap_sample(cubemap2,
241 | query,
242 | face_reso,
243 | n_channels,
244 | chnl_id);
245 | return lerp(val1, val2, interp_wt);
246 | }
247 |
248 | __device__ __inline__ void
249 | multi_cubemap_sample_backward(
250 | float* __restrict__ cubemap_grad1, // (6, face_reso, face_reso, n_channels)
251 | float* __restrict__ cubemap_grad2, // (6, face_reso, face_reso, n_channels)
252 | const CubemapBilerpQuery& query,
253 | float interp_wt,
254 | int face_reso,
255 | int n_channels,
256 | float grad_out,
257 | int chnl_id,
258 | bool* __restrict__ mask_out1 = nullptr,
259 | bool* __restrict__ mask_out2 = nullptr) {
260 | if (cubemap_grad1 == nullptr) return;
261 | cubemap_sample_backward(cubemap_grad1,
262 | query,
263 | face_reso,
264 | n_channels,
265 | grad_out * (1.f - interp_wt),
266 | chnl_id,
267 | mask_out1);
268 | cubemap_sample_backward(cubemap_grad2,
269 | query,
270 | face_reso,
271 | n_channels,
272 | grad_out * interp_wt,
273 | chnl_id,
274 | mask_out1 == nullptr ? nullptr : mask_out2);
275 | }
276 |
277 |
278 | } // namespace device
279 | } // namespace
280 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/include/cuda_util.cuh:
--------------------------------------------------------------------------------
1 | // Copyright 2021 Alex Yu
2 | #pragma once
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include "util.hpp"
8 |
9 |
10 | #define DEVICE_GUARD(_ten) \
11 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));
12 |
13 | #define CUDA_GET_THREAD_ID(tid, Q) const int tid = blockIdx.x * blockDim.x + threadIdx.x; \
14 | if (tid >= Q) return
15 | #define CUDA_GET_THREAD_ID_U64(tid, Q) const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; \
16 | if (tid >= Q) return
17 | #define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1)
18 | #define CUDA_CHECK_ERRORS \
19 | cudaError_t err = cudaGetLastError(); \
20 | if (err != cudaSuccess) \
21 | printf("Error in svox2.%s : %s\n", __FUNCTION__, cudaGetErrorString(err))
22 |
23 | #define CUDA_MAX_THREADS at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock
24 |
25 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
26 | #else
27 | __device__ inline double atomicAdd(double* address, double val){
28 | unsigned long long int* address_as_ull = (unsigned long long int*)address;
29 | unsigned long long int old = *address_as_ull, assumed;
30 | do {
31 | assumed = old;
32 | old = atomicCAS(address_as_ull, assumed,
33 | __double_as_longlong(val + __longlong_as_double(assumed)));
34 | } while (assumed != old);
35 | return __longlong_as_double(old);
36 | }
37 | #endif
38 |
39 | __device__ inline void atomicMax(float* result, float value){
40 | unsigned* result_as_u = (unsigned*)result;
41 | unsigned old = *result_as_u, assumed;
42 | do {
43 | assumed = old;
44 | old = atomicCAS(result_as_u, assumed,
45 | __float_as_int(fmaxf(value, __int_as_float(assumed))));
46 | } while (old != assumed);
47 | return;
48 | }
49 |
50 | __device__ inline void atomicMax(double* result, double value){
51 | unsigned long long int* result_as_ull = (unsigned long long int*)result;
52 | unsigned long long int old = *result_as_ull, assumed;
53 | do {
54 | assumed = old;
55 | old = atomicCAS(result_as_ull, assumed,
56 | __double_as_longlong(fmaxf(value, __longlong_as_double(assumed))));
57 | } while (old != assumed);
58 | return;
59 | }
60 |
61 | __device__ __inline__ void transform_coord(float* __restrict__ point,
62 | const float* __restrict__ scaling,
63 | const float* __restrict__ offset) {
64 | point[0] = fmaf(point[0], scaling[0], offset[0]); // a*b + c
65 | point[1] = fmaf(point[1], scaling[1], offset[1]); // a*b + c
66 | point[2] = fmaf(point[2], scaling[2], offset[2]); // a*b + c
67 | }
68 |
69 | // Linear interp
70 | // Subtract and fused multiply-add
71 | // (1-w) a + w b
72 | template
73 | __host__ __device__ __inline__ T lerp(T a, T b, T w) {
74 | return fmaf(w, b - a, a);
75 | }
76 |
77 | __device__ __inline__ static float _norm(
78 | const float* __restrict__ dir) {
79 | // return sqrtf(dir[0] * dir[0] + dir[1] * dir[1] + dir[2] * dir[2]);
80 | return norm3df(dir[0], dir[1], dir[2]);
81 | }
82 |
83 | __device__ __inline__ static float _rnorm(
84 | const float* __restrict__ dir) {
85 | // return 1.f / _norm(dir);
86 | return rnorm3df(dir[0], dir[1], dir[2]);
87 | }
88 |
89 | __host__ __device__ __inline__ static void xsuby3d(
90 | float* __restrict__ x,
91 | const float* __restrict__ y) {
92 | x[0] -= y[0];
93 | x[1] -= y[1];
94 | x[2] -= y[2];
95 | }
96 |
97 | __host__ __device__ __inline__ static float _dot(
98 | const float* __restrict__ x,
99 | const float* __restrict__ y) {
100 | return x[0] * y[0] + x[1] * y[1] + x[2] * y[2];
101 | }
102 |
103 | __host__ __device__ __inline__ static void _cross(
104 | const float* __restrict__ a,
105 | const float* __restrict__ b,
106 | float* __restrict__ out) {
107 | out[0] = a[1] * b[2] - a[2] * b[1];
108 | out[1] = a[2] * b[0] - a[0] * b[2];
109 | out[2] = a[0] * b[1] - a[1] * b[0];
110 | }
111 |
112 | __device__ __inline__ static float _dist_ray_to_origin(
113 | const float* __restrict__ origin,
114 | const float* __restrict__ dir) {
115 | // dir must be unit vector
116 | float tmp[3];
117 | _cross(origin, dir, tmp);
118 | return _norm(tmp);
119 | }
120 |
121 | #define int_div2_ceil(x) ((((x) - 1) >> 1) + 1)
122 |
123 | __host__ __inline__ cudaError_t cuda_assert(
124 | const cudaError_t code, const char* const file,
125 | const int line, const bool abort) {
126 | if (code != cudaSuccess) {
127 | fprintf(stderr, "cuda_assert: %s %s %s %d\n", cudaGetErrorName(code) ,cudaGetErrorString(code),
128 | file, line);
129 |
130 | if (abort) {
131 | cudaDeviceReset();
132 | exit(code);
133 | }
134 | }
135 |
136 | return code;
137 | }
138 |
139 | #define cuda(...) cuda_assert((cuda##__VA_ARGS__), __FILE__, __LINE__, true);
140 |
141 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/include/data_spec.hpp:
--------------------------------------------------------------------------------
1 | // Copyright 2021 Alex Yu
2 | #pragma once
3 | #include "util.hpp"
4 | #include
5 |
6 | using torch::Tensor;
7 |
8 | enum BasisType {
9 | // For svox 1 compatibility
10 | // BASIS_TYPE_RGBA = 0
11 | BASIS_TYPE_SH = 1,
12 | // BASIS_TYPE_SG = 2
13 | // BASIS_TYPE_ASG = 3
14 | BASIS_TYPE_3D_TEXTURE = 4,
15 | BASIS_TYPE_MLP = 255,
16 | };
17 |
18 | struct SparseGridSpec {
19 | Tensor density_data;
20 | Tensor sh_data;
21 | Tensor links;
22 | Tensor _offset;
23 | Tensor _scaling;
24 |
25 | Tensor background_links;
26 | Tensor background_data;
27 |
28 | int basis_dim;
29 | uint8_t basis_type;
30 | Tensor basis_data;
31 |
32 | inline void check() {
33 | CHECK_INPUT(density_data);
34 | CHECK_INPUT(sh_data);
35 | CHECK_INPUT(links);
36 | if (background_links.defined()) {
37 | CHECK_INPUT(background_links);
38 | CHECK_INPUT(background_data);
39 | TORCH_CHECK(background_links.ndimension() ==
40 | 2); // (H, W) -> [N] \cup {-1}
41 | TORCH_CHECK(background_data.ndimension() == 3); // (N, D, C) -> R
42 | }
43 | if (basis_data.defined()) {
44 | CHECK_INPUT(basis_data);
45 | }
46 | CHECK_CPU_INPUT(_offset);
47 | CHECK_CPU_INPUT(_scaling);
48 | TORCH_CHECK(density_data.ndimension() == 2);
49 | TORCH_CHECK(sh_data.ndimension() == 2);
50 | TORCH_CHECK(links.ndimension() == 3);
51 | }
52 | };
53 |
54 | struct GridOutputGrads {
55 | torch::Tensor grad_density_out;
56 | torch::Tensor grad_sh_out;
57 | torch::Tensor grad_basis_out;
58 | torch::Tensor grad_background_out;
59 |
60 | torch::Tensor mask_out;
61 | torch::Tensor mask_background_out;
62 | inline void check() {
63 | if (grad_density_out.defined()) {
64 | CHECK_INPUT(grad_density_out);
65 | }
66 | if (grad_sh_out.defined()) {
67 | CHECK_INPUT(grad_sh_out);
68 | }
69 | if (grad_basis_out.defined()) {
70 | CHECK_INPUT(grad_basis_out);
71 | }
72 | if (grad_background_out.defined()) {
73 | CHECK_INPUT(grad_background_out);
74 | }
75 | if (mask_out.defined() && mask_out.size(0) > 0) {
76 | CHECK_INPUT(mask_out);
77 | }
78 | if (mask_background_out.defined() && mask_background_out.size(0) > 0) {
79 | CHECK_INPUT(mask_background_out);
80 | }
81 | }
82 | };
83 |
84 | struct CameraSpec {
85 | torch::Tensor c2w;
86 | float fx;
87 | float fy;
88 | float cx;
89 | float cy;
90 | int width;
91 | int height;
92 |
93 | float ndc_coeffx;
94 | float ndc_coeffy;
95 |
96 | inline void check() {
97 | CHECK_INPUT(c2w);
98 | TORCH_CHECK(c2w.is_floating_point());
99 | TORCH_CHECK(c2w.ndimension() == 2);
100 | TORCH_CHECK(c2w.size(1) == 4);
101 | }
102 | };
103 |
104 | struct RaysSpec {
105 | Tensor origins;
106 | Tensor dirs;
107 | inline void check() {
108 | CHECK_INPUT(origins);
109 | CHECK_INPUT(dirs);
110 | TORCH_CHECK(origins.is_floating_point());
111 | TORCH_CHECK(dirs.is_floating_point());
112 | }
113 | };
114 |
115 | struct RenderOptions {
116 | float background_brightness;
117 | // float step_epsilon;
118 | float step_size;
119 | float sigma_thresh;
120 | float stop_thresh;
121 |
122 | float near_clip;
123 | bool use_spheric_clip;
124 |
125 | bool last_sample_opaque;
126 |
127 | // bool randomize;
128 | // float random_sigma_std;
129 | // float random_sigma_std_background;
130 | // 32-bit RNG state masks
131 | // uint32_t _m1, _m2, _m3;
132 |
133 | // int msi_start_layer = 0;
134 | // int msi_end_layer = 66;
135 | };
136 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/include/data_spec_packed.cuh:
--------------------------------------------------------------------------------
1 | // Copyright 2021 Alex Yu
2 | #pragma once
3 | #include
4 | #include "data_spec.hpp"
5 | #include "cuda_util.cuh"
6 | #include "random_util.cuh"
7 |
8 | namespace {
9 | namespace device {
10 |
11 | struct PackedSparseGridSpec {
12 | PackedSparseGridSpec(SparseGridSpec& spec)
13 | :
14 | density_data(spec.density_data.data_ptr()),
15 | sh_data(spec.sh_data.data_ptr()),
16 | links(spec.links.data_ptr()),
17 | basis_type(spec.basis_type),
18 | basis_data(spec.basis_data.defined() ? spec.basis_data.data_ptr() : nullptr),
19 | background_links(spec.background_links.defined() ?
20 | spec.background_links.data_ptr() :
21 | nullptr),
22 | background_data(spec.background_data.defined() ?
23 | spec.background_data.data_ptr() :
24 | nullptr),
25 | size{(int)spec.links.size(0),
26 | (int)spec.links.size(1),
27 | (int)spec.links.size(2)},
28 | stride_x{(int)spec.links.stride(0)},
29 | background_reso{
30 | spec.background_links.defined() ? (int)spec.background_links.size(1) : 0,
31 | },
32 | background_nlayers{
33 | spec.background_data.defined() ? (int)spec.background_data.size(1) : 0
34 | },
35 | basis_dim(spec.basis_dim),
36 | sh_data_dim((int)spec.sh_data.size(1)),
37 | basis_reso(spec.basis_data.defined() ? spec.basis_data.size(0) : 0),
38 | _offset{spec._offset.data_ptr()[0],
39 | spec._offset.data_ptr()[1],
40 | spec._offset.data_ptr()[2]},
41 | _scaling{spec._scaling.data_ptr()[0],
42 | spec._scaling.data_ptr()[1],
43 | spec._scaling.data_ptr()[2]} {
44 | }
45 |
46 | float* __restrict__ density_data;
47 | float* __restrict__ sh_data;
48 | const int32_t* __restrict__ links;
49 |
50 | const uint8_t basis_type;
51 | float* __restrict__ basis_data;
52 |
53 | const int32_t* __restrict__ background_links;
54 | float* __restrict__ background_data;
55 |
56 | const int size[3], stride_x;
57 | const int background_reso, background_nlayers;
58 |
59 | const int basis_dim, sh_data_dim, basis_reso;
60 | const float _offset[3];
61 | const float _scaling[3];
62 | };
63 |
64 | struct PackedGridOutputGrads {
65 | PackedGridOutputGrads(GridOutputGrads& grads) :
66 | grad_density_out(grads.grad_density_out.defined() ? grads.grad_density_out.data_ptr() : nullptr),
67 | grad_sh_out(grads.grad_sh_out.defined() ? grads.grad_sh_out.data_ptr() : nullptr),
68 | grad_basis_out(grads.grad_basis_out.defined() ? grads.grad_basis_out.data_ptr() : nullptr),
69 | grad_background_out(grads.grad_background_out.defined() ? grads.grad_background_out.data_ptr() : nullptr),
70 | mask_out((grads.mask_out.defined() && grads.mask_out.size(0) > 0) ? grads.mask_out.data_ptr() : nullptr),
71 | mask_background_out((grads.mask_background_out.defined() && grads.mask_background_out.size(0) > 0) ? grads.mask_background_out.data_ptr() : nullptr)
72 | {}
73 | float* __restrict__ grad_density_out;
74 | float* __restrict__ grad_sh_out;
75 | float* __restrict__ grad_basis_out;
76 | float* __restrict__ grad_background_out;
77 |
78 | bool* __restrict__ mask_out;
79 | bool* __restrict__ mask_background_out;
80 | };
81 |
82 | struct PackedCameraSpec {
83 | PackedCameraSpec(CameraSpec& cam) :
84 | c2w(cam.c2w.packed_accessor32()),
85 | fx(cam.fx), fy(cam.fy),
86 | cx(cam.cx), cy(cam.cy),
87 | width(cam.width), height(cam.height),
88 | ndc_coeffx(cam.ndc_coeffx), ndc_coeffy(cam.ndc_coeffy) {}
89 | const torch::PackedTensorAccessor32
90 | c2w;
91 | float fx;
92 | float fy;
93 | float cx;
94 | float cy;
95 | int width;
96 | int height;
97 |
98 | float ndc_coeffx;
99 | float ndc_coeffy;
100 | };
101 |
102 | struct PackedRaysSpec {
103 | const torch::PackedTensorAccessor32 origins;
104 | const torch::PackedTensorAccessor32 dirs;
105 | PackedRaysSpec(RaysSpec& spec) :
106 | origins(spec.origins.packed_accessor32()),
107 | dirs(spec.dirs.packed_accessor32())
108 | { }
109 | };
110 |
111 | struct SingleRaySpec {
112 | SingleRaySpec() = default;
113 | __device__ SingleRaySpec(const float* __restrict__ origin, const float* __restrict__ dir)
114 | : origin{origin[0], origin[1], origin[2]},
115 | dir{dir[0], dir[1], dir[2]} {}
116 | __device__ void set(const float* __restrict__ origin, const float* __restrict__ dir) {
117 | #pragma unroll 3
118 | for (int i = 0; i < 3; ++i) {
119 | this->origin[i] = origin[i];
120 | this->dir[i] = dir[i];
121 | }
122 | }
123 |
124 | float origin[3];
125 | float dir[3];
126 | float tmin, tmax, world_step;
127 |
128 | float pos[3];
129 | int32_t l[3];
130 | RandomEngine32 rng;
131 | };
132 |
133 | } // namespace device
134 | } // namespace
135 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/include/random_util.cuh:
--------------------------------------------------------------------------------
1 | // Copyright 2021 Alex Yu
2 | #pragma once
3 | #include
4 | #include
5 |
6 | // A custom xorshift random generator
7 | // Maybe replace with some CUDA internal stuff?
8 | struct RandomEngine32 {
9 | uint32_t x, y, z;
10 |
11 | // Inclusive both
12 | __host__ __device__
13 | uint32_t randint(uint32_t lo, uint32_t hi) {
14 | if (hi <= lo) return lo;
15 | uint32_t z = (*this)();
16 | return z % (hi - lo + 1) + lo;
17 | }
18 |
19 | __host__ __device__
20 | void rand2(float* out1, float* out2) {
21 | const uint32_t z = (*this)();
22 | const uint32_t fmax = (1 << 16);
23 | const uint32_t z1 = z >> 16;
24 | const uint32_t z2 = z & (fmax - 1);
25 | const float ifmax = 1.f / fmax;
26 |
27 | *out1 = z1 * ifmax;
28 | *out2 = z2 * ifmax;
29 | }
30 |
31 | __host__ __device__
32 | float rand() {
33 | uint32_t z = (*this)();
34 | return float(z) / (1LL << 32);
35 | }
36 |
37 |
38 | __host__ __device__
39 | void randn2(float* out1, float* out2) {
40 | rand2(out1, out2);
41 | // Box-Muller transform
42 | const float srlog = sqrtf(-2 * logf(*out1 + 1e-32f));
43 | *out2 *= 2 * M_PI;
44 | *out1 = srlog * cosf(*out2);
45 | *out2 = srlog * sinf(*out2);
46 | }
47 |
48 | __host__ __device__
49 | float randn() {
50 | float x, y;
51 | rand2(&x, &y);
52 | // Box-Muller transform
53 | return sqrtf(-2 * logf(x + 1e-32f))* cosf(2 * M_PI * y);
54 | }
55 |
56 | __host__ __device__
57 | uint32_t operator()() {
58 | uint32_t t;
59 | x ^= x << 16;
60 | x ^= x >> 5;
61 | x ^= x << 1;
62 | t = x;
63 | x = y;
64 | y = z;
65 | z = t ^ x ^ y;
66 | return z;
67 | }
68 | };
69 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/include/util.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | // Changed from x.type().is_cuda() due to deprecation
3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
4 | #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
5 | #define CHECK_CONTIGUOUS(x) \
6 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
7 | #define CHECK_INPUT(x) \
8 | CHECK_CUDA(x); \
9 | CHECK_CONTIGUOUS(x)
10 | #define CHECK_CPU_INPUT(x) \
11 | CHECK_CPU(x); \
12 | CHECK_CONTIGUOUS(x)
13 |
14 | #if defined(__CUDACC__)
15 | // #define _EXP(x) expf(x) // SLOW EXP
16 | #define _EXP(x) __expf(x) // FAST EXP
17 | #define _SIGMOID(x) (1 / (1 + _EXP(-(x))))
18 |
19 | #else
20 |
21 | #define _EXP(x) expf(x)
22 | #define _SIGMOID(x) (1 / (1 + expf(-(x))))
23 | #endif
24 | #define _SQR(x) ((x) * (x))
25 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/optim_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright 2021 Alex Yu
2 | // Optimizer-related kernels
3 |
4 | #include
5 | #include "cuda_util.cuh"
6 |
7 | namespace {
8 |
9 | const int RMSPROP_STEP_CUDA_THREADS = 256;
10 | const int MIN_BLOCKS_PER_SM = 4;
11 |
12 | namespace device {
13 |
14 | // RMSPROP
15 | __inline__ __device__ void rmsprop_once(
16 | float* __restrict__ ptr_data,
17 | float* __restrict__ ptr_rms,
18 | float* __restrict__ ptr_grad,
19 | const float beta, const float lr, const float epsilon, float minval) {
20 | float rms = *ptr_rms;
21 | rms = rms == 0.f ? _SQR(*ptr_grad) : lerp(_SQR(*ptr_grad), rms, beta);
22 | *ptr_rms = rms;
23 | *ptr_data = fmaxf(*ptr_data - lr * (*ptr_grad) / (sqrtf(rms) + epsilon), minval);
24 | *ptr_grad = 0.f;
25 | }
26 |
27 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
28 | __global__ void rmsprop_step_kernel(
29 | torch::PackedTensorAccessor64 all_data,
30 | torch::PackedTensorAccessor64 all_rms,
31 | torch::PackedTensorAccessor64 all_grad,
32 | float beta,
33 | float lr,
34 | float epsilon,
35 | float minval,
36 | float lr_last) {
37 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
38 | int32_t chnl = tid % all_data.size(1);
39 | rmsprop_once(all_data.data() + tid,
40 | all_rms.data() + tid,
41 | all_grad.data() + tid,
42 | beta,
43 | (chnl == all_data.size(1) - 1) ? lr_last : lr,
44 | epsilon,
45 | minval);
46 | }
47 |
48 |
49 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
50 | __global__ void rmsprop_mask_step_kernel(
51 | torch::PackedTensorAccessor64 all_data,
52 | torch::PackedTensorAccessor64 all_rms,
53 | torch::PackedTensorAccessor64 all_grad,
54 | const bool* __restrict__ mask,
55 | float beta,
56 | float lr,
57 | float epsilon,
58 | float minval,
59 | float lr_last) {
60 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
61 | if (mask[tid / all_data.size(1)] == false) return;
62 | int32_t chnl = tid % all_data.size(1);
63 | rmsprop_once(all_data.data() + tid,
64 | all_rms.data() + tid,
65 | all_grad.data() + tid,
66 | beta,
67 | (chnl == all_data.size(1) - 1) ? lr_last : lr,
68 | epsilon,
69 | minval);
70 | }
71 |
72 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
73 | __global__ void rmsprop_index_step_kernel(
74 | torch::PackedTensorAccessor64 all_data,
75 | torch::PackedTensorAccessor64 all_rms,
76 | torch::PackedTensorAccessor64 all_grad,
77 | torch::PackedTensorAccessor32 indices,
78 | float beta,
79 | float lr,
80 | float epsilon,
81 | float minval,
82 | float lr_last) {
83 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1));
84 | int32_t i = indices[tid / all_data.size(1)];
85 | int32_t chnl = tid % all_data.size(1);
86 | size_t off = i * all_data.size(1) + chnl;
87 | rmsprop_once(all_data.data() + off, all_rms.data() + off,
88 | all_grad.data() + off,
89 | beta,
90 | (chnl == all_data.size(1) - 1) ? lr_last : lr,
91 | epsilon,
92 | minval);
93 | }
94 |
95 |
96 | // SGD
97 | __inline__ __device__ void sgd_once(
98 | float* __restrict__ ptr_data,
99 | float* __restrict__ ptr_grad,
100 | const float lr) {
101 | *ptr_data -= lr * (*ptr_grad);
102 | *ptr_grad = 0.f;
103 | }
104 |
105 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
106 | __global__ void sgd_step_kernel(
107 | torch::PackedTensorAccessor64 all_data,
108 | torch::PackedTensorAccessor64 all_grad,
109 | float lr,
110 | float lr_last) {
111 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
112 | int32_t chnl = tid % all_data.size(1);
113 | sgd_once(all_data.data() + tid,
114 | all_grad.data() + tid,
115 | (chnl == all_data.size(1) - 1) ? lr_last : lr);
116 | }
117 |
118 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
119 | __global__ void sgd_mask_step_kernel(
120 | torch::PackedTensorAccessor64 all_data,
121 | torch::PackedTensorAccessor64 all_grad,
122 | const bool* __restrict__ mask,
123 | float lr,
124 | float lr_last) {
125 | CUDA_GET_THREAD_ID(tid, all_data.size(0) * all_data.size(1));
126 | if (mask[tid / all_data.size(1)] == false) return;
127 | int32_t chnl = tid % all_data.size(1);
128 | sgd_once(all_data.data() + tid,
129 | all_grad.data() + tid,
130 | (chnl == all_data.size(1) - 1) ? lr_last : lr);
131 | }
132 |
133 | __launch_bounds__(RMSPROP_STEP_CUDA_THREADS, MIN_BLOCKS_PER_SM)
134 | __global__ void sgd_index_step_kernel(
135 | torch::PackedTensorAccessor64 all_data,
136 | torch::PackedTensorAccessor64 all_grad,
137 | torch::PackedTensorAccessor32 indices,
138 | float lr,
139 | float lr_last) {
140 | CUDA_GET_THREAD_ID(tid, indices.size(0) * all_data.size(1));
141 | int32_t i = indices[tid / all_data.size(1)];
142 | int32_t chnl = tid % all_data.size(1);
143 | size_t off = i * all_data.size(1) + chnl;
144 | sgd_once(all_data.data() + off,
145 | all_grad.data() + off,
146 | (chnl == all_data.size(1) - 1) ? lr_last : lr);
147 | }
148 |
149 |
150 |
151 | } // namespace device
152 | } // namespace
153 |
154 | void rmsprop_step(
155 | torch::Tensor data,
156 | torch::Tensor rms,
157 | torch::Tensor grad,
158 | torch::Tensor indexer,
159 | float beta,
160 | float lr,
161 | float epsilon,
162 | float minval,
163 | float lr_last) {
164 |
165 | DEVICE_GUARD(data);
166 | CHECK_INPUT(data);
167 | CHECK_INPUT(rms);
168 | CHECK_INPUT(grad);
169 | CHECK_INPUT(indexer);
170 |
171 | if (lr_last < 0.f) lr_last = lr;
172 |
173 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS;
174 |
175 | if (indexer.dim() == 0) {
176 | const size_t Q = data.size(0) * data.size(1);
177 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
178 | device::rmsprop_step_kernel<<>>(
179 | data.packed_accessor64(),
180 | rms.packed_accessor64(),
181 | grad.packed_accessor64(),
182 | beta,
183 | lr,
184 | epsilon,
185 | minval,
186 | lr_last);
187 | } else if (indexer.size(0) == 0) {
188 | // Skip
189 | } else if (indexer.scalar_type() == at::ScalarType::Bool) {
190 | const size_t Q = data.size(0) * data.size(1);
191 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
192 | device::rmsprop_mask_step_kernel<<>>(
193 | data.packed_accessor64(),
194 | rms.packed_accessor64(),
195 | grad.packed_accessor64(),
196 | indexer.data_ptr(),
197 | beta,
198 | lr,
199 | epsilon,
200 | minval,
201 | lr_last);
202 | } else {
203 | const size_t Q = indexer.size(0) * data.size(1);
204 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
205 | device::rmsprop_index_step_kernel<<>>(
206 | data.packed_accessor64(),
207 | rms.packed_accessor64(),
208 | grad.packed_accessor64(),
209 | indexer.packed_accessor32(),
210 | beta,
211 | lr,
212 | epsilon,
213 | minval,
214 | lr_last);
215 | }
216 |
217 | CUDA_CHECK_ERRORS;
218 | }
219 |
220 | void sgd_step(
221 | torch::Tensor data,
222 | torch::Tensor grad,
223 | torch::Tensor indexer,
224 | float lr,
225 | float lr_last) {
226 |
227 | DEVICE_GUARD(data);
228 | CHECK_INPUT(data);
229 | CHECK_INPUT(grad);
230 | CHECK_INPUT(indexer);
231 |
232 | if (lr_last < 0.f) lr_last = lr;
233 |
234 | const int cuda_n_threads = RMSPROP_STEP_CUDA_THREADS;
235 |
236 | if (indexer.dim() == 0) {
237 | const size_t Q = data.size(0) * data.size(1);
238 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
239 | device::sgd_step_kernel<<>>(
240 | data.packed_accessor64(),
241 | grad.packed_accessor64(),
242 | lr,
243 | lr_last);
244 | } else if (indexer.size(0) == 0) {
245 | // Skip
246 | } else if (indexer.scalar_type() == at::ScalarType::Bool) {
247 | const size_t Q = data.size(0) * data.size(1);
248 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
249 | device::sgd_mask_step_kernel<<>>(
250 | data.packed_accessor64(),
251 | grad.packed_accessor64(),
252 | indexer.data_ptr(),
253 | lr,
254 | lr_last);
255 | } else {
256 | const size_t Q = indexer.size(0) * data.size(1);
257 | const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
258 | device::sgd_index_step_kernel<<>>(
259 | data.packed_accessor64(),
260 | grad.packed_accessor64(),
261 | indexer.packed_accessor32(),
262 | lr,
263 | lr_last);
264 | }
265 |
266 | CUDA_CHECK_ERRORS;
267 | }
268 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/csrc/svox2.cpp:
--------------------------------------------------------------------------------
1 | // Copyright 2021 Alex Yu
2 |
3 | // This file contains only Python bindings
4 | #include "data_spec.hpp"
5 | #include
6 | #include
7 | #include
8 |
9 | using torch::Tensor;
10 |
11 | std::tuple sample_grid(SparseGridSpec &, Tensor,
12 | bool);
13 | void sample_grid_backward(SparseGridSpec &, Tensor, Tensor, Tensor, Tensor,
14 | Tensor, bool);
15 |
16 | // ** NeRF rendering formula (trilerp)
17 | Tensor volume_render_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &);
18 | Tensor volume_render_w_alpha_depth_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &, Tensor, Tensor, Tensor, Tensor, bool);
19 | Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &,
20 | RenderOptions &);
21 | void volume_render_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &,
22 | Tensor, Tensor, GridOutputGrads &);
23 | void volume_render_w_alpha_depth_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &,
24 | Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, GridOutputGrads &);
25 | Tensor volume_render_reg_cuvol(SparseGridSpec &, RaysSpec &, RenderOptions &, float, Tensor);
26 | void volume_render_reg_cuvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &, float, float, Tensor, Tensor, Tensor, GridOutputGrads &);
27 | void volume_render_cuvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &,
28 | Tensor, float, float, Tensor, GridOutputGrads &);
29 | // Expected termination (depth) rendering
30 | torch::Tensor volume_render_expected_term(SparseGridSpec &, RaysSpec &,
31 | RenderOptions &);
32 | // Depth rendering based on sigma-threshold as in Dex-NeRF
33 | torch::Tensor volume_render_sigma_thresh(SparseGridSpec &, RaysSpec &,
34 | RenderOptions &, float);
35 |
36 | // ** NV rendering formula (trilerp)
37 | Tensor volume_render_nvol(SparseGridSpec &, RaysSpec &, RenderOptions &);
38 | void volume_render_nvol_backward(SparseGridSpec &, RaysSpec &, RenderOptions &,
39 | Tensor, Tensor, GridOutputGrads &);
40 | void volume_render_nvol_fused(SparseGridSpec &, RaysSpec &, RenderOptions &,
41 | Tensor, float, float, Tensor, GridOutputGrads &);
42 |
43 | // ** NeRF rendering formula (nearest-neighbor, infinitely many steps)
44 | Tensor volume_render_svox1(SparseGridSpec &, RaysSpec &, RenderOptions &);
45 | void volume_render_svox1_backward(SparseGridSpec &, RaysSpec &, RenderOptions &,
46 | Tensor, Tensor, GridOutputGrads &);
47 | void volume_render_svox1_fused(SparseGridSpec &, RaysSpec &, RenderOptions &,
48 | Tensor, float, float, Tensor, GridOutputGrads &);
49 |
50 | // Tensor volume_render_cuvol_image(SparseGridSpec &, CameraSpec &,
51 | // RenderOptions &);
52 | //
53 | // void volume_render_cuvol_image_backward(SparseGridSpec &, CameraSpec &,
54 | // RenderOptions &, Tensor, Tensor,
55 | // GridOutputGrads &);
56 |
57 | // Misc
58 | Tensor dilate(Tensor);
59 | void accel_dist_prop(Tensor);
60 | void grid_weight_render(Tensor, CameraSpec &, float, float, bool, Tensor,
61 | Tensor, Tensor);
62 | // void sample_cubemap(Tensor, Tensor, bool, Tensor);
63 |
64 | // Loss
65 | Tensor tv(Tensor, Tensor, int, int, bool, float, bool, float, float);
66 | void tv_grad(Tensor, Tensor, int, int, float, bool, float, bool, float, float,
67 | Tensor);
68 | void tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, int, int, float, bool,
69 | float, bool, bool, float, float, Tensor);
70 | void msi_tv_grad_sparse(Tensor, Tensor, Tensor, Tensor, float, float, Tensor);
71 | void lumisphere_tv_grad_sparse(SparseGridSpec &, Tensor, Tensor, Tensor, float,
72 | float, float, float, GridOutputGrads &);
73 |
74 | // Optim
75 | void rmsprop_step(Tensor, Tensor, Tensor, Tensor, float, float, float, float,
76 | float);
77 | void sgd_step(Tensor, Tensor, Tensor, float, float);
78 |
79 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
80 | #define _REG_FUNC(funname) m.def(#funname, &funname)
81 | _REG_FUNC(sample_grid);
82 | _REG_FUNC(sample_grid_backward);
83 | _REG_FUNC(volume_render_cuvol);
84 | _REG_FUNC(volume_render_cuvol_image);
85 | _REG_FUNC(volume_render_cuvol_backward);
86 | _REG_FUNC(volume_render_reg_cuvol);
87 | _REG_FUNC(volume_render_reg_cuvol_backward);
88 | _REG_FUNC(volume_render_cuvol_fused);
89 | _REG_FUNC(volume_render_expected_term);
90 | _REG_FUNC(volume_render_sigma_thresh);
91 |
92 | _REG_FUNC(volume_render_nvol);
93 | _REG_FUNC(volume_render_nvol_backward);
94 | _REG_FUNC(volume_render_nvol_fused);
95 |
96 | _REG_FUNC(volume_render_svox1);
97 | _REG_FUNC(volume_render_svox1_backward);
98 | _REG_FUNC(volume_render_svox1_fused);
99 |
100 | _REG_FUNC(volume_render_w_alpha_depth_cuvol);
101 | _REG_FUNC(volume_render_w_alpha_depth_cuvol_backward);
102 | // _REG_FUNC(volume_render_cuvol_image);
103 | // _REG_FUNC(volume_render_cuvol_image_backward);
104 |
105 | // Loss
106 | _REG_FUNC(tv);
107 | _REG_FUNC(tv_grad);
108 | _REG_FUNC(tv_grad_sparse);
109 | _REG_FUNC(msi_tv_grad_sparse);
110 | _REG_FUNC(lumisphere_tv_grad_sparse);
111 |
112 | // Misc
113 | _REG_FUNC(dilate);
114 | _REG_FUNC(accel_dist_prop);
115 | _REG_FUNC(grid_weight_render);
116 | // _REG_FUNC(sample_cubemap);
117 |
118 | // Optimizer
119 | _REG_FUNC(rmsprop_step);
120 | _REG_FUNC(sgd_step);
121 | #undef _REG_FUNC
122 |
123 | py::class_(m, "SparseGridSpec")
124 | .def(py::init<>())
125 | .def_readwrite("density_data", &SparseGridSpec::density_data)
126 | .def_readwrite("sh_data", &SparseGridSpec::sh_data)
127 | .def_readwrite("links", &SparseGridSpec::links)
128 | .def_readwrite("_offset", &SparseGridSpec::_offset)
129 | .def_readwrite("_scaling", &SparseGridSpec::_scaling)
130 | .def_readwrite("basis_dim", &SparseGridSpec::basis_dim)
131 | .def_readwrite("basis_type", &SparseGridSpec::basis_type)
132 | .def_readwrite("basis_data", &SparseGridSpec::basis_data)
133 | .def_readwrite("background_links", &SparseGridSpec::background_links)
134 | .def_readwrite("background_data", &SparseGridSpec::background_data);
135 |
136 | py::class_(m, "CameraSpec")
137 | .def(py::init<>())
138 | .def_readwrite("c2w", &CameraSpec::c2w)
139 | .def_readwrite("fx", &CameraSpec::fx)
140 | .def_readwrite("fy", &CameraSpec::fy)
141 | .def_readwrite("cx", &CameraSpec::cx)
142 | .def_readwrite("cy", &CameraSpec::cy)
143 | .def_readwrite("width", &CameraSpec::width)
144 | .def_readwrite("height", &CameraSpec::height)
145 | .def_readwrite("ndc_coeffx", &CameraSpec::ndc_coeffx)
146 | .def_readwrite("ndc_coeffy", &CameraSpec::ndc_coeffy);
147 |
148 | py::class_(m, "RaysSpec")
149 | .def(py::init<>())
150 | .def_readwrite("origins", &RaysSpec::origins)
151 | .def_readwrite("dirs", &RaysSpec::dirs);
152 |
153 | py::class_(m, "RenderOptions")
154 | .def(py::init<>())
155 | .def_readwrite("background_brightness",
156 | &RenderOptions::background_brightness)
157 | .def_readwrite("step_size", &RenderOptions::step_size)
158 | .def_readwrite("sigma_thresh", &RenderOptions::sigma_thresh)
159 | .def_readwrite("stop_thresh", &RenderOptions::stop_thresh)
160 | .def_readwrite("near_clip", &RenderOptions::near_clip)
161 | .def_readwrite("use_spheric_clip", &RenderOptions::use_spheric_clip)
162 | .def_readwrite("last_sample_opaque", &RenderOptions::last_sample_opaque);
163 | // .def_readwrite("randomize", &RenderOptions::randomize)
164 | // .def_readwrite("random_sigma_std", &RenderOptions::random_sigma_std)
165 | // .def_readwrite("random_sigma_std_background",
166 | // &RenderOptions::random_sigma_std_background)
167 | // .def_readwrite("_m1", &RenderOptions::_m1)
168 | // .def_readwrite("_m2", &RenderOptions::_m2)
169 | // .def_readwrite("_m3", &RenderOptions::_m3);
170 |
171 | py::class_(m, "GridOutputGrads")
172 | .def(py::init<>())
173 | .def_readwrite("grad_density_out", &GridOutputGrads::grad_density_out)
174 | .def_readwrite("grad_sh_out", &GridOutputGrads::grad_sh_out)
175 | .def_readwrite("grad_basis_out", &GridOutputGrads::grad_basis_out)
176 | .def_readwrite("grad_background_out",
177 | &GridOutputGrads::grad_background_out)
178 | .def_readwrite("mask_out", &GridOutputGrads::mask_out)
179 | .def_readwrite("mask_background_out",
180 | &GridOutputGrads::mask_background_out);
181 | }
182 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/defs.py:
--------------------------------------------------------------------------------
1 | # Basis types (copied from C++ data_spec.hpp)
2 | BASIS_TYPE_SH = 1
3 | BASIS_TYPE_3D_TEXTURE = 4
4 | BASIS_TYPE_MLP = 255
5 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/svox2/version.py:
--------------------------------------------------------------------------------
1 | __version__ = 'voxgraf-0.0.1.dev0+sphtexcub.lincolor.fast'
2 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/test/test_render_gradcheck_alpha.py:
--------------------------------------------------------------------------------
1 | import svox2
2 | import torch
3 | import torch.nn.functional as F
4 | from svox2.utils import Timing
5 |
6 | torch.random.manual_seed(2)
7 | # torch.random.manual_seed(8289)
8 |
9 | device = 'cuda:0'
10 | torch.cuda.set_device(device)
11 | dtype = torch.float32
12 | grid = svox2.SparseGrid(
13 | reso=128,
14 | center=[0.0, 0.0, 0.0],
15 | radius=[1.0, 1.0, 1.0],
16 | basis_dim=9,
17 | use_z_order=True,
18 | device=device,
19 | background_nlayers=0,
20 | basis_type=svox2.BASIS_TYPE_SH)
21 | grid.opt.backend = 'cuvol'
22 | grid.opt.sigma_thresh = 0.0
23 | grid.opt.stop_thresh = 0.0
24 | grid.opt.background_brightness = 0.0
25 |
26 | print(grid.sh_data.shape)
27 |
28 | # Render using a white color to obtain equivalent to alpha rendering
29 | from svox2.utils import SH_C0
30 | grid.sh_data.data = torch.zeros_like(grid.sh_data.data.view(-1, 3, grid.basis_dim))
31 | grid.sh_data.data[..., 0] = 0.5 / SH_C0
32 | grid.sh_data.data = grid.sh_data.data.flatten(1, 2)
33 | grid.density_data.data[:] = 1.0
34 |
35 | if grid.use_background:
36 | grid.background_data.data[..., -1] = 0.5
37 | grid.background_data.data[..., :-1] = torch.randn_like(
38 | grid.background_data.data[..., :-1]) * 0.01
39 |
40 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
41 | grid.basis_data.data.normal_()
42 | grid.basis_data.data += 1.0
43 |
44 | ENABLE_TORCH_CHECK = True
45 | # N_RAYS = 5000 #200 * 200
46 | N_RAYS = 200 * 200
47 | origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3
48 | dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype)
49 |
50 | # from training.networks_plenoxel import RaySampler, PoseSampler
51 | # img_res = 32
52 | # ps = PoseSampler(range_azim=(180, 45), range_polar=(90, 15), radius=8, dist='normal')
53 | # rs = RaySampler(img_resolution=img_res, fov=30)
54 | # pose = ps.sample()
55 | # origins, dirs = rs.sample_at(pose[None, :3, :4], rs.img_resolution)
56 | # origins = origins.cuda()
57 | # dirs = dirs.cuda()
58 |
59 | # origins = torch.clip(origins, -0.8, 0.8)
60 |
61 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype)
62 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype)
63 | dirs /= torch.norm(dirs, dim=-1, keepdim=True)
64 |
65 | # start = 71
66 | # end = 72
67 | # origins = origins[start:end]
68 | # dirs = dirs[start:end]
69 | # print(origins.tolist(), dirs.tolist())
70 |
71 | # breakpoint()
72 | rays = svox2.Rays(origins, dirs)
73 |
74 | rgb_gt = torch.zeros((origins.size(0), 4), device=device, dtype=dtype)
75 |
76 | # grid.requires_grad_(True)
77 |
78 | # samps = grid.volume_render(rays, use_kernel=True)
79 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False)
80 |
81 | with Timing("ours"):
82 | samps = grid.volume_render(rays, use_kernel=True, render_alpha=True)
83 | assert all([(a == samps[:, 0]).all() for a in samps.T[1:]]), 'Alpha map and rendered color images do not match.'
84 |
85 | # Check internal gradients, i.e. white color vs alpha
86 | density_grads_s = []
87 | grid.sh_data.grad = None
88 | grid.density_data.grad = None
89 | for i in range(4):
90 | a = grid.volume_render(rays, use_kernel=True, render_alpha=True)[:, i]
91 | s = F.mse_loss(a, rgb_gt[:, i])
92 | with Timing("ours_backward"):
93 | s.backward()
94 | density_grads_s.append(grid.density_data.grad.clone().cpu())
95 | grid.sh_data.grad = None
96 | grid.density_data.grad = None
97 | density_grads_s = torch.stack(density_grads_s)
98 | print('Error gradients color/alpha', [(g - density_grads_s[0]).abs().max().item() for g in density_grads_s[1:]])
99 |
100 | # Check gradients wrt torch implementation
101 | s = F.mse_loss(samps, rgb_gt)
102 |
103 | print(s)
104 | print('bkwd..')
105 | with Timing("ours_backward"):
106 | s.backward()
107 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu()
108 | grid_density_grad_s = grid.density_data.grad.clone().cpu()
109 | grid.sh_data.grad = None
110 | grid.density_data.grad = None
111 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
112 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu()
113 | grid.basis_data.grad = None
114 | if grid.use_background:
115 | grid_bg_grad_s = grid.background_data.grad.clone().cpu()
116 | grid.background_data.grad = None
117 |
118 | if ENABLE_TORCH_CHECK:
119 | with Timing("torch"):
120 | sampt = grid.volume_render(rays, use_kernel=False, render_alpha=True)
121 | assert all([(a == sampt[:, 0]).all() for a in sampt.T[1:]]), 'Alpha map and rendered color images do not match.'
122 |
123 | print('Do ours and torch output match?', torch.isclose(samps, sampt).all().item())
124 |
125 | s = F.mse_loss(sampt, rgb_gt)
126 | with Timing("torch_backward"):
127 | s.backward()
128 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s)
129 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s)
130 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
131 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu()
132 | if grid.use_background:
133 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s)
134 |
135 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t)
136 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t)
137 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
138 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t)
139 | if grid.use_background:
140 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t)
141 | print('err', torch.abs(samps - sampt).max())
142 | print('err_sh_grad\n', E.max())
143 | print(' mean\n', E.mean())
144 | print('err_density_grad\n', Ed.max())
145 | print(' mean\n', Ed.mean())
146 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
147 | print('err_basis_grad\n', Eb.max())
148 | print(' mean\n', Eb.mean())
149 | if grid.use_background:
150 | print('err_background_grad\n', Ebg.max())
151 | print(' mean\n', Ebg.mean())
152 | print()
153 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max())
154 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max())
155 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max())
156 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max())
157 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
158 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max())
159 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max())
160 | if grid.use_background:
161 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max())
162 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max())
163 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/test/test_render_gradcheck_depth.py:
--------------------------------------------------------------------------------
1 | import svox2
2 | import torch
3 | import torch.nn.functional as F
4 | from svox2.utils import Timing
5 |
6 | torch.random.manual_seed(2)
7 | # torch.random.manual_seed(8289)
8 |
9 | device = 'cuda:0'
10 | torch.cuda.set_device(device)
11 | dtype = torch.float32
12 | grid_res = 128
13 | grid = svox2.SparseGrid(
14 | reso=grid_res,
15 | center=[0.0, 0.0, 0.0],
16 | radius=[1.0, 1.0, 1.0],
17 | basis_dim=9,
18 | use_z_order=True,
19 | device=device,
20 | background_nlayers=0,
21 | basis_type=svox2.BASIS_TYPE_SH)
22 | grid.opt.backend = 'cuvol'
23 | grid.opt.sigma_thresh = 0.0
24 | grid.opt.step_size = 0.5
25 | grid.opt.stop_thresh = 0.0
26 | grid.opt.background_brightness = 1.0
27 |
28 | print(grid.sh_data.shape)
29 | # grid.sh_data.data.normal_()
30 | grid.sh_data.data[..., 0] = 0.5
31 | grid.sh_data.data[..., 1:].normal_(std=0.1)
32 | grid.density_data.data[:] = 1.0
33 |
34 | if grid.use_background:
35 | grid.background_data.data[..., -1] = 0.5
36 | grid.background_data.data[..., :-1] = torch.randn_like(
37 | grid.background_data.data[..., :-1]) * 0.01
38 |
39 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
40 | grid.basis_data.data.normal_()
41 | grid.basis_data.data += 1.0
42 |
43 | ENABLE_TORCH_CHECK = True
44 | # N_RAYS = 5000 #200 * 200
45 | N_RAYS = 200 * 200
46 | origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3
47 | dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype)
48 |
49 | # from training.networks_plenoxel import RaySampler, PoseSampler
50 | # img_res = 32
51 | # ps = PoseSampler(range_azim=(180, 45), range_polar=(90, 15), radius=8, dist='normal')
52 | # rs = RaySampler(img_resolution=img_res, fov=30)
53 | # pose = ps.sample()
54 | # origins, dirs = rs.sample_at(pose[None, :3, :4], rs.img_resolution)
55 | # origins = origins.cuda()
56 | # dirs = dirs.cuda()
57 |
58 | # origins = torch.clip(origins, -0.8, 0.8)
59 |
60 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype)
61 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype)
62 | dirs /= torch.norm(dirs, dim=-1, keepdim=True)
63 |
64 | # start = 71
65 | # end = 72
66 | # origins = origins[start:end]
67 | # dirs = dirs[start:end]
68 | # print(origins.tolist(), dirs.tolist())
69 |
70 | # breakpoint()
71 | rays = svox2.Rays(origins, dirs)
72 |
73 | dims = 5
74 | rgb_gt = torch.zeros((origins.size(0), dims), device=device, dtype=dtype)
75 |
76 | # grid.requires_grad_(True)
77 |
78 | # samps = grid.volume_render(rays, use_kernel=True)
79 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False)
80 |
81 | with Timing("ours"):
82 | samps = grid.volume_render(rays, use_kernel=True, render_alpha=dims==5, render_depth=True)[:, -dims:]
83 | # samps_original = grid.volume_render_depth(rays)[:, None]#, use_kernel=True, render_alpha=False, render_depth=True)
84 | # assert (samps == samps_original).all(), 'Depth implementation does not match original plenoxel implementation' # original implementation does not feature background depth
85 |
86 | # Check gradients wrt torch implementation
87 | s = F.mse_loss(samps, rgb_gt)
88 |
89 | print(s)
90 | print('bkwd..')
91 | with Timing("ours_backward"):
92 | s.backward()
93 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu()
94 | grid_density_grad_s = grid.density_data.grad.clone().cpu()
95 | grid.sh_data.grad = None
96 | grid.density_data.grad = None
97 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
98 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu()
99 | grid.basis_data.grad = None
100 | if grid.use_background:
101 | grid_bg_grad_s = grid.background_data.grad.clone().cpu()
102 | grid.background_data.grad = None
103 |
104 | if ENABLE_TORCH_CHECK:
105 | with Timing("torch"):
106 | sampt = grid.volume_render(rays, use_kernel=False, render_alpha=dims==5, render_depth=True)[:, -dims:]
107 | print('Do ours and torch output match?', torch.isclose(samps, sampt).all().item())
108 |
109 | s = F.mse_loss(sampt, rgb_gt)
110 | with Timing("torch_backward"):
111 | s.backward()
112 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s)
113 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s)
114 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
115 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu()
116 | if grid.use_background:
117 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s)
118 |
119 | grid_density_grad_t = grid_density_grad_t.view(grid_res, grid_res, grid_res)
120 | grid_density_grad_s = grid_density_grad_s.view(grid_res, grid_res, grid_res)
121 |
122 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t)
123 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t)
124 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
125 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t)
126 | if grid.use_background:
127 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t)
128 | print('err', torch.abs(samps - sampt).max())
129 | print('err_sh_grad\n', E.max())
130 | print(' mean\n', E.mean())
131 | print('err_density_grad\n', Ed.max())
132 | print(' mean\n', Ed.mean())
133 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
134 | print('err_basis_grad\n', Eb.max())
135 | print(' mean\n', Eb.mean())
136 | if grid.use_background:
137 | print('err_background_grad\n', Ebg.max())
138 | print(' mean\n', Ebg.mean())
139 | print()
140 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max())
141 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max())
142 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max())
143 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max())
144 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
145 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max())
146 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max())
147 | if grid.use_background:
148 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max())
149 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max())
150 |
--------------------------------------------------------------------------------
/voxgraf-plenoxels/test/test_render_gradcheck_vardepth.py:
--------------------------------------------------------------------------------
1 | import svox2
2 | import torch
3 | import torch.nn.functional as F
4 | from svox2.utils import Timing
5 |
6 | torch.random.manual_seed(2)
7 | # torch.random.manual_seed(8289)
8 |
9 | device = 'cuda:0'
10 | torch.cuda.set_device(device)
11 | dtype = torch.float32
12 | grid_res = 64
13 | grid = svox2.SparseGrid(
14 | reso=grid_res,
15 | center=[0.0, 0.0, 0.0],
16 | radius=[1.0, 1.0, 1.0],
17 | basis_dim=1,
18 | use_z_order=True,
19 | device=device,
20 | background_nlayers=0,
21 | basis_type=svox2.BASIS_TYPE_SH)
22 | grid.opt.backend = 'cuvol'
23 | grid.opt.sigma_thresh = 0.0
24 | grid.opt.step_size = 0.5
25 | grid.opt.stop_thresh = 0.0
26 | grid.opt.background_brightness = 1.0
27 |
28 | feat_dim = 3*grid.basis_dim
29 | grid.sh_data.data = torch.empty([grid.sh_data.shape[0], feat_dim], device=grid.sh_data.device, dtype=grid.sh_data.dtype)
30 | print(grid.sh_data.shape)
31 | # grid.sh_data.data.normal_()
32 | grid.sh_data.data[..., 0] = 0.5
33 | grid.sh_data.data[..., 1:].normal_(std=0.1)
34 | grid.density_data.data[:] = 1.0
35 |
36 | if grid.use_background:
37 | grid.background_data.data[..., -1] = 0.5
38 | grid.background_data.data[..., :-1] = torch.randn_like(
39 | grid.background_data.data[..., :-1]) * 0.01
40 |
41 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
42 | grid.basis_data.data.normal_()
43 | grid.basis_data.data += 1.0
44 |
45 | ENABLE_TORCH_CHECK = True
46 | # # N_RAYS = 5000 #200 * 200
47 | # N_RAYS = 200 * 200
48 | # origins = torch.randn((N_RAYS, 3), device=device, dtype=dtype) * 3
49 | # dirs = torch.randn((N_RAYS, 3), device=device, dtype=dtype)
50 |
51 | from training.virtual_camera_utils import RaySampler, PoseSampler
52 | img_res = 32
53 | ps = PoseSampler(range_azim=(180, 45), range_polar=(90, 15), radius=8, dist='normal')
54 | rs = RaySampler(img_resolution=img_res, fov=30)
55 | pose = ps.sample_from_dist()
56 | origins, dirs = rs.sample_at(pose[None, :3, :4], rs.img_resolution, fov=30)
57 | origins = origins.cuda()
58 | dirs = dirs.cuda()
59 |
60 | # origins = torch.clip(origins, -0.8, 0.8)
61 |
62 | # origins = torch.tensor([[-0.6747068762779236, -0.752697229385376, -0.800000011920929]], device=device, dtype=dtype)
63 | # dirs = torch.tensor([[0.6418760418891907, -0.37417781352996826, 0.6693176627159119]], device=device, dtype=dtype)
64 | dirs /= torch.norm(dirs, dim=-1, keepdim=True)
65 |
66 | # start = 71
67 | # end = 72
68 | # origins = origins[start:end]
69 | # dirs = dirs[start:end]
70 | # print(origins.tolist(), dirs.tolist())
71 |
72 | # breakpoint()
73 | rays = svox2.Rays(origins, dirs)
74 |
75 | dims = feat_dim + 3
76 | rgb_gt = torch.zeros((origins.size(0), dims), device=device, dtype=dtype)
77 |
78 | # grid.requires_grad_(True)
79 |
80 | # samps = grid.volume_render(rays, use_kernel=True)
81 | # sampt = grid.volume_render(grid, origins, dirs, use_kernel=False)
82 |
83 | with Timing("ours"):
84 | samps = grid.volume_render(rays, use_kernel=True, render_alpha=dims==(feat_dim + 3), render_depth=True, render_vardepth=True)[:, -dims:]
85 |
86 | # Check gradients wrt torch implementation
87 | s = F.mse_loss(samps, rgb_gt)
88 |
89 | print(s)
90 | print('bkwd..')
91 | with Timing("ours_backward"):
92 | s.backward()
93 | grid_sh_grad_s = grid.sh_data.grad.clone().cpu()
94 | grid_density_grad_s = grid.density_data.grad.clone().cpu()
95 | grid.sh_data.grad = None
96 | grid.density_data.grad = None
97 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
98 | grid_basis_grad_s = grid.basis_data.grad.clone().cpu()
99 | grid.basis_data.grad = None
100 | if grid.use_background:
101 | grid_bg_grad_s = grid.background_data.grad.clone().cpu()
102 | grid.background_data.grad = None
103 |
104 | if ENABLE_TORCH_CHECK:
105 | with Timing("torch"):
106 | sampt = grid.volume_render(rays, use_kernel=False, render_alpha=dims==(feat_dim + 3), render_depth=True, render_vardepth=True)[:, -dims:]
107 | print('Do ours and torch output match?', torch.isclose(samps, sampt).all().item())
108 |
109 | s = F.mse_loss(sampt, rgb_gt)
110 | with Timing("torch_backward"):
111 | s.backward()
112 | grid_sh_grad_t = grid.sh_data.grad.clone().cpu() if grid.sh_data.grad is not None else torch.zeros_like(grid_sh_grad_s)
113 | grid_density_grad_t = grid.density_data.grad.clone().cpu() if grid.density_data.grad is not None else torch.zeros_like(grid_density_grad_s)
114 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
115 | grid_basis_grad_t = grid.basis_data.grad.clone().cpu()
116 | if grid.use_background:
117 | grid_bg_grad_t = grid.background_data.grad.clone().cpu() if grid.background_data.grad is not None else torch.zeros_like(grid_bg_grad_s)
118 |
119 | grid_density_grad_t = grid_density_grad_t.view(grid_res, grid_res, grid_res)
120 | grid_density_grad_s = grid_density_grad_s.view(grid_res, grid_res, grid_res)
121 |
122 | E = torch.abs(grid_sh_grad_s-grid_sh_grad_t)
123 | Ed = torch.abs(grid_density_grad_s-grid_density_grad_t)
124 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
125 | Eb = torch.abs(grid_basis_grad_s-grid_basis_grad_t)
126 | if grid.use_background:
127 | Ebg = torch.abs(grid_bg_grad_s-grid_bg_grad_t)
128 | print('err', torch.abs(samps - sampt).max())
129 | print('err_sh_grad\n', E.max())
130 | print(' mean\n', E.mean())
131 | print('err_density_grad\n', Ed.max())
132 | print(' mean\n', Ed.mean())
133 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
134 | print('err_basis_grad\n', Eb.max())
135 | print(' mean\n', Eb.mean())
136 | if grid.use_background:
137 | print('err_background_grad\n', Ebg.max())
138 | print(' mean\n', Ebg.mean())
139 | print()
140 | print('g_ours sh min/max\n', grid_sh_grad_s.min(), grid_sh_grad_s.max())
141 | print('g_torch sh min/max\n', grid_sh_grad_t.min(), grid_sh_grad_t.max())
142 | print('g_ours sigma min/max\n', grid_density_grad_s.min(), grid_density_grad_s.max())
143 | print('g_torch sigma min/max\n', grid_density_grad_t.min(), grid_density_grad_t.max())
144 | if grid.basis_type == svox2.BASIS_TYPE_3D_TEXTURE:
145 | print('g_ours basis min/max\n', grid_basis_grad_s.min(), grid_basis_grad_s.max())
146 | print('g_torch basis min/max\n', grid_basis_grad_t.min(), grid_basis_grad_t.max())
147 | if grid.use_background:
148 | print('g_ours bg min/max\n', grid_bg_grad_s.min(), grid_bg_grad_s.max())
149 | print('g_torch bg min/max\n', grid_bg_grad_t.min(), grid_bg_grad_t.max())
150 |
--------------------------------------------------------------------------------