├── .gitignore
├── 2d.py
├── LICENSE
├── README.md
├── assets
├── group_photo.png
├── group_photo_fan.png
├── nopgc.png
└── pgc.png
├── data
├── 0000.png
├── background.png
└── panda_spear_shield.obj
├── encoding.py
├── freqencoder
├── __init__.py
├── backend.py
├── freq.py
├── setup.py
└── src
│ ├── bindings.cpp
│ ├── freqencoder.cu
│ └── freqencoder.h
├── gridencoder
├── __init__.py
├── backend.py
├── grid.py
├── setup.py
└── src
│ ├── bindings.cpp
│ ├── gridencoder.cu
│ └── gridencoder.h
├── guidance
├── clip_utils.py
├── guidance_utils.py
├── sd.py
├── sdcontrolnet.py
├── sdxl.py
├── sdxl_controlnet.py
├── sdxl_vsd.py
└── zero123_utils.py
├── ldm
├── extras.py
├── guidance.py
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── plms.py
│ │ └── sampling_util.py
├── modules
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── evaluate
│ │ ├── adm_evaluator.py
│ │ ├── evaluate_perceptualsim.py
│ │ ├── frechet_video_distance.py
│ │ ├── ssim.py
│ │ └── torch_frechet_video_distance.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
├── thirdp
│ └── psp
│ │ ├── helpers.py
│ │ ├── id_loss.py
│ │ └── model_irse.py
└── util.py
├── main.py
├── nerf
├── fine
│ ├── dmtet.py
│ ├── dmtet_utils.py
│ ├── generate_tets.py
│ ├── neural_render.py
│ ├── perspective_camera.py
│ ├── rast_utils.py
│ ├── rasterizer.py
│ ├── rasterizer_mesh.py
│ ├── rasterizer_pbr.py
│ ├── tets
│ │ └── 256_compress.npz
│ └── trainer.py
├── provider.py
└── utils.py
├── others
└── nvdiffrec
│ ├── data
│ └── irrmaps
│ │ ├── README.txt
│ │ ├── aerodynamics_workshop_2k.hdr
│ │ ├── blocky_photo_studio_2k.hdr
│ │ ├── bsdf_256_256.bin
│ │ └── mud_road_puresky_4k.hdr
│ └── render
│ ├── light.py
│ ├── material.py
│ ├── mesh.py
│ ├── mlptexture.py
│ ├── obj.py
│ ├── regularizer.py
│ ├── render.py
│ ├── renderutils
│ ├── __init__.py
│ ├── bsdf.py
│ ├── c_src
│ │ ├── bsdf.cu
│ │ ├── bsdf.h
│ │ ├── common.cpp
│ │ ├── common.h
│ │ ├── cubemap.cu
│ │ ├── cubemap.h
│ │ ├── loss.cu
│ │ ├── loss.h
│ │ ├── mesh.cu
│ │ ├── mesh.h
│ │ ├── normal.cu
│ │ ├── normal.h
│ │ ├── tensor.h
│ │ ├── torch_bindings.cpp
│ │ ├── vec3f.h
│ │ └── vec4f.h
│ ├── loss.py
│ ├── ops.py
│ └── tests
│ │ ├── test_bsdf.py
│ │ ├── test_loss.py
│ │ ├── test_mesh.py
│ │ └── test_perf.py
│ ├── texture.py
│ └── util.py
├── raymarching
├── __init__.py
├── backend.py
├── raymarching.py
├── setup.py
└── src
│ ├── bindings.cpp
│ ├── raymarching.cu
│ └── raymarching.h
├── requirements.txt
├── scripts
├── blender.py
├── install_ext.sh
└── linear_app.py
└── shencoder
├── __init__.py
├── backend.py
├── setup.py
├── sphere_harmonics.py
└── src
├── bindings.cpp
├── shencoder.cu
└── shencoder.h
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | build/
3 | *.egg-info/
4 | *.so
5 |
6 | tmp*
7 | trial*/
8 | train*/
9 | .vs/
10 | run.sh
11 | TOKEN
12 | exp/
13 | *.pth
14 | *.out
15 | .idea
16 | 2d
17 | results/
18 | scripts/debug.py
19 | scripts/run*.sh
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping
2 | ### [[Paper]](https://arxiv.org/abs/2310.12474) | [[Project]](https://fudan-zvg.github.io/PGC-3D/)
3 |
4 | > [**Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping**](https://arxiv.org/abs/2310.12474),
5 | > Zijie Pan, [Jiachen Lu](https://victorllu.github.io/), [Xiatian Zhu](https://surrey-uplab.github.io/), [Li Zhang](https://lzrobots.github.io)
6 | > **ICLR 2024**
7 |
8 | **Official implementation of "Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping".**
9 |
10 |
11 | **PGC** (Pixel-wise Gradient Clipping) introduces a refined method to adapt traditional gradient clipping. By focusing on pixel-wise gradient magnitudes, it retains vital texture details. This approach acts as a versatile **plug-in**, seamlessly complementing existing **SDS and LDM-based 3D generative models**. The result is a marked improvement in high-resolution 3D texture synthesis.
12 |
13 | With PGC, users can:
14 | - Address and mitigate gradient-related challenges common in LDM, elevating the quality of 3D generation.
15 | - Employ the [**SDXL**](https://github.com/Stability-AI/generative-models) approach, previously not adaptable for 3D generation.
16 |
17 | This repo also offers an unified implementation for mesh optimization and reproduction of many SDS variants.
18 |
19 | # Install
20 |
21 | ```bash
22 | git clone https://github.com/fudan-zvg/PGC-3D.git
23 | cd PGC-3D
24 | ```
25 |
26 | **Huggingface token**:
27 | Some new models need access token.
28 | Create a file called `TOKEN` under this directory (i.e., `PGC-3D/TOKEN`)
29 | and copy your [token](https://huggingface.co/settings/tokens) into it.
30 |
31 | ### Install pytorch3d
32 |
33 | We use pytorch3d implementation for normal consistency loss, so we need first install pytorch3d.
34 | ```bash
35 | conda create -n pgc python=3.9
36 | conda activate pgc
37 | conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia
38 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath
39 | conda install pytorch3d -c pytorch3d
40 | ```
41 |
42 | ### Install with pip and build extension
43 |
44 | ```bash
45 | pip install -r requirements.txt
46 | bash scripts/install_ext.sh
47 | ```
48 |
49 | ### Tested environments
50 | * Ubuntu 20 with torch 1.13 & CUDA 11.7 on two A6000.
51 |
52 | # Usage
53 |
54 | This repo supports mesh optimization (Stage2 of Magic3D/Fantasia3D) with many methods including SDS, VSD, BGT, CSD, SSD loss and PBR material modeling.
55 | If you find error/bug in the reproduction, feel free to raise an issue.
56 | We assume a triangle mesh is provided.
57 |
58 | ```bash
59 | ### basic options
60 | # -O0 pixel-wise gradient clipping
61 | # -O1 Stage2 of Magic3D/Fantasia3D using DMTet
62 | # -O2 with mesh fixed
63 | # -O3 using an image as reference and Zero123 (not fully tested)
64 | # -O4 using SDXL
65 | # --pbr PBR modeling as nvdiffrec/Fantasia3D
66 | # --vsd VSD loss proposed by ProlificDreamer (only implemented for SDXL, so -O4 is needed)
67 | # --bgt BGT+ loss proposed by HiFA
68 | # --csd Classifier Score Distillation
69 | # --ssd Stable Score Distillation
70 | # --dreamtime an annealing time schedule proposed by DreamTime
71 | # --latent training in latent space as Latent-NeRF
72 | # --cfg_rescale suggested by AvatarStudio
73 | # --init_texture initialize texture if mesh has vertex color
74 | # --bs 4 batch size 4 per gpu
75 |
76 | ### sample commands
77 | # all the time is tested on two A6000
78 | # text-to-texture using SDXL with depth controlnet (46G, 36min)
79 | python main.py -O0 -O2 -O4 --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/sdxl --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
80 |
81 | # PBR modeling (48G, 36min)
82 | python main.py -O0 -O2 -O4 --pbr --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/pbr --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
83 |
84 | # VSD loss (42G, 60min)
85 | python main.py -O0 -O2 -O4 --vsd --guidance_scale 7.5 --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/vsd --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
86 |
87 | # optimize the shape using normal-SDS as Fantasia3D (46G, 33min)
88 | python main.py -O0 -O4 --only_normal --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/normal_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
89 | # if finetune, add depth control to stablize the shape (46G, 36min)
90 | python main.py -O0 -O4 --only_normal --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/normal_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
91 |
92 | # optimize both shape and texture using RGB-SDS as Magic3D (46G, 36min)
93 | python main.py -O0 -O4 --no_normal --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/rgb_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
94 |
95 | # optimize both shape and texture using both normal-SDS and RGB-SDS (47G, 74min)
96 | python main.py -O0 -O4 --guidance "controlnet" --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/nrm_rgb_sds --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
97 |
98 | # We also support multi-controlnets, e.g. depth + shuffle control with a reference image (only Stable Diffusion v1.5)
99 | # 20G, 22min
100 | python main.py -O0 -O2 --ref_path "data/panda_reference.png" --control_type "depth" "shuffle" --control_scale 0.7 0.3 --text "A panda is dressed in armor, holding a spear in one hand and a shield in the other, realistic" --workspace panda/sd_depth_shuffle --gpus "0,1" --mesh_path "data/panda_spear_shield.obj"
101 | ```
102 |
103 | ## Results
104 | Incorporating the proficient and potent **PGC** implementation into SDXL guidance has led to notable advancements in 3D generation results.
105 |
106 | #### [Fantasia3D](https://github.com/Gorilla-Lab-SCUT/Fantasia3D)
107 |
108 |
109 | #### Ours
110 |
111 |
112 | ## Tips for SDXL
113 | - Using [sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)
114 | - [Controlnet](https://huggingface.co/diffusers/controlnet-depth-sdxl-1.0) may be important to make training easier to converge
115 |
116 | ## BibTeX
117 | If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
118 | ```bibtex
119 | @inproceedings{pan2024enhancing,
120 | title={Enhancing High-Resolution 3D Generation through Pixel-wise Gradient Clipping},
121 | author={Pan, Zijie and Lu, Jiachen and Zhu, Xiatian and Zhang, Li},
122 | booktitle={International Conference on Learning Representations (ICLR)},
123 | year={2024}
124 | }
125 | ```
126 |
--------------------------------------------------------------------------------
/assets/group_photo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/group_photo.png
--------------------------------------------------------------------------------
/assets/group_photo_fan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/group_photo_fan.png
--------------------------------------------------------------------------------
/assets/nopgc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/nopgc.png
--------------------------------------------------------------------------------
/assets/pgc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/assets/pgc.png
--------------------------------------------------------------------------------
/data/0000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/data/0000.png
--------------------------------------------------------------------------------
/data/background.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/data/background.png
--------------------------------------------------------------------------------
/encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FreqEncoder_torch(nn.Module):
6 | def __init__(self, input_dim, max_freq_log2, N_freqs,
7 | log_sampling=True, include_input=True,
8 | periodic_fns=(torch.sin, torch.cos)):
9 |
10 | super().__init__()
11 |
12 | self.input_dim = input_dim
13 | self.include_input = include_input
14 | self.periodic_fns = periodic_fns
15 |
16 | self.output_dim = 0
17 | if self.include_input:
18 | self.output_dim += self.input_dim
19 |
20 | self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
21 |
22 | if log_sampling:
23 | self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
24 | else:
25 | self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
26 |
27 | self.freq_bands = self.freq_bands.numpy().tolist()
28 |
29 | def forward(self, input, **kwargs):
30 |
31 | out = []
32 | if self.include_input:
33 | out.append(input)
34 |
35 | for i in range(len(self.freq_bands)):
36 | freq = self.freq_bands[i]
37 | for p_fn in self.periodic_fns:
38 | out.append(p_fn(input * freq))
39 |
40 | out = torch.cat(out, dim=-1)
41 |
42 | return out
43 |
44 | def get_encoder(encoding, input_dim=3,
45 | multires=6,
46 | degree=4,
47 | num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
48 | **kwargs):
49 |
50 | if encoding == 'None':
51 | return lambda x, **kwargs: x, input_dim
52 |
53 | elif encoding == 'frequency_torch':
54 | encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
55 |
56 | elif encoding == 'frequency': # CUDA implementation, faster than torch.
57 | from freqencoder import FreqEncoder
58 | encoder = FreqEncoder(input_dim=input_dim, degree=multires)
59 |
60 | elif encoding == 'sphere_harmonics':
61 | from shencoder import SHEncoder
62 | encoder = SHEncoder(input_dim=input_dim, degree=degree)
63 |
64 | elif encoding == 'hashgrid':
65 | from gridencoder import GridEncoder
66 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
67 |
68 | elif encoding == 'tiledgrid':
69 | from gridencoder import GridEncoder
70 | encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
71 |
72 | else:
73 | raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
74 |
75 | return encoder, encoder.output_dim
--------------------------------------------------------------------------------
/freqencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .freq import FreqEncoder
--------------------------------------------------------------------------------
/freqencoder/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | '-use_fast_math'
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | _backend = load(name='_freqencoder',
33 | extra_cflags=c_flags,
34 | extra_cuda_cflags=nvcc_flags,
35 | sources=[os.path.join(_src_path, 'src', f) for f in [
36 | 'freqencoder.cu',
37 | 'bindings.cpp',
38 | ]],
39 | )
40 |
41 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/freqencoder/freq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Function
6 | from torch.autograd.function import once_differentiable
7 | from torch.cuda.amp import custom_bwd, custom_fwd
8 |
9 | try:
10 | import _freqencoder as _backend
11 | except ImportError:
12 | from .backend import _backend
13 |
14 |
15 | class _freq_encoder(Function):
16 | @staticmethod
17 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
18 | def forward(ctx, inputs, degree, output_dim):
19 | # inputs: [B, input_dim], float
20 | # RETURN: [B, F], float
21 |
22 | if not inputs.is_cuda: inputs = inputs.cuda()
23 | inputs = inputs.contiguous()
24 |
25 | B, input_dim = inputs.shape # batch size, coord dim
26 |
27 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
28 |
29 | _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
30 |
31 | ctx.save_for_backward(inputs, outputs)
32 | ctx.dims = [B, input_dim, degree, output_dim]
33 |
34 | return outputs
35 |
36 | @staticmethod
37 | #@once_differentiable
38 | @custom_bwd
39 | def backward(ctx, grad):
40 | # grad: [B, C * C]
41 |
42 | grad = grad.contiguous()
43 | inputs, outputs = ctx.saved_tensors
44 | B, input_dim, degree, output_dim = ctx.dims
45 |
46 | grad_inputs = torch.zeros_like(inputs)
47 | _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
48 |
49 | return grad_inputs, None, None
50 |
51 |
52 | freq_encode = _freq_encoder.apply
53 |
54 |
55 | class FreqEncoder(nn.Module):
56 | def __init__(self, input_dim=3, degree=4):
57 | super().__init__()
58 |
59 | self.input_dim = input_dim
60 | self.degree = degree
61 | self.output_dim = input_dim + input_dim * 2 * degree
62 |
63 | def __repr__(self):
64 | return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
65 |
66 | def forward(self, inputs, **kwargs):
67 | # inputs: [..., input_dim]
68 | # return: [..., ]
69 |
70 | prefix_shape = list(inputs.shape[:-1])
71 | inputs = inputs.reshape(-1, self.input_dim)
72 |
73 | outputs = freq_encode(inputs, self.degree, self.output_dim)
74 |
75 | outputs = outputs.reshape(prefix_shape + [self.output_dim])
76 |
77 | return outputs
--------------------------------------------------------------------------------
/freqencoder/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | '-use_fast_math'
11 | ]
12 |
13 | if os.name == "posix":
14 | c_flags = ['-O3', '-std=c++14']
15 | elif os.name == "nt":
16 | c_flags = ['/O2', '/std:c++17']
17 |
18 | # find cl.exe
19 | def find_cl_path():
20 | import glob
21 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
22 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
23 | if paths:
24 | return paths[0]
25 |
26 | # If cl.exe is not on path, try to find it.
27 | if os.system("where cl.exe >nul 2>nul") != 0:
28 | cl_path = find_cl_path()
29 | if cl_path is None:
30 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
31 | os.environ["PATH"] += ";" + cl_path
32 |
33 | setup(
34 | name='freqencoder', # package name, import this to use python API
35 | ext_modules=[
36 | CUDAExtension(
37 | name='_freqencoder', # extension name, import this to use CUDA API
38 | sources=[os.path.join(_src_path, 'src', f) for f in [
39 | 'freqencoder.cu',
40 | 'bindings.cpp',
41 | ]],
42 | extra_compile_args={
43 | 'cxx': c_flags,
44 | 'nvcc': nvcc_flags,
45 | }
46 | ),
47 | ],
48 | cmdclass={
49 | 'build_ext': BuildExtension,
50 | }
51 | )
--------------------------------------------------------------------------------
/freqencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "freqencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
7 | m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
8 | }
--------------------------------------------------------------------------------
/freqencoder/src/freqencoder.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 |
10 | #include
11 | #include
12 |
13 | #include
14 |
15 |
16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
17 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
18 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
19 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
20 |
21 | inline constexpr __device__ float PI() { return 3.141592653589793f; }
22 |
23 | template
24 | __host__ __device__ T div_round_up(T val, T divisor) {
25 | return (val + divisor - 1) / divisor;
26 | }
27 |
28 | // inputs: [B, D]
29 | // outputs: [B, C], C = D + D * deg * 2
30 | __global__ void kernel_freq(
31 | const float * __restrict__ inputs,
32 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
33 | float * outputs
34 | ) {
35 | // parallel on per-element
36 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
37 | if (t >= B * C) return;
38 |
39 | // get index
40 | const uint32_t b = t / C;
41 | const uint32_t c = t - b * C; // t % C;
42 |
43 | // locate
44 | inputs += b * D;
45 | outputs += t;
46 |
47 | // write self
48 | if (c < D) {
49 | outputs[0] = inputs[c];
50 | // write freq
51 | } else {
52 | const uint32_t col = c / D - 1;
53 | const uint32_t d = c % D;
54 | const uint32_t freq = col / 2;
55 | const float phase_shift = (col % 2) * (PI() / 2);
56 | outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
57 | }
58 | }
59 |
60 | // grad: [B, C], C = D + D * deg * 2
61 | // outputs: [B, C]
62 | // grad_inputs: [B, D]
63 | __global__ void kernel_freq_backward(
64 | const float * __restrict__ grad,
65 | const float * __restrict__ outputs,
66 | uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
67 | float * grad_inputs
68 | ) {
69 | // parallel on per-element
70 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
71 | if (t >= B * D) return;
72 |
73 | const uint32_t b = t / D;
74 | const uint32_t d = t - b * D; // t % D;
75 |
76 | // locate
77 | grad += b * C;
78 | outputs += b * C;
79 | grad_inputs += t;
80 |
81 | // register
82 | float result = grad[d];
83 | grad += D;
84 | outputs += D;
85 |
86 | for (uint32_t f = 0; f < deg; f++) {
87 | result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
88 | grad += 2 * D;
89 | outputs += 2 * D;
90 | }
91 |
92 | // write
93 | grad_inputs[0] = result;
94 | }
95 |
96 |
97 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
98 | CHECK_CUDA(inputs);
99 | CHECK_CUDA(outputs);
100 |
101 | CHECK_CONTIGUOUS(inputs);
102 | CHECK_CONTIGUOUS(outputs);
103 |
104 | CHECK_IS_FLOATING(inputs);
105 | CHECK_IS_FLOATING(outputs);
106 |
107 | static constexpr uint32_t N_THREADS = 128;
108 |
109 | kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr());
110 | }
111 |
112 |
113 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
114 | CHECK_CUDA(grad);
115 | CHECK_CUDA(outputs);
116 | CHECK_CUDA(grad_inputs);
117 |
118 | CHECK_CONTIGUOUS(grad);
119 | CHECK_CONTIGUOUS(outputs);
120 | CHECK_CONTIGUOUS(grad_inputs);
121 |
122 | CHECK_IS_FLOATING(grad);
123 | CHECK_IS_FLOATING(outputs);
124 | CHECK_IS_FLOATING(grad_inputs);
125 |
126 | static constexpr uint32_t N_THREADS = 128;
127 |
128 | kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr());
129 | }
--------------------------------------------------------------------------------
/freqencoder/src/freqencoder.h:
--------------------------------------------------------------------------------
1 | # pragma once
2 |
3 | #include
4 | #include
5 |
6 | // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
7 | void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
8 |
9 | // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
10 | void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
--------------------------------------------------------------------------------
/gridencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .grid import GridEncoder
--------------------------------------------------------------------------------
/gridencoder/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21 | if paths:
22 | return paths[0]
23 |
24 | # If cl.exe is not on path, try to find it.
25 | if os.system("where cl.exe >nul 2>nul") != 0:
26 | cl_path = find_cl_path()
27 | if cl_path is None:
28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29 | os.environ["PATH"] += ";" + cl_path
30 |
31 | _backend = load(name='_grid_encoder',
32 | extra_cflags=c_flags,
33 | extra_cuda_cflags=nvcc_flags,
34 | sources=[os.path.join(_src_path, 'src', f) for f in [
35 | 'gridencoder.cu',
36 | 'bindings.cpp',
37 | ]],
38 | )
39 |
40 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/gridencoder/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | setup(
33 | name='gridencoder', # package name, import this to use python API
34 | ext_modules=[
35 | CUDAExtension(
36 | name='_gridencoder', # extension name, import this to use CUDA API
37 | sources=[os.path.join(_src_path, 'src', f) for f in [
38 | 'gridencoder.cu',
39 | 'bindings.cpp',
40 | ]],
41 | extra_compile_args={
42 | 'cxx': c_flags,
43 | 'nvcc': nvcc_flags,
44 | }
45 | ),
46 | ],
47 | cmdclass={
48 | 'build_ext': BuildExtension,
49 | }
50 | )
--------------------------------------------------------------------------------
/gridencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "gridencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
8 | m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
9 | m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)");
10 | }
--------------------------------------------------------------------------------
/gridencoder/src/gridencoder.h:
--------------------------------------------------------------------------------
1 | #ifndef _HASH_ENCODE_H
2 | #define _HASH_ENCODE_H
3 |
4 | #include
5 | #include
6 |
7 | // inputs: [B, D], float, in [0, 1]
8 | // embeddings: [sO, C], float
9 | // offsets: [L + 1], uint32_t
10 | // outputs: [B, L * C], float
11 | // H: base resolution
12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
14 |
15 | void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
16 | void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);
17 |
18 | #endif
--------------------------------------------------------------------------------
/guidance/clip_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import torchvision.transforms as T
6 | import torchvision.transforms.functional as TF
7 |
8 | import clip
9 |
10 |
11 | class CLIP(nn.Module):
12 | def __init__(self, device, **kwargs):
13 | super().__init__()
14 |
15 | self.device = device
16 |
17 | self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
18 |
19 | # image augmentation
20 | self.aug = T.Compose([
21 | T.Resize((224, 224)),
22 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
23 | ])
24 |
25 | # self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10))
26 |
27 | def get_text_embeds(self, prompt, negative_prompt, **kwargs):
28 | # NOTE: negative_prompt is ignored for CLIP.
29 |
30 | text = clip.tokenize(prompt).to(self.device)
31 | text_z = self.clip_model.encode_text(text)
32 | text_z = text_z / text_z.norm(dim=-1, keepdim=True)
33 |
34 | return text_z
35 |
36 | def get_image_embeds(self, pred_rgb):
37 | pred_rgb = self.aug(pred_rgb)
38 |
39 | image_z = self.clip_model.encode_image(pred_rgb)
40 | image_z = image_z / image_z.norm(dim=-1, keepdim=True)
41 |
42 | return image_z
43 |
44 | def train_step(self, text_z, pred_rgb, **kwargs):
45 | pred_rgb = self.aug(pred_rgb)
46 |
47 | image_z = self.clip_model.encode_image(pred_rgb)
48 | image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
49 |
50 | loss = - (image_z * text_z).sum(-1).mean()
51 |
52 | return loss
53 |
54 |
55 | if __name__ == '__main__':
56 | import os
57 | import numpy as np
58 | import torch
59 | from skimage.io import imread, imsave
60 | import argparse
61 |
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--input_path', type=str, default='others/data/compare/spiderman')
64 | parser.add_argument('--data_path', type=str, default='others/output/compare/spiderman')
65 | parser.add_argument('--long_image', action='store_true')
66 | args = parser.parse_args()
67 |
68 | model = CLIP("cuda")
69 | image_size = 256
70 |
71 | input_images = []
72 | for img_path in os.listdir(args.input_path):
73 | img = imread(os.path.join(args.input_path, img_path)) / 255
74 | if img.shape[-1] == 4:
75 | mask = img[..., 3:]
76 | img = img[..., :3] * mask + (1 - mask)
77 | img = torch.tensor(img[None, ..., :3], dtype=torch.float)
78 | input_images.append(img)
79 |
80 | gen_images = []
81 | for img_path in os.listdir(args.data_path):
82 | if '.png' in img_path:
83 | img = imread(os.path.join(args.data_path, img_path)) / 255
84 | if args.long_image:
85 | for index in range(0, 16):
86 | rgb = np.copy(img[:, index * image_size:(index + 1) * image_size, :])
87 | rgb = torch.tensor(rgb[None, ..., :3], dtype=torch.float)
88 | gen_images.append(rgb)
89 | else:
90 | rgb = np.copy(img)
91 | rgb = torch.tensor(rgb[None, ..., :3], dtype=torch.float)
92 | rgb = F.interpolate(
93 | rgb.permute(0, 3, 1, 2), (256, 256), mode="bilinear", align_corners=False
94 | ).permute(0, 2, 3, 1)
95 | gen_images.append(rgb)
96 |
97 | input_images = torch.cat(input_images, dim=0).permute(0, 3, 1, 2).cuda()
98 | gen_images = torch.cat(gen_images, dim=0).permute(0, 3, 1, 2).cuda()
99 |
100 | input_z = model.get_image_embeds(input_images)
101 | gen_z = model.get_image_embeds(gen_images)
102 | similarity = input_z[:, None, ...] * gen_z[None, :, ...]
103 | print(similarity.sum(-1).mean())
104 |
--------------------------------------------------------------------------------
/guidance/guidance_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.cuda.amp import custom_bwd, custom_fwd
3 | import numpy as np
4 |
5 |
6 | class SpecifyGradient(torch.autograd.Function):
7 | @staticmethod
8 | @custom_fwd
9 | def forward(ctx, input_tensor, gt_grad):
10 | ctx.save_for_backward(gt_grad)
11 | # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
12 | return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
13 |
14 | @staticmethod
15 | @custom_bwd
16 | def backward(ctx, grad_scale):
17 | gt_grad, = ctx.saved_tensors
18 | gt_grad = gt_grad * grad_scale
19 | return gt_grad, None
20 |
21 |
22 | def seed_everything(seed):
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed(seed)
25 | # torch.backends.cudnn.deterministic = True
26 | # torch.backends.cudnn.benchmark = True
27 |
28 |
29 | def w_star(t, m1=800, m2=500, s1=300, s2=100):
30 | # max time 1000
31 | r = np.ones_like(t) * 1.0
32 | r[t > m1] = np.exp(-((t[t > m1] - m1) ** 2) / (2 * s1 * s1))
33 | r[t < m2] = np.exp(-((t[t < m2] - m2) ** 2) / (2 * s2 * s2))
34 | return r
35 |
36 |
37 | def precompute_prior(T=1000, min_t=200, max_t=800):
38 | ts = np.arange(T)
39 | prior = w_star(ts)[min_t:max_t]
40 | prior = prior / prior.sum()
41 | prior = prior[::-1].cumsum()[::-1]
42 | return prior, min_t
43 |
44 |
45 | def time_prioritize(step_ratio, time_prior, min_t=200):
46 | return np.abs(time_prior - step_ratio).argmin() + min_t
47 |
48 |
49 | def noise_norm(eps):
50 | # [B, 3, H, W]
51 | return torch.sqrt(torch.square(eps).sum(dim=[1, 2, 3]))
52 |
53 |
54 | if __name__ == '__main__':
55 | prior, _ = precompute_prior()
56 | t = time_prioritize(0.5, prior)
57 |
58 |
59 |
--------------------------------------------------------------------------------
/ldm/extras.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from omegaconf import OmegaConf
3 | import torch
4 | from ldm.util import instantiate_from_config
5 | import logging
6 | from contextlib import contextmanager
7 |
8 | from contextlib import contextmanager
9 | import logging
10 |
11 | @contextmanager
12 | def all_logging_disabled(highest_level=logging.CRITICAL):
13 | """
14 | A context manager that will prevent any logging messages
15 | triggered during the body from being processed.
16 |
17 | :param highest_level: the maximum logging level in use.
18 | This would only need to be changed if a custom level greater than CRITICAL
19 | is defined.
20 |
21 | https://gist.github.com/simon-weber/7853144
22 | """
23 | # two kind-of hacks here:
24 | # * can't get the highest logging level in effect => delegate to the user
25 | # * can't get the current module-level override => use an undocumented
26 | # (but non-private!) interface
27 |
28 | previous_level = logging.root.manager.disable
29 |
30 | logging.disable(highest_level)
31 |
32 | try:
33 | yield
34 | finally:
35 | logging.disable(previous_level)
36 |
37 | def load_training_dir(train_dir, device, epoch="last"):
38 | """Load a checkpoint and config from training directory"""
39 | train_dir = Path(train_dir)
40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
42 | config = list(train_dir.rglob(f"*-project.yaml"))
43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
44 | if len(config) > 1:
45 | print(f"found {len(config)} matching config files")
46 | config = sorted(config)[-1]
47 | print(f"selecting {config}")
48 | else:
49 | config = config[0]
50 |
51 |
52 | config = OmegaConf.load(config)
53 | return load_model_from_config(config, ckpt[0], device)
54 |
55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False):
56 | """Loads a model from config and a ckpt
57 | if config is a path will use omegaconf to load
58 | """
59 | if isinstance(config, (str, Path)):
60 | config = OmegaConf.load(config)
61 |
62 | with all_logging_disabled():
63 | print(f"Loading model from {ckpt}")
64 | pl_sd = torch.load(ckpt, map_location="cpu")
65 | global_step = pl_sd["global_step"]
66 | sd = pl_sd["state_dict"]
67 | model = instantiate_from_config(config.model)
68 | m, u = model.load_state_dict(sd, strict=False)
69 | if len(m) > 0 and verbose:
70 | print("missing keys:")
71 | print(m)
72 | if len(u) > 0 and verbose:
73 | print("unexpected keys:")
74 | model.to(device)
75 | model.eval()
76 | model.cond_stage_model.device = device
77 | return model
--------------------------------------------------------------------------------
/ldm/guidance.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | from scipy import interpolate
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | from IPython.display import clear_output
7 | import abc
8 |
9 |
10 | class GuideModel(torch.nn.Module, abc.ABC):
11 | def __init__(self) -> None:
12 | super().__init__()
13 |
14 | @abc.abstractmethod
15 | def preprocess(self, x_img):
16 | pass
17 |
18 | @abc.abstractmethod
19 | def compute_loss(self, inp):
20 | pass
21 |
22 |
23 | class Guider(torch.nn.Module):
24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
25 | """Apply classifier guidance
26 |
27 | Specify a guidance scale as either a scalar
28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g.
29 | [(0, 10), (0.5, 20), (1, 50)]
30 | """
31 | super().__init__()
32 | self.sampler = sampler
33 | self.index = 0
34 | self.show = verbose
35 | self.guide_model = guide_model
36 | self.history = []
37 |
38 | if isinstance(scale, (Tuple, List)):
39 | times = np.array([x[0] for x in scale])
40 | values = np.array([x[1] for x in scale])
41 | self.scale_schedule = {"times": times, "values": values}
42 | else:
43 | self.scale_schedule = float(scale)
44 |
45 | self.ddim_timesteps = sampler.ddim_timesteps
46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
47 |
48 |
49 | def get_scales(self):
50 | if isinstance(self.scale_schedule, float):
51 | return len(self.ddim_timesteps)*[self.scale_schedule]
52 |
53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
55 | return interpolater(fractional_steps)
56 |
57 | def modify_score(self, model, e_t, x, t, c):
58 |
59 | # TODO look up index by t
60 | scale = self.get_scales()[self.index]
61 |
62 | if (scale == 0):
63 | return e_t
64 |
65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
66 | with torch.enable_grad():
67 | x_in = x.detach().requires_grad_(True)
68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
70 |
71 | inp = self.guide_model.preprocess(x_img)
72 | loss = self.guide_model.compute_loss(inp)
73 | grads = torch.autograd.grad(loss.sum(), x_in)[0]
74 | correction = grads * scale
75 |
76 | if self.show:
77 | clear_output(wait=True)
78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
81 | plt.axis('off')
82 | plt.show()
83 | plt.imshow(correction[0][0].detach().cpu())
84 | plt.axis('off')
85 | plt.show()
86 |
87 |
88 | e_t_mod = e_t - sqrt_1ma*correction
89 | if self.show:
90 | fig, axs = plt.subplots(1, 3)
91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
94 | plt.show()
95 | self.index += 1
96 | return e_t_mod
--------------------------------------------------------------------------------
/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def renorm_thresholding(x0, value):
15 | # renorm
16 | pred_max = x0.max()
17 | pred_min = x0.min()
18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
20 |
21 | s = torch.quantile(
22 | rearrange(pred_x0, 'b ... -> b (...)').abs(),
23 | value,
24 | dim=-1
25 | )
26 | s.clamp_(min=1.0)
27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
28 |
29 | # clip by threshold
30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
31 |
32 | # temporary hack: numpy on cpu
33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device)
35 |
36 | # re.renorm
37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
39 | return pred_x0
40 |
41 |
42 | def norm_thresholding(x0, value):
43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
44 | return x0 * (value / s)
45 |
46 |
47 | def spatial_norm_thresholding(x0, value):
48 | # b c h w
49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
50 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/evaluate/frechet_video_distance.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python2, python3
17 | """Minimal Reference implementation for the Frechet Video Distance (FVD).
18 |
19 | FVD is a metric for the quality of video generation models. It is inspired by
20 | the FID (Frechet Inception Distance) used for images, but uses a different
21 | embedding to be better suitable for videos.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 |
29 | import six
30 | import tensorflow.compat.v1 as tf
31 | import tensorflow_gan as tfgan
32 | import tensorflow_hub as hub
33 |
34 |
35 | def preprocess(videos, target_resolution):
36 | """Runs some preprocessing on the videos for I3D model.
37 |
38 | Args:
39 | videos: [batch_size, num_frames, height, width, depth] The videos to be
40 | preprocessed. We don't care about the specific dtype of the videos, it can
41 | be anything that tf.image.resize_bilinear accepts. Values are expected to
42 | be in the range 0-255.
43 | target_resolution: (width, height): target video resolution
44 |
45 | Returns:
46 | videos: [batch_size, num_frames, height, width, depth]
47 | """
48 | videos_shape = list(videos.shape)
49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
52 | output_videos = tf.reshape(resized_videos, target_shape)
53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
54 | return scaled_videos
55 |
56 |
57 | def _is_in_graph(tensor_name):
58 | """Checks whether a given tensor does exists in the graph."""
59 | try:
60 | tf.get_default_graph().get_tensor_by_name(tensor_name)
61 | except KeyError:
62 | return False
63 | return True
64 |
65 |
66 | def create_id3_embedding(videos,warmup=False,batch_size=16):
67 | """Embeds the given videos using the Inflated 3D Convolution ne twork.
68 |
69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the
70 | first call.
71 |
72 | Args:
73 | videos: [batch_size, num_frames, height=224, width=224, depth=3].
74 | Expected range is [-1, 1].
75 |
76 | Returns:
77 | embedding: [batch_size, embedding_size]. embedding_size depends
78 | on the model used.
79 |
80 | Raises:
81 | ValueError: when a provided embedding_layer is not supported.
82 | """
83 |
84 | # batch_size = 16
85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
86 |
87 |
88 | # Making sure that we import the graph separately for
89 | # each different input video tensor.
90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
91 | videos.name).replace(":", "_")
92 |
93 |
94 |
95 | assert_ops = [
96 | tf.Assert(
97 | tf.reduce_max(videos) <= 1.001,
98 | ["max value in frame is > 1", videos]),
99 | tf.Assert(
100 | tf.reduce_min(videos) >= -1.001,
101 | ["min value in frame is < -1", videos]),
102 | tf.assert_equal(
103 | tf.shape(videos)[0],
104 | batch_size, ["invalid frame batch size: ",
105 | tf.shape(videos)],
106 | summarize=6),
107 | ]
108 | with tf.control_dependencies(assert_ops):
109 | videos = tf.identity(videos)
110 |
111 | module_scope = "%s_apply_default/" % module_name
112 |
113 | # To check whether the module has already been loaded into the graph, we look
114 | # for a given tensor name. If this tensor name exists, we assume the function
115 | # has been called before and the graph was imported. Otherwise we import it.
116 | # Note: in theory, the tensor could exist, but have wrong shapes.
117 | # This will happen if create_id3_embedding is called with a frames_placehoder
118 | # of wrong size/batch size, because even though that will throw a tf.Assert
119 | # on graph-execution time, it will insert the tensor (with wrong shape) into
120 | # the graph. This is why we need the following assert.
121 | if warmup:
122 | video_batch_size = int(videos.shape[0])
123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
125 | if not _is_in_graph(tensor_name):
126 | i3d_model = hub.Module(module_spec, name=module_name)
127 | i3d_model(videos)
128 |
129 | # gets the kinetics-i3d-400-logits layer
130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
132 | return tensor
133 |
134 |
135 | def calculate_fvd(real_activations,
136 | generated_activations):
137 | """Returns a list of ops that compute metrics as funcs of activations.
138 |
139 | Args:
140 | real_activations: [num_samples, embedding_size]
141 | generated_activations: [num_samples, embedding_size]
142 |
143 | Returns:
144 | A scalar that contains the requested FVD.
145 | """
146 | return tfgan.eval.frechet_classifier_distance_from_activations(
147 | real_activations, generated_activations)
148 |
--------------------------------------------------------------------------------
/ldm/modules/evaluate/ssim.py:
--------------------------------------------------------------------------------
1 | # MIT Licence
2 |
3 | # Methods to predict the SSIM, taken from
4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
5 |
6 | from math import exp
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 |
12 | def gaussian(window_size, sigma):
13 | gauss = torch.Tensor(
14 | [
15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
16 | for x in range(window_size)
17 | ]
18 | )
19 | return gauss / gauss.sum()
20 |
21 |
22 | def create_window(window_size, channel):
23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
25 | window = Variable(
26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous()
27 | )
28 | return window
29 |
30 |
31 | def _ssim(
32 | img1, img2, window, window_size, channel, mask=None, size_average=True
33 | ):
34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
36 |
37 | mu1_sq = mu1.pow(2)
38 | mu2_sq = mu2.pow(2)
39 | mu1_mu2 = mu1 * mu2
40 |
41 | sigma1_sq = (
42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
43 | - mu1_sq
44 | )
45 | sigma2_sq = (
46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
47 | - mu2_sq
48 | )
49 | sigma12 = (
50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
51 | - mu1_mu2
52 | )
53 |
54 | C1 = (0.01) ** 2
55 | C2 = (0.03) ** 2
56 |
57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
59 | )
60 |
61 | if not (mask is None):
62 | b = mask.size(0)
63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
65 | dim=1
66 | ).clamp(min=1)
67 | return ssim_map
68 |
69 | import pdb
70 |
71 | pdb.set_trace
72 |
73 | if size_average:
74 | return ssim_map.mean()
75 | else:
76 | return ssim_map.mean(1).mean(1).mean(1)
77 |
78 |
79 | class SSIM(torch.nn.Module):
80 | def __init__(self, window_size=11, size_average=True):
81 | super(SSIM, self).__init__()
82 | self.window_size = window_size
83 | self.size_average = size_average
84 | self.channel = 1
85 | self.window = create_window(window_size, self.channel)
86 |
87 | def forward(self, img1, img2, mask=None):
88 | (_, channel, _, _) = img1.size()
89 |
90 | if (
91 | channel == self.channel
92 | and self.window.data.type() == img1.data.type()
93 | ):
94 | window = self.window
95 | else:
96 | window = create_window(self.window_size, channel)
97 |
98 | if img1.is_cuda:
99 | window = window.cuda(img1.get_device())
100 | window = window.type_as(img1)
101 |
102 | self.window = window
103 | self.channel = channel
104 |
105 | return _ssim(
106 | img1,
107 | img2,
108 | window,
109 | self.window_size,
110 | channel,
111 | mask,
112 | self.size_average,
113 | )
114 |
115 |
116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True):
117 | (_, channel, _, _) = img1.size()
118 | window = create_window(window_size, channel)
119 |
120 | if img1.is_cuda:
121 | window = window.cuda(img1.get_device())
122 | window = window.type_as(img1)
123 |
124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average)
125 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/ldm/thirdp/psp/helpers.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 |
3 | from collections import namedtuple
4 | import torch
5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
6 |
7 | """
8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
9 | """
10 |
11 |
12 | class Flatten(Module):
13 | def forward(self, input):
14 | return input.view(input.size(0), -1)
15 |
16 |
17 | def l2_norm(input, axis=1):
18 | norm = torch.norm(input, 2, axis, True)
19 | output = torch.div(input, norm)
20 | return output
21 |
22 |
23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
24 | """ A named tuple describing a ResNet block. """
25 |
26 |
27 | def get_block(in_channel, depth, num_units, stride=2):
28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
29 |
30 |
31 | def get_blocks(num_layers):
32 | if num_layers == 50:
33 | blocks = [
34 | get_block(in_channel=64, depth=64, num_units=3),
35 | get_block(in_channel=64, depth=128, num_units=4),
36 | get_block(in_channel=128, depth=256, num_units=14),
37 | get_block(in_channel=256, depth=512, num_units=3)
38 | ]
39 | elif num_layers == 100:
40 | blocks = [
41 | get_block(in_channel=64, depth=64, num_units=3),
42 | get_block(in_channel=64, depth=128, num_units=13),
43 | get_block(in_channel=128, depth=256, num_units=30),
44 | get_block(in_channel=256, depth=512, num_units=3)
45 | ]
46 | elif num_layers == 152:
47 | blocks = [
48 | get_block(in_channel=64, depth=64, num_units=3),
49 | get_block(in_channel=64, depth=128, num_units=8),
50 | get_block(in_channel=128, depth=256, num_units=36),
51 | get_block(in_channel=256, depth=512, num_units=3)
52 | ]
53 | else:
54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
55 | return blocks
56 |
57 |
58 | class SEModule(Module):
59 | def __init__(self, channels, reduction):
60 | super(SEModule, self).__init__()
61 | self.avg_pool = AdaptiveAvgPool2d(1)
62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
63 | self.relu = ReLU(inplace=True)
64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
65 | self.sigmoid = Sigmoid()
66 |
67 | def forward(self, x):
68 | module_input = x
69 | x = self.avg_pool(x)
70 | x = self.fc1(x)
71 | x = self.relu(x)
72 | x = self.fc2(x)
73 | x = self.sigmoid(x)
74 | return module_input * x
75 |
76 |
77 | class bottleneck_IR(Module):
78 | def __init__(self, in_channel, depth, stride):
79 | super(bottleneck_IR, self).__init__()
80 | if in_channel == depth:
81 | self.shortcut_layer = MaxPool2d(1, stride)
82 | else:
83 | self.shortcut_layer = Sequential(
84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
85 | BatchNorm2d(depth)
86 | )
87 | self.res_layer = Sequential(
88 | BatchNorm2d(in_channel),
89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
91 | )
92 |
93 | def forward(self, x):
94 | shortcut = self.shortcut_layer(x)
95 | res = self.res_layer(x)
96 | return res + shortcut
97 |
98 |
99 | class bottleneck_IR_SE(Module):
100 | def __init__(self, in_channel, depth, stride):
101 | super(bottleneck_IR_SE, self).__init__()
102 | if in_channel == depth:
103 | self.shortcut_layer = MaxPool2d(1, stride)
104 | else:
105 | self.shortcut_layer = Sequential(
106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
107 | BatchNorm2d(depth)
108 | )
109 | self.res_layer = Sequential(
110 | BatchNorm2d(in_channel),
111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
112 | PReLU(depth),
113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
114 | BatchNorm2d(depth),
115 | SEModule(depth, 16)
116 | )
117 |
118 | def forward(self, x):
119 | shortcut = self.shortcut_layer(x)
120 | res = self.res_layer(x)
121 | return res + shortcut
--------------------------------------------------------------------------------
/ldm/thirdp/psp/id_loss.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 | import torch
3 | from torch import nn
4 | from ldm.thirdp.psp.model_irse import Backbone
5 |
6 |
7 | class IDFeatures(nn.Module):
8 | def __init__(self, model_path):
9 | super(IDFeatures, self).__init__()
10 | print('Loading ResNet ArcFace')
11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14 | self.facenet.eval()
15 |
16 | def forward(self, x, crop=False):
17 | # Not sure of the image range here
18 | if crop:
19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
20 | x = x[:, :, 35:223, 32:220]
21 | x = self.face_pool(x)
22 | x_feats = self.facenet(x)
23 | return x_feats
24 |
--------------------------------------------------------------------------------
/ldm/thirdp/psp/model_irse.py:
--------------------------------------------------------------------------------
1 | # https://github.com/eladrich/pixel2style2pixel
2 |
3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
5 |
6 | """
7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
8 | """
9 |
10 |
11 | class Backbone(Module):
12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
13 | super(Backbone, self).__init__()
14 | assert input_size in [112, 224], "input_size should be 112 or 224"
15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
17 | blocks = get_blocks(num_layers)
18 | if mode == 'ir':
19 | unit_module = bottleneck_IR
20 | elif mode == 'ir_se':
21 | unit_module = bottleneck_IR_SE
22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
23 | BatchNorm2d(64),
24 | PReLU(64))
25 | if input_size == 112:
26 | self.output_layer = Sequential(BatchNorm2d(512),
27 | Dropout(drop_ratio),
28 | Flatten(),
29 | Linear(512 * 7 * 7, 512),
30 | BatchNorm1d(512, affine=affine))
31 | else:
32 | self.output_layer = Sequential(BatchNorm2d(512),
33 | Dropout(drop_ratio),
34 | Flatten(),
35 | Linear(512 * 14 * 14, 512),
36 | BatchNorm1d(512, affine=affine))
37 |
38 | modules = []
39 | for block in blocks:
40 | for bottleneck in block:
41 | modules.append(unit_module(bottleneck.in_channel,
42 | bottleneck.depth,
43 | bottleneck.stride))
44 | self.body = Sequential(*modules)
45 |
46 | def forward(self, x):
47 | x = self.input_layer(x)
48 | x = self.body(x)
49 | x = self.output_layer(x)
50 | return l2_norm(x)
51 |
52 |
53 | def IR_50(input_size):
54 | """Constructs a ir-50 model."""
55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
56 | return model
57 |
58 |
59 | def IR_101(input_size):
60 | """Constructs a ir-101 model."""
61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
62 | return model
63 |
64 |
65 | def IR_152(input_size):
66 | """Constructs a ir-152 model."""
67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
68 | return model
69 |
70 |
71 | def IR_SE_50(input_size):
72 | """Constructs a ir_se-50 model."""
73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
74 | return model
75 |
76 |
77 | def IR_SE_101(input_size):
78 | """Constructs a ir_se-101 model."""
79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
80 | return model
81 |
82 |
83 | def IR_SE_152(input_size):
84 | """Constructs a ir_se-152 model."""
85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
86 | return model
--------------------------------------------------------------------------------
/nerf/fine/dmtet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8 |
9 | import torch
10 |
11 |
12 | def get_center_boundary_index(verts):
13 | length_ = torch.sum(verts ** 2, dim=-1)
14 | center_idx = torch.argmin(length_)
15 | boundary_neg = verts == verts.max()
16 | boundary_pos = verts == verts.min()
17 | boundary = torch.bitwise_or(boundary_pos, boundary_neg)
18 | boundary = torch.sum(boundary.float(), dim=-1)
19 | boundary_idx = torch.nonzero(boundary)
20 | return center_idx, boundary_idx.squeeze(dim=-1)
21 |
--------------------------------------------------------------------------------
/nerf/fine/generate_tets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8 |
9 | import os
10 | import numpy as np
11 |
12 | '''
13 | This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
14 | to generate a tet grid
15 | 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
16 | 2) Run the function below to generate a file `cube_32_tet.tet`
17 | '''
18 |
19 |
20 | def generate_tetrahedron_grid_file(res=32, root='..'):
21 | frac = 1.0 / res
22 | command = 'cd %s/quartet; ' % (root) + \
23 | './quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)
24 | os.system(command)
25 |
26 |
27 | '''
28 | This code segment shows how to convert from a quartet .tet file to compressed npz file
29 | '''
30 |
31 |
32 | def generate_tetrahedrons(res=50, root='..'):
33 | tetrahedrons = []
34 | vertices = []
35 | if res > 1.0:
36 | res = 1.0 / res
37 |
38 | root_path = os.path.join(root, 'quartet/meshes')
39 | file_name = os.path.join(root_path, 'cube_%f_tet.tet' % (res))
40 |
41 | # generate tetrahedron is not exist files
42 | if not os.path.exists(file_name):
43 | command = 'cd %s/quartet; ' % (root) + \
44 | './quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (res, res, res)
45 | os.system(command)
46 |
47 | with open(file_name, 'r') as f:
48 | line = f.readline()
49 | line = line.strip().split(' ')
50 | n_vert = int(line[1])
51 | n_t = int(line[2])
52 | for i in range(n_vert):
53 | line = f.readline()
54 | line = line.strip().split(' ')
55 | assert len(line) == 3
56 | vertices.append([float(v) for v in line])
57 | for i in range(n_t):
58 | line = f.readline()
59 | line = line.strip().split(' ')
60 | assert len(line) == 4
61 | tetrahedrons.append([int(v) for v in line])
62 |
63 | assert len(tetrahedrons) == n_t
64 | assert len(vertices) == n_vert
65 | vertices = np.asarray(vertices)
66 | vertices[vertices <= (0 + res / 4.0)] = 0 # determine the boundary point
67 | vertices[vertices >= (1 - res / 4.0)] = 1 # determine the boundary point
68 | np.savez_compressed('%d_compress' % (res), vertices=vertices, tets=tetrahedrons)
69 | return vertices, tetrahedrons
70 |
--------------------------------------------------------------------------------
/nerf/fine/neural_render.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | import nvdiffrast.torch as dr
12 |
13 | _FG_LUT = None
14 |
15 |
16 | def interpolate(attr, rast, attr_idx, rast_db=None):
17 | return dr.interpolate(
18 | attr.contiguous(), rast, attr_idx, rast_db=rast_db,
19 | diff_attrs=None if rast_db is None else 'all')
20 |
21 |
22 | def xfm_points(points, matrix, use_python=True):
23 | '''Transform points.
24 | Args:
25 | points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
26 | matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
27 | use_python: Use PyTorch's torch.matmul (for validation)
28 | Returns:
29 | Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
30 | '''
31 | out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
32 | if torch.is_anomaly_enabled():
33 | assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
34 | return out
35 |
36 |
37 | class NeuralRender:
38 | def __init__(self, device='cuda', camera_model=None):
39 | super(NeuralRender, self).__init__()
40 | self.device = device
41 | self.ctx = None
42 | self.projection_mtx = None
43 | self.camera = camera_model
44 |
45 | def render_mesh(
46 | self,
47 | mesh_v_pos_bxnx3,
48 | mesh_t_pos_idx_fx3,
49 | camera_mv_bx4x4,
50 | mesh_v_feat_bxnxd,
51 | resolution=256,
52 | spp=1,
53 | device='cuda',
54 | hierarchical_mask=False
55 | ):
56 | assert not hierarchical_mask
57 | if self.ctx is None:
58 | # self.ctx = dr.RasterizeGLContext(device=self.device)
59 | self.ctx = dr.RasterizeCudaContext(device=self.device)
60 |
61 | mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
62 | # v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
63 | mtx_in = mtx_in.unsqueeze(dim=1)
64 | v_pos = mesh_v_pos_bxnx3 - mtx_in[..., :3, 3]
65 | v_pos = v_pos @ mtx_in[:, 0, :3, :3]
66 | v_pos = torch.nn.functional.pad(v_pos, pad=(0, 1), mode='constant', value=1.0)
67 | v_pos_clip = self.camera.project(v_pos).float() # Projection in the camera
68 |
69 | # Render the image,
70 | # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
71 | num_layers = 1
72 | mask_pyramid = None
73 | assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
74 | mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd, v_pos], dim=-1) # Concatenate the pos compute the supervision
75 |
76 | with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
77 | for _ in range(num_layers):
78 | rast, db = peeler.rasterize_next_layer()
79 | gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
80 |
81 | hard_mask = torch.clamp(rast[..., -1:], 0, 1)
82 | antialias_mask = dr.antialias(
83 | hard_mask.clone().contiguous(), rast, v_pos_clip,
84 | mesh_t_pos_idx_fx3)
85 |
86 | albedo = gb_feat[..., 3:6]
87 | antialias_albedo = dr.antialias(
88 | albedo.clone().contiguous(), rast, v_pos_clip,
89 | mesh_t_pos_idx_fx3)
90 |
91 | depth = gb_feat[..., -2:-1]
92 | ori_mesh_feature = gb_feat[..., :-4]
93 | return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, antialias_albedo
94 |
--------------------------------------------------------------------------------
/nerf/fine/perspective_camera.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8 |
9 | import torch
10 | import numpy as np
11 |
12 |
13 | def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
14 | if near_plane is None:
15 | near_plane = n
16 | # return np.array(
17 | # [[n / x, 0, 0, 0],
18 | # [0, n / -x, 0, 0],
19 | # [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
20 | # [0, 0, -1, 0]]).astype(np.float32)
21 | return np.array(
22 | [[n / x, 0, 0, 0],
23 | [0, n / x, 0, 0],
24 | [0, 0, f / (f - near_plane), -(f * near_plane) / (f - near_plane)],
25 | [0, 0, 1, 0]]).astype(np.float32)
26 |
27 | class PerspectiveCamera:
28 | def __init__(self, focal=None, fovy=49.0, device='cuda'):
29 | super(PerspectiveCamera, self).__init__()
30 | self.device = device
31 | if focal is None:
32 | focal = np.tan(fovy / 180.0 * np.pi * 0.5)
33 | self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
34 |
35 | def project(self, points_bxnx4):
36 | out = torch.matmul(
37 | points_bxnx4,
38 | torch.transpose(self.proj_mtx, 1, 2))
39 | return out
40 |
--------------------------------------------------------------------------------
/nerf/fine/rast_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from sklearn.neighbors import NearestNeighbors
7 | from sklearn.mixture import GaussianMixture
8 |
9 | import xatlas
10 | from encoding import get_encoder
11 |
12 | import others.nvdiffrec.render.renderutils as ru
13 | from others.nvdiffrec.render import obj
14 | from others.nvdiffrec.render import material
15 | from others.nvdiffrec.render import util
16 | from others.nvdiffrec.render import mesh
17 | from others.nvdiffrec.render import texture
18 | #from others.nvdiffrec.render import mlptexture
19 | from others.nvdiffrec.render import light
20 | from others.nvdiffrec.render import render
21 |
22 |
23 | class MLP(nn.Module):
24 | def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
25 | super().__init__()
26 | self.dim_in = dim_in
27 | self.dim_out = dim_out
28 | self.dim_hidden = dim_hidden
29 | self.num_layers = num_layers
30 |
31 | net = []
32 | for l in range(num_layers):
33 | net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden,
34 | self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
35 |
36 | self.net = nn.ModuleList(net)
37 |
38 | def forward(self, x):
39 | for l in range(self.num_layers):
40 | x = self.net[l](x)
41 | if l != self.num_layers - 1:
42 | x = F.relu(x, inplace=True)
43 | return x
44 |
45 |
46 | class Material_mlp(nn.Module):
47 | def __init__(self, min_max):
48 | super().__init__()
49 | self.min_max = min_max
50 | self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3,
51 | log2_hashmap_size=16,
52 | desired_resolution=2 ** 12,
53 | base_resolution=2 ** 4,
54 | level_dim=4)
55 | self.texture_MLP = MLP(self.in_dim, 9, 32, 2, bias=True)
56 |
57 | def sample(self, texc):
58 | _texc = texc.view(-1, 3)
59 | p_enc = self.encoder(_texc.contiguous())
60 | out = self.texture_MLP(p_enc)
61 |
62 | # Sigmoid limit and scale to the allowed range
63 | out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
64 |
65 | return out.view(*texc.shape[:-1], 9)
66 |
67 |
68 | def safe_normalize(x, eps=1e-20):
69 | return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
70 |
71 |
72 | def compute_targets(v, vn, n_targets, vc=None):
73 | knn = NearestNeighbors(n_neighbors=20, algorithm='kd_tree').fit(v)
74 | _, indices = knn.kneighbors(v)
75 | vn_neighbors = vn[indices]
76 | if vc is not None:
77 | vc_neighbors = vc[indices]
78 |
79 | v_num = v.shape[0]
80 | normal_weights = np.ones(v_num)
81 | rgb_weights = np.ones(v_num)
82 | for i in range(v_num):
83 | normal_weights[i] = np.linalg.norm(np.cov(vn_neighbors[i].T), 'fro')
84 | if vc is not None:
85 | rgb_weights[i] = np.linalg.norm(np.cov(vc_neighbors[i].T), 'fro')
86 |
87 | v_norm = np.sqrt((v ** 2).sum(1))
88 | unit_v = v / v_norm[..., None]
89 | k = n_targets
90 | gmm = GaussianMixture(n_components=k, covariance_type='full', random_state=0)
91 | gmm.fit(unit_v)
92 | pred_label = gmm.predict(unit_v)
93 | pred_prob = gmm.predict_proba(unit_v).max(1)
94 | weights_ = pred_prob * (v_norm ** 4) * normal_weights * rgb_weights
95 | targets = np.zeros([k, 3])
96 | radius = np.ones([k])
97 | weights = gmm.weights_
98 | for i in range(k):
99 | index = pred_label == i
100 | w_sum = weights_[index].sum()
101 | targets[i] = (v[index] * weights_[index][..., None]).sum(0) / w_sum
102 | radius[i] = np.quantile(np.sqrt(((v[index] - targets[i]) ** 2).sum(1)), 0.3)
103 | weights[i] = weights[i] * w_sum
104 |
105 | weights = weights / weights.sum()
106 | return targets, weights, radius, weights_
107 |
108 |
109 | # for save mesh
110 | @torch.no_grad()
111 | def xatlas_uvmap(glctx, eval_mesh, mat, device,
112 | kd_min_max, ks_min_max, nrm_min_max):
113 |
114 | # Create uvs with xatlas
115 | v_pos = eval_mesh.v_pos.detach().cpu().numpy()
116 | t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
117 | vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
118 |
119 | # Convert to tensors
120 | indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
121 |
122 | uvs = torch.tensor(uvs, dtype=torch.float32, device=device)
123 | faces = torch.tensor(indices_int64, dtype=torch.int64, device=device)
124 |
125 | new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
126 |
127 | mask, kd, ks, normal = render.render_uv(glctx, new_mesh, [2048, 2048], eval_mesh.material['kd_ks_normal'])
128 |
129 | # if FLAGS.layers > 1:
130 | # kd = torch.cat((kd, torch.rand_like(kd[..., 0:1])), dim=-1)
131 |
132 | new_mesh.material = material.Material({
133 | 'bsdf': mat['bsdf'],
134 | 'kd': texture.Texture2D(kd, min_max=kd_min_max),
135 | 'ks': texture.Texture2D(ks, min_max=ks_min_max),
136 | 'normal': texture.Texture2D(normal, min_max=nrm_min_max),
137 | 'no_perturbed_nrm': mat['no_perturbed_nrm']
138 | })
139 |
140 | return new_mesh
141 |
--------------------------------------------------------------------------------
/nerf/fine/tets/256_compress.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/nerf/fine/tets/256_compress.npz
--------------------------------------------------------------------------------
/nerf/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import tqdm
4 | import math
5 | import imageio
6 | import random
7 | import warnings
8 | import tensorboardX
9 |
10 | import numpy as np
11 | import pandas as pd
12 |
13 | import time
14 | from datetime import datetime
15 |
16 | import cv2
17 | import matplotlib.pyplot as plt
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.optim as optim
22 | import torch.nn.functional as F
23 | import torch.distributed as dist
24 | from torch.utils.data import Dataset, DataLoader
25 |
26 | import trimesh
27 | from rich.console import Console
28 | from torch_ema import ExponentialMovingAverage
29 |
30 | from packaging import version as pver
31 |
32 |
33 | def custom_meshgrid(*args):
34 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
35 | if pver.parse(torch.__version__) < pver.parse('1.10'):
36 | return torch.meshgrid(*args)
37 | else:
38 | return torch.meshgrid(*args, indexing='ij')
39 |
40 |
41 | def safe_normalize(x, eps=1e-20):
42 | return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
43 |
44 |
45 | @torch.cuda.amp.autocast(enabled=False)
46 | def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
47 | ''' get rays
48 | Args:
49 | poses: [B, 4, 4], cam2world
50 | intrinsics: [4]
51 | H, W, N: int
52 | error_map: [B, 128 * 128], sample probability based on training error
53 | Returns:
54 | rays_o, rays_d: [B, N, 3]
55 | inds: [B, N]
56 | '''
57 |
58 | device = poses.device
59 | B = poses.shape[0]
60 | fx, fy, cx, cy = intrinsics
61 |
62 | i, j = custom_meshgrid(torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device))
63 | i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5
64 | j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5
65 |
66 | results = {}
67 |
68 | if N > 0:
69 | N = min(N, H * W)
70 |
71 | if error_map is None:
72 | inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate
73 | inds = inds.expand([B, N])
74 | else:
75 |
76 | # weighted sample on a low-reso grid
77 | inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
78 |
79 | # map to the original resolution with random perturb.
80 | inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
81 | sx, sy = H / 128, W / 128
82 | inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
83 | inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
84 | inds = inds_x * W + inds_y
85 |
86 | results['inds_coarse'] = inds_coarse # need this when updating error_map
87 |
88 | i = torch.gather(i, -1, inds)
89 | j = torch.gather(j, -1, inds)
90 |
91 | results['inds'] = inds
92 |
93 | else:
94 | inds = torch.arange(H * W, device=device).expand([B, H * W])
95 |
96 | zs = torch.ones_like(i)
97 | xs = (i - cx) / fx * zs
98 | ys = (j - cy) / fy * zs
99 | directions = torch.stack((xs, ys, zs), dim=-1)
100 | directions = safe_normalize(directions)
101 | rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
102 |
103 | rays_o = poses[..., :3, 3] # [B, 3]
104 | rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
105 |
106 | results['rays_o'] = rays_o
107 | results['rays_d'] = rays_d
108 |
109 | return results
110 |
111 |
112 | def seed_everything(seed):
113 | random.seed(seed)
114 | os.environ['PYTHONHASHSEED'] = str(seed)
115 | np.random.seed(seed)
116 | torch.manual_seed(seed)
117 | torch.cuda.manual_seed(seed)
118 | torch.cuda.manual_seed_all(seed)
119 | # torch.backends.cudnn.deterministic = True
120 | # torch.backends.cudnn.benchmark = True
121 |
122 |
123 | def torch_vis_2d(x, renormalize=False):
124 | # x: [3, H, W] or [1, H, W] or [H, W]
125 | import matplotlib.pyplot as plt
126 | import numpy as np
127 | import torch
128 |
129 | if isinstance(x, torch.Tensor):
130 | if len(x.shape) == 3:
131 | x = x.permute(1, 2, 0).squeeze()
132 | x = x.detach().cpu().numpy()
133 |
134 | print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
135 |
136 | x = x.astype(np.float32)
137 |
138 | # renormalize
139 | if renormalize:
140 | x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
141 |
142 | plt.imshow(x)
143 | plt.show()
144 |
145 |
146 | @torch.jit.script
147 | def linear_to_srgb(x):
148 | return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
149 |
150 |
151 | @torch.jit.script
152 | def srgb_to_linear(x):
153 | return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
154 |
--------------------------------------------------------------------------------
/others/nvdiffrec/data/irrmaps/README.txt:
--------------------------------------------------------------------------------
1 | The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop
2 | CC0 License.
3 |
4 |
--------------------------------------------------------------------------------
/others/nvdiffrec/data/irrmaps/aerodynamics_workshop_2k.hdr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/aerodynamics_workshop_2k.hdr
--------------------------------------------------------------------------------
/others/nvdiffrec/data/irrmaps/blocky_photo_studio_2k.hdr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/blocky_photo_studio_2k.hdr
--------------------------------------------------------------------------------
/others/nvdiffrec/data/irrmaps/bsdf_256_256.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/bsdf_256_256.bin
--------------------------------------------------------------------------------
/others/nvdiffrec/data/irrmaps/mud_road_puresky_4k.hdr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fudan-zvg/PGC-3D/14b96533a2711d421be2d3bbe4f35ae13d435e7d/others/nvdiffrec/data/irrmaps/mud_road_puresky_4k.hdr
--------------------------------------------------------------------------------
/others/nvdiffrec/render/light.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import os
11 | import numpy as np
12 | import torch
13 | import nvdiffrast.torch as dr
14 |
15 | from . import util
16 | from . import renderutils as ru
17 |
18 | ######################################################################################
19 | # Utility functions
20 | ######################################################################################
21 |
22 | class cubemap_mip(torch.autograd.Function):
23 | @staticmethod
24 | def forward(ctx, cubemap):
25 | return util.avg_pool_nhwc(cubemap, (2,2))
26 |
27 | @staticmethod
28 | def backward(ctx, dout):
29 | res = dout.shape[1] * 2
30 | out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
31 | for s in range(6):
32 | gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
33 | torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
34 | indexing='ij')
35 | v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
36 | out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
37 | return out
38 |
39 | ######################################################################################
40 | # Split-sum environment map light source with automatic mipmap generation
41 | ######################################################################################
42 |
43 | class EnvironmentLight(torch.nn.Module):
44 | LIGHT_MIN_RES = 16
45 |
46 | MIN_ROUGHNESS = 0.08
47 | MAX_ROUGHNESS = 0.5
48 |
49 | def __init__(self, base):
50 | super(EnvironmentLight, self).__init__()
51 | self.mtx = None
52 | self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=False)
53 | self.register_parameter('env_base', self.base)
54 |
55 | def xfm(self, mtx):
56 | self.mtx = mtx
57 |
58 | def clone(self):
59 | return EnvironmentLight(self.base.clone().detach())
60 |
61 | def clamp_(self, min=None, max=None):
62 | self.base.clamp_(min, max)
63 |
64 | def get_mip(self, roughness):
65 | return torch.where(roughness < self.MAX_ROUGHNESS
66 | , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2)
67 | , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2)
68 |
69 | def build_mips(self, cutoff=0.99):
70 | self.specular = [self.base]
71 | while self.specular[-1].shape[1] > self.LIGHT_MIN_RES:
72 | self.specular += [cubemap_mip.apply(self.specular[-1])]
73 |
74 | self.diffuse = ru.diffuse_cubemap(self.specular[-1])
75 |
76 | for idx in range(len(self.specular) - 1):
77 | roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS
78 | self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff)
79 | self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff)
80 |
81 | def regularizer(self):
82 | white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0
83 | return torch.mean(torch.abs(self.base - white))
84 |
85 | def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True):
86 | wo = util.safe_normalize(view_pos - gb_pos)
87 |
88 | if specular:
89 | roughness = ks[..., 1:2] # y component
90 | metallic = ks[..., 2:3] # z component
91 | spec_col = (1.0 - metallic)*0.04 + kd * metallic
92 | diff_col = kd * (1.0 - metallic)
93 | else:
94 | diff_col = kd
95 | spec_col = None
96 |
97 | reflvec = util.safe_normalize(util.reflect(wo, gb_normal))
98 | nrmvec = gb_normal
99 | if self.mtx is not None: # Rotate lookup
100 | mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
101 | reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
102 | nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
103 |
104 | # Diffuse lookup
105 | diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube')
106 | shaded_col = diffuse * diff_col
107 |
108 | if specular:
109 | # Lookup FG term from lookup texture
110 | NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4)
111 | fg_uv = torch.cat((NdotV, roughness), dim=-1)
112 | if not hasattr(self, '_FG_LUT'):
113 | self._FG_LUT = torch.as_tensor(np.fromfile('others/nvdiffrec/data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda')
114 | fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp')
115 |
116 | # Roughness adjusted specular env lookup
117 | miplevel = self.get_mip(roughness)
118 | spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube')
119 |
120 | # Compute aggregate lighting
121 | reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2]
122 | shaded_col += spec * reflectance
123 |
124 | return shaded_col * (1.0 - ks[..., 0:1]), spec_col # Modulate by hemisphere visibility
125 |
126 | ######################################################################################
127 | # Load and store
128 | ######################################################################################
129 |
130 | # Load from latlong .HDR file
131 | def _load_env_hdr(fn, scale=1.0):
132 | latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
133 | cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
134 |
135 | l = EnvironmentLight(cubemap)
136 | l.build_mips()
137 |
138 | return l
139 |
140 | def load_env(fn, scale=1.0):
141 | if os.path.splitext(fn)[1].lower() == ".hdr":
142 | return _load_env_hdr(fn, scale)
143 | else:
144 | assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1]
145 |
146 | def save_env_map(fn, light):
147 | assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently"
148 | if isinstance(light, EnvironmentLight):
149 | color = util.cubemap_to_latlong(light.base, [512, 1024])
150 | util.save_image_raw(fn, color.detach().cpu().numpy())
151 |
152 | ######################################################################################
153 | # Create trainable env map with random initialization
154 | ######################################################################################
155 |
156 | def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):
157 | base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias
158 | return EnvironmentLight(base)
159 |
160 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/material.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import os
11 | import numpy as np
12 | import torch
13 |
14 | from . import util
15 | from . import texture
16 |
17 | ######################################################################################
18 | # Wrapper to make materials behave like a python dict, but register textures as
19 | # torch.nn.Module parameters.
20 | ######################################################################################
21 | class Material(torch.nn.Module):
22 | def __init__(self, mat_dict):
23 | super(Material, self).__init__()
24 | self.mat_keys = set()
25 | for key in mat_dict.keys():
26 | self.mat_keys.add(key)
27 | self[key] = mat_dict[key]
28 |
29 | def __contains__(self, key):
30 | return hasattr(self, key)
31 |
32 | def __getitem__(self, key):
33 | return getattr(self, key)
34 |
35 | def __setitem__(self, key, val):
36 | self.mat_keys.add(key)
37 | setattr(self, key, val)
38 |
39 | def __delitem__(self, key):
40 | self.mat_keys.remove(key)
41 | delattr(self, key)
42 |
43 | def keys(self):
44 | return self.mat_keys
45 |
46 | ######################################################################################
47 | # .mtl material format loading / storing
48 | ######################################################################################
49 | @torch.no_grad()
50 | def load_mtl(fn, clear_ks=True):
51 | import re
52 | mtl_path = os.path.dirname(fn)
53 |
54 | # Read file
55 | with open(fn, 'r') as f:
56 | lines = f.readlines()
57 |
58 | # Parse materials
59 | materials = []
60 | for line in lines:
61 | split_line = re.split(' +|\t+|\n+', line.strip())
62 | prefix = split_line[0].lower()
63 | data = split_line[1:]
64 | if 'newmtl' in prefix:
65 | material = Material({'name' : data[0]})
66 | materials += [material]
67 | elif materials:
68 | if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
69 | material[prefix] = data[0]
70 | else:
71 | material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
72 |
73 | # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
74 | for mat in materials:
75 | if not 'bsdf' in mat:
76 | mat['bsdf'] = 'pbr'
77 |
78 | if 'map_kd' in mat:
79 | mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
80 | else:
81 | mat['kd'] = texture.Texture2D(mat['kd'])
82 |
83 | if 'map_ks' in mat:
84 | mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
85 | else:
86 | mat['ks'] = texture.Texture2D(mat['ks'])
87 |
88 | if 'bump' in mat:
89 | mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
90 |
91 | # Convert Kd from sRGB to linear RGB
92 | mat['kd'] = texture.srgb_to_rgb(mat['kd'])
93 |
94 | if clear_ks:
95 | # Override ORM occlusion (red) channel by zeros. We hijack this channel
96 | for mip in mat['ks'].getMips():
97 | mip[..., 0] = 0.0
98 |
99 | return materials
100 |
101 | @torch.no_grad()
102 | def save_mtl(fn, material):
103 | folder = os.path.dirname(fn)
104 | with open(fn, "w") as f:
105 | f.write('newmtl defaultMat\n')
106 | if material is not None:
107 | f.write('bsdf %s\n' % material['bsdf'])
108 | if 'kd' in material.keys():
109 | f.write('map_Kd texture_kd.png\n')
110 | texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd']))
111 | texture.save_texture2D(os.path.join(folder, 'texture_kd_rgb.png'), material['kd'])
112 | if 'ks' in material.keys():
113 | f.write('map_Ks texture_ks.png\n')
114 | texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks'])
115 | if 'normal' in material.keys() and not ('no_perturbed_nrm' in material and material['no_perturbed_nrm']):
116 | f.write('bump texture_n.png\n')
117 | texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5)
118 | else:
119 | f.write('Kd 1 1 1\n')
120 | f.write('Ks 0 0 0\n')
121 | f.write('Ka 0 0 0\n')
122 | f.write('Tf 1 1 1\n')
123 | f.write('Ni 1\n')
124 | f.write('Ns 0\n')
125 |
126 | ######################################################################################
127 | # Merge multiple materials into a single uber-material
128 | ######################################################################################
129 |
130 | def _upscale_replicate(x, full_res):
131 | x = x.permute(0, 3, 1, 2)
132 | x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate')
133 | return x.permute(0, 2, 3, 1).contiguous()
134 |
135 | def merge_materials(materials, texcoords, tfaces, mfaces):
136 | assert len(materials) > 0
137 | for mat in materials:
138 | assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)"
139 | assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled"
140 |
141 | uber_material = Material({
142 | 'name' : 'uber_material',
143 | 'bsdf' : materials[0]['bsdf'],
144 | })
145 |
146 | textures = ['kd', 'ks', 'normal']
147 |
148 | # Find maximum texture resolution across all materials and textures
149 | max_res = None
150 | for mat in materials:
151 | for tex in textures:
152 | tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1])
153 | max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res
154 |
155 | # Compute size of compund texture and round up to nearest PoT
156 | full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int)
157 |
158 | # Normalize texture resolution across all materials & combine into a single large texture
159 | for tex in textures:
160 | if tex in materials[0]:
161 | tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x
162 | tex_data = _upscale_replicate(tex_data, full_res)
163 | uber_material[tex] = texture.Texture2D(tex_data)
164 |
165 | # Compute scaling values for used / unused texture area
166 | s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]]
167 |
168 | # Recompute texture coordinates to cooincide with new composite texture
169 | new_tverts = {}
170 | new_tverts_data = []
171 | for fi in range(len(tfaces)):
172 | matIdx = mfaces[fi]
173 | for vi in range(3):
174 | ti = tfaces[fi][vi]
175 | if not (ti in new_tverts):
176 | new_tverts[ti] = {}
177 | if not (matIdx in new_tverts[ti]): # create new vertex
178 | new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
179 | new_tverts[ti][matIdx] = len(new_tverts_data) - 1
180 | tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex
181 |
182 | return uber_material, new_tverts_data, tfaces
183 |
184 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/mlptexture.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import torch
11 | import tinycudann as tcnn
12 | import numpy as np
13 |
14 | #######################################################################################################################################################
15 | # Small MLP using PyTorch primitives, internal helper class
16 | #######################################################################################################################################################
17 |
18 | class _MLP(torch.nn.Module):
19 | def __init__(self, cfg, loss_scale=1.0):
20 | super(_MLP, self).__init__()
21 | self.loss_scale = loss_scale
22 | net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
23 | for i in range(cfg['n_hidden_layers']-1):
24 | net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
25 | net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),)
26 | self.net = torch.nn.Sequential(*net).cuda()
27 |
28 | self.net.apply(self._init_weights)
29 |
30 | if self.loss_scale != 1.0:
31 | self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, ))
32 |
33 | def forward(self, x):
34 | return self.net(x.to(torch.float32))
35 |
36 | @staticmethod
37 | def _init_weights(m):
38 | if type(m) == torch.nn.Linear:
39 | torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
40 | if hasattr(m.bias, 'data'):
41 | m.bias.data.fill_(0.0)
42 |
43 | #######################################################################################################################################################
44 | # Outward visible MLP class
45 | #######################################################################################################################################################
46 |
47 | class MLPTexture3D(torch.nn.Module):
48 | def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2, min_max = None):
49 | super(MLPTexture3D, self).__init__()
50 |
51 | self.channels = channels
52 | self.internal_dims = internal_dims
53 | self.AABB = AABB
54 | self.min_max = min_max
55 |
56 | # Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details
57 | desired_resolution = 4096
58 | base_grid_resolution = 16
59 | num_levels = 16
60 | per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1))
61 |
62 | enc_cfg = {
63 | "otype": "HashGrid",
64 | "n_levels": num_levels,
65 | "n_features_per_level": 2,
66 | "log2_hashmap_size": 19,
67 | "base_resolution": base_grid_resolution,
68 | "per_level_scale" : per_level_scale
69 | }
70 |
71 | gradient_scaling = 128.0
72 | self.encoder = tcnn.Encoding(3, enc_cfg)
73 | self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, ))
74 |
75 | # Setup MLP
76 | mlp_cfg = {
77 | "n_input_dims" : self.encoder.n_output_dims,
78 | "n_output_dims" : self.channels,
79 | "n_hidden_layers" : hidden,
80 | "n_neurons" : self.internal_dims
81 | }
82 | self.net = _MLP(mlp_cfg, gradient_scaling)
83 | print("Encoder output: %d dims" % (self.encoder.n_output_dims))
84 |
85 | # Sample texture at a given location
86 | def sample(self, texc):
87 | _texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...])
88 | _texc = torch.clamp(_texc, min=0, max=1)
89 |
90 | p_enc = self.encoder(_texc.contiguous())
91 | out = self.net.forward(p_enc)
92 |
93 | # Sigmoid limit and scale to the allowed range
94 | out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
95 |
96 | return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c]
97 |
98 | # In-place clamp with no derivative to make sure values are in valid range after training
99 | def clamp_(self):
100 | pass
101 |
102 | def cleanup(self):
103 | tcnn.free_temporary_memory()
104 |
105 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/obj.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import os
11 | import torch
12 |
13 | from . import texture
14 | from . import mesh
15 | from . import material
16 |
17 | ######################################################################################
18 | # Utility functions
19 | ######################################################################################
20 |
21 | def _find_mat(materials, name):
22 | for mat in materials:
23 | if mat['name'] == name:
24 | return mat
25 | return materials[0] # Materials 0 is the default
26 |
27 | ######################################################################################
28 | # Create mesh object from objfile
29 | ######################################################################################
30 |
31 | def load_obj(filename, clear_ks=True, mtl_override=None):
32 | obj_path = os.path.dirname(filename)
33 |
34 | # Read entire file
35 | with open(filename, 'r') as f:
36 | lines = f.readlines()
37 |
38 | # Load materials
39 | all_materials = [
40 | {
41 | 'name' : '_default_mat',
42 | 'bsdf' : 'pbr',
43 | 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
44 | 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
45 | }
46 | ]
47 | if mtl_override is None:
48 | for line in lines:
49 | if len(line.split()) == 0:
50 | continue
51 | if line.split()[0] == 'mtllib':
52 | all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library
53 | else:
54 | all_materials += material.load_mtl(mtl_override)
55 |
56 | # load vertices
57 | vertices, texcoords, normals = [], [], []
58 | for line in lines:
59 | if len(line.split()) == 0:
60 | continue
61 |
62 | prefix = line.split()[0].lower()
63 | if prefix == 'v':
64 | vertices.append([float(v) for v in line.split()[1:]])
65 | elif prefix == 'vt':
66 | val = [float(v) for v in line.split()[1:]]
67 | texcoords.append([val[0], 1.0 - val[1]])
68 | elif prefix == 'vn':
69 | normals.append([float(v) for v in line.split()[1:]])
70 |
71 | # load faces
72 | activeMatIdx = None
73 | used_materials = []
74 | faces, tfaces, nfaces, mfaces = [], [], [], []
75 | for line in lines:
76 | if len(line.split()) == 0:
77 | continue
78 |
79 | prefix = line.split()[0].lower()
80 | if prefix == 'usemtl': # Track used materials
81 | mat = _find_mat(all_materials, line.split()[1])
82 | if not mat in used_materials:
83 | used_materials.append(mat)
84 | activeMatIdx = used_materials.index(mat)
85 | elif prefix == 'f': # Parse face
86 | vs = line.split()[1:]
87 | nv = len(vs)
88 | vv = vs[0].split('/')
89 | v0 = int(vv[0]) - 1
90 | t0 = int(vv[1]) - 1 if vv[1] != "" else -1
91 | n0 = int(vv[2]) - 1 if vv[2] != "" else -1
92 | for i in range(nv - 2): # Triangulate polygons
93 | vv = vs[i + 1].split('/')
94 | v1 = int(vv[0]) - 1
95 | t1 = int(vv[1]) - 1 if vv[1] != "" else -1
96 | n1 = int(vv[2]) - 1 if vv[2] != "" else -1
97 | vv = vs[i + 2].split('/')
98 | v2 = int(vv[0]) - 1
99 | t2 = int(vv[1]) - 1 if vv[1] != "" else -1
100 | n2 = int(vv[2]) - 1 if vv[2] != "" else -1
101 | mfaces.append(activeMatIdx)
102 | faces.append([v0, v1, v2])
103 | tfaces.append([t0, t1, t2])
104 | nfaces.append([n0, n1, n2])
105 | assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
106 |
107 | # Create an "uber" material by combining all textures into a larger texture
108 | if len(used_materials) > 1:
109 | uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
110 | else:
111 | uber_material = used_materials[0]
112 |
113 | vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
114 | texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
115 | normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
116 |
117 | faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
118 | tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
119 | nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
120 |
121 | return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
122 |
123 | ######################################################################################
124 | # Save mesh object to objfile
125 | ######################################################################################
126 |
127 | def write_obj(folder, mesh, save_material=True):
128 | obj_file = os.path.join(folder, 'mesh.obj')
129 | print("Writing mesh: ", obj_file)
130 | with open(obj_file, "w") as f:
131 | f.write("mtllib mesh.mtl\n")
132 | f.write("g default\n")
133 |
134 | v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None
135 | v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None
136 | v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None
137 |
138 | t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None
139 | t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
140 | t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None
141 |
142 | print(" writing %d vertices" % len(v_pos))
143 | for v in v_pos:
144 | f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
145 |
146 | if v_tex is not None:
147 | print(" writing %d texcoords" % len(v_tex))
148 | assert(len(t_pos_idx) == len(t_tex_idx))
149 | for v in v_tex:
150 | f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
151 |
152 | if v_nrm is not None:
153 | print(" writing %d normals" % len(v_nrm))
154 | assert(len(t_pos_idx) == len(t_nrm_idx))
155 | for v in v_nrm:
156 | f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
157 |
158 | # faces
159 | f.write("s 1 \n")
160 | f.write("g pMesh1\n")
161 | f.write("usemtl defaultMat\n")
162 |
163 | # Write faces
164 | print(" writing %d faces" % len(t_pos_idx))
165 | for i in range(len(t_pos_idx)):
166 | f.write("f ")
167 | for j in range(3):
168 | f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
169 | f.write("\n")
170 |
171 | if save_material:
172 | mtl_file = os.path.join(folder, 'mesh.mtl')
173 | print("Writing material: ", mtl_file)
174 | material.save_mtl(mtl_file, mesh.material)
175 |
176 | print("Done exporting mesh")
177 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/regularizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import torch
11 | import nvdiffrast.torch as dr
12 |
13 | from . import util
14 | from . import mesh
15 |
16 | ######################################################################################
17 | # Computes the image gradient, useful for kd/ks smoothness losses
18 | ######################################################################################
19 | def image_grad(buf, std=0.01):
20 | t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"),
21 | torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"),
22 | indexing='ij')
23 | tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...]
24 | tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp')
25 | return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:]
26 |
27 | ######################################################################################
28 | # Computes the avergage edge length of a mesh.
29 | # Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients
30 | ######################################################################################
31 | def avg_edge_length(v_pos, t_pos_idx):
32 | e_pos_idx = mesh.compute_edges(t_pos_idx)
33 | edge_len = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]])
34 | return torch.mean(edge_len)
35 |
36 | ######################################################################################
37 | # Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
38 | # https://mgarland.org/class/geom04/material/smoothing.pdf
39 | ######################################################################################
40 | def laplace_regularizer_const(v_pos, t_pos_idx):
41 | term = torch.zeros_like(v_pos)
42 | norm = torch.zeros_like(v_pos[..., 0:1])
43 |
44 | v0 = v_pos[t_pos_idx[:, 0], :]
45 | v1 = v_pos[t_pos_idx[:, 1], :]
46 | v2 = v_pos[t_pos_idx[:, 2], :]
47 |
48 | term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
49 | term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
50 | term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
51 |
52 | two = torch.ones_like(v0) * 2.0
53 | norm.scatter_add_(0, t_pos_idx[:, 0:1], two)
54 | norm.scatter_add_(0, t_pos_idx[:, 1:2], two)
55 | norm.scatter_add_(0, t_pos_idx[:, 2:3], two)
56 |
57 | term = term / torch.clamp(norm, min=1.0)
58 |
59 | return torch.mean(term**2)
60 |
61 | ######################################################################################
62 | # Smooth vertex normals
63 | ######################################################################################
64 | def normal_consistency(v_pos, t_pos_idx):
65 | # Compute face normals
66 | v0 = v_pos[t_pos_idx[:, 0], :]
67 | v1 = v_pos[t_pos_idx[:, 1], :]
68 | v2 = v_pos[t_pos_idx[:, 2], :]
69 |
70 | face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
71 |
72 | tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx)
73 |
74 | # Fetch normals for both faces sharind an edge
75 | n0 = face_normals[tris_per_edge[:, 0], :]
76 | n1 = face_normals[tris_per_edge[:, 1], :]
77 |
78 | # Compute error metric based on normal difference
79 | term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)
80 | term = (1.0 - term) * 0.5
81 |
82 | return torch.mean(torch.abs(term))
83 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
11 | __all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
12 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/bsdf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import math
11 | import torch
12 |
13 | NORMAL_THRESHOLD = 0.1
14 |
15 | ################################################################################
16 | # Vector utility functions
17 | ################################################################################
18 |
19 | def _dot(x, y):
20 | return torch.sum(x*y, -1, keepdim=True)
21 |
22 | def _reflect(x, n):
23 | return 2*_dot(x, n)*n - x
24 |
25 | def _safe_normalize(x):
26 | return torch.nn.functional.normalize(x, dim = -1)
27 |
28 | def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
29 | # Swap normal direction for backfacing surfaces
30 | if two_sided_shading:
31 | smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
32 | geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
33 |
34 | t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
35 | return torch.lerp(geom_nrm, smooth_nrm, t)
36 |
37 |
38 | def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
39 | smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
40 | if opengl:
41 | shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
42 | else:
43 | shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
44 | return _safe_normalize(shading_nrm)
45 |
46 | def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
47 | smooth_nrm = _safe_normalize(smooth_nrm)
48 | smooth_tng = _safe_normalize(smooth_tng)
49 | view_vec = _safe_normalize(view_pos - pos)
50 | shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
51 | return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
52 |
53 | ################################################################################
54 | # Simple lambertian diffuse BSDF
55 | ################################################################################
56 |
57 | def bsdf_lambert(nrm, wi):
58 | return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
59 |
60 | ################################################################################
61 | # Frostbite diffuse
62 | ################################################################################
63 |
64 | def bsdf_frostbite(nrm, wi, wo, linearRoughness):
65 | wiDotN = _dot(wi, nrm)
66 | woDotN = _dot(wo, nrm)
67 |
68 | h = _safe_normalize(wo + wi)
69 | wiDotH = _dot(wi, h)
70 |
71 | energyBias = 0.5 * linearRoughness
72 | energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
73 | f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
74 | f0 = 1.0
75 |
76 | wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
77 | woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
78 | res = wiScatter * woScatter * energyFactor
79 | return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
80 |
81 | ################################################################################
82 | # Phong specular, loosely based on mitsuba implementation
83 | ################################################################################
84 |
85 | def bsdf_phong(nrm, wo, wi, N):
86 | dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
87 | dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
88 | return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
89 |
90 | ################################################################################
91 | # PBR's implementation of GGX specular
92 | ################################################################################
93 |
94 | specular_epsilon = 1e-4
95 |
96 | def bsdf_fresnel_shlick(f0, f90, cosTheta):
97 | _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
98 | return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
99 |
100 | def bsdf_ndf_ggx(alphaSqr, cosTheta):
101 | _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
102 | d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
103 | return alphaSqr / (d * d * math.pi)
104 |
105 | def bsdf_lambda_ggx(alphaSqr, cosTheta):
106 | _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
107 | cosThetaSqr = _cosTheta * _cosTheta
108 | tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
109 | res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
110 | return res
111 |
112 | def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
113 | lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
114 | lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
115 | return 1 / (1 + lambdaI + lambdaO)
116 |
117 | def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
118 | _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
119 | alphaSqr = _alpha * _alpha
120 |
121 | h = _safe_normalize(wo + wi)
122 | woDotN = _dot(wo, nrm)
123 | wiDotN = _dot(wi, nrm)
124 | woDotH = _dot(wo, h)
125 | nDotH = _dot(nrm, h)
126 |
127 | D = bsdf_ndf_ggx(alphaSqr, nDotH)
128 | G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
129 | F = bsdf_fresnel_shlick(col, 1, woDotH)
130 |
131 | w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
132 |
133 | frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
134 | return torch.where(frontfacing, w, torch.zeros_like(w))
135 |
136 | def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
137 | wo = _safe_normalize(view_pos - pos)
138 | wi = _safe_normalize(light_pos - pos)
139 |
140 | spec_str = arm[..., 0:1] # x component
141 | roughness = arm[..., 1:2] # y component
142 | metallic = arm[..., 2:3] # z component
143 | ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
144 | kd = kd * (1.0 - metallic)
145 |
146 | if BSDF == 0:
147 | diffuse = kd * bsdf_lambert(nrm, wi)
148 | else:
149 | diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
150 | specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
151 | return diffuse + specular
152 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/bsdf.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "common.h"
15 |
16 | struct LambertKernelParams
17 | {
18 | Tensor nrm;
19 | Tensor wi;
20 | Tensor out;
21 | dim3 gridSize;
22 | };
23 |
24 | struct FrostbiteDiffuseKernelParams
25 | {
26 | Tensor nrm;
27 | Tensor wi;
28 | Tensor wo;
29 | Tensor linearRoughness;
30 | Tensor out;
31 | dim3 gridSize;
32 | };
33 |
34 | struct FresnelShlickKernelParams
35 | {
36 | Tensor f0;
37 | Tensor f90;
38 | Tensor cosTheta;
39 | Tensor out;
40 | dim3 gridSize;
41 | };
42 |
43 | struct NdfGGXParams
44 | {
45 | Tensor alphaSqr;
46 | Tensor cosTheta;
47 | Tensor out;
48 | dim3 gridSize;
49 | };
50 |
51 | struct MaskingSmithParams
52 | {
53 | Tensor alphaSqr;
54 | Tensor cosThetaI;
55 | Tensor cosThetaO;
56 | Tensor out;
57 | dim3 gridSize;
58 | };
59 |
60 | struct PbrSpecular
61 | {
62 | Tensor col;
63 | Tensor nrm;
64 | Tensor wo;
65 | Tensor wi;
66 | Tensor alpha;
67 | Tensor out;
68 | dim3 gridSize;
69 | float min_roughness;
70 | };
71 |
72 | struct PbrBSDF
73 | {
74 | Tensor kd;
75 | Tensor arm;
76 | Tensor pos;
77 | Tensor nrm;
78 | Tensor view_pos;
79 | Tensor light_pos;
80 | Tensor out;
81 | dim3 gridSize;
82 | float min_roughness;
83 | int BSDF;
84 | };
85 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/common.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #include
13 | #include
14 |
15 | //------------------------------------------------------------------------
16 | // Block and grid size calculators for kernel launches.
17 |
18 | dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
19 | {
20 | int maxThreads = maxWidth * maxHeight;
21 | if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
22 | return dim3(1, 1, 1); // Degenerate.
23 |
24 | // Start from max size.
25 | int bw = maxWidth;
26 | int bh = maxHeight;
27 |
28 | // Optimizations for weirdly sized buffers.
29 | if (dims.x < bw)
30 | {
31 | // Decrease block width to smallest power of two that covers the buffer width.
32 | while ((bw >> 1) >= dims.x)
33 | bw >>= 1;
34 |
35 | // Maximize height.
36 | bh = maxThreads / bw;
37 | if (bh > dims.y)
38 | bh = dims.y;
39 | }
40 | else if (dims.y < bh)
41 | {
42 | // Halve height and double width until fits completely inside buffer vertically.
43 | while (bh > dims.y)
44 | {
45 | bh >>= 1;
46 | if (bw < dims.x)
47 | bw <<= 1;
48 | }
49 | }
50 |
51 | // Done.
52 | return dim3(bw, bh, 1);
53 | }
54 |
55 | // returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
56 | dim3 getWarpSize(dim3 blockSize)
57 | {
58 | return dim3(
59 | std::min(blockSize.x, 32u),
60 | std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
61 | std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
62 | );
63 | }
64 |
65 | dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
66 | {
67 | dim3 gridSize;
68 | gridSize.x = (dims.x - 1) / blockSize.x + 1;
69 | gridSize.y = (dims.y - 1) / blockSize.y + 1;
70 | gridSize.z = (dims.z - 1) / blockSize.z + 1;
71 | return gridSize;
72 | }
73 |
74 | //------------------------------------------------------------------------
75 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/common.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 | #include
14 | #include
15 |
16 | #include "vec3f.h"
17 | #include "vec4f.h"
18 | #include "tensor.h"
19 |
20 | dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
21 | dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
22 |
23 | #ifdef __CUDACC__
24 |
25 | #ifdef _MSC_VER
26 | #define M_PI 3.14159265358979323846f
27 | #endif
28 |
29 | __host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
30 | {
31 | return dim3(
32 | min(blockSize.x, 32u),
33 | min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
34 | min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
35 | );
36 | }
37 |
38 | __device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
39 | #else
40 | dim3 getWarpSize(dim3 blockSize);
41 | #endif
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/cubemap.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "common.h"
15 |
16 | struct DiffuseCubemapKernelParams
17 | {
18 | Tensor cubemap;
19 | Tensor out;
20 | dim3 gridSize;
21 | };
22 |
23 | struct SpecularCubemapKernelParams
24 | {
25 | Tensor cubemap;
26 | Tensor bounds;
27 | Tensor out;
28 | dim3 gridSize;
29 | float costheta_cutoff;
30 | float roughness;
31 | };
32 |
33 | struct SpecularBoundsKernelParams
34 | {
35 | float costheta_cutoff;
36 | Tensor out;
37 | dim3 gridSize;
38 | };
39 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/loss.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #include
13 |
14 | #include "common.h"
15 | #include "loss.h"
16 |
17 | //------------------------------------------------------------------------
18 | // Utils
19 |
20 | __device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
21 |
22 | __device__ float warpSum(float val) {
23 | for (int i = 1; i < 32; i *= 2)
24 | val += __shfl_xor_sync(0xFFFFFFFF, val, i);
25 | return val;
26 | }
27 |
28 | //------------------------------------------------------------------------
29 | // Tonemapping
30 |
31 | __device__ inline float fwdSRGB(float x)
32 | {
33 | return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
34 | }
35 |
36 | __device__ inline void bwdSRGB(float x, float &d_x, float d_out)
37 | {
38 | if (x > 0.0031308f)
39 | d_x += d_out * 0.439583f / powf(x, 0.583333f);
40 | else if (x > 0.0f)
41 | d_x += d_out * 12.92f;
42 | }
43 |
44 | __device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
45 | {
46 | return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
47 | }
48 |
49 | __device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
50 | {
51 | if (x.x > 0.0f && x.x < 65535.0f)
52 | {
53 | bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
54 | d_x.x *= 1 / (x.x + 1.0f);
55 | }
56 | if (x.y > 0.0f && x.y < 65535.0f)
57 | {
58 | bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
59 | d_x.y *= 1 / (x.y + 1.0f);
60 | }
61 | if (x.z > 0.0f && x.z < 65535.0f)
62 | {
63 | bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
64 | d_x.z *= 1 / (x.z + 1.0f);
65 | }
66 | }
67 |
68 | __device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
69 | {
70 | return (img - target) * (img - target) / (img * img + target * target + eps);
71 | }
72 |
73 | __device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
74 | {
75 | float denom = (target * target + img * img + eps);
76 | d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
77 | d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
78 | }
79 |
80 | __device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
81 | {
82 | return abs(img - target) / (img + target + eps);
83 | }
84 |
85 | __device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
86 | {
87 | float denom = (target + img + eps);
88 | d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
89 | d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
90 | }
91 |
92 | //------------------------------------------------------------------------
93 | // Kernels
94 |
95 | __global__ void imgLossFwdKernel(LossKernelParams p)
96 | {
97 | // Calculate pixel position.
98 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
99 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
100 | unsigned int pz = blockIdx.z;
101 |
102 | float floss = 0.0f;
103 | if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
104 | {
105 | vec3f img = p.img.fetch3(px, py, pz);
106 | vec3f target = p.target.fetch3(px, py, pz);
107 |
108 | img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
109 | target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
110 |
111 | if (p.tonemapper == TONEMAPPER_LOG_SRGB)
112 | {
113 | img = fwdTonemapLogSRGB(img);
114 | target = fwdTonemapLogSRGB(target);
115 | }
116 |
117 | vec3f vloss(0);
118 | if (p.loss == LOSS_MSE)
119 | vloss = (img - target) * (img - target);
120 | else if (p.loss == LOSS_RELMSE)
121 | vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
122 | else if (p.loss == LOSS_SMAPE)
123 | vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
124 | else
125 | vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
126 |
127 | floss = sum(vloss) / 3.0f;
128 | }
129 |
130 | floss = warpSum(floss);
131 |
132 | dim3 warpSize = getWarpSize(blockDim);
133 | if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
134 | p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
135 | }
136 |
137 | __global__ void imgLossBwdKernel(LossKernelParams p)
138 | {
139 | // Calculate pixel position.
140 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
141 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
142 | unsigned int pz = blockIdx.z;
143 |
144 | if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
145 | return;
146 |
147 | dim3 warpSize = getWarpSize(blockDim);
148 |
149 | vec3f _img = p.img.fetch3(px, py, pz);
150 | vec3f _target = p.target.fetch3(px, py, pz);
151 | float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
152 |
153 | /////////////////////////////////////////////////////////////////////
154 | // FWD
155 |
156 | vec3f img = _img, target = _target;
157 | if (p.tonemapper == TONEMAPPER_LOG_SRGB)
158 | {
159 | img = fwdTonemapLogSRGB(img);
160 | target = fwdTonemapLogSRGB(target);
161 | }
162 |
163 | /////////////////////////////////////////////////////////////////////
164 | // BWD
165 |
166 | vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
167 |
168 | vec3f d_img(0), d_target(0);
169 | if (p.loss == LOSS_MSE)
170 | {
171 | d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
172 | d_target = -d_img;
173 | }
174 | else if (p.loss == LOSS_RELMSE)
175 | {
176 | bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
177 | bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
178 | bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
179 | }
180 | else if (p.loss == LOSS_SMAPE)
181 | {
182 | bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
183 | bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
184 | bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
185 | }
186 | else
187 | {
188 | d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
189 | d_target = -d_img;
190 | }
191 |
192 |
193 | if (p.tonemapper == TONEMAPPER_LOG_SRGB)
194 | {
195 | vec3f d__img(0), d__target(0);
196 | bwdTonemapLogSRGB(_img, d__img, d_img);
197 | bwdTonemapLogSRGB(_target, d__target, d_target);
198 | d_img = d__img; d_target = d__target;
199 | }
200 |
201 | if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
202 | if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
203 | if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
204 | if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
205 | if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
206 | if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
207 |
208 | p.img.store_grad(px, py, pz, d_img);
209 | p.target.store_grad(px, py, pz, d_target);
210 | }
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/loss.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "common.h"
15 |
16 | enum TonemapperType
17 | {
18 | TONEMAPPER_NONE = 0,
19 | TONEMAPPER_LOG_SRGB = 1
20 | };
21 |
22 | enum LossType
23 | {
24 | LOSS_L1 = 0,
25 | LOSS_MSE = 1,
26 | LOSS_RELMSE = 2,
27 | LOSS_SMAPE = 3
28 | };
29 |
30 | struct LossKernelParams
31 | {
32 | Tensor img;
33 | Tensor target;
34 | Tensor out;
35 | dim3 gridSize;
36 | TonemapperType tonemapper;
37 | LossType loss;
38 | };
39 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/mesh.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #include
13 | #include
14 |
15 | #include "common.h"
16 | #include "mesh.h"
17 |
18 |
19 | //------------------------------------------------------------------------
20 | // Kernels
21 |
22 | __global__ void xfmPointsFwdKernel(XfmKernelParams p)
23 | {
24 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
25 | unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
26 |
27 | __shared__ float mtx[4][4];
28 | if (threadIdx.x < 16)
29 | mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
30 | __syncthreads();
31 |
32 | if (px >= p.gridSize.x)
33 | return;
34 |
35 | vec3f pos(
36 | p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
37 | p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
38 | p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
39 | );
40 |
41 | if (p.isPoints)
42 | {
43 | p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]);
44 | p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]);
45 | p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]);
46 | p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]);
47 | }
48 | else
49 | {
50 | p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]);
51 | p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]);
52 | p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]);
53 | }
54 | }
55 |
56 | __global__ void xfmPointsBwdKernel(XfmKernelParams p)
57 | {
58 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
59 | unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
60 |
61 | __shared__ float mtx[4][4];
62 | if (threadIdx.x < 16)
63 | mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
64 | __syncthreads();
65 |
66 | if (px >= p.gridSize.x)
67 | return;
68 |
69 | vec3f pos(
70 | p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
71 | p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
72 | p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
73 | );
74 |
75 | vec4f d_out(
76 | p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)),
77 | p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)),
78 | p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)),
79 | p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0))
80 | );
81 |
82 | if (p.isPoints)
83 | {
84 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]);
85 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]);
86 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]);
87 | }
88 | else
89 | {
90 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]);
91 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]);
92 | p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]);
93 | }
94 | }
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/mesh.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "common.h"
15 |
16 | struct XfmKernelParams
17 | {
18 | bool isPoints;
19 | Tensor points;
20 | Tensor matrix;
21 | Tensor out;
22 | dim3 gridSize;
23 | };
24 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/normal.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #include "common.h"
13 | #include "normal.h"
14 |
15 | #define NORMAL_THRESHOLD 0.1f
16 |
17 | //------------------------------------------------------------------------
18 | // Perturb shading normal by tangent frame
19 |
20 | __device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl)
21 | {
22 | vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
23 | vec3f smooth_bitng = safeNormalize(_smooth_bitng);
24 | vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
25 | return safeNormalize(_shading_nrm);
26 | }
27 |
28 | __device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl)
29 | {
30 | ////////////////////////////////////////////////////////////////////////
31 | // FWD
32 | vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
33 | vec3f smooth_bitng = safeNormalize(_smooth_bitng);
34 | vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
35 |
36 | ////////////////////////////////////////////////////////////////////////
37 | // BWD
38 | vec3f d_shading_nrm(0);
39 | bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out);
40 |
41 | vec3f d_smooth_bitng(0);
42 |
43 | if (perturbed_nrm.z > 0.0f)
44 | {
45 | d_smooth_nrm += d_shading_nrm * perturbed_nrm.z;
46 | d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm);
47 | }
48 |
49 | d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y;
50 | d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng);
51 |
52 | d_smooth_tng += d_shading_nrm * perturbed_nrm.x;
53 | d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng);
54 |
55 | vec3f d__smooth_bitng(0);
56 | bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng);
57 |
58 | bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng);
59 | }
60 |
61 | //------------------------------------------------------------------------
62 | #define bent_nrm_eps 0.001f
63 |
64 | __device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm)
65 | {
66 | float dp = dot(view_vec, smooth_nrm);
67 | float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
68 | return geom_nrm * (1.0f - t) + smooth_nrm * t;
69 | }
70 |
71 | __device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out)
72 | {
73 | ////////////////////////////////////////////////////////////////////////
74 | // FWD
75 | float dp = dot(view_vec, smooth_nrm);
76 | float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
77 |
78 | ////////////////////////////////////////////////////////////////////////
79 | // BWD
80 | if (dp > NORMAL_THRESHOLD)
81 | d_smooth_nrm += d_out;
82 | else
83 | {
84 | // geom_nrm * (1.0f - t) + smooth_nrm * t;
85 | d_geom_nrm += d_out * (1.0f - t);
86 | d_smooth_nrm += d_out * t;
87 | float d_t = sum(d_out * (smooth_nrm - geom_nrm));
88 |
89 | float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD;
90 |
91 | bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp);
92 | }
93 | }
94 |
95 | //------------------------------------------------------------------------
96 | // Kernels
97 |
98 | __global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p)
99 | {
100 | // Calculate pixel position.
101 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
102 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
103 | unsigned int pz = blockIdx.z;
104 | if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
105 | return;
106 |
107 | vec3f pos = p.pos.fetch3(px, py, pz);
108 | vec3f view_pos = p.view_pos.fetch3(px, py, pz);
109 | vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
110 | vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
111 | vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
112 | vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
113 |
114 | vec3f smooth_nrm = safeNormalize(_smooth_nrm);
115 | vec3f smooth_tng = safeNormalize(_smooth_tng);
116 | vec3f view_vec = safeNormalize(view_pos - pos);
117 | vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
118 |
119 | vec3f res;
120 | if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
121 | res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm);
122 | else
123 | res = fwdBendNormal(view_vec, shading_nrm, geom_nrm);
124 |
125 | p.out.store(px, py, pz, res);
126 | }
127 |
128 | __global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p)
129 | {
130 | // Calculate pixel position.
131 | unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
132 | unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
133 | unsigned int pz = blockIdx.z;
134 | if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
135 | return;
136 |
137 | vec3f pos = p.pos.fetch3(px, py, pz);
138 | vec3f view_pos = p.view_pos.fetch3(px, py, pz);
139 | vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
140 | vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
141 | vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
142 | vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
143 | vec3f d_out = p.out.fetch3(px, py, pz);
144 |
145 | ///////////////////////////////////////////////////////////////////////////////////////////////////
146 | // FWD
147 |
148 | vec3f smooth_nrm = safeNormalize(_smooth_nrm);
149 | vec3f smooth_tng = safeNormalize(_smooth_tng);
150 | vec3f _view_vec = view_pos - pos;
151 | vec3f view_vec = safeNormalize(view_pos - pos);
152 |
153 | vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
154 |
155 | ///////////////////////////////////////////////////////////////////////////////////////////////////
156 | // BWD
157 |
158 | vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0);
159 | if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
160 | {
161 | bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
162 | d_shading_nrm = -d_shading_nrm;
163 | d_geom_nrm = -d_geom_nrm;
164 | }
165 | else
166 | bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
167 |
168 | vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0);
169 | bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl);
170 |
171 | vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0);
172 | bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec);
173 | bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm);
174 | bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng);
175 |
176 | p.pos.store_grad(px, py, pz, -d__view_vec);
177 | p.view_pos.store_grad(px, py, pz, d__view_vec);
178 | p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm);
179 | p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm);
180 | p.smooth_tng.store_grad(px, py, pz, d__smooth_tng);
181 | p.geom_nrm.store_grad(px, py, pz, d_geom_nrm);
182 | }
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/normal.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "common.h"
15 |
16 | struct PrepareShadingNormalKernelParams
17 | {
18 | Tensor pos;
19 | Tensor view_pos;
20 | Tensor perturbed_nrm;
21 | Tensor smooth_nrm;
22 | Tensor smooth_tng;
23 | Tensor geom_nrm;
24 | Tensor out;
25 | dim3 gridSize;
26 | bool two_sided_shading, opengl;
27 | };
28 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/tensor.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 | #if defined(__CUDACC__) && defined(BFLOAT16)
14 | #include // bfloat16 is float32 compatible with less mantissa bits
15 | #endif
16 |
17 | //---------------------------------------------------------------------------------
18 | // CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16
19 |
20 | struct Tensor
21 | {
22 | void* val;
23 | void* d_val;
24 | int dims[4], _dims[4];
25 | int strides[4];
26 | bool fp16;
27 |
28 | #if defined(__CUDA__) && !defined(__CUDA_ARCH__)
29 | Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {}
30 | #endif
31 |
32 | #ifdef __CUDACC__
33 | // Helpers to index and read/write a single element
34 | __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; }
35 | __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); }
36 | __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; }
37 | #ifdef BFLOAT16
38 | __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; }
39 | __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; }
40 | __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; }
41 | #else
42 | __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; }
43 | __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; }
44 | __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; }
45 | #endif
46 |
47 | //////////////////////////////////////////////////////////////////////////////////////////
48 | // Fetch, use broadcasting for tensor dimensions of size 1
49 | __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const
50 | {
51 | return fetch(nhwcIndex(z, y, x, 0));
52 | }
53 |
54 | __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const
55 | {
56 | return vec3f(
57 | fetch(nhwcIndex(z, y, x, 0)),
58 | fetch(nhwcIndex(z, y, x, 1)),
59 | fetch(nhwcIndex(z, y, x, 2))
60 | );
61 | }
62 |
63 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////
64 | // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
65 | __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val)
66 | {
67 | store(_nhwcIndex(z, y, x, 0), _val);
68 | }
69 |
70 | __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
71 | {
72 | store(_nhwcIndex(z, y, x, 0), _val.x);
73 | store(_nhwcIndex(z, y, x, 1), _val.y);
74 | store(_nhwcIndex(z, y, x, 2), _val.z);
75 | }
76 |
77 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////
78 | // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
79 | __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val)
80 | {
81 | store_grad(nhwcIndexContinuous(z, y, x, 0), _val);
82 | }
83 |
84 | __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
85 | {
86 | store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x);
87 | store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y);
88 | store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z);
89 | }
90 | #endif
91 |
92 | };
93 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/vec3f.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 |
14 | struct vec3f
15 | {
16 | float x, y, z;
17 |
18 | #ifdef __CUDACC__
19 | __device__ vec3f() { }
20 | __device__ vec3f(float v) { x = v; y = v; z = v; }
21 | __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; }
22 | __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; }
23 |
24 | __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; }
25 | __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; }
26 | __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; }
27 | __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; }
28 | #endif
29 | };
30 |
31 | #ifdef __CUDACC__
32 | __device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); }
33 | __device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); }
34 | __device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); }
35 | __device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); }
36 | __device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); }
37 |
38 | __device__ static inline float sum(vec3f a)
39 | {
40 | return a.x + a.y + a.z;
41 | }
42 |
43 | __device__ static inline vec3f cross(vec3f a, vec3f b)
44 | {
45 | vec3f out;
46 | out.x = a.y * b.z - a.z * b.y;
47 | out.y = a.z * b.x - a.x * b.z;
48 | out.z = a.x * b.y - a.y * b.x;
49 | return out;
50 | }
51 |
52 | __device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out)
53 | {
54 | d_a.x += d_out.z * b.y - d_out.y * b.z;
55 | d_a.y += d_out.x * b.z - d_out.z * b.x;
56 | d_a.z += d_out.y * b.x - d_out.x * b.y;
57 |
58 | d_b.x += d_out.y * a.z - d_out.z * a.y;
59 | d_b.y += d_out.z * a.x - d_out.x * a.z;
60 | d_b.z += d_out.x * a.y - d_out.y * a.x;
61 | }
62 |
63 | __device__ static inline float dot(vec3f a, vec3f b)
64 | {
65 | return a.x * b.x + a.y * b.y + a.z * b.z;
66 | }
67 |
68 | __device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out)
69 | {
70 | d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;
71 | d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;
72 | }
73 |
74 | __device__ static inline vec3f reflect(vec3f x, vec3f n)
75 | {
76 | return n * 2.0f * dot(n, x) - x;
77 | }
78 |
79 | __device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out)
80 | {
81 | d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);
82 | d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);
83 | d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);
84 |
85 | d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);
86 | d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);
87 | d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));
88 | }
89 |
90 | __device__ static inline vec3f safeNormalize(vec3f v)
91 | {
92 | float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
93 | return l > 0.0f ? (v / l) : vec3f(0.0f);
94 | }
95 |
96 | __device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out)
97 | {
98 |
99 | float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
100 | if (l > 0.0f)
101 | {
102 | float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);
103 | d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;
104 | d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;
105 | d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;
106 | }
107 | }
108 |
109 | #endif
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/c_src/vec4f.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5 | * property and proprietary rights in and to this material, related
6 | * documentation and any modifications thereto. Any use, reproduction,
7 | * disclosure or distribution of this material and related documentation
8 | * without an express license agreement from NVIDIA CORPORATION or
9 | * its affiliates is strictly prohibited.
10 | */
11 |
12 | #pragma once
13 |
14 | struct vec4f
15 | {
16 | float x, y, z, w;
17 |
18 | #ifdef __CUDACC__
19 | __device__ vec4f() { }
20 | __device__ vec4f(float v) { x = v; y = v; z = v; w = v; }
21 | __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; }
22 | __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; }
23 | #endif
24 | };
25 |
26 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import torch
11 |
12 | #----------------------------------------------------------------------------
13 | # HDR image losses
14 | #----------------------------------------------------------------------------
15 |
16 | def _tonemap_srgb(f):
17 | return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
18 |
19 | def _SMAPE(img, target, eps=0.01):
20 | nom = torch.abs(img - target)
21 | denom = torch.abs(img) + torch.abs(target) + 0.01
22 | return torch.mean(nom / denom)
23 |
24 | def _RELMSE(img, target, eps=0.1):
25 | nom = (img - target) * (img - target)
26 | denom = img * img + target * target + 0.1
27 | return torch.mean(nom / denom)
28 |
29 | def image_loss_fn(img, target, loss, tonemapper):
30 | if tonemapper == 'log_srgb':
31 | img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1))
32 | target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1))
33 |
34 | if loss == 'mse':
35 | return torch.nn.functional.mse_loss(img, target)
36 | elif loss == 'smape':
37 | return _SMAPE(img, target)
38 | elif loss == 'relmse':
39 | return _RELMSE(img, target)
40 | else:
41 | return torch.nn.functional.l1_loss(img, target)
42 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/tests/test_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import torch
11 |
12 | import os
13 | import sys
14 | sys.path.insert(0, os.path.join(sys.path[0], '../..'))
15 | import renderutils as ru
16 |
17 | RES = 8
18 | DTYPE = torch.float32
19 |
20 | def tonemap_srgb(f):
21 | return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
22 |
23 | def l1(output, target):
24 | x = torch.clamp(output, min=0, max=65535)
25 | r = torch.clamp(target, min=0, max=65535)
26 | x = tonemap_srgb(torch.log(x + 1))
27 | r = tonemap_srgb(torch.log(r + 1))
28 | return torch.nn.functional.l1_loss(x,r)
29 |
30 | def relative_loss(name, ref, cuda):
31 | ref = ref.float()
32 | cuda = cuda.float()
33 | print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
34 |
35 | def test_loss(loss, tonemapper):
36 | img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
37 | img_ref = img_cuda.clone().detach().requires_grad_(True)
38 | target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
39 | target_ref = target_cuda.clone().detach().requires_grad_(True)
40 |
41 | ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True)
42 | ref_loss.backward()
43 |
44 | cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper)
45 | cuda_loss.backward()
46 |
47 | print("-------------------------------------------------------------")
48 | print(" Loss: %s, %s" % (loss, tonemapper))
49 | print("-------------------------------------------------------------")
50 |
51 | relative_loss("res:", ref_loss, cuda_loss)
52 | relative_loss("img:", img_ref.grad, img_cuda.grad)
53 | relative_loss("target:", target_ref.grad, target_cuda.grad)
54 |
55 |
56 | test_loss('l1', 'none')
57 | test_loss('l1', 'log_srgb')
58 | test_loss('mse', 'log_srgb')
59 | test_loss('smape', 'none')
60 | test_loss('relmse', 'none')
61 | test_loss('mse', 'none')
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/tests/test_mesh.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import torch
11 |
12 | import os
13 | import sys
14 | sys.path.insert(0, os.path.join(sys.path[0], '../..'))
15 | import renderutils as ru
16 |
17 | BATCH = 8
18 | RES = 1024
19 | DTYPE = torch.float32
20 |
21 | torch.manual_seed(0)
22 |
23 | def tonemap_srgb(f):
24 | return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
25 |
26 | def l1(output, target):
27 | x = torch.clamp(output, min=0, max=65535)
28 | r = torch.clamp(target, min=0, max=65535)
29 | x = tonemap_srgb(torch.log(x + 1))
30 | r = tonemap_srgb(torch.log(r + 1))
31 | return torch.nn.functional.l1_loss(x,r)
32 |
33 | def relative_loss(name, ref, cuda):
34 | ref = ref.float()
35 | cuda = cuda.float()
36 | print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item())
37 |
38 | def test_xfm_points():
39 | points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
40 | points_ref = points_cuda.clone().detach().requires_grad_(True)
41 | mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
42 | mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
43 | target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
44 |
45 | ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True)
46 | ref_loss = torch.nn.MSELoss()(ref_out, target)
47 | ref_loss.backward()
48 |
49 | cuda_out = ru.xfm_points(points_cuda, mtx_cuda)
50 | cuda_loss = torch.nn.MSELoss()(cuda_out, target)
51 | cuda_loss.backward()
52 |
53 | print("-------------------------------------------------------------")
54 |
55 | relative_loss("res:", ref_out, cuda_out)
56 | relative_loss("points:", points_ref.grad, points_cuda.grad)
57 |
58 | def test_xfm_vectors():
59 | points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
60 | points_ref = points_cuda.clone().detach().requires_grad_(True)
61 | points_cuda_p = points_cuda.clone().detach().requires_grad_(True)
62 | points_ref_p = points_cuda.clone().detach().requires_grad_(True)
63 | mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)
64 | mtx_ref = mtx_cuda.clone().detach().requires_grad_(True)
65 | target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)
66 |
67 | ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True)
68 | ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3])
69 | ref_loss.backward()
70 |
71 | cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda)
72 | cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3])
73 | cuda_loss.backward()
74 |
75 | ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True)
76 | ref_loss_p = torch.nn.MSELoss()(ref_out_p, target)
77 | ref_loss_p.backward()
78 |
79 | cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda)
80 | cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target)
81 | cuda_loss_p.backward()
82 |
83 | print("-------------------------------------------------------------")
84 |
85 | relative_loss("res:", ref_out, cuda_out)
86 | relative_loss("points:", points_ref.grad, points_cuda.grad)
87 | relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad)
88 |
89 | test_xfm_points()
90 | test_xfm_vectors()
91 |
--------------------------------------------------------------------------------
/others/nvdiffrec/render/renderutils/tests/test_perf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4 | # property and proprietary rights in and to this material, related
5 | # documentation and any modifications thereto. Any use, reproduction,
6 | # disclosure or distribution of this material and related documentation
7 | # without an express license agreement from NVIDIA CORPORATION or
8 | # its affiliates is strictly prohibited.
9 |
10 | import torch
11 |
12 | import os
13 | import sys
14 | sys.path.insert(0, os.path.join(sys.path[0], '../..'))
15 | import renderutils as ru
16 |
17 | DTYPE=torch.float32
18 |
19 | def test_bsdf(BATCH, RES, ITR):
20 | kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
21 | kd_ref = kd_cuda.clone().detach().requires_grad_(True)
22 | arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
23 | arm_ref = arm_cuda.clone().detach().requires_grad_(True)
24 | pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
25 | pos_ref = pos_cuda.clone().detach().requires_grad_(True)
26 | nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
27 | nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
28 | view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
29 | view_ref = view_cuda.clone().detach().requires_grad_(True)
30 | light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
31 | light_ref = light_cuda.clone().detach().requires_grad_(True)
32 | target = torch.rand(BATCH, RES, RES, 3, device='cuda')
33 |
34 | start = torch.cuda.Event(enable_timing=True)
35 | end = torch.cuda.Event(enable_timing=True)
36 |
37 | ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
38 |
39 | print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES))
40 |
41 | start.record()
42 | for i in range(ITR):
43 | ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True)
44 | end.record()
45 | torch.cuda.synchronize()
46 | print("Pbr BSDF python:", start.elapsed_time(end))
47 |
48 | start.record()
49 | for i in range(ITR):
50 | cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)
51 | end.record()
52 | torch.cuda.synchronize()
53 | print("Pbr BSDF cuda:", start.elapsed_time(end))
54 |
55 | test_bsdf(1, 512, 1000)
56 | test_bsdf(16, 512, 1000)
57 | test_bsdf(1, 2048, 1000)
58 |
--------------------------------------------------------------------------------
/raymarching/__init__.py:
--------------------------------------------------------------------------------
1 | from .raymarching import *
--------------------------------------------------------------------------------
/raymarching/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21 | if paths:
22 | return paths[0]
23 |
24 | # If cl.exe is not on path, try to find it.
25 | if os.system("where cl.exe >nul 2>nul") != 0:
26 | cl_path = find_cl_path()
27 | if cl_path is None:
28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29 | os.environ["PATH"] += ";" + cl_path
30 |
31 | _backend = load(name='_raymarching',
32 | extra_cflags=c_flags,
33 | extra_cuda_cflags=nvcc_flags,
34 | sources=[os.path.join(_src_path, 'src', f) for f in [
35 | 'raymarching.cu',
36 | 'bindings.cpp',
37 | ]],
38 | )
39 |
40 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/raymarching/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | '''
33 | Usage:
34 |
35 | python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
36 |
37 | python setup.py install # build extensions and install (copy) to PATH.
38 | pip install . # ditto but better (e.g., dependency & metadata handling)
39 |
40 | python setup.py develop # build extensions and install (symbolic) to PATH.
41 | pip install -e . # ditto but better (e.g., dependency & metadata handling)
42 |
43 | '''
44 | setup(
45 | name='raymarching', # package name, import this to use python API
46 | ext_modules=[
47 | CUDAExtension(
48 | name='_raymarching', # extension name, import this to use CUDA API
49 | sources=[os.path.join(_src_path, 'src', f) for f in [
50 | 'raymarching.cu',
51 | 'bindings.cpp',
52 | ]],
53 | extra_compile_args={
54 | 'cxx': c_flags,
55 | 'nvcc': nvcc_flags,
56 | }
57 | ),
58 | ],
59 | cmdclass={
60 | 'build_ext': BuildExtension,
61 | }
62 | )
--------------------------------------------------------------------------------
/raymarching/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "raymarching.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | // utils
7 | m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)");
8 | m.def("packbits", &packbits, "packbits (CUDA)");
9 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
10 | m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
11 | m.def("morton3D", &morton3D, "morton3D (CUDA)");
12 | m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
13 | // train
14 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
15 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
16 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
17 | // infer
18 | m.def("march_rays", &march_rays, "march rays (CUDA)");
19 | m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
20 | }
--------------------------------------------------------------------------------
/raymarching/src/raymarching.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 |
7 | void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8 | void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
9 | void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
10 | void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
11 | void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
12 | void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res);
13 |
14 | void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises);
15 | void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
16 | void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
17 |
18 | void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises);
19 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | rich
3 | ninja
4 | numpy
5 | pandas
6 | scipy
7 | scikit-learn
8 | matplotlib
9 | opencv-python
10 | imageio
11 | imageio-ffmpeg
12 |
13 | torch
14 | torch-ema
15 | einops
16 | tensorboard
17 | tensorboardX
18 |
19 | # for grid_tcnn
20 | # git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
21 |
22 | # for stable-diffusion
23 | huggingface_hub
24 | diffusers==0.20.0
25 | accelerate
26 | transformers
27 | git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
28 | triton
29 |
30 | # for dmtet and mesh export
31 | xatlas
32 | trimesh
33 | PyMCubes
34 | pymeshlab
35 | git+https://github.com/NVlabs/nvdiffrast/
36 | open3d
37 |
38 | # for zero123
39 | carvekit-colab
40 | omegaconf
41 | pytorch-lightning
42 | taming-transformers-rom1504
43 | kornia
44 | git+https://github.com/openai/CLIP.git
45 |
--------------------------------------------------------------------------------
/scripts/blender.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 |
10 | # Simple script to show how to load our assets in Blender
11 | #
12 | # Open Blender 3.x, click the scripting tab and open this script
13 | # Run the script (Alt P)
14 | # Under the shading tab, you will see the shading network and environent probe (World node)
15 | # You can then render the model using the Cycles renderer
16 | import os
17 | import bpy
18 | import numpy as np
19 |
20 | # path to your mesh
21 | MESH_PATH = "path2mesh"
22 |
23 | RESOLUTION = 512
24 | SAMPLES = 64
25 |
26 | ################### Renderer settings ###################
27 | bpy.ops.file.pack_all()
28 | scene = bpy.context.scene
29 | scene.world.use_nodes = True
30 | scene.render.engine = 'CYCLES'
31 | scene.render.film_transparent = True
32 | scene.cycles.device = 'GPU'
33 | scene.cycles.samples = SAMPLES
34 | scene.cycles.max_bounces = 0
35 | scene.cycles.diffuse_bounces = 0
36 | scene.cycles.glossy_bounces = 0
37 | scene.cycles.transmission_bounces = 0
38 | scene.cycles.volume_bounces = 0
39 | scene.cycles.transparent_max_bounces = 8
40 | scene.cycles.use_denoising = True
41 | scene.render.resolution_x = RESOLUTION
42 | scene.render.resolution_y = RESOLUTION
43 | scene.render.resolution_percentage = 100
44 |
45 | ################### Image output ###################
46 |
47 | # PNG output with sRGB tonemapping
48 | scene.display_settings.display_device = 'None'
49 | scene.view_settings.view_transform = 'Standard'
50 | scene.view_settings.exposure = 0.0
51 | scene.view_settings.gamma = 1.0
52 | scene.render.image_settings.file_format = 'PNG' # set output format to .png
53 |
54 | # OpenEXR output, no tonemapping applied
55 | # scene.display_settings.display_device = 'None'
56 | # scene.view_settings.view_transform = 'Standard'
57 | # scene.view_settings.exposure = 0.0
58 | # scene.view_settings.gamma = 1.0
59 | # scene.render.image_settings.file_format = 'OPEN_EXR'
60 |
61 |
62 | ################### Import obj mesh ###################
63 |
64 | imported_object = bpy.ops.import_scene.obj(filepath=os.path.join(MESH_PATH, "mesh.obj"), axis_forward = '-Z', axis_up = 'Y')
65 | obj_object = bpy.context.selected_objects[0]
66 |
67 | ################### Fix material graph ###################
68 |
69 | # Get material node tree, find BSDF and specular texture
70 | material = obj_object.active_material
71 | bsdf = material.node_tree.nodes["Principled BSDF"]
72 | image_node_ks = bsdf.inputs["Specular"].links[0].from_node
73 |
74 | # Split the specular texture into metalness and roughness
75 | separate_node = material.node_tree.nodes.new(type="ShaderNodeSeparateRGB")
76 | separate_node.name="SeparateKs"
77 | material.node_tree.links.new(image_node_ks.outputs[0], separate_node.inputs[0])
78 | material.node_tree.links.new(separate_node.outputs[2], bsdf.inputs["Metallic"])
79 | material.node_tree.links.new(separate_node.outputs[1], bsdf.inputs["Roughness"])
80 |
81 | if True:
82 | normal_map_node = bsdf.inputs["Normal"].links[0].from_node
83 | texture_n_node = normal_map_node.inputs["Color"].links[0].from_node
84 | material.node_tree.links.remove(normal_map_node.inputs["Color"].links[0])
85 | normal_separate_node = material.node_tree.nodes.new(type="ShaderNodeSeparateRGB")
86 | normal_separate_node.name="SeparateNormal"
87 | normal_combine_node = material.node_tree.nodes.new(type="ShaderNodeCombineRGB")
88 | normal_combine_node.name="CombineNormal"
89 |
90 | normal_invert_node = material.node_tree.nodes.new(type="ShaderNodeMath")
91 | normal_invert_node.name="InvertNormal"
92 | normal_invert_node.operation='SUBTRACT'
93 | normal_invert_node.inputs[0].default_value = 1.0
94 |
95 | material.node_tree.links.new(texture_n_node.outputs[0], normal_separate_node.inputs['Image'])
96 | material.node_tree.links.new(normal_separate_node.outputs['R'], normal_combine_node.inputs['R'])
97 | material.node_tree.links.new(normal_separate_node.outputs['G'], normal_invert_node.inputs[1])
98 | material.node_tree.links.new(normal_invert_node.outputs[0], normal_combine_node.inputs['G'])
99 | material.node_tree.links.new(normal_separate_node.outputs['B'], normal_combine_node.inputs['B'])
100 | material.node_tree.links.new(normal_combine_node.outputs[0], normal_map_node.inputs["Color"])
101 |
102 | material.node_tree.links.remove(bsdf.inputs["Specular"].links[0])
103 |
104 | # Set default values
105 | bsdf.inputs["Specular"].default_value = 0.5
106 | bsdf.inputs["Specular Tint"].default_value = 0.0
107 | bsdf.inputs["Sheen Tint"].default_value = 0.0
108 | bsdf.inputs["Clearcoat Roughness"].default_value = 0.0
109 |
110 | ################### Load HDR probe ###################
111 |
112 | texcoord = scene.world.node_tree.nodes.new(type="ShaderNodeTexCoord")
113 | mapping = scene.world.node_tree.nodes.new(type="ShaderNodeMapping")
114 | mapping.inputs['Rotation'].default_value = [0, 0, -np.pi*0.5]
115 | envmap = scene.world.node_tree.nodes.new(type="ShaderNodeTexEnvironment")
116 | envmap.image = bpy.data.images.load(os.path.join(MESH_PATH, "probe.hdr"))
117 |
118 | scene.world.node_tree.links.new(envmap.outputs['Color'], scene.world.node_tree.nodes['Background'].inputs['Color'])
119 | scene.world.node_tree.links.new(texcoord.outputs['Generated'], mapping.inputs['Vector'])
120 | scene.world.node_tree.links.new(mapping.outputs['Vector'], envmap.inputs['Vector'])
121 |
--------------------------------------------------------------------------------
/scripts/install_ext.sh:
--------------------------------------------------------------------------------
1 | pip install ./raymarching
2 | pip install ./shencoder
3 | pip install ./freqencoder
4 | pip install ./gridencoder
--------------------------------------------------------------------------------
/scripts/linear_app.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import cv2
4 | import os
5 | from diffusers import (
6 | AutoencoderKL,
7 | StableDiffusionPipeline,
8 | )
9 | from tqdm import tqdm
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 |
13 | path = "path to images folder"
14 | imgs_path = os.listdir(path)
15 |
16 | imgs_128 = []
17 | imgs_1024 = []
18 | latents = []
19 |
20 | vae = AutoencoderKL.from_pretrained(
21 | "madebyollin/sdxl-vae-fp16-fix",
22 | # "pretrained/SDXL/vae_fp16",
23 | # local_files_only=True,
24 | use_safetensors=True,
25 | torch_dtype=precision_t
26 | )
27 |
28 | def encode_imgs(imgs):
29 | # imgs: [B, 3, H, W]
30 | imgs = 2 * imgs - 1
31 | posterior = vae.encode(imgs).latent_dist
32 | latents = posterior.sample() * vae.config.scaling_factor
33 | return latents
34 |
35 |
36 | with torch.no_grad():
37 | for i in tqdm(range(1000)):
38 | img = cv2.imread(os.path.join(path, imgs_path[i]))
39 | img = cv2.cvtColor(img[..., :3], cv2.COLOR_BGR2RGB) / 255
40 | img = torch.tensor(img, dtype=torch.float32, device='cuda')[None, ...]
41 | img = img.permute(0, 3, 1, 2)
42 | img_1024 = F.interpolate(img, (1024, 1024), mode='bilinear', align_corners=False)
43 | img_128 = F.interpolate(img, (128, 128), mode='bilinear', align_corners=False)
44 | # imgs_1024.append(img_1024)
45 | imgs_128.append(img_128)
46 |
47 | with torch.cuda.amp.autocast(enabled=True):
48 | latent = encode_imgs(img_1024)
49 | latents.append(latent)
50 | torch.cuda.empty_cache()
51 |
52 | data = torch.cat(imgs_128, dim=0).permute(0, 2, 3, 1).reshape(-1, 3) # [N, 3]
53 | target = torch.cat(latents, dim=0).permute(0, 2, 3, 1).reshape(-1, 4) # [N, 4]
54 | data_ = torch.cat([data, torch.ones(data.shape[0], 1).cuda()], dim=1)
55 | target_ = torch.cat([target, torch.ones(target.shape[0], 1).cuda()], dim=1)
56 |
57 | coef_r2l = torch.linalg.inv(data_.T @ data_ + 0.001 * torch.eye(4).cuda()) @ (data_.T @ target)
58 | coef_l2r = (torch.linalg.inv(target_.T @ target_ + 0.001 * torch.eye(5).cuda())) @ (target_.T @ data)
59 |
60 | def normalize(img):
61 | img = (img - img.min()) / (img.max() - img.min())
62 | return img
63 |
64 |
65 | save_path = "2d/vae_linear_approx/sdxl"
66 | os.makedirs(os.path.join(save_path, 'l2r'), exist_ok=True)
67 | os.makedirs(os.path.join(save_path, 'r2l'), exist_ok=True)
68 | for i in range(100):
69 | x = latents[i][0].permute(1, 2, 0) @ coef_l2r[:4, :] + coef_l2r[4, :]
70 | rgb_fit = (np.clip((x.detach().cpu().numpy()) * 255, 0, 255)).astype(np.uint8)
71 | rgb_gt = (imgs_128[i][0].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
72 | # plt.imshow(rgb_fit);plt.show()
73 | # plt.imshow(imgs_128[i][0].permute(1, 2, 0).detach().cpu().numpy());plt.show()
74 |
75 | y = imgs_128[i][0].permute(1, 2, 0) @ coef_r2l[:3, :] + coef_r2l[3, :]
76 | latent_fit = (normalize(y.detach().cpu().numpy()) * 255).astype(np.uint8)
77 | latent_gt = (normalize(latents[i][0].permute(1, 2, 0).detach().cpu().numpy()) * 255).astype(np.uint8)
78 | # plt.imshow(normalize(y.detach().cpu().numpy()));plt.show()
79 | # plt.imshow(normalize(latents[i][0].permute(1, 2, 0).detach().cpu().numpy()));plt.show()
80 |
81 | cv2.imwrite(os.path.join(save_path, 'l2r', f'{i:04d}_fit.png'),
82 | cv2.cvtColor(rgb_fit, cv2.COLOR_RGB2BGR))
83 | cv2.imwrite(os.path.join(save_path, 'l2r', f'{i:04d}_gt.png'),
84 | cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2BGR))
85 | cv2.imwrite(os.path.join(save_path, 'r2l', f'{i:04d}_fit.png'),
86 | cv2.cvtColor(latent_fit, cv2.COLOR_RGB2BGR))
87 | cv2.imwrite(os.path.join(save_path, 'r2l', f'{i:04d}_gt.png'),
88 | cv2.cvtColor(latent_gt, cv2.COLOR_RGB2BGR))
89 |
--------------------------------------------------------------------------------
/shencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .sphere_harmonics import SHEncoder
--------------------------------------------------------------------------------
/shencoder/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.cpp_extension import load
3 |
4 | _src_path = os.path.dirname(os.path.abspath(__file__))
5 |
6 | nvcc_flags = [
7 | '-O3', '-std=c++14',
8 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9 | ]
10 |
11 | if os.name == "posix":
12 | c_flags = ['-O3', '-std=c++14']
13 | elif os.name == "nt":
14 | c_flags = ['/O2', '/std:c++17']
15 |
16 | # find cl.exe
17 | def find_cl_path():
18 | import glob
19 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21 | if paths:
22 | return paths[0]
23 |
24 | # If cl.exe is not on path, try to find it.
25 | if os.system("where cl.exe >nul 2>nul") != 0:
26 | cl_path = find_cl_path()
27 | if cl_path is None:
28 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29 | os.environ["PATH"] += ";" + cl_path
30 |
31 | _backend = load(name='_sh_encoder',
32 | extra_cflags=c_flags,
33 | extra_cuda_cflags=nvcc_flags,
34 | sources=[os.path.join(_src_path, 'src', f) for f in [
35 | 'shencoder.cu',
36 | 'bindings.cpp',
37 | ]],
38 | )
39 |
40 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/shencoder/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | '-O3', '-std=c++14',
9 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10 | ]
11 |
12 | if os.name == "posix":
13 | c_flags = ['-O3', '-std=c++14']
14 | elif os.name == "nt":
15 | c_flags = ['/O2', '/std:c++17']
16 |
17 | # find cl.exe
18 | def find_cl_path():
19 | import glob
20 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21 | paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22 | if paths:
23 | return paths[0]
24 |
25 | # If cl.exe is not on path, try to find it.
26 | if os.system("where cl.exe >nul 2>nul") != 0:
27 | cl_path = find_cl_path()
28 | if cl_path is None:
29 | raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30 | os.environ["PATH"] += ";" + cl_path
31 |
32 | setup(
33 | name='shencoder', # package name, import this to use python API
34 | ext_modules=[
35 | CUDAExtension(
36 | name='_shencoder', # extension name, import this to use CUDA API
37 | sources=[os.path.join(_src_path, 'src', f) for f in [
38 | 'shencoder.cu',
39 | 'bindings.cpp',
40 | ]],
41 | extra_compile_args={
42 | 'cxx': c_flags,
43 | 'nvcc': nvcc_flags,
44 | }
45 | ),
46 | ],
47 | cmdclass={
48 | 'build_ext': BuildExtension,
49 | }
50 | )
--------------------------------------------------------------------------------
/shencoder/sphere_harmonics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Function
6 | from torch.autograd.function import once_differentiable
7 | from torch.cuda.amp import custom_bwd, custom_fwd
8 |
9 | try:
10 | import _shencoder as _backend
11 | except ImportError:
12 | from .backend import _backend
13 |
14 | class _sh_encoder(Function):
15 | @staticmethod
16 | @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
17 | def forward(ctx, inputs, degree, calc_grad_inputs=False):
18 | # inputs: [B, input_dim], float in [-1, 1]
19 | # RETURN: [B, F], float
20 |
21 | inputs = inputs.contiguous()
22 | B, input_dim = inputs.shape # batch size, coord dim
23 | output_dim = degree ** 2
24 |
25 | outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
26 |
27 | if calc_grad_inputs:
28 | dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
29 | else:
30 | dy_dx = None
31 |
32 | _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
33 |
34 | ctx.save_for_backward(inputs, dy_dx)
35 | ctx.dims = [B, input_dim, degree]
36 |
37 | return outputs
38 |
39 | @staticmethod
40 | #@once_differentiable
41 | @custom_bwd
42 | def backward(ctx, grad):
43 | # grad: [B, C * C]
44 |
45 | inputs, dy_dx = ctx.saved_tensors
46 |
47 | if dy_dx is not None:
48 | grad = grad.contiguous()
49 | B, input_dim, degree = ctx.dims
50 | grad_inputs = torch.zeros_like(inputs)
51 | _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
52 | return grad_inputs, None, None
53 | else:
54 | return None, None, None
55 |
56 |
57 |
58 | sh_encode = _sh_encoder.apply
59 |
60 |
61 | class SHEncoder(nn.Module):
62 | def __init__(self, input_dim=3, degree=4):
63 | super().__init__()
64 |
65 | self.input_dim = input_dim # coord dims, must be 3
66 | self.degree = degree # 0 ~ 4
67 | self.output_dim = degree ** 2
68 |
69 | assert self.input_dim == 3, "SH encoder only support input dim == 3"
70 | assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
71 |
72 | def __repr__(self):
73 | return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
74 |
75 | def forward(self, inputs, size=1):
76 | # inputs: [..., input_dim], normalized real world positions in [-size, size]
77 | # return: [..., degree^2]
78 |
79 | inputs = inputs / size # [-1, 1]
80 |
81 | prefix_shape = list(inputs.shape[:-1])
82 | inputs = inputs.reshape(-1, self.input_dim)
83 |
84 | outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
85 | outputs = outputs.reshape(prefix_shape + [self.output_dim])
86 |
87 | return outputs
--------------------------------------------------------------------------------
/shencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "shencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
7 | m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
8 | }
--------------------------------------------------------------------------------
/shencoder/src/shencoder.h:
--------------------------------------------------------------------------------
1 | # pragma once
2 |
3 | #include
4 | #include
5 |
6 | // inputs: [B, D], float, in [-1, 1]
7 | // outputs: [B, F], float
8 |
9 | void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx);
10 | void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);
--------------------------------------------------------------------------------