├── .env ├── .gitignore ├── .gitmodules ├── DSS ├── __init__.py ├── core │ ├── __init__.py │ ├── camera.py │ ├── cloud.py │ ├── lighting.py │ ├── rasterizer.py │ ├── renderer.py │ └── texture.py ├── csrc │ ├── bitmask.cuh │ ├── cuda_utils.h │ ├── ext.cpp │ ├── macros.hpp │ ├── rasterization_utils.cuh │ ├── rasterize_backward_cuda_kernel.cu │ ├── rasterize_forward_cuda_kernel.cu │ ├── rasterize_points.cu │ ├── rasterize_points.h │ ├── rasterize_points_backward.cu │ ├── rasterize_points_cpu.cpp │ ├── types.hpp │ ├── weighted_sum.cu │ └── weighted_sum.h ├── logger.py ├── misc │ ├── __init__.py │ ├── checkpoints.py │ └── visualize.py ├── models │ ├── __init__.py │ ├── combined_modeling.py │ ├── common.py │ ├── implicit_modeling.py │ ├── levelset_sampling.py │ ├── occupancy_modeling.py │ └── point_modeling.py ├── training │ ├── losses.py │ ├── scheduler.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── dataset.py │ ├── io.py │ ├── mathHelper.py │ ├── point_processing.py │ └── sampler.py ├── README.md ├── common.py ├── config.py ├── environment.yml ├── evaluation.py ├── generate_mvr.py ├── images ├── idr-mvr.png ├── idr-rabbit.png ├── points.png ├── sampling.png ├── siren-pointcloud.png └── siren-synthetic-mvr.png ├── requirements.txt ├── scripts ├── create_mvr_data_from_mesh.py ├── evaluatePointClouds.py ├── filter_dtu_predictions.py ├── gen_denoising_pairs.py └── plot_evaluations.py ├── setup.py ├── test_dtu_points.py ├── tests ├── test_data.py ├── test_dtu_points.py ├── test_dvr_camera.py ├── test_projection.py └── test_uniform_projection.py └── train_mvr.py /.env: -------------------------------------------------------------------------------- 1 | PYTHONPATH="./DSS":${PYTHONPATH} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # C++ build directory 2 | build 3 | 4 | 5 | # Created by https://www.gitignore.io/api/c++,cmake 6 | # Edit at https://www.gitignore.io/?templates=c++,cmake 7 | 8 | ### C++ ### 9 | # Prerequisites 10 | *.d 11 | 12 | # Compiled Object files 13 | *.slo 14 | *.lo 15 | *.o 16 | *.obj 17 | 18 | # Precompiled Headers 19 | *.gch 20 | *.pch 21 | 22 | # Compiled Dynamic libraries 23 | *.so 24 | *.dylib 25 | *.dll 26 | 27 | # Fortran module files 28 | *.mod 29 | *.smod 30 | 31 | # Compiled Static libraries 32 | *.lai 33 | *.la 34 | *.a 35 | *.lib 36 | 37 | # Executables 38 | *.exe 39 | *.out 40 | *.app 41 | 42 | ### CMake ### 43 | CMakeLists.txt.user 44 | CMakeCache.txt 45 | CMakeFiles 46 | CMakeScripts 47 | Testing 48 | Makefile 49 | cmake_install.cmake 50 | install_manifest.txt 51 | compile_commands.json 52 | CTestTestfile.cmake 53 | 54 | # End of https://www.gitignore.io/api/c++,cmake 55 | 56 | 57 | # Created by https://www.gitignore.io/api/vim,linux,macos,python,windows 58 | 59 | ### Linux ### 60 | *~ 61 | 62 | # temporary files which can be created if a process still has a handle open of a deleted file 63 | .fuse_hidden* 64 | 65 | # KDE directory preferences 66 | .directory 67 | 68 | # Linux trash folder which might appear on any partition or disk 69 | .Trash-* 70 | 71 | # .nfs files are created when an open file is removed but is still being accessed 72 | .nfs* 73 | 74 | ### macOS ### 75 | # General 76 | .DS_Store 77 | .AppleDouble 78 | .LSOverride 79 | 80 | # Icon must end with two \r 81 | Icon 82 | 83 | # Thumbnails 84 | ._* 85 | 86 | # Files that might appear in the root of a volume 87 | .DocumentRevisions-V100 88 | .fseventsd 89 | .Spotlight-V100 90 | .TemporaryItems 91 | .Trashes 92 | .VolumeIcon.icns 93 | .com.apple.timemachine.donotpresent 94 | 95 | # Directories potentially created on remote AFP share 96 | .AppleDB 97 | .AppleDesktop 98 | Network Trash Folder 99 | Temporary Items 100 | .apdisk 101 | 102 | 103 | ### Python ### 104 | # Byte-compiled / optimized / DLL files 105 | __pycache__/ 106 | *.py[cod] 107 | *$py.class 108 | 109 | # C extensions 110 | *.so 111 | 112 | # Distribution / packaging 113 | .Python 114 | build/ 115 | develop-eggs/ 116 | dist/ 117 | downloads/ 118 | eggs/ 119 | .eggs/ 120 | lib/ 121 | lib64/ 122 | parts/ 123 | sdist/ 124 | var/ 125 | wheels/ 126 | *.egg-info/ 127 | .installed.cfg 128 | *.egg 129 | MANIFEST 130 | 131 | # PyInstaller 132 | # Usually these files are written by a python script from a template 133 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 134 | *.manifest 135 | *.spec 136 | 137 | # Installer logs 138 | pip-log.txt 139 | pip-delete-this-directory.txt 140 | 141 | # Django stuff: 142 | *.log 143 | local_settings.py 144 | db.sqlite3 145 | 146 | # Flask stuff: 147 | instance/ 148 | .webassets-cache 149 | 150 | # Scrapy stuff: 151 | .scrapy 152 | 153 | # Sphinx documentation 154 | docs/_build/ 155 | 156 | # PyBuilder 157 | target/ 158 | 159 | # Jupyter Notebook 160 | .ipynb_checkpoints 161 | 162 | # IPython 163 | profile_default/ 164 | ipython_config.py 165 | 166 | # pyenv 167 | .python-version 168 | 169 | # celery beat schedule file 170 | celerybeat-schedule 171 | 172 | # SageMath parsed files 173 | *.sage.py 174 | 175 | # mypy 176 | .mypy_cache/ 177 | .dmypy.json 178 | dmypy.json 179 | 180 | ### Vim ### 181 | # Swap 182 | [._]*.s[a-v][a-z] 183 | [._]*.sw[a-p] 184 | [._]s[a-rt-v][a-z] 185 | [._]ss[a-gi-z] 186 | [._]sw[a-p] 187 | 188 | # Session 189 | Session.vim 190 | 191 | # Temporary 192 | .netrwhist 193 | # Auto-generated tag files 194 | tags 195 | # Persistent undo 196 | [._]*.un~ 197 | 198 | ### Windows ### 199 | # Windows thumbnail cache files 200 | Thumbs.db 201 | ehthumbs.db 202 | ehthumbs_vista.db 203 | 204 | # Dump file 205 | *.stackdump 206 | 207 | # Folder config file 208 | [Dd]esktop.ini 209 | 210 | # Recycle Bin used on file shares 211 | $RECYCLE.BIN/ 212 | 213 | # End of https://www.gitignore.io/api/vim,linux,macos,python,windows 214 | .vscode 215 | 216 | learn_examples/ 217 | trained_models/** 218 | example_data/pointclouds/ 219 | scripts/renders 220 | renders/ 221 | .mypy_cache/ 222 | tests/outputs 223 | tests/test_inputs 224 | data 225 | exp 226 | configs 227 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/torch-batch-svd"] 2 | path = external/torch-batch-svd 3 | url = https://github.com/KinglittleQ/torch-batch-svd.git 4 | ignore = untracked 5 | [submodule "external/FRNN"] 6 | path = external/FRNN 7 | url = https://github.com/lxxue/FRNN.git 8 | ignore = untracked 9 | -------------------------------------------------------------------------------- /DSS/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .logger import get_logger 4 | from collections import OrderedDict 5 | 6 | logger_py = get_logger(__name__) 7 | 8 | _debug = False 9 | _debugging_tensor = None 10 | 11 | def set_deterministic_(): 12 | torch.manual_seed(0) 13 | torch.backends.cudnn.deterministic = True 14 | torch.backends.cudnn.benchmark = False 15 | np.random.seed(0) 16 | 17 | # Each attribute contains list of tensors or dictionaries, where 18 | # each element in the list is a sample in the minibatch. 19 | # If dictionaries are used, then the (keys, tensor) will be used to plot 20 | # debugging visuals separately. 21 | class DebuggingTensor: 22 | __slots__ = ['pts_world', 23 | 'pts_world_grad', 24 | 'img_mask_grad'] 25 | 26 | def __init__(self,): 27 | self.pts_world = OrderedDict() 28 | self.pts_world_grad = OrderedDict() 29 | self.img_mask_grad = OrderedDict() 30 | 31 | 32 | def set_debugging_mode_(is_debug, *args, **kwargs): 33 | global _debugging_tensor, _debug 34 | _debug = is_debug 35 | if _debug: 36 | _debugging_tensor = DebuggingTensor(*args, **kwargs) 37 | logger_py.info('Enabled debugging mode.') 38 | else: 39 | _debugging_tensor = None 40 | 41 | 42 | def get_debugging_mode(): 43 | return _debug 44 | 45 | 46 | def get_debugging_tensor(): 47 | if _debugging_tensor is None: 48 | logger_py.warning( 49 | 'Attempt to get debugging tensor before setting debugging mode to true.') 50 | set_debugging_mode_(True) 51 | return _debugging_tensor 52 | 53 | 54 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 55 | -------------------------------------------------------------------------------- /DSS/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/iso-points/3a126b3765bb895a8fb148e8186befe2b3a35f73/DSS/core/__init__.py -------------------------------------------------------------------------------- /DSS/core/camera.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch3d.renderer.cameras import (PerspectiveCameras, 3 | look_at_view_transform) 4 | 5 | 6 | class CameraSampler(object): 7 | """ 8 | create camera transformations looking at the origin of the coordinate 9 | from varying distance 10 | 11 | Attributes: 12 | R, T: (num_cams_total, 3, 3) and (num_cams_total, 3) 13 | camera_type (Class): class to create a new camera 14 | camera_params (dict): camera parameters to call camera_type 15 | (besides R, T) 16 | """ 17 | 18 | def __init__(self, num_cams_total, num_cams_batch, 19 | distance_range=(5, 10), sort_distance=True, 20 | return_cams=True, 21 | camera_type=PerspectiveCameras, camera_params=None): 22 | """ 23 | Args: 24 | num_cams_total (int): the total number of cameras to sample 25 | num_cams_batch (int): the number of cameras per iteration 26 | distance_range (tensor or list): (num_cams_total, 2) or (1, 2) 27 | the range of camera distance for uniform sampling 28 | sort_distance: sort the created camera transformations by the 29 | distance in ascending order 30 | return_cams (bool): whether to return camera instances or just the R,T 31 | camera_type (class): camera type from pytorch3d.renderer.cameras 32 | camera_params (dict): camera parameters besides R, T 33 | """ 34 | self.num_cams_batch = num_cams_batch 35 | self.num_cams_total = num_cams_total 36 | 37 | self.sort_distance = sort_distance 38 | self.camera_type = camera_type 39 | self.camera_params = {} if camera_params is None else camera_params 40 | 41 | # create camera locations 42 | distance_scale = distance_range[:, -1] - distance_range[:, 0] 43 | distances = torch.rand(num_cams_total) * distance_scale + \ 44 | distance_range[:, 0] 45 | if sort_distance: 46 | distances, _ = distances.sort(descending=True) 47 | azim = torch.rand(num_cams_total) * 360 - 180 48 | elev = torch.rand(num_cams_total) * 180 - 90 49 | at = torch.rand((num_cams_total, 3)) * 0.1 - 0.05 50 | self.R, self.T = look_at_view_transform( 51 | distances, elev, azim, at=at, degrees=True) 52 | 53 | self._idx = 0 54 | 55 | def __len__(self): 56 | return (self.R.shape[0] + self.num_cams_batch - 1) // \ 57 | self.num_cams_batch 58 | 59 | def __iter__(self): 60 | return self 61 | 62 | def __next__(self): 63 | if self._idx >= len(self): 64 | raise StopIteration 65 | start_idx = self._idx * self.num_cams_batch 66 | end_idx = min(start_idx + self.num_cams_batch, self.R.shape[0]) 67 | cameras = self.camera_type(R=self.R[start_idx:end_idx], 68 | T=self.T[start_idx:end_idx], 69 | **self.camera_params) 70 | self._idx += 1 71 | return cameras 72 | -------------------------------------------------------------------------------- /DSS/core/lighting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handle multiple light sources per batch for pytorch3d.renderer.lighting 3 | """ 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from pytorch3d.renderer import lighting, convert_to_tensors_and_broadcast 8 | 9 | 10 | def diffuse(normals, color, direction) -> torch.Tensor: 11 | """ 12 | Calculate the diffuse component of light reflection using Lambert's 13 | cosine law. 14 | 15 | Args: 16 | normals: (N, ..., 3) xyz normal vectors. Normals and points are 17 | expected to have the same shape. 18 | color: (1, L, 3) or (N, L, 3) RGB color of the diffuse component of the light. 19 | direction: (x,y,z) direction of the light 20 | 21 | Returns: 22 | colors: (N, ..., 3), same shape as the input points. 23 | 24 | The normals and light direction should be in the same coordinate frame 25 | i.e. if the points have been transformed from world -> view space then 26 | the normals and direction should also be in view space. 27 | 28 | NOTE: to use with the packed vertices (i.e. no batch dimension) reformat the 29 | inputs in the following way. 30 | 31 | .. code-block:: python 32 | 33 | Args: 34 | normals: (P, 3) 35 | color: (N, L, 3)[batch_idx, :] -> (P, L, 3) 36 | direction: (N, L, 3)[batch_idx, :] -> (P, L, 3) 37 | 38 | Returns: 39 | colors: (P, 3) 40 | 41 | where batch_idx is of shape (P). For meshes, batch_idx can be: 42 | meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx() 43 | depending on whether points refers to the vertex coordinates or 44 | average/interpolated face coordinates. 45 | """ 46 | # TODO: handle attentuation. 47 | # Ensure color and location have same batch dimension as normals 48 | # (N,3) (1,L,3) 49 | normals, color, direction = convert_to_tensors_and_broadcast( 50 | normals, color, direction, device=normals.device 51 | ) 52 | 53 | # Ensure the same number of light color and light direction 54 | num_lights_per_batch = color.shape[1] 55 | assert(direction.shape[1] == num_lights_per_batch), \ 56 | "color and direction must have the length on dimension (1), {} != {}".format( 57 | color.shape, direction.shape 58 | ) 59 | 60 | normals = normals[:, None, ...] 61 | # Reshape direction and color so they have all the arbitrary intermediate 62 | # dimensions as normals. Assume first dim = batch dim, seconde dim = light dim 63 | # and last dim = 3. 64 | points_dims = normals.shape[2:-1] 65 | expand_dims = (-1, num_lights_per_batch) + (1,) * len(points_dims) + (3,) 66 | if direction.shape[2: -1] != normals.shape[2: -1]: 67 | direction = direction.view(expand_dims) 68 | if color.shape[2: -1] != normals.shape[2: -1]: 69 | color = color.view(expand_dims) 70 | 71 | # Renormalize the normals in case they have been interpolated. 72 | normals = F.normalize(normals, p=2, dim=-1, eps=1e-6) 73 | direction = F.normalize(direction, p=2, dim=-1, eps=1e-6) 74 | angle = F.relu(torch.sum(normals * direction, dim=-1)) 75 | 76 | # Sum light sources 77 | acc_color = torch.sum(color * angle[..., None], dim=1) 78 | return acc_color 79 | 80 | 81 | def specular( 82 | points, normals, direction, color, camera_position, shininess 83 | ) -> torch.Tensor: 84 | """ 85 | Calculate the specular component of light reflection. 86 | 87 | Args: 88 | points: (N, ..., 3) xyz coordinates of the points. 89 | normals: (N, ..., 3) xyz normal vectors for each point. 90 | color: (N, L, 3) RGB color of the specular component of the light. 91 | direction: (N, L, 3) vector direction of the light. 92 | camera_position: (N, 3) The xyz position of the camera. 93 | shininess: (N) The specular exponent of the material. 94 | 95 | Returns: 96 | colors: (N, ..., 3), same shape as the input points. 97 | 98 | The points, normals, camera_position, and direction should be in the same 99 | coordinate frame i.e. if the points have been transformed from 100 | world -> view space then the normals, camera_position, and light direction 101 | should also be in view space. 102 | 103 | To use with a batch of packed points reindex in the following way. 104 | .. code-block:: python:: 105 | 106 | Args: 107 | points: (P, 3) 108 | normals: (P, 3) 109 | color: (N, L, 3)[batch_idx] -> (P, 3) 110 | direction: (N, L, 3)[batch_idx] -> (P, 3) 111 | camera_position: (N, 3)[batch_idx] -> (P, 3) 112 | shininess: (N)[batch_idx] -> (P) 113 | Returns: 114 | colors: (P, 3) 115 | 116 | where batch_idx is of shape (P). For meshes batch_idx can be: 117 | meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx(). 118 | """ 119 | # TODO: attentuate based on inverse squared distance to the light source 120 | if points.shape != normals.shape: 121 | msg = "Expected points and normals to have the same shape: got %r, %r" 122 | raise ValueError(msg % (points.shape, normals.shape)) 123 | 124 | # Ensure all inputs have same batch dimension as points 125 | matched_tensors = convert_to_tensors_and_broadcast( 126 | points, color, direction, camera_position, shininess, device=points.device 127 | ) 128 | _, color, direction, camera_position, shininess = matched_tensors 129 | 130 | # Ensure the same number of light color and light direction 131 | num_lights_per_batch = color.shape[1] 132 | assert(direction.shape[1] == num_lights_per_batch), \ 133 | "color and direction must have the length on dimension (1), {} != {}".format( 134 | color.shape, direction.shape 135 | ) 136 | batch_size = color.shape[0] 137 | points = points[:, None, ...] 138 | normals = normals[:, None, ...] 139 | camera_position = camera_position[:, None, ...] 140 | # Reshape direction and color so they have all the arbitrary intermediate 141 | # dimensions as points. Assume first dim = batch dim, seconde dim = light dim 142 | # and last dim = 3. 143 | points_dims = normals.shape[2:-1] 144 | expand_dims = (-1,) + (1,) * len(points_dims) 145 | if direction.shape[2: -1] != normals.shape[2: -1]: 146 | direction = direction.view((batch_size,) + expand_dims + (3,)) 147 | if color.shape[2: -1] != normals.shape[2: -1]: 148 | color = color.view((batch_size,) + expand_dims + (3,)) 149 | if camera_position.shape != normals.shape: 150 | camera_position = camera_position.view( 151 | (batch_size,) + expand_dims + (3,)) 152 | if shininess.shape != normals.shape: 153 | shininess = shininess.view((batch_size,) + expand_dims) 154 | 155 | # Renormalize the normals in case they have been interpolated. 156 | normals = F.normalize(normals, p=2, dim=-1, eps=1e-6) 157 | direction = F.normalize(direction, p=2, dim=-1, eps=1e-6) 158 | cos_angle = torch.sum(normals * direction, dim=-1) 159 | # No specular highlights if angle is less than 0. 160 | mask = (cos_angle > 0).to(torch.float32) 161 | 162 | # Calculate the specular reflection. 163 | view_direction = camera_position - points 164 | view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6) 165 | reflect_direction = -direction + 2 * (cos_angle[..., None] * normals) 166 | 167 | # Cosine of the angle between the reflected light ray and the viewer 168 | alpha = F.relu(torch.sum(view_direction * reflect_direction, 169 | dim=-1)) * mask 170 | 171 | acc_color = torch.sum(color * torch.pow(alpha, shininess)[..., None], 172 | dim=1) 173 | return acc_color 174 | 175 | 176 | class DirectionalLights(lighting.DirectionalLights): 177 | 178 | def __init__( 179 | self, 180 | ambient_color=(((0.5, 0.5, 0.5), ), ), 181 | diffuse_color=(((0.3, 0.3, 0.3), ), ), 182 | specular_color=(((0.2, 0.2, 0.2), ), ), 183 | direction=(((0, 1, 0), ), ), 184 | device: str = "cpu", **kwargs 185 | ): 186 | """ 187 | Args: 188 | ambient_color: RGB color of the ambient component. 189 | diffuse_color: RGB color of the diffuse component. 190 | specular_color: RGB color of the specular component. 191 | direction: (x, y, z) direction vector of the light. 192 | device: torch.device on which the tensors should be located 193 | 194 | The inputs can each be 195 | - 3 element tuple/list or list of lists 196 | - torch tensor of shape (1, 1, 3) 197 | - torch tensor of shape (1, L, 3) 198 | - torch tensor of shape (N, L, 3) 199 | The inputs are broadcast against each other so they all have batch 200 | dimension N. 201 | """ 202 | super().__init__( 203 | device=device, 204 | ambient_color=ambient_color, 205 | diffuse_color=diffuse_color, 206 | specular_color=specular_color, 207 | direction=direction, 208 | ) 209 | # check diffuse_color, specular_color and direction 210 | for prop in ('diffuse_color', 'specular_color', 'direction'): 211 | if getattr(self, prop).dim() != 3: 212 | raise ValueError("{} must be (N,L,3) tensor, got {}".format( 213 | prop, repr(getattr(self, prop).shape))) 214 | 215 | def diffuse(self, normals, points=None) -> torch.Tensor: 216 | # NOTE: Points is not used but is kept in the args so that the API is 217 | # the same for directional and point lights. The call sites should not 218 | # need to know the light type. 219 | return diffuse( 220 | normals=normals, color=self.diffuse_color, direction=self.direction 221 | ) 222 | 223 | def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: 224 | return specular( 225 | points=points, 226 | normals=normals, 227 | color=self.specular_color, 228 | direction=self.direction, 229 | camera_position=camera_position, 230 | shininess=shininess, 231 | ) 232 | 233 | 234 | class PointLights(lighting.PointLights): 235 | def __init__( 236 | self, 237 | ambient_color=(((0.5, 0.5, 0.5), ), ), 238 | diffuse_color=(((0.3, 0.3, 0.3), ), ), 239 | specular_color=(((0.2, 0.2, 0.2), ), ), 240 | location=(((0, 1, 0), ), ), 241 | device: str = "cpu", **kwargs 242 | ): 243 | """ 244 | Args: 245 | ambient_color: RGB color of the ambient component 246 | diffuse_color: RGB color of the diffuse component 247 | specular_color: RGB color of the specular component 248 | location: xyz position of the light. 249 | device: torch.device on which the tensors should be located 250 | 251 | The inputs can each be 252 | - 3 element tuple/list or list of lists 253 | - torch tensor of shape (1, L, 3) 254 | - torch tensor of shape (N, L, 3) 255 | The inputs are broadcast against each other so they all have batch 256 | dimension N. 257 | """ 258 | super().__init__( 259 | device=device, 260 | ambient_color=ambient_color, 261 | diffuse_color=diffuse_color, 262 | specular_color=specular_color, 263 | location=location, 264 | ) 265 | 266 | def diffuse(self, normals, points) -> torch.Tensor: 267 | location, points = convert_to_tensors_and_broadcast( 268 | self.location, points, device=self.device) 269 | batch, L = location.shape[:2] 270 | location = location 271 | 272 | location = location.view( 273 | (batch, L) + (1,) * (points.ndim - 2) + (3,)) 274 | 275 | direction = location - points.unsqueeze(1) 276 | return diffuse(normals=normals, color=self.diffuse_color, direction=direction) 277 | 278 | def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: 279 | """ 280 | Args: 281 | points (N,*,3) 282 | normals (N,*,3) 283 | camera_position (N,3) or (1,3) 284 | shininess 285 | """ 286 | location, points = convert_to_tensors_and_broadcast( 287 | self.location, points, device=self.device) 288 | batch, L = location.shape[:2] 289 | location = location 290 | 291 | location = location.view( 292 | (batch, L) + (1,) * (points.ndim - 2) + (3,)) 293 | 294 | direction = location - points.unsqueeze(1) 295 | return specular( 296 | points=points, 297 | normals=normals, 298 | color=self.specular_color, 299 | direction=direction, 300 | camera_position=camera_position, 301 | shininess=shininess, 302 | ) 303 | -------------------------------------------------------------------------------- /DSS/core/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch3d.renderer import PointsRenderer, NormWeightedCompositor 3 | from pytorch3d.renderer.compositing import weighted_sum 4 | from .. import logger_py 5 | 6 | 7 | __all__ = ['SurfaceSplattingRenderer'] 8 | 9 | """ 10 | Returns a 4-Channel image for RGBA 11 | """ 12 | 13 | 14 | class SurfaceSplattingRenderer(PointsRenderer): 15 | 16 | def __init__(self, rasterizer, compositor, antialiasing_sigma: float = 1.0, 17 | density: float = 1e-4, frnn_radius=-1): 18 | super().__init__(rasterizer, compositor) 19 | 20 | self.cameras = self.rasterizer.cameras 21 | self._Vrk_h = None 22 | # screen space low pass filter 23 | self.antialiasing_sigma = antialiasing_sigma 24 | # average of squared distance to the nearest neighbors 25 | self.density = density 26 | 27 | if self.compositor is None: 28 | logger_py.info('Composite with weighted sum.') 29 | elif not isinstance(self.compositor, NormWeightedCompositor): 30 | logger_py.warning('Expect a NormWeightedCompositor, but initialized with {}'.format( 31 | self.compositor.__class__.__name__)) 32 | 33 | self.frnn_radius = frnn_radius 34 | # logger_py.error("frnn_radius: {}".format(frnn_radius)) 35 | 36 | def forward(self, point_clouds, **kwargs) -> torch.Tensor: 37 | """ 38 | point_clouds_filter: used to get activation mask and update visibility mask 39 | cutoff_threshold 40 | """ 41 | if point_clouds.isempty(): 42 | return None 43 | 44 | # rasterize 45 | fragments = kwargs.get('fragments', None) 46 | if fragments is None: 47 | if kwargs.get('verbose', False): 48 | fragments, point_clouds, per_point_info = self.rasterizer(point_clouds, **kwargs) 49 | else: 50 | fragments, point_clouds = self.rasterizer(point_clouds, **kwargs) 51 | 52 | # compute weight: scalar*exp(-0.5Q) 53 | weights = torch.exp(-0.5 * fragments.qvalue) * fragments.scaler 54 | weights = weights.permute(0, 3, 1, 2) 55 | 56 | # from fragments to rgba 57 | pts_rgb = point_clouds.features_packed()[:, :3] 58 | 59 | if self.compositor is None: 60 | # NOTE: weight _splat_points_weights_backward, weighted sum will return 61 | # zero gradient for the weights. 62 | images = weighted_sum(fragments.idx.long().permute(0, 3, 1, 2), 63 | weights, 64 | pts_rgb.permute(1, 0), 65 | **kwargs) 66 | else: 67 | images = self.compositor( 68 | fragments.idx.long().permute(0, 3, 1, 2), 69 | weights, 70 | pts_rgb.permute(1, 0), 71 | **kwargs 72 | ) 73 | 74 | # permute so image comes at the end 75 | images = images.permute(0, 2, 3, 1) 76 | mask = fragments.occupancy 77 | 78 | images = torch.cat([images, mask.unsqueeze(-1)], dim=-1) 79 | 80 | if kwargs.get('verbose', False): 81 | return images, fragments 82 | return images 83 | -------------------------------------------------------------------------------- /DSS/core/texture.py: -------------------------------------------------------------------------------- 1 | """ 2 | PointTexture class 3 | 4 | Inputs should be fragments (including point location, 5 | normals and other features) 6 | Output is the color per point (doesn't have the blending step) 7 | 8 | diffuse shader 9 | specular shader 10 | neural shader 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from pytorch3d.renderer.cameras import OrthographicCameras 16 | from .lighting import DirectionalLights 17 | from .cloud import PointClouds3D 18 | from .. import logger_py 19 | from ..utils import gather_batch_to_packed 20 | 21 | 22 | __all__ = ["LightingTexture", "NeuralTexture"] 23 | 24 | 25 | def apply_lighting(points, normals, lights, cameras, 26 | specular=True, shininess=64): 27 | """ 28 | Args: 29 | points: torch tensor of shape (N, P, 3) or (P, 3). 30 | normals: torch tensor of shape (N, P, 3) or (P, 3) 31 | lights: instance of the Lights class. 32 | cameras: instance of the Cameras class. 33 | shininess: scalar for the specular coefficient. 34 | specular: (bool) whether to add the specular effect 35 | 36 | Returns: 37 | ambient_color: same shape as materials.ambient_color 38 | diffuse_color: same shape as the input points 39 | specular_color: same shape as the input points 40 | """ 41 | light_diffuse = lights.diffuse(normals=normals, points=points) 42 | light_specular = lights.specular( 43 | normals=normals, 44 | points=points, 45 | camera_position=cameras.get_camera_center(), 46 | shininess=shininess, 47 | ) 48 | ambient_color = lights.ambient_color 49 | if ambient_color.ndim==3: 50 | if ambient_color.shape[1] > 1: 51 | logger_py.warn('Found multiple ambient colors') 52 | ambient_color = torch.sum(ambient_color, dim=1) 53 | diffuse_color = light_diffuse 54 | specular_color = light_specular 55 | if normals.dim() == 2 and points.dim() == 2: 56 | # If given packed inputs remove batch dim in output. 57 | return ( 58 | ambient_color.squeeze(0), 59 | diffuse_color.squeeze(0), 60 | specular_color.squeeze(0), 61 | ) 62 | return ambient_color, diffuse_color, specular_color 63 | 64 | 65 | class LightingTexture(nn.Module): 66 | def __init__(self, device="cpu", 67 | cameras=None, lights=None, materials=None): 68 | super().__init__() 69 | self.lights = lights 70 | self.cameras = cameras 71 | if materials is not None: 72 | logger_py.warning("Material is not supported, ignored.") 73 | 74 | def forward(self, pointclouds, shininess=64, **kwargs) -> PointClouds3D: 75 | """ 76 | Args: 77 | pointclouds (Pointclouds3D) 78 | points_rgb (P, 3): same shape as the packed features 79 | Returns: 80 | pointclouds (Pointclouds3D) with features set to RGB colors 81 | """ 82 | if pointclouds.isempty(): 83 | return pointclouds 84 | 85 | lights = kwargs.get("lights", self.lights).to(pointclouds.device) 86 | cameras = kwargs.get("cameras", self.cameras).to(pointclouds.device) 87 | if len(cameras) != len(pointclouds) and len(pointclouds) == 1: 88 | pointclouds = pointclouds.extend(len(cameras)) 89 | points = pointclouds.points_packed() 90 | point_normals = pointclouds.normals_packed() 91 | points_rgb = kwargs.get("points_rgb", None) 92 | if points_rgb is None: 93 | try: 94 | points_rgb = pointclouds.features_packed()[:, :3] 95 | except: 96 | points_rgb = torch.ones_like(points) 97 | 98 | if point_normals is None: 99 | logger_py.warning("Point normals are required, " 100 | "but not available in pointclouds. " 101 | "Using estimated normals instead.") 102 | 103 | vert_to_cloud_idx = pointclouds.packed_to_cloud_idx() 104 | if points_rgb.shape[-1] != 3: 105 | raise ValueError("Expected points_rgb to be 3-channel," 106 | "got {}".format(points_rgb.shape)) 107 | 108 | # Format properties of lights and materials so they are compatible 109 | # with the packed representation of the vertices. This transforms 110 | # all tensor properties in the class from shape (N, ...) -> (V, ...) where 111 | # V is the number of packed vertices. If the number of meshes in the 112 | # batch is one then this is not necessary. 113 | if len(pointclouds) > 1: 114 | lights = lights.clone().gather_props(vert_to_cloud_idx) 115 | cameras = cameras.clone().gather_props(vert_to_cloud_idx) 116 | 117 | # Calculate the illumination at each point 118 | ambient, diffuse, specular = apply_lighting( 119 | points, point_normals, lights, cameras, 120 | shininess=shininess, 121 | ) 122 | points_colors_shaded = points_rgb * (ambient + diffuse) + specular 123 | 124 | pointclouds_colored = pointclouds.clone() 125 | pointclouds_colored.update_features_(points_colors_shaded) 126 | 127 | return pointclouds_colored 128 | 129 | 130 | class NeuralTexture(nn.Module): 131 | def __init__(self, decoder, view_dependent=True): 132 | super().__init__() 133 | self.view_dependent = view_dependent 134 | self.decoder = decoder 135 | 136 | def forward(self, pointclouds: PointClouds3D, c=None, **kwargs) -> PointClouds3D: 137 | if self.decoder.dim == 3 and not self.view_dependent: 138 | x = pointclouds.points_packed() 139 | else: 140 | x = pointclouds.normals_packed() 141 | assert(x is not None) 142 | # x = F.normalize(x, dim=-1, eps=1e-15) 143 | p = pointclouds.points_packed() 144 | x = torch.cat([x,p], dim=-1) 145 | if self.view_dependent: 146 | cameras = kwargs.get('cameras', None) 147 | if cameras is not None: 148 | cameras = cameras.to(pointclouds.device) 149 | cam_pos = cameras.get_camera_center() 150 | cam_pos = gather_batch_to_packed( 151 | cam_pos, pointclouds.packed_to_cloud_idx()) 152 | view_direction = p[...,:3].detach() - cam_pos 153 | view_direction = F.normalize(view_direction, dim=-1) 154 | if hasattr(self.decoder, 'embed_fn') and self.decoder.embed_fn is not None: 155 | view_direction = self.decoder.embed_fn(view_direction) 156 | x = torch.cat([x, view_direction], dim=-1) 157 | 158 | 159 | output = self.decoder(x, c=c, **kwargs) 160 | pointclouds_colored = pointclouds.clone() 161 | pointclouds_colored.update_features_(output.rgb) 162 | return pointclouds_colored 163 | -------------------------------------------------------------------------------- /DSS/csrc/bitmask.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #pragma once 4 | #define BINMASK_H 5 | 6 | // A BitMask represents a bool array of shape (H, W, N). We pack values into 7 | // the bits of unsigned ints; a single unsigned int has B = 32 bits, so to hold 8 | // all values we use H * W * (N / B) = H * W * D values. We want to store 9 | // BitMasks in shared memory, so we assume that the memory has already been 10 | // allocated for it elsewhere. 11 | class BitMask { 12 | public: 13 | __device__ BitMask(unsigned int* data, int H, int W, int N) 14 | : data(data), H(H), W(W), B(8 * sizeof(unsigned int)), D(N / B) { 15 | // TODO: check if the data is null. 16 | N = ceilf(N % 32); // take ceil incase N % 32 != 0 17 | block_clear(); // clear the data 18 | } 19 | 20 | // Use all threads in the current block to clear all bits of this BitMask 21 | __device__ void block_clear() { 22 | for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) { 23 | data[i] = 0; 24 | } 25 | __syncthreads(); 26 | } 27 | 28 | __device__ int _get_elem_idx(int y, int x, int d) { 29 | return y * W * D + x * D + d / B; 30 | } 31 | 32 | __device__ int _get_bit_idx(int d) { 33 | return d % B; 34 | } 35 | 36 | // Turn on a single bit (y, x, d) 37 | __device__ void set(int y, int x, int d) { 38 | int elem_idx = _get_elem_idx(y, x, d); 39 | int bit_idx = _get_bit_idx(d); 40 | const unsigned int mask = 1U << bit_idx; 41 | atomicOr(data + elem_idx, mask); 42 | } 43 | 44 | // Turn off a single bit (y, x, d) 45 | __device__ void unset(int y, int x, int d) { 46 | int elem_idx = _get_elem_idx(y, x, d); 47 | int bit_idx = _get_bit_idx(d); 48 | const unsigned int mask = ~(1U << bit_idx); 49 | atomicAnd(data + elem_idx, mask); 50 | } 51 | 52 | // Check whether the bit (y, x, d) is on or off 53 | __device__ bool get(int y, int x, int d) { 54 | int elem_idx = _get_elem_idx(y, x, d); 55 | int bit_idx = _get_bit_idx(d); 56 | return (data[elem_idx] >> bit_idx) & 1U; 57 | } 58 | 59 | // Compute the number of bits set in the row (y, x, :) 60 | __device__ int count(int y, int x) { 61 | int total = 0; 62 | for (int i = 0; i < D; ++i) { 63 | int elem_idx = y * W * D + x * D + i; 64 | unsigned int elem = data[elem_idx]; 65 | total += __popc(elem); 66 | } 67 | return total; 68 | } 69 | 70 | private: 71 | unsigned int* data; 72 | int H, W, B, D; 73 | }; 74 | -------------------------------------------------------------------------------- /DSS/csrc/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | #include 4 | #include 5 | 6 | #define TOTAL_THREADS 512 7 | 8 | inline int opt_n_threads(int work_size) { 9 | // round work_size to power of 2 between 512 and 1 10 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 11 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 12 | } 13 | 14 | inline dim3 opt_block_config(int x, int y) { 15 | const int x_threads = opt_n_threads(x); 16 | const int y_threads = 17 | std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 18 | dim3 block_config(x_threads, y_threads, 1); 19 | 20 | return block_config; 21 | } 22 | 23 | #define CUDA_CHECK_ERRORS() \ 24 | do { \ 25 | cudaError_t err = cudaGetLastError(); \ 26 | if (cudaSuccess != err) { \ 27 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 28 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 29 | __FILE__); \ 30 | exit(-1); \ 31 | } \ 32 | } while (0) 33 | #endif 34 | -------------------------------------------------------------------------------- /DSS/csrc/ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "rasterize_points.h" 3 | 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // module docstring 7 | m.doc() = "pybind11 compute_visibility_maps plugin"; 8 | m.def("splat_points", &RasterizePoints); 9 | m.def("_splat_points_naive", &RasterizePointsNaive); 10 | m.def("_splat_points_occ_backward", &RasterizePointsOccBackward); 11 | m.def("_rasterize_coarse", &RasterizePointsCoarse); 12 | m.def("_rasterize_fine", &RasterizePointsFine); 13 | #ifdef WITH_CUDA 14 | m.def("_splat_points_occ_fast_cuda_backward", &RasterizePointsBackwardCudaFast); 15 | #endif 16 | m.def("_splat_points_occ_backward", &RasterizePointsOccBackward); 17 | m.def("_backward_zbuf", &RasterizeZbufBackward); 18 | } 19 | -------------------------------------------------------------------------------- /DSS/csrc/macros.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define CHECK_INPUT(x) \ 5 | TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); \ 6 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); -------------------------------------------------------------------------------- /DSS/csrc/rasterization_utils.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #pragma once 4 | 5 | // Given a pixel coordinate 0 <= i < S, convert it to a normalized device 6 | // coordinate in the range [-1, 1]. We divide the NDC range into S evenly-sized 7 | // pixels, and assume that each pixel falls in the *center* of its range. 8 | __device__ inline float PixToNdc(int i, int S) { 9 | // NDC x-offset + (i * pixel_width + half_pixel_width) 10 | return -1 + (2 * i + 1.0f) / S; 11 | } 12 | 13 | // The maximum number of points per pixel that we can return. Since we use 14 | // thread-local arrays to hold and sort points, the maximum size of the array 15 | // needs to be known at compile time. There might be some fancy template magic 16 | // we could use to make this more dynamic, but for now just fix a constant. 17 | // TODO: is 8 enough? Would increasing have performance considerations? 18 | const int32_t kMaxPointsPerPixel = 150; 19 | 20 | const int32_t kMaxFacesPerBin = 22; 21 | 22 | template 23 | __device__ inline void BubbleSort(T* arr, int n) { 24 | // Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this 25 | // regime we care more about warp divergence than computational complexity. 26 | for (int i = 0; i < n - 1; ++i) { 27 | for (int j = 0; j < n - i - 1; ++j) { 28 | if (arr[j + 1] < arr[j]) { 29 | T temp = arr[j]; 30 | arr[j] = arr[j + 1]; 31 | arr[j + 1] = temp; 32 | } 33 | } 34 | } 35 | } 36 | 37 | template 38 | __device__ inline T eps_denom(const T denom, const T eps) 39 | { 40 | int denom_sign = (T(0.0) < denom) - (denom < T(0.0)); 41 | T safe_denom = T(denom_sign) * max(abs(denom), eps); 42 | return safe_denom; 43 | } 44 | template 45 | __device__ inline T clamp(const T x, const T a, const T b) 46 | { 47 | return max(a, min(b, x)); 48 | } -------------------------------------------------------------------------------- /DSS/csrc/rasterize_forward_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "cuda_utils.h" 2 | #include "macros.hpp" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | /* 12 | Given point cloud's screen position, local depth values and local 13 | filter values rho, 14 | output a PxHxW point index map indicating which point contributes 15 | to which pixel by how much. 16 | (pytorch tensors are row-major) 17 | */ 18 | template 19 | __global__ void 20 | gather_maps_kernel(int batchSize, int numPoint, int imgWidth, int imgHeight, 21 | int topK, int C, 22 | const indice_t *__restrict__ indices, // BxHxWxtopK 23 | const scalar_t *__restrict__ data, // BxNxC 24 | const scalar_t defaultValue, 25 | scalar_t *output) // BxHxWxtopKxC 26 | { 27 | const int numPixels = imgWidth * imgHeight; 28 | // loop all pixels 29 | for (int b = blockIdx.x; b < batchSize; b += gridDim.x) { 30 | for (int p = threadIdx.x + blockIdx.y * blockDim.x; p < numPixels; 31 | p += blockDim.x * gridDim.y) { 32 | const int pixID = b * numPixels + p; 33 | // loop over topK dimension 34 | for (int i = 0; i < topK; i++) { 35 | const indice_t pid = indices[pixID * topK + i]; 36 | for (int c = 0; c < C; c++) { 37 | // dereference point from the N dimension of data 38 | if (pid < 0 || pid > numPoint) 39 | output[pixID * topK * C + i * C + c] = defaultValue; 40 | else 41 | output[pixID * topK * C + i * C + c] = 42 | data[(b * numPoint + pid) * C + c]; 43 | } 44 | } 45 | } 46 | } 47 | } 48 | /* put gradient BxHxWxKxC to the correct position in BxNxC */ 49 | template 50 | __global__ void 51 | scatter_maps_kernel(int batchSize, int numPoint, int imgWidth, int imgHeight, 52 | int topK, int C, 53 | const scalar_t *__restrict__ outGrad, // BxHxWxtopKxC 54 | const indice_t *__restrict__ indices, // BxHxWxtopK 55 | scalar_t *__restrict__ dataGrad) // BxNxC 56 | { 57 | // const int numPixels = imgWidth * imgHeight; 58 | // loop all points 59 | // TODO instead of looping over points then pixels, loop over pixels only and write 60 | // in points with cudasync? 61 | for (int b = blockIdx.x; b < batchSize; b += gridDim.x) { 62 | for (indice_t p = threadIdx.x + blockIdx.y * blockDim.x; p < numPoint; 63 | p += blockDim.x * gridDim.y) { 64 | // loop over all pixels 65 | for (int i = 0; i < imgHeight; i++) { 66 | for (int j = 0; j < imgWidth; j++) { 67 | int pixID = b * imgWidth * imgHeight + i * imgWidth + j; 68 | const indice_t *iOffset = indices + pixID * topK; 69 | for (int k = 0; k < topK; k++) { 70 | indice_t pid = iOffset[k]; 71 | if (pid == p) { 72 | for (int c = 0; c < C; c++) { 73 | dataGrad[(b * numPoint + pid) * C + c] += 74 | outGrad[(pixID * topK + k) * C + c]; 75 | } 76 | } 77 | } 78 | } 79 | } 80 | } 81 | } 82 | } 83 | /* put gradient BxHxWxKxC to the correct position in BxNxC */ 84 | template 85 | __global__ void 86 | guided_scatter_maps_kernel(int batchSize, int numPoint, int imgWidth, 87 | int imgHeight, int topK, int C, 88 | const scalar_t *__restrict__ outGrad, // BxHxWxtopKxC 89 | const indice_t *__restrict__ indices, // BxHxWxtopK 90 | const indice_t *__restrict__ boundingBoxes, // BxNx4 91 | scalar_t *__restrict__ dataGrad) // BxNxC 92 | { 93 | // const int numPixels = imgWidth * imgHeight; 94 | // loop all points 95 | for (indice_t b = blockIdx.x; b < batchSize; b += gridDim.x) { 96 | for (indice_t p = threadIdx.x + blockIdx.y * blockDim.x; p < numPoint; 97 | p += blockDim.x * gridDim.y) { 98 | const indice_t curPointIdx = b * numPoint + p; 99 | scalar_t xmin = max(boundingBoxes[curPointIdx * 4], indice_t(0)); 100 | indice_t ymin = max(boundingBoxes[curPointIdx * 4 + 1], indice_t(0)); 101 | indice_t xmax = 102 | min(indice_t(boundingBoxes[curPointIdx * 4 + 2]), indice_t(imgWidth)); 103 | indice_t ymax = min(indice_t(boundingBoxes[curPointIdx * 4 + 3]), 104 | indice_t(imgHeight)); 105 | // loop over all pixels 106 | for (indice_t i = ymin; i < ymax; i++) { 107 | for (indice_t j = xmin; j < xmax; j++) { 108 | indice_t pixID = b * imgWidth * imgHeight + i * imgWidth + j; 109 | const indice_t *iOffset = indices + pixID * topK; 110 | for (indice_t k = 0; k < topK; k++) { 111 | indice_t pid = iOffset[k]; 112 | if (pid == p) { 113 | for (indice_t c = 0; c < C; c++) { 114 | dataGrad[(b * numPoint + pid) * C + c] += 115 | outGrad[(pixID * topK + k) * C + c]; 116 | } 117 | } 118 | } 119 | } 120 | } 121 | } 122 | } 123 | } 124 | template 125 | __device__ void 126 | update_IndexMap(const scalar_t depth, const int pointId, const int yInBB, 127 | const int xInBB, const int topK, indice_t *pointIdxList, 128 | indice_t *bbPositionList, scalar_t *pointDepthList) { 129 | // compare depth with topK depth list of the current pixel 130 | for (int i = 0; i < topK; i++) { 131 | if (depth < pointDepthList[i]) { 132 | // insert current pointID, yInBB and xInBB and depth to the list 133 | // by shifting [i, topK-1] part of the topK list 134 | for (int j = topK - 1; j > i; j--) { 135 | pointDepthList[j] = pointDepthList[j - 1]; 136 | pointIdxList[j] = pointIdxList[j - 1]; 137 | bbPositionList[j * 2] = bbPositionList[(j - 1) * 2]; 138 | bbPositionList[j * 2 + 1] = bbPositionList[(j - 1) * 2 + 1]; 139 | } 140 | pointIdxList[i] = indice_t(pointId); 141 | bbPositionList[i * 2] = indice_t(yInBB); 142 | bbPositionList[i * 2 + 1] = indice_t(xInBB); 143 | pointDepthList[i] = depth; 144 | break; 145 | } 146 | } 147 | } 148 | // visibility kernel, outputs pointIdxMap and depthMap that saves 5 points per 149 | // pixel which are order by the increasing z-value. bbPositionMap is the 150 | // relative position of the current pixel within the point's bounding box 151 | template 152 | __global__ void compute_visiblity_maps_kernel( 153 | int batchSize, int numPoint, int imgWidth, int imgHeight, int bbWidth, 154 | int bbHeight, int topK, 155 | const indice_t *__restrict__ boundingBoxes, // BxPx2 156 | const scalar_t *__restrict__ inPlane, // BxPxhxwx3 157 | indice_t *pointIdxMap, // BxHxWxK 158 | indice_t *bbPositionMap, // BxHxWxKx2 159 | scalar_t *depthMap // BxHxWxK 160 | ) { 161 | if (numPoint <= 0) 162 | return; 163 | int numPixels = imgWidth * imgHeight; 164 | int bbSize = bbWidth * bbHeight; 165 | // loop in the batch 166 | for (int b = blockIdx.x; b < batchSize; b += gridDim.x) { 167 | // loop all pixels 168 | for (int p = threadIdx.x + blockIdx.y * blockDim.x; p < numPixels; 169 | p += blockDim.x * gridDim.y) { 170 | // current pixel position (h,w) 171 | const int h = p / imgWidth; 172 | const int w = p % imgWidth; 173 | const int pixID = b * numPixels + p; 174 | assert(pixID < batchSize * numPixels && pixID >= 0); 175 | // loop all points 176 | for (int k = 0; k < numPoint; k++) { 177 | const int pointId = b * numPoint + k; 178 | const indice_t xmin = boundingBoxes[pointId * 2]; 179 | const indice_t ymin = boundingBoxes[pointId * 2 + 1]; 180 | // if current pixel is inside the point's boundingbox, 181 | // compare with depth. Update pointIdxMap and depthMap 182 | if (xmin <= w && ymin <= h && (xmin + bbWidth > w) && 183 | (ymin + bbHeight > h)) { 184 | // relative position inside the bounding box 185 | const int yInBB = h - ymin; 186 | const int xInBB = w - xmin; 187 | assert(yInBB >= 0 && yInBB < bbHeight); 188 | assert(xInBB >= 0 && xInBB < bbWidth); 189 | const scalar_t depth = inPlane[pointId * bbSize * 3 + 190 | yInBB * bbWidth * 3 + xInBB * 3 + 2]; 191 | update_IndexMap( 192 | depth, k, yInBB, xInBB, topK, pointIdxMap + pixID * topK, 193 | bbPositionMap + pixID * topK * 2, depthMap + pixID * topK); 194 | } 195 | } 196 | } 197 | } 198 | } 199 | 200 | void compute_visiblity_maps_cuda(const at::Tensor &boundingBoxes, 201 | const at::Tensor &inPlane, 202 | at::Tensor &pointIdxMap, 203 | at::Tensor &bbPositionMap, 204 | at::Tensor &depthMap) { 205 | TORCH_CHECK(inPlane.dim() == 5); 206 | const int batchSize = inPlane.size(0); 207 | const int numPoint = inPlane.size(1); 208 | const int bbHeight = inPlane.size(2); 209 | const int bbWidth = inPlane.size(3); 210 | const int imgHeight = pointIdxMap.size(1); 211 | const int imgWidth = pointIdxMap.size(2); 212 | const int topK = pointIdxMap.size(-1); 213 | 214 | int device; 215 | cudaGetDevice(&device); 216 | // printf("compute_visiblity_maps_cuda using device %d\n", device); 217 | unsigned int n_threads, n_blocks; 218 | int numPixels = imgWidth * imgHeight; 219 | n_threads = opt_n_threads(numPixels); 220 | n_blocks = min(32, (numPixels * batchSize + n_threads - 1) / n_threads); 221 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 222 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 223 | inPlane.scalar_type(), "compute_visiblity_maps_kernel", ([&] { 224 | compute_visiblity_maps_kernel 225 | <<>>( 226 | batchSize, numPoint, imgWidth, imgHeight, bbWidth, bbHeight, 227 | topK, 228 | boundingBoxes.data_ptr(), 229 | inPlane.data_ptr(), pointIdxMap.data_ptr(), 230 | bbPositionMap.data_ptr(), depthMap.data_ptr()); 231 | })); 232 | cudaError_t err = cudaGetLastError(); 233 | // cudaError_t err = cudaGetLastError(); 234 | if (err != cudaSuccess) { 235 | printf("compute_visiblity_maps_cuda kernel failed: %s\n", 236 | cudaGetErrorString(err)); 237 | exit(-1); 238 | } 239 | return; 240 | } 241 | 242 | // data BxNxC, indices BxHxWxK value (0~N-1), output BxHxKxC 243 | at::Tensor gather_maps_cuda(const at::Tensor &data, const at::Tensor &indices, 244 | const double defaultValue) { 245 | const int batchSize = data.size(0); 246 | const int numPoint = data.size(1); 247 | const int imgHeight = indices.size(1); 248 | const int imgWidth = indices.size(2); 249 | const int topK = indices.size(3); 250 | const int channels = data.size(2); 251 | at::Scalar dv = at::Scalar(defaultValue); 252 | auto output = at::full({batchSize, imgHeight, imgWidth, topK, channels}, dv, 253 | data.options()); 254 | unsigned int n_threads, n_blocks; 255 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 256 | int pixelNumber = imgWidth * imgHeight; 257 | n_threads = opt_n_threads(pixelNumber); 258 | n_blocks = min(32, (pixelNumber * batchSize + n_threads - 1) / n_threads); 259 | // printf("gather_maps_cuda: kernel config (%d, %d, %d)\n", batchSize, 260 | // n_blocks, 261 | // n_threads); 262 | AT_DISPATCH_ALL_TYPES( 263 | data.scalar_type(), "gather_maps_kernel", ([&] { 264 | gather_maps_kernel 265 | <<>>( 266 | batchSize, numPoint, imgWidth, imgHeight, topK, channels, 267 | indices.data_ptr(), data.data_ptr(), 268 | dv.to(), output.data_ptr()); 269 | })); 270 | // cudaError_t err = cudaDeviceSynchronize(); 271 | cudaError_t err = cudaGetLastError(); 272 | if (err != cudaSuccess) { 273 | printf("gather_maps_cuda kernel failed: %s\n", cudaGetErrorString(err)); 274 | exit(-1); 275 | } 276 | return output; 277 | } 278 | 279 | // the inverse of gather_maps 280 | // src BxHxWxKxC, indices BxHxWxK value (0~N-1), output BxHxC 281 | at::Tensor scatter_maps_cuda(const int64_t numPoint, const at::Tensor &src, 282 | const at::Tensor &indices) { 283 | const int batchSize = indices.size(0); 284 | const int imgHeight = indices.size(1); 285 | const int imgWidth = indices.size(2); 286 | const int topK = indices.size(3); 287 | const int channels = src.size(-1); 288 | unsigned int n_threads, n_blocks; 289 | const int nP = int(numPoint); 290 | n_threads = opt_n_threads(nP); 291 | n_blocks = min(32, (nP * batchSize + n_threads - 1) / n_threads); 292 | // initialize with zeros 293 | auto dataGrad = at::zeros({batchSize, nP, channels}, src.options()); 294 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 295 | AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, 296 | src.scalar_type(), "scatter_maps_kernel", ([&] { 297 | scatter_maps_kernel 298 | <<>>( 299 | batchSize, nP, imgWidth, imgHeight, topK, channels, 300 | src.data_ptr(), indices.data_ptr(), 301 | dataGrad.data_ptr()); 302 | })); 303 | cudaError_t err = cudaDeviceSynchronize(); 304 | // cudaError_t err = cudaGetLastError(); 305 | if (err != cudaSuccess) { 306 | printf("scatter_maps_cuda kernel failed: %s\n", cudaGetErrorString(err)); 307 | exit(-1); 308 | } 309 | return dataGrad; 310 | } 311 | 312 | // the inverse of gather_maps, use boundingboxes to restrict search areas 313 | // src BxHxWxKxC, indices BxHxWxK value (0~N-1), output BxHxC 314 | at::Tensor guided_scatter_maps_cuda(const int64_t numPoint, 315 | const at::Tensor &src, 316 | const at::Tensor &indices, 317 | const at::Tensor &boundingBoxes) { 318 | const int batchSize = indices.size(0); 319 | const int imgHeight = indices.size(1); 320 | const int imgWidth = indices.size(2); 321 | const int topK = indices.size(3); 322 | CHECK_EQ(src.dim(), 5); 323 | const int channels = src.size(-1); 324 | unsigned int n_threads, n_blocks; 325 | const int nP = int(numPoint); 326 | n_threads = opt_n_threads(nP); 327 | // 2D grid (batchSize, n_blocks) 328 | n_blocks = min(32, (nP * batchSize + n_threads - 1) / n_threads); 329 | // initialize with zeros 330 | auto dataGrad = at::zeros({batchSize, nP, channels}, src.options()); 331 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 332 | AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, 333 | src.scalar_type(), "guided_scatter_maps_kernel", ([&] { 334 | guided_scatter_maps_kernel 335 | <<>>( 336 | batchSize, nP, imgWidth, imgHeight, topK, channels, 337 | src.data_ptr(), 338 | indices.data_ptr(), 339 | boundingBoxes.data_ptr(), 340 | dataGrad.data_ptr()); 341 | })); 342 | cudaError_t err = cudaDeviceSynchronize(); 343 | // cudaError_t err = cudaGetLastError(); 344 | if (err != cudaSuccess) { 345 | printf("guided_scatter_maps_cuda kernel failed: %s\n", 346 | cudaGetErrorString(err)); 347 | exit(-1); 348 | } 349 | return dataGrad; 350 | } -------------------------------------------------------------------------------- /DSS/csrc/rasterize_points_backward.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "rasterization_utils.cuh" 9 | 10 | #define GRID_3D_MIN_X 0 11 | #define GRID_3D_MIN_Y 1 12 | #define GRID_3D_MIN_Z 2 13 | #define GRID_3D_DELTA 3 14 | #define GRID_3D_RES_X 4 15 | #define GRID_3D_RES_Y 5 16 | #define GRID_3D_RES_Z 6 17 | #define GRID_3D_TOTAL 7 18 | #define GRID_3D_PARAMS_SIZE 8 19 | #define GRID_3D_MAX_RES 128 20 | 21 | #define GRID_2D_MIN_X 0 22 | #define GRID_2D_MIN_Y 1 23 | #define GRID_2D_DELTA 2 24 | #define GRID_2D_RES_X 3 25 | #define GRID_2D_RES_Y 4 26 | #define GRID_2D_TOTAL 5 27 | #define GRID_2D_PARAMS_SIZE 6 28 | #define GRID_2D_MAX_RES 1024 29 | 30 | __global__ void RasterizePointsBackwardCudaFastKernel( 31 | const float* __restrict__ points_sorted, // (P,3) 32 | const float* __restrict__ radii_sorted, // (P,2) 33 | const float* __restrict__ rs, // (N,) 34 | const long* __restrict__ num_points_per_cloud, // (N,) 35 | const long* __restrict__ cloud_to_packed_first_idx, // (N,) 36 | const int32_t* __restrict__ points_grid_off, // (N,G) offset for the entire pack 37 | const float* __restrict__ grid_params, 38 | const float* __restrict__ grad_occ, // (N,H,W) 39 | // const int32_t * __restrict__ point_idxs, // (N,H,W,K) 40 | // const float* __restrict__ grad_zbuf, // (N,H,W,K) 41 | const int N, 42 | const int H, 43 | const int W, 44 | // const int K, 45 | const int B, 46 | const int G, 47 | float* grad_points 48 | ) { 49 | // loop over all pixels. indexing s.t. neighboring threads get pixels inside the same grid 50 | const int BIN_SIZE_Y = (H + B/2) / B; 51 | const int BIN_SIZE_X = (W + B/2) / B; 52 | 53 | const float PIXEL_SIZE_X = 2.0f / W; 54 | const float PIXEL_SIZE_Y = 2.0f / H; 55 | const float PIXEL_SIZE = (PIXEL_SIZE_X+PIXEL_SIZE_Y) / 2.0; 56 | 57 | const int num_pixels = N * BIN_SIZE_X * BIN_SIZE_Y * B * B; 58 | const int num_threads = gridDim.x * blockDim.x; 59 | const int tid = blockIdx.x * blockDim.x + threadIdx.x; 60 | 61 | for (int pid = tid; pid < num_pixels; pid += num_threads) { 62 | int i = pid; 63 | 64 | // Convert linear index into bin and pixel indices. 65 | const int n = i / (BIN_SIZE_X * BIN_SIZE_Y * B * B); // batch 66 | i %= BIN_SIZE_X * BIN_SIZE_Y * B * B; 67 | const int by = i / (B * BIN_SIZE_X * BIN_SIZE_Y); // bin_y 68 | i %= B * BIN_SIZE_X * BIN_SIZE_Y; 69 | const int bx = i / (BIN_SIZE_X * BIN_SIZE_Y); // bin_x 70 | assert(n < N && n >= 0); 71 | assert(bx < B && bx >= 0); 72 | assert(by < B && by >= 0); 73 | // lixin 74 | // i %= B * BIN_SIZE_X; 75 | // const int bx = i / BIN_SIZE_X; 76 | i %= BIN_SIZE_X * BIN_SIZE_Y; // index inside the bin 77 | 78 | // Pixel indices 79 | const int yidx = i / BIN_SIZE_X + by * BIN_SIZE_Y; 80 | const int xidx = i % BIN_SIZE_X + bx * BIN_SIZE_X; 81 | 82 | if (yidx >= H || xidx >= W) { 83 | continue; 84 | } 85 | const float grad_occ_pix = grad_occ[n*H*W + yidx*W + xidx]; 86 | if (grad_occ_pix != 0.0f) { 87 | // reverse because NDC assuming +y is up and +x is left 88 | const int yi = H - 1 - yidx; 89 | const int xi = W - 1 - xidx; 90 | assert(xi >= 0 && xi < W); 91 | assert(yi >= 0 && yi < H); 92 | 93 | // Pixel in NDC coordinates 94 | const float xf = PixToNdc(xi, W); 95 | const float yf = PixToNdc(yi, H); 96 | assert(abs(xf) <= 1.0 && abs(yf) <= 1.0); 97 | 98 | const long cur_first_idx = cloud_to_packed_first_idx[n]; 99 | const float cur_r = rs[n]; // search radius 100 | const float cur_r2 = cur_r * cur_r; 101 | 102 | const float grid_min_x = grid_params[n*GRID_2D_PARAMS_SIZE+GRID_2D_MIN_X]; 103 | const float grid_min_y = grid_params[n*GRID_2D_PARAMS_SIZE+GRID_2D_MIN_Y]; 104 | const float grid_delta = grid_params[n*GRID_2D_PARAMS_SIZE+GRID_2D_DELTA]; // 1/cell_size 105 | const int grid_res_x = grid_params[n*GRID_2D_PARAMS_SIZE+GRID_2D_RES_X]; 106 | const int grid_res_y = grid_params[n*GRID_2D_PARAMS_SIZE+GRID_2D_RES_Y]; 107 | const int grid_total = grid_params[n*GRID_2D_PARAMS_SIZE+GRID_2D_TOTAL]; 108 | // const float grad_occ_pix = grad_occ[i]; 109 | assert(n*H*W + yi*W + xi < N*H*W); 110 | 111 | int min_gc_x = (int) floor((xf-grid_min_x-cur_r) * grid_delta); 112 | int min_gc_y = (int) floor((yf-grid_min_y-cur_r) * grid_delta); 113 | int max_gc_x = (int) floor((xf-grid_min_x+cur_r) * grid_delta); 114 | int max_gc_y = (int) floor((yf-grid_min_y+cur_r) * grid_delta); 115 | 116 | // Search the relevant grid 117 | for (int x=max(min_gc_x, 0); x<=min(max_gc_x, grid_res_x-1); ++x) { 118 | for (int y=max(min_gc_y, 0); y<=min(max_gc_y, grid_res_y-1); ++y) { 119 | int cell_idx = x*grid_res_y + y; 120 | assert(cell_idx < grid_total); 121 | // Get the relevant index range of points 122 | const int64_t p2_start = points_grid_off[n*G + cell_idx]; 123 | int p2_end; 124 | if (cell_idx+1 == grid_total) { 125 | p2_end = num_points_per_cloud[n]; 126 | } 127 | else { 128 | p2_end = points_grid_off[n*G+cell_idx+1]; 129 | } 130 | if (p2_end > cur_first_idx+num_points_per_cloud[n]){ 131 | printf("points_grid_off[%d, %d] = %d, grid_total = %d, p2_end = %d, num_points_per_cloud[%d] = %d, cur_first_idx = %d, pid = %d\n", n, cell_idx, p2_start, grid_total, n, num_points_per_cloud[n], p2_end, cur_first_idx, pid); 132 | } 133 | assert(p2_end <= cur_first_idx+num_points_per_cloud[n]); 134 | // Loop over the relevant points, aggregate gradients 135 | for (int p_idx=p2_start; p_idx= cur_first_idx); 141 | const float px = points_sorted[p_idx * 3 + 0]; 142 | const float py = points_sorted[p_idx * 3 + 1]; 143 | const float pz = points_sorted[p_idx * 3 + 2]; 144 | // outside renderable area 145 | if (pz < 0 || abs(py) > 1.0 || abs(px) > 1.0) 146 | continue; 147 | const float dx = xf - px; 148 | const float dy = yf - py; 149 | 150 | const float radiix = radii_sorted[p_idx * 2 + 0]; 151 | const float radiiy = radii_sorted[p_idx * 2 + 1]; 152 | 153 | const float dist2 = dx * dx + dy * dy; 154 | 155 | // inside backpropagation radius? 156 | if (dist2 > cur_r2) 157 | continue; // Skip if pixel out of precomputed radii range 158 | 159 | // inside splat? NOTE: this is not as accurate as check qvalue < cutoffthreshold 160 | // but it's a close approximation 161 | const bool pix_outside_splat = (abs(dx) > radiix) || (abs(dy) > radiiy); 162 | 163 | // if grad_occ_pix > 0, it means that this pixel shouldn't be occluded 164 | // but if it's outside the splat, it doesn't generate meaninigful information 165 | // for in which direction the point should move. 166 | if (grad_occ_pix > 0.0f && pix_outside_splat) 167 | // if (grad_occ_pix > 0.0f) 168 | continue; 169 | 170 | const float denom = eps_denom(dist2, 1e-10f); 171 | const float grad_px = dx / denom * grad_occ_pix; 172 | const float grad_py = dy / denom * grad_occ_pix; 173 | // const float grad_px = clamp(dx / denom, -10/PIXEL_SIZE, 10/PIXEL_SIZE) * grad_occ_pix; 174 | // const float grad_py = clamp(dy / denom, -10/PIXEL_SIZE, 10/PIXEL_SIZE) * grad_occ_pix; 175 | 176 | // printf("grad_pts[%d] = [%g, %g]\n", p_idx, grad_px, grad_py); 177 | gpuAtomicAdd(grad_points + p_idx * 2 + 0, grad_px); 178 | gpuAtomicAdd(grad_points + p_idx * 2 + 1, grad_py); 179 | 180 | } 181 | 182 | // // If inside splat, copy the grad_pz 183 | // if (!pix_outside_splat) { 184 | // const int ik = (n*H*W+yi*W+xi) * K; // pid is for (B*BIN_SIZE_Y)x(B*BIN_SIZE_X) 185 | // assert(n < N); 186 | // assert(yi < H); 187 | // assert(xi < W); 188 | // assert(n*H*W+yi*W+xi < N*H*W); 189 | // assert(ik+K-1 < N*H*W*K); 190 | // // if (ik >= N*H*W*K) { 191 | // // printf("N: %d, n: %d\n", N, n); 192 | // // // printf("H: %d, h: %d\n", H, yi); 193 | // // // printf("W: %d, w: %d\n", W, xi); 194 | // // // printf("N*H*W: %d, n*H*W+yi*W+xi: %d\n", N*H*W, n*H*W+yi*W+xi); 195 | // // printf("N*H*W*K: %d, (n*H*W+yi*W+xi)*K: %d\n", N*H*W*K, (n*H*W+yi*W+xi)*K); 196 | // // assert(ik < N*H*W*K); 197 | // // } 198 | // for (int k = 0; k < K; k++) 199 | // { 200 | // const int z_idx = point_idxs[ik + k]; 201 | // if (z_idx < 0) 202 | // break; 203 | // const float grad_pz = grad_zbuf[ik + k]; 204 | // gpuAtomicAdd(grad_points + z_idx * 3 + 2, grad_pz); 205 | // } 206 | // } 207 | } 208 | 209 | } 210 | } 211 | } 212 | } 213 | /* 214 | Args: 215 | points, // (P, 3) 216 | radii, // (P, 2) 217 | idx, // (N, H, W, K) 218 | rs, // (N, ) 219 | grad_occ, // (N, H, W) 220 | grad_zbuf, // (N, H, W, K) 221 | num_points_per_cloud, // (N,) 222 | cloud_to_packed_first_idx, // (N,) 223 | points_grid_off (N, G) The packed index of the first point stored in a grid 224 | Returns: 225 | grad_points: (P, 3) 226 | */ 227 | at::Tensor RasterizePointsBackwardCudaFast( 228 | const at::Tensor &points_sorted, // (P, 3) 229 | const at::Tensor &radii_sorted, // (P, 2) 230 | // const at::Tensor &idxs, // (N, H, W, K) 231 | const at::Tensor &rs, // (N, ) 232 | const at::Tensor &grad_occ, // (N, H, W) 233 | // const at::Tensor &grad_zbuf, // (N, H, W, K) 234 | const at::Tensor &num_points_per_cloud, // (N,) 235 | const at::Tensor &cloud_to_packed_first_idx, // (N,) 236 | const at::Tensor &points_grid_off, // (N, G) 237 | const at::Tensor &grid_params // (N, GRID_2D_PARAMS_SIZE) 238 | ) { 239 | // Check inputs are on the same device 240 | at::TensorArg points_t{points_sorted, "points", 1}, 241 | radii_t{radii_sorted, "radii", 2}, 242 | // idxs_t{idxs, "idxs", 3}, 243 | rs_t{rs, "rs", 3}, 244 | grad_occ_t{grad_occ, "grad_occ", 4}, 245 | // grad_zbuf_t{grad_zbuf, "grad_zbuf", 6}, 246 | num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 5}, 247 | cloud_to_packed_first_idx_t{ 248 | cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 6}, 249 | points_grid_off_t{points_grid_off, "points_grid_off", 7}, 250 | grid_params_t{grid_params, "grid_params", 8}; 251 | at::CheckedFrom c = "RasterizePointsBackwardCudaFast"; 252 | at::checkDim(c, points_t, 2); 253 | at::checkDim(c, radii_t, 2); 254 | // at::checkDim(c, idxs_t, 4); 255 | at::checkDim(c, rs_t, 1); 256 | at::checkDim(c, grad_occ_t, 3); 257 | // at::checkDim(c, grad_zbuf_t, 4); 258 | at::checkDim(c, num_points_per_cloud_t, 1); 259 | at::checkDim(c, cloud_to_packed_first_idx_t, 1); 260 | at::checkDim(c, points_grid_off_t, 2); 261 | at::checkDim(c, grid_params_t, 2); 262 | at::checkSize(c, grid_params_t, 1, GRID_2D_PARAMS_SIZE); 263 | at::checkSize(c, points_t, 1, 3); 264 | at::checkSize(c, radii_t, {points_t->size(0), 2}); 265 | // at::checkAllSameSize(c, {rs_t, num_points_per_cloud_t, cloud_to_packed_first_idx_t}); 266 | at::checkSameSize(c, rs_t, cloud_to_packed_first_idx_t); 267 | at::checkSameSize(c, rs_t, num_points_per_cloud_t); 268 | // at::checkSameSize(c, grad_zbuf_t, idxs_t); 269 | 270 | at::checkAllSameGPU( 271 | c, {points_t, radii_t, rs_t, grad_occ_t, 272 | num_points_per_cloud_t, cloud_to_packed_first_idx_t, points_grid_off_t}); 273 | 274 | // Set the device for the kernel launch based on the device of the input 275 | at::cuda::CUDAGuard device_guard(points_sorted.device()); 276 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 277 | 278 | const int P = points_sorted.size(0); 279 | const int G = points_grid_off.size(1); 280 | const int N = grad_occ.size(0); 281 | const int H = grad_occ.size(1); 282 | const int W = grad_occ.size(2); 283 | int B = 1; 284 | const int S = min(H, W); 285 | 286 | if (S >= 64) 287 | B = 8; 288 | if (S >= 128) 289 | B = 16; 290 | if (S >= 256) 291 | B = 32; 292 | if (S >= 512) 293 | B = 64; 294 | 295 | // call backward fast kernel on sorted points_sorted and sorted radii_sorted, this will return a gradient [P, 3] of the *sorted* points_sorted 296 | const size_t blocks = 1024; 297 | const size_t threads = 64; 298 | at::Tensor grad_points_sorted = at::zeros({P, 2}, points_sorted.options()); 299 | RasterizePointsBackwardCudaFastKernel<<>>( 300 | points_sorted.contiguous().data_ptr(), // (P,3) 301 | radii_sorted.contiguous().data_ptr(), // (P,2) 302 | rs.contiguous().data_ptr(), // (N,) 303 | num_points_per_cloud.contiguous().data_ptr(), // (N,) 304 | cloud_to_packed_first_idx.contiguous().data_ptr(), // (N,) 305 | points_grid_off.contiguous().data_ptr(), // (N,G) 306 | grid_params.contiguous().data_ptr(), // (N,8) 307 | grad_occ.contiguous().data_ptr(), // (N,H,W) 308 | // idxs.contiguous().data_ptr(), // (N,H,W) 309 | // grad_zbuf.contiguous().data_ptr(), // (N,H,W) 310 | N, 311 | H, 312 | W, 313 | // K, 314 | B, 315 | G, // grid_res_x * grid_res_y 316 | grad_points_sorted.contiguous().data_ptr() 317 | ); 318 | 319 | AT_CUDA_CHECK(cudaGetLastError()); 320 | 321 | return grad_points_sorted; 322 | } -------------------------------------------------------------------------------- /DSS/csrc/types.hpp: -------------------------------------------------------------------------------- 1 | using PointIndex = int; 2 | using Coord = int; 3 | using Float = float; 4 | 5 | -------------------------------------------------------------------------------- /DSS/csrc/weighted_sum.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | // TODO(gkioxari) support all data types once AtomicAdd supports doubles. 15 | // Currently, support is for floats only. 16 | 17 | // Modifications: use scalar_k and Q value map to get alpha 18 | // Do this in python as done in points/renderer.py 19 | // so the input is alphas, and no need to change things here 20 | // Inputs: 21 | // features: FloatTensor of shape (C, P) which gives the features 22 | // of each point where C is the size of the feature and 23 | // P the number of points. 24 | // alphas: FloatTensor of shape (N, points_per_pixel, W, W) where 25 | // points_per_pixel is the number of points in the z-buffer 26 | // sorted in z-order, and W is the image size. 27 | // points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the 28 | // indices of the nearest points at each pixel, sorted in z-order. 29 | // Returns: 30 | // weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated 31 | // feature in each point. Concretely, it gives: 32 | // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * 33 | // features[c,points_idx[b,k,i,j]] 34 | 35 | 36 | 37 | 38 | __global__ void weightedSumCudaForwardKernel( 39 | // clang-format off 40 | at::PackedTensorAccessor64 result, 41 | // at::PackedTensorAccessor64 alphas, 42 | const at::PackedTensorAccessor64 features, 43 | const at::PackedTensorAccessor64 alphas, 44 | // const at::PackedTensorAccessor64 scalars, 45 | // const at::PackedTensorAccessor64 qvalue_map, 46 | const at::PackedTensorAccessor64 points_idx) { 47 | // clang-format on 48 | const int64_t batch_size = result.size(0); 49 | const int64_t C = features.size(0); 50 | const int64_t H = points_idx.size(2); 51 | const int64_t W = points_idx.size(3); 52 | 53 | // Get the batch and index 54 | const int batch = blockIdx.x; 55 | 56 | const int num_pixels = C * W * H; 57 | const int num_threads = gridDim.y * blockDim.x; 58 | const int tid = blockIdx.y * blockDim.x + threadIdx.x; 59 | 60 | // Parallelize over each feature in each pixel in images of size H * W, 61 | // for each image in the batch of size batch_size 62 | for (int pid = tid; pid < num_pixels; pid += num_threads) { 63 | int ch = pid / (W * H); 64 | int j = (pid % (W * H)) / H; 65 | int i = (pid % (W * H)) % H; 66 | 67 | // Iterate through the closest K points for this pixel 68 | for (int k = 0; k < points_idx.size(1); ++k) { 69 | int n_idx = points_idx[batch][k][j][i]; 70 | // Sentinel value is -1 indicating no point overlaps the pixel 71 | if (n_idx < 0) { 72 | continue; 73 | } 74 | 75 | // Accumulate the values 76 | float alpha = alphas[batch][k][j][i]; 77 | // TODO(gkioxari) It might be more efficient to have threads write in a 78 | // local variable, and move atomicAdd outside of the loop such that 79 | // atomicAdd is executed once per thread. 80 | atomicAdd(&result[batch][ch][j][i], features[ch][n_idx] * alpha); 81 | } 82 | } 83 | } 84 | 85 | // TODO(gkioxari) support all data types once AtomicAdd supports doubles. 86 | // Currently, support is for floats only. 87 | __global__ void weightedSumCudaBackwardKernel( 88 | // clang-format off 89 | at::PackedTensorAccessor64 grad_features, 90 | at::PackedTensorAccessor64 grad_alphas, 91 | const at::PackedTensorAccessor64 grad_outputs, 92 | const at::PackedTensorAccessor64 features, 93 | const at::PackedTensorAccessor64 alphas, 94 | const at::PackedTensorAccessor64 points_idx) { 95 | // clang-format on 96 | const int64_t batch_size = points_idx.size(0); 97 | const int64_t C = features.size(0); 98 | const int64_t H = points_idx.size(2); 99 | const int64_t W = points_idx.size(3); 100 | 101 | // Get the batch and index 102 | const int batch = blockIdx.x; 103 | 104 | const int num_pixels = C * W * H; 105 | const int num_threads = gridDim.y * blockDim.x; 106 | const int tid = blockIdx.y * blockDim.x + threadIdx.x; 107 | 108 | // Iterate over each pixel to compute the contribution to the 109 | // gradient for the features and weights 110 | for (int pid = tid; pid < num_pixels; pid += num_threads) { 111 | int ch = pid / (W * H); 112 | int j = (pid % (W * H)) / H; 113 | int i = (pid % (W * H)) % H; 114 | 115 | // Iterate through the closest K points for this pixel 116 | for (int k = 0; k < points_idx.size(1); ++k) { 117 | int n_idx = points_idx[batch][k][j][i]; 118 | // Sentinel value is -1 indicating no point overlaps the pixel 119 | if (n_idx < 0) { 120 | continue; 121 | } 122 | float alpha = alphas[batch][k][j][i]; 123 | 124 | // TODO(gkioxari) It might be more efficient to have threads write in a 125 | // local variable, and move atomicAdd outside of the loop such that 126 | // atomicAdd is executed once per thread. 127 | atomicAdd( 128 | &grad_alphas[batch][k][j][i], 129 | features[ch][n_idx] * grad_outputs[batch][ch][j][i]); 130 | atomicAdd( 131 | &grad_features[ch][n_idx], alpha * grad_outputs[batch][ch][j][i]); 132 | } 133 | } 134 | } 135 | 136 | at::Tensor weightedSumCudaForward( 137 | const at::Tensor& features, 138 | const at::Tensor& alphas, 139 | const at::Tensor& points_idx) { 140 | // Check inputs are on the same device 141 | at::TensorArg features_t{features, "features", 1}, 142 | alphas_t{alphas, "alphas", 2}, points_idx_t{points_idx, "points_idx", 3}; 143 | at::CheckedFrom c = "weightedSumCudaForward"; 144 | at::checkAllSameGPU(c, {features_t, alphas_t, points_idx_t}); 145 | at::checkAllSameType(c, {features_t, alphas_t}); 146 | 147 | // Set the device for the kernel launch based on the device of the input 148 | at::cuda::CUDAGuard device_guard(features.device()); 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | 151 | const int64_t batch_size = points_idx.size(0); 152 | const int64_t C = features.size(0); 153 | const int64_t H = points_idx.size(2); 154 | const int64_t W = points_idx.size(3); 155 | 156 | auto result = at::zeros({batch_size, C, H, W}, features.options()); 157 | 158 | if (result.numel() == 0) { 159 | AT_CUDA_CHECK(cudaGetLastError()); 160 | return result; 161 | } 162 | 163 | const dim3 threadsPerBlock(64); 164 | const dim3 numBlocks(batch_size, 1024 / batch_size + 1); 165 | 166 | // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports 167 | // doubles. Currently, support is for floats only. 168 | weightedSumCudaForwardKernel<<>>( 169 | // clang-format off 170 | // As we are using packed accessors here the tensors 171 | // do not need to be made contiguous. 172 | result.packed_accessor64(), 173 | features.packed_accessor64(), 174 | alphas.packed_accessor64(), 175 | points_idx.packed_accessor64()); 176 | // clang-format on 177 | AT_CUDA_CHECK(cudaGetLastError()); 178 | return result; 179 | } 180 | 181 | std::tuple weightedSumCudaBackward( 182 | const at::Tensor& grad_outputs, 183 | const at::Tensor& features, 184 | const at::Tensor& alphas, 185 | const at::Tensor& points_idx) { 186 | // Check inputs are on the same device 187 | at::TensorArg grad_outputs_t{grad_outputs, "grad_outputs", 1}, 188 | features_t{features, "features", 2}, alphas_t{alphas, "alphas", 3}, 189 | points_idx_t{points_idx, "points_idx", 4}; 190 | at::CheckedFrom c = "weightedSumCudaBackward"; 191 | at::checkAllSameGPU(c, {grad_outputs_t, features_t, alphas_t, points_idx_t}); 192 | at::checkAllSameType(c, {grad_outputs_t, features_t, alphas_t}); 193 | 194 | // Set the device for the kernel launch based on the device of the input 195 | at::cuda::CUDAGuard device_guard(features.device()); 196 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 197 | 198 | auto grad_features = at::zeros_like(features); 199 | auto grad_alphas = at::zeros_like(alphas); 200 | 201 | if (grad_features.numel() == 0 || grad_alphas.numel() == 0) { 202 | AT_CUDA_CHECK(cudaGetLastError()); 203 | return std::make_tuple(grad_features, grad_alphas); 204 | } 205 | 206 | const int64_t bs = points_idx.size(0); 207 | 208 | const dim3 threadsPerBlock(64); 209 | const dim3 numBlocks(bs, 1024 / bs + 1); 210 | 211 | // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports 212 | // doubles. Currently, support is for floats only. 213 | weightedSumCudaBackwardKernel<<>>( 214 | // clang-format off 215 | // As we are using packed accessors here the tensors 216 | // do not need to be made contiguous. 217 | grad_features.packed_accessor64(), 218 | grad_alphas.packed_accessor64(), 219 | grad_outputs.packed_accessor64(), 220 | features.packed_accessor64(), 221 | alphas.packed_accessor64(), 222 | points_idx.packed_accessor64()); 223 | // clang-format on 224 | AT_CUDA_CHECK(cudaGetLastError()); 225 | return std::make_tuple(grad_features, grad_alphas); 226 | } 227 | -------------------------------------------------------------------------------- /DSS/csrc/weighted_sum.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #include 4 | #include "utils/pytorch3d_cutils.h" 5 | 6 | #include 7 | 8 | // Perform weighted sum compositing of points in a z-buffer. 9 | // 10 | // Inputs: 11 | // features: FloatTensor of shape (C, P) which gives the features 12 | // of each point where C is the size of the feature and 13 | // P the number of points. 14 | // alphas: FloatTensor of shape (N, points_per_pixel, W, W) where 15 | // points_per_pixel is the number of points in the z-buffer 16 | // sorted in z-order, and W is the image size. 17 | // points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the 18 | // indices of the nearest points at each pixel, sorted in z-order. 19 | // Returns: 20 | // weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated 21 | // feature in each point. Concretely, it gives: 22 | // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * 23 | // features[c,points_idx[b,k,i,j]] 24 | 25 | // CUDA declarations 26 | #ifdef WITH_CUDA 27 | torch::Tensor weightedSumCudaForward( 28 | const torch::Tensor& features, 29 | const torch::Tensor& alphas, 30 | const torch::Tensor& points_idx); 31 | 32 | std::tuple weightedSumCudaBackward( 33 | const torch::Tensor& grad_outputs, 34 | const torch::Tensor& features, 35 | const torch::Tensor& alphas, 36 | const torch::Tensor& points_idx); 37 | #endif 38 | 39 | // C++ declarations 40 | torch::Tensor weightedSumCpuForward( 41 | const torch::Tensor& features, 42 | const torch::Tensor& alphas, 43 | const torch::Tensor& points_idx); 44 | 45 | std::tuple weightedSumCpuBackward( 46 | const torch::Tensor& grad_outputs, 47 | const torch::Tensor& features, 48 | const torch::Tensor& alphas, 49 | const torch::Tensor& points_idx); 50 | 51 | torch::Tensor weightedSumForward( 52 | torch::Tensor& features, 53 | torch::Tensor& alphas, 54 | torch::Tensor& points_idx) { 55 | features = features.contiguous(); 56 | alphas = alphas.contiguous(); 57 | points_idx = points_idx.contiguous(); 58 | 59 | if (features.is_cuda()) { 60 | #ifdef WITH_CUDA 61 | CHECK_CUDA(features); 62 | CHECK_CUDA(alphas); 63 | CHECK_CUDA(points_idx); 64 | return weightedSumCudaForward(features, alphas, points_idx); 65 | #else 66 | AT_ERROR("Not compiled with GPU support"); 67 | #endif 68 | } else { 69 | return weightedSumCpuForward(features, alphas, points_idx); 70 | } 71 | } 72 | 73 | std::tuple weightedSumBackward( 74 | torch::Tensor& grad_outputs, 75 | torch::Tensor& features, 76 | torch::Tensor& alphas, 77 | torch::Tensor& points_idx) { 78 | grad_outputs = grad_outputs.contiguous(); 79 | features = features.contiguous(); 80 | alphas = alphas.contiguous(); 81 | points_idx = points_idx.contiguous(); 82 | 83 | if (grad_outputs.is_cuda()) { 84 | #ifdef WITH_CUDA 85 | CHECK_CUDA(grad_outputs); 86 | CHECK_CUDA(features); 87 | CHECK_CUDA(alphas); 88 | CHECK_CUDA(points_idx); 89 | 90 | return weightedSumCudaBackward(grad_outputs, features, alphas, points_idx); 91 | #else 92 | AT_ERROR("Not compiled with GPU support"); 93 | #endif 94 | } else { 95 | return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /DSS/logger.py: -------------------------------------------------------------------------------- 1 | """ From https://github.com/t177398/best_python_logger """ 2 | import logging 3 | import sys 4 | 5 | class _CustomFormatter(logging.Formatter): 6 | """Logging Formatter to add colors and count warning / errors""" 7 | 8 | grey = "\x1b[0;37m" 9 | green = "\x1b[1;32m" 10 | yellow = "\x1b[1;33m" 11 | red = "\x1b[1;31m" 12 | purple = "\x1b[1;35m" 13 | blue = "\x1b[1;34m" 14 | light_blue = "\x1b[1;36m" 15 | reset = "\x1b[0m" 16 | blink_red = "\x1b[5m\x1b[1;31m" 17 | format_prefix = f"{purple}%(asctime)s{reset} " \ 18 | f"{blue}%(name)s{reset} " \ 19 | f"{light_blue}(%(filename)s:%(lineno)d){reset} " 20 | 21 | format_suffix = "%(levelname)s - %(message)s" 22 | 23 | FORMATS = { 24 | logging.DEBUG: format_prefix + green + format_suffix + reset, 25 | logging.INFO: format_prefix + grey + format_suffix + reset, 26 | logging.WARNING: format_prefix + yellow + format_suffix + reset, 27 | logging.ERROR: format_prefix + red + format_suffix + reset, 28 | logging.CRITICAL: format_prefix + blink_red + format_suffix + reset 29 | } 30 | 31 | def format(self, record): 32 | log_fmt = self.FORMATS.get(record.levelno) 33 | formatter = logging.Formatter(log_fmt) 34 | return formatter.format(record) 35 | 36 | 37 | # Just import this function into your programs 38 | # "from logger import get_logger" 39 | # "logger = get_logger(__name__)" 40 | # Use the variable __name__ so the logger will print the file's name also 41 | 42 | def get_logger(name): 43 | logger = logging.getLogger(name) 44 | logger.setLevel(logging.DEBUG) 45 | ch = logging.StreamHandler() 46 | ch.setLevel(logging.DEBUG) 47 | ch.setFormatter(_CustomFormatter()) 48 | logger.addHandler(ch) 49 | return logger -------------------------------------------------------------------------------- /DSS/misc/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | from .. import logger_py 4 | 5 | 6 | class Thread(threading.Thread): 7 | def __init__(self, target, name='', args=(), kwargs={}): 8 | super().__init__(target=target, name=name, args=args, kwargs=kwargs) 9 | self.args = args 10 | self.kwargs = kwargs 11 | self.name 12 | 13 | def run(self): 14 | t0 = time.time() 15 | super().run() 16 | t1 = time.time() 17 | logger_py.info('{}: {:.3f} seconds'.format(self.name, t1 - t0)) 18 | -------------------------------------------------------------------------------- /DSS/misc/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | import shutil 6 | import datetime 7 | 8 | 9 | class CheckpointIO(object): 10 | ''' CheckpointIO class. 11 | 12 | It handles saving and loading checkpoints. 13 | 14 | Args: 15 | checkpoint_dir (str): path where checkpoints are saved 16 | ''' 17 | 18 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 19 | self.module_dict = kwargs 20 | self.checkpoint_dir = checkpoint_dir 21 | if not os.path.exists(checkpoint_dir): 22 | os.makedirs(checkpoint_dir) 23 | 24 | def register_modules(self, **kwargs): 25 | ''' Registers modules in current module dictionary. 26 | ''' 27 | self.module_dict.update(kwargs) 28 | 29 | def save(self, filename, **kwargs): 30 | ''' Saves the current module dictionary. 31 | 32 | Args: 33 | filename (str): name of output file 34 | ''' 35 | if not os.path.isabs(filename): 36 | filename = os.path.join(self.checkpoint_dir, filename) 37 | 38 | outdict = kwargs 39 | for k, v in self.module_dict.items(): 40 | outdict[k] = v.state_dict() 41 | torch.save(outdict, filename) 42 | 43 | def backup_model_best(self, filename, **kwargs): 44 | if not os.path.isabs(filename): 45 | filename = os.path.join(self.checkpoint_dir, filename) 46 | if os.path.exists(filename): 47 | # Backup model 48 | backup_dir = os.path.join(self.checkpoint_dir, 'backup_model_best') 49 | if not os.path.exists(backup_dir): 50 | os.makedirs(backup_dir) 51 | ts = datetime.datetime.now().timestamp() 52 | filename_backup = os.path.join(backup_dir, '%s.pt' % ts) 53 | shutil.copy(filename, filename_backup) 54 | 55 | def load(self, filename): 56 | '''Loads a module dictionary from local file or url. 57 | 58 | Args: 59 | filename (str): name of saved module dictionary 60 | ''' 61 | if is_url(filename): 62 | return self.load_url(filename) 63 | else: 64 | return self.load_file(filename) 65 | 66 | def load_file(self, filename): 67 | '''Loads a module dictionary from file. 68 | 69 | Args: 70 | filename (str): name of saved module dictionary 71 | ''' 72 | 73 | if not os.path.isabs(filename): 74 | filename = os.path.join(self.checkpoint_dir, filename) 75 | 76 | if os.path.exists(filename): 77 | print(filename) 78 | print('=> Loading checkpoint from local file...', end='') 79 | state_dict = torch.load(filename) 80 | scalars = self.parse_state_dict(state_dict) 81 | print('Done!') 82 | return scalars 83 | else: 84 | raise FileExistsError 85 | 86 | def load_url(self, url): 87 | '''Load a module dictionary from url. 88 | 89 | Args: 90 | url (str): url to saved model 91 | ''' 92 | print(url) 93 | print('=> Loading checkpoint from url...', end='') 94 | state_dict = model_zoo.load_url(url, progress=True) 95 | scalars = self.parse_state_dict(state_dict) 96 | print('Done!') 97 | return scalars 98 | 99 | def parse_state_dict(self, state_dict): 100 | '''Parse state_dict of model and return scalars. 101 | 102 | Args: 103 | state_dict (dict): State dict of model 104 | ''' 105 | 106 | for k, v in self.module_dict.items(): 107 | if k in state_dict: 108 | if isinstance(v, torch.optim.Optimizer): 109 | v.load_state_dict(state_dict[k]) 110 | else: 111 | missing_keys, unexpected_keys = v.load_state_dict(state_dict[k], strict=False) 112 | if len(missing_keys) > 0: 113 | print('Warning: Could not find %s in checkpoint!' % missing_keys) 114 | if len(unexpected_keys) > 0: 115 | print('Warning: Found unexpectedly %s in checkpoint!' % unexpected_keys) 116 | 117 | else: 118 | print('Warning: Could not find %s in checkpoint!' % k) 119 | scalars = {k: v for k, v in state_dict.items() 120 | if k not in self.module_dict} 121 | return scalars 122 | 123 | 124 | def is_url(url): 125 | ''' Checks if input string is a URL. 126 | 127 | Args: 128 | url (string): URL 129 | ''' 130 | scheme = urllib.parse.urlparse(url).scheme 131 | return scheme in ('http', 'https') 132 | -------------------------------------------------------------------------------- /DSS/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import namedtuple 3 | 4 | __all__ = ['BaseGenerator', 'ImplicitModel', 'OccupancyModel', 5 | 'PointModel', 'CombinedModel', 'ModelReturns'] 6 | 7 | 8 | class BaseGenerator(object): 9 | def __init__(self, model, device): 10 | self.model = model.to(device) 11 | self.device = device 12 | 13 | def generate_meshes(self, *args, **kwargs): 14 | return [] 15 | 16 | def generate_pointclouds(self, *args, **kwargs): 17 | return [] 18 | 19 | def generate_images(self, *args, **kwargs): 20 | return [] 21 | 22 | 23 | ModelReturns = namedtuple( 24 | 'ModelReturns', 'pointclouds mask_pred sdf_freespace sdf_occupancy img_pred mask_img_pred') 25 | 26 | from .implicit_modeling import Model as ImplicitModel 27 | from .point_modeling import Model as PointModel 28 | from .combined_modeling import Model as CombinedModel 29 | -------------------------------------------------------------------------------- /DSS/training/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | change trainer settings according to iterations 3 | """ 4 | from typing import List 5 | import bisect 6 | from .. import logger_py 7 | from ..models.levelset_sampling import LevelSetProjection 8 | 9 | 10 | class TrainerScheduler(object): 11 | """ Increase n_points_per_cloud and Reduce n_training_points """ 12 | 13 | def __init__(self, init_n_points_dss: int = 0, init_n_rays: int = 0, 14 | init_proj_tolerance: float = 0, 15 | init_sdf_alpha: float=1.0, 16 | init_lambda_occupied: float=5.0, init_lambda_freespace: float=5.0, 17 | steps_n_points_dss: int = -1, steps_n_rays: int= -1, 18 | steps_proj_tolerance: int = -1, steps_lambda_rgb: int = -1, 19 | steps_sdf_alpha:int = -1, 20 | steps_lambda_sdf: int = -1, 21 | init_lambda_rgb: float = 1.0, warm_up_iters: int = None, 22 | gamma_n_points_dss: float = 2.0, gamma_n_rays: float = 0.5, 23 | gamma_lambda_rgb: float = 1.0, 24 | gamma_sdf_alpha: float = 1.0, 25 | gamma_lambda_sdf: float = 1.0, 26 | limit_n_points_dss: int = 1e5, limit_n_rays: int = 0, 27 | limit_lambda_rgb: float = 1.0, 28 | limit_sdf_alpha: float=100, 29 | limit_lambda_freespace: float=1.0, limit_lambda_occupied=1.0, 30 | gamma_proj_tolerance: float = 0.1, limit_proj_tolerance: float = 5e-5, 31 | ): 32 | """ steps_n_points_dss: list """ 33 | 34 | self.init_n_points_dss = init_n_points_dss 35 | self.init_n_rays = init_n_rays 36 | self.init_proj_tolerance = init_proj_tolerance 37 | self.init_sdf_alpha = init_sdf_alpha 38 | self.init_lambda_rgb = init_lambda_rgb 39 | self.init_lambda_freespace = init_lambda_freespace 40 | self.init_lambda_occupied = init_lambda_occupied 41 | 42 | self.steps_n_points_dss = steps_n_points_dss 43 | self.steps_n_rays = steps_n_rays 44 | self.steps_proj_tolerance = steps_proj_tolerance 45 | self.steps_lambda_rgb = steps_lambda_rgb 46 | self.steps_sdf_alpha = steps_sdf_alpha 47 | self.steps_lambda_sdf = steps_lambda_sdf 48 | 49 | self.gamma_n_points_dss = gamma_n_points_dss 50 | self.gamma_n_rays = gamma_n_rays 51 | self.gamma_proj_tolerance = gamma_proj_tolerance 52 | self.gamma_lambda_rgb = gamma_lambda_rgb 53 | self.gamma_sdf_alpha = gamma_sdf_alpha 54 | self.gamma_lambda_sdf = gamma_lambda_sdf 55 | 56 | self.limit_n_points_dss = limit_n_points_dss 57 | self.limit_n_rays = limit_n_rays 58 | self.limit_proj_tolerance = limit_proj_tolerance 59 | self.limit_sdf_alpha = limit_sdf_alpha 60 | self.limit_lambda_freespace = limit_lambda_freespace 61 | self.limit_lambda_occupied = limit_lambda_occupied 62 | self.limit_lambda_rgb = limit_lambda_rgb 63 | 64 | self.warm_up_iters = warm_up_iters 65 | 66 | def step(self, trainer, it): 67 | if it < 0: 68 | return 69 | 70 | if self.steps_n_points_dss > 0 and hasattr(trainer.model, 'n_points_per_cloud'): 71 | i = it // self.steps_n_points_dss 72 | gamma = self.gamma_n_points_dss ** i 73 | old_n_points_per_cloud = trainer.model.n_points_per_cloud 74 | trainer.model.n_points_per_cloud = min( 75 | int(self.init_n_points_dss * gamma), self.limit_n_points_dss) 76 | if old_n_points_per_cloud != trainer.model.n_points_per_cloud: 77 | logger_py.info('Updated n_points_per_cloud: {} -> {}'.format( 78 | old_n_points_per_cloud, trainer.model.n_points_per_cloud)) 79 | 80 | if self.steps_n_rays > 0: 81 | # reduce n_rays gradually 82 | i = it // self.steps_n_rays 83 | gamma = self.gamma_n_rays ** i 84 | old_n_rays = trainer.n_training_points 85 | if self.gamma_n_rays < 1: 86 | trainer.n_training_points = max( 87 | int(self.init_n_rays * gamma), self.limit_n_rays) 88 | else: 89 | trainer.n_training_points = min( 90 | int(self.init_n_rays * gamma), self.limit_n_rays) 91 | if old_n_rays != trainer.n_training_points: 92 | logger_py.info('Updated n_training_points: {} -> {}'.format( 93 | old_n_rays, trainer.n_training_points)) 94 | 95 | # adjust projection tolerance and proj_max_iters 96 | if self.steps_proj_tolerance > 0 and it % self.steps_proj_tolerance == 0: 97 | if hasattr(trainer.model, 'projection') and isinstance(trainer.model.projection, LevelSetProjection): 98 | old_proj_tol = trainer.model.projection.proj_tolerance 99 | i = it // self.steps_proj_tolerance 100 | gamma = self.gamma_proj_tolerance ** i 101 | trainer.model.projection.proj_tolerance = max(self.init_proj_tolerance*gamma, self.limit_proj_tolerance) 102 | if old_proj_tol != trainer.model.projection.proj_tolerance: 103 | logger_py.info('Updated projection.proj_tolerance: {} -> {}'.format( 104 | old_proj_tol, trainer.model.projection.proj_tolerance)) 105 | trainer.model.projection.proj_max_iters = min(trainer.model.projection.proj_max_iters *2, 50) 106 | if hasattr(trainer.model, 'sphere_tracing') and isinstance(trainer.model.sphere_tracing, LevelSetProjection): 107 | old_proj_tol = trainer.model.sphere_tracing.proj_tolerance 108 | trainer.model.sphere_tracing.proj_tolerance = max(self.init_proj_tolerance*gamma, self.limit_proj_tolerance) 109 | if old_proj_tol != trainer.model.sphere_tracing.proj_tolerance: 110 | logger_py.info('Updated sphere_tracing.proj_tolerance: {} -> {}'.format( 111 | old_proj_tol, trainer.model.sphere_tracing.proj_tolerance)) 112 | trainer.model.sphere_tracing.proj_max_iters = min(trainer.model.sphere_tracing.proj_max_iters *2, 50) 113 | 114 | 115 | # increase lambda_rgb slowly 116 | if self.steps_lambda_rgb > 0: 117 | # also change the init_lambda_rgb 118 | old_lambda_rgb = trainer.lambda_rgb 119 | trainer.lambda_rgb = self.init_lambda_rgb * self.gamma_lambda_rgb ** ( 120 | it // self.steps_lambda_rgb) 121 | trainer.lambda_rgb = min(trainer.lambda_rgb, self.limit_lambda_rgb) 122 | if old_lambda_rgb != trainer.lambda_rgb: 123 | logger_py.info('Updated lambda_rgb: {} -> {}'.format( 124 | old_lambda_rgb, trainer.lambda_rgb)) 125 | 126 | # update init_lambda_occupied and init_lambda_freespace 127 | if self.steps_lambda_sdf > 0: 128 | old_lambda = trainer.lambda_freespace 129 | scale = self.gamma_lambda_sdf ** (it // self.steps_lambda_sdf) 130 | trainer.lambda_freespace = self.init_lambda_freespace * scale 131 | if self.gamma_lambda_sdf < 1.0 and self.limit_lambda_freespace < self.init_lambda_freespace: 132 | trainer.lambda_freespace = max(trainer.lambda_freespace, self.limit_lambda_freespace) 133 | else: 134 | trainer.lambda_freespace = min(trainer.lambda_freespace, self.limit_lambda_freespace) 135 | 136 | if old_lambda != trainer.lambda_freespace: 137 | logger_py.info('Updated lambda_freespace: {} -> {}'.format( 138 | old_lambda, trainer.lambda_freespace)) 139 | 140 | old_lambda = trainer.lambda_occupied 141 | scale = self.gamma_lambda_sdf ** (it // self.steps_lambda_sdf) 142 | trainer.lambda_occupied = self.init_lambda_occupied * scale 143 | if self.gamma_lambda_sdf < 1.0 and self.limit_lambda_occupied < self.init_lambda_occupied: 144 | trainer.lambda_occupied = max(trainer.lambda_occupied, self.limit_lambda_occupied) 145 | else: 146 | trainer.lambda_occupied = min(trainer.lambda_occupied, self.limit_lambda_occupied) 147 | 148 | if old_lambda != trainer.lambda_occupied: 149 | logger_py.info('Updated lambda_occupied: {} -> {}'.format( 150 | old_lambda, trainer.lambda_occupied)) 151 | 152 | if self.steps_sdf_alpha > 0: 153 | # change sdf loss weight gradually 154 | i = it // self.steps_sdf_alpha 155 | gamma = self.gamma_sdf_alpha ** i 156 | old_alpha = trainer.sdf_alpha 157 | if self.gamma_sdf_alpha < 1: 158 | trainer.sdf_alpha = max( 159 | int(self.init_sdf_alpha * gamma), self.limit_sdf_alpha) 160 | else: 161 | trainer.sdf_alpha = min( 162 | int(self.init_sdf_alpha * gamma), self.limit_sdf_alpha) 163 | if old_alpha != trainer.sdf_alpha: 164 | logger_py.info('Updated sdf_alpha: {} -> {}'.format( 165 | old_alpha, trainer.sdf_alpha)) 166 | -------------------------------------------------------------------------------- /DSS/utils/io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import plyfile 4 | import numpy as np 5 | from matplotlib import cm 6 | import matplotlib.colors as mpc 7 | 8 | 9 | def saveDebugPNG(projPoint, imgTensor, savePath): 10 | """ 11 | save imgTensor to PNG, highlight the projPoint with grid lines 12 | params: 13 | projPoint (1, 2) 14 | imgTensor (H,W,3or1) torch.Tensor or numpy.array 15 | """ 16 | import matplotlib.pyplot as plt 17 | # normalize imgTensor 18 | plt.clf() 19 | cmin = imgTensor.min() 20 | cmax = imgTensor.max() 21 | imgTensor = (imgTensor - cmin) / (cmax - cmin) 22 | imgTensor[np.isnan(imgTensor) != False] = 0.0 23 | if imgTensor.ndim == 2 or (imgTensor.ndim == 3 and imgTensor.shape[-1] == 1): 24 | plt.imshow(imgTensor, cmap='gray') 25 | else: 26 | plt.imshow(imgTensor) 27 | i, j = projPoint.flatten()[:] 28 | plt.scatter(i, j, facecolors='none', edgecolors="cyan") 29 | plt.axvline(x=i, color='red') 30 | plt.axhline(y=j, color='red') 31 | plt.savefig(savePath) 32 | 33 | 34 | def encodeFlow(flowTensor: torch.Tensor, logScale=False): 35 | """ 36 | encode the vector field to a colored image 37 | :params 38 | flowTensor: (H,W,2) 39 | :return 40 | rgb: (H,W,3) numpy array floating type 41 | """ 42 | h, w = flowTensor.shape[:2] 43 | rho, phi = cart2pol(flowTensor[:, :, 0], flowTensor[:, :, 1]) 44 | rmin, rmax = rho.min(), rho.max() 45 | rho = (rho - rmin) / (rmax - rmin) * 255 46 | if logScale: 47 | rho = torch.log(1 + rho) 48 | rho[np.isnan(rho) != False] = 0.0 49 | hsv = np.full((h, w, 3), 255, dtype=np.uint8) 50 | hsv[..., 0] = phi * 255 / 2 / np.pi 51 | hsv[..., 2] = rho 52 | from skimage.color import hsv2rgb 53 | rgb = hsv2rgb(hsv) 54 | return rgb 55 | 56 | 57 | def cart2pol(x, y): 58 | """ 59 | cartesian coordinates to polar coordinates 60 | return: 61 | rho: length 62 | phi: (, 2pi) 63 | """ 64 | rho = (x**2 + y**2).sqrt() 65 | phi = np.arctan2(y, x) + np.pi 66 | return (rho, phi) 67 | 68 | 69 | def pol2cart(rho, phi): 70 | """ polar to cartesian """ 71 | x = rho * phi.cos() 72 | y = rho * phi.sin() 73 | return (x, y) 74 | 75 | 76 | def read_ply(file): 77 | loaded = plyfile.PlyData.read(file) 78 | points = np.vstack([loaded['vertex'].data['x'], 79 | loaded['vertex'].data['y'], loaded['vertex'].data['z']]) 80 | if 'nx' in loaded['vertex'].data.dtype.names: 81 | normals = np.vstack([loaded['vertex'].data['nx'], 82 | loaded['vertex'].data['ny'], loaded['vertex'].data['nz']]) 83 | points = np.concatenate([points, normals], axis=0) 84 | 85 | points = points.transpose(1, 0) 86 | return points 87 | 88 | 89 | def save_ply(filename, points, colors=None, normals=None, binary=True): 90 | """ 91 | save 3D/2D points to ply file 92 | Args: 93 | points (numpy array): (N,2or3) 94 | colors (numpy uint8 array): (N, 3or4) 95 | """ 96 | assert(points.ndim == 2) 97 | if points.shape[-1] == 2: 98 | points = np.concatenate( 99 | [points, np.zeros_like(points)[:, :1]], axis=-1) 100 | 101 | vertex = np.core.records.fromarrays(points.transpose( 102 | 1, 0), names='x, y, z', formats='f4, f4, f4') 103 | num_vertex = len(vertex) 104 | desc = vertex.dtype.descr 105 | 106 | if normals is not None: 107 | assert(normals.ndim == 2) 108 | if normals.shape[-1] == 2: 109 | normals = np.concatenate( 110 | [normals, np.zeros_like(normals)[:, :1]], axis=-1) 111 | vertex_normal = np.core.records.fromarrays( 112 | normals.transpose(1, 0), names='nx, ny, nz', formats='f4, f4, f4') 113 | assert len(vertex_normal) == num_vertex 114 | desc = desc + vertex_normal.dtype.descr 115 | 116 | if colors is not None: 117 | assert len(colors) == num_vertex 118 | if colors.max() <= 1: 119 | colors = colors * 255 120 | if colors.shape[1] == 4: 121 | vertex_color = np.core.records.fromarrays(colors.transpose( 122 | 1, 0), names='red, green, blue, alpha', formats='u1, u1, u1, u1') 123 | else: 124 | vertex_color = np.core.records.fromarrays(colors.transpose( 125 | 1, 0), names='red, green, blue', formats='u1, u1, u1') 126 | desc = desc + vertex_color.dtype.descr 127 | 128 | vertex_all = np.empty(num_vertex, dtype=desc) 129 | 130 | for prop in vertex.dtype.names: 131 | vertex_all[prop] = vertex[prop] 132 | 133 | if normals is not None: 134 | for prop in vertex_normal.dtype.names: 135 | vertex_all[prop] = vertex_normal[prop] 136 | 137 | if colors is not None: 138 | for prop in vertex_color.dtype.names: 139 | vertex_all[prop] = vertex_color[prop] 140 | 141 | ply = plyfile.PlyData( 142 | [plyfile.PlyElement.describe(vertex_all, 'vertex')], text=(not binary)) 143 | if not os.path.exists(os.path.dirname(filename)): 144 | os.makedirs(os.path.dirname(filename)) 145 | ply.write(filename) 146 | 147 | 148 | def save_ply_property(filename, points, property, 149 | property_max=None, property_min=None, 150 | normals=None, cmap_name='Set1', binary=True): 151 | point_num = points.shape[0] 152 | colors = np.full([point_num, 3], 0.5) 153 | cmap = cm.get_cmap(cmap_name) 154 | if property_max is None: 155 | property_max = np.amax(property, axis=0) 156 | if property_min is None: 157 | property_min = np.amin(property, axis=0) 158 | p_range = property_max - property_min 159 | if property_max == property_min: 160 | property_max = property_min + 1 161 | normalizer = mpc.Normalize(vmin=property_min, vmax=property_max) 162 | p = normalizer(property) 163 | colors = cmap(p)[:, :3] 164 | save_ply(filename, points, colors, normals, binary) 165 | -------------------------------------------------------------------------------- /DSS/utils/mathHelper.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Optional 2 | import numpy as np 3 | import torch 4 | from torch_batch_svd import svd as batch_svd 5 | import scipy.sparse as sp 6 | from scipy.sparse.linalg import spsolve 7 | from pytorch3d.structures import Pointclouds 8 | from pytorch3d.loss.mesh_laplacian_smoothing import laplacian_cot 9 | from pytorch3d.ops.utils import convert_pointclouds_to_tensor 10 | from pytorch3d.ops import knn_points 11 | from pytorch3d.ops.points_normals import _disambiguate_vector_directions 12 | 13 | 14 | def eps_denom(denom, eps=1e-17): 15 | """ Prepare denominator for division """ 16 | denom_sign = denom.sign() + (denom == 0.0).type_as(denom) 17 | denom = denom_sign * torch.clamp(denom.abs(), eps) 18 | return denom 19 | 20 | def eps_sqrt(squared, eps=1e-17): 21 | """ 22 | Prepare for the input for sqrt, make sure the input positive and 23 | larger than eps 24 | """ 25 | return torch.clamp(squared.abs(), eps) 26 | 27 | 28 | def pinverse(inputs: torch.Tensor): 29 | assert(inputs.ndim >= 2) 30 | shp = inputs.shape 31 | U, S, V = batch_svd(inputs.view(-1, shp[-2], shp[-1])) 32 | S[S < 1e-6] = 0 33 | S_inv = torch.where(S < 1e-5, torch.zeros_like(S), 1/S) 34 | pinv = V @ torch.diag_embed(S_inv) @ U.transpose(1,2) 35 | return pinv.view(shp) 36 | 37 | def clip_norm(x, dim=-1, max_value=1.0): 38 | """ clip norm in the given dimension """ 39 | x_norm = torch.norm(x, dim=dim) 40 | factor = torch.where(x_norm > max_value, max_value / x_norm, torch.ones_like(x_norm)) 41 | return x * factor.unsqueeze(dim=dim) 42 | 43 | def estimate_pointcloud_local_coord_frames( 44 | pointclouds: Union[torch.Tensor, Pointclouds], 45 | neighborhood_size: int = 50, 46 | disambiguate_directions: bool = True, 47 | return_knn_result: bool = False, 48 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional['KNN']]: 49 | """ 50 | Faster version of pytorch3d estimate_pointcloud_local_coord_frames 51 | 52 | Estimates the principal directions of curvature (which includes normals) 53 | of a batch of `pointclouds`. 54 | Returns: 55 | curvatures (N,P,3) ascending order 56 | local_frames (N,P,3,3) corresponding eigenvectors 57 | """ 58 | points_padded, num_points = convert_pointclouds_to_tensor(pointclouds) 59 | 60 | ba, N, dim = points_padded.shape 61 | if dim != 3: 62 | raise ValueError( 63 | "The pointclouds argument has to be of shape (minibatch, N, 3)" 64 | ) 65 | 66 | if (num_points <= neighborhood_size).any(): 67 | raise ValueError( 68 | "The neighborhood_size argument has to be" 69 | + " >= size of each of the point clouds." 70 | ) 71 | # undo global mean for stability 72 | # TODO: replace with tutil.wmean once landed 73 | pcl_mean = points_padded.sum(1) / num_points[:, None] 74 | points_centered = points_padded - pcl_mean[:, None, :] 75 | 76 | # get K nearest neighbor idx for each point in the point cloud 77 | knn_result = knn_points( 78 | points_padded, 79 | points_padded, 80 | lengths1=num_points, 81 | lengths2=num_points, 82 | K=neighborhood_size, 83 | return_nn=True, 84 | ) 85 | k_nearest_neighbors = knn_result.knn 86 | # obtain the mean of the neighborhood 87 | pt_mean = k_nearest_neighbors.mean(2, keepdim=True) 88 | # compute the diff of the neighborhood and the mean of the neighborhood 89 | # N,P,K,3 90 | central_diff = k_nearest_neighbors - pt_mean 91 | per_pts_diff = central_diff.view(-1, neighborhood_size, 3) 92 | # S (NP,3) and local_coord_framds (NP,3,3) 93 | _, S, local_coord_frames = batch_svd(per_pts_diff) 94 | curvature = S * S / neighborhood_size 95 | local_coord_frames = local_coord_frames.view(ba, N, dim, dim) 96 | curvature = curvature.view(ba, N, dim) 97 | 98 | # flip to ascending order 99 | curvature = curvature.flip(-1) 100 | local_coord_frames = local_coord_frames.flip(-1) 101 | 102 | # disambiguate the directions of individual principal vectors 103 | if disambiguate_directions: 104 | # disambiguate normal 105 | n = _disambiguate_vector_directions( 106 | points_centered, k_nearest_neighbors, local_coord_frames[:, :, :, 0] 107 | ) 108 | # disambiguate the main curvature 109 | z = _disambiguate_vector_directions( 110 | points_centered, k_nearest_neighbors, local_coord_frames[:, :, :, 2] 111 | ) 112 | # the secondary curvature is just a cross between n and z 113 | y = torch.cross(n, z, dim=2) 114 | # cat to form the set of principal directions 115 | local_coord_frames = torch.stack((n, y, z), dim=3) 116 | 117 | if return_knn_result: 118 | return curvature, local_coord_frames, knn_result 119 | return curvature, local_coord_frames 120 | 121 | 122 | def estimate_pointcloud_normals( 123 | pointclouds: Union[torch.Tensor, Pointclouds], 124 | neighborhood_size: int = 50, 125 | disambiguate_directions: bool = True, 126 | ) -> torch.Tensor: 127 | """ 128 | Estimates the normals of a batch of `pointclouds` using fast `estimate_pointcloud_local_coord_frames 129 | 130 | Args: 131 | **pointclouds**: Batch of 3-dimensional points of shape 132 | `(minibatch, num_point, 3)` or a `Pointclouds` object. 133 | **neighborhood_size**: The size of the neighborhood used to estimate the 134 | geometry around each point. 135 | **disambiguate_directions**: If `True`, uses the algorithm from [1] to 136 | ensure sign consistency of the normals of neigboring points. 137 | 138 | Returns: 139 | **normals**: A tensor of normals for each input point 140 | of shape `(minibatch, num_point, 3)`. 141 | If `pointclouds` are of `Pointclouds` class, returns a padded tensor. 142 | 143 | References: 144 | [1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for 145 | Local Surface Description, ECCV 2010. 146 | """ 147 | curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames( 148 | pointclouds, 149 | neighborhood_size=neighborhood_size, 150 | disambiguate_directions=disambiguate_directions, 151 | ) 152 | 153 | # the normals correspond to the first vector of each local coord frame 154 | normals = local_coord_frames[:, :, :, 0] 155 | 156 | return normals 157 | 158 | 159 | def ndc_to_pix(p, resolution): 160 | """ 161 | Reverse of pytorch3d pix_to_ndc function 162 | Args: 163 | p (float tensor): (..., 3) 164 | resolution (scalar): image resolution (for now, supports only aspectratio = 1) 165 | Returns: 166 | pix (long tensor): (..., 2) 167 | """ 168 | pix = resolution - ((p[..., :2] + 1.0) * resolution - 1.0) / 2 169 | return pix 170 | 171 | 172 | def decompose_to_R_and_t(transform_mat, row_major=True): 173 | """ decompose a 4x4 transform matrix to R (3,3) and t (1,3)""" 174 | assert(transform_mat.shape[-2:] == (4, 4)), \ 175 | "Expecting batches of 4x4 matrice" 176 | # ... 3x3 177 | if not row_major: 178 | transform_mat = transform_mat.transpose(-2, -1) 179 | 180 | R = transform_mat[..., :3, :3] 181 | t = transform_mat[..., -1, :3] 182 | 183 | return R, t 184 | 185 | 186 | def to_homogen(x, dim=-1): 187 | """ append one to the specified dimension """ 188 | if dim < 0: 189 | dim = x.ndim + dim 190 | shp = x.shape 191 | new_shp = shp[:dim] + (1, ) + shp[dim + 1:] 192 | x_homogen = x.new_ones(new_shp) 193 | x_homogen = torch.cat([x, x_homogen], dim=dim) 194 | return x_homogen 195 | 196 | 197 | def vectors_to_angles(x, y, z): 198 | """ Returns azim [0, 2pi), elev (-pi/2, pi/2) """ 199 | # x = dist * torch.cos(elev) * torch.sin(azim) 200 | # y = dist * torch.sin(elev) 201 | # z = dist * torch.cos(elev) * torch.cos(azim) 202 | azim = torch.atan2(x, z) 203 | dist = torch.sqrt(eps_sqrt(x * x + y * y + z * z)) 204 | elev = torch.asin(y / eps_denom(dist)) 205 | return azim, elev, dist 206 | 207 | 208 | def angles_to_vectors(azim, elev, dist): 209 | """ Returns the rotation matrix from azim, elev and dist """ 210 | x = torch.cos(elev) * torch.sin(azim) 211 | y = torch.sin(elev) 212 | z = torch.cos(elev) * torch.cos(azim) 213 | vec = torch.stack([x, y, z], dim=-1) 214 | return vec 215 | 216 | 217 | def cot_laplacian_matrix(meshes: 'Meshes', smooth=False): 218 | """ 219 | Returns a scipy sparse matrix 220 | """ 221 | L, inv_areas = laplacian_cot(meshes) 222 | L = L.coalesce() 223 | L_sp = sp.coo_matrix((L._values(), (L._indices()[1], L._indices()[0])), L.shape).tocsr() 224 | norm_w = 1.0/L_sp.sum(axis=1).squeeze(1) 225 | L_sp_final = sp.spdiags(norm_w, 0, L_sp.shape[0], L_sp.shape[1]).dot(L_sp) - sp.identity(L_sp.shape[0]) 226 | return L_sp_final 227 | 228 | def smooth_mesh_curvature(meshes, iters=10, alpha=0.5): 229 | L_sp_final = cot_laplacian_matrix(meshes) 230 | vertices = meshes.verts_packed().cpu().detach().numpy() 231 | curvatures = np.linalg.norm(L_sp_final.dot(vertices), axis=-1) 232 | 233 | for it in range(iters): 234 | curvatures = spsolve((sp.identity(vertices.shape[0]) - alpha*L_sp_final), curvatures) 235 | 236 | curvatures = torch.from_numpy(curvatures, device=meshes.device, dtype=vertices.dtype) 237 | return curvatures 238 | 239 | 240 | class RunningStat(object): 241 | """ 242 | Running Mean and Variance 243 | """ 244 | def __init__(self, n_values, device='cpu'): 245 | self._counter = torch.zeros(n_values, dtype=torch.float).to(device=device) 246 | self._mean = torch.zeros(n_values, dtype=torch.float).to(device=device) 247 | self._variance = torch.zeros(n_values, dtype=torch.float).to(device=device) 248 | 249 | def add(self, values, mask=None): 250 | with torch.autograd.no_grad(): 251 | if mask is not None: 252 | self._counter += mask.float() 253 | # m_newS = m_oldS + (x - m_oldM)*(x - m_newM) 254 | new_mean = self._mean[mask] + (values[mask] - self._mean[mask])/self._counter[mask] 255 | self._variance[mask] = self._variance[mask] + (values[mask] - new_mean)*(values[mask]-self._mean[mask]) 256 | self._mean[mask] = new_mean 257 | else: 258 | self._counter += 1 259 | new_mean = self._mean + (values - self._mean)/self._counter 260 | self._variance = self._variance + (values - self._mean)*(values-new_mean) 261 | self._mean = new_mean 262 | 263 | if self._variance.isinf().any(): 264 | __import__('pdb').set_trace() 265 | 266 | def mean(self): 267 | return self._mean 268 | 269 | def variance(self): 270 | return self._variance/torch.where(self._counter>1, self._counter - 1, torch.ones_like(self._counter)) 271 | 272 | def std(self): 273 | return torch.where(self._variance > 0, self._variance.sqrt(), torch.zeros_like(self._variance)) 274 | -------------------------------------------------------------------------------- /DSS/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import int_classes as _int_classes 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class WeightedSubsetRandomSampler(Sampler): 7 | r"""Samples elements from a given list of indices with given probabilities (weights), with replacement. 8 | 9 | Arguments: 10 | weights (sequence) : a sequence of weights, not necessary summing up to one 11 | num_samples (int): number of samples to draw 12 | """ 13 | 14 | def __init__(self, indices, weights, num_samples=0): 15 | if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool): 16 | raise ValueError("num_samples should be a non-negative integeral " 17 | "value, but got num_samples={}".format(num_samples)) 18 | self.indices = indices 19 | weights = [weights[i] for i in self.indices] 20 | self.weights = torch.tensor(weights, dtype=torch.double) 21 | if num_samples == 0: 22 | self.num_samples = len(self.weights) 23 | else: 24 | self.num_samples = num_samples 25 | self.replacement = True 26 | 27 | def __iter__(self): 28 | return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, self.replacement)) 29 | 30 | def __len__(self): 31 | return self.num_samples 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation for Iso-Points (CVPR 2021) 2 | 3 | *Official code for paper Iso-Points: Optimizing Neural Implicit Surfaces with Hybrid Representations* 4 | 5 | [paper](https://igl.ethz.ch/projects/iso-points/iso_points-CVPR2021-yifan.pdf) | [supplementary material](https://igl.ethz.ch/projects/iso-points/iso_points-supp-CVPR2021-yifan.pdf) | [project page](https://igl.ethz.ch/projects/iso-points/) 6 | 7 | 8 | 9 | ## Overview 10 | ***Iso-points*** are well-distributed points which lie on the neural iso-surface, they are an explicit form of representation of the implicit surfaces. We propose using iso-points to augment the optimization of implicit neural surfaces. 11 | The implicit and explicit surface representations are coupled, i.e. the implicit model determines the locations and normals of iso-points, whereas the iso-points can be utilized to control the optimization of the implicit model. 12 | 13 | ***The implementation*** of the key steps for iso-points extraction is in `levelset_sampling.py` and `utils/point_processing.py`. 14 | To demonstrate the utilisation of *iso-points*, we provide scripts for multiple applications and scenarios: 15 | - [multiview reconstruction](#multiview-reconstruction) 16 | - [Demo 1](#sampling-with-iso-points): importance sampling with iso-points 17 | - [Demo 2](#DTU-data): [Multiview Neural Surface Reconstruction][IDR] on [DTU] data with iso-points 18 | - [surface reconstruction from sparse point cloud](#implicit-surface-to-noisy-point-cloud) 19 | 20 | ## Demo 21 | ### Installation 22 | This code is built as an extension of out Differentiable Surface Splatting pytorch library ([DSS](https://github.com/yifita/dss)), which depends on [pytorch3d](https://github.com/facebookresearch/pytorch3d), [torch_cluster](https://github.com/rusty1s/pytorch_cluster). 23 | **Currently we support up to pytorch 1.6**. 24 | 25 | ````bash 26 | git clone --recursive https://github.com/yifita/iso-points.git 27 | cd iso-points 28 | 29 | # conda environment and dependencies 30 | # update conda 31 | conda update -n base -c defaults conda 32 | # install requirements 33 | conda env create --name DSS -f environment.yml 34 | conda activate DSS 35 | 36 | # build additional dependencies of DSS 37 | # FRNN - fixed radius nearest neighbors 38 | cd external/FRNN/external 39 | git submodule update --init --recursive 40 | cd prefix_sum 41 | python setup.py install 42 | cd ../.. 43 | python setup.py install 44 | 45 | # build batch-svd 46 | cd ../torch-batch-svd 47 | python setup.py install 48 | 49 | # build DSS itself 50 | cd ../.. 51 | python setup.py develop 52 | ```` 53 | 54 | ### prepare data 55 | Download data 56 | ```bash 57 | cd data 58 | wget https://igl.ethz.ch/projects/iso-points/data.zip 59 | unzip data.zip 60 | rm data.zip 61 | ``` 62 | Including subset of masked [DTU data](https://www.dropbox.com/sh/5tam07ai8ch90pf/AADniBT3dmAexvm_J1oL__uoa) (courtesy of Yariv et.al.), synthetic rendered multiview data, and masked furu stereo reconstruction of DTU dataset. 63 | 64 | ### multiview reconstruction 65 | #### sampling-with-iso-points 66 | ```bash 67 | # train baseline implicit representation only using ray-tracing 68 | python train_mvr.py configs/compressor_implicit.yml --exit-after 6000 69 | 70 | # train with uniform iso-points 71 | python train_mvr.py configs/compressor_uni.yml --exit-after 6000 72 | 73 | # train with iso-points distributed according to loss value (hard example mining) 74 | python train_mvr.py configs/compressor_uni_lossS.yml --exit-after 6000 75 | ``` 76 | sampling result 77 | 78 | ### DTU-data 79 | ```bash 80 | python train_mvr.py configs/dtu55_iso.yml 81 | ``` 82 | dtu mvr result 83 | 84 | ### implicit surface to noisy point cloud 85 | ```bash 86 | python test_dtu_points.py data/DTU_furu/scan122.ply --use_off_normal_loss -o exp/points_3d_outputs/scan122_ours 87 | ``` 88 | 89 | 90 | 91 | ## cite 92 | Please cite us if you find the code useful! 93 | ``` 94 | @inproceedings{yifan2020isopoints, 95 | title={Iso-Points: Optimizing Neural Implicit Surfaces with Hybrid Representations}, 96 | author={Wang Yifan and Shihao Wu and Cengiz Oztireli and Olga Sorkine-Hornung}, 97 | year={2020}, 98 | booktitle = {CVPR}, 99 | year = {2020}, 100 | } 101 | ``` 102 | 103 | ## Acknowledgement 104 | This work was supported in parts by Apple scholarship, SWISSHEART Failure Network (SHFN), and UKRI Future Leaders Fellowship [grant number MR/T043229/1] 105 | 106 | A good portion of this codebase uses or adapts codes from previous works and implementations. We sincerely thank the authors for their effort in making their work accessible. 107 | Most notably we refer to the following repos 108 | - Siren: https://github.com/vsitzmann/siren 109 | - IDR: https://github.com/lioryariv/idr 110 | - DVR: https://github.com/autonomousvision/differentiable_volumetric_rendering 111 | 112 | [IDR]: https://github.com/lioryariv/idr 113 | [DTU]: http://roboimagedata.compute.dtu.dk/?page_id=36 114 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import random 5 | from DSS.misc.visualize import animate_points, animate_mesh, figures_to_html 6 | from DSS import logger_py 7 | 8 | 9 | def create_animation(pts_dir, show_max=-1): 10 | figs = [] 11 | # points 12 | pts_files = [f for f in os.listdir(pts_dir) if 'pts' in f and f[-4:].lower() in ('.ply', 'obj')] 13 | if len(pts_files) == 0: 14 | logger_py.info("Couldn't find '*pts*' files in {}".format(pts_dir)) 15 | else: 16 | pts_files.sort() 17 | if show_max > 0: 18 | pts_files = pts_files[::max(len(pts_files) // show_max, 1)] 19 | pts_names = list(map(lambda x: os.path.basename(x) 20 | [:-4].split('_')[0], pts_files)) 21 | pts_paths = [os.path.join(pts_dir, fname) for fname in pts_files] 22 | fig = animate_points(pts_paths, pts_names) 23 | figs.append(fig) 24 | # mesh 25 | mesh_files = [f for f in os.listdir(pts_dir) if 'mesh' in f and f[-4:].lower() in ('.ply', '.obj')] 26 | # mesh_files = list(filter(lambda x: x.split('_') 27 | # [1] == '000.obj', mesh_files)) 28 | if len(mesh_files) == 0: 29 | logger_py.info("Couldn't find '*mesh*' files in {}".format(pts_dir)) 30 | else: 31 | mesh_files.sort() 32 | if show_max > 0: 33 | mesh_files = mesh_files[::max(len(mesh_files) // show_max, 1)] 34 | mesh_names = list(map(lambda x: os.path.basename(x) 35 | [:-4].split('_')[0], mesh_files)) 36 | mesh_paths = [os.path.join(pts_dir, fname) for fname in mesh_files] 37 | fig = animate_mesh(mesh_paths, mesh_names) 38 | figs.append(fig) 39 | 40 | save_html = os.path.join(pts_dir, 'animation.html') 41 | os.makedirs(os.path.dirname(save_html), exist_ok=True) 42 | figures_to_html(figs, save_html) 43 | 44 | 45 | def get_tri_color_lights_for_view(cams, has_specular=False, point_lights=True): 46 | """ 47 | Create RGB lights direction in the half dome 48 | The direction is given in the same coordinates as the pointcloud 49 | Args: 50 | cams 51 | Returns: 52 | Lights with three RGB light sources (B: right, G: left, R: bottom) 53 | """ 54 | import math 55 | from DSS.core.lighting import (DirectionalLights, PointLights) 56 | from pytorch3d.renderer.cameras import look_at_rotation 57 | from pytorch3d.transforms import Rotate 58 | 59 | elev = torch.tensor(((30, 30, 30),),device=cams.device) 60 | azim = torch.tensor(((-60, 60, 180),),device=cams.device) 61 | elev = math.pi / 180.0 * elev 62 | azim = math.pi / 180.0 * azim 63 | 64 | x = torch.cos(elev) * torch.sin(azim) 65 | y = torch.sin(elev) 66 | z = torch.cos(elev) * torch.cos(azim) 67 | light_directions = torch.stack([x, y, z], dim=-1) 68 | # import trimesh 69 | # import pdb; pdb.set_trace() 70 | # trimesh.Trimesh(vertices=light_directions[0].cpu().numpy(), process=False).export('tests/outputs/light_dir_pre.ply') 71 | # transform from y-up to z-up 72 | # transform from camera to world 73 | cam_pos = cams.get_camera_center() 74 | R = look_at_rotation(torch.zeros_like(cam_pos), at=F.normalize(torch.cross(cam_pos, torch.rand_like(cam_pos)), dim=-1), up=cam_pos) 75 | light_directions = Rotate(R=R.transpose(1,2), device=cams.device).transform_points(light_directions) 76 | # trimesh.Trimesh(vertices=torch.cat([cam_pos, light_directions[0]], dim=0).cpu().numpy(), process=False).export('tests/outputs/light_dir.ply') 77 | ambient_color = torch.FloatTensor((((0.2, 0.2, 0.2), ), )) 78 | diffuse_color = torch.FloatTensor( 79 | (((0.0, 0.0, 0.8), (0.0, 0.8, 0.0), (0.8, 0.0, 0.0), ), )) 80 | if has_specular: 81 | specular_color = 0.15 * diffuse_color 82 | diffuse_color *= 0.85 83 | else: 84 | specular_color = (((0, 0, 0), (0, 0, 0), (0, 0, 0), ), ) 85 | if not point_lights: 86 | lights = DirectionalLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 87 | specular_color=specular_color, direction=light_directions) 88 | else: 89 | location = light_directions*5 90 | lights = PointLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 91 | specular_color=specular_color, location=location) 92 | return lights 93 | 94 | def get_light_for_view(cams, has_specular): 95 | # create tri-color lights and a specular+diffuse shader 96 | ambient_color = torch.FloatTensor((((0.6, 0.6, 0.6),),)) 97 | diffuse_color = torch.FloatTensor( 98 | (((0.2, 0.2, 0.2),),)) 99 | 100 | if opt.has_specular: 101 | specular_color = 0.15 * diffuse_color 102 | diffuse_color *= 0.85 103 | else: 104 | specular_color = (((0, 0, 0),),) 105 | 106 | elev = torch.FloatTensor(((random.randint(10, 90),),)) 107 | azim = torch.FloatTensor(((random.randint(0, 360)),)) 108 | elev = math.pi / 180.0 * elev 109 | azim = math.pi / 180.0 * azim 110 | 111 | x = torch.cos(elev) * torch.sin(azim) 112 | y = torch.sin(elev) 113 | z = torch.cos(elev) * torch.cos(azim) 114 | light_directions = torch.stack([x, y, z], dim=-1) 115 | # transform from camera to world 116 | light_directions = cams.get_world_to_view_transform().inverse().transform_points(light_directions) 117 | if not point_lights: 118 | lights = DirectionalLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 119 | specular_color=specular_color, direction=light_directions) 120 | else: 121 | location = light_directions*5 122 | lights = PointLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 123 | specular_color=specular_color, location=location) 124 | return lights -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | from easydict import EasyDict as edict 4 | import torch 5 | import pytorch3d 6 | import trimesh 7 | import numpy as np 8 | from pytorch3d.utils import ico_sphere 9 | from pytorch3d.ops import sample_points_from_meshes 10 | from DSS.core.texture import LightingTexture, NeuralTexture 11 | from DSS.utils import get_class_from_string 12 | from DSS.training.trainer import Trainer 13 | from DSS import set_debugging_mode_ 14 | from DSS import logger_py 15 | 16 | 17 | # General config 18 | def load_config(path, default_path=None): 19 | ''' Loads config file. 20 | 21 | Args: 22 | path (str): path to config file 23 | default_path (bool): whether to use default path 24 | ''' 25 | # Load configuration from file itself 26 | cfg_special = None 27 | with open(path, 'r') as f: 28 | cfg_special = edict(yaml.load(f, Loader=yaml.Loader)) 29 | 30 | # Check if we should inherit from a config 31 | inherit_from = cfg_special.get('inherit_from') 32 | 33 | # If yes, load this config first as default 34 | # If no, use the default_path 35 | if inherit_from is not None: 36 | cfg = load_config(inherit_from, default_path) 37 | elif default_path is not None: 38 | with open(default_path, 'r') as f: 39 | cfg = edict(yaml.load(f, Loader=yaml.Loader)) 40 | else: 41 | cfg = edict() 42 | 43 | # Include main configuration 44 | update_recursive(cfg, cfg_special) 45 | 46 | return cfg 47 | 48 | 49 | def save_config(path, config): 50 | """ 51 | Save config dictionary as json file 52 | """ 53 | out_dir = os.path.dirname(path) 54 | if not os.path.exists(out_dir): 55 | os.makedirs(out_dir) 56 | 57 | if os.path.isfile(path): 58 | logger_py.warn( 59 | "Found file existing in {}, overwriting the existing file.".format(out_dir)) 60 | 61 | with open(path, 'w') as f: 62 | yaml.dump(config, f, sort_keys=False) 63 | 64 | logger_py.info("Saved config to {}".format(path)) 65 | 66 | 67 | def update_recursive(dict1, dict2): 68 | ''' Update two config dictionaries recursively. 69 | 70 | Args: 71 | dict1 (dict): first dictionary to be updated 72 | dict2 (dict): second dictionary which entries should be used 73 | 74 | ''' 75 | for k, v in dict2.items(): 76 | if k not in dict1: 77 | dict1[k] = edict() 78 | if isinstance(v, dict): 79 | update_recursive(dict1[k], v) 80 | else: 81 | dict1[k] = v 82 | 83 | 84 | def _get_tensor_with_default(opt, key, size, fill_value=0.0): 85 | if key not in opt: 86 | return torch.zeros(*size).fill_(fill_value) 87 | else: 88 | return torch.FloatTensor(opt[key]) 89 | 90 | 91 | def create_point_texture(opt_renderer_texture): 92 | from DSS.core.texture import (NeuralTexture, LightingTexture) 93 | """ create shader that generate per-point color """ 94 | if opt_renderer_texture.texture.is_neural_shader: 95 | texture = NeuralTexture(opt_renderer_texture.texture) 96 | else: 97 | lights = create_lights(opt_renderer_texture.get('lights', None)) 98 | texture = LightingTexture( 99 | specular=opt_renderer_texture.texture.specular, lights=lights) 100 | 101 | return texture 102 | 103 | 104 | def create_lights(opt_renderer_texture_lights): 105 | """ 106 | Create lights specified by opt, if no sun or point lights 107 | are given, create the tri-color lights. 108 | Currently only supports the same lights for all batches 109 | """ 110 | from DSS.core.lighting import (DirectionalLights, PointLights) 111 | ambient_color = torch.tensor( 112 | opt_renderer_texture_lights.ambient_color).view(1, -1, 3) 113 | specular_color = torch.tensor( 114 | opt_renderer_texture_lights.specular_color).view(1, -1, 3) 115 | diffuse_color = torch.tensor( 116 | opt_renderer_texture_lights.diffuse_color).view(1, -1, 3) 117 | if opt_renderer_texture_lights['type'] == "sun": 118 | direction = torch.tensor( 119 | opt_renderer_texture_lights.direction).view(1, -1, 3) 120 | lights = DirectionalLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 121 | specular_color=specular_color, direction=direction) 122 | elif opt_renderer_texture_lights['type'] == 'point': 123 | location = torch.tensor( 124 | opt_renderer_texture_lights.location).view(1, -1, 3) 125 | lights = PointLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 126 | specular_color=specular_color, location=location) 127 | 128 | return lights 129 | 130 | 131 | def create_cameras(opt): 132 | pass 133 | 134 | 135 | def create_dataset(opt_data, mode="train"): 136 | import DSS.utils.dataset as DssDataset 137 | if opt_data.type == 'MVR': 138 | dataset = DssDataset.MVRDataset(**opt_data, mode=mode) 139 | elif opt_data.type == 'DTU': 140 | dataset = DssDataset.DTUDataset(**opt_data, mode=mode) 141 | else: 142 | raise NotImplementedError 143 | return dataset 144 | 145 | 146 | def create_model(cfg, device, mode="train", camera_model=None, **kwargs): 147 | ''' Returns model 148 | 149 | Args: 150 | cfg (edict): imported yaml config 151 | device (device): pytorch device 152 | ''' 153 | decoder = cfg['model']['decoder'] 154 | encoder = cfg['model']['encoder'] 155 | 156 | if mode == 'test' and cfg.model.type == 'combined': 157 | cfg.model.type = 'implicit' 158 | 159 | if cfg.model.type == 'point': 160 | decoder = None 161 | 162 | if decoder is not None: 163 | c_dim = cfg['model']['c_dim'] 164 | Decoder = get_class_from_string(cfg.model.decoder_type) 165 | decoder = Decoder( 166 | c_dim=c_dim, dim=3, **cfg.model.decoder_kwargs).to(device=device) 167 | logger_py.info("Created Decoder {}".format(decoder.__class__)) 168 | logger_py.info(decoder) 169 | # initialize siren model to be a sphere 170 | if cfg.model.decoder_type == 'DSS.models.common.Siren': 171 | decoder_kwargs = cfg.model.decoder_kwargs 172 | if cfg.training.init_siren_from_sphere: 173 | try: 174 | pretrained_model_file = os.path.join('data', 'trained_model', 'siren_l{}_c{}_o{}.pt'.format( 175 | decoder_kwargs.n_layers, decoder_kwargs.hidden_size, decoder_kwargs.first_omega_0)) 176 | loaded_state_dict = torch.load(pretrained_model_file) 177 | decoder.load_state_dict(loaded_state_dict) 178 | logger_py.info('initialized Siren decoder with {}'.format( 179 | pretrained_model_file)) 180 | except Exception: 181 | pass 182 | 183 | texture = None 184 | use_lighting = (cfg.renderer is not None and not cfg.renderer.get( 185 | 'is_neural_texture', True)) 186 | if use_lighting: 187 | texture = LightingTexture() 188 | else: 189 | if 'rgb' not in cfg.model.decoder_kwargs.out_dims: 190 | Texture = get_class_from_string(cfg.model.texture_type) 191 | cfg.model.texture_kwargs['c_dim'] = cfg.model.decoder_kwargs.out_dims.get('latent', 0) 192 | texture_decoder = Texture(**cfg.model.texture_kwargs) 193 | else: 194 | texture_decoder = decoder 195 | logger_py.info("Decoder used as NeuralTexture") 196 | 197 | texture = NeuralTexture( 198 | view_dependent=cfg.model.texture_kwargs.view_dependent, decoder=texture_decoder).to(device=device) 199 | logger_py.info("Created NeuralTexture {}".format(texture.__class__)) 200 | logger_py.info(texture) 201 | 202 | Model = get_class_from_string( 203 | "DSS.models.{}_modeling.Model".format(cfg.model.type)) 204 | 205 | if cfg.model.type == 'implicit': 206 | model = Model(decoder, renderer=None, 207 | texture=texture, encoder=encoder, cameras=camera_model, 208 | device=device, **cfg.model.model_kwargs) 209 | 210 | elif cfg.model.type == 'combined': 211 | renderer = create_renderer(cfg.renderer).to(device) 212 | # TODO: load 213 | points = None 214 | point_file = os.path.join( 215 | cfg.training.out_dir, cfg.name, cfg.training.point_file) 216 | if os.path.isfile(point_file): 217 | # load point or mesh then sample 218 | loaded_shape = trimesh.load(point_file) 219 | if isinstance(loaded_shape, trimesh.PointCloud): 220 | # overide n_points_per_cloud 221 | cfg.model.model_kwargs.n_points_per_cloud = loaded_shape.vertices.shape[0] 222 | points = loaded_shape.vertices 223 | else: 224 | n_points = cfg.model.model_kwargs['n_points_per_cloud'] 225 | try: 226 | # reject sampling can produce less points, hence sample more 227 | points = trimesh.sample.sample_surface_even(loaded_shape, 228 | int(n_points * 1.1), 229 | radius=0.01)[0] 230 | p_idx = np.random.permutation( 231 | loaded_shape.vertices.shape[0])[:n_points] 232 | points = points[p_idx, ...] 233 | except Exception: 234 | # randomly 235 | p_idx = np.random.permutation(loaded_shape.vertices.shape[0])[ 236 | :n_points] 237 | points = loaded_shape.vertices[p_idx, ...] 238 | 239 | points = torch.tensor(points, dtype=torch.float, device=device) 240 | 241 | model = Model( 242 | decoder, renderer, texture=texture, encoder=encoder, cameras=camera_model, device=device, points=points, 243 | **cfg.model.model_kwargs) 244 | 245 | else: 246 | ValueError('model type must be combined|point|implicit|occupancy') 247 | 248 | return model 249 | 250 | 251 | def create_generator(cfg, model, device, **kwargs): 252 | ''' Returns the generator object. 253 | 254 | Args: 255 | model (nn.Module): model 256 | cfg (dict): imported yaml config 257 | device (device): pytorch device 258 | ''' 259 | decoder = cfg.model.decoder_type 260 | Generator = get_class_from_string( 261 | 'DSS.models.{}_modeling.Generator'.format(cfg.model.type)) 262 | 263 | generator = Generator(model, device, 264 | threshold=cfg['test']['threshold'], 265 | **cfg.generation) 266 | return generator 267 | 268 | 269 | def create_trainer(cfg, model, optimizer, scheduler, generator, train_loader, val_loader, device, **kwargs): 270 | ''' Returns the trainer object. 271 | 272 | Args: 273 | model (nn.Module): the model 274 | optimizer (optimizer): pytorch optimizer object 275 | cfg (dict): imported yaml config 276 | device (device): pytorch device 277 | generator (Generator): generator instance to 278 | generate meshes for visualization 279 | ''' 280 | threshold = cfg['test']['threshold'] 281 | out_dir = os.path.join(cfg['training']['out_dir'], cfg['name']) 282 | vis_dir = os.path.join(out_dir, 'vis') 283 | debug_dir = os.path.join(out_dir, 'debug') 284 | log_dir = os.path.join(out_dir, 'logs') 285 | val_dir = os.path.join(out_dir, 'val') 286 | depth_from_visual_hull = cfg['data']['depth_from_visual_hull'] 287 | depth_range = cfg['data']['depth_range'] 288 | 289 | trainer = Trainer( 290 | model, optimizer, scheduler, generator, train_loader, val_loader, 291 | device=device, 292 | vis_dir=vis_dir, debug_dir=debug_dir, log_dir=log_dir, val_dir=val_dir, 293 | threshold=threshold, 294 | depth_from_visual_hull=depth_from_visual_hull, 295 | depth_range=depth_range, 296 | **cfg.training) 297 | 298 | return trainer 299 | 300 | 301 | def create_renderer(render_opt): 302 | """ Create rendere """ 303 | Renderer = get_class_from_string(render_opt.renderer_type) 304 | Raster = get_class_from_string(render_opt.raster_type) 305 | i = render_opt.raster_type.rfind('.') 306 | raster_setting_type = render_opt.raster_type[:i] + \ 307 | '.PointsRasterizationSettings' 308 | if render_opt.compositor_type is not None: 309 | Compositor = get_class_from_string(render_opt.compositor_type) 310 | compositor = Compositor() 311 | else: 312 | compositor = None 313 | 314 | RasterSetting = get_class_from_string(raster_setting_type) 315 | raster_settings = RasterSetting(**render_opt.raster_params) 316 | 317 | renderer = Renderer( 318 | rasterizer=Raster( 319 | cameras=None, raster_settings=raster_settings), 320 | compositor=compositor, 321 | ) 322 | return renderer 323 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch3d 2 | channels: 3 | - pytorch 4 | - pytorch3d 5 | - conda-forge 6 | - bottler 7 | - defaults 8 | dependencies: 9 | - python=3.8 10 | - cudatoolkit=10.2 11 | - pytorch::pytorch=1.6.0 12 | - pytorch::torchvision 13 | - fvcore 14 | - jupyter 15 | - scikit-image 16 | - matplotlib 17 | - imageio 18 | - pytorch3d::pytorch3d 19 | - pip 20 | - tensorboard=1.15.0 21 | - trimesh 22 | - cython 23 | - pyyaml 24 | - pandas 25 | - easydict 26 | - nvidiacub 27 | prefix: /home/ywang/anaconda3/envs/pytorch3d 28 | 29 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import csv 4 | import numpy as np 5 | from glob import glob 6 | from collections import OrderedDict, defaultdict 7 | import config 8 | import trimesh 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import pytorch3d 13 | from pytorch3d.structures import Meshes, Pointclouds 14 | from pytorch3d.loss import point_mesh_face_distance, chamfer_distance 15 | from pytorch3d.ops import sample_points_from_meshes 16 | import point_cloud_utils as pcu 17 | from DSS import set_deterministic_ 18 | from DSS.utils.io import read_ply 19 | 20 | """ 21 | Given an experiment folder, evaluate the meshes `vis/*_mesh.ply` and `generation/mesh00.ply`, 22 | write the results in `val/all.csv` 23 | """ 24 | 25 | set_deterministic_() 26 | 27 | 28 | def get_filenames(source, extension): 29 | # If extension is a list 30 | if source is None: 31 | return [] 32 | # Seamlessy load single file, list of files and files from directories. 33 | source_fns = [] 34 | if isinstance(source, str): 35 | if os.path.isdir(source): 36 | if not isinstance(extension, str): 37 | for fmt in extension: 38 | source_fns += get_filenames(source, fmt) 39 | else: 40 | source_fns = sorted( 41 | glob("{}/**/*{}".format(source, extension), recursive=True)) 42 | elif os.path.isfile(source): 43 | source_fns = [source] 44 | elif len(source) and isinstance(source[0], str): 45 | for s in source: 46 | source_fns.extend(get_filenames(s, extension=extension)) 47 | return source_fns 48 | 49 | 50 | def eval_one_dir(exp_dir, n_pts=50000): 51 | """ 52 | Function for one directory 53 | """ 54 | device = torch.device('cuda:0') 55 | cfg = config.load_config(os.path.join(exp_dir, 'config.yaml')) 56 | dataset = config.create_dataset(cfg.data, mode='val') 57 | meshes_gt = dataset.get_meshes().to(device) 58 | val_gt_pts_file = os.path.join(cfg.data.data_dir, 'val%d.ply' % n_pts) 59 | if os.path.isfile(val_gt_pts_file): 60 | points, normals = np.split(read_ply(val_gt_pts_file), 2, axis=1) 61 | pcl_gt = Pointclouds(torch.from_numpy(points[None, ...]).float(), 62 | torch.from_numpy(normals[None, ...]).float()).to(device) 63 | else: 64 | pcl_gt = dataset.get_pointclouds(n_pts).to(device) 65 | trimesh.Trimesh(pcl_gt.points_packed().cpu().numpy(), 66 | vertex_normals=pcl_gt.normals_packed().cpu().numpy(), process=False).export( 67 | val_gt_pts_file, vertex_normal=True 68 | ) 69 | 70 | # load vis directories 71 | vis_dir = os.path.join(exp_dir, 'vis') 72 | vis_files = sorted(get_filenames(vis_dir, '_mesh.ply')) 73 | iters = [int(os.path.basename(v).split('_')[0]) for v in vis_files] 74 | best_dict = defaultdict(lambda: float('inf')) 75 | vis_eval_csv = os.path.join(vis_dir, "evaluation_n%d.csv" % n_pts) 76 | if not os.path.isfile(vis_eval_csv): 77 | with open(os.path.join(vis_dir, "evaluation_n%d.csv" % n_pts), "w") as f: 78 | fieldnames = ['mtime', 'it', 'chamfer_p', 'chamfer_n', 'pf_dist'] 79 | writer = csv.DictWriter(f, fieldnames=fieldnames, 80 | restval="-", extrasaction="ignore") 81 | writer.writeheader() 82 | mtime0 = None 83 | for it, vis_file in zip(iters, vis_files): 84 | eval_dict = OrderedDict() 85 | mtime = os.path.getmtime(vis_file) 86 | if mtime0 is None: 87 | mtime0 = mtime 88 | eval_dict['it'] = it 89 | eval_dict['mtime'] = mtime - mtime0 90 | val_pts_file = os.path.join(vis_dir, os.path.basename( 91 | vis_file).replace('_mesh', '_val%d' % n_pts)) 92 | if os.path.isfile(val_pts_file): 93 | points, normals = np.split( 94 | read_ply(val_pts_file), 2, axis=1) 95 | points = torch.from_numpy(points).float().to( 96 | device=device).view(1, -1, 3) 97 | normals = torch.from_numpy(normals).float().to( 98 | device=device).view(1, -1, 3) 99 | else: 100 | mesh = trimesh.load(vis_file, process=False) 101 | # points, normals = pcu.sample_mesh_poisson_disk( 102 | # mesh.vertices, mesh.faces, 103 | # mesh.vertex_normals.ravel().reshape(-1, 3), n_pts, use_geodesic_distance=True) 104 | # p_idx = np.random.permutation(points.shape[0])[:n_pts] 105 | # points = points[p_idx, ...] 106 | # normals = normals[p_idx, ...] 107 | # points = torch.from_numpy(points).float().to( 108 | # device=device).view(1, -1, 3) 109 | # normals = torch.from_numpy(normals).float().to( 110 | # device=device).view(1, -1, 3) 111 | meshes = Meshes(torch.from_numpy(mesh.vertices[None, ...]).float(), 112 | torch.from_numpy(mesh.faces[None, ...]).float()).to(device) 113 | points, normals = sample_points_from_meshes( 114 | meshes, n_pts, return_normals=True) 115 | trimesh.Trimesh(points.cpu().numpy()[0], vertex_normals=normals.cpu().numpy()[0], process=False).export( 116 | val_pts_file, vertex_normal=True 117 | ) 118 | pcl = Pointclouds(points, normals) 119 | chamfer_p, chamfer_n = chamfer_distance( 120 | points, pcl_gt.points_padded(), 121 | x_normals=normals, y_normals=pcl_gt.normals_padded(), 122 | ) 123 | eval_dict['chamfer_p'] = chamfer_p.item() 124 | eval_dict['chamfer_n'] = chamfer_n.item() 125 | pf_dist = point_mesh_face_distance(meshes_gt, pcl) 126 | eval_dict['pf_dist'] = pf_dist.item() 127 | writer.writerow(eval_dict) 128 | for k, v in eval_dict.items(): 129 | if v < best_dict[k]: 130 | best_dict[k] = v 131 | print('best {} so far ({}): {:.4g}'.format(k, vis_file, v)) 132 | 133 | # generation dictories 134 | gen_dir = os.path.join(exp_dir, 'generation') 135 | if not os.path.isdir(gen_dir): 136 | return 137 | 138 | final_file = os.path.join(gen_dir, 'mesh.ply') 139 | val_pts_file = final_file[:-4] + '_val%d' % n_pts + '.ply' 140 | if not os.path.isfile(final_file): 141 | return 142 | 143 | gen_file_csv = os.path.join(gen_dir, "evaluation_n%d.csv" % n_pts) 144 | if not os.path.isfile(gen_file_csv): 145 | with open(os.path.join(gen_dir, "evaluation_n%d.csv" % n_pts), "w") as f: 146 | fieldnames = ['chamfer_p', 'chamfer_n', 'pf_dist'] 147 | writer = csv.DictWriter(f, fieldnames=fieldnames, 148 | restval="-", extrasaction="ignore") 149 | writer.writeheader() 150 | eval_dict = OrderedDict() 151 | mesh = trimesh.load(final_file) 152 | # points, normals = pcu.sample_mesh_poisson_disk( 153 | # mesh.vertices, mesh.faces, 154 | # mesh.vertex_normals.ravel().reshape(-1, 3), n_pts, use_geodesic_distance=True) 155 | # p_idx = np.random.permutation(points.shape[0])[:n_pts] 156 | # points = points[p_idx, ...] 157 | # normals = normals[p_idx, ...] 158 | # points = torch.from_numpy(points).float().to( 159 | # device=device).view(1, -1, 3) 160 | # normals = torch.from_numpy(normals).float().to( 161 | # device=device).view(1, -1, 3) 162 | meshes = Meshes(torch.from_numpy(mesh.vertices[None, ...]).float(), 163 | torch.from_numpy(mesh.faces[None, ...]).float()).to(device) 164 | points, normals = sample_points_from_meshes( 165 | meshes, n_pts, return_normals=True) 166 | trimesh.Trimesh(points.cpu().numpy()[0], vertex_normals=normals.cpu().numpy()[0], process=False).export( 167 | val_pts_file, vertex_normal=True) 168 | pcl = Pointclouds(points, normals) 169 | chamfer_p, chamfer_n = chamfer_distance( 170 | points, pcl_gt.points_padded(), 171 | x_normals=normals, y_normals=pcl_gt.normals_padded(), 172 | ) 173 | eval_dict['chamfer_p'] = chamfer_p.item() 174 | eval_dict['chamfer_n'] = chamfer_n.item() 175 | pf_dist = point_mesh_face_distance(meshes_gt, pcl) 176 | eval_dict['pf_dist'] = pf_dist.item() 177 | writer.writerow(eval_dict) 178 | for k, v in eval_dict.items(): 179 | if v < best_dict[k]: 180 | best_dict[k] = v 181 | print('best {} so far ({}): {:.4g}'.format(k, final_file, v)) 182 | 183 | 184 | if __name__ == '__main__': 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--dirs", type=str, nargs='+', required=True, 187 | help="Experiment directories") 188 | parser.add_argument("--n_pts", type=int, default=50000, 189 | help="number of points used for evaluation") 190 | args = parser.parse_args() 191 | for exp in args.dirs: 192 | eval_one_dir(exp, args.n_pts) 193 | -------------------------------------------------------------------------------- /generate_mvr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from numbers import Number 3 | # import torch.distributions as dist 4 | import os 5 | import shutil 6 | import argparse 7 | import trimesh 8 | from tqdm import tqdm 9 | import time 10 | from collections import defaultdict 11 | import pandas as pd 12 | import numpy as np 13 | import config 14 | from DSS.misc.checkpoints import CheckpointIO 15 | import imageio 16 | import plotly.graph_objs as go 17 | from DSS import logger_py 18 | from DSS.utils import get_surface_high_res_mesh 19 | 20 | 21 | parser = argparse.ArgumentParser( 22 | description='Extract meshes from occupancy process.' 23 | ) 24 | parser.add_argument('config', type=str, help='Path to config file.') 25 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 26 | parser.add_argument('--img_size', type=int, nargs='*', help="overwrite original image size") 27 | parser.add_argument('--resolution', type=int, default=512, 28 | help='Overrites the default resolution in config') 29 | parser.add_argument('--mesh-only', action='store_true') 30 | parser.add_argument('--render-only', action='store_true') 31 | 32 | 33 | args = parser.parse_args() 34 | cfg = config.load_config(args.config, 'configs/default.yaml') 35 | 36 | 37 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 38 | device = torch.device("cuda" if is_cuda else "cpu") 39 | 40 | # Shortcuts 41 | out_dir = os.path.join(cfg['training']['out_dir'], cfg['name']) 42 | generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) 43 | if not os.path.exists(generation_dir): 44 | os.makedirs(generation_dir) 45 | 46 | batch_size = 1 47 | vis_n_outputs = cfg['generation']['vis_n_outputs'] 48 | mesh_extension = cfg['generation']['mesh_extension'] 49 | 50 | # Dataset 51 | dataset = config.create_dataset(cfg.data, mode='test') 52 | test_loader = torch.utils.data.DataLoader( 53 | dataset, batch_size=batch_size, num_workers=1, shuffle=False) 54 | img_size = args.img_size or dataset.resolution 55 | if isinstance(img_size, Number): 56 | img_size = (img_size, img_size) 57 | 58 | # Model 59 | model = config.create_model(cfg, mode='test', device=device, camera_model=dataset.get_cameras()).to(device=device) 60 | 61 | checkpoint_io = CheckpointIO(out_dir, model=model) 62 | checkpoint_io.load(cfg['test']['model_file']) 63 | 64 | # Generator 65 | generator = config.create_generator(cfg, model, device=device) 66 | 67 | torch.manual_seed(0) 68 | 69 | # Generate 70 | with torch.autograd.no_grad(): 71 | model.eval() 72 | # Generate meshes 73 | if not args.render_only: 74 | logger_py.info('Generating mesh...') 75 | mesh = get_surface_high_res_mesh(lambda x: model.decode(x).sdf.squeeze(), resolution=args.resolution) 76 | if cfg.data.type == 'DTU': 77 | mesh.apply_transform(dataset.get_scale_mat()) 78 | mesh_out_file = os.path.join(generation_dir, 'mesh.%s' % mesh_extension) 79 | mesh.export(mesh_out_file) 80 | 81 | # Generate cuts 82 | logger_py.info('Generating cross section plots') 83 | img = generator.generate_iso_contour(imgs_per_cut=5) 84 | out_file = os.path.join(generation_dir, 'iso') 85 | img.write_html(out_file + '.html') 86 | 87 | if not args.mesh_only: 88 | # generate images 89 | for i, batch in enumerate(test_loader): 90 | img_mask = batch['img.mask'] 91 | cam_mat = batch['camera_mat'] 92 | cameras = dataset.get_cameras(cam_mat) 93 | lights = dataset.get_lights(**batch.get('lights', {})) 94 | rgbas = generator.raytrace_images(img_size, img_mask, cameras=cameras, lights=lights) 95 | for rgba in rgbas: 96 | imageio.imwrite(os.path.join(generation_dir, '%05d.png' % i), rgba) 97 | torch.cuda.empty_cache() 98 | -------------------------------------------------------------------------------- /images/idr-mvr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/iso-points/3a126b3765bb895a8fb148e8186befe2b3a35f73/images/idr-mvr.png -------------------------------------------------------------------------------- /images/idr-rabbit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/iso-points/3a126b3765bb895a8fb148e8186befe2b3a35f73/images/idr-rabbit.png -------------------------------------------------------------------------------- /images/points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/iso-points/3a126b3765bb895a8fb148e8186befe2b3a35f73/images/points.png -------------------------------------------------------------------------------- /images/sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/iso-points/3a126b3765bb895a8fb148e8186befe2b3a35f73/images/sampling.png -------------------------------------------------------------------------------- /images/siren-pointcloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/iso-points/3a126b3765bb895a8fb148e8186befe2b3a35f73/images/siren-pointcloud.png -------------------------------------------------------------------------------- /images/siren-synthetic-mvr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/iso-points/3a126b3765bb895a8fb148e8186befe2b3a35f73/images/siren-synthetic-mvr.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.17 2 | easydict==1.9 3 | -e git+https://github.com/lxxue/FRNN.git@eab337feaa61074d63e08af58f83fe87fe87bf3f#egg=frnn 4 | future @ file:///home/conda/feedstock_root/build_artifacts/future_1602538316091/work 5 | fvcore @ file:///home/conda/feedstock_root/build_artifacts/fvcore_1604059542311/work 6 | imageio==2.8.0 7 | matplotlib==3.2.1 8 | numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1597938342049/work 9 | pickleshare==0.7.5 10 | Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1594213113401/work 11 | plotly==4.7.1 12 | plyfile==0.7.2 13 | point-cloud-utils==0.8.0 14 | prefix-sum==0.0.0 15 | protobuf==3.13.0 16 | pymeshlab==0.1.8 17 | pypng==0.0.20 18 | pypoisson==0.10 19 | pytest==6.2.1 20 | PyYAML==5.3.1 21 | scikit-image==0.16.2 22 | tensorboard==1.15.0 23 | torch==1.6.0 24 | # Editable install with no version control (torch-batch-svd==0.0.0) 25 | -e /home/ywang/anaconda3/envs/pytorch3d/lib/python3.8/site-packages/torch_batch_svd-0.0.0-py3.8-linux-x86_64.egg 26 | torchvision==0.7.0 27 | tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1596476591553/work 28 | trimesh==3.6.34 29 | typing-extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1602702424206/work 30 | -------------------------------------------------------------------------------- /scripts/create_mvr_data_from_mesh.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create Synthetic MVR data using pytorch3d point renderer 3 | per-shape: 4 | data_dict.npz: 5 | cameras_type 6 | cameras_params 7 | lights_type 8 | camera_mat [V,4,4] matrix 9 | saving per-view: 10 | mask (png) 11 | RGB (png) 12 | lights (dict) 13 | camera_mat (4,4) 14 | ------------------ 15 | (used for dvr only) 16 | cameras.npz 17 | camera_mat_%d (4,4) projection scaling part (top-left 2x2 matrix from pytorch3d projection matrix) 18 | world_mat_%d (4,4) source-to-view matrix 19 | scale_mat_%d (4,4) identity matrix 20 | pcl.npz (sparse point clouds) 21 | points 22 | colors 23 | normals 24 | """ 25 | from pytorch3d.renderer import ( 26 | RasterizationSettings, 27 | FoVPerspectiveCameras, 28 | MeshRenderer, 29 | MeshRasterizer, 30 | HardFlatShader, 31 | ) 32 | from pytorch3d.ops import eyes, sample_points_from_meshes 33 | from pytorch3d.io import load_obj, load_ply, save_obj 34 | from itertools import chain 35 | from glob import glob 36 | import numpy as np 37 | import imageio 38 | import argparse 39 | import os 40 | import torch 41 | from DSS.core.camera import CameraSampler 42 | from DSS.core.lighting import PointLights, DirectionalLights 43 | from pytorch3d.renderer import Textures 44 | from pytorch3d.structures import Meshes 45 | from DSS.utils import convert_tensor_property_to_value_dict 46 | from common import get_tri_color_lights_for_view, get_light_for_view 47 | 48 | # torch.manual_seed(0) 49 | # torch.backends.cudnn.deterministic = True 50 | # torch.backends.cudnn.benchmark = False 51 | # np.random.seed(0) 52 | 53 | 54 | def get_names_and_paths(opt): 55 | points_paths = list(chain.from_iterable(glob(p) for p in opt.points)) 56 | assert(len(points_paths) > 0), "Found no point clouds in with path {}".format( 57 | points_paths) 58 | 59 | if len(points_paths) > 1: 60 | points_dir = os.path.commonpath(points_paths) 61 | points_relpaths = [os.path.relpath( 62 | p, points_dir) for p in points_paths] 63 | else: 64 | points_relpaths = [os.path.basename(p) for p in points_paths] 65 | 66 | name_and_path = {(os.path.splitext(rel_path)[0].replace(os.path.sep, "_"), path) 67 | for rel_path, path in zip(points_relpaths, points_paths)} 68 | return name_and_path 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser( 73 | "Create Synthetic MVR data saving per-view: RGBA, camera matrix, depth") 74 | parser.add_argument("--points", required=True, nargs="+", 75 | help="String to grob point clouds, e.g. \"data/**/*.ply\"") 76 | parser.add_argument("--num_cameras", type=int, default=120) 77 | parser.add_argument("--image-size", type=int, default=512) 78 | parser.add_argument("--output", type=str, default="data") 79 | parser.add_argument("--tri_color_light", action='store_true') 80 | parser.add_argument("--point_lights", action='store_true') 81 | parser.add_argument("--has_specular", action='store_true') 82 | parser.add_argument("--min_dist", type=float, default=1.2) 83 | parser.add_argument("--max_dist", type=float, default=2.2) 84 | parser.add_argument("--znear", type=float, default=0.1) 85 | opt, _ = parser.parse_known_args() 86 | 87 | device = torch.device("cuda:0") 88 | torch.cuda.set_device(device) 89 | 90 | names_and_path = get_names_and_paths(opt) 91 | for mesh_name, mesh_path in names_and_path: 92 | output_dir = os.path.join(opt.output, mesh_name +'_variational_light') 93 | rgb_dir = os.path.join(output_dir, "image") 94 | mask_dir = os.path.join(output_dir, "mask") 95 | depth_dir = os.path.join(output_dir, "depth") 96 | os.makedirs(output_dir, exist_ok=True) 97 | os.makedirs(rgb_dir, exist_ok=True) 98 | os.makedirs(mask_dir, exist_ok=True) 99 | os.makedirs(depth_dir, exist_ok=True) 100 | 101 | # load and normalize mesh 102 | if os.path.splitext(mesh_path)[1].lower() == ".ply": 103 | verts, faces = load_ply(mesh_path) 104 | verts_idx = faces 105 | elif os.path.splitext(mesh_path)[1].lower() == ".obj": 106 | verts, faces, aux = load_obj(mesh_path) 107 | verts_idx = faces.verts_idx 108 | else: 109 | raise NotImplementedError 110 | 111 | # # normalize to unit box 112 | # vert_range = (verts.max(dim=0)[0] - verts.min(dim=0)[0]).max() 113 | # vert_center = (verts.max(dim=0)[0] + verts.min(dim=0)[0]) / 2 114 | # verts -= vert_center 115 | # verts /= vert_range 116 | 117 | # normalize to unit sphere 118 | vert_center = (verts.max(dim=0)[0] + verts.min(dim=0)[0]) / 2 119 | verts -= vert_center 120 | vert_scale = torch.norm(verts, dim=1).max() 121 | verts /= vert_scale 122 | 123 | save_obj(os.path.join(output_dir, "mesh.obj"), 124 | verts=verts, faces=verts_idx) 125 | textures = Textures(verts_rgb=torch.ones( 126 | 1, verts.shape[0], 3)).to(device=device) 127 | meshes = Meshes(verts=[verts], faces=[verts_idx], 128 | textures=textures).to(device=device) 129 | 130 | # Initialize an OpenGL perspective camera. 131 | batch_size = 1 132 | camera_params = {"znear": opt.znear} 133 | camera_sampler = CameraSampler(opt.num_cameras, 134 | batch_size, distance_range=torch.tensor( 135 | ((opt.min_dist, opt.max_dist),)), # min distance should be larger than znear+obj_dim 136 | sort_distance=True, 137 | camera_type=FoVPerspectiveCameras, 138 | camera_params=camera_params) 139 | 140 | 141 | # Define the settings for rasterization and shading. 142 | # Refer to raster_points.py for explanations of these parameters. 143 | raster_settings = RasterizationSettings( 144 | image_size=opt.image_size, 145 | blur_radius=0.0, 146 | faces_per_pixel=5, 147 | # this setting controls whether naive or coarse-to-fine rasterization is used 148 | bin_size=None, 149 | max_faces_per_bin=None # this setting is for coarse rasterization 150 | ) 151 | 152 | renderer = MeshRenderer( 153 | rasterizer=MeshRasterizer( 154 | cameras=None, raster_settings=raster_settings), 155 | shader=HardFlatShader(device=device) 156 | ) 157 | renderer.to(device) 158 | 159 | if opt.point_lights: 160 | template_lights = PointLights() 161 | else: 162 | template_lights = DirectionalLights() 163 | 164 | # pcl_dict = {'points': pointclouds.points_padded[0].cpu().numpy()} 165 | data_dict = {"cameras_type": '.'.join([camera_sampler.camera_type.__module__, 166 | camera_sampler.camera_type.__name__]), 167 | "cameras_params": camera_params, 168 | "lights_type": '.'.join([template_lights.__module__, template_lights.__class__.__name__]), 169 | } 170 | num_points = 20000 171 | V, V_normal = sample_points_from_meshes( 172 | meshes, num_samples=num_points, return_normals=True) 173 | num_points = V.shape[1] 174 | data_dict['points'] = V[0].cpu().numpy() 175 | data_dict['normals'] = V_normal[0].cpu().numpy() 176 | data_dict['colors'] = np.ones_like( 177 | data_dict['points'], dtype=np.float32) 178 | data_dict['camera_mat'] = torch.empty(opt.num_cameras, 4, 4) 179 | 180 | # DVR data no projection step, assumes use SfMcamera 181 | cameras_dict = {} 182 | pcl_dict = {} 183 | pcl_dict['points'] = data_dict['points'] 184 | pcl_dict['normals'] = data_dict['normals'] 185 | pcl_dict['colors'] = data_dict['colors'] 186 | pcl_dict['is_in_visual_hull'] = np.full((num_points), True) 187 | 188 | idx = 0 189 | for c_idx, cams in enumerate(camera_sampler): 190 | meshes_batch = meshes.extend(batch_size) 191 | cams = cams.to(device) 192 | 193 | # create tri-color lights and a specular+diffuse shader 194 | if opt.tri_color_light: 195 | lights = get_tri_color_lights_for_view(cams, 196 | point_lights=opt.point_lights, has_specular=opt.has_specular) 197 | else: 198 | lights = get_light_for_view(cams, point_lights=opt.point_lights, has_specular=opt.has_specular) 199 | 200 | assert(type(lights) is type(template_lights)) 201 | lights.to(device=device) 202 | 203 | # renderer function (flat shading) 204 | fragments = renderer.rasterizer(meshes_batch, cameras=cams) 205 | images = renderer.shader( 206 | fragments, meshes_batch, lights=lights, cameras=cams) 207 | 208 | mask = fragments.pix_to_face[..., :1] >= 0 209 | mask_imgs = mask.to(dtype=torch.uint8) * 255 210 | 211 | # use hard alpha values 212 | images = torch.cat([images[..., :3], mask.float()], dim=-1) 213 | dense_depths = cams.zfar.view(-1, 1, 214 | 1, 1).clone().expand_as(mask_imgs) 215 | dense_depths = torch.where( 216 | mask, fragments.zbuf[..., :1], dense_depths) 217 | 218 | # cameras 219 | camera_mat = cams.get_projection_transform().get_matrix().cpu() 220 | world_mat = cams.get_world_to_view_transform().get_matrix().cpu() 221 | id_mat = np.eye(4) 222 | # DVR scales x,y and does the projection step manually (/z) 223 | dvr_camera_mat = eyes(4, camera_mat.shape[0]).to(camera_mat.device) 224 | dvr_camera_mat[:, :2, :2] = camera_mat[:, :2, :2] 225 | # dense depth read from rasterizer 226 | for b in range(images.shape[0]): 227 | # save camera data 228 | data_dict['camera_mat'][idx, ...] = world_mat[b] 229 | data_dict['lights_%d' % idx] = convert_tensor_property_to_value_dict(lights) 230 | 231 | # DVR camera data 232 | cameras_dict['world_mat_%d' % 233 | idx] = world_mat[b].transpose(0, 1) 234 | cameras_dict['scale_mat_%d' % idx] = id_mat 235 | cameras_dict['camera_mat_%d' % 236 | idx] = dvr_camera_mat[b].transpose(0, 1) 237 | # save dense depth 238 | imageio.imwrite(os.path.join(depth_dir, "%06d.exr" % idx), 239 | dense_depths[b, ...].cpu()) 240 | # save rgb 241 | imageio.imwrite(os.path.join(rgb_dir, "%06d.png" % idx), 242 | (images[b].cpu().numpy() * 255.0).astype('uint8'),) 243 | # save mask 244 | imageio.imwrite(os.path.join(mask_dir, "%06d.png" % idx), 245 | mask_imgs[b, ...].cpu()) 246 | idx += 1 247 | 248 | data_dict['camera_mat'] = data_dict['camera_mat'].tolist() 249 | np.savez(os.path.join(output_dir, "data_dict.npz"), 250 | allow_pickle=False, **data_dict) 251 | np.savez(os.path.join(output_dir, "cameras.npz"), 252 | allow_pickle=False, **cameras_dict) 253 | -------------------------------------------------------------------------------- /scripts/evaluatePointClouds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from glob import glob 6 | import re 7 | import csv 8 | from collections import OrderedDict 9 | from pytorch_points.network.operations import normalize_point_batch 10 | from pytorch_points.network.model_loss import nndistance 11 | from pytorch_points.utils.pc_utils import save_ply_property, load, save_ply_property 12 | 13 | 14 | def get_filenames(source, extension): 15 | # If extension is a list 16 | if source is None: 17 | return [] 18 | # Seamlessy load single file, list of files and files from directories. 19 | source_fns = [] 20 | if isinstance(source, str): 21 | if os.path.isdir(source): 22 | if not isinstance(extension, str): 23 | for fmt in extension: 24 | source_fns += get_filenames(source, fmt) 25 | else: 26 | source_fns = sorted( 27 | glob("{}/**/*{}".format(source, extension), recursive=True)) 28 | elif os.path.isfile(source): 29 | source_fns = [source] 30 | elif len(source) and isinstance(source[0], str): 31 | for s in source: 32 | source_fns.extend(get_filenames(s, extension=extension)) 33 | return source_fns 34 | 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--gt", type=str, required=True, help="directory or file name for ground truth point clouds") 38 | parser.add_argument("--pred", type=str, nargs="+", required=True, help="directorie of predictions") 39 | parser.add_argument("--name", type=str, required=True, help="name pattern if provided directory for pred and gt") 40 | FLAGS = parser.parse_args() 41 | if os.path.isdir(FLAGS.gt): 42 | GT_DIR = FLAGS.gt 43 | gt_paths = get_filenames(GT_DIR, ("ply", "pcd", "xyz")) 44 | gt_names = [os.path.basename(p)[:-4] for p in gt_paths] 45 | elif os.path.isfile(FLAGS.gt): 46 | gt_paths = [FLAGS.gt] 47 | 48 | PRED_DIR = FLAGS.pred 49 | NAME = FLAGS.name 50 | 51 | 52 | fieldnames = ["name", "CD", "hausdorff", "p2f avg", "p2f std"] + ["nuc_%d" % d for d in range(7)] 53 | print("{:60s} ".format("name"), "|".join(["{:>15s}".format(d) for d in fieldnames[1:]])) 54 | for D in PRED_DIR: 55 | avg_md_forward_value = 0 56 | avg_md_backward_value = 0 57 | avg_hd_value = 0 58 | counter = 0 59 | pred_paths = glob(os.path.join(D, "**", NAME), recursive=True) 60 | if len(pred_paths) == 1 and len(pred_paths) > 1: 61 | gt_pred_pairs = [] 62 | for p in pred_paths: 63 | name, ext = os.path.splitext(os.path.basename(p)) 64 | assert(ext in (".ply", ".xyz")) 65 | try: 66 | gt = gt_paths[gt_names.index(name)] 67 | except ValueError: 68 | pass 69 | else: 70 | gt_pred_pairs.append((gt, p)) 71 | else: 72 | gt_pred_pairs = [] 73 | for p in pred_paths: 74 | gt_pred_pairs.append((gt_paths[0], p)) 75 | 76 | # print("total inputs ", len(gt_pred_pairs)) 77 | # tag = re.search("/(\w+)/result", os.path.dirname(gt_pred_pairs[0][1])) 78 | tag = os.path.basename(os.path.dirname(gt_pred_pairs[0][1])) 79 | 80 | print("{:60s}".format(tag), end=' ') 81 | global_p2f = [] 82 | global_density = [] 83 | with open(os.path.join(os.path.dirname(gt_pred_pairs[0][1]), "evaluation.csv"), "w") as f: 84 | writer = csv.DictWriter(f, fieldnames=fieldnames, restval="-", extrasaction="ignore") 85 | writer.writeheader() 86 | for gt_path, pred_path in gt_pred_pairs: 87 | row = {} 88 | gt = load(gt_path)[:, :3] 89 | gt = gt[np.newaxis, ...] 90 | pred = load(pred_path) 91 | pred = pred[:, :3] 92 | 93 | row["name"] = os.path.basename(pred_path) 94 | pred = pred[np.newaxis, ...] 95 | 96 | pred = torch.from_numpy(pred).cuda() 97 | gt = torch.from_numpy(gt).cuda() 98 | 99 | pred_tensor, centroid, furthest_distance = normalize_point_batch(pred) 100 | gt_tensor, centroid, furthest_distance = normalize_point_batch(gt) 101 | 102 | # B, P_predict, 1 103 | cd_forward, cd_backward = nndistance(pred, gt) 104 | # cd_forward, _ = knn_point(1, gt_tensor, pred_tensor) 105 | # cd_backward, _ = knn_point(1, pred_tensor, gt_tensor) 106 | # cd_forward = cd_forward[0, :, 0] 107 | # cd_backward = cd_backward[0, :, 0] 108 | cd_forward = cd_forward.detach().cpu().numpy()[0] 109 | cd_backward = cd_backward.detach().cpu().numpy()[0] 110 | 111 | save_ply_property(pred.squeeze(0).detach().cpu().numpy(), cd_forward, pred_path[:-4]+"_cdF.ply", property_max=0.003, cmap_name="jet") 112 | save_ply_property(gt.squeeze(0).detach().cpu().numpy(), cd_backward, pred_path[:-4]+"_cdB.ply", property_max=0.003, cmap_name="jet") 113 | 114 | md_value = np.mean(cd_forward)+np.mean(cd_backward) 115 | hd_value = np.max(np.amax(cd_forward, axis=0)+np.amax(cd_backward, axis=0)) 116 | cd_backward = np.mean(cd_backward) 117 | cd_forward = np.mean(cd_forward) 118 | # row["CD_forward"] = np.mean(cd_forward) 119 | # row["CD_backwar"] = np.mean(cd_backward) 120 | row["CD"] = cd_forward+cd_backward 121 | 122 | row["hausdorff"] = hd_value 123 | avg_md_forward_value += cd_forward 124 | avg_md_backward_value += cd_backward 125 | avg_hd_value += hd_value 126 | if os.path.isfile(pred_path[:-4] + "_point2mesh_distance.xyz"): 127 | point2mesh_distance = load(pred_path[:-4] + "_point2mesh_distance.xyz") 128 | if point2mesh_distance.size == 0: 129 | continue 130 | point2mesh_distance = point2mesh_distance[:, 3] 131 | row["p2f avg"] = np.nanmean(point2mesh_distance) 132 | row["p2f std"] = np.nanstd(point2mesh_distance) 133 | global_p2f.append(point2mesh_distance) 134 | if os.path.isfile(pred_path[:-4] + "_density.xyz"): 135 | density = load(pred_path[:-4] + "_density.xyz") 136 | global_density.append(density) 137 | std = np.std(density, axis=0) 138 | for i in range(7): 139 | row["nuc_%d" % i] = std[i] 140 | writer.writerow(row) 141 | counter += 1 142 | 143 | row = OrderedDict() 144 | 145 | avg_md_forward_value /= counter 146 | avg_md_backward_value /= counter 147 | avg_hd_value /= counter 148 | # row["CD_forward"] = avg_md_forward_value 149 | # row["CD_backward"] = avg_md_backward_value 150 | row["CD"] = avg_md_forward_value+avg_md_backward_value 151 | row["hausdorff"] = avg_hd_value 152 | if global_p2f: 153 | global_p2f = np.concatenate(global_p2f, axis=0) 154 | mean_p2f = np.nanmean(global_p2f) 155 | std_p2f = np.nanstd(global_p2f) 156 | row["p2f avg"] = mean_p2f 157 | row["p2f std"] = std_p2f 158 | if global_density: 159 | global_density = np.concatenate(global_density, axis=0) 160 | nuc = np.std(global_density, axis=0) 161 | for i in range(7): 162 | row["nuc_%d" % i] = std[i] 163 | 164 | writer.writerow(row) 165 | print("|".join(["{:>15.8f}".format(d) for d in row.values()])) 166 | -------------------------------------------------------------------------------- /scripts/filter_dtu_predictions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from tqdm import tqdm, trange 5 | import numpy as np 6 | from im2mesh import config 7 | from DSS.utils.io import read_ply, save_ply 8 | from DSS.utils.dataset import DTUDataset 9 | from skimage.morphology import binary_dilation, disk 10 | import open3d 11 | 12 | 13 | if __name__ == '__main__': 14 | # Adjust this to your paths; the input path should point to the 15 | # DTU dataset including the mvs data which can be downloaded here 16 | # http://roboimagedata.compute.dtu.dk/ 17 | INPUT_PATH = '/home/mnt/points/data/DTU_MVS/' 18 | INPUT_PATH = os.path.join(INPUT_PATH, 'Points') 19 | if not os.path.exists(INPUT_PATH): 20 | raise FileNotFoundError("The input path is not pointing to the DTU Dataset. " + \ 21 | "Please download the DTU Dataset and adjust your input path.") 22 | 23 | methods = ['furu', 'tola', 'camp', 'stl'] 24 | # Shortcuts 25 | out_dir = '/home/mnt/points/data/DTU_MVS/Points' 26 | generation_dir = os.path.join(out_dir, '..', 'Points_in_Mask') 27 | 28 | if not os.path.isdir(generation_dir): 29 | os.makedirs(generation_dir) 30 | 31 | parser = argparse.ArgumentParser( 32 | description='Filter the DTU baseline predictions with the object masks.' 33 | ) 34 | parser.add_argument('scan_ids', type=int, nargs='+', help='Path to config file.') 35 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 36 | args = parser.parse_args() 37 | scan_ids = args.scan_ids 38 | 39 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 40 | device = torch.device("cuda" if is_cuda else "cpu") 41 | 42 | 43 | def filter_points(p, data): 44 | n_images = len(data) 45 | 46 | p = torch.from_numpy(p) 47 | n_p = p.shape[0] 48 | inside_mask = np.ones((n_p,), dtype=np.bool) 49 | inside_img = np.zeros((n_p,), dtype=np.bool) 50 | for i in trange(n_images): 51 | # get data 52 | maski_in = data.object_masks[i][0].astype('float32') 53 | 54 | # Apply binary dilation to account for errors in the mask 55 | maski = torch.from_numpy(binary_dilation(maski_in, disk(12))).float() 56 | 57 | #h, w = maski.shape 58 | h, w = maski.shape 59 | w_mat = torch.from_numpy(data.data_dict['world_mat_%d' % i]) 60 | c_mat = torch.from_numpy(data.data_dict['camera_mat_%d' % i]) 61 | s_mat = torch.from_numpy(data.data_dict['scale_mat_%d' % i]) 62 | 63 | # project points into image 64 | phom = torch.cat([p, torch.ones(n_p, 1)], dim=-1).transpose(1, 0) 65 | proj = c_mat @ w_mat @ phom 66 | proj = (proj[:2] / proj[-2].unsqueeze(0)).transpose(1, 0) 67 | 68 | # check which points are inside image; by our definition, 69 | # the coordinates have to be in [-1, 1] 70 | mask_p_inside = ((proj[:, 0] >= -1) & 71 | (proj[:, 1] >= -1) & 72 | (proj[:, 0] <= 1) & 73 | (proj[:, 1] <= 1) 74 | ) 75 | inside_img |= mask_p_inside.cpu().numpy() 76 | 77 | # get image coordinates 78 | proj[:, 0] = (proj[:, 0] + 1) * (w - 1) / 2. 79 | proj[:, 1] = (proj[:, 1] + 1) * (h - 1) / 2. 80 | proj = proj.long() 81 | 82 | # fill occupancy values 83 | proj = proj[mask_p_inside] 84 | occ = torch.ones(n_p) 85 | occ[mask_p_inside] = maski[proj[:, 1], proj[:, 0]] 86 | inside_mask &= (occ.cpu().numpy() >= 0.5) 87 | 88 | occ_out = np.zeros((n_p,)) 89 | occ_out[inside_img & inside_mask] = 1. 90 | 91 | return occ_out 92 | 93 | # Dataset 94 | for scan_id in scan_ids: 95 | dataset = DTUDataset('data/DTU/scan%d' % scan_id) 96 | 97 | for method in methods: 98 | out_dir = os.path.join(generation_dir, method) 99 | 100 | if not os.path.isdir(out_dir): 101 | os.makedirs(out_dir) 102 | 103 | in_dir = os.path.join(INPUT_PATH, method) 104 | if method != 'stl': 105 | scan_path = os.path.join(in_dir, '%s%03d_l3.ply' % (method, scan_id)) 106 | else: 107 | scan_path = os.path.join(in_dir, '%s%03d_total.ply' % (method, scan_id)) 108 | 109 | print(scan_path) 110 | 111 | out_file = os.path.join(out_dir, 'scan%d.ply' % scan_id) 112 | if not os.path.exists(out_file): 113 | pcl = open3d.io.read_point_cloud(scan_path) 114 | p = np.asarray(pcl.points).astype(np.float32) 115 | occ = filter_points(p, dataset) > 0.5 116 | pcl.points = open3d.utility.Vector3dVector(p[occ]) 117 | if len(pcl.colors) != 0: 118 | c = np.asarray(pcl.colors) 119 | pcl.colors = open3d.utility.Vector3dVector(c[occ]) 120 | if len(pcl.normals) != 0: 121 | n = np.asarray(pcl.normals) 122 | pcl.normals = open3d.utility.Vector3dVector(n[occ]) 123 | open3d.io.write_point_cloud(out_file, pcl) 124 | -------------------------------------------------------------------------------- /scripts/gen_denoising_pairs.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # import os 3 | # import tqdm 4 | # import argparse 5 | # import time 6 | # import numpy as np 7 | # from itertools import chain 8 | # from glob import glob 9 | # import sys 10 | # sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 11 | # from neural_point_splatter.neuralSplatter import NeuralPointSplatter, BaselineRenderer, CameraSampler 12 | # from neural_point_splatter.mathHelper import dot 13 | # from neural_point_splatter.splatterIo import saveAsPng, readScene, readCloud, getBasename, checkScenePaths, writeCameras 14 | # from demos.app_skeletton import renderScene, parse_device, writeScene 15 | # from pytorch_points.network.operations import normalize_point_batch, batch_normals 16 | # from pytorch_points.utils.pc_utils import save_ply 17 | 18 | # if __name__ == "__main__": 19 | 20 | # parser = argparse.ArgumentParser(description='Render a given point cloud.') 21 | # parser.add_argument("--input", nargs=1, help="paths to input") 22 | # parser.add_argument("--target", nargs=1, help="paths to target") 23 | # parser.add_argument('-s', '--scene', nargs=1, default=["../example_data/scenes/template.json"], 24 | # help='Input file') 25 | # parser.add_argument('-o', '--output', dest='output', 26 | # help='Output file path') 27 | # parser.add_argument('--width', dest='image_width', type=int, default=None, 28 | # help='Desired image width in pixels.') 29 | # parser.add_argument('--height', dest='image_height', type=int, default=None, 30 | # help='Desired image height in pixels.') 31 | # parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', default=False, 32 | # help='If true: show additional output like inbetween calculations') 33 | # parser.add_argument('-d', '--device', dest='device', default='cuda:0', help='Device to run the computations on, options: cpu, cuda') 34 | # parser.add_argument('-c', '--gen-camera', type=int, default=0, help='number of random cameras') 35 | # parser.add_argument('-cO', '--cam-offset', dest="gen_camera_offset", type=float, nargs=2, default=[5, 20], help='depth offset for generated cameras') 36 | # parser.add_argument('-cF', '--cam-focal', dest="gen_camera_focal", type=float, default=15, help='focal length for generated cameras') 37 | # parser.add_argument('--cutoff', type=float, default=1, help='cutoff threshold') 38 | # parser.add_argument('--baseline', action="store_true", help="use baseline depth renderer") 39 | # parser.add_argument('-k', '--topK', dest='topK', type=int, default=5, help='topK for merging depth') 40 | # parser.add_argument('-mT', '--merge_threshold', type=float, default=0.05, help='threshold for merging depth') 41 | # parser.add_argument('--vrk-mode', help="nearestNeighbor or constant", choices=["nearestNeighbor", "constant"], default="constant") 42 | # parser.add_argument('--pca-normal', action="store_true", help="recompute noisy point cloud normal with pca") 43 | # parser.add_argument('--name', type=str, default="*.ply") 44 | 45 | # args = parser.parse_args() 46 | # args.input = args.input.pop() 47 | # args.target = args.target.pop() 48 | # target_points_paths = glob(os.path.join(args.target, "**", args.name), recursive=True) 49 | # input_points_paths = glob(os.path.join(args.input, "**", args.name), recursive=True) 50 | # assert(len(target_points_paths) == len(input_points_paths)) 51 | 52 | # if args.output is None: 53 | # args.output = 'renders/' 54 | # VERBOSE = args.verbose 55 | 56 | # torch.manual_seed(24) 57 | # torch.backends.cudnn.deterministic = True 58 | # torch.backends.cudnn.benchmark = False 59 | # np.random.seed(24) 60 | 61 | # (device, isCpu) = parse_device(args) 62 | # torch.cuda.set_device(device) 63 | 64 | # if VERBOSE: 65 | # print("Rendering on:", device) 66 | 67 | # scenePath = checkScenePaths(args.scene).pop() 68 | # scene = readScene(scenePath, device=device) 69 | # scene.cloud.VrkMode = args.vrk_mode 70 | # with torch.no_grad(): 71 | # splatter = MODEL(scene, device=device, verbose=False, shading=scene.cloud.shading, 72 | # mergeTopK=args.topK, mergeThreshold=args.merge_threshold, cutOffThreshold=args.cutoff) 73 | # for inputPath, targetPath in zip(input_points_paths, target_points_paths): 74 | # inputRelPath = os.path.relpath(inputPath, args.input) 75 | # targetRelPath = os.path.relpath(targetPath, args.target) 76 | # input_outDir = os.path.join(args.output, "input_rendered", inputRelPath[:-4]) 77 | # target_outDir = os.path.join(args.output, "target_rendered", targetRelPath[:-4]) 78 | 79 | # readSceneTick = time.time() 80 | # targetPoints = readCloud(targetPath, device=device) 81 | # # targetPoints_, _, _ = normalize_point_batch(targetPoints[:, :, :3], NCHW=False) 82 | # # targetPoints_.squeeze_(0) 83 | # # targetPoints[:, :3] = targetPoints_ 84 | # inputPoints = readCloud(inputPath, device=device) 85 | # # inputPoints_, _, _ = normalize_point_batch(inputPoints[:, :, :3], NCHW=False) 86 | # # inputPoints.squeeze_(0) 87 | # # inputPoints[:, :3] = inputPoints_ 88 | # if args.pca_normal: 89 | # inputNormals = batch_normals(inputPoints[:, :3].unsqueeze(0), nn_size=32, NCHW=False) 90 | # inputNormals = torch.where(dot(inputNormals, targetPoints[:, 3:6].unsqueeze(0), dim=-1).unsqueeze(-1) < 0, -inputNormals, inputNormals) 91 | # inputPoints[:, 3:6] = inputNormals.squeeze(0) 92 | # # targetNormals = batch_normals(targetPoints[:, :3].unsqueeze(0), nn_size=32, NCHW=False) 93 | # # targetNormals = torch.where(dot(targetNormals, targetPoints[:, 3:6].unsqueeze(0), dim=-1).unsqueeze(-1) < 0, -targetNormals, targetNormals) 94 | # # targetPoints[:, 3:6] = targetNormals.squeeze(0) 95 | # # save_ply(targetPoints[:, :3].cpu().numpy(), targetPath[:-4]+"_pca.ply", normals=targetPoints[:, 3:6].cpu().numpy()) 96 | # save_ply(inputPoints[:, :3].cpu().numpy(), inputPath[:-4]+"_pca.ply", normals=inputPoints[:, 3:6].cpu().numpy()) 97 | 98 | # readSceneTock = time.time() 99 | # renderCount = 0 100 | 101 | # # for offset in range(int(args.gen_camera_offset[0]), int(args.gen_camera_offset[1])): 102 | # # offsets = (np.clip(np.random.randn(2), -0.2, 0.2)+1)*offset 103 | # # for o in offsets: 104 | # # if args.gen_camera > 0: 105 | # # camSampler = CameraSampler(args.gen_camera, o, args.gen_camera_focal, 106 | # # points=targetPoints[:, :3].unsqueeze(0), 107 | # # camWidth=args.image_width, 108 | # # camHeight=args.image_height, 109 | # # filename="../example_data/pointclouds/dome_300.ply") 110 | # # cameras = [] 111 | # # for i in range(args.gen_camera): 112 | # # cam = next(camSampler) 113 | # # cameras.append(cam) 114 | # # else: 115 | # # cameras = None 116 | 117 | # # if cameras is not None: 118 | # # splatter.initCameras(cameras=cameras) 119 | # # # writeCameras(scene, args.output + '/cameras.ply') 120 | # # else: 121 | # # splatter.initCameras(cameras=scene.cameras) 122 | 123 | # # rendered = [] 124 | # # for i, cam in enumerate(scene.cameras): 125 | # # try: 126 | # # splatter.setCamera(i) 127 | # # scene.loadPoints(inputPoints) 128 | # # splatter.setCloud(scene.cloud) 129 | # # inputRendered = splatter.render().detach()[0] 130 | # # scene.loadPoints(targetPoints) 131 | # # splatter.setCloud(scene.cloud) 132 | # # targetRendered = splatter.render().detach()[0] 133 | # # except Exception as e: 134 | # # print(inputPath, targetPath, renderCount, e) 135 | # # else: 136 | # # if VERBOSE and i == 0: 137 | # # pngDir = os.path.join(input_outDir, "png") 138 | # # os.makedirs(os.path.join(input_outDir, "png"), exist_ok=True) 139 | # # saveAsPng(inputRendered.cpu(), os.path.join(pngDir, 'cam%04d_input.png' % renderCount)) 140 | # # np.save(os.path.join(input_outDir, 'cam%04d_input.npy' % renderCount), inputRendered.cpu().numpy()) 141 | # # if VERBOSE and i == 0: 142 | # # pngDir = os.path.join(target_outDir, "png") 143 | # # os.makedirs(os.path.join(target_outDir, "png"), exist_ok=True) 144 | # # saveAsPng(targetRendered.cpu(), os.path.join(pngDir, 'cam%04d_target.png' % (renderCount))) 145 | # # np.save(os.path.join(target_outDir, 'cam%04d_target.npy' % renderCount), targetRendered.cpu().numpy()) 146 | # # renderCount += 1 147 | 148 | # print(inputPath, targetPath) 149 | -------------------------------------------------------------------------------- /scripts/plot_evaluations.py: -------------------------------------------------------------------------------- 1 | import plotly 2 | from plotly import graph_objects as go 3 | from plotly.subplots import make_subplots 4 | import plotly.express as px 5 | import argparse 6 | import datetime 7 | import os 8 | import csv 9 | from glob import glob 10 | 11 | 12 | def plot_evaluations(in_dirs): 13 | metrics = ['chamfer_p', 'chamfer_n', 'pf_dist'] 14 | fig = make_subplots(rows=3, cols=1, 15 | subplot_titles=metrics, 16 | shared_xaxes='all', 17 | vertical_spacing=0.1, horizontal_spacing=0.01, 18 | ) 19 | df = dict() 20 | colors = px.colors.qualitative.Plotly 21 | for i, exp_dir in enumerate(in_dirs): 22 | exp_name = os.path.basename(exp_dir.rstrip('/')) 23 | evals_in_exp = glob(os.path.join(exp_dir, 'vis', 'evaluation*.csv')) 24 | for eval_f in evals_in_exp: 25 | eval_name = os.path.splitext(os.path.basename(eval_f))[0] 26 | with open(eval_f, 'r') as csvfile: 27 | print(eval_f) 28 | fieldnames = ['mtime', 'it', 29 | 'chamfer_p', 'chamfer_n', 'pf_dist'] 30 | for k in fieldnames: 31 | df['.'.join([eval_name, exp_name, k])] = [] 32 | reader = csv.DictReader( 33 | csvfile, fieldnames=fieldnames, restval='-', ) 34 | for it, row in enumerate(reader): 35 | if it == 0 or row['it'] == 0: 36 | # skip header 37 | continue 38 | for k, v in row.items(): 39 | if v == '-': 40 | continue 41 | df['.'.join([eval_name, exp_name, k])].append( 42 | float(v)) 43 | 44 | name_prefix = '.'.join([eval_name, exp_name]) 45 | for idx, k in enumerate(metrics): 46 | y_data = df[name_prefix + '.' + k] 47 | x_data = df[name_prefix + '.' + 'mtime'] 48 | fig.add_trace(go.Scatter(x=x_data, y=y_data, 49 | mode='lines+markers', 50 | name=name_prefix + '.' + k, 51 | marker_color=colors[i]), 52 | col=1, row=idx + 1) 53 | 54 | fig.update_yaxes(type="log", autorange=True) 55 | fig.update_layout(legend=dict( 56 | orientation="h", 57 | yanchor="bottom", 58 | y=1.02, 59 | xanchor="right", 60 | x=1 61 | ), template='plotly_white') 62 | return fig 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--dirs", type=str, nargs='+', required=True, 68 | help="Experiment directories") 69 | args = parser.parse_args() 70 | fig = plot_evaluations(args.dirs) 71 | out_fname = 'eval' + datetime.datetime.now().strftime("-%Y%m%d-%H%M%S") + '.html' 72 | fig.write_html( 73 | out_fname) 74 | print('Saved to ' + out_fname) 75 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | from glob import glob 4 | import torch 5 | from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension, CppExtension 6 | 7 | extra_compile_args = {"cxx": ["-std=c++14"]} 8 | define_macros = [] 9 | sources = glob(os.path.join('DSS', 'csrc', '*.cpp')) 10 | source_cuda = glob(os.path.join('DSS', 'csrc', '*.cu')) 11 | 12 | force_cuda = os.getenv("FORCE_CUDA", "0") == "1" 13 | if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda: 14 | extension = CUDAExtension 15 | sources += source_cuda 16 | define_macros += [("WITH_CUDA", None)] 17 | nvcc_args = [ 18 | "-DCUDA_HAS_FP16=1", 19 | "-D__CUDA_NO_HALF_OPERATORS__", 20 | "-D__CUDA_NO_HALF_CONVERSIONS__", 21 | "-D__CUDA_NO_HALF2_OPERATORS__", 22 | ] 23 | nvcc_flags_env = os.getenv("NVCC_FLAGS", "") 24 | if nvcc_flags_env != "": 25 | nvcc_args.extend(nvcc_flags_env.split(" ")) 26 | 27 | # It's better if pytorch can do this by default .. 28 | CC = os.environ.get("CC", None) 29 | if CC is not None: 30 | CC_arg = "-ccbin={}".format(CC) 31 | if CC_arg not in nvcc_args: 32 | if any(arg.startswith("-ccbin") for arg in nvcc_args): 33 | raise ValueError("Inconsistent ccbins") 34 | nvcc_args.append(CC_arg) 35 | 36 | extra_compile_args["nvcc"] = nvcc_args 37 | else: 38 | print('Cuda is not available!') 39 | 40 | include_dirs = torch.utils.cpp_extension.include_paths() 41 | ext_modules = [ 42 | CUDAExtension('DSS._C', sources, 43 | include_dirs=['DSS/csrc']+include_dirs, 44 | define_macros=define_macros, 45 | extra_compile_args=extra_compile_args 46 | ) 47 | ] 48 | 49 | INSTALL_REQUIREMENTS = ['numpy', 'torch', 'plyfile',] 50 | 51 | class BuildExtension(torch.utils.cpp_extension.BuildExtension): 52 | def __init__(self, *args, **kwargs): 53 | super().__init__(use_ninja=False, *args, **kwargs) 54 | 55 | setup( 56 | name='DSS', 57 | description='Differentiable Surface Splatter', 58 | author='Yifan Wang, Lixin Xue and Felice Serena', 59 | packages=find_packages(exclude=('tests')), 60 | license='MIT License', 61 | version='1.0', 62 | install_requires=INSTALL_REQUIREMENTS, 63 | ext_modules=ext_modules, 64 | cmdclass={'build_ext': BuildExtension} 65 | ) 66 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test camera matrics for generated data by reprojecting point cloud/object 3 | """ 4 | import unittest 5 | from pytorch3d.renderer import ( 6 | look_at_view_transform, 7 | RasterizationSettings, 8 | FoVPerspectiveCameras, 9 | MeshRenderer, 10 | MeshRasterizer, 11 | SoftPhongShader, 12 | ) 13 | from pytorch3d.ops import ( 14 | packed_to_padded, 15 | eyes, 16 | sample_points_from_meshes, 17 | padded_to_packed 18 | ) 19 | from pytorch3d.io import load_ply, save_obj, load_objs_as_meshes 20 | from pytorch3d.structures import Meshes, Textures, padded_to_list 21 | from itertools import chain 22 | from glob import glob 23 | import numpy as np 24 | import imageio 25 | import argparse 26 | import os 27 | import torch 28 | import sys 29 | import trimesh 30 | sys.path.append(".") 31 | from DSS.core.camera import CameraSampler 32 | from DSS.core.texture import LightingTexture 33 | from DSS.utils.dataset import MVRDataset 34 | from DSS.utils.mathHelper import decompose_to_R_and_t 35 | from common import get_tri_color_lights 36 | from im2mesh.common import (transform_to_camera_space, 37 | sample_patch_points, 38 | arange_pixels, 39 | transform_to_world, 40 | get_tensor_values) 41 | 42 | 43 | class TestMVRData(unittest.TestCase): 44 | def test_dataset(self): 45 | # 1. rerender input point clouds / meshes using the saved camera_mat 46 | # compare mask image with saved mask image 47 | # 2. backproject masked points to space with dense depth map, 48 | # fuse all views and save 49 | batch_size = 1 50 | device = torch.device('cuda:0') 51 | 52 | data_dir = 'data/synthetic/cube_mesh' 53 | output_dir = os.path.join('tests', 'outputs', 'test_data') 54 | if not os.path.isdir(output_dir): 55 | os.makedirs(output_dir) 56 | 57 | # dataset 58 | dataset = MVRDataset(data_dir=data_dir, load_dense_depth=True, mode="train") 59 | data_loader = torch.utils.data.DataLoader( 60 | dataset, batch_size=batch_size, num_workers=0, shuffle=False 61 | ) 62 | meshes = load_objs_as_meshes([os.path.join(data_dir, 'mesh.obj')]).to(device) 63 | cams = dataset.get_cameras().to(device) 64 | image_size = imageio.imread(dataset.image_files[0]).shape[0] 65 | 66 | # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc 67 | raster_settings = RasterizationSettings( 68 | image_size=image_size, 69 | blur_radius=0.0, 70 | faces_per_pixel=5, 71 | bin_size = None, # this setting controls whether naive or coarse-to-fine rasterization is used 72 | max_faces_per_bin = None # this setting is for coarse rasterization 73 | ) 74 | rasterizer = MeshRasterizer(cameras=None, raster_settings=raster_settings) 75 | 76 | # render with loaded cameras positions and training tranformation functions 77 | pixel_world_all = [] 78 | for idx, data in enumerate(data_loader): 79 | # get datas 80 | img = data.get('img.rgb').to(device) 81 | assert(img.min() >= 0 and img.max() <= 1), "Image must be a floating number between 0 and 1." 82 | mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1) 83 | 84 | camera_mat = data['camera_mat'].to(device) 85 | 86 | cams.R, cams.T = decompose_to_R_and_t(camera_mat) 87 | cams._N = cams.R.shape[0] 88 | cams.to(device) 89 | self.assertTrue(torch.equal(cams.get_world_to_view_transform().get_matrix(), camera_mat)) 90 | 91 | # transform to view and rerender with non-rotated camera 92 | verts_padded = transform_to_camera_space(meshes.verts_padded(), cams) 93 | meshes_in_view = meshes.offset_verts( 94 | -meshes.verts_packed() + 95 | padded_to_packed(verts_padded, meshes.mesh_to_verts_packed_first_idx(), meshes.verts_packed().shape[0])) 96 | 97 | fragments = rasterizer(meshes_in_view, cameras=dataset.get_cameras().to(device)) 98 | 99 | # compare mask 100 | mask = fragments.pix_to_face[..., :1] >= 0 101 | imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx), mask[0, ...].cpu().to(dtype=torch.uint8)*255) 102 | # allow 5 pixels difference 103 | self.assertTrue(torch.sum(mask_gt != mask)<5) 104 | 105 | # check dense maps 106 | # backproject points to the world pixel range (-1, 1) 107 | pixels = arange_pixels((image_size, image_size), batch_size)[1].to(device) 108 | 109 | depth_img = data.get('img.depth').to(device) 110 | # get the depth and mask at the sampled pixel position 111 | depth_gt = get_tensor_values(depth_img, pixels, squeeze_channel_dim=True) 112 | mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(), pixels, squeeze_channel_dim=True).bool() 113 | # get pixels and depth inside the masked area 114 | pixels_packed = pixels[mask_gt] 115 | depth_gt_packed = depth_gt[mask_gt] 116 | first_idx = torch.zeros((pixels.shape[0],), device=device, dtype=torch.long) 117 | num_pts_in_mask = mask_gt.sum(dim=1) 118 | first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1] 119 | pixels_padded = packed_to_padded(pixels_packed, first_idx, num_pts_in_mask.max().item()) 120 | depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx, num_pts_in_mask.max().item()) 121 | # backproject to world coordinates 122 | # contains nan and infinite values due to depth_gt_padded containing 0.0 123 | pixel_world_padded = transform_to_world(pixels_padded, depth_gt_padded[..., None], cams) 124 | # transform back to list, containing no padded values 125 | split_size = num_pts_in_mask[..., None].repeat(1, 2) 126 | split_size[:, 1] = 3 127 | pixel_world_list = padded_to_list(pixel_world_padded, split_size) 128 | pixel_world_all.extend(pixel_world_list) 129 | 130 | idx += 1 131 | if idx >= 10: 132 | break 133 | 134 | pixel_world_all = torch.cat(pixel_world_all, dim=0) 135 | mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(), faces=None,process=False) 136 | mesh.export(os.path.join(output_dir, 'pixel_to_world.ply')) -------------------------------------------------------------------------------- /tests/test_dvr_camera.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import glob 4 | import os 5 | import imageio 6 | import trimesh 7 | import torch 8 | from DSS.utils.mathHelper import decompose_to_R_and_t 9 | from pytorch3d.io import load_objs_as_meshes 10 | from pytorch3d.structures import Meshes 11 | from pytorch3d.renderer import (PerspectiveCameras, 12 | MeshRenderer, 13 | MeshRasterizer, 14 | SoftPhongShader, 15 | RasterizationSettings, 16 | TexturesVertex) 17 | 18 | 19 | class TestDVRData(unittest.TestCase): 20 | """ 21 | parse DVR camera data 22 | """ 23 | def test_cameras(self): 24 | """ 25 | DVR cameras 26 | """ 27 | device = torch.device('cuda:0') 28 | input_dir = '/home/ywang/Documents/points/neural_splatter/differentiable_volumetric_rendering_upstream/data/DTU/scan106/scan106' 29 | out_dir = os.path.join('tests', 'outputs', 'test_dvr_data') 30 | if not os.path.exists(out_dir): 31 | os.makedirs(out_dir) 32 | 33 | dvr_camera_file = os.path.join(input_dir, 'cameras.npz') 34 | dvr_camera_dict = np.load(dvr_camera_file) 35 | n_views = len(glob.glob(os.path.join(input_dir, 'image', '*.png'))) 36 | 37 | focal_lengths = dvr_camera_dict['camera_mat_0'][(0,1),(0,1)].reshape(1,2) 38 | principal_point = dvr_camera_dict['camera_mat_0'][(0,1),(2,2)].reshape(1,2) 39 | cameras = PerspectiveCameras(focal_length=focal_lengths, principal_point=principal_point).to(device) 40 | # Define the settings for rasterization and shading. 41 | # Refer to raster_points.py for explanations of these parameters. 42 | raster_settings = RasterizationSettings( 43 | image_size=512, 44 | blur_radius=0.0, 45 | faces_per_pixel=5, 46 | # this setting controls whether naive or coarse-to-fine rasterization is used 47 | bin_size=None, 48 | max_faces_per_bin=None # this setting is for coarse rasterization 49 | ) 50 | renderer = MeshRenderer( 51 | rasterizer=MeshRasterizer( 52 | cameras=None, raster_settings=raster_settings), 53 | shader=SoftPhongShader(device=device) 54 | ) 55 | mesh = trimesh.load_mesh('/home/ywang/Documents/points/neural_splatter/differentiable_volumetric_rendering_upstream/out/multi_view_reconstruction/birds/ours_depth_mvs/vis/000_0000477500.ply') 56 | textures = TexturesVertex(verts_features=torch.ones( 57 | 1, mesh.vertices.shape[0], 3)).to(device=device) 58 | meshes = Meshes(verts=[torch.tensor(mesh.vertices).float()], faces=[torch.tensor(mesh.faces)], 59 | textures=textures).to(device=device) 60 | for i in range(n_views): 61 | transform_mat = torch.from_numpy(dvr_camera_dict['scale_mat_%d' % i].T @ dvr_camera_dict['world_mat_%d' % i].T).to(device).unsqueeze(0).float() 62 | cameras.R, cameras.T = decompose_to_R_and_t(transform_mat) 63 | cameras._N = cameras.R.shape[0] 64 | imgs = renderer(meshes, cameras=cameras, zfar=1e4, znear=1.0) 65 | import pdb; pdb.set_trace() 66 | imageio.imwrite(os.path.join(out_dir, '%06d.png' % i), (imgs[0].detach().cpu().numpy()*255).astype('uint8')) -------------------------------------------------------------------------------- /train_mvr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import git 3 | import time 4 | import numpy as np 5 | import os 6 | import logging 7 | import config 8 | import torch 9 | import torch.optim as optim 10 | from DSS.models import ImplicitModel, PointModel 11 | from DSS.utils import tolerating_collate 12 | from common import create_animation 13 | from DSS.misc.checkpoints import CheckpointIO 14 | from DSS import logger_py, set_deterministic_ 15 | 16 | set_deterministic_() 17 | 18 | # Arguments 19 | parser = argparse.ArgumentParser( 20 | description='Train implicit representations without 3D supervision.' 21 | ) 22 | parser.add_argument('config', type=str, help='Path to config file.') 23 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 24 | parser.add_argument('--exit-after', type=int, default=-1, 25 | help='Checkpoint and exit after specified number of ' 26 | 'seconds with exit code 2.') 27 | 28 | args = parser.parse_args() 29 | cfg = config.load_config(args.config, 'configs/default.yaml') 30 | 31 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 32 | device = torch.device("cuda" if is_cuda else "cpu") 33 | 34 | # Shorthands 35 | out_dir = os.path.join(cfg['training']['out_dir'], cfg['name']) 36 | backup_every = cfg['training']['backup_every'] 37 | exit_after = args.exit_after 38 | lr = cfg['training']['learning_rate'] 39 | batch_size = cfg['training']['batch_size'] 40 | batch_size_val = cfg['training']['batch_size_val'] 41 | n_workers = cfg['training']['n_workers'] 42 | model_selection_metric = cfg['training']['model_selection_metric'] 43 | if cfg['training']['model_selection_mode'] == 'maximize': 44 | model_selection_sign = 1 45 | elif cfg['training']['model_selection_mode'] == 'minimize': 46 | model_selection_sign = -1 47 | else: 48 | raise ValueError('model_selection_mode must be ' 49 | 'either maximize or minimize.') 50 | 51 | 52 | # Output directory 53 | if not os.path.exists(out_dir): 54 | os.makedirs(out_dir) 55 | 56 | # Begin logging also to the log file 57 | fileHandler = logging.FileHandler(os.path.join(out_dir, cfg.training.logfile)) 58 | fileHandler.setLevel(logging.DEBUG) 59 | logger_py.addHandler(fileHandler) 60 | 61 | repo = git.Repo(search_parent_directories=False) 62 | sha = repo.head.object.hexsha 63 | logger_py.debug('Git commit: %s' % sha) 64 | 65 | # Data 66 | train_dataset = config.create_dataset(cfg.data, mode='train') 67 | val_dataset = config.create_dataset(cfg.data, mode='val') 68 | val_loader = torch.utils.data.DataLoader( 69 | val_dataset, batch_size=batch_size_val, num_workers=int(n_workers // 2), 70 | shuffle=False, collate_fn=tolerating_collate, 71 | ) 72 | # data_viz = next(iter(val_loader)) 73 | model = config.create_model(cfg, camera_model=train_dataset.get_cameras(), device=device) 74 | 75 | # Create rendering objects from loaded data 76 | cameras = train_dataset.get_cameras() 77 | lights = train_dataset.get_lights() 78 | 79 | 80 | # Optimizer 81 | if cfg.model.type == 'point': 82 | optimizer = optim.SGD( 83 | [p for p in model.parameters() if p.requires_grad], lr=lr) 84 | else: 85 | if cfg.renderer.is_neural_texture: 86 | optimizer = optim.Adam(model.parameters(), lr=lr) 87 | else: 88 | optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99)) 89 | 90 | # Loads checkpoints 91 | checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer) 92 | try: 93 | load_dict = checkpoint_io.load(cfg.training.resume_from) 94 | except FileExistsError: 95 | load_dict = dict() 96 | 97 | epoch_it = load_dict.get('epoch_it', -1) 98 | it = load_dict.get('it', -1) 99 | 100 | # Save config to log directory 101 | config.save_config(os.path.join(out_dir, 'config.yaml'), cfg) 102 | 103 | # Update Metrics from loaded 104 | model_selection_metric = cfg['training']['model_selection_metric'] 105 | metric_val_best = load_dict.get( 106 | 'loss_val_best', -model_selection_sign * np.inf) 107 | 108 | if metric_val_best == np.inf or metric_val_best == -np.inf: 109 | metric_val_best = -model_selection_sign * np.inf 110 | 111 | logger_py.info('Current best validation metric (%s): %.8f' 112 | % (model_selection_metric, metric_val_best)) 113 | 114 | # Shorthands 115 | print_every = cfg['training']['print_every'] 116 | checkpoint_every = cfg['training']['checkpoint_every'] 117 | validate_every = cfg['training']['validate_every'] 118 | visualize_every = cfg['training']['visualize_every'] 119 | debug_every = cfg['training']['debug_every'] 120 | reweight_every = cfg['training']['reweight_every'] 121 | 122 | scheduler = optim.lr_scheduler.MultiStepLR( 123 | optimizer, cfg['training']['scheduler_milestones'], 124 | gamma=cfg['training']['scheduler_gamma'], last_epoch=epoch_it) 125 | 126 | # Set mesh extraction to low resolution for fast visuliation 127 | # during training 128 | cfg['generation']['resolution'] = cfg['training']['visualize_resolution'] 129 | cfg['generation']['img_size'] = tuple(x//4 for x in train_dataset.resolution) 130 | generator = config.create_generator(cfg, model, device=device) 131 | trainer = config.create_trainer( 132 | cfg, model, optimizer, scheduler, generator, None, val_loader, device=device) 133 | 134 | # Print model 135 | nparameters = sum(p.numel() for p in model.parameters()) 136 | logger_py.info('Total number of parameters: %d' % nparameters) 137 | 138 | 139 | # Start training loop 140 | t0 = time.time() 141 | t0b = time.time() 142 | sample_weights = np.ones(len(train_dataset)).astype('float32') 143 | 144 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 145 | num_workers=n_workers, drop_last=True, 146 | collate_fn=tolerating_collate) 147 | while True: 148 | epoch_it += 1 149 | for batch in train_loader: 150 | it += 1 151 | 152 | loss = trainer.train_step(batch, cameras=cameras, lights=lights, it=it) 153 | 154 | # Visualize output 155 | if it > 0 and visualize_every > 0 and (it % visualize_every) == 0: 156 | logger_py.info('Visualizing') 157 | trainer.visualize(batch, it=it, vis_type='image', cameras=cameras, lights=lights) 158 | if isinstance(trainer.model, PointModel): 159 | trainer.visualize( 160 | batch, it=it, vis_type='pointcloud', cameras=cameras, lights=lights) 161 | if isinstance(trainer.model, ImplicitModel): 162 | trainer.visualize( 163 | batch, it=it, vis_type='mesh', cameras=cameras, lights=lights) 164 | 165 | # Print output 166 | if print_every > 0 and (it % print_every) == 0: 167 | logger_py.info('[Epoch %02d] it=%03d, loss=%.4f, time=%.4f' 168 | % (epoch_it, it, loss, time.time() - t0b)) 169 | t0b = time.time() 170 | 171 | # Debug visualization 172 | if it > 0 and debug_every > 0 and (it % debug_every) == 0: 173 | logger_py.info('Visualizing gradients') 174 | trainer.debug(batch, cameras=cameras, lights=lights, it=it, 175 | mesh_gt=train_dataset.get_meshes()) 176 | # Save checkpoint 177 | if it > 0 and (checkpoint_every > 0 and (it % checkpoint_every) == 0): 178 | logger_py.info('Saving checkpoint') 179 | print('Saving checkpoint') 180 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 181 | loss_val_best=metric_val_best) 182 | trainer.save_shape(os.path.join(out_dir, 'shape'), it) 183 | 184 | # Backup if necessary 185 | if it > 0 and (backup_every > 0 and (it % backup_every) == 0): 186 | logger_py.info('Backup checkpoint') 187 | checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it, 188 | loss_val_best=metric_val_best) 189 | trainer.save_shape(os.path.join(out_dir, 'shape_%d' % it), it) 190 | 191 | # Run validation and adjust sampling rate 192 | if it > 0 and validate_every > 0 and (it % validate_every) == 0: 193 | mesh_gt = train_dataset.get_meshes() 194 | if mesh_gt is not None: 195 | eval_dict = trainer.evaluate_mesh(val_loader, it, cameras=cameras, lights=lights) 196 | metric_val = eval_dict['chamfer'] 197 | else: 198 | eval_dict = trainer.evaluate(val_loader, cameras=cameras, lights=lights) 199 | metric_val = eval_dict[model_selection_metric] 200 | 201 | logger_py.info('Validation metric (%s): %.4g' % (model_selection_metric, metric_val)) 202 | 203 | if model_selection_sign * (metric_val - metric_val_best) > 0: 204 | metric_val_best = metric_val 205 | logger_py.info('New best model (loss %.4g)' % metric_val_best) 206 | checkpoint_io.backup_model_best('model_best.pt') 207 | checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it, 208 | loss_val_best=metric_val_best) 209 | trainer.save_shape(os.path.join(out_dir, 'shape_best'), it) 210 | 211 | # Exit if necessary 212 | if exit_after > 0 and (time.time() - t0) >= exit_after: 213 | logger_py.info('Time limit reached. Exiting.') 214 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 215 | loss_val_best=metric_val_best) 216 | # create animation 217 | try: 218 | os.makedirs(trainer.vis_dir, exist_ok=True) 219 | create_animation(trainer.vis_dir, show_max=20) 220 | except Exception as e: 221 | logger_py.warning( 222 | "Couldn't create animated sequence: {}".format(e)) 223 | for t in trainer._threads: 224 | t.join() 225 | exit(3) 226 | 227 | # Make scheduler step after full epoch 228 | trainer.update_learning_rate(it) --------------------------------------------------------------------------------