├── .gitignore ├── LICENSE ├── README.md ├── confs └── womask_iron.conf ├── create_env.sh ├── download_data.sh ├── evaluation ├── eval_image_folder.py └── eval_mesh.py ├── models ├── dataset.py ├── embedder.py ├── export_materials.py ├── export_mesh.py ├── export_uv.py ├── fields.py ├── ggx │ ├── ext_mts_rtrans_data.txt │ └── int_mts_diff_rtrans_data.txt ├── image_losses.py ├── raytracer.py ├── renderer.py └── renderer_ggx.py ├── readme_resources ├── assets_lowres.png └── inputs_outputs.png ├── render_surface.py ├── render_synthetic_data ├── render_rgb_flash_mat.py └── rgb_flash_hdr_mat.xml ├── render_volume.py ├── singleview ├── 12.png └── cam_dict_norm.json ├── test_mitsuba ├── render_rgb_envmap_mat.py ├── render_rgb_flash_mat.py ├── rgb_envmap_hdr_mat.xml └── rgb_flash_hdr_mat.xml ├── tests ├── data_singleview │ ├── 12.png │ └── cam_dict_norm.json ├── test_raytracer.py ├── test_singleview.py └── test_viewsynthesis.py └── train_scene.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # IPython Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | 83 | # virtualenv 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | ### VirtualEnv template 93 | # Virtualenv 94 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 95 | .Python 96 | [Bb]in 97 | [Ii]nclude 98 | [Ll]ib 99 | [Ll]ib64 100 | [Ll]ocal 101 | [Ss]cripts 102 | pyvenv.cfg 103 | .venv 104 | pip-selfcheck.json 105 | ### JetBrains template 106 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 107 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 108 | 109 | # User-specific stuff: 110 | .idea/workspace.xml 111 | .idea/tasks.xml 112 | .idea/dictionaries 113 | .idea/vcs.xml 114 | .idea/jsLibraryMappings.xml 115 | 116 | # Sensitive or high-churn files: 117 | .idea/dataSources.ids 118 | .idea/dataSources.xml 119 | .idea/dataSources.local.xml 120 | .idea/sqlDataSources.xml 121 | .idea/dynamic.xml 122 | .idea/uiDesigner.xml 123 | 124 | # Gradle: 125 | .idea/gradle.xml 126 | .idea/libraries 127 | 128 | # Mongo Explorer plugin: 129 | .idea/mongoSettings.xml 130 | 131 | .idea/ 132 | 133 | ## File-based project format: 134 | *.iws 135 | 136 | ## Plugin-specific files: 137 | 138 | # IntelliJ 139 | /out/ 140 | 141 | # mpeltonen/sbt-idea plugin 142 | .idea_modules/ 143 | 144 | # JIRA plugin 145 | atlassian-ide-plugin.xml 146 | 147 | # Crashlytics plugin (for Android Studio and IntelliJ) 148 | com_crashlytics_export_strings.xml 149 | crashlytics.properties 150 | crashlytics-build.properties 151 | fabric.properties 152 | 153 | data 154 | public_data 155 | exp 156 | tmp 157 | 158 | */debug_raytracer* 159 | */debug_singleview* 160 | */debug_multiview* 161 | */debug_inverse_rendering* 162 | */debug_viewsynthesis* 163 | */*/.DS_Store 164 | */*/test_buddha 165 | */*/test_kitty 166 | exp_iron* 167 | blender* 168 | data_flashlight* 169 | */*/*.log -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2022, Kai Zhang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IRON: Inverse Rendering by Optimizing Neural SDFs and Materials from Photometric Images 2 | 3 | Note: this repo is still under construction. 4 | 5 | Project page: 6 | 7 | ![example results](./readme_resources/inputs_outputs.png) 8 | 9 | ## Usage 10 | 11 | ### Create environment 12 | 13 | ```shell 14 | git clone https://github.com/Kai-46/iron.git && cd iron && . ./create_env.sh 15 | ``` 16 | 17 | ### Download data 18 | 19 | ```shell 20 | . ./download_data.sh 21 | ``` 22 | 23 | ### Training and testing 24 | 25 | ```shell 26 | . ./train_scene.sh drv/dragon 27 | ``` 28 | 29 | Once training is done, you will see the recovered mesh and materials under the folder ```./exp_iron_stage2/drv/dragon/mesh_and_materials_50000/```. At the same time, the rendered test images are under the folder ``````./exp_iron_stage2/drv/dragon/render_test_50000/`````` 30 | 31 | ### Relight the 3D assets using envmaps 32 | 33 | Check ```test_mitsuba/render_rgb_envmap_mat.py```. 34 | 35 | ### Evaluation 36 | 37 | Check ```evaluation/eval_mesh.py``` and ```evaluation/eval_image_folder.py```. 38 | 39 | ### Render synthetic data using Mitsuba 40 | 41 | Check ```render_synthetic_data/render_rgb_flash_mat.py```. To make renderings more shiny, try scaling up the specular albedo and scaling down the specular roughness; to make renderings more diffuse, try the opposite. 42 | 43 | ### Camera parameters convention 44 | 45 | We use the OpenCV camera convention just like [NeRF++](https://github.com/Kai-46/nerfplusplus); you might want to use the camera visualization and debugging tools in that codebase to inspect if there's any issue with the camera parameters. Note we also assume the objects are inside the unit sphere. 46 | 47 | ## Citations 48 | 49 | ``` 50 | @inproceedings{iron-2022, 51 | title={IRON: Inverse Rendering by Optimizing Neural SDFs and Materials from Photometric Images}, 52 | author={Zhang, Kai and Luan, Fujun and Li, Zhengqi and Snavely, Noah}, 53 | booktitle={IEEE Conf. Comput. Vis. Pattern Recog.}, 54 | year={2022} 55 | } 56 | ``` 57 | 58 | ## Example results 59 | 60 | 61 | 62 | ![example results](./readme_resources/assets_lowres.png) 63 | 64 | ## Acknowledgements 65 | 66 | We would like to thank the authors of [IDR](https://github.com/lioryariv/idr) and [NeuS](https://github.com/Totoro97/NeuS) for open-sourcing their projects. 67 | -------------------------------------------------------------------------------- /confs/womask_iron.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp_iron_stage1/CASE_NAME/ 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = ./data_flashlight/CASE_NAME/train/ 11 | render_cameras_name = cameras_sphere.npz 12 | object_cameras_name = cameras_sphere.npz 13 | } 14 | 15 | train { 16 | learning_rate = 5e-4 17 | learning_rate_alpha = 0.05 18 | end_iter = 100001 19 | 20 | batch_size = 512 21 | validate_resolution_level = 4 22 | warm_up_end = 5000 23 | anneal_end = 50000 24 | use_white_bkgd = False 25 | 26 | save_freq = 10000 27 | val_freq = 2500 28 | val_mesh_freq = 5000 29 | report_freq = 100 30 | 31 | igr_weight = 0.1 32 | mask_weight = 0.0 33 | } 34 | 35 | model { 36 | nerf { 37 | D = 8, 38 | d_in = 4, 39 | d_in_view = 3, 40 | W = 256, 41 | multires = 10, 42 | multires_view = 4, 43 | output_ch = 4, 44 | skips=[4], 45 | use_viewdirs=True 46 | } 47 | 48 | sdf_network { 49 | d_out = 257 50 | d_in = 3 51 | d_hidden = 256 52 | n_layers = 8 53 | skip_in = [4] 54 | multires = 6 55 | bias = 0.5 56 | scale = 1.0 57 | geometric_init = True 58 | weight_norm = True 59 | } 60 | 61 | variance_network { 62 | init_val = 0.3 63 | } 64 | 65 | rendering_network { 66 | d_feature = 256 67 | mode = idr 68 | d_in = 9 69 | d_out = 3 70 | d_hidden = 256 71 | n_layers = 8 72 | skip_in = [4] 73 | weight_norm = True 74 | multires = 10 75 | multires_view = 4 76 | squeeze_out = True 77 | } 78 | 79 | neus_renderer { 80 | n_samples = 64 81 | n_importance = 64 82 | n_outside = 32 83 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 84 | perturb = 1.0 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /create_env.sh: -------------------------------------------------------------------------------- 1 | conda create -y -n iron python=3.8 && conda activate iron 2 | pip install numpy scipy trimesh opencv_python scikit-image imageio imageio-ffmpeg pyhocon PyMCubes tqdm icecream configargparse 3 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 4 | pip install tensorboard kornia 5 | conda install -c conda-forge igl 6 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | 3 | echo "Downloading image indices for Bi et al 2020: Deep Reflectance Volumes: Relightable Reconstructions from Multi-View Photometric Images" 4 | echo "Please ask the authors of this work for data, and then split the data using the image indices" 5 | gdown 1BThZgEnHgsL7dgyVTQuSFYZjAkZzQozx 6 | unzip "Bi et al 2020-image_indices.zip" 7 | 8 | echo "Downloading real data captured by Luan et al 2021: Unified Shape and SVBRDF Recovery using Differentiable Monte Carlo Rendering" 9 | echo "Please credit the original paper if you use this data" 10 | gdown 1BO6XZjUm8PhHof5RZ7O0Y3C815loBlqj 11 | unzip "Luan et al 2021.zip" 12 | 13 | echo "Downloading synthetic assets for creating synthetic data with Mitsuba" 14 | gdown 1EhDI06NsluXsC98ZErvB7UN_TPI1_6sn 15 | unzip "synthetic_assets.zip" 16 | -------------------------------------------------------------------------------- /evaluation/eval_image_folder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import os 4 | from skimage.metrics import structural_similarity 5 | import lpips 6 | import torch 7 | import glob 8 | 9 | 10 | def skimage_ssim(pred_im, trgt_im): 11 | ssim = 0. 12 | for ch in range(3): 13 | ssim += structural_similarity(trgt_im[:, :, ch], pred_im[:, :, ch], 14 | data_range=1.0, win_size=11, sigma=1.5, 15 | use_sample_covariance=False, k1=0.01, k2=0.03) 16 | ssim /= 3. 17 | return ssim 18 | 19 | def read_image(fpath): 20 | return imageio.imread(fpath).astype(np.float32) / 255. 21 | 22 | mse2psnr = lambda x: -10. * np.log(x+1e-10) / np.log(10.) 23 | 24 | import sys 25 | folder = sys.argv[1] 26 | 27 | all_psnr = [] 28 | all_ssim = [] 29 | all_lpips = [] 30 | 31 | loss_fn_alex = lpips.LPIPS(net='alex').cuda() # best forward scores 32 | # loss_fn_vgg = lpips.LPIPS(net='vgg') # closer to "traditional" perceptual loss, when used for optimization 33 | 34 | with open(os.path.join(folder, '../metrics.txt'), 'w') as fp: 35 | fp.write('img_name\tpsnr\tssim\tlpips\n') 36 | for _, fpath in enumerate(glob.glob(os.path.join(folder, '*_truth.png'))): 37 | name = os.path.basename(fpath) 38 | idx = name.find('_') 39 | idx = int(name[:idx]) 40 | 41 | pred_im = read_image(os.path.join(folder, '{}_prediction.png'.format(idx))) 42 | trgt_im = read_image(os.path.join(folder, '{}_truth.png'.format(idx))) 43 | 44 | psnr = mse2psnr(np.mean((pred_im - trgt_im) ** 2)) 45 | 46 | ssim = skimage_ssim(trgt_im, pred_im) 47 | 48 | pred_im = torch.from_numpy(pred_im).permute(2, 0, 1).unsqueeze(0) * 2. - 1. 49 | trgt_im = torch.from_numpy(trgt_im).permute(2, 0, 1).unsqueeze(0) * 2. - 1. 50 | d = loss_fn_alex(trgt_im.cuda(), pred_im.cuda()).item() 51 | 52 | fp.write('{}_prediction.png\t{:.3f}\t{:.3f}\t{:.4f}\n'.format(idx, psnr, ssim, d)) 53 | 54 | all_psnr.append(psnr) 55 | all_ssim.append(ssim) 56 | all_lpips.append(d) 57 | fp.write('\nAverage\t{:.3f}\t{:.3f}\t{:.4f}\n'.format(np.mean(all_psnr), np.mean(all_ssim), np.mean(all_lpips))) 58 | 59 | -------------------------------------------------------------------------------- /evaluation/eval_mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import igl 4 | 5 | 6 | def cal_mesh_err(va, fa, vb, fb): 7 | sqrD1, _, _ = igl.point_mesh_squared_distance(va, vb, fb) 8 | sqrD2, _, _ = igl.point_mesh_squared_distance(vb, va, fa) 9 | D1 = np.sqrt(sqrD1) 10 | D2 = np.sqrt(sqrD2) 11 | ret = (D1.mean() + D2.mean()) * 0.5 12 | return ret 13 | 14 | 15 | def eval_obj_meshes(pred_mesh_fpath, trgt_mesh_fpath): 16 | v1, _, n1, f1, _, _ = igl.read_obj(pred_mesh_fpath) 17 | v4, _, n4, f4, _, _ = igl.read_obj(trgt_mesh_fpath) 18 | 19 | return cal_mesh_err(v1, f1, v4, f4) 20 | 21 | 22 | import sys 23 | pred_mesh_fpath = sys.argv[1] 24 | trgt_mesh_fpath = sys.argv[2] 25 | dist_bidirectional = eval_obj_meshes(pred_mesh_fpath, trgt_mesh_fpath) 26 | print('\tChamfer_dist: ', dist_bidirectional) 27 | -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from glob import glob 7 | from icecream import ic 8 | from scipy.spatial.transform import Rotation as Rot 9 | from scipy.spatial.transform import Slerp 10 | import traceback 11 | 12 | 13 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 14 | def load_K_Rt_from_P(filename, P=None): 15 | if P is None: 16 | lines = open(filename).read().splitlines() 17 | if len(lines) == 4: 18 | lines = lines[1:] 19 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 20 | P = np.asarray(lines).astype(np.float32).squeeze() 21 | 22 | out = cv.decomposeProjectionMatrix(P) 23 | K = out[0] 24 | R = out[1] 25 | t = out[2] 26 | 27 | K = K / K[2, 2] 28 | intrinsics = np.eye(4) 29 | intrinsics[:3, :3] = K 30 | 31 | pose = np.eye(4, dtype=np.float32) 32 | pose[:3, :3] = R.transpose() 33 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 34 | 35 | return intrinsics, pose 36 | 37 | 38 | class Dataset: 39 | def __init__(self, conf): 40 | super(Dataset, self).__init__() 41 | print("Load data: Begin") 42 | self.device = torch.device("cuda") 43 | self.conf = conf 44 | 45 | self.data_dir = conf.get_string("data_dir") 46 | self.render_cameras_name = conf.get_string("render_cameras_name") 47 | self.object_cameras_name = conf.get_string("object_cameras_name") 48 | 49 | self.camera_outside_sphere = conf.get_bool("camera_outside_sphere", default=True) 50 | # self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1) # not used 51 | 52 | import json 53 | 54 | camera_dict = json.load(open(os.path.join(self.data_dir, "cam_dict_norm.json"))) 55 | for x in list(camera_dict.keys()): 56 | x = x[:-4] + ".png" 57 | camera_dict[x]["K"] = np.array(camera_dict[x]["K"]).reshape((4, 4)) 58 | camera_dict[x]["W2C"] = np.array(camera_dict[x]["W2C"]).reshape((4, 4)) 59 | 60 | self.camera_dict = camera_dict 61 | 62 | try: 63 | self.images_lis = sorted(glob(os.path.join(self.data_dir, "image/*.png"))) 64 | self.n_images = len(self.images_lis) 65 | self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 255.0 66 | except: 67 | # traceback.print_exc() 68 | 69 | print("Loading png images failed; try loading exr images") 70 | import pyexr 71 | 72 | self.images_lis = sorted(glob(os.path.join(self.data_dir, "image/*.exr"))) 73 | self.n_images = len(self.images_lis) 74 | self.images_np = np.clip( 75 | np.power(np.stack([pyexr.open(im_name).get()[:, :, ::-1] for im_name in self.images_lis]), 1.0 / 2.2), 76 | 0.0, 77 | 1.0, 78 | ) 79 | 80 | no_mask = True 81 | if no_mask: 82 | print("Not using masks") 83 | self.masks_lis = None 84 | self.masks_np = np.ones_like(self.images_np) 85 | else: 86 | try: 87 | self.masks_lis = sorted(glob(os.path.join(self.data_dir, "mask/*.png"))) 88 | self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 255.0 89 | except: 90 | # traceback.print_exc() 91 | 92 | print("Loading mask images failed; try not using masks") 93 | self.masks_lis = None 94 | self.masks_np = np.ones_like(self.images_np) 95 | 96 | self.images_np = self.images_np[..., :3] 97 | self.masks_np = self.masks_np[..., :3] 98 | 99 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin. 100 | self.scale_mats_np = [np.eye(4).astype(np.float32) for idx in range(self.n_images)] 101 | 102 | self.intrinsics_all = [] 103 | self.pose_all = [] 104 | self.world_mats_np = [] 105 | for x in self.images_lis: 106 | x = os.path.basename(x)[:-4] + ".png" 107 | K = self.camera_dict[x]["K"].astype(np.float32) 108 | W2C = self.camera_dict[x]["W2C"].astype(np.float32) 109 | C2W = np.linalg.inv(self.camera_dict[x]["W2C"]).astype(np.float32) 110 | self.intrinsics_all.append(torch.from_numpy(K)) 111 | self.pose_all.append(torch.from_numpy(C2W)) 112 | self.world_mats_np.append(W2C) 113 | 114 | self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3] 115 | self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3] 116 | print("image shape, mask shape: ", self.images.shape, self.masks.shape) 117 | print("image pixel range: ", self.images.min().item(), self.images.max().item()) 118 | 119 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] 120 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4] 121 | self.focal = self.intrinsics_all[0][0, 0] 122 | self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4] 123 | self.H, self.W = self.images.shape[1], self.images.shape[2] 124 | self.image_pixels = self.H * self.W 125 | 126 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0]) 127 | object_bbox_max = np.array([1.01, 1.01, 1.01, 1.0]) 128 | # Object scale mat: region of interest to **extract mesh** 129 | object_scale_mat = np.eye(4).astype(np.float32) 130 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None] 131 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None] 132 | self.object_bbox_min = object_bbox_min[:3, 0] 133 | self.object_bbox_max = object_bbox_max[:3, 0] 134 | 135 | print("Load data: End") 136 | 137 | def gen_rays_at(self, img_idx, resolution_level=1): 138 | """ 139 | Generate rays at world space from one camera. 140 | """ 141 | l = resolution_level 142 | tx = torch.linspace(0, self.W - 1, self.W // l) 143 | ty = torch.linspace(0, self.H - 1, self.H // l) 144 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 145 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 146 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 147 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 148 | rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 149 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3 150 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 151 | 152 | def gen_random_rays_at(self, img_idx, batch_size): 153 | """ 154 | Generate random rays at world space from one camera. 155 | """ 156 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) 157 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) 158 | color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3 159 | mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3 160 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3 161 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3 162 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3 163 | rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3 164 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3 165 | return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10 166 | 167 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 168 | """ 169 | Interpolate pose between two cameras. 170 | """ 171 | l = resolution_level 172 | tx = torch.linspace(0, self.W - 1, self.W // l) 173 | ty = torch.linspace(0, self.H - 1, self.H // l) 174 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 175 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 176 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 177 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 178 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 179 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 180 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 181 | pose_0 = np.linalg.inv(pose_0) 182 | pose_1 = np.linalg.inv(pose_1) 183 | rot_0 = pose_0[:3, :3] 184 | rot_1 = pose_1[:3, :3] 185 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 186 | key_times = [0, 1] 187 | slerp = Slerp(key_times, rots) 188 | rot = slerp(ratio) 189 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 190 | pose = pose.astype(np.float32) 191 | pose[:3, :3] = rot.as_matrix() 192 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 193 | pose = np.linalg.inv(pose) 194 | rot = torch.from_numpy(pose[:3, :3]).cuda() 195 | trans = torch.from_numpy(pose[:3, 3]).cuda() 196 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 197 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 198 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 199 | 200 | def near_far_from_sphere(self, rays_o, rays_d): 201 | a = torch.sum(rays_d**2, dim=-1, keepdim=True) 202 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 203 | mid = 0.5 * (-b) / a 204 | near = mid - 1.0 205 | far = mid + 1.0 206 | return near, far 207 | 208 | def image_at(self, idx, resolution_level): 209 | if self.images_lis[idx].endswith(".exr"): 210 | import pyexr 211 | 212 | img = np.power(pyexr.open(self.images_lis[idx]).get()[:, :, ::-1], 1.0 / 2.2) * 255.0 213 | else: 214 | img = cv.imread(self.images_lis[idx]) 215 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255).astype(np.uint8) 216 | -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs["input_dims"] 14 | out_dim = 0 15 | if self.kwargs["include_input"]: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs["max_freq_log2"] 20 | N_freqs = self.kwargs["num_freqs"] 21 | 22 | if self.kwargs["log_sampling"]: 23 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs["periodic_fns"]: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | 39 | def get_embedder(multires, input_dims=3): 40 | embed_kwargs = { 41 | "include_input": True, 42 | "input_dims": input_dims, 43 | "max_freq_log2": multires - 1, 44 | "num_freqs": multires, 45 | "log_sampling": True, 46 | "periodic_fns": [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | 51 | def embed(x, eo=embedder_obj): 52 | return eo.embed(x) 53 | 54 | return embed, embedder_obj.out_dim 55 | -------------------------------------------------------------------------------- /models/export_materials.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import igl 4 | import trimesh 5 | import os 6 | import shutil 7 | import torch 8 | 9 | 10 | to8b = lambda x: np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8) 11 | 12 | 13 | def sample_surface(vertices, face_vertices, texturecoords, face_texturecoords, n_samples): 14 | """ 15 | Samples point cloud on the surface of the model defined as vectices and 16 | faces. This function uses vectorized operations so fast at the cost of some 17 | memory. 18 | """ 19 | vec_cross = np.cross( 20 | vertices[face_vertices[:, 0], :] - vertices[face_vertices[:, 2], :], 21 | vertices[face_vertices[:, 1], :] - vertices[face_vertices[:, 2], :], 22 | ) 23 | face_areas = np.sqrt(np.sum(vec_cross**2, 1)) 24 | face_areas = face_areas / np.sum(face_areas) 25 | 26 | # Sample exactly n_samples. First, oversample points and remove redundant 27 | # Error fix by Yangyan (yangyan.lee@gmail.com) 2017-Aug-7 28 | n_samples_per_face = np.ceil(n_samples * face_areas).astype(int) 29 | floor_num = np.sum(n_samples_per_face) - n_samples 30 | if floor_num > 0: 31 | indices = np.where(n_samples_per_face > 0)[0] 32 | floor_indices = np.random.choice(indices, floor_num, replace=True) 33 | n_samples_per_face[floor_indices] -= 1 34 | 35 | n_samples = np.sum(n_samples_per_face) 36 | 37 | # Create a vector that contains the face indices 38 | sample_face_idx = np.zeros((n_samples,), dtype=int) 39 | acc = 0 40 | for face_idx, _n_sample in enumerate(n_samples_per_face): 41 | sample_face_idx[acc : acc + _n_sample] = face_idx 42 | acc += _n_sample 43 | 44 | r = np.random.rand(n_samples, 2) 45 | 46 | A = vertices[face_vertices[sample_face_idx, 0], :] 47 | B = vertices[face_vertices[sample_face_idx, 1], :] 48 | C = vertices[face_vertices[sample_face_idx, 2], :] 49 | P = (1 - np.sqrt(r[:, 0:1])) * A + np.sqrt(r[:, 0:1]) * (1 - r[:, 1:]) * B + np.sqrt(r[:, 0:1]) * r[:, 1:] * C 50 | 51 | A = texturecoords[face_texturecoords[sample_face_idx, 0], :] 52 | B = texturecoords[face_texturecoords[sample_face_idx, 1], :] 53 | C = texturecoords[face_texturecoords[sample_face_idx, 2], :] 54 | P_uv = (1 - np.sqrt(r[:, 0:1])) * A + np.sqrt(r[:, 0:1]) * (1 - r[:, 1:]) * B + np.sqrt(r[:, 0:1]) * r[:, 1:] * C 55 | 56 | return P.astype(np.float32), P_uv.astype(np.float32) 57 | 58 | 59 | class Groupby(object): 60 | def __init__(self, keys): 61 | """note keys are assumed to by integer""" 62 | super().__init__() 63 | 64 | self.unique_keys, self.keys_as_int = np.unique(keys, return_inverse=True) 65 | self.n_keys = len(self.unique_keys) 66 | self.indices = [[] for i in range(self.n_keys)] 67 | for i, k in enumerate(self.keys_as_int): 68 | self.indices[k].append(i) 69 | self.indices = [np.array(elt) for elt in self.indices] 70 | 71 | def apply(self, function, vector): 72 | assert len(vector.shape) <= 2 73 | if len(vector.shape) == 2: 74 | result = np.zeros((self.n_keys, vector.shape[-1])) 75 | else: 76 | result = np.zeros((self.n_keys,)) 77 | 78 | for k, idx in enumerate(self.indices): 79 | result[k] = function(vector[idx], axis=0) 80 | 81 | return result 82 | 83 | 84 | def accumulate_splat_material(xyz_image, material_image, weight_image, pcd, uv, material): 85 | H, W = material_image.shape[:2] 86 | 87 | xyz_image = xyz_image.reshape((H * W, -1)) 88 | material_image = material_image.reshape((H * W, -1)) 89 | weight_image = weight_image.reshape((H * W,)) 90 | 91 | ### label each 3d point with their splat pixel index 92 | uv[:, 0] = uv[:, 0] * W 93 | uv[:, 1] = H - uv[:, 1] * H 94 | 95 | ### repeat to a neighborhood 96 | pcd = np.tile(pcd, (5, 1)) 97 | material = np.tile(material, (5, 1)) 98 | uv_up = np.copy(uv) 99 | uv_up[:, 1] -= 1 100 | uv_right = np.copy(uv) 101 | uv_right[:, 0] += 1 102 | uv_down = np.copy(uv) 103 | uv_down[:, 1] += 1 104 | uv_left = np.copy(uv) 105 | uv_left[:, 0] -= 1 106 | uv = np.concatenate((uv, uv_up, uv_right, uv_down, uv_left), axis=0) 107 | 108 | ### compute pixel coordinates 109 | pixel_col = np.floor(uv[:, 0]) 110 | pixel_row = np.floor(uv[:, 1]) 111 | label = (pixel_row * W + pixel_col).astype(int) 112 | 113 | ### filter out-of-range points 114 | mask = np.logical_and(label >= 0, label < H * W) 115 | label = label[mask] 116 | uv = uv[mask] 117 | material = material[mask] 118 | pcd = pcd[mask] 119 | pixel_col = pixel_col[mask] 120 | pixel_row = pixel_row[mask] 121 | 122 | # compute gaussian weight 123 | sigma = 1.0 124 | weight = np.exp(-((uv[:, 0] - pixel_col - 0.5) ** 2 + (uv[:, 1] - pixel_row - 0.5) ** 2) / (2 * sigma * sigma)) 125 | # weight = np.ones_like(uv[:, 0]) 126 | 127 | groupby_obj = Groupby(label) 128 | delta_xyz = groupby_obj.apply(np.sum, weight[:, np.newaxis] * pcd) 129 | delta_material = groupby_obj.apply(np.sum, weight[:, np.newaxis] * material) 130 | delta_weight = groupby_obj.apply(np.sum, weight) 131 | 132 | xyz_image[groupby_obj.unique_keys] += delta_xyz 133 | material_image[groupby_obj.unique_keys] += delta_material 134 | weight_image[groupby_obj.unique_keys] += delta_weight 135 | 136 | xyz_image = xyz_image.reshape((H, W, -1)) 137 | material_image = material_image.reshape((H, W, -1)) 138 | weight_image = weight_image.reshape((H, W)) 139 | 140 | return xyz_image, material_image, weight_image 141 | 142 | 143 | def loadmesh_and_checkuv(obj_fpath, out_dir): 144 | os.makedirs(out_dir, exist_ok=True) 145 | 146 | vertices, texturecoords, _, face_vertices, face_texturecoords, _ = igl.read_obj(obj_fpath, dtype="float32") 147 | 148 | def make_rgba_color(float_rgb): 149 | float_rgba = np.concatenate((float_rgb, np.ones_like(float_rgb[:, 0:1])), axis=-1) 150 | return np.uint8(np.clip(float_rgba * 255.0, 0.0, 255.0)) 151 | 152 | #### create debug plot 153 | pcd, pcd_uv = sample_surface(vertices, face_vertices, texturecoords, face_texturecoords, n_samples=10**6) 154 | 155 | uv_color = np.concatenate((pcd_uv, np.zeros_like(pcd_uv[:, 0:1])), axis=-1) 156 | trimesh.PointCloud(vertices=pcd, colors=make_rgba_color(uv_color)).export(os.path.join(out_dir, "check_uvmap.ply")) 157 | W, H = 512, 512 158 | grid_w, grid_h = np.meshgrid(np.linspace(0.0, 1.0, W), np.linspace(1, 0.0, H)) 159 | grid_color = np.stack((grid_w, grid_h, np.zeros_like(grid_w)), axis=2) 160 | imageio.imwrite(os.path.join(out_dir, "check_uvmap.png"), to8b(grid_color)) 161 | 162 | return vertices, face_vertices, texturecoords, face_texturecoords 163 | 164 | 165 | def export_materials(mesh_fpath, material_predictor, out_dir, max_num_pts=320000, texture_H=2048, texture_W=2048): 166 | """output material parameters""" 167 | os.makedirs(out_dir, exist_ok=True) 168 | vertices, face_vertices, texturecoords, face_texturecoords = loadmesh_and_checkuv(mesh_fpath, out_dir) 169 | 170 | xyz_image = np.zeros((texture_H, texture_W, 3), dtype=np.float32) 171 | material_image = np.zeros((texture_H, texture_W, 7), dtype=np.float32) 172 | weight_image = np.zeros((texture_H, texture_W), dtype=np.float32) 173 | 174 | for i in range(5): 175 | points, points_uv = sample_surface( 176 | vertices, face_vertices, texturecoords, face_texturecoords, n_samples=5 * 10**6 177 | ) 178 | 179 | points = torch.from_numpy(points).cuda() 180 | merge_materials = [] 181 | for points_split in torch.split(points, max_num_pts, dim=0): 182 | with torch.set_grad_enabled(False): 183 | diffuse_albedo, specular_albedo, specular_roughness = material_predictor(points_split) 184 | merge_materials.append( 185 | torch.cat((diffuse_albedo, specular_albedo, specular_roughness), dim=-1).detach().cpu() 186 | ) 187 | merge_materials = torch.cat(merge_materials, dim=0).numpy() 188 | points = points.detach().cpu().numpy() 189 | 190 | accumulate_splat_material(xyz_image, material_image, weight_image, points, points_uv, merge_materials) 191 | 192 | final_xyz_image = xyz_image / (weight_image[:, :, np.newaxis] + 1e-10) 193 | final_material_image = material_image / (weight_image[:, :, np.newaxis] + 1e-10) 194 | 195 | imageio.imwrite(os.path.join(out_dir, "xyz.exr"), final_xyz_image) 196 | imageio.imwrite(os.path.join(out_dir, "diffuse_albedo.exr"), final_material_image[:, :, :3]) 197 | imageio.imwrite(os.path.join(out_dir, "specular_albedo.exr"), final_material_image[:, :, 3:6]) 198 | imageio.imwrite(os.path.join(out_dir, "roughness.exr"), final_material_image[:, :, 6]) 199 | 200 | imageio.imwrite(os.path.join(out_dir, "xyz.png"), to8b(final_xyz_image * 0.5 + 0.5)) 201 | imageio.imwrite(os.path.join(out_dir, "diffuse_albedo.png"), to8b(final_material_image[:, :, :3])) 202 | imageio.imwrite(os.path.join(out_dir, "specular_albedo.png"), to8b(final_material_image[:, :, 3:6])) 203 | imageio.imwrite(os.path.join(out_dir, "roughness.png"), to8b(final_material_image[:, :, 6])) 204 | 205 | out_mesh_fpath = mesh_fpath 206 | with open(out_mesh_fpath, "r") as original: 207 | data = original.read() 208 | with open(out_mesh_fpath, "w") as modified: 209 | modified.write("usemtl ./{}\n\n".format(os.path.basename(out_mesh_fpath)[:-4] + ".mtl") + data) 210 | 211 | with open(os.path.join(out_dir, os.path.basename(out_mesh_fpath)[:-4] + ".mtl"), "w") as fp: 212 | fp.write( 213 | "newmtl Wood\n" 214 | "Ka 1.000000 1.000000 1.000000\n" 215 | "Kd 0.640000 0.640000 0.640000\n" 216 | "Ks 0.500000 0.500000 0.500000\n" 217 | "Ns 96.078431\n" 218 | "Ni 1.000000\n" 219 | "d 1.000000\n" 220 | "illum 0\n" 221 | "map_Kd diffuse_albedo.png\n" 222 | ) 223 | -------------------------------------------------------------------------------- /models/export_mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import trimesh 4 | from skimage import measure 5 | 6 | 7 | def get_grid_uniform(resolution): 8 | x = np.linspace(-1.0, 1.0, resolution) 9 | y = x 10 | z = x 11 | 12 | xx, yy, zz = np.meshgrid(x, y, z) 13 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) 14 | 15 | return {"grid_points": grid_points.cuda(), "shortest_axis_length": 2.0, "xyz": [x, y, z], "shortest_axis_index": 0} 16 | 17 | 18 | def get_grid(points, resolution, eps=0.1): 19 | input_min = torch.min(points, dim=0)[0].squeeze().numpy() 20 | input_max = torch.max(points, dim=0)[0].squeeze().numpy() 21 | 22 | bounding_box = input_max - input_min 23 | shortest_axis = np.argmin(bounding_box) 24 | if shortest_axis == 0: 25 | x = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) 26 | length = np.max(x) - np.min(x) 27 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 28 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 29 | elif shortest_axis == 1: 30 | y = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) 31 | length = np.max(y) - np.min(y) 32 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 33 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 34 | elif shortest_axis == 2: 35 | z = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) 36 | length = np.max(z) - np.min(z) 37 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 38 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 39 | 40 | xx, yy, zz = np.meshgrid(x, y, z) 41 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 42 | return { 43 | "grid_points": grid_points, 44 | "shortest_axis_length": length, 45 | "xyz": [x, y, z], 46 | "shortest_axis_index": shortest_axis, 47 | } 48 | 49 | 50 | def export_mesh(sdf, mesh_fpath, resolution=512, max_n_pts=100000): 51 | assert mesh_fpath.endswith(".obj"), f"must use .obj format: {mesh_fpath}" 52 | # get low res mesh to sample point cloud 53 | grid = get_grid_uniform(100) 54 | z = [] 55 | points = grid["grid_points"] 56 | for i, pnts in enumerate(torch.split(points, max_n_pts, dim=0)): 57 | z.append(sdf(pnts).detach().cpu().numpy()) 58 | z = np.concatenate(z, axis=0).astype(np.float32) 59 | verts, faces, normals, values = measure.marching_cubes( 60 | volume=z.reshape(grid["xyz"][1].shape[0], grid["xyz"][0].shape[0], grid["xyz"][2].shape[0]).transpose( 61 | [1, 0, 2] 62 | ), 63 | level=0, 64 | spacing=( 65 | grid["xyz"][0][2] - grid["xyz"][0][1], 66 | grid["xyz"][0][2] - grid["xyz"][0][1], 67 | grid["xyz"][0][2] - grid["xyz"][0][1], 68 | ), 69 | ) 70 | verts = verts + np.array([grid["xyz"][0][0], grid["xyz"][1][0], grid["xyz"][2][0]]) 71 | mesh_low_res = trimesh.Trimesh(verts, faces, normals) 72 | components = mesh_low_res.split(only_watertight=False) 73 | areas = np.array([c.area for c in components], dtype=np.float) 74 | mesh_low_res = components[areas.argmax()] 75 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] 76 | recon_pc = torch.from_numpy(recon_pc).float().cuda() 77 | 78 | # Center and align the recon pc 79 | s_mean = recon_pc.mean(dim=0) 80 | s_cov = recon_pc - s_mean 81 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) 82 | vecs = torch.eig(s_cov, True)[1].transpose(0, 1) 83 | if torch.det(vecs) < 0: 84 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) 85 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), (recon_pc - s_mean).unsqueeze(-1)).squeeze() 86 | 87 | grid_aligned = get_grid(helper.cpu(), resolution) 88 | grid_points = grid_aligned["grid_points"] 89 | g = [] 90 | for i, pnts in enumerate(torch.split(grid_points, max_n_pts, dim=0)): 91 | g.append( 92 | ( 93 | torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), pnts.unsqueeze(-1)).squeeze() 94 | + s_mean 95 | ) 96 | .detach() 97 | .cpu() 98 | ) 99 | grid_points = torch.cat(g, dim=0) 100 | 101 | # MC to new grid 102 | points = grid_points 103 | z = [] 104 | for i, pnts in enumerate(torch.split(points, max_n_pts, dim=0)): 105 | z.append(sdf(pnts.cuda()).detach().cpu().numpy()) 106 | z = np.concatenate(z, axis=0).astype(np.float32) 107 | 108 | if not (np.min(z) > 0 or np.max(z) < 0): 109 | verts, faces, normals, values = measure.marching_cubes( 110 | volume=z.reshape( 111 | grid_aligned["xyz"][1].shape[0], grid_aligned["xyz"][0].shape[0], grid_aligned["xyz"][2].shape[0] 112 | ).transpose([1, 0, 2]), 113 | level=0, 114 | spacing=( 115 | grid_aligned["xyz"][0][2] - grid_aligned["xyz"][0][1], 116 | grid_aligned["xyz"][0][2] - grid_aligned["xyz"][0][1], 117 | grid_aligned["xyz"][0][2] - grid_aligned["xyz"][0][1], 118 | ), 119 | ) 120 | 121 | verts = torch.from_numpy(verts).float() 122 | verts = torch.bmm( 123 | vecs.detach().cpu().unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), verts.unsqueeze(-1) 124 | ).squeeze() 125 | verts = (verts + grid_points[0]).numpy() 126 | 127 | trimesh.Trimesh(verts, faces, normals).export(mesh_fpath) 128 | -------------------------------------------------------------------------------- /models/export_uv.py: -------------------------------------------------------------------------------- 1 | # Usage: Blender --background --python export_uv.py {in_mesh_fpath} {out_mesh_fpath} 2 | 3 | import os 4 | import bpy 5 | import sys 6 | 7 | 8 | def export_uv(in_mesh_fpath, out_mesh_fpath): 9 | assert in_mesh_fpath.endswith(".obj"), f"must use .obj format: {in_mesh_fpath}" 10 | assert out_mesh_fpath.endswith(".obj"), f"must use .obj format: {out_mesh_fpath}" 11 | 12 | bpy.data.objects["Camera"].select_set(True) 13 | bpy.data.objects["Cube"].select_set(True) 14 | bpy.data.objects["Light"].select_set(True) 15 | bpy.ops.object.delete() # delete camera, cube, light 16 | 17 | mesh_fname = os.path.basename(in_mesh_fpath)[:-4] 18 | bpy.ops.import_scene.obj( 19 | filepath=in_mesh_fpath, 20 | use_edges=True, 21 | use_smooth_groups=True, 22 | use_split_objects=True, 23 | use_split_groups=True, 24 | use_groups_as_vgroups=False, 25 | use_image_search=True, 26 | split_mode="ON", 27 | global_clamp_size=0, 28 | axis_forward="-Z", 29 | axis_up="Y", 30 | ) 31 | 32 | obj = bpy.data.objects[mesh_fname] 33 | obj.select_set(True) 34 | bpy.context.view_layer.objects.active = obj 35 | bpy.ops.object.mode_set(mode="EDIT") 36 | bpy.ops.mesh.select_all(action="SELECT") 37 | bpy.ops.uv.smart_project() 38 | bpy.ops.object.mode_set(mode="OBJECT") 39 | 40 | bpy.ops.export_scene.obj( 41 | filepath=out_mesh_fpath, 42 | axis_forward="-Z", 43 | axis_up="Y", 44 | use_selection=True, 45 | use_normals=True, 46 | use_uvs=True, 47 | use_materials=False, 48 | use_triangles=True, 49 | ) 50 | 51 | 52 | print(sys.argv) 53 | export_uv(sys.argv[-2], sys.argv[-1]) 54 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | 8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 9 | class SDFNetwork(nn.Module): 10 | def __init__( 11 | self, 12 | d_in, 13 | d_out, 14 | d_hidden, 15 | n_layers, 16 | skip_in=(4,), 17 | multires=0, 18 | bias=0.5, 19 | scale=1, 20 | geometric_init=True, 21 | weight_norm=True, 22 | inside_outside=False, 23 | ): 24 | super(SDFNetwork, self).__init__() 25 | 26 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 27 | 28 | self.embed_fn_fine = None 29 | 30 | if multires > 0: 31 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 32 | self.embed_fn_fine = embed_fn 33 | dims[0] = input_ch 34 | 35 | self.num_layers = len(dims) 36 | self.skip_in = skip_in 37 | self.scale = scale 38 | 39 | for l in range(0, self.num_layers - 1): 40 | if l + 1 in self.skip_in: 41 | out_dim = dims[l + 1] - dims[0] 42 | else: 43 | out_dim = dims[l + 1] 44 | 45 | lin = nn.Linear(dims[l], out_dim) 46 | 47 | if geometric_init: 48 | if l == self.num_layers - 2: 49 | if not inside_outside: 50 | torch.nn.init.normal_( 51 | lin.weight, 52 | mean=np.sqrt(np.pi) / np.sqrt(dims[l]), 53 | std=0.0001, 54 | ) 55 | torch.nn.init.constant_(lin.bias, -bias) 56 | else: 57 | torch.nn.init.normal_( 58 | lin.weight, 59 | mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), 60 | std=0.0001, 61 | ) 62 | torch.nn.init.constant_(lin.bias, bias) 63 | elif multires > 0 and l == 0: 64 | torch.nn.init.constant_(lin.bias, 0.0) 65 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 66 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 67 | elif multires > 0 and l in self.skip_in: 68 | torch.nn.init.constant_(lin.bias, 0.0) 69 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 70 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) 71 | else: 72 | torch.nn.init.constant_(lin.bias, 0.0) 73 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 74 | 75 | if weight_norm: 76 | lin = nn.utils.weight_norm(lin) 77 | 78 | setattr(self, "lin" + str(l), lin) 79 | 80 | self.activation = nn.Softplus(beta=100) 81 | 82 | def forward(self, inputs): 83 | inputs = inputs * self.scale 84 | if self.embed_fn_fine is not None: 85 | inputs = self.embed_fn_fine(inputs) 86 | 87 | x = inputs 88 | for l in range(0, self.num_layers - 1): 89 | lin = getattr(self, "lin" + str(l)) 90 | 91 | if l in self.skip_in: 92 | x = torch.cat([x, inputs], -1) / np.sqrt(2) 93 | 94 | x = lin(x) 95 | 96 | if l < self.num_layers - 2: 97 | x = self.activation(x) 98 | return torch.cat([x[..., :1] / self.scale, x[..., 1:]], dim=-1) 99 | 100 | def sdf(self, x): 101 | return self.forward(x)[..., :1] 102 | 103 | def sdf_hidden_appearance(self, x): 104 | return self.forward(x) 105 | 106 | def gradient(self, x): 107 | x.requires_grad_(True) 108 | y = self.sdf(x) 109 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 110 | gradients = torch.autograd.grad( 111 | outputs=y, 112 | inputs=x, 113 | grad_outputs=d_output, 114 | create_graph=True, 115 | retain_graph=True, 116 | only_inputs=True, 117 | )[0] 118 | return gradients 119 | 120 | def get_all(self, x, is_training=True): 121 | with torch.enable_grad(): 122 | x.requires_grad_(True) 123 | tmp = self.forward(x) 124 | y, feature = tmp[..., :1], tmp[..., 1:] 125 | 126 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 127 | gradients = torch.autograd.grad( 128 | outputs=y, 129 | inputs=x, 130 | grad_outputs=d_output, 131 | create_graph=is_training, 132 | retain_graph=is_training, 133 | only_inputs=True, 134 | )[0] 135 | if not is_training: 136 | return y.detach(), feature.detach(), gradients.detach() 137 | return y, feature, gradients 138 | 139 | 140 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 141 | class RenderingNetwork(nn.Module): 142 | def __init__( 143 | self, 144 | d_feature, 145 | mode, 146 | d_in, 147 | d_out, 148 | d_hidden, 149 | n_layers, 150 | weight_norm=True, 151 | multires=0, 152 | multires_view=0, 153 | squeeze_out=True, 154 | squeeze_out_scale=1.0, 155 | output_bias=0.0, 156 | output_scale=1.0, 157 | skip_in=(), 158 | ): 159 | super().__init__() 160 | 161 | self.mode = mode 162 | self.squeeze_out = squeeze_out 163 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] 164 | 165 | self.embed_fn = None 166 | if multires > 0: 167 | embed_fn, input_ch = get_embedder(multires) 168 | self.embed_fn = embed_fn 169 | dims[0] += input_ch - 3 170 | 171 | self.embedview_fn = None 172 | if multires_view > 0: 173 | embedview_fn, input_ch = get_embedder(multires_view) 174 | self.embedview_fn = embedview_fn 175 | dims[0] += input_ch - 3 176 | 177 | self.num_layers = len(dims) 178 | self.skip_in = skip_in 179 | 180 | for l in range(0, self.num_layers - 1): 181 | if l in self.skip_in: 182 | dims[l] += dims[0] 183 | 184 | for l in range(0, self.num_layers - 1): 185 | if l + 1 in self.skip_in: 186 | out_dim = dims[l + 1] - dims[0] 187 | else: 188 | out_dim = dims[l + 1] 189 | 190 | lin = nn.Linear(dims[l], out_dim) 191 | 192 | if weight_norm: 193 | lin = nn.utils.weight_norm(lin) 194 | 195 | setattr(self, "lin" + str(l), lin) 196 | 197 | self.relu = nn.ReLU() 198 | 199 | self.output_bias = output_bias 200 | self.output_scale = output_scale 201 | self.squeeze_out_scale = squeeze_out_scale 202 | 203 | def forward(self, points, normals, view_dirs, feature_vectors): 204 | 205 | if self.embed_fn is not None: 206 | points = self.embed_fn(points) 207 | 208 | if self.embedview_fn is not None and self.mode != "no_view_dir": 209 | view_dirs = self.embedview_fn(view_dirs) 210 | 211 | rendering_input = None 212 | 213 | if self.mode == "idr": 214 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 215 | elif self.mode == "no_view_dir": 216 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) 217 | elif self.mode == "no_normal": 218 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) 219 | 220 | x = rendering_input 221 | 222 | for l in range(0, self.num_layers - 1): 223 | lin = getattr(self, "lin" + str(l)) 224 | 225 | if l in self.skip_in: 226 | x = torch.cat([x, rendering_input], dim=-1) / np.sqrt(2) 227 | 228 | x = lin(x) 229 | 230 | if l < self.num_layers - 2: 231 | x = self.relu(x) 232 | 233 | x = self.output_scale * (x + self.output_bias) 234 | if self.squeeze_out: 235 | x = self.squeeze_out_scale * torch.sigmoid(x) 236 | 237 | return x 238 | 239 | 240 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch 241 | class NeRF(nn.Module): 242 | def __init__( 243 | self, 244 | D=8, 245 | W=256, 246 | d_in=3, 247 | d_in_view=3, 248 | multires=0, 249 | multires_view=0, 250 | output_ch=4, 251 | skips=[4], 252 | use_viewdirs=False, 253 | ): 254 | super(NeRF, self).__init__() 255 | self.D = D 256 | self.W = W 257 | self.d_in = d_in 258 | self.d_in_view = d_in_view 259 | self.input_ch = 3 260 | self.input_ch_view = 3 261 | self.embed_fn = None 262 | self.embed_fn_view = None 263 | 264 | if multires > 0: 265 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 266 | self.embed_fn = embed_fn 267 | self.input_ch = input_ch 268 | 269 | if multires_view > 0: 270 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view) 271 | self.embed_fn_view = embed_fn_view 272 | self.input_ch_view = input_ch_view 273 | 274 | self.skips = skips 275 | self.use_viewdirs = use_viewdirs 276 | 277 | self.pts_linears = nn.ModuleList( 278 | [nn.Linear(self.input_ch, W)] 279 | + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)] 280 | ) 281 | 282 | ### Implementation according to the official code release 283 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 284 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) 285 | 286 | ### Implementation according to the paper 287 | # self.views_linears = nn.ModuleList( 288 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 289 | 290 | if use_viewdirs: 291 | self.feature_linear = nn.Linear(W, W) 292 | self.alpha_linear = nn.Linear(W, 1) 293 | self.rgb_linear = nn.Linear(W // 2, 3) 294 | else: 295 | self.output_linear = nn.Linear(W, output_ch) 296 | 297 | def forward(self, input_pts, input_views): 298 | if self.embed_fn is not None: 299 | input_pts = self.embed_fn(input_pts) 300 | if self.embed_fn_view is not None: 301 | input_views = self.embed_fn_view(input_views) 302 | 303 | h = input_pts 304 | for i, l in enumerate(self.pts_linears): 305 | h = self.pts_linears[i](h) 306 | h = F.relu(h) 307 | if i in self.skips: 308 | h = torch.cat([input_pts, h], -1) 309 | 310 | if self.use_viewdirs: 311 | alpha = self.alpha_linear(h) 312 | feature = self.feature_linear(h) 313 | h = torch.cat([feature, input_views], -1) 314 | 315 | for i, l in enumerate(self.views_linears): 316 | h = self.views_linears[i](h) 317 | h = F.relu(h) 318 | 319 | rgb = self.rgb_linear(h) 320 | return alpha, rgb 321 | else: 322 | assert False 323 | 324 | 325 | class SingleVarianceNetwork(nn.Module): 326 | def __init__(self, init_val): 327 | super(SingleVarianceNetwork, self).__init__() 328 | self.register_parameter("variance", nn.Parameter(torch.tensor(init_val))) 329 | 330 | def forward(self, x): 331 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0) 332 | -------------------------------------------------------------------------------- /models/ggx/int_mts_diff_rtrans_data.txt: -------------------------------------------------------------------------------- 1 | 0.416354 2 | 0.416355 3 | 0.416354 4 | 0.416354 5 | 0.41635 6 | 0.416334 7 | 0.416277 8 | 0.416124 9 | 0.415772 10 | 0.415112 11 | 0.414269 12 | 0.413012 13 | 0.41189 14 | 0.410755 15 | 0.410089 16 | 0.409991 17 | 0.409841 18 | 0.410012 19 | 0.410206 20 | 0.410433 21 | 0.41088 22 | 0.41127 23 | 0.41126 24 | 0.411295 25 | 0.410715 26 | 0.409467 27 | 0.407075 28 | 0.403966 29 | 0.399456 30 | 0.393355 31 | 0.385689 32 | 0.376357 33 | 0.365266 34 | 0.352599 35 | 0.338526 36 | 0.323228 37 | 0.306972 38 | 0.290034 39 | 0.272687 40 | 0.25522 41 | 0.237934 42 | 0.220986 43 | 0.20462 44 | 0.188985 45 | 0.174162 46 | 0.160172 47 | 0.147244 48 | 0.13523 49 | 0.124225 50 | 0.114087 51 | -------------------------------------------------------------------------------- /models/image_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import scipy.ndimage 6 | 7 | import warnings 8 | import kornia 9 | 10 | from icecream import ic 11 | 12 | 13 | class PyramidL2Loss(nn.Module): 14 | def __init__(self, use_cuda=True): 15 | super().__init__() 16 | 17 | dirac = np.zeros((7, 7), dtype=np.float32) 18 | dirac[3, 3] = 1.0 19 | f = np.zeros([3, 3, 7, 7], dtype=np.float32) 20 | gf = scipy.ndimage.filters.gaussian_filter(dirac, 1.0) 21 | f[0, 0, :, :] = gf 22 | f[1, 1, :, :] = gf 23 | f[2, 2, :, :] = gf 24 | self.f = torch.from_numpy(f) 25 | if use_cuda: 26 | self.f = self.f.cuda() 27 | self.m = torch.nn.AvgPool2d(2) 28 | 29 | def forward(self, pred_img, trgt_img): 30 | """ 31 | pred_img, trgt_img: [B, C, H, W] 32 | """ 33 | diff_0 = pred_img - trgt_img 34 | 35 | h, w = pred_img.shape[-2:] 36 | # Convolve then downsample 37 | diff_1 = self.m(torch.nn.functional.conv2d(diff_0, self.f, padding=3)) 38 | diff_2 = self.m(torch.nn.functional.conv2d(diff_1, self.f, padding=3)) 39 | diff_3 = self.m(torch.nn.functional.conv2d(diff_2, self.f, padding=3)) 40 | diff_4 = self.m(torch.nn.functional.conv2d(diff_3, self.f, padding=3)) 41 | loss = ( 42 | diff_0.pow(2).sum() / (h * w) 43 | + diff_1.pow(2).sum() / ((h / 2.0) * (w / 2.0)) 44 | + diff_2.pow(2).sum() / ((h / 4.0) * (w / 4.0)) 45 | + diff_3.pow(2).sum() / ((h / 8.0) * (w / 8.0)) 46 | + diff_4.pow(2).sum() / ((h / 16.0) * (w / 16.0)) 47 | ) 48 | return loss 49 | 50 | 51 | def _fspecial_gauss_1d(size, sigma): 52 | r"""Create 1-D gauss kernel 53 | Args: 54 | size (int): the size of gauss kernel 55 | sigma (float): sigma of normal distribution 56 | Returns: 57 | torch.Tensor: 1D kernel (1 x 1 x size) 58 | """ 59 | coords = torch.arange(size, dtype=torch.float) 60 | coords -= size // 2 61 | 62 | g = torch.exp(-(coords**2) / (2 * sigma**2)) 63 | g /= g.sum() 64 | 65 | return g.unsqueeze(0).unsqueeze(0) 66 | 67 | 68 | def gaussian_filter(input, win): 69 | r"""Blur input with 1-D kernel 70 | Args: 71 | input (torch.Tensor): a batch of tensors to be blurred 72 | window (torch.Tensor): 1-D gauss kernel 73 | Returns: 74 | torch.Tensor: blurred tensors 75 | """ 76 | assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape 77 | if len(input.shape) == 4: 78 | conv = F.conv2d 79 | elif len(input.shape) == 5: 80 | conv = F.conv3d 81 | else: 82 | raise NotImplementedError(input.shape) 83 | 84 | C = input.shape[1] 85 | out = input 86 | for i, s in enumerate(input.shape[2:]): 87 | if s >= win.shape[-1]: 88 | out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) 89 | else: 90 | warnings.warn( 91 | f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" 92 | ) 93 | 94 | return out 95 | 96 | 97 | def ssim_loss_fn(X, Y, mask=None, data_range=1.0, win_size=11, win_sigma=1.5, K=(0.01, 0.03)): 98 | r"""Calculate ssim index for X and Y 99 | Args: 100 | X (torch.Tensor): images of shape [b, c, h, w] 101 | Y (torch.Tensor): images of shape [b, c, h, w] 102 | mask (torch.Tensor): [b, 1, h, w] 103 | win_size: (int, optional): the size of gauss kernel 104 | win_sigma: (float, optional): sigma of normal distribution 105 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 106 | Returns: 107 | torch.Tensor: per pixel ssim results (same size as input images X, Y) 108 | """ 109 | if not X.shape == Y.shape: 110 | raise ValueError("Input images should have the same dimensions.") 111 | 112 | if not X.type() == Y.type(): 113 | raise ValueError("Input images should have the same dtype.") 114 | 115 | if len(X.shape) != 4: 116 | raise ValueError(f"Input images should be 4-d tensors, but got {X.shape}") 117 | 118 | if not (win_size % 2 == 1): 119 | raise ValueError("Window size should be odd.") 120 | 121 | win = _fspecial_gauss_1d(win_size, win_sigma) 122 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 123 | 124 | K1, K2 = K 125 | # batch, channel, [depth,] height, width = X.shape 126 | compensation = 1.0 127 | 128 | C1 = (K1 * data_range) ** 2 129 | C2 = (K2 * data_range) ** 2 130 | 131 | win = win.to(X.device, dtype=X.dtype) 132 | 133 | mu1 = gaussian_filter(X, win) 134 | mu2 = gaussian_filter(Y, win) 135 | 136 | mu1_sq = mu1.pow(2) 137 | mu2_sq = mu2.pow(2) 138 | mu1_mu2 = mu1 * mu2 139 | 140 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) 141 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) 142 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) 143 | 144 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 145 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 146 | ssim_map = ssim_map.mean(dim=1, keepdim=True) 147 | 148 | if mask is not None: 149 | ### pad ssim_map to original size 150 | ssim_map = F.pad( 151 | ssim_map, (win_size // 2, win_size // 2, win_size // 2, win_size // 2), mode="constant", value=1.0 152 | ) 153 | 154 | mask = kornia.morphology.erosion(mask.float(), torch.ones(win_size, win_size).float().to(mask.device)) > 0.5 155 | # ic(ssim_map.shape, mask.shape) 156 | ssim_map = ssim_map[mask] 157 | 158 | return 1.0 - ssim_map.mean() 159 | 160 | 161 | if __name__ == "__main__": 162 | pred_im = torch.rand(1, 3, 256, 256).cuda() 163 | # gt_im = torch.rand(1, 3, 256, 256).cuda() 164 | gt_im = pred_im.clone() 165 | mask = torch.ones(1, 1, 256, 256).bool().cuda() 166 | 167 | ssim_loss = ssim_loss_fn(pred_im, gt_im, mask) 168 | ic(ssim_loss) 169 | -------------------------------------------------------------------------------- /models/raytracer.py: -------------------------------------------------------------------------------- 1 | from email import contentmanager 2 | from operator import contains 3 | import os 4 | from sys import prefix 5 | from turtle import update 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import kornia 10 | import cv2 11 | 12 | from icecream import ic 13 | 14 | VERBOSE_MODE = False 15 | 16 | 17 | def reparam_points(nondiff_points, nondiff_grads, nondiff_trgt_dirs, diff_sdf_vals): 18 | # note that flipping the direction of nondiff_trgt_dirs would not change this equations at all 19 | # hence we require dot >= 0 20 | dot = (nondiff_grads * nondiff_trgt_dirs).sum(dim=-1, keepdim=True) 21 | # assert (dot >= 0.).all(), 'dot>=0 not satisfied in reparam_points: {},{}'.format(dot.min().item(), dot.max().item()) 22 | dot = torch.clamp(dot, min=1e-4) 23 | diff_points = nondiff_points - nondiff_trgt_dirs / dot * (diff_sdf_vals - diff_sdf_vals.detach()) 24 | return diff_points 25 | 26 | 27 | class RayTracer(nn.Module): 28 | def __init__( 29 | self, 30 | sdf_threshold=5.0e-5, 31 | sphere_tracing_iters=16, 32 | n_steps=128, 33 | max_num_pts=200000, 34 | ): 35 | super().__init__() 36 | """sdf values of convergent points must be inside [-sdf_threshold, sdf_threshold]""" 37 | self.sdf_threshold = sdf_threshold 38 | # sphere tracing hyper-params 39 | self.sphere_tracing_iters = sphere_tracing_iters 40 | # dense sampling hyper-params 41 | self.n_steps = n_steps 42 | 43 | self.max_num_pts = max_num_pts 44 | 45 | @torch.no_grad() 46 | def forward(self, sdf, ray_o, ray_d, min_dis, max_dis, work_mask): 47 | ( 48 | convergent_mask, 49 | unfinished_mask_start, 50 | curr_start_points, 51 | curr_start_sdf, 52 | acc_start_dis, 53 | ) = self.sphere_tracing(sdf, ray_o, ray_d, min_dis, max_dis, work_mask) 54 | sphere_tracing_cnt = convergent_mask.sum() 55 | 56 | sampler_work_mask = unfinished_mask_start 57 | sampler_cnt = 0 58 | if sampler_work_mask.sum() > 0: 59 | tmp_mask = (curr_start_sdf[sampler_work_mask] > 0.0).float() 60 | sampler_min_dis = ( 61 | tmp_mask * acc_start_dis[sampler_work_mask] + (1.0 - tmp_mask) * min_dis[sampler_work_mask] 62 | ) 63 | sampler_max_dis = ( 64 | tmp_mask * max_dis[sampler_work_mask] + (1.0 - tmp_mask) * acc_start_dis[sampler_work_mask] 65 | ) 66 | 67 | (sampler_convergent_mask, sampler_points, sampler_sdf, sampler_dis,) = self.ray_sampler( 68 | sdf, 69 | ray_o[sampler_work_mask], 70 | ray_d[sampler_work_mask], 71 | sampler_min_dis, 72 | sampler_max_dis, 73 | ) 74 | 75 | convergent_mask[sampler_work_mask] = sampler_convergent_mask 76 | curr_start_points[sampler_work_mask] = sampler_points 77 | curr_start_sdf[sampler_work_mask] = sampler_sdf 78 | acc_start_dis[sampler_work_mask] = sampler_dis 79 | sampler_cnt = sampler_convergent_mask.sum() 80 | 81 | ret_dict = { 82 | "convergent_mask": convergent_mask, 83 | "points": curr_start_points, 84 | "sdf": curr_start_sdf, 85 | "distance": acc_start_dis, 86 | } 87 | 88 | if VERBOSE_MODE: # debug 89 | sdf_check = sdf(curr_start_points) 90 | ic( 91 | convergent_mask.sum() / convergent_mask.numel(), 92 | sdf_check[convergent_mask].min().item(), 93 | sdf_check[convergent_mask].max().item(), 94 | ) 95 | debug_info = "Total,raytraced,convergent(sphere tracing+dense sampling): {},{},{} ({}+{})".format( 96 | work_mask.numel(), 97 | work_mask.sum(), 98 | convergent_mask.sum(), 99 | sphere_tracing_cnt, 100 | sampler_cnt, 101 | ) 102 | ic(debug_info) 103 | return ret_dict 104 | 105 | def sphere_tracing(self, sdf, ray_o, ray_d, min_dis, max_dis, work_mask): 106 | """Run sphere tracing algorithm for max iterations""" 107 | iters = 0 108 | unfinished_mask_start = work_mask.clone() 109 | acc_start_dis = min_dis.clone() 110 | curr_start_points = ray_o + ray_d * acc_start_dis.unsqueeze(-1) 111 | curr_sdf_start = sdf(curr_start_points) 112 | while True: 113 | # Check convergence 114 | unfinished_mask_start = ( 115 | unfinished_mask_start & (curr_sdf_start.abs() > self.sdf_threshold) & (acc_start_dis < max_dis) 116 | ) 117 | 118 | if iters == self.sphere_tracing_iters or unfinished_mask_start.sum() == 0: 119 | break 120 | iters += 1 121 | 122 | # Make step 123 | tmp = curr_sdf_start[unfinished_mask_start] 124 | acc_start_dis[unfinished_mask_start] += tmp 125 | curr_start_points[unfinished_mask_start] += ray_d[unfinished_mask_start] * tmp.unsqueeze(-1) 126 | curr_sdf_start[unfinished_mask_start] = sdf(curr_start_points[unfinished_mask_start]) 127 | 128 | convergent_mask = ( 129 | work_mask 130 | & ~unfinished_mask_start 131 | & (curr_sdf_start.abs() <= self.sdf_threshold) 132 | & (acc_start_dis < max_dis) 133 | ) 134 | return ( 135 | convergent_mask, 136 | unfinished_mask_start, 137 | curr_start_points, 138 | curr_sdf_start, 139 | acc_start_dis, 140 | ) 141 | 142 | def ray_sampler(self, sdf, ray_o, ray_d, min_dis, max_dis): 143 | """Sample the ray in a given range and perform rootfinding on ray segments which have sign transition""" 144 | intervals_dis = ( 145 | torch.linspace(0, 1, steps=self.n_steps).float().to(min_dis.device).view(1, self.n_steps) 146 | ) # [1, n_steps] 147 | intervals_dis = min_dis.unsqueeze(-1) + intervals_dis * ( 148 | max_dis.unsqueeze(-1) - min_dis.unsqueeze(-1) 149 | ) # [n_valid, n_steps] 150 | points = ray_o.unsqueeze(-2) + ray_d.unsqueeze(-2) * intervals_dis.unsqueeze(-1) # [n_valid, n_steps, 3] 151 | 152 | sdf_val = [] 153 | for pnts in torch.split(points.reshape(-1, 3), self.max_num_pts, dim=0): 154 | sdf_val.append(sdf(pnts)) 155 | sdf_val = torch.cat(sdf_val, dim=0).reshape(-1, self.n_steps) 156 | 157 | # To be returned 158 | sampler_pts = torch.zeros_like(ray_d) 159 | sampler_sdf = torch.zeros_like(min_dis) 160 | sampler_dis = torch.zeros_like(min_dis) 161 | 162 | tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).float().to(sdf_val.device).reshape( 163 | 1, self.n_steps 164 | ) 165 | # return first negative sdf point if exists 166 | min_val, min_idx = torch.min(tmp, dim=-1) 167 | rootfind_work_mask = (min_val < 0.0) & (min_idx >= 1) 168 | n_rootfind = rootfind_work_mask.sum() 169 | if n_rootfind > 0: 170 | # [n_rootfind, 1] 171 | min_idx = min_idx[rootfind_work_mask].unsqueeze(-1) 172 | z_low = torch.gather(intervals_dis[rootfind_work_mask], dim=-1, index=min_idx - 1).squeeze( 173 | -1 174 | ) # [n_rootfind, ] 175 | # [n_rootfind, ]; > 0 176 | sdf_low = torch.gather(sdf_val[rootfind_work_mask], dim=-1, index=min_idx - 1).squeeze(-1) 177 | z_high = torch.gather(intervals_dis[rootfind_work_mask], dim=-1, index=min_idx).squeeze( 178 | -1 179 | ) # [n_rootfind, ] 180 | # [n_rootfind, ]; < 0 181 | sdf_high = torch.gather(sdf_val[rootfind_work_mask], dim=-1, index=min_idx).squeeze(-1) 182 | 183 | p_pred, z_pred, sdf_pred = self.rootfind( 184 | sdf, 185 | sdf_low, 186 | sdf_high, 187 | z_low, 188 | z_high, 189 | ray_o[rootfind_work_mask], 190 | ray_d[rootfind_work_mask], 191 | ) 192 | 193 | sampler_pts[rootfind_work_mask] = p_pred 194 | sampler_sdf[rootfind_work_mask] = sdf_pred 195 | sampler_dis[rootfind_work_mask] = z_pred 196 | 197 | return rootfind_work_mask, sampler_pts, sampler_sdf, sampler_dis 198 | 199 | def rootfind(self, sdf, f_low, f_high, d_low, d_high, ray_o, ray_d): 200 | """binary search the root""" 201 | work_mask = (f_low > 0) & (f_high < 0) 202 | d_mid = (d_low + d_high) / 2.0 203 | i = 0 204 | while work_mask.any(): 205 | p_mid = ray_o + ray_d * d_mid.unsqueeze(-1) 206 | f_mid = sdf(p_mid) 207 | ind_low = f_mid > 0 208 | ind_high = f_mid <= 0 209 | if ind_low.sum() > 0: 210 | d_low[ind_low] = d_mid[ind_low] 211 | f_low[ind_low] = f_mid[ind_low] 212 | if ind_high.sum() > 0: 213 | d_high[ind_high] = d_mid[ind_high] 214 | f_high[ind_high] = f_mid[ind_high] 215 | d_mid = (d_low + d_high) / 2.0 216 | work_mask &= (d_high - d_low) > 2 * self.sdf_threshold 217 | i += 1 218 | p_mid = ray_o + ray_d * d_mid.unsqueeze(-1) 219 | f_mid = sdf(p_mid) 220 | return p_mid, d_mid, f_mid 221 | 222 | 223 | @torch.no_grad() 224 | def intersect_sphere(ray_o, ray_d, r): 225 | """ 226 | ray_o, ray_d: [..., 3] 227 | compute the depth of the intersection point between this ray and unit sphere 228 | """ 229 | # note: d1 becomes negative if this mid point is behind camera 230 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) 231 | p = ray_o + d1.unsqueeze(-1) * ray_d 232 | 233 | tmp = r * r - torch.sum(p * p, dim=-1) 234 | mask_intersect = tmp > 0.0 235 | d2 = torch.sqrt(torch.clamp(tmp, min=0.0)) / torch.norm(ray_d, dim=-1) 236 | 237 | return mask_intersect, torch.clamp(d1 - d2, min=0.0), d1 + d2 238 | 239 | 240 | class Camera(object): 241 | def __init__(self, W, H, K, W2C): 242 | """ 243 | W, H: int 244 | K, W2C: 4x4 tensor 245 | """ 246 | self.W = W 247 | self.H = H 248 | self.K = K 249 | self.W2C = W2C 250 | self.K_inv = torch.inverse(K) 251 | self.C2W = torch.inverse(W2C) 252 | self.device = self.K.device 253 | 254 | def get_rays(self, uv): 255 | """ 256 | uv: [..., 2] 257 | """ 258 | dots_sh = list(uv.shape[:-1]) 259 | 260 | uv = uv.view(-1, 2) 261 | uv = torch.cat((uv, torch.ones_like(uv[..., 0:1])), dim=-1) 262 | ray_d = torch.matmul( 263 | torch.matmul(uv, self.K_inv[:3, :3].transpose(1, 0)), 264 | self.C2W[:3, :3].transpose(1, 0), 265 | ).reshape( 266 | dots_sh 267 | + [ 268 | 3, 269 | ] 270 | ) 271 | 272 | ray_d_norm = ray_d.norm(dim=-1) 273 | ray_d = ray_d / ray_d_norm.unsqueeze(-1) 274 | 275 | ray_o = ( 276 | self.C2W[:3, 3] 277 | .unsqueeze(0) 278 | .expand(uv.shape[0], -1) 279 | .reshape( 280 | dots_sh 281 | + [ 282 | 3, 283 | ] 284 | ) 285 | ) 286 | return ray_o, ray_d, ray_d_norm 287 | 288 | def get_camera_origin(self, prefix_shape=None): 289 | ray_o = self.C2W[:3, 3] 290 | if prefix_shape is not None: 291 | prefix_shape = list(prefix_shape) 292 | ray_o = ray_o.view([1,] * len(prefix_shape) + [3,]).expand( 293 | prefix_shape 294 | + [ 295 | 3, 296 | ] 297 | ) 298 | return ray_o 299 | 300 | def get_uv(self): 301 | u, v = np.meshgrid(np.arange(self.W), np.arange(self.H)) 302 | uv = torch.from_numpy(np.stack((u, v), axis=-1).astype(np.float32)).to(self.device) + 0.5 303 | return uv 304 | 305 | def project(self, points): 306 | """ 307 | points: [..., 3] 308 | """ 309 | dots_sh = list(points.shape[:-1]) 310 | 311 | points = points.view(-1, 3) 312 | points = torch.cat([points, torch.ones_like(points[:, :1])], dim=1) 313 | uv = torch.matmul( 314 | torch.matmul(points, self.W2C.transpose(1, 0)), 315 | self.K.transpose(1, 0), 316 | ) 317 | uv = uv[:, :2] / uv[:, 2:3] 318 | 319 | uv = uv.view( 320 | dots_sh 321 | + [ 322 | 2, 323 | ] 324 | ) 325 | return uv 326 | 327 | def crop_region(self, trgt_W, trgt_H, center_crop=False, ul_corner=None, image=None): 328 | K = self.K.clone() 329 | if ul_corner is not None: 330 | ul_col, ul_row = ul_corner 331 | elif center_crop: 332 | ul_col = self.W // 2 - trgt_W // 2 333 | ul_row = self.H // 2 - trgt_H // 2 334 | else: 335 | ul_col = np.random.randint(0, self.W - trgt_W) 336 | ul_row = np.random.randint(0, self.H - trgt_H) 337 | # modify K 338 | K[0, 2] -= ul_col 339 | K[1, 2] -= ul_row 340 | 341 | camera = Camera(trgt_W, trgt_H, K, self.W2C.clone()) 342 | 343 | if image is not None: 344 | assert image.shape[0] == self.H and image.shape[1] == self.W, "image size does not match specfied size" 345 | image = image[ul_row : ul_row + trgt_H, ul_col : ul_col + trgt_W] 346 | return camera, image 347 | 348 | def resize(self, factor, image=None): 349 | trgt_H, trgt_W = int(self.H * factor), int(self.W * factor) 350 | K = self.K.clone() 351 | K[0, :3] *= trgt_W / self.W 352 | K[1, :3] *= trgt_H / self.H 353 | camera = Camera(trgt_W, trgt_H, K, self.W2C.clone()) 354 | 355 | if image is not None: 356 | device = image.device 357 | image = cv2.resize(image.detach().cpu().numpy(), (trgt_W, trgt_H), interpolation=cv2.INTER_AREA) 358 | image = torch.from_numpy(image).to(device) 359 | return camera, image 360 | 361 | 362 | @torch.no_grad() 363 | def raytrace_pixels(sdf_network, raytracer, uv, camera, mask=None, max_num_rays=200000): 364 | if mask is None: 365 | mask = torch.ones_like(uv[..., 0]).bool() 366 | 367 | dots_sh = list(uv.shape[:-1]) 368 | 369 | ray_o, ray_d, ray_d_norm = camera.get_rays(uv) 370 | sdf = lambda x: sdf_network(x)[..., 0] 371 | 372 | merge_results = None 373 | for ray_o_split, ray_d_split, ray_d_norm_split, mask_split in zip( 374 | torch.split(ray_o.view(-1, 3), max_num_rays, dim=0), 375 | torch.split(ray_d.view(-1, 3), max_num_rays, dim=0), 376 | torch.split( 377 | ray_d_norm.view( 378 | -1, 379 | ), 380 | max_num_rays, 381 | dim=0, 382 | ), 383 | torch.split( 384 | mask.view( 385 | -1, 386 | ), 387 | max_num_rays, 388 | dim=0, 389 | ), 390 | ): 391 | mask_intersect_split, min_dis_split, max_dis_split = intersect_sphere(ray_o_split, ray_d_split, r=1.0) 392 | results = raytracer( 393 | sdf, 394 | ray_o_split, 395 | ray_d_split, 396 | min_dis_split, 397 | max_dis_split, 398 | mask_intersect_split & mask_split, 399 | ) 400 | results["depth"] = results["distance"] / ray_d_norm_split 401 | 402 | if merge_results is None: 403 | merge_results = dict( 404 | [ 405 | ( 406 | x, 407 | [ 408 | results[x], 409 | ], 410 | ) 411 | for x in results.keys() 412 | if isinstance(results[x], torch.Tensor) 413 | ] 414 | ) 415 | else: 416 | for x in results.keys(): 417 | merge_results[x].append(results[x]) # gpu 418 | 419 | for x in list(merge_results.keys()): 420 | results = torch.cat(merge_results[x], dim=0).reshape( 421 | dots_sh 422 | + [ 423 | -1, 424 | ] 425 | ) 426 | if results.shape[-1] == 1: 427 | results = results[..., 0] 428 | merge_results[x] = results # gpu 429 | 430 | # append more results 431 | merge_results.update( 432 | { 433 | "uv": uv, 434 | "ray_o": ray_o, 435 | "ray_d": ray_d, 436 | "ray_d_norm": ray_d_norm, 437 | } 438 | ) 439 | return merge_results 440 | 441 | 442 | def unique(x, dim=-1): 443 | """ 444 | return: unique elements in x, and their original indices in x 445 | """ 446 | unique, inverse = torch.unique(x, return_inverse=True, dim=dim) 447 | perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device) 448 | inverse, perm = inverse.flip([dim]), perm.flip([dim]) 449 | return unique, inverse.new_empty(unique.size(dim)).scatter_(dim, inverse, perm) 450 | 451 | 452 | @torch.no_grad() 453 | def locate_edge_points( 454 | camera, walk_start_points, sdf_network, max_step, step_size, dot_threshold, max_num_rays=200000, mask=None 455 | ): 456 | """walk on the surface to locate 3d edge points with high precision""" 457 | if mask is None: 458 | mask = torch.ones_like(walk_start_points[..., 0]).bool() 459 | 460 | walk_finish_points = walk_start_points.clone() 461 | walk_edge_found_mask = mask.clone() 462 | n_valid = mask.sum() 463 | if n_valid > 0: 464 | dots_sh = list(walk_start_points.shape[:-1]) 465 | 466 | walk_finish_points_valid = [] 467 | walk_edge_found_mask_valid = [] 468 | for cur_points_split in torch.split(walk_start_points[mask].clone().view(-1, 3).detach(), max_num_rays, dim=0): 469 | walk_edge_found_mask_split = torch.zeros_like(cur_points_split[..., 0]).bool() 470 | not_found_mask_split = ~walk_edge_found_mask_split 471 | 472 | ray_o_split = camera.get_camera_origin(prefix_shape=cur_points_split.shape[:-1]) 473 | 474 | i = 0 475 | while True: 476 | cur_viewdir_split = ray_o_split[not_found_mask_split] - cur_points_split[not_found_mask_split] 477 | cur_viewdir_split = cur_viewdir_split / (cur_viewdir_split.norm(dim=-1, keepdim=True) + 1e-10) 478 | cur_sdf_split, _, cur_normal_split = sdf_network.get_all( 479 | cur_points_split[not_found_mask_split].view(-1, 3), 480 | is_training=False, 481 | ) 482 | cur_normal_split = cur_normal_split / (cur_normal_split.norm(dim=-1, keepdim=True) + 1e-10) 483 | 484 | dot_split = (cur_normal_split * cur_viewdir_split).sum(dim=-1) 485 | tmp_not_found_mask = dot_split.abs() > dot_threshold 486 | walk_edge_found_mask_split[not_found_mask_split] = ~tmp_not_found_mask 487 | not_found_mask_split = ~walk_edge_found_mask_split 488 | 489 | if i >= max_step or not_found_mask_split.sum() == 0: 490 | break 491 | 492 | cur_walkdir_split = cur_normal_split - cur_viewdir_split / dot_split.unsqueeze(-1) 493 | cur_walkdir_split = cur_walkdir_split / (cur_walkdir_split.norm(dim=-1, keepdim=True) + 1e-10) 494 | # regularize walk direction such that we don't get far away from the zero iso-surface 495 | cur_walkdir_split = cur_walkdir_split - cur_sdf_split * cur_normal_split 496 | cur_points_split[not_found_mask_split] += (step_size * cur_walkdir_split)[tmp_not_found_mask] 497 | 498 | i += 1 499 | 500 | walk_finish_points_valid.append(cur_points_split) 501 | walk_edge_found_mask_valid.append(walk_edge_found_mask_split) 502 | 503 | walk_finish_points[mask] = torch.cat(walk_finish_points_valid, dim=0) 504 | walk_edge_found_mask[mask] = torch.cat(walk_edge_found_mask_valid, dim=0) 505 | walk_finish_points = walk_finish_points.reshape( 506 | dots_sh 507 | + [ 508 | 3, 509 | ] 510 | ) 511 | walk_edge_found_mask = walk_edge_found_mask.reshape(dots_sh) 512 | 513 | edge_points = walk_finish_points[walk_edge_found_mask] 514 | edge_mask = torch.zeros(camera.H, camera.W).bool().to(walk_finish_points.device) 515 | edge_uv = torch.zeros_like(edge_points[..., :2]) 516 | update_pixels = torch.Tensor([]).long().to(walk_finish_points.device) 517 | if walk_edge_found_mask.any(): 518 | # filter out edge points out of camera's fov; 519 | # if there are multiple edge points mapping to the same pixel, only keep one 520 | edge_uv = camera.project(edge_points) 521 | update_pixels = torch.floor(edge_uv.detach()).long() 522 | update_pixels = update_pixels[:, 1] * camera.W + update_pixels[:, 0] 523 | mask = (update_pixels < camera.H * camera.W) & (update_pixels >= 0) 524 | update_pixels, edge_points, edge_uv = update_pixels[mask], edge_points[mask], edge_uv[mask] 525 | if mask.any(): 526 | cnt = update_pixels.shape[0] 527 | update_pixels, unique_idx = unique(update_pixels, dim=0) 528 | unique_idx = torch.arange(cnt, device=update_pixels.device)[unique_idx] 529 | # assert update_pixels.shape == unique_idx.shape, f"{update_pixels.shape},{unique_idx.shape}" 530 | edge_points = edge_points[unique_idx] 531 | edge_uv = edge_uv[unique_idx] 532 | 533 | edge_mask.view(-1)[update_pixels] = True 534 | # edge_cnt = edge_mask.sum() 535 | # assert ( 536 | # edge_cnt == edge_points.shape[0] 537 | # ), f"{edge_cnt},{edge_points.shape},{edge_uv.shape},{update_pixels.shape},{torch.unique(update_pixels).shape},{update_pixels.min()},{update_pixels.max()}" 538 | # assert ( 539 | # edge_cnt == edge_uv.shape[0] 540 | # ), f"{edge_cnt},{edge_points.shape},{edge_uv.shape},{update_pixels.shape},{torch.unique(update_pixels).shape}" 541 | 542 | # ic(edge_mask.shape, edge_points.shape, edge_uv.shape) 543 | results = {"edge_mask": edge_mask, "edge_points": edge_points, "edge_uv": edge_uv, "edge_pixel_idx": update_pixels} 544 | 545 | if VERBOSE_MODE: # debug 546 | edge_angles = torch.zeros_like(edge_mask).float() 547 | edge_sdf = torch.zeros_like(edge_mask).float().unsqueeze(-1) 548 | if edge_mask.any(): 549 | ray_o = camera.get_camera_origin(prefix_shape=edge_points.shape[:-1]) 550 | edge_viewdir = ray_o - edge_points 551 | edge_viewdir = edge_viewdir / (edge_viewdir.norm(dim=-1, keepdim=True) + 1e-10) 552 | with torch.enable_grad(): 553 | edge_sdf_vals, _, edge_normals = sdf_network.get_all(edge_points, is_training=False) 554 | edge_normals = edge_normals / (edge_normals.norm(dim=-1, keepdim=True) + 1e-10) 555 | edge_dot = (edge_viewdir * edge_normals).sum(dim=-1) 556 | # edge_angles[edge_mask] = torch.rad2deg(torch.acos(edge_dot)) 557 | # edge_sdf[edge_mask] = edge_sdf_vals 558 | edge_angles.view(-1)[update_pixels] = torch.rad2deg(torch.acos(edge_dot)) 559 | edge_sdf.view(-1)[update_pixels] = edge_sdf_vals.squeeze(-1) 560 | 561 | results.update( 562 | { 563 | "walk_edge_found_mask": walk_edge_found_mask, 564 | "edge_angles": edge_angles, 565 | "edge_sdf": edge_sdf, 566 | } 567 | ) 568 | 569 | return results 570 | 571 | 572 | @torch.no_grad() 573 | def raytrace_camera( 574 | camera, 575 | sdf_network, 576 | raytracer, 577 | max_num_rays=200000, 578 | fill_holes=False, 579 | detect_edges=False, 580 | ): 581 | results = raytrace_pixels(sdf_network, raytracer, camera.get_uv(), camera, max_num_rays=max_num_rays) 582 | results["depth"] *= results["convergent_mask"].float() 583 | 584 | if fill_holes: 585 | depth = results["depth"] 586 | kernel = torch.ones(3, 3).float().to(depth.device) 587 | depth = kornia.morphology.closing(depth.unsqueeze(0).unsqueeze(0), kernel).squeeze(0).squeeze(0) 588 | new_convergent_mask = depth > 1e-2 589 | update_mask = new_convergent_mask & (~results["convergent_mask"]) 590 | if update_mask.any(): 591 | results["depth"][update_mask] = depth[update_mask] 592 | results["convergent_mask"] = new_convergent_mask 593 | results["distance"] = results["depth"] * results["ray_d_norm"] 594 | results["points"] = results["ray_o"] + results["ray_d"] * results["distance"].unsqueeze(-1) 595 | 596 | if detect_edges: 597 | depth = results["depth"] 598 | convergent_mask = results["convergent_mask"] 599 | depth_grad_norm = kornia.filters.sobel(depth.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 600 | depth_edge_mask = (depth_grad_norm > 1e-2) & convergent_mask 601 | # depth_edge_mask = convergent_mask 602 | 603 | results.update( 604 | locate_edge_points( 605 | camera, 606 | results["points"], 607 | sdf_network, 608 | max_step=16, 609 | step_size=1e-3, 610 | dot_threshold=5e-2, 611 | max_num_rays=max_num_rays, 612 | mask=depth_edge_mask, 613 | ) 614 | ) 615 | results["convergent_mask"] &= ~results["edge_mask"] 616 | 617 | if VERBOSE_MODE: # debug 618 | results.update({"depth_grad_norm": depth_grad_norm, "depth_edge_mask": depth_edge_mask}) 619 | 620 | return results 621 | 622 | 623 | def render_normal_and_color( 624 | results, 625 | sdf_network, 626 | color_network_dict, 627 | render_fn, 628 | is_training=False, 629 | max_num_pts=320000, 630 | ): 631 | """ 632 | results: returned by raytrace_pixels function 633 | 634 | render interior and freespace pixels 635 | note: predicted color is black for freespace pixels 636 | """ 637 | dots_sh = list(results["convergent_mask"].shape) 638 | 639 | merge_render_results = None 640 | for points_split, ray_d_split, ray_o_split, mask_split in zip( 641 | torch.split(results["points"].view(-1, 3), max_num_pts, dim=0), 642 | torch.split(results["ray_d"].view(-1, 3), max_num_pts, dim=0), 643 | torch.split(results["ray_o"].view(-1, 3), max_num_pts, dim=0), 644 | torch.split(results["convergent_mask"].view(-1), max_num_pts, dim=0), 645 | ): 646 | if mask_split.any(): 647 | points_split, ray_d_split, ray_o_split = ( 648 | points_split[mask_split], 649 | ray_d_split[mask_split], 650 | ray_o_split[mask_split], 651 | ) 652 | sdf_split, feature_split, normal_split = sdf_network.get_all(points_split, is_training=is_training) 653 | if is_training: 654 | points_split = reparam_points(points_split, normal_split.detach(), -ray_d_split.detach(), sdf_split) 655 | # normal_split = normal_split / (normal_split.norm(dim=-1, keepdim=True) + 1e-10) 656 | else: 657 | points_split, ray_d_split, ray_o_split, normal_split, feature_split = ( 658 | torch.Tensor([]).float().cuda(), 659 | torch.Tensor([]).float().cuda(), 660 | torch.Tensor([]).float().cuda(), 661 | torch.Tensor([]).float().cuda(), 662 | torch.Tensor([]).float().cuda(), 663 | ) 664 | 665 | with torch.set_grad_enabled(is_training): 666 | render_results = render_fn( 667 | mask_split, 668 | color_network_dict, 669 | ray_o_split, 670 | ray_d_split, 671 | points_split, 672 | normal_split, 673 | feature_split, 674 | ) 675 | 676 | if merge_render_results is None: 677 | merge_render_results = dict( 678 | [ 679 | ( 680 | x, 681 | [ 682 | render_results[x], 683 | ], 684 | ) 685 | for x in render_results.keys() 686 | ] 687 | ) 688 | else: 689 | for x in render_results.keys(): 690 | merge_render_results[x].append(render_results[x]) 691 | 692 | for x in list(merge_render_results.keys()): 693 | tmp = torch.cat(merge_render_results[x], dim=0).reshape( 694 | dots_sh 695 | + [ 696 | -1, 697 | ] 698 | ) 699 | if tmp.shape[-1] == 1: 700 | tmp = tmp.squeeze(-1) 701 | merge_render_results[x] = tmp 702 | 703 | results.update(merge_render_results) 704 | 705 | 706 | def render_edge_pixels( 707 | results, 708 | camera, 709 | sdf_network, 710 | raytracer, 711 | color_network_dict, 712 | render_fn, 713 | is_training=False, 714 | ): 715 | edge_mask, edge_points, edge_uv, edge_pixel_idx = ( 716 | results["edge_mask"], 717 | results["edge_points"], 718 | results["edge_uv"], 719 | results["edge_pixel_idx"], 720 | ) 721 | edge_pixel_center = torch.floor(edge_uv) + 0.5 722 | 723 | edge_sdf, _, edge_grads = sdf_network.get_all(edge_points, is_training=is_training) 724 | edge_normals = edge_grads.detach() / (edge_grads.detach().norm(dim=-1, keepdim=True) + 1e-10) 725 | if is_training: 726 | edge_points = reparam_points(edge_points, edge_grads.detach(), edge_normals, edge_sdf) 727 | edge_uv = camera.project(edge_points) 728 | 729 | edge_normals2d = torch.matmul(edge_normals, camera.W2C[:3, :3].transpose(1, 0))[:, :2] 730 | edge_normals2d = edge_normals2d / (edge_normals2d.norm(dim=-1, keepdim=True) + 1e-10) 731 | 732 | # sample a point on both sides of the edge 733 | # approximately think of each pixel as being approximately a circle with radius 0.707=sqrt(2)/2 734 | pixel_radius = 0.707 735 | pos_side_uv = edge_pixel_center - pixel_radius * edge_normals2d 736 | neg_side_uv = edge_pixel_center + pixel_radius * edge_normals2d 737 | 738 | dot2d = torch.sum((edge_uv - edge_pixel_center) * edge_normals2d, dim=-1) 739 | alpha = 2 * torch.arccos(torch.clamp(dot2d / pixel_radius, min=0.0, max=1.0)) 740 | pos_side_weight = 1.0 - (alpha - torch.sin(alpha)) / (2.0 * np.pi) 741 | 742 | # render positive-side and negative-side colors by raytracing; speed up using edge mask 743 | pos_side_results = raytrace_pixels(sdf_network, raytracer, pos_side_uv, camera) 744 | neg_side_results = raytrace_pixels(sdf_network, raytracer, neg_side_uv, camera) 745 | render_normal_and_color(pos_side_results, sdf_network, color_network_dict, render_fn, is_training=is_training) 746 | render_normal_and_color(neg_side_results, sdf_network, color_network_dict, render_fn, is_training=is_training) 747 | # ic(pos_side_results.keys(), pos_side_results['convergent_mask'].sum()) 748 | 749 | # assign colors to edge pixels 750 | edge_color = pos_side_results["color"] * pos_side_weight.unsqueeze(-1) + neg_side_results["color"] * ( 751 | 1.0 - pos_side_weight.unsqueeze(-1) 752 | ) 753 | # results["color"][edge_mask] = edge_color 754 | # results["normal"][edge_mask] = edge_normals 755 | 756 | results["color"].view(-1, 3)[edge_pixel_idx] = edge_color 757 | # results["normal"].view(-1, 3)[edge_pixel_idx] = edge_normals 758 | results["normal"].view(-1, 3)[edge_pixel_idx] = edge_grads 759 | 760 | results["edge_pos_neg_normal"] = torch.cat( 761 | [ 762 | pos_side_results["normal"][pos_side_results["convergent_mask"]], 763 | neg_side_results["normal"][neg_side_results["convergent_mask"]], 764 | ], 765 | dim=0, 766 | ) 767 | # debug 768 | # results["uv"][edge_mask] = edge_uv.detach() 769 | # results["points"][edge_mask] = edge_points.detach() 770 | 771 | results["uv"].view(-1, 2)[edge_pixel_idx] = edge_uv.detach() 772 | results["points"].view(-1, 3)[edge_pixel_idx] = edge_points.detach() 773 | 774 | if VERBOSE_MODE: 775 | pos_side_weight_fullsize = torch.zeros_like(edge_mask).float() 776 | # pos_side_weight_fullsize[edge_mask] = pos_side_weight 777 | pos_side_weight_fullsize.view(-1)[edge_pixel_idx] = pos_side_weight 778 | 779 | pos_side_depth = torch.zeros_like(edge_mask).float() 780 | # pos_side_depth[edge_mask] = pos_side_results["depth"] 781 | pos_side_depth.view(-1)[edge_pixel_idx] = pos_side_results["depth"] 782 | neg_side_depth = torch.zeros_like(edge_mask).float() 783 | # neg_side_depth[edge_mask] = neg_side_results["depth"] 784 | neg_side_depth.view(-1)[edge_pixel_idx] = neg_side_results["depth"] 785 | 786 | pos_side_color = ( 787 | torch.zeros( 788 | list(edge_mask.shape) 789 | + [ 790 | 3, 791 | ] 792 | ) 793 | .float() 794 | .to(edge_mask.device) 795 | ) 796 | # pos_side_color[edge_mask] = pos_side_results["color"] 797 | pos_side_color.view(-1, 3)[edge_pixel_idx] = pos_side_results["color"] 798 | neg_side_color = ( 799 | torch.zeros( 800 | list(edge_mask.shape) 801 | + [ 802 | 3, 803 | ] 804 | ) 805 | .float() 806 | .to(edge_mask.device) 807 | ) 808 | # neg_side_color[edge_mask] = neg_side_results["color"] 809 | neg_side_color.view(-1, 3)[edge_pixel_idx] = neg_side_results["color"] 810 | results.update( 811 | { 812 | "edge_pos_side_weight": pos_side_weight_fullsize, 813 | "edge_normals2d": edge_normals2d, 814 | "pos_side_uv": pos_side_uv, 815 | "neg_side_uv": neg_side_uv, 816 | "edge_pos_side_depth": pos_side_depth, 817 | "edge_neg_side_depth": neg_side_depth, 818 | "edge_pos_side_color": pos_side_color, 819 | "edge_neg_side_color": neg_side_color, 820 | } 821 | ) 822 | 823 | 824 | def render_camera( 825 | camera, 826 | sdf_network, 827 | raytracer, 828 | color_network_dict, 829 | render_fn, 830 | fill_holes=True, 831 | handle_edges=True, 832 | is_training=False, 833 | ): 834 | results = raytrace_camera( 835 | camera, 836 | sdf_network, 837 | raytracer, 838 | max_num_rays=200000, 839 | fill_holes=fill_holes, 840 | detect_edges=handle_edges, 841 | ) 842 | render_normal_and_color( 843 | results, 844 | sdf_network, 845 | color_network_dict, 846 | render_fn, 847 | is_training=is_training, 848 | max_num_pts=320000, 849 | ) 850 | if handle_edges and results["edge_mask"].sum() > 0: 851 | render_edge_pixels( 852 | results, 853 | camera, 854 | sdf_network, 855 | raytracer, 856 | color_network_dict, 857 | render_fn, 858 | is_training=is_training, 859 | ) 860 | return results 861 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import logging 6 | import mcubes 7 | from icecream import ic 8 | 9 | 10 | def extract_fields(bound_min, bound_max, resolution, query_func): 11 | N = 64 12 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 13 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 14 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 15 | 16 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 17 | with torch.no_grad(): 18 | for xi, xs in enumerate(X): 19 | for yi, ys in enumerate(Y): 20 | for zi, zs in enumerate(Z): 21 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 22 | pts = torch.cat( 23 | [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], 24 | dim=-1, 25 | ) 26 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() 27 | u[ 28 | xi * N : xi * N + len(xs), 29 | yi * N : yi * N + len(ys), 30 | zi * N : zi * N + len(zs), 31 | ] = val 32 | return u 33 | 34 | 35 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): 36 | print("threshold: {}".format(threshold)) 37 | u = extract_fields(bound_min, bound_max, resolution, query_func) 38 | vertices, triangles = mcubes.marching_cubes(u, threshold) 39 | b_max_np = bound_max.detach().cpu().numpy() 40 | b_min_np = bound_min.detach().cpu().numpy() 41 | 42 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] 43 | return vertices, triangles 44 | 45 | 46 | def sample_pdf(bins, weights, n_samples, det=False): 47 | # This implementation is from NeRF 48 | # Get pdf 49 | weights = weights + 1e-5 # prevent nans 50 | pdf = weights / torch.sum(weights, -1, keepdim=True) 51 | cdf = torch.cumsum(pdf, -1) 52 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 53 | # Take uniform samples 54 | if det: 55 | u = torch.linspace(0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples) 56 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 57 | else: 58 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]) 59 | 60 | # Invert CDF 61 | u = u.contiguous() 62 | inds = torch.searchsorted(cdf, u, right=True) 63 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 64 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 65 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 66 | 67 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 68 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 69 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 70 | 71 | denom = cdf_g[..., 1] - cdf_g[..., 0] 72 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 73 | t = (u - cdf_g[..., 0]) / denom 74 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 75 | 76 | return samples 77 | 78 | 79 | class NeuSRenderer: 80 | def __init__( 81 | self, 82 | nerf, 83 | sdf_network, 84 | deviation_network, 85 | color_network, 86 | n_samples, 87 | n_importance, 88 | n_outside, 89 | up_sample_steps, 90 | perturb, 91 | ): 92 | self.nerf = nerf 93 | self.sdf_network = sdf_network 94 | self.deviation_network = deviation_network 95 | self.color_network = color_network 96 | self.n_samples = n_samples 97 | self.n_importance = n_importance 98 | self.n_outside = n_outside 99 | self.up_sample_steps = up_sample_steps 100 | self.perturb = perturb 101 | 102 | def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None): 103 | """ 104 | Render background 105 | """ 106 | batch_size, n_samples = z_vals.shape 107 | 108 | # Section length 109 | dists = z_vals[..., 1:] - z_vals[..., :-1] 110 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 111 | mid_z_vals = z_vals + dists * 0.5 112 | 113 | # Section midpoints 114 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 115 | 116 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10) 117 | pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 118 | 119 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) 120 | 121 | pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) 122 | dirs = dirs.reshape(-1, 3) 123 | 124 | density, sampled_color = nerf(pts, dirs) 125 | alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists) 126 | alpha = alpha.reshape(batch_size, n_samples) 127 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1.0 - alpha + 1e-7], -1), -1)[:, :-1] 128 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3) 129 | color = (weights[:, :, None] * sampled_color).sum(dim=1) 130 | if background_rgb is not None: 131 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True)) 132 | 133 | return { 134 | "color": color, 135 | "sampled_color": sampled_color, 136 | "alpha": alpha, 137 | "weights": weights, 138 | } 139 | 140 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s): 141 | """ 142 | Up sampling give a fixed inv_s 143 | """ 144 | batch_size, n_samples = z_vals.shape 145 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 146 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) 147 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0) 148 | sdf = sdf.reshape(batch_size, n_samples) 149 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] 150 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] 151 | mid_sdf = (prev_sdf + next_sdf) * 0.5 152 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) 153 | 154 | # ---------------------------------------------------------------------------------------------------------- 155 | # Use min value of [ cos, prev_cos ] 156 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more 157 | # robust when meeting situations like below: 158 | # 159 | # SDF 160 | # ^ 161 | # |\ -----x----... 162 | # | \ / 163 | # | x x 164 | # |---\----/-------------> 0 level 165 | # | \ / 166 | # | \/ 167 | # | 168 | # ---------------------------------------------------------------------------------------------------------- 169 | prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1) 170 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) 171 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) 172 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere 173 | 174 | dist = next_z_vals - prev_z_vals 175 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 176 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5 177 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) 178 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s) 179 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) 180 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1.0 - alpha + 1e-7], -1), -1)[:, :-1] 181 | 182 | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() 183 | return z_samples 184 | 185 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): 186 | batch_size, n_samples = z_vals.shape 187 | _, n_importance = new_z_vals.shape 188 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] 189 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1) 190 | z_vals, index = torch.sort(z_vals, dim=-1) 191 | 192 | if not last: 193 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) 194 | sdf = torch.cat([sdf, new_sdf], dim=-1) 195 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) 196 | index = index.reshape(-1) 197 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) 198 | 199 | return z_vals, sdf 200 | 201 | def render_core( 202 | self, 203 | rays_o, 204 | rays_d, 205 | z_vals, 206 | sample_dist, 207 | sdf_network, 208 | deviation_network, 209 | color_network, 210 | background_alpha=None, 211 | background_sampled_color=None, 212 | background_rgb=None, 213 | cos_anneal_ratio=0.0, 214 | ): 215 | batch_size, n_samples = z_vals.shape 216 | 217 | # Section length 218 | dists = z_vals[..., 1:] - z_vals[..., :-1] 219 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 220 | mid_z_vals = z_vals + dists * 0.5 221 | 222 | # Section midpoints 223 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 224 | dirs = rays_d[:, None, :].expand(pts.shape) 225 | 226 | pts = pts.reshape(-1, 3) 227 | dirs = dirs.reshape(-1, 3) 228 | 229 | sdf_nn_output = sdf_network(pts) 230 | sdf = sdf_nn_output[:, :1] 231 | feature_vector = sdf_nn_output[:, 1:] 232 | 233 | gradients = sdf_network.gradient(pts) 234 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3) 235 | 236 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 237 | inv_s = inv_s.expand(batch_size * n_samples, 1) 238 | 239 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 240 | 241 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 242 | # the cos value "not dead" at the beginning training iterations, for better convergence. 243 | iter_cos = -( 244 | F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + F.relu(-true_cos) * cos_anneal_ratio 245 | ) # always non-positive 246 | 247 | # Estimate signed distances at section points 248 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 249 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 250 | 251 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 252 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 253 | 254 | p = prev_cdf - next_cdf 255 | c = prev_cdf 256 | 257 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) 258 | 259 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) 260 | inside_sphere = (pts_norm < 1.0).float().detach() 261 | relax_inside_sphere = (pts_norm < 1.2).float().detach() 262 | 263 | # Render with background 264 | if background_alpha is not None: 265 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere) 266 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1) 267 | sampled_color = ( 268 | sampled_color * inside_sphere[:, :, None] 269 | + background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None] 270 | ) 271 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1) 272 | 273 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1.0 - alpha + 1e-7], -1), -1)[:, :-1] 274 | weights_sum = weights.sum(dim=-1, keepdim=True) 275 | 276 | color = (sampled_color * weights[:, :, None]).sum(dim=1) 277 | if background_rgb is not None: # Fixed background, usually black 278 | color = color + background_rgb * (1.0 - weights_sum) 279 | 280 | # Eikonal loss 281 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, dim=-1) - 1.0) ** 2 282 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5) 283 | 284 | return { 285 | "color": color, 286 | "sdf": sdf, 287 | "dists": dists, 288 | "gradients": gradients.reshape(batch_size, n_samples, 3), 289 | "s_val": 1.0 / inv_s, 290 | "mid_z_vals": mid_z_vals, 291 | "weights": weights, 292 | "cdf": c.reshape(batch_size, n_samples), 293 | "gradient_error": gradient_error, 294 | "inside_sphere": inside_sphere, 295 | } 296 | 297 | def render( 298 | self, 299 | rays_o, 300 | rays_d, 301 | near, 302 | far, 303 | perturb_overwrite=-1, 304 | background_rgb=None, 305 | cos_anneal_ratio=0.0, 306 | ): 307 | batch_size = len(rays_o) 308 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere 309 | z_vals = torch.linspace(0.0, 1.0, self.n_samples) 310 | z_vals = near + (far - near) * z_vals[None, :] 311 | 312 | z_vals_outside = None 313 | if self.n_outside > 0: 314 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) 315 | 316 | n_samples = self.n_samples 317 | perturb = self.perturb 318 | 319 | if perturb_overwrite >= 0: 320 | perturb = perturb_overwrite 321 | if perturb > 0: 322 | t_rand = torch.rand([batch_size, 1]) - 0.5 323 | z_vals = z_vals + t_rand * 2.0 / self.n_samples 324 | 325 | if self.n_outside > 0: 326 | mids = 0.5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1]) 327 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) 328 | lower = torch.cat([z_vals_outside[..., :1], mids], -1) 329 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) 330 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand 331 | 332 | if self.n_outside > 0: 333 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples 334 | 335 | background_alpha = None 336 | background_sampled_color = None 337 | 338 | # Up sample 339 | if self.n_importance > 0: 340 | with torch.no_grad(): 341 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] 342 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) 343 | 344 | for i in range(self.up_sample_steps): 345 | new_z_vals = self.up_sample( 346 | rays_o, 347 | rays_d, 348 | z_vals, 349 | sdf, 350 | self.n_importance // self.up_sample_steps, 351 | 64 * 2**i, 352 | ) 353 | z_vals, sdf = self.cat_z_vals( 354 | rays_o, 355 | rays_d, 356 | z_vals, 357 | new_z_vals, 358 | sdf, 359 | last=(i + 1 == self.up_sample_steps), 360 | ) 361 | 362 | n_samples = self.n_samples + self.n_importance 363 | 364 | # Background model 365 | if self.n_outside > 0: 366 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) 367 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) 368 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf) 369 | 370 | background_sampled_color = ret_outside["sampled_color"] 371 | background_alpha = ret_outside["alpha"] 372 | 373 | # Render core 374 | ret_fine = self.render_core( 375 | rays_o, 376 | rays_d, 377 | z_vals, 378 | sample_dist, 379 | self.sdf_network, 380 | self.deviation_network, 381 | self.color_network, 382 | background_rgb=background_rgb, 383 | background_alpha=background_alpha, 384 | background_sampled_color=background_sampled_color, 385 | cos_anneal_ratio=cos_anneal_ratio, 386 | ) 387 | 388 | color_fine = ret_fine["color"] 389 | weights = ret_fine["weights"] 390 | weights_sum = weights.sum(dim=-1, keepdim=True) 391 | gradients = ret_fine["gradients"] 392 | s_val = ret_fine["s_val"].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True) 393 | 394 | return { 395 | "color_fine": color_fine, 396 | "s_val": s_val, 397 | "cdf_fine": ret_fine["cdf"], 398 | "weight_sum": weights_sum, 399 | "weight_max": torch.max(weights, dim=-1, keepdim=True)[0], 400 | "gradients": gradients, 401 | "weights": weights, 402 | "gradient_error": ret_fine["gradient_error"], 403 | "inside_sphere": ret_fine["inside_sphere"], 404 | } 405 | 406 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0): 407 | return extract_geometry( 408 | bound_min, 409 | bound_max, 410 | resolution=resolution, 411 | threshold=threshold, 412 | query_func=lambda pts: -self.sdf_network.sdf(pts), 413 | ) 414 | -------------------------------------------------------------------------------- /models/renderer_ggx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | 6 | 7 | ### https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/microfacet.h#L477 8 | def smithG1(cosTheta, alpha): 9 | sinTheta = torch.sqrt(1.0 - cosTheta * cosTheta) 10 | tanTheta = sinTheta / (cosTheta + 1e-10) 11 | root = alpha * tanTheta 12 | return 2.0 / (1.0 + torch.hypot(root, torch.ones_like(root))) 13 | 14 | 15 | class GGXColocatedRenderer(nn.Module): 16 | def __init__(self, use_cuda=False): 17 | super().__init__() 18 | 19 | self.MTS_TRANS = torch.from_numpy( 20 | np.loadtxt(os.path.join(os.path.dirname(os.path.abspath(__file__)), "ggx/ext_mts_rtrans_data.txt")).astype( 21 | np.float32 22 | ) 23 | ) # 5000 entries, external IOR 24 | self.MTS_DIFF_TRANS = torch.from_numpy( 25 | np.loadtxt( 26 | os.path.join(os.path.dirname(os.path.abspath(__file__)), "ggx/int_mts_diff_rtrans_data.txt") 27 | ).astype(np.float32) 28 | ) # 50 entries, internal IOR 29 | self.num_theta_samples = 100 30 | self.num_alpha_samples = 50 31 | 32 | if use_cuda: 33 | self.MTS_TRANS = self.MTS_TRANS.cuda() 34 | self.MTS_DIFF_TRANS = self.MTS_DIFF_TRANS.cuda() 35 | 36 | def forward(self, light, distance, normal, viewdir, diffuse_albedo, specular_albedo, alpha): 37 | """ 38 | light: 39 | distance: [..., 1] 40 | normal, viewdir: [..., 3]; both normal and viewdir point away from objects 41 | diffuse_albedo, specular_albedo: [..., 3] 42 | alpha: [..., 1]; roughness 43 | """ 44 | # decay light according to squared-distance falloff 45 | light_intensity = light / (distance * distance + 1e-10) 46 | 47 | # = = in colocated setting 48 | dot = torch.sum(viewdir * normal, dim=-1, keepdims=True) 49 | dot = torch.clamp(dot, min=0.00001, max=0.99999) # must be very precise; cannot be 0.999 50 | # default value of IOR['polypropylene'] / IOR['air']. 51 | m_eta = 1.48958738 52 | m_invEta2 = 1.0 / (m_eta * m_eta) 53 | 54 | # clamp alpha for numeric stability 55 | alpha = torch.clamp(alpha, min=0.0001) 56 | 57 | # specular term: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/roughplastic.cpp#L347 58 | ## compute GGX NDF: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/microfacet.h#L191 59 | cosTheta2 = dot * dot 60 | root = cosTheta2 + (1.0 - cosTheta2) / (alpha * alpha + 1e-10) 61 | D = 1.0 / (np.pi * alpha * alpha * root * root + 1e-10) 62 | ## compute fresnel: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/libcore/util.cpp#L651 63 | # F = 0.04 64 | F = 0.03867 65 | 66 | ## compute shadowing term: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/microfacet.h#L520 67 | G = smithG1(dot, alpha) ** 2 # [..., 1] 68 | 69 | specular_rgb = light_intensity * specular_albedo * F * D * G / (4.0 * dot + 1e-10) 70 | 71 | # diffuse term: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/roughplastic.cpp#L367 72 | ## compute T12: : https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/rtrans.h#L183 73 | ### data_file: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/rtrans.h#L93 74 | ### assume eta is fixed 75 | warpedCosTheta = dot**0.25 76 | alphaMin, alphaMax = 0, 4 77 | warpedAlpha = ((alpha - alphaMin) / (alphaMax - alphaMin)) ** 0.25 # [..., 1] 78 | tx = torch.floor(warpedCosTheta * self.num_theta_samples).long() 79 | ty = torch.floor(warpedAlpha * self.num_alpha_samples).long() 80 | t_idx = ty * self.num_theta_samples + tx 81 | 82 | dots_sh = list(t_idx.shape[:-1]) 83 | data = self.MTS_TRANS.view([1,] * len(dots_sh) + [-1,]).expand( 84 | dots_sh 85 | + [ 86 | -1, 87 | ] 88 | ) 89 | 90 | t_idx = torch.clamp(t_idx, min=0, max=data.shape[-1] - 1).long() # important 91 | T12 = torch.clamp(torch.gather(input=data, index=t_idx, dim=-1), min=0.0, max=1.0) 92 | T21 = T12 # colocated setting 93 | 94 | ## compute Fdr: https://github.com/mitsuba-renderer/mitsuba/blob/cfeb7766e7a1513492451f35dc65b86409655a7b/src/bsdfs/rtrans.h#L249 95 | t_idx = torch.floor(warpedAlpha * self.num_alpha_samples).long() 96 | data = self.MTS_DIFF_TRANS.view([1,] * len(dots_sh) + [-1,]).expand( 97 | dots_sh 98 | + [ 99 | -1, 100 | ] 101 | ) 102 | t_idx = torch.clamp(t_idx, min=0, max=data.shape[-1] - 1).long() # important 103 | Fdr = torch.clamp(1.0 - torch.gather(input=data, index=t_idx, dim=-1), min=0.0, max=1.0) # [..., 1] 104 | 105 | diffuse_rgb = light_intensity * (diffuse_albedo / (1.0 - Fdr + 1e-10) / np.pi) * dot * T12 * T21 * m_invEta2 106 | ret = {"diffuse_rgb": diffuse_rgb, "specular_rgb": specular_rgb, "rgb": diffuse_rgb + specular_rgb} 107 | return ret 108 | -------------------------------------------------------------------------------- /readme_resources/assets_lowres.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/readme_resources/assets_lowres.png -------------------------------------------------------------------------------- /readme_resources/inputs_outputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/readme_resources/inputs_outputs.png -------------------------------------------------------------------------------- /render_surface.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import json 7 | import imageio 8 | imageio.plugins.freeimage.download() 9 | from torch.utils.tensorboard import SummaryWriter 10 | import configargparse 11 | from icecream import ic 12 | import glob 13 | import shutil 14 | import traceback 15 | 16 | from models.fields import SDFNetwork, RenderingNetwork 17 | from models.raytracer import RayTracer, Camera, render_camera 18 | from models.renderer_ggx import GGXColocatedRenderer 19 | from models.image_losses import PyramidL2Loss, ssim_loss_fn 20 | from models.export_mesh import export_mesh 21 | from models.export_materials import export_materials 22 | 23 | ###### arguments 24 | def config_parser(): 25 | parser = configargparse.ArgumentParser() 26 | parser.add_argument("--data_dir", type=str, default=None, help="input data directory") 27 | parser.add_argument("--out_dir", type=str, default=None, help="output directory") 28 | parser.add_argument("--neus_ckpt_fpath", type=str, default=None, help="checkpoint to load") 29 | parser.add_argument("--num_iters", type=int, default=100001, help="number of iterations") 30 | parser.add_argument("--patch_size", type=int, default=128, help="width and height of the rendered patches") 31 | parser.add_argument("--eik_weight", type=float, default=0.1, help="weight for eikonal loss") 32 | parser.add_argument("--ssim_weight", type=float, default=1.0, help="weight for ssim loss") 33 | parser.add_argument("--roughrange_weight", type=float, default=0.1, help="weight for roughness range loss") 34 | 35 | parser.add_argument("--plot_image_name", type=str, default=None, help="image to plot during training") 36 | parser.add_argument("--no_edgesample", action="store_true", help="whether to disable edge sampling") 37 | parser.add_argument( 38 | "--inv_gamma_gt", action="store_true", help="whether to inverse gamma correct the ground-truth photos" 39 | ) 40 | parser.add_argument("--gamma_pred", action="store_true", help="whether to gamma correct the predictions") 41 | parser.add_argument( 42 | "--is_metal", 43 | action="store_true", 44 | help="whether the object of interest is made of metals or the scene contains metals", 45 | ) 46 | parser.add_argument("--init_light_scale", type=float, default=8.0, help="scaling parameters for light") 47 | parser.add_argument( 48 | "--export_all", 49 | action="store_true", 50 | help="whether to export meshes and uv textures", 51 | ) 52 | parser.add_argument( 53 | "--render_all", 54 | action="store_true", 55 | help="whether to render the input image set", 56 | ) 57 | return parser 58 | 59 | 60 | parser = config_parser() 61 | args = parser.parse_args() 62 | ic(args) 63 | 64 | ###### back up arguments and code scripts 65 | os.makedirs(args.out_dir, exist_ok=True) 66 | parser.write_config_file( 67 | args, 68 | [ 69 | os.path.join(args.out_dir, "args.txt"), 70 | ], 71 | ) 72 | 73 | 74 | ###### rendering functions 75 | def get_materials(color_network_dict, points, normals, features, is_metal=args.is_metal): 76 | diffuse_albedo = color_network_dict["diffuse_albedo_network"](points, normals, -normals, features).abs()[ 77 | ..., [2, 1, 0] 78 | ] 79 | specular_albedo = color_network_dict["specular_albedo_network"](points, normals, None, features).abs() 80 | if not is_metal: 81 | specular_albedo = torch.mean(specular_albedo, dim=-1, keepdim=True).expand_as(specular_albedo) 82 | specular_roughness = color_network_dict["specular_roughness_network"](points, normals, None, features).abs() + 0.01 83 | return diffuse_albedo, specular_albedo, specular_roughness 84 | 85 | 86 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features): 87 | dots_sh = list(interior_mask.shape) 88 | rgb = torch.zeros( 89 | dots_sh 90 | + [ 91 | 3, 92 | ], 93 | dtype=torch.float32, 94 | device=interior_mask.device, 95 | ) 96 | diffuse_rgb = rgb.clone() 97 | specular_rgb = rgb.clone() 98 | diffuse_albedo = rgb.clone() 99 | specular_albedo = rgb.clone() 100 | specular_roughness = rgb[..., 0].clone() 101 | normals_pad = rgb.clone() 102 | if interior_mask.any(): 103 | normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-10) 104 | interior_diffuse_albedo, interior_specular_albedo, interior_specular_roughness = get_materials( 105 | color_network_dict, points, normals, features 106 | ) 107 | results = ggx_renderer( 108 | color_network_dict["point_light_network"](), 109 | (points - ray_o).norm(dim=-1, keepdim=True), 110 | normals, 111 | -ray_d, 112 | interior_diffuse_albedo, 113 | interior_specular_albedo, 114 | interior_specular_roughness, 115 | ) 116 | rgb[interior_mask] = results["rgb"] 117 | diffuse_rgb[interior_mask] = results["diffuse_rgb"] 118 | specular_rgb[interior_mask] = results["specular_rgb"] 119 | diffuse_albedo[interior_mask] = interior_diffuse_albedo 120 | specular_albedo[interior_mask] = interior_specular_albedo 121 | specular_roughness[interior_mask] = interior_specular_roughness.squeeze(-1) 122 | normals_pad[interior_mask] = normals 123 | 124 | return { 125 | "color": rgb, 126 | "diffuse_color": diffuse_rgb, 127 | "specular_color": specular_rgb, 128 | "diffuse_albedo": diffuse_albedo, 129 | "specular_albedo": specular_albedo, 130 | "specular_roughness": specular_roughness, 131 | "normal": normals_pad, 132 | } 133 | 134 | 135 | ###### network specifications 136 | sdf_network = SDFNetwork( 137 | d_in=3, 138 | d_out=257, 139 | d_hidden=256, 140 | n_layers=8, 141 | skip_in=[ 142 | 4, 143 | ], 144 | multires=6, 145 | bias=0.5, 146 | scale=1.0, 147 | geometric_init=True, 148 | weight_norm=True, 149 | ).cuda() 150 | raytracer = RayTracer() 151 | 152 | 153 | class PointLightNetwork(nn.Module): 154 | def __init__(self): 155 | super().__init__() 156 | self.register_parameter("light", nn.Parameter(torch.tensor(5.0))) 157 | 158 | def forward(self): 159 | return self.light 160 | 161 | def set_light(self, light): 162 | self.light.data.fill_(light) 163 | 164 | def get_light(self): 165 | return self.light.data.clone().detach() 166 | 167 | 168 | color_network_dict = { 169 | "color_network": RenderingNetwork( 170 | d_in=9, 171 | d_out=3, 172 | d_feature=256, 173 | d_hidden=256, 174 | n_layers=4, 175 | multires_view=4, 176 | mode="idr", 177 | squeeze_out=True, 178 | ).cuda(), 179 | "diffuse_albedo_network": RenderingNetwork( 180 | d_in=9, 181 | d_out=3, 182 | d_feature=256, 183 | d_hidden=256, 184 | n_layers=8, 185 | multires=10, 186 | multires_view=4, 187 | mode="idr", 188 | squeeze_out=True, 189 | skip_in=(4,), 190 | ).cuda(), 191 | "specular_albedo_network": RenderingNetwork( 192 | d_in=6, 193 | d_out=3, 194 | d_feature=256, 195 | d_hidden=256, 196 | n_layers=4, 197 | multires=6, 198 | multires_view=-1, 199 | mode="no_view_dir", 200 | squeeze_out=False, 201 | output_bias=0.4, 202 | output_scale=0.1, 203 | ).cuda(), 204 | "specular_roughness_network": RenderingNetwork( 205 | d_in=6, 206 | d_out=1, 207 | d_feature=256, 208 | d_hidden=256, 209 | n_layers=4, 210 | multires=6, 211 | multires_view=-1, 212 | mode="no_view_dir", 213 | squeeze_out=False, 214 | output_bias=0.1, 215 | output_scale=0.1, 216 | ).cuda(), 217 | "point_light_network": PointLightNetwork().cuda(), 218 | } 219 | 220 | ###### optimizer specifications 221 | sdf_optimizer = torch.optim.Adam(sdf_network.parameters(), lr=1e-5) 222 | color_optimizer_dict = { 223 | "color_network": torch.optim.Adam(color_network_dict["color_network"].parameters(), lr=1e-4), 224 | "diffuse_albedo_network": torch.optim.Adam(color_network_dict["diffuse_albedo_network"].parameters(), lr=1e-4), 225 | "specular_albedo_network": torch.optim.Adam(color_network_dict["specular_albedo_network"].parameters(), lr=1e-4), 226 | "specular_roughness_network": torch.optim.Adam( 227 | color_network_dict["specular_roughness_network"].parameters(), lr=1e-4 228 | ), 229 | "point_light_network": torch.optim.Adam(color_network_dict["point_light_network"].parameters(), lr=1e-2), 230 | } 231 | 232 | ###### loss specifications 233 | ggx_renderer = GGXColocatedRenderer(use_cuda=True) 234 | pyramidl2_loss_fn = PyramidL2Loss(use_cuda=True) 235 | 236 | ###### load dataset 237 | def to8b(x): 238 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8) 239 | 240 | 241 | def load_datadir(datadir): 242 | cam_dict = json.load(open(os.path.join(datadir, "cam_dict_norm.json"))) 243 | imgnames = list(cam_dict.keys()) 244 | try: 245 | imgnames = sorted(imgnames, key=lambda x: int(x[:-4])) 246 | except: 247 | imgnames = sorted(imgnames) 248 | 249 | image_fpaths = [] 250 | gt_images = [] 251 | Ks = [] 252 | W2Cs = [] 253 | for x in imgnames: 254 | fpath = os.path.join(datadir, "image", x) 255 | assert fpath[-4:] in [".jpg", ".png"], "must use ldr images as inputs" 256 | im = imageio.imread(fpath).astype(np.float32) / 255.0 257 | K = np.array(cam_dict[x]["K"]).reshape((4, 4)).astype(np.float32) 258 | W2C = np.array(cam_dict[x]["W2C"]).reshape((4, 4)).astype(np.float32) 259 | 260 | image_fpaths.append(fpath) 261 | gt_images.append(torch.from_numpy(im)) 262 | Ks.append(torch.from_numpy(K)) 263 | W2Cs.append(torch.from_numpy(W2C)) 264 | gt_images = torch.stack(gt_images, dim=0) 265 | Ks = torch.stack(Ks, dim=0) 266 | W2Cs = torch.stack(W2Cs, dim=0) 267 | return image_fpaths, gt_images, Ks, W2Cs 268 | 269 | 270 | image_fpaths, gt_images, Ks, W2Cs = load_datadir(args.data_dir) 271 | cameras = [ 272 | Camera(W=gt_images[i].shape[1], H=gt_images[i].shape[0], K=Ks[i].cuda(), W2C=W2Cs[i].cuda()) 273 | for i in range(gt_images.shape[0]) 274 | ] 275 | ic(len(image_fpaths), gt_images.shape, Ks.shape, W2Cs.shape, len(cameras)) 276 | 277 | ###### initialization using neus 278 | ic(args.neus_ckpt_fpath) 279 | if os.path.isfile(args.neus_ckpt_fpath): 280 | ic(f"Loading from neus checkpoint: {args.neus_ckpt_fpath}") 281 | ckpt = torch.load(args.neus_ckpt_fpath, map_location=torch.device("cuda")) 282 | try: 283 | sdf_network.load_state_dict(ckpt["sdf_network_fine"]) 284 | color_network_dict["diffuse_albedo_network"].load_state_dict(ckpt["color_network_fine"]) 285 | except: 286 | traceback.print_exc() 287 | # ic("Failed to initialize diffuse_albedo_network from checkpoint: ", args.neus_ckpt_fpath) 288 | dist = np.median([torch.norm(cameras[i].get_camera_origin()).item() for i in range(len(cameras))]) 289 | init_light = args.init_light_scale * dist * dist 290 | color_network_dict["point_light_network"].set_light(init_light) 291 | 292 | #### load pretrained checkpoints 293 | start_step = -1 294 | ckpt_fpaths = glob.glob(os.path.join(args.out_dir, "ckpt_*.pth")) 295 | if len(ckpt_fpaths) > 0: 296 | path2step = lambda x: int(os.path.basename(x)[len("ckpt_") : -4]) 297 | ckpt_fpaths = sorted(ckpt_fpaths, key=path2step) 298 | ckpt_fpath = ckpt_fpaths[-1] 299 | start_step = path2step(ckpt_fpath) 300 | ic("Reloading from checkpoint: ", ckpt_fpath) 301 | ckpt = torch.load(ckpt_fpath, map_location=torch.device("cuda")) 302 | sdf_network.load_state_dict(ckpt["sdf_network"]) 303 | for x in list(color_network_dict.keys()): 304 | color_network_dict[x].load_state_dict(ckpt[x]) 305 | # logim_names = [os.path.basename(x) for x in glob.glob(os.path.join(args.out_dir, "logim_*.png"))] 306 | # start_step = sorted([int(x[len("logim_") : -4]) for x in logim_names])[-1] 307 | ic(dist, color_network_dict["point_light_network"].light.data) 308 | ic(start_step) 309 | 310 | 311 | ###### export mesh and materials 312 | blender_fpath = "./blender-3.1.0-linux-x64/blender" 313 | if not os.path.isfile(blender_fpath): 314 | os.system( 315 | "wget https://mirror.clarkson.edu/blender/release/Blender3.1/blender-3.1.0-linux-x64.tar.xz && \ 316 | tar -xvf blender-3.1.0-linux-x64.tar.xz" 317 | ) 318 | 319 | 320 | def export_mesh_and_materials(export_out_dir, sdf_network, color_network_dict): 321 | ic(f"Exporting mesh and materials to: {export_out_dir}") 322 | sdf_fn = lambda x: sdf_network(x)[..., 0] 323 | ic("Exporting mesh and uv...") 324 | with torch.no_grad(): 325 | export_mesh(sdf_fn, os.path.join(export_out_dir, "mesh.obj")) 326 | os.system( 327 | f"{blender_fpath} --background --python models/export_uv.py {os.path.join(export_out_dir, 'mesh.obj')} {os.path.join(export_out_dir, 'mesh.obj')}" 328 | ) 329 | 330 | class MaterialPredictor(nn.Module): 331 | def __init__(self, sdf_network, color_network_dict): 332 | super().__init__() 333 | self.sdf_network = sdf_network 334 | self.color_network_dict = color_network_dict 335 | 336 | def forward(self, points): 337 | _, features, normals = self.sdf_network.get_all(points, is_training=False) 338 | normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-10) 339 | diffuse_albedo, specular_albedo, specular_roughness = get_materials( 340 | color_network_dict, points, normals, features 341 | ) 342 | return diffuse_albedo, specular_albedo, specular_roughness 343 | 344 | ic("Exporting materials...") 345 | material_predictor = MaterialPredictor(sdf_network, color_network_dict) 346 | with torch.no_grad(): 347 | export_materials(os.path.join(export_out_dir, "mesh.obj"), material_predictor, export_out_dir) 348 | 349 | ic(f"Exported mesh and materials to: {export_out_dir}") 350 | 351 | 352 | if args.export_all: 353 | export_out_dir = os.path.join(args.out_dir, f"mesh_and_materials_{start_step}") 354 | os.makedirs(export_out_dir, exist_ok=True) 355 | export_mesh_and_materials(export_out_dir, sdf_network, color_network_dict) 356 | exit(0) 357 | 358 | 359 | ###### render all images 360 | if args.render_all: 361 | render_out_dir = os.path.join(args.out_dir, f"render_{os.path.basename(args.data_dir)}_{start_step}") 362 | os.makedirs(render_out_dir, exist_ok=True) 363 | ic(f"Rendering images to: {render_out_dir}") 364 | n_cams = len(cameras) 365 | for i in tqdm.tqdm(range(n_cams)): 366 | cam, impath = cameras[i], image_fpaths[i] 367 | results = render_camera( 368 | cam, 369 | sdf_network, 370 | raytracer, 371 | color_network_dict, 372 | render_fn, 373 | fill_holes=True, 374 | handle_edges=True, 375 | is_training=False, 376 | ) 377 | if args.gamma_pred: 378 | results["color"] = torch.pow(results["color"] + 1e-6, 1.0 / 2.2) 379 | for x in list(results.keys()): 380 | results[x] = results[x].detach().cpu().numpy() 381 | color_im = results["color"] 382 | imageio.imwrite(os.path.join(render_out_dir, os.path.basename(impath)), to8b(color_im)) 383 | exit(0) 384 | 385 | ###### training 386 | fill_holes = False 387 | handle_edges = not args.no_edgesample 388 | is_training = True 389 | if args.inv_gamma_gt: 390 | ic("linearizing ground-truth images using inverse gamma correction") 391 | gt_images = torch.pow(gt_images, 2.2) 392 | 393 | ic(fill_holes, handle_edges, is_training, args.inv_gamma_gt) 394 | writer = SummaryWriter(log_dir=os.path.join(args.out_dir, "logs")) 395 | 396 | for global_step in tqdm.tqdm(range(start_step + 1, args.num_iters)): 397 | sdf_optimizer.zero_grad() 398 | for x in color_optimizer_dict.keys(): 399 | color_optimizer_dict[x].zero_grad() 400 | 401 | idx = np.random.randint(0, gt_images.shape[0]) 402 | camera_crop, gt_color_crop = cameras[idx].crop_region( 403 | trgt_W=args.patch_size, trgt_H=args.patch_size, image=gt_images[idx] 404 | ) 405 | 406 | results = render_camera( 407 | camera_crop, 408 | sdf_network, 409 | raytracer, 410 | color_network_dict, 411 | render_fn, 412 | fill_holes=fill_holes, 413 | handle_edges=handle_edges, 414 | is_training=is_training, 415 | ) 416 | if args.gamma_pred: 417 | results["color"] = torch.pow(results["color"] + 1e-6, 1.0 / 2.2) 418 | results["diffuse_color"] = torch.pow(results["diffuse_color"] + 1e-6, 1.0 / 2.2) 419 | results["specular_color"] = torch.clamp(results["color"] - results["diffuse_color"], min=0.0) 420 | 421 | mask = results["convergent_mask"] 422 | if handle_edges: 423 | mask = mask | results["edge_mask"] 424 | 425 | img_loss = torch.Tensor([0.0]).cuda() 426 | img_l2_loss = torch.Tensor([0.0]).cuda() 427 | img_ssim_loss = torch.Tensor([0.0]).cuda() 428 | roughrange_loss = torch.Tensor([0.0]).cuda() 429 | 430 | eik_points = torch.empty(camera_crop.H * camera_crop.W // 2, 3).cuda().float().uniform_(-1.0, 1.0) 431 | eik_grad = sdf_network.gradient(eik_points).view(-1, 3) 432 | eik_cnt = eik_grad.shape[0] 433 | eik_loss = ((eik_grad.norm(dim=-1) - 1) ** 2).sum() 434 | if mask.any(): 435 | pred_img = results["color"].permute(2, 0, 1).unsqueeze(0) 436 | gt_img = gt_color_crop.permute(2, 0, 1).unsqueeze(0).to(pred_img.device) 437 | img_l2_loss = pyramidl2_loss_fn(pred_img, gt_img) 438 | img_ssim_loss = args.ssim_weight * ssim_loss_fn(pred_img, gt_img, mask.unsqueeze(0).unsqueeze(0)) 439 | img_loss = img_l2_loss + img_ssim_loss 440 | 441 | eik_grad = results["normal"][mask] 442 | eik_cnt += eik_grad.shape[0] 443 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum() 444 | if "edge_pos_neg_normal" in results: 445 | eik_grad = results["edge_pos_neg_normal"] 446 | eik_cnt += eik_grad.shape[0] 447 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum() 448 | 449 | roughness = results["specular_roughness"][mask] 450 | roughness = roughness[roughness > 0.5] 451 | if roughness.numel() > 0: 452 | roughrange_loss = (roughness - 0.5).mean() * args.roughrange_weight 453 | eik_loss = eik_loss / eik_cnt * args.eik_weight 454 | 455 | loss = img_loss + eik_loss + roughrange_loss 456 | loss.backward() 457 | sdf_optimizer.step() 458 | for x in color_optimizer_dict.keys(): 459 | color_optimizer_dict[x].step() 460 | 461 | if global_step % 50 == 0: 462 | writer.add_scalar("loss/loss", loss, global_step) 463 | writer.add_scalar("loss/img_loss", img_loss, global_step) 464 | writer.add_scalar("loss/img_l2_loss", img_l2_loss, global_step) 465 | writer.add_scalar("loss/img_ssim_loss", img_ssim_loss, global_step) 466 | writer.add_scalar("loss/eik_loss", eik_loss, global_step) 467 | writer.add_scalar("loss/roughrange_loss", roughrange_loss, global_step) 468 | writer.add_scalar("light", color_network_dict["point_light_network"].get_light()) 469 | 470 | if global_step % 1000 == 0: 471 | torch.save( 472 | dict( 473 | [ 474 | ("sdf_network", sdf_network.state_dict()), 475 | ] 476 | + [(x, color_network_dict[x].state_dict()) for x in color_network_dict.keys()] 477 | ), 478 | os.path.join(args.out_dir, f"ckpt_{global_step}.pth"), 479 | ) 480 | 481 | if global_step % 500 == 0: 482 | ic( 483 | args.out_dir, 484 | global_step, 485 | loss.item(), 486 | img_loss.item(), 487 | img_l2_loss.item(), 488 | img_ssim_loss.item(), 489 | eik_loss.item(), 490 | roughrange_loss.item(), 491 | color_network_dict["point_light_network"].get_light().item(), 492 | ) 493 | 494 | for x in list(results.keys()): 495 | del results[x] 496 | 497 | idx = 0 498 | if args.plot_image_name is not None: 499 | while idx < len(image_fpaths): 500 | if args.plot_image_name in image_fpaths[idx]: 501 | break 502 | idx += 1 503 | 504 | camera_resize, gt_color_resize = cameras[idx].resize(factor=0.25, image=gt_images[idx]) 505 | results = render_camera( 506 | camera_resize, 507 | sdf_network, 508 | raytracer, 509 | color_network_dict, 510 | render_fn, 511 | fill_holes=fill_holes, 512 | handle_edges=handle_edges, 513 | is_training=False, 514 | ) 515 | if args.gamma_pred: 516 | results["color"] = torch.pow(results["color"] + 1e-6, 1.0 / 2.2) 517 | results["diffuse_color"] = torch.pow(results["diffuse_color"] + 1e-6, 1.0 / 2.2) 518 | results["specular_color"] = torch.clamp(results["color"] - results["diffuse_color"], min=0.0) 519 | for x in list(results.keys()): 520 | results[x] = results[x].detach().cpu().numpy() 521 | 522 | gt_color_im = gt_color_resize.detach().cpu().numpy() 523 | color_im = results["color"] 524 | diffuse_color_im = results["diffuse_color"] 525 | specular_color_im = results["specular_color"] 526 | normal = results["normal"] 527 | normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-10) 528 | normal_im = (normal + 1.0) / 2.0 529 | edge_mask_im = np.tile(results["edge_mask"][:, :, np.newaxis], (1, 1, 3)) 530 | diffuse_albedo_im = results["diffuse_albedo"] 531 | specular_albedo_im = results["specular_albedo"] 532 | specular_roughness_im = np.tile(results["specular_roughness"][:, :, np.newaxis], (1, 1, 3)) 533 | if args.inv_gamma_gt: 534 | gt_color_im = np.power(gt_color_im + 1e-6, 1.0 / 2.2) 535 | color_im = np.power(color_im + 1e-6, 1.0 / 2.2) 536 | diffuse_color_im = np.power(diffuse_color_im + 1e-6, 1.0 / 2.2) 537 | specular_color_im = color_im - diffuse_color_im 538 | 539 | row1 = np.concatenate([gt_color_im, normal_im, edge_mask_im], axis=1) 540 | row2 = np.concatenate([color_im, diffuse_color_im, specular_color_im], axis=1) 541 | row3 = np.concatenate([diffuse_albedo_im, specular_albedo_im, specular_roughness_im], axis=1) 542 | im = np.concatenate((row1, row2, row3), axis=0) 543 | imageio.imwrite(os.path.join(args.out_dir, f"logim_{global_step}.png"), to8b(im)) 544 | 545 | 546 | ###### export mesh and materials 547 | export_out_dir = os.path.join(args.out_dir, f"mesh_and_materials_{global_step}") 548 | os.makedirs(export_out_dir, exist_ok=True) 549 | export_mesh_and_materials(export_out_dir, sdf_network, color_network_dict) 550 | -------------------------------------------------------------------------------- /render_synthetic_data/render_rgb_flash_mat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import shutil 5 | import imageio 6 | 7 | imageio.plugins.freeimage.download() 8 | 9 | 10 | asset_dir = 'path/to/synthetic_assets' 11 | out_dir = 'path/to/output_folder' 12 | 13 | for scene in os.listdir(asset_dir): 14 | in_scene_dir = os.path.join(asset_dir, scene) 15 | out_scene_dir = os.path.join(out_dir, scene) 16 | os.makedirs(out_scene_dir, exist_ok=True) 17 | 18 | light = 20. 19 | with open(os.path.join(out_scene_dir, 'light.txt'), 'w') as fp: 20 | fp.write(f'{light}\n') 21 | 22 | for split in ['train', 'test']: 23 | out_split_dir = os.path.join(out_scene_dir, split) 24 | os.makedirs(os.path.join(out_split_dir, 'image'), exist_ok=True) 25 | 26 | cam_dict_fpath = os.path.join(asset_dir, f'{split}_cam_dict_norm.json') 27 | shutil.copy2(cam_dict_fpath, os.path.join(out_split_dir, 'cam_dict_norm.json')) 28 | 29 | cam_dict = json.load(open(cam_dict_fpath)) 30 | img_list = list(cam_dict.keys()) 31 | img_list = sorted(img_list, key=lambda x: int(x[:-4])) 32 | 33 | use_docker = True 34 | 35 | for index, img_name in enumerate(img_list): 36 | mesh = os.path.join(in_scene_dir, "model.obj") 37 | d_albedo = os.path.join(in_scene_dir, "diffuse_albedo.exr") 38 | s_albedo = os.path.join(in_scene_dir, "specular_albedo.exr") 39 | s_roughness = os.path.join(in_scene_dir, "specular_roughness.exr") 40 | 41 | K = np.array(cam_dict[img_name]["K"]).reshape((4, 4)) 42 | focal = K[0, 0] 43 | width, height = cam_dict[img_name]["img_size"] 44 | fov = np.rad2deg(np.arctan(width / 2.0 / focal) * 2.0) 45 | w2c = np.array(cam_dict[img_name]["W2C"]).reshape((4, 4)) 46 | # check if unit aspect ratio 47 | assert np.isclose(K[0, 0] - K[1, 1], 0.0), f"{K[0,0]} != {K[1,1]}" 48 | 49 | c2w = np.linalg.inv(w2c) 50 | c2w[:3, :2] *= -1 # mitsuba camera coordinate system: x-->left, y-->up, z-->scene 51 | origin = c2w[:3, 3] 52 | c2w = " ".join([str(x) for x in c2w.flatten().tolist()]) 53 | 54 | out_fpath = os.path.join(out_split_dir, 'image', img_name[:-4] + ".exr") 55 | cmd = ( 56 | 'mitsuba -b 10 rgb_flash_hdr_mat.xml -D fov={} -D width={} -D height={} -D c2w="{}" ' 57 | "-D mesh={} -D d_albedo={} -D s_albedo={} -D s_roughness={} " 58 | "-D light={} " 59 | "-D px={} -D py={} -D pz={} " 60 | "-o {} ".format( 61 | fov, 62 | width, 63 | height, 64 | c2w, 65 | mesh, 66 | d_albedo, 67 | s_albedo, 68 | s_roughness, 69 | light, 70 | origin[0], 71 | origin[1], 72 | origin[2], 73 | out_fpath, 74 | ) 75 | ) 76 | 77 | if use_docker: 78 | docker_prefix = "docker run -w `pwd` --rm -v `pwd`:`pwd` -v /phoenix:/phoenix ninjaben/mitsuba-rgb " 79 | cmd = docker_prefix + cmd 80 | 81 | os.system(cmd) 82 | os.system("rm mitsuba.*.log") 83 | 84 | -------------------------------------------------------------------------------- /render_synthetic_data/rgb_flash_hdr_mat.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /render_volume.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import argparse 5 | import numpy as np 6 | import cv2 as cv 7 | import trimesh 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | from shutil import copyfile 12 | from icecream import ic 13 | from tqdm import tqdm 14 | from pyhocon import ConfigFactory 15 | from models.dataset import Dataset 16 | from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF 17 | from models.renderer import NeuSRenderer 18 | 19 | 20 | class Runner: 21 | def __init__(self, conf_path, mode="train", case="CASE_NAME", is_continue=False): 22 | self.device = torch.device("cuda") 23 | 24 | # Configuration 25 | self.conf_path = conf_path 26 | f = open(self.conf_path) 27 | conf_text = f.read() 28 | conf_text = conf_text.replace("CASE_NAME", case) 29 | f.close() 30 | 31 | self.conf = ConfigFactory.parse_string(conf_text) 32 | self.conf["dataset.data_dir"] = self.conf["dataset.data_dir"].replace("CASE_NAME", case) 33 | self.base_exp_dir = self.conf["general.base_exp_dir"] 34 | os.makedirs(self.base_exp_dir, exist_ok=True) 35 | self.dataset = Dataset(self.conf["dataset"]) 36 | self.iter_step = 0 37 | 38 | # Training parameters 39 | self.end_iter = self.conf.get_int("train.end_iter") 40 | self.save_freq = self.conf.get_int("train.save_freq") 41 | self.report_freq = self.conf.get_int("train.report_freq") 42 | self.val_freq = self.conf.get_int("train.val_freq") 43 | self.val_mesh_freq = self.conf.get_int("train.val_mesh_freq") 44 | self.batch_size = self.conf.get_int("train.batch_size") 45 | self.validate_resolution_level = self.conf.get_int("train.validate_resolution_level") 46 | self.learning_rate = self.conf.get_float("train.learning_rate") 47 | self.learning_rate_alpha = self.conf.get_float("train.learning_rate_alpha") 48 | self.use_white_bkgd = self.conf.get_bool("train.use_white_bkgd") 49 | self.warm_up_end = self.conf.get_float("train.warm_up_end", default=0.0) 50 | self.anneal_end = self.conf.get_float("train.anneal_end", default=0.0) 51 | 52 | # Weights 53 | self.igr_weight = self.conf.get_float("train.igr_weight") 54 | self.mask_weight = self.conf.get_float("train.mask_weight") 55 | self.is_continue = is_continue 56 | self.mode = mode 57 | self.model_list = [] 58 | self.writer = None 59 | 60 | # Networks 61 | params_to_train = [] 62 | self.nerf_outside = NeRF(**self.conf["model.nerf"]).to(self.device) 63 | self.sdf_network = SDFNetwork(**self.conf["model.sdf_network"]).to(self.device) 64 | self.deviation_network = SingleVarianceNetwork(**self.conf["model.variance_network"]).to(self.device) 65 | self.color_network = RenderingNetwork(**self.conf["model.rendering_network"]).to(self.device) 66 | params_to_train += list(self.nerf_outside.parameters()) 67 | params_to_train += list(self.sdf_network.parameters()) 68 | params_to_train += list(self.deviation_network.parameters()) 69 | params_to_train += list(self.color_network.parameters()) 70 | 71 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate) 72 | 73 | self.renderer = NeuSRenderer( 74 | self.nerf_outside, 75 | self.sdf_network, 76 | self.deviation_network, 77 | self.color_network, 78 | **self.conf["model.neus_renderer"] 79 | ) 80 | 81 | # Load checkpoint 82 | latest_model_name = None 83 | if is_continue: 84 | model_list_raw = os.listdir(os.path.join(self.base_exp_dir, "checkpoints")) 85 | model_list = [] 86 | for model_name in model_list_raw: 87 | if model_name[-3:] == "pth" and int(model_name[5:-4]) <= self.end_iter: 88 | model_list.append(model_name) 89 | model_list.sort() 90 | latest_model_name = model_list[-1] 91 | 92 | if latest_model_name is not None: 93 | logging.info("Find checkpoint: {}".format(latest_model_name)) 94 | self.load_checkpoint(latest_model_name) 95 | 96 | # Backup codes and configs for debug 97 | if self.mode[:5] == "train": 98 | self.file_backup() 99 | 100 | def train(self): 101 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, "logs")) 102 | self.update_learning_rate() 103 | res_step = self.end_iter - self.iter_step 104 | image_perm = self.get_image_perm() 105 | 106 | for iter_i in tqdm(range(res_step)): 107 | data = self.dataset.gen_random_rays_at(image_perm[self.iter_step % len(image_perm)], self.batch_size) 108 | 109 | rays_o, rays_d, true_rgb, mask = ( 110 | data[:, :3], 111 | data[:, 3:6], 112 | data[:, 6:9], 113 | data[:, 9:10], 114 | ) 115 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d) 116 | 117 | background_rgb = None 118 | if self.use_white_bkgd: 119 | background_rgb = torch.ones([1, 3]) 120 | 121 | if self.mask_weight > 0.0: 122 | mask = (mask > 0.5).float() 123 | else: 124 | mask = torch.ones_like(mask) 125 | 126 | mask_sum = mask.sum() + 1e-5 127 | render_out = self.renderer.render( 128 | rays_o, 129 | rays_d, 130 | near, 131 | far, 132 | background_rgb=background_rgb, 133 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 134 | ) 135 | 136 | color_fine = render_out["color_fine"] 137 | s_val = render_out["s_val"] 138 | cdf_fine = render_out["cdf_fine"] 139 | gradient_error = render_out["gradient_error"] 140 | weight_max = render_out["weight_max"] 141 | weight_sum = render_out["weight_sum"] 142 | 143 | # Loss 144 | color_error = (color_fine - true_rgb) * mask 145 | color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction="sum") / mask_sum 146 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb) ** 2 * mask).sum() / (mask_sum * 3.0)).sqrt()) 147 | 148 | eikonal_loss = gradient_error 149 | 150 | mask_loss = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask) 151 | 152 | loss = color_fine_loss + eikonal_loss * self.igr_weight + mask_loss * self.mask_weight 153 | 154 | self.optimizer.zero_grad() 155 | loss.backward() 156 | self.optimizer.step() 157 | 158 | self.iter_step += 1 159 | 160 | self.writer.add_scalar("Loss/loss", loss, self.iter_step) 161 | self.writer.add_scalar("Loss/color_loss", color_fine_loss, self.iter_step) 162 | self.writer.add_scalar("Loss/eikonal_loss", eikonal_loss, self.iter_step) 163 | self.writer.add_scalar("Statistics/s_val", s_val.mean(), self.iter_step) 164 | self.writer.add_scalar( 165 | "Statistics/cdf", 166 | (cdf_fine[:, :1] * mask).sum() / mask_sum, 167 | self.iter_step, 168 | ) 169 | self.writer.add_scalar( 170 | "Statistics/weight_max", 171 | (weight_max * mask).sum() / mask_sum, 172 | self.iter_step, 173 | ) 174 | self.writer.add_scalar("Statistics/psnr", psnr, self.iter_step) 175 | 176 | if self.iter_step % self.report_freq == 0: 177 | print(self.base_exp_dir) 178 | print("iter:{:8>d} loss = {} lr={}".format(self.iter_step, loss, self.optimizer.param_groups[0]["lr"])) 179 | 180 | if self.iter_step % self.save_freq == 0: 181 | self.save_checkpoint() 182 | 183 | if self.iter_step % self.val_freq == 0: 184 | self.validate_image() 185 | 186 | if self.iter_step % self.val_mesh_freq == 0: 187 | self.validate_mesh() 188 | 189 | self.update_learning_rate() 190 | 191 | if self.iter_step % len(image_perm) == 0: 192 | image_perm = self.get_image_perm() 193 | 194 | def get_image_perm(self): 195 | return torch.randperm(self.dataset.n_images) 196 | 197 | def get_cos_anneal_ratio(self): 198 | if self.anneal_end == 0.0: 199 | return 1.0 200 | else: 201 | return np.min([1.0, self.iter_step / self.anneal_end]) 202 | 203 | def update_learning_rate(self): 204 | if self.iter_step < self.warm_up_end: 205 | learning_factor = self.iter_step / self.warm_up_end 206 | else: 207 | alpha = self.learning_rate_alpha 208 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end) 209 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha 210 | 211 | for g in self.optimizer.param_groups: 212 | g["lr"] = self.learning_rate * learning_factor 213 | 214 | def file_backup(self): 215 | dir_lis = self.conf["general.recording"] 216 | os.makedirs(os.path.join(self.base_exp_dir, "recording"), exist_ok=True) 217 | for dir_name in dir_lis: 218 | cur_dir = os.path.join(self.base_exp_dir, "recording", dir_name) 219 | os.makedirs(cur_dir, exist_ok=True) 220 | files = os.listdir(dir_name) 221 | for f_name in files: 222 | if f_name[-3:] == ".py": 223 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) 224 | 225 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, "recording", "config.conf")) 226 | 227 | def load_checkpoint(self, checkpoint_name): 228 | checkpoint = torch.load( 229 | os.path.join(self.base_exp_dir, "checkpoints", checkpoint_name), 230 | map_location=self.device, 231 | ) 232 | self.nerf_outside.load_state_dict(checkpoint["nerf"]) 233 | self.sdf_network.load_state_dict(checkpoint["sdf_network_fine"]) 234 | self.deviation_network.load_state_dict(checkpoint["variance_network_fine"]) 235 | self.color_network.load_state_dict(checkpoint["color_network_fine"]) 236 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 237 | self.iter_step = checkpoint["iter_step"] 238 | 239 | logging.info("End") 240 | 241 | def save_checkpoint(self): 242 | checkpoint = { 243 | "nerf": self.nerf_outside.state_dict(), 244 | "sdf_network_fine": self.sdf_network.state_dict(), 245 | "variance_network_fine": self.deviation_network.state_dict(), 246 | "color_network_fine": self.color_network.state_dict(), 247 | "optimizer": self.optimizer.state_dict(), 248 | "iter_step": self.iter_step, 249 | } 250 | 251 | os.makedirs(os.path.join(self.base_exp_dir, "checkpoints"), exist_ok=True) 252 | torch.save( 253 | checkpoint, 254 | os.path.join( 255 | self.base_exp_dir, 256 | "checkpoints", 257 | "ckpt_{:0>6d}.pth".format(self.iter_step), 258 | ), 259 | ) 260 | 261 | def validate_image(self, idx=-1, resolution_level=-1): 262 | if idx < 0: 263 | idx = np.random.randint(self.dataset.n_images) 264 | 265 | print("Validate: iter: {}, camera: {}".format(self.iter_step, idx)) 266 | 267 | if resolution_level < 0: 268 | resolution_level = self.validate_resolution_level 269 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) 270 | H, W, _ = rays_o.shape 271 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 272 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 273 | 274 | out_rgb_fine = [] 275 | out_normal_fine = [] 276 | 277 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 278 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 279 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 280 | 281 | render_out = self.renderer.render( 282 | rays_o_batch, 283 | rays_d_batch, 284 | near, 285 | far, 286 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 287 | background_rgb=background_rgb, 288 | ) 289 | 290 | def feasible(key): 291 | return (key in render_out) and (render_out[key] is not None) 292 | 293 | if feasible("color_fine"): 294 | out_rgb_fine.append(render_out["color_fine"].detach().cpu().numpy()) 295 | if feasible("gradients") and feasible("weights"): 296 | n_samples = self.renderer.n_samples + self.renderer.n_importance 297 | normals = render_out["gradients"] * render_out["weights"][:, :n_samples, None] 298 | if feasible("inside_sphere"): 299 | normals = normals * render_out["inside_sphere"][..., None] 300 | normals = normals.sum(dim=1).detach().cpu().numpy() 301 | out_normal_fine.append(normals) 302 | del render_out 303 | 304 | img_fine = None 305 | if len(out_rgb_fine) > 0: 306 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) 307 | 308 | normal_img = None 309 | if len(out_normal_fine) > 0: 310 | normal_img = np.concatenate(out_normal_fine, axis=0) 311 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy()) 312 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]).reshape([H, W, 3, -1]) * 128 + 128).clip( 313 | 0, 255 314 | ) 315 | 316 | os.makedirs(os.path.join(self.base_exp_dir, "validations_fine"), exist_ok=True) 317 | os.makedirs(os.path.join(self.base_exp_dir, "normals"), exist_ok=True) 318 | 319 | for i in range(img_fine.shape[-1]): 320 | if len(out_rgb_fine) > 0: 321 | cv.imwrite( 322 | os.path.join( 323 | self.base_exp_dir, 324 | "validations_fine", 325 | "{:0>8d}_{}_{}.png".format(self.iter_step, i, idx), 326 | ), 327 | np.concatenate( 328 | [ 329 | img_fine[..., i], 330 | self.dataset.image_at(idx, resolution_level=resolution_level), 331 | ] 332 | ), 333 | ) 334 | if len(out_normal_fine) > 0: 335 | cv.imwrite( 336 | os.path.join( 337 | self.base_exp_dir, 338 | "normals", 339 | "{:0>8d}_{}_{}.png".format(self.iter_step, i, idx), 340 | ), 341 | normal_img[..., i], 342 | ) 343 | 344 | def render_novel_image(self, idx_0, idx_1, ratio, resolution_level): 345 | """ 346 | Interpolate view between two cameras. 347 | """ 348 | rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level) 349 | H, W, _ = rays_o.shape 350 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 351 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 352 | 353 | out_rgb_fine = [] 354 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 355 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 356 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 357 | 358 | render_out = self.renderer.render( 359 | rays_o_batch, 360 | rays_d_batch, 361 | near, 362 | far, 363 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 364 | background_rgb=background_rgb, 365 | ) 366 | 367 | out_rgb_fine.append(render_out["color_fine"].detach().cpu().numpy()) 368 | 369 | del render_out 370 | 371 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8) 372 | return img_fine 373 | 374 | def validate_mesh(self, world_space=False, resolution=64, threshold=0.0): 375 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32) 376 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32) 377 | 378 | vertices, triangles = self.renderer.extract_geometry( 379 | bound_min, bound_max, resolution=resolution, threshold=threshold 380 | ) 381 | os.makedirs(os.path.join(self.base_exp_dir, "meshes"), exist_ok=True) 382 | 383 | if world_space: 384 | vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None] 385 | 386 | mesh = trimesh.Trimesh(vertices, triangles) 387 | mesh.export(os.path.join(self.base_exp_dir, "meshes", "{:0>8d}.ply".format(self.iter_step))) 388 | 389 | logging.info("End") 390 | 391 | def interpolate_view(self, img_idx_0, img_idx_1): 392 | images = [] 393 | n_frames = 60 394 | for i in range(n_frames): 395 | print(i) 396 | images.append( 397 | self.render_novel_image( 398 | img_idx_0, 399 | img_idx_1, 400 | np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5, 401 | resolution_level=4, 402 | ) 403 | ) 404 | for i in range(n_frames): 405 | images.append(images[n_frames - i - 1]) 406 | 407 | fourcc = cv.VideoWriter_fourcc(*"mp4v") 408 | video_dir = os.path.join(self.base_exp_dir, "render") 409 | os.makedirs(video_dir, exist_ok=True) 410 | h, w, _ = images[0].shape 411 | writer = cv.VideoWriter( 412 | os.path.join( 413 | video_dir, 414 | "{:0>8d}_{}_{}.mp4".format(self.iter_step, img_idx_0, img_idx_1), 415 | ), 416 | fourcc, 417 | 30, 418 | (w, h), 419 | ) 420 | 421 | for image in images: 422 | writer.write(image) 423 | 424 | writer.release() 425 | 426 | 427 | if __name__ == "__main__": 428 | print("Hello Wooden") 429 | 430 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 431 | 432 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" 433 | logging.basicConfig(level=logging.DEBUG, format=FORMAT) 434 | 435 | parser = argparse.ArgumentParser() 436 | parser.add_argument("--conf", type=str, default="./confs/base.conf") 437 | parser.add_argument("--mode", type=str, default="train") 438 | parser.add_argument("--mcube_threshold", type=float, default=0.0) 439 | parser.add_argument("--is_continue", default=False, action="store_true") 440 | parser.add_argument("--gpu", type=int, default=0) 441 | parser.add_argument("--case", type=str, default="") 442 | 443 | args = parser.parse_args() 444 | 445 | torch.cuda.set_device(args.gpu) 446 | runner = Runner(args.conf, args.mode, args.case, args.is_continue) 447 | 448 | if args.mode == "train": 449 | runner.train() 450 | elif args.mode == "validate_mesh": 451 | runner.validate_mesh(world_space=True, resolution=512, threshold=args.mcube_threshold) 452 | elif args.mode.startswith("interpolate"): # Interpolate views given two image indices 453 | _, img_idx_0, img_idx_1 = args.mode.split("_") 454 | img_idx_0 = int(img_idx_0) 455 | img_idx_1 = int(img_idx_1) 456 | runner.interpolate_view(img_idx_0, img_idx_1) 457 | -------------------------------------------------------------------------------- /singleview/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/singleview/12.png -------------------------------------------------------------------------------- /singleview/cam_dict_norm.json: -------------------------------------------------------------------------------- 1 | { 2 | "12.png": { 3 | "K": [ 4 | 811.9282694049824, 5 | 0.0, 6 | 256.0, 7 | 0.0, 8 | 0.0, 9 | 811.9282694049824, 10 | 256.0, 11 | 0.0, 12 | 0.0, 13 | 0.0, 14 | 1.0, 15 | 0.0, 16 | 0.0, 17 | 0.0, 18 | 0.0, 19 | 1.0 20 | ], 21 | "W2C": [ 22 | 0.998867339183008, 23 | 0.0, 24 | -0.04758191582374219, 25 | 1.5416074755814572e-17, 26 | -0.013163727354886733, 27 | -0.9609695958324571, 28 | -0.27634064516051604, 29 | 1.1553154537250536e-16, 30 | -0.045724774418075535, 31 | 0.27665400030652737, 32 | -0.959881143224937, 33 | 2.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 1.0 38 | ], 39 | "img_size": [ 40 | 512, 41 | 512 42 | ] 43 | } 44 | } -------------------------------------------------------------------------------- /test_mitsuba/render_rgb_envmap_mat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import imageio 5 | 6 | imageio.plugins.freeimage.download() 7 | 8 | import sys 9 | 10 | asset_dir = sys.argv[1] 11 | cam_dict_fpath = sys.argv[2] 12 | envmap_fpath = sys.argv[3] 13 | out_dir = sys.argv[4] 14 | 15 | 16 | d_albedo = os.path.join(asset_dir, "diffuse_albedo.exr") 17 | s_albedo = os.path.join(asset_dir, "specular_albedo.exr") 18 | s_roughness = os.path.join(asset_dir, "roughness.exr") 19 | mesh_fpath = os.path.join(asset_dir, "mesh.obj") 20 | 21 | os.makedirs(out_dir, exist_ok=True) 22 | 23 | 24 | envmap_fpath = os.path.join(asset_dir, "../envmap.exr") 25 | 26 | cam_dict = json.load(open(cam_dict_fpath)) 27 | 28 | use_docker = True 29 | 30 | for img_name in list(cam_dict.keys()): 31 | out_fpath = os.path.join(out_dir, img_name[:-4] + ".exr") 32 | K = np.array(cam_dict[img_name]["K"]).reshape((4, 4)) 33 | focal = K[0, 0] 34 | width, height = cam_dict[img_name]["img_size"] 35 | fov = np.rad2deg(np.arctan(width / 2.0 / focal) * 2.0) 36 | w2c = np.array(cam_dict[img_name]["W2C"]).reshape((4, 4)) 37 | 38 | c2w = np.linalg.inv(w2c) 39 | c2w[:3, :2] *= -1 # mitsuba camera coordinate system: x-->left, y-->up, z-->scene 40 | origin = c2w[:3, 3] 41 | c2w = " ".join([str(x) for x in c2w.flatten().tolist()]) 42 | 43 | cmd = ( 44 | 'mitsuba -b 10 rgb_envmap_hdr_mat.xml -D fov={} -D width={} -D height={} -D c2w="{}" ' 45 | "-D mesh={} -D d_albedo={} -D s_albedo={} -D s_roughness={} " 46 | "-D envmap={} " 47 | "-o {} ".format(fov, width, height, c2w, mesh_fpath, d_albedo, s_albedo, s_roughness, envmap_fpath, out_fpath) 48 | ) 49 | 50 | if use_docker: 51 | docker_prefix = "docker run -w `pwd` --rm -v `pwd`:`pwd` -v /phoenix:/phoenix ninjaben/mitsuba-rgb " 52 | cmd = docker_prefix + cmd 53 | 54 | os.system(cmd) 55 | os.system("rm mitsuba.*.log") 56 | 57 | to8b = lambda x: np.uint8(np.clip(x * 255.0, 0.0, 255.0)) 58 | im = imageio.imread(out_fpath).astype(np.float32) 59 | imageio.imwrite(out_fpath[:-4] + ".png", to8b(np.power(im, 1.0 / 2.2))) 60 | -------------------------------------------------------------------------------- /test_mitsuba/render_rgb_flash_mat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import imageio 5 | 6 | imageio.plugins.freeimage.download() 7 | import sys 8 | 9 | 10 | cam_dict_fpath = sys.argv[1] 11 | asset_dir = sys.argv[2] 12 | 13 | out_dir = os.path.join(asset_dir, "mitsuba_render") 14 | os.makedirs(out_dir, exist_ok=True) 15 | 16 | light = 61.3303 # pony 17 | # light = 28.8344 # girl 18 | # light = 48.5146 # triton 19 | # light = 136.0487 # tree 20 | # light = 14.7209 # dragon 21 | 22 | cam_dict = json.load(open(cam_dict_fpath)) 23 | img_list = list(cam_dict.keys()) 24 | img_list = sorted(img_list, key=lambda x: int(x[:-4])) 25 | 26 | 27 | use_docker = True 28 | 29 | for index, img_name in enumerate(img_list): 30 | mesh = os.path.join(asset_dir, "mesh.obj") 31 | d_albedo = os.path.join(asset_dir, "diffuse_albedo.exr") 32 | s_albedo = os.path.join(asset_dir, "specular_albedo.exr") 33 | s_roughness = os.path.join(asset_dir, "roughness.exr") 34 | 35 | K = np.array(cam_dict[img_name]["K"]).reshape((4, 4)) 36 | focal = K[0, 0] 37 | width, height = cam_dict[img_name]["img_size"] 38 | fov = np.rad2deg(np.arctan(width / 2.0 / focal) * 2.0) 39 | w2c = np.array(cam_dict[img_name]["W2C"]).reshape((4, 4)) 40 | # check if unit aspect ratio 41 | assert np.isclose(K[0, 0] - K[1, 1], 0.0), f"{K[0,0]} != {K[1,1]}" 42 | 43 | c2w = np.linalg.inv(w2c) 44 | c2w[:3, :2] *= -1 # mitsuba camera coordinate system: x-->left, y-->up, z-->scene 45 | origin = c2w[:3, 3] 46 | c2w = " ".join([str(x) for x in c2w.flatten().tolist()]) 47 | 48 | out_fpath = os.path.join(out_dir, img_name[:-4] + ".exr") 49 | cmd = ( 50 | 'mitsuba -b 10 rgb_flash_hdr_mat.xml -D fov={} -D width={} -D height={} -D c2w="{}" ' 51 | "-D mesh={} -D d_albedo={} -D s_albedo={} -D s_roughness={} " 52 | "-D light={} " 53 | "-D px={} -D py={} -D pz={} " 54 | "-o {} ".format( 55 | fov, 56 | width, 57 | height, 58 | c2w, 59 | mesh, 60 | d_albedo, 61 | s_albedo, 62 | s_roughness, 63 | light, 64 | origin[0], 65 | origin[1], 66 | origin[2], 67 | out_fpath, 68 | ) 69 | ) 70 | 71 | if use_docker: 72 | docker_prefix = "docker run -w `pwd` --rm -v `pwd`:`pwd` -v /phoenix:/phoenix ninjaben/mitsuba-rgb " 73 | cmd = docker_prefix + cmd 74 | 75 | os.system(cmd) 76 | os.system("rm mitsuba.*.log") 77 | 78 | to8b = lambda x: np.uint8(np.clip(x * 255.0, 0.0, 255.0)) 79 | im = imageio.imread(out_fpath).astype(np.float32) 80 | imageio.imwrite(out_fpath[:-4] + ".png", to8b(np.power(im, 1.0 / 2.2))) 81 | -------------------------------------------------------------------------------- /test_mitsuba/rgb_envmap_hdr_mat.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /test_mitsuba/rgb_flash_hdr_mat.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /tests/data_singleview/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kai-46/IRON/8e9a7c172542afd52b8e6ef28bc96ad52b5ffd5a/tests/data_singleview/12.png -------------------------------------------------------------------------------- /tests/data_singleview/cam_dict_norm.json: -------------------------------------------------------------------------------- 1 | { 2 | "12.png": { 3 | "K": [ 4 | 811.9282694049824, 5 | 0.0, 6 | 256.0, 7 | 0.0, 8 | 0.0, 9 | 811.9282694049824, 10 | 256.0, 11 | 0.0, 12 | 0.0, 13 | 0.0, 14 | 1.0, 15 | 0.0, 16 | 0.0, 17 | 0.0, 18 | 0.0, 19 | 1.0 20 | ], 21 | "W2C": [ 22 | 0.998867339183008, 23 | 0.0, 24 | -0.04758191582374219, 25 | 1.5416074755814572e-17, 26 | -0.013163727354886733, 27 | -0.9609695958324571, 28 | -0.27634064516051604, 29 | 1.1553154537250536e-16, 30 | -0.045724774418075535, 31 | 0.27665400030652737, 32 | -0.959881143224937, 33 | 2.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 1.0 38 | ], 39 | "img_size": [ 40 | 512, 41 | 512 42 | ] 43 | } 44 | } -------------------------------------------------------------------------------- /tests/test_raytracer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | import trimesh 6 | import imageio 7 | 8 | imageio.plugins.freeimage.download() 9 | 10 | from icecream import ic 11 | import sys 12 | 13 | sys.path.append("../") 14 | 15 | from models.fields import SDFNetwork, RenderingNetwork 16 | import models.raytracer 17 | 18 | models.raytracer.VERBOSE_MODE = True 19 | from models.raytracer import RayTracer, Camera, render_camera 20 | 21 | 22 | def to8b(x): 23 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8) 24 | 25 | 26 | sdf_network = SDFNetwork( 27 | d_in=3, 28 | d_out=257, 29 | d_hidden=256, 30 | n_layers=8, 31 | skip_in=[ 32 | 4, 33 | ], 34 | multires=6, 35 | bias=0.5, 36 | scale=1.0, 37 | geometric_init=True, 38 | weight_norm=True, 39 | ).cuda() 40 | color_network = RenderingNetwork( 41 | d_in=9, 42 | d_out=3, 43 | d_feature=256, 44 | d_hidden=256, 45 | n_layers=4, 46 | multires_view=4, 47 | mode="idr", 48 | squeeze_out=True, 49 | ).cuda() 50 | raytracer = RayTracer() 51 | 52 | scene = "dtu_scan69" 53 | ckpt_fpath = f"../exp/{scene}/womask_sphere/checkpoints/ckpt_300000.pth" 54 | 55 | ckpt = torch.load(ckpt_fpath, map_location=torch.device("cuda")) 56 | sdf_network.load_state_dict(ckpt["sdf_network_fine"]) 57 | color_network.load_state_dict(ckpt["color_network_fine"]) 58 | 59 | color_network_dict = {"color_network": color_network} 60 | 61 | 62 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features): 63 | interior_color = color_network_dict["color_network"](points, normals, ray_d, features) # [..., [2, 0, 1]] 64 | 65 | dots_sh = list(interior_mask.shape) 66 | color = torch.zeros( 67 | dots_sh 68 | + [ 69 | 3, 70 | ], 71 | dtype=torch.float32, 72 | device=interior_mask.device, 73 | ) 74 | color[interior_mask] = interior_color 75 | 76 | normals_pad = torch.zeros( 77 | dots_sh 78 | + [ 79 | 3, 80 | ], 81 | dtype=torch.float32, 82 | device=interior_mask.device, 83 | ) 84 | normals_pad[interior_mask] = normals 85 | return {"color": color, "normal": normals_pad} 86 | 87 | 88 | def load_datadir(data_dir): 89 | from glob import glob 90 | from models.dataset import load_K_Rt_from_P 91 | 92 | camera_dict = np.load(os.path.join(data_dir, "cameras_sphere.npz")) 93 | images_lis = sorted(glob(os.path.join(data_dir, "image/*.png"))) 94 | n_images = len(images_lis) 95 | images = np.stack([imageio.imread(im_name) for im_name in images_lis]) / 255.0 96 | images = torch.from_numpy(images).float() 97 | # world_mat is a projection matrix from world to image 98 | world_mats_np = [camera_dict["world_mat_%d" % idx].astype(np.float32) for idx in range(n_images)] 99 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin. 100 | scale_mats_np = [camera_dict["scale_mat_%d" % idx].astype(np.float32) for idx in range(n_images)] 101 | intrinsics_all = [] 102 | pose_all = [] 103 | for scale_mat, world_mat in zip(scale_mats_np, world_mats_np): 104 | P = world_mat @ scale_mat 105 | P = P[:3, :4] 106 | intrinsics, pose = load_K_Rt_from_P(None, P) 107 | intrinsics_all.append(torch.from_numpy(intrinsics).float()) 108 | pose_all.append(torch.from_numpy(pose).float()) 109 | intrinsics_all = torch.stack(intrinsics_all, dim=0) 110 | pose_all = torch.stack(pose_all, dim=0) # C2W 111 | pose_all = torch.inverse(pose_all) 112 | 113 | ic(images.shape, intrinsics_all.shape, pose_all.shape) 114 | return images, intrinsics_all, pose_all 115 | 116 | 117 | gt_images, Ks, W2Cs = load_datadir(f"../public_data/{scene}") 118 | 119 | img_idx = 10 120 | gt_color = gt_images[img_idx] 121 | camera = Camera(W=gt_color.shape[1], H=gt_color.shape[0], K=Ks[img_idx].cuda(), W2C=W2Cs[img_idx].cuda()) 122 | 123 | fill_holes = False 124 | handle_edges = True 125 | is_training = False 126 | out_dir = f"./debug_raytracer_{scene}_{fill_holes}_{handle_edges}_{is_training}" 127 | ic(out_dir) 128 | os.makedirs(out_dir, exist_ok=True) 129 | 130 | if is_training: 131 | camera, gt_color = camera.crop_region(trgt_W=256, trgt_H=256, center_crop=True, image=gt_color) 132 | ic(gt_color.shape, camera.H, camera.W) 133 | 134 | results = render_camera( 135 | camera, 136 | sdf_network, 137 | raytracer, 138 | color_network_dict, 139 | render_fn, 140 | fill_holes=fill_holes, 141 | handle_edges=handle_edges, 142 | is_training=is_training, 143 | ) 144 | 145 | for x in list(results.keys()): 146 | results[x] = results[x].detach().cpu().numpy() 147 | 148 | 149 | def append_allones(x): 150 | return np.concatenate((x, np.ones_like(x[..., 0:1])), axis=-1) 151 | 152 | 153 | imageio.imwrite(os.path.join(out_dir, "convergent_mask.png"), to8b(results["convergent_mask"])) 154 | imageio.imwrite(os.path.join(out_dir, "distance.exr"), results["distance"]) 155 | imageio.imwrite(os.path.join(out_dir, "depth.exr"), results["depth"]) 156 | imageio.imwrite(os.path.join(out_dir, "sdf.exr"), results["sdf"]) 157 | imageio.imwrite(os.path.join(out_dir, "points.exr"), results["points"]) 158 | imageio.imwrite(os.path.join(out_dir, "normal.png"), to8b((results["normal"] + 1.0) / 2.0)) 159 | imageio.imwrite(os.path.join(out_dir, "normal.exr"), results["normal"]) 160 | imageio.imwrite(os.path.join(out_dir, "color.png"), to8b(results["color"])[..., ::-1]) 161 | imageio.imwrite(os.path.join(out_dir, "color_gt.png"), to8b(gt_color.detach().cpu().numpy())) 162 | imageio.imwrite(os.path.join(out_dir, "uv.exr"), append_allones(results["uv"])) 163 | 164 | imageio.imwrite(os.path.join(out_dir, "depth_grad_norm.exr"), results["depth_grad_norm"]) 165 | imageio.imwrite(os.path.join(out_dir, "depth_edge_mask.png"), to8b(results["depth_edge_mask"])) 166 | imageio.imwrite( 167 | os.path.join(out_dir, "walk_edge_found_mask.png"), 168 | to8b(results["walk_edge_found_mask"]), 169 | ) 170 | trimesh.PointCloud(results["edge_points"].reshape((-1, 3))).export(os.path.join(out_dir, "edge_points.ply")) 171 | imageio.imwrite(os.path.join(out_dir, "edge_mask.png"), to8b(results["edge_mask"])) 172 | imageio.imwrite(os.path.join(out_dir, "edge_pos_side_weight.exr"), results["edge_pos_side_weight"]) 173 | imageio.imwrite(os.path.join(out_dir, "edge_angles.exr"), results["edge_angles"]) 174 | imageio.imwrite(os.path.join(out_dir, "edge_sdf.exr"), results["edge_sdf"]) 175 | imageio.imwrite(os.path.join(out_dir, "edge_pos_side_depth.exr"), results["edge_pos_side_depth"]) 176 | imageio.imwrite(os.path.join(out_dir, "edge_neg_side_depth.exr"), results["edge_neg_side_depth"]) 177 | imageio.imwrite(os.path.join(out_dir, "edge_pos_side_color.png"), to8b(results["edge_pos_side_color"])[..., ::-1]) 178 | imageio.imwrite(os.path.join(out_dir, "edge_neg_side_color.png"), to8b(results["edge_neg_side_color"])[..., ::-1]) 179 | -------------------------------------------------------------------------------- /tests/test_singleview.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from unittest import result 4 | import numpy as np 5 | import torch 6 | import trimesh 7 | import json 8 | import imageio 9 | 10 | imageio.plugins.freeimage.download() 11 | 12 | from icecream import ic 13 | import sys 14 | 15 | sys.path.append("../") 16 | 17 | from models.fields import SDFNetwork 18 | import models.raytracer 19 | 20 | models.raytracer.VERBOSE_MODE = False 21 | from models.raytracer import RayTracer, Camera, render_camera 22 | 23 | 24 | def to8b(x): 25 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8) 26 | 27 | 28 | sdf_network = SDFNetwork( 29 | d_in=3, 30 | d_out=257, 31 | d_hidden=256, 32 | n_layers=8, 33 | skip_in=[ 34 | 4, 35 | ], 36 | multires=6, 37 | bias=0.5, 38 | scale=1.0, 39 | geometric_init=True, 40 | weight_norm=True, 41 | ).cuda() 42 | raytracer = RayTracer() 43 | 44 | color_network_dict = {} 45 | 46 | 47 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features): 48 | dots_sh = list(interior_mask.shape) 49 | color = torch.zeros( 50 | dots_sh 51 | + [ 52 | 3, 53 | ], 54 | dtype=torch.float32, 55 | device=interior_mask.device, 56 | ) 57 | normals_pad = torch.zeros( 58 | dots_sh 59 | + [ 60 | 3, 61 | ], 62 | dtype=torch.float32, 63 | device=interior_mask.device, 64 | ) 65 | if interior_mask.any(): 66 | interior_color = ( 67 | torch.ones_like(points.view(-1, 3)) 68 | * torch.Tensor([[237.0 / 255.0, 61.0 / 255.0, 100.0 / 255.0]]).float().cuda() 69 | ) 70 | interior_color = interior_color.view(list(points.shape)) 71 | color[interior_mask] = interior_color 72 | normals_pad[interior_mask] = normals 73 | 74 | return {"color": color, "normal": normals_pad} 75 | 76 | 77 | gt_color = imageio.imread("./data_singleview/12.png").astype(np.float32) / 255.0 78 | gt_color = torch.from_numpy(gt_color).cuda() 79 | 80 | cam_dict = json.load(open("./data_singleview/cam_dict_norm.json")) 81 | K = torch.from_numpy(np.array(cam_dict["12.png"]["K"]).reshape((4, 4)).astype(np.float32)).cuda() 82 | W2C = torch.from_numpy(np.array(cam_dict["12.png"]["W2C"]).reshape((4, 4)).astype(np.float32)).cuda() 83 | W, H = cam_dict["12.png"]["img_size"] 84 | 85 | camera = Camera(W=W, H=H, K=K, W2C=W2C) 86 | 87 | fill_holes = False 88 | handle_edges = True 89 | is_training = True 90 | out_dir = f"./debug_singleview_{fill_holes}_{handle_edges}_{is_training}" 91 | ic(out_dir) 92 | os.makedirs(out_dir, exist_ok=True) 93 | 94 | sdf_optimizer = torch.optim.Adam(sdf_network.parameters(), lr=1e-4) 95 | 96 | for global_step in range(15000): 97 | sdf_optimizer.zero_grad() 98 | 99 | camera_crop, gt_color_crop = camera.crop_region(trgt_W=128, trgt_H=128, image=gt_color) 100 | 101 | results = render_camera( 102 | camera_crop, 103 | sdf_network, 104 | raytracer, 105 | color_network_dict, 106 | render_fn, 107 | fill_holes=fill_holes, 108 | handle_edges=handle_edges, 109 | is_training=is_training, 110 | ) 111 | 112 | mask = results["convergent_mask"] 113 | if handle_edges: 114 | # mask = mask | results["edge_mask"] 115 | mask = results["edge_mask"] 116 | 117 | img_loss = torch.Tensor( 118 | [ 119 | 0.0, 120 | ] 121 | ).cuda() 122 | rand_eik_points = torch.empty(camera_crop.H * camera_crop.W // 2, 3).cuda().float().uniform_(-1.0, 1.0) 123 | eik_grad = sdf_network.gradient(rand_eik_points).view(-1, 3) 124 | 125 | if mask.any(): 126 | img_loss = ((results["color"][mask] - gt_color_crop[mask]) ** 2).mean() 127 | interior_normals = results["normal"][mask | results["convergent_mask"]] 128 | eik_grad = torch.cat([eik_grad, interior_normals], dim=0) 129 | if "edge_pos_neg_normal" in results: 130 | eik_grad = torch.cat([eik_grad, results["edge_pos_neg_normal"]], dim=0) 131 | eik_loss = ((eik_grad.norm(dim=-1) - 1) ** 2).mean() 132 | 133 | loss = img_loss + 0.1 * eik_loss 134 | loss.backward() 135 | sdf_optimizer.step() 136 | 137 | if global_step % 200 == 0: 138 | ic(global_step, loss.item(), img_loss.item(), eik_loss.item()) 139 | for x in list(results.keys()): 140 | del results[x] 141 | 142 | camera_resize, gt_color_resize = camera.resize(factor=0.25, image=gt_color) 143 | results = render_camera( 144 | camera_resize, 145 | sdf_network, 146 | raytracer, 147 | color_network_dict, 148 | render_fn, 149 | fill_holes=fill_holes, 150 | handle_edges=handle_edges, 151 | is_training=False, 152 | ) 153 | for x in list(results.keys()): 154 | results[x] = results[x].detach().cpu().numpy() 155 | 156 | gt_color_im = gt_color_resize.detach().cpu().numpy() 157 | color_im = results["color"] 158 | normal = results["normal"] 159 | normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-10) 160 | normal_im = (normal + 1.0) / 2.0 161 | edge_mask_im = np.tile(results["edge_mask"][:, :, np.newaxis], (1, 1, 3)) 162 | im = np.concatenate([gt_color_im, color_im, normal_im, edge_mask_im], axis=1) 163 | imageio.imwrite(os.path.join(out_dir, f"logim_{global_step}.png"), to8b(im)) 164 | 165 | torch.save(sdf_network.state_dict(), os.path.join(out_dir, "ckpt.pth")) 166 | -------------------------------------------------------------------------------- /tests/test_viewsynthesis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import trimesh 7 | import json 8 | import imageio 9 | from torch.utils.tensorboard import SummaryWriter 10 | import configargparse 11 | from icecream import ic 12 | import glob 13 | 14 | import sys 15 | 16 | sys.path.append("../") 17 | 18 | from models.fields import SDFNetwork, RenderingNetwork 19 | from models.raytracer import RayTracer, Camera, render_camera 20 | from models.renderer_ggx import GGXColocatedRenderer 21 | from models.image_losses import PyramidL2Loss, ssim_loss_fn 22 | 23 | 24 | def config_parser(): 25 | parser = configargparse.ArgumentParser() 26 | parser.add_argument("--data_dir", type=str, default=None, help="input data directory") 27 | parser.add_argument("--out_dir", type=str, default=None, help="output directory") 28 | # parser.add_argument("--neus_ckpt_fpath", type=str, default=None, help="checkpoint to load") 29 | parser.add_argument("--num_iters", type=int, default=100001, help="number of iterations") 30 | # parser.add_argument("--white_specular_albedo", action='store_true', help='force specular albedo to be white') 31 | parser.add_argument("--eik_weight", type=float, default=0.1, help="weight for eikonal loss") 32 | parser.add_argument("--ssim_weight", type=float, default=1.0, help="weight for ssim loss") 33 | parser.add_argument("--roughrange_weight", type=float, default=0.1, help="weight for roughness range loss") 34 | 35 | parser.add_argument("--plot_image_name", type=str, default=None, help="image to plot during training") 36 | parser.add_argument("--no_edgesample", action="store_true", help="whether to disable edge sampling") 37 | 38 | return parser 39 | 40 | 41 | parser = config_parser() 42 | args = parser.parse_args() 43 | ic(args) 44 | 45 | 46 | def to8b(x): 47 | return np.clip(x * 255.0, 0.0, 255.0).astype(np.uint8) 48 | 49 | 50 | ggx_renderer = GGXColocatedRenderer(use_cuda=True) 51 | pyramidl2_loss_fn = PyramidL2Loss(use_cuda=True) 52 | 53 | 54 | def render_fn(interior_mask, color_network_dict, ray_o, ray_d, points, normals, features): 55 | dots_sh = list(interior_mask.shape) 56 | color = torch.zeros( 57 | dots_sh 58 | + [ 59 | 3, 60 | ], 61 | dtype=torch.float32, 62 | device=interior_mask.device, 63 | ) 64 | normals_pad = color.clone() 65 | if interior_mask.any(): 66 | normals = normals / (normals.norm(dim=-1, keepdim=True) + 1e-10) 67 | interior_color = color_network_dict["color_network"](points, normals, ray_d, features) 68 | 69 | color[interior_mask] = interior_color 70 | normals_pad[interior_mask] = normals 71 | 72 | return { 73 | "color": color, 74 | "normal": normals_pad, 75 | } 76 | 77 | 78 | sdf_network = SDFNetwork( 79 | d_in=3, 80 | d_out=257, 81 | d_hidden=256, 82 | n_layers=8, 83 | skip_in=[ 84 | 4, 85 | ], 86 | multires=6, 87 | bias=0.5, 88 | scale=1.0, 89 | geometric_init=True, 90 | weight_norm=True, 91 | ).cuda() 92 | raytracer = RayTracer() 93 | 94 | 95 | color_network_dict = { 96 | "color_network": RenderingNetwork( 97 | d_in=9, 98 | d_out=3, 99 | d_feature=256, 100 | d_hidden=256, 101 | n_layers=8, 102 | multires=10, 103 | multires_view=4, 104 | mode="idr", 105 | squeeze_out=True, 106 | skip_in=(4,), 107 | ).cuda() 108 | } 109 | 110 | sdf_optimizer = torch.optim.Adam(sdf_network.parameters(), lr=1e-5) 111 | color_optimizer_dict = {"color_network": torch.optim.Adam(color_network_dict["color_network"].parameters(), lr=1e-4)} 112 | 113 | 114 | def load_datadir(datadir): 115 | cam_dict = json.load(open(os.path.join(datadir, "cam_dict_norm.json"))) 116 | imgnames = list(cam_dict.keys()) 117 | try: 118 | imgnames = sorted(imgnames, key=lambda x: int(x[:-4])) 119 | except: 120 | imgnames = sorted(imgnames) 121 | 122 | image_fpaths = [] 123 | gt_images = [] 124 | Ks = [] 125 | W2Cs = [] 126 | for x in imgnames: 127 | fpath = os.path.join(datadir, "image", x) 128 | assert fpath[-4:] in [".jpg", ".png"], "must use ldr images as inputs" 129 | im = imageio.imread(fpath).astype(np.float32) / 255.0 130 | K = np.array(cam_dict[x]["K"]).reshape((4, 4)).astype(np.float32) 131 | W2C = np.array(cam_dict[x]["W2C"]).reshape((4, 4)).astype(np.float32) 132 | 133 | image_fpaths.append(fpath) 134 | gt_images.append(torch.from_numpy(im)) 135 | Ks.append(torch.from_numpy(K)) 136 | W2Cs.append(torch.from_numpy(W2C)) 137 | gt_images = torch.stack(gt_images, dim=0) 138 | Ks = torch.stack(Ks, dim=0) 139 | W2Cs = torch.stack(W2Cs, dim=0) 140 | return image_fpaths, gt_images, Ks, W2Cs 141 | 142 | 143 | image_fpaths, gt_images, Ks, W2Cs = load_datadir(args.data_dir) 144 | cameras = [ 145 | Camera(W=gt_images[i].shape[1], H=gt_images[i].shape[0], K=Ks[i].cuda(), W2C=W2Cs[i].cuda()) 146 | for i in range(gt_images.shape[0]) 147 | ] 148 | ic(len(image_fpaths), gt_images.shape, Ks.shape, W2Cs.shape, len(cameras)) 149 | 150 | #### load pretrained checkpoints 151 | start_step = -1 152 | ckpt_fpaths = glob.glob(os.path.join(args.out_dir, "ckpt_*.pth")) 153 | if len(ckpt_fpaths) > 0: 154 | path2step = lambda x: int(os.path.basename(x)[len("ckpt_") : -4]) 155 | ckpt_fpaths = sorted(ckpt_fpaths, key=path2step) 156 | ckpt_fpath = ckpt_fpaths[-1] 157 | start_step = path2step(ckpt_fpath) 158 | ic("Reloading from checkpoint: ", ckpt_fpath) 159 | ckpt = torch.load(ckpt_fpath, map_location=torch.device("cuda")) 160 | sdf_network.load_state_dict(ckpt["sdf_network"]) 161 | for x in list(color_network_dict.keys()): 162 | color_network_dict[x].load_state_dict(ckpt[x]) 163 | # logim_names = [os.path.basename(x) for x in glob.glob(os.path.join(args.out_dir, "logim_*.png"))] 164 | # start_step = sorted([int(x[len("logim_") : -4]) for x in logim_names])[-1] 165 | 166 | ic(start_step) 167 | 168 | fill_holes = False 169 | handle_edges = not args.no_edgesample 170 | is_training = True 171 | inv_gamma_gt = False 172 | if inv_gamma_gt: 173 | ic("linearizing ground-truth images using inverse gamma correction") 174 | gt_images = torch.pow(gt_images, 2.2) 175 | 176 | ic(fill_holes, handle_edges, is_training, inv_gamma_gt) 177 | os.makedirs(args.out_dir, exist_ok=True) 178 | writer = SummaryWriter(log_dir=os.path.join(args.out_dir, "logs")) 179 | 180 | 181 | for global_step in tqdm.tqdm(range(start_step + 1, args.num_iters)): 182 | sdf_optimizer.zero_grad() 183 | for x in color_optimizer_dict.keys(): 184 | color_optimizer_dict[x].zero_grad() 185 | 186 | idx = np.random.randint(0, gt_images.shape[0]) 187 | camera_crop, gt_color_crop = cameras[idx].crop_region(trgt_W=128, trgt_H=128, image=gt_images[idx]) 188 | 189 | results = render_camera( 190 | camera_crop, 191 | sdf_network, 192 | raytracer, 193 | color_network_dict, 194 | render_fn, 195 | fill_holes=fill_holes, 196 | handle_edges=handle_edges, 197 | is_training=is_training, 198 | ) 199 | 200 | mask = results["convergent_mask"] 201 | if handle_edges: 202 | mask = mask | results["edge_mask"] 203 | 204 | img_loss = torch.Tensor([0.0]).cuda() 205 | img_l2_loss = torch.Tensor([0.0]).cuda() 206 | img_ssim_loss = torch.Tensor([0.0]).cuda() 207 | 208 | eik_points = torch.empty(camera_crop.H * camera_crop.W // 2, 3).cuda().float().uniform_(-1.0, 1.0) 209 | eik_grad = sdf_network.gradient(eik_points).view(-1, 3) 210 | eik_cnt = eik_grad.shape[0] 211 | eik_loss = ((eik_grad.norm(dim=-1) - 1) ** 2).sum() 212 | if mask.any(): 213 | pred_img = results["color"].permute(2, 0, 1).unsqueeze(0) 214 | gt_img = gt_color_crop.permute(2, 0, 1).unsqueeze(0).to(pred_img.device) 215 | img_l2_loss = pyramidl2_loss_fn(pred_img, gt_img) 216 | img_ssim_loss = args.ssim_weight * ssim_loss_fn(pred_img, gt_img, mask.unsqueeze(0).unsqueeze(0)) 217 | img_loss = img_l2_loss + img_ssim_loss 218 | 219 | eik_grad = results["normal"][mask] 220 | eik_cnt += eik_grad.shape[0] 221 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum() 222 | if "edge_pos_neg_normal" in results: 223 | eik_grad = results["edge_pos_neg_normal"] 224 | eik_cnt += eik_grad.shape[0] 225 | eik_loss = eik_loss + ((eik_grad.norm(dim=-1) - 1) ** 2).sum() 226 | 227 | eik_loss = eik_loss / eik_cnt * args.eik_weight 228 | 229 | loss = img_loss + eik_loss 230 | loss.backward() 231 | sdf_optimizer.step() 232 | for x in color_optimizer_dict.keys(): 233 | color_optimizer_dict[x].step() 234 | 235 | if global_step % 50 == 0: 236 | writer.add_scalar("loss/loss", loss, global_step) 237 | writer.add_scalar("loss/img_loss", img_loss, global_step) 238 | writer.add_scalar("loss/img_l2_loss", img_l2_loss, global_step) 239 | writer.add_scalar("loss/img_ssim_loss", img_ssim_loss, global_step) 240 | writer.add_scalar("loss/eik_loss", eik_loss, global_step) 241 | 242 | if global_step % 1000 == 0: 243 | torch.save( 244 | dict( 245 | [ 246 | ("sdf_network", sdf_network.state_dict()), 247 | ] 248 | + [(x, color_network_dict[x].state_dict()) for x in color_network_dict.keys()] 249 | ), 250 | os.path.join(args.out_dir, f"ckpt_{global_step}.pth"), 251 | ) 252 | 253 | if global_step % 500 == 0: 254 | ic( 255 | args.out_dir, 256 | global_step, 257 | loss.item(), 258 | img_loss.item(), 259 | img_l2_loss.item(), 260 | img_ssim_loss.item(), 261 | eik_loss.item(), 262 | ) 263 | 264 | for x in list(results.keys()): 265 | del results[x] 266 | 267 | idx = 0 268 | if args.plot_image_name is not None: 269 | while idx < len(image_fpaths): 270 | if args.plot_image_name in image_fpaths[idx]: 271 | break 272 | idx += 1 273 | 274 | camera_resize, gt_color_resize = cameras[idx].resize(factor=0.25, image=gt_images[idx]) 275 | results = render_camera( 276 | camera_resize, 277 | sdf_network, 278 | raytracer, 279 | color_network_dict, 280 | render_fn, 281 | fill_holes=fill_holes, 282 | handle_edges=handle_edges, 283 | is_training=False, 284 | ) 285 | for x in list(results.keys()): 286 | results[x] = results[x].detach().cpu().numpy() 287 | 288 | gt_color_im = gt_color_resize.detach().cpu().numpy() 289 | color_im = results["color"] 290 | normal = results["normal"] 291 | normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-10) 292 | normal_im = (normal + 1.0) / 2.0 293 | edge_mask_im = np.tile(results["edge_mask"][:, :, np.newaxis], (1, 1, 3)) 294 | if inv_gamma_gt: 295 | gt_color_im = np.power(gt_color_im + 1e-6, 1.0 / 2.2) 296 | color_im = np.power(color_im + 1e-6, 1.0 / 2.2) 297 | 298 | im = np.concatenate([gt_color_im, color_im, normal_im, edge_mask_im], axis=1) 299 | imageio.imwrite(os.path.join(args.out_dir, f"logim_{global_step}.png"), to8b(im)) 300 | -------------------------------------------------------------------------------- /train_scene.sh: -------------------------------------------------------------------------------- 1 | SCENE=$1 2 | 3 | python render_volume.py --mode train --conf ./confs/womask_iron.conf --case ${SCENE} 4 | 5 | python render_surface.py --data_dir ./data_flashlight/${SCENE}/train \ 6 | --out_dir ./exp_iron_stage2/${SCENE} \ 7 | --neus_ckpt_fpath ./exp_iron_stage1/${SCENE}/checkpoints/ckpt_100000.pth \ 8 | --num_iters 50001 --gamma_pred 9 | # render test set 10 | python render_surface.py --data_dir ./data_flashlight/${SCENE}/test \ 11 | --out_dir ./exp_iron_stage2/${SCENE} \ 12 | --neus_ckpt_fpath ./exp_iron_stage1/${SCENE}/checkpoints/ckpt_100000.pth \ 13 | --num_iters 50001 --gamma_pred --render_all 14 | --------------------------------------------------------------------------------