├── .gitignore ├── 2d.py ├── LICENSE ├── README.md ├── assets ├── group_photo.png ├── group_photo_fan.png ├── nopgc.png └── pgc.png ├── data ├── 0000.png ├── background.png └── panda_spear_shield.obj ├── encoding.py ├── freqencoder ├── __init__.py ├── backend.py ├── freq.py ├── setup.py └── src │ ├── bindings.cpp │ ├── freqencoder.cu │ └── freqencoder.h ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── guidance ├── clip_utils.py ├── guidance_utils.py ├── sd.py ├── sdcontrolnet.py ├── sdxl.py ├── sdxl_controlnet.py ├── sdxl_vsd.py └── zero123_utils.py ├── ldm ├── extras.py ├── guidance.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── evaluate │ │ ├── adm_evaluator.py │ │ ├── evaluate_perceptualsim.py │ │ ├── frechet_video_distance.py │ │ ├── ssim.py │ │ └── torch_frechet_video_distance.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py ├── thirdp │ └── psp │ │ ├── helpers.py │ │ ├── id_loss.py │ │ └── model_irse.py └── util.py ├── main.py ├── nerf ├── fine │ ├── dmtet.py │ ├── dmtet_utils.py │ ├── generate_tets.py │ ├── neural_render.py │ ├── perspective_camera.py │ ├── rast_utils.py │ ├── rasterizer.py │ ├── rasterizer_mesh.py │ ├── rasterizer_pbr.py │ ├── tets │ │ └── 256_compress.npz │ └── trainer.py ├── provider.py └── utils.py ├── others └── nvdiffrec │ ├── data │ └── irrmaps │ │ ├── README.txt │ │ ├── aerodynamics_workshop_2k.hdr │ │ ├── blocky_photo_studio_2k.hdr │ │ ├── bsdf_256_256.bin │ │ └── mud_road_puresky_4k.hdr │ └── render │ ├── light.py │ ├── material.py │ ├── mesh.py │ ├── mlptexture.py │ ├── obj.py │ ├── regularizer.py │ ├── render.py │ ├── renderutils │ ├── __init__.py │ ├── bsdf.py │ ├── c_src │ │ ├── bsdf.cu │ │ ├── bsdf.h │ │ ├── common.cpp │ │ ├── common.h │ │ ├── cubemap.cu │ │ ├── cubemap.h │ │ ├── loss.cu │ │ ├── loss.h │ │ ├── mesh.cu │ │ ├── mesh.h │ │ ├── normal.cu │ │ ├── normal.h │ │ ├── tensor.h │ │ ├── torch_bindings.cpp │ │ ├── vec3f.h │ │ └── vec4f.h │ ├── loss.py │ ├── ops.py │ └── tests │ │ ├── test_bsdf.py │ │ ├── test_loss.py │ │ ├── test_mesh.py │ │ └── test_perf.py │ ├── texture.py │ └── util.py ├── raymarching ├── __init__.py ├── backend.py ├── raymarching.py ├── setup.py └── src │ ├── bindings.cpp │ ├── raymarching.cu │ └── raymarching.h ├── requirements.txt ├── scripts ├── blender.py ├── install_ext.sh └── linear_app.py └── shencoder ├── __init__.py ├── backend.py ├── setup.py ├── sphere_harmonics.py └── src ├── bindings.cpp ├── shencoder.cu └── shencoder.h /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | *.egg-info/ 4 | *.so 5 | 6 | tmp* 7 | trial*/ 8 | train*/ 9 | .vs/ 10 | run.sh 11 | TOKEN 12 | exp/ 13 | *.pth 14 | *.out 15 | .idea 16 | 2d 17 | results/ 18 | scripts/debug.py 19 | scripts/run*.sh 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping 2 | ### [[Paper]](https://arxiv.org/abs/2310.12474) | [[Project]](https://fudan-zvg.github.io/PGC-3D/) 3 | 4 | > [**Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping**](https://arxiv.org/abs/2310.12474), 5 | > Zijie Pan, [Jiachen Lu](https://victorllu.github.io/), [Xiatian Zhu](https://surrey-uplab.github.io/), [Li Zhang](https://lzrobots.github.io) 6 | > **ICLR 2024** 7 | 8 | **Official implementation of "Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping".** 9 | 10 | 11 | **PGC** (Pixel-wise Gradient Clipping) introduces a refined method to adapt traditional gradient clipping. By focusing on pixel-wise gradient magnitudes, it retains vital texture details. This approach acts as a versatile **plug-in**, seamlessly complementing existing **SDS and LDM-based 3D generative models**. The result is a marked improvement in high-resolution 3D texture synthesis. 12 | 13 | With PGC, users can: 14 | - Address and mitigate gradient-related challenges common in LDM, elevating the quality of 3D generation. 15 | - Employ the [**SDXL**](https://github.com/Stability-AI/generative-models) approach, previously not adaptable for 3D generation. 16 | 17 | This repo also offers an unified implementation for mesh optimization and reproduction of many SDS variants. 18 | 19 | # Install 20 | 21 | ```bash 22 | git clone https://github.com/fudan-zvg/PGC-3D.git 23 | cd PGC-3D 24 | ``` 25 | 26 | **Huggingface token**: 27 | Some new models need access token. 28 | Create a file called `TOKEN` under this directory (i.e., `PGC-3D/TOKEN`) 29 | and copy your [token](https://huggingface.co/settings/tokens) into it. 30 | 31 | ### Install pytorch3d 32 | 33 | We use pytorch3d implementation for normal consistency loss, so we need first install pytorch3d. 34 | ```bash 35 | conda create -n pgc python=3.9 36 | conda activate pgc 37 | conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia 38 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 39 | conda install pytorch3d -c pytorch3d 40 | ``` 41 | 42 | ### Install with pip and build extension 43 | 44 | ```bash 45 | pip install -r requirements.txt 46 | bash scripts/install_ext.sh 47 | ``` 48 | 49 | ### Tested environments 50 | * Ubuntu 20 with torch 1.13 & CUDA 11.7 on two A6000. 51 | 52 | # Usage 53 | 54 | This repo supports mesh optimization (Stage2 of Magic3D/Fantasia3D) with many methods including SDS, VSD, BGT, CSD, SSD loss and PBR material modeling. 55 | If you find error/bug in the reproduction, feel free to raise an issue. 56 | We assume a triangle mesh is provided. 57 | 58 | ```bash 59 | ### basic options 60 | # -O0 pixel-wise gradient clipping 61 | # -O1 Stage2 of Magic3D/Fantasia3D using DMTet 62 | # -O2 with mesh fixed 63 | # -O3 using an image as reference and Zero123 (not fully tested) 64 | # -O4 using SDXL 65 | # --pbr PBR modeling as nvdiffrec/Fantasia3D 66 | # --vsd VSD loss proposed by ProlificDreamer (only implemented for SDXL, so -O4 is needed) 67 | # --bgt BGT+ loss proposed by HiFA 68 | # --csd Classifier Score Distillation 69 | # --ssd Stable Score Distillation 70 | # --dreamtime an annealing time schedule proposed by DreamTime 71 | # --latent training in latent space as Latent-NeRF 72 | # --cfg_rescale suggested by AvatarStudio 73 | # --init_texture initialize texture if mesh has vertex color 74 | # --bs 4 batch size 4 per gpu 75 | 76 | ### sample commands 77 | # all the time is tested on two A6000 78 | # text-to-texture using SDXL with depth controlnet (46G, 36min) 79 | python main.py -O0 -O2 -O4 --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/sdxl --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 80 | 81 | # PBR modeling (48G, 36min) 82 | python main.py -O0 -O2 -O4 --pbr --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/pbr --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 83 | 84 | # VSD loss (42G, 60min) 85 | python main.py -O0 -O2 -O4 --vsd --guidance_scale 7.5 --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/vsd --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 86 | 87 | # optimize the shape using normal-SDS as Fantasia3D (46G, 33min) 88 | python main.py -O0 -O4 --only_normal --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/normal_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 89 | # if finetune, add depth control to stablize the shape (46G, 36min) 90 | python main.py -O0 -O4 --only_normal --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/normal_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 91 | 92 | # optimize both shape and texture using RGB-SDS as Magic3D (46G, 36min) 93 | python main.py -O0 -O4 --no_normal --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/rgb_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 94 | 95 | # optimize both shape and texture using both normal-SDS and RGB-SDS (47G, 74min) 96 | python main.py -O0 -O4 --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/nrm_rgb_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 97 | 98 | # We also support multi-controlnets, e.g. depth + shuffle control with a reference image (only Stable Diffusion v1.5) 99 | # 20G, 22min 100 | python main.py -O0 -O2 --ref_path "data/panda_reference.png" --control_type "depth" "shuffle" --control_scale 0.7 0.3 --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/sd_depth_shuffle --gpus "0,1" --mesh_path "data/panda_spear_shield.obj" 101 | ``` 102 | 103 | ## Results 104 | Incorporating the proficient and potent **PGC** implementation into SDXL guidance has led to notable advancements in 3D generation results. 105 | 106 | #### [Fantasia3D](https://github.com/Gorilla-Lab-SCUT/Fantasia3D) 107 | photo 108 | 109 | #### Ours 110 | photo 111 | 112 | ## Tips for SDXL 113 | - Using [sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) 114 | - [Controlnet](https://huggingface.co/diffusers/controlnet-depth-sdxl-1.0) may be important to make training easier to converge 115 | 116 | ## BibTeX 117 | If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work: 118 | ```bibtex 119 | @inproceedings{pan2024enhancing, 120 | title={Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping}, 121 | author={Pan, Zijie and Lu, Jiachen and Zhu, Xiatian and Zhang, Li}, 122 | booktitle={International Conference on Learning Representations (ICLR)}, 123 | year={2024} 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /assets/group_photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/group_photo.png -------------------------------------------------------------------------------- /assets/group_photo_fan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/group_photo_fan.png -------------------------------------------------------------------------------- /assets/nopgc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/nopgc.png -------------------------------------------------------------------------------- /assets/pgc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/pgc.png -------------------------------------------------------------------------------- /data/0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/data/0000.png -------------------------------------------------------------------------------- /data/background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/data/background.png -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FreqEncoder_torch(nn.Module): 6 | def __init__(self, input_dim, max_freq_log2, N_freqs, 7 | log_sampling=True, include_input=True, 8 | periodic_fns=(torch.sin, torch.cos)): 9 | 10 | super().__init__() 11 | 12 | self.input_dim = input_dim 13 | self.include_input = include_input 14 | self.periodic_fns = periodic_fns 15 | 16 | self.output_dim = 0 17 | if self.include_input: 18 | self.output_dim += self.input_dim 19 | 20 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) 21 | 22 | if log_sampling: 23 | self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs) 24 | else: 25 | self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs) 26 | 27 | self.freq_bands = self.freq_bands.numpy().tolist() 28 | 29 | def forward(self, input, **kwargs): 30 | 31 | out = [] 32 | if self.include_input: 33 | out.append(input) 34 | 35 | for i in range(len(self.freq_bands)): 36 | freq = self.freq_bands[i] 37 | for p_fn in self.periodic_fns: 38 | out.append(p_fn(input * freq)) 39 | 40 | out = torch.cat(out, dim=-1) 41 | 42 | return out 43 | 44 | def get_encoder(encoding, input_dim=3, 45 | multires=6, 46 | degree=4, 47 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, 48 | **kwargs): 49 | 50 | if encoding == 'None': 51 | return lambda x, **kwargs: x, input_dim 52 | 53 | elif encoding == 'frequency_torch': 54 | encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) 55 | 56 | elif encoding == 'frequency': # CUDA implementation, faster than torch. 57 | from freqencoder import FreqEncoder 58 | encoder = FreqEncoder(input_dim=input_dim, degree=multires) 59 | 60 | elif encoding == 'sphere_harmonics': 61 | from shencoder import SHEncoder 62 | encoder = SHEncoder(input_dim=input_dim, degree=degree) 63 | 64 | elif encoding == 'hashgrid': 65 | from gridencoder import GridEncoder 66 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) 67 | 68 | elif encoding == 'tiledgrid': 69 | from gridencoder import GridEncoder 70 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) 71 | 72 | else: 73 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') 74 | 75 | return encoder, encoder.output_dim -------------------------------------------------------------------------------- /freqencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .freq import FreqEncoder -------------------------------------------------------------------------------- /freqencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | '-use_fast_math' 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | _backend = load(name='_freqencoder', 33 | extra_cflags=c_flags, 34 | extra_cuda_cflags=nvcc_flags, 35 | sources=[os.path.join(_src_path, 'src', f) for f in [ 36 | 'freqencoder.cu', 37 | 'bindings.cpp', 38 | ]], 39 | ) 40 | 41 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /freqencoder/freq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _freqencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | 15 | class _freq_encoder(Function): 16 | @staticmethod 17 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 18 | def forward(ctx, inputs, degree, output_dim): 19 | # inputs: [B, input_dim], float 20 | # RETURN: [B, F], float 21 | 22 | if not inputs.is_cuda: inputs = inputs.cuda() 23 | inputs = inputs.contiguous() 24 | 25 | B, input_dim = inputs.shape # batch size, coord dim 26 | 27 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 28 | 29 | _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 30 | 31 | ctx.save_for_backward(inputs, outputs) 32 | ctx.dims = [B, input_dim, degree, output_dim] 33 | 34 | return outputs 35 | 36 | @staticmethod 37 | #@once_differentiable 38 | @custom_bwd 39 | def backward(ctx, grad): 40 | # grad: [B, C * C] 41 | 42 | grad = grad.contiguous() 43 | inputs, outputs = ctx.saved_tensors 44 | B, input_dim, degree, output_dim = ctx.dims 45 | 46 | grad_inputs = torch.zeros_like(inputs) 47 | _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 48 | 49 | return grad_inputs, None, None 50 | 51 | 52 | freq_encode = _freq_encoder.apply 53 | 54 | 55 | class FreqEncoder(nn.Module): 56 | def __init__(self, input_dim=3, degree=4): 57 | super().__init__() 58 | 59 | self.input_dim = input_dim 60 | self.degree = degree 61 | self.output_dim = input_dim + input_dim * 2 * degree 62 | 63 | def __repr__(self): 64 | return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" 65 | 66 | def forward(self, inputs, **kwargs): 67 | # inputs: [..., input_dim] 68 | # return: [..., ] 69 | 70 | prefix_shape = list(inputs.shape[:-1]) 71 | inputs = inputs.reshape(-1, self.input_dim) 72 | 73 | outputs = freq_encode(inputs, self.degree, self.output_dim) 74 | 75 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 76 | 77 | return outputs -------------------------------------------------------------------------------- /freqencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | '-use_fast_math' 11 | ] 12 | 13 | if os.name == "posix": 14 | c_flags = ['-O3', '-std=c++14'] 15 | elif os.name == "nt": 16 | c_flags = ['/O2', '/std:c++17'] 17 | 18 | # find cl.exe 19 | def find_cl_path(): 20 | import glob 21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 22 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 23 | if paths: 24 | return paths[0] 25 | 26 | # If cl.exe is not on path, try to find it. 27 | if os.system("where cl.exe >nul 2>nul") != 0: 28 | cl_path = find_cl_path() 29 | if cl_path is None: 30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 31 | os.environ["PATH"] += ";" + cl_path 32 | 33 | setup( 34 | name='freqencoder', # package name, import this to use python API 35 | ext_modules=[ 36 | CUDAExtension( 37 | name='_freqencoder', # extension name, import this to use CUDA API 38 | sources=[os.path.join(_src_path, 'src', f) for f in [ 39 | 'freqencoder.cu', 40 | 'bindings.cpp', 41 | ]], 42 | extra_compile_args={ 43 | 'cxx': c_flags, 44 | 'nvcc': nvcc_flags, 45 | } 46 | ), 47 | ], 48 | cmdclass={ 49 | 'build_ext': BuildExtension, 50 | } 51 | ) -------------------------------------------------------------------------------- /freqencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "freqencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); 7 | m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 20 | 21 | inline constexpr __device__ float PI() { return 3.141592653589793f; } 22 | 23 | template 24 | __host__ __device__ T div_round_up(T val, T divisor) { 25 | return (val + divisor - 1) / divisor; 26 | } 27 | 28 | // inputs: [B, D] 29 | // outputs: [B, C], C = D + D * deg * 2 30 | __global__ void kernel_freq( 31 | const float * __restrict__ inputs, 32 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 33 | float * outputs 34 | ) { 35 | // parallel on per-element 36 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 37 | if (t >= B * C) return; 38 | 39 | // get index 40 | const uint32_t b = t / C; 41 | const uint32_t c = t - b * C; // t % C; 42 | 43 | // locate 44 | inputs += b * D; 45 | outputs += t; 46 | 47 | // write self 48 | if (c < D) { 49 | outputs[0] = inputs[c]; 50 | // write freq 51 | } else { 52 | const uint32_t col = c / D - 1; 53 | const uint32_t d = c % D; 54 | const uint32_t freq = col / 2; 55 | const float phase_shift = (col % 2) * (PI() / 2); 56 | outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); 57 | } 58 | } 59 | 60 | // grad: [B, C], C = D + D * deg * 2 61 | // outputs: [B, C] 62 | // grad_inputs: [B, D] 63 | __global__ void kernel_freq_backward( 64 | const float * __restrict__ grad, 65 | const float * __restrict__ outputs, 66 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C, 67 | float * grad_inputs 68 | ) { 69 | // parallel on per-element 70 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 71 | if (t >= B * D) return; 72 | 73 | const uint32_t b = t / D; 74 | const uint32_t d = t - b * D; // t % D; 75 | 76 | // locate 77 | grad += b * C; 78 | outputs += b * C; 79 | grad_inputs += t; 80 | 81 | // register 82 | float result = grad[d]; 83 | grad += D; 84 | outputs += D; 85 | 86 | for (uint32_t f = 0; f < deg; f++) { 87 | result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); 88 | grad += 2 * D; 89 | outputs += 2 * D; 90 | } 91 | 92 | // write 93 | grad_inputs[0] = result; 94 | } 95 | 96 | 97 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { 98 | CHECK_CUDA(inputs); 99 | CHECK_CUDA(outputs); 100 | 101 | CHECK_CONTIGUOUS(inputs); 102 | CHECK_CONTIGUOUS(outputs); 103 | 104 | CHECK_IS_FLOATING(inputs); 105 | CHECK_IS_FLOATING(outputs); 106 | 107 | static constexpr uint32_t N_THREADS = 128; 108 | 109 | kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); 110 | } 111 | 112 | 113 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { 114 | CHECK_CUDA(grad); 115 | CHECK_CUDA(outputs); 116 | CHECK_CUDA(grad_inputs); 117 | 118 | CHECK_CONTIGUOUS(grad); 119 | CHECK_CONTIGUOUS(outputs); 120 | CHECK_CONTIGUOUS(grad_inputs); 121 | 122 | CHECK_IS_FLOATING(grad); 123 | CHECK_IS_FLOATING(outputs); 124 | CHECK_IS_FLOATING(grad_inputs); 125 | 126 | static constexpr uint32_t N_THREADS = 128; 127 | 128 | kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); 129 | } -------------------------------------------------------------------------------- /freqencoder/src/freqencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) 7 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); 8 | 9 | // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) 10 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_grid_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'gridencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='gridencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_gridencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'gridencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); 9 | m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)"); 10 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); 14 | 15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); 16 | void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L); 17 | 18 | #endif -------------------------------------------------------------------------------- /guidance/clip_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision.transforms as T 6 | import torchvision.transforms.functional as TF 7 | 8 | import clip 9 | 10 | 11 | class CLIP(nn.Module): 12 | def __init__(self, device, **kwargs): 13 | super().__init__() 14 | 15 | self.device = device 16 | 17 | self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False) 18 | 19 | # image augmentation 20 | self.aug = T.Compose([ 21 | T.Resize((224, 224)), 22 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 23 | ]) 24 | 25 | # self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10)) 26 | 27 | def get_text_embeds(self, prompt, negative_prompt, **kwargs): 28 | # NOTE: negative_prompt is ignored for CLIP. 29 | 30 | text = clip.tokenize(prompt).to(self.device) 31 | text_z = self.clip_model.encode_text(text) 32 | text_z = text_z / text_z.norm(dim=-1, keepdim=True) 33 | 34 | return text_z 35 | 36 | def get_image_embeds(self, pred_rgb): 37 | pred_rgb = self.aug(pred_rgb) 38 | 39 | image_z = self.clip_model.encode_image(pred_rgb) 40 | image_z = image_z / image_z.norm(dim=-1, keepdim=True) 41 | 42 | return image_z 43 | 44 | def train_step(self, text_z, pred_rgb, **kwargs): 45 | pred_rgb = self.aug(pred_rgb) 46 | 47 | image_z = self.clip_model.encode_image(pred_rgb) 48 | image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features 49 | 50 | loss = - (image_z * text_z).sum(-1).mean() 51 | 52 | return loss 53 | 54 | 55 | if __name__ == '__main__': 56 | import os 57 | import numpy as np 58 | import torch 59 | from skimage.io import imread, imsave 60 | import argparse 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--input_path', type=str, default='others/data/compare/spiderman') 64 | parser.add_argument('--data_path', type=str, default='others/output/compare/spiderman') 65 | parser.add_argument('--long_image', action='store_true') 66 | args = parser.parse_args() 67 | 68 | model = CLIP("cuda") 69 | image_size = 256 70 | 71 | input_images = [] 72 | for img_path in os.listdir(args.input_path): 73 | img = imread(os.path.join(args.input_path, img_path)) / 255 74 | if img.shape[-1] == 4: 75 | mask = img[..., 3:] 76 | img = img[..., :3] * mask + (1 - mask) 77 | img = torch.tensor(img[None, ..., :3], dtype=torch.float) 78 | input_images.append(img) 79 | 80 | gen_images = [] 81 | for img_path in os.listdir(args.data_path): 82 | if '.png' in img_path: 83 | img = imread(os.path.join(args.data_path, img_path)) / 255 84 | if args.long_image: 85 | for index in range(0, 16): 86 | rgb = np.copy(img[:, index * image_size:(index + 1) * image_size, :]) 87 | rgb = torch.tensor(rgb[None, ..., :3], dtype=torch.float) 88 | gen_images.append(rgb) 89 | else: 90 | rgb = np.copy(img) 91 | rgb = torch.tensor(rgb[None, ..., :3], dtype=torch.float) 92 | rgb = F.interpolate( 93 | rgb.permute(0, 3, 1, 2), (256, 256), mode="bilinear", align_corners=False 94 | ).permute(0, 2, 3, 1) 95 | gen_images.append(rgb) 96 | 97 | input_images = torch.cat(input_images, dim=0).permute(0, 3, 1, 2).cuda() 98 | gen_images = torch.cat(gen_images, dim=0).permute(0, 3, 1, 2).cuda() 99 | 100 | input_z = model.get_image_embeds(input_images) 101 | gen_z = model.get_image_embeds(gen_images) 102 | similarity = input_z[:, None, ...] * gen_z[None, :, ...] 103 | print(similarity.sum(-1).mean()) 104 | -------------------------------------------------------------------------------- /guidance/guidance_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda.amp import custom_bwd, custom_fwd 3 | import numpy as np 4 | 5 | 6 | class SpecifyGradient(torch.autograd.Function): 7 | @staticmethod 8 | @custom_fwd 9 | def forward(ctx, input_tensor, gt_grad): 10 | ctx.save_for_backward(gt_grad) 11 | # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. 12 | return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) 13 | 14 | @staticmethod 15 | @custom_bwd 16 | def backward(ctx, grad_scale): 17 | gt_grad, = ctx.saved_tensors 18 | gt_grad = gt_grad * grad_scale 19 | return gt_grad, None 20 | 21 | 22 | def seed_everything(seed): 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | # torch.backends.cudnn.deterministic = True 26 | # torch.backends.cudnn.benchmark = True 27 | 28 | 29 | def w_star(t, m1=800, m2=500, s1=300, s2=100): 30 | # max time 1000 31 | r = np.ones_like(t) * 1.0 32 | r[t > m1] = np.exp(-((t[t > m1] - m1) ** 2) / (2 * s1 * s1)) 33 | r[t < m2] = np.exp(-((t[t < m2] - m2) ** 2) / (2 * s2 * s2)) 34 | return r 35 | 36 | 37 | def precompute_prior(T=1000, min_t=200, max_t=800): 38 | ts = np.arange(T) 39 | prior = w_star(ts)[min_t:max_t] 40 | prior = prior / prior.sum() 41 | prior = prior[::-1].cumsum()[::-1] 42 | return prior, min_t 43 | 44 | 45 | def time_prioritize(step_ratio, time_prior, min_t=200): 46 | return np.abs(time_prior - step_ratio).argmin() + min_t 47 | 48 | 49 | def noise_norm(eps): 50 | # [B, 3, H, W] 51 | return torch.sqrt(torch.square(eps).sum(dim=[1, 2, 3])) 52 | 53 | 54 | if __name__ == '__main__': 55 | prior, _ = precompute_prior() 56 | t = time_prioritize(0.5, prior) 57 | 58 | 59 | -------------------------------------------------------------------------------- /ldm/extras.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from omegaconf import OmegaConf 3 | import torch 4 | from ldm.util import instantiate_from_config 5 | import logging 6 | from contextlib import contextmanager 7 | 8 | from contextlib import contextmanager 9 | import logging 10 | 11 | @contextmanager 12 | def all_logging_disabled(highest_level=logging.CRITICAL): 13 | """ 14 | A context manager that will prevent any logging messages 15 | triggered during the body from being processed. 16 | 17 | :param highest_level: the maximum logging level in use. 18 | This would only need to be changed if a custom level greater than CRITICAL 19 | is defined. 20 | 21 | https://gist.github.com/simon-weber/7853144 22 | """ 23 | # two kind-of hacks here: 24 | # * can't get the highest logging level in effect => delegate to the user 25 | # * can't get the current module-level override => use an undocumented 26 | # (but non-private!) interface 27 | 28 | previous_level = logging.root.manager.disable 29 | 30 | logging.disable(highest_level) 31 | 32 | try: 33 | yield 34 | finally: 35 | logging.disable(previous_level) 36 | 37 | def load_training_dir(train_dir, device, epoch="last"): 38 | """Load a checkpoint and config from training directory""" 39 | train_dir = Path(train_dir) 40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) 41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" 42 | config = list(train_dir.rglob(f"*-project.yaml")) 43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}" 44 | if len(config) > 1: 45 | print(f"found {len(config)} matching config files") 46 | config = sorted(config)[-1] 47 | print(f"selecting {config}") 48 | else: 49 | config = config[0] 50 | 51 | 52 | config = OmegaConf.load(config) 53 | return load_model_from_config(config, ckpt[0], device) 54 | 55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 56 | """Loads a model from config and a ckpt 57 | if config is a path will use omegaconf to load 58 | """ 59 | if isinstance(config, (str, Path)): 60 | config = OmegaConf.load(config) 61 | 62 | with all_logging_disabled(): 63 | print(f"Loading model from {ckpt}") 64 | pl_sd = torch.load(ckpt, map_location="cpu") 65 | global_step = pl_sd["global_step"] 66 | sd = pl_sd["state_dict"] 67 | model = instantiate_from_config(config.model) 68 | m, u = model.load_state_dict(sd, strict=False) 69 | if len(m) > 0 and verbose: 70 | print("missing keys:") 71 | print(m) 72 | if len(u) > 0 and verbose: 73 | print("unexpected keys:") 74 | model.to(device) 75 | model.eval() 76 | model.cond_stage_model.device = device 77 | return model -------------------------------------------------------------------------------- /ldm/guidance.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from scipy import interpolate 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import abc 8 | 9 | 10 | class GuideModel(torch.nn.Module, abc.ABC): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | @abc.abstractmethod 15 | def preprocess(self, x_img): 16 | pass 17 | 18 | @abc.abstractmethod 19 | def compute_loss(self, inp): 20 | pass 21 | 22 | 23 | class Guider(torch.nn.Module): 24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): 25 | """Apply classifier guidance 26 | 27 | Specify a guidance scale as either a scalar 28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g. 29 | [(0, 10), (0.5, 20), (1, 50)] 30 | """ 31 | super().__init__() 32 | self.sampler = sampler 33 | self.index = 0 34 | self.show = verbose 35 | self.guide_model = guide_model 36 | self.history = [] 37 | 38 | if isinstance(scale, (Tuple, List)): 39 | times = np.array([x[0] for x in scale]) 40 | values = np.array([x[1] for x in scale]) 41 | self.scale_schedule = {"times": times, "values": values} 42 | else: 43 | self.scale_schedule = float(scale) 44 | 45 | self.ddim_timesteps = sampler.ddim_timesteps 46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps 47 | 48 | 49 | def get_scales(self): 50 | if isinstance(self.scale_schedule, float): 51 | return len(self.ddim_timesteps)*[self.scale_schedule] 52 | 53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) 54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps 55 | return interpolater(fractional_steps) 56 | 57 | def modify_score(self, model, e_t, x, t, c): 58 | 59 | # TODO look up index by t 60 | scale = self.get_scales()[self.index] 61 | 62 | if (scale == 0): 63 | return e_t 64 | 65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) 66 | with torch.enable_grad(): 67 | x_in = x.detach().requires_grad_(True) 68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) 69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) 70 | 71 | inp = self.guide_model.preprocess(x_img) 72 | loss = self.guide_model.compute_loss(inp) 73 | grads = torch.autograd.grad(loss.sum(), x_in)[0] 74 | correction = grads * scale 75 | 76 | if self.show: 77 | clear_output(wait=True) 78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) 79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) 80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) 81 | plt.axis('off') 82 | plt.show() 83 | plt.imshow(correction[0][0].detach().cpu()) 84 | plt.axis('off') 85 | plt.show() 86 | 87 | 88 | e_t_mod = e_t - sqrt_1ma*correction 89 | if self.show: 90 | fig, axs = plt.subplots(1, 3) 91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) 92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) 93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) 94 | plt.show() 95 | self.index += 1 96 | return e_t_mod -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def renorm_thresholding(x0, value): 15 | # renorm 16 | pred_max = x0.max() 17 | pred_min = x0.min() 18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 20 | 21 | s = torch.quantile( 22 | rearrange(pred_x0, 'b ... -> b (...)').abs(), 23 | value, 24 | dim=-1 25 | ) 26 | s.clamp_(min=1.0) 27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 28 | 29 | # clip by threshold 30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 31 | 32 | # temporary hack: numpy on cpu 33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() 34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 35 | 36 | # re.renorm 37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 39 | return pred_x0 40 | 41 | 42 | def norm_thresholding(x0, value): 43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 44 | return x0 * (value / s) 45 | 46 | 47 | def spatial_norm_thresholding(x0, value): 48 | # b c h w 49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 50 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/evaluate/frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Minimal Reference implementation for the Frechet Video Distance (FVD). 18 | 19 | FVD is a metric for the quality of video generation models. It is inspired by 20 | the FID (Frechet Inception Distance) used for images, but uses a different 21 | embedding to be better suitable for videos. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | 29 | import six 30 | import tensorflow.compat.v1 as tf 31 | import tensorflow_gan as tfgan 32 | import tensorflow_hub as hub 33 | 34 | 35 | def preprocess(videos, target_resolution): 36 | """Runs some preprocessing on the videos for I3D model. 37 | 38 | Args: 39 | videos: [batch_size, num_frames, height, width, depth] The videos to be 40 | preprocessed. We don't care about the specific dtype of the videos, it can 41 | be anything that tf.image.resize_bilinear accepts. Values are expected to 42 | be in the range 0-255. 43 | target_resolution: (width, height): target video resolution 44 | 45 | Returns: 46 | videos: [batch_size, num_frames, height, width, depth] 47 | """ 48 | videos_shape = list(videos.shape) 49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) 51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 52 | output_videos = tf.reshape(resized_videos, target_shape) 53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 54 | return scaled_videos 55 | 56 | 57 | def _is_in_graph(tensor_name): 58 | """Checks whether a given tensor does exists in the graph.""" 59 | try: 60 | tf.get_default_graph().get_tensor_by_name(tensor_name) 61 | except KeyError: 62 | return False 63 | return True 64 | 65 | 66 | def create_id3_embedding(videos,warmup=False,batch_size=16): 67 | """Embeds the given videos using the Inflated 3D Convolution ne twork. 68 | 69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the 70 | first call. 71 | 72 | Args: 73 | videos: [batch_size, num_frames, height=224, width=224, depth=3]. 74 | Expected range is [-1, 1]. 75 | 76 | Returns: 77 | embedding: [batch_size, embedding_size]. embedding_size depends 78 | on the model used. 79 | 80 | Raises: 81 | ValueError: when a provided embedding_layer is not supported. 82 | """ 83 | 84 | # batch_size = 16 85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 86 | 87 | 88 | # Making sure that we import the graph separately for 89 | # each different input video tensor. 90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 91 | videos.name).replace(":", "_") 92 | 93 | 94 | 95 | assert_ops = [ 96 | tf.Assert( 97 | tf.reduce_max(videos) <= 1.001, 98 | ["max value in frame is > 1", videos]), 99 | tf.Assert( 100 | tf.reduce_min(videos) >= -1.001, 101 | ["min value in frame is < -1", videos]), 102 | tf.assert_equal( 103 | tf.shape(videos)[0], 104 | batch_size, ["invalid frame batch size: ", 105 | tf.shape(videos)], 106 | summarize=6), 107 | ] 108 | with tf.control_dependencies(assert_ops): 109 | videos = tf.identity(videos) 110 | 111 | module_scope = "%s_apply_default/" % module_name 112 | 113 | # To check whether the module has already been loaded into the graph, we look 114 | # for a given tensor name. If this tensor name exists, we assume the function 115 | # has been called before and the graph was imported. Otherwise we import it. 116 | # Note: in theory, the tensor could exist, but have wrong shapes. 117 | # This will happen if create_id3_embedding is called with a frames_placehoder 118 | # of wrong size/batch size, because even though that will throw a tf.Assert 119 | # on graph-execution time, it will insert the tensor (with wrong shape) into 120 | # the graph. This is why we need the following assert. 121 | if warmup: 122 | video_batch_size = int(videos.shape[0]) 123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" 124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 125 | if not _is_in_graph(tensor_name): 126 | i3d_model = hub.Module(module_spec, name=module_name) 127 | i3d_model(videos) 128 | 129 | # gets the kinetics-i3d-400-logits layer 130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) 132 | return tensor 133 | 134 | 135 | def calculate_fvd(real_activations, 136 | generated_activations): 137 | """Returns a list of ops that compute metrics as funcs of activations. 138 | 139 | Args: 140 | real_activations: [num_samples, embedding_size] 141 | generated_activations: [num_samples, embedding_size] 142 | 143 | Returns: 144 | A scalar that contains the requested FVD. 145 | """ 146 | return tfgan.eval.frechet_classifier_distance_from_activations( 147 | real_activations, generated_activations) 148 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor( 14 | [ 15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) 16 | for x in range(window_size) 17 | ] 18 | ) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable( 26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 27 | ) 28 | return window 29 | 30 | 31 | def _ssim( 32 | img1, img2, window, window_size, channel, mask=None, size_average=True 33 | ): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1 * mu2 40 | 41 | sigma1_sq = ( 42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 43 | - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 47 | - mu2_sq 48 | ) 49 | sigma12 = ( 50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 51 | - mu1_mu2 52 | ) 53 | 54 | C1 = (0.01) ** 2 55 | C2 = (0.03) ** 2 56 | 57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 59 | ) 60 | 61 | if not (mask is None): 62 | b = mask.size(0) 63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 65 | dim=1 66 | ).clamp(min=1) 67 | return ssim_map 68 | 69 | import pdb 70 | 71 | pdb.set_trace 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | 78 | 79 | class SSIM(torch.nn.Module): 80 | def __init__(self, window_size=11, size_average=True): 81 | super(SSIM, self).__init__() 82 | self.window_size = window_size 83 | self.size_average = size_average 84 | self.channel = 1 85 | self.window = create_window(window_size, self.channel) 86 | 87 | def forward(self, img1, img2, mask=None): 88 | (_, channel, _, _) = img1.size() 89 | 90 | if ( 91 | channel == self.channel 92 | and self.window.data.type() == img1.data.type() 93 | ): 94 | window = self.window 95 | else: 96 | window = create_window(self.window_size, channel) 97 | 98 | if img1.is_cuda: 99 | window = window.cuda(img1.get_device()) 100 | window = window.type_as(img1) 101 | 102 | self.window = window 103 | self.channel = channel 104 | 105 | return _ssim( 106 | img1, 107 | img2, 108 | window, 109 | self.window_size, 110 | channel, 111 | mask, 112 | self.size_average, 113 | ) 114 | 115 | 116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 117 | (_, channel, _, _) = img1.size() 118 | window = create_window(window_size, channel) 119 | 120 | if img1.is_cuda: 121 | window = window.cuda(img1.get_device()) 122 | window = window.type_as(img1) 123 | 124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 125 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from collections import namedtuple 4 | import torch 5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 6 | 7 | """ 8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 9 | """ 10 | 11 | 12 | class Flatten(Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), -1) 15 | 16 | 17 | def l2_norm(input, axis=1): 18 | norm = torch.norm(input, 2, axis, True) 19 | output = torch.div(input, norm) 20 | return output 21 | 22 | 23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 24 | """ A named tuple describing a ResNet block. """ 25 | 26 | 27 | def get_block(in_channel, depth, num_units, stride=2): 28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 29 | 30 | 31 | def get_blocks(num_layers): 32 | if num_layers == 50: 33 | blocks = [ 34 | get_block(in_channel=64, depth=64, num_units=3), 35 | get_block(in_channel=64, depth=128, num_units=4), 36 | get_block(in_channel=128, depth=256, num_units=14), 37 | get_block(in_channel=256, depth=512, num_units=3) 38 | ] 39 | elif num_layers == 100: 40 | blocks = [ 41 | get_block(in_channel=64, depth=64, num_units=3), 42 | get_block(in_channel=64, depth=128, num_units=13), 43 | get_block(in_channel=128, depth=256, num_units=30), 44 | get_block(in_channel=256, depth=512, num_units=3) 45 | ] 46 | elif num_layers == 152: 47 | blocks = [ 48 | get_block(in_channel=64, depth=64, num_units=3), 49 | get_block(in_channel=64, depth=128, num_units=8), 50 | get_block(in_channel=128, depth=256, num_units=36), 51 | get_block(in_channel=256, depth=512, num_units=3) 52 | ] 53 | else: 54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 55 | return blocks 56 | 57 | 58 | class SEModule(Module): 59 | def __init__(self, channels, reduction): 60 | super(SEModule, self).__init__() 61 | self.avg_pool = AdaptiveAvgPool2d(1) 62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 63 | self.relu = ReLU(inplace=True) 64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 65 | self.sigmoid = Sigmoid() 66 | 67 | def forward(self, x): 68 | module_input = x 69 | x = self.avg_pool(x) 70 | x = self.fc1(x) 71 | x = self.relu(x) 72 | x = self.fc2(x) 73 | x = self.sigmoid(x) 74 | return module_input * x 75 | 76 | 77 | class bottleneck_IR(Module): 78 | def __init__(self, in_channel, depth, stride): 79 | super(bottleneck_IR, self).__init__() 80 | if in_channel == depth: 81 | self.shortcut_layer = MaxPool2d(1, stride) 82 | else: 83 | self.shortcut_layer = Sequential( 84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 85 | BatchNorm2d(depth) 86 | ) 87 | self.res_layer = Sequential( 88 | BatchNorm2d(in_channel), 89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = self.shortcut_layer(x) 95 | res = self.res_layer(x) 96 | return res + shortcut 97 | 98 | 99 | class bottleneck_IR_SE(Module): 100 | def __init__(self, in_channel, depth, stride): 101 | super(bottleneck_IR_SE, self).__init__() 102 | if in_channel == depth: 103 | self.shortcut_layer = MaxPool2d(1, stride) 104 | else: 105 | self.shortcut_layer = Sequential( 106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 107 | BatchNorm2d(depth) 108 | ) 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | shortcut = self.shortcut_layer(x) 120 | res = self.res_layer(x) 121 | return res + shortcut -------------------------------------------------------------------------------- /ldm/thirdp/psp/id_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | import torch 3 | from torch import nn 4 | from ldm.thirdp.psp.model_irse import Backbone 5 | 6 | 7 | class IDFeatures(nn.Module): 8 | def __init__(self, model_path): 9 | super(IDFeatures, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def forward(self, x, crop=False): 17 | # Not sure of the image range here 18 | if crop: 19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area") 20 | x = x[:, :, 35:223, 32:220] 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/model_irse.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 5 | 6 | """ 7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Backbone(Module): 12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 13 | super(Backbone, self).__init__() 14 | assert input_size in [112, 224], "input_size should be 112 or 224" 15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 17 | blocks = get_blocks(num_layers) 18 | if mode == 'ir': 19 | unit_module = bottleneck_IR 20 | elif mode == 'ir_se': 21 | unit_module = bottleneck_IR_SE 22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 23 | BatchNorm2d(64), 24 | PReLU(64)) 25 | if input_size == 112: 26 | self.output_layer = Sequential(BatchNorm2d(512), 27 | Dropout(drop_ratio), 28 | Flatten(), 29 | Linear(512 * 7 * 7, 512), 30 | BatchNorm1d(512, affine=affine)) 31 | else: 32 | self.output_layer = Sequential(BatchNorm2d(512), 33 | Dropout(drop_ratio), 34 | Flatten(), 35 | Linear(512 * 14 * 14, 512), 36 | BatchNorm1d(512, affine=affine)) 37 | 38 | modules = [] 39 | for block in blocks: 40 | for bottleneck in block: 41 | modules.append(unit_module(bottleneck.in_channel, 42 | bottleneck.depth, 43 | bottleneck.stride)) 44 | self.body = Sequential(*modules) 45 | 46 | def forward(self, x): 47 | x = self.input_layer(x) 48 | x = self.body(x) 49 | x = self.output_layer(x) 50 | return l2_norm(x) 51 | 52 | 53 | def IR_50(input_size): 54 | """Constructs a ir-50 model.""" 55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 56 | return model 57 | 58 | 59 | def IR_101(input_size): 60 | """Constructs a ir-101 model.""" 61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 62 | return model 63 | 64 | 65 | def IR_152(input_size): 66 | """Constructs a ir-152 model.""" 67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 68 | return model 69 | 70 | 71 | def IR_SE_50(input_size): 72 | """Constructs a ir_se-50 model.""" 73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 74 | return model 75 | 76 | 77 | def IR_SE_101(input_size): 78 | """Constructs a ir_se-101 model.""" 79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 80 | return model 81 | 82 | 83 | def IR_SE_152(input_size): 84 | """Constructs a ir_se-152 model.""" 85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 86 | return model -------------------------------------------------------------------------------- /nerf/fine/dmtet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | 11 | 12 | def get_center_boundary_index(verts): 13 | length_ = torch.sum(verts ** 2, dim=-1) 14 | center_idx = torch.argmin(length_) 15 | boundary_neg = verts == verts.max() 16 | boundary_pos = verts == verts.min() 17 | boundary = torch.bitwise_or(boundary_pos, boundary_neg) 18 | boundary = torch.sum(boundary.float(), dim=-1) 19 | boundary_idx = torch.nonzero(boundary) 20 | return center_idx, boundary_idx.squeeze(dim=-1) 21 | -------------------------------------------------------------------------------- /nerf/fine/generate_tets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | 12 | ''' 13 | This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, 14 | to generate a tet grid 15 | 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` 16 | 2) Run the function below to generate a file `cube_32_tet.tet` 17 | ''' 18 | 19 | 20 | def generate_tetrahedron_grid_file(res=32, root='..'): 21 | frac = 1.0 / res 22 | command = 'cd %s/quartet; ' % (root) + \ 23 | './quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res) 24 | os.system(command) 25 | 26 | 27 | ''' 28 | This code segment shows how to convert from a quartet .tet file to compressed npz file 29 | ''' 30 | 31 | 32 | def generate_tetrahedrons(res=50, root='..'): 33 | tetrahedrons = [] 34 | vertices = [] 35 | if res > 1.0: 36 | res = 1.0 / res 37 | 38 | root_path = os.path.join(root, 'quartet/meshes') 39 | file_name = os.path.join(root_path, 'cube_%f_tet.tet' % (res)) 40 | 41 | # generate tetrahedron is not exist files 42 | if not os.path.exists(file_name): 43 | command = 'cd %s/quartet; ' % (root) + \ 44 | './quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (res, res, res) 45 | os.system(command) 46 | 47 | with open(file_name, 'r') as f: 48 | line = f.readline() 49 | line = line.strip().split(' ') 50 | n_vert = int(line[1]) 51 | n_t = int(line[2]) 52 | for i in range(n_vert): 53 | line = f.readline() 54 | line = line.strip().split(' ') 55 | assert len(line) == 3 56 | vertices.append([float(v) for v in line]) 57 | for i in range(n_t): 58 | line = f.readline() 59 | line = line.strip().split(' ') 60 | assert len(line) == 4 61 | tetrahedrons.append([int(v) for v in line]) 62 | 63 | assert len(tetrahedrons) == n_t 64 | assert len(vertices) == n_vert 65 | vertices = np.asarray(vertices) 66 | vertices[vertices <= (0 + res / 4.0)] = 0 # determine the boundary point 67 | vertices[vertices >= (1 - res / 4.0)] = 1 # determine the boundary point 68 | np.savez_compressed('%d_compress' % (res), vertices=vertices, tets=tetrahedrons) 69 | return vertices, tetrahedrons 70 | -------------------------------------------------------------------------------- /nerf/fine/neural_render.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import nvdiffrast.torch as dr 12 | 13 | _FG_LUT = None 14 | 15 | 16 | def interpolate(attr, rast, attr_idx, rast_db=None): 17 | return dr.interpolate( 18 | attr.contiguous(), rast, attr_idx, rast_db=rast_db, 19 | diff_attrs=None if rast_db is None else 'all') 20 | 21 | 22 | def xfm_points(points, matrix, use_python=True): 23 | '''Transform points. 24 | Args: 25 | points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] 26 | matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] 27 | use_python: Use PyTorch's torch.matmul (for validation) 28 | Returns: 29 | Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. 30 | ''' 31 | out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) 32 | if torch.is_anomaly_enabled(): 33 | assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" 34 | return out 35 | 36 | 37 | class NeuralRender: 38 | def __init__(self, device='cuda', camera_model=None): 39 | super(NeuralRender, self).__init__() 40 | self.device = device 41 | self.ctx = None 42 | self.projection_mtx = None 43 | self.camera = camera_model 44 | 45 | def render_mesh( 46 | self, 47 | mesh_v_pos_bxnx3, 48 | mesh_t_pos_idx_fx3, 49 | camera_mv_bx4x4, 50 | mesh_v_feat_bxnxd, 51 | resolution=256, 52 | spp=1, 53 | device='cuda', 54 | hierarchical_mask=False 55 | ): 56 | assert not hierarchical_mask 57 | if self.ctx is None: 58 | # self.ctx = dr.RasterizeGLContext(device=self.device) 59 | self.ctx = dr.RasterizeCudaContext(device=self.device) 60 | 61 | mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 62 | # v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates 63 | mtx_in = mtx_in.unsqueeze(dim=1) 64 | v_pos = mesh_v_pos_bxnx3 - mtx_in[..., :3, 3] 65 | v_pos = v_pos @ mtx_in[:, 0, :3, :3] 66 | v_pos = torch.nn.functional.pad(v_pos, pad=(0, 1), mode='constant', value=1.0) 67 | v_pos_clip = self.camera.project(v_pos).float() # Projection in the camera 68 | 69 | # Render the image, 70 | # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render 71 | num_layers = 1 72 | mask_pyramid = None 73 | assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes 74 | mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd, v_pos], dim=-1) # Concatenate the pos compute the supervision 75 | 76 | with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: 77 | for _ in range(num_layers): 78 | rast, db = peeler.rasterize_next_layer() 79 | gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) 80 | 81 | hard_mask = torch.clamp(rast[..., -1:], 0, 1) 82 | antialias_mask = dr.antialias( 83 | hard_mask.clone().contiguous(), rast, v_pos_clip, 84 | mesh_t_pos_idx_fx3) 85 | 86 | albedo = gb_feat[..., 3:6] 87 | antialias_albedo = dr.antialias( 88 | albedo.clone().contiguous(), rast, v_pos_clip, 89 | mesh_t_pos_idx_fx3) 90 | 91 | depth = gb_feat[..., -2:-1] 92 | ori_mesh_feature = gb_feat[..., :-4] 93 | return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, antialias_albedo 94 | -------------------------------------------------------------------------------- /nerf/fine/perspective_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import numpy as np 11 | 12 | 13 | def projection(x=0.1, n=1.0, f=50.0, near_plane=None): 14 | if near_plane is None: 15 | near_plane = n 16 | # return np.array( 17 | # [[n / x, 0, 0, 0], 18 | # [0, n / -x, 0, 0], 19 | # [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], 20 | # [0, 0, -1, 0]]).astype(np.float32) 21 | return np.array( 22 | [[n / x, 0, 0, 0], 23 | [0, n / x, 0, 0], 24 | [0, 0, f / (f - near_plane), -(f * near_plane) / (f - near_plane)], 25 | [0, 0, 1, 0]]).astype(np.float32) 26 | 27 | class PerspectiveCamera: 28 | def __init__(self, focal=None, fovy=49.0, device='cuda'): 29 | super(PerspectiveCamera, self).__init__() 30 | self.device = device 31 | if focal is None: 32 | focal = np.tan(fovy / 180.0 * np.pi * 0.5) 33 | self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) 34 | 35 | def project(self, points_bxnx4): 36 | out = torch.matmul( 37 | points_bxnx4, 38 | torch.transpose(self.proj_mtx, 1, 2)) 39 | return out 40 | -------------------------------------------------------------------------------- /nerf/fine/rast_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from sklearn.neighbors import NearestNeighbors 7 | from sklearn.mixture import GaussianMixture 8 | 9 | import xatlas 10 | from encoding import get_encoder 11 | 12 | import others.nvdiffrec.render.renderutils as ru 13 | from others.nvdiffrec.render import obj 14 | from others.nvdiffrec.render import material 15 | from others.nvdiffrec.render import util 16 | from others.nvdiffrec.render import mesh 17 | from others.nvdiffrec.render import texture 18 | #from others.nvdiffrec.render import mlptexture 19 | from others.nvdiffrec.render import light 20 | from others.nvdiffrec.render import render 21 | 22 | 23 | class MLP(nn.Module): 24 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): 25 | super().__init__() 26 | self.dim_in = dim_in 27 | self.dim_out = dim_out 28 | self.dim_hidden = dim_hidden 29 | self.num_layers = num_layers 30 | 31 | net = [] 32 | for l in range(num_layers): 33 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, 34 | self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) 35 | 36 | self.net = nn.ModuleList(net) 37 | 38 | def forward(self, x): 39 | for l in range(self.num_layers): 40 | x = self.net[l](x) 41 | if l != self.num_layers - 1: 42 | x = F.relu(x, inplace=True) 43 | return x 44 | 45 | 46 | class Material_mlp(nn.Module): 47 | def __init__(self, min_max): 48 | super().__init__() 49 | self.min_max = min_max 50 | self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, 51 | log2_hashmap_size=16, 52 | desired_resolution=2 ** 12, 53 | base_resolution=2 ** 4, 54 | level_dim=4) 55 | self.texture_MLP = MLP(self.in_dim, 9, 32, 2, bias=True) 56 | 57 | def sample(self, texc): 58 | _texc = texc.view(-1, 3) 59 | p_enc = self.encoder(_texc.contiguous()) 60 | out = self.texture_MLP(p_enc) 61 | 62 | # Sigmoid limit and scale to the allowed range 63 | out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] 64 | 65 | return out.view(*texc.shape[:-1], 9) 66 | 67 | 68 | def safe_normalize(x, eps=1e-20): 69 | return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) 70 | 71 | 72 | def compute_targets(v, vn, n_targets, vc=None): 73 | knn = NearestNeighbors(n_neighbors=20, algorithm='kd_tree').fit(v) 74 | _, indices = knn.kneighbors(v) 75 | vn_neighbors = vn[indices] 76 | if vc is not None: 77 | vc_neighbors = vc[indices] 78 | 79 | v_num = v.shape[0] 80 | normal_weights = np.ones(v_num) 81 | rgb_weights = np.ones(v_num) 82 | for i in range(v_num): 83 | normal_weights[i] = np.linalg.norm(np.cov(vn_neighbors[i].T), 'fro') 84 | if vc is not None: 85 | rgb_weights[i] = np.linalg.norm(np.cov(vc_neighbors[i].T), 'fro') 86 | 87 | v_norm = np.sqrt((v ** 2).sum(1)) 88 | unit_v = v / v_norm[..., None] 89 | k = n_targets 90 | gmm = GaussianMixture(n_components=k, covariance_type='full', random_state=0) 91 | gmm.fit(unit_v) 92 | pred_label = gmm.predict(unit_v) 93 | pred_prob = gmm.predict_proba(unit_v).max(1) 94 | weights_ = pred_prob * (v_norm ** 4) * normal_weights * rgb_weights 95 | targets = np.zeros([k, 3]) 96 | radius = np.ones([k]) 97 | weights = gmm.weights_ 98 | for i in range(k): 99 | index = pred_label == i 100 | w_sum = weights_[index].sum() 101 | targets[i] = (v[index] * weights_[index][..., None]).sum(0) / w_sum 102 | radius[i] = np.quantile(np.sqrt(((v[index] - targets[i]) ** 2).sum(1)), 0.3) 103 | weights[i] = weights[i] * w_sum 104 | 105 | weights = weights / weights.sum() 106 | return targets, weights, radius, weights_ 107 | 108 | 109 | # for save mesh 110 | @torch.no_grad() 111 | def xatlas_uvmap(glctx, eval_mesh, mat, device, 112 | kd_min_max, ks_min_max, nrm_min_max): 113 | 114 | # Create uvs with xatlas 115 | v_pos = eval_mesh.v_pos.detach().cpu().numpy() 116 | t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() 117 | vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) 118 | 119 | # Convert to tensors 120 | indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) 121 | 122 | uvs = torch.tensor(uvs, dtype=torch.float32, device=device) 123 | faces = torch.tensor(indices_int64, dtype=torch.int64, device=device) 124 | 125 | new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) 126 | 127 | mask, kd, ks, normal = render.render_uv(glctx, new_mesh, [2048, 2048], eval_mesh.material['kd_ks_normal']) 128 | 129 | # if FLAGS.layers > 1: 130 | # kd = torch.cat((kd, torch.rand_like(kd[..., 0:1])), dim=-1) 131 | 132 | new_mesh.material = material.Material({ 133 | 'bsdf': mat['bsdf'], 134 | 'kd': texture.Texture2D(kd, min_max=kd_min_max), 135 | 'ks': texture.Texture2D(ks, min_max=ks_min_max), 136 | 'normal': texture.Texture2D(normal, min_max=nrm_min_max), 137 | 'no_perturbed_nrm': mat['no_perturbed_nrm'] 138 | }) 139 | 140 | return new_mesh 141 | -------------------------------------------------------------------------------- /nerf/fine/tets/256_compress.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/nerf/fine/tets/256_compress.npz -------------------------------------------------------------------------------- /nerf/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | import math 5 | import imageio 6 | import random 7 | import warnings 8 | import tensorboardX 9 | 10 | import numpy as np 11 | import pandas as pd 12 | 13 | import time 14 | from datetime import datetime 15 | 16 | import cv2 17 | import matplotlib.pyplot as plt 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.optim as optim 22 | import torch.nn.functional as F 23 | import torch.distributed as dist 24 | from torch.utils.data import Dataset, DataLoader 25 | 26 | import trimesh 27 | from rich.console import Console 28 | from torch_ema import ExponentialMovingAverage 29 | 30 | from packaging import version as pver 31 | 32 | 33 | def custom_meshgrid(*args): 34 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid 35 | if pver.parse(torch.__version__) < pver.parse('1.10'): 36 | return torch.meshgrid(*args) 37 | else: 38 | return torch.meshgrid(*args, indexing='ij') 39 | 40 | 41 | def safe_normalize(x, eps=1e-20): 42 | return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) 43 | 44 | 45 | @torch.cuda.amp.autocast(enabled=False) 46 | def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): 47 | ''' get rays 48 | Args: 49 | poses: [B, 4, 4], cam2world 50 | intrinsics: [4] 51 | H, W, N: int 52 | error_map: [B, 128 * 128], sample probability based on training error 53 | Returns: 54 | rays_o, rays_d: [B, N, 3] 55 | inds: [B, N] 56 | ''' 57 | 58 | device = poses.device 59 | B = poses.shape[0] 60 | fx, fy, cx, cy = intrinsics 61 | 62 | i, j = custom_meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device)) 63 | i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 64 | j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 65 | 66 | results = {} 67 | 68 | if N > 0: 69 | N = min(N, H * W) 70 | 71 | if error_map is None: 72 | inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate 73 | inds = inds.expand([B, N]) 74 | else: 75 | 76 | # weighted sample on a low-reso grid 77 | inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128) 78 | 79 | # map to the original resolution with random perturb. 80 | inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway. 81 | sx, sy = H / 128, W / 128 82 | inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1) 83 | inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1) 84 | inds = inds_x * W + inds_y 85 | 86 | results['inds_coarse'] = inds_coarse # need this when updating error_map 87 | 88 | i = torch.gather(i, -1, inds) 89 | j = torch.gather(j, -1, inds) 90 | 91 | results['inds'] = inds 92 | 93 | else: 94 | inds = torch.arange(H * W, device=device).expand([B, H * W]) 95 | 96 | zs = torch.ones_like(i) 97 | xs = (i - cx) / fx * zs 98 | ys = (j - cy) / fy * zs 99 | directions = torch.stack((xs, ys, zs), dim=-1) 100 | directions = safe_normalize(directions) 101 | rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) 102 | 103 | rays_o = poses[..., :3, 3] # [B, 3] 104 | rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] 105 | 106 | results['rays_o'] = rays_o 107 | results['rays_d'] = rays_d 108 | 109 | return results 110 | 111 | 112 | def seed_everything(seed): 113 | random.seed(seed) 114 | os.environ['PYTHONHASHSEED'] = str(seed) 115 | np.random.seed(seed) 116 | torch.manual_seed(seed) 117 | torch.cuda.manual_seed(seed) 118 | torch.cuda.manual_seed_all(seed) 119 | # torch.backends.cudnn.deterministic = True 120 | # torch.backends.cudnn.benchmark = True 121 | 122 | 123 | def torch_vis_2d(x, renormalize=False): 124 | # x: [3, H, W] or [1, H, W] or [H, W] 125 | import matplotlib.pyplot as plt 126 | import numpy as np 127 | import torch 128 | 129 | if isinstance(x, torch.Tensor): 130 | if len(x.shape) == 3: 131 | x = x.permute(1, 2, 0).squeeze() 132 | x = x.detach().cpu().numpy() 133 | 134 | print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') 135 | 136 | x = x.astype(np.float32) 137 | 138 | # renormalize 139 | if renormalize: 140 | x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) 141 | 142 | plt.imshow(x) 143 | plt.show() 144 | 145 | 146 | @torch.jit.script 147 | def linear_to_srgb(x): 148 | return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) 149 | 150 | 151 | @torch.jit.script 152 | def srgb_to_linear(x): 153 | return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) 154 | -------------------------------------------------------------------------------- /others/nvdiffrec/data/irrmaps/README.txt: -------------------------------------------------------------------------------- 1 | The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop 2 | CC0 License. 3 | 4 | -------------------------------------------------------------------------------- /others/nvdiffrec/data/irrmaps/aerodynamics_workshop_2k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/aerodynamics_workshop_2k.hdr -------------------------------------------------------------------------------- /others/nvdiffrec/data/irrmaps/blocky_photo_studio_2k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/blocky_photo_studio_2k.hdr -------------------------------------------------------------------------------- /others/nvdiffrec/data/irrmaps/bsdf_256_256.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/bsdf_256_256.bin -------------------------------------------------------------------------------- /others/nvdiffrec/data/irrmaps/mud_road_puresky_4k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/mud_road_puresky_4k.hdr -------------------------------------------------------------------------------- /others/nvdiffrec/render/light.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import os 11 | import numpy as np 12 | import torch 13 | import nvdiffrast.torch as dr 14 | 15 | from . import util 16 | from . import renderutils as ru 17 | 18 | ###################################################################################### 19 | # Utility functions 20 | ###################################################################################### 21 | 22 | class cubemap_mip(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, cubemap): 25 | return util.avg_pool_nhwc(cubemap, (2,2)) 26 | 27 | @staticmethod 28 | def backward(ctx, dout): 29 | res = dout.shape[1] * 2 30 | out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda") 31 | for s in range(6): 32 | gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), 33 | torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"), 34 | indexing='ij') 35 | v = util.safe_normalize(util.cube_to_dir(s, gx, gy)) 36 | out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube') 37 | return out 38 | 39 | ###################################################################################### 40 | # Split-sum environment map light source with automatic mipmap generation 41 | ###################################################################################### 42 | 43 | class EnvironmentLight(torch.nn.Module): 44 | LIGHT_MIN_RES = 16 45 | 46 | MIN_ROUGHNESS = 0.08 47 | MAX_ROUGHNESS = 0.5 48 | 49 | def __init__(self, base): 50 | super(EnvironmentLight, self).__init__() 51 | self.mtx = None 52 | self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=False) 53 | self.register_parameter('env_base', self.base) 54 | 55 | def xfm(self, mtx): 56 | self.mtx = mtx 57 | 58 | def clone(self): 59 | return EnvironmentLight(self.base.clone().detach()) 60 | 61 | def clamp_(self, min=None, max=None): 62 | self.base.clamp_(min, max) 63 | 64 | def get_mip(self, roughness): 65 | return torch.where(roughness < self.MAX_ROUGHNESS 66 | , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2) 67 | , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2) 68 | 69 | def build_mips(self, cutoff=0.99): 70 | self.specular = [self.base] 71 | while self.specular[-1].shape[1] > self.LIGHT_MIN_RES: 72 | self.specular += [cubemap_mip.apply(self.specular[-1])] 73 | 74 | self.diffuse = ru.diffuse_cubemap(self.specular[-1]) 75 | 76 | for idx in range(len(self.specular) - 1): 77 | roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS 78 | self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff) 79 | self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff) 80 | 81 | def regularizer(self): 82 | white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0 83 | return torch.mean(torch.abs(self.base - white)) 84 | 85 | def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): 86 | wo = util.safe_normalize(view_pos - gb_pos) 87 | 88 | if specular: 89 | roughness = ks[..., 1:2] # y component 90 | metallic = ks[..., 2:3] # z component 91 | spec_col = (1.0 - metallic)*0.04 + kd * metallic 92 | diff_col = kd * (1.0 - metallic) 93 | else: 94 | diff_col = kd 95 | spec_col = None 96 | 97 | reflvec = util.safe_normalize(util.reflect(wo, gb_normal)) 98 | nrmvec = gb_normal 99 | if self.mtx is not None: # Rotate lookup 100 | mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda') 101 | reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape) 102 | nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape) 103 | 104 | # Diffuse lookup 105 | diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube') 106 | shaded_col = diffuse * diff_col 107 | 108 | if specular: 109 | # Lookup FG term from lookup texture 110 | NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4) 111 | fg_uv = torch.cat((NdotV, roughness), dim=-1) 112 | if not hasattr(self, '_FG_LUT'): 113 | self._FG_LUT = torch.as_tensor(np.fromfile('others/nvdiffrec/data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda') 114 | fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp') 115 | 116 | # Roughness adjusted specular env lookup 117 | miplevel = self.get_mip(roughness) 118 | spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube') 119 | 120 | # Compute aggregate lighting 121 | reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2] 122 | shaded_col += spec * reflectance 123 | 124 | return shaded_col * (1.0 - ks[..., 0:1]), spec_col # Modulate by hemisphere visibility 125 | 126 | ###################################################################################### 127 | # Load and store 128 | ###################################################################################### 129 | 130 | # Load from latlong .HDR file 131 | def _load_env_hdr(fn, scale=1.0): 132 | latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale 133 | cubemap = util.latlong_to_cubemap(latlong_img, [512, 512]) 134 | 135 | l = EnvironmentLight(cubemap) 136 | l.build_mips() 137 | 138 | return l 139 | 140 | def load_env(fn, scale=1.0): 141 | if os.path.splitext(fn)[1].lower() == ".hdr": 142 | return _load_env_hdr(fn, scale) 143 | else: 144 | assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1] 145 | 146 | def save_env_map(fn, light): 147 | assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently" 148 | if isinstance(light, EnvironmentLight): 149 | color = util.cubemap_to_latlong(light.base, [512, 1024]) 150 | util.save_image_raw(fn, color.detach().cpu().numpy()) 151 | 152 | ###################################################################################### 153 | # Create trainable env map with random initialization 154 | ###################################################################################### 155 | 156 | def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25): 157 | base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias 158 | return EnvironmentLight(base) 159 | 160 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/material.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import os 11 | import numpy as np 12 | import torch 13 | 14 | from . import util 15 | from . import texture 16 | 17 | ###################################################################################### 18 | # Wrapper to make materials behave like a python dict, but register textures as 19 | # torch.nn.Module parameters. 20 | ###################################################################################### 21 | class Material(torch.nn.Module): 22 | def __init__(self, mat_dict): 23 | super(Material, self).__init__() 24 | self.mat_keys = set() 25 | for key in mat_dict.keys(): 26 | self.mat_keys.add(key) 27 | self[key] = mat_dict[key] 28 | 29 | def __contains__(self, key): 30 | return hasattr(self, key) 31 | 32 | def __getitem__(self, key): 33 | return getattr(self, key) 34 | 35 | def __setitem__(self, key, val): 36 | self.mat_keys.add(key) 37 | setattr(self, key, val) 38 | 39 | def __delitem__(self, key): 40 | self.mat_keys.remove(key) 41 | delattr(self, key) 42 | 43 | def keys(self): 44 | return self.mat_keys 45 | 46 | ###################################################################################### 47 | # .mtl material format loading / storing 48 | ###################################################################################### 49 | @torch.no_grad() 50 | def load_mtl(fn, clear_ks=True): 51 | import re 52 | mtl_path = os.path.dirname(fn) 53 | 54 | # Read file 55 | with open(fn, 'r') as f: 56 | lines = f.readlines() 57 | 58 | # Parse materials 59 | materials = [] 60 | for line in lines: 61 | split_line = re.split(' +|\t+|\n+', line.strip()) 62 | prefix = split_line[0].lower() 63 | data = split_line[1:] 64 | if 'newmtl' in prefix: 65 | material = Material({'name' : data[0]}) 66 | materials += [material] 67 | elif materials: 68 | if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: 69 | material[prefix] = data[0] 70 | else: 71 | material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') 72 | 73 | # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps 74 | for mat in materials: 75 | if not 'bsdf' in mat: 76 | mat['bsdf'] = 'pbr' 77 | 78 | if 'map_kd' in mat: 79 | mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) 80 | else: 81 | mat['kd'] = texture.Texture2D(mat['kd']) 82 | 83 | if 'map_ks' in mat: 84 | mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) 85 | else: 86 | mat['ks'] = texture.Texture2D(mat['ks']) 87 | 88 | if 'bump' in mat: 89 | mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) 90 | 91 | # Convert Kd from sRGB to linear RGB 92 | mat['kd'] = texture.srgb_to_rgb(mat['kd']) 93 | 94 | if clear_ks: 95 | # Override ORM occlusion (red) channel by zeros. We hijack this channel 96 | for mip in mat['ks'].getMips(): 97 | mip[..., 0] = 0.0 98 | 99 | return materials 100 | 101 | @torch.no_grad() 102 | def save_mtl(fn, material): 103 | folder = os.path.dirname(fn) 104 | with open(fn, "w") as f: 105 | f.write('newmtl defaultMat\n') 106 | if material is not None: 107 | f.write('bsdf %s\n' % material['bsdf']) 108 | if 'kd' in material.keys(): 109 | f.write('map_Kd texture_kd.png\n') 110 | texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd'])) 111 | texture.save_texture2D(os.path.join(folder, 'texture_kd_rgb.png'), material['kd']) 112 | if 'ks' in material.keys(): 113 | f.write('map_Ks texture_ks.png\n') 114 | texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks']) 115 | if 'normal' in material.keys() and not ('no_perturbed_nrm' in material and material['no_perturbed_nrm']): 116 | f.write('bump texture_n.png\n') 117 | texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) 118 | else: 119 | f.write('Kd 1 1 1\n') 120 | f.write('Ks 0 0 0\n') 121 | f.write('Ka 0 0 0\n') 122 | f.write('Tf 1 1 1\n') 123 | f.write('Ni 1\n') 124 | f.write('Ns 0\n') 125 | 126 | ###################################################################################### 127 | # Merge multiple materials into a single uber-material 128 | ###################################################################################### 129 | 130 | def _upscale_replicate(x, full_res): 131 | x = x.permute(0, 3, 1, 2) 132 | x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') 133 | return x.permute(0, 2, 3, 1).contiguous() 134 | 135 | def merge_materials(materials, texcoords, tfaces, mfaces): 136 | assert len(materials) > 0 137 | for mat in materials: 138 | assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" 139 | assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" 140 | 141 | uber_material = Material({ 142 | 'name' : 'uber_material', 143 | 'bsdf' : materials[0]['bsdf'], 144 | }) 145 | 146 | textures = ['kd', 'ks', 'normal'] 147 | 148 | # Find maximum texture resolution across all materials and textures 149 | max_res = None 150 | for mat in materials: 151 | for tex in textures: 152 | tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) 153 | max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res 154 | 155 | # Compute size of compund texture and round up to nearest PoT 156 | full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int) 157 | 158 | # Normalize texture resolution across all materials & combine into a single large texture 159 | for tex in textures: 160 | if tex in materials[0]: 161 | tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x 162 | tex_data = _upscale_replicate(tex_data, full_res) 163 | uber_material[tex] = texture.Texture2D(tex_data) 164 | 165 | # Compute scaling values for used / unused texture area 166 | s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] 167 | 168 | # Recompute texture coordinates to cooincide with new composite texture 169 | new_tverts = {} 170 | new_tverts_data = [] 171 | for fi in range(len(tfaces)): 172 | matIdx = mfaces[fi] 173 | for vi in range(3): 174 | ti = tfaces[fi][vi] 175 | if not (ti in new_tverts): 176 | new_tverts[ti] = {} 177 | if not (matIdx in new_tverts[ti]): # create new vertex 178 | new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here 179 | new_tverts[ti][matIdx] = len(new_tverts_data) - 1 180 | tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex 181 | 182 | return uber_material, new_tverts_data, tfaces 183 | 184 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/mlptexture.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import torch 11 | import tinycudann as tcnn 12 | import numpy as np 13 | 14 | ####################################################################################################################################################### 15 | # Small MLP using PyTorch primitives, internal helper class 16 | ####################################################################################################################################################### 17 | 18 | class _MLP(torch.nn.Module): 19 | def __init__(self, cfg, loss_scale=1.0): 20 | super(_MLP, self).__init__() 21 | self.loss_scale = loss_scale 22 | net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) 23 | for i in range(cfg['n_hidden_layers']-1): 24 | net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) 25 | net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),) 26 | self.net = torch.nn.Sequential(*net).cuda() 27 | 28 | self.net.apply(self._init_weights) 29 | 30 | if self.loss_scale != 1.0: 31 | self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, )) 32 | 33 | def forward(self, x): 34 | return self.net(x.to(torch.float32)) 35 | 36 | @staticmethod 37 | def _init_weights(m): 38 | if type(m) == torch.nn.Linear: 39 | torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') 40 | if hasattr(m.bias, 'data'): 41 | m.bias.data.fill_(0.0) 42 | 43 | ####################################################################################################################################################### 44 | # Outward visible MLP class 45 | ####################################################################################################################################################### 46 | 47 | class MLPTexture3D(torch.nn.Module): 48 | def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2, min_max = None): 49 | super(MLPTexture3D, self).__init__() 50 | 51 | self.channels = channels 52 | self.internal_dims = internal_dims 53 | self.AABB = AABB 54 | self.min_max = min_max 55 | 56 | # Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details 57 | desired_resolution = 4096 58 | base_grid_resolution = 16 59 | num_levels = 16 60 | per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1)) 61 | 62 | enc_cfg = { 63 | "otype": "HashGrid", 64 | "n_levels": num_levels, 65 | "n_features_per_level": 2, 66 | "log2_hashmap_size": 19, 67 | "base_resolution": base_grid_resolution, 68 | "per_level_scale" : per_level_scale 69 | } 70 | 71 | gradient_scaling = 128.0 72 | self.encoder = tcnn.Encoding(3, enc_cfg) 73 | self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, )) 74 | 75 | # Setup MLP 76 | mlp_cfg = { 77 | "n_input_dims" : self.encoder.n_output_dims, 78 | "n_output_dims" : self.channels, 79 | "n_hidden_layers" : hidden, 80 | "n_neurons" : self.internal_dims 81 | } 82 | self.net = _MLP(mlp_cfg, gradient_scaling) 83 | print("Encoder output: %d dims" % (self.encoder.n_output_dims)) 84 | 85 | # Sample texture at a given location 86 | def sample(self, texc): 87 | _texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...]) 88 | _texc = torch.clamp(_texc, min=0, max=1) 89 | 90 | p_enc = self.encoder(_texc.contiguous()) 91 | out = self.net.forward(p_enc) 92 | 93 | # Sigmoid limit and scale to the allowed range 94 | out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] 95 | 96 | return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c] 97 | 98 | # In-place clamp with no derivative to make sure values are in valid range after training 99 | def clamp_(self): 100 | pass 101 | 102 | def cleanup(self): 103 | tcnn.free_temporary_memory() 104 | 105 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/obj.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import os 11 | import torch 12 | 13 | from . import texture 14 | from . import mesh 15 | from . import material 16 | 17 | ###################################################################################### 18 | # Utility functions 19 | ###################################################################################### 20 | 21 | def _find_mat(materials, name): 22 | for mat in materials: 23 | if mat['name'] == name: 24 | return mat 25 | return materials[0] # Materials 0 is the default 26 | 27 | ###################################################################################### 28 | # Create mesh object from objfile 29 | ###################################################################################### 30 | 31 | def load_obj(filename, clear_ks=True, mtl_override=None): 32 | obj_path = os.path.dirname(filename) 33 | 34 | # Read entire file 35 | with open(filename, 'r') as f: 36 | lines = f.readlines() 37 | 38 | # Load materials 39 | all_materials = [ 40 | { 41 | 'name' : '_default_mat', 42 | 'bsdf' : 'pbr', 43 | 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), 44 | 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) 45 | } 46 | ] 47 | if mtl_override is None: 48 | for line in lines: 49 | if len(line.split()) == 0: 50 | continue 51 | if line.split()[0] == 'mtllib': 52 | all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library 53 | else: 54 | all_materials += material.load_mtl(mtl_override) 55 | 56 | # load vertices 57 | vertices, texcoords, normals = [], [], [] 58 | for line in lines: 59 | if len(line.split()) == 0: 60 | continue 61 | 62 | prefix = line.split()[0].lower() 63 | if prefix == 'v': 64 | vertices.append([float(v) for v in line.split()[1:]]) 65 | elif prefix == 'vt': 66 | val = [float(v) for v in line.split()[1:]] 67 | texcoords.append([val[0], 1.0 - val[1]]) 68 | elif prefix == 'vn': 69 | normals.append([float(v) for v in line.split()[1:]]) 70 | 71 | # load faces 72 | activeMatIdx = None 73 | used_materials = [] 74 | faces, tfaces, nfaces, mfaces = [], [], [], [] 75 | for line in lines: 76 | if len(line.split()) == 0: 77 | continue 78 | 79 | prefix = line.split()[0].lower() 80 | if prefix == 'usemtl': # Track used materials 81 | mat = _find_mat(all_materials, line.split()[1]) 82 | if not mat in used_materials: 83 | used_materials.append(mat) 84 | activeMatIdx = used_materials.index(mat) 85 | elif prefix == 'f': # Parse face 86 | vs = line.split()[1:] 87 | nv = len(vs) 88 | vv = vs[0].split('/') 89 | v0 = int(vv[0]) - 1 90 | t0 = int(vv[1]) - 1 if vv[1] != "" else -1 91 | n0 = int(vv[2]) - 1 if vv[2] != "" else -1 92 | for i in range(nv - 2): # Triangulate polygons 93 | vv = vs[i + 1].split('/') 94 | v1 = int(vv[0]) - 1 95 | t1 = int(vv[1]) - 1 if vv[1] != "" else -1 96 | n1 = int(vv[2]) - 1 if vv[2] != "" else -1 97 | vv = vs[i + 2].split('/') 98 | v2 = int(vv[0]) - 1 99 | t2 = int(vv[1]) - 1 if vv[1] != "" else -1 100 | n2 = int(vv[2]) - 1 if vv[2] != "" else -1 101 | mfaces.append(activeMatIdx) 102 | faces.append([v0, v1, v2]) 103 | tfaces.append([t0, t1, t2]) 104 | nfaces.append([n0, n1, n2]) 105 | assert len(tfaces) == len(faces) and len(nfaces) == len (faces) 106 | 107 | # Create an "uber" material by combining all textures into a larger texture 108 | if len(used_materials) > 1: 109 | uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) 110 | else: 111 | uber_material = used_materials[0] 112 | 113 | vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') 114 | texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None 115 | normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None 116 | 117 | faces = torch.tensor(faces, dtype=torch.int64, device='cuda') 118 | tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None 119 | nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None 120 | 121 | return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) 122 | 123 | ###################################################################################### 124 | # Save mesh object to objfile 125 | ###################################################################################### 126 | 127 | def write_obj(folder, mesh, save_material=True): 128 | obj_file = os.path.join(folder, 'mesh.obj') 129 | print("Writing mesh: ", obj_file) 130 | with open(obj_file, "w") as f: 131 | f.write("mtllib mesh.mtl\n") 132 | f.write("g default\n") 133 | 134 | v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None 135 | v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None 136 | v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None 137 | 138 | t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None 139 | t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None 140 | t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None 141 | 142 | print(" writing %d vertices" % len(v_pos)) 143 | for v in v_pos: 144 | f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) 145 | 146 | if v_tex is not None: 147 | print(" writing %d texcoords" % len(v_tex)) 148 | assert(len(t_pos_idx) == len(t_tex_idx)) 149 | for v in v_tex: 150 | f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) 151 | 152 | if v_nrm is not None: 153 | print(" writing %d normals" % len(v_nrm)) 154 | assert(len(t_pos_idx) == len(t_nrm_idx)) 155 | for v in v_nrm: 156 | f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) 157 | 158 | # faces 159 | f.write("s 1 \n") 160 | f.write("g pMesh1\n") 161 | f.write("usemtl defaultMat\n") 162 | 163 | # Write faces 164 | print(" writing %d faces" % len(t_pos_idx)) 165 | for i in range(len(t_pos_idx)): 166 | f.write("f ") 167 | for j in range(3): 168 | f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) 169 | f.write("\n") 170 | 171 | if save_material: 172 | mtl_file = os.path.join(folder, 'mesh.mtl') 173 | print("Writing material: ", mtl_file) 174 | material.save_mtl(mtl_file, mesh.material) 175 | 176 | print("Done exporting mesh") 177 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/regularizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import torch 11 | import nvdiffrast.torch as dr 12 | 13 | from . import util 14 | from . import mesh 15 | 16 | ###################################################################################### 17 | # Computes the image gradient, useful for kd/ks smoothness losses 18 | ###################################################################################### 19 | def image_grad(buf, std=0.01): 20 | t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"), 21 | torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"), 22 | indexing='ij') 23 | tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...] 24 | tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp') 25 | return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:] 26 | 27 | ###################################################################################### 28 | # Computes the avergage edge length of a mesh. 29 | # Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients 30 | ###################################################################################### 31 | def avg_edge_length(v_pos, t_pos_idx): 32 | e_pos_idx = mesh.compute_edges(t_pos_idx) 33 | edge_len = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]]) 34 | return torch.mean(edge_len) 35 | 36 | ###################################################################################### 37 | # Laplacian regularization using umbrella operator (Fujiwara / Desbrun). 38 | # https://mgarland.org/class/geom04/material/smoothing.pdf 39 | ###################################################################################### 40 | def laplace_regularizer_const(v_pos, t_pos_idx): 41 | term = torch.zeros_like(v_pos) 42 | norm = torch.zeros_like(v_pos[..., 0:1]) 43 | 44 | v0 = v_pos[t_pos_idx[:, 0], :] 45 | v1 = v_pos[t_pos_idx[:, 1], :] 46 | v2 = v_pos[t_pos_idx[:, 2], :] 47 | 48 | term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0)) 49 | term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1)) 50 | term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2)) 51 | 52 | two = torch.ones_like(v0) * 2.0 53 | norm.scatter_add_(0, t_pos_idx[:, 0:1], two) 54 | norm.scatter_add_(0, t_pos_idx[:, 1:2], two) 55 | norm.scatter_add_(0, t_pos_idx[:, 2:3], two) 56 | 57 | term = term / torch.clamp(norm, min=1.0) 58 | 59 | return torch.mean(term**2) 60 | 61 | ###################################################################################### 62 | # Smooth vertex normals 63 | ###################################################################################### 64 | def normal_consistency(v_pos, t_pos_idx): 65 | # Compute face normals 66 | v0 = v_pos[t_pos_idx[:, 0], :] 67 | v1 = v_pos[t_pos_idx[:, 1], :] 68 | v2 = v_pos[t_pos_idx[:, 2], :] 69 | 70 | face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) 71 | 72 | tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx) 73 | 74 | # Fetch normals for both faces sharind an edge 75 | n0 = face_normals[tris_per_edge[:, 0], :] 76 | n1 = face_normals[tris_per_edge[:, 1], :] 77 | 78 | # Compute error metric based on normal difference 79 | term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0) 80 | term = (1.0 - term) * 0.5 81 | 82 | return torch.mean(torch.abs(term)) 83 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith 11 | __all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ] 12 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/bsdf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import math 11 | import torch 12 | 13 | NORMAL_THRESHOLD = 0.1 14 | 15 | ################################################################################ 16 | # Vector utility functions 17 | ################################################################################ 18 | 19 | def _dot(x, y): 20 | return torch.sum(x*y, -1, keepdim=True) 21 | 22 | def _reflect(x, n): 23 | return 2*_dot(x, n)*n - x 24 | 25 | def _safe_normalize(x): 26 | return torch.nn.functional.normalize(x, dim = -1) 27 | 28 | def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading): 29 | # Swap normal direction for backfacing surfaces 30 | if two_sided_shading: 31 | smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm) 32 | geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm) 33 | 34 | t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1) 35 | return torch.lerp(geom_nrm, smooth_nrm, t) 36 | 37 | 38 | def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl): 39 | smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm)) 40 | if opengl: 41 | shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) 42 | else: 43 | shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) 44 | return _safe_normalize(shading_nrm) 45 | 46 | def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): 47 | smooth_nrm = _safe_normalize(smooth_nrm) 48 | smooth_tng = _safe_normalize(smooth_tng) 49 | view_vec = _safe_normalize(view_pos - pos) 50 | shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl) 51 | return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading) 52 | 53 | ################################################################################ 54 | # Simple lambertian diffuse BSDF 55 | ################################################################################ 56 | 57 | def bsdf_lambert(nrm, wi): 58 | return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi 59 | 60 | ################################################################################ 61 | # Frostbite diffuse 62 | ################################################################################ 63 | 64 | def bsdf_frostbite(nrm, wi, wo, linearRoughness): 65 | wiDotN = _dot(wi, nrm) 66 | woDotN = _dot(wo, nrm) 67 | 68 | h = _safe_normalize(wo + wi) 69 | wiDotH = _dot(wi, h) 70 | 71 | energyBias = 0.5 * linearRoughness 72 | energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness 73 | f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness 74 | f0 = 1.0 75 | 76 | wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN) 77 | woScatter = bsdf_fresnel_shlick(f0, f90, woDotN) 78 | res = wiScatter * woScatter * energyFactor 79 | return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res)) 80 | 81 | ################################################################################ 82 | # Phong specular, loosely based on mitsuba implementation 83 | ################################################################################ 84 | 85 | def bsdf_phong(nrm, wo, wi, N): 86 | dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0) 87 | dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0) 88 | return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi) 89 | 90 | ################################################################################ 91 | # PBR's implementation of GGX specular 92 | ################################################################################ 93 | 94 | specular_epsilon = 1e-4 95 | 96 | def bsdf_fresnel_shlick(f0, f90, cosTheta): 97 | _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) 98 | return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0 99 | 100 | def bsdf_ndf_ggx(alphaSqr, cosTheta): 101 | _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) 102 | d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1 103 | return alphaSqr / (d * d * math.pi) 104 | 105 | def bsdf_lambda_ggx(alphaSqr, cosTheta): 106 | _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) 107 | cosThetaSqr = _cosTheta * _cosTheta 108 | tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr 109 | res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0) 110 | return res 111 | 112 | def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO): 113 | lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI) 114 | lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO) 115 | return 1 / (1 + lambdaI + lambdaO) 116 | 117 | def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08): 118 | _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0) 119 | alphaSqr = _alpha * _alpha 120 | 121 | h = _safe_normalize(wo + wi) 122 | woDotN = _dot(wo, nrm) 123 | wiDotN = _dot(wi, nrm) 124 | woDotH = _dot(wo, h) 125 | nDotH = _dot(nrm, h) 126 | 127 | D = bsdf_ndf_ggx(alphaSqr, nDotH) 128 | G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN) 129 | F = bsdf_fresnel_shlick(col, 1, woDotH) 130 | 131 | w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon) 132 | 133 | frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon) 134 | return torch.where(frontfacing, w, torch.zeros_like(w)) 135 | 136 | def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): 137 | wo = _safe_normalize(view_pos - pos) 138 | wi = _safe_normalize(light_pos - pos) 139 | 140 | spec_str = arm[..., 0:1] # x component 141 | roughness = arm[..., 1:2] # y component 142 | metallic = arm[..., 2:3] # z component 143 | ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str) 144 | kd = kd * (1.0 - metallic) 145 | 146 | if BSDF == 0: 147 | diffuse = kd * bsdf_lambert(nrm, wi) 148 | else: 149 | diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness) 150 | specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness) 151 | return diffuse + specular 152 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/bsdf.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "common.h" 15 | 16 | struct LambertKernelParams 17 | { 18 | Tensor nrm; 19 | Tensor wi; 20 | Tensor out; 21 | dim3 gridSize; 22 | }; 23 | 24 | struct FrostbiteDiffuseKernelParams 25 | { 26 | Tensor nrm; 27 | Tensor wi; 28 | Tensor wo; 29 | Tensor linearRoughness; 30 | Tensor out; 31 | dim3 gridSize; 32 | }; 33 | 34 | struct FresnelShlickKernelParams 35 | { 36 | Tensor f0; 37 | Tensor f90; 38 | Tensor cosTheta; 39 | Tensor out; 40 | dim3 gridSize; 41 | }; 42 | 43 | struct NdfGGXParams 44 | { 45 | Tensor alphaSqr; 46 | Tensor cosTheta; 47 | Tensor out; 48 | dim3 gridSize; 49 | }; 50 | 51 | struct MaskingSmithParams 52 | { 53 | Tensor alphaSqr; 54 | Tensor cosThetaI; 55 | Tensor cosThetaO; 56 | Tensor out; 57 | dim3 gridSize; 58 | }; 59 | 60 | struct PbrSpecular 61 | { 62 | Tensor col; 63 | Tensor nrm; 64 | Tensor wo; 65 | Tensor wi; 66 | Tensor alpha; 67 | Tensor out; 68 | dim3 gridSize; 69 | float min_roughness; 70 | }; 71 | 72 | struct PbrBSDF 73 | { 74 | Tensor kd; 75 | Tensor arm; 76 | Tensor pos; 77 | Tensor nrm; 78 | Tensor view_pos; 79 | Tensor light_pos; 80 | Tensor out; 81 | dim3 gridSize; 82 | float min_roughness; 83 | int BSDF; 84 | }; 85 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/common.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #include 13 | #include 14 | 15 | //------------------------------------------------------------------------ 16 | // Block and grid size calculators for kernel launches. 17 | 18 | dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims) 19 | { 20 | int maxThreads = maxWidth * maxHeight; 21 | if (maxThreads <= 1 || (dims.x * dims.y) <= 1) 22 | return dim3(1, 1, 1); // Degenerate. 23 | 24 | // Start from max size. 25 | int bw = maxWidth; 26 | int bh = maxHeight; 27 | 28 | // Optimizations for weirdly sized buffers. 29 | if (dims.x < bw) 30 | { 31 | // Decrease block width to smallest power of two that covers the buffer width. 32 | while ((bw >> 1) >= dims.x) 33 | bw >>= 1; 34 | 35 | // Maximize height. 36 | bh = maxThreads / bw; 37 | if (bh > dims.y) 38 | bh = dims.y; 39 | } 40 | else if (dims.y < bh) 41 | { 42 | // Halve height and double width until fits completely inside buffer vertically. 43 | while (bh > dims.y) 44 | { 45 | bh >>= 1; 46 | if (bw < dims.x) 47 | bw <<= 1; 48 | } 49 | } 50 | 51 | // Done. 52 | return dim3(bw, bh, 1); 53 | } 54 | 55 | // returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync) 56 | dim3 getWarpSize(dim3 blockSize) 57 | { 58 | return dim3( 59 | std::min(blockSize.x, 32u), 60 | std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)), 61 | std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z)) 62 | ); 63 | } 64 | 65 | dim3 getLaunchGridSize(dim3 blockSize, dim3 dims) 66 | { 67 | dim3 gridSize; 68 | gridSize.x = (dims.x - 1) / blockSize.x + 1; 69 | gridSize.y = (dims.y - 1) / blockSize.y + 1; 70 | gridSize.z = (dims.z - 1) / blockSize.z + 1; 71 | return gridSize; 72 | } 73 | 74 | //------------------------------------------------------------------------ 75 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/common.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | #include 14 | #include 15 | 16 | #include "vec3f.h" 17 | #include "vec4f.h" 18 | #include "tensor.h" 19 | 20 | dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims); 21 | dim3 getLaunchGridSize(dim3 blockSize, dim3 dims); 22 | 23 | #ifdef __CUDACC__ 24 | 25 | #ifdef _MSC_VER 26 | #define M_PI 3.14159265358979323846f 27 | #endif 28 | 29 | __host__ __device__ static inline dim3 getWarpSize(dim3 blockSize) 30 | { 31 | return dim3( 32 | min(blockSize.x, 32u), 33 | min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)), 34 | min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z)) 35 | ); 36 | } 37 | 38 | __device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); } 39 | #else 40 | dim3 getWarpSize(dim3 blockSize); 41 | #endif -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/cubemap.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "common.h" 15 | 16 | struct DiffuseCubemapKernelParams 17 | { 18 | Tensor cubemap; 19 | Tensor out; 20 | dim3 gridSize; 21 | }; 22 | 23 | struct SpecularCubemapKernelParams 24 | { 25 | Tensor cubemap; 26 | Tensor bounds; 27 | Tensor out; 28 | dim3 gridSize; 29 | float costheta_cutoff; 30 | float roughness; 31 | }; 32 | 33 | struct SpecularBoundsKernelParams 34 | { 35 | float costheta_cutoff; 36 | Tensor out; 37 | dim3 gridSize; 38 | }; 39 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/loss.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #include 13 | 14 | #include "common.h" 15 | #include "loss.h" 16 | 17 | //------------------------------------------------------------------------ 18 | // Utils 19 | 20 | __device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; } 21 | 22 | __device__ float warpSum(float val) { 23 | for (int i = 1; i < 32; i *= 2) 24 | val += __shfl_xor_sync(0xFFFFFFFF, val, i); 25 | return val; 26 | } 27 | 28 | //------------------------------------------------------------------------ 29 | // Tonemapping 30 | 31 | __device__ inline float fwdSRGB(float x) 32 | { 33 | return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f); 34 | } 35 | 36 | __device__ inline void bwdSRGB(float x, float &d_x, float d_out) 37 | { 38 | if (x > 0.0031308f) 39 | d_x += d_out * 0.439583f / powf(x, 0.583333f); 40 | else if (x > 0.0f) 41 | d_x += d_out * 12.92f; 42 | } 43 | 44 | __device__ inline vec3f fwdTonemapLogSRGB(vec3f x) 45 | { 46 | return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f))); 47 | } 48 | 49 | __device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out) 50 | { 51 | if (x.x > 0.0f && x.x < 65535.0f) 52 | { 53 | bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x); 54 | d_x.x *= 1 / (x.x + 1.0f); 55 | } 56 | if (x.y > 0.0f && x.y < 65535.0f) 57 | { 58 | bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y); 59 | d_x.y *= 1 / (x.y + 1.0f); 60 | } 61 | if (x.z > 0.0f && x.z < 65535.0f) 62 | { 63 | bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z); 64 | d_x.z *= 1 / (x.z + 1.0f); 65 | } 66 | } 67 | 68 | __device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f) 69 | { 70 | return (img - target) * (img - target) / (img * img + target * target + eps); 71 | } 72 | 73 | __device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f) 74 | { 75 | float denom = (target * target + img * img + eps); 76 | d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom); 77 | d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom); 78 | } 79 | 80 | __device__ inline float fwdSMAPE(float img, float target, float eps=0.01f) 81 | { 82 | return abs(img - target) / (img + target + eps); 83 | } 84 | 85 | __device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f) 86 | { 87 | float denom = (target + img + eps); 88 | d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom); 89 | d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom); 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | // Kernels 94 | 95 | __global__ void imgLossFwdKernel(LossKernelParams p) 96 | { 97 | // Calculate pixel position. 98 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; 99 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; 100 | unsigned int pz = blockIdx.z; 101 | 102 | float floss = 0.0f; 103 | if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z) 104 | { 105 | vec3f img = p.img.fetch3(px, py, pz); 106 | vec3f target = p.target.fetch3(px, py, pz); 107 | 108 | img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f)); 109 | target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f)); 110 | 111 | if (p.tonemapper == TONEMAPPER_LOG_SRGB) 112 | { 113 | img = fwdTonemapLogSRGB(img); 114 | target = fwdTonemapLogSRGB(target); 115 | } 116 | 117 | vec3f vloss(0); 118 | if (p.loss == LOSS_MSE) 119 | vloss = (img - target) * (img - target); 120 | else if (p.loss == LOSS_RELMSE) 121 | vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z)); 122 | else if (p.loss == LOSS_SMAPE) 123 | vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z)); 124 | else 125 | vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z)); 126 | 127 | floss = sum(vloss) / 3.0f; 128 | } 129 | 130 | floss = warpSum(floss); 131 | 132 | dim3 warpSize = getWarpSize(blockDim); 133 | if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0) 134 | p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss); 135 | } 136 | 137 | __global__ void imgLossBwdKernel(LossKernelParams p) 138 | { 139 | // Calculate pixel position. 140 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; 141 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; 142 | unsigned int pz = blockIdx.z; 143 | 144 | if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) 145 | return; 146 | 147 | dim3 warpSize = getWarpSize(blockDim); 148 | 149 | vec3f _img = p.img.fetch3(px, py, pz); 150 | vec3f _target = p.target.fetch3(px, py, pz); 151 | float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z); 152 | 153 | ///////////////////////////////////////////////////////////////////// 154 | // FWD 155 | 156 | vec3f img = _img, target = _target; 157 | if (p.tonemapper == TONEMAPPER_LOG_SRGB) 158 | { 159 | img = fwdTonemapLogSRGB(img); 160 | target = fwdTonemapLogSRGB(target); 161 | } 162 | 163 | ///////////////////////////////////////////////////////////////////// 164 | // BWD 165 | 166 | vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f; 167 | 168 | vec3f d_img(0), d_target(0); 169 | if (p.loss == LOSS_MSE) 170 | { 171 | d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z)); 172 | d_target = -d_img; 173 | } 174 | else if (p.loss == LOSS_RELMSE) 175 | { 176 | bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); 177 | bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); 178 | bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); 179 | } 180 | else if (p.loss == LOSS_SMAPE) 181 | { 182 | bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); 183 | bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); 184 | bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); 185 | } 186 | else 187 | { 188 | d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z)); 189 | d_target = -d_img; 190 | } 191 | 192 | 193 | if (p.tonemapper == TONEMAPPER_LOG_SRGB) 194 | { 195 | vec3f d__img(0), d__target(0); 196 | bwdTonemapLogSRGB(_img, d__img, d_img); 197 | bwdTonemapLogSRGB(_target, d__target, d_target); 198 | d_img = d__img; d_target = d__target; 199 | } 200 | 201 | if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0; 202 | if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0; 203 | if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0; 204 | if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0; 205 | if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0; 206 | if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0; 207 | 208 | p.img.store_grad(px, py, pz, d_img); 209 | p.target.store_grad(px, py, pz, d_target); 210 | } -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/loss.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "common.h" 15 | 16 | enum TonemapperType 17 | { 18 | TONEMAPPER_NONE = 0, 19 | TONEMAPPER_LOG_SRGB = 1 20 | }; 21 | 22 | enum LossType 23 | { 24 | LOSS_L1 = 0, 25 | LOSS_MSE = 1, 26 | LOSS_RELMSE = 2, 27 | LOSS_SMAPE = 3 28 | }; 29 | 30 | struct LossKernelParams 31 | { 32 | Tensor img; 33 | Tensor target; 34 | Tensor out; 35 | dim3 gridSize; 36 | TonemapperType tonemapper; 37 | LossType loss; 38 | }; 39 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/mesh.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #include 13 | #include 14 | 15 | #include "common.h" 16 | #include "mesh.h" 17 | 18 | 19 | //------------------------------------------------------------------------ 20 | // Kernels 21 | 22 | __global__ void xfmPointsFwdKernel(XfmKernelParams p) 23 | { 24 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; 25 | unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; 26 | 27 | __shared__ float mtx[4][4]; 28 | if (threadIdx.x < 16) 29 | mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); 30 | __syncthreads(); 31 | 32 | if (px >= p.gridSize.x) 33 | return; 34 | 35 | vec3f pos( 36 | p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), 37 | p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), 38 | p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) 39 | ); 40 | 41 | if (p.isPoints) 42 | { 43 | p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]); 44 | p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]); 45 | p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]); 46 | p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]); 47 | } 48 | else 49 | { 50 | p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]); 51 | p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]); 52 | p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]); 53 | } 54 | } 55 | 56 | __global__ void xfmPointsBwdKernel(XfmKernelParams p) 57 | { 58 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; 59 | unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; 60 | 61 | __shared__ float mtx[4][4]; 62 | if (threadIdx.x < 16) 63 | mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); 64 | __syncthreads(); 65 | 66 | if (px >= p.gridSize.x) 67 | return; 68 | 69 | vec3f pos( 70 | p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), 71 | p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), 72 | p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) 73 | ); 74 | 75 | vec4f d_out( 76 | p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)), 77 | p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)), 78 | p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)), 79 | p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0)) 80 | ); 81 | 82 | if (p.isPoints) 83 | { 84 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]); 85 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]); 86 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]); 87 | } 88 | else 89 | { 90 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]); 91 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]); 92 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]); 93 | } 94 | } -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/mesh.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "common.h" 15 | 16 | struct XfmKernelParams 17 | { 18 | bool isPoints; 19 | Tensor points; 20 | Tensor matrix; 21 | Tensor out; 22 | dim3 gridSize; 23 | }; 24 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/normal.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #include "common.h" 13 | #include "normal.h" 14 | 15 | #define NORMAL_THRESHOLD 0.1f 16 | 17 | //------------------------------------------------------------------------ 18 | // Perturb shading normal by tangent frame 19 | 20 | __device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl) 21 | { 22 | vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); 23 | vec3f smooth_bitng = safeNormalize(_smooth_bitng); 24 | vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); 25 | return safeNormalize(_shading_nrm); 26 | } 27 | 28 | __device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl) 29 | { 30 | //////////////////////////////////////////////////////////////////////// 31 | // FWD 32 | vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); 33 | vec3f smooth_bitng = safeNormalize(_smooth_bitng); 34 | vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); 35 | 36 | //////////////////////////////////////////////////////////////////////// 37 | // BWD 38 | vec3f d_shading_nrm(0); 39 | bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out); 40 | 41 | vec3f d_smooth_bitng(0); 42 | 43 | if (perturbed_nrm.z > 0.0f) 44 | { 45 | d_smooth_nrm += d_shading_nrm * perturbed_nrm.z; 46 | d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm); 47 | } 48 | 49 | d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y; 50 | d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng); 51 | 52 | d_smooth_tng += d_shading_nrm * perturbed_nrm.x; 53 | d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng); 54 | 55 | vec3f d__smooth_bitng(0); 56 | bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng); 57 | 58 | bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng); 59 | } 60 | 61 | //------------------------------------------------------------------------ 62 | #define bent_nrm_eps 0.001f 63 | 64 | __device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm) 65 | { 66 | float dp = dot(view_vec, smooth_nrm); 67 | float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); 68 | return geom_nrm * (1.0f - t) + smooth_nrm * t; 69 | } 70 | 71 | __device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out) 72 | { 73 | //////////////////////////////////////////////////////////////////////// 74 | // FWD 75 | float dp = dot(view_vec, smooth_nrm); 76 | float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); 77 | 78 | //////////////////////////////////////////////////////////////////////// 79 | // BWD 80 | if (dp > NORMAL_THRESHOLD) 81 | d_smooth_nrm += d_out; 82 | else 83 | { 84 | // geom_nrm * (1.0f - t) + smooth_nrm * t; 85 | d_geom_nrm += d_out * (1.0f - t); 86 | d_smooth_nrm += d_out * t; 87 | float d_t = sum(d_out * (smooth_nrm - geom_nrm)); 88 | 89 | float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD; 90 | 91 | bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp); 92 | } 93 | } 94 | 95 | //------------------------------------------------------------------------ 96 | // Kernels 97 | 98 | __global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p) 99 | { 100 | // Calculate pixel position. 101 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; 102 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; 103 | unsigned int pz = blockIdx.z; 104 | if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) 105 | return; 106 | 107 | vec3f pos = p.pos.fetch3(px, py, pz); 108 | vec3f view_pos = p.view_pos.fetch3(px, py, pz); 109 | vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); 110 | vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); 111 | vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); 112 | vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); 113 | 114 | vec3f smooth_nrm = safeNormalize(_smooth_nrm); 115 | vec3f smooth_tng = safeNormalize(_smooth_tng); 116 | vec3f view_vec = safeNormalize(view_pos - pos); 117 | vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); 118 | 119 | vec3f res; 120 | if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) 121 | res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm); 122 | else 123 | res = fwdBendNormal(view_vec, shading_nrm, geom_nrm); 124 | 125 | p.out.store(px, py, pz, res); 126 | } 127 | 128 | __global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p) 129 | { 130 | // Calculate pixel position. 131 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; 132 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; 133 | unsigned int pz = blockIdx.z; 134 | if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) 135 | return; 136 | 137 | vec3f pos = p.pos.fetch3(px, py, pz); 138 | vec3f view_pos = p.view_pos.fetch3(px, py, pz); 139 | vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); 140 | vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); 141 | vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); 142 | vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); 143 | vec3f d_out = p.out.fetch3(px, py, pz); 144 | 145 | /////////////////////////////////////////////////////////////////////////////////////////////////// 146 | // FWD 147 | 148 | vec3f smooth_nrm = safeNormalize(_smooth_nrm); 149 | vec3f smooth_tng = safeNormalize(_smooth_tng); 150 | vec3f _view_vec = view_pos - pos; 151 | vec3f view_vec = safeNormalize(view_pos - pos); 152 | 153 | vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); 154 | 155 | /////////////////////////////////////////////////////////////////////////////////////////////////// 156 | // BWD 157 | 158 | vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0); 159 | if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) 160 | { 161 | bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); 162 | d_shading_nrm = -d_shading_nrm; 163 | d_geom_nrm = -d_geom_nrm; 164 | } 165 | else 166 | bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); 167 | 168 | vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0); 169 | bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl); 170 | 171 | vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0); 172 | bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec); 173 | bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm); 174 | bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng); 175 | 176 | p.pos.store_grad(px, py, pz, -d__view_vec); 177 | p.view_pos.store_grad(px, py, pz, d__view_vec); 178 | p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm); 179 | p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm); 180 | p.smooth_tng.store_grad(px, py, pz, d__smooth_tng); 181 | p.geom_nrm.store_grad(px, py, pz, d_geom_nrm); 182 | } -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/normal.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "common.h" 15 | 16 | struct PrepareShadingNormalKernelParams 17 | { 18 | Tensor pos; 19 | Tensor view_pos; 20 | Tensor perturbed_nrm; 21 | Tensor smooth_nrm; 22 | Tensor smooth_tng; 23 | Tensor geom_nrm; 24 | Tensor out; 25 | dim3 gridSize; 26 | bool two_sided_shading, opengl; 27 | }; 28 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/tensor.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | #if defined(__CUDACC__) && defined(BFLOAT16) 14 | #include // bfloat16 is float32 compatible with less mantissa bits 15 | #endif 16 | 17 | //--------------------------------------------------------------------------------- 18 | // CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16 19 | 20 | struct Tensor 21 | { 22 | void* val; 23 | void* d_val; 24 | int dims[4], _dims[4]; 25 | int strides[4]; 26 | bool fp16; 27 | 28 | #if defined(__CUDA__) && !defined(__CUDA_ARCH__) 29 | Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {} 30 | #endif 31 | 32 | #ifdef __CUDACC__ 33 | // Helpers to index and read/write a single element 34 | __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; } 35 | __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); } 36 | __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; } 37 | #ifdef BFLOAT16 38 | __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; } 39 | __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; } 40 | __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; } 41 | #else 42 | __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; } 43 | __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; } 44 | __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; } 45 | #endif 46 | 47 | ////////////////////////////////////////////////////////////////////////////////////////// 48 | // Fetch, use broadcasting for tensor dimensions of size 1 49 | __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const 50 | { 51 | return fetch(nhwcIndex(z, y, x, 0)); 52 | } 53 | 54 | __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const 55 | { 56 | return vec3f( 57 | fetch(nhwcIndex(z, y, x, 0)), 58 | fetch(nhwcIndex(z, y, x, 1)), 59 | fetch(nhwcIndex(z, y, x, 2)) 60 | ); 61 | } 62 | 63 | ///////////////////////////////////////////////////////////////////////////////////////////////////////////// 64 | // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside 65 | __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val) 66 | { 67 | store(_nhwcIndex(z, y, x, 0), _val); 68 | } 69 | 70 | __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val) 71 | { 72 | store(_nhwcIndex(z, y, x, 0), _val.x); 73 | store(_nhwcIndex(z, y, x, 1), _val.y); 74 | store(_nhwcIndex(z, y, x, 2), _val.z); 75 | } 76 | 77 | ///////////////////////////////////////////////////////////////////////////////////////////////////////////// 78 | // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside 79 | __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val) 80 | { 81 | store_grad(nhwcIndexContinuous(z, y, x, 0), _val); 82 | } 83 | 84 | __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val) 85 | { 86 | store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x); 87 | store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y); 88 | store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z); 89 | } 90 | #endif 91 | 92 | }; 93 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/vec3f.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | 14 | struct vec3f 15 | { 16 | float x, y, z; 17 | 18 | #ifdef __CUDACC__ 19 | __device__ vec3f() { } 20 | __device__ vec3f(float v) { x = v; y = v; z = v; } 21 | __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; } 22 | __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; } 23 | 24 | __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; } 25 | __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; } 26 | __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; } 27 | __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; } 28 | #endif 29 | }; 30 | 31 | #ifdef __CUDACC__ 32 | __device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); } 33 | __device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); } 34 | __device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); } 35 | __device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); } 36 | __device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); } 37 | 38 | __device__ static inline float sum(vec3f a) 39 | { 40 | return a.x + a.y + a.z; 41 | } 42 | 43 | __device__ static inline vec3f cross(vec3f a, vec3f b) 44 | { 45 | vec3f out; 46 | out.x = a.y * b.z - a.z * b.y; 47 | out.y = a.z * b.x - a.x * b.z; 48 | out.z = a.x * b.y - a.y * b.x; 49 | return out; 50 | } 51 | 52 | __device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out) 53 | { 54 | d_a.x += d_out.z * b.y - d_out.y * b.z; 55 | d_a.y += d_out.x * b.z - d_out.z * b.x; 56 | d_a.z += d_out.y * b.x - d_out.x * b.y; 57 | 58 | d_b.x += d_out.y * a.z - d_out.z * a.y; 59 | d_b.y += d_out.z * a.x - d_out.x * a.z; 60 | d_b.z += d_out.x * a.y - d_out.y * a.x; 61 | } 62 | 63 | __device__ static inline float dot(vec3f a, vec3f b) 64 | { 65 | return a.x * b.x + a.y * b.y + a.z * b.z; 66 | } 67 | 68 | __device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out) 69 | { 70 | d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z; 71 | d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z; 72 | } 73 | 74 | __device__ static inline vec3f reflect(vec3f x, vec3f n) 75 | { 76 | return n * 2.0f * dot(n, x) - x; 77 | } 78 | 79 | __device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out) 80 | { 81 | d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z); 82 | d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z); 83 | d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1); 84 | 85 | d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x); 86 | d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y); 87 | d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z)); 88 | } 89 | 90 | __device__ static inline vec3f safeNormalize(vec3f v) 91 | { 92 | float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); 93 | return l > 0.0f ? (v / l) : vec3f(0.0f); 94 | } 95 | 96 | __device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out) 97 | { 98 | 99 | float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); 100 | if (l > 0.0f) 101 | { 102 | float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f); 103 | d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac; 104 | d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac; 105 | d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac; 106 | } 107 | } 108 | 109 | #endif -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/c_src/vec4f.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | * property and proprietary rights in and to this material, related 6 | * documentation and any modifications thereto. Any use, reproduction, 7 | * disclosure or distribution of this material and related documentation 8 | * without an express license agreement from NVIDIA CORPORATION or 9 | * its affiliates is strictly prohibited. 10 | */ 11 | 12 | #pragma once 13 | 14 | struct vec4f 15 | { 16 | float x, y, z, w; 17 | 18 | #ifdef __CUDACC__ 19 | __device__ vec4f() { } 20 | __device__ vec4f(float v) { x = v; y = v; z = v; w = v; } 21 | __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; } 22 | __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; } 23 | #endif 24 | }; 25 | 26 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import torch 11 | 12 | #---------------------------------------------------------------------------- 13 | # HDR image losses 14 | #---------------------------------------------------------------------------- 15 | 16 | def _tonemap_srgb(f): 17 | return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) 18 | 19 | def _SMAPE(img, target, eps=0.01): 20 | nom = torch.abs(img - target) 21 | denom = torch.abs(img) + torch.abs(target) + 0.01 22 | return torch.mean(nom / denom) 23 | 24 | def _RELMSE(img, target, eps=0.1): 25 | nom = (img - target) * (img - target) 26 | denom = img * img + target * target + 0.1 27 | return torch.mean(nom / denom) 28 | 29 | def image_loss_fn(img, target, loss, tonemapper): 30 | if tonemapper == 'log_srgb': 31 | img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1)) 32 | target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1)) 33 | 34 | if loss == 'mse': 35 | return torch.nn.functional.mse_loss(img, target) 36 | elif loss == 'smape': 37 | return _SMAPE(img, target) 38 | elif loss == 'relmse': 39 | return _RELMSE(img, target) 40 | else: 41 | return torch.nn.functional.l1_loss(img, target) 42 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/tests/test_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import torch 11 | 12 | import os 13 | import sys 14 | sys.path.insert(0, os.path.join(sys.path[0], '../..')) 15 | import renderutils as ru 16 | 17 | RES = 8 18 | DTYPE = torch.float32 19 | 20 | def tonemap_srgb(f): 21 | return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) 22 | 23 | def l1(output, target): 24 | x = torch.clamp(output, min=0, max=65535) 25 | r = torch.clamp(target, min=0, max=65535) 26 | x = tonemap_srgb(torch.log(x + 1)) 27 | r = tonemap_srgb(torch.log(r + 1)) 28 | return torch.nn.functional.l1_loss(x,r) 29 | 30 | def relative_loss(name, ref, cuda): 31 | ref = ref.float() 32 | cuda = cuda.float() 33 | print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) 34 | 35 | def test_loss(loss, tonemapper): 36 | img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 37 | img_ref = img_cuda.clone().detach().requires_grad_(True) 38 | target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 39 | target_ref = target_cuda.clone().detach().requires_grad_(True) 40 | 41 | ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True) 42 | ref_loss.backward() 43 | 44 | cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper) 45 | cuda_loss.backward() 46 | 47 | print("-------------------------------------------------------------") 48 | print(" Loss: %s, %s" % (loss, tonemapper)) 49 | print("-------------------------------------------------------------") 50 | 51 | relative_loss("res:", ref_loss, cuda_loss) 52 | relative_loss("img:", img_ref.grad, img_cuda.grad) 53 | relative_loss("target:", target_ref.grad, target_cuda.grad) 54 | 55 | 56 | test_loss('l1', 'none') 57 | test_loss('l1', 'log_srgb') 58 | test_loss('mse', 'log_srgb') 59 | test_loss('smape', 'none') 60 | test_loss('relmse', 'none') 61 | test_loss('mse', 'none') -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/tests/test_mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import torch 11 | 12 | import os 13 | import sys 14 | sys.path.insert(0, os.path.join(sys.path[0], '../..')) 15 | import renderutils as ru 16 | 17 | BATCH = 8 18 | RES = 1024 19 | DTYPE = torch.float32 20 | 21 | torch.manual_seed(0) 22 | 23 | def tonemap_srgb(f): 24 | return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) 25 | 26 | def l1(output, target): 27 | x = torch.clamp(output, min=0, max=65535) 28 | r = torch.clamp(target, min=0, max=65535) 29 | x = tonemap_srgb(torch.log(x + 1)) 30 | r = tonemap_srgb(torch.log(r + 1)) 31 | return torch.nn.functional.l1_loss(x,r) 32 | 33 | def relative_loss(name, ref, cuda): 34 | ref = ref.float() 35 | cuda = cuda.float() 36 | print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item()) 37 | 38 | def test_xfm_points(): 39 | points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 40 | points_ref = points_cuda.clone().detach().requires_grad_(True) 41 | mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) 42 | mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) 43 | target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) 44 | 45 | ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True) 46 | ref_loss = torch.nn.MSELoss()(ref_out, target) 47 | ref_loss.backward() 48 | 49 | cuda_out = ru.xfm_points(points_cuda, mtx_cuda) 50 | cuda_loss = torch.nn.MSELoss()(cuda_out, target) 51 | cuda_loss.backward() 52 | 53 | print("-------------------------------------------------------------") 54 | 55 | relative_loss("res:", ref_out, cuda_out) 56 | relative_loss("points:", points_ref.grad, points_cuda.grad) 57 | 58 | def test_xfm_vectors(): 59 | points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 60 | points_ref = points_cuda.clone().detach().requires_grad_(True) 61 | points_cuda_p = points_cuda.clone().detach().requires_grad_(True) 62 | points_ref_p = points_cuda.clone().detach().requires_grad_(True) 63 | mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) 64 | mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) 65 | target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) 66 | 67 | ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True) 68 | ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3]) 69 | ref_loss.backward() 70 | 71 | cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda) 72 | cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3]) 73 | cuda_loss.backward() 74 | 75 | ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True) 76 | ref_loss_p = torch.nn.MSELoss()(ref_out_p, target) 77 | ref_loss_p.backward() 78 | 79 | cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda) 80 | cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target) 81 | cuda_loss_p.backward() 82 | 83 | print("-------------------------------------------------------------") 84 | 85 | relative_loss("res:", ref_out, cuda_out) 86 | relative_loss("points:", points_ref.grad, points_cuda.grad) 87 | relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad) 88 | 89 | test_xfm_points() 90 | test_xfm_vectors() 91 | -------------------------------------------------------------------------------- /others/nvdiffrec/render/renderutils/tests/test_perf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 4 | # property and proprietary rights in and to this material, related 5 | # documentation and any modifications thereto. Any use, reproduction, 6 | # disclosure or distribution of this material and related documentation 7 | # without an express license agreement from NVIDIA CORPORATION or 8 | # its affiliates is strictly prohibited. 9 | 10 | import torch 11 | 12 | import os 13 | import sys 14 | sys.path.insert(0, os.path.join(sys.path[0], '../..')) 15 | import renderutils as ru 16 | 17 | DTYPE=torch.float32 18 | 19 | def test_bsdf(BATCH, RES, ITR): 20 | kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 21 | kd_ref = kd_cuda.clone().detach().requires_grad_(True) 22 | arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 23 | arm_ref = arm_cuda.clone().detach().requires_grad_(True) 24 | pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 25 | pos_ref = pos_cuda.clone().detach().requires_grad_(True) 26 | nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 27 | nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) 28 | view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 29 | view_ref = view_cuda.clone().detach().requires_grad_(True) 30 | light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) 31 | light_ref = light_cuda.clone().detach().requires_grad_(True) 32 | target = torch.rand(BATCH, RES, RES, 3, device='cuda') 33 | 34 | start = torch.cuda.Event(enable_timing=True) 35 | end = torch.cuda.Event(enable_timing=True) 36 | 37 | ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) 38 | 39 | print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES)) 40 | 41 | start.record() 42 | for i in range(ITR): 43 | ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True) 44 | end.record() 45 | torch.cuda.synchronize() 46 | print("Pbr BSDF python:", start.elapsed_time(end)) 47 | 48 | start.record() 49 | for i in range(ITR): 50 | cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) 51 | end.record() 52 | torch.cuda.synchronize() 53 | print("Pbr BSDF cuda:", start.elapsed_time(end)) 54 | 55 | test_bsdf(1, 512, 1000) 56 | test_bsdf(16, 512, 1000) 57 | test_bsdf(1, 2048, 1000) 58 | -------------------------------------------------------------------------------- /raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * -------------------------------------------------------------------------------- /raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_raymarching', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'raymarching.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /raymarching/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | ''' 33 | Usage: 34 | 35 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) 36 | 37 | python setup.py install # build extensions and install (copy) to PATH. 38 | pip install . # ditto but better (e.g., dependency & metadata handling) 39 | 40 | python setup.py develop # build extensions and install (symbolic) to PATH. 41 | pip install -e . # ditto but better (e.g., dependency & metadata handling) 42 | 43 | ''' 44 | setup( 45 | name='raymarching', # package name, import this to use python API 46 | ext_modules=[ 47 | CUDAExtension( 48 | name='_raymarching', # extension name, import this to use CUDA API 49 | sources=[os.path.join(_src_path, 'src', f) for f in [ 50 | 'raymarching.cu', 51 | 'bindings.cpp', 52 | ]], 53 | extra_compile_args={ 54 | 'cxx': c_flags, 55 | 'nvcc': nvcc_flags, 56 | } 57 | ), 58 | ], 59 | cmdclass={ 60 | 'build_ext': BuildExtension, 61 | } 62 | ) -------------------------------------------------------------------------------- /raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)"); 8 | m.def("packbits", &packbits, "packbits (CUDA)"); 9 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 10 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); 11 | m.def("morton3D", &morton3D, "morton3D (CUDA)"); 12 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); 13 | // train 14 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 15 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 16 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 17 | // infer 18 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 19 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); 20 | } -------------------------------------------------------------------------------- /raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); 9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); 10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); 11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); 12 | void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res); 13 | 14 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises); 15 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 16 | void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 17 | 18 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises); 19 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | rich 3 | ninja 4 | numpy 5 | pandas 6 | scipy 7 | scikit-learn 8 | matplotlib 9 | opencv-python 10 | imageio 11 | imageio-ffmpeg 12 | 13 | torch 14 | torch-ema 15 | einops 16 | tensorboard 17 | tensorboardX 18 | 19 | # for grid_tcnn 20 | # git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 21 | 22 | # for stable-diffusion 23 | huggingface_hub 24 | diffusers==0.20.0 25 | accelerate 26 | transformers 27 | git+https://github.com/facebookresearch/xformers.git@main#egg=xformers 28 | triton 29 | 30 | # for dmtet and mesh export 31 | xatlas 32 | trimesh 33 | PyMCubes 34 | pymeshlab 35 | git+https://github.com/NVlabs/nvdiffrast/ 36 | open3d 37 | 38 | # for zero123 39 | carvekit-colab 40 | omegaconf 41 | pytorch-lightning 42 | taming-transformers-rom1504 43 | kornia 44 | git+https://github.com/openai/CLIP.git 45 | -------------------------------------------------------------------------------- /scripts/blender.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | # Simple script to show how to load our assets in Blender 11 | # 12 | # Open Blender 3.x, click the scripting tab and open this script 13 | # Run the script (Alt P) 14 | # Under the shading tab, you will see the shading network and environent probe (World node) 15 | # You can then render the model using the Cycles renderer 16 | import os 17 | import bpy 18 | import numpy as np 19 | 20 | # path to your mesh 21 | MESH_PATH = "path2mesh" 22 | 23 | RESOLUTION = 512 24 | SAMPLES = 64 25 | 26 | ################### Renderer settings ################### 27 | bpy.ops.file.pack_all() 28 | scene = bpy.context.scene 29 | scene.world.use_nodes = True 30 | scene.render.engine = 'CYCLES' 31 | scene.render.film_transparent = True 32 | scene.cycles.device = 'GPU' 33 | scene.cycles.samples = SAMPLES 34 | scene.cycles.max_bounces = 0 35 | scene.cycles.diffuse_bounces = 0 36 | scene.cycles.glossy_bounces = 0 37 | scene.cycles.transmission_bounces = 0 38 | scene.cycles.volume_bounces = 0 39 | scene.cycles.transparent_max_bounces = 8 40 | scene.cycles.use_denoising = True 41 | scene.render.resolution_x = RESOLUTION 42 | scene.render.resolution_y = RESOLUTION 43 | scene.render.resolution_percentage = 100 44 | 45 | ################### Image output ################### 46 | 47 | # PNG output with sRGB tonemapping 48 | scene.display_settings.display_device = 'None' 49 | scene.view_settings.view_transform = 'Standard' 50 | scene.view_settings.exposure = 0.0 51 | scene.view_settings.gamma = 1.0 52 | scene.render.image_settings.file_format = 'PNG' # set output format to .png 53 | 54 | # OpenEXR output, no tonemapping applied 55 | # scene.display_settings.display_device = 'None' 56 | # scene.view_settings.view_transform = 'Standard' 57 | # scene.view_settings.exposure = 0.0 58 | # scene.view_settings.gamma = 1.0 59 | # scene.render.image_settings.file_format = 'OPEN_EXR' 60 | 61 | 62 | ################### Import obj mesh ################### 63 | 64 | imported_object = bpy.ops.import_scene.obj(filepath=os.path.join(MESH_PATH, "mesh.obj"), axis_forward = '-Z', axis_up = 'Y') 65 | obj_object = bpy.context.selected_objects[0] 66 | 67 | ################### Fix material graph ################### 68 | 69 | # Get material node tree, find BSDF and specular texture 70 | material = obj_object.active_material 71 | bsdf = material.node_tree.nodes["Principled BSDF"] 72 | image_node_ks = bsdf.inputs["Specular"].links[0].from_node 73 | 74 | # Split the specular texture into metalness and roughness 75 | separate_node = material.node_tree.nodes.new(type="ShaderNodeSeparateRGB") 76 | separate_node.name="SeparateKs" 77 | material.node_tree.links.new(image_node_ks.outputs[0], separate_node.inputs[0]) 78 | material.node_tree.links.new(separate_node.outputs[2], bsdf.inputs["Metallic"]) 79 | material.node_tree.links.new(separate_node.outputs[1], bsdf.inputs["Roughness"]) 80 | 81 | if True: 82 | normal_map_node = bsdf.inputs["Normal"].links[0].from_node 83 | texture_n_node = normal_map_node.inputs["Color"].links[0].from_node 84 | material.node_tree.links.remove(normal_map_node.inputs["Color"].links[0]) 85 | normal_separate_node = material.node_tree.nodes.new(type="ShaderNodeSeparateRGB") 86 | normal_separate_node.name="SeparateNormal" 87 | normal_combine_node = material.node_tree.nodes.new(type="ShaderNodeCombineRGB") 88 | normal_combine_node.name="CombineNormal" 89 | 90 | normal_invert_node = material.node_tree.nodes.new(type="ShaderNodeMath") 91 | normal_invert_node.name="InvertNormal" 92 | normal_invert_node.operation='SUBTRACT' 93 | normal_invert_node.inputs[0].default_value = 1.0 94 | 95 | material.node_tree.links.new(texture_n_node.outputs[0], normal_separate_node.inputs['Image']) 96 | material.node_tree.links.new(normal_separate_node.outputs['R'], normal_combine_node.inputs['R']) 97 | material.node_tree.links.new(normal_separate_node.outputs['G'], normal_invert_node.inputs[1]) 98 | material.node_tree.links.new(normal_invert_node.outputs[0], normal_combine_node.inputs['G']) 99 | material.node_tree.links.new(normal_separate_node.outputs['B'], normal_combine_node.inputs['B']) 100 | material.node_tree.links.new(normal_combine_node.outputs[0], normal_map_node.inputs["Color"]) 101 | 102 | material.node_tree.links.remove(bsdf.inputs["Specular"].links[0]) 103 | 104 | # Set default values 105 | bsdf.inputs["Specular"].default_value = 0.5 106 | bsdf.inputs["Specular Tint"].default_value = 0.0 107 | bsdf.inputs["Sheen Tint"].default_value = 0.0 108 | bsdf.inputs["Clearcoat Roughness"].default_value = 0.0 109 | 110 | ################### Load HDR probe ################### 111 | 112 | texcoord = scene.world.node_tree.nodes.new(type="ShaderNodeTexCoord") 113 | mapping = scene.world.node_tree.nodes.new(type="ShaderNodeMapping") 114 | mapping.inputs['Rotation'].default_value = [0, 0, -np.pi*0.5] 115 | envmap = scene.world.node_tree.nodes.new(type="ShaderNodeTexEnvironment") 116 | envmap.image = bpy.data.images.load(os.path.join(MESH_PATH, "probe.hdr")) 117 | 118 | scene.world.node_tree.links.new(envmap.outputs['Color'], scene.world.node_tree.nodes['Background'].inputs['Color']) 119 | scene.world.node_tree.links.new(texcoord.outputs['Generated'], mapping.inputs['Vector']) 120 | scene.world.node_tree.links.new(mapping.outputs['Vector'], envmap.inputs['Vector']) 121 | -------------------------------------------------------------------------------- /scripts/install_ext.sh: -------------------------------------------------------------------------------- 1 | pip install ./raymarching 2 | pip install ./shencoder 3 | pip install ./freqencoder 4 | pip install ./gridencoder -------------------------------------------------------------------------------- /scripts/linear_app.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 4 | import os 5 | from diffusers import ( 6 | AutoencoderKL, 7 | StableDiffusionPipeline, 8 | ) 9 | from tqdm import tqdm 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | path = "path to images folder" 14 | imgs_path = os.listdir(path) 15 | 16 | imgs_128 = [] 17 | imgs_1024 = [] 18 | latents = [] 19 | 20 | vae = AutoencoderKL.from_pretrained( 21 | "madebyollin/sdxl-vae-fp16-fix", 22 | # "pretrained/SDXL/vae_fp16", 23 | # local_files_only=True, 24 | use_safetensors=True, 25 | torch_dtype=precision_t 26 | ) 27 | 28 | def encode_imgs(imgs): 29 | # imgs: [B, 3, H, W] 30 | imgs = 2 * imgs - 1 31 | posterior = vae.encode(imgs).latent_dist 32 | latents = posterior.sample() * vae.config.scaling_factor 33 | return latents 34 | 35 | 36 | with torch.no_grad(): 37 | for i in tqdm(range(1000)): 38 | img = cv2.imread(os.path.join(path, imgs_path[i])) 39 | img = cv2.cvtColor(img[..., :3], cv2.COLOR_BGR2RGB) / 255 40 | img = torch.tensor(img, dtype=torch.float32, device='cuda')[None, ...] 41 | img = img.permute(0, 3, 1, 2) 42 | img_1024 = F.interpolate(img, (1024, 1024), mode='bilinear', align_corners=False) 43 | img_128 = F.interpolate(img, (128, 128), mode='bilinear', align_corners=False) 44 | # imgs_1024.append(img_1024) 45 | imgs_128.append(img_128) 46 | 47 | with torch.cuda.amp.autocast(enabled=True): 48 | latent = encode_imgs(img_1024) 49 | latents.append(latent) 50 | torch.cuda.empty_cache() 51 | 52 | data = torch.cat(imgs_128, dim=0).permute(0, 2, 3, 1).reshape(-1, 3) # [N, 3] 53 | target = torch.cat(latents, dim=0).permute(0, 2, 3, 1).reshape(-1, 4) # [N, 4] 54 | data_ = torch.cat([data, torch.ones(data.shape[0], 1).cuda()], dim=1) 55 | target_ = torch.cat([target, torch.ones(target.shape[0], 1).cuda()], dim=1) 56 | 57 | coef_r2l = torch.linalg.inv(data_.T @ data_ + 0.001 * torch.eye(4).cuda()) @ (data_.T @ target) 58 | coef_l2r = (torch.linalg.inv(target_.T @ target_ + 0.001 * torch.eye(5).cuda())) @ (target_.T @ data) 59 | 60 | def normalize(img): 61 | img = (img - img.min()) / (img.max() - img.min()) 62 | return img 63 | 64 | 65 | save_path = "2d/vae_linear_approx/sdxl" 66 | os.makedirs(os.path.join(save_path, 'l2r'), exist_ok=True) 67 | os.makedirs(os.path.join(save_path, 'r2l'), exist_ok=True) 68 | for i in range(100): 69 | x = latents[i][0].permute(1, 2, 0) @ coef_l2r[:4, :] + coef_l2r[4, :] 70 | rgb_fit = (np.clip((x.detach().cpu().numpy()) * 255, 0, 255)).astype(np.uint8) 71 | rgb_gt = (imgs_128[i][0].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8) 72 | # plt.imshow(rgb_fit);plt.show() 73 | # plt.imshow(imgs_128[i][0].permute(1, 2, 0).detach().cpu().numpy());plt.show() 74 | 75 | y = imgs_128[i][0].permute(1, 2, 0) @ coef_r2l[:3, :] + coef_r2l[3, :] 76 | latent_fit = (normalize(y.detach().cpu().numpy()) * 255).astype(np.uint8) 77 | latent_gt = (normalize(latents[i][0].permute(1, 2, 0).detach().cpu().numpy()) * 255).astype(np.uint8) 78 | # plt.imshow(normalize(y.detach().cpu().numpy()));plt.show() 79 | # plt.imshow(normalize(latents[i][0].permute(1, 2, 0).detach().cpu().numpy()));plt.show() 80 | 81 | cv2.imwrite(os.path.join(save_path, 'l2r', f'{i:04d}_fit.png'), 82 | cv2.cvtColor(rgb_fit, cv2.COLOR_RGB2BGR)) 83 | cv2.imwrite(os.path.join(save_path, 'l2r', f'{i:04d}_gt.png'), 84 | cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2BGR)) 85 | cv2.imwrite(os.path.join(save_path, 'r2l', f'{i:04d}_fit.png'), 86 | cv2.cvtColor(latent_fit, cv2.COLOR_RGB2BGR)) 87 | cv2.imwrite(os.path.join(save_path, 'r2l', f'{i:04d}_gt.png'), 88 | cv2.cvtColor(latent_gt, cv2.COLOR_RGB2BGR)) 89 | -------------------------------------------------------------------------------- /shencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphere_harmonics import SHEncoder -------------------------------------------------------------------------------- /shencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | '-O3', '-std=c++14', 8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 9 | ] 10 | 11 | if os.name == "posix": 12 | c_flags = ['-O3', '-std=c++14'] 13 | elif os.name == "nt": 14 | c_flags = ['/O2', '/std:c++17'] 15 | 16 | # find cl.exe 17 | def find_cl_path(): 18 | import glob 19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 21 | if paths: 22 | return paths[0] 23 | 24 | # If cl.exe is not on path, try to find it. 25 | if os.system("where cl.exe >nul 2>nul") != 0: 26 | cl_path = find_cl_path() 27 | if cl_path is None: 28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 29 | os.environ["PATH"] += ";" + cl_path 30 | 31 | _backend = load(name='_sh_encoder', 32 | extra_cflags=c_flags, 33 | extra_cuda_cflags=nvcc_flags, 34 | sources=[os.path.join(_src_path, 'src', f) for f in [ 35 | 'shencoder.cu', 36 | 'bindings.cpp', 37 | ]], 38 | ) 39 | 40 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /shencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | '-O3', '-std=c++14', 9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 10 | ] 11 | 12 | if os.name == "posix": 13 | c_flags = ['-O3', '-std=c++14'] 14 | elif os.name == "nt": 15 | c_flags = ['/O2', '/std:c++17'] 16 | 17 | # find cl.exe 18 | def find_cl_path(): 19 | import glob 20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) 22 | if paths: 23 | return paths[0] 24 | 25 | # If cl.exe is not on path, try to find it. 26 | if os.system("where cl.exe >nul 2>nul") != 0: 27 | cl_path = find_cl_path() 28 | if cl_path is None: 29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") 30 | os.environ["PATH"] += ";" + cl_path 31 | 32 | setup( 33 | name='shencoder', # package name, import this to use python API 34 | ext_modules=[ 35 | CUDAExtension( 36 | name='_shencoder', # extension name, import this to use CUDA API 37 | sources=[os.path.join(_src_path, 'src', f) for f in [ 38 | 'shencoder.cu', 39 | 'bindings.cpp', 40 | ]], 41 | extra_compile_args={ 42 | 'cxx': c_flags, 43 | 'nvcc': nvcc_flags, 44 | } 45 | ), 46 | ], 47 | cmdclass={ 48 | 'build_ext': BuildExtension, 49 | } 50 | ) -------------------------------------------------------------------------------- /shencoder/sphere_harmonics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _shencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | class _sh_encoder(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision 17 | def forward(ctx, inputs, degree, calc_grad_inputs=False): 18 | # inputs: [B, input_dim], float in [-1, 1] 19 | # RETURN: [B, F], float 20 | 21 | inputs = inputs.contiguous() 22 | B, input_dim = inputs.shape # batch size, coord dim 23 | output_dim = degree ** 2 24 | 25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) 26 | 27 | if calc_grad_inputs: 28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) 29 | else: 30 | dy_dx = None 31 | 32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) 33 | 34 | ctx.save_for_backward(inputs, dy_dx) 35 | ctx.dims = [B, input_dim, degree] 36 | 37 | return outputs 38 | 39 | @staticmethod 40 | #@once_differentiable 41 | @custom_bwd 42 | def backward(ctx, grad): 43 | # grad: [B, C * C] 44 | 45 | inputs, dy_dx = ctx.saved_tensors 46 | 47 | if dy_dx is not None: 48 | grad = grad.contiguous() 49 | B, input_dim, degree = ctx.dims 50 | grad_inputs = torch.zeros_like(inputs) 51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) 52 | return grad_inputs, None, None 53 | else: 54 | return None, None, None 55 | 56 | 57 | 58 | sh_encode = _sh_encoder.apply 59 | 60 | 61 | class SHEncoder(nn.Module): 62 | def __init__(self, input_dim=3, degree=4): 63 | super().__init__() 64 | 65 | self.input_dim = input_dim # coord dims, must be 3 66 | self.degree = degree # 0 ~ 4 67 | self.output_dim = degree ** 2 68 | 69 | assert self.input_dim == 3, "SH encoder only support input dim == 3" 70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" 71 | 72 | def __repr__(self): 73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" 74 | 75 | def forward(self, inputs, size=1): 76 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 77 | # return: [..., degree^2] 78 | 79 | inputs = inputs / size # [-1, 1] 80 | 81 | prefix_shape = list(inputs.shape[:-1]) 82 | inputs = inputs.reshape(-1, self.input_dim) 83 | 84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad) 85 | outputs = outputs.reshape(prefix_shape + [self.output_dim]) 86 | 87 | return outputs -------------------------------------------------------------------------------- /shencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "shencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); 7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /shencoder/src/shencoder.h: -------------------------------------------------------------------------------- 1 | # pragma once 2 | 3 | #include 4 | #include 5 | 6 | // inputs: [B, D], float, in [-1, 1] 7 | // outputs: [B, F], float 8 | 9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); 10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); --------------------------------------------------------------------------------