├── .gitignore
├── LICENSE
├── README.md
├── confs
└── womask_iron.conf
├── create_env.sh
├── download_data.sh
├── evaluation
├── eval_image_folder.py
└── eval_mesh.py
├── models
├── dataset.py
├── embedder.py
├── export_materials.py
├── export_mesh.py
├── export_uv.py
├── fields.py
├── ggx
│ ├── ext_mts_rtrans_data.txt
│ └── int_mts_diff_rtrans_data.txt
├── image_losses.py
├── raytracer.py
├── renderer.py
└── renderer_ggx.py
├── readme_resources
├── assets_lowres.png
└── inputs_outputs.png
├── render_surface.py
├── render_synthetic_data
├── render_rgb_flash_mat.py
└── rgb_flash_hdr_mat.xml
├── render_volume.py
├── singleview
├── 12.png
└── cam_dict_norm.json
├── test_mitsuba
├── render_rgb_envmap_mat.py
├── render_rgb_flash_mat.py
├── rgb_envmap_hdr_mat.xml
└── rgb_flash_hdr_mat.xml
├── tests
├── data_singleview
│ ├── 12.png
│ └── cam_dict_norm.json
├── test_raytracer.py
├── test_singleview.py
└── test_viewsynthesis.py
└── train_scene.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *,cover
48 | .hypothesis/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # IPython Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # dotenv
81 | .env
82 |
83 | # virtualenv
84 | venv/
85 | ENV/
86 |
87 | # Spyder project settings
88 | .spyderproject
89 |
90 | # Rope project settings
91 | .ropeproject
92 | ### VirtualEnv template
93 | # Virtualenv
94 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
95 | .Python
96 | [Bb]in
97 | [Ii]nclude
98 | [Ll]ib
99 | [Ll]ib64
100 | [Ll]ocal
101 | [Ss]cripts
102 | pyvenv.cfg
103 | .venv
104 | pip-selfcheck.json
105 | ### JetBrains template
106 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
107 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
108 |
109 | # User-specific stuff:
110 | .idea/workspace.xml
111 | .idea/tasks.xml
112 | .idea/dictionaries
113 | .idea/vcs.xml
114 | .idea/jsLibraryMappings.xml
115 |
116 | # Sensitive or high-churn files:
117 | .idea/dataSources.ids
118 | .idea/dataSources.xml
119 | .idea/dataSources.local.xml
120 | .idea/sqlDataSources.xml
121 | .idea/dynamic.xml
122 | .idea/uiDesigner.xml
123 |
124 | # Gradle:
125 | .idea/gradle.xml
126 | .idea/libraries
127 |
128 | # Mongo Explorer plugin:
129 | .idea/mongoSettings.xml
130 |
131 | .idea/
132 |
133 | ## File-based project format:
134 | *.iws
135 |
136 | ## Plugin-specific files:
137 |
138 | # IntelliJ
139 | /out/
140 |
141 | # mpeltonen/sbt-idea plugin
142 | .idea_modules/
143 |
144 | # JIRA plugin
145 | atlassian-ide-plugin.xml
146 |
147 | # Crashlytics plugin (for Android Studio and IntelliJ)
148 | com_crashlytics_export_strings.xml
149 | crashlytics.properties
150 | crashlytics-build.properties
151 | fabric.properties
152 |
153 | data
154 | public_data
155 | exp
156 | tmp
157 |
158 | */debug_raytracer*
159 | */debug_singleview*
160 | */debug_multiview*
161 | */debug_inverse_rendering*
162 | */debug_viewsynthesis*
163 | */*/.DS_Store
164 | */*/test_buddha
165 | */*/test_kitty
166 | exp_iron*
167 | blender*
168 | data_flashlight*
169 | */*/*.log
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2022, Kai Zhang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # IRON: Inverse Rendering by Optimizing Neural SDFs and Materials from Photometric Images
2 |
3 | Note: this repo is still under construction.
4 |
5 | Project page:
6 |
7 | 
8 |
9 | ## Usage
10 |
11 | ### Create environment
12 |
13 | ```shell
14 | git clone https://github.com/Kai-46/iron.git && cd iron && . ./create_env.sh
15 | ```
16 |
17 | ### Download data
18 |
19 | ```shell
20 | . ./download_data.sh
21 | ```
22 |
23 | ### Training and testing
24 |
25 | ```shell
26 | . ./train_scene.sh drv/dragon
27 | ```
28 |
29 | Once training is done, you will see the recovered mesh and materials under the folder ```./exp_iron_stage2/drv/dragon/mesh_and_materials_50000/```. At the same time, the rendered test images are under the folder ``````./exp_iron_stage2/drv/dragon/render_test_50000/``````
30 |
31 | ### Relight the 3D assets using envmaps
32 |
33 | Check ```test_mitsuba/render_rgb_envmap_mat.py```.
34 |
35 | ### Evaluation
36 |
37 | Check ```evaluation/eval_mesh.py``` and ```evaluation/eval_image_folder.py```.
38 |
39 | ### Render synthetic data using Mitsuba
40 |
41 | Check ```render_synthetic_data/render_rgb_flash_mat.py```. To make renderings more shiny, try scaling up the specular albedo and scaling down the specular roughness; to make renderings more diffuse, try the opposite.
42 |
43 | ### Camera parameters convention
44 |
45 | We use the OpenCV camera convention just like [NeRF++](https://github.com/Kai-46/nerfplusplus); you might want to use the camera visualization and debugging tools in that codebase to inspect if there's any issue with the camera parameters. Note we also assume the objects are inside the unit sphere.
46 |
47 | ## Citations
48 |
49 | ```
50 | @inproceedings{iron-2022,
51 | title={IRON: Inverse Rendering by Optimizing Neural SDFs and Materials from Photometric Images},
52 | author={Zhang, Kai and Luan, Fujun and Li, Zhengqi and Snavely, Noah},
53 | booktitle={IEEE Conf. Comput. Vis. Pattern Recog.},
54 | year={2022}
55 | }
56 | ```
57 |
58 | ## Example results
59 |
60 |
61 |
62 | 
63 |
64 | ## Acknowledgements
65 |
66 | We would like to thank the authors of [IDR](https://github.com/lioryariv/idr) and [NeuS](https://github.com/Totoro97/NeuS) for open-sourcing their projects.
67 |
--------------------------------------------------------------------------------
/confs/womask_iron.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_exp_dir = ./exp_iron_stage1/CASE_NAME/
3 | recording = [
4 | ./,
5 | ./models
6 | ]
7 | }
8 |
9 | dataset {
10 | data_dir = ./data_flashlight/CASE_NAME/train/
11 | render_cameras_name = cameras_sphere.npz
12 | object_cameras_name = cameras_sphere.npz
13 | }
14 |
15 | train {
16 | learning_rate = 5e-4
17 | learning_rate_alpha = 0.05
18 | end_iter = 100001
19 |
20 | batch_size = 512
21 | validate_resolution_level = 4
22 | warm_up_end = 5000
23 | anneal_end = 50000
24 | use_white_bkgd = False
25 |
26 | save_freq = 10000
27 | val_freq = 2500
28 | val_mesh_freq = 5000
29 | report_freq = 100
30 |
31 | igr_weight = 0.1
32 | mask_weight = 0.0
33 | }
34 |
35 | model {
36 | nerf {
37 | D = 8,
38 | d_in = 4,
39 | d_in_view = 3,
40 | W = 256,
41 | multires = 10,
42 | multires_view = 4,
43 | output_ch = 4,
44 | skips=[4],
45 | use_viewdirs=True
46 | }
47 |
48 | sdf_network {
49 | d_out = 257
50 | d_in = 3
51 | d_hidden = 256
52 | n_layers = 8
53 | skip_in = [4]
54 | multires = 6
55 | bias = 0.5
56 | scale = 1.0
57 | geometric_init = True
58 | weight_norm = True
59 | }
60 |
61 | variance_network {
62 | init_val = 0.3
63 | }
64 |
65 | rendering_network {
66 | d_feature = 256
67 | mode = idr
68 | d_in = 9
69 | d_out = 3
70 | d_hidden = 256
71 | n_layers = 8
72 | skip_in = [4]
73 | weight_norm = True
74 | multires = 10
75 | multires_view = 4
76 | squeeze_out = True
77 | }
78 |
79 | neus_renderer {
80 | n_samples = 64
81 | n_importance = 64
82 | n_outside = 32
83 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling
84 | perturb = 1.0
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/create_env.sh:
--------------------------------------------------------------------------------
1 | conda create -y -n iron python=3.8 && conda activate iron
2 | pip install numpy scipy trimesh opencv_python scikit-image imageio imageio-ffmpeg pyhocon PyMCubes tqdm icecream configargparse
3 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
4 | pip install tensorboard kornia
5 | conda install -c conda-forge igl
6 |
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | pip install gdown
2 |
3 | echo "Downloading image indices for Bi et al 2020: Deep Reflectance Volumes: Relightable Reconstructions from Multi-View Photometric Images"
4 | echo "Please ask the authors of this work for data, and then split the data using the image indices"
5 | gdown 1BThZgEnHgsL7dgyVTQuSFYZjAkZzQozx
6 | unzip "Bi et al 2020-image_indices.zip"
7 |
8 | echo "Downloading real data captured by Luan et al 2021: Unified Shape and SVBRDF Recovery using Differentiable Monte Carlo Rendering"
9 | echo "Please credit the original paper if you use this data"
10 | gdown 1BO6XZjUm8PhHof5RZ7O0Y3C815loBlqj
11 | unzip "Luan et al 2021.zip"
12 |
13 | echo "Downloading synthetic assets for creating synthetic data with Mitsuba"
14 | gdown 1EhDI06NsluXsC98ZErvB7UN_TPI1_6sn
15 | unzip "synthetic_assets.zip"
16 |
--------------------------------------------------------------------------------
/evaluation/eval_image_folder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import os
4 | from skimage.metrics import structural_similarity
5 | import lpips
6 | import torch
7 | import glob
8 |
9 |
10 | def skimage_ssim(pred_im, trgt_im):
11 | ssim = 0.
12 | for ch in range(3):
13 | ssim += structural_similarity(trgt_im[:, :, ch], pred_im[:, :, ch],
14 | data_range=1.0, win_size=11, sigma=1.5,
15 | use_sample_covariance=False, k1=0.01, k2=0.03)
16 | ssim /= 3.
17 | return ssim
18 |
19 | def read_image(fpath):
20 | return imageio.imread(fpath).astype(np.float32) / 255.
21 |
22 | mse2psnr = lambda x: -10. * np.log(x+1e-10) / np.log(10.)
23 |
24 | import sys
25 | folder = sys.argv[1]
26 |
27 | all_psnr = []
28 | all_ssim = []
29 | all_lpips = []
30 |
31 | loss_fn_alex = lpips.LPIPS(net='alex').cuda() # best forward scores
32 | # loss_fn_vgg = lpips.LPIPS(net='vgg') # closer to "traditional" perceptual loss, when used for optimization
33 |
34 | with open(os.path.join(folder, '../metrics.txt'), 'w') as fp:
35 | fp.write('img_name\tpsnr\tssim\tlpips\n')
36 | for _, fpath in enumerate(glob.glob(os.path.join(folder, '*_truth.png'))):
37 | name = os.path.basename(fpath)
38 | idx = name.find('_')
39 | idx = int(name[:idx])
40 |
41 | pred_im = read_image(os.path.join(folder, '{}_prediction.png'.format(idx)))
42 | trgt_im = read_image(os.path.join(folder, '{}_truth.png'.format(idx)))
43 |
44 | psnr = mse2psnr(np.mean((pred_im - trgt_im) ** 2))
45 |
46 | ssim = skimage_ssim(trgt_im, pred_im)
47 |
48 | pred_im = torch.from_numpy(pred_im).permute(2, 0, 1).unsqueeze(0) * 2. - 1.
49 | trgt_im = torch.from_numpy(trgt_im).permute(2, 0, 1).unsqueeze(0) * 2. - 1.
50 | d = loss_fn_alex(trgt_im.cuda(), pred_im.cuda()).item()
51 |
52 | fp.write('{}_prediction.png\t{:.3f}\t{:.3f}\t{:.4f}\n'.format(idx, psnr, ssim, d))
53 |
54 | all_psnr.append(psnr)
55 | all_ssim.append(ssim)
56 | all_lpips.append(d)
57 | fp.write('\nAverage\t{:.3f}\t{:.3f}\t{:.4f}\n'.format(np.mean(all_psnr), np.mean(all_ssim), np.mean(all_lpips)))
58 |
59 |
--------------------------------------------------------------------------------
/evaluation/eval_mesh.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import igl
4 |
5 |
6 | def cal_mesh_err(va, fa, vb, fb):
7 | sqrD1, _, _ = igl.point_mesh_squared_distance(va, vb, fb)
8 | sqrD2, _, _ = igl.point_mesh_squared_distance(vb, va, fa)
9 | D1 = np.sqrt(sqrD1)
10 | D2 = np.sqrt(sqrD2)
11 | ret = (D1.mean() + D2.mean()) * 0.5
12 | return ret
13 |
14 |
15 | def eval_obj_meshes(pred_mesh_fpath, trgt_mesh_fpath):
16 | v1, _, n1, f1, _, _ = igl.read_obj(pred_mesh_fpath)
17 | v4, _, n4, f4, _, _ = igl.read_obj(trgt_mesh_fpath)
18 |
19 | return cal_mesh_err(v1, f1, v4, f4)
20 |
21 |
22 | import sys
23 | pred_mesh_fpath = sys.argv[1]
24 | trgt_mesh_fpath = sys.argv[2]
25 | dist_bidirectional = eval_obj_meshes(pred_mesh_fpath, trgt_mesh_fpath)
26 | print('\tChamfer_dist: ', dist_bidirectional)
27 |
--------------------------------------------------------------------------------
/models/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import cv2 as cv
4 | import numpy as np
5 | import os
6 | from glob import glob
7 | from icecream import ic
8 | from scipy.spatial.transform import Rotation as Rot
9 | from scipy.spatial.transform import Slerp
10 | import traceback
11 |
12 |
13 | # This function is borrowed from IDR: https://github.com/lioryariv/idr
14 | def load_K_Rt_from_P(filename, P=None):
15 | if P is None:
16 | lines = open(filename).read().splitlines()
17 | if len(lines) == 4:
18 | lines = lines[1:]
19 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
20 | P = np.asarray(lines).astype(np.float32).squeeze()
21 |
22 | out = cv.decomposeProjectionMatrix(P)
23 | K = out[0]
24 | R = out[1]
25 | t = out[2]
26 |
27 | K = K / K[2, 2]
28 | intrinsics = np.eye(4)
29 | intrinsics[:3, :3] = K
30 |
31 | pose = np.eye(4, dtype=np.float32)
32 | pose[:3, :3] = R.transpose()
33 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
34 |
35 | return intrinsics, pose
36 |
37 |
38 | class Dataset:
39 | def __init__(self, conf):
40 | super(Dataset, self).__init__()
41 | print("Load data: Begin")
42 | self.device = torch.device("cuda")
43 | self.conf = conf
44 |
45 | self.data_dir = conf.get_string("data_dir")
46 | self.render_cameras_name = conf.get_string("render_cameras_name")
47 | self.object_cameras_name = conf.get_string("object_cameras_name")
48 |
49 | self.camera_outside_sphere = conf.get_bool("camera_outside_sphere", default=True)
50 | # self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1) # not used
51 |
52 | import json
53 |
54 | camera_dict = json.load(open(os.path.join(self.data_dir, "cam_dict_norm.json")))
55 | for x in list(camera_dict.keys()):
56 | x = x[:-4] + ".png"
57 | camera_dict[x]["K"] = np.array(camera_dict[x]["K"]).reshape((4, 4))
58 | camera_dict[x]["W2C"] = np.array(camera_dict[x]["W2C"]).reshape((4, 4))
59 |
60 | self.camera_dict = camera_dict
61 |
62 | try:
63 | self.images_lis = sorted(glob(os.path.join(self.data_dir, "image/*.png")))
64 | self.n_images = len(self.images_lis)
65 | self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 255.0
66 | except:
67 | # traceback.print_exc()
68 |
69 | print("Loading png images failed; try loading exr images")
70 | import pyexr
71 |
72 | self.images_lis = sorted(glob(os.path.join(self.data_dir, "image/*.exr")))
73 | self.n_images = len(self.images_lis)
74 | self.images_np = np.clip(
75 | np.power(np.stack([pyexr.open(im_name).get()[:, :, ::-1] for im_name in self.images_lis]), 1.0 / 2.2),
76 | 0.0,
77 | 1.0,
78 | )
79 |
80 | no_mask = True
81 | if no_mask:
82 | print("Not using masks")
83 | self.masks_lis = None
84 | self.masks_np = np.ones_like(self.images_np)
85 | else:
86 | try:
87 | self.masks_lis = sorted(glob(os.path.join(self.data_dir, "mask/*.png")))
88 | self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 255.0
89 | except:
90 | # traceback.print_exc()
91 |
92 | print("Loading mask images failed; try not using masks")
93 | self.masks_lis = None
94 | self.masks_np = np.ones_like(self.images_np)
95 |
96 | self.images_np = self.images_np[..., :3]
97 | self.masks_np = self.masks_np[..., :3]
98 |
99 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
100 | self.scale_mats_np = [np.eye(4).astype(np.float32) for idx in range(self.n_images)]
101 |
102 | self.intrinsics_all = []
103 | self.pose_all = []
104 | self.world_mats_np = []
105 | for x in self.images_lis:
106 | x = os.path.basename(x)[:-4] + ".png"
107 | K = self.camera_dict[x]["K"].astype(np.float32)
108 | W2C = self.camera_dict[x]["W2C"].astype(np.float32)
109 | C2W = np.linalg.inv(self.camera_dict[x]["W2C"]).astype(np.float32)
110 | self.intrinsics_all.append(torch.from_numpy(K))
111 | self.pose_all.append(torch.from_numpy(C2W))
112 | self.world_mats_np.append(W2C)
113 |
114 | self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3]
115 | self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3]
116 | print("image shape, mask shape: ", self.images.shape, self.masks.shape)
117 | print("image pixel range: ", self.images.min().item(), self.images.max().item())
118 |
119 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4]
120 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4]
121 | self.focal = self.intrinsics_all[0][0, 0]
122 | self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4]
123 | self.H, self.W = self.images.shape[1], self.images.shape[2]
124 | self.image_pixels = self.H * self.W
125 |
126 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
127 | object_bbox_max = np.array([1.01, 1.01, 1.01, 1.0])
128 | # Object scale mat: region of interest to **extract mesh**
129 | object_scale_mat = np.eye(4).astype(np.float32)
130 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
131 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
132 | self.object_bbox_min = object_bbox_min[:3, 0]
133 | self.object_bbox_max = object_bbox_max[:3, 0]
134 |
135 | print("Load data: End")
136 |
137 | def gen_rays_at(self, img_idx, resolution_level=1):
138 | """
139 | Generate rays at world space from one camera.
140 | """
141 | l = resolution_level
142 | tx = torch.linspace(0, self.W - 1, self.W // l)
143 | ty = torch.linspace(0, self.H - 1, self.H // l)
144 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
145 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
146 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
147 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
148 | rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
149 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
150 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
151 |
152 | def gen_random_rays_at(self, img_idx, batch_size):
153 | """
154 | Generate random rays at world space from one camera.
155 | """
156 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size])
157 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size])
158 | color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
159 | mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
160 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
161 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3
162 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
163 | rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3
164 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3
165 | return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10
166 |
167 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
168 | """
169 | Interpolate pose between two cameras.
170 | """
171 | l = resolution_level
172 | tx = torch.linspace(0, self.W - 1, self.W // l)
173 | ty = torch.linspace(0, self.H - 1, self.H // l)
174 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
175 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
176 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
177 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
178 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
179 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
180 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
181 | pose_0 = np.linalg.inv(pose_0)
182 | pose_1 = np.linalg.inv(pose_1)
183 | rot_0 = pose_0[:3, :3]
184 | rot_1 = pose_1[:3, :3]
185 | rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
186 | key_times = [0, 1]
187 | slerp = Slerp(key_times, rots)
188 | rot = slerp(ratio)
189 | pose = np.diag([1.0, 1.0, 1.0, 1.0])
190 | pose = pose.astype(np.float32)
191 | pose[:3, :3] = rot.as_matrix()
192 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
193 | pose = np.linalg.inv(pose)
194 | rot = torch.from_numpy(pose[:3, :3]).cuda()
195 | trans = torch.from_numpy(pose[:3, 3]).cuda()
196 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
197 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
198 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
199 |
200 | def near_far_from_sphere(self, rays_o, rays_d):
201 | a = torch.sum(rays_d**2, dim=-1, keepdim=True)
202 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
203 | mid = 0.5 * (-b) / a
204 | near = mid - 1.0
205 | far = mid + 1.0
206 | return near, far
207 |
208 | def image_at(self, idx, resolution_level):
209 | if self.images_lis[idx].endswith(".exr"):
210 | import pyexr
211 |
212 | img = np.power(pyexr.open(self.images_lis[idx]).get()[:, :, ::-1], 1.0 / 2.2) * 255.0
213 | else:
214 | img = cv.imread(self.images_lis[idx])
215 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255).astype(np.uint8)
216 |
--------------------------------------------------------------------------------
/models/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
6 | class Embedder:
7 | def __init__(self, **kwargs):
8 | self.kwargs = kwargs
9 | self.create_embedding_fn()
10 |
11 | def create_embedding_fn(self):
12 | embed_fns = []
13 | d = self.kwargs["input_dims"]
14 | out_dim = 0
15 | if self.kwargs["include_input"]:
16 | embed_fns.append(lambda x: x)
17 | out_dim += d
18 |
19 | max_freq = self.kwargs["max_freq_log2"]
20 | N_freqs = self.kwargs["num_freqs"]
21 |
22 | if self.kwargs["log_sampling"]:
23 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs)
24 | else:
25 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs)
26 |
27 | for freq in freq_bands:
28 | for p_fn in self.kwargs["periodic_fns"]:
29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
30 | out_dim += d
31 |
32 | self.embed_fns = embed_fns
33 | self.out_dim = out_dim
34 |
35 | def embed(self, inputs):
36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37 |
38 |
39 | def get_embedder(multires, input_dims=3):
40 | embed_kwargs = {
41 | "include_input": True,
42 | "input_dims": input_dims,
43 | "max_freq_log2": multires - 1,
44 | "num_freqs": multires,
45 | "log_sampling": True,
46 | "periodic_fns": [torch.sin, torch.cos],
47 | }
48 |
49 | embedder_obj = Embedder(**embed_kwargs)
50 |
51 | def embed(x, eo=embedder_obj):
52 | return eo.embed(x)
53 |
54 | return embed, embedder_obj.out_dim
55 |
--------------------------------------------------------------------------------
/models/export_materials.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import igl
4 | import trimesh
5 | import os
6 | import shutil
7 | import torch
8 |
9 |
10 | to8b = lambda x: np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8)
11 |
12 |
13 | def sample_surface(vertices, face_vertices, texturecoords, face_texturecoords, n_samples):
14 | """
15 | Samples point cloud on the surface of the model defined as vectices and
16 | faces. This function uses vectorized operations so fast at the cost of some
17 | memory.
18 | """
19 | vec_cross = np.cross(
20 | vertices[face_vertices[:, 0], :] - vertices[face_vertices[:, 2], :],
21 | vertices[face_vertices[:, 1], :] - vertices[face_vertices[:, 2], :],
22 | )
23 | face_areas = np.sqrt(np.sum(vec_cross**2, 1))
24 | face_areas = face_areas / np.sum(face_areas)
25 |
26 | # Sample exactly n_samples. First, oversample points and remove redundant
27 | # Error fix by Yangyan (yangyan.lee@gmail.com) 2017-Aug-7
28 | n_samples_per_face = np.ceil(n_samples * face_areas).astype(int)
29 | floor_num = np.sum(n_samples_per_face) - n_samples
30 | if floor_num > 0:
31 | indices = np.where(n_samples_per_face > 0)[0]
32 | floor_indices = np.random.choice(indices, floor_num, replace=True)
33 | n_samples_per_face[floor_indices] -= 1
34 |
35 | n_samples = np.sum(n_samples_per_face)
36 |
37 | # Create a vector that contains the face indices
38 | sample_face_idx = np.zeros((n_samples,), dtype=int)
39 | acc = 0
40 | for face_idx, _n_sample in enumerate(n_samples_per_face):
41 | sample_face_idx[acc : acc + _n_sample] = face_idx
42 | acc += _n_sample
43 |
44 | r = np.random.rand(n_samples, 2)
45 |
46 | A = vertices[face_vertices[sample_face_idx, 0], :]
47 | B = vertices[face_vertices[sample_face_idx, 1], :]
48 | C = vertices[face_vertices[sample_face_idx, 2], :]
49 | P = (1 - np.sqrt(r[:, 0:1])) * A + np.sqrt(r[:, 0:1]) * (1 - r[:, 1:]) * B + np.sqrt(r[:, 0:1]) * r[:, 1:] * C
50 |
51 | A = texturecoords[face_texturecoords[sample_face_idx, 0], :]
52 | B = texturecoords[face_texturecoords[sample_face_idx, 1], :]
53 | C = texturecoords[face_texturecoords[sample_face_idx, 2], :]
54 | P_uv = (1 - np.sqrt(r[:, 0:1])) * A + np.sqrt(r[:, 0:1]) * (1 - r[:, 1:]) * B + np.sqrt(r[:, 0:1]) * r[:, 1:] * C
55 |
56 | return P.astype(np.float32), P_uv.astype(np.float32)
57 |
58 |
59 | class Groupby(object):
60 | def __init__(self, keys):
61 | """note keys are assumed to by integer"""
62 | super().__init__()
63 |
64 | self.unique_keys, self.keys_as_int = np.unique(keys, return_inverse=True)
65 | self.n_keys = len(self.unique_keys)
66 | self.indices = [[] for i in range(self.n_keys)]
67 | for i, k in enumerate(self.keys_as_int):
68 | self.indices[k].append(i)
69 | self.indices = [np.array(elt) for elt in self.indices]
70 |
71 | def apply(self, function, vector):
72 | assert len(vector.shape) <= 2
73 | if len(vector.shape) == 2:
74 | result = np.zeros((self.n_keys, vector.shape[-1]))
75 | else:
76 | result = np.zeros((self.n_keys,))
77 |
78 | for k, idx in enumerate(self.indices):
79 | result[k] = function(vector[idx], axis=0)
80 |
81 | return result
82 |
83 |
84 | def accumulate_splat_material(xyz_image, material_image, weight_image, pcd, uv, material):
85 | H, W = material_image.shape[:2]
86 |
87 | xyz_image = xyz_image.reshape((H * W, -1))
88 | material_image = material_image.reshape((H * W, -1))
89 | weight_image = weight_image.reshape((H * W,))
90 |
91 | ### label each 3d point with their splat pixel index
92 | uv[:, 0] = uv[:, 0] * W
93 | uv[:, 1] = H - uv[:, 1] * H
94 |
95 | ### repeat to a neighborhood
96 | pcd = np.tile(pcd, (5, 1))
97 | material = np.tile(material, (5, 1))
98 | uv_up = np.copy(uv)
99 | uv_up[:, 1] -= 1
100 | uv_right = np.copy(uv)
101 | uv_right[:, 0] += 1
102 | uv_down = np.copy(uv)
103 | uv_down[:, 1] += 1
104 | uv_left = np.copy(uv)
105 | uv_left[:, 0] -= 1
106 | uv = np.concatenate((uv, uv_up, uv_right, uv_down, uv_left), axis=0)
107 |
108 | ### compute pixel coordinates
109 | pixel_col = np.floor(uv[:, 0])
110 | pixel_row = np.floor(uv[:, 1])
111 | label = (pixel_row * W + pixel_col).astype(int)
112 |
113 | ### filter out-of-range points
114 | mask = np.logical_and(label >= 0, label < H * W)
115 | label = label[mask]
116 | uv = uv[mask]
117 | material = material[mask]
118 | pcd = pcd[mask]
119 | pixel_col = pixel_col[mask]
120 | pixel_row = pixel_row[mask]
121 |
122 | # compute gaussian weight
123 | sigma = 1.0
124 | weight = np.exp(-((uv[:, 0] - pixel_col - 0.5) ** 2 + (uv[:, 1] - pixel_row - 0.5) ** 2) / (2 * sigma * sigma))
125 | # weight = np.ones_like(uv[:, 0])
126 |
127 | groupby_obj = Groupby(label)
128 | delta_xyz = groupby_obj.apply(np.sum, weight[:, np.newaxis] * pcd)
129 | delta_material = groupby_obj.apply(np.sum, weight[:, np.newaxis] * material)
130 | delta_weight = groupby_obj.apply(np.sum, weight)
131 |
132 | xyz_image[groupby_obj.unique_keys] += delta_xyz
133 | material_image[groupby_obj.unique_keys] += delta_material
134 | weight_image[groupby_obj.unique_keys] += delta_weight
135 |
136 | xyz_image = xyz_image.reshape((H, W, -1))
137 | material_image = material_image.reshape((H, W, -1))
138 | weight_image = weight_image.reshape((H, W))
139 |
140 | return xyz_image, material_image, weight_image
141 |
142 |
143 | def loadmesh_and_checkuv(obj_fpath, out_dir):
144 | os.makedirs(out_dir, exist_ok=True)
145 |
146 | vertices, texturecoords, _, face_vertices, face_texturecoords, _ = igl.read_obj(obj_fpath, dtype="float32")
147 |
148 | def make_rgba_color(float_rgb):
149 | float_rgba = np.concatenate((float_rgb, np.ones_like(float_rgb[:, 0:1])), axis=-1)
150 | return np.uint8(np.clip(float_rgba * 255.0, 0.0, 255.0))
151 |
152 | #### create debug plot
153 | pcd, pcd_uv = sample_surface(vertices, face_vertices, texturecoords, face_texturecoords, n_samples=10**6)
154 |
155 | uv_color = np.concatenate((pcd_uv, np.zeros_like(pcd_uv[:, 0:1])), axis=-1)
156 | trimesh.PointCloud(vertices=pcd, colors=make_rgba_color(uv_color)).export(os.path.join(out_dir, "check_uvmap.ply"))
157 | W, H = 512, 512
158 | grid_w, grid_h = np.meshgrid(np.linspace(0.0, 1.0, W), np.linspace(1, 0.0, H))
159 | grid_color = np.stack((grid_w, grid_h, np.zeros_like(grid_w)), axis=2)
160 | imageio.imwrite(os.path.join(out_dir, "check_uvmap.png"), to8b(grid_color))
161 |
162 | return vertices, face_vertices, texturecoords, face_texturecoords
163 |
164 |
165 | def export_materials(mesh_fpath, material_predictor, out_dir, max_num_pts=320000, texture_H=2048, texture_W=2048):
166 | """output material parameters"""
167 | os.makedirs(out_dir, exist_ok=True)
168 | vertices, face_vertices, texturecoords, face_texturecoords = loadmesh_and_checkuv(mesh_fpath, out_dir)
169 |
170 | xyz_image = np.zeros((texture_H, texture_W, 3), dtype=np.float32)
171 | material_image = np.zeros((texture_H, texture_W, 7), dtype=np.float32)
172 | weight_image = np.zeros((texture_H, texture_W), dtype=np.float32)
173 |
174 | for i in range(5):
175 | points, points_uv = sample_surface(
176 | vertices, face_vertices, texturecoords, face_texturecoords, n_samples=5 * 10**6
177 | )
178 |
179 | points = torch.from_numpy(points).cuda()
180 | merge_materials = []
181 | for points_split in torch.split(points, max_num_pts, dim=0):
182 | with torch.set_grad_enabled(False):
183 | diffuse_albedo, specular_albedo, specular_roughness = material_predictor(points_split)
184 | merge_materials.append(
185 | torch.cat((diffuse_albedo, specular_albedo, specular_roughness), dim=-1).detach().cpu()
186 | )
187 | merge_materials = torch.cat(merge_materials, dim=0).numpy()
188 | points = points.detach().cpu().numpy()
189 |
190 | accumulate_splat_material(xyz_image, material_image, weight_image, points, points_uv, merge_materials)
191 |
192 | final_xyz_image = xyz_image / (weight_image[:, :, np.newaxis] + 1e-10)
193 | final_material_image = material_image / (weight_image[:, :, np.newaxis] + 1e-10)
194 |
195 | imageio.imwrite(os.path.join(out_dir, "xyz.exr"), final_xyz_image)
196 | imageio.imwrite(os.path.join(out_dir, "diffuse_albedo.exr"), final_material_image[:, :, :3])
197 | imageio.imwrite(os.path.join(out_dir, "specular_albedo.exr"), final_material_image[:, :, 3:6])
198 | imageio.imwrite(os.path.join(out_dir, "roughness.exr"), final_material_image[:, :, 6])
199 |
200 | imageio.imwrite(os.path.join(out_dir, "xyz.png"), to8b(final_xyz_image * 0.5 + 0.5))
201 | imageio.imwrite(os.path.join(out_dir, "diffuse_albedo.png"), to8b(final_material_image[:, :, :3]))
202 | imageio.imwrite(os.path.join(out_dir, "specular_albedo.png"), to8b(final_material_image[:, :, 3:6]))
203 | imageio.imwrite(os.path.join(out_dir, "roughness.png"), to8b(final_material_image[:, :, 6]))
204 |
205 | out_mesh_fpath = mesh_fpath
206 | with open(out_mesh_fpath, "r") as original:
207 | data = original.read()
208 | with open(out_mesh_fpath, "w") as modified:
209 | modified.write("usemtl ./{}\n\n".format(os.path.basename(out_mesh_fpath)[:-4] + ".mtl") + data)
210 |
211 | with open(os.path.join(out_dir, os.path.basename(out_mesh_fpath)[:-4] + ".mtl"), "w") as fp:
212 | fp.write(
213 | "newmtl Wood\n"
214 | "Ka 1.000000 1.000000 1.000000\n"
215 | "Kd 0.640000 0.640000 0.640000\n"
216 | "Ks 0.500000 0.500000 0.500000\n"
217 | "Ns 96.078431\n"
218 | "Ni 1.000000\n"
219 | "d 1.000000\n"
220 | "illum 0\n"
221 | "map_Kd diffuse_albedo.png\n"
222 | )
223 |
--------------------------------------------------------------------------------
/models/export_mesh.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import trimesh
4 | from skimage import measure
5 |
6 |
7 | def get_grid_uniform(resolution):
8 | x = np.linspace(-1.0, 1.0, resolution)
9 | y = x
10 | z = x
11 |
12 | xx, yy, zz = np.meshgrid(x, y, z)
13 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)
14 |
15 | return {"grid_points": grid_points.cuda(), "shortest_axis_length": 2.0, "xyz": [x, y, z], "shortest_axis_index": 0}
16 |
17 |
18 | def get_grid(points, resolution, eps=0.1):
19 | input_min = torch.min(points, dim=0)[0].squeeze().numpy()
20 | input_max = torch.max(points, dim=0)[0].squeeze().numpy()
21 |
22 | bounding_box = input_max - input_min
23 | shortest_axis = np.argmin(bounding_box)
24 | if shortest_axis == 0:
25 | x = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution)
26 | length = np.max(x) - np.min(x)
27 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
28 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
29 | elif shortest_axis == 1:
30 | y = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution)
31 | length = np.max(y) - np.min(y)
32 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
33 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
34 | elif shortest_axis == 2:
35 | z = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution)
36 | length = np.max(z) - np.min(z)
37 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
38 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
39 |
40 | xx, yy, zz = np.meshgrid(x, y, z)
41 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
42 | return {
43 | "grid_points": grid_points,
44 | "shortest_axis_length": length,
45 | "xyz": [x, y, z],
46 | "shortest_axis_index": shortest_axis,
47 | }
48 |
49 |
50 | def export_mesh(sdf, mesh_fpath, resolution=512, max_n_pts=100000):
51 | assert mesh_fpath.endswith(".obj"), f"must use .obj format: {mesh_fpath}"
52 | # get low res mesh to sample point cloud
53 | grid = get_grid_uniform(100)
54 | z = []
55 | points = grid["grid_points"]
56 | for i, pnts in enumerate(torch.split(points, max_n_pts, dim=0)):
57 | z.append(sdf(pnts).detach().cpu().numpy())
58 | z = np.concatenate(z, axis=0).astype(np.float32)
59 | verts, faces, normals, values = measure.marching_cubes(
60 | volume=z.reshape(grid["xyz"][1].shape[0], grid["xyz"][0].shape[0], grid["xyz"][2].shape[0]).transpose(
61 | [1, 0, 2]
62 | ),
63 | level=0,
64 | spacing=(
65 | grid["xyz"][0][2] - grid["xyz"][0][1],
66 | grid["xyz"][0][2] - grid["xyz"][0][1],
67 | grid["xyz"][0][2] - grid["xyz"][0][1],
68 | ),
69 | )
70 | verts = verts + np.array([grid["xyz"][0][0], grid["xyz"][1][0], grid["xyz"][2][0]])
71 | mesh_low_res = trimesh.Trimesh(verts, faces, normals)
72 | components = mesh_low_res.split(only_watertight=False)
73 | areas = np.array([c.area for c in components], dtype=np.float)
74 | mesh_low_res = components[areas.argmax()]
75 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
76 | recon_pc = torch.from_numpy(recon_pc).float().cuda()
77 |
78 | # Center and align the recon pc
79 | s_mean = recon_pc.mean(dim=0)
80 | s_cov = recon_pc - s_mean
81 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
82 | vecs = torch.eig(s_cov, True)[1].transpose(0, 1)
83 | if torch.det(vecs) < 0:
84 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)
85 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), (recon_pc - s_mean).unsqueeze(-1)).squeeze()
86 |
87 | grid_aligned = get_grid(helper.cpu(), resolution)
88 | grid_points = grid_aligned["grid_points"]
89 | g = []
90 | for i, pnts in enumerate(torch.split(grid_points, max_n_pts, dim=0)):
91 | g.append(
92 | (
93 | torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), pnts.unsqueeze(-1)).squeeze()
94 | + s_mean
95 | )
96 | .detach()
97 | .cpu()
98 | )
99 | grid_points = torch.cat(g, dim=0)
100 |
101 | # MC to new grid
102 | points = grid_points
103 | z = []
104 | for i, pnts in enumerate(torch.split(points, max_n_pts, dim=0)):
105 | z.append(sdf(pnts.cuda()).detach().cpu().numpy())
106 | z = np.concatenate(z, axis=0).astype(np.float32)
107 |
108 | if not (np.min(z) > 0 or np.max(z) < 0):
109 | verts, faces, normals, values = measure.marching_cubes(
110 | volume=z.reshape(
111 | grid_aligned["xyz"][1].shape[0], grid_aligned["xyz"][0].shape[0], grid_aligned["xyz"][2].shape[0]
112 | ).transpose([1, 0, 2]),
113 | level=0,
114 | spacing=(
115 | grid_aligned["xyz"][0][2] - grid_aligned["xyz"][0][1],
116 | grid_aligned["xyz"][0][2] - grid_aligned["xyz"][0][1],
117 | grid_aligned["xyz"][0][2] - grid_aligned["xyz"][0][1],
118 | ),
119 | )
120 |
121 | verts = torch.from_numpy(verts).float()
122 | verts = torch.bmm(
123 | vecs.detach().cpu().unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), verts.unsqueeze(-1)
124 | ).squeeze()
125 | verts = (verts + grid_points[0]).numpy()
126 |
127 | trimesh.Trimesh(verts, faces, normals).export(mesh_fpath)
128 |
--------------------------------------------------------------------------------
/models/export_uv.py:
--------------------------------------------------------------------------------
1 | # Usage: Blender --background --python export_uv.py {in_mesh_fpath} {out_mesh_fpath}
2 |
3 | import os
4 | import bpy
5 | import sys
6 |
7 |
8 | def export_uv(in_mesh_fpath, out_mesh_fpath):
9 | assert in_mesh_fpath.endswith(".obj"), f"must use .obj format: {in_mesh_fpath}"
10 | assert out_mesh_fpath.endswith(".obj"), f"must use .obj format: {out_mesh_fpath}"
11 |
12 | bpy.data.objects["Camera"].select_set(True)
13 | bpy.data.objects["Cube"].select_set(True)
14 | bpy.data.objects["Light"].select_set(True)
15 | bpy.ops.object.delete() # delete camera, cube, light
16 |
17 | mesh_fname = os.path.basename(in_mesh_fpath)[:-4]
18 | bpy.ops.import_scene.obj(
19 | filepath=in_mesh_fpath,
20 | use_edges=True,
21 | use_smooth_groups=True,
22 | use_split_objects=True,
23 | use_split_groups=True,
24 | use_groups_as_vgroups=False,
25 | use_image_search=True,
26 | split_mode="ON",
27 | global_clamp_size=0,
28 | axis_forward="-Z",
29 | axis_up="Y",
30 | )
31 |
32 | obj = bpy.data.objects[mesh_fname]
33 | obj.select_set(True)
34 | bpy.context.view_layer.objects.active = obj
35 | bpy.ops.object.mode_set(mode="EDIT")
36 | bpy.ops.mesh.select_all(action="SELECT")
37 | bpy.ops.uv.smart_project()
38 | bpy.ops.object.mode_set(mode="OBJECT")
39 |
40 | bpy.ops.export_scene.obj(
41 | filepath=out_mesh_fpath,
42 | axis_forward="-Z",
43 | axis_up="Y",
44 | use_selection=True,
45 | use_normals=True,
46 | use_uvs=True,
47 | use_materials=False,
48 | use_triangles=True,
49 | )
50 |
51 |
52 | print(sys.argv)
53 | export_uv(sys.argv[-2], sys.argv[-1])
54 |
--------------------------------------------------------------------------------
/models/fields.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from models.embedder import get_embedder
6 |
7 |
8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
9 | class SDFNetwork(nn.Module):
10 | def __init__(
11 | self,
12 | d_in,
13 | d_out,
14 | d_hidden,
15 | n_layers,
16 | skip_in=(4,),
17 | multires=0,
18 | bias=0.5,
19 | scale=1,
20 | geometric_init=True,
21 | weight_norm=True,
22 | inside_outside=False,
23 | ):
24 | super(SDFNetwork, self).__init__()
25 |
26 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
27 |
28 | self.embed_fn_fine = None
29 |
30 | if multires > 0:
31 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
32 | self.embed_fn_fine = embed_fn
33 | dims[0] = input_ch
34 |
35 | self.num_layers = len(dims)
36 | self.skip_in = skip_in
37 | self.scale = scale
38 |
39 | for l in range(0, self.num_layers - 1):
40 | if l + 1 in self.skip_in:
41 | out_dim = dims[l + 1] - dims[0]
42 | else:
43 | out_dim = dims[l + 1]
44 |
45 | lin = nn.Linear(dims[l], out_dim)
46 |
47 | if geometric_init:
48 | if l == self.num_layers - 2:
49 | if not inside_outside:
50 | torch.nn.init.normal_(
51 | lin.weight,
52 | mean=np.sqrt(np.pi) / np.sqrt(dims[l]),
53 | std=0.0001,
54 | )
55 | torch.nn.init.constant_(lin.bias, -bias)
56 | else:
57 | torch.nn.init.normal_(
58 | lin.weight,
59 | mean=-np.sqrt(np.pi) / np.sqrt(dims[l]),
60 | std=0.0001,
61 | )
62 | torch.nn.init.constant_(lin.bias, bias)
63 | elif multires > 0 and l == 0:
64 | torch.nn.init.constant_(lin.bias, 0.0)
65 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
66 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
67 | elif multires > 0 and l in self.skip_in:
68 | torch.nn.init.constant_(lin.bias, 0.0)
69 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
70 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
71 | else:
72 | torch.nn.init.constant_(lin.bias, 0.0)
73 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
74 |
75 | if weight_norm:
76 | lin = nn.utils.weight_norm(lin)
77 |
78 | setattr(self, "lin" + str(l), lin)
79 |
80 | self.activation = nn.Softplus(beta=100)
81 |
82 | def forward(self, inputs):
83 | inputs = inputs * self.scale
84 | if self.embed_fn_fine is not None:
85 | inputs = self.embed_fn_fine(inputs)
86 |
87 | x = inputs
88 | for l in range(0, self.num_layers - 1):
89 | lin = getattr(self, "lin" + str(l))
90 |
91 | if l in self.skip_in:
92 | x = torch.cat([x, inputs], -1) / np.sqrt(2)
93 |
94 | x = lin(x)
95 |
96 | if l < self.num_layers - 2:
97 | x = self.activation(x)
98 | return torch.cat([x[..., :1] / self.scale, x[..., 1:]], dim=-1)
99 |
100 | def sdf(self, x):
101 | return self.forward(x)[..., :1]
102 |
103 | def sdf_hidden_appearance(self, x):
104 | return self.forward(x)
105 |
106 | def gradient(self, x):
107 | x.requires_grad_(True)
108 | y = self.sdf(x)
109 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
110 | gradients = torch.autograd.grad(
111 | outputs=y,
112 | inputs=x,
113 | grad_outputs=d_output,
114 | create_graph=True,
115 | retain_graph=True,
116 | only_inputs=True,
117 | )[0]
118 | return gradients
119 |
120 | def get_all(self, x, is_training=True):
121 | with torch.enable_grad():
122 | x.requires_grad_(True)
123 | tmp = self.forward(x)
124 | y, feature = tmp[..., :1], tmp[..., 1:]
125 |
126 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
127 | gradients = torch.autograd.grad(
128 | outputs=y,
129 | inputs=x,
130 | grad_outputs=d_output,
131 | create_graph=is_training,
132 | retain_graph=is_training,
133 | only_inputs=True,
134 | )[0]
135 | if not is_training:
136 | return y.detach(), feature.detach(), gradients.detach()
137 | return y, feature, gradients
138 |
139 |
140 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
141 | class RenderingNetwork(nn.Module):
142 | def __init__(
143 | self,
144 | d_feature,
145 | mode,
146 | d_in,
147 | d_out,
148 | d_hidden,
149 | n_layers,
150 | weight_norm=True,
151 | multires=0,
152 | multires_view=0,
153 | squeeze_out=True,
154 | squeeze_out_scale=1.0,
155 | output_bias=0.0,
156 | output_scale=1.0,
157 | skip_in=(),
158 | ):
159 | super().__init__()
160 |
161 | self.mode = mode
162 | self.squeeze_out = squeeze_out
163 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
164 |
165 | self.embed_fn = None
166 | if multires > 0:
167 | embed_fn, input_ch = get_embedder(multires)
168 | self.embed_fn = embed_fn
169 | dims[0] += input_ch - 3
170 |
171 | self.embedview_fn = None
172 | if multires_view > 0:
173 | embedview_fn, input_ch = get_embedder(multires_view)
174 | self.embedview_fn = embedview_fn
175 | dims[0] += input_ch - 3
176 |
177 | self.num_layers = len(dims)
178 | self.skip_in = skip_in
179 |
180 | for l in range(0, self.num_layers - 1):
181 | if l in self.skip_in:
182 | dims[l] += dims[0]
183 |
184 | for l in range(0, self.num_layers - 1):
185 | if l + 1 in self.skip_in:
186 | out_dim = dims[l + 1] - dims[0]
187 | else:
188 | out_dim = dims[l + 1]
189 |
190 | lin = nn.Linear(dims[l], out_dim)
191 |
192 | if weight_norm:
193 | lin = nn.utils.weight_norm(lin)
194 |
195 | setattr(self, "lin" + str(l), lin)
196 |
197 | self.relu = nn.ReLU()
198 |
199 | self.output_bias = output_bias
200 | self.output_scale = output_scale
201 | self.squeeze_out_scale = squeeze_out_scale
202 |
203 | def forward(self, points, normals, view_dirs, feature_vectors):
204 |
205 | if self.embed_fn is not None:
206 | points = self.embed_fn(points)
207 |
208 | if self.embedview_fn is not None and self.mode != "no_view_dir":
209 | view_dirs = self.embedview_fn(view_dirs)
210 |
211 | rendering_input = None
212 |
213 | if self.mode == "idr":
214 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
215 | elif self.mode == "no_view_dir":
216 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
217 | elif self.mode == "no_normal":
218 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
219 |
220 | x = rendering_input
221 |
222 | for l in range(0, self.num_layers - 1):
223 | lin = getattr(self, "lin" + str(l))
224 |
225 | if l in self.skip_in:
226 | x = torch.cat([x, rendering_input], dim=-1) / np.sqrt(2)
227 |
228 | x = lin(x)
229 |
230 | if l < self.num_layers - 2:
231 | x = self.relu(x)
232 |
233 | x = self.output_scale * (x + self.output_bias)
234 | if self.squeeze_out:
235 | x = self.squeeze_out_scale * torch.sigmoid(x)
236 |
237 | return x
238 |
239 |
240 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch
241 | class NeRF(nn.Module):
242 | def __init__(
243 | self,
244 | D=8,
245 | W=256,
246 | d_in=3,
247 | d_in_view=3,
248 | multires=0,
249 | multires_view=0,
250 | output_ch=4,
251 | skips=[4],
252 | use_viewdirs=False,
253 | ):
254 | super(NeRF, self).__init__()
255 | self.D = D
256 | self.W = W
257 | self.d_in = d_in
258 | self.d_in_view = d_in_view
259 | self.input_ch = 3
260 | self.input_ch_view = 3
261 | self.embed_fn = None
262 | self.embed_fn_view = None
263 |
264 | if multires > 0:
265 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
266 | self.embed_fn = embed_fn
267 | self.input_ch = input_ch
268 |
269 | if multires_view > 0:
270 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view)
271 | self.embed_fn_view = embed_fn_view
272 | self.input_ch_view = input_ch_view
273 |
274 | self.skips = skips
275 | self.use_viewdirs = use_viewdirs
276 |
277 | self.pts_linears = nn.ModuleList(
278 | [nn.Linear(self.input_ch, W)]
279 | + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]
280 | )
281 |
282 | ### Implementation according to the official code release
283 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
284 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)])
285 |
286 | ### Implementation according to the paper
287 | # self.views_linears = nn.ModuleList(
288 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
289 |
290 | if use_viewdirs:
291 | self.feature_linear = nn.Linear(W, W)
292 | self.alpha_linear = nn.Linear(W, 1)
293 | self.rgb_linear = nn.Linear(W // 2, 3)
294 | else:
295 | self.output_linear = nn.Linear(W, output_ch)
296 |
297 | def forward(self, input_pts, input_views):
298 | if self.embed_fn is not None:
299 | input_pts = self.embed_fn(input_pts)
300 | if self.embed_fn_view is not None:
301 | input_views = self.embed_fn_view(input_views)
302 |
303 | h = input_pts
304 | for i, l in enumerate(self.pts_linears):
305 | h = self.pts_linears[i](h)
306 | h = F.relu(h)
307 | if i in self.skips:
308 | h = torch.cat([input_pts, h], -1)
309 |
310 | if self.use_viewdirs:
311 | alpha = self.alpha_linear(h)
312 | feature = self.feature_linear(h)
313 | h = torch.cat([feature, input_views], -1)
314 |
315 | for i, l in enumerate(self.views_linears):
316 | h = self.views_linears[i](h)
317 | h = F.relu(h)
318 |
319 | rgb = self.rgb_linear(h)
320 | return alpha, rgb
321 | else:
322 | assert False
323 |
324 |
325 | class SingleVarianceNetwork(nn.Module):
326 | def __init__(self, init_val):
327 | super(SingleVarianceNetwork, self).__init__()
328 | self.register_parameter("variance", nn.Parameter(torch.tensor(init_val)))
329 |
330 | def forward(self, x):
331 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0)
332 |
--------------------------------------------------------------------------------
/models/ggx/int_mts_diff_rtrans_data.txt:
--------------------------------------------------------------------------------
1 | 0.416354
2 | 0.416355
3 | 0.416354
4 | 0.416354
5 | 0.41635
6 | 0.416334
7 | 0.416277
8 | 0.416124
9 | 0.415772
10 | 0.415112
11 | 0.414269
12 | 0.413012
13 | 0.41189
14 | 0.410755
15 | 0.410089
16 | 0.409991
17 | 0.409841
18 | 0.410012
19 | 0.410206
20 | 0.410433
21 | 0.41088
22 | 0.41127
23 | 0.41126
24 | 0.411295
25 | 0.410715
26 | 0.409467
27 | 0.407075
28 | 0.403966
29 | 0.399456
30 | 0.393355
31 | 0.385689
32 | 0.376357
33 | 0.365266
34 | 0.352599
35 | 0.338526
36 | 0.323228
37 | 0.306972
38 | 0.290034
39 | 0.272687
40 | 0.25522
41 | 0.237934
42 | 0.220986
43 | 0.20462
44 | 0.188985
45 | 0.174162
46 | 0.160172
47 | 0.147244
48 | 0.13523
49 | 0.124225
50 | 0.114087
51 |
--------------------------------------------------------------------------------
/models/image_losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import scipy.ndimage
6 |
7 | import warnings
8 | import kornia
9 |
10 | from icecream import ic
11 |
12 |
13 | class PyramidL2Loss(nn.Module):
14 | def __init__(self, use_cuda=True):
15 | super().__init__()
16 |
17 | dirac = np.zeros((7, 7), dtype=np.float32)
18 | dirac[3, 3] = 1.0
19 | f = np.zeros([3, 3, 7, 7], dtype=np.float32)
20 | gf = scipy.ndimage.filters.gaussian_filter(dirac, 1.0)
21 | f[0, 0, :, :] = gf
22 | f[1, 1, :, :] = gf
23 | f[2, 2, :, :] = gf
24 | self.f = torch.from_numpy(f)
25 | if use_cuda:
26 | self.f = self.f.cuda()
27 | self.m = torch.nn.AvgPool2d(2)
28 |
29 | def forward(self, pred_img, trgt_img):
30 | """
31 | pred_img, trgt_img: [B, C, H, W]
32 | """
33 | diff_0 = pred_img - trgt_img
34 |
35 | h, w = pred_img.shape[-2:]
36 | # Convolve then downsample
37 | diff_1 = self.m(torch.nn.functional.conv2d(diff_0, self.f, padding=3))
38 | diff_2 = self.m(torch.nn.functional.conv2d(diff_1, self.f, padding=3))
39 | diff_3 = self.m(torch.nn.functional.conv2d(diff_2, self.f, padding=3))
40 | diff_4 = self.m(torch.nn.functional.conv2d(diff_3, self.f, padding=3))
41 | loss = (
42 | diff_0.pow(2).sum() / (h * w)
43 | + diff_1.pow(2).sum() / ((h / 2.0) * (w / 2.0))
44 | + diff_2.pow(2).sum() / ((h / 4.0) * (w / 4.0))
45 | + diff_3.pow(2).sum() / ((h / 8.0) * (w / 8.0))
46 | + diff_4.pow(2).sum() / ((h / 16.0) * (w / 16.0))
47 | )
48 | return loss
49 |
50 |
51 | def _fspecial_gauss_1d(size, sigma):
52 | r"""Create 1-D gauss kernel
53 | Args:
54 | size (int): the size of gauss kernel
55 | sigma (float): sigma of normal distribution
56 | Returns:
57 | torch.Tensor: 1D kernel (1 x 1 x size)
58 | """
59 | coords = torch.arange(size, dtype=torch.float)
60 | coords -= size // 2
61 |
62 | g = torch.exp(-(coords**2) / (2 * sigma**2))
63 | g /= g.sum()
64 |
65 | return g.unsqueeze(0).unsqueeze(0)
66 |
67 |
68 | def gaussian_filter(input, win):
69 | r"""Blur input with 1-D kernel
70 | Args:
71 | input (torch.Tensor): a batch of tensors to be blurred
72 | window (torch.Tensor): 1-D gauss kernel
73 | Returns:
74 | torch.Tensor: blurred tensors
75 | """
76 | assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
77 | if len(input.shape) == 4:
78 | conv = F.conv2d
79 | elif len(input.shape) == 5:
80 | conv = F.conv3d
81 | else:
82 | raise NotImplementedError(input.shape)
83 |
84 | C = input.shape[1]
85 | out = input
86 | for i, s in enumerate(input.shape[2:]):
87 | if s >= win.shape[-1]:
88 | out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
89 | else:
90 | warnings.warn(
91 | f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
92 | )
93 |
94 | return out
95 |
96 |
97 | def ssim_loss_fn(X, Y, mask=None, data_range=1.0, win_size=11, win_sigma=1.5, K=(0.01, 0.03)):
98 | r"""Calculate ssim index for X and Y
99 | Args:
100 | X (torch.Tensor): images of shape [b, c, h, w]
101 | Y (torch.Tensor): images of shape [b, c, h, w]
102 | mask (torch.Tensor): [b, 1, h, w]
103 | win_size: (int, optional): the size of gauss kernel
104 | win_sigma: (float, optional): sigma of normal distribution
105 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
106 | Returns:
107 | torch.Tensor: per pixel ssim results (same size as input images X, Y)
108 | """
109 | if not X.shape == Y.shape:
110 | raise ValueError("Input images should have the same dimensions.")
111 |
112 | if not X.type() == Y.type():
113 | raise ValueError("Input images should have the same dtype.")
114 |
115 | if len(X.shape) != 4:
116 | raise ValueError(f"Input images should be 4-d tensors, but got {X.shape}")
117 |
118 | if not (win_size % 2 == 1):
119 | raise ValueError("Window size should be odd.")
120 |
121 | win = _fspecial_gauss_1d(win_size, win_sigma)
122 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
123 |
124 | K1, K2 = K
125 | # batch, channel, [depth,] height, width = X.shape
126 | compensation = 1.0
127 |
128 | C1 = (K1 * data_range) ** 2
129 | C2 = (K2 * data_range) ** 2
130 |
131 | win = win.to(X.device, dtype=X.dtype)
132 |
133 | mu1 = gaussian_filter(X, win)
134 | mu2 = gaussian_filter(Y, win)
135 |
136 | mu1_sq = mu1.pow(2)
137 | mu2_sq = mu2.pow(2)
138 | mu1_mu2 = mu1 * mu2
139 |
140 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
141 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
142 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
143 |
144 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
145 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
146 | ssim_map = ssim_map.mean(dim=1, keepdim=True)
147 |
148 | if mask is not None:
149 | ### pad ssim_map to original size
150 | ssim_map = F.pad(
151 | ssim_map, (win_size // 2, win_size // 2, win_size // 2, win_size // 2), mode="constant", value=1.0
152 | )
153 |
154 | mask = kornia.morphology.erosion(mask.float(), torch.ones(win_size, win_size).float().to(mask.device)) > 0.5
155 | # ic(ssim_map.shape, mask.shape)
156 | ssim_map = ssim_map[mask]
157 |
158 | return 1.0 - ssim_map.mean()
159 |
160 |
161 | if __name__ == "__main__":
162 | pred_im = torch.rand(1, 3, 256, 256).cuda()
163 | # gt_im = torch.rand(1, 3, 256, 256).cuda()
164 | gt_im = pred_im.clone()
165 | mask = torch.ones(1, 1, 256, 256).bool().cuda()
166 |
167 | ssim_loss = ssim_loss_fn(pred_im, gt_im, mask)
168 | ic(ssim_loss)
169 |
--------------------------------------------------------------------------------
/models/raytracer.py:
--------------------------------------------------------------------------------
1 | from email import contentmanager
2 | from operator import contains
3 | import os
4 | from sys import prefix
5 | from turtle import update
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | import kornia
10 | import cv2
11 |
12 | from icecream import ic
13 |
14 | VERBOSE_MODE = False
15 |
16 |
17 | def reparam_points(nondiff_points, nondiff_grads, nondiff_trgt_dirs, diff_sdf_vals):
18 | # note that flipping the direction of nondiff_trgt_dirs would not change this equations at all
19 | # hence we require dot >= 0
20 | dot = (nondiff_grads * nondiff_trgt_dirs).sum(dim=-1, keepdim=True)
21 | # assert (dot >= 0.).all(), 'dot>=0 not satisfied in reparam_points: {},{}'.format(dot.min().item(), dot.max().item())
22 | dot = torch.clamp(dot, min=1e-4)
23 | diff_points = nondiff_points - nondiff_trgt_dirs / dot * (diff_sdf_vals - diff_sdf_vals.detach())
24 | return diff_points
25 |
26 |
27 | class RayTracer(nn.Module):
28 | def __init__(
29 | self,
30 | sdf_threshold=5.0e-5,
31 | sphere_tracing_iters=16,
32 | n_steps=128,
33 | max_num_pts=200000,
34 | ):
35 | super().__init__()
36 | """sdf values of convergent points must be inside [-sdf_threshold, sdf_threshold]"""
37 | self.sdf_threshold = sdf_threshold
38 | # sphere tracing hyper-params
39 | self.sphere_tracing_iters = sphere_tracing_iters
40 | # dense sampling hyper-params
41 | self.n_steps = n_steps
42 |
43 | self.max_num_pts = max_num_pts
44 |
45 | @torch.no_grad()
46 | def forward(self, sdf, ray_o, ray_d, min_dis, max_dis, work_mask):
47 | (
48 | convergent_mask,
49 | unfinished_mask_start,
50 | curr_start_points,
51 | curr_start_sdf,
52 | acc_start_dis,
53 | ) = self.sphere_tracing(sdf, ray_o, ray_d, min_dis, max_dis, work_mask)
54 | sphere_tracing_cnt = convergent_mask.sum()
55 |
56 | sampler_work_mask = unfinished_mask_start
57 | sampler_cnt = 0
58 | if sampler_work_mask.sum() > 0:
59 | tmp_mask = (curr_start_sdf[sampler_work_mask] > 0.0).float()
60 | sampler_min_dis = (
61 | tmp_mask * acc_start_dis[sampler_work_mask] + (1.0 - tmp_mask) * min_dis[sampler_work_mask]
62 | )
63 | sampler_max_dis = (
64 | tmp_mask * max_dis[sampler_work_mask] + (1.0 - tmp_mask) * acc_start_dis[sampler_work_mask]
65 | )
66 |
67 | (sampler_convergent_mask, sampler_points, sampler_sdf, sampler_dis,) = self.ray_sampler(
68 | sdf,
69 | ray_o[sampler_work_mask],
70 | ray_d[sampler_work_mask],
71 | sampler_min_dis,
72 | sampler_max_dis,
73 | )
74 |
75 | convergent_mask[sampler_work_mask] = sampler_convergent_mask
76 | curr_start_points[sampler_work_mask] = sampler_points
77 | curr_start_sdf[sampler_work_mask] = sampler_sdf
78 | acc_start_dis[sampler_work_mask] = sampler_dis
79 | sampler_cnt = sampler_convergent_mask.sum()
80 |
81 | ret_dict = {
82 | "convergent_mask": convergent_mask,
83 | "points": curr_start_points,
84 | "sdf": curr_start_sdf,
85 | "distance": acc_start_dis,
86 | }
87 |
88 | if VERBOSE_MODE: # debug
89 | sdf_check = sdf(curr_start_points)
90 | ic(
91 | convergent_mask.sum() / convergent_mask.numel(),
92 | sdf_check[convergent_mask].min().item(),
93 | sdf_check[convergent_mask].max().item(),
94 | )
95 | debug_info = "Total,raytraced,convergent(sphere tracing+dense sampling): {},{},{} ({}+{})".format(
96 | work_mask.numel(),
97 | work_mask.sum(),
98 | convergent_mask.sum(),
99 | sphere_tracing_cnt,
100 | sampler_cnt,
101 | )
102 | ic(debug_info)
103 | return ret_dict
104 |
105 | def sphere_tracing(self, sdf, ray_o, ray_d, min_dis, max_dis, work_mask):
106 | """Run sphere tracing algorithm for max iterations"""
107 | iters = 0
108 | unfinished_mask_start = work_mask.clone()
109 | acc_start_dis = min_dis.clone()
110 | curr_start_points = ray_o + ray_d * acc_start_dis.unsqueeze(-1)
111 | curr_sdf_start = sdf(curr_start_points)
112 | while True:
113 | # Check convergence
114 | unfinished_mask_start = (
115 | unfinished_mask_start & (curr_sdf_start.abs() > self.sdf_threshold) & (acc_start_dis < max_dis)
116 | )
117 |
118 | if iters == self.sphere_tracing_iters or unfinished_mask_start.sum() == 0:
119 | break
120 | iters += 1
121 |
122 | # Make step
123 | tmp = curr_sdf_start[unfinished_mask_start]
124 | acc_start_dis[unfinished_mask_start] += tmp
125 | curr_start_points[unfinished_mask_start] += ray_d[unfinished_mask_start] * tmp.unsqueeze(-1)
126 | curr_sdf_start[unfinished_mask_start] = sdf(curr_start_points[unfinished_mask_start])
127 |
128 | convergent_mask = (
129 | work_mask
130 | & ~unfinished_mask_start
131 | & (curr_sdf_start.abs() <= self.sdf_threshold)
132 | & (acc_start_dis < max_dis)
133 | )
134 | return (
135 | convergent_mask,
136 | unfinished_mask_start,
137 | curr_start_points,
138 | curr_sdf_start,
139 | acc_start_dis,
140 | )
141 |
142 | def ray_sampler(self, sdf, ray_o, ray_d, min_dis, max_dis):
143 | """Sample the ray in a given range and perform rootfinding on ray segments which have sign transition"""
144 | intervals_dis = (
145 | torch.linspace(0, 1, steps=self.n_steps).float().to(min_dis.device).view(1, self.n_steps)
146 | ) # [1, n_steps]
147 | intervals_dis = min_dis.unsqueeze(-1) + intervals_dis * (
148 | max_dis.unsqueeze(-1) - min_dis.unsqueeze(-1)
149 | ) # [n_valid, n_steps]
150 | points = ray_o.unsqueeze(-2) + ray_d.unsqueeze(-2) * intervals_dis.unsqueeze(-1) # [n_valid, n_steps, 3]
151 |
152 | sdf_val = []
153 | for pnts in torch.split(points.reshape(-1, 3), self.max_num_pts, dim=0):
154 | sdf_val.append(sdf(pnts))
155 | sdf_val = torch.cat(sdf_val, dim=0).reshape(-1, self.n_steps)
156 |
157 | # To be returned
158 | sampler_pts = torch.zeros_like(ray_d)
159 | sampler_sdf = torch.zeros_like(min_dis)
160 | sampler_dis = torch.zeros_like(min_dis)
161 |
162 | tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).float().to(sdf_val.device).reshape(
163 | 1, self.n_steps
164 | )
165 | # return first negative sdf point if exists
166 | min_val, min_idx = torch.min(tmp, dim=-1)
167 | rootfind_work_mask = (min_val < 0.0) & (min_idx >= 1)
168 | n_rootfind = rootfind_work_mask.sum()
169 | if n_rootfind > 0:
170 | # [n_rootfind, 1]
171 | min_idx = min_idx[rootfind_work_mask].unsqueeze(-1)
172 | z_low = torch.gather(intervals_dis[rootfind_work_mask], dim=-1, index=min_idx - 1).squeeze(
173 | -1
174 | ) # [n_rootfind, ]
175 | # [n_rootfind, ]; > 0
176 | sdf_low = torch.gather(sdf_val[rootfind_work_mask], dim=-1, index=min_idx - 1).squeeze(-1)
177 | z_high = torch.gather(intervals_dis[rootfind_work_mask], dim=-1, index=min_idx).squeeze(
178 | -1
179 | ) # [n_rootfind, ]
180 | # [n_rootfind, ]; < 0
181 | sdf_high = torch.gather(sdf_val[rootfind_work_mask], dim=-1, index=min_idx).squeeze(-1)
182 |
183 | p_pred, z_pred, sdf_pred = self.rootfind(
184 | sdf,
185 | sdf_low,
186 | sdf_high,
187 | z_low,
188 | z_high,
189 | ray_o[rootfind_work_mask],
190 | ray_d[rootfind_work_mask],
191 | )
192 |
193 | sampler_pts[rootfind_work_mask] = p_pred
194 | sampler_sdf[rootfind_work_mask] = sdf_pred
195 | sampler_dis[rootfind_work_mask] = z_pred
196 |
197 | return rootfind_work_mask, sampler_pts, sampler_sdf, sampler_dis
198 |
199 | def rootfind(self, sdf, f_low, f_high, d_low, d_high, ray_o, ray_d):
200 | """binary search the root"""
201 | work_mask = (f_low > 0) & (f_high < 0)
202 | d_mid = (d_low + d_high) / 2.0
203 | i = 0
204 | while work_mask.any():
205 | p_mid = ray_o + ray_d * d_mid.unsqueeze(-1)
206 | f_mid = sdf(p_mid)
207 | ind_low = f_mid > 0
208 | ind_high = f_mid <= 0
209 | if ind_low.sum() > 0:
210 | d_low[ind_low] = d_mid[ind_low]
211 | f_low[ind_low] = f_mid[ind_low]
212 | if ind_high.sum() > 0:
213 | d_high[ind_high] = d_mid[ind_high]
214 | f_high[ind_high] = f_mid[ind_high]
215 | d_mid = (d_low + d_high) / 2.0
216 | work_mask &= (d_high - d_low) > 2 * self.sdf_threshold
217 | i += 1
218 | p_mid = ray_o + ray_d * d_mid.unsqueeze(-1)
219 | f_mid = sdf(p_mid)
220 | return p_mid, d_mid, f_mid
221 |
222 |
223 | @torch.no_grad()
224 | def intersect_sphere(ray_o, ray_d, r):
225 | """
226 | ray_o, ray_d: [..., 3]
227 | compute the depth of the intersection point between this ray and unit sphere
228 | """
229 | # note: d1 becomes negative if this mid point is behind camera
230 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
231 | p = ray_o + d1.unsqueeze(-1) * ray_d
232 |
233 | tmp = r * r - torch.sum(p * p, dim=-1)
234 | mask_intersect = tmp > 0.0
235 | d2 = torch.sqrt(torch.clamp(tmp, min=0.0)) / torch.norm(ray_d, dim=-1)
236 |
237 | return mask_intersect, torch.clamp(d1 - d2, min=0.0), d1 + d2
238 |
239 |
240 | class Camera(object):
241 | def __init__(self, W, H, K, W2C):
242 | """
243 | W, H: int
244 | K, W2C: 4x4 tensor
245 | """
246 | self.W = W
247 | self.H = H
248 | self.K = K
249 | self.W2C = W2C
250 | self.K_inv = torch.inverse(K)
251 | self.C2W = torch.inverse(W2C)
252 | self.device = self.K.device
253 |
254 | def get_rays(self, uv):
255 | """
256 | uv: [..., 2]
257 | """
258 | dots_sh = list(uv.shape[:-1])
259 |
260 | uv = uv.view(-1, 2)
261 | uv = torch.cat((uv, torch.ones_like(uv[..., 0:1])), dim=-1)
262 | ray_d = torch.matmul(
263 | torch.matmul(uv, self.K_inv[:3, :3].transpose(1, 0)),
264 | self.C2W[:3, :3].transpose(1, 0),
265 | ).reshape(
266 | dots_sh
267 | + [
268 | 3,
269 | ]
270 | )
271 |
272 | ray_d_norm = ray_d.norm(dim=-1)
273 | ray_d = ray_d / ray_d_norm.unsqueeze(-1)
274 |
275 | ray_o = (
276 | self.C2W[:3, 3]
277 | .unsqueeze(0)
278 | .expand(uv.shape[0], -1)
279 | .reshape(
280 | dots_sh
281 | + [
282 | 3,
283 | ]
284 | )
285 | )
286 | return ray_o, ray_d, ray_d_norm
287 |
288 | def get_camera_origin(self, prefix_shape=None):
289 | ray_o = self.C2W[:3, 3]
290 | if prefix_shape is not None:
291 | prefix_shape = list(prefix_shape)
292 | ray_o = ray_o.view([1,] * len(prefix_shape) + [3,]).expand(
293 | prefix_shape
294 | + [
295 | 3,
296 | ]
297 | )
298 | return ray_o
299 |
300 | def get_uv(self):
301 | u, v = np.meshgrid(np.arange(self.W), np.arange(self.H))
302 | uv = torch.from_numpy(np.stack((u, v), axis=-1).astype(np.float32)).to(self.device) + 0.5
303 | return uv
304 |
305 | def project(self, points):
306 | """
307 | points: [..., 3]
308 | """
309 | dots_sh = list(points.shape[:-1])
310 |
311 | points = points.view(-1, 3)
312 | points = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
313 | uv = torch.matmul(
314 | torch.matmul(points, self.W2C.transpose(1, 0)),
315 | self.K.transpose(1, 0),
316 | )
317 | uv = uv[:, :2] / uv[:, 2:3]
318 |
319 | uv = uv.view(
320 | dots_sh
321 | + [
322 | 2,
323 | ]
324 | )
325 | return uv
326 |
327 | def crop_region(self, trgt_W, trgt_H, center_crop=False, ul_corner=None, image=None):
328 | K = self.K.clone()
329 | if ul_corner is not None:
330 | ul_col, ul_row = ul_corner
331 | elif center_crop:
332 | ul_col = self.W // 2 - trgt_W // 2
333 | ul_row = self.H // 2 - trgt_H // 2
334 | else:
335 | ul_col = np.random.randint(0, self.W - trgt_W)
336 | ul_row = np.random.randint(0, self.H - trgt_H)
337 | # modify K
338 | K[0, 2] -= ul_col
339 | K[1, 2] -= ul_row
340 |
341 | camera = Camera(trgt_W, trgt_H, K, self.W2C.clone())
342 |
343 | if image is not None:
344 | assert image.shape[0] == self.H and image.shape[1] == self.W, "image size does not match specfied size"
345 | image = image[ul_row : ul_row + trgt_H, ul_col : ul_col + trgt_W]
346 | return camera, image
347 |
348 | def resize(self, factor, image=None):
349 | trgt_H, trgt_W = int(self.H * factor), int(self.W * factor)
350 | K = self.K.clone()
351 | K[0, :3] *= trgt_W / self.W
352 | K[1, :3] *= trgt_H / self.H
353 | camera = Camera(trgt_W, trgt_H, K, self.W2C.clone())
354 |
355 | if image is not None:
356 | device = image.device
357 | image = cv2.resize(image.detach().cpu().numpy(), (trgt_W, trgt_H), interpolation=cv2.INTER_AREA)
358 | image = torch.from_numpy(image).to(device)
359 | return camera, image
360 |
361 |
362 | @torch.no_grad()
363 | def raytrace_pixels(sdf_network, raytracer, uv, camera, mask=None, max_num_rays=200000):
364 | if mask is None:
365 | mask = torch.ones_like(uv[..., 0]).bool()
366 |
367 | dots_sh = list(uv.shape[:-1])
368 |
369 | ray_o, ray_d, ray_d_norm = camera.get_rays(uv)
370 | sdf = lambda x: sdf_network(x)[..., 0]
371 |
372 | merge_results = None
373 | for ray_o_split, ray_d_split, ray_d_norm_split, mask_split in zip(
374 | torch.split(ray_o.view(-1, 3), max_num_rays, dim=0),
375 | torch.split(ray_d.view(-1, 3), max_num_rays, dim=0),
376 | torch.split(
377 | ray_d_norm.view(
378 | -1,
379 | ),
380 | max_num_rays,
381 | dim=0,
382 | ),
383 | torch.split(
384 | mask.view(
385 | -1,
386 | ),
387 | max_num_rays,
388 | dim=0,
389 | ),
390 | ):
391 | mask_intersect_split, min_dis_split, max_dis_split = intersect_sphere(ray_o_split, ray_d_split, r=1.0)
392 | results = raytracer(
393 | sdf,
394 | ray_o_split,
395 | ray_d_split,
396 | min_dis_split,
397 | max_dis_split,
398 | mask_intersect_split & mask_split,
399 | )
400 | results["depth"] = results["distance"] / ray_d_norm_split
401 |
402 | if merge_results is None:
403 | merge_results = dict(
404 | [
405 | (
406 | x,
407 | [
408 | results[x],
409 | ],
410 | )
411 | for x in results.keys()
412 | if isinstance(results[x], torch.Tensor)
413 | ]
414 | )
415 | else:
416 | for x in results.keys():
417 | merge_results[x].append(results[x]) # gpu
418 |
419 | for x in list(merge_results.keys()):
420 | results = torch.cat(merge_results[x], dim=0).reshape(
421 | dots_sh
422 | + [
423 | -1,
424 | ]
425 | )
426 | if results.shape[-1] == 1:
427 | results = results[..., 0]
428 | merge_results[x] = results # gpu
429 |
430 | # append more results
431 | merge_results.update(
432 | {
433 | "uv": uv,
434 | "ray_o": ray_o,
435 | "ray_d": ray_d,
436 | "ray_d_norm": ray_d_norm,
437 | }
438 | )
439 | return merge_results
440 |
441 |
442 | def unique(x, dim=-1):
443 | """
444 | return: unique elements in x, and their original indices in x
445 | """
446 | unique, inverse = torch.unique(x, return_inverse=True, dim=dim)
447 | perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device)
448 | inverse, perm = inverse.flip([dim]), perm.flip([dim])
449 | return unique, inverse.new_empty(unique.size(dim)).scatter_(dim, inverse, perm)
450 |
451 |
452 | @torch.no_grad()
453 | def locate_edge_points(
454 | camera, walk_start_points, sdf_network, max_step, step_size, dot_threshold, max_num_rays=200000, mask=None
455 | ):
456 | """walk on the surface to locate 3d edge points with high precision"""
457 | if mask is None:
458 | mask = torch.ones_like(walk_start_points[..., 0]).bool()
459 |
460 | walk_finish_points = walk_start_points.clone()
461 | walk_edge_found_mask = mask.clone()
462 | n_valid = mask.sum()
463 | if n_valid > 0:
464 | dots_sh = list(walk_start_points.shape[:-1])
465 |
466 | walk_finish_points_valid = []
467 | walk_edge_found_mask_valid = []
468 | for cur_points_split in torch.split(walk_start_points[mask].clone().view(-1, 3).detach(), max_num_rays, dim=0):
469 | walk_edge_found_mask_split = torch.zeros_like(cur_points_split[..., 0]).bool()
470 | not_found_mask_split = ~walk_edge_found_mask_split
471 |
472 | ray_o_split = camera.get_camera_origin(prefix_shape=cur_points_split.shape[:-1])
473 |
474 | i = 0
475 | while True:
476 | cur_viewdir_split = ray_o_split[not_found_mask_split] - cur_points_split[not_found_mask_split]
477 | cur_viewdir_split = cur_viewdir_split / (cur_viewdir_split.norm(dim=-1, keepdim=True) + 1e-10)
478 | cur_sdf_split, _, cur_normal_split = sdf_network.get_all(
479 | cur_points_split[not_found_mask_split].view(-1, 3),
480 | is_training=False,
481 | )
482 | cur_normal_split = cur_normal_split / (cur_normal_split.norm(dim=-1, keepdim=True) + 1e-10)
483 |
484 | dot_split = (cur_normal_split * cur_viewdir_split).sum(dim=-1)
485 | tmp_not_found_mask = dot_split.abs() > dot_threshold
486 | walk_edge_found_mask_split[not_found_mask_split] = ~tmp_not_found_mask
487 | not_found_mask_split = ~walk_edge_found_mask_split
488 |
489 | if i >= max_step or not_found_mask_split.sum() == 0:
490 | break
491 |
492 | cur_walkdir_split = cur_normal_split - cur_viewdir_split / dot_split.unsqueeze(-1)
493 | cur_walkdir_split = cur_walkdir_split / (cur_walkdir_split.norm(dim=-1, keepdim=True) + 1e-10)
494 | # regularize walk direction such that we don't get far away from the zero iso-surface
495 | cur_walkdir_split = cur_walkdir_split - cur_sdf_split * cur_normal_split
496 | cur_points_split[not_found_mask_split] += (step_size * cur_walkdir_split)[tmp_not_found_mask]
497 |
498 | i += 1
499 |
500 | walk_finish_points_valid.append(cur_points_split)
501 | walk_edge_found_mask_valid.append(walk_edge_found_mask_split)
502 |
503 | walk_finish_points[mask] = torch.cat(walk_finish_points_valid, dim=0)
504 | walk_edge_found_mask[mask] = torch.cat(walk_edge_found_mask_valid, dim=0)
505 | walk_finish_points = walk_finish_points.reshape(
506 | dots_sh
507 | + [
508 | 3,
509 | ]
510 | )
511 | walk_edge_found_mask = walk_edge_found_mask.reshape(dots_sh)
512 |
513 | edge_points = walk_finish_points[walk_edge_found_mask]
514 | edge_mask = torch.zeros(camera.H, camera.W).bool().to(walk_finish_points.device)
515 | edge_uv = torch.zeros_like(edge_points[..., :2])
516 | update_pixels = torch.Tensor([]).long().to(walk_finish_points.device)
517 | if walk_edge_found_mask.any():
518 | # filter out edge points out of camera's fov;
519 | # if there are multiple edge points mapping to the same pixel, only keep one
520 | edge_uv = camera.project(edge_points)
521 | update_pixels = torch.floor(edge_uv.detach()).long()
522 | update_pixels = update_pixels[:, 1] * camera.W + update_pixels[:, 0]
523 | mask = (update_pixels < camera.H * camera.W) & (update_pixels >= 0)
524 | update_pixels, edge_points, edge_uv = update_pixels[mask], edge_points[mask], edge_uv[mask]
525 | if mask.any():
526 | cnt = update_pixels.shape[0]
527 | update_pixels, unique_idx = unique(update_pixels, dim=0)
528 | unique_idx = torch.arange(cnt, device=update_pixels.device)[unique_idx]
529 | # assert update_pixels.shape == unique_idx.shape, f"{update_pixels.shape},{unique_idx.shape}"
530 | edge_points = edge_points[unique_idx]
531 | edge_uv = edge_uv[unique_idx]
532 |
533 | edge_mask.view(-1)[update_pixels] = True
534 | # edge_cnt = edge_mask.sum()
535 | # assert (
536 | # edge_cnt == edge_points.shape[0]
537 | # ), f"{edge_cnt},{edge_points.shape},{edge_uv.shape},{update_pixels.shape},{torch.unique(update_pixels).shape},{update_pixels.min()},{update_pixels.max()}"
538 | # assert (
539 | # edge_cnt == edge_uv.shape[0]
540 | # ), f"{edge_cnt},{edge_points.shape},{edge_uv.shape},{update_pixels.shape},{torch.unique(update_pixels).shape}"
541 |
542 | # ic(edge_mask.shape, edge_points.shape, edge_uv.shape)
543 | results = {"edge_mask": edge_mask, "edge_points": edge_points, "edge_uv": edge_uv, "edge_pixel_idx": update_pixels}
544 |
545 | if VERBOSE_MODE: # debug
546 | edge_angles = torch.zeros_like(edge_mask).float()
547 | edge_sdf = torch.zeros_like(edge_mask).float().unsqueeze(-1)
548 | if edge_mask.any():
549 | ray_o = camera.get_camera_origin(prefix_shape=edge_points.shape[:-1])
550 | edge_viewdir = ray_o - edge_points
551 | edge_viewdir = edge_viewdir / (edge_viewdir.norm(dim=-1, keepdim=True) + 1e-10)
552 | with torch.enable_grad():
553 | edge_sdf_vals, _, edge_normals = sdf_network.get_all(edge_points, is_training=False)
554 | edge_normals = edge_normals / (edge_normals.norm(dim=-1, keepdim=True) + 1e-10)
555 | edge_dot = (edge_viewdir * edge_normals).sum(dim=-1)
556 | # edge_angles[edge_mask] = torch.rad2deg(torch.acos(edge_dot))
557 | # edge_sdf[edge_mask] = edge_sdf_vals
558 | edge_angles.view(-1)[update_pixels] = torch.rad2deg(torch.acos(edge_dot))
559 | edge_sdf.view(-1)[update_pixels] = edge_sdf_vals.squeeze(-1)
560 |
561 | results.update(
562 | {
563 | "walk_edge_found_mask": walk_edge_found_mask,
564 | "edge_angles": edge_angles,
565 | "edge_sdf": edge_sdf,
566 | }
567 | )
568 |
569 | return results
570 |
571 |
572 | @torch.no_grad()
573 | def raytrace_camera(
574 | camera,
575 | sdf_network,
576 | raytracer,
577 | max_num_rays=200000,
578 | fill_holes=False,
579 | detect_edges=False,
580 | ):
581 | results = raytrace_pixels(sdf_network, raytracer, camera.get_uv(), camera, max_num_rays=max_num_rays)
582 | results["depth"] *= results["convergent_mask"].float()
583 |
584 | if fill_holes:
585 | depth = results["depth"]
586 | kernel = torch.ones(3, 3).float().to(depth.device)
587 | depth = kornia.morphology.closing(depth.unsqueeze(0).unsqueeze(0), kernel).squeeze(0).squeeze(0)
588 | new_convergent_mask = depth > 1e-2
589 | update_mask = new_convergent_mask & (~results["convergent_mask"])
590 | if update_mask.any():
591 | results["depth"][update_mask] = depth[update_mask]
592 | results["convergent_mask"] = new_convergent_mask
593 | results["distance"] = results["depth"] * results["ray_d_norm"]
594 | results["points"] = results["ray_o"] + results["ray_d"] * results["distance"].unsqueeze(-1)
595 |
596 | if detect_edges:
597 | depth = results["depth"]
598 | convergent_mask = results["convergent_mask"]
599 | depth_grad_norm = kornia.filters.sobel(depth.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
600 | depth_edge_mask = (depth_grad_norm > 1e-2) & convergent_mask
601 | # depth_edge_mask = convergent_mask
602 |
603 | results.update(
604 | locate_edge_points(
605 | camera,
606 | results["points"],
607 | sdf_network,
608 | max_step=16,
609 | step_size=1e-3,
610 | dot_threshold=5e-2,
611 | max_num_rays=max_num_rays,
612 | mask=depth_edge_mask,
613 | )
614 | )
615 | results["convergent_mask"] &= ~results["edge_mask"]
616 |
617 | if VERBOSE_MODE: # debug
618 | results.update({"depth_grad_norm": depth_grad_norm, "depth_edge_mask": depth_edge_mask})
619 |
620 | return results
621 |
622 |
623 | def render_normal_and_color(
624 | results,
625 | sdf_network,
626 | color_network_dict,
627 | render_fn,
628 | is_training=False,
629 | max_num_pts=320000,
630 | ):
631 | """
632 | results: returned by raytrace_pixels function
633 |
634 | render interior and freespace pixels
635 | note: predicted color is black for freespace pixels
636 | """
637 | dots_sh = list(results["convergent_mask"].shape)
638 |
639 | merge_render_results = None
640 | for points_split, ray_d_split, ray_o_split, mask_split in zip(
641 | torch.split(results["points"].view(-1, 3), max_num_pts, dim=0),
642 | torch.split(results["ray_d"].view(-1, 3), max_num_pts, dim=0),
643 | torch.split(results["ray_o"].view(-1, 3), max_num_pts, dim=0),
644 | torch.split(results["convergent_mask"].view(-1), max_num_pts, dim=0),
645 | ):
646 | if mask_split.any():
647 | points_split, ray_d_split, ray_o_split = (
648 | points_split[mask_split],
649 | ray_d_split[mask_split],
650 | ray_o_split[mask_split],
651 | )
652 | sdf_split, feature_split, normal_split = sdf_network.get_all(points_split, is_training=is_training)
653 | if is_training:
654 | points_split = reparam_points(points_split, normal_split.detach(), -ray_d_split.detach(), sdf_split)
655 | # normal_split = normal_split / (normal_split.norm(dim=-1, keepdim=True) + 1e-10)
656 | else:
657 | points_split, ray_d_split, ray_o_split, normal_split, feature_split = (
658 | torch.Tensor([]).float().cuda(),
659 | torch.Tensor([]).float().cuda(),
660 | torch.Tensor([]).float().cuda(),
661 | torch.Tensor([]).float().cuda(),
662 | torch.Tensor([]).float().cuda(),
663 | )
664 |
665 | with torch.set_grad_enabled(is_training):
666 | render_results = render_fn(
667 | mask_split,
668 | color_network_dict,
669 | ray_o_split,
670 | ray_d_split,
671 | points_split,
672 | normal_split,
673 | feature_split,
674 | )
675 |
676 | if merge_render_results is None:
677 | merge_render_results = dict(
678 | [
679 | (
680 | x,
681 | [
682 | render_results[x],
683 | ],
684 | )
685 | for x in render_results.keys()
686 | ]
687 | )
688 | else:
689 | for x in render_results.keys():
690 | merge_render_results[x].append(render_results[x])
691 |
692 | for x in list(merge_render_results.keys()):
693 | tmp = torch.cat(merge_render_results[x], dim=0).reshape(
694 | dots_sh
695 | + [
696 | -1,
697 | ]
698 | )
699 | if tmp.shape[-1] == 1:
700 | tmp = tmp.squeeze(-1)
701 | merge_render_results[x] = tmp
702 |
703 | results.update(merge_render_results)
704 |
705 |
706 | def render_edge_pixels(
707 | results,
708 | camera,
709 | sdf_network,
710 | raytracer,
711 | color_network_dict,
712 | render_fn,
713 | is_training=False,
714 | ):
715 | edge_mask, edge_points, edge_uv, edge_pixel_idx = (
716 | results["edge_mask"],
717 | results["edge_points"],
718 | results["edge_uv"],
719 | results["edge_pixel_idx"],
720 | )
721 | edge_pixel_center = torch.floor(edge_uv) + 0.5
722 |
723 | edge_sdf, _, edge_grads = sdf_network.get_all(edge_points, is_training=is_training)
724 | edge_normals = edge_grads.detach() / (edge_grads.detach().norm(dim=-1, keepdim=True) + 1e-10)
725 | if is_training:
726 | edge_points = reparam_points(edge_points, edge_grads.detach(), edge_normals, edge_sdf)
727 | edge_uv = camera.project(edge_points)
728 |
729 | edge_normals2d = torch.matmul(edge_normals, camera.W2C[:3, :3].transpose(1, 0))[:, :2]
730 | edge_normals2d = edge_normals2d / (edge_normals2d.norm(dim=-1, keepdim=True) + 1e-10)
731 |
732 | # sample a point on both sides of the edge
733 | # approximately think of each pixel as being approximately a circle with radius 0.707=sqrt(2)/2
734 | pixel_radius = 0.707
735 | pos_side_uv = edge_pixel_center - pixel_radius * edge_normals2d
736 | neg_side_uv = edge_pixel_center + pixel_radius * edge_normals2d
737 |
738 | dot2d = torch.sum((edge_uv - edge_pixel_center) * edge_normals2d, dim=-1)
739 | alpha = 2 * torch.arccos(torch.clamp(dot2d / pixel_radius, min=0.0, max=1.0))
740 | pos_side_weight = 1.0 - (alpha - torch.sin(alpha)) / (2.0 * np.pi)
741 |
742 | # render positive-side and negative-side colors by raytracing; speed up using edge mask
743 | pos_side_results = raytrace_pixels(sdf_network, raytracer, pos_side_uv, camera)
744 | neg_side_results = raytrace_pixels(sdf_network, raytracer, neg_side_uv, camera)
745 | render_normal_and_color(pos_side_results, sdf_network, color_network_dict, render_fn, is_training=is_training)
746 | render_normal_and_color(neg_side_results, sdf_network, color_network_dict, render_fn, is_training=is_training)
747 | # ic(pos_side_results.keys(), pos_side_results['convergent_mask'].sum())
748 |
749 | # assign colors to edge pixels
750 | edge_color = pos_side_results["color"] * pos_side_weight.unsqueeze(-1) + neg_side_results["color"] * (
751 | 1.0 - pos_side_weight.unsqueeze(-1)
752 | )
753 | # results["color"][edge_mask] = edge_color
754 | # results["normal"][edge_mask] = edge_normals
755 |
756 | results["color"].view(-1, 3)[edge_pixel_idx] = edge_color
757 | # results["normal"].view(-1, 3)[edge_pixel_idx] = edge_normals
758 | results["normal"].view(-1, 3)[edge_pixel_idx] = edge_grads
759 |
760 | results["edge_pos_neg_normal"] = torch.cat(
761 | [
762 | pos_side_results["normal"][pos_side_results["convergent_mask"]],
763 | neg_side_results["normal"][neg_side_results["convergent_mask"]],
764 | ],
765 | dim=0,
766 | )
767 | # debug
768 | # results["uv"][edge_mask] = edge_uv.detach()
769 | # results["points"][edge_mask] = edge_points.detach()
770 |
771 | results["uv"].view(-1, 2)[edge_pixel_idx] = edge_uv.detach()
772 | results["points"].view(-1, 3)[edge_pixel_idx] = edge_points.detach()
773 |
774 | if VERBOSE_MODE:
775 | pos_side_weight_fullsize = torch.zeros_like(edge_mask).float()
776 | # pos_side_weight_fullsize[edge_mask] = pos_side_weight
777 | pos_side_weight_fullsize.view(-1)[edge_pixel_idx] = pos_side_weight
778 |
779 | pos_side_depth = torch.zeros_like(edge_mask).float()
780 | # pos_side_depth[edge_mask] = pos_side_results["depth"]
781 | pos_side_depth.view(-1)[edge_pixel_idx] = pos_side_results["depth"]
782 | neg_side_depth = torch.zeros_like(edge_mask).float()
783 | # neg_side_depth[edge_mask] = neg_side_results["depth"]
784 | neg_side_depth.view(-1)[edge_pixel_idx] = neg_side_results["depth"]
785 |
786 | pos_side_color = (
787 | torch.zeros(
788 | list(edge_mask.shape)
789 | + [
790 | 3,
791 | ]
792 | )
793 | .float()
794 | .to(edge_mask.device)
795 | )
796 | # pos_side_color[edge_mask] = pos_side_results["color"]
797 | pos_side_color.view(-1, 3)[edge_pixel_idx] = pos_side_results["color"]
798 | neg_side_color = (
799 | torch.zeros(
800 | list(edge_mask.shape)
801 | + [
802 | 3,
803 | ]
804 | )
805 | .float()
806 | .to(edge_mask.device)
807 | )
808 | # neg_side_color[edge_mask] = neg_side_results["color"]
809 | neg_side_color.view(-1, 3)[edge_pixel_idx] = neg_side_results["color"]
810 | results.update(
811 | {
812 | "edge_pos_side_weight": pos_side_weight_fullsize,
813 | "edge_normals2d": edge_normals2d,
814 | "pos_side_uv": pos_side_uv,
815 | "neg_side_uv": neg_side_uv,
816 | "edge_pos_side_depth": pos_side_depth,
817 | "edge_neg_side_depth": neg_side_depth,
818 | "edge_pos_side_color": pos_side_color,
819 | "edge_neg_side_color": neg_side_color,
820 | }
821 | )
822 |
823 |
824 | def render_camera(
825 | camera,
826 | sdf_network,
827 | raytracer,
828 | color_network_dict,
829 | render_fn,
830 | fill_holes=True,
831 | handle_edges=True,
832 | is_training=False,
833 | ):
834 | results = raytrace_camera(
835 | camera,
836 | sdf_network,
837 | raytracer,
838 | max_num_rays=200000,
839 | fill_holes=fill_holes,
840 | detect_edges=handle_edges,
841 | )
842 | render_normal_and_color(
843 | results,
844 | sdf_network,
845 | color_network_dict,
846 | render_fn,
847 | is_training=is_training,
848 | max_num_pts=320000,
849 | )
850 | if handle_edges and results["edge_mask"].sum() > 0:
851 | render_edge_pixels(
852 | results,
853 | camera,
854 | sdf_network,
855 | raytracer,
856 | color_network_dict,
857 | render_fn,
858 | is_training=is_training,
859 | )
860 | return results
861 |
--------------------------------------------------------------------------------
/models/renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import logging
6 | import mcubes
7 | from icecream import ic
8 |
9 |
10 | def extract_fields(bound_min, bound_max, resolution, query_func):
11 | N = 64
12 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
13 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
14 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
15 |
16 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
17 | with torch.no_grad():
18 | for xi, xs in enumerate(X):
19 | for yi, ys in enumerate(Y):
20 | for zi, zs in enumerate(Z):
21 | xx, yy, zz = torch.meshgrid(xs, ys, zs)
22 | pts = torch.cat(
23 | [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
24 | dim=-1,
25 | )
26 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
27 | u[
28 | xi * N : xi * N + len(xs),
29 | yi * N : yi * N + len(ys),
30 | zi * N : zi * N + len(zs),
31 | ] = val
32 | return u
33 |
34 |
35 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
36 | print("threshold: {}".format(threshold))
37 | u = extract_fields(bound_min, bound_max, resolution, query_func)
38 | vertices, triangles = mcubes.marching_cubes(u, threshold)
39 | b_max_np = bound_max.detach().cpu().numpy()
40 | b_min_np = bound_min.detach().cpu().numpy()
41 |
42 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
43 | return vertices, triangles
44 |
45 |
46 | def sample_pdf(bins, weights, n_samples, det=False):
47 | # This implementation is from NeRF
48 | # Get pdf
49 | weights = weights + 1e-5 # prevent nans
50 | pdf = weights / torch.sum(weights, -1, keepdim=True)
51 | cdf = torch.cumsum(pdf, -1)
52 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
53 | # Take uniform samples
54 | if det:
55 | u = torch.linspace(0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples)
56 | u = u.expand(list(cdf.shape[:-1]) + [n_samples])
57 | else:
58 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
59 |
60 | # Invert CDF
61 | u = u.contiguous()
62 | inds = torch.searchsorted(cdf, u, right=True)
63 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
64 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
65 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
66 |
67 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
68 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
69 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
70 |
71 | denom = cdf_g[..., 1] - cdf_g[..., 0]
72 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
73 | t = (u - cdf_g[..., 0]) / denom
74 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
75 |
76 | return samples
77 |
78 |
79 | class NeuSRenderer:
80 | def __init__(
81 | self,
82 | nerf,
83 | sdf_network,
84 | deviation_network,
85 | color_network,
86 | n_samples,
87 | n_importance,
88 | n_outside,
89 | up_sample_steps,
90 | perturb,
91 | ):
92 | self.nerf = nerf
93 | self.sdf_network = sdf_network
94 | self.deviation_network = deviation_network
95 | self.color_network = color_network
96 | self.n_samples = n_samples
97 | self.n_importance = n_importance
98 | self.n_outside = n_outside
99 | self.up_sample_steps = up_sample_steps
100 | self.perturb = perturb
101 |
102 | def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None):
103 | """
104 | Render background
105 | """
106 | batch_size, n_samples = z_vals.shape
107 |
108 | # Section length
109 | dists = z_vals[..., 1:] - z_vals[..., :-1]
110 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
111 | mid_z_vals = z_vals + dists * 0.5
112 |
113 | # Section midpoints
114 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3
115 |
116 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
117 | pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4
118 |
119 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
120 |
121 | pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
122 | dirs = dirs.reshape(-1, 3)
123 |
124 | density, sampled_color = nerf(pts, dirs)
125 | alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
126 | alpha = alpha.reshape(batch_size, n_samples)
127 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1.0 - alpha + 1e-7], -1), -1)[:, :-1]
128 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
129 | color = (weights[:, :, None] * sampled_color).sum(dim=1)
130 | if background_rgb is not None:
131 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
132 |
133 | return {
134 | "color": color,
135 | "sampled_color": sampled_color,
136 | "alpha": alpha,
137 | "weights": weights,
138 | }
139 |
140 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
141 | """
142 | Up sampling give a fixed inv_s
143 | """
144 | batch_size, n_samples = z_vals.shape
145 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
146 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
147 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
148 | sdf = sdf.reshape(batch_size, n_samples)
149 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
150 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
151 | mid_sdf = (prev_sdf + next_sdf) * 0.5
152 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
153 |
154 | # ----------------------------------------------------------------------------------------------------------
155 | # Use min value of [ cos, prev_cos ]
156 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
157 | # robust when meeting situations like below:
158 | #
159 | # SDF
160 | # ^
161 | # |\ -----x----...
162 | # | \ /
163 | # | x x
164 | # |---\----/-------------> 0 level
165 | # | \ /
166 | # | \/
167 | # |
168 | # ----------------------------------------------------------------------------------------------------------
169 | prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
170 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
171 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
172 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
173 |
174 | dist = next_z_vals - prev_z_vals
175 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
176 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5
177 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
178 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
179 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
180 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1.0 - alpha + 1e-7], -1), -1)[:, :-1]
181 |
182 | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
183 | return z_samples
184 |
185 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
186 | batch_size, n_samples = z_vals.shape
187 | _, n_importance = new_z_vals.shape
188 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
189 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
190 | z_vals, index = torch.sort(z_vals, dim=-1)
191 |
192 | if not last:
193 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
194 | sdf = torch.cat([sdf, new_sdf], dim=-1)
195 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
196 | index = index.reshape(-1)
197 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
198 |
199 | return z_vals, sdf
200 |
201 | def render_core(
202 | self,
203 | rays_o,
204 | rays_d,
205 | z_vals,
206 | sample_dist,
207 | sdf_network,
208 | deviation_network,
209 | color_network,
210 | background_alpha=None,
211 | background_sampled_color=None,
212 | background_rgb=None,
213 | cos_anneal_ratio=0.0,
214 | ):
215 | batch_size, n_samples = z_vals.shape
216 |
217 | # Section length
218 | dists = z_vals[..., 1:] - z_vals[..., :-1]
219 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
220 | mid_z_vals = z_vals + dists * 0.5
221 |
222 | # Section midpoints
223 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
224 | dirs = rays_d[:, None, :].expand(pts.shape)
225 |
226 | pts = pts.reshape(-1, 3)
227 | dirs = dirs.reshape(-1, 3)
228 |
229 | sdf_nn_output = sdf_network(pts)
230 | sdf = sdf_nn_output[:, :1]
231 | feature_vector = sdf_nn_output[:, 1:]
232 |
233 | gradients = sdf_network.gradient(pts)
234 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
235 |
236 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
237 | inv_s = inv_s.expand(batch_size * n_samples, 1)
238 |
239 | true_cos = (dirs * gradients).sum(-1, keepdim=True)
240 |
241 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
242 | # the cos value "not dead" at the beginning training iterations, for better convergence.
243 | iter_cos = -(
244 | F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + F.relu(-true_cos) * cos_anneal_ratio
245 | ) # always non-positive
246 |
247 | # Estimate signed distances at section points
248 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
249 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
250 |
251 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
252 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
253 |
254 | p = prev_cdf - next_cdf
255 | c = prev_cdf
256 |
257 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
258 |
259 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
260 | inside_sphere = (pts_norm < 1.0).float().detach()
261 | relax_inside_sphere = (pts_norm < 1.2).float().detach()
262 |
263 | # Render with background
264 | if background_alpha is not None:
265 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
266 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
267 | sampled_color = (
268 | sampled_color * inside_sphere[:, :, None]
269 | + background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
270 | )
271 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
272 |
273 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1.0 - alpha + 1e-7], -1), -1)[:, :-1]
274 | weights_sum = weights.sum(dim=-1, keepdim=True)
275 |
276 | color = (sampled_color * weights[:, :, None]).sum(dim=1)
277 | if background_rgb is not None: # Fixed background, usually black
278 | color = color + background_rgb * (1.0 - weights_sum)
279 |
280 | # Eikonal loss
281 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, dim=-1) - 1.0) ** 2
282 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
283 |
284 | return {
285 | "color": color,
286 | "sdf": sdf,
287 | "dists": dists,
288 | "gradients": gradients.reshape(batch_size, n_samples, 3),
289 | "s_val": 1.0 / inv_s,
290 | "mid_z_vals": mid_z_vals,
291 | "weights": weights,
292 | "cdf": c.reshape(batch_size, n_samples),
293 | "gradient_error": gradient_error,
294 | "inside_sphere": inside_sphere,
295 | }
296 |
297 | def render(
298 | self,
299 | rays_o,
300 | rays_d,
301 | near,
302 | far,
303 | perturb_overwrite=-1,
304 | background_rgb=None,
305 | cos_anneal_ratio=0.0,
306 | ):
307 | batch_size = len(rays_o)
308 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
309 | z_vals = torch.linspace(0.0, 1.0, self.n_samples)
310 | z_vals = near + (far - near) * z_vals[None, :]
311 |
312 | z_vals_outside = None
313 | if self.n_outside > 0:
314 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
315 |
316 | n_samples = self.n_samples
317 | perturb = self.perturb
318 |
319 | if perturb_overwrite >= 0:
320 | perturb = perturb_overwrite
321 | if perturb > 0:
322 | t_rand = torch.rand([batch_size, 1]) - 0.5
323 | z_vals = z_vals + t_rand * 2.0 / self.n_samples
324 |
325 | if self.n_outside > 0:
326 | mids = 0.5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
327 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
328 | lower = torch.cat([z_vals_outside[..., :1], mids], -1)
329 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
330 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
331 |
332 | if self.n_outside > 0:
333 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
334 |
335 | background_alpha = None
336 | background_sampled_color = None
337 |
338 | # Up sample
339 | if self.n_importance > 0:
340 | with torch.no_grad():
341 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
342 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
343 |
344 | for i in range(self.up_sample_steps):
345 | new_z_vals = self.up_sample(
346 | rays_o,
347 | rays_d,
348 | z_vals,
349 | sdf,
350 | self.n_importance // self.up_sample_steps,
351 | 64 * 2**i,
352 | )
353 | z_vals, sdf = self.cat_z_vals(
354 | rays_o,
355 | rays_d,
356 | z_vals,
357 | new_z_vals,
358 | sdf,
359 | last=(i + 1 == self.up_sample_steps),
360 | )
361 |
362 | n_samples = self.n_samples + self.n_importance
363 |
364 | # Background model
365 | if self.n_outside > 0:
366 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
367 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
368 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf)
369 |
370 | background_sampled_color = ret_outside["sampled_color"]
371 | background_alpha = ret_outside["alpha"]
372 |
373 | # Render core
374 | ret_fine = self.render_core(
375 | rays_o,
376 | rays_d,
377 | z_vals,
378 | sample_dist,
379 | self.sdf_network,
380 | self.deviation_network,
381 | self.color_network,
382 | background_rgb=background_rgb,
383 | background_alpha=background_alpha,
384 | background_sampled_color=background_sampled_color,
385 | cos_anneal_ratio=cos_anneal_ratio,
386 | )
387 |
388 | color_fine = ret_fine["color"]
389 | weights = ret_fine["weights"]
390 | weights_sum = weights.sum(dim=-1, keepdim=True)
391 | gradients = ret_fine["gradients"]
392 | s_val = ret_fine["s_val"].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
393 |
394 | return {
395 | "color_fine": color_fine,
396 | "s_val": s_val,
397 | "cdf_fine": ret_fine["cdf"],
398 | "weight_sum": weights_sum,
399 | "weight_max": torch.max(weights, dim=-1, keepdim=True)[0],
400 | "gradients": gradients,
401 | "weights": weights,
402 | "gradient_error": ret_fine["gradient_error"],
403 | "inside_sphere": ret_fine["inside_sphere"],
404 | }
405 |
406 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
407 | return extract_geometry(
408 | bound_min,
409 | bound_max,
410 | resolution=resolution,
411 | threshold=threshold,
412 | query_func=lambda pts: -self.sdf_network.sdf(pts),
413 | )
414 |
--------------------------------------------------------------------------------
/models/renderer_ggx.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import os
5 |
6 |
7 | ### https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/microfacet.h#L477
8 | def smithG1(cosTheta, alpha):
9 | sinTheta = torch.sqrt(1.0 - cosTheta * cosTheta)
10 | tanTheta = sinTheta / (cosTheta + 1e-10)
11 | root = alpha * tanTheta
12 | return 2.0 / (1.0 + torch.hypot(root, torch.ones_like(root)))
13 |
14 |
15 | class GGXColocatedRenderer(nn.Module):
16 | def __init__(self, use_cuda=False):
17 | super().__init__()
18 |
19 | self.MTS_TRANS = torch.from_numpy(
20 | np.loadtxt(os.path.join(os.path.dirname(os.path.abspath(__file__)), "ggx/ext_mts_rtrans_data.txt")).astype(
21 | np.float32
22 | )
23 | ) # 5000 entries, external IOR
24 | self.MTS_DIFF_TRANS = torch.from_numpy(
25 | np.loadtxt(
26 | os.path.join(os.path.dirname(os.path.abspath(__file__)), "ggx/int_mts_diff_rtrans_data.txt")
27 | ).astype(np.float32)
28 | ) # 50 entries, internal IOR
29 | self.num_theta_samples = 100
30 | self.num_alpha_samples = 50
31 |
32 | if use_cuda:
33 | self.MTS_TRANS = self.MTS_TRANS.cuda()
34 | self.MTS_DIFF_TRANS = self.MTS_DIFF_TRANS.cuda()
35 |
36 | def forward(self, light, distance, normal, viewdir, diffuse_albedo, specular_albedo, alpha):
37 | """
38 | light:
39 | distance: [..., 1]
40 | normal, viewdir: [..., 3]; both normal and viewdir point away from objects
41 | diffuse_albedo, specular_albedo: [..., 3]
42 | alpha: [..., 1]; roughness
43 | """
44 | # decay light according to squared-distance falloff
45 | light_intensity = light / (distance * distance + 1e-10)
46 |
47 | # = = in colocated setting
48 | dot = torch.sum(viewdir * normal, dim=-1, keepdims=True)
49 | dot = torch.clamp(dot, min=0.00001, max=0.99999) # must be very precise; cannot be 0.999
50 | # default value of IOR['polypropylene'] / IOR['air'].
51 | m_eta = 1.48958738
52 | m_invEta2 = 1.0 / (m_eta * m_eta)
53 |
54 | # clamp alpha for numeric stability
55 | alpha = torch.clamp(alpha, min=0.0001)
56 |
57 | # specular term: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/roughplastic.cpp#L347
58 | ## compute GGX NDF: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/microfacet.h#L191
59 | cosTheta2 = dot * dot
60 | root = cosTheta2 + (1.0 - cosTheta2) / (alpha * alpha + 1e-10)
61 | D = 1.0 / (np.pi * alpha * alpha * root * root + 1e-10)
62 | ## compute fresnel: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/libcore/util.cpp#L651
63 | # F = 0.04
64 | F = 0.03867
65 |
66 | ## compute shadowing term: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/microfacet.h#L520
67 | G = smithG1(dot, alpha) ** 2 # [..., 1]
68 |
69 | specular_rgb = light_intensity * specular_albedo * F * D * G / (4.0 * dot + 1e-10)
70 |
71 | # diffuse term: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/roughplastic.cpp#L367
72 | ## compute T12: : https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/rtrans.h#L183
73 | ### data_file: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/rtrans.h#L93
74 | ### assume eta is fixed
75 | warpedCosTheta = dot**0.25
76 | alphaMin, alphaMax = 0, 4
77 | warpedAlpha = ((alpha - alphaMin) / (alphaMax - alphaMin)) ** 0.25 # [..., 1]
78 | tx = torch.floor(warpedCosTheta * self.num_theta_samples).long()
79 | ty = torch.floor(warpedAlpha * self.num_alpha_samples).long()
80 | t_idx = ty * self.num_theta_samples + tx
81 |
82 | dots_sh = list(t_idx.shape[:-1])
83 | data = self.MTS_TRANS.view([1,] * len(dots_sh) + [-1,]).expand(
84 | dots_sh
85 | + [
86 | -1,
87 | ]
88 | )
89 |
90 | t_idx = torch.clamp(t_idx, min=0, max=data.shape[-1] - 1).long() # important
91 | T12 = torch.clamp(torch.gather(input=data, index=t_idx, dim=-1), min=0.0, max=1.0)
92 | T21 = T12 # colocated setting
93 |
94 | ## compute Fdr: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/rtrans.h#L249
95 | t_idx = torch.floor(warpedAlpha * self.num_alpha_samples).long()
96 | data = self.MTS_DIFF_TRANS.view([1,] * len(dots_sh) + [-1,]).expand(
97 | dots_sh
98 | + [
99 | -1,
100 | ]
101 | )
102 | t_idx = torch.clamp(t_idx, min=0, max=data.shape[-1] - 1).long() # important
103 | Fdr = torch.clamp(1.0 - torch.gather(input=data, index=t_idx, dim=-1), min=0.0, max=1.0) # [..., 1]
104 |
105 | diffuse_rgb = light_intensity * (diffuse_albedo / (1.0 - Fdr + 1e-10) / np.pi) * dot * T12 * T21 * m_invEta2
106 | ret = {"diffuse_rgb": diffuse_rgb, "specular_rgb": specular_rgb, "rgb": diffuse_rgb + specular_rgb}
107 | return ret
108 |
--------------------------------------------------------------------------------
/readme_resources/assets_lowres.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/readme_resources/assets_lowres.png
--------------------------------------------------------------------------------
/readme_resources/inputs_outputs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/readme_resources/inputs_outputs.png
--------------------------------------------------------------------------------
/render_surface.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import json
7 | import imageio
8 | imageio.plugins.freeimage.download()
9 | from torch.utils.tensorboard import SummaryWriter
10 | import configargparse
11 | from icecream import ic
12 | import glob
13 | import shutil
14 | import traceback
15 |
16 | from models.fields import SDFNetwork, RenderingNetwork
17 | from models.raytracer import RayTracer, Camera, render_camera
18 | from models.renderer_ggx import GGXColocatedRenderer
19 | from models.image_losses import PyramidL2Loss, ssim_loss_fn
20 | from models.export_mesh import export_mesh
21 | from models.export_materials import export_materials
22 |
23 | ###### arguments
24 | def config_parser():
25 | parser = configargparse.ArgumentParser()
26 | parser.add_argument("--data_dir", type=str, default=None, help="input data directory")
27 | parser.add_argument("--out_dir", type=str, default=None, help="output directory")
28 | parser.add_argument("--neus_ckpt_fpath", type=str, default=None, help="checkpoint to load")
29 | parser.add_argument("--num_iters", type=int, default=100001, help="number of iterations")
30 | parser.add_argument("--patch_size", type=int, default=128, help="width and height of the rendered patches")
31 | parser.add_argument("--eik_weight", type=float, default=0.1, help="weight for eikonal loss")
32 | parser.add_argument("--ssim_weight", type=float, default=1.0, help="weight for ssim loss")
33 | parser.add_argument("--roughrange_weight", type=float, default=0.1, help="weight for roughness range loss")
34 |
35 | parser.add_argument("--plot_image_name", type=str, default=None, help="image to plot during training")
36 | parser.add_argument("--no_edgesample", action="store_true", help="whether to disable edge sampling")
37 | parser.add_argument(
38 | "--inv_gamma_gt", action="store_true", help="whether to inverse gamma correct the ground-truth photos"
39 | )
40 | parser.add_argument("--gamma_pred", action="store_true", help="whether to gamma correct the predictions")
41 | parser.add_argument(
42 | "--is_metal",
43 | action="store_true",
44 | help="whether the object of interest is made of metals or the scene contains metals",
45 | )
46 | parser.add_argument("--init_light_scale", type=float, default=8.0, help="scaling parameters for light")
47 | parser.add_argument(
48 | "--export_all",
49 | action="store_true",
50 | help="whether to export meshes and uv textures",
51 | )
52 | parser.add_argument(
53 | "--render_all",
54 | action="store_true",
55 | help="whether to render the input image set",
56 | )
57 | return parser
58 |
59 |
60 | parser = config_parser()
61 | args = parser.parse_args()
62 | ic(args)
63 |
64 | ###### back up arguments and code scripts
65 | os.makedirs(args.out_dir, exist_ok=True)
66 | parser.write_config_file(
67 | args,
68 | [
69 | os.path.join(args.out_dir, "args.txt"),
70 | ],
71 | )
72 |
73 |
74 | ###### rendering functions
75 | def get_materials(color_network_dict, points, normals, features, is_metal=args.is_metal):
76 | diffuse_albedo = color_network_dict["diffuse_albedo_network"](points, normals, -normals, features).abs()[
77 | ..., [2, 1, 0]
78 | ]
79 | specular_albedo = color_network_dict["specular_albedo_network"](points, normals, None, features).abs()
80 | if not is_metal:
81 | specular_albedo = torch.mean(specular_albedo, dim=-1, keepdim=True).expand_as(specular_albedo)
82 | specular_roughness = color_network_dict["specular_roughness_network"](points, normals, None, features).abs() + 0.01
83 | return diffuse_albedo, specular_albedo, specular_roughness
84 |
85 |
86 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features):
87 | dots_sh = list(interior_mask.shape)
88 | rgb = torch.zeros(
89 | dots_sh
90 | + [
91 | 3,
92 | ],
93 | dtype=torch.float32,
94 | device=interior_mask.device,
95 | )
96 | diffuse_rgb = rgb.clone()
97 | specular_rgb = rgb.clone()
98 | diffuse_albedo = rgb.clone()
99 | specular_albedo = rgb.clone()
100 | specular_roughness = rgb[..., 0].clone()
101 | normals_pad = rgb.clone()
102 | if interior_mask.any():
103 | normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-10)
104 | interior_diffuse_albedo, interior_specular_albedo, interior_specular_roughness = get_materials(
105 | color_network_dict, points, normals, features
106 | )
107 | results = ggx_renderer(
108 | color_network_dict["point_light_network"](),
109 | (points - ray_o).norm(dim=-1, keepdim=True),
110 | normals,
111 | -ray_d,
112 | interior_diffuse_albedo,
113 | interior_specular_albedo,
114 | interior_specular_roughness,
115 | )
116 | rgb[interior_mask] = results["rgb"]
117 | diffuse_rgb[interior_mask] = results["diffuse_rgb"]
118 | specular_rgb[interior_mask] = results["specular_rgb"]
119 | diffuse_albedo[interior_mask] = interior_diffuse_albedo
120 | specular_albedo[interior_mask] = interior_specular_albedo
121 | specular_roughness[interior_mask] = interior_specular_roughness.squeeze(-1)
122 | normals_pad[interior_mask] = normals
123 |
124 | return {
125 | "color": rgb,
126 | "diffuse_color": diffuse_rgb,
127 | "specular_color": specular_rgb,
128 | "diffuse_albedo": diffuse_albedo,
129 | "specular_albedo": specular_albedo,
130 | "specular_roughness": specular_roughness,
131 | "normal": normals_pad,
132 | }
133 |
134 |
135 | ###### network specifications
136 | sdf_network = SDFNetwork(
137 | d_in=3,
138 | d_out=257,
139 | d_hidden=256,
140 | n_layers=8,
141 | skip_in=[
142 | 4,
143 | ],
144 | multires=6,
145 | bias=0.5,
146 | scale=1.0,
147 | geometric_init=True,
148 | weight_norm=True,
149 | ).cuda()
150 | raytracer = RayTracer()
151 |
152 |
153 | class PointLightNetwork(nn.Module):
154 | def __init__(self):
155 | super().__init__()
156 | self.register_parameter("light", nn.Parameter(torch.tensor(5.0)))
157 |
158 | def forward(self):
159 | return self.light
160 |
161 | def set_light(self, light):
162 | self.light.data.fill_(light)
163 |
164 | def get_light(self):
165 | return self.light.data.clone().detach()
166 |
167 |
168 | color_network_dict = {
169 | "color_network": RenderingNetwork(
170 | d_in=9,
171 | d_out=3,
172 | d_feature=256,
173 | d_hidden=256,
174 | n_layers=4,
175 | multires_view=4,
176 | mode="idr",
177 | squeeze_out=True,
178 | ).cuda(),
179 | "diffuse_albedo_network": RenderingNetwork(
180 | d_in=9,
181 | d_out=3,
182 | d_feature=256,
183 | d_hidden=256,
184 | n_layers=8,
185 | multires=10,
186 | multires_view=4,
187 | mode="idr",
188 | squeeze_out=True,
189 | skip_in=(4,),
190 | ).cuda(),
191 | "specular_albedo_network": RenderingNetwork(
192 | d_in=6,
193 | d_out=3,
194 | d_feature=256,
195 | d_hidden=256,
196 | n_layers=4,
197 | multires=6,
198 | multires_view=-1,
199 | mode="no_view_dir",
200 | squeeze_out=False,
201 | output_bias=0.4,
202 | output_scale=0.1,
203 | ).cuda(),
204 | "specular_roughness_network": RenderingNetwork(
205 | d_in=6,
206 | d_out=1,
207 | d_feature=256,
208 | d_hidden=256,
209 | n_layers=4,
210 | multires=6,
211 | multires_view=-1,
212 | mode="no_view_dir",
213 | squeeze_out=False,
214 | output_bias=0.1,
215 | output_scale=0.1,
216 | ).cuda(),
217 | "point_light_network": PointLightNetwork().cuda(),
218 | }
219 |
220 | ###### optimizer specifications
221 | sdf_optimizer = torch.optim.Adam(sdf_network.parameters(), lr=1e-5)
222 | color_optimizer_dict = {
223 | "color_network": torch.optim.Adam(color_network_dict["color_network"].parameters(), lr=1e-4),
224 | "diffuse_albedo_network": torch.optim.Adam(color_network_dict["diffuse_albedo_network"].parameters(), lr=1e-4),
225 | "specular_albedo_network": torch.optim.Adam(color_network_dict["specular_albedo_network"].parameters(), lr=1e-4),
226 | "specular_roughness_network": torch.optim.Adam(
227 | color_network_dict["specular_roughness_network"].parameters(), lr=1e-4
228 | ),
229 | "point_light_network": torch.optim.Adam(color_network_dict["point_light_network"].parameters(), lr=1e-2),
230 | }
231 |
232 | ###### loss specifications
233 | ggx_renderer = GGXColocatedRenderer(use_cuda=True)
234 | pyramidl2_loss_fn = PyramidL2Loss(use_cuda=True)
235 |
236 | ###### load dataset
237 | def to8b(x):
238 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8)
239 |
240 |
241 | def load_datadir(datadir):
242 | cam_dict = json.load(open(os.path.join(datadir, "cam_dict_norm.json")))
243 | imgnames = list(cam_dict.keys())
244 | try:
245 | imgnames = sorted(imgnames, key=lambda x: int(x[:-4]))
246 | except:
247 | imgnames = sorted(imgnames)
248 |
249 | image_fpaths = []
250 | gt_images = []
251 | Ks = []
252 | W2Cs = []
253 | for x in imgnames:
254 | fpath = os.path.join(datadir, "image", x)
255 | assert fpath[-4:] in [".jpg", ".png"], "must use ldr images as inputs"
256 | im = imageio.imread(fpath).astype(np.float32) / 255.0
257 | K = np.array(cam_dict[x]["K"]).reshape((4, 4)).astype(np.float32)
258 | W2C = np.array(cam_dict[x]["W2C"]).reshape((4, 4)).astype(np.float32)
259 |
260 | image_fpaths.append(fpath)
261 | gt_images.append(torch.from_numpy(im))
262 | Ks.append(torch.from_numpy(K))
263 | W2Cs.append(torch.from_numpy(W2C))
264 | gt_images = torch.stack(gt_images, dim=0)
265 | Ks = torch.stack(Ks, dim=0)
266 | W2Cs = torch.stack(W2Cs, dim=0)
267 | return image_fpaths, gt_images, Ks, W2Cs
268 |
269 |
270 | image_fpaths, gt_images, Ks, W2Cs = load_datadir(args.data_dir)
271 | cameras = [
272 | Camera(W=gt_images[i].shape[1], H=gt_images[i].shape[0], K=Ks[i].cuda(), W2C=W2Cs[i].cuda())
273 | for i in range(gt_images.shape[0])
274 | ]
275 | ic(len(image_fpaths), gt_images.shape, Ks.shape, W2Cs.shape, len(cameras))
276 |
277 | ###### initialization using neus
278 | ic(args.neus_ckpt_fpath)
279 | if os.path.isfile(args.neus_ckpt_fpath):
280 | ic(f"Loading from neus checkpoint: {args.neus_ckpt_fpath}")
281 | ckpt = torch.load(args.neus_ckpt_fpath, map_location=torch.device("cuda"))
282 | try:
283 | sdf_network.load_state_dict(ckpt["sdf_network_fine"])
284 | color_network_dict["diffuse_albedo_network"].load_state_dict(ckpt["color_network_fine"])
285 | except:
286 | traceback.print_exc()
287 | # ic("Failed to initialize diffuse_albedo_network from checkpoint: ", args.neus_ckpt_fpath)
288 | dist = np.median([torch.norm(cameras[i].get_camera_origin()).item() for i in range(len(cameras))])
289 | init_light = args.init_light_scale * dist * dist
290 | color_network_dict["point_light_network"].set_light(init_light)
291 |
292 | #### load pretrained checkpoints
293 | start_step = -1
294 | ckpt_fpaths = glob.glob(os.path.join(args.out_dir, "ckpt_*.pth"))
295 | if len(ckpt_fpaths) > 0:
296 | path2step = lambda x: int(os.path.basename(x)[len("ckpt_") : -4])
297 | ckpt_fpaths = sorted(ckpt_fpaths, key=path2step)
298 | ckpt_fpath = ckpt_fpaths[-1]
299 | start_step = path2step(ckpt_fpath)
300 | ic("Reloading from checkpoint: ", ckpt_fpath)
301 | ckpt = torch.load(ckpt_fpath, map_location=torch.device("cuda"))
302 | sdf_network.load_state_dict(ckpt["sdf_network"])
303 | for x in list(color_network_dict.keys()):
304 | color_network_dict[x].load_state_dict(ckpt[x])
305 | # logim_names = [os.path.basename(x) for x in glob.glob(os.path.join(args.out_dir, "logim_*.png"))]
306 | # start_step = sorted([int(x[len("logim_") : -4]) for x in logim_names])[-1]
307 | ic(dist, color_network_dict["point_light_network"].light.data)
308 | ic(start_step)
309 |
310 |
311 | ###### export mesh and materials
312 | blender_fpath = "./blender-3.1.0-linux-x64/blender"
313 | if not os.path.isfile(blender_fpath):
314 | os.system(
315 | "wget https://mirror.clarkson.edu/blender/release/Blender3.1/blender-3.1.0-linux-x64.tar.xz && \
316 | tar -xvf blender-3.1.0-linux-x64.tar.xz"
317 | )
318 |
319 |
320 | def export_mesh_and_materials(export_out_dir, sdf_network, color_network_dict):
321 | ic(f"Exporting mesh and materials to: {export_out_dir}")
322 | sdf_fn = lambda x: sdf_network(x)[..., 0]
323 | ic("Exporting mesh and uv...")
324 | with torch.no_grad():
325 | export_mesh(sdf_fn, os.path.join(export_out_dir, "mesh.obj"))
326 | os.system(
327 | f"{blender_fpath} --background --python models/export_uv.py {os.path.join(export_out_dir, 'mesh.obj')} {os.path.join(export_out_dir, 'mesh.obj')}"
328 | )
329 |
330 | class MaterialPredictor(nn.Module):
331 | def __init__(self, sdf_network, color_network_dict):
332 | super().__init__()
333 | self.sdf_network = sdf_network
334 | self.color_network_dict = color_network_dict
335 |
336 | def forward(self, points):
337 | _, features, normals = self.sdf_network.get_all(points, is_training=False)
338 | normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-10)
339 | diffuse_albedo, specular_albedo, specular_roughness = get_materials(
340 | color_network_dict, points, normals, features
341 | )
342 | return diffuse_albedo, specular_albedo, specular_roughness
343 |
344 | ic("Exporting materials...")
345 | material_predictor = MaterialPredictor(sdf_network, color_network_dict)
346 | with torch.no_grad():
347 | export_materials(os.path.join(export_out_dir, "mesh.obj"), material_predictor, export_out_dir)
348 |
349 | ic(f"Exported mesh and materials to: {export_out_dir}")
350 |
351 |
352 | if args.export_all:
353 | export_out_dir = os.path.join(args.out_dir, f"mesh_and_materials_{start_step}")
354 | os.makedirs(export_out_dir, exist_ok=True)
355 | export_mesh_and_materials(export_out_dir, sdf_network, color_network_dict)
356 | exit(0)
357 |
358 |
359 | ###### render all images
360 | if args.render_all:
361 | render_out_dir = os.path.join(args.out_dir, f"render_{os.path.basename(args.data_dir)}_{start_step}")
362 | os.makedirs(render_out_dir, exist_ok=True)
363 | ic(f"Rendering images to: {render_out_dir}")
364 | n_cams = len(cameras)
365 | for i in tqdm.tqdm(range(n_cams)):
366 | cam, impath = cameras[i], image_fpaths[i]
367 | results = render_camera(
368 | cam,
369 | sdf_network,
370 | raytracer,
371 | color_network_dict,
372 | render_fn,
373 | fill_holes=True,
374 | handle_edges=True,
375 | is_training=False,
376 | )
377 | if args.gamma_pred:
378 | results["color"] = torch.pow(results["color"] + 1e-6, 1.0 / 2.2)
379 | for x in list(results.keys()):
380 | results[x] = results[x].detach().cpu().numpy()
381 | color_im = results["color"]
382 | imageio.imwrite(os.path.join(render_out_dir, os.path.basename(impath)), to8b(color_im))
383 | exit(0)
384 |
385 | ###### training
386 | fill_holes = False
387 | handle_edges = not args.no_edgesample
388 | is_training = True
389 | if args.inv_gamma_gt:
390 | ic("linearizing ground-truth images using inverse gamma correction")
391 | gt_images = torch.pow(gt_images, 2.2)
392 |
393 | ic(fill_holes, handle_edges, is_training, args.inv_gamma_gt)
394 | writer = SummaryWriter(log_dir=os.path.join(args.out_dir, "logs"))
395 |
396 | for global_step in tqdm.tqdm(range(start_step + 1, args.num_iters)):
397 | sdf_optimizer.zero_grad()
398 | for x in color_optimizer_dict.keys():
399 | color_optimizer_dict[x].zero_grad()
400 |
401 | idx = np.random.randint(0, gt_images.shape[0])
402 | camera_crop, gt_color_crop = cameras[idx].crop_region(
403 | trgt_W=args.patch_size, trgt_H=args.patch_size, image=gt_images[idx]
404 | )
405 |
406 | results = render_camera(
407 | camera_crop,
408 | sdf_network,
409 | raytracer,
410 | color_network_dict,
411 | render_fn,
412 | fill_holes=fill_holes,
413 | handle_edges=handle_edges,
414 | is_training=is_training,
415 | )
416 | if args.gamma_pred:
417 | results["color"] = torch.pow(results["color"] + 1e-6, 1.0 / 2.2)
418 | results["diffuse_color"] = torch.pow(results["diffuse_color"] + 1e-6, 1.0 / 2.2)
419 | results["specular_color"] = torch.clamp(results["color"] - results["diffuse_color"], min=0.0)
420 |
421 | mask = results["convergent_mask"]
422 | if handle_edges:
423 | mask = mask | results["edge_mask"]
424 |
425 | img_loss = torch.Tensor([0.0]).cuda()
426 | img_l2_loss = torch.Tensor([0.0]).cuda()
427 | img_ssim_loss = torch.Tensor([0.0]).cuda()
428 | roughrange_loss = torch.Tensor([0.0]).cuda()
429 |
430 | eik_points = torch.empty(camera_crop.H * camera_crop.W // 2, 3).cuda().float().uniform_(-1.0, 1.0)
431 | eik_grad = sdf_network.gradient(eik_points).view(-1, 3)
432 | eik_cnt = eik_grad.shape[0]
433 | eik_loss = ((eik_grad.norm(dim=-1) - 1) ** 2).sum()
434 | if mask.any():
435 | pred_img = results["color"].permute(2, 0, 1).unsqueeze(0)
436 | gt_img = gt_color_crop.permute(2, 0, 1).unsqueeze(0).to(pred_img.device)
437 | img_l2_loss = pyramidl2_loss_fn(pred_img, gt_img)
438 | img_ssim_loss = args.ssim_weight * ssim_loss_fn(pred_img, gt_img, mask.unsqueeze(0).unsqueeze(0))
439 | img_loss = img_l2_loss + img_ssim_loss
440 |
441 | eik_grad = results["normal"][mask]
442 | eik_cnt += eik_grad.shape[0]
443 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum()
444 | if "edge_pos_neg_normal" in results:
445 | eik_grad = results["edge_pos_neg_normal"]
446 | eik_cnt += eik_grad.shape[0]
447 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum()
448 |
449 | roughness = results["specular_roughness"][mask]
450 | roughness = roughness[roughness > 0.5]
451 | if roughness.numel() > 0:
452 | roughrange_loss = (roughness - 0.5).mean() * args.roughrange_weight
453 | eik_loss = eik_loss / eik_cnt * args.eik_weight
454 |
455 | loss = img_loss + eik_loss + roughrange_loss
456 | loss.backward()
457 | sdf_optimizer.step()
458 | for x in color_optimizer_dict.keys():
459 | color_optimizer_dict[x].step()
460 |
461 | if global_step % 50 == 0:
462 | writer.add_scalar("loss/loss", loss, global_step)
463 | writer.add_scalar("loss/img_loss", img_loss, global_step)
464 | writer.add_scalar("loss/img_l2_loss", img_l2_loss, global_step)
465 | writer.add_scalar("loss/img_ssim_loss", img_ssim_loss, global_step)
466 | writer.add_scalar("loss/eik_loss", eik_loss, global_step)
467 | writer.add_scalar("loss/roughrange_loss", roughrange_loss, global_step)
468 | writer.add_scalar("light", color_network_dict["point_light_network"].get_light())
469 |
470 | if global_step % 1000 == 0:
471 | torch.save(
472 | dict(
473 | [
474 | ("sdf_network", sdf_network.state_dict()),
475 | ]
476 | + [(x, color_network_dict[x].state_dict()) for x in color_network_dict.keys()]
477 | ),
478 | os.path.join(args.out_dir, f"ckpt_{global_step}.pth"),
479 | )
480 |
481 | if global_step % 500 == 0:
482 | ic(
483 | args.out_dir,
484 | global_step,
485 | loss.item(),
486 | img_loss.item(),
487 | img_l2_loss.item(),
488 | img_ssim_loss.item(),
489 | eik_loss.item(),
490 | roughrange_loss.item(),
491 | color_network_dict["point_light_network"].get_light().item(),
492 | )
493 |
494 | for x in list(results.keys()):
495 | del results[x]
496 |
497 | idx = 0
498 | if args.plot_image_name is not None:
499 | while idx < len(image_fpaths):
500 | if args.plot_image_name in image_fpaths[idx]:
501 | break
502 | idx += 1
503 |
504 | camera_resize, gt_color_resize = cameras[idx].resize(factor=0.25, image=gt_images[idx])
505 | results = render_camera(
506 | camera_resize,
507 | sdf_network,
508 | raytracer,
509 | color_network_dict,
510 | render_fn,
511 | fill_holes=fill_holes,
512 | handle_edges=handle_edges,
513 | is_training=False,
514 | )
515 | if args.gamma_pred:
516 | results["color"] = torch.pow(results["color"] + 1e-6, 1.0 / 2.2)
517 | results["diffuse_color"] = torch.pow(results["diffuse_color"] + 1e-6, 1.0 / 2.2)
518 | results["specular_color"] = torch.clamp(results["color"] - results["diffuse_color"], min=0.0)
519 | for x in list(results.keys()):
520 | results[x] = results[x].detach().cpu().numpy()
521 |
522 | gt_color_im = gt_color_resize.detach().cpu().numpy()
523 | color_im = results["color"]
524 | diffuse_color_im = results["diffuse_color"]
525 | specular_color_im = results["specular_color"]
526 | normal = results["normal"]
527 | normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-10)
528 | normal_im = (normal + 1.0) / 2.0
529 | edge_mask_im = np.tile(results["edge_mask"][:, :, np.newaxis], (1, 1, 3))
530 | diffuse_albedo_im = results["diffuse_albedo"]
531 | specular_albedo_im = results["specular_albedo"]
532 | specular_roughness_im = np.tile(results["specular_roughness"][:, :, np.newaxis], (1, 1, 3))
533 | if args.inv_gamma_gt:
534 | gt_color_im = np.power(gt_color_im + 1e-6, 1.0 / 2.2)
535 | color_im = np.power(color_im + 1e-6, 1.0 / 2.2)
536 | diffuse_color_im = np.power(diffuse_color_im + 1e-6, 1.0 / 2.2)
537 | specular_color_im = color_im - diffuse_color_im
538 |
539 | row1 = np.concatenate([gt_color_im, normal_im, edge_mask_im], axis=1)
540 | row2 = np.concatenate([color_im, diffuse_color_im, specular_color_im], axis=1)
541 | row3 = np.concatenate([diffuse_albedo_im, specular_albedo_im, specular_roughness_im], axis=1)
542 | im = np.concatenate((row1, row2, row3), axis=0)
543 | imageio.imwrite(os.path.join(args.out_dir, f"logim_{global_step}.png"), to8b(im))
544 |
545 |
546 | ###### export mesh and materials
547 | export_out_dir = os.path.join(args.out_dir, f"mesh_and_materials_{global_step}")
548 | os.makedirs(export_out_dir, exist_ok=True)
549 | export_mesh_and_materials(export_out_dir, sdf_network, color_network_dict)
550 |
--------------------------------------------------------------------------------
/render_synthetic_data/render_rgb_flash_mat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | import shutil
5 | import imageio
6 |
7 | imageio.plugins.freeimage.download()
8 |
9 |
10 | asset_dir = 'path/to/synthetic_assets'
11 | out_dir = 'path/to/output_folder'
12 |
13 | for scene in os.listdir(asset_dir):
14 | in_scene_dir = os.path.join(asset_dir, scene)
15 | out_scene_dir = os.path.join(out_dir, scene)
16 | os.makedirs(out_scene_dir, exist_ok=True)
17 |
18 | light = 20.
19 | with open(os.path.join(out_scene_dir, 'light.txt'), 'w') as fp:
20 | fp.write(f'{light}\n')
21 |
22 | for split in ['train', 'test']:
23 | out_split_dir = os.path.join(out_scene_dir, split)
24 | os.makedirs(os.path.join(out_split_dir, 'image'), exist_ok=True)
25 |
26 | cam_dict_fpath = os.path.join(asset_dir, f'{split}_cam_dict_norm.json')
27 | shutil.copy2(cam_dict_fpath, os.path.join(out_split_dir, 'cam_dict_norm.json'))
28 |
29 | cam_dict = json.load(open(cam_dict_fpath))
30 | img_list = list(cam_dict.keys())
31 | img_list = sorted(img_list, key=lambda x: int(x[:-4]))
32 |
33 | use_docker = True
34 |
35 | for index, img_name in enumerate(img_list):
36 | mesh = os.path.join(in_scene_dir, "model.obj")
37 | d_albedo = os.path.join(in_scene_dir, "diffuse_albedo.exr")
38 | s_albedo = os.path.join(in_scene_dir, "specular_albedo.exr")
39 | s_roughness = os.path.join(in_scene_dir, "specular_roughness.exr")
40 |
41 | K = np.array(cam_dict[img_name]["K"]).reshape((4, 4))
42 | focal = K[0, 0]
43 | width, height = cam_dict[img_name]["img_size"]
44 | fov = np.rad2deg(np.arctan(width / 2.0 / focal) * 2.0)
45 | w2c = np.array(cam_dict[img_name]["W2C"]).reshape((4, 4))
46 | # check if unit aspect ratio
47 | assert np.isclose(K[0, 0] - K[1, 1], 0.0), f"{K[0,0]} != {K[1,1]}"
48 |
49 | c2w = np.linalg.inv(w2c)
50 | c2w[:3, :2] *= -1 # mitsuba camera coordinate system: x-->left, y-->up, z-->scene
51 | origin = c2w[:3, 3]
52 | c2w = " ".join([str(x) for x in c2w.flatten().tolist()])
53 |
54 | out_fpath = os.path.join(out_split_dir, 'image', img_name[:-4] + ".exr")
55 | cmd = (
56 | 'mitsuba -b 10 rgb_flash_hdr_mat.xml -D fov={} -D width={} -D height={} -D c2w="{}" '
57 | "-D mesh={} -D d_albedo={} -D s_albedo={} -D s_roughness={} "
58 | "-D light={} "
59 | "-D px={} -D py={} -D pz={} "
60 | "-o {} ".format(
61 | fov,
62 | width,
63 | height,
64 | c2w,
65 | mesh,
66 | d_albedo,
67 | s_albedo,
68 | s_roughness,
69 | light,
70 | origin[0],
71 | origin[1],
72 | origin[2],
73 | out_fpath,
74 | )
75 | )
76 |
77 | if use_docker:
78 | docker_prefix = "docker run -w `pwd` --rm -v `pwd`:`pwd` -v /phoenix:/phoenix ninjaben/mitsuba-rgb "
79 | cmd = docker_prefix + cmd
80 |
81 | os.system(cmd)
82 | os.system("rm mitsuba.*.log")
83 |
84 |
--------------------------------------------------------------------------------
/render_synthetic_data/rgb_flash_hdr_mat.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/render_volume.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import argparse
5 | import numpy as np
6 | import cv2 as cv
7 | import trimesh
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.tensorboard import SummaryWriter
11 | from shutil import copyfile
12 | from icecream import ic
13 | from tqdm import tqdm
14 | from pyhocon import ConfigFactory
15 | from models.dataset import Dataset
16 | from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF
17 | from models.renderer import NeuSRenderer
18 |
19 |
20 | class Runner:
21 | def __init__(self, conf_path, mode="train", case="CASE_NAME", is_continue=False):
22 | self.device = torch.device("cuda")
23 |
24 | # Configuration
25 | self.conf_path = conf_path
26 | f = open(self.conf_path)
27 | conf_text = f.read()
28 | conf_text = conf_text.replace("CASE_NAME", case)
29 | f.close()
30 |
31 | self.conf = ConfigFactory.parse_string(conf_text)
32 | self.conf["dataset.data_dir"] = self.conf["dataset.data_dir"].replace("CASE_NAME", case)
33 | self.base_exp_dir = self.conf["general.base_exp_dir"]
34 | os.makedirs(self.base_exp_dir, exist_ok=True)
35 | self.dataset = Dataset(self.conf["dataset"])
36 | self.iter_step = 0
37 |
38 | # Training parameters
39 | self.end_iter = self.conf.get_int("train.end_iter")
40 | self.save_freq = self.conf.get_int("train.save_freq")
41 | self.report_freq = self.conf.get_int("train.report_freq")
42 | self.val_freq = self.conf.get_int("train.val_freq")
43 | self.val_mesh_freq = self.conf.get_int("train.val_mesh_freq")
44 | self.batch_size = self.conf.get_int("train.batch_size")
45 | self.validate_resolution_level = self.conf.get_int("train.validate_resolution_level")
46 | self.learning_rate = self.conf.get_float("train.learning_rate")
47 | self.learning_rate_alpha = self.conf.get_float("train.learning_rate_alpha")
48 | self.use_white_bkgd = self.conf.get_bool("train.use_white_bkgd")
49 | self.warm_up_end = self.conf.get_float("train.warm_up_end", default=0.0)
50 | self.anneal_end = self.conf.get_float("train.anneal_end", default=0.0)
51 |
52 | # Weights
53 | self.igr_weight = self.conf.get_float("train.igr_weight")
54 | self.mask_weight = self.conf.get_float("train.mask_weight")
55 | self.is_continue = is_continue
56 | self.mode = mode
57 | self.model_list = []
58 | self.writer = None
59 |
60 | # Networks
61 | params_to_train = []
62 | self.nerf_outside = NeRF(**self.conf["model.nerf"]).to(self.device)
63 | self.sdf_network = SDFNetwork(**self.conf["model.sdf_network"]).to(self.device)
64 | self.deviation_network = SingleVarianceNetwork(**self.conf["model.variance_network"]).to(self.device)
65 | self.color_network = RenderingNetwork(**self.conf["model.rendering_network"]).to(self.device)
66 | params_to_train += list(self.nerf_outside.parameters())
67 | params_to_train += list(self.sdf_network.parameters())
68 | params_to_train += list(self.deviation_network.parameters())
69 | params_to_train += list(self.color_network.parameters())
70 |
71 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate)
72 |
73 | self.renderer = NeuSRenderer(
74 | self.nerf_outside,
75 | self.sdf_network,
76 | self.deviation_network,
77 | self.color_network,
78 | **self.conf["model.neus_renderer"]
79 | )
80 |
81 | # Load checkpoint
82 | latest_model_name = None
83 | if is_continue:
84 | model_list_raw = os.listdir(os.path.join(self.base_exp_dir, "checkpoints"))
85 | model_list = []
86 | for model_name in model_list_raw:
87 | if model_name[-3:] == "pth" and int(model_name[5:-4]) <= self.end_iter:
88 | model_list.append(model_name)
89 | model_list.sort()
90 | latest_model_name = model_list[-1]
91 |
92 | if latest_model_name is not None:
93 | logging.info("Find checkpoint: {}".format(latest_model_name))
94 | self.load_checkpoint(latest_model_name)
95 |
96 | # Backup codes and configs for debug
97 | if self.mode[:5] == "train":
98 | self.file_backup()
99 |
100 | def train(self):
101 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, "logs"))
102 | self.update_learning_rate()
103 | res_step = self.end_iter - self.iter_step
104 | image_perm = self.get_image_perm()
105 |
106 | for iter_i in tqdm(range(res_step)):
107 | data = self.dataset.gen_random_rays_at(image_perm[self.iter_step % len(image_perm)], self.batch_size)
108 |
109 | rays_o, rays_d, true_rgb, mask = (
110 | data[:, :3],
111 | data[:, 3:6],
112 | data[:, 6:9],
113 | data[:, 9:10],
114 | )
115 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d)
116 |
117 | background_rgb = None
118 | if self.use_white_bkgd:
119 | background_rgb = torch.ones([1, 3])
120 |
121 | if self.mask_weight > 0.0:
122 | mask = (mask > 0.5).float()
123 | else:
124 | mask = torch.ones_like(mask)
125 |
126 | mask_sum = mask.sum() + 1e-5
127 | render_out = self.renderer.render(
128 | rays_o,
129 | rays_d,
130 | near,
131 | far,
132 | background_rgb=background_rgb,
133 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
134 | )
135 |
136 | color_fine = render_out["color_fine"]
137 | s_val = render_out["s_val"]
138 | cdf_fine = render_out["cdf_fine"]
139 | gradient_error = render_out["gradient_error"]
140 | weight_max = render_out["weight_max"]
141 | weight_sum = render_out["weight_sum"]
142 |
143 | # Loss
144 | color_error = (color_fine - true_rgb) * mask
145 | color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction="sum") / mask_sum
146 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb) ** 2 * mask).sum() / (mask_sum * 3.0)).sqrt())
147 |
148 | eikonal_loss = gradient_error
149 |
150 | mask_loss = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask)
151 |
152 | loss = color_fine_loss + eikonal_loss * self.igr_weight + mask_loss * self.mask_weight
153 |
154 | self.optimizer.zero_grad()
155 | loss.backward()
156 | self.optimizer.step()
157 |
158 | self.iter_step += 1
159 |
160 | self.writer.add_scalar("Loss/loss", loss, self.iter_step)
161 | self.writer.add_scalar("Loss/color_loss", color_fine_loss, self.iter_step)
162 | self.writer.add_scalar("Loss/eikonal_loss", eikonal_loss, self.iter_step)
163 | self.writer.add_scalar("Statistics/s_val", s_val.mean(), self.iter_step)
164 | self.writer.add_scalar(
165 | "Statistics/cdf",
166 | (cdf_fine[:, :1] * mask).sum() / mask_sum,
167 | self.iter_step,
168 | )
169 | self.writer.add_scalar(
170 | "Statistics/weight_max",
171 | (weight_max * mask).sum() / mask_sum,
172 | self.iter_step,
173 | )
174 | self.writer.add_scalar("Statistics/psnr", psnr, self.iter_step)
175 |
176 | if self.iter_step % self.report_freq == 0:
177 | print(self.base_exp_dir)
178 | print("iter:{:8>d} loss = {} lr={}".format(self.iter_step, loss, self.optimizer.param_groups[0]["lr"]))
179 |
180 | if self.iter_step % self.save_freq == 0:
181 | self.save_checkpoint()
182 |
183 | if self.iter_step % self.val_freq == 0:
184 | self.validate_image()
185 |
186 | if self.iter_step % self.val_mesh_freq == 0:
187 | self.validate_mesh()
188 |
189 | self.update_learning_rate()
190 |
191 | if self.iter_step % len(image_perm) == 0:
192 | image_perm = self.get_image_perm()
193 |
194 | def get_image_perm(self):
195 | return torch.randperm(self.dataset.n_images)
196 |
197 | def get_cos_anneal_ratio(self):
198 | if self.anneal_end == 0.0:
199 | return 1.0
200 | else:
201 | return np.min([1.0, self.iter_step / self.anneal_end])
202 |
203 | def update_learning_rate(self):
204 | if self.iter_step < self.warm_up_end:
205 | learning_factor = self.iter_step / self.warm_up_end
206 | else:
207 | alpha = self.learning_rate_alpha
208 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end)
209 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha
210 |
211 | for g in self.optimizer.param_groups:
212 | g["lr"] = self.learning_rate * learning_factor
213 |
214 | def file_backup(self):
215 | dir_lis = self.conf["general.recording"]
216 | os.makedirs(os.path.join(self.base_exp_dir, "recording"), exist_ok=True)
217 | for dir_name in dir_lis:
218 | cur_dir = os.path.join(self.base_exp_dir, "recording", dir_name)
219 | os.makedirs(cur_dir, exist_ok=True)
220 | files = os.listdir(dir_name)
221 | for f_name in files:
222 | if f_name[-3:] == ".py":
223 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
224 |
225 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, "recording", "config.conf"))
226 |
227 | def load_checkpoint(self, checkpoint_name):
228 | checkpoint = torch.load(
229 | os.path.join(self.base_exp_dir, "checkpoints", checkpoint_name),
230 | map_location=self.device,
231 | )
232 | self.nerf_outside.load_state_dict(checkpoint["nerf"])
233 | self.sdf_network.load_state_dict(checkpoint["sdf_network_fine"])
234 | self.deviation_network.load_state_dict(checkpoint["variance_network_fine"])
235 | self.color_network.load_state_dict(checkpoint["color_network_fine"])
236 | self.optimizer.load_state_dict(checkpoint["optimizer"])
237 | self.iter_step = checkpoint["iter_step"]
238 |
239 | logging.info("End")
240 |
241 | def save_checkpoint(self):
242 | checkpoint = {
243 | "nerf": self.nerf_outside.state_dict(),
244 | "sdf_network_fine": self.sdf_network.state_dict(),
245 | "variance_network_fine": self.deviation_network.state_dict(),
246 | "color_network_fine": self.color_network.state_dict(),
247 | "optimizer": self.optimizer.state_dict(),
248 | "iter_step": self.iter_step,
249 | }
250 |
251 | os.makedirs(os.path.join(self.base_exp_dir, "checkpoints"), exist_ok=True)
252 | torch.save(
253 | checkpoint,
254 | os.path.join(
255 | self.base_exp_dir,
256 | "checkpoints",
257 | "ckpt_{:0>6d}.pth".format(self.iter_step),
258 | ),
259 | )
260 |
261 | def validate_image(self, idx=-1, resolution_level=-1):
262 | if idx < 0:
263 | idx = np.random.randint(self.dataset.n_images)
264 |
265 | print("Validate: iter: {}, camera: {}".format(self.iter_step, idx))
266 |
267 | if resolution_level < 0:
268 | resolution_level = self.validate_resolution_level
269 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
270 | H, W, _ = rays_o.shape
271 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
272 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
273 |
274 | out_rgb_fine = []
275 | out_normal_fine = []
276 |
277 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
278 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
279 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
280 |
281 | render_out = self.renderer.render(
282 | rays_o_batch,
283 | rays_d_batch,
284 | near,
285 | far,
286 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
287 | background_rgb=background_rgb,
288 | )
289 |
290 | def feasible(key):
291 | return (key in render_out) and (render_out[key] is not None)
292 |
293 | if feasible("color_fine"):
294 | out_rgb_fine.append(render_out["color_fine"].detach().cpu().numpy())
295 | if feasible("gradients") and feasible("weights"):
296 | n_samples = self.renderer.n_samples + self.renderer.n_importance
297 | normals = render_out["gradients"] * render_out["weights"][:, :n_samples, None]
298 | if feasible("inside_sphere"):
299 | normals = normals * render_out["inside_sphere"][..., None]
300 | normals = normals.sum(dim=1).detach().cpu().numpy()
301 | out_normal_fine.append(normals)
302 | del render_out
303 |
304 | img_fine = None
305 | if len(out_rgb_fine) > 0:
306 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)
307 |
308 | normal_img = None
309 | if len(out_normal_fine) > 0:
310 | normal_img = np.concatenate(out_normal_fine, axis=0)
311 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy())
312 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]).reshape([H, W, 3, -1]) * 128 + 128).clip(
313 | 0, 255
314 | )
315 |
316 | os.makedirs(os.path.join(self.base_exp_dir, "validations_fine"), exist_ok=True)
317 | os.makedirs(os.path.join(self.base_exp_dir, "normals"), exist_ok=True)
318 |
319 | for i in range(img_fine.shape[-1]):
320 | if len(out_rgb_fine) > 0:
321 | cv.imwrite(
322 | os.path.join(
323 | self.base_exp_dir,
324 | "validations_fine",
325 | "{:0>8d}_{}_{}.png".format(self.iter_step, i, idx),
326 | ),
327 | np.concatenate(
328 | [
329 | img_fine[..., i],
330 | self.dataset.image_at(idx, resolution_level=resolution_level),
331 | ]
332 | ),
333 | )
334 | if len(out_normal_fine) > 0:
335 | cv.imwrite(
336 | os.path.join(
337 | self.base_exp_dir,
338 | "normals",
339 | "{:0>8d}_{}_{}.png".format(self.iter_step, i, idx),
340 | ),
341 | normal_img[..., i],
342 | )
343 |
344 | def render_novel_image(self, idx_0, idx_1, ratio, resolution_level):
345 | """
346 | Interpolate view between two cameras.
347 | """
348 | rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level)
349 | H, W, _ = rays_o.shape
350 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
351 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
352 |
353 | out_rgb_fine = []
354 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
355 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
356 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
357 |
358 | render_out = self.renderer.render(
359 | rays_o_batch,
360 | rays_d_batch,
361 | near,
362 | far,
363 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
364 | background_rgb=background_rgb,
365 | )
366 |
367 | out_rgb_fine.append(render_out["color_fine"].detach().cpu().numpy())
368 |
369 | del render_out
370 |
371 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8)
372 | return img_fine
373 |
374 | def validate_mesh(self, world_space=False, resolution=64, threshold=0.0):
375 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32)
376 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32)
377 |
378 | vertices, triangles = self.renderer.extract_geometry(
379 | bound_min, bound_max, resolution=resolution, threshold=threshold
380 | )
381 | os.makedirs(os.path.join(self.base_exp_dir, "meshes"), exist_ok=True)
382 |
383 | if world_space:
384 | vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]
385 |
386 | mesh = trimesh.Trimesh(vertices, triangles)
387 | mesh.export(os.path.join(self.base_exp_dir, "meshes", "{:0>8d}.ply".format(self.iter_step)))
388 |
389 | logging.info("End")
390 |
391 | def interpolate_view(self, img_idx_0, img_idx_1):
392 | images = []
393 | n_frames = 60
394 | for i in range(n_frames):
395 | print(i)
396 | images.append(
397 | self.render_novel_image(
398 | img_idx_0,
399 | img_idx_1,
400 | np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5,
401 | resolution_level=4,
402 | )
403 | )
404 | for i in range(n_frames):
405 | images.append(images[n_frames - i - 1])
406 |
407 | fourcc = cv.VideoWriter_fourcc(*"mp4v")
408 | video_dir = os.path.join(self.base_exp_dir, "render")
409 | os.makedirs(video_dir, exist_ok=True)
410 | h, w, _ = images[0].shape
411 | writer = cv.VideoWriter(
412 | os.path.join(
413 | video_dir,
414 | "{:0>8d}_{}_{}.mp4".format(self.iter_step, img_idx_0, img_idx_1),
415 | ),
416 | fourcc,
417 | 30,
418 | (w, h),
419 | )
420 |
421 | for image in images:
422 | writer.write(image)
423 |
424 | writer.release()
425 |
426 |
427 | if __name__ == "__main__":
428 | print("Hello Wooden")
429 |
430 | torch.set_default_tensor_type("torch.cuda.FloatTensor")
431 |
432 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
433 | logging.basicConfig(level=logging.DEBUG, format=FORMAT)
434 |
435 | parser = argparse.ArgumentParser()
436 | parser.add_argument("--conf", type=str, default="./confs/base.conf")
437 | parser.add_argument("--mode", type=str, default="train")
438 | parser.add_argument("--mcube_threshold", type=float, default=0.0)
439 | parser.add_argument("--is_continue", default=False, action="store_true")
440 | parser.add_argument("--gpu", type=int, default=0)
441 | parser.add_argument("--case", type=str, default="")
442 |
443 | args = parser.parse_args()
444 |
445 | torch.cuda.set_device(args.gpu)
446 | runner = Runner(args.conf, args.mode, args.case, args.is_continue)
447 |
448 | if args.mode == "train":
449 | runner.train()
450 | elif args.mode == "validate_mesh":
451 | runner.validate_mesh(world_space=True, resolution=512, threshold=args.mcube_threshold)
452 | elif args.mode.startswith("interpolate"): # Interpolate views given two image indices
453 | _, img_idx_0, img_idx_1 = args.mode.split("_")
454 | img_idx_0 = int(img_idx_0)
455 | img_idx_1 = int(img_idx_1)
456 | runner.interpolate_view(img_idx_0, img_idx_1)
457 |
--------------------------------------------------------------------------------
/singleview/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/singleview/12.png
--------------------------------------------------------------------------------
/singleview/cam_dict_norm.json:
--------------------------------------------------------------------------------
1 | {
2 | "12.png": {
3 | "K": [
4 | 811.9282694049824,
5 | 0.0,
6 | 256.0,
7 | 0.0,
8 | 0.0,
9 | 811.9282694049824,
10 | 256.0,
11 | 0.0,
12 | 0.0,
13 | 0.0,
14 | 1.0,
15 | 0.0,
16 | 0.0,
17 | 0.0,
18 | 0.0,
19 | 1.0
20 | ],
21 | "W2C": [
22 | 0.998867339183008,
23 | 0.0,
24 | -0.04758191582374219,
25 | 1.5416074755814572e-17,
26 | -0.013163727354886733,
27 | -0.9609695958324571,
28 | -0.27634064516051604,
29 | 1.1553154537250536e-16,
30 | -0.045724774418075535,
31 | 0.27665400030652737,
32 | -0.959881143224937,
33 | 2.0,
34 | 0.0,
35 | 0.0,
36 | 0.0,
37 | 1.0
38 | ],
39 | "img_size": [
40 | 512,
41 | 512
42 | ]
43 | }
44 | }
--------------------------------------------------------------------------------
/test_mitsuba/render_rgb_envmap_mat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | import imageio
5 |
6 | imageio.plugins.freeimage.download()
7 |
8 | import sys
9 |
10 | asset_dir = sys.argv[1]
11 | cam_dict_fpath = sys.argv[2]
12 | envmap_fpath = sys.argv[3]
13 | out_dir = sys.argv[4]
14 |
15 |
16 | d_albedo = os.path.join(asset_dir, "diffuse_albedo.exr")
17 | s_albedo = os.path.join(asset_dir, "specular_albedo.exr")
18 | s_roughness = os.path.join(asset_dir, "roughness.exr")
19 | mesh_fpath = os.path.join(asset_dir, "mesh.obj")
20 |
21 | os.makedirs(out_dir, exist_ok=True)
22 |
23 |
24 | envmap_fpath = os.path.join(asset_dir, "../envmap.exr")
25 |
26 | cam_dict = json.load(open(cam_dict_fpath))
27 |
28 | use_docker = True
29 |
30 | for img_name in list(cam_dict.keys()):
31 | out_fpath = os.path.join(out_dir, img_name[:-4] + ".exr")
32 | K = np.array(cam_dict[img_name]["K"]).reshape((4, 4))
33 | focal = K[0, 0]
34 | width, height = cam_dict[img_name]["img_size"]
35 | fov = np.rad2deg(np.arctan(width / 2.0 / focal) * 2.0)
36 | w2c = np.array(cam_dict[img_name]["W2C"]).reshape((4, 4))
37 |
38 | c2w = np.linalg.inv(w2c)
39 | c2w[:3, :2] *= -1 # mitsuba camera coordinate system: x-->left, y-->up, z-->scene
40 | origin = c2w[:3, 3]
41 | c2w = " ".join([str(x) for x in c2w.flatten().tolist()])
42 |
43 | cmd = (
44 | 'mitsuba -b 10 rgb_envmap_hdr_mat.xml -D fov={} -D width={} -D height={} -D c2w="{}" '
45 | "-D mesh={} -D d_albedo={} -D s_albedo={} -D s_roughness={} "
46 | "-D envmap={} "
47 | "-o {} ".format(fov, width, height, c2w, mesh_fpath, d_albedo, s_albedo, s_roughness, envmap_fpath, out_fpath)
48 | )
49 |
50 | if use_docker:
51 | docker_prefix = "docker run -w `pwd` --rm -v `pwd`:`pwd` -v /phoenix:/phoenix ninjaben/mitsuba-rgb "
52 | cmd = docker_prefix + cmd
53 |
54 | os.system(cmd)
55 | os.system("rm mitsuba.*.log")
56 |
57 | to8b = lambda x: np.uint8(np.clip(x * 255.0, 0.0, 255.0))
58 | im = imageio.imread(out_fpath).astype(np.float32)
59 | imageio.imwrite(out_fpath[:-4] + ".png", to8b(np.power(im, 1.0 / 2.2)))
60 |
--------------------------------------------------------------------------------
/test_mitsuba/render_rgb_flash_mat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | import imageio
5 |
6 | imageio.plugins.freeimage.download()
7 | import sys
8 |
9 |
10 | cam_dict_fpath = sys.argv[1]
11 | asset_dir = sys.argv[2]
12 |
13 | out_dir = os.path.join(asset_dir, "mitsuba_render")
14 | os.makedirs(out_dir, exist_ok=True)
15 |
16 | light = 61.3303 # pony
17 | # light = 28.8344 # girl
18 | # light = 48.5146 # triton
19 | # light = 136.0487 # tree
20 | # light = 14.7209 # dragon
21 |
22 | cam_dict = json.load(open(cam_dict_fpath))
23 | img_list = list(cam_dict.keys())
24 | img_list = sorted(img_list, key=lambda x: int(x[:-4]))
25 |
26 |
27 | use_docker = True
28 |
29 | for index, img_name in enumerate(img_list):
30 | mesh = os.path.join(asset_dir, "mesh.obj")
31 | d_albedo = os.path.join(asset_dir, "diffuse_albedo.exr")
32 | s_albedo = os.path.join(asset_dir, "specular_albedo.exr")
33 | s_roughness = os.path.join(asset_dir, "roughness.exr")
34 |
35 | K = np.array(cam_dict[img_name]["K"]).reshape((4, 4))
36 | focal = K[0, 0]
37 | width, height = cam_dict[img_name]["img_size"]
38 | fov = np.rad2deg(np.arctan(width / 2.0 / focal) * 2.0)
39 | w2c = np.array(cam_dict[img_name]["W2C"]).reshape((4, 4))
40 | # check if unit aspect ratio
41 | assert np.isclose(K[0, 0] - K[1, 1], 0.0), f"{K[0,0]} != {K[1,1]}"
42 |
43 | c2w = np.linalg.inv(w2c)
44 | c2w[:3, :2] *= -1 # mitsuba camera coordinate system: x-->left, y-->up, z-->scene
45 | origin = c2w[:3, 3]
46 | c2w = " ".join([str(x) for x in c2w.flatten().tolist()])
47 |
48 | out_fpath = os.path.join(out_dir, img_name[:-4] + ".exr")
49 | cmd = (
50 | 'mitsuba -b 10 rgb_flash_hdr_mat.xml -D fov={} -D width={} -D height={} -D c2w="{}" '
51 | "-D mesh={} -D d_albedo={} -D s_albedo={} -D s_roughness={} "
52 | "-D light={} "
53 | "-D px={} -D py={} -D pz={} "
54 | "-o {} ".format(
55 | fov,
56 | width,
57 | height,
58 | c2w,
59 | mesh,
60 | d_albedo,
61 | s_albedo,
62 | s_roughness,
63 | light,
64 | origin[0],
65 | origin[1],
66 | origin[2],
67 | out_fpath,
68 | )
69 | )
70 |
71 | if use_docker:
72 | docker_prefix = "docker run -w `pwd` --rm -v `pwd`:`pwd` -v /phoenix:/phoenix ninjaben/mitsuba-rgb "
73 | cmd = docker_prefix + cmd
74 |
75 | os.system(cmd)
76 | os.system("rm mitsuba.*.log")
77 |
78 | to8b = lambda x: np.uint8(np.clip(x * 255.0, 0.0, 255.0))
79 | im = imageio.imread(out_fpath).astype(np.float32)
80 | imageio.imwrite(out_fpath[:-4] + ".png", to8b(np.power(im, 1.0 / 2.2)))
81 |
--------------------------------------------------------------------------------
/test_mitsuba/rgb_envmap_hdr_mat.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/test_mitsuba/rgb_flash_hdr_mat.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/tests/data_singleview/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/tests/data_singleview/12.png
--------------------------------------------------------------------------------
/tests/data_singleview/cam_dict_norm.json:
--------------------------------------------------------------------------------
1 | {
2 | "12.png": {
3 | "K": [
4 | 811.9282694049824,
5 | 0.0,
6 | 256.0,
7 | 0.0,
8 | 0.0,
9 | 811.9282694049824,
10 | 256.0,
11 | 0.0,
12 | 0.0,
13 | 0.0,
14 | 1.0,
15 | 0.0,
16 | 0.0,
17 | 0.0,
18 | 0.0,
19 | 1.0
20 | ],
21 | "W2C": [
22 | 0.998867339183008,
23 | 0.0,
24 | -0.04758191582374219,
25 | 1.5416074755814572e-17,
26 | -0.013163727354886733,
27 | -0.9609695958324571,
28 | -0.27634064516051604,
29 | 1.1553154537250536e-16,
30 | -0.045724774418075535,
31 | 0.27665400030652737,
32 | -0.959881143224937,
33 | 2.0,
34 | 0.0,
35 | 0.0,
36 | 0.0,
37 | 1.0
38 | ],
39 | "img_size": [
40 | 512,
41 | 512
42 | ]
43 | }
44 | }
--------------------------------------------------------------------------------
/tests/test_raytracer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 | import trimesh
6 | import imageio
7 |
8 | imageio.plugins.freeimage.download()
9 |
10 | from icecream import ic
11 | import sys
12 |
13 | sys.path.append("../")
14 |
15 | from models.fields import SDFNetwork, RenderingNetwork
16 | import models.raytracer
17 |
18 | models.raytracer.VERBOSE_MODE = True
19 | from models.raytracer import RayTracer, Camera, render_camera
20 |
21 |
22 | def to8b(x):
23 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8)
24 |
25 |
26 | sdf_network = SDFNetwork(
27 | d_in=3,
28 | d_out=257,
29 | d_hidden=256,
30 | n_layers=8,
31 | skip_in=[
32 | 4,
33 | ],
34 | multires=6,
35 | bias=0.5,
36 | scale=1.0,
37 | geometric_init=True,
38 | weight_norm=True,
39 | ).cuda()
40 | color_network = RenderingNetwork(
41 | d_in=9,
42 | d_out=3,
43 | d_feature=256,
44 | d_hidden=256,
45 | n_layers=4,
46 | multires_view=4,
47 | mode="idr",
48 | squeeze_out=True,
49 | ).cuda()
50 | raytracer = RayTracer()
51 |
52 | scene = "dtu_scan69"
53 | ckpt_fpath = f"../exp/{scene}/womask_sphere/checkpoints/ckpt_300000.pth"
54 |
55 | ckpt = torch.load(ckpt_fpath, map_location=torch.device("cuda"))
56 | sdf_network.load_state_dict(ckpt["sdf_network_fine"])
57 | color_network.load_state_dict(ckpt["color_network_fine"])
58 |
59 | color_network_dict = {"color_network": color_network}
60 |
61 |
62 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features):
63 | interior_color = color_network_dict["color_network"](points, normals, ray_d, features) # [..., [2, 0, 1]]
64 |
65 | dots_sh = list(interior_mask.shape)
66 | color = torch.zeros(
67 | dots_sh
68 | + [
69 | 3,
70 | ],
71 | dtype=torch.float32,
72 | device=interior_mask.device,
73 | )
74 | color[interior_mask] = interior_color
75 |
76 | normals_pad = torch.zeros(
77 | dots_sh
78 | + [
79 | 3,
80 | ],
81 | dtype=torch.float32,
82 | device=interior_mask.device,
83 | )
84 | normals_pad[interior_mask] = normals
85 | return {"color": color, "normal": normals_pad}
86 |
87 |
88 | def load_datadir(data_dir):
89 | from glob import glob
90 | from models.dataset import load_K_Rt_from_P
91 |
92 | camera_dict = np.load(os.path.join(data_dir, "cameras_sphere.npz"))
93 | images_lis = sorted(glob(os.path.join(data_dir, "image/*.png")))
94 | n_images = len(images_lis)
95 | images = np.stack([imageio.imread(im_name) for im_name in images_lis]) / 255.0
96 | images = torch.from_numpy(images).float()
97 | # world_mat is a projection matrix from world to image
98 | world_mats_np = [camera_dict["world_mat_%d" % idx].astype(np.float32) for idx in range(n_images)]
99 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
100 | scale_mats_np = [camera_dict["scale_mat_%d" % idx].astype(np.float32) for idx in range(n_images)]
101 | intrinsics_all = []
102 | pose_all = []
103 | for scale_mat, world_mat in zip(scale_mats_np, world_mats_np):
104 | P = world_mat @ scale_mat
105 | P = P[:3, :4]
106 | intrinsics, pose = load_K_Rt_from_P(None, P)
107 | intrinsics_all.append(torch.from_numpy(intrinsics).float())
108 | pose_all.append(torch.from_numpy(pose).float())
109 | intrinsics_all = torch.stack(intrinsics_all, dim=0)
110 | pose_all = torch.stack(pose_all, dim=0) # C2W
111 | pose_all = torch.inverse(pose_all)
112 |
113 | ic(images.shape, intrinsics_all.shape, pose_all.shape)
114 | return images, intrinsics_all, pose_all
115 |
116 |
117 | gt_images, Ks, W2Cs = load_datadir(f"../public_data/{scene}")
118 |
119 | img_idx = 10
120 | gt_color = gt_images[img_idx]
121 | camera = Camera(W=gt_color.shape[1], H=gt_color.shape[0], K=Ks[img_idx].cuda(), W2C=W2Cs[img_idx].cuda())
122 |
123 | fill_holes = False
124 | handle_edges = True
125 | is_training = False
126 | out_dir = f"./debug_raytracer_{scene}_{fill_holes}_{handle_edges}_{is_training}"
127 | ic(out_dir)
128 | os.makedirs(out_dir, exist_ok=True)
129 |
130 | if is_training:
131 | camera, gt_color = camera.crop_region(trgt_W=256, trgt_H=256, center_crop=True, image=gt_color)
132 | ic(gt_color.shape, camera.H, camera.W)
133 |
134 | results = render_camera(
135 | camera,
136 | sdf_network,
137 | raytracer,
138 | color_network_dict,
139 | render_fn,
140 | fill_holes=fill_holes,
141 | handle_edges=handle_edges,
142 | is_training=is_training,
143 | )
144 |
145 | for x in list(results.keys()):
146 | results[x] = results[x].detach().cpu().numpy()
147 |
148 |
149 | def append_allones(x):
150 | return np.concatenate((x, np.ones_like(x[..., 0:1])), axis=-1)
151 |
152 |
153 | imageio.imwrite(os.path.join(out_dir, "convergent_mask.png"), to8b(results["convergent_mask"]))
154 | imageio.imwrite(os.path.join(out_dir, "distance.exr"), results["distance"])
155 | imageio.imwrite(os.path.join(out_dir, "depth.exr"), results["depth"])
156 | imageio.imwrite(os.path.join(out_dir, "sdf.exr"), results["sdf"])
157 | imageio.imwrite(os.path.join(out_dir, "points.exr"), results["points"])
158 | imageio.imwrite(os.path.join(out_dir, "normal.png"), to8b((results["normal"] + 1.0) / 2.0))
159 | imageio.imwrite(os.path.join(out_dir, "normal.exr"), results["normal"])
160 | imageio.imwrite(os.path.join(out_dir, "color.png"), to8b(results["color"])[..., ::-1])
161 | imageio.imwrite(os.path.join(out_dir, "color_gt.png"), to8b(gt_color.detach().cpu().numpy()))
162 | imageio.imwrite(os.path.join(out_dir, "uv.exr"), append_allones(results["uv"]))
163 |
164 | imageio.imwrite(os.path.join(out_dir, "depth_grad_norm.exr"), results["depth_grad_norm"])
165 | imageio.imwrite(os.path.join(out_dir, "depth_edge_mask.png"), to8b(results["depth_edge_mask"]))
166 | imageio.imwrite(
167 | os.path.join(out_dir, "walk_edge_found_mask.png"),
168 | to8b(results["walk_edge_found_mask"]),
169 | )
170 | trimesh.PointCloud(results["edge_points"].reshape((-1, 3))).export(os.path.join(out_dir, "edge_points.ply"))
171 | imageio.imwrite(os.path.join(out_dir, "edge_mask.png"), to8b(results["edge_mask"]))
172 | imageio.imwrite(os.path.join(out_dir, "edge_pos_side_weight.exr"), results["edge_pos_side_weight"])
173 | imageio.imwrite(os.path.join(out_dir, "edge_angles.exr"), results["edge_angles"])
174 | imageio.imwrite(os.path.join(out_dir, "edge_sdf.exr"), results["edge_sdf"])
175 | imageio.imwrite(os.path.join(out_dir, "edge_pos_side_depth.exr"), results["edge_pos_side_depth"])
176 | imageio.imwrite(os.path.join(out_dir, "edge_neg_side_depth.exr"), results["edge_neg_side_depth"])
177 | imageio.imwrite(os.path.join(out_dir, "edge_pos_side_color.png"), to8b(results["edge_pos_side_color"])[..., ::-1])
178 | imageio.imwrite(os.path.join(out_dir, "edge_neg_side_color.png"), to8b(results["edge_neg_side_color"])[..., ::-1])
179 |
--------------------------------------------------------------------------------
/tests/test_singleview.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from unittest import result
4 | import numpy as np
5 | import torch
6 | import trimesh
7 | import json
8 | import imageio
9 |
10 | imageio.plugins.freeimage.download()
11 |
12 | from icecream import ic
13 | import sys
14 |
15 | sys.path.append("../")
16 |
17 | from models.fields import SDFNetwork
18 | import models.raytracer
19 |
20 | models.raytracer.VERBOSE_MODE = False
21 | from models.raytracer import RayTracer, Camera, render_camera
22 |
23 |
24 | def to8b(x):
25 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8)
26 |
27 |
28 | sdf_network = SDFNetwork(
29 | d_in=3,
30 | d_out=257,
31 | d_hidden=256,
32 | n_layers=8,
33 | skip_in=[
34 | 4,
35 | ],
36 | multires=6,
37 | bias=0.5,
38 | scale=1.0,
39 | geometric_init=True,
40 | weight_norm=True,
41 | ).cuda()
42 | raytracer = RayTracer()
43 |
44 | color_network_dict = {}
45 |
46 |
47 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features):
48 | dots_sh = list(interior_mask.shape)
49 | color = torch.zeros(
50 | dots_sh
51 | + [
52 | 3,
53 | ],
54 | dtype=torch.float32,
55 | device=interior_mask.device,
56 | )
57 | normals_pad = torch.zeros(
58 | dots_sh
59 | + [
60 | 3,
61 | ],
62 | dtype=torch.float32,
63 | device=interior_mask.device,
64 | )
65 | if interior_mask.any():
66 | interior_color = (
67 | torch.ones_like(points.view(-1, 3))
68 | * torch.Tensor([[237.0 / 255.0, 61.0 / 255.0, 100.0 / 255.0]]).float().cuda()
69 | )
70 | interior_color = interior_color.view(list(points.shape))
71 | color[interior_mask] = interior_color
72 | normals_pad[interior_mask] = normals
73 |
74 | return {"color": color, "normal": normals_pad}
75 |
76 |
77 | gt_color = imageio.imread("./data_singleview/12.png").astype(np.float32) / 255.0
78 | gt_color = torch.from_numpy(gt_color).cuda()
79 |
80 | cam_dict = json.load(open("./data_singleview/cam_dict_norm.json"))
81 | K = torch.from_numpy(np.array(cam_dict["12.png"]["K"]).reshape((4, 4)).astype(np.float32)).cuda()
82 | W2C = torch.from_numpy(np.array(cam_dict["12.png"]["W2C"]).reshape((4, 4)).astype(np.float32)).cuda()
83 | W, H = cam_dict["12.png"]["img_size"]
84 |
85 | camera = Camera(W=W, H=H, K=K, W2C=W2C)
86 |
87 | fill_holes = False
88 | handle_edges = True
89 | is_training = True
90 | out_dir = f"./debug_singleview_{fill_holes}_{handle_edges}_{is_training}"
91 | ic(out_dir)
92 | os.makedirs(out_dir, exist_ok=True)
93 |
94 | sdf_optimizer = torch.optim.Adam(sdf_network.parameters(), lr=1e-4)
95 |
96 | for global_step in range(15000):
97 | sdf_optimizer.zero_grad()
98 |
99 | camera_crop, gt_color_crop = camera.crop_region(trgt_W=128, trgt_H=128, image=gt_color)
100 |
101 | results = render_camera(
102 | camera_crop,
103 | sdf_network,
104 | raytracer,
105 | color_network_dict,
106 | render_fn,
107 | fill_holes=fill_holes,
108 | handle_edges=handle_edges,
109 | is_training=is_training,
110 | )
111 |
112 | mask = results["convergent_mask"]
113 | if handle_edges:
114 | # mask = mask | results["edge_mask"]
115 | mask = results["edge_mask"]
116 |
117 | img_loss = torch.Tensor(
118 | [
119 | 0.0,
120 | ]
121 | ).cuda()
122 | rand_eik_points = torch.empty(camera_crop.H * camera_crop.W // 2, 3).cuda().float().uniform_(-1.0, 1.0)
123 | eik_grad = sdf_network.gradient(rand_eik_points).view(-1, 3)
124 |
125 | if mask.any():
126 | img_loss = ((results["color"][mask] - gt_color_crop[mask]) ** 2).mean()
127 | interior_normals = results["normal"][mask | results["convergent_mask"]]
128 | eik_grad = torch.cat([eik_grad, interior_normals], dim=0)
129 | if "edge_pos_neg_normal" in results:
130 | eik_grad = torch.cat([eik_grad, results["edge_pos_neg_normal"]], dim=0)
131 | eik_loss = ((eik_grad.norm(dim=-1) - 1) ** 2).mean()
132 |
133 | loss = img_loss + 0.1 * eik_loss
134 | loss.backward()
135 | sdf_optimizer.step()
136 |
137 | if global_step % 200 == 0:
138 | ic(global_step, loss.item(), img_loss.item(), eik_loss.item())
139 | for x in list(results.keys()):
140 | del results[x]
141 |
142 | camera_resize, gt_color_resize = camera.resize(factor=0.25, image=gt_color)
143 | results = render_camera(
144 | camera_resize,
145 | sdf_network,
146 | raytracer,
147 | color_network_dict,
148 | render_fn,
149 | fill_holes=fill_holes,
150 | handle_edges=handle_edges,
151 | is_training=False,
152 | )
153 | for x in list(results.keys()):
154 | results[x] = results[x].detach().cpu().numpy()
155 |
156 | gt_color_im = gt_color_resize.detach().cpu().numpy()
157 | color_im = results["color"]
158 | normal = results["normal"]
159 | normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-10)
160 | normal_im = (normal + 1.0) / 2.0
161 | edge_mask_im = np.tile(results["edge_mask"][:, :, np.newaxis], (1, 1, 3))
162 | im = np.concatenate([gt_color_im, color_im, normal_im, edge_mask_im], axis=1)
163 | imageio.imwrite(os.path.join(out_dir, f"logim_{global_step}.png"), to8b(im))
164 |
165 | torch.save(sdf_network.state_dict(), os.path.join(out_dir, "ckpt.pth"))
166 |
--------------------------------------------------------------------------------
/tests/test_viewsynthesis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import trimesh
7 | import json
8 | import imageio
9 | from torch.utils.tensorboard import SummaryWriter
10 | import configargparse
11 | from icecream import ic
12 | import glob
13 |
14 | import sys
15 |
16 | sys.path.append("../")
17 |
18 | from models.fields import SDFNetwork, RenderingNetwork
19 | from models.raytracer import RayTracer, Camera, render_camera
20 | from models.renderer_ggx import GGXColocatedRenderer
21 | from models.image_losses import PyramidL2Loss, ssim_loss_fn
22 |
23 |
24 | def config_parser():
25 | parser = configargparse.ArgumentParser()
26 | parser.add_argument("--data_dir", type=str, default=None, help="input data directory")
27 | parser.add_argument("--out_dir", type=str, default=None, help="output directory")
28 | # parser.add_argument("--neus_ckpt_fpath", type=str, default=None, help="checkpoint to load")
29 | parser.add_argument("--num_iters", type=int, default=100001, help="number of iterations")
30 | # parser.add_argument("--white_specular_albedo", action='store_true', help='force specular albedo to be white')
31 | parser.add_argument("--eik_weight", type=float, default=0.1, help="weight for eikonal loss")
32 | parser.add_argument("--ssim_weight", type=float, default=1.0, help="weight for ssim loss")
33 | parser.add_argument("--roughrange_weight", type=float, default=0.1, help="weight for roughness range loss")
34 |
35 | parser.add_argument("--plot_image_name", type=str, default=None, help="image to plot during training")
36 | parser.add_argument("--no_edgesample", action="store_true", help="whether to disable edge sampling")
37 |
38 | return parser
39 |
40 |
41 | parser = config_parser()
42 | args = parser.parse_args()
43 | ic(args)
44 |
45 |
46 | def to8b(x):
47 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8)
48 |
49 |
50 | ggx_renderer = GGXColocatedRenderer(use_cuda=True)
51 | pyramidl2_loss_fn = PyramidL2Loss(use_cuda=True)
52 |
53 |
54 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features):
55 | dots_sh = list(interior_mask.shape)
56 | color = torch.zeros(
57 | dots_sh
58 | + [
59 | 3,
60 | ],
61 | dtype=torch.float32,
62 | device=interior_mask.device,
63 | )
64 | normals_pad = color.clone()
65 | if interior_mask.any():
66 | normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-10)
67 | interior_color = color_network_dict["color_network"](points, normals, ray_d, features)
68 |
69 | color[interior_mask] = interior_color
70 | normals_pad[interior_mask] = normals
71 |
72 | return {
73 | "color": color,
74 | "normal": normals_pad,
75 | }
76 |
77 |
78 | sdf_network = SDFNetwork(
79 | d_in=3,
80 | d_out=257,
81 | d_hidden=256,
82 | n_layers=8,
83 | skip_in=[
84 | 4,
85 | ],
86 | multires=6,
87 | bias=0.5,
88 | scale=1.0,
89 | geometric_init=True,
90 | weight_norm=True,
91 | ).cuda()
92 | raytracer = RayTracer()
93 |
94 |
95 | color_network_dict = {
96 | "color_network": RenderingNetwork(
97 | d_in=9,
98 | d_out=3,
99 | d_feature=256,
100 | d_hidden=256,
101 | n_layers=8,
102 | multires=10,
103 | multires_view=4,
104 | mode="idr",
105 | squeeze_out=True,
106 | skip_in=(4,),
107 | ).cuda()
108 | }
109 |
110 | sdf_optimizer = torch.optim.Adam(sdf_network.parameters(), lr=1e-5)
111 | color_optimizer_dict = {"color_network": torch.optim.Adam(color_network_dict["color_network"].parameters(), lr=1e-4)}
112 |
113 |
114 | def load_datadir(datadir):
115 | cam_dict = json.load(open(os.path.join(datadir, "cam_dict_norm.json")))
116 | imgnames = list(cam_dict.keys())
117 | try:
118 | imgnames = sorted(imgnames, key=lambda x: int(x[:-4]))
119 | except:
120 | imgnames = sorted(imgnames)
121 |
122 | image_fpaths = []
123 | gt_images = []
124 | Ks = []
125 | W2Cs = []
126 | for x in imgnames:
127 | fpath = os.path.join(datadir, "image", x)
128 | assert fpath[-4:] in [".jpg", ".png"], "must use ldr images as inputs"
129 | im = imageio.imread(fpath).astype(np.float32) / 255.0
130 | K = np.array(cam_dict[x]["K"]).reshape((4, 4)).astype(np.float32)
131 | W2C = np.array(cam_dict[x]["W2C"]).reshape((4, 4)).astype(np.float32)
132 |
133 | image_fpaths.append(fpath)
134 | gt_images.append(torch.from_numpy(im))
135 | Ks.append(torch.from_numpy(K))
136 | W2Cs.append(torch.from_numpy(W2C))
137 | gt_images = torch.stack(gt_images, dim=0)
138 | Ks = torch.stack(Ks, dim=0)
139 | W2Cs = torch.stack(W2Cs, dim=0)
140 | return image_fpaths, gt_images, Ks, W2Cs
141 |
142 |
143 | image_fpaths, gt_images, Ks, W2Cs = load_datadir(args.data_dir)
144 | cameras = [
145 | Camera(W=gt_images[i].shape[1], H=gt_images[i].shape[0], K=Ks[i].cuda(), W2C=W2Cs[i].cuda())
146 | for i in range(gt_images.shape[0])
147 | ]
148 | ic(len(image_fpaths), gt_images.shape, Ks.shape, W2Cs.shape, len(cameras))
149 |
150 | #### load pretrained checkpoints
151 | start_step = -1
152 | ckpt_fpaths = glob.glob(os.path.join(args.out_dir, "ckpt_*.pth"))
153 | if len(ckpt_fpaths) > 0:
154 | path2step = lambda x: int(os.path.basename(x)[len("ckpt_") : -4])
155 | ckpt_fpaths = sorted(ckpt_fpaths, key=path2step)
156 | ckpt_fpath = ckpt_fpaths[-1]
157 | start_step = path2step(ckpt_fpath)
158 | ic("Reloading from checkpoint: ", ckpt_fpath)
159 | ckpt = torch.load(ckpt_fpath, map_location=torch.device("cuda"))
160 | sdf_network.load_state_dict(ckpt["sdf_network"])
161 | for x in list(color_network_dict.keys()):
162 | color_network_dict[x].load_state_dict(ckpt[x])
163 | # logim_names = [os.path.basename(x) for x in glob.glob(os.path.join(args.out_dir, "logim_*.png"))]
164 | # start_step = sorted([int(x[len("logim_") : -4]) for x in logim_names])[-1]
165 |
166 | ic(start_step)
167 |
168 | fill_holes = False
169 | handle_edges = not args.no_edgesample
170 | is_training = True
171 | inv_gamma_gt = False
172 | if inv_gamma_gt:
173 | ic("linearizing ground-truth images using inverse gamma correction")
174 | gt_images = torch.pow(gt_images, 2.2)
175 |
176 | ic(fill_holes, handle_edges, is_training, inv_gamma_gt)
177 | os.makedirs(args.out_dir, exist_ok=True)
178 | writer = SummaryWriter(log_dir=os.path.join(args.out_dir, "logs"))
179 |
180 |
181 | for global_step in tqdm.tqdm(range(start_step + 1, args.num_iters)):
182 | sdf_optimizer.zero_grad()
183 | for x in color_optimizer_dict.keys():
184 | color_optimizer_dict[x].zero_grad()
185 |
186 | idx = np.random.randint(0, gt_images.shape[0])
187 | camera_crop, gt_color_crop = cameras[idx].crop_region(trgt_W=128, trgt_H=128, image=gt_images[idx])
188 |
189 | results = render_camera(
190 | camera_crop,
191 | sdf_network,
192 | raytracer,
193 | color_network_dict,
194 | render_fn,
195 | fill_holes=fill_holes,
196 | handle_edges=handle_edges,
197 | is_training=is_training,
198 | )
199 |
200 | mask = results["convergent_mask"]
201 | if handle_edges:
202 | mask = mask | results["edge_mask"]
203 |
204 | img_loss = torch.Tensor([0.0]).cuda()
205 | img_l2_loss = torch.Tensor([0.0]).cuda()
206 | img_ssim_loss = torch.Tensor([0.0]).cuda()
207 |
208 | eik_points = torch.empty(camera_crop.H * camera_crop.W // 2, 3).cuda().float().uniform_(-1.0, 1.0)
209 | eik_grad = sdf_network.gradient(eik_points).view(-1, 3)
210 | eik_cnt = eik_grad.shape[0]
211 | eik_loss = ((eik_grad.norm(dim=-1) - 1) ** 2).sum()
212 | if mask.any():
213 | pred_img = results["color"].permute(2, 0, 1).unsqueeze(0)
214 | gt_img = gt_color_crop.permute(2, 0, 1).unsqueeze(0).to(pred_img.device)
215 | img_l2_loss = pyramidl2_loss_fn(pred_img, gt_img)
216 | img_ssim_loss = args.ssim_weight * ssim_loss_fn(pred_img, gt_img, mask.unsqueeze(0).unsqueeze(0))
217 | img_loss = img_l2_loss + img_ssim_loss
218 |
219 | eik_grad = results["normal"][mask]
220 | eik_cnt += eik_grad.shape[0]
221 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum()
222 | if "edge_pos_neg_normal" in results:
223 | eik_grad = results["edge_pos_neg_normal"]
224 | eik_cnt += eik_grad.shape[0]
225 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum()
226 |
227 | eik_loss = eik_loss / eik_cnt * args.eik_weight
228 |
229 | loss = img_loss + eik_loss
230 | loss.backward()
231 | sdf_optimizer.step()
232 | for x in color_optimizer_dict.keys():
233 | color_optimizer_dict[x].step()
234 |
235 | if global_step % 50 == 0:
236 | writer.add_scalar("loss/loss", loss, global_step)
237 | writer.add_scalar("loss/img_loss", img_loss, global_step)
238 | writer.add_scalar("loss/img_l2_loss", img_l2_loss, global_step)
239 | writer.add_scalar("loss/img_ssim_loss", img_ssim_loss, global_step)
240 | writer.add_scalar("loss/eik_loss", eik_loss, global_step)
241 |
242 | if global_step % 1000 == 0:
243 | torch.save(
244 | dict(
245 | [
246 | ("sdf_network", sdf_network.state_dict()),
247 | ]
248 | + [(x, color_network_dict[x].state_dict()) for x in color_network_dict.keys()]
249 | ),
250 | os.path.join(args.out_dir, f"ckpt_{global_step}.pth"),
251 | )
252 |
253 | if global_step % 500 == 0:
254 | ic(
255 | args.out_dir,
256 | global_step,
257 | loss.item(),
258 | img_loss.item(),
259 | img_l2_loss.item(),
260 | img_ssim_loss.item(),
261 | eik_loss.item(),
262 | )
263 |
264 | for x in list(results.keys()):
265 | del results[x]
266 |
267 | idx = 0
268 | if args.plot_image_name is not None:
269 | while idx < len(image_fpaths):
270 | if args.plot_image_name in image_fpaths[idx]:
271 | break
272 | idx += 1
273 |
274 | camera_resize, gt_color_resize = cameras[idx].resize(factor=0.25, image=gt_images[idx])
275 | results = render_camera(
276 | camera_resize,
277 | sdf_network,
278 | raytracer,
279 | color_network_dict,
280 | render_fn,
281 | fill_holes=fill_holes,
282 | handle_edges=handle_edges,
283 | is_training=False,
284 | )
285 | for x in list(results.keys()):
286 | results[x] = results[x].detach().cpu().numpy()
287 |
288 | gt_color_im = gt_color_resize.detach().cpu().numpy()
289 | color_im = results["color"]
290 | normal = results["normal"]
291 | normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-10)
292 | normal_im = (normal + 1.0) / 2.0
293 | edge_mask_im = np.tile(results["edge_mask"][:, :, np.newaxis], (1, 1, 3))
294 | if inv_gamma_gt:
295 | gt_color_im = np.power(gt_color_im + 1e-6, 1.0 / 2.2)
296 | color_im = np.power(color_im + 1e-6, 1.0 / 2.2)
297 |
298 | im = np.concatenate([gt_color_im, color_im, normal_im, edge_mask_im], axis=1)
299 | imageio.imwrite(os.path.join(args.out_dir, f"logim_{global_step}.png"), to8b(im))
300 |
--------------------------------------------------------------------------------
/train_scene.sh:
--------------------------------------------------------------------------------
1 | SCENE=$1
2 |
3 | python render_volume.py --mode train --conf ./confs/womask_iron.conf --case ${SCENE}
4 |
5 | python render_surface.py --data_dir ./data_flashlight/${SCENE}/train \
6 | --out_dir ./exp_iron_stage2/${SCENE} \
7 | --neus_ckpt_fpath ./exp_iron_stage1/${SCENE}/checkpoints/ckpt_100000.pth \
8 | --num_iters 50001 --gamma_pred
9 | # render test set
10 | python render_surface.py --data_dir ./data_flashlight/${SCENE}/test \
11 | --out_dir ./exp_iron_stage2/${SCENE} \
12 | --neus_ckpt_fpath ./exp_iron_stage1/${SCENE}/checkpoints/ckpt_100000.pth \
13 | --num_iters 50001 --gamma_pred --render_all
14 |
--------------------------------------------------------------------------------