├── .gitignore ├── .gitmodules ├── LICENSE ├── ablation_study.sh ├── activation.py ├── assets ├── bg_model.jpg ├── ccnerf.jpg ├── fox.jpg ├── gallery.md ├── llff.jpg ├── truck.jpg └── update_logs.md ├── dec_48b_whit.torchscript.pt ├── dnerf ├── gui.py ├── network.py ├── network_basis.py ├── network_hyper.py ├── provider.py ├── renderer.py └── utils.py ├── encoding.py ├── ffmlp ├── __init__.py ├── backend.py ├── ffmlp.py ├── setup.py └── src │ ├── bindings.cpp │ ├── cutlass_matmul.h │ ├── ffmlp.cu │ ├── ffmlp.h │ └── utils.h ├── freqencoder ├── __init__.py ├── backend.py ├── freq.py ├── setup.py └── src │ ├── bindings.cpp │ ├── freqencoder.cu │ └── freqencoder.h ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── loss.py ├── main_nerf.py ├── nerf ├── clip_utils.py ├── gui.py ├── network.py ├── network_ff.py ├── network_tcnn.py ├── optimizer.py ├── provider.py ├── renderer.py └── utils.py ├── raymarching ├── __init__.py ├── backend.py ├── raymarching.py ├── setup.py └── src │ ├── bindings.cpp │ ├── raymarching.cu │ └── raymarching.h ├── readme.md ├── requirements.txt ├── scripts ├── colmap2nerf.py ├── hyper2nerf.py ├── install_ext.sh ├── llff2nerf.py ├── run_ccnerf.sh ├── run_dnerf.sh ├── run_gui_nerf.sh ├── run_gui_nerf_clip.sh ├── run_gui_tensoRF.sh ├── run_nerf.sh ├── run_sdf.sh ├── run_tensoRF.sh └── tanks2nerf.py ├── sdf ├── netowrk.py ├── netowrk_ff.py ├── network_tcnn.py ├── provider.py └── utils.py ├── shencoder ├── __init__.py ├── backend.py ├── setup.py ├── sphere_harmonics.py └── src │ ├── bindings.cpp │ ├── shencoder.cu │ └── shencoder.h ├── tensoRF ├── network.py ├── network_cc.py ├── network_cp.py └── utils.py ├── testing ├── test_ffmlp.py ├── test_hashencoder.py ├── test_hashgrid_grad.py ├── test_raymarching.py └── test_shencoder.py └── utils_img.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | 6 | tmp* 7 | data/ 8 | trial*/ 9 | .vs/ 10 | 11 | #**dnerf* 12 | #dnerf/ 13 | dnerf/network_rf.py 14 | dnerf/network_basis_perpoint.py 15 | 16 | **neus* 17 | neus/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ffmlp/dependencies/cutlass"] 2 | path = ffmlp/dependencies/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 hawkey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ablation_study.sh: -------------------------------------------------------------------------------- 1 | 2 | ####################################################### 3 | # Only Global Rendering 4 | 5 | 6 | # python main_nerf.py data/nerf_synthetic/lego --workspace out/local/lego_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 7 | # python main_nerf.py data/nerf_synthetic/drums --workspace out/local/drums_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 8 | # python main_nerf.py data/nerf_synthetic/chair --workspace out/local/chair_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 9 | 10 | # python main_nerf.py data/nerf_synthetic/lego --workspace out/local/lego_0.01 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.01 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 11 | # python main_nerf.py data/nerf_synthetic/drums --workspace out/local/drums_0.01 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.01 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 12 | # python main_nerf.py data/nerf_synthetic/chair --workspace out/local/chair_0.01 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.01 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 13 | 14 | 15 | # python main_nerf.py data/nerf_synthetic/lego --workspace out/local/lego_0.1 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.1 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 16 | # python main_nerf.py data/nerf_synthetic/drums --workspace out/local/drums_0.1 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.1 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 17 | # python main_nerf.py data/nerf_synthetic/chair --workspace out/local/chair_0.1 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.1 --iters 10000 --global_sample 0 --ml_sample 0 --lr 4e-2 18 | 19 | 20 | # ####################################################### 21 | # Only Global Rendering 22 | # python main_nerf.py data/nerf_synthetic/lego --workspace out/global/lego_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --local_sample False 23 | # # PSNR = 31.344405 24 | # # SSIM = 0.955633 25 | # # LPIPS (alex) = 0.025490 26 | # # BitAcc = 0.881354 27 | # python main_nerf.py data/nerf_synthetic/drums --workspace out/global/drums_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --local_sample False 28 | # python main_nerf.py data/nerf_synthetic/chair --workspace out/global/chair_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --local_sample False 29 | 30 | 31 | 32 | ################################################################## 33 | # Local w/ Global Rendering 34 | # python main_nerf.py data/nerf_synthetic/lego --workspace out/local_w_global/lego_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 35 | # # PSNR = 32.917592 36 | # # SSIM = 0.969716 37 | # # LPIPS (alex) = 0.016145 38 | # # BitAcc = 0.865417 39 | # python main_nerf.py data/nerf_synthetic/drums --workspace out/local_w_global/drums_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 40 | # python main_nerf.py data/nerf_synthetic/chair --workspace out/local_w_global/chair_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 41 | 42 | 43 | 44 | ############################################################################# 45 | 46 | # Local w/ Mul Global 47 | # python main_nerf.py data/nerf_synthetic/lego --workspace out/local_w_ml_global/lego_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 48 | # python main_nerf.py data/nerf_synthetic/drums --workspace out/local_w_ml_global/drums_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 49 | # python main_nerf.py data/nerf_synthetic/chair --workspace out/local_w_ml_global/chair_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 50 | # python main_nerf.py data/nerf_synthetic/hotdog --workspace out/local_w_ml_global/hotdog_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 51 | 52 | # python main_nerf.py data/nerf_synthetic/ficus --workspace out/local_w_ml_global/ficus_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 53 | 54 | # python scripts/llff2nerf.py data/llff/flower --images images_8 --downscale 8 55 | # python scripts/llff2nerf.py data/llff/fern --images images_8 --downscale 8 56 | # python scripts/llff2nerf.py data/llff/fortress --images images_8 --downscale 8 57 | # python scripts/llff2nerf.py data/llff/horns --images images_8 --downscale 8 58 | 59 | # python main_nerf.py data/llff/flower --workspace out/local_w_ml_global/flower_0.001 -O --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 60 | # python main_nerf.py data/llff/fern --workspace out/local_w_ml_global/fern_0.001 -O --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 61 | # python main_nerf.py data/llff/fortress --workspace out/local_w_ml_global/fortress_0.001 -O --lambda1 0.001 --iters 8000 --lr 4e-2 --ml_sample True 62 | # python main_nerf.py data/llff/horns --workspace out/local_w_ml_global/horns_0.001 -O --lambda1 0.001 --iters 2000 --lr 4e-2 --ml_sample True 63 | ############################################################################# 64 | 65 | ### test 66 | # out/global/lego_0.001/checkpoints/ngp_ep0100.pth 67 | 68 | 69 | 70 | #### ab for L 71 | python main_nerf.py data/nerf_synthetic/lego --workspace out/final/lego_0.1 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.1 --iters 10000 --lr 4e-2 72 | python main_nerf.py data/nerf_synthetic/lego --workspace out/final/lego_0.01 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.01 --iters 10000 --lr 4e-2 73 | python main_nerf.py data/nerf_synthetic/lego --workspace out/final/lego_0.005 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.005 --iters 10000 --lr 4e-2 74 | python main_nerf.py data/nerf_synthetic/lego --workspace out/final/lego_0.001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 75 | python main_nerf.py data/nerf_synthetic/lego --workspace out/final/lego_0.0001 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.0001 --iters 10000 --lr 4e-2 -------------------------------------------------------------------------------- /activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.cuda.amp import custom_bwd, custom_fwd 4 | 5 | class _trunc_exp(Function): 6 | @staticmethod 7 | @custom_fwd(cast_inputs=torch.float32) # cast to float32 8 | def forward(ctx, x): 9 | ctx.save_for_backward(x) 10 | return torch.exp(x) 11 | 12 | @staticmethod 13 | @custom_bwd 14 | def backward(ctx, g): 15 | x = ctx.saved_tensors[0] 16 | return g * torch.exp(x.clamp(-15, 15)) 17 | 18 | trunc_exp = _trunc_exp.apply -------------------------------------------------------------------------------- /assets/bg_model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsong2001/NeRFProtector-code/df95cb2598172b3af660e583ae0343c35897ea2c/assets/bg_model.jpg -------------------------------------------------------------------------------- /assets/ccnerf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsong2001/NeRFProtector-code/df95cb2598172b3af660e583ae0343c35897ea2c/assets/ccnerf.jpg -------------------------------------------------------------------------------- /assets/fox.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsong2001/NeRFProtector-code/df95cb2598172b3af660e583ae0343c35897ea2c/assets/fox.jpg -------------------------------------------------------------------------------- /assets/gallery.md: -------------------------------------------------------------------------------- 1 | # Gallery 2 | 3 | ## D-NeRF 4 | 5 | https://user-images.githubusercontent.com/25863658/175821784-63ba79f6-29be-47b5-b3fc-dab5282fce7a.mp4 6 | 7 | 8 | ## Instant-ngp NeRF 9 | 10 | Fox: 11 | 12 | ![fox](fox.jpg) 13 | 14 | LLFF: 15 | 16 | ![llff](llff.jpg) 17 | 18 | Tanks&Temples: 19 | 20 | ![truck](truck.jpg) 21 | 22 | ## CCNeRF 23 | 24 | Composition example: 25 | 26 | ![ccnerf](ccnerf.jpg) 27 | -------------------------------------------------------------------------------- /assets/llff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsong2001/NeRFProtector-code/df95cb2598172b3af660e583ae0343c35897ea2c/assets/llff.jpg -------------------------------------------------------------------------------- /assets/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsong2001/NeRFProtector-code/df95cb2598172b3af660e583ae0343c35897ea2c/assets/truck.jpg -------------------------------------------------------------------------------- /assets/update_logs.md: -------------------------------------------------------------------------------- 1 | ## Update logs 2 | 3 | * 7.28: support saving video at test. 4 | * 7.26: add a CUDA-based freqencoder (though not used by default), add LPIPS metric. 5 | * 7.16: add temporal basis based dynamic nerf (experimental). It trains much faster compared to the deformation based dynamic nerf, but performance is much worse for now... 6 | * 6.29: add support for HyperNeRF's dataset. 7 | * we use a simplified pinhole camera model, may introduce bias. 8 | * 6.26: add support for D-NeRF. 9 | * issue: to enable the `--cuda_ray` in a dynamic scene, we have to record different density grid for different time. This lead to much slower `update_extra_status` and much larger `density_grid` since there is an additional time dimension. Current work arounds: (1) only use 64 time intervals, (2) update it every 100 steps (compared to the 16 steps in static nerf), (3) stop updation after 100 times since the grid should be stable now. 10 | * 6.16: add support for CCNeRF. 11 | * 6.15: fixed a bug in raymarching, improved PSNR. Density thresh is directly applied on sigmas now (removed the empirical scaling factor). 12 | * 6.6: fix gridencoder to always use more accurate float32 inputs (coords), slightly improved performance (matched with tcnn). 13 | * 6.3: implement morton3D, misc improvements. 14 | * 5.29: fix a random bg color issue, add color_space option, better results for blender dataset. 15 | * 5.28: add a background model (set bg_radius > 0), which can suppress noises for real-world 360 datasets. 16 | * 5.21: expose more parameters to control, implement packbits. 17 | * 4.30: performance improvement (better lr_scheduler). 18 | * 4.25: add Tanks&Temples dataset support. 19 | * 4.18: add some experimental utils for random pose sampling and combined training with CLIP. 20 | * 4.13: add LLFF dataset support. 21 | * 4.13: also implmented tiled grid encoder according to this [issue](https://github.com/NVlabs/instant-ngp/issues/97). 22 | * 4.12: optimized dataloader, add error_map sampling (experimental, will slow down training since will only sample hard rays...) 23 | * 4.10: add Windows support. 24 | * 4.9: use 6D AABB instead of a single `bound` for more flexible rendering. More options in GUI to control the AABB and `dt_gamma` for adaptive ray marching. 25 | * 4.9: implemented multi-res density grid (cascade) and adaptive ray marching. Now the fox renders much faster! 26 | * 4.6: fixed TensorCP hyper-parameters. 27 | * 4.3: add `mark_untrained_grid` to prevent training on out-of-camera regions. Add custom dataset instructions. 28 | * 3.31: better compatibility for lower pytorch versions. 29 | * 3.29: fix training speed for the fox dataset (balanced speed with performance...). 30 | * 3.27: major update. basically improve performance, and support tensoRF model. 31 | * 3.22: reverted from pre-generating rays as it takes too much CPU memory, still the PSNR for Lego can reach ~33 now. 32 | * 3.14: fixed the precision related issue for `fp16` mode, and it renders much better quality. Added PSNR metric for NeRF. 33 | * 3.14: linearly scale `desired_resolution` with `bound` according to https://github.com/ashawkey/torch-ngp/issues/23. 34 | * 3.11: raymarching now supports supervising weights_sum (pixel alpha, or mask) directly, and bg_color is separated from CUDA to make it more flexible. Add an option to preload data into GPU. 35 | * 3.9: add fov for gui. 36 | * 3.1: add type='all' for blender dataset (load train + val + test data), which is the default behavior of instant-ngp. 37 | * 2.28: density_grid now stores density on the voxel center (with randomness), instead of on the grid. This should improve the rendering quality, such as the black strips in the lego scene. 38 | * 2.23: better support for the blender dataset. 39 | * 2.22: add GUI for NeRF training. 40 | * 2.21: add GUI for NeRF visualizing. 41 | * 2.20: cuda raymarching is finally stable now! 42 | * 2.15: add the official [tinycudann](https://github.com/NVlabs/tiny-cuda-nn) as an alternative backend. 43 | * 2.10: add cuda_ray, can train/infer faster, but performance is worse currently. 44 | * 2.6: add support for RGBA image. 45 | * 1.30: fixed atomicAdd() to use __half2 in HashGrid Encoder's backward, now the training speed with fp16 is as expected! 46 | * 1.29: finished an experimental binding of fully-fused MLP. replace SHEncoder with a CUDA implementation. 47 | * 1.26: add fp16 support for HashGrid Encoder (requires CUDA >= 10 and GPU ARCH >= 70 for now...). -------------------------------------------------------------------------------- /dec_48b_whit.torchscript.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qsong2001/NeRFProtector-code/df95cb2598172b3af660e583ae0343c35897ea2c/dec_48b_whit.torchscript.pt -------------------------------------------------------------------------------- /dnerf/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from .renderer import NeRFRenderer 8 | 9 | 10 | class NeRFNetwork(NeRFRenderer): 11 | def __init__(self, 12 | encoding="tiledgrid", 13 | encoding_dir="sphere_harmonics", 14 | encoding_time="frequency", 15 | encoding_deform="frequency", # "hashgrid" seems worse 16 | encoding_bg="hashgrid", 17 | num_layers=2, 18 | hidden_dim=64, 19 | geo_feat_dim=15, 20 | num_layers_color=3, 21 | hidden_dim_color=64, 22 | num_layers_bg=2, 23 | hidden_dim_bg=64, 24 | num_layers_deform=5, # a deeper MLP is very necessary for performance. 25 | hidden_dim_deform=128, 26 | bound=1, 27 | **kwargs, 28 | ): 29 | super().__init__(bound, **kwargs) 30 | 31 | # deformation network 32 | self.num_layers_deform = num_layers_deform 33 | self.hidden_dim_deform = hidden_dim_deform 34 | self.encoder_deform, self.in_dim_deform = get_encoder(encoding_deform, multires=10) 35 | self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6) 36 | 37 | 38 | deform_net = [] 39 | for l in range(num_layers_deform): 40 | if l == 0: 41 | in_dim = self.in_dim_deform + self.in_dim_time # grid dim + time 42 | else: 43 | in_dim = hidden_dim_deform 44 | 45 | if l == num_layers_deform - 1: 46 | out_dim = 3 # deformation for xyz 47 | else: 48 | out_dim = hidden_dim_deform 49 | 50 | deform_net.append(nn.Linear(in_dim, out_dim, bias=False)) 51 | 52 | self.deform_net = nn.ModuleList(deform_net) 53 | 54 | 55 | # sigma network 56 | self.num_layers = num_layers 57 | self.hidden_dim = hidden_dim 58 | self.geo_feat_dim = geo_feat_dim 59 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) 60 | 61 | sigma_net = [] 62 | for l in range(num_layers): 63 | if l == 0: 64 | in_dim = self.in_dim + self.in_dim_time + self.in_dim_deform # concat everything 65 | else: 66 | in_dim = hidden_dim 67 | 68 | if l == num_layers - 1: 69 | out_dim = 1 + self.geo_feat_dim # 1 sigma + features for color 70 | else: 71 | out_dim = hidden_dim 72 | 73 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 74 | 75 | self.sigma_net = nn.ModuleList(sigma_net) 76 | 77 | # color network 78 | self.num_layers_color = num_layers_color 79 | self.hidden_dim_color = hidden_dim_color 80 | self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) 81 | 82 | color_net = [] 83 | for l in range(num_layers_color): 84 | if l == 0: 85 | in_dim = self.in_dim_dir + self.geo_feat_dim 86 | else: 87 | in_dim = hidden_dim_color 88 | 89 | if l == num_layers_color - 1: 90 | out_dim = 3 # 3 rgb 91 | else: 92 | out_dim = hidden_dim_color 93 | 94 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 95 | 96 | self.color_net = nn.ModuleList(color_net) 97 | 98 | # background network 99 | if self.bg_radius > 0: 100 | self.num_layers_bg = num_layers_bg 101 | self.hidden_dim_bg = hidden_dim_bg 102 | self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid 103 | 104 | bg_net = [] 105 | for l in range(num_layers_bg): 106 | if l == 0: 107 | in_dim = self.in_dim_bg + self.in_dim_dir 108 | else: 109 | in_dim = hidden_dim_bg 110 | 111 | if l == num_layers_bg - 1: 112 | out_dim = 3 # 3 rgb 113 | else: 114 | out_dim = hidden_dim_bg 115 | 116 | bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) 117 | 118 | self.bg_net = nn.ModuleList(bg_net) 119 | else: 120 | self.bg_net = None 121 | 122 | 123 | def forward(self, x, d, t): 124 | # x: [N, 3], in [-bound, bound] 125 | # d: [N, 3], nomalized in [-1, 1] 126 | # t: [1, 1], in [0, 1] 127 | 128 | # deform 129 | enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C] 130 | enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] 131 | if enc_t.shape[0] == 1: 132 | enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] 133 | 134 | deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C'] 135 | for l in range(self.num_layers_deform): 136 | deform = self.deform_net[l](deform) 137 | if l != self.num_layers_deform - 1: 138 | deform = F.relu(deform, inplace=True) 139 | 140 | x = x + deform 141 | 142 | # sigma 143 | x = self.encoder(x, bound=self.bound) 144 | h = torch.cat([x, enc_ori_x, enc_t], dim=1) 145 | for l in range(self.num_layers): 146 | h = self.sigma_net[l](h) 147 | if l != self.num_layers - 1: 148 | h = F.relu(h, inplace=True) 149 | 150 | #sigma = F.relu(h[..., 0]) 151 | sigma = trunc_exp(h[..., 0]) 152 | geo_feat = h[..., 1:] 153 | 154 | # color 155 | d = self.encoder_dir(d) 156 | h = torch.cat([d, geo_feat], dim=-1) 157 | for l in range(self.num_layers_color): 158 | h = self.color_net[l](h) 159 | if l != self.num_layers_color - 1: 160 | h = F.relu(h, inplace=True) 161 | 162 | # sigmoid activation for rgb 163 | rgbs = torch.sigmoid(h) 164 | 165 | return sigma, rgbs, deform 166 | 167 | def density(self, x, t): 168 | # x: [N, 3], in [-bound, bound] 169 | # t: [1, 1], in [0, 1] 170 | 171 | results = {} 172 | 173 | # deformation 174 | enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C] 175 | enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] 176 | if enc_t.shape[0] == 1: 177 | enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] 178 | 179 | deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C'] 180 | for l in range(self.num_layers_deform): 181 | deform = self.deform_net[l](deform) 182 | if l != self.num_layers_deform - 1: 183 | deform = F.relu(deform, inplace=True) 184 | 185 | x = x + deform 186 | results['deform'] = deform 187 | 188 | # sigma 189 | x = self.encoder(x, bound=self.bound) 190 | h = torch.cat([x, enc_ori_x, enc_t], dim=1) 191 | for l in range(self.num_layers): 192 | h = self.sigma_net[l](h) 193 | if l != self.num_layers - 1: 194 | h = F.relu(h, inplace=True) 195 | 196 | #sigma = F.relu(h[..., 0]) 197 | sigma = trunc_exp(h[..., 0]) 198 | geo_feat = h[..., 1:] 199 | 200 | results['sigma'] = sigma 201 | results['geo_feat'] = geo_feat 202 | 203 | return results 204 | 205 | def background(self, x, d): 206 | # x: [N, 2], in [-1, 1] 207 | 208 | h = self.encoder_bg(x) # [N, C] 209 | d = self.encoder_dir(d) 210 | 211 | h = torch.cat([d, h], dim=-1) 212 | for l in range(self.num_layers_bg): 213 | h = self.bg_net[l](h) 214 | if l != self.num_layers_bg - 1: 215 | h = F.relu(h, inplace=True) 216 | 217 | # sigmoid activation for rgb 218 | rgbs = torch.sigmoid(h) 219 | 220 | return rgbs 221 | 222 | # allow masked inference 223 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 224 | # x: [N, 3] in [-bound, bound] 225 | # t: [1, 1], in [0, 1] 226 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 227 | 228 | if mask is not None: 229 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 230 | # in case of empty mask 231 | if not mask.any(): 232 | return rgbs 233 | x = x[mask] 234 | d = d[mask] 235 | geo_feat = geo_feat[mask] 236 | 237 | d = self.encoder_dir(d) 238 | h = torch.cat([d, geo_feat], dim=-1) 239 | for l in range(self.num_layers_color): 240 | h = self.color_net[l](h) 241 | if l != self.num_layers_color - 1: 242 | h = F.relu(h, inplace=True) 243 | 244 | # sigmoid activation for rgb 245 | h = torch.sigmoid(h) 246 | 247 | if mask is not None: 248 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 249 | else: 250 | rgbs = h 251 | 252 | return rgbs 253 | 254 | # optimizer utils 255 | def get_params(self, lr, lr_net): 256 | 257 | params = [ 258 | {'params': self.encoder.parameters(), 'lr': lr}, 259 | {'params': self.sigma_net.parameters(), 'lr': lr_net}, 260 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 261 | {'params': self.color_net.parameters(), 'lr': lr_net}, 262 | {'params': self.encoder_deform.parameters(), 'lr': lr}, 263 | {'params': self.encoder_time.parameters(), 'lr': lr}, 264 | {'params': self.deform_net.parameters(), 'lr': lr_net}, 265 | ] 266 | if self.bg_radius > 0: 267 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 268 | params.append({'params': self.bg_net.parameters(), 'lr': lr_net}) 269 | 270 | return params 271 | -------------------------------------------------------------------------------- /dnerf/network_basis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from .renderer import NeRFRenderer 8 | 9 | 10 | class NeRFNetwork(NeRFRenderer): 11 | def __init__(self, 12 | encoding="tiledgrid", 13 | encoding_dir="sphere_harmonics", 14 | encoding_time="frequency", 15 | encoding_bg="hashgrid", 16 | num_layers=2, 17 | hidden_dim=64, 18 | geo_feat_dim=32, 19 | num_layers_color=3, 20 | hidden_dim_color=64, 21 | num_layers_bg=2, 22 | hidden_dim_bg=64, 23 | sigma_basis_dim=32, 24 | color_basis_dim=8, 25 | num_layers_basis=5, 26 | hidden_dim_basis=128, 27 | bound=1, 28 | **kwargs, 29 | ): 30 | super().__init__(bound, **kwargs) 31 | 32 | # basis network 33 | self.num_layers_basis = num_layers_basis 34 | self.hidden_dim_basis = hidden_dim_basis 35 | self.sigma_basis_dim = sigma_basis_dim 36 | self.color_basis_dim = color_basis_dim 37 | self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6) 38 | 39 | basis_net = [] 40 | for l in range(num_layers_basis): 41 | if l == 0: 42 | in_dim = self.in_dim_time 43 | else: 44 | in_dim = hidden_dim_basis 45 | 46 | if l == num_layers_basis - 1: 47 | out_dim = self.sigma_basis_dim + self.color_basis_dim 48 | else: 49 | out_dim = hidden_dim_basis 50 | 51 | basis_net.append(nn.Linear(in_dim, out_dim, bias=False)) 52 | 53 | self.basis_net = nn.ModuleList(basis_net) 54 | 55 | # sigma network 56 | self.num_layers = num_layers 57 | self.hidden_dim = hidden_dim 58 | self.geo_feat_dim = geo_feat_dim 59 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) 60 | 61 | sigma_net = [] 62 | for l in range(num_layers): 63 | if l == 0: 64 | in_dim = self.in_dim 65 | else: 66 | in_dim = hidden_dim 67 | 68 | if l == num_layers - 1: 69 | out_dim = self.sigma_basis_dim + self.geo_feat_dim # SB sigma + features for color 70 | else: 71 | out_dim = hidden_dim 72 | 73 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 74 | 75 | self.sigma_net = nn.ModuleList(sigma_net) 76 | 77 | # color network 78 | self.num_layers_color = num_layers_color 79 | self.hidden_dim_color = hidden_dim_color 80 | self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) 81 | 82 | color_net = [] 83 | for l in range(num_layers_color): 84 | if l == 0: 85 | in_dim = self.in_dim_dir + self.geo_feat_dim 86 | else: 87 | in_dim = hidden_dim_color 88 | 89 | if l == num_layers_color - 1: 90 | out_dim = 3 * self.color_basis_dim # 3 * CB rgb 91 | else: 92 | out_dim = hidden_dim_color 93 | 94 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 95 | 96 | self.color_net = nn.ModuleList(color_net) 97 | 98 | # background network 99 | if self.bg_radius > 0: 100 | self.num_layers_bg = num_layers_bg 101 | self.hidden_dim_bg = hidden_dim_bg 102 | self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid 103 | 104 | bg_net = [] 105 | for l in range(num_layers_bg): 106 | if l == 0: 107 | in_dim = self.in_dim_bg + self.in_dim_dir 108 | else: 109 | in_dim = hidden_dim_bg 110 | 111 | if l == num_layers_bg - 1: 112 | out_dim = 3 # 3 rgb 113 | else: 114 | out_dim = hidden_dim_bg 115 | 116 | bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) 117 | 118 | self.bg_net = nn.ModuleList(bg_net) 119 | else: 120 | self.bg_net = None 121 | 122 | 123 | def forward(self, x, d, t): 124 | # x: [N, 3], in [-bound, bound] 125 | # d: [N, 3], nomalized in [-1, 1] 126 | # t: [1, 1], in [0, 1] 127 | 128 | # time --> basis 129 | enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] 130 | h = enc_t 131 | for l in range(self.num_layers_basis): 132 | h = self.basis_net[l](h) 133 | if l != self.num_layers_basis - 1: 134 | h = F.relu(h, inplace=True) 135 | 136 | sigma_basis = h[0, :self.sigma_basis_dim] 137 | color_basis = h[0, self.sigma_basis_dim:] 138 | 139 | # sigma 140 | x = self.encoder(x, bound=self.bound) 141 | h = x 142 | for l in range(self.num_layers): 143 | h = self.sigma_net[l](h) 144 | if l != self.num_layers - 1: 145 | h = F.relu(h, inplace=True) 146 | 147 | sigma = trunc_exp(h[..., :self.sigma_basis_dim] @ sigma_basis) 148 | geo_feat = h[..., self.sigma_basis_dim:] 149 | 150 | # color 151 | d = self.encoder_dir(d) 152 | h = torch.cat([d, geo_feat], dim=-1) 153 | for l in range(self.num_layers_color): 154 | h = self.color_net[l](h) 155 | if l != self.num_layers_color - 1: 156 | h = F.relu(h, inplace=True) 157 | 158 | # sigmoid activation for rgb 159 | rgbs = torch.sigmoid(h.view(-1, 3, self.color_basis_dim) @ color_basis) 160 | 161 | return sigma, rgbs, None 162 | 163 | def density(self, x, t): 164 | # x: [N, 3], in [-bound, bound] 165 | # t: [1, 1], in [0, 1] 166 | 167 | results = {} 168 | 169 | # time --> basis 170 | enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] 171 | h = enc_t 172 | for l in range(self.num_layers_basis): 173 | h = self.basis_net[l](h) 174 | if l != self.num_layers_basis - 1: 175 | h = F.relu(h, inplace=True) 176 | 177 | sigma_basis = h[0, :self.sigma_basis_dim] 178 | color_basis = h[0, self.sigma_basis_dim:] 179 | 180 | # sigma 181 | x = self.encoder(x, bound=self.bound) 182 | h = x 183 | for l in range(self.num_layers): 184 | h = self.sigma_net[l](h) 185 | if l != self.num_layers - 1: 186 | h = F.relu(h, inplace=True) 187 | 188 | sigma = trunc_exp(h[..., :self.sigma_basis_dim] @ sigma_basis) 189 | geo_feat = h[..., self.sigma_basis_dim:] 190 | 191 | results['sigma'] = sigma 192 | results['geo_feat'] = geo_feat 193 | # results['color_basis'] = color_basis 194 | 195 | return results 196 | 197 | def background(self, x, d): 198 | # x: [N, 2], in [-1, 1] 199 | 200 | h = self.encoder_bg(x) # [N, C] 201 | d = self.encoder_dir(d) 202 | 203 | h = torch.cat([d, h], dim=-1) 204 | for l in range(self.num_layers_bg): 205 | h = self.bg_net[l](h) 206 | if l != self.num_layers_bg - 1: 207 | h = F.relu(h, inplace=True) 208 | 209 | # sigmoid activation for rgb 210 | rgbs = torch.sigmoid(h) 211 | 212 | return rgbs 213 | 214 | # TODO: non cuda-ray mode is broken for now... (how to pass color_basis to self.color()) 215 | # # allow masked inference 216 | # def color(self, x, d, mask=None, geo_feat=None, **kwargs): 217 | # # x: [N, 3] in [-bound, bound] 218 | # # t: [1, 1], in [0, 1] 219 | # # mask: [N,], bool, indicates where we actually needs to compute rgb. 220 | 221 | # if mask is not None: 222 | # rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 223 | # # in case of empty mask 224 | # if not mask.any(): 225 | # return rgbs 226 | # x = x[mask] 227 | # d = d[mask] 228 | # geo_feat = geo_feat[mask] 229 | 230 | # d = self.encoder_dir(d) 231 | # h = torch.cat([d, geo_feat], dim=-1) 232 | # for l in range(self.num_layers_color): 233 | # h = self.color_net[l](h) 234 | # if l != self.num_layers_color - 1: 235 | # h = F.relu(h, inplace=True) 236 | 237 | # # sigmoid activation for rgb 238 | # h = torch.sigmoid(h) 239 | 240 | # if mask is not None: 241 | # rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 242 | # else: 243 | # rgbs = h 244 | 245 | # return rgbs 246 | 247 | # optimizer utils 248 | def get_params(self, lr, lr_net): 249 | 250 | params = [ 251 | {'params': self.encoder.parameters(), 'lr': lr}, 252 | {'params': self.sigma_net.parameters(), 'lr': lr_net}, 253 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 254 | {'params': self.color_net.parameters(), 'lr': lr_net}, 255 | {'params': self.encoder_time.parameters(), 'lr': lr}, 256 | {'params': self.basis_net.parameters(), 'lr': lr_net}, 257 | ] 258 | if self.bg_radius > 0: 259 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 260 | params.append({'params': self.bg_net.parameters(), 'lr': lr_net}) 261 | 262 | return params 263 | -------------------------------------------------------------------------------- /dnerf/network_hyper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from .renderer import NeRFRenderer 8 | 9 | 10 | class NeRFNetwork(NeRFRenderer): 11 | def __init__(self, 12 | encoding="tiledgrid", 13 | encoding_dir="sphere_harmonics", 14 | encoding_time="frequency", 15 | encoding_bg="hashgrid", 16 | num_layers=2, 17 | hidden_dim=64, 18 | geo_feat_dim=32, 19 | num_layers_color=3, 20 | hidden_dim_color=64, 21 | num_layers_bg=2, 22 | hidden_dim_bg=64, 23 | num_layers_ambient=5, 24 | hidden_dim_ambient=128, 25 | ambient_dim=1, 26 | bound=1, 27 | **kwargs, 28 | ): 29 | super().__init__(bound, **kwargs) 30 | 31 | # ambient network 32 | self.num_layers_ambient = num_layers_ambient 33 | self.hidden_dim_ambient = hidden_dim_ambient 34 | self.ambient_dim = ambient_dim 35 | self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6) 36 | 37 | ambient_net = [] 38 | for l in range(num_layers_ambient): 39 | if l == 0: 40 | in_dim = self.in_dim_time 41 | else: 42 | in_dim = hidden_dim_ambient 43 | 44 | if l == num_layers_ambient - 1: 45 | out_dim = self.ambient_dim 46 | else: 47 | out_dim = hidden_dim_ambient 48 | 49 | ambient_net.append(nn.Linear(in_dim, out_dim, bias=False)) 50 | 51 | self.ambient_net = nn.ModuleList(ambient_net) 52 | 53 | # sigma network 54 | self.num_layers = num_layers 55 | self.hidden_dim = hidden_dim 56 | self.geo_feat_dim = geo_feat_dim 57 | self.encoder, self.in_dim = get_encoder(encoding, input_dim=3+self.ambient_dim, desired_resolution=2048 * bound) 58 | 59 | sigma_net = [] 60 | for l in range(num_layers): 61 | if l == 0: 62 | in_dim = self.in_dim 63 | else: 64 | in_dim = hidden_dim 65 | 66 | if l == num_layers - 1: 67 | out_dim = 1 + self.geo_feat_dim # 1 sigma + features for color 68 | else: 69 | out_dim = hidden_dim 70 | 71 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 72 | 73 | self.sigma_net = nn.ModuleList(sigma_net) 74 | 75 | # color network 76 | self.num_layers_color = num_layers_color 77 | self.hidden_dim_color = hidden_dim_color 78 | self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) 79 | 80 | color_net = [] 81 | for l in range(num_layers_color): 82 | if l == 0: 83 | in_dim = self.in_dim_dir + self.geo_feat_dim 84 | else: 85 | in_dim = hidden_dim_color 86 | 87 | if l == num_layers_color - 1: 88 | out_dim = 3 # 3 rgb 89 | else: 90 | out_dim = hidden_dim_color 91 | 92 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 93 | 94 | self.color_net = nn.ModuleList(color_net) 95 | 96 | # background network 97 | if self.bg_radius > 0: 98 | self.num_layers_bg = num_layers_bg 99 | self.hidden_dim_bg = hidden_dim_bg 100 | self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid 101 | 102 | bg_net = [] 103 | for l in range(num_layers_bg): 104 | if l == 0: 105 | in_dim = self.in_dim_bg + self.in_dim_dir 106 | else: 107 | in_dim = hidden_dim_bg 108 | 109 | if l == num_layers_bg - 1: 110 | out_dim = 3 # 3 rgb 111 | else: 112 | out_dim = hidden_dim_bg 113 | 114 | bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) 115 | 116 | self.bg_net = nn.ModuleList(bg_net) 117 | else: 118 | self.bg_net = None 119 | 120 | 121 | def forward(self, x, d, t): 122 | # x: [N, 3], in [-bound, bound] 123 | # d: [N, 3], nomalized in [-1, 1] 124 | # t: [1, 1], in [0, 1] 125 | 126 | # time --> ambient 127 | enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] 128 | # if enc_t.shape[0] == 1: 129 | # enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] 130 | ambient = enc_t 131 | for l in range(self.num_layers_ambient): 132 | ambient = self.ambient_net[l](ambient) 133 | if l != self.num_layers_ambient - 1: 134 | ambient = F.relu(ambient, inplace=True) 135 | 136 | ambient = F.tanh(ambient) * self.bound 137 | x = torch.cat([x, ambient.repeat(x.shape[0], 1)], dim=1) 138 | 139 | # sigma 140 | x = self.encoder(x, bound=self.bound) 141 | h = x 142 | for l in range(self.num_layers): 143 | h = self.sigma_net[l](h) 144 | if l != self.num_layers - 1: 145 | h = F.relu(h, inplace=True) 146 | 147 | sigma = trunc_exp(h[..., 0]) 148 | geo_feat = h[..., 1:] 149 | 150 | # color 151 | d = self.encoder_dir(d) 152 | h = torch.cat([d, geo_feat], dim=-1) 153 | for l in range(self.num_layers_color): 154 | h = self.color_net[l](h) 155 | if l != self.num_layers_color - 1: 156 | h = F.relu(h, inplace=True) 157 | 158 | # sigmoid activation for rgb 159 | rgbs = torch.sigmoid(h) 160 | 161 | return sigma, rgbs, None 162 | 163 | def density(self, x, t): 164 | # x: [N, 3], in [-bound, bound] 165 | # t: [1, 1], in [0, 1] 166 | 167 | results = {} 168 | 169 | # time --> ambient 170 | enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] 171 | ambient = enc_t 172 | for l in range(self.num_layers_ambient): 173 | ambient = self.ambient_net[l](ambient) 174 | if l != self.num_layers_ambient - 1: 175 | ambient = F.relu(ambient, inplace=True) 176 | 177 | ambient = F.tanh(ambient) * self.bound 178 | x = torch.cat([x, ambient.repeat(x.shape[0], 1)], dim=1) 179 | 180 | # sigma 181 | x = self.encoder(x, bound=self.bound) 182 | h = x 183 | for l in range(self.num_layers): 184 | h = self.sigma_net[l](h) 185 | if l != self.num_layers - 1: 186 | h = F.relu(h, inplace=True) 187 | 188 | sigma = trunc_exp(h[..., 0]) 189 | geo_feat = h[..., 1:] 190 | 191 | results['sigma'] = sigma 192 | results['geo_feat'] = geo_feat 193 | 194 | return results 195 | 196 | def background(self, x, d): 197 | # x: [N, 2], in [-1, 1] 198 | 199 | h = self.encoder_bg(x) # [N, C] 200 | d = self.encoder_dir(d) 201 | 202 | h = torch.cat([d, h], dim=-1) 203 | for l in range(self.num_layers_bg): 204 | h = self.bg_net[l](h) 205 | if l != self.num_layers_bg - 1: 206 | h = F.relu(h, inplace=True) 207 | 208 | # sigmoid activation for rgb 209 | rgbs = torch.sigmoid(h) 210 | 211 | return rgbs 212 | 213 | 214 | # allow masked inference 215 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 216 | # x: [N, 3] in [-bound, bound] 217 | # t: [1, 1], in [0, 1] 218 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 219 | 220 | if mask is not None: 221 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 222 | # in case of empty mask 223 | if not mask.any(): 224 | return rgbs 225 | x = x[mask] 226 | d = d[mask] 227 | geo_feat = geo_feat[mask] 228 | 229 | d = self.encoder_dir(d) 230 | h = torch.cat([d, geo_feat], dim=-1) 231 | for l in range(self.num_layers_color): 232 | h = self.color_net[l](h) 233 | if l != self.num_layers_color - 1: 234 | h = F.relu(h, inplace=True) 235 | 236 | # sigmoid activation for rgb 237 | h = torch.sigmoid(h) 238 | 239 | if mask is not None: 240 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 241 | else: 242 | rgbs = h 243 | 244 | return rgbs 245 | 246 | # optimizer utils 247 | def get_params(self, lr, lr_net): 248 | 249 | params = [ 250 | {'params': self.encoder.parameters(), 'lr': lr}, 251 | {'params': self.sigma_net.parameters(), 'lr': lr_net}, 252 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 253 | {'params': self.color_net.parameters(), 'lr': lr_net}, 254 | {'params': self.encoder_time.parameters(), 'lr': lr}, 255 | {'params': self.ambient_net.parameters(), 'lr': lr_net}, 256 | ] 257 | if self.bg_radius > 0: 258 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 259 | params.append({'params': self.bg_net.parameters(), 'lr': lr_net}) 260 | 261 | return params 262 | -------------------------------------------------------------------------------- /dnerf/utils.py: -------------------------------------------------------------------------------- 1 | from nerf.utils import * 2 | from nerf.utils import Trainer as _Trainer 3 | 4 | 5 | class Trainer(_Trainer): 6 | def __init__(self, 7 | name, # name of this experiment 8 | opt, # extra conf 9 | model, # network 10 | criterion=None, # loss function, if None, assume inline implementation in train_step 11 | optimizer=None, # optimizer 12 | ema_decay=None, # if use EMA, set the decay 13 | lr_scheduler=None, # scheduler 14 | metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. 15 | local_rank=0, # which GPU am I 16 | world_size=1, # total num of GPUs 17 | device=None, # device to use, usually setting to None is OK. (auto choose device) 18 | mute=False, # whether to mute all print 19 | fp16=False, # amp optimize level 20 | eval_interval=1, # eval once every $ epoch 21 | max_keep_ckpt=2, # max num of saved ckpts in disk 22 | workspace='workspace', # workspace to save logs & ckpts 23 | best_mode='min', # the smaller/larger result, the better 24 | use_loss_as_metric=True, # use loss as the first metric 25 | report_metric_at_train=False, # also report metrics at training 26 | use_checkpoint="latest", # which ckpt to use at init time 27 | use_tensorboardX=True, # whether to use tensorboard for logging 28 | scheduler_update_every_step=False, # whether to call scheduler.step() after every train step 29 | ): 30 | 31 | self.optimizer_fn = optimizer 32 | self.lr_scheduler_fn = lr_scheduler 33 | 34 | super().__init__(name, opt, model, criterion, optimizer, ema_decay, lr_scheduler, metrics, local_rank, world_size, device, mute, fp16, eval_interval, max_keep_ckpt, workspace, best_mode, use_loss_as_metric, report_metric_at_train, use_checkpoint, use_tensorboardX, scheduler_update_every_step) 35 | 36 | ### ------------------------------ 37 | 38 | def train_step(self, data): 39 | 40 | rays_o = data['rays_o'] # [B, N, 3] 41 | rays_d = data['rays_d'] # [B, N, 3] 42 | time = data['time'] # [B, 1] 43 | 44 | # if there is no gt image, we train with CLIP loss. 45 | if 'images' not in data: 46 | 47 | B, N = rays_o.shape[:2] 48 | H, W = data['H'], data['W'] 49 | 50 | # currently fix white bg, MUST force all rays! 51 | outputs = self.model.render(rays_o, rays_d, time, staged=False, bg_color=None, perturb=True, force_all_rays=True, **vars(self.opt)) 52 | pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() 53 | 54 | # [debug] uncomment to plot the images used in train_step 55 | #torch_vis_2d(pred_rgb[0]) 56 | 57 | loss = self.clip_loss(pred_rgb) 58 | 59 | return pred_rgb, None, loss 60 | 61 | images = data['images'] # [B, N, 3/4] 62 | 63 | B, N, C = images.shape 64 | 65 | if self.opt.color_space == 'linear': 66 | images[..., :3] = srgb_to_linear(images[..., :3]) 67 | 68 | if C == 3 or self.model.bg_radius > 0: 69 | bg_color = 1 70 | # train with random background color if not using a bg model and has alpha channel. 71 | else: 72 | #bg_color = torch.ones(3, device=self.device) # [3], fixed white background 73 | #bg_color = torch.rand(3, device=self.device) # [3], frame-wise random. 74 | bg_color = torch.rand_like(images[..., :3]) # [N, 3], pixel-wise random. 75 | 76 | if C == 4: 77 | gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) 78 | else: 79 | gt_rgb = images 80 | 81 | outputs = self.model.render(rays_o, rays_d, time, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt)) 82 | 83 | pred_rgb = outputs['image'] 84 | 85 | loss = self.criterion(pred_rgb, gt_rgb).mean(-1) # [B, N, 3] --> [B, N] 86 | 87 | # special case for CCNeRF's rank-residual training 88 | if len(loss.shape) == 3: # [K, B, N] 89 | loss = loss.mean(0) 90 | 91 | # update error_map 92 | if self.error_map is not None: 93 | index = data['index'] # [B] 94 | inds = data['inds_coarse'] # [B, N] 95 | 96 | # take out, this is an advanced indexing and the copy is unavoidable. 97 | error_map = self.error_map[index] # [B, H * W] 98 | 99 | # [debug] uncomment to save and visualize error map 100 | # if self.global_step % 1001 == 0: 101 | # tmp = error_map[0].view(128, 128).cpu().numpy() 102 | # print(f'[write error map] {tmp.shape} {tmp.min()} ~ {tmp.max()}') 103 | # tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min()) 104 | # cv2.imwrite(os.path.join(self.workspace, f'{self.global_step}.jpg'), (tmp * 255).astype(np.uint8)) 105 | 106 | error = loss.detach().to(error_map.device) # [B, N], already in [0, 1] 107 | 108 | # ema update 109 | ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error 110 | error_map.scatter_(1, inds, ema_error) 111 | 112 | # put back 113 | self.error_map[index] = error_map 114 | 115 | loss = loss.mean() 116 | 117 | # deform regularization 118 | if 'deform' in outputs and outputs['deform'] is not None: 119 | loss = loss + 1e-3 * outputs['deform'].abs().mean() 120 | 121 | return pred_rgb, gt_rgb, loss 122 | 123 | def eval_step(self, data): 124 | 125 | rays_o = data['rays_o'] # [B, N, 3] 126 | rays_d = data['rays_d'] # [B, N, 3] 127 | time = data['time'] # [B, 1] 128 | images = data['images'] # [B, H, W, 3/4] 129 | B, H, W, C = images.shape 130 | 131 | if self.opt.color_space == 'linear': 132 | images[..., :3] = srgb_to_linear(images[..., :3]) 133 | 134 | # eval with fixed background color 135 | bg_color = 1 136 | if C == 4: 137 | gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) 138 | else: 139 | gt_rgb = images 140 | 141 | outputs = self.model.render(rays_o, rays_d, time, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt)) 142 | 143 | pred_rgb = outputs['image'].reshape(B, H, W, 3) 144 | pred_depth = outputs['depth'].reshape(B, H, W) 145 | 146 | loss = self.criterion(pred_rgb, gt_rgb).mean() 147 | 148 | return pred_rgb, pred_depth, gt_rgb, loss 149 | 150 | # moved out bg_color and perturb for more flexible control... 151 | def test_step(self, data, bg_color=None, perturb=False): 152 | 153 | rays_o = data['rays_o'] # [B, N, 3] 154 | rays_d = data['rays_d'] # [B, N, 3] 155 | time = data['time'] # [B, 1] 156 | H, W = data['H'], data['W'] 157 | 158 | if bg_color is not None: 159 | bg_color = bg_color.to(self.device) 160 | 161 | outputs = self.model.render(rays_o, rays_d, time, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt)) 162 | 163 | pred_rgb = outputs['image'].reshape(-1, H, W, 3) 164 | pred_depth = outputs['depth'].reshape(-1, H, W) 165 | 166 | return pred_rgb, pred_depth 167 | 168 | # [GUI] test on a single image 169 | def test_gui(self, pose, intrinsics, W, H, time=0, bg_color=None, spp=1, downscale=1): 170 | 171 | # render resolution (may need downscale to for better frame rate) 172 | rH = int(H * downscale) 173 | rW = int(W * downscale) 174 | intrinsics = intrinsics * downscale 175 | 176 | pose = torch.from_numpy(pose).unsqueeze(0).to(self.device) 177 | 178 | rays = get_rays(pose, intrinsics, rH, rW, -1) 179 | 180 | data = { 181 | 'time': torch.FloatTensor([[time]]).to(self.device), # from scalar to [1, 1] tensor. 182 | 'rays_o': rays['rays_o'], 183 | 'rays_d': rays['rays_d'], 184 | 'H': rH, 185 | 'W': rW, 186 | } 187 | 188 | self.model.eval() 189 | 190 | if self.ema is not None: 191 | self.ema.store() 192 | self.ema.copy_to() 193 | 194 | with torch.no_grad(): 195 | with torch.cuda.amp.autocast(enabled=self.fp16): 196 | # here spp is used as perturb random seed! 197 | preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp) 198 | 199 | if self.ema is not None: 200 | self.ema.restore() 201 | 202 | # interpolation to the original resolution 203 | if downscale != 1: 204 | # TODO: have to permute twice with torch... 205 | preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous() 206 | preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) 207 | 208 | if self.opt.color_space == 'linear': 209 | preds = linear_to_srgb(preds) 210 | 211 | pred = preds[0].detach().cpu().numpy() 212 | pred_depth = preds_depth[0].detach().cpu().numpy() 213 | 214 | outputs = { 215 | 'image': pred, 216 | 'depth': pred_depth, 217 | } 218 | 219 | return outputs 220 | 221 | def save_mesh(self, time, save_path=None, resolution=256, threshold=10): 222 | # time: scalar in [0, 1] 223 | time = torch.FloatTensor([[time]]).to(self.device) 224 | 225 | if save_path is None: 226 | save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply') 227 | 228 | self.log(f"==> Saving mesh to {save_path}") 229 | 230 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 231 | 232 | def query_func(pts): 233 | with torch.no_grad(): 234 | with torch.cuda.amp.autocast(enabled=self.fp16): 235 | sigma = self.model.density(pts.to(self.device), time)['sigma'] 236 | return sigma 237 | 238 | vertices, triangles = extract_geometry(self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func) 239 | 240 | mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... 241 | mesh.export(save_path) 242 | 243 | self.log(f"==> Finished saving mesh.") -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FreqEncoder(nn.Module): 6 | def __init__(self, input_dim, max_freq_log2, N_freqs, 7 | log_sampling=True, include_input=True, 8 | periodic_fns=(torch.sin, torch.cos)): 9 | 10 | super().__init__() 11 | 12 | self.input_dim = input_dim 13 | self.include_input = include_input 14 | self.periodic_fns = periodic_fns 15 | 16 | self.output_dim = 0 17 | if self.include_input: 18 | self.output_dim += self.input_dim 19 | 20 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 21 | 22 | if log_sampling: 23 | self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) 24 | else: 25 | self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs) 26 | 27 | self.freq_bands = self.freq_bands.numpy().tolist() 28 | 29 | def forward(self, input, **kwargs): 30 | 31 | out = [] 32 | if self.include_input: 33 | out.append(input) 34 | 35 | for i in range(len(self.freq_bands)): 36 | freq = self.freq_bands[i] 37 | for p_fn in self.periodic_fns: 38 | out.append(p_fn(input * freq)) 39 | 40 | out = torch.cat(out, dim=-1) 41 | 42 | 43 | return out 44 | 45 | def get_encoder(encoding, input_dim=3, 46 | multires=6, 47 | degree=4, 48 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, 49 | **kwargs): 50 | 51 | if encoding == 'None': 52 | return lambda x, **kwargs: x, input_dim 53 | 54 | elif encoding == 'frequency': 55 | #encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 56 | from freqencoder import FreqEncoder 57 | encoder = FreqEncoder(input_dim=input_dim, degree=multires) 58 | 59 | elif encoding == 'sphere_harmonics': 60 | from shencoder import SHEncoder 61 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 62 | 63 | elif encoding == 'hashgrid': 64 | from gridencoder import GridEncoder 65 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) 66 | 67 | elif encoding == 'tiledgrid': 68 | from gridencoder import GridEncoder 69 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) 70 | 71 | elif encoding == 'ash': 72 | from ashencoder import AshEncoder 73 | encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) 74 | 75 | else: 76 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') 77 | 78 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /ffmlp/__init__.py: -------------------------------------------------------------------------------- 1 | from .ffmlp import FFMLP -------------------------------------------------------------------------------- /ffmlp/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '--expt-extended-lambda', '--expt-relaxed-constexpr', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | nvcc_flags += ['-Xcompiler=-mf16c', '-Xcompiler=-Wno-float-conversion', '-Xcompiler=-fno-strict-aliasing'] 14 | c_flags = ['-O3', '-std=c++14'] 15 | elif os.name == "nt": 16 | c_flags = ['/O2', '/std:c++17'] 17 | 18 | # find cl.exe 19 | def find_cl_path(): 20 | import glob 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | _backend = load(name='_ffmlp', 34 | extra_cflags=c_flags, 35 | extra_cuda_cflags=nvcc_flags, 36 | extra_include_paths=[ 37 | os.path.join(_src_path, 'dependencies/cutlass/include'), 38 | os.path.join(_src_path, 'dependencies/cutlass/tools/util/include'), 39 | ], 40 | sources=[os.path.join(_src_path, 'src', f) for f in [ 41 | 'ffmlp.cu', 42 | 'bindings.cpp', 43 | ]], 44 | ) 45 | 46 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /ffmlp/ffmlp.py: -------------------------------------------------------------------------------- 1 | import math 2 | from turtle import backward, forward 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | import atexit 9 | 10 | try: 11 | import _ffmlp as _backend 12 | except ImportError: 13 | from .backend import _backend 14 | 15 | class _ffmlp_forward(Function): 16 | 17 | @staticmethod 18 | @custom_fwd(cast_inputs=torch.half) 19 | def forward(ctx, inputs, weights, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, inference=False, calc_grad_inputs=False): 20 | 21 | B = inputs.shape[0] 22 | 23 | inputs = inputs.contiguous() 24 | weights = weights.contiguous() 25 | 26 | # print('[inputs]', torch.any(torch.isnan(inputs)), inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 27 | # print('[weights]', torch.any(torch.isnan(weights)), weights.shape, weights.dtype, weights.min().item(), weights.max().item()) 28 | 29 | # allocate output 30 | outputs = torch.empty(B, output_dim, device=inputs.device, dtype=inputs.dtype) 31 | 32 | if not inference: 33 | forward_buffer = torch.empty(num_layers, B, hidden_dim, device=inputs.device, dtype=inputs.dtype) 34 | _backend.ffmlp_forward(inputs, weights, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, forward_buffer, outputs) 35 | ctx.save_for_backward(inputs, weights, outputs, forward_buffer) 36 | ctx.dims = (input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs) 37 | 38 | # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 39 | # print('[forward_buffer]', torch.any(torch.isnan(forward_buffer)), forward_buffer.shape, forward_buffer.dtype, forward_buffer.min().item(), forward_buffer.max().item()) 40 | else: 41 | inference_buffer = torch.empty(B, hidden_dim, device=inputs.device, dtype=inputs.dtype) 42 | _backend.ffmlp_inference(inputs, weights, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, inference_buffer, outputs) 43 | 44 | # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 45 | # print('[inference_buffer]', torch.any(torch.isnan(inference_buffer)), inference_buffer.shape, inference_buffer.dtype, inference_buffer.min().item(), inference_buffer.max().item()) 46 | 47 | 48 | return outputs 49 | 50 | @staticmethod 51 | @custom_bwd 52 | def backward(ctx, grad): 53 | # grad: [B, output_dim] 54 | 55 | B = grad.shape[0] 56 | 57 | grad = grad.contiguous() 58 | 59 | # print('[grad]', torch.any(torch.isnan(grad)), grad.shape, grad.dtype, grad.min().item(), grad.max().item()) 60 | # print(grad) 61 | 62 | inputs, weights, outputs, forward_buffer = ctx.saved_tensors 63 | 64 | input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs = ctx.dims 65 | 66 | # allocate outputs 67 | if calc_grad_inputs: 68 | grad_inputs = torch.zeros_like(inputs) 69 | else: 70 | grad_inputs = torch.zeros(1, device=grad.device, dtype=grad.dtype) # dummy 71 | 72 | grad_weights = torch.zeros_like(weights) 73 | backward_buffer = torch.zeros(num_layers, B, hidden_dim, device=grad.device, dtype=grad.dtype) 74 | 75 | _backend.ffmlp_backward(grad, inputs, weights, forward_buffer, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs, backward_buffer, grad_inputs, grad_weights) 76 | 77 | # print('[grad_inputs]', grad_inputs.shape, grad_inputs.dtype, grad_inputs.min().item(), grad_inputs.max().item()) 78 | # print('[grad_weights]', grad_weights.shape, grad_weights.dtype, grad_weights.min().item(), grad_weights.max().item()) 79 | # print('[backward_buffer]', backward_buffer.shape, backward_buffer.dtype, backward_buffer.min().item(), backward_buffer.max().item()) 80 | if calc_grad_inputs: 81 | return grad_inputs, grad_weights, None, None, None, None, None, None, None, None 82 | else: 83 | return None, grad_weights, None, None, None, None, None, None, None, None 84 | 85 | 86 | ffmlp_forward = _ffmlp_forward.apply 87 | 88 | 89 | def convert_activation(act): 90 | if act == 'relu': return 0 91 | elif act == 'exponential': return 1 92 | elif act == 'sine': return 2 93 | elif act == 'sigmoid': return 3 94 | elif act == 'squareplus': return 4 95 | elif act == 'softplus': return 5 96 | else: return 6 97 | 98 | 99 | class FFMLP(nn.Module): 100 | def __init__(self, input_dim, output_dim, hidden_dim, num_layers, activation='relu'): 101 | super().__init__() 102 | 103 | self.input_dim = input_dim 104 | self.output_dim = output_dim 105 | self.hidden_dim = hidden_dim 106 | self.num_layers = num_layers 107 | self.activation = convert_activation(activation) 108 | self.output_activation = convert_activation('none') # not supported currently 109 | 110 | self.tensorcore_width = 16 111 | 112 | assert hidden_dim in [16, 32, 64, 128, 256], f"FFMLP only support hidden_dim in [16, 32, 64, 128, 256], but got {hidden_dim}" 113 | assert input_dim > 0 and input_dim % 16 == 0, f"FFMLP input_dim should be 16 * m (m > 0), but got {input_dim}" 114 | assert output_dim <= 16, f"FFMLP current only supports output dim <= 16, but got {output_dim}" 115 | assert num_layers >= 2, f"FFMLP num_layers should be larger than 2 (3 matmuls), but got {num_layers}" 116 | 117 | # pad output 118 | self.padded_output_dim = int(math.ceil(output_dim / 16)) * 16 119 | 120 | # parameters (continuous in memory) 121 | self.num_parameters = hidden_dim * (input_dim + hidden_dim * (num_layers - 1) + self.padded_output_dim) 122 | self.weights = nn.Parameter(torch.zeros(self.num_parameters)) 123 | self.reset_parameters() 124 | 125 | # allocate streams 126 | _backend.allocate_splitk(self.num_layers + 1) 127 | 128 | # register destructor 129 | #atexit.register(self.cleanup) # how to correctly clean? this gives CUDA Error: cudaEventDestroy(events[i]) failed with error context is destroyed 130 | 131 | 132 | def cleanup(self): 133 | # destroy streams 134 | _backend.free_splitk() 135 | 136 | 137 | def __repr__(self): 138 | return f"FFMLP: input_dim={self.input_dim} output_dim={self.output_dim} hidden_dim={self.hidden_dim} num_layers={self.num_layers} activation={self.activation}" 139 | 140 | 141 | def reset_parameters(self): 142 | torch.manual_seed(42) 143 | std = math.sqrt(3 / self.hidden_dim) 144 | self.weights.data.uniform_(-std, std) 145 | 146 | 147 | def forward(self, inputs): 148 | # inputs: [B, input_dim] 149 | # return: [B, outupt_dim] 150 | 151 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item(), inputs.requires_grad) 152 | 153 | B, C = inputs.shape 154 | #assert B >= 128 and B % 128 == 0, f"ffmlp batch size must be 128 * m (m > 0), but got {B}." 155 | 156 | # pad input 157 | pad = 128 - (B % 128) 158 | if pad > 0: 159 | inputs = torch.cat([inputs, torch.zeros(pad, C, dtype=inputs.dtype, device=inputs.device)], dim=0) 160 | 161 | outputs = ffmlp_forward(inputs, self.weights, self.input_dim, self.padded_output_dim, self.hidden_dim, self.num_layers, self.activation, self.output_activation, not self.training, inputs.requires_grad) 162 | 163 | # unpad output 164 | if B != outputs.shape[0] or self.padded_output_dim != self.output_dim: 165 | outputs = outputs[:B, :self.output_dim] 166 | 167 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 168 | 169 | return outputs -------------------------------------------------------------------------------- /ffmlp/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '--expt-extended-lambda', '--expt-relaxed-constexpr', 10 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 11 | ] 12 | 13 | if os.name == "posix": 14 | nvcc_flags += ['-Xcompiler=-mf16c', '-Xcompiler=-Wno-float-conversion', '-Xcompiler=-fno-strict-aliasing'] 15 | c_flags = ['-O3', '-std=c++14'] 16 | elif os.name == "nt": 17 | c_flags = ['/O2', '/std:c++17'] 18 | 19 | # find cl.exe 20 | def find_cl_path(): 21 | import glob 22 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 23 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 24 | if paths: 25 | return paths[0] 26 | 27 | # If cl.exe is not on path, try to find it. 28 | if os.system("where cl.exe >nul 2>nul") != 0: 29 | cl_path = find_cl_path() 30 | if cl_path is None: 31 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 32 | os.environ["PATH"] += ";" + cl_path 33 | 34 | setup( 35 | name='ffmlp', # package name, import this to use python API 36 | ext_modules=[ 37 | CUDAExtension( 38 | name='_ffmlp', # extension name, import this to use CUDA API 39 | sources=[os.path.join(_src_path, 'src', f) for f in [ 40 | 'ffmlp.cu', 41 | 'bindings.cpp', 42 | ]], 43 | extra_compile_args={ 44 | 'cxx': c_flags, 45 | 'nvcc': nvcc_flags, 46 | }, 47 | include_dirs=[ 48 | os.path.join(_src_path, 'dependencies/cutlass/include'), 49 | os.path.join(_src_path, 'dependencies/cutlass/tools/util/include'), 50 | ], 51 | ), 52 | ], 53 | cmdclass={ 54 | 'build_ext': BuildExtension, 55 | } 56 | ) -------------------------------------------------------------------------------- /ffmlp/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "ffmlp.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("ffmlp_forward", &ffmlp_forward, "ffmlp_forward (CUDA)"); 7 | m.def("ffmlp_inference", &ffmlp_inference, "ffmlp_inference (CUDA)"); 8 | m.def("ffmlp_backward", &ffmlp_backward, "ffmlp_backward (CUDA)"); 9 | m.def("allocate_splitk", &allocate_splitk, "allocate_splitk (CUDA)"); 10 | m.def("free_splitk", &free_splitk, "free_splitk (CUDA)"); 11 | } -------------------------------------------------------------------------------- /ffmlp/src/ffmlp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | // activation: should have been enum, here we just use int. 8 | void ffmlp_forward(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor forward_buffer, at::Tensor outputs); 9 | void ffmlp_inference(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor inference_buffer, at::Tensor outputs); 10 | 11 | void ffmlp_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor weights, const at::Tensor forward_buffer, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation, const uint32_t output_activation, const bool calc_grad_inputs, at::Tensor backward_buffer, at::Tensor grad_inputs, at::Tensor grad_weights); 12 | 13 | void allocate_splitk(size_t size); 14 | void free_splitk(); -------------------------------------------------------------------------------- /freqencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .freq import FreqEncoder -------------------------------------------------------------------------------- /freqencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | '-use_fast_math' 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | _backend = load(name='_freqencoder', 33 | extra_cflags=c_flags, 34 | extra_cuda_cflags=nvcc_flags, 35 | sources=[os.path.join(_src_path, 'src', f) for f in [ 36 | 'freqencoder.cu', 37 | 'bindings.cpp', 38 | ]], 39 | ) 40 | 41 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /freqencoder/freq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _freqencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | class _freq_encoder(Function): 16 | @staticmethod 17 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 18 | def forward(ctx, inputs, degree, output_dim): 19 | # inputs: [B, input_dim], float 20 | # RETURN: [B, F], float 21 | 22 | if not inputs.is_cuda: inputs = inputs.cuda() 23 | inputs = inputs.contiguous() 24 | 25 | B, input_dim = inputs.shape # batch size, coord dim 26 | 27 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 28 | 29 | _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 30 | 31 | ctx.save_for_backward(inputs, outputs) 32 | ctx.dims = [B, input_dim, degree, output_dim] 33 | 34 | return outputs 35 | 36 | @staticmethod 37 | #@once_differentiable 38 | @custom_bwd 39 | def backward(ctx, grad): 40 | # grad: [B, C * C] 41 | 42 | grad = grad.contiguous() 43 | inputs, outputs = ctx.saved_tensors 44 | B, input_dim, degree, output_dim = ctx.dims 45 | 46 | grad_inputs = torch.zeros_like(inputs) 47 | _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 48 | 49 | return grad_inputs, None, None 50 | 51 | 52 | freq_encode = _freq_encoder.apply 53 | 54 | 55 | class FreqEncoder(nn.Module): 56 | def __init__(self, input_dim=3, degree=4): 57 | super().__init__() 58 | 59 | self.input_dim = input_dim 60 | self.degree = degree 61 | self.output_dim = input_dim + input_dim * 2 * degree 62 | 63 | def __repr__(self): 64 | return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" 65 | 66 | def forward(self, inputs, **kwargs): 67 | # inputs: [..., input_dim] 68 | # return: [..., ] 69 | 70 | prefix_shape = list(inputs.shape[:-1]) 71 | inputs = inputs.reshape(-1, self.input_dim) 72 | 73 | outputs = freq_encode(inputs, self.degree, self.output_dim) 74 | 75 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 76 | 77 | return outputs -------------------------------------------------------------------------------- /freqencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | '-use_fast_math' 11 | ] 12 | 13 | if os.name == "posix": 14 | c_flags = ['-O3', '-std=c++14'] 15 | elif os.name == "nt": 16 | c_flags = ['/O2', '/std:c++17'] 17 | 18 | # find cl.exe 19 | def find_cl_path(): 20 | import glob 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | setup( 34 | name='freqencoder', # package name, import this to use python API 35 | ext_modules=[ 36 | CUDAExtension( 37 | name='_freqencoder', # extension name, import this to use CUDA API 38 | sources=[os.path.join(_src_path, 'src', f) for f in [ 39 | 'freqencoder.cu', 40 | 'bindings.cpp', 41 | ]], 42 | extra_compile_args={ 43 | 'cxx': c_flags, 44 | 'nvcc': nvcc_flags, 45 | } 46 | ), 47 | ], 48 | cmdclass={ 49 | 'build_ext': BuildExtension, 50 | } 51 | ) -------------------------------------------------------------------------------- /freqencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "freqencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); 7 | m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 20 | 21 | inline constexpr __device__ float PI() { return 3.141592653589793f; } 22 | 23 | template 24 | __host__ __device__ T div_round_up(T val, T divisor) { 25 | return (val + divisor - 1) / divisor; 26 | } 27 | 28 | // inputs: [B, D] 29 | // outputs: [B, C], C = D + D * deg * 2 30 | __global__ void kernel_freq( 31 | const float * __restrict__ inputs, 32 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 33 | float * outputs 34 | ) { 35 | // parallel on per-element 36 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 37 | if (t >= B * C) return; 38 | 39 | // get index 40 | const uint32_t b = t / C; 41 | const uint32_t c = t - b * C; // t % C; 42 | 43 | // locate 44 | inputs += b * D; 45 | outputs += t; 46 | 47 | // write self 48 | if (c < D) { 49 | outputs[0] = inputs[c]; 50 | // write freq 51 | } else { 52 | const uint32_t col = c / D - 1; 53 | const uint32_t d = c % D; 54 | const uint32_t freq = col / 2; 55 | const float phase_shift = (col % 2) * (PI() / 2); 56 | outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); 57 | } 58 | } 59 | 60 | // grad: [B, C], C = D + D * deg * 2 61 | // outputs: [B, C] 62 | // grad_inputs: [B, D] 63 | __global__ void kernel_freq_backward( 64 | const float * __restrict__ grad, 65 | const float * __restrict__ outputs, 66 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 67 | float * grad_inputs 68 | ) { 69 | // parallel on per-element 70 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 71 | if (t >= B * D) return; 72 | 73 | const uint32_t b = t / D; 74 | const uint32_t d = t - b * D; // t % D; 75 | 76 | // locate 77 | grad += b * C; 78 | outputs += b * C; 79 | grad_inputs += t; 80 | 81 | // register 82 | float result = grad[d]; 83 | grad += D; 84 | outputs += D; 85 | 86 | for (uint32_t f = 0; f < deg; f++) { 87 | result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); 88 | grad += 2 * D; 89 | outputs += 2 * D; 90 | } 91 | 92 | // write 93 | grad_inputs[0] = result; 94 | } 95 | 96 | 97 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { 98 | CHECK_CUDA(inputs); 99 | CHECK_CUDA(outputs); 100 | 101 | CHECK_CONTIGUOUS(inputs); 102 | CHECK_CONTIGUOUS(outputs); 103 | 104 | CHECK_IS_FLOATING(inputs); 105 | CHECK_IS_FLOATING(outputs); 106 | 107 | static constexpr uint32_t N_THREADS = 128; 108 | 109 | kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); 110 | } 111 | 112 | 113 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { 114 | CHECK_CUDA(grad); 115 | CHECK_CUDA(outputs); 116 | CHECK_CUDA(grad_inputs); 117 | 118 | CHECK_CONTIGUOUS(grad); 119 | CHECK_CONTIGUOUS(outputs); 120 | CHECK_CONTIGUOUS(grad_inputs); 121 | 122 | CHECK_IS_FLOATING(grad); 123 | CHECK_IS_FLOATING(outputs); 124 | CHECK_IS_FLOATING(grad_inputs); 125 | 126 | static constexpr uint32_t N_THREADS = 128; 127 | 128 | kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); 129 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 7 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); 8 | 9 | // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 10 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | 'hash': 0, 16 | 'tiled': 1, 17 | } 18 | 19 | _interp_to_id = { 20 | 'linear': 0, 21 | 'smoothstep': 1, 22 | } 23 | 24 | class _grid_encode(Function): 25 | @staticmethod 26 | @custom_fwd 27 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): 28 | # inputs: [B, D], float in [0, 1] 29 | # embeddings: [sO, C], float 30 | # offsets: [L + 1], int 31 | # RETURN: [B, F], float 32 | 33 | inputs = inputs.contiguous() 34 | 35 | B, D = inputs.shape # batch size, coord dim 36 | L = offsets.shape[0] - 1 # level 37 | C = embeddings.shape[1] # embedding dim for each level 38 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 39 | H = base_resolution # base resolution 40 | 41 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 42 | # if C % 2 != 0, force float, since half for atomicAdd is very slow. 43 | if torch.is_autocast_enabled() and C % 2 == 0: 44 | embeddings = embeddings.to(torch.half) 45 | 46 | # L first, optimize cache for cuda kernel, but needs an extra permute later 47 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 48 | 49 | if calc_grad_inputs: 50 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) 51 | else: 52 | dy_dx = None 53 | 54 | _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) 55 | 56 | # permute back to [B, L * C] 57 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 58 | 59 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 60 | ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] 61 | ctx.align_corners = align_corners 62 | 63 | return outputs 64 | 65 | @staticmethod 66 | #@once_differentiable 67 | @custom_bwd 68 | def backward(ctx, grad): 69 | 70 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 71 | B, D, C, L, S, H, gridtype, interpolation = ctx.dims 72 | align_corners = ctx.align_corners 73 | 74 | # grad: [B, L * C] --> [L, B, C] 75 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 76 | 77 | grad_embeddings = torch.zeros_like(embeddings) 78 | 79 | if dy_dx is not None: 80 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 81 | else: 82 | grad_inputs = None 83 | 84 | _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) 85 | 86 | if dy_dx is not None: 87 | grad_inputs = grad_inputs.to(inputs.dtype) 88 | 89 | return grad_inputs, grad_embeddings, None, None, None, None, None, None, None 90 | 91 | 92 | 93 | grid_encode = _grid_encode.apply 94 | 95 | 96 | class GridEncoder(nn.Module): 97 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): 98 | super().__init__() 99 | 100 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 101 | if desired_resolution is not None: 102 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 103 | 104 | self.input_dim = input_dim # coord dims, 2 or 3 105 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 106 | self.level_dim = level_dim # encode channels per level 107 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 108 | self.log2_hashmap_size = log2_hashmap_size 109 | self.base_resolution = base_resolution 110 | self.output_dim = num_levels * level_dim 111 | self.gridtype = gridtype 112 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 113 | self.interpolation = interpolation 114 | self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" 115 | self.align_corners = align_corners 116 | 117 | # allocate parameters 118 | offsets = [] 119 | offset = 0 120 | self.max_params = 2 ** log2_hashmap_size 121 | for i in range(num_levels): 122 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 123 | params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number 124 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 125 | offsets.append(offset) 126 | offset += params_in_level 127 | offsets.append(offset) 128 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 129 | self.register_buffer('offsets', offsets) 130 | 131 | self.n_params = offsets[-1] * level_dim 132 | 133 | # parameters 134 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 135 | 136 | self.reset_parameters() 137 | 138 | def reset_parameters(self): 139 | std = 1e-4 140 | self.embeddings.data.uniform_(-std, std) 141 | 142 | def __repr__(self): 143 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" 144 | 145 | def forward(self, inputs, bound=1): 146 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 147 | # return: [..., num_levels * level_dim] 148 | 149 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 150 | 151 | #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 152 | 153 | prefix_shape = list(inputs.shape[:-1]) 154 | inputs = inputs.view(-1, self.input_dim) 155 | 156 | outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) 157 | outputs = outputs.view(prefix_shape + [self.output_dim]) 158 | 159 | #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 160 | 161 | return outputs 162 | 163 | # always run in float precision! 164 | @torch.cuda.amp.autocast(enabled=False) 165 | def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): 166 | # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. 167 | 168 | D = self.input_dim 169 | C = self.embeddings.shape[1] # embedding dim for each level 170 | L = self.offsets.shape[0] - 1 # level 171 | S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 172 | H = self.base_resolution # base resolution 173 | 174 | if inputs is None: 175 | # randomized in [0, 1] 176 | inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) 177 | else: 178 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 179 | inputs = inputs.view(-1, self.input_dim) 180 | B = inputs.shape[0] 181 | 182 | if self.embeddings.grad is None: 183 | raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') 184 | 185 | _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 14 | 15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); 16 | 17 | #endif -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | def mape_loss(pred, target, reduction='mean'): 8 | # pred, target: [B, 1], torch tenspr 9 | difference = (pred - target).abs() 10 | scale = 1 / (target.abs() + 1e-2) 11 | loss = difference * scale 12 | 13 | if reduction == 'mean': 14 | loss = loss.mean() 15 | 16 | return loss 17 | 18 | def huber_loss(pred, target, delta=0.1, reduction='mean'): 19 | rel = (pred - target).abs() 20 | sqr = 0.5 / delta * rel * rel 21 | loss = torch.where(rel > delta, rel - 0.5 * delta, sqr) 22 | 23 | if reduction == 'mean': 24 | loss = loss.mean() 25 | 26 | return loss 27 | 28 | 29 | # ref: https://github.com/sunset1995/torch_efficient_distloss/blob/main/torch_efficient_distloss/eff_distloss.py 30 | class EffDistLoss(torch.autograd.Function): 31 | @staticmethod 32 | def forward(ctx, w, m, interval): 33 | ''' 34 | Efficient O(N) realization of distortion loss. 35 | There are B rays each with N sampled points. 36 | w: Float tensor in shape [B,N]. Volume rendering weights of each point. 37 | m: Float tensor in shape [B,N]. Midpoint distance to camera of each point. 38 | interval: Scalar or float tensor in shape [B,N]. The query interval of each point. 39 | ''' 40 | n_rays = np.prod(w.shape[:-1]) 41 | wm = (w * m) 42 | w_cumsum = w.cumsum(dim=-1) 43 | wm_cumsum = wm.cumsum(dim=-1) 44 | 45 | w_total = w_cumsum[..., [-1]] 46 | wm_total = wm_cumsum[..., [-1]] 47 | w_prefix = torch.cat([torch.zeros_like(w_total), w_cumsum[..., :-1]], dim=-1) 48 | wm_prefix = torch.cat([torch.zeros_like(wm_total), wm_cumsum[..., :-1]], dim=-1) 49 | loss_uni = (1/3) * interval * w.pow(2) 50 | loss_bi = 2 * w * (m * w_prefix - wm_prefix) 51 | if torch.is_tensor(interval): 52 | ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval) 53 | ctx.interval = None 54 | else: 55 | ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total) 56 | ctx.interval = interval 57 | ctx.n_rays = n_rays 58 | return (loss_bi.sum() + loss_uni.sum()) / n_rays 59 | 60 | @staticmethod 61 | @torch.autograd.function.once_differentiable 62 | def backward(ctx, grad_back): 63 | interval = ctx.interval 64 | n_rays = ctx.n_rays 65 | if interval is None: 66 | w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval = ctx.saved_tensors 67 | else: 68 | w, m, wm, w_prefix, w_total, wm_prefix, wm_total = ctx.saved_tensors 69 | grad_uni = (1/3) * interval * 2 * w 70 | w_suffix = w_total - (w_prefix + w) 71 | wm_suffix = wm_total - (wm_prefix + wm) 72 | grad_bi = 2 * (m * (w_prefix - w_suffix) + (wm_suffix - wm_prefix)) 73 | grad = grad_back * (grad_bi + grad_uni) / n_rays 74 | return grad, None, None, None 75 | 76 | eff_distloss = EffDistLoss.apply 77 | -------------------------------------------------------------------------------- /main_nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | from nerf.provider import NeRFDataset 5 | # from nerf.gui import NeRFGUI 6 | from nerf.utils import * 7 | 8 | from functools import partial 9 | from loss import huber_loss 10 | 11 | #torch.autograd.set_detect_anomaly(True) 12 | 13 | if __name__ == '__main__': 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('path', type=str) 17 | parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload") 18 | parser.add_argument('--test', action='store_true', help="test mode") 19 | parser.add_argument('--workspace', type=str, default='workspace') 20 | parser.add_argument('--seed', type=int, default=0) 21 | 22 | ### training options 23 | parser.add_argument('--iters', type=int, default=3000, help="training iters") 24 | parser.add_argument('--lr', type=float, default=4e-2, help="initial learning rate") 25 | parser.add_argument('--ckpt', type=str, default='latest') 26 | parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step") 27 | parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") 28 | parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") 29 | parser.add_argument('--num_steps', type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") 30 | parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") 31 | parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") 32 | parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") 33 | parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") 34 | 35 | ### network backbone options 36 | parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") 37 | parser.add_argument('--ff', action='store_true', help="use fully-fused MLP") 38 | parser.add_argument('--tcnn', action='store_true', help="use TCNN backend") 39 | 40 | parser.add_argument('--eval_interval', default=5, help="eval_interval") 41 | ### dataset options 42 | parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") 43 | parser.add_argument('--preload', action='store_true', help="preload all data into GPU, accelerate training but use more GPU memory") 44 | # (the default value is for the fox dataset) 45 | parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") 46 | parser.add_argument('--scale', type=float, default=0.33, help="scale camera location into box[-bound, bound]^3") 47 | parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") 48 | parser.add_argument('--dt_gamma', type=float, default=1/128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") 49 | parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera") 50 | parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") 51 | parser.add_argument('--bg_radius', type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)") 52 | 53 | ### experimental 54 | parser.add_argument('--error_map', action='store_true', help="use error map to sample rays") 55 | parser.add_argument('--clip_text', type=str, default='', help="text input for CLIP guidance") 56 | parser.add_argument('--rand_pose', type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses") 57 | 58 | 59 | ###### NeRFProtector 60 | parser.add_argument('--lambda1', type=float, default=0.001, help="Weight of watermarking loss") 61 | parser.add_argument('--local_sample', default=True, help="Ordinary NeRF sampling (local)") 62 | parser.add_argument('--global_sample', default=False, type=bool, help="single-layer view sampling (global)") 63 | parser.add_argument('--ml_sample', default=False, type=bool, help="multi-layer views sampling") 64 | 65 | parser.add_argument('--random_pose', default=False, type=int, help="Sample one random view") 66 | parser.add_argument('--extractor', default='dec_48b_whit.torchscript.pt', type=str, help="extractor") 67 | parser.add_argument('--bit', default=48, type=int, help="extractor") 68 | parser.add_argument('--init_ckpt', default=None, type=str, help="Path to ckpt of NeRF") 69 | opt = parser.parse_args() 70 | 71 | 72 | if opt.O: 73 | opt.fp16 = True 74 | opt.cuda_ray = True 75 | opt.preload = True 76 | 77 | if opt.ml_sample: 78 | opt.global_sample = False 79 | 80 | if opt.patch_size > 1: 81 | opt.error_map = False # do not use error_map if use patch-based training 82 | # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." 83 | assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." 84 | 85 | 86 | if opt.ff: 87 | opt.fp16 = True 88 | assert opt.bg_radius <= 0, "background model is not implemented for --ff" 89 | from nerf.network_ff import NeRFNetwork 90 | elif opt.tcnn: 91 | opt.fp16 = True 92 | assert opt.bg_radius <= 0, "background model is not implemented for --tcnn" 93 | from nerf.network_tcnn import NeRFNetwork 94 | else: 95 | from nerf.network import NeRFNetwork 96 | 97 | print(opt) 98 | 99 | seed_everything(opt.seed) 100 | 101 | model = NeRFNetwork( 102 | encoding="hashgrid", 103 | bound=opt.bound, 104 | cuda_ray=opt.cuda_ray, 105 | density_scale=1, 106 | min_near=opt.min_near, 107 | density_thresh=opt.density_thresh, 108 | bg_radius=opt.bg_radius, 109 | ) 110 | 111 | print(model) 112 | 113 | criterion = torch.nn.MSELoss() 114 | 115 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 116 | 117 | if opt.test: 118 | 119 | metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device), BitAccMeter(net = opt.extractor, device=device)] 120 | 121 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) 122 | 123 | test_loader = NeRFDataset(opt, device=device, type='test',downscale=2).dataloader() 124 | 125 | if test_loader.has_gt: 126 | trainer.evaluate(test_loader) # blender has gt, so evaluate it. 127 | 128 | trainer.test(test_loader, write_video=False) # test and save video 129 | 130 | trainer.save_mesh(resolution=256, threshold=10) 131 | 132 | else: 133 | 134 | train_loader = NeRFDataset(opt, device=device, type='train', downscale=2).dataloader() 135 | 136 | optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) 137 | 138 | 139 | 140 | # decay to 0.1 * init_lr at last iter step 141 | scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / (opt.iters), 1)) 142 | 143 | metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device), BitAccMeter(device=device)] 144 | 145 | trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval) 146 | 147 | 148 | # else: 149 | valid_loader = NeRFDataset(opt, device=device, type='val', downscale=2).dataloader() 150 | 151 | 152 | max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) 153 | trainer.train(train_loader, valid_loader, max_epoch) 154 | 155 | # also test 156 | test_loader = NeRFDataset(opt, device=device, type='test', downscale=2).dataloader() 157 | 158 | if test_loader.has_gt: 159 | trainer.evaluate(test_loader) # blender has gt, so evaluate it. 160 | 161 | trainer.test(test_loader, write_video=True) # test and save video 162 | 163 | -------------------------------------------------------------------------------- /nerf/clip_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import torchvision.transforms as T 7 | import torchvision.transforms.functional as TF 8 | 9 | import clip 10 | 11 | class CLIPLoss: 12 | def __init__(self, device, name='ViT-B/16'): 13 | self.device = device 14 | self.name = name 15 | self.clip_model, self.transform_PIL = clip.load(self.name, device=self.device, jit=False) 16 | 17 | # disable training 18 | self.clip_model.eval() 19 | for p in self.clip_model.parameters(): 20 | p.requires_grad = False 21 | 22 | # image augmentation 23 | self.transform = T.Compose([ 24 | T.Resize((224, 224)), 25 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 26 | ]) 27 | 28 | # placeholder 29 | self.text_zs = None 30 | self.image_zs = None 31 | 32 | def normalize(self, x): 33 | return x / x.norm(dim=-1, keepdim=True) 34 | 35 | # image-text (e.g., dreamfields) 36 | def prepare_text(self, texts): 37 | # texts: list of strings. 38 | texts = clip.tokenize(texts).to(self.device) 39 | self.text_zs = self.normalize(self.clip_model.encode_text(texts)) 40 | print(f'[INFO] prepared CLIP text feature: {self.text_zs.shape}') 41 | 42 | def __call__(self, images, mode='text'): 43 | 44 | images = self.transform(images) 45 | image_zs = self.normalize(self.clip_model.encode_image(images)) 46 | 47 | if mode == 'text': 48 | # if more than one string, randomly choose one. 49 | if self.text_zs.shape[0] > 1: 50 | idx = random.randint(0, self.text_zs.shape[0] - 1) 51 | text_zs = self.text_zs[[idx]] 52 | else: 53 | text_zs = self.text_zs 54 | # broadcast text_zs to all image_zs 55 | loss = - (image_zs * text_zs).sum(-1).mean() 56 | else: 57 | raise NotImplementedError 58 | 59 | return loss 60 | 61 | # image-image (e.g., diet-nerf) 62 | def prepare_image(self, dataset): 63 | # images: a nerf dataset (we need both poses and images!) 64 | pass -------------------------------------------------------------------------------- /nerf/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from .renderer import NeRFRenderer 8 | 9 | 10 | class NeRFNetwork(NeRFRenderer): 11 | def __init__(self, 12 | encoding="hashgrid", 13 | encoding_dir="sphere_harmonics", 14 | encoding_bg="hashgrid", 15 | num_layers=2, 16 | hidden_dim=64, 17 | geo_feat_dim=15, 18 | num_layers_color=3, 19 | hidden_dim_color=64, 20 | num_layers_bg=2, 21 | hidden_dim_bg=64, 22 | bound=1, 23 | **kwargs, 24 | ): 25 | super().__init__(bound, **kwargs) 26 | 27 | # sigma network 28 | self.num_layers = num_layers 29 | self.hidden_dim = hidden_dim 30 | self.geo_feat_dim = geo_feat_dim 31 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) 32 | 33 | sigma_net = [] 34 | for l in range(num_layers): 35 | if l == 0: 36 | in_dim = self.in_dim 37 | else: 38 | in_dim = hidden_dim 39 | 40 | if l == num_layers - 1: 41 | out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color 42 | else: 43 | out_dim = hidden_dim 44 | 45 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 46 | 47 | self.sigma_net = nn.ModuleList(sigma_net) 48 | 49 | # color network 50 | self.num_layers_color = num_layers_color 51 | self.hidden_dim_color = hidden_dim_color 52 | self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) 53 | 54 | color_net = [] 55 | for l in range(num_layers_color): 56 | if l == 0: 57 | in_dim = self.in_dim_dir + self.geo_feat_dim 58 | else: 59 | in_dim = hidden_dim_color 60 | 61 | if l == num_layers_color - 1: 62 | out_dim = 3 # 3 rgb 63 | else: 64 | out_dim = hidden_dim_color 65 | 66 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 67 | 68 | self.color_net = nn.ModuleList(color_net) 69 | 70 | # background network 71 | if self.bg_radius > 0: 72 | self.num_layers_bg = num_layers_bg 73 | self.hidden_dim_bg = hidden_dim_bg 74 | self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid 75 | 76 | bg_net = [] 77 | for l in range(num_layers_bg): 78 | if l == 0: 79 | in_dim = self.in_dim_bg + self.in_dim_dir 80 | else: 81 | in_dim = hidden_dim_bg 82 | 83 | if l == num_layers_bg - 1: 84 | out_dim = 3 # 3 rgb 85 | else: 86 | out_dim = hidden_dim_bg 87 | 88 | bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) 89 | 90 | self.bg_net = nn.ModuleList(bg_net) 91 | else: 92 | self.bg_net = None 93 | 94 | 95 | def forward(self, x, d): 96 | # x: [N, 3], in [-bound, bound] 97 | # d: [N, 3], nomalized in [-1, 1] 98 | 99 | # sigma 100 | x = self.encoder(x, bound=self.bound) 101 | 102 | h = x 103 | for l in range(self.num_layers): 104 | h = self.sigma_net[l](h) 105 | if l != self.num_layers - 1: 106 | h = F.relu(h, inplace=True) 107 | 108 | #sigma = F.relu(h[..., 0]) 109 | sigma = trunc_exp(h[..., 0]) 110 | geo_feat = h[..., 1:] 111 | 112 | # color 113 | 114 | d = self.encoder_dir(d) 115 | h = torch.cat([d, geo_feat], dim=-1) 116 | for l in range(self.num_layers_color): 117 | h = self.color_net[l](h) 118 | if l != self.num_layers_color - 1: 119 | h = F.relu(h, inplace=True) 120 | 121 | # sigmoid activation for rgb 122 | color = torch.sigmoid(h) 123 | 124 | return sigma, color 125 | 126 | def density(self, x): 127 | # x: [N, 3], in [-bound, bound] 128 | 129 | x = self.encoder(x, bound=self.bound) 130 | h = x 131 | for l in range(self.num_layers): 132 | h = self.sigma_net[l](h) 133 | if l != self.num_layers - 1: 134 | h = F.relu(h, inplace=True) 135 | 136 | #sigma = F.relu(h[..., 0]) 137 | sigma = trunc_exp(h[..., 0]) 138 | geo_feat = h[..., 1:] 139 | 140 | return { 141 | 'sigma': sigma, 142 | 'geo_feat': geo_feat, 143 | } 144 | 145 | def background(self, x, d): 146 | # x: [N, 2], in [-1, 1] 147 | 148 | h = self.encoder_bg(x) # [N, C] 149 | d = self.encoder_dir(d) 150 | 151 | h = torch.cat([d, h], dim=-1) 152 | for l in range(self.num_layers_bg): 153 | h = self.bg_net[l](h) 154 | if l != self.num_layers_bg - 1: 155 | h = F.relu(h, inplace=True) 156 | 157 | # sigmoid activation for rgb 158 | rgbs = torch.sigmoid(h) 159 | 160 | return rgbs 161 | 162 | # allow masked inference 163 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 164 | # x: [N, 3] in [-bound, bound] 165 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 166 | 167 | if mask is not None: 168 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 169 | # in case of empty mask 170 | if not mask.any(): 171 | return rgbs 172 | x = x[mask] 173 | d = d[mask] 174 | geo_feat = geo_feat[mask] 175 | 176 | d = self.encoder_dir(d) 177 | h = torch.cat([d, geo_feat], dim=-1) 178 | for l in range(self.num_layers_color): 179 | h = self.color_net[l](h) 180 | if l != self.num_layers_color - 1: 181 | h = F.relu(h, inplace=True) 182 | 183 | # sigmoid activation for rgb 184 | h = torch.sigmoid(h) 185 | 186 | if mask is not None: 187 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 188 | else: 189 | rgbs = h 190 | 191 | return rgbs 192 | 193 | # optimizer utils 194 | def get_params(self, lr): 195 | 196 | params = [ 197 | {'params': self.encoder.parameters(), 'lr': lr}, 198 | {'params': self.sigma_net.parameters(), 'lr': lr}, 199 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 200 | {'params': self.color_net.parameters(), 'lr': lr}, 201 | ] 202 | if self.bg_radius > 0: 203 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 204 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 205 | 206 | return params 207 | -------------------------------------------------------------------------------- /nerf/network_ff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from activation import trunc_exp 7 | from ffmlp import FFMLP 8 | 9 | from .renderer import NeRFRenderer 10 | 11 | class NeRFNetwork(NeRFRenderer): 12 | def __init__(self, 13 | encoding="hashgrid", 14 | encoding_dir="sphere_harmonics", 15 | num_layers=2, 16 | hidden_dim=64, 17 | geo_feat_dim=15, 18 | num_layers_color=3, 19 | hidden_dim_color=64, 20 | bound=1, 21 | **kwargs 22 | ): 23 | super().__init__(bound, **kwargs) 24 | 25 | # sigma network 26 | self.num_layers = num_layers 27 | self.hidden_dim = hidden_dim 28 | self.geo_feat_dim = geo_feat_dim 29 | self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound) 30 | 31 | self.sigma_net = FFMLP( 32 | input_dim=self.in_dim, 33 | output_dim=1 + self.geo_feat_dim, 34 | hidden_dim=self.hidden_dim, 35 | num_layers=self.num_layers, 36 | ) 37 | 38 | # color network 39 | self.num_layers_color = num_layers_color 40 | self.hidden_dim_color = hidden_dim_color 41 | self.encoder_dir, self.in_dim_color = get_encoder(encoding_dir) 42 | self.in_dim_color += self.geo_feat_dim + 1 # a manual fixing to make it 32, as done in nerf_network.h#178 43 | 44 | self.color_net = FFMLP( 45 | input_dim=self.in_dim_color, 46 | output_dim=3, 47 | hidden_dim=self.hidden_dim_color, 48 | num_layers=self.num_layers_color, 49 | ) 50 | 51 | def forward(self, x, d): 52 | # x: [N, 3], in [-bound, bound] 53 | # d: [N, 3], nomalized in [-1, 1] 54 | 55 | # sigma 56 | x = self.encoder(x, bound=self.bound) 57 | h = self.sigma_net(x) 58 | 59 | #sigma = F.relu(h[..., 0]) 60 | sigma = trunc_exp(h[..., 0]) 61 | geo_feat = h[..., 1:] 62 | 63 | # color 64 | d = self.encoder_dir(d) 65 | 66 | # TODO: preallocate space and avoid this cat? 67 | p = torch.zeros_like(geo_feat[..., :1]) # manual input padding 68 | h = torch.cat([d, geo_feat, p], dim=-1) 69 | h = self.color_net(h) 70 | 71 | # sigmoid activation for rgb 72 | rgb = torch.sigmoid(h) 73 | 74 | return sigma, rgb 75 | 76 | def density(self, x): 77 | # x: [N, 3], in [-bound, bound] 78 | 79 | x = self.encoder(x, bound=self.bound) 80 | h = self.sigma_net(x) 81 | 82 | #sigma = F.relu(h[..., 0]) 83 | sigma = trunc_exp(h[..., 0]) 84 | geo_feat = h[..., 1:] 85 | 86 | return { 87 | 'sigma': sigma, 88 | 'geo_feat': geo_feat, 89 | } 90 | 91 | # allow masked inference 92 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 93 | # x: [N, 3] in [-bound, bound] 94 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 95 | 96 | #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 97 | #starter.record() 98 | 99 | if mask is not None: 100 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 101 | # in case of empty mask 102 | if not mask.any(): 103 | return rgbs 104 | x = x[mask] 105 | d = d[mask] 106 | geo_feat = geo_feat[mask] 107 | 108 | #print(x.shape, rgbs.shape) 109 | 110 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'mask = {curr_time}') 111 | #starter.record() 112 | 113 | d = self.encoder_dir(d) 114 | 115 | p = torch.zeros_like(geo_feat[..., :1]) # manual input padding 116 | h = torch.cat([d, geo_feat, p], dim=-1) 117 | 118 | h = self.color_net(h) 119 | 120 | # sigmoid activation for rgb 121 | h = torch.sigmoid(h) 122 | 123 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'call = {curr_time}') 124 | #starter.record() 125 | 126 | if mask is not None: 127 | rgbs[mask] = h.to(rgbs.dtype) 128 | else: 129 | rgbs = h 130 | 131 | #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'unmask = {curr_time}') 132 | #starter.record() 133 | 134 | return rgbs 135 | 136 | # optimizer utils 137 | def get_params(self, lr): 138 | 139 | params = [ 140 | {'params': self.encoder.parameters(), 'lr': lr}, 141 | {'params': self.sigma_net.parameters(), 'lr': lr}, 142 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 143 | {'params': self.color_net.parameters(), 'lr': lr}, 144 | ] 145 | if self.bg_radius > 0: 146 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 147 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 148 | 149 | return params -------------------------------------------------------------------------------- /nerf/network_tcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | import tinycudann as tcnn 8 | from activation import trunc_exp 9 | from .renderer import NeRFRenderer 10 | 11 | 12 | class NeRFNetwork(NeRFRenderer): 13 | def __init__(self, 14 | encoding="HashGrid", 15 | encoding_dir="SphericalHarmonics", 16 | num_layers=2, 17 | hidden_dim=64, 18 | geo_feat_dim=15, 19 | num_layers_color=3, 20 | hidden_dim_color=64, 21 | bound=1, 22 | **kwargs 23 | ): 24 | super().__init__(bound, **kwargs) 25 | 26 | # sigma network 27 | self.num_layers = num_layers 28 | self.hidden_dim = hidden_dim 29 | self.geo_feat_dim = geo_feat_dim 30 | 31 | per_level_scale = np.exp2(np.log2(2048 * bound / 16) / (16 - 1)) 32 | 33 | self.encoder = tcnn.Encoding( 34 | n_input_dims=3, 35 | encoding_config={ 36 | "otype": "HashGrid", 37 | "n_levels": 16, 38 | "n_features_per_level": 2, 39 | "log2_hashmap_size": 19, 40 | "base_resolution": 16, 41 | "per_level_scale": per_level_scale, 42 | }, 43 | ) 44 | 45 | self.sigma_net = tcnn.Network( 46 | n_input_dims=32, 47 | n_output_dims=1 + self.geo_feat_dim, 48 | network_config={ 49 | "otype": "FullyFusedMLP", 50 | "activation": "ReLU", 51 | "output_activation": "None", 52 | "n_neurons": hidden_dim, 53 | "n_hidden_layers": num_layers - 1, 54 | }, 55 | ) 56 | 57 | # color network 58 | self.num_layers_color = num_layers_color 59 | self.hidden_dim_color = hidden_dim_color 60 | 61 | self.encoder_dir = tcnn.Encoding( 62 | n_input_dims=3, 63 | encoding_config={ 64 | "otype": "SphericalHarmonics", 65 | "degree": 4, 66 | }, 67 | ) 68 | 69 | self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim 70 | 71 | self.color_net = tcnn.Network( 72 | n_input_dims=self.in_dim_color, 73 | n_output_dims=3, 74 | network_config={ 75 | "otype": "FullyFusedMLP", 76 | "activation": "ReLU", 77 | "output_activation": "None", 78 | "n_neurons": hidden_dim_color, 79 | "n_hidden_layers": num_layers_color - 1, 80 | }, 81 | ) 82 | 83 | 84 | def forward(self, x, d): 85 | # x: [N, 3], in [-bound, bound] 86 | # d: [N, 3], nomalized in [-1, 1] 87 | 88 | 89 | # sigma 90 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 91 | x = self.encoder(x) 92 | h = self.sigma_net(x) 93 | 94 | #sigma = F.relu(h[..., 0]) 95 | sigma = trunc_exp(h[..., 0]) 96 | geo_feat = h[..., 1:] 97 | 98 | # color 99 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] 100 | d = self.encoder_dir(d) 101 | 102 | #p = torch.zeros_like(geo_feat[..., :1]) # manual input padding 103 | h = torch.cat([d, geo_feat], dim=-1) 104 | h = self.color_net(h) 105 | 106 | # sigmoid activation for rgb 107 | color = torch.sigmoid(h) 108 | 109 | return sigma, color 110 | 111 | def density(self, x): 112 | # x: [N, 3], in [-bound, bound] 113 | 114 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 115 | x = self.encoder(x) 116 | h = self.sigma_net(x) 117 | 118 | #sigma = F.relu(h[..., 0]) 119 | sigma = trunc_exp(h[..., 0]) 120 | geo_feat = h[..., 1:] 121 | 122 | return { 123 | 'sigma': sigma, 124 | 'geo_feat': geo_feat, 125 | } 126 | 127 | # allow masked inference 128 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 129 | # x: [N, 3] in [-bound, bound] 130 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 131 | 132 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 133 | 134 | if mask is not None: 135 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 136 | # in case of empty mask 137 | if not mask.any(): 138 | return rgbs 139 | x = x[mask] 140 | d = d[mask] 141 | geo_feat = geo_feat[mask] 142 | 143 | # color 144 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] 145 | d = self.encoder_dir(d) 146 | 147 | h = torch.cat([d, geo_feat], dim=-1) 148 | h = self.color_net(h) 149 | 150 | # sigmoid activation for rgb 151 | h = torch.sigmoid(h) 152 | 153 | if mask is not None: 154 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 155 | else: 156 | rgbs = h 157 | 158 | return rgbs 159 | 160 | # optimizer utils 161 | def get_params(self, lr): 162 | 163 | params = [ 164 | {'params': self.encoder.parameters(), 'lr': lr}, 165 | {'params': self.sigma_net.parameters(), 'lr': lr}, 166 | {'params': self.encoder_dir.parameters(), 'lr': lr}, 167 | {'params': self.color_net.parameters(), 'lr': lr}, 168 | ] 169 | if self.bg_radius > 0: 170 | params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) 171 | params.append({'params': self.bg_net.parameters(), 'lr': lr}) 172 | 173 | return params -------------------------------------------------------------------------------- /raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_raymarching', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'raymarching.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | ''' 33 | Usage: 34 | 35 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 36 | 37 | python setup.py install # build extensions and install (copy) to PATH. 38 | pip install . # ditto but better (e.g., dependency & metadata handling) 39 | 40 | python setup.py develop # build extensions and install (symbolic) to PATH. 41 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 42 | 43 | ''' 44 | setup( 45 | name='raymarching', # package name, import this to use python API 46 | ext_modules=[ 47 | CUDAExtension( 48 | name='_raymarching', # extension name, import this to use CUDA API 49 | sources=[os.path.join(_src_path, 'src', f) for f in [ 50 | 'raymarching.cu', 51 | 'bindings.cpp', 52 | ]], 53 | extra_compile_args={ 54 | 'cxx': c_flags, 55 | 'nvcc': nvcc_flags, 56 | } 57 | ), 58 | ], 59 | cmdclass={ 60 | 'build_ext': BuildExtension, 61 | } 62 | ) -------------------------------------------------------------------------------- /raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("packbits", &packbits, "packbits (CUDA)"); 8 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 9 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); 10 | m.def("morton3D", &morton3D, "morton3D (CUDA)"); 11 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); 12 | // train 13 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 14 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 15 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 16 | // infer 17 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 18 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 19 | } -------------------------------------------------------------------------------- /raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); 9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); 10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); 11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); 12 | 13 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); 14 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 15 | void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 16 | 17 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); 18 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # NeRFProtector 2 | ## This is the official implementation of Protecting NeRFs’ Copyright via Plug-And-Play Watermarking Base Model (ECCV 2024) [[Arxiv]](https://arxiv.org/abs/2407.07735) 3 | 4 | # Install 5 | ```bash 6 | git clone --recursive https://github.com/qsong2001/NeRFProtector-code.git 7 | cd NeRFProtector-code 8 | ``` 9 | 10 | ### Install with conda and pip 11 | ```bash 12 | conda create -n ngp python==3.11 13 | conda activate ngp 14 | 15 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 16 | 17 | pip install -r requirements.txt 18 | 19 | ``` 20 | 21 | # Usage 22 | 23 | We implement our NeRFProtector on torch-ngp. 24 | 25 | Please refer to [torch-ngp](https://github.com/ashawkey/torch-ngp) for more details of usage. 26 | 27 | For example, in the Lego scene, we can use the script below: 28 | ```bash 29 | python main_nerf.py data/nerf_synthetic/lego --workspace out/lego -O --bound 1.0 --scale 0.8 --dt_gamma 0 --lambda1 0.001 --iters 10000 --lr 4e-2 --ml_sample True 30 | ``` 31 | 32 | 33 | ### Other related projects 34 | 35 | * [torch-ngp](https://github.com/ashawkey/torch-ngp): PyTorch-based Instant NGP 36 | 37 | * [stable signature](https://github.com/facebookresearch/stable_signature): Rooting Watermarks in Latent Diffusion Models 38 | 39 | 40 | # Citation 41 | 42 | If you find this work useful, a citation will be appreciated via: 43 | ``` 44 | @inproceedings{song2024protecting, 45 | title={Protecting NeRFs' Copyright via Plug-And-Play Watermarking Base Model}, 46 | author={Song, Qi and Luo, Ziyuan and Cheung, Ka Chun and See, Simon and Wan, Renjie}, 47 | booktitle={ECCV}, 48 | year={2024} 49 | } 50 | ``` 51 | 52 | # Acknowledgement 53 | 54 | * Credits to [Thomas Müller](https://tom94.net/) for the amazing [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) and [instant-ngp](https://github.com/NVlabs/instant-ngp): 55 | ``` 56 | @misc{tiny-cuda-nn, 57 | Author = {Thomas M\"uller}, 58 | Year = {2021}, 59 | Note = {https://github.com/nvlabs/tiny-cuda-nn}, 60 | Title = {Tiny {CUDA} Neural Network Framework} 61 | } 62 | 63 | @article{mueller2022instant, 64 | title = {Instant Neural Graphics Primitives with a Multiresolution Hash Encoding}, 65 | author = {Thomas M\"uller and Alex Evans and Christoph Schied and Alexander Keller}, 66 | journal = {arXiv:2201.05989}, 67 | year = {2022}, 68 | month = jan 69 | } 70 | 71 | @misc{torch-ngp, 72 | Author = {Jiaxiang Tang}, 73 | Year = {2022}, 74 | Note = {https://github.com/ashawkey/torch-ngp}, 75 | Title = {Torch-ngp: a PyTorch implementation of instant-ngp} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch-ema 2 | ninja 3 | trimesh 4 | opencv-python 5 | tensorboardX 6 | numpy 7 | pandas 8 | tqdm 9 | matplotlib 10 | PyMCubes 11 | rich 12 | pysdf 13 | dearpygui 14 | packaging 15 | scipy 16 | lpips 17 | imageio 18 | torchmetrics 19 | augly 20 | -------------------------------------------------------------------------------- /scripts/hyper2nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import json 5 | import argparse 6 | import trimesh 7 | 8 | 9 | def visualize_poses(poses, size=0.1): 10 | # poses: [B, 4, 4] 11 | 12 | axes = trimesh.creation.axis(axis_length=4) 13 | box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() 14 | box.colors = np.array([[128, 128, 128]] * len(box.entities)) 15 | objects = [axes, box] 16 | 17 | for pose in poses: 18 | # a camera is visualized with 8 line segments. 19 | pos = pose[:3, 3] 20 | a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 21 | b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 22 | c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 23 | d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 24 | 25 | dir = (a + b + c + d) / 4 - pos 26 | dir = dir / (np.linalg.norm(dir) + 1e-8) 27 | o = pos + dir * 3 28 | 29 | segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) 30 | segs = trimesh.load_path(segs) 31 | objects.append(segs) 32 | 33 | trimesh.Scene(objects).show() 34 | 35 | # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 36 | def closest_point_2_lines(oa, da, ob, db): 37 | da = da / np.linalg.norm(da) 38 | db = db / np.linalg.norm(db) 39 | c = np.cross(da, db) 40 | denom = np.linalg.norm(c)**2 41 | t = ob - oa 42 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 43 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 44 | if ta > 0: 45 | ta = 0 46 | if tb > 0: 47 | tb = 0 48 | return (oa+ta*da+ob+tb*db) * 0.5, denom 49 | 50 | def rotmat(a, b): 51 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 52 | v = np.cross(a, b) 53 | c = np.dot(a, b) 54 | # handle exception for the opposite direction input 55 | if c < -1 + 1e-10: 56 | return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) 57 | s = np.linalg.norm(v) 58 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 59 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 60 | 61 | if __name__ == '__main__': 62 | 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('path', type=str, help="root directory to the HyperNeRF dataset (contains camera/, rgb/, dataset.json, scene.json)") 65 | parser.add_argument('--downscale', type=int, default=2, help="image size down scale, choose from [2, 4, 8, 16], e.g., 8") 66 | parser.add_argument('--interval', type=int, default=4, help="used for interp dataset's train/val split, should > 2 and be even") 67 | 68 | opt = parser.parse_args() 69 | 70 | print(f'[INFO] process {opt.path}') 71 | 72 | # load data 73 | with open(os.path.join(opt.path, 'dataset.json'), 'r') as f: 74 | json_dataset = json.load(f) 75 | 76 | names = json_dataset['ids'] 77 | val_names = json_dataset['val_ids'] 78 | 79 | # data split mode following hypernerf (vrig / interp) 80 | if len(val_names) > 0: 81 | train_names = json_dataset['train_ids'] 82 | val_ids = [] 83 | train_ids = [] 84 | for i, name in enumerate(names): 85 | if name in val_names: 86 | val_ids.append(i) 87 | elif name in train_names: 88 | train_ids.append(i) 89 | else: 90 | all_ids = np.arange(len(names)) 91 | train_ids = all_ids[::opt.interval] 92 | val_ids = (train_ids[:-1] + train_ids[1:]) // 2 93 | 94 | print(f'[INFO] train_ids: {len(train_ids)}, val_ids: {len(val_ids)}') 95 | 96 | with open(os.path.join(opt.path, 'scene.json'), 'r') as f: 97 | json_scene = json.load(f) 98 | 99 | scale = json_scene['scale'] 100 | center = json_scene['center'] 101 | 102 | with open(os.path.join(opt.path, 'metadata.json'), 'r') as f: 103 | json_meta = json.load(f) 104 | 105 | images = [] 106 | times = [] 107 | poses = [] 108 | H, W, f, cx, cy = None, None, None, None, None 109 | 110 | for name in names: 111 | 112 | # load image 113 | images.append(os.path.join('rgb', f'{opt.downscale}x', f'{name}.png')) 114 | 115 | # load time 116 | times.append(json_meta[name]['time_id']) 117 | 118 | # load pose 119 | with open(os.path.join(opt.path, 'camera', f'{name}.json'), 'r') as f: 120 | cam = json.load(f) 121 | 122 | # TODO: we use a simplified pinhole camera model rather than the original openCV camera model... hope it won't influence results seriously... 123 | 124 | pose = np.eye(4, 4) 125 | pose[:3, :3] = np.array(cam['orientation']).T # it works... 126 | #pose[:3, 3] = (np.array(cam['position']) - center) * scale * 4 127 | pose[:3, 3] = np.array(cam['position']) 128 | 129 | # CHECK: simply assume all intrinsic are same ? 130 | W, H = cam['image_size'] # before scale 131 | cx, cy = cam['principal_point'] 132 | fl = cam['focal_length'] 133 | 134 | poses.append(pose) 135 | 136 | poses = np.stack(poses, axis=0) # [N, 4, 4] 137 | times = np.asarray(times, dtype=np.float32) # [N] 138 | times = times / times.max() # normalize to [0, 1] 139 | 140 | N = len(images) 141 | 142 | W = W // opt.downscale 143 | H = H // opt.downscale 144 | cx = cx / opt.downscale 145 | cy = cy / opt.downscale 146 | fl = fl / opt.downscale 147 | 148 | print(f'[INFO] H = {H}, W = {W}, fl = {fl} (downscale = {opt.downscale})') 149 | 150 | # visualize_poses(poses) 151 | 152 | # the following stuff are from colmap2nerf... 153 | poses[:, 0:3, 1] *= -1 154 | poses[:, 0:3, 2] *= -1 155 | poses = poses[:, [1, 0, 2, 3], :] # swap y and z 156 | poses[:, 2, :] *= -1 # flip whole world upside down 157 | 158 | up = poses[:, 0:3, 1].sum(0) 159 | up = up / np.linalg.norm(up) 160 | R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] 161 | R = np.pad(R, [0, 1]) 162 | R[-1, -1] = 1 163 | 164 | poses = R @ poses 165 | 166 | totw = 0.0 167 | totp = np.array([0.0, 0.0, 0.0]) 168 | for i in range(N): 169 | mf = poses[i, :3, :] 170 | for j in range(i + 1, N): 171 | mg = poses[j, :3, :] 172 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 173 | #print(i, j, p, w) 174 | if w > 0.01: 175 | totp += p * w 176 | totw += w 177 | totp /= totw 178 | print(f'[INFO] totp = {totp}') 179 | poses[:, :3, 3] -= totp 180 | avglen = np.linalg.norm(poses[:, :3, 3], axis=-1).mean() 181 | poses[:, :3, 3] *= 4.0 / avglen 182 | print(f'[INFO] average radius = {avglen}') 183 | 184 | # visualize_poses(poses) 185 | 186 | # construct frames 187 | frames_train = [] 188 | for i in train_ids: 189 | frames_train.append({ 190 | 'file_path': images[i], 191 | 'time': float(times[i]), 192 | 'transform_matrix': poses[i].tolist(), 193 | }) 194 | 195 | frames_val = [] 196 | for i in val_ids: 197 | frames_val.append({ 198 | 'file_path': images[i], 199 | 'time': float(times[i]), 200 | 'transform_matrix': poses[i].tolist(), 201 | }) 202 | 203 | def write_json(filename, frames): 204 | 205 | # construct a transforms.json 206 | out = { 207 | 'w': W, 208 | 'h': H, 209 | 'fl_x': fl, 210 | 'fl_y': fl, 211 | 'cx': cx, 212 | 'cy': cy, 213 | 'frames': frames, 214 | } 215 | 216 | # write 217 | output_path = os.path.join(opt.path, filename) 218 | print(f'[INFO] write {len(frames)} images to {output_path}') 219 | with open(output_path, 'w') as f: 220 | json.dump(out, f, indent=2) 221 | 222 | write_json('transforms_train.json', frames_train) 223 | write_json('transforms_val.json', frames_val[::10]) 224 | write_json('transforms_test.json', frames_val) -------------------------------------------------------------------------------- /scripts/install_ext.sh: -------------------------------------------------------------------------------- 1 | pip install ./raymarching 2 | 3 | pip install ./gridencoder 4 | 5 | pip install ./shencoder 6 | 7 | pip install ./freqencoder 8 | 9 | # turned off by default, very slow to compile, and performance is not good enough. 10 | #pip install ./ffmlp -------------------------------------------------------------------------------- /scripts/llff2nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import math 5 | import json 6 | import trimesh 7 | import argparse 8 | 9 | # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 10 | def closest_point_2_lines(oa, da, ob, db): 11 | da = da / np.linalg.norm(da) 12 | db = db / np.linalg.norm(db) 13 | c = np.cross(da, db) 14 | denom = np.linalg.norm(c)**2 15 | t = ob - oa 16 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 17 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 18 | if ta > 0: 19 | ta = 0 20 | if tb > 0: 21 | tb = 0 22 | return (oa+ta*da+ob+tb*db) * 0.5, denom 23 | 24 | def rotmat(a, b): 25 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 26 | v = np.cross(a, b) 27 | c = np.dot(a, b) 28 | # handle exception for the opposite direction input 29 | if c < -1 + 1e-10: 30 | return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) 31 | s = np.linalg.norm(v) 32 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 33 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 34 | 35 | 36 | def visualize_poses(poses, size=0.1): 37 | # poses: [B, 4, 4] 38 | 39 | axes = trimesh.creation.axis(axis_length=4) 40 | box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() 41 | box.colors = np.array([[128, 128, 128]] * len(box.entities)) 42 | objects = [axes, box] 43 | 44 | for pose in poses: 45 | # a camera is visualized with 8 line segments. 46 | pos = pose[:3, 3] 47 | a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 48 | b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] 49 | c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 50 | d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] 51 | 52 | dir = (a + b + c + d) / 4 - pos 53 | dir = dir / (np.linalg.norm(dir) + 1e-8) 54 | o = pos + dir * 3 55 | 56 | segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) 57 | segs = trimesh.load_path(segs) 58 | objects.append(segs) 59 | 60 | trimesh.Scene(objects).show() 61 | 62 | if __name__ == '__main__': 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('path', type=str, help="root directory to the LLFF dataset (contains images/ and pose_bounds.npy)") 66 | parser.add_argument('--images', type=str, default='images_8', help="images folder (do not include full path, e.g., just use `images_4`)") 67 | parser.add_argument('--downscale', type=float, default=8, help="image size down scale, e.g., 4") 68 | parser.add_argument('--hold', type=int, default=8, help="hold out for validation every $ images") 69 | 70 | opt = parser.parse_args() 71 | print(f'[INFO] process {opt.path}') 72 | 73 | # path must end with / to make sure image path is relative 74 | if opt.path[-1] != '/': 75 | opt.path += '/' 76 | 77 | # load data 78 | images = [f[len(opt.path):] for f in sorted(glob.glob(os.path.join(opt.path, opt.images, "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')] 79 | 80 | poses_bounds = np.load(os.path.join(opt.path, 'poses_bounds.npy')) 81 | N = poses_bounds.shape[0] 82 | 83 | print(f'[INFO] loaded {len(images)} images, {N} poses_bounds as {poses_bounds.shape}') 84 | 85 | assert N == len(images) 86 | 87 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N, 3, 5) 88 | bounds = poses_bounds[:, -2:] # (N, 2) 89 | 90 | H, W, fl = poses[0, :, -1] 91 | 92 | H = H // opt.downscale 93 | W = W // opt.downscale 94 | fl = fl / opt.downscale 95 | 96 | print(f'[INFO] H = {H}, W = {W}, fl = {fl} (downscale = {opt.downscale})') 97 | 98 | # inversion of this: https://github.com/Fyusion/LLFF/blob/c6e27b1ee59cb18f054ccb0f87a90214dbe70482/llff/poses/pose_utils.py#L51 99 | poses = np.concatenate([poses[..., 1:2], poses[..., 0:1], -poses[..., 2:3], poses[..., 3:4]], -1) # (N, 3, 4) 100 | 101 | # to homogeneous 102 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N, 1, 4) 103 | poses = np.concatenate([poses, last_row], axis=1) # (N, 4, 4) 104 | 105 | # visualize_poses(poses) 106 | 107 | # the following stuff are from colmap2nerf... [flower fails, the camera must be in-ward...] 108 | poses[:, 0:3, 1] *= -1 109 | poses[:, 0:3, 2] *= -1 110 | poses = poses[:, [1, 0, 2, 3], :] # swap y and z 111 | poses[:, 2, :] *= -1 # flip whole world upside down 112 | 113 | up = poses[:, 0:3, 1].sum(0) 114 | up = up / np.linalg.norm(up) 115 | R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] 116 | R = np.pad(R, [0, 1]) 117 | R[-1, -1] = 1 118 | 119 | poses = R @ poses 120 | 121 | totw = 0.0 122 | totp = np.array([0.0, 0.0, 0.0]) 123 | for i in range(N): 124 | mf = poses[i, :3, :] 125 | for j in range(i + 1, N): 126 | mg = poses[j, :3, :] 127 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 128 | #print(i, j, p, w) 129 | if w > 0.01: 130 | totp += p * w 131 | totw += w 132 | totp /= totw 133 | print(f'[INFO] totp = {totp}') 134 | poses[:, :3, 3] -= totp 135 | avglen = np.linalg.norm(poses[:, :3, 3], axis=-1).mean() 136 | poses[:, :3, 3] *= 4.0 / avglen 137 | print(f'[INFO] average radius = {avglen}') 138 | 139 | # visualize_poses(poses) 140 | 141 | # construct frames 142 | 143 | all_ids = np.arange(N) 144 | test_ids = all_ids[::opt.hold] 145 | train_ids = np.array([i for i in all_ids if i not in test_ids]) 146 | 147 | frames_train = [] 148 | frames_test = [] 149 | for i in train_ids: 150 | frames_train.append({ 151 | 'file_path': images[i], 152 | 'transform_matrix': poses[i].tolist(), 153 | }) 154 | for i in test_ids: 155 | frames_test.append({ 156 | 'file_path': images[i], 157 | 'transform_matrix': poses[i].tolist(), 158 | }) 159 | 160 | def write_json(filename, frames): 161 | 162 | # construct a transforms.json 163 | out = { 164 | 'w': W, 165 | 'h': H, 166 | 'fl_x': fl, 167 | 'fl_y': fl, 168 | 'cx': W // 2, 169 | 'cy': H // 2, 170 | 'aabb_scale': 2, 171 | 'frames': frames, 172 | } 173 | 174 | # write 175 | output_path = os.path.join(opt.path, filename) 176 | print(f'[INFO] write {len(frames)} images to {output_path}') 177 | with open(output_path, 'w') as f: 178 | json.dump(out, f, indent=2) 179 | 180 | write_json('transforms_train.json', frames_train) 181 | write_json('transforms_val.json', frames_test[::10]) 182 | write_json('transforms_test.json', frames_test) 183 | 184 | -------------------------------------------------------------------------------- /scripts/run_ccnerf.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # train single objects 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/ficus --workspace trial_cc_ficus -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/chair --workspace trial_cc_chair -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map 6 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O --bound 1.0 --scale 0.67 --dt_gamma 0 --error_map 7 | 8 | # compose 9 | # use more samples per ray (--max_steps) and a larger bound for better results 10 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=7 python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O --bound 2.0 --scale 0.67 --dt_gamma 0 --max_steps 2048 --test --compose -------------------------------------------------------------------------------- /scripts/run_dnerf.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=6 python main_dnerf.py data/dnerf/bouncingballs --workspace trial_dnerf_bouncingballs -O --bound 1 --scale 0.8 --dt_gamma 0 #--gui --test 4 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=6 python main_dnerf.py data/dnerf/bouncingballs --workspace trial_dnerf_basis_bouncingballs -O --bound 1 --scale 0.8 --dt_gamma 0 --basis #--gui --test 5 | 6 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=6 python main_dnerf.py data/dnerf/standup --workspace trial_dnerf_standup -O --bound 1 --scale 0.8 --dt_gamma 0 #--gui --test 7 | 8 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies -O --bound 1 --scale 0.3 #--gui --test 9 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies_ncr --preload --fp16 --bound 1 --scale 0.3 #--gui --test 10 | 11 | # OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=4 python main_dnerf.py data/vrig-3dprinter/ --workspace trial_dnerf_printer -O --bound 2 --scale 0.33 #--gui --test -------------------------------------------------------------------------------- /scripts/run_gui_nerf.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_fox -O --gui #--error_map 4 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui #--error_map 5 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_llff_data/orchids --workspace trial_nerf_orchids -O --gui --bound 2.0 --scale 0.6 6 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/TanksAndTemple/Family --workspace trial_nerf_family -O --bound 1.0 --scale 0.33 --dt_gamma 0 --gui 7 | 8 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/figure --workspace trial_nerf_fig -O --gui --bound 1.0 --scale 0.3 --bg_radius 128 9 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=5 python main_nerf.py data/vasedeck --workspace trial_nerf_vase -O --gui --bound 4.0 --scale 0.3 -------------------------------------------------------------------------------- /scripts/run_gui_nerf_clip.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1.0 --scale 0.67 --dt_gamma 0 --gui 4 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1.0 --scale 0.67 --dt_gamma 0 --gui --rand_pose 0 --clip_text "red" --lr 1e-3 --ckpt latest_model 5 | 6 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_llff_data/orchids --workspace trial_nerf_orchids -O --gui --bound 2.0 --scale 0.6 7 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_llff_data/orchids --workspace trial_nerf_orchids -O --gui --bound 2.0 --scale 0.6 --rand_pose 0 --clip_text "blue flower" --lr 1e-3 --ckpt latest_model -------------------------------------------------------------------------------- /scripts/run_gui_tensoRF.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensoRF_fox -O --gui 4 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui 5 | 6 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensorCP_fox --cp -O --gui 7 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensorCP_lego --cp -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui 8 | 9 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/figure --workspace trial_tensoRF_fig -O --gui --scale 0.33 --bound 1.0 --bg_radius 32 -------------------------------------------------------------------------------- /scripts/run_nerf.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/fox --workspace trial_nerf_fox -O 4 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego -O --bound 1 --scale 0.8 --dt_gamma 0 5 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_emap -O --bound 1 --scale 0.8 --dt_gamma 0 --error_map 6 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/TanksAndTemple/Barn --workspace trial_nerf_barn -O --bound 1.0 --scale 0.33 --dt_gamma 0 7 | 8 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/firekeeper --workspace trial_nerf_firekeeper_bg_32 -O --bound 1.0 --scale 0.33 --bg_radius 32 #--gui #--test 9 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/garden --workspace trial_nerf_garden_bound_16 --cuda_ray --fp16 --bound 16.0 --scale 0.33 #--gui --test 10 | 11 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/vasedeck --workspace trial_nerf_vasedeck -O --bound 4.0 --scale 0.33 #--gui #--test 12 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/vasedeck --workspace trial_nerf_vasedeck_bg_32 -O --bound 4.0 --scale 0.33 --bg_radius 32 #--gui #--test -------------------------------------------------------------------------------- /scripts/run_sdf.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 4 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/armadillo.obj --workspace trial_sdf_ff --fp16 --ff 5 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/armadillo.obj --workspace trial_sdf_tcnn --fp16 --tcnn 6 | 7 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_sdf.py data/lucy.obj --workspace trial_sdf --fp16 -------------------------------------------------------------------------------- /scripts/run_tensoRF.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensoRF_fox -O --error_map 5 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego -O --bound 1.0 --scale 0.8 --dt_gamma 0 6 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego_emap -O --bound 1.0 --scale 0.8 --dt_gamma 0 --error_map 7 | 8 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensorCP_fox -O --cp --resolution1 500 --error_map 9 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensorCP_lego --cp --resolution1 500 -O --bound 1.0 --scale 0.8 --dt_gamma 0 --error_map 10 | 11 | #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/figure --workspace trial_tensoRF_fig -O --scale 0.33 --bound 1.0 -------------------------------------------------------------------------------- /scripts/tanks2nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import json 5 | 6 | import argparse 7 | 8 | # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 9 | def closest_point_2_lines(oa, da, ob, db): 10 | da = da / np.linalg.norm(da) 11 | db = db / np.linalg.norm(db) 12 | c = np.cross(da, db) 13 | denom = np.linalg.norm(c)**2 14 | t = ob - oa 15 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 16 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 17 | if ta > 0: 18 | ta = 0 19 | if tb > 0: 20 | tb = 0 21 | return (oa+ta*da+ob+tb*db) * 0.5, denom 22 | 23 | def rotmat(a, b): 24 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 25 | v = np.cross(a, b) 26 | c = np.dot(a, b) 27 | # handle exception for the opposite direction input 28 | if c < -1 + 1e-10: 29 | return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) 30 | s = np.linalg.norm(v) 31 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 32 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 33 | 34 | if __name__ == '__main__': 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('path', type=str, help="root directory to the Tanks&Temple dataset (contains rgb/, pose/, intrinsics.txt)") 38 | 39 | opt = parser.parse_args() 40 | print(f'[INFO] process {opt.path}') 41 | 42 | # load data 43 | 44 | intrinsics = np.loadtxt(os.path.join(opt.path, "intrinsics.txt")) 45 | fl_x = intrinsics[0, 0] 46 | fl_y = intrinsics[1, 1] 47 | cx = intrinsics[0, 2] 48 | cy = intrinsics[1, 2] 49 | H = 1080 50 | W = 1920 51 | 52 | pose_files = sorted(os.listdir(os.path.join(opt.path, 'pose'))) 53 | img_files = sorted(os.listdir(os.path.join(opt.path, 'rgb'))) 54 | 55 | # read in all poses, and do transform 56 | poses = [] 57 | for pose_f in pose_files: 58 | pose = np.loadtxt(os.path.join(opt.path, 'pose', pose_f)) # [4, 4] 59 | poses.append(pose) 60 | 61 | poses = np.stack(poses, axis=0) # [N, 4, 4] 62 | N = poses.shape[0] 63 | 64 | # the following stuff are from colmap2nerf... 65 | poses[:, 0:3, 1] *= -1 66 | poses[:, 0:3, 2] *= -1 67 | poses = poses[:, [1, 0, 2, 3], :] # swap y and z 68 | poses[:, 2, :] *= -1 # flip whole world upside down 69 | 70 | up = poses[:, 0:3, 1].sum(0) 71 | up = up / np.linalg.norm(up) 72 | R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] 73 | R = np.pad(R, [0, 1]) 74 | R[-1, -1] = 1 75 | 76 | poses = R @ poses 77 | 78 | totw = 0.0 79 | totp = np.array([0.0, 0.0, 0.0]) 80 | for i in range(N): 81 | mf = poses[i, :3, :] 82 | for j in range(i + 1, N): 83 | mg = poses[j, :3, :] 84 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 85 | #print(i, j, p, w) 86 | if w > 0.01: 87 | totp += p * w 88 | totw += w 89 | totp /= totw 90 | print(f'[INFO] totp = {totp}') 91 | poses[:, :3, 3] -= totp 92 | 93 | avglen = np.linalg.norm(poses[:, :3, 3], axis=-1).mean() 94 | 95 | poses[:, :3, 3] *= 4.0 / avglen 96 | 97 | print(f'[INFO] average radius = {avglen}') 98 | 99 | # process three splits 100 | for split, prefix in zip(['train', 'val', 'test'], ['0_', '1_', '2_']): 101 | 102 | print(f'[INFO] process split = {split}') 103 | 104 | split_poses = [poses[i] for i, x in enumerate(pose_files) if x.startswith(prefix)] 105 | split_images = [x for x in img_files if x.startswith(prefix)] 106 | 107 | if len(split_poses) == 0: 108 | print(f'[INFO] No test data found, use valid as test') 109 | split_poses = [poses[i] for i, x in enumerate(pose_files) if x.startswith('1_')] 110 | split_images = [x for x in img_files if x.startswith('1_')] 111 | 112 | print(f'[INFO] loaded {len(split_images)} images, {len(split_poses)} poses.') 113 | 114 | assert len(split_poses) == len(split_images) 115 | 116 | # construct a transforms.json 117 | frames = [] 118 | for image, pose in zip(split_images, split_poses): 119 | frames.append({ 120 | 'file_path': os.path.join('rgb', image), 121 | 'transform_matrix': pose.tolist(), 122 | }) 123 | 124 | transforms = { 125 | 'w': W, 126 | 'h': H, 127 | 'fl_x': fl_x, 128 | 'fl_y': fl_y, 129 | 'cx': cx, 130 | 'cy': cy, 131 | 'aabb_scale': 2, 132 | 'frames': frames, 133 | } 134 | 135 | # write 136 | output_path = os.path.join(opt.path, f'transforms_{split}.json') 137 | print(f'[INFO] write to {output_path}') 138 | with open(output_path, 'w') as f: 139 | json.dump(transforms, f, indent=2) 140 | 141 | -------------------------------------------------------------------------------- /sdf/netowrk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | 7 | 8 | class SDFNetwork(nn.Module): 9 | def __init__(self, 10 | encoding="hashgrid", 11 | num_layers=3, 12 | skips=[], 13 | hidden_dim=64, 14 | clip_sdf=None, 15 | ): 16 | super().__init__() 17 | 18 | 19 | self.num_layers = num_layers 20 | self.skips = skips 21 | self.hidden_dim = hidden_dim 22 | self.clip_sdf = clip_sdf 23 | 24 | self.encoder, self.in_dim = get_encoder(encoding) 25 | 26 | backbone = [] 27 | 28 | for l in range(num_layers): 29 | if l == 0: 30 | in_dim = self.in_dim 31 | elif l in self.skips: 32 | in_dim = self.hidden_dim + self.in_dim 33 | else: 34 | in_dim = self.hidden_dim 35 | 36 | if l == num_layers - 1: 37 | out_dim = 1 38 | else: 39 | out_dim = self.hidden_dim 40 | 41 | backbone.append(nn.Linear(in_dim, out_dim, bias=False)) 42 | 43 | self.backbone = nn.ModuleList(backbone) 44 | 45 | 46 | def forward(self, x): 47 | # x: [B, 3] 48 | 49 | x = self.encoder(x) 50 | 51 | h = x 52 | for l in range(self.num_layers): 53 | if l in self.skips: 54 | h = torch.cat([h, x], dim=-1) 55 | h = self.backbone[l](h) 56 | if l != self.num_layers - 1: 57 | h = F.relu(h, inplace=True) 58 | 59 | if self.clip_sdf is not None: 60 | h = h.clamp(-self.clip_sdf, self.clip_sdf) 61 | 62 | return h -------------------------------------------------------------------------------- /sdf/netowrk_ff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from encoding import get_encoder 6 | from ffmlp import FFMLP 7 | 8 | 9 | class SDFNetwork(nn.Module): 10 | def __init__(self, 11 | encoding="hashgrid", 12 | num_layers=3, 13 | skips=[], 14 | hidden_dim=64, 15 | clip_sdf=None, 16 | ): 17 | super().__init__() 18 | 19 | 20 | self.num_layers = num_layers 21 | self.skips = skips 22 | self.hidden_dim = hidden_dim 23 | self.clip_sdf = clip_sdf 24 | 25 | assert self.skips == [], 'FFMLP does not support concatenating inside, please use skips=[].' 26 | 27 | self.encoder, self.in_dim = get_encoder(encoding) 28 | 29 | self.backbone = FFMLP( 30 | input_dim=self.in_dim, 31 | output_dim=1, 32 | hidden_dim=self.hidden_dim, 33 | num_layers=self.num_layers, 34 | ) 35 | 36 | 37 | def forward(self, x): 38 | # x: [B, 3] 39 | 40 | x = self.encoder(x) 41 | 42 | h = self.backbone(x) 43 | 44 | if self.clip_sdf is not None: 45 | h = h.clamp(-self.clip_sdf, self.clip_sdf) 46 | 47 | return h -------------------------------------------------------------------------------- /sdf/network_tcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import tinycudann as tcnn 6 | 7 | class SDFNetwork(nn.Module): 8 | def __init__(self, 9 | encoding="hashgrid", 10 | num_layers=3, 11 | skips=[], 12 | hidden_dim=64, 13 | clip_sdf=None, 14 | ): 15 | super().__init__() 16 | 17 | 18 | self.num_layers = num_layers 19 | self.skips = skips 20 | self.hidden_dim = hidden_dim 21 | self.clip_sdf = clip_sdf 22 | 23 | assert self.skips == [], 'TCNN does not support concatenating inside, please use skips=[].' 24 | 25 | self.encoder = tcnn.Encoding( 26 | n_input_dims=3, 27 | encoding_config={ 28 | "otype": "HashGrid", 29 | "n_levels": 16, 30 | "n_features_per_level": 2, 31 | "log2_hashmap_size": 19, 32 | "base_resolution": 16, 33 | "per_level_scale": 1.3819, 34 | }, 35 | ) 36 | 37 | self.backbone = tcnn.Network( 38 | n_input_dims=32, 39 | n_output_dims=1, 40 | network_config={ 41 | "otype": "FullyFusedMLP", 42 | "activation": "ReLU", 43 | "output_activation": "None", 44 | "n_neurons": hidden_dim, 45 | "n_hidden_layers": num_layers - 1, 46 | }, 47 | ) 48 | 49 | 50 | def forward(self, x): 51 | # x: [B, 3] 52 | 53 | x = (x + 1) / 2 # to [0, 1] 54 | x = self.encoder(x) 55 | h = self.backbone(x) 56 | 57 | if self.clip_sdf is not None: 58 | h = h.clamp(-self.clip_sdf, self.clip_sdf) 59 | 60 | return h -------------------------------------------------------------------------------- /sdf/provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | import trimesh 7 | import pysdf 8 | 9 | def map_color(value, cmap_name='viridis', vmin=None, vmax=None): 10 | # value: [N], float 11 | # return: RGB, [N, 3], float in [0, 1] 12 | import matplotlib.cm as cm 13 | if vmin is None: vmin = value.min() 14 | if vmax is None: vmax = value.max() 15 | value = (value - vmin) / (vmax - vmin) # range in [0, 1] 16 | cmap = cm.get_cmap(cmap_name) 17 | rgb = cmap(value)[:, :3] # will return rgba, we take only first 3 so we get rgb 18 | return rgb 19 | 20 | def plot_pointcloud(pc, sdfs): 21 | # pc: [N, 3] 22 | # sdfs: [N, 1] 23 | color = map_color(sdfs.squeeze(1)) 24 | pc = trimesh.PointCloud(pc, color) 25 | trimesh.Scene([pc]).show() 26 | 27 | # SDF dataset 28 | class SDFDataset(Dataset): 29 | def __init__(self, path, size=100, num_samples=2**18, clip_sdf=None): 30 | super().__init__() 31 | self.path = path 32 | 33 | # load obj 34 | self.mesh = trimesh.load(path, force='mesh') 35 | 36 | # normalize to [-1, 1] (different from instant-sdf where is [0, 1]) 37 | vs = self.mesh.vertices 38 | vmin = vs.min(0) 39 | vmax = vs.max(0) 40 | v_center = (vmin + vmax) / 2 41 | v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95 42 | vs = (vs - v_center[None, :]) * v_scale 43 | self.mesh.vertices = vs 44 | 45 | print(f"[INFO] mesh: {self.mesh.vertices.shape} {self.mesh.faces.shape}") 46 | 47 | if not self.mesh.is_watertight: 48 | print(f"[WARN] mesh is not watertight! SDF maybe incorrect.") 49 | #trimesh.Scene([self.mesh]).show() 50 | 51 | self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces) 52 | 53 | self.num_samples = num_samples 54 | assert self.num_samples % 8 == 0, "num_samples must be divisible by 8." 55 | self.clip_sdf = clip_sdf 56 | 57 | self.size = size 58 | 59 | 60 | def __len__(self): 61 | return self.size 62 | 63 | def __getitem__(self, _): 64 | 65 | # online sampling 66 | sdfs = np.zeros((self.num_samples, 1)) 67 | # surface 68 | points_surface = self.mesh.sample(self.num_samples * 7 // 8) 69 | # perturb surface 70 | points_surface[self.num_samples // 2:] += 0.01 * np.random.randn(self.num_samples * 3 // 8, 3) 71 | # random 72 | points_uniform = np.random.rand(self.num_samples // 8, 3) * 2 - 1 73 | points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32) 74 | 75 | sdfs[self.num_samples // 2:] = -self.sdf_fn(points[self.num_samples // 2:])[:,None].astype(np.float32) 76 | 77 | # clip sdf 78 | if self.clip_sdf is not None: 79 | sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf) 80 | 81 | results = { 82 | 'sdfs': sdfs, 83 | 'points': points, 84 | } 85 | 86 | #plot_pointcloud(points, sdfs) 87 | 88 | return results 89 | -------------------------------------------------------------------------------- /shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_sh_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'shencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='shencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_shencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'shencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = None 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | 37 | return outputs 38 | 39 | @staticmethod 40 | #@once_differentiable 41 | @custom_bwd 42 | def backward(ctx, grad): 43 | # grad: [B, C * C] 44 | 45 | inputs, dy_dx = ctx.saved_tensors 46 | 47 | if dy_dx is not None: 48 | grad = grad.contiguous() 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | prefix_shape = list(inputs.shape[:-1]) 82 | inputs = inputs.reshape(-1, self.input_dim) 83 | 84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 85 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 86 | 87 | return outputs -------------------------------------------------------------------------------- /shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); 10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /tensoRF/network_cp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | from encoding import get_encoder 8 | from activation import trunc_exp 9 | from nerf.renderer import NeRFRenderer 10 | 11 | import raymarching 12 | 13 | 14 | class NeRFNetwork(NeRFRenderer): 15 | def __init__(self, 16 | resolution=[128] * 3, 17 | sigma_rank=[96] * 3, # ref: https://github.com/apchenstu/TensoRF/commit/7f505875a9f321fa8439a8d5c6a15fc7d2f17303 18 | color_rank=[288] * 3, 19 | color_feat_dim=27, 20 | num_layers=3, 21 | hidden_dim=128, 22 | bound=1, 23 | **kwargs 24 | ): 25 | super().__init__(bound, **kwargs) 26 | 27 | self.resolution = resolution 28 | 29 | # vector-matrix decomposition 30 | self.sigma_rank = sigma_rank 31 | self.color_rank = color_rank 32 | self.color_feat_dim = color_feat_dim 33 | 34 | self.mat_ids = [[0, 1], [0, 2], [1, 2]] 35 | self.vec_ids = [2, 1, 0] 36 | 37 | self.sigma_vec = self.init_one_svd(self.sigma_rank, self.resolution) 38 | self.color_vec = self.init_one_svd(self.color_rank, self.resolution) 39 | self.basis_mat = nn.Linear(self.color_rank[0], self.color_feat_dim, bias=False) 40 | 41 | # render module (default to freq feat + freq dir) 42 | self.num_layers = num_layers 43 | self.hidden_dim = hidden_dim 44 | 45 | self.encoder, enc_dim = get_encoder('frequency', input_dim=color_feat_dim, multires=2) 46 | self.encoder_dir, enc_dim_dir = get_encoder('frequency', input_dim=3, multires=2) 47 | 48 | self.in_dim = enc_dim + enc_dim_dir 49 | 50 | color_net = [] 51 | for l in range(num_layers): 52 | if l == 0: 53 | in_dim = self.in_dim 54 | else: 55 | in_dim = self.hidden_dim 56 | 57 | if l == num_layers - 1: 58 | out_dim = 3 # rgb 59 | else: 60 | out_dim = self.hidden_dim 61 | 62 | color_net.append(nn.Linear(in_dim, out_dim, bias=False)) 63 | 64 | self.color_net = nn.ModuleList(color_net) 65 | 66 | 67 | def init_one_svd(self, n_component, resolution, scale=0.2): 68 | 69 | vec = [] 70 | 71 | for i in range(len(self.vec_ids)): 72 | vec_id = self.vec_ids[i] 73 | vec.append(torch.nn.Parameter(scale * torch.randn((1, n_component[i], resolution[vec_id], 1)))) # [1, R, D, 1] (fake 2d to use grid_sample) 74 | 75 | return torch.nn.ParameterList(vec) 76 | 77 | 78 | def get_sigma_feat(self, x): 79 | # x: [N, 3], in [-1, 1] 80 | 81 | N = x.shape[0] 82 | 83 | # line basis 84 | vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) 85 | vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord 86 | 87 | vec_feat = F.grid_sample(self.sigma_vec[0], vec_coord[[0]], align_corners=True).view(-1, N) * \ 88 | F.grid_sample(self.sigma_vec[1], vec_coord[[1]], align_corners=True).view(-1, N) * \ 89 | F.grid_sample(self.sigma_vec[2], vec_coord[[2]], align_corners=True).view(-1, N) # [R, N] 90 | 91 | sigma_feat = torch.sum(vec_feat, dim=0) 92 | 93 | return sigma_feat 94 | 95 | 96 | def get_color_feat(self, x): 97 | # x: [N, 3], in [-1, 1] 98 | 99 | N = x.shape[0] 100 | 101 | # line basis 102 | vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) 103 | vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord 104 | 105 | vec_feat = F.grid_sample(self.color_vec[0], vec_coord[[0]], align_corners=True).view(-1, N) * \ 106 | F.grid_sample(self.color_vec[1], vec_coord[[1]], align_corners=True).view(-1, N) * \ 107 | F.grid_sample(self.color_vec[2], vec_coord[[2]], align_corners=True).view(-1, N) # [R, N] 108 | 109 | color_feat = self.basis_mat(vec_feat.T) # [N, R] --> [N, color_feat_dim] 110 | 111 | return color_feat 112 | 113 | 114 | def forward(self, x, d): 115 | # x: [N, 3], in [-bound, bound] 116 | # d: [N, 3], nomalized in [-1, 1] 117 | 118 | # normalize to [-1, 1] inside aabb_train 119 | x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 120 | 121 | # sigma 122 | sigma_feat = self.get_sigma_feat(x) 123 | sigma = trunc_exp(sigma_feat) 124 | 125 | # rgb 126 | color_feat = self.get_color_feat(x) 127 | enc_color_feat = self.encoder(color_feat) 128 | enc_d = self.encoder_dir(d) 129 | 130 | h = torch.cat([enc_color_feat, enc_d], dim=-1) 131 | for l in range(self.num_layers): 132 | h = self.color_net[l](h) 133 | if l != self.num_layers - 1: 134 | h = F.relu(h, inplace=True) 135 | 136 | # sigmoid activation for rgb 137 | rgb = torch.sigmoid(h) 138 | 139 | return sigma, rgb 140 | 141 | 142 | def density(self, x): 143 | # x: [N, 3], in [-bound, bound] 144 | 145 | # normalize to [-1, 1] inside aabb_train 146 | x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 147 | 148 | sigma_feat = self.get_sigma_feat(x) 149 | sigma = trunc_exp(sigma_feat) 150 | 151 | return { 152 | 'sigma': sigma, 153 | } 154 | 155 | # allow masked inference 156 | def color(self, x, d, mask=None, **kwargs): 157 | # x: [N, 3] in [-bound, bound] 158 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 159 | 160 | # normalize to [-1, 1] inside aabb_train 161 | x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 162 | 163 | if mask is not None: 164 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] 165 | # in case of empty mask 166 | if not mask.any(): 167 | return rgbs 168 | x = x[mask] 169 | d = d[mask] 170 | 171 | color_feat = self.get_color_feat(x) 172 | color_feat = self.encoder(color_feat) 173 | d = self.encoder_dir(d) 174 | 175 | h = torch.cat([color_feat, d], dim=-1) 176 | for l in range(self.num_layers): 177 | h = self.color_net[l](h) 178 | if l != self.num_layers - 1: 179 | h = F.relu(h, inplace=True) 180 | 181 | # sigmoid activation for rgb 182 | h = torch.sigmoid(h) 183 | 184 | if mask is not None: 185 | rgbs[mask] = h.to(rgbs.dtype) 186 | else: 187 | rgbs = h 188 | 189 | return rgbs 190 | 191 | 192 | # L1 penalty for loss 193 | def density_loss(self): 194 | loss = 0 195 | for i in range(len(self.sigma_vec)): 196 | loss = loss + torch.mean(torch.abs(self.sigma_vec[i])) 197 | return loss 198 | 199 | # upsample utils 200 | @torch.no_grad() 201 | def upsample_params(self, vec, resolution): 202 | 203 | for i in range(len(self.vec_ids)): 204 | vec_id = self.vec_ids[i] 205 | vec[i] = torch.nn.Parameter(F.interpolate(vec[i].data, size=(resolution[vec_id], 1), mode='bilinear', align_corners=True)) 206 | 207 | 208 | @torch.no_grad() 209 | def upsample_model(self, resolution): 210 | self.upsample_params(self.sigma_vec, resolution) 211 | self.upsample_params(self.color_vec, resolution) 212 | self.resolution = resolution 213 | 214 | @torch.no_grad() 215 | def shrink_model(self): 216 | 217 | half_grid_size = self.bound / self.grid_size 218 | thresh = min(self.density_thresh, self.mean_density) 219 | 220 | # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?) 221 | valid_grid = self.density_grid[self.cascade - 1] > thresh # [N] 222 | valid_pos = raymarching.morton3D_invert(torch.nonzero(valid_grid)) # [Nz] --> [Nz, 3], in [0, H - 1] 223 | 224 | #plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf... 225 | valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (self.bound - half_grid_size) # [Nz, 3], in [-b+hgs, b-hgs] 226 | min_pos = valid_pos.amin(0) - half_grid_size # [3] 227 | max_pos = valid_pos.amax(0) + half_grid_size # [3] 228 | 229 | # shrink model 230 | reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) 231 | units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso 232 | tl = (min_pos - self.aabb_train[:3]) / units 233 | br = (max_pos - self.aabb_train[:3]) / units 234 | tl = torch.round(tl).long().clamp(min=0) 235 | br = torch.minimum(torch.round(br).long(), reso) 236 | 237 | for i in range(len(self.vec_ids)): 238 | vec_id = self.vec_ids[i] 239 | 240 | self.sigma_vec[i] = nn.Parameter(self.sigma_vec[i].data[..., tl[vec_id]:br[vec_id], :]) 241 | self.color_vec[i] = nn.Parameter(self.color_vec[i].data[..., tl[vec_id]:br[vec_id], :]) 242 | 243 | self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] 244 | 245 | print(f'[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}') 246 | print(f'[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}') 247 | 248 | # optimizer utils 249 | def get_params(self, lr1, lr2): 250 | return [ 251 | {'params': self.sigma_vec, 'lr': lr1}, 252 | {'params': self.color_vec, 'lr': lr1}, 253 | {'params': self.basis_mat.parameters(), 'lr': lr2}, 254 | {'params': self.color_net.parameters(), 'lr': lr2}, 255 | ] 256 | -------------------------------------------------------------------------------- /testing/test_ffmlp.py: -------------------------------------------------------------------------------- 1 | from matplotlib.animation import AVConvBase 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ffmlp import FFMLP 7 | import math 8 | 9 | import tinycudann as tcnn 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, input_dim, output_dim, hidden_dim, num_layers, activation=F.relu): 13 | super().__init__() 14 | 15 | self.num_layers = num_layers 16 | self.hidden_dim = hidden_dim 17 | self.activation = activation 18 | 19 | self.net = nn.ModuleList() 20 | self.net.append(nn.Linear(input_dim, hidden_dim, bias=False)) 21 | for i in range(num_layers - 1): 22 | self.net.append(nn.Linear(hidden_dim, hidden_dim, bias=False)) 23 | self.net.append(nn.Linear(hidden_dim, output_dim, bias=False)) 24 | 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | torch.manual_seed(42) 29 | for p in self.parameters(): 30 | #nn.init.constant_(p.data, 1) 31 | std = math.sqrt(3 / self.hidden_dim) 32 | p.data.uniform_(-std, std) 33 | #torch.manual_seed(42) 34 | #nn.init.uniform_(p.data, 0, 1) 35 | #nn.init.eye_(p.data) 36 | 37 | 38 | def forward(self, x): 39 | for i in range(self.num_layers + 1): 40 | x = self.net[i](x) 41 | if i != self.num_layers: 42 | x = self.activation(x) 43 | return x 44 | 45 | # ################################## 46 | # # Functionality 47 | # ################################## 48 | 49 | # BATCH_SIZE = 1280000 # 1048576 # 128 # the least batch to lauch a full block ! 50 | # INPUT_DIM = 16 # 16 # != (16 * m) has bug... 51 | # OUTPUT_DIM = 1 # 16 # > 16 still has bug... 52 | # HIDDEN_DIM = 64 # 16 53 | # NUM_LAYERS = 3 # 2 54 | 55 | 56 | # net0 = FFMLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() 57 | # net1 = MLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() 58 | 59 | # # print(net0.weights) 60 | # # print(net1.net[0].weight) 61 | 62 | # for _ in range(5): 63 | 64 | # x0 = torch.randn(BATCH_SIZE, INPUT_DIM).cuda() * 1 65 | # x1 = x0.detach().clone() 66 | # x0.requires_grad_(True) 67 | # x1.requires_grad_(True) 68 | 69 | # # print('===== x =====') 70 | # # print(x0) 71 | # # print(x1) 72 | 73 | # with torch.cuda.amp.autocast(enabled=True): 74 | # y1 = net1(x1) 75 | # y0 = net0(x0) 76 | 77 | 78 | # print('===== y1 =====') 79 | # print(y1) 80 | 81 | # print('===== y0 =====') 82 | # print(y0) 83 | 84 | # (y1.sum() * 1).backward() 85 | # print('===== grad w1 =====') 86 | # print(net1.net[0].weight.grad.dtype, torch.cat([net1.net[0].weight.grad.view(-1), net1.net[1].weight.grad.view(-1), net1.net[2].weight.grad.view(-1)], dim=0)) 87 | # print(x1.grad.dtype, x1.grad) 88 | 89 | # (y0.sum() * 1).backward() 90 | # print('===== grad w0 =====') 91 | # print(net0.weights.grad.dtype, net0.weights.grad) 92 | # print(x0.grad.dtype, x0.grad) 93 | 94 | 95 | 96 | # ################################## 97 | # # Speed 98 | # ################################## 99 | 100 | BATCH_SIZE = 2**21 101 | INPUT_DIM = 16 102 | OUTPUT_DIM = 16 103 | HIDDEN_DIM = 64 104 | NUM_LAYERS = 2 105 | 106 | net0 = FFMLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() 107 | net1 = MLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() 108 | net2 = tcnn.Network(n_input_dims=INPUT_DIM, n_output_dims=OUTPUT_DIM, network_config={ 109 | "otype": "FullyFusedMLP", 110 | "activation": "ReLU", 111 | "output_activation": "None", 112 | "n_neurons": HIDDEN_DIM, 113 | "n_hidden_layers": NUM_LAYERS, 114 | }) 115 | 116 | x = torch.rand(BATCH_SIZE, INPUT_DIM).cuda() * 10 117 | x1 = x.detach().clone() 118 | x2 = x.detach().clone() 119 | x3 = x.detach().clone() 120 | 121 | 122 | 123 | #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: 124 | 125 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 126 | starter.record() 127 | y2 = net1(x2) 128 | ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'pytorch MLP (fp32 train) = {curr_time}') 129 | 130 | starter.record() 131 | y2.sum().backward() 132 | ender.record() 133 | torch.cuda.synchronize() 134 | curr_time = starter.elapsed_time(ender) 135 | print(f'pytorch MLP (fp32 back) = {curr_time}') 136 | 137 | #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 138 | 139 | with torch.cuda.amp.autocast(enabled=True): 140 | 141 | #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: 142 | starter.record() 143 | y0 = net0(x) 144 | ender.record() 145 | torch.cuda.synchronize() 146 | curr_time = starter.elapsed_time(ender) 147 | print(f'FFMLP (forward) = {curr_time}') 148 | 149 | starter.record() 150 | y0.sum().backward() 151 | ender.record() 152 | torch.cuda.synchronize() 153 | curr_time = starter.elapsed_time(ender) 154 | print(f'FFMLP (backward) = {curr_time}') 155 | 156 | #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 157 | 158 | #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: 159 | starter.record() 160 | y1 = net1(x1) 161 | ender.record() 162 | torch.cuda.synchronize() 163 | curr_time = starter.elapsed_time(ender) 164 | print(f'pytorch MLP (forward) = {curr_time}') 165 | 166 | starter.record() 167 | y1.sum().backward() 168 | ender.record() 169 | torch.cuda.synchronize() 170 | curr_time = starter.elapsed_time(ender) 171 | print(f'pytorch MLP (backward) = {curr_time}') 172 | #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 173 | 174 | #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: 175 | starter.record() 176 | y3 = net2(x3) 177 | ender.record() 178 | torch.cuda.synchronize() 179 | curr_time = starter.elapsed_time(ender) 180 | print(f'TCNN (forward) = {curr_time}') 181 | 182 | starter.record() 183 | y3.sum().backward() 184 | ender.record() 185 | torch.cuda.synchronize() 186 | curr_time = starter.elapsed_time(ender) 187 | print(f'TCNN (backward) = {curr_time}') 188 | #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 189 | 190 | with torch.no_grad(): 191 | 192 | starter.record() 193 | y1 = net1(x) 194 | ender.record() 195 | torch.cuda.synchronize() 196 | curr_time = starter.elapsed_time(ender) 197 | print(f'pytorch MLP (fp32 infer) = {curr_time}') 198 | 199 | with torch.cuda.amp.autocast(enabled=True): 200 | 201 | 202 | #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: 203 | 204 | starter.record() 205 | y0 = net0(x) 206 | ender.record() 207 | torch.cuda.synchronize() 208 | curr_time = starter.elapsed_time(ender) 209 | print(f'FFMLP (infer) = {curr_time}') 210 | 211 | #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 212 | 213 | #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: 214 | 215 | starter.record() 216 | y1 = net1(x) 217 | ender.record() 218 | torch.cuda.synchronize() 219 | curr_time = starter.elapsed_time(ender) 220 | print(f'pytorch MLP (infer) = {curr_time}') 221 | 222 | #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 223 | 224 | #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: 225 | 226 | starter.record() 227 | y2 = net2(x) 228 | ender.record() 229 | torch.cuda.synchronize() 230 | curr_time = starter.elapsed_time(ender) 231 | print(f'TCNN (infer) = {curr_time}') 232 | 233 | #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) 234 | 235 | 236 | # print(y0) 237 | # print(y1) 238 | -------------------------------------------------------------------------------- /testing/test_hashencoder.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from gridencoder import GridEncoder 5 | 6 | B = 1 7 | D = 2 8 | 9 | enc = GridEncoder(D=D, L=2, C=1, base_resolution=4, log2_hashmap_size=5).cuda() 10 | #enc = GridEncoder(D=D, L=16, C=2, base_resolution=16).cuda() 11 | 12 | print(f"=== enc ===") 13 | print(enc.embeddings.shape) 14 | print(enc.embeddings) 15 | 16 | #x = torch.rand(B, D).cuda() * 2 - 1 # in [-1, 1] 17 | x = torch.FloatTensor(np.array([ 18 | #[-1, 1], 19 | #[1, 1], 20 | [0, 0], 21 | #[-1, -1], 22 | #[1, -1], 23 | ])).cuda() 24 | 25 | #x.requires_grad_(True) 26 | 27 | print(f"=== x ===") 28 | print(x) 29 | print(x.shape) 30 | 31 | y = enc(x, calc_grad_inputs=False) 32 | 33 | print(f"=== y ===") 34 | print(y.shape) 35 | print(y) 36 | 37 | y.sum().backward() 38 | 39 | print(f"=== grad enc ===") 40 | print(enc.embeddings.grad.shape) 41 | print(enc.embeddings.grad) 42 | 43 | #print(x.grad.shape) 44 | #print(x.grad) -------------------------------------------------------------------------------- /testing/test_hashgrid_grad.py: -------------------------------------------------------------------------------- 1 | # we need check the grad_hash_grid; 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import gradcheck 5 | import numpy as np 6 | from gridencoder.grid import _grid_encode 7 | import random 8 | import os 9 | # import torch.random as random 10 | device=torch.device(0) 11 | input_dim=3 # 2 12 | num_levels=4 # 1 13 | level_dim=2 # 1 14 | per_level_scale=2 15 | base_resolution=4 # 2 16 | log2_hashmap_size=8 # 4 17 | # inputs , embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False 18 | 19 | output_dim = num_levels * level_dim 20 | 21 | if level_dim % 2 != 0: 22 | print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)') 23 | 24 | # allocate parameters 25 | offsets = [] 26 | offset = 0 27 | max_params = 2 ** log2_hashmap_size 28 | for i in range(num_levels): 29 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 30 | params_in_level = min(max_params, (resolution + 1) ** input_dim) # limit max number 31 | #params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible 32 | offsets.append(offset) 33 | offset += params_in_level 34 | offsets.append(offset) 35 | 36 | print(offsets) 37 | 38 | def seed_torch(seed=42): 39 | random.seed(seed) 40 | os.environ['PYTHONHASHSEED'] = str(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed(seed) 44 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 45 | torch.backends.cudnn.benchmark = False 46 | torch.backends.cudnn.deterministic = True 47 | 48 | #seed_torch() 49 | 50 | # parameters 51 | inputs = torch.rand(1, input_dim, dtype= torch.float64, requires_grad=False).to(device) 52 | 53 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)).to(device) 54 | embeddings = torch.randn(offset, level_dim, dtype=torch.float64, requires_grad=True).to(device) * 0.1 55 | 56 | print(inputs) 57 | print(embeddings) 58 | 59 | 60 | Inputs = (inputs, embeddings, offsets, per_level_scale, base_resolution, inputs.requires_grad) 61 | check_results1 = torch.autograd.gradcheck(_grid_encode.apply, Inputs, eps=1e-2, atol=1e-3, rtol=0.01, fast_mode=False) 62 | print("check_results1", check_results1) 63 | -------------------------------------------------------------------------------- /testing/test_raymarching.py: -------------------------------------------------------------------------------- 1 | import raymarching -------------------------------------------------------------------------------- /testing/test_shencoder.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from shencoder import SHEncoder 6 | 7 | 8 | class SHEncoder_torch(nn.Module): 9 | def __init__(self, input_dim=3, degree=4): 10 | 11 | super().__init__() 12 | 13 | self.input_dim = input_dim 14 | self.degree = degree 15 | 16 | assert self.input_dim == 3 17 | assert self.degree >= 1 and self.degree <= 5 18 | 19 | self.output_dim = degree ** 2 20 | 21 | self.C0 = 0.28209479177387814 22 | self.C1 = 0.4886025119029199 23 | self.C2 = [ 24 | 1.0925484305920792, 25 | -1.0925484305920792, 26 | 0.31539156525252005, 27 | -1.0925484305920792, 28 | 0.5462742152960396 29 | ] 30 | self.C3 = [ 31 | -0.5900435899266435, 32 | 2.890611442640554, 33 | -0.4570457994644658, 34 | 0.3731763325901154, 35 | -0.4570457994644658, 36 | 1.445305721320277, 37 | -0.5900435899266435 38 | ] 39 | self.C4 = [ 40 | 2.5033429417967046, 41 | -1.7701307697799304, 42 | 0.9461746957575601, 43 | -0.6690465435572892, 44 | 0.10578554691520431, 45 | -0.6690465435572892, 46 | 0.47308734787878004, 47 | -1.7701307697799304, 48 | 0.6258357354491761 49 | ] 50 | 51 | def forward(self, input, **kwargs): 52 | 53 | result = torch.empty((*input.shape[:-1], self.output_dim), dtype=input.dtype, device=input.device) 54 | x, y, z = input.unbind(-1) 55 | 56 | result[..., 0] = self.C0 57 | if self.degree > 1: 58 | result[..., 1] = -self.C1 * y 59 | result[..., 2] = self.C1 * z 60 | result[..., 3] = -self.C1 * x 61 | if self.degree > 2: 62 | xx, yy, zz = x * x, y * y, z * z 63 | xy, yz, xz = x * y, y * z, x * z 64 | result[..., 4] = self.C2[0] * xy 65 | result[..., 5] = self.C2[1] * yz 66 | #result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy) 67 | result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting... 68 | result[..., 7] = self.C2[3] * xz 69 | result[..., 8] = self.C2[4] * (xx - yy) 70 | if self.degree > 3: 71 | result[..., 9] = self.C3[0] * y * (3 * xx - yy) 72 | result[..., 10] = self.C3[1] * xy * z 73 | result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy) 74 | result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy) 75 | result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy) 76 | result[..., 14] = self.C3[5] * z * (xx - yy) 77 | result[..., 15] = self.C3[6] * x * (xx - 3 * yy) 78 | if self.degree > 4: 79 | result[..., 16] = self.C4[0] * xy * (xx - yy) 80 | result[..., 17] = self.C4[1] * yz * (3 * xx - yy) 81 | result[..., 18] = self.C4[2] * xy * (7 * zz - 1) 82 | result[..., 19] = self.C4[3] * yz * (7 * zz - 3) 83 | result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3) 84 | result[..., 21] = self.C4[5] * xz * (7 * zz - 3) 85 | result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1) 86 | result[..., 23] = self.C4[7] * xz * (xx - 3 * yy) 87 | result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 88 | 89 | return result 90 | 91 | B = 25600 92 | C = 3 93 | degree = 4 94 | 95 | enc1 = SHEncoder_torch(degree=degree).cuda() 96 | enc2 = SHEncoder(degree=degree).cuda() 97 | 98 | x1 = torch.rand(B, 3).cuda() * 2 - 1 # in [-1, 1] 99 | x1 = x1 / (torch.norm(x1, dim=-1, keepdim=True) + 1e-8) 100 | x1.requires_grad_(True) 101 | 102 | x2 = x1.detach().clone() 103 | x2.requires_grad_(True) 104 | 105 | print(f"=== x ===") 106 | print(x1) 107 | 108 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 109 | 110 | with torch.no_grad(): 111 | with torch.cuda.amp.autocast(enabled=True): 112 | 113 | starter.record() 114 | y1 = enc1(x1) 115 | ender.record() 116 | torch.cuda.synchronize() 117 | curr_time = starter.elapsed_time(ender) 118 | print(f'time 1 = {curr_time}') 119 | 120 | starter.record() 121 | y2 = enc2(x2) 122 | ender.record() 123 | torch.cuda.synchronize() 124 | curr_time = starter.elapsed_time(ender) 125 | print(f'time 2 = {curr_time}') 126 | 127 | print(f"=== y ===") 128 | print(y1) 129 | print(y2) 130 | 131 | # starter.record() 132 | # y1.sum().backward() 133 | # ender.record() 134 | # torch.cuda.synchronize() 135 | # curr_time = starter.elapsed_time(ender) 136 | # print(f'time 1 (back) = {curr_time}') 137 | 138 | # starter.record() 139 | # y2.sum().backward() 140 | # ender.record() 141 | # torch.cuda.synchronize() 142 | # curr_time = starter.elapsed_time(ender) 143 | # print(f'time 2 (back) = {curr_time}') 144 | 145 | # print(f"=== grad x ===") 146 | # print(x1.grad) 147 | # print(x2.grad) -------------------------------------------------------------------------------- /utils_img.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyright: reportMissingModuleSource=false 8 | 9 | import numpy as np 10 | from augly.image import functional as aug_functional 11 | import torch 12 | from torchvision import transforms 13 | from torchvision.transforms import functional 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | default_transform = transforms.Compose([ 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 20 | ]) 21 | 22 | normalize_vqgan = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize (x - 0.5) / 0.5 23 | unnormalize_vqgan = transforms.Normalize(mean=[-1, -1, -1], std=[1/0.5, 1/0.5, 1/0.5]) # Unnormalize (x * 0.5) + 0.5 24 | normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize (x - mean) / std 25 | unnormalize_img = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) # Unnormalize (x * std) + mean 26 | 27 | def psnr(x, y, img_space='vqgan'): 28 | """ 29 | Return PSNR 30 | Args: 31 | x: Image tensor with values approx. between [-1,1] 32 | y: Image tensor with values approx. between [-1,1], ex: original image 33 | """ 34 | if img_space == 'vqgan': 35 | delta = torch.clamp(unnormalize_vqgan(x), 0, 1) - torch.clamp(unnormalize_vqgan(y), 0, 1) 36 | elif img_space == 'img': 37 | delta = torch.clamp(unnormalize_img(x), 0, 1) - torch.clamp(unnormalize_img(y), 0, 1) 38 | else: 39 | delta = x - y 40 | delta = 255 * delta 41 | delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1]) # BxCxHxW 42 | psnr = 20*np.log10(255) - 10*torch.log10(torch.mean(delta**2, dim=(1,2,3))) # B 43 | return psnr 44 | 45 | def center_crop(x, scale): 46 | """ Perform center crop such that the target area of the crop is at a given scale 47 | Args: 48 | x: PIL image 49 | scale: target area scale 50 | """ 51 | scale = np.sqrt(scale) 52 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 53 | return functional.center_crop(x, new_edges_size) 54 | 55 | def resize(x, scale): 56 | """ Perform center crop such that the target area of the crop is at a given scale 57 | Args: 58 | x: PIL image 59 | scale: target area scale 60 | """ 61 | scale = np.sqrt(scale) 62 | new_edges_size = [int(s*scale) for s in x.shape[-2:]][::-1] 63 | return functional.resize(x, new_edges_size) 64 | 65 | def rotate(x, angle): 66 | """ Rotate image by angle 67 | Args: 68 | x: image (PIl or tensor) 69 | angle: angle in degrees 70 | """ 71 | return functional.rotate(x, angle) 72 | 73 | def adjust_brightness(x, brightness_factor): 74 | """ Adjust brightness of an image 75 | Args: 76 | x: PIL image 77 | brightness_factor: brightness factor 78 | """ 79 | return normalize_img(functional.adjust_brightness(unnormalize_img(x), brightness_factor)) 80 | 81 | def adjust_contrast(x, contrast_factor): 82 | """ Adjust contrast of an image 83 | Args: 84 | x: PIL image 85 | contrast_factor: contrast factor 86 | """ 87 | return normalize_img(functional.adjust_contrast(unnormalize_img(x), contrast_factor)) 88 | 89 | def adjust_saturation(x, saturation_factor): 90 | """ Adjust saturation of an image 91 | Args: 92 | x: PIL image 93 | saturation_factor: saturation factor 94 | """ 95 | return normalize_img(functional.adjust_saturation(unnormalize_img(x), saturation_factor)) 96 | 97 | def adjust_hue(x, hue_factor): 98 | """ Adjust hue of an image 99 | Args: 100 | x: PIL image 101 | hue_factor: hue factor 102 | """ 103 | return normalize_img(functional.adjust_hue(unnormalize_img(x), hue_factor)) 104 | 105 | def adjust_gamma(x, gamma, gain=1): 106 | """ Adjust gamma of an image 107 | Args: 108 | x: PIL image 109 | gamma: gamma factor 110 | gain: gain factor 111 | """ 112 | return normalize_img(functional.adjust_gamma(unnormalize_img(x), gamma, gain)) 113 | 114 | def adjust_sharpness(x, sharpness_factor): 115 | """ Adjust sharpness of an image 116 | Args: 117 | x: PIL image 118 | sharpness_factor: sharpness factor 119 | """ 120 | return normalize_img(functional.adjust_sharpness(unnormalize_img(x), sharpness_factor)) 121 | 122 | def overlay_text(x, text='Lorem Ipsum'): 123 | """ Overlay text on image 124 | Args: 125 | x: PIL image 126 | text: text to overlay 127 | font_path: path to font 128 | font_size: font size 129 | color: text color 130 | position: text position 131 | """ 132 | to_pil = transforms.ToPILImage() 133 | to_tensor = transforms.ToTensor() 134 | img_aug = torch.zeros_like(x, device=x.device) 135 | for ii,img in enumerate(x): 136 | pil_img = to_pil(unnormalize_img(img)) 137 | img_aug[ii] = to_tensor(aug_functional.overlay_text(pil_img, text=text)) 138 | return normalize_img(img_aug) 139 | 140 | def jpeg_compress(x, quality_factor): 141 | """ Apply jpeg compression to image 142 | Args: 143 | x: PIL image 144 | quality_factor: quality factor 145 | """ 146 | to_pil = transforms.ToPILImage() 147 | to_tensor = transforms.ToTensor() 148 | img_aug = torch.zeros_like(x, device=x.device) 149 | for ii,img in enumerate(x): 150 | pil_img = to_pil(unnormalize_img(img)) 151 | img_aug[ii] = to_tensor(aug_functional.encoding_quality(pil_img, quality=quality_factor)) 152 | return normalize_img(img_aug) 153 | --------------------------------------------------------------------------------