├── .gitignore ├── .gitmodules ├── DSS ├── __init__.py ├── core │ ├── __init__.py │ ├── camera.py │ ├── cloud.py │ ├── lighting.py │ ├── rasterizer.py │ ├── renderer.py │ └── texture.py ├── csrc │ ├── bitmask.cuh │ ├── cuda_utils.h │ ├── ext.cpp │ ├── macros.hpp │ ├── rasterization_utils.cuh │ ├── rasterize_points.cu │ ├── rasterize_points.h │ ├── rasterize_points_backward.cu │ ├── rasterize_points_cpu.cpp │ ├── types.hpp │ ├── weighted_sum.cu │ └── weighted_sum.h ├── logger.py ├── misc │ ├── __init__.py │ ├── checkpoints.py │ ├── imageFilters.py │ ├── pix2pix │ │ ├── .gitignore │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── aligned_dataset.py │ │ │ ├── base_dataset.py │ │ │ └── single_dataset.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── networks.py │ │ │ ├── pix2pix_model.py │ │ │ ├── template_model.py │ │ │ └── test_model.py │ │ ├── options │ │ │ ├── base_options.py │ │ │ └── test_options.py │ │ └── util │ │ │ ├── __init__.py │ │ │ ├── html.py │ │ │ └── util.py │ └── visualize.py ├── models │ ├── __init__.py │ ├── common.py │ └── point_modeling.py ├── options │ ├── base_options.py │ ├── deformation_options.py │ ├── filter_options.py │ ├── finetune_options.py │ └── render_options.py ├── training │ ├── losses.py │ ├── scheduler.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── dataset.py │ ├── io.py │ ├── mathHelper.py │ ├── matrixConstruction.py │ └── sampler.py ├── README.md ├── common.py ├── config.py ├── configs ├── default.yaml └── dss.yml ├── environment.yml ├── example_data ├── mesh │ └── yoga6.ply └── pointclouds │ ├── Kangaroo_V10k.ply │ ├── Kangaroo_V10k_nc.ply │ ├── Koala_V10k.ply │ ├── Koala_V10k_nc.ply │ ├── a72-seated_jew_aligned_pca.ply │ ├── armadillo_aligned_pca.ply │ ├── bunny-8000.ply │ ├── cube_20k.ply │ ├── grid32.ply │ ├── noisy03_points │ ├── A9-vulcan_aligned_pca.ply │ ├── Gramme_aligned_pca.ply │ ├── a72-seated_jew_aligned_pca.ply │ ├── asklepios_aligned_pca.ply │ ├── baron_seutin_aligned_pca.ply │ ├── charite_-_CleanUp_-_LowPoly_aligned_pca.ply │ ├── cheval_terracotta_-_LowPoly-RealOne_aligned_pca.ply │ ├── cupid_aligned_pca.ply │ ├── dame_assise_-_CleanUp_-_LowPoly_aligned_pca.ply │ ├── drunkard_-_CleanUp_-_LowPoly_aligned_pca.ply │ ├── madeleine_aligned_pca.ply │ ├── retheur_-_LowPoly_aligned_pca.ply │ └── saint_lambert_aligned_pca.ply │ ├── noisy1_points │ ├── A9-vulcan_aligned_pca.ply │ ├── Gramme_aligned_pca.ply │ ├── a72-seated_jew_aligned_pca.ply │ ├── asklepios_aligned_pca.ply │ ├── baron_seutin_aligned_pca.ply │ ├── charite_-_CleanUp_-_LowPoly_aligned_pca.ply │ ├── cheval_terracotta_-_LowPoly-RealOne_aligned_pca.ply │ ├── cupid_aligned_pca.ply │ ├── dame_assise_-_CleanUp_-_LowPoly_aligned_pca.ply │ ├── drunkard_-_CleanUp_-_LowPoly_aligned_pca.ply │ ├── madeleine_aligned_pca.ply │ ├── retheur_-_LowPoly_aligned_pca.ply │ └── saint_lambert_aligned_pca.ply │ ├── point-one.ply │ ├── sphere_2k.ply │ ├── sphere_300.ply │ ├── sphere_normal_dense.ply │ ├── teapot_clean.ply │ ├── teapot_normal_dense.ply │ ├── yoga1_out.ply │ └── yoga6_out.ply ├── images ├── 2D_teapot.gif ├── armadillo_2_all.png ├── seated_all.png ├── teapot_2D.gif ├── teapot_3D.gif ├── teapot_sequence.gif ├── teaser.png ├── video-thumb.png ├── yoga6-1.gif └── yoga6.gif ├── learn_image_filter.py ├── requirements.txt ├── scripts ├── create_mvr_data_from_mesh.py ├── evaluatePointClouds.py ├── pcl2Mesh.py ├── poisson_sampling.mlx ├── poisson_sampling_pca.mlx ├── random_displacement.mlx └── run_meshlab_filter.sh ├── sequences.py ├── setup.py ├── test_opendr.py ├── train_mvr.py └── trained_models └── download_data.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # C++ build directory 2 | build 3 | 4 | 5 | # Created by https://www.gitignore.io/api/c++,cmake 6 | # Edit at https://www.gitignore.io/?templates=c++,cmake 7 | 8 | ### C++ ### 9 | # Prerequisites 10 | *.d 11 | 12 | # Compiled Object files 13 | *.slo 14 | *.lo 15 | *.o 16 | *.obj 17 | 18 | # Precompiled Headers 19 | *.gch 20 | *.pch 21 | 22 | # Compiled Dynamic libraries 23 | *.so 24 | *.dylib 25 | *.dll 26 | 27 | # Fortran module files 28 | *.mod 29 | *.smod 30 | 31 | # Compiled Static libraries 32 | *.lai 33 | *.la 34 | *.a 35 | *.lib 36 | 37 | # Executables 38 | *.exe 39 | *.out 40 | *.app 41 | 42 | ### CMake ### 43 | CMakeLists.txt.user 44 | CMakeCache.txt 45 | CMakeFiles 46 | CMakeScripts 47 | Testing 48 | Makefile 49 | cmake_install.cmake 50 | install_manifest.txt 51 | compile_commands.json 52 | CTestTestfile.cmake 53 | 54 | # End of https://www.gitignore.io/api/c++,cmake 55 | 56 | 57 | # Created by https://www.gitignore.io/api/vim,linux,macos,python,windows 58 | 59 | ### Linux ### 60 | *~ 61 | 62 | # temporary files which can be created if a process still has a handle open of a deleted file 63 | .fuse_hidden* 64 | 65 | # KDE directory preferences 66 | .directory 67 | 68 | # Linux trash folder which might appear on any partition or disk 69 | .Trash-* 70 | 71 | # .nfs files are created when an open file is removed but is still being accessed 72 | .nfs* 73 | 74 | ### macOS ### 75 | # General 76 | .DS_Store 77 | .AppleDouble 78 | .LSOverride 79 | 80 | # Icon must end with two \r 81 | Icon 82 | 83 | # Thumbnails 84 | ._* 85 | 86 | # Files that might appear in the root of a volume 87 | .DocumentRevisions-V100 88 | .fseventsd 89 | .Spotlight-V100 90 | .TemporaryItems 91 | .Trashes 92 | .VolumeIcon.icns 93 | .com.apple.timemachine.donotpresent 94 | 95 | # Directories potentially created on remote AFP share 96 | .AppleDB 97 | .AppleDesktop 98 | Network Trash Folder 99 | Temporary Items 100 | .apdisk 101 | 102 | 103 | ### Python ### 104 | # Byte-compiled / optimized / DLL files 105 | __pycache__/ 106 | *.py[cod] 107 | *$py.class 108 | 109 | # C extensions 110 | *.so 111 | 112 | # Distribution / packaging 113 | .Python 114 | build/ 115 | develop-eggs/ 116 | dist/ 117 | downloads/ 118 | eggs/ 119 | .eggs/ 120 | lib/ 121 | lib64/ 122 | parts/ 123 | sdist/ 124 | var/ 125 | wheels/ 126 | *.egg-info/ 127 | .installed.cfg 128 | *.egg 129 | MANIFEST 130 | 131 | # PyInstaller 132 | # Usually these files are written by a python script from a template 133 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 134 | *.manifest 135 | *.spec 136 | 137 | # Installer logs 138 | pip-log.txt 139 | pip-delete-this-directory.txt 140 | 141 | # Unit test / coverage reports 142 | htmlcov/ 143 | .tox/ 144 | .nox/ 145 | .coverage 146 | .coverage.* 147 | .cache 148 | nosetests.xml 149 | coverage.xml 150 | *.cover 151 | .hypothesis/ 152 | .pytest_cache/ 153 | 154 | # Translations 155 | *.mo 156 | *.pot 157 | 158 | # Django stuff: 159 | *.log 160 | local_settings.py 161 | db.sqlite3 162 | 163 | # Flask stuff: 164 | instance/ 165 | .webassets-cache 166 | 167 | # Scrapy stuff: 168 | .scrapy 169 | 170 | # Sphinx documentation 171 | docs/_build/ 172 | 173 | # PyBuilder 174 | target/ 175 | 176 | # Jupyter Notebook 177 | .ipynb_checkpoints 178 | 179 | # IPython 180 | profile_default/ 181 | ipython_config.py 182 | 183 | # pyenv 184 | .python-version 185 | 186 | # celery beat schedule file 187 | celerybeat-schedule 188 | 189 | # SageMath parsed files 190 | *.sage.py 191 | 192 | # Environments 193 | .env 194 | .venv 195 | env/ 196 | venv/ 197 | ENV/ 198 | env.bak/ 199 | venv.bak/ 200 | 201 | # Spyder project settings 202 | .spyderproject 203 | .spyproject 204 | 205 | # Rope project settings 206 | .ropeproject 207 | 208 | # mkdocs documentation 209 | /site 210 | 211 | # mypy 212 | .mypy_cache/ 213 | .dmypy.json 214 | dmypy.json 215 | 216 | ### Python Patch ### 217 | .venv/ 218 | 219 | ### Python.VirtualEnv Stack ### 220 | # Virtualenv 221 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 222 | [Bb]in 223 | [Ii]nclude 224 | [Ll]ib 225 | [Ll]ib64 226 | [Ll]ocal 227 | #[Ss]cripts 228 | pyvenv.cfg 229 | pip-selfcheck.json 230 | 231 | 232 | ### Vim ### 233 | # Swap 234 | [._]*.s[a-v][a-z] 235 | [._]*.sw[a-p] 236 | [._]s[a-rt-v][a-z] 237 | [._]ss[a-gi-z] 238 | [._]sw[a-p] 239 | 240 | # Session 241 | Session.vim 242 | 243 | # Temporary 244 | .netrwhist 245 | # Auto-generated tag files 246 | tags 247 | # Persistent undo 248 | [._]*.un~ 249 | 250 | ### Windows ### 251 | # Windows thumbnail cache files 252 | Thumbs.db 253 | ehthumbs.db 254 | ehthumbs_vista.db 255 | 256 | # Dump file 257 | *.stackdump 258 | 259 | # Folder config file 260 | [Dd]esktop.ini 261 | 262 | # Recycle Bin used on file shares 263 | $RECYCLE.BIN/ 264 | 265 | # Windows Installer files 266 | *.cab 267 | *.msi 268 | *.msix 269 | *.msm 270 | *.msp 271 | 272 | # Windows shortcuts 273 | *.lnk 274 | 275 | 276 | trained_models/** 277 | exp 278 | mypy/ 279 | 280 | 281 | .vscode 282 | 283 | # data generated by scripts/create_mvr_data_from_mesh.py 284 | example_data/images/ 285 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/prefix_sum"] 2 | path = external/prefix_sum 3 | url = https://github.com/lxxue/prefix_sum.git 4 | ignore = untracked 5 | [submodule "external/FRNN"] 6 | path = external/FRNN 7 | url = https://github.com/lxxue/FRNN.git 8 | ignore = untracked 9 | [submodule "external/torch-batch-svd"] 10 | path = external/torch-batch-svd 11 | url = https://github.com/KinglittleQ/torch-batch-svd.git 12 | ignore = untracked 13 | -------------------------------------------------------------------------------- /DSS/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import numpy as np 4 | from .logger import get_logger 5 | from collections import OrderedDict 6 | 7 | logger_py = get_logger(__name__) 8 | 9 | _debug = False 10 | _debugging_tensor = None 11 | 12 | def set_deterministic_(): 13 | torch.manual_seed(0) 14 | torch.backends.cudnn.deterministic = True 15 | torch.backends.cudnn.benchmark = False 16 | np.random.seed(0) 17 | 18 | # Each attribute contains list of tensors or dictionaries, where 19 | # each element in the list is a sample in the minibatch. 20 | # If dictionaries are used, then the (keys, tensor) will be used to plot 21 | # debugging visuals separately. 22 | class DebuggingTensor: 23 | __slots__ = ['pts_world', 24 | 'pts_world_grad', 25 | 'img_mask_grad'] 26 | 27 | def __init__(self,): 28 | self.pts_world = OrderedDict() 29 | self.pts_world_grad = OrderedDict() 30 | self.img_mask_grad = OrderedDict() 31 | 32 | 33 | def set_debugging_mode_(is_debug, *args, **kwargs): 34 | global _debugging_tensor, _debug 35 | _debug = is_debug 36 | if _debug: 37 | _debugging_tensor = DebuggingTensor(*args, **kwargs) 38 | logger_py.info('Enabled debugging mode.') 39 | else: 40 | _debugging_tensor = None 41 | 42 | 43 | def get_debugging_mode(): 44 | return _debug 45 | 46 | 47 | def get_debugging_tensor(): 48 | if _debugging_tensor is None: 49 | logger_py.warning( 50 | 'Attempt to get debugging tensor before setting debugging mode to true.') 51 | set_debugging_mode_(True) 52 | return _debugging_tensor 53 | 54 | 55 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 56 | -------------------------------------------------------------------------------- /DSS/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/DSS/core/__init__.py -------------------------------------------------------------------------------- /DSS/core/camera.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch3d.renderer.cameras import (PerspectiveCameras, 3 | look_at_view_transform) 4 | 5 | 6 | class CameraSampler(object): 7 | """ 8 | create camera transformations looking at the origin of the coordinate 9 | from varying distance 10 | 11 | Attributes: 12 | R, T: (num_cams_total, 3, 3) and (num_cams_total, 3) 13 | camera_type (Class): class to create a new camera 14 | camera_params (dict): camera parameters to call camera_type 15 | (besides R, T) 16 | """ 17 | 18 | def __init__(self, num_cams_total, num_cams_batch, 19 | distance_range=(5, 10), sort_distance=True, 20 | return_cams=True, 21 | camera_type=PerspectiveCameras, camera_params=None): 22 | """ 23 | Args: 24 | num_cams_total (int): the total number of cameras to sample 25 | num_cams_batch (int): the number of cameras per iteration 26 | distance_range (tensor or list): (num_cams_total, 2) or (1, 2) 27 | the range of camera distance for uniform sampling 28 | sort_distance: sort the created camera transformations by the 29 | distance in ascending order 30 | return_cams (bool): whether to return camera instances or just the R,T 31 | camera_type (class): camera type from pytorch3d.renderer.cameras 32 | camera_params (dict): camera parameters besides R, T 33 | """ 34 | self.num_cams_batch = num_cams_batch 35 | self.num_cams_total = num_cams_total 36 | 37 | self.sort_distance = sort_distance 38 | self.camera_type = camera_type 39 | self.camera_params = {} if camera_params is None else camera_params 40 | 41 | # create camera locations 42 | distance_scale = distance_range[:, -1] - distance_range[:, 0] 43 | distances = torch.rand(num_cams_total) * distance_scale + \ 44 | distance_range[:, 0] 45 | if sort_distance: 46 | distances, _ = distances.sort(descending=True) 47 | azim = torch.rand(num_cams_total) * 360 - 180 48 | elev = torch.rand(num_cams_total) * 180 - 90 49 | at = torch.rand((num_cams_total, 3)) * 0.1 - 0.05 50 | self.R, self.T = look_at_view_transform( 51 | distances, elev, azim, at=at, degrees=True) 52 | 53 | self._idx = 0 54 | 55 | def __len__(self): 56 | return (self.R.shape[0] + self.num_cams_batch - 1) // \ 57 | self.num_cams_batch 58 | 59 | def __iter__(self): 60 | return self 61 | 62 | def __next__(self): 63 | if self._idx >= len(self): 64 | raise StopIteration 65 | start_idx = self._idx * self.num_cams_batch 66 | end_idx = min(start_idx + self.num_cams_batch, self.R.shape[0]) 67 | cameras = self.camera_type(R=self.R[start_idx:end_idx], 68 | T=self.T[start_idx:end_idx], 69 | **self.camera_params) 70 | self._idx += 1 71 | return cameras 72 | -------------------------------------------------------------------------------- /DSS/core/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch3d.renderer import PointsRenderer, NormWeightedCompositor 3 | from pytorch3d.renderer.compositing import weighted_sum 4 | from .. import logger_py 5 | 6 | 7 | __all__ = ['SurfaceSplattingRenderer'] 8 | 9 | """ 10 | Returns a 4-Channel image for RGBA 11 | """ 12 | 13 | 14 | class SurfaceSplattingRenderer(PointsRenderer): 15 | 16 | def __init__(self, rasterizer, compositor, antialiasing_sigma: float = 1.0, 17 | density: float = 1e-4, frnn_radius=-1): 18 | super().__init__(rasterizer, compositor) 19 | 20 | self.cameras = self.rasterizer.cameras 21 | self._Vrk_h = None 22 | # screen space low pass filter 23 | self.antialiasing_sigma = antialiasing_sigma 24 | # average of squared distance to the nearest neighbors 25 | self.density = density 26 | 27 | if self.compositor is None: 28 | logger_py.info('Composite with weighted sum.') 29 | elif not isinstance(self.compositor, NormWeightedCompositor): 30 | logger_py.warning('Expect a NormWeightedCompositor, but initialized with {}'.format( 31 | self.compositor.__class__.__name__)) 32 | 33 | self.frnn_radius = frnn_radius 34 | # logger_py.error("frnn_radius: {}".format(frnn_radius)) 35 | 36 | def forward(self, point_clouds, **kwargs) -> torch.Tensor: 37 | """ 38 | point_clouds_filter: used to get activation mask and update visibility mask 39 | cutoff_threshold 40 | """ 41 | if point_clouds.isempty(): 42 | return None 43 | 44 | # rasterize 45 | fragments = kwargs.get('fragments', None) 46 | if fragments is None: 47 | if kwargs.get('verbose', False): 48 | fragments, point_clouds, per_point_info = self.rasterizer(point_clouds, **kwargs) 49 | else: 50 | fragments, point_clouds = self.rasterizer(point_clouds, **kwargs) 51 | 52 | # compute weight: scalar*exp(-0.5Q) 53 | weights = torch.exp(-0.5 * fragments.qvalue) * fragments.scaler 54 | weights = weights.permute(0, 3, 1, 2) 55 | 56 | # from fragments to rgba 57 | pts_rgb = point_clouds.features_packed()[:, :3] 58 | 59 | if self.compositor is None: 60 | # NOTE: weight _splat_points_weights_backward, weighted sum will return 61 | # zero gradient for the weights. 62 | images = weighted_sum(fragments.idx.long().permute(0, 3, 1, 2), 63 | weights, 64 | pts_rgb.permute(1, 0), 65 | **kwargs) 66 | else: 67 | images = self.compositor( 68 | fragments.idx.long().permute(0, 3, 1, 2), 69 | weights, 70 | pts_rgb.permute(1, 0), 71 | **kwargs 72 | ) 73 | 74 | # permute so image comes at the end 75 | images = images.permute(0, 2, 3, 1) 76 | mask = fragments.occupancy 77 | 78 | images = torch.cat([images, mask.unsqueeze(-1)], dim=-1) 79 | 80 | if kwargs.get('verbose', False): 81 | return images, fragments 82 | return images 83 | -------------------------------------------------------------------------------- /DSS/core/texture.py: -------------------------------------------------------------------------------- 1 | """ 2 | PointTexture class 3 | 4 | Inputs should be fragments (including point location, 5 | normals and other features) 6 | Output is the color per point (doesn't have the blending step) 7 | 8 | diffuse shader 9 | specular shader 10 | neural shader 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from pytorch3d.renderer.cameras import OrthographicCameras 16 | from .lighting import DirectionalLights 17 | from .cloud import PointClouds3D 18 | from .. import logger_py 19 | from ..utils import gather_batch_to_packed 20 | 21 | 22 | __all__ = ["LightingTexture", "NeuralTexture"] 23 | 24 | 25 | def apply_lighting(points, normals, lights, cameras, 26 | specular=True, shininess=64): 27 | """ 28 | Args: 29 | points: torch tensor of shape (N, P, 3) or (P, 3). 30 | normals: torch tensor of shape (N, P, 3) or (P, 3) 31 | lights: instance of the Lights class. 32 | cameras: instance of the Cameras class. 33 | shininess: scalar for the specular coefficient. 34 | specular: (bool) whether to add the specular effect 35 | 36 | Returns: 37 | ambient_color: same shape as materials.ambient_color 38 | diffuse_color: same shape as the input points 39 | specular_color: same shape as the input points 40 | """ 41 | light_diffuse = lights.diffuse(normals=normals, points=points) 42 | light_specular = lights.specular( 43 | normals=normals, 44 | points=points, 45 | camera_position=cameras.get_camera_center(), 46 | shininess=shininess, 47 | ) 48 | ambient_color = lights.ambient_color 49 | if ambient_color.ndim==3: 50 | if ambient_color.shape[1] > 1: 51 | logger_py.warn('Found multiple ambient colors') 52 | ambient_color = torch.sum(ambient_color, dim=1) 53 | diffuse_color = light_diffuse 54 | specular_color = light_specular 55 | if normals.dim() == 2 and points.dim() == 2: 56 | # If given packed inputs remove batch dim in output. 57 | return ( 58 | ambient_color.squeeze(0), 59 | diffuse_color.squeeze(0), 60 | specular_color.squeeze(0), 61 | ) 62 | return ambient_color, diffuse_color, specular_color 63 | 64 | 65 | class LightingTexture(nn.Module): 66 | def __init__(self, device="cpu", 67 | cameras=None, lights=None, materials=None): 68 | super().__init__() 69 | self.lights = lights 70 | self.cameras = cameras 71 | if materials is not None: 72 | logger_py.warning("Material is not supported, ignored.") 73 | 74 | def forward(self, pointclouds, shininess=64, **kwargs) -> PointClouds3D: 75 | """ 76 | Args: 77 | pointclouds (Pointclouds3D) 78 | points_rgb (P, 3): same shape as the packed features 79 | Returns: 80 | pointclouds (Pointclouds3D) with features set to RGB colors 81 | """ 82 | if pointclouds.isempty(): 83 | return pointclouds 84 | 85 | lights = kwargs.get("lights", self.lights).to(pointclouds.device) 86 | cameras = kwargs.get("cameras", self.cameras).to(pointclouds.device) 87 | if len(cameras) != len(pointclouds) and len(pointclouds) == 1: 88 | pointclouds = pointclouds.extend(len(cameras)) 89 | points = pointclouds.points_packed() 90 | point_normals = pointclouds.normals_packed() 91 | points_rgb = kwargs.get("points_rgb", None) 92 | if points_rgb is None: 93 | try: 94 | points_rgb = pointclouds.features_packed()[:, :3] 95 | except: 96 | points_rgb = torch.ones_like(points) 97 | 98 | if point_normals is None: 99 | logger_py.warning("Point normals are required, " 100 | "but not available in pointclouds. " 101 | "Using estimated normals instead.") 102 | 103 | vert_to_cloud_idx = pointclouds.packed_to_cloud_idx() 104 | if points_rgb.shape[-1] != 3: 105 | raise ValueError("Expected points_rgb to be 3-channel," 106 | "got {}".format(points_rgb.shape)) 107 | 108 | # Format properties of lights and materials so they are compatible 109 | # with the packed representation of the vertices. This transforms 110 | # all tensor properties in the class from shape (N, ...) -> (V, ...) where 111 | # V is the number of packed vertices. If the number of meshes in the 112 | # batch is one then this is not necessary. 113 | if len(pointclouds) > 1: 114 | lights = lights.clone().gather_props(vert_to_cloud_idx) 115 | cameras = cameras.clone().gather_props(vert_to_cloud_idx) 116 | 117 | # Calculate the illumination at each point 118 | ambient, diffuse, specular = apply_lighting( 119 | points, point_normals, lights, cameras, 120 | shininess=shininess, 121 | ) 122 | points_colors_shaded = points_rgb * (ambient + diffuse) + specular 123 | 124 | pointclouds_colored = pointclouds.clone() 125 | pointclouds_colored.update_features_(points_colors_shaded) 126 | 127 | return pointclouds_colored 128 | 129 | 130 | class NeuralTexture(nn.Module): 131 | def __init__(self, decoder, view_dependent=True): 132 | super().__init__() 133 | self.view_dependent = view_dependent 134 | self.decoder = decoder 135 | 136 | def forward(self, pointclouds: PointClouds3D, c=None, **kwargs) -> PointClouds3D: 137 | if self.decoder.dim == 3 and not self.view_dependent: 138 | x = pointclouds.points_packed() 139 | else: 140 | x = pointclouds.normals_packed() 141 | assert(x is not None) 142 | # x = F.normalize(x, dim=-1, eps=1e-15) 143 | p = pointclouds.points_packed() 144 | x = torch.cat([x,p], dim=-1) 145 | if self.view_dependent: 146 | cameras = kwargs.get('cameras', None) 147 | if cameras is not None: 148 | cameras = cameras.to(pointclouds.device) 149 | cam_pos = cameras.get_camera_center() 150 | cam_pos = gather_batch_to_packed( 151 | cam_pos, pointclouds.packed_to_cloud_idx()) 152 | view_direction = p[...,:3].detach() - cam_pos 153 | view_direction = F.normalize(view_direction, dim=-1) 154 | if hasattr(self.decoder, 'embed_fn') and self.decoder.embed_fn is not None: 155 | view_direction = self.decoder.embed_fn(view_direction) 156 | x = torch.cat([x, view_direction], dim=-1) 157 | 158 | 159 | output = self.decoder(x, c=c, **kwargs) 160 | pointclouds_colored = pointclouds.clone() 161 | pointclouds_colored.update_features_(output.rgb) 162 | return pointclouds_colored 163 | -------------------------------------------------------------------------------- /DSS/csrc/bitmask.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #pragma once 4 | #define BINMASK_H 5 | 6 | // A BitMask represents a bool array of shape (H, W, N). We pack values into 7 | // the bits of unsigned ints; a single unsigned int has B = 32 bits, so to hold 8 | // all values we use H * W * (N / B) = H * W * D values. We want to store 9 | // BitMasks in shared memory, so we assume that the memory has already been 10 | // allocated for it elsewhere. 11 | class BitMask { 12 | public: 13 | __device__ BitMask(unsigned int* data, int H, int W, int N) 14 | : data(data), H(H), W(W), B(8 * sizeof(unsigned int)), D(N / B) { 15 | // TODO: check if the data is null. 16 | N = ceilf(N % 32); // take ceil incase N % 32 != 0 17 | block_clear(); // clear the data 18 | } 19 | 20 | // Use all threads in the current block to clear all bits of this BitMask 21 | __device__ void block_clear() { 22 | for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) { 23 | data[i] = 0; 24 | } 25 | __syncthreads(); 26 | } 27 | 28 | __device__ int _get_elem_idx(int y, int x, int d) { 29 | return y * W * D + x * D + d / B; 30 | } 31 | 32 | __device__ int _get_bit_idx(int d) { 33 | return d % B; 34 | } 35 | 36 | // Turn on a single bit (y, x, d) 37 | __device__ void set(int y, int x, int d) { 38 | int elem_idx = _get_elem_idx(y, x, d); 39 | int bit_idx = _get_bit_idx(d); 40 | const unsigned int mask = 1U << bit_idx; 41 | atomicOr(data + elem_idx, mask); 42 | } 43 | 44 | // Turn off a single bit (y, x, d) 45 | __device__ void unset(int y, int x, int d) { 46 | int elem_idx = _get_elem_idx(y, x, d); 47 | int bit_idx = _get_bit_idx(d); 48 | const unsigned int mask = ~(1U << bit_idx); 49 | atomicAnd(data + elem_idx, mask); 50 | } 51 | 52 | // Check whether the bit (y, x, d) is on or off 53 | __device__ bool get(int y, int x, int d) { 54 | int elem_idx = _get_elem_idx(y, x, d); 55 | int bit_idx = _get_bit_idx(d); 56 | return (data[elem_idx] >> bit_idx) & 1U; 57 | } 58 | 59 | // Compute the number of bits set in the row (y, x, :) 60 | __device__ int count(int y, int x) { 61 | int total = 0; 62 | for (int i = 0; i < D; ++i) { 63 | int elem_idx = y * W * D + x * D + i; 64 | unsigned int elem = data[elem_idx]; 65 | total += __popc(elem); 66 | } 67 | return total; 68 | } 69 | 70 | private: 71 | unsigned int* data; 72 | int H, W, B, D; 73 | }; 74 | -------------------------------------------------------------------------------- /DSS/csrc/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | #include 4 | #include 5 | 6 | #define TOTAL_THREADS 512 7 | 8 | inline int opt_n_threads(int work_size) { 9 | // round work_size to power of 2 between 512 and 1 10 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 11 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 12 | } 13 | 14 | inline dim3 opt_block_config(int x, int y) { 15 | const int x_threads = opt_n_threads(x); 16 | const int y_threads = 17 | std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 18 | dim3 block_config(x_threads, y_threads, 1); 19 | 20 | return block_config; 21 | } 22 | 23 | #define CUDA_CHECK_ERRORS() \ 24 | do { \ 25 | cudaError_t err = cudaGetLastError(); \ 26 | if (cudaSuccess != err) { \ 27 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 28 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 29 | __FILE__); \ 30 | exit(-1); \ 31 | } \ 32 | } while (0) 33 | #endif 34 | -------------------------------------------------------------------------------- /DSS/csrc/ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "rasterize_points.h" 3 | 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // module docstring 7 | m.doc() = "pybind11 compute_visibility_maps plugin"; 8 | m.def("splat_points", &RasterizePoints); 9 | m.def("_splat_points_naive", &RasterizePointsNaive); 10 | m.def("_splat_points_occ_backward", &RasterizePointsOccBackward); 11 | m.def("_rasterize_coarse", &RasterizePointsCoarse); 12 | m.def("_rasterize_fine", &RasterizePointsFine); 13 | #ifdef WITH_CUDA 14 | m.def("_splat_points_occ_fast_cuda_backward", &RasterizePointsBackwardCudaFast); 15 | #endif 16 | m.def("_splat_points_occ_backward", &RasterizePointsOccBackward); 17 | m.def("_backward_zbuf", &RasterizeZbufBackward); 18 | } 19 | -------------------------------------------------------------------------------- /DSS/csrc/macros.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #define CHECK_INPUT(x) \ 5 | TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); \ 6 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); -------------------------------------------------------------------------------- /DSS/csrc/rasterization_utils.cuh: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #pragma once 4 | 5 | // Given a pixel coordinate 0 <= i < S, convert it to a normalized device 6 | // coordinate in the range [-1, 1]. We divide the NDC range into S evenly-sized 7 | // pixels, and assume that each pixel falls in the *center* of its range. 8 | __device__ inline float PixToNdc(int i, int S) { 9 | // NDC x-offset + (i * pixel_width + half_pixel_width) 10 | return -1 + (2 * i + 1.0f) / S; 11 | } 12 | 13 | // The maximum number of points per pixel that we can return. Since we use 14 | // thread-local arrays to hold and sort points, the maximum size of the array 15 | // needs to be known at compile time. There might be some fancy template magic 16 | // we could use to make this more dynamic, but for now just fix a constant. 17 | // TODO: is 8 enough? Would increasing have performance considerations? 18 | const int32_t kMaxPointsPerPixel = 150; 19 | 20 | const int32_t kMaxFacesPerBin = 22; 21 | 22 | template 23 | __device__ inline void BubbleSort(T* arr, int n) { 24 | // Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this 25 | // regime we care more about warp divergence than computational complexity. 26 | for (int i = 0; i < n - 1; ++i) { 27 | for (int j = 0; j < n - i - 1; ++j) { 28 | if (arr[j + 1] < arr[j]) { 29 | T temp = arr[j]; 30 | arr[j] = arr[j + 1]; 31 | arr[j + 1] = temp; 32 | } 33 | } 34 | } 35 | } 36 | 37 | template 38 | __device__ inline T eps_denom(const T denom, const T eps) 39 | { 40 | int denom_sign = (T(0.0) < denom) - (denom < T(0.0)); 41 | T safe_denom = T(denom_sign) * max(abs(denom), eps); 42 | return safe_denom; 43 | } 44 | template 45 | __device__ inline T clamp(const T x, const T a, const T b) 46 | { 47 | return max(a, min(b, x)); 48 | } -------------------------------------------------------------------------------- /DSS/csrc/types.hpp: -------------------------------------------------------------------------------- 1 | using PointIndex = int; 2 | using Coord = int; 3 | using Float = float; 4 | 5 | -------------------------------------------------------------------------------- /DSS/csrc/weighted_sum.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | #include 4 | #include "utils/pytorch3d_cutils.h" 5 | 6 | #include 7 | 8 | // Perform weighted sum compositing of points in a z-buffer. 9 | // 10 | // Inputs: 11 | // features: FloatTensor of shape (C, P) which gives the features 12 | // of each point where C is the size of the feature and 13 | // P the number of points. 14 | // alphas: FloatTensor of shape (N, points_per_pixel, W, W) where 15 | // points_per_pixel is the number of points in the z-buffer 16 | // sorted in z-order, and W is the image size. 17 | // points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the 18 | // indices of the nearest points at each pixel, sorted in z-order. 19 | // Returns: 20 | // weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated 21 | // feature in each point. Concretely, it gives: 22 | // weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * 23 | // features[c,points_idx[b,k,i,j]] 24 | 25 | // CUDA declarations 26 | #ifdef WITH_CUDA 27 | torch::Tensor weightedSumCudaForward( 28 | const torch::Tensor& features, 29 | const torch::Tensor& alphas, 30 | const torch::Tensor& points_idx); 31 | 32 | std::tuple weightedSumCudaBackward( 33 | const torch::Tensor& grad_outputs, 34 | const torch::Tensor& features, 35 | const torch::Tensor& alphas, 36 | const torch::Tensor& points_idx); 37 | #endif 38 | 39 | // C++ declarations 40 | torch::Tensor weightedSumCpuForward( 41 | const torch::Tensor& features, 42 | const torch::Tensor& alphas, 43 | const torch::Tensor& points_idx); 44 | 45 | std::tuple weightedSumCpuBackward( 46 | const torch::Tensor& grad_outputs, 47 | const torch::Tensor& features, 48 | const torch::Tensor& alphas, 49 | const torch::Tensor& points_idx); 50 | 51 | torch::Tensor weightedSumForward( 52 | torch::Tensor& features, 53 | torch::Tensor& alphas, 54 | torch::Tensor& points_idx) { 55 | features = features.contiguous(); 56 | alphas = alphas.contiguous(); 57 | points_idx = points_idx.contiguous(); 58 | 59 | if (features.is_cuda()) { 60 | #ifdef WITH_CUDA 61 | CHECK_CUDA(features); 62 | CHECK_CUDA(alphas); 63 | CHECK_CUDA(points_idx); 64 | return weightedSumCudaForward(features, alphas, points_idx); 65 | #else 66 | AT_ERROR("Not compiled with GPU support"); 67 | #endif 68 | } else { 69 | return weightedSumCpuForward(features, alphas, points_idx); 70 | } 71 | } 72 | 73 | std::tuple weightedSumBackward( 74 | torch::Tensor& grad_outputs, 75 | torch::Tensor& features, 76 | torch::Tensor& alphas, 77 | torch::Tensor& points_idx) { 78 | grad_outputs = grad_outputs.contiguous(); 79 | features = features.contiguous(); 80 | alphas = alphas.contiguous(); 81 | points_idx = points_idx.contiguous(); 82 | 83 | if (grad_outputs.is_cuda()) { 84 | #ifdef WITH_CUDA 85 | CHECK_CUDA(grad_outputs); 86 | CHECK_CUDA(features); 87 | CHECK_CUDA(alphas); 88 | CHECK_CUDA(points_idx); 89 | 90 | return weightedSumCudaBackward(grad_outputs, features, alphas, points_idx); 91 | #else 92 | AT_ERROR("Not compiled with GPU support"); 93 | #endif 94 | } else { 95 | return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /DSS/logger.py: -------------------------------------------------------------------------------- 1 | """ From https://github.com/t177398/best_python_logger """ 2 | import logging 3 | import sys 4 | 5 | 6 | def color_cheat_sheet(): 7 | # This doesn't work very good in IDEs python consoles. 8 | terse = "-t" in sys.argv[1:] or "--terse" in sys.argv[1:] 9 | write = sys.stdout.write 10 | for i in range(2 if terse else 10): 11 | for j in range(30, 38): 12 | for k in range(40, 48): 13 | if terse: 14 | write("\33[%d;%d;%dm%d;%d;%d\33[m " % (i, j, k, i, j, k)) 15 | else: 16 | write("%d;%d;%d: \33[%d;%d;%dm Hello, World! \33[m \n" % 17 | (i, j, k, i, j, k,)) 18 | write("\n") 19 | 20 | 21 | class _CustomFormatter(logging.Formatter): 22 | """Logging Formatter to add colors and count warning / errors""" 23 | 24 | grey = "\x1b[0;37m" 25 | green = "\x1b[1;32m" 26 | yellow = "\x1b[1;33m" 27 | red = "\x1b[1;31m" 28 | purple = "\x1b[1;35m" 29 | blue = "\x1b[1;34m" 30 | light_blue = "\x1b[1;36m" 31 | reset = "\x1b[0m" 32 | blink_red = "\x1b[5m\x1b[1;31m" 33 | format_prefix = f"{purple}%(asctime)s{reset} " \ 34 | f"{blue}%(name)s{reset} " \ 35 | f"{light_blue}(%(filename)s:%(lineno)d){reset} " 36 | 37 | format_suffix = "%(levelname)s - %(message)s" 38 | 39 | FORMATS = { 40 | logging.DEBUG: format_prefix + green + format_suffix + reset, 41 | logging.INFO: format_prefix + grey + format_suffix + reset, 42 | logging.WARNING: format_prefix + yellow + format_suffix + reset, 43 | logging.ERROR: format_prefix + red + format_suffix + reset, 44 | logging.CRITICAL: format_prefix + blink_red + format_suffix + reset 45 | } 46 | 47 | def format(self, record): 48 | log_fmt = self.FORMATS.get(record.levelno) 49 | formatter = logging.Formatter(log_fmt) 50 | return formatter.format(record) 51 | 52 | 53 | # Just import this function into your programs 54 | # "from logger import get_logger" 55 | # "logger = get_logger(__name__)" 56 | # Use the variable __name__ so the logger will print the file's name also 57 | 58 | def get_logger(name): 59 | logger = logging.getLogger(name) 60 | logger.setLevel(logging.DEBUG) 61 | ch = logging.StreamHandler() 62 | ch.setLevel(logging.DEBUG) 63 | ch.setFormatter(_CustomFormatter()) 64 | logger.addHandler(ch) 65 | return logger -------------------------------------------------------------------------------- /DSS/misc/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | from .. import logger_py 4 | 5 | 6 | class Thread(threading.Thread): 7 | def __init__(self, target, name='', args=(), kwargs={}): 8 | super().__init__(target=target, name=name, args=args, kwargs=kwargs) 9 | self.args = args 10 | self.kwargs = kwargs 11 | self.name 12 | 13 | def run(self): 14 | t0 = time.time() 15 | super().run() 16 | t1 = time.time() 17 | logger_py.info('{}: {:.3f} seconds'.format(self.name, t1 - t0)) 18 | -------------------------------------------------------------------------------- /DSS/misc/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import torch 4 | from torch.utils import model_zoo 5 | import shutil 6 | import datetime 7 | 8 | 9 | class CheckpointIO(object): 10 | ''' CheckpointIO class. 11 | 12 | It handles saving and loading checkpoints. 13 | 14 | Args: 15 | checkpoint_dir (str): path where checkpoints are saved 16 | ''' 17 | 18 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 19 | self.module_dict = kwargs 20 | self.checkpoint_dir = checkpoint_dir 21 | if not os.path.exists(checkpoint_dir): 22 | os.makedirs(checkpoint_dir) 23 | 24 | def register_modules(self, **kwargs): 25 | ''' Registers modules in current module dictionary. 26 | ''' 27 | self.module_dict.update(kwargs) 28 | 29 | def save(self, filename, **kwargs): 30 | ''' Saves the current module dictionary. 31 | 32 | Args: 33 | filename (str): name of output file 34 | ''' 35 | if not os.path.isabs(filename): 36 | filename = os.path.join(self.checkpoint_dir, filename) 37 | 38 | outdict = kwargs 39 | for k, v in self.module_dict.items(): 40 | outdict[k] = v.state_dict() 41 | torch.save(outdict, filename) 42 | 43 | def backup_model_best(self, filename, **kwargs): 44 | if not os.path.isabs(filename): 45 | filename = os.path.join(self.checkpoint_dir, filename) 46 | if os.path.exists(filename): 47 | # Backup model 48 | backup_dir = os.path.join(self.checkpoint_dir, 'backup_model_best') 49 | if not os.path.exists(backup_dir): 50 | os.makedirs(backup_dir) 51 | ts = datetime.datetime.now().timestamp() 52 | filename_backup = os.path.join(backup_dir, '%s.pt' % ts) 53 | shutil.copy(filename, filename_backup) 54 | 55 | def load(self, filename): 56 | '''Loads a module dictionary from local file or url. 57 | 58 | Args: 59 | filename (str): name of saved module dictionary 60 | ''' 61 | if is_url(filename): 62 | return self.load_url(filename) 63 | else: 64 | return self.load_file(filename) 65 | 66 | def load_file(self, filename): 67 | '''Loads a module dictionary from file. 68 | 69 | Args: 70 | filename (str): name of saved module dictionary 71 | ''' 72 | 73 | if not os.path.isabs(filename): 74 | filename = os.path.join(self.checkpoint_dir, filename) 75 | 76 | if os.path.exists(filename): 77 | print(filename) 78 | print('=> Loading checkpoint from local file...', end='') 79 | state_dict = torch.load(filename) 80 | scalars = self.parse_state_dict(state_dict) 81 | print('Done!') 82 | return scalars 83 | else: 84 | raise FileExistsError 85 | 86 | def load_url(self, url): 87 | '''Load a module dictionary from url. 88 | 89 | Args: 90 | url (str): url to saved model 91 | ''' 92 | print(url) 93 | print('=> Loading checkpoint from url...', end='') 94 | state_dict = model_zoo.load_url(url, progress=True) 95 | scalars = self.parse_state_dict(state_dict) 96 | print('Done!') 97 | return scalars 98 | 99 | def parse_state_dict(self, state_dict): 100 | '''Parse state_dict of model and return scalars. 101 | 102 | Args: 103 | state_dict (dict): State dict of model 104 | ''' 105 | 106 | for k, v in self.module_dict.items(): 107 | if k in state_dict: 108 | if isinstance(v, torch.optim.Optimizer): 109 | v.load_state_dict(state_dict[k]) 110 | else: 111 | missing_keys, unexpected_keys = v.load_state_dict(state_dict[k], strict=False) 112 | if len(missing_keys) > 0: 113 | print('Warning: Could not find %s in checkpoint!' % missing_keys) 114 | if len(unexpected_keys) > 0: 115 | print('Warning: Found unexpectedly %s in checkpoint!' % unexpected_keys) 116 | 117 | else: 118 | print('Warning: Could not find %s in checkpoint!' % k) 119 | scalars = {k: v for k, v in state_dict.items() 120 | if k not in self.module_dict} 121 | return scalars 122 | 123 | 124 | def is_url(url): 125 | ''' Checks if input string is a URL. 126 | 127 | Args: 128 | url (string): URL 129 | ''' 130 | scheme = urllib.parse.urlparse(url).scheme 131 | return scheme in ('http', 'https') 132 | -------------------------------------------------------------------------------- /DSS/misc/imageFilters.py: -------------------------------------------------------------------------------- 1 | # https://github.com/HTDerekLiu/Paparazzi/blob/master/utils/imageL0Smooth.py 2 | # Referneces: 3 | # 1. Xu et al. "Image Smoothing via L0 Gradient Minimization", 2011 4 | # 2. This code is adapted from https://github.com/t-suzuki/l0_gradient_minimization_test 5 | 6 | import numpy as np 7 | from scipy.fftpack import fft2, ifft2 8 | import skimage 9 | from skimage.segmentation import slic 10 | import torch 11 | 12 | 13 | def box(img, r): 14 | """ O(1) box filter 15 | img - >= 2d image 16 | r - radius of box filter 17 | """ 18 | (rows, cols) = img.shape[:2] 19 | imDst = np.zeros_like(img) 20 | 21 | tile = [1] * img.ndim 22 | tile[0] = r 23 | imCum = np.cumsum(img, 0) 24 | imDst[0:r+1, :, ...] = imCum[r:2*r+1, :, ...] 25 | imDst[r+1:rows-r, :, ...] = imCum[2*r+1:rows, :, ...] - imCum[0:rows-2*r-1, :, ...] 26 | imDst[rows-r:rows, :, ...] = np.tile(imCum[rows-1:rows, :, ...], tile) - imCum[rows-2*r-1:rows-r-1, :, ...] 27 | 28 | tile = [1] * img.ndim 29 | tile[1] = r 30 | imCum = np.cumsum(imDst, 1) 31 | imDst[:, 0:r+1, ...] = imCum[:, r:2*r+1, ...] 32 | imDst[:, r+1:cols-r, ...] = imCum[:, 2*r+1: cols, ...] - imCum[:, 0: cols-2*r-1, ...] 33 | imDst[:, cols-r: cols, ...] = np.tile(imCum[:, cols-1:cols, ...], tile) - imCum[:, cols-2*r-1: cols-r-1, ...] 34 | 35 | return imDst 36 | 37 | 38 | def gf(I, p, r, eps, s=None): 39 | """ Color guided filter 40 | I - guide image (rgb) 41 | p - filtering input (single channel) 42 | r - window radius 43 | eps - regularization (roughly, variance of non-edge noise) 44 | s - subsampling factor for fast guided filter 45 | """ 46 | fullI = I 47 | fullP = p 48 | if s is not None: 49 | I = scipy.ndimage.zoom(fullI, [1/s, 1/s, 1], order=1) 50 | p = scipy.ndimage.zoom(fullP, [1/s, 1/s], order=1) 51 | r = round(r / s) 52 | 53 | h, w = p.shape[:2] 54 | N = box(np.ones((h, w)), r) 55 | 56 | mI_r = box(I[:, :, 0], r) / N 57 | mI_g = box(I[:, :, 1], r) / N 58 | mI_b = box(I[:, :, 2], r) / N 59 | 60 | mP = box(p, r) / N 61 | 62 | # mean of I * p 63 | mIp_r = box(I[:, :, 0]*p, r) / N 64 | mIp_g = box(I[:, :, 1]*p, r) / N 65 | mIp_b = box(I[:, :, 2]*p, r) / N 66 | 67 | # per-patch covariance of (I, p) 68 | covIp_r = mIp_r - mI_r * mP 69 | covIp_g = mIp_g - mI_g * mP 70 | covIp_b = mIp_b - mI_b * mP 71 | 72 | # symmetric covariance matrix of I in each patch: 73 | # rr rg rb 74 | # rg gg gb 75 | # rb gb bb 76 | var_I_rr = box(I[:, :, 0] * I[:, :, 0], r) / N - mI_r * mI_r 77 | var_I_rg = box(I[:, :, 0] * I[:, :, 1], r) / N - mI_r * mI_g 78 | var_I_rb = box(I[:, :, 0] * I[:, :, 2], r) / N - mI_r * mI_b 79 | 80 | var_I_gg = box(I[:, :, 1] * I[:, :, 1], r) / N - mI_g * mI_g 81 | var_I_gb = box(I[:, :, 1] * I[:, :, 2], r) / N - mI_g * mI_b 82 | 83 | var_I_bb = box(I[:, :, 2] * I[:, :, 2], r) / N - mI_b * mI_b 84 | 85 | a = np.zeros((h, w, 3)) 86 | for i in range(h): 87 | for j in range(w): 88 | sig = np.array([ 89 | [var_I_rr[i, j], var_I_rg[i, j], var_I_rb[i, j]], 90 | [var_I_rg[i, j], var_I_gg[i, j], var_I_gb[i, j]], 91 | [var_I_rb[i, j], var_I_gb[i, j], var_I_bb[i, j]] 92 | ]) 93 | covIp = np.array([covIp_r[i, j], covIp_g[i, j], covIp_b[i, j]]) 94 | a[i, j, :] = np.linalg.solve(sig + eps * np.eye(3), covIp) 95 | 96 | b = mP - a[:, :, 0] * mI_r - a[:, :, 1] * mI_g - a[:, :, 2] * mI_b 97 | 98 | meanA = box(a, r) / N[..., np.newaxis] 99 | meanB = box(b, r) / N 100 | 101 | if s is not None: 102 | meanA = scipy.ndimage.zoom(meanA, [s, s, 1], order=1) 103 | meanB = scipy.ndimage.zoom(meanB, [s, s], order=1) 104 | 105 | q = np.sum(meanA * fullI, axis=2) + meanB 106 | 107 | return q 108 | 109 | 110 | def SuperPixel(images): 111 | # SLIC superpixel [Achanta et al. 2012] 112 | compactness = 20 113 | numSegments = 150 114 | maxIter = 3000 115 | imgSize = 256 116 | results = [None] * len(images) 117 | for idx, I in enumerate(images): 118 | isTensor = False 119 | if isinstance(I, torch.Tensor): 120 | device = I.device 121 | I = I.cpu().numpy() 122 | isTensor = True 123 | # compute FFT denominator (second part only) 124 | segs = skimage.segmentation.slic(I, compactness=compactness, n_segments=numSegments, enforce_connectivity=False) 125 | S = skimage.color.label2rgb(segs, I, kind='avg') 126 | if isTensor: 127 | results[idx] = torch.from_numpy(S).to(device=device, dtype=torch.float) 128 | else: 129 | results[idx] = S 130 | return results 131 | 132 | 133 | def L0Smooth(images, lmd=0.05): 134 | results = [] 135 | for idx, I in enumerate(images): 136 | betaMax = 1e5 137 | beta = 0.1 138 | betaRate = 2.0 139 | numIter = 40 140 | isTensor = False 141 | if isinstance(I, torch.Tensor): 142 | device = I.device 143 | I = I.cpu().numpy() 144 | isTensor = True 145 | # compute FFT denominator (second part only) 146 | FI = fft2(I, axes=(0, 1)) 147 | dx = np.zeros((I.shape[0], I.shape[1])) # gradient along x direction 148 | dy = np.zeros((I.shape[0], I.shape[1])) # gradient along y direction 149 | dx[dx.shape[0]//2, dx.shape[1]//2-1:dx.shape[1]//2+1] = [-1, 1] 150 | dy[dy.shape[0]//2-1:dy.shape[0]//2+1, dy.shape[1]//2] = [-1, 1] 151 | denominator_second = np.conj(fft2(dx))*fft2(dx) + np.conj(fft2(dy))*fft2(dy) 152 | denominator_second = np.tile(np.expand_dims(denominator_second, axis=2), [1, 1, I.shape[2]]) 153 | 154 | S = I 155 | hp = 0*I 156 | vp = 0*I 157 | for iter in range(numIter): 158 | # solve hp, vp 159 | hp = np.concatenate((S[:, 1:], S[:, :1]), axis=1) - S 160 | vp = np.concatenate((S[1:, :], S[:1, :]), axis=0) - S 161 | if len(I.shape) == 3: 162 | zeroIdx = np.sum(hp**2+vp**2, axis=2) < lmd/beta 163 | else: 164 | zeroIdx = hp**2.0 + vp**2.0 < lmd/beta 165 | hp[zeroIdx] = 0.0 166 | vp[zeroIdx] = 0.0 167 | 168 | # solve S 169 | hv = np.concatenate((hp[:, -1:], hp[:, :-1]), axis=1) - hp + np.concatenate((vp[-1:, :], vp[:-1, :]), axis=0) - vp 170 | S = np.real(ifft2((FI + (beta*fft2(hv, axes=(0, 1)))) / (1+beta*denominator_second), axes=(0, 1))) 171 | 172 | # update parameters 173 | beta *= betaRate 174 | if beta > betaMax: 175 | break 176 | if isTensor: 177 | results.append(torch.from_numpy(S).to(device=device, dtype=torch.float)) 178 | else: 179 | results.append(S) 180 | return results 181 | 182 | 183 | def Pix2PixDenoising(images, model=None): 184 | import os 185 | from .pix2pix.options.test_options import TestOptions 186 | from .pix2pix.data import create_dataset 187 | from .pix2pix.models import create_model 188 | import torch 189 | opt = TestOptions().parse() # get test options 190 | opt.gpu_ids = [torch.cuda.current_device()] 191 | # hard-code some parameters for test 192 | opt.num_threads = 0 # test code only supports num_threads = 1 193 | opt.batch_size = 1 # test code only supports batch_size = 1 194 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 195 | opt.no_flip = True # no flip; comment this line if results on flipped images are needed. 196 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 197 | opt.checkpoints_dir = 'trained_models' 198 | # opt.name = 'render_PCA_resnet' 199 | opt.name = model or 'render_PCA_resnet_noise03_knn15' 200 | opt.epoch = "latest" 201 | # opt.name = 'render_test_25shape_reconv_pix' 202 | opt.norm = 'pixel' 203 | # opt.netG = 'unet_256_Re1' 204 | opt.netG = 'resnet_9blocks' 205 | 206 | model = create_model(opt) # create a model given opt.model and other options 207 | model.setup(opt) # regular setup: load and print networks; create schedulers 208 | model.eval() 209 | 210 | images = torch.stack(images, dim=0) 211 | with torch.no_grad(): 212 | # B, 3, W, H 213 | images = images.permute(0, 3, 1, 2) 214 | images_normalized = images - 0.5 215 | input_dict = {'A': images_normalized, 'A_paths': [None]} 216 | model.set_input(input_dict) # unpack data from data loader 217 | model.test() # run inference 218 | results = model.get_current_visuals() # get image results 219 | B, C, H, W = results['real_A'].shape 220 | minValues = torch.min(results['real_A'].view(B, C, -1), dim=2, keepdim=True)[0].view(B, C, 1, 1) 221 | maxValues = torch.max(results['real_A'].view(B, C, -1), dim=2, keepdim=True)[0].view(B, C, 1, 1) 222 | results['fake_B'] = torch.min(results['fake_B'], maxValues.expand(-1, -1, H, W)) 223 | # results['real_A'] = torch.min(results['real_A'], maxValues.expand(-1, -1, H, W)) 224 | results['fake_B'] = torch.max(results['fake_B'], minValues.expand(-1, -1, H, W)) 225 | # results['real_A'] = torch.max(results['real_A'], minValues.expand(-1, -1, H, W)) 226 | 227 | results = results["fake_B"] + 0.5 228 | results = results.permute(0, 2, 3, 1).contiguous() 229 | return results 230 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | trained_model/ 3 | result/ 4 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from .base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = __name__ + "." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from data.base_dataset import BaseDataset, get_params, get_transform 4 | import torchvision.transforms as transforms 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | from os import listdir 8 | import numpy as np 9 | from util.util import is_image_file, load_img, save_img_tensor, tensor2im, save_image 10 | import torch 11 | import glob 12 | from os import walk 13 | 14 | 15 | def npy_loader(path): 16 | sample = torch.from_numpy(np.load(path)) 17 | return sample 18 | 19 | 20 | class AlignedDataset(BaseDataset): 21 | """A dataset class for paired image dataset. 22 | 23 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 24 | During test time, you need to prepare a directory '/path/to/data/test'. 25 | """ 26 | 27 | def __init__(self, opt): 28 | """Initialize this dataset class. 29 | 30 | Parameters: 31 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 32 | """ 33 | BaseDataset.__init__(self, opt) 34 | # self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 35 | 36 | # self.dir_A = os.path.join(opt.dataroot, 'trainA') 37 | # self.dir_B = os.path.join(opt.dataroot, 'trainB') 38 | self.dir_A = os.path.join(opt.dataroot, 'input_rendered') 39 | self.dir_B = os.path.join(opt.dataroot, 'target_rendered') 40 | 41 | # self.image_filenames = [x for x in listdir(self.dir_B) if is_image_file(x)] 42 | 43 | self.image_filenames = [] 44 | # glob.glob(self.dir_B + "/*.npy") 45 | 46 | for (root, dirs, files) in walk(self.dir_B): 47 | for filename in files: 48 | if filename.endswith(('.npy')): 49 | self.image_filenames.append(os.path.join(root, filename)) 50 | # print(os.path.join(root, filename)) 51 | 52 | # self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # get image paths 53 | # self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) 54 | 55 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 56 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 57 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 58 | 59 | def __getitem__(self, index): 60 | """Return a data point and its metadata information. 61 | 62 | Parameters: 63 | index - - a random integer for data indexing 64 | 65 | Returns a dictionary that contains A, B, A_paths and B_paths 66 | A (tensor) - - an image in the input domain 67 | B (tensor) - - its corresponding image in the target domain 68 | A_paths (str) - - image paths 69 | B_paths (str) - - image paths (same as A_paths) 70 | """ 71 | # read a image given a random integer index 72 | 73 | # AB_path = self.AB_paths[index] 74 | # AB = Image.open(AB_path).convert('RGB') 75 | # # split AB image into A and B 76 | # w, h = AB.size 77 | # w2 = int(w / 2) 78 | # A = AB.crop((0, 0, w2, h)) 79 | # B = AB.crop((w2, 0, w, h)) 80 | 81 | # A = Image.open(os.path.join(self.dir_A, self.image_filenames[index])).convert('RGB') 82 | # B = Image.open(os.path.join(self.dir_B, self.image_filenames[index])).convert('RGB') 83 | # 84 | # transform_params = get_params(self.opt, A.size) 85 | # A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 86 | # B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 87 | 88 | # print(self.image_filenames[index]) 89 | # print(self.image_filenames[index].replace('_target', '_input').replace('target_rendered', 'input_rendered')) 90 | 91 | # A = np.load(os.path.join(self.dir_A, self.image_filenames[index].replace('_target', '_input'))) 92 | # B = np.load(os.path.join(self.dir_B, self.image_filenames[index])) 93 | 94 | A = np.load(self.image_filenames[index].replace('_target', '_input').replace('target_rendered', 'input_rendered')) 95 | B = np.load(self.image_filenames[index]) 96 | 97 | # print(A.shape) 98 | # print(A) 99 | # apply the same transform to both A and B 100 | 101 | A_transform = get_transform(self.opt) 102 | B_transform = get_transform(self.opt) 103 | 104 | #### 105 | import pdb 106 | pdb.set_trace() 107 | A = A_transform(A) 108 | B = B_transform(B) 109 | 110 | # save_img_tensor(A, './testA.png') 111 | # save_img_tensor(B, './testB.png') 112 | 113 | return {'A': A, 'B': B} 114 | 115 | def __len__(self): 116 | """Return the total number of images in the dataset.""" 117 | return len(self.image_filenames) 118 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | # if grayscale: 84 | # transform_list.append(transforms.Grayscale(1)) 85 | # if 'resize' in opt.preprocess: 86 | # osize = [opt.load_size, opt.load_size] 87 | # transform_list.append(transforms.Resize(osize, method)) 88 | # elif 'scale_width' in opt.preprocess: 89 | # transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 90 | # 91 | # if 'crop' in opt.preprocess: 92 | # if params is None: 93 | # transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | # else: 95 | # transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | # if opt.preprocess == 'none': 98 | # transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | # if not opt.no_flip: 101 | # if params is None: 102 | # transform_list.append(transforms.RandomHorizontalFlip()) 103 | # elif params['flip']: 104 | # transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | # transform_list += [transforms.ToTensor(), 108 | # transforms.Normalize((0.5, 0.5, 0.5), 109 | # (0.5, 0.5, 0.5))] 110 | transform_list += [transforms.ToTensor(), 111 | transforms.Normalize((0.5, 0.5, 0.5), 112 | (1.0, 1.0, 1.0))] 113 | return transforms.Compose(transform_list) 114 | 115 | 116 | def __make_power_2(img, base, method=Image.BICUBIC): 117 | ow, oh = img.size 118 | h = int(round(oh / base) * base) 119 | w = int(round(ow / base) * base) 120 | if (h == oh) and (w == ow): 121 | return img 122 | 123 | __print_size_warning(ow, oh, w, h) 124 | return img.resize((w, h), method) 125 | 126 | 127 | def __scale_width(img, target_width, method=Image.BICUBIC): 128 | ow, oh = img.size 129 | if (ow == target_width): 130 | return img 131 | w = target_width 132 | h = int(target_width * oh / ow) 133 | return img.resize((w, h), method) 134 | 135 | 136 | def __crop(img, pos, size): 137 | ow, oh = img.size 138 | x1, y1 = pos 139 | tw = th = size 140 | if (ow > tw or oh > th): 141 | return img.crop((x1, y1, x1 + tw, y1 + th)) 142 | return img 143 | 144 | 145 | def __flip(img, flip): 146 | if flip: 147 | return img.transpose(Image.FLIP_LEFT_RIGHT) 148 | return img 149 | 150 | 151 | def __print_size_warning(ow, oh, w, h): 152 | """Print warning information about image size(only print once)""" 153 | if not hasattr(__print_size_warning, 'has_printed'): 154 | print("The image size needs to be a multiple of 4. " 155 | "The loaded image size was (%d, %d), so it was adjusted to " 156 | "(%d, %d). This adjustment will be done to all images " 157 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 158 | __print_size_warning.has_printed = True 159 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset, get_transform 2 | # from .image_folder import make_dataset 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | 7 | 8 | def make_dataset(dir, max_dataset_size=float("inf")): 9 | images = [] 10 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 11 | 12 | for root, _, fnames in sorted(os.walk(dir)): 13 | for fname in fnames: 14 | if fname[-3:] == "npy": 15 | path = os.path.join(root, fname) 16 | images.append(path) 17 | return images[:min(max_dataset_size, len(images))] 18 | 19 | 20 | def make_dataset(dir, max_dataset_size=float("inf")): 21 | images = [] 22 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 23 | 24 | for root, _, fnames in sorted(os.walk(dir)): 25 | for fname in fnames: 26 | if fname[-3:] == "npy": 27 | path = os.path.join(root, fname) 28 | images.append(path) 29 | return images[:min(max_dataset_size, len(images))] 30 | 31 | 32 | class SingleDataset(BaseDataset): 33 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 34 | 35 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 36 | """ 37 | 38 | def __init__(self, opt): 39 | """Initialize this dataset class. 40 | 41 | Parameters: 42 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 43 | """ 44 | BaseDataset.__init__(self, opt) 45 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 46 | input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 47 | self.transform = get_transform(opt, grayscale=(input_nc == 1)) 48 | 49 | def __getitem__(self, index): 50 | """Return a data point and its metadata information. 51 | 52 | Parameters: 53 | index - - a random integer for data indexing 54 | 55 | Returns a dictionary that contains A and A_paths 56 | A(tensor) - - an image in one domain 57 | A_paths(str) - - the path of the image 58 | """ 59 | A_path = self.A_paths[index] 60 | A_img = np.load(A_path) 61 | A = self.transform(A_img) 62 | return {'A': A, 'A_paths': A_path} 63 | 64 | def __len__(self): 65 | """Return the total number of images in the dataset.""" 66 | return len(self.A_paths) 67 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from .base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = __name__ + "." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | 6 | class Pix2PixModel(BaseModel): 7 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 8 | 9 | The model training requires '--dataset_mode aligned' dataset. 10 | By default, it uses a '--netG unet256' U-Net generator, 11 | a '--netD basic' discriminator (PatchGAN), 12 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 13 | 14 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 15 | """ 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train=True): 18 | """Add new dataset-specific options, and rewrite default values for existing options. 19 | 20 | Parameters: 21 | parser -- original option parser 22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 23 | 24 | Returns: 25 | the modified parser. 26 | 27 | For pix2pix, we do not use image buffer 28 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 29 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 30 | """ 31 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 32 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 33 | if is_train: 34 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 35 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 36 | 37 | return parser 38 | 39 | def __init__(self, opt): 40 | """Initialize the pix2pix class. 41 | 42 | Parameters: 43 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 44 | """ 45 | BaseModel.__init__(self, opt) 46 | # specify the training losses you want to print out. The training/test scripts will call 47 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 48 | # specify the images you want to save/display. The training/test scripts will call 49 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 50 | # specify the models you want to save to the disk. The training/test scripts will call and 51 | if self.isTrain: 52 | self.model_names = ['G', 'D'] 53 | else: # during test time, only load G 54 | self.model_names = ['G'] 55 | # define networks (both generator and discriminator) 56 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 57 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 58 | 59 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 60 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 61 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 62 | 63 | if self.isTrain: 64 | # define loss functions 65 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 66 | self.criterionL1 = torch.nn.L1Loss() 67 | # initialize optimizers; schedulers will be automatically created by function . 68 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 69 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 70 | self.optimizers.append(self.optimizer_G) 71 | self.optimizers.append(self.optimizer_D) 72 | 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input (dict): include the data itself and its metadata information. 78 | 79 | The option 'direction' can be used to swap images in domain A and domain B. 80 | """ 81 | AtoB = self.opt.direction == 'AtoB' 82 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 83 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 84 | # self.image_paths = input['A_paths' if AtoB else 'B_paths'] 85 | 86 | def forward(self): 87 | """Run forward pass; called by both functions and .""" 88 | self.fake_B = self.netG(self.real_A) # G(A) 89 | 90 | def backward_D(self): 91 | """Calculate GAN loss for the discriminator""" 92 | # Fake; stop backprop to the generator by detaching fake_B 93 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 94 | pred_fake = self.netD(fake_AB.detach()) 95 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 96 | # Real 97 | real_AB = torch.cat((self.real_A, self.real_B), 1) 98 | pred_real = self.netD(real_AB) 99 | self.loss_D_real = self.criterionGAN(pred_real, True) 100 | # combine loss and calculate gradients 101 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 102 | self.loss_D.backward() 103 | 104 | def backward_G(self): 105 | """Calculate GAN and L1 loss for the generator""" 106 | # First, G(A) should fake the discriminator 107 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 108 | pred_fake = self.netD(fake_AB) 109 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 110 | # Second, G(A) = B 111 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 112 | # combine loss and calculate gradients 113 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 114 | self.loss_G.backward() 115 | 116 | def optimize_parameters(self): 117 | self.forward() # compute fake images: G(A) 118 | # update D 119 | self.set_requires_grad(self.netD, True) # enable backprop for D 120 | self.optimizer_D.zero_grad() # set D's gradients to zero 121 | self.backward_D() # calculate gradients for D 122 | self.optimizer_D.step() # update D's weights 123 | # update G 124 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 125 | self.optimizer_G.zero_grad() # set G's gradients to zero 126 | self.backward_G() # calculate graidents for G 127 | self.optimizer_G.step() # udpate G's weights 128 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/models/template_model.py: -------------------------------------------------------------------------------- 1 | """Model class template 2 | 3 | This module provides a template for users to implement custom models. 4 | You can specify '--model template' to use this model. 5 | The class name should be consistent with both the filename and its model option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | It implements a simple image-to-image translation baseline based on regression loss. 9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: 10 | min_ ||netG(data_A) - data_B||_1 11 | You need to implement the following functions: 12 | : Add model-specific options and rewrite default values for existing options. 13 | <__init__>: Initialize this model class. 14 | : Unpack input data and perform data pre-processing. 15 | : Run forward pass. This will be called by both and . 16 | : Update network weights; it will be called in every training iteration. 17 | """ 18 | import torch 19 | from .base_model import BaseModel 20 | from . import networks 21 | 22 | 23 | class TemplateModel(BaseModel): 24 | @staticmethod 25 | def modify_commandline_options(parser, is_train=True): 26 | """Add new model-specific options and rewrite default values for existing options. 27 | 28 | Parameters: 29 | parser -- the option parser 30 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. 31 | 32 | Returns: 33 | the modified parser. 34 | """ 35 | parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. 36 | if is_train: 37 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. 38 | 39 | return parser 40 | 41 | def __init__(self, opt): 42 | """Initialize this model class. 43 | 44 | Parameters: 45 | opt -- training/test options 46 | 47 | A few things can be done here. 48 | - (required) call the initialization function of BaseModel 49 | - define loss function, visualization images, model names, and optimizers 50 | """ 51 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel 52 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. 53 | self.loss_names = ['loss_G'] 54 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. 55 | self.visual_names = ['data_A', 'data_B', 'output'] 56 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. 57 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. 58 | self.model_names = ['G'] 59 | # define networks; you can use opt.isTrain to specify different behaviors for training and test. 60 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) 61 | if self.isTrain: # only defined during training time 62 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. 63 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) 64 | self.criterionLoss = torch.nn.L1Loss() 65 | # define and initialize optimizers. You can define one optimizer for each network. 66 | # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 67 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 68 | self.optimizers = [self.optimizer] 69 | 70 | # Our program will automatically call to define schedulers, load networks, and print networks 71 | 72 | def set_input(self, input): 73 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 74 | 75 | Parameters: 76 | input: a dictionary that contains the data itself and its metadata information. 77 | """ 78 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B 79 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A 80 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B 81 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths 82 | 83 | def forward(self): 84 | """Run forward pass. This will be called by both functions and .""" 85 | self.output = self.netG(self.data_A) # generate output image given the input data_A 86 | 87 | def backward(self): 88 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 89 | # caculate the intermediate results if necessary; here self.output has been computed during function 90 | # calculate loss given the input and intermediate results 91 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression 92 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G 93 | 94 | def optimize_parameters(self): 95 | """Update network weights; it will be called in every training iteration.""" 96 | self.forward() # first call forward to calculate intermediate results 97 | self.optimizer.zero_grad() # clear network G's existing gradients 98 | self.backward() # calculate gradients for network G 99 | self.optimizer.step() # update gradients for network G 100 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/models/test_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | 4 | 5 | class TestModel(BaseModel): 6 | """ This TesteModel can be used to generate CycleGAN results for only one direction. 7 | This model will automatically set '--dataset_mode single', which only loads the images from one collection. 8 | 9 | See the test instruction for more details. 10 | """ 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train=True): 13 | """Add new dataset-specific options, and rewrite default values for existing options. 14 | 15 | Parameters: 16 | parser -- original option parser 17 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 18 | 19 | Returns: 20 | the modified parser. 21 | 22 | The model can only be used during test time. It requires '--dataset_mode single'. 23 | You need to specify the network using the option '--model_suffix'. 24 | """ 25 | assert not is_train, 'TestModel cannot be used during training time' 26 | parser.set_defaults(dataset_mode='single') 27 | parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.') 28 | 29 | return parser 30 | 31 | def __init__(self, opt): 32 | """Initialize the pix2pix class. 33 | 34 | Parameters: 35 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 36 | """ 37 | assert(not opt.isTrain) 38 | BaseModel.__init__(self, opt) 39 | # specify the training losses you want to print out. The training/test scripts will call 40 | self.loss_names = [] 41 | # specify the images you want to save/display. The training/test scripts will call 42 | self.visual_names = ['real_A', 'fake_B'] 43 | # specify the models you want to save to the disk. The training/test scripts will call and 44 | self.model_names = ['G' + opt.model_suffix] # only generator is needed. 45 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, 46 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 47 | 48 | # assigns the model to self.netG_[suffix] so that it can be loaded 49 | # please see 50 | setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. 51 | 52 | def set_input(self, input): 53 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 54 | 55 | Parameters: 56 | input: a dictionary that contains the data itself and its metadata information. 57 | 58 | We need to use 'single_dataset' dataset mode. It only load images from one domain. 59 | """ 60 | self.real_A = input['A'].to(self.device) 61 | self.image_paths = input['A_paths'] 62 | 63 | def forward(self): 64 | """Run forward pass.""" 65 | self.fake_B = self.netG(self.real_A) # G(A) 66 | 67 | def optimize_parameters(self): 68 | """No optimization for test model.""" 69 | pass 70 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from DSS.misc.pix2pix.util import util 5 | from DSS.misc.pix2pix import models 6 | from DSS.misc.pix2pix import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataroot', required=False, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 24 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--checkpoints_dir', type=str, help='models are saved here') 27 | # model parameters 28 | parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 29 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 30 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 31 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 32 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 33 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 34 | parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') 35 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 36 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 37 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 38 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 39 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 40 | # dataset parameters 41 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') 42 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') 43 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 44 | parser.add_argument('--num_threads', default=1, type=int, help='# threads for loading data') 45 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 46 | parser.add_argument('--load_size', type=int, default=256, help='scale images to this size') 47 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 48 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 49 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 50 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 51 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 52 | # additional parameters 53 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 54 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 55 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 56 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 57 | self.initialized = True 58 | return parser 59 | 60 | def gather_options(self): 61 | """Initialize our parser with basic options(only once). 62 | Add additional model-specific and dataset-specific options. 63 | These options are defined in the function 64 | in model and dataset classes. 65 | """ 66 | if not self.initialized: # check if it has been initialized 67 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 68 | parser = self.initialize(parser) 69 | 70 | # get the basic options 71 | opt, _ = parser.parse_known_args([""]) 72 | 73 | # modify model-related parser options 74 | model_name = opt.model 75 | model_option_setter = models.get_option_setter(model_name) 76 | parser = model_option_setter(parser, self.isTrain) 77 | opt, _ = parser.parse_known_args([""]) # parse again with new defaults 78 | 79 | # modify dataset-related parser options 80 | dataset_name = opt.dataset_mode 81 | dataset_option_setter = data.get_option_setter(dataset_name) 82 | parser = dataset_option_setter(parser, self.isTrain) 83 | 84 | # save and return the parser 85 | self.parser = parser 86 | return parser.parse_args("") 87 | 88 | def print_options(self, opt): 89 | """Print and save options 90 | 91 | It will print both current options and default values(if different). 92 | It will save options into a text file / [checkpoints_dir] / opt.txt 93 | """ 94 | message = '' 95 | message += '----------------- Options ---------------\n' 96 | for k, v in sorted(vars(opt).items()): 97 | comment = '' 98 | default = self.parser.get_default(k) 99 | if v != default: 100 | comment = '\t[default: %s]' % str(default) 101 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 102 | message += '----------------- End -------------------' 103 | print(message) 104 | 105 | # # save to the disk 106 | # expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 107 | # util.mkdirs(expr_dir) 108 | # file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 109 | # with open(file_name, 'wt') as opt_file: 110 | # opt_file.write(message) 111 | # opt_file.write('\n') 112 | 113 | def parse(self): 114 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 115 | opt = self.gather_options() 116 | opt.isTrain = self.isTrain # train or test 117 | 118 | # process opt.suffix 119 | if opt.suffix: 120 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 121 | opt.name = opt.name + suffix 122 | 123 | self.print_options(opt) 124 | self.opt = opt 125 | return self.opt 126 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | # rewrite devalue values 20 | parser.set_defaults(model='test') 21 | # To avoid cropping, the load_size should be the same as crop_size 22 | parser.set_defaults(load_size=parser.get_default('crop_size')) 23 | self.isTrain = False 24 | return parser 25 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /DSS/misc/pix2pix/util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 25 | else: # if it is a numpy array, do nothing 26 | image_numpy = input_image 27 | return image_numpy.astype(imtype) 28 | 29 | 30 | def diagnose_network(net, name='network'): 31 | """Calculate and print the mean of average absolute(gradients) 32 | 33 | Parameters: 34 | net (torch network) -- Torch network 35 | name (str) -- the name of the network 36 | """ 37 | mean = 0.0 38 | count = 0 39 | for param in net.parameters(): 40 | if param.grad is not None: 41 | mean += torch.mean(torch.abs(param.grad.data)) 42 | count += 1 43 | if count > 0: 44 | mean = mean / count 45 | print(name) 46 | print(mean) 47 | 48 | 49 | def save_image(image_numpy, image_path): 50 | """Save a numpy image to the disk 51 | 52 | Parameters: 53 | image_numpy (numpy array) -- input numpy array 54 | image_path (str) -- the path of the image 55 | """ 56 | image_pil = Image.fromarray(image_numpy) 57 | image_pil.save(image_path) 58 | 59 | 60 | def print_numpy(x, val=True, shp=False): 61 | """Print the mean, min, max, median, std, and size of a numpy array 62 | 63 | Parameters: 64 | val (bool) -- if print the values of the numpy array 65 | shp (bool) -- if print the shape of the numpy array 66 | """ 67 | x = x.astype(np.float64) 68 | if shp: 69 | print('shape,', x.shape) 70 | if val: 71 | x = x.flatten() 72 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 73 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 74 | 75 | 76 | def mkdirs(paths): 77 | """create empty directories if they don't exist 78 | 79 | Parameters: 80 | paths (str list) -- a list of directory paths 81 | """ 82 | if isinstance(paths, list) and not isinstance(paths, str): 83 | for path in paths: 84 | mkdir(path) 85 | else: 86 | mkdir(paths) 87 | 88 | 89 | def mkdir(path): 90 | """create a single empty directory if it didn't exist 91 | 92 | Parameters: 93 | path (str) -- a single directory path 94 | """ 95 | if not os.path.exists(path): 96 | os.makedirs(path) 97 | 98 | 99 | def is_image_file(filename): 100 | # return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 101 | return any(filename.endswith(extension) for extension in [".npy"]) 102 | 103 | 104 | def load_img(filepath): 105 | img = Image.open(filepath).convert('RGB') 106 | img = img.resize((256, 256), Image.BICUBIC) 107 | return img 108 | 109 | # 110 | 111 | 112 | def save_img_tensor(image_tensor, filename): 113 | image_numpy = image_tensor.float().numpy() 114 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 115 | image_numpy = image_numpy.clip(0, 255) 116 | image_numpy = image_numpy.astype(np.uint8) 117 | image_pil = Image.fromarray(image_numpy) 118 | image_pil.save(filename) 119 | print("Image saved as {}".format(filename)) 120 | -------------------------------------------------------------------------------- /DSS/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import namedtuple 3 | 4 | __all__ = ['BaseGenerator', 'PointModel'] 5 | 6 | 7 | class BaseGenerator(object): 8 | def __init__(self, model, device): 9 | self.model = model.to(device) 10 | self.device = device 11 | 12 | def generate_meshes(self, *args, **kwargs): 13 | return [] 14 | 15 | def generate_pointclouds(self, *args, **kwargs): 16 | return [] 17 | 18 | def generate_images(self, *args, **kwargs): 19 | return [] 20 | 21 | from .point_modeling import Model as PointModel 22 | -------------------------------------------------------------------------------- /DSS/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import yaml 5 | 6 | 7 | class BaseOptions(): 8 | """This class defines options used for basic inverse rendering, e.g. large deformation. 9 | """ 10 | 11 | def __init__(self): 12 | """Reset the class; indicates the class hasn't been initailized""" 13 | self.initialized = False 14 | 15 | def initialize(self, parser): 16 | """Define the common options that are used in both training and test.""" 17 | # basic parameters 18 | parser.add_argument('source', metavar="source", nargs='?', 19 | default="example_data/scenes/sphere.json", 20 | help='json|config file defining scenes initialization') 21 | parser.add_argument('-t', "--target", dest="ref", nargs='?', 22 | default="example_data/scenes/bunny.json", help='reference scene.') 23 | parser.add_argument('-d', '--device', dest='device', default='cuda:0', 24 | help='Device to run the computations on, options: cpu, cuda:{ID}') 25 | parser.add_argument('--name', default="experiment") 26 | parser.add_argument('-o', '--output', default="./learn_examples") 27 | parser.add_argument('-sS', '--startingStep', type=int, default=0) 28 | parser.add_argument('-C', '--cycles', type=int, default=12, 29 | help="number of (step_point, step_normal) optimization cycles") 30 | parser.add_argument('--modifiers', type=str, nargs='+', 31 | default=['localNormals', 'localPoints']) 32 | parser.add_argument('--steps', type=int, nargs='+', default=[15, 25]) 33 | parser.add_argument('--learningRates', type=float, 34 | nargs='+', default=[2000, 5]) 35 | parser.add_argument('--clip', dest="clipGrad", 36 | type=float, default=0.01, help='clip gradient') 37 | parser.add_argument('--verbose', action="store_true", 38 | help='more prints') 39 | parser.add_argument('--debug', action="store_true", 40 | help="log debugging information") 41 | parser.add_argument('--type', default="DSS", 42 | help="DSS or Baseline", choices=["DSS", "Baseline"]) 43 | parser.add_argument('--width', type=int, 44 | default=256, help="image width") 45 | parser.add_argument('--height', type=int, 46 | default=256, help="image height") 47 | parser.add_argument('--sv', default=128, help="view scale") 48 | parser.add_argument('-k', '--topK', dest="mergeTopK", 49 | type=int, default=5, help='topK for merging depth') 50 | parser.add_argument('-mT', '--mergeThreshold', type=float, 51 | default=0.05, help='threshold for merging depth') 52 | parser.add_argument('--img-loss-type', 53 | choices=["SMAPE", "L1", "L2"], default="SMAPE") 54 | parser.add_argument('--no-z', dest='considerZ', action="store_false", 55 | help='do not optimize Z in backward (default false)') 56 | parser.add_argument('-rR', '--repulsionRadius', type=float, 57 | default=0.05, help='radius for repulsion loss') 58 | parser.add_argument('-rW', '--repulsionWeight', type=float, 59 | default=0.03, help='weight for repulsion loss') 60 | parser.add_argument('-pR', '--projectionRadius', type=float, 61 | default=0.3, help='radius for projection loss') 62 | parser.add_argument('-pW', '--projectionWeight', type=float, 63 | default=0.05, help='weight for projection loss') 64 | parser.add_argument('-aW', '--averageWeight', type=float, 65 | default=0, help='weight for average term') 66 | parser.add_argument('--average-term', action="store_true", 67 | help="apply average term") 68 | parser.add_argument('-iW', '--imageWeight', type=float, 69 | default=1, help='weight for projection loss') 70 | parser.add_argument('-fR', '--repulsionFreq', type=int, 71 | default=1, help='frequency for repulsion term') 72 | parser.add_argument('-fP', '--projectionFreq', type=int, 73 | default=2, help='frequency for denoising term') 74 | parser.add_argument('-c', '--genCamera', type=int, 75 | default=12, help='number of random cameras') 76 | parser.add_argument( 77 | '--cameraFile', default="example_data/pointclouds/sphere_300.ply") 78 | parser.add_argument('-cO', '--camOffset', type=float, 79 | default=15, help='depth offset for generated cameras') 80 | parser.add_argument('-cF', '--camFocalLength', type=float, 81 | default=15, help='focal length for generated cameras') 82 | parser.add_argument('--cutOffThreshold', 83 | type=float, default=1, help='cutoff threshold') 84 | parser.add_argument('--Vrk_h', type=float, default=0.02, help='standard deviation for V_r^h in EWA') 85 | parser.add_argument('--backwardLocalSize', default=128, 86 | type=int, help='window size for computing pixel loss') 87 | parser.add_argument('--backwardLocalSizeDecay', default=0.9, 88 | type=float, help='decay for backward window size after each cycle') 89 | parser.add_argument('--baseline', action="store_true", 90 | help="use baseline depth renderer") 91 | parser.add_argument('--sharpnessSigma', default=60, 92 | type=float, help="sharpness sigma for weighted PCA") 93 | self.initialized = True 94 | return parser 95 | 96 | def gather_options(self): 97 | """Initialize our parser with basic options(only once). 98 | Add additional model-specific and dataset-specific options. 99 | These options are defined in the function 100 | in model and dataset classes. 101 | """ 102 | if not self.initialized: # check if it has been initialized 103 | parser = argparse.ArgumentParser( 104 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='DSS optimization') 105 | parser = self.initialize(parser) 106 | 107 | # save and return the parser 108 | self.parser = parser 109 | # get the basic options 110 | opt, _ = self.parser.parse_known_args() 111 | return opt 112 | 113 | return self.parser.parse_args() 114 | 115 | def print_options(self, opt): 116 | """Print and save options 117 | 118 | It will print both current options and default values(if different). 119 | It will save options into a text file / [checkpoints_dir] / opt.txt 120 | """ 121 | message = '' 122 | message += '----------------- Options ---------------\n' 123 | opt_dict = {} 124 | for k, v in sorted(vars(opt).items()): 125 | opt_dict[str(k)] = str(v) 126 | comment = '' 127 | default = self.parser.get_default(k) 128 | if str(k) == "device": 129 | opt_dict[str(k)] = str(v) 130 | else: 131 | opt_dict[k] = v 132 | if v != default: 133 | comment = '\t[default: %s]' % str(default) 134 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 135 | message += '----------------- End -------------------' 136 | print(message) 137 | 138 | # save to the disk 139 | expr_dir = os.path.join(opt.output, opt.name) 140 | os.makedirs(expr_dir, exist_ok=True) 141 | file_name = os.path.join(expr_dir, 'opt.txt') 142 | with open(file_name, 'wt') as opt_file: 143 | opt_file.write(message) 144 | opt_file.write('\n') 145 | opt_file_name = os.path.join(expr_dir, 'opt.yaml') 146 | with open(opt_file_name, 'wt') as opt_file: 147 | yaml.dump(opt_dict, opt_file) 148 | 149 | def parse(self): 150 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 151 | opt = self.gather_options() 152 | device, isCuda = parse_device(opt.device) 153 | opt.device = device 154 | torch.cuda.set_device(opt.device) 155 | self.opt = opt 156 | return self.opt 157 | 158 | 159 | def parse_device(device): 160 | if "cuda" in device: 161 | device = torch.device(device) 162 | isCpu = False 163 | elif device == 'cpu': 164 | device = torch.device('cpu') 165 | isCpu = True 166 | else: 167 | print("Unknown device name " + str(device) + ", falling back to cpu") 168 | device = torch.device('cuda:0') 169 | isCpu = False 170 | return (device, isCpu) 171 | -------------------------------------------------------------------------------- /DSS/options/deformation_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | import json 3 | 4 | 5 | class DeformationOptions(BaseOptions): 6 | """ 7 | This class defines options used during finetuning. 8 | """ 9 | 10 | def parse(self): 11 | self.opt = super().parse() 12 | with open(self.opt.ref, "r") as f: 13 | targetJson = json.load(f) 14 | if "cmdLineArgs" in targetJson: 15 | self.parser.set_defaults(**targetJson["cmdLineArgs"]) 16 | # parser again with new defaults 17 | # self.opt, _ = self.parser.parse_known_args() 18 | self.opt = super().parse() 19 | self.print_options(self.opt) 20 | return self.opt 21 | -------------------------------------------------------------------------------- /DSS/options/filter_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | import yaml 3 | 4 | 5 | class FilterOptions(BaseOptions): 6 | """This class defines options used during finetuning. 7 | """ 8 | 9 | def initialize(self, parser): 10 | """ 11 | defines additional paramters 12 | """ 13 | parser = BaseOptions.initialize(self, parser) # define shared options 14 | parser.add_argument('--flip-normal', action="store_false", 15 | dest="backfaceCulling", help='flip normal wrt view direction') 16 | parser.add_argument('--pix2pix', type=str, default="render_PCA_resnet", 17 | help="model name for pix2pix") 18 | parser.add_argument('--recursiveFiltering', action="store_true", 19 | help="apply the same filter on the output") 20 | parser.add_argument('--cloud', nargs='?', help='source cloud') 21 | parser.add_argument('--im_filter', '-f', default="Pix2PixDenoising", help='filter function') 22 | 23 | parser.set_defaults(steps=[19, 1], learningRates=[ 24 | 2000, 1], projectionRadius=0.1, projectionWeight=0.05, 25 | repulsionRadius=0.03, repulsionWeight=0.05, 26 | repulsionFreq=1, projectionFreq=1, 27 | averageWeight=0.01, 28 | camOffset=10, camFocalLength=15, name="filter", 29 | backward_bb=100) 30 | self.initialized = True 31 | return parser 32 | 33 | def parse(self): 34 | self.opt = super().parse() 35 | self.opt.modifiers = ["localNormals", "localPoints"] 36 | self.opt.isFiltering = True 37 | with open(self.opt.source, "r") as f: 38 | targetJson = yaml.load(f) 39 | if "cmdLineArgs" in targetJson: 40 | self.parser.set_defaults(**targetJson["cmdLineArgs"]) 41 | self.opt, _ = self.parser.parse_known_args() 42 | if self.opt.im_filter == "Pix2PixDenoising": 43 | self.opt.shading = "diffuse" 44 | self.opt.average_term = True 45 | self.opt.recursiveFiltering = False 46 | self.print_options(self.opt) 47 | return self.opt 48 | -------------------------------------------------------------------------------- /DSS/options/finetune_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import json 5 | from .deformation_options import DeformationOptions 6 | from .base_options import BaseOptions 7 | 8 | 9 | class FinetuneOptions(BaseOptions): 10 | """This class defines options used during finetuning. 11 | """ 12 | 13 | def initialize(self, parser): 14 | """ 15 | defines additional paramters 16 | """ 17 | parser = super().initialize(parser) # define shared options 18 | parser.set_defaults(steps=[20, 15], learningRates=[ 19 | 2000, 1], cycles=15, 20 | projectionRadius=0.1, projectionWeight=0.02, 21 | repulsionRadius=0.05, repulsionWeight=0.05, 22 | repulsionFreq=1, projectionFreq=1, 23 | cutOffThreshold=1.5, 24 | camOffset=9, camFocalLength=15, 25 | backwardLocalSizeDecay=0.95) 26 | self.initialized = True 27 | return parser 28 | 29 | def parse(self): 30 | self.opt = super().parse() 31 | with open(self.opt.ref, "r") as f: 32 | targetJson = json.load(f) 33 | if "cmdLineArgs" in targetJson: 34 | self.parser.set_defaults(**targetJson["cmdLineArgs"]) 35 | if "finetuneArgs" in targetJson: 36 | self.parser.set_defaults(**targetJson["finetuneArgs"]) 37 | try: 38 | self.parser.set_defaults(camOffset=0.5*targetJson["cmdLineArgs"]["camOffset"]) 39 | except KeyError: 40 | pass 41 | # parser again with new defaults 42 | self.opt = super().parse() 43 | self.print_options(self.opt) 44 | return self.opt 45 | -------------------------------------------------------------------------------- /DSS/options/render_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import pdb 5 | from .base_options import BaseOptions 6 | import json 7 | 8 | 9 | class RenderOptions(BaseOptions): 10 | """This class defines options used during finetuning. 11 | """ 12 | 13 | def initialize(self, parser): 14 | super().initialize(parser) 15 | parser.add_argument("--points", nargs="*", help="paths to points") 16 | parser.add_argument('--rot-axis', '-a', help="rotation axis", default="y") 17 | parser.set_defaults(output="renders/") 18 | return parser 19 | 20 | def parse(self): 21 | self.opt = super().parse() 22 | with open(self.opt.source, "r") as f: 23 | targetJson = json.load(f) 24 | if "cmdLineArgs" in targetJson: 25 | for key, value in targetJson['cmdLineArgs'].items(): 26 | if key == "source": 27 | continue 28 | setattr(self.opt, key, value) 29 | self.opt = super().parse() 30 | self.print_options(self.opt) 31 | return self.opt 32 | -------------------------------------------------------------------------------- /DSS/training/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | change trainer settings according to iterations 3 | """ 4 | from typing import List 5 | import bisect 6 | from .. import logger_py 7 | 8 | 9 | class TrainerScheduler(object): 10 | """ Increase n_points_per_cloud and Reduce n_training_points """ 11 | 12 | def __init__(self, init_dss_backward_radii: float = 0, 13 | steps_dss_backward_radii: int = -1, 14 | steps_proj: int=-1, 15 | warm_up_iters: int = 0, 16 | gamma_dss_backward_radii: float = 0.99, 17 | gamma_proj: float = 5, 18 | limit_dss_backward_radii: float = 1.5, 19 | limit_proj: float = 1.0, 20 | ): 21 | """ steps_n_points_dss: list """ 22 | 23 | self.init_dss_backward_radii = init_dss_backward_radii 24 | 25 | self.steps_dss_backward_radii = steps_dss_backward_radii 26 | self.steps_proj = steps_proj 27 | 28 | self.gamma_dss_backward_radii = gamma_dss_backward_radii 29 | self.gamma_proj = gamma_proj 30 | 31 | self.limit_dss_backward_radii = limit_dss_backward_radii 32 | self.limit_proj = limit_proj 33 | 34 | self.warm_up_iters = warm_up_iters 35 | 36 | def step(self, trainer, it): 37 | # change rasterize backward radii 38 | if self.steps_dss_backward_radii > 0 and hasattr(trainer.model, 'renderer'): 39 | # shortcut 40 | raster_settings = trainer.model.renderer.rasterizer.raster_settings 41 | i = it // self.steps_dss_backward_radii 42 | gamma = self.gamma_dss_backward_radii ** i 43 | old_backward_scaler = raster_settings.radii_backward_scaler 44 | raster_settings.radii_backward_scaler = max( 45 | self.init_dss_backward_radii * gamma, self.limit_dss_backward_radii) 46 | if old_backward_scaler != raster_settings.radii_backward_scaler: 47 | logger_py.info('Updated radii_backward_scaler: {} -> {}'.format( 48 | old_backward_scaler, raster_settings.radii_backward_scaler)) 49 | 50 | if self.steps_proj > 0: 51 | i = it // self.steps_proj 52 | gamma = self.gamma_proj ** i 53 | trainer.lambda_dr_proj = min(trainer.lambda_dr_proj * gamma, self.limit_proj) -------------------------------------------------------------------------------- /DSS/utils/io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import plyfile 4 | import numpy as np 5 | from matplotlib import cm 6 | import matplotlib.colors as mpc 7 | 8 | 9 | def saveDebugPNG(projPoint, imgTensor, savePath): 10 | """ 11 | save imgTensor to PNG, highlight the projPoint with grid lines 12 | params: 13 | projPoint (1, 2) 14 | imgTensor (H,W,3or1) torch.Tensor or numpy.array 15 | """ 16 | import matplotlib.pyplot as plt 17 | # normalize imgTensor 18 | plt.clf() 19 | cmin = imgTensor.min() 20 | cmax = imgTensor.max() 21 | imgTensor = (imgTensor - cmin) / (cmax - cmin) 22 | imgTensor[np.isnan(imgTensor) != False] = 0.0 23 | if imgTensor.ndim == 2 or (imgTensor.ndim == 3 and imgTensor.shape[-1] == 1): 24 | plt.imshow(imgTensor, cmap='gray') 25 | else: 26 | plt.imshow(imgTensor) 27 | i, j = projPoint.flatten()[:] 28 | plt.scatter(i, j, facecolors='none', edgecolors="cyan") 29 | plt.axvline(x=i, color='red') 30 | plt.axhline(y=j, color='red') 31 | plt.savefig(savePath) 32 | 33 | 34 | def encodeFlow(flowTensor: torch.Tensor, logScale=False): 35 | """ 36 | encode the vector field to a colored image 37 | :params 38 | flowTensor: (H,W,2) 39 | :return 40 | rgb: (H,W,3) numpy array floating type 41 | """ 42 | h, w = flowTensor.shape[:2] 43 | rho, phi = cart2pol(flowTensor[:, :, 0], flowTensor[:, :, 1]) 44 | rmin, rmax = rho.min(), rho.max() 45 | rho = (rho - rmin) / (rmax - rmin) * 255 46 | if logScale: 47 | rho = torch.log(1 + rho) 48 | rho[np.isnan(rho) != False] = 0.0 49 | hsv = np.full((h, w, 3), 255, dtype=np.uint8) 50 | hsv[..., 0] = phi * 255 / 2 / np.pi 51 | hsv[..., 2] = rho 52 | from skimage.color import hsv2rgb 53 | rgb = hsv2rgb(hsv) 54 | return rgb 55 | 56 | 57 | def cart2pol(x, y): 58 | """ 59 | cartesian coordinates to polar coordinates 60 | return: 61 | rho: length 62 | phi: (, 2pi) 63 | """ 64 | rho = (x**2 + y**2).sqrt() 65 | phi = np.arctan2(y, x) + np.pi 66 | return (rho, phi) 67 | 68 | 69 | def pol2cart(rho, phi): 70 | """ polar to cartesian """ 71 | x = rho * phi.cos() 72 | y = rho * phi.sin() 73 | return (x, y) 74 | 75 | 76 | def read_ply(file): 77 | loaded = plyfile.PlyData.read(file) 78 | points = np.vstack([loaded['vertex'].data['x'], 79 | loaded['vertex'].data['y'], loaded['vertex'].data['z']]) 80 | if 'nx' in loaded['vertex'].data.dtype.names: 81 | normals = np.vstack([loaded['vertex'].data['nx'], 82 | loaded['vertex'].data['ny'], loaded['vertex'].data['nz']]) 83 | points = np.concatenate([points, normals], axis=0) 84 | 85 | points = points.transpose(1, 0) 86 | return points 87 | 88 | 89 | def save_ply(filename, points, colors=None, normals=None, binary=True): 90 | """ 91 | save 3D/2D points to ply file 92 | Args: 93 | points (numpy array): (N,2or3) 94 | colors (numpy uint8 array): (N, 3or4) 95 | """ 96 | assert(points.ndim == 2) 97 | if points.shape[-1] == 2: 98 | points = np.concatenate( 99 | [points, np.zeros_like(points)[:, :1]], axis=-1) 100 | 101 | vertex = np.core.records.fromarrays(points.transpose( 102 | 1, 0), names='x, y, z', formats='f4, f4, f4') 103 | num_vertex = len(vertex) 104 | desc = vertex.dtype.descr 105 | 106 | if normals is not None: 107 | assert(normals.ndim == 2) 108 | if normals.shape[-1] == 2: 109 | normals = np.concatenate( 110 | [normals, np.zeros_like(normals)[:, :1]], axis=-1) 111 | vertex_normal = np.core.records.fromarrays( 112 | normals.transpose(1, 0), names='nx, ny, nz', formats='f4, f4, f4') 113 | assert len(vertex_normal) == num_vertex 114 | desc = desc + vertex_normal.dtype.descr 115 | 116 | if colors is not None: 117 | assert len(colors) == num_vertex 118 | if colors.max() <= 1: 119 | colors = colors * 255 120 | if colors.shape[1] == 4: 121 | vertex_color = np.core.records.fromarrays(colors.transpose( 122 | 1, 0), names='red, green, blue, alpha', formats='u1, u1, u1, u1') 123 | else: 124 | vertex_color = np.core.records.fromarrays(colors.transpose( 125 | 1, 0), names='red, green, blue', formats='u1, u1, u1') 126 | desc = desc + vertex_color.dtype.descr 127 | 128 | vertex_all = np.empty(num_vertex, dtype=desc) 129 | 130 | for prop in vertex.dtype.names: 131 | vertex_all[prop] = vertex[prop] 132 | 133 | if normals is not None: 134 | for prop in vertex_normal.dtype.names: 135 | vertex_all[prop] = vertex_normal[prop] 136 | 137 | if colors is not None: 138 | for prop in vertex_color.dtype.names: 139 | vertex_all[prop] = vertex_color[prop] 140 | 141 | ply = plyfile.PlyData( 142 | [plyfile.PlyElement.describe(vertex_all, 'vertex')], text=(not binary)) 143 | if not os.path.exists(os.path.dirname(filename)): 144 | os.makedirs(os.path.dirname(filename)) 145 | ply.write(filename) 146 | 147 | 148 | def save_ply_property(filename, points, property, 149 | property_max=None, property_min=None, 150 | normals=None, cmap_name='Set1', binary=True): 151 | point_num = points.shape[0] 152 | colors = np.full([point_num, 3], 0.5) 153 | cmap = cm.get_cmap(cmap_name) 154 | if property_max is None: 155 | property_max = np.amax(property, axis=0) 156 | if property_min is None: 157 | property_min = np.amin(property, axis=0) 158 | p_range = property_max - property_min 159 | if property_max == property_min: 160 | property_max = property_min + 1 161 | normalizer = mpc.Normalize(vmin=property_min, vmax=property_max) 162 | p = normalizer(property) 163 | colors = cmap(p)[:, :3] 164 | save_ply(filename, points, colors, normals, binary) 165 | -------------------------------------------------------------------------------- /DSS/utils/mathHelper.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Optional 2 | import numpy as np 3 | import torch 4 | from torch_batch_svd import svd as batch_svd 5 | from pytorch3d.ops.utils import convert_pointclouds_to_tensor 6 | from pytorch3d.ops import knn_points 7 | from pytorch3d.ops.points_normals import _disambiguate_vector_directions 8 | 9 | 10 | def eps_denom(denom, eps=1e-17): 11 | """ Prepare denominator for division """ 12 | denom_sign = denom.sign() + (denom == 0.0).type_as(denom) 13 | denom = denom_sign * torch.clamp(denom.abs(), eps) 14 | return denom 15 | 16 | def eps_sqrt(squared, eps=1e-17): 17 | """ 18 | Prepare for the input for sqrt, make sure the input positive and 19 | larger than eps 20 | """ 21 | return torch.clamp(squared.abs(), eps) 22 | 23 | 24 | def pinverse(inputs: torch.Tensor): 25 | assert(inputs.ndim >= 2) 26 | shp = inputs.shape 27 | U, S, V = batch_svd(inputs.view(-1, shp[-2], shp[-1])) 28 | S[S < 1e-6] = 0 29 | S_inv = torch.where(S < 1e-5, torch.zeros_like(S), 1/S) 30 | pinv = V @ torch.diag_embed(S_inv) @ U.transpose(1,2) 31 | return pinv.view(shp) 32 | 33 | 34 | def estimate_pointcloud_local_coord_frames( 35 | pointclouds: Union[torch.Tensor, "Pointclouds"], 36 | neighborhood_size: int = 50, 37 | disambiguate_directions: bool = True, 38 | return_knn_result: bool = False, 39 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional['KNN']]: 40 | """ 41 | Faster version of pytorch3d estimate_pointcloud_local_coord_frames 42 | 43 | Estimates the principal directions of curvature (which includes normals) 44 | of a batch of `pointclouds`. 45 | Returns: 46 | curvatures (N,P,3) ascending order 47 | local_frames (N,P,3,3) corresponding eigenvectors 48 | """ 49 | points_padded, num_points = convert_pointclouds_to_tensor(pointclouds) 50 | 51 | ba, N, dim = points_padded.shape 52 | if dim != 3: 53 | raise ValueError( 54 | "The pointclouds argument has to be of shape (minibatch, N, 3)" 55 | ) 56 | 57 | if (num_points <= neighborhood_size).any(): 58 | raise ValueError( 59 | "The neighborhood_size argument has to be" 60 | + " >= size of each of the point clouds." 61 | ) 62 | # undo global mean for stability 63 | # TODO: replace with tutil.wmean once landed 64 | pcl_mean = points_padded.sum(1) / num_points[:, None] 65 | points_centered = points_padded - pcl_mean[:, None, :] 66 | 67 | # get K nearest neighbor idx for each point in the point cloud 68 | knn_result = knn_points( 69 | points_padded, 70 | points_padded, 71 | lengths1=num_points, 72 | lengths2=num_points, 73 | K=neighborhood_size, 74 | return_nn=True, 75 | ) 76 | k_nearest_neighbors = knn_result.knn 77 | # obtain the mean of the neighborhood 78 | pt_mean = k_nearest_neighbors.mean(2, keepdim=True) 79 | # compute the diff of the neighborhood and the mean of the neighborhood 80 | # N,P,K,3 81 | central_diff = k_nearest_neighbors - pt_mean 82 | per_pts_diff = central_diff.view(-1, neighborhood_size, 3) 83 | # S (NP,3) and local_coord_framds (NP,3,3) 84 | _, S, local_coord_frames = batch_svd(per_pts_diff) 85 | curvature = S * S / neighborhood_size 86 | local_coord_frames = local_coord_frames.view(ba, N, dim, dim) 87 | curvature = curvature.view(ba, N, dim) 88 | 89 | # flip to ascending order 90 | curvature = curvature.flip(-1) 91 | local_coord_frames = local_coord_frames.flip(-1) 92 | 93 | # disambiguate the directions of individual principal vectors 94 | if disambiguate_directions: 95 | # disambiguate normal 96 | n = _disambiguate_vector_directions( 97 | points_centered, k_nearest_neighbors, local_coord_frames[:, :, :, 0] 98 | ) 99 | # disambiguate the main curvature 100 | z = _disambiguate_vector_directions( 101 | points_centered, k_nearest_neighbors, local_coord_frames[:, :, :, 2] 102 | ) 103 | # the secondary curvature is just a cross between n and z 104 | y = torch.cross(n, z, dim=2) 105 | # cat to form the set of principal directions 106 | local_coord_frames = torch.stack((n, y, z), dim=3) 107 | 108 | if return_knn_result: 109 | return curvature, local_coord_frames, knn_result 110 | return curvature, local_coord_frames 111 | 112 | 113 | def estimate_pointcloud_normals( 114 | pointclouds: Union[torch.Tensor, "Pointclouds"], 115 | neighborhood_size: int = 50, 116 | disambiguate_directions: bool = True, 117 | ) -> torch.Tensor: 118 | """ 119 | Estimates the normals of a batch of `pointclouds` using fast `estimate_pointcloud_local_coord_frames 120 | 121 | Args: 122 | **pointclouds**: Batch of 3-dimensional points of shape 123 | `(minibatch, num_point, 3)` or a `Pointclouds` object. 124 | **neighborhood_size**: The size of the neighborhood used to estimate the 125 | geometry around each point. 126 | **disambiguate_directions**: If `True`, uses the algorithm from [1] to 127 | ensure sign consistency of the normals of neigboring points. 128 | 129 | Returns: 130 | **normals**: A tensor of normals for each input point 131 | of shape `(minibatch, num_point, 3)`. 132 | If `pointclouds` are of `Pointclouds` class, returns a padded tensor. 133 | 134 | References: 135 | [1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for 136 | Local Surface Description, ECCV 2010. 137 | """ 138 | curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames( 139 | pointclouds, 140 | neighborhood_size=neighborhood_size, 141 | disambiguate_directions=disambiguate_directions, 142 | ) 143 | 144 | # the normals correspond to the first vector of each local coord frame 145 | normals = local_coord_frames[:, :, :, 0] 146 | 147 | return normals 148 | 149 | 150 | def ndc_to_pix(p, resolution): 151 | """ 152 | Reverse of pytorch3d pix_to_ndc function 153 | Args: 154 | p (float tensor): (..., 3) 155 | resolution (scalar): image resolution (for now, supports only aspectratio = 1) 156 | Returns: 157 | pix (long tensor): (..., 2) 158 | """ 159 | pix = resolution - ((p[..., :2] + 1.0) * resolution - 1.0) / 2 160 | return pix 161 | 162 | 163 | def decompose_to_R_and_t(transform_mat, row_major=True): 164 | """ decompose a 4x4 transform matrix to R (3,3) and t (1,3)""" 165 | assert(transform_mat.shape[-2:] == (4, 4)), \ 166 | "Expecting batches of 4x4 matrice" 167 | # ... 3x3 168 | if not row_major: 169 | transform_mat = transform_mat.transpose(-2, -1) 170 | 171 | R = transform_mat[..., :3, :3] 172 | t = transform_mat[..., -1, :3] 173 | 174 | return R, t 175 | 176 | 177 | def to_homogen(x, dim=-1): 178 | """ append one to the specified dimension """ 179 | if dim < 0: 180 | dim = x.ndim + dim 181 | shp = x.shape 182 | new_shp = shp[:dim] + (1, ) + shp[dim + 1:] 183 | x_homogen = x.new_ones(new_shp) 184 | x_homogen = torch.cat([x, x_homogen], dim=dim) 185 | return x_homogen 186 | -------------------------------------------------------------------------------- /DSS/utils/matrixConstruction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import unittest 4 | 5 | from .mathHelper import normalize 6 | 7 | 8 | def rotationMatrixX(alpha): 9 | tc = torch.cos(alpha) 10 | ts = torch.sin(alpha) 11 | R = torch.eye(3, device=alpha.device) 12 | R[1, 1] = tc 13 | R[2, 2] = tc 14 | R[1, 2] = -ts 15 | R[2, 1] = ts 16 | return R 17 | 18 | 19 | def rotationMatrixY(alpha): 20 | tc = torch.cos(alpha) 21 | ts = torch.sin(alpha) 22 | R = torch.eye(3, device=alpha.device) 23 | R[0, 0] = tc 24 | R[2, 2] = tc 25 | R[2, 0] = -ts 26 | R[0, 2] = ts 27 | return R 28 | 29 | 30 | def rotationMatrixZ(alpha): 31 | tc = torch.cos(alpha) 32 | ts = torch.sin(alpha) 33 | R = torch.eye(3, device=alpha.device) 34 | R[0, 0] = tc 35 | R[1, 1] = tc 36 | R[0, 1] = -ts 37 | R[1, 0] = ts 38 | return R 39 | 40 | 41 | def rotationMatrix(alpha, beta, gamma): 42 | return rotationMatrixX(alpha).mm(rotationMatrixY(beta).mm(rotationMatrixZ(gamma))) 43 | 44 | 45 | def convertWorldToCameraTransform(Rw, C): 46 | """ 47 | Takes a camera transformation in world space and returns the camera transformation in 48 | camera space to be used as extrinsic parameters 49 | Rw (B, 3, 3) 50 | C (B, 3) 51 | """ 52 | Rc = Rw.transpose(1, 2) 53 | t = -Rc.matmul(C.unsqueeze(-1)).squeeze(-1) 54 | return (Rc, t) 55 | 56 | 57 | def batchAffineMatrix(R, t, scale=1.0, column_matrix=True): 58 | """ 59 | affine transformation with uniform scaling->rotation->tranlation 60 | Args: 61 | R (..., 3, 3): rotation matrix 62 | t (..., 3): translation tensor 63 | scale (scaler or vector): scale vector 64 | column_matrix (bool): if True, [R | t] expect transformation R @ p, 65 | otherwise p @ R (in pytorch3d) 66 | 67 | Returns: 68 | (..., 4, 4) transformation matrix 69 | """ 70 | assert R.shape[-2:] == (3, 3), "R must be of shape (..., 3, 3)" 71 | assert t.dim() == (R.dim() - 1), f"t must be of shape (..., 3) ({t.shape=}" 72 | out_shape = list(R.shape[:-2]) + [4, 4] 73 | T = R.new_zeros(out_shape) 74 | T[..., 3, 3] = 1.0 75 | T[..., :3, :3] = scale * R 76 | if column_matrix: 77 | T[..., :3, 3] = t 78 | else: 79 | T[..., 3, :3] = t 80 | return T 81 | 82 | 83 | def batchLookAt(fromP, toP, upP): 84 | """ 85 | construct rotation and translation using from, forward, upward vectors 86 | fromP batches of (3,) vectors 87 | toP batches of (3,) vectors 88 | upP batches of (3,) vectors 89 | """ 90 | # change (..., 3) to (..., 3, 3) 91 | shapeP = list(fromP.shape) 92 | shapeP.append(3) 93 | translation = fromP 94 | fromP = fromP.view(-1, 3) 95 | toP = toP.view(-1, 3) 96 | upP = upP.view(-1, 3) 97 | b, _ = fromP.shape 98 | forward = normalize(toP - fromP, dim=-1) 99 | right = normalize(forward.cross(upP), dim=-1) 100 | upP = normalize(right.cross(forward), dim=-1) 101 | rotation = torch.empty([b, 3, 3], device=fromP.device) 102 | rotation[:, :, 0] = right 103 | rotation[:, :, 1] = upP 104 | rotation[:, :, 2] = forward 105 | rotation = rotation.view(*shapeP) 106 | return rotation, translation 107 | 108 | 109 | def lookAt(fromP, toP, upP): 110 | forward = normalize(toP - fromP) 111 | right = normalize(forward.cross(upP)) 112 | upP = right.cross(forward) 113 | rotation = torch.empty([3, 3], device=fromP.device) 114 | rotation[:, 0] = right 115 | rotation[:, 1] = upP 116 | rotation[:, 2] = forward 117 | translation = fromP 118 | return (rotation, translation) 119 | 120 | 121 | if __name__ == '__main__': 122 | unittest.main() 123 | -------------------------------------------------------------------------------- /DSS/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._six import int_classes as _int_classes 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class WeightedSubsetRandomSampler(Sampler): 7 | r"""Samples elements from a given list of indices with given probabilities (weights), with replacement. 8 | 9 | Arguments: 10 | weights (sequence) : a sequence of weights, not necessary summing up to one 11 | num_samples (int): number of samples to draw 12 | """ 13 | 14 | def __init__(self, indices, weights, num_samples=0): 15 | if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool): 16 | raise ValueError("num_samples should be a non-negative integeral " 17 | "value, but got num_samples={}".format(num_samples)) 18 | self.indices = indices 19 | weights = [weights[i] for i in self.indices] 20 | self.weights = torch.tensor(weights, dtype=torch.double) 21 | if num_samples == 0: 22 | self.num_samples = len(self.weights) 23 | else: 24 | self.num_samples = num_samples 25 | self.replacement = True 26 | 27 | def __iter__(self): 28 | return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, self.replacement)) 29 | 30 | def __len__(self): 31 | return self.num_samples 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSS: Differentiable Surface Splatting 2 | 3 | | [Paper PDF](https://igl.ethz.ch/projects/differentiable-surface-splatting/DSS-2019-SA-Yifan-etal.pdf) | [Project page](https://igl.ethz.ch/projects/differentiable-surface-splatting/) | 4 | | ----------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------ | 5 | 6 | ![bunny](images/teaser.png) 7 | 8 | code for paper Differentiable Surface Splatting for Point-based Geometry Processing 9 | 10 | ```diff 11 | + Mar 2021: major updates tag 2.0. 12 | + > Now supports simultaneous normal and point position updates. 13 | + > Unified learning rate using Adam optimizer. 14 | + > Highly optimized cuda operations 15 | + > Shares pytorch3d structure 16 | ``` 17 | 18 | - [DSS: Differentiable Surface Splatting](#dss-differentiable-surface-splatting) 19 | - [Installation](#installation) 20 | - [Demos](#demos) 21 | - [inverse rendering - shape deformation](#inverse-rendering---shape-deformation) 22 | - [~~denoising (TBA)~~](#denoising-tba) 23 | - [video](#video) 24 | - [cite](#cite) 25 | - [Acknowledgement](#acknowledgement) 26 | 27 | ## Installation 28 | 29 | 1. install prequisitories. Our code uses python 3.8, pytorch 1.6.0, pytorch3d 0.2.5. the installation instruction requires the latest anaconda. 30 | 31 | ```bash 32 | # we tested with cuda 10.2, pytorch 1.6.0, and pytorch 0.2.5 33 | # install requirements 34 | conda create -n DSS python=3.8 35 | conda activate DSS 36 | conda install -c pytorch pytorch=1.6.0 torchvision cudatoolkit=10.2 37 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 38 | conda install -c bottler nvidiacub 39 | conda install -c pytorch3d pytorch3d=0.2.5 40 | pip install -r requirements.txt 41 | pip install "git+https://github.com/mmolero/pypoisson.git" 42 | ``` 43 | 44 | 2. clone and compile 45 | 46 | ```bash 47 | git clone --recursive https://github.com/yifita/DSS.git 48 | cd DSS 49 | # if you have cloned it without `--recusive`, you can execute this command under DSS/ 50 | # git submodule update --init --recursive 51 | # compile external dependencies 52 | cd external/prefix_sum 53 | pip install . 54 | cd ../FRNN 55 | pip install . 56 | cd ../torch-batch-svd 57 | pip install . 58 | # compile library 59 | cd ../.. 60 | pip install -e . 61 | ``` 62 | 63 | ## Demos 64 | 65 | ### inverse rendering - shape deformation 66 | 67 | ```bash 68 | # create mvr images using intrinsics defined in the script 69 | python scripts/create_mvr_data_from_mesh.py --points example_data/mesh/yoga6.ply --output example_data/images --num_cameras 128 --image-size 512 --tri_color_light --point_lights --has_specular 70 | 71 | python train_mvr.py --config configs/dss.yml 72 | ``` 73 | 74 | Check the optimization process in tensorboard. 75 | 76 | ``` 77 | tensorboard --logdir=exp/dss_proj 78 | ``` 79 | 80 | ### ~~denoising (TBA)~~ 81 | 82 | We will add back this function ASAP. 83 | 84 | ![denoise_1noise](images/armadillo_2_all.png) 85 | 86 | ## video 87 | 88 | [![accompanying video](images/video-thumb.png)](https://youtu.be/MIu59GiJZ2s "Accompanying video") 89 | 90 | 91 | 92 | ## cite 93 | 94 | Please cite us if you find the code useful! 95 | 96 | ``` 97 | @article{Yifan:DSS:2019, 98 | author = {Yifan, Wang and 99 | Serena, Felice and 100 | Wu, Shihao and 101 | {\"{O}}ztireli, Cengiz and 102 | Sorkine{-}Hornung, Olga}, 103 | title = {Differentiable Surface Splatting for Point-based Geometry Processing}, 104 | journal = {ACM Transactions on Graphics (proceedings of ACM SIGGRAPH ASIA)}, 105 | volume = {38}, 106 | number = {6}, 107 | year = {2019}, 108 | } 109 | ``` 110 | 111 | ## Acknowledgement 112 | 113 | We would like to thank Federico Danieli for the insightful discussion, Phillipp Herholz for the timely feedack, Romann Weber for the video voice-over and Derek Liu for the help during the rebuttal. 114 | This work was supported in part by gifts from Adobe, Facebook and Snap, Inc. 115 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import math 5 | import random 6 | from DSS.misc.visualize import animate_points, animate_mesh, figures_to_html 7 | from DSS.core.lighting import PointLights, DirectionalLights 8 | from DSS import logger_py 9 | 10 | 11 | def create_animation(pts_dir, show_max=-1): 12 | figs = [] 13 | # points 14 | pts_files = [f for f in os.listdir(pts_dir) if 'pts' in f and f[-4:].lower() in ('.ply', 'obj')] 15 | if len(pts_files) == 0: 16 | logger_py.info("Couldn't find '*pts*' files in {}".format(pts_dir)) 17 | else: 18 | pts_files.sort() 19 | if show_max > 0: 20 | pts_files = pts_files[::max(len(pts_files) // show_max, 1)] 21 | pts_names = list(map(lambda x: os.path.basename(x) 22 | [:-4].split('_')[0], pts_files)) 23 | pts_paths = [os.path.join(pts_dir, fname) for fname in pts_files] 24 | fig = animate_points(pts_paths, pts_names) 25 | figs.append(fig) 26 | # mesh 27 | mesh_files = [f for f in os.listdir(pts_dir) if 'mesh' in f and f[-4:].lower() in ('.ply', '.obj')] 28 | # mesh_files = list(filter(lambda x: x.split('_') 29 | # [1] == '000.obj', mesh_files)) 30 | if len(mesh_files) == 0: 31 | logger_py.info("Couldn't find '*mesh*' files in {}".format(pts_dir)) 32 | else: 33 | mesh_files.sort() 34 | if show_max > 0: 35 | mesh_files = mesh_files[::max(len(mesh_files) // show_max, 1)] 36 | mesh_names = list(map(lambda x: os.path.basename(x) 37 | [:-4].split('_')[0], mesh_files)) 38 | mesh_paths = [os.path.join(pts_dir, fname) for fname in mesh_files] 39 | fig = animate_mesh(mesh_paths, mesh_names) 40 | figs.append(fig) 41 | 42 | save_html = os.path.join(pts_dir, 'animation.html') 43 | os.makedirs(os.path.dirname(save_html), exist_ok=True) 44 | figures_to_html(figs, save_html) 45 | 46 | 47 | def get_tri_color_lights_for_view(cams, has_specular=False, point_lights=True): 48 | """ 49 | Create RGB lights direction in the half dome 50 | The direction is given in the same coordinates as the pointcloud 51 | Args: 52 | cams 53 | Returns: 54 | Lights with three RGB light sources (B: right, G: left, R: bottom) 55 | """ 56 | import math 57 | from DSS.core.lighting import (DirectionalLights, PointLights) 58 | from pytorch3d.renderer.cameras import look_at_rotation 59 | from pytorch3d.transforms import Rotate 60 | 61 | elev = torch.tensor(((30, 30, 30),),device=cams.device) 62 | azim = torch.tensor(((-60, 60, 180),),device=cams.device) 63 | elev = math.pi / 180.0 * elev 64 | azim = math.pi / 180.0 * azim 65 | 66 | x = torch.cos(elev) * torch.sin(azim) 67 | y = torch.sin(elev) 68 | z = torch.cos(elev) * torch.cos(azim) 69 | light_directions = torch.stack([x, y, z], dim=-1) 70 | cam_pos = cams.get_camera_center() 71 | R = look_at_rotation(torch.zeros_like(cam_pos), at=F.normalize(torch.cross(cam_pos, torch.rand_like(cam_pos)), dim=-1), up=cam_pos) 72 | light_directions = Rotate(R=R.transpose(1,2), device=cams.device).transform_points(light_directions) 73 | # trimesh.Trimesh(vertices=torch.cat([cam_pos, light_directions[0]], dim=0).cpu().numpy(), process=False).export('tests/outputs/light_dir.ply') 74 | ambient_color = torch.FloatTensor((((0.2, 0.2, 0.2), ), )) 75 | diffuse_color = torch.FloatTensor( 76 | (((0.0, 0.0, 0.8), (0.0, 0.8, 0.0), (0.8, 0.0, 0.0), ), )) 77 | if has_specular: 78 | specular_color = 0.15 * diffuse_color 79 | diffuse_color *= 0.85 80 | else: 81 | specular_color = (((0, 0, 0), (0, 0, 0), (0, 0, 0), ), ) 82 | if not point_lights: 83 | lights = DirectionalLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 84 | specular_color=specular_color, direction=light_directions) 85 | else: 86 | location = light_directions*5 87 | lights = PointLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 88 | specular_color=specular_color, location=location) 89 | return lights 90 | 91 | def get_light_for_view(cams, point_lights, has_specular): 92 | # create tri-color lights and a specular+diffuse shader 93 | ambient_color = torch.FloatTensor((((0.6, 0.6, 0.6),),)) 94 | diffuse_color = torch.FloatTensor( 95 | (((0.2, 0.2, 0.2),),)) 96 | 97 | if has_specular: 98 | specular_color = 0.15 * diffuse_color 99 | diffuse_color *= 0.85 100 | else: 101 | specular_color = (((0, 0, 0),),) 102 | 103 | elev = torch.tensor(((random.randint(10, 90),),), dtype=torch.float, device=cams.device) 104 | azim = torch.tensor(((random.randint(0, 360)),), dtype=torch.float, device=cams.device) 105 | elev = math.pi / 180.0 * elev 106 | azim = math.pi / 180.0 * azim 107 | 108 | x = torch.cos(elev) * torch.sin(azim) 109 | y = torch.sin(elev) 110 | z = torch.cos(elev) * torch.cos(azim) 111 | light_directions = torch.stack([x, y, z], dim=-1) 112 | # transform from camera to world 113 | light_directions = cams.get_world_to_view_transform().inverse().transform_points(light_directions) 114 | if not point_lights: 115 | lights = DirectionalLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 116 | specular_color=specular_color, direction=light_directions) 117 | else: 118 | location = light_directions*5 119 | lights = PointLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 120 | specular_color=specular_color, location=location) 121 | return lights -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | from easydict import EasyDict as edict 4 | import torch 5 | from pytorch3d.utils import ico_sphere 6 | from pytorch3d.ops import sample_points_from_meshes 7 | from pytorch3d.renderer import FoVPerspectiveCameras 8 | from DSS.core.texture import LightingTexture, NeuralTexture 9 | from DSS.utils import get_class_from_string 10 | from DSS.training.trainer import Trainer 11 | from DSS import logger_py 12 | 13 | 14 | # General config 15 | def load_config(path, default_path=None): 16 | ''' Loads config file. 17 | 18 | Args: 19 | path (str): path to config file 20 | default_path (bool): whether to use default path 21 | ''' 22 | # Load configuration from file itself 23 | cfg_special = None 24 | with open(path, 'r') as f: 25 | cfg_special = edict(yaml.load(f, Loader=yaml.Loader)) 26 | 27 | # Check if we should inherit from a config 28 | inherit_from = cfg_special.get('inherit_from') 29 | 30 | # If yes, load this config first as default 31 | # If no, use the default_path 32 | if inherit_from is not None: 33 | cfg = load_config(inherit_from, default_path) 34 | elif default_path is not None: 35 | with open(default_path, 'r') as f: 36 | cfg = edict(yaml.load(f, Loader=yaml.Loader)) 37 | else: 38 | cfg = edict() 39 | 40 | # Include main configuration 41 | update_recursive(cfg, cfg_special) 42 | 43 | return cfg 44 | 45 | 46 | def save_config(path, config): 47 | """ 48 | Save config dictionary as json file 49 | """ 50 | out_dir = os.path.dirname(path) 51 | if not os.path.exists(out_dir): 52 | os.makedirs(out_dir) 53 | 54 | if os.path.isfile(path): 55 | logger_py.warn( 56 | "Found file existing in {}, overwriting the existing file.".format(out_dir)) 57 | 58 | with open(path, 'w') as f: 59 | yaml.dump(config, f, sort_keys=False) 60 | 61 | logger_py.info("Saved config to {}".format(path)) 62 | 63 | 64 | def update_recursive(dict1, dict2): 65 | ''' Update two config dictionaries recursively. 66 | 67 | Args: 68 | dict1 (dict): first dictionary to be updated 69 | dict2 (dict): second dictionary which entries should be used 70 | 71 | ''' 72 | for k, v in dict2.items(): 73 | if k not in dict1: 74 | dict1[k] = edict() 75 | if isinstance(v, dict): 76 | update_recursive(dict1[k], v) 77 | else: 78 | dict1[k] = v 79 | 80 | 81 | def _get_tensor_with_default(opt, key, size, fill_value=0.0): 82 | if key not in opt: 83 | return torch.zeros(*size).fill_(fill_value) 84 | else: 85 | return torch.FloatTensor(opt[key]) 86 | 87 | 88 | def create_point_texture(opt_renderer_texture): 89 | from DSS.core.texture import (NeuralTexture, LightingTexture) 90 | """ create shader that generate per-point color """ 91 | if opt_renderer_texture.texture.is_neural_shader: 92 | texture = NeuralTexture(opt_renderer_texture.texture) 93 | else: 94 | lights = create_lights(opt_renderer_texture.get('lights', None)) 95 | texture = LightingTexture( 96 | specular=opt_renderer_texture.texture.specular, lights=lights) 97 | 98 | return texture 99 | 100 | 101 | def create_lights(opt_renderer_texture_lights): 102 | """ 103 | Create lights specified by opt, if no sun or point lights 104 | are given, create the tri-color lights. 105 | Currently only supports the same lights for all batches 106 | """ 107 | from DSS.core.lighting import (DirectionalLights, PointLights) 108 | ambient_color = torch.tensor( 109 | opt_renderer_texture_lights.ambient_color).view(1, -1, 3) 110 | specular_color = torch.tensor( 111 | opt_renderer_texture_lights.specular_color).view(1, -1, 3) 112 | diffuse_color = torch.tensor( 113 | opt_renderer_texture_lights.diffuse_color).view(1, -1, 3) 114 | if opt_renderer_texture_lights['type'] == "sun": 115 | direction = torch.tensor( 116 | opt_renderer_texture_lights.direction).view(1, -1, 3) 117 | lights = DirectionalLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 118 | specular_color=specular_color, direction=direction) 119 | elif opt_renderer_texture_lights['type'] == 'point': 120 | location = torch.tensor( 121 | opt_renderer_texture_lights.location).view(1, -1, 3) 122 | lights = PointLights(ambient_color=ambient_color, diffuse_color=diffuse_color, 123 | specular_color=specular_color, location=location) 124 | 125 | return lights 126 | 127 | 128 | def create_cameras(opt): 129 | pass 130 | 131 | 132 | def create_dataset(opt_data, mode="train"): 133 | import DSS.utils.dataset as DssDataset 134 | if opt_data.type == 'MVR': 135 | dataset = DssDataset.MVRDataset(**opt_data, mode=mode) 136 | elif opt_data.type == 'DTU': 137 | dataset = DssDataset.DTUDataset(**opt_data, mode=mode) 138 | else: 139 | raise NotImplementedError 140 | return dataset 141 | 142 | 143 | def create_model(cfg, device, mode="train", camera_model=None, **kwargs): 144 | ''' Returns model 145 | 146 | Args: 147 | cfg (edict): imported yaml config 148 | device (device): pytorch device 149 | ''' 150 | if cfg.model.type == 'point': 151 | decoder = None 152 | 153 | texture = None 154 | use_lighting = (cfg.renderer is not None and not cfg.renderer.get( 155 | 'is_neural_texture', True)) 156 | if use_lighting: 157 | texture = LightingTexture() 158 | else: 159 | if 'rgb' not in cfg.model.decoder_kwargs.out_dims: 160 | Texture = get_class_from_string(cfg.model.texture_type) 161 | cfg.model.texture_kwargs['c_dim'] = cfg.model.decoder_kwargs.out_dims.get('latent', 0) 162 | texture_decoder = Texture(**cfg.model.texture_kwargs) 163 | else: 164 | texture_decoder = decoder 165 | logger_py.info("Decoder used as NeuralTexture") 166 | 167 | texture = NeuralTexture( 168 | view_dependent=cfg.model.texture_kwargs.view_dependent, decoder=texture_decoder).to(device=device) 169 | logger_py.info("Created NeuralTexture {}".format(texture.__class__)) 170 | logger_py.info(texture) 171 | 172 | Model = get_class_from_string( 173 | "DSS.models.{}_modeling.Model".format(cfg.model.type)) 174 | 175 | # if not using decoder, then use non-parameterized point renderer 176 | # create icosphere as initial point cloud 177 | sphere_mesh = ico_sphere(level=4) 178 | sphere_mesh.scale_verts_(0.5) 179 | points, normals = sample_points_from_meshes( 180 | sphere_mesh, num_samples=int( 181 | cfg['model']['model_kwargs']['n_points_per_cloud']), 182 | return_normals=True) 183 | colors = torch.ones_like(points) 184 | renderer = create_renderer(cfg.renderer).to(device) 185 | model = Model( 186 | points, normals, colors, 187 | renderer, 188 | device=device, 189 | texture=texture, 190 | **cfg.model.model_kwargs, 191 | ).to(device=device) 192 | 193 | return model 194 | 195 | 196 | def create_generator(cfg, model, device, **kwargs): 197 | ''' Returns the generator object. 198 | 199 | Args: 200 | model (nn.Module): model 201 | cfg (dict): imported yaml config 202 | device (device): pytorch device 203 | ''' 204 | Generator = get_class_from_string( 205 | 'DSS.models.{}_modeling.Generator'.format(cfg.model.type)) 206 | 207 | generator = Generator(model, device, 208 | threshold=cfg['test']['threshold'], 209 | **cfg.generation) 210 | return generator 211 | 212 | 213 | def create_trainer(cfg, model, optimizer, scheduler, generator, train_loader, val_loader, device, **kwargs): 214 | ''' Returns the trainer object. 215 | 216 | Args: 217 | model (nn.Module): the model 218 | optimizer (optimizer): pytorch optimizer object 219 | cfg (dict): imported yaml config 220 | device (device): pytorch device 221 | generator (Generator): generator instance to 222 | generate meshes for visualization 223 | ''' 224 | threshold = cfg['test']['threshold'] 225 | out_dir = os.path.join(cfg['training']['out_dir'], cfg['name']) 226 | vis_dir = os.path.join(out_dir, 'vis') 227 | debug_dir = os.path.join(out_dir, 'debug') 228 | log_dir = os.path.join(out_dir, 'logs') 229 | val_dir = os.path.join(out_dir, 'val') 230 | 231 | trainer = Trainer( 232 | model, optimizer, scheduler, generator, train_loader, val_loader, 233 | device=device, 234 | vis_dir=vis_dir, debug_dir=debug_dir, log_dir=log_dir, val_dir=val_dir, 235 | threshold=threshold, 236 | **cfg.training) 237 | 238 | return trainer 239 | 240 | 241 | def create_renderer(render_opt): 242 | """ Create rendere """ 243 | Renderer = get_class_from_string(render_opt.renderer_type) 244 | Raster = get_class_from_string(render_opt.raster_type) 245 | i = render_opt.raster_type.rfind('.') 246 | raster_setting_type = render_opt.raster_type[:i] + \ 247 | '.PointsRasterizationSettings' 248 | if render_opt.compositor_type is not None: 249 | Compositor = get_class_from_string(render_opt.compositor_type) 250 | compositor = Compositor() 251 | else: 252 | compositor = None 253 | 254 | RasterSetting = get_class_from_string(raster_setting_type) 255 | raster_settings = RasterSetting(**render_opt.raster_params) 256 | 257 | renderer = Renderer( 258 | rasterizer=Raster( 259 | cameras=FoVPerspectiveCameras(), raster_settings=raster_settings), 260 | compositor=compositor, 261 | ) 262 | return renderer 263 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | name: "demo" 2 | data: 3 | type: MVR 4 | data_dir: "example_data" 5 | data_dict: "data_dict.npz" 6 | img_folder: image 7 | mask_folder: mask 8 | depth_folder: depth 9 | img_extension: png 10 | img_extension_input: jpg 11 | mask_extension: png 12 | img_with_camera: true 13 | img_with_mask: true 14 | n_imgs: null 15 | resolution: [512, 512] 16 | renderer: 17 | is_neural_texture: False 18 | renderer_type: DSS.core.renderer.SurfaceSplattingRenderer 19 | raster_type: DSS.core.rasterizer.SurfaceSplatting 20 | raster_params: 21 | backface_culling: false 22 | Vrk_isotropic: false 23 | bin_size: null 24 | clip_pts_grad: 0.05 25 | cutoff_threshold: 0.5 26 | depth_merging_threshold: 0.05 27 | image_size: 512 28 | max_points_per_bin: null 29 | points_per_pixel: 5 30 | radii_backward_scaler: 5 31 | compositor_type: pytorch3d.renderer.NormWeightedCompositor 32 | composite_params: {} 33 | lighting: 'from_data' 34 | # 'from_data' | 'default' 35 | # from_data: use ground truth lighting (in this case learn_colors should be false) 36 | # default: use white ambient light, in this case learn_normals shoudl be false 37 | model: 38 | type: point 39 | model_kwargs: 40 | learn_points: true 41 | learn_normals: true 42 | learn_colors: False 43 | n_points_per_cloud: 8000 44 | training: 45 | out_dir: exp 46 | # loss when renderer to get predicted image 47 | lambda_dr_rgb: 1.0 48 | lambda_dr_silhouette: 1.0 49 | lambda_dr_proj: 0.1 50 | lambda_dr_repel: 0.1 51 | batch_size: 1 52 | batch_size_val: 1 53 | patch_size: 1 54 | print_every: 10 55 | checkpoint_every: 500 56 | visualize_every: 100 57 | validate_every: 500 58 | debug_every: 500 59 | learning_rate: 0.0001 60 | scheduler_milestones: [500, 800] 61 | scheduler_gamma: 0.5 62 | n_workers: 1 63 | logfile: train.log 64 | overwrite_visualization: false 65 | n_debug_points: -1 66 | resume_from: model.pt 67 | point_file: shape_pts.ply 68 | model_selection_metric: chamfer_point 69 | model_selection_mode: minimize 70 | generation: 71 | batch_size: 1 72 | vis_n_outputs: 30 73 | generation_dir: generation 74 | refinement_step: -1 75 | points_batch_size: 100000 76 | with_colors: true 77 | with_normals: true 78 | mesh_extension: ply 79 | test: 80 | eval_file_name: eval_meshes 81 | threshold: 0.0 82 | model_file: model_best.pt 83 | -------------------------------------------------------------------------------- /configs/dss.yml: -------------------------------------------------------------------------------- 1 | name: dss_proj 2 | data: 3 | type: MVR 4 | data_dir: example_data/images/yoga6_variational_light 5 | model: 6 | type: point 7 | model_kwargs: 8 | n_points_per_cloud: 5000 9 | learn_colors: false 10 | learn_points: true 11 | learn_normals: true 12 | renderer: 13 | is_neural_texture: false 14 | raster_params: 15 | Vrk_invariant: true 16 | Vrk_isotropic: false 17 | clip_pts_grad: 0.05 18 | cutoff_threshold: 1.0 19 | depth_merging_threshold: 0.05 20 | image_size: 512 21 | points_per_pixel: 5 22 | radii_backward_scaler: 5 23 | raster_type: DSS.core.rasterizer.SurfaceSplatting 24 | renderer_type: DSS.core.renderer.SurfaceSplattingRenderer 25 | training: 26 | backup_every: 1000 27 | batch_size: 8 28 | checkpoint_every: 400 29 | debug_every: 100 30 | lambda_dr_proj: 0.01 31 | lambda_dr_repel: 0.0 32 | lambda_dr_rgb: 1.0 33 | lambda_dr_silhouette: 1.0 34 | n_training_points: 0 35 | n_workers: 0 36 | point_file: shape_pts.ply 37 | print_every: 10 38 | scheduler_gamma: 0.5 39 | visualize_every: 100 40 | steps_dss_backward_radii: 200 41 | gamma_dss_backward_radii: 0.9 42 | limit_dss_backward_radii: 2 43 | type: MVR 44 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DSS 2 | channels: 3 | - pytorch3d 4 | - bottler 5 | - iopath 6 | - conda-forge 7 | - pytorch 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=1_gnu 12 | - absl-py=0.12.0=pyhd8ed1ab_0 13 | - attrs=20.3.0=pyhd3deb0d_0 14 | - blas=1.0=mkl 15 | - c-ares=1.17.1=h7f98852_1 16 | - ca-certificates=2020.12.5=ha878542_0 17 | - certifi=2020.12.5=py38h578d9bd_1 18 | - cloudpickle=1.6.0=py_0 19 | - cudatoolkit=10.2.89=h8f6ccaa_8 20 | - cycler=0.10.0=py_2 21 | - cython=0.29.22=py38h709712a_0 22 | - cytoolz=0.11.0=py38h497a2fe_3 23 | - dask-core=2021.3.0=pyhd8ed1ab_0 24 | - dbus=1.13.18=hb2f20db_0 25 | - decorator=4.4.2=py_0 26 | - easydict=1.9=py_0 27 | - expat=2.2.10=h9c3ff4c_0 28 | - fontconfig=2.13.1=he4413a7_1000 29 | - freetype=2.10.4=h0708190_1 30 | - fvcore=0.1.3.post20210223=pyhd8ed1ab_0 31 | - gettext=0.19.8.1=h0b5b191_1005 32 | - gitdb=4.0.5=pyhd8ed1ab_1 33 | - gitpython=3.1.14=pyhd8ed1ab_0 34 | - glib=2.66.7=h9c3ff4c_1 35 | - glib-tools=2.66.7=h9c3ff4c_1 36 | - grpcio=1.36.1=py38hdd6454d_0 37 | - gst-plugins-base=1.14.0=hbbd80ab_1 38 | - gstreamer=1.14.0=h28cd5cc_2 39 | - icu=58.2=hf484d3e_1000 40 | - imageio=2.8.0=py_0 41 | - importlib-metadata=3.7.3=py38h578d9bd_0 42 | - iniconfig=1.1.1=pyh9f0ad1d_0 43 | - intel-openmp=2020.2=254 44 | - iopath=0.1.6.post20210311=py38 45 | - jpeg=9d=h36c2ea0_0 46 | - kiwisolver=1.3.1=py38h1fd1430_1 47 | - lcms2=2.12=hddcbb42_0 48 | - ld_impl_linux-64=2.35.1=hea4e1c9_2 49 | - libffi=3.3=h58526e2_2 50 | - libgcc-ng=9.3.0=h2828fa1_18 51 | - libgfortran-ng=7.5.0=h14aa051_18 52 | - libgfortran4=7.5.0=h14aa051_18 53 | - libglib=2.66.7=h3e27bee_1 54 | - libgomp=9.3.0=h2828fa1_18 55 | - libiconv=1.16=h516909a_0 56 | - libpng=1.6.37=h21135ba_2 57 | - libprotobuf=3.15.6=h780b84a_0 58 | - libstdcxx-ng=9.3.0=h6de172a_18 59 | - libtiff=4.2.0=hdc55705_0 60 | - libuuid=2.32.1=h7f98852_1000 61 | - libwebp-base=1.2.0=h7f98852_0 62 | - libxcb=1.13=h7f98852_1003 63 | - libxml2=2.9.9=h13577e0_2 64 | - llvm-meta=7.0.0=0 65 | - lz4-c=1.9.3=h9c3ff4c_0 66 | - markdown=3.3.4=pyhd8ed1ab_0 67 | - matplotlib=3.2.1=0 68 | - matplotlib-base=3.2.1=py38hef1b27d_0 69 | - mkl=2020.2=256 70 | - mkl-service=2.3.0=py38h1e0a361_2 71 | - mkl_fft=1.3.0=py38h54f3939_0 72 | - mkl_random=1.2.0=py38hc5bc63f_1 73 | - more-itertools=8.7.0=pyhd8ed1ab_0 74 | - ncurses=6.2=h58526e2_4 75 | - networkx=2.5=py_0 76 | - ninja=1.10.2=h4bd325d_0 77 | - numpy=1.19.2=py38h54aff64_0 78 | - numpy-base=1.19.2=py38hfa32c7d_0 79 | - nvidiacub=1.10.0=0 80 | - olefile=0.46=pyh9f0ad1d_1 81 | - openmp=7.0.0=h2d50403_0 82 | - openssl=1.1.1j=h7f98852_0 83 | - packaging=20.9=pyh44b312d_0 84 | - pcre=8.44=he1b5a44_0 85 | - pillow=8.1.2=py38ha0e1e83_0 86 | - pip=21.0.1=pyhd8ed1ab_0 87 | - plotly=4.7.1=pyh9f0ad1d_0 88 | - pluggy=0.13.1=py38h578d9bd_4 89 | - plyfile=0.7.2=pyh9f0ad1d_0 90 | - point_cloud_utils=0.17.1=py38hc10631b_2 91 | - portalocker=1.7.0=py38h578d9bd_1 92 | - protobuf=3.15.6=py38h709712a_0 93 | - pthread-stubs=0.4=h36c2ea0_1001 94 | - py=1.10.0=pyhd3deb0d_0 95 | - pyparsing=2.4.7=pyh9f0ad1d_0 96 | - pyqt=5.9.2=py38h05f1152_4 97 | - pytest=6.2.1=py38h578d9bd_1 98 | - python=3.8.8=hffdb5ce_0_cpython 99 | - python-dateutil=2.8.1=py_0 100 | - python_abi=3.8=1_cp38 101 | - pytorch=1.6.0=py3.8_cuda10.2.89_cudnn7.6.5_0 102 | - pytorch3d=0.4.0=py38_cu102_pyt160 103 | - pywavelets=1.1.1=py38h5c078b8_3 104 | - pyyaml=5.4.1=py38h497a2fe_0 105 | - qt=5.9.7=h5867ecd_1 106 | - readline=8.0=he28a2e2_2 107 | - retrying=1.3.3=py_2 108 | - scikit-image=0.16.2=py38hb3f55d8_0 109 | - scipy=1.6.1=py38h91f5cce_0 110 | - setuptools=49.6.0=py38h578d9bd_3 111 | - sip=4.19.13=py38he6710b0_0 112 | - six=1.15.0=pyh9f0ad1d_0 113 | - smmap=3.0.5=pyh44b312d_0 114 | - sqlite=3.34.0=h74cdb3f_0 115 | - tabulate=0.8.9=pyhd8ed1ab_0 116 | - tensorboard=1.15.0=py38_0 117 | - termcolor=1.1.0=py_2 118 | - tk=8.6.10=h21135ba_1 119 | - toml=0.10.2=pyhd8ed1ab_0 120 | - toolz=0.11.1=py_0 121 | - torchvision=0.7.0=py38_cu102 122 | - tornado=6.1=py38h497a2fe_1 123 | - tqdm=4.59.0=pyhd8ed1ab_0 124 | - trimesh=3.6.34=pyh9f0ad1d_0 125 | - werkzeug=1.0.1=pyh9f0ad1d_0 126 | - wheel=0.36.2=pyhd3deb0d_0 127 | - xorg-libxau=1.0.9=h7f98852_0 128 | - xorg-libxdmcp=1.1.3=h7f98852_0 129 | - xz=5.2.5=h516909a_1 130 | - yaml=0.2.5=h516909a_0 131 | - zipp=3.4.1=pyhd8ed1ab_0 132 | - zlib=1.2.11=h516909a_1010 133 | - zstd=1.4.9=ha95c52a_0 134 | - pip: 135 | - frnn==0.0.0 136 | - future==0.18.2 137 | - prefix-sum==0.0.0 138 | - pymeshlab==0.2 139 | - pypoisson==0.10 140 | - torch-batch-svd==1.0.0 141 | - yacs==0.1.8 142 | prefix: /home/ywang/anaconda3/envs/DSS 143 | -------------------------------------------------------------------------------- /example_data/mesh/yoga6.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/mesh/yoga6.ply -------------------------------------------------------------------------------- /example_data/pointclouds/Kangaroo_V10k.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/Kangaroo_V10k.ply -------------------------------------------------------------------------------- /example_data/pointclouds/Kangaroo_V10k_nc.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/Kangaroo_V10k_nc.ply -------------------------------------------------------------------------------- /example_data/pointclouds/Koala_V10k.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/Koala_V10k.ply -------------------------------------------------------------------------------- /example_data/pointclouds/Koala_V10k_nc.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/Koala_V10k_nc.ply -------------------------------------------------------------------------------- /example_data/pointclouds/a72-seated_jew_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/a72-seated_jew_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/armadillo_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/armadillo_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/bunny-8000.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/bunny-8000.ply -------------------------------------------------------------------------------- /example_data/pointclouds/cube_20k.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/cube_20k.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/A9-vulcan_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/A9-vulcan_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/Gramme_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/Gramme_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/a72-seated_jew_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/a72-seated_jew_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/asklepios_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/asklepios_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/baron_seutin_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/baron_seutin_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/charite_-_CleanUp_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/charite_-_CleanUp_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/cheval_terracotta_-_LowPoly-RealOne_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/cheval_terracotta_-_LowPoly-RealOne_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/cupid_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/cupid_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/dame_assise_-_CleanUp_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/dame_assise_-_CleanUp_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/drunkard_-_CleanUp_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/drunkard_-_CleanUp_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/madeleine_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/madeleine_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/retheur_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/retheur_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy03_points/saint_lambert_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy03_points/saint_lambert_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/A9-vulcan_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/A9-vulcan_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/Gramme_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/Gramme_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/a72-seated_jew_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/a72-seated_jew_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/asklepios_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/asklepios_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/baron_seutin_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/baron_seutin_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/charite_-_CleanUp_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/charite_-_CleanUp_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/cheval_terracotta_-_LowPoly-RealOne_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/cheval_terracotta_-_LowPoly-RealOne_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/cupid_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/cupid_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/dame_assise_-_CleanUp_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/dame_assise_-_CleanUp_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/drunkard_-_CleanUp_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/drunkard_-_CleanUp_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/madeleine_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/madeleine_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/retheur_-_LowPoly_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/retheur_-_LowPoly_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/noisy1_points/saint_lambert_aligned_pca.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/noisy1_points/saint_lambert_aligned_pca.ply -------------------------------------------------------------------------------- /example_data/pointclouds/point-one.ply: -------------------------------------------------------------------------------- 1 | ply 2 | format ascii 1.0 3 | comment Some example points for test renders 4 | element vertex 1 5 | property float x 6 | property float y 7 | property float z 8 | property float nx 9 | property float ny 10 | property float nz 11 | property uchar red 12 | property uchar green 13 | property uchar blue 14 | end_header 15 | 0.0000 0.0000 0.0000 0.0000 0.0000 1.0000 200 200 200 16 | -------------------------------------------------------------------------------- /example_data/pointclouds/sphere_2k.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/sphere_2k.ply -------------------------------------------------------------------------------- /example_data/pointclouds/sphere_300.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/sphere_300.ply -------------------------------------------------------------------------------- /example_data/pointclouds/teapot_normal_dense.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/teapot_normal_dense.ply -------------------------------------------------------------------------------- /example_data/pointclouds/yoga1_out.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/yoga1_out.ply -------------------------------------------------------------------------------- /example_data/pointclouds/yoga6_out.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/example_data/pointclouds/yoga6_out.ply -------------------------------------------------------------------------------- /images/2D_teapot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/2D_teapot.gif -------------------------------------------------------------------------------- /images/armadillo_2_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/armadillo_2_all.png -------------------------------------------------------------------------------- /images/seated_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/seated_all.png -------------------------------------------------------------------------------- /images/teapot_2D.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/teapot_2D.gif -------------------------------------------------------------------------------- /images/teapot_3D.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/teapot_3D.gif -------------------------------------------------------------------------------- /images/teapot_sequence.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/teapot_sequence.gif -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/teaser.png -------------------------------------------------------------------------------- /images/video-thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/video-thumb.png -------------------------------------------------------------------------------- /images/yoga6-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/yoga6-1.gif -------------------------------------------------------------------------------- /images/yoga6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifita/DSS/8fd8d86593272a56305b13ec7d9b4bdd9d241ef9/images/yoga6.gif -------------------------------------------------------------------------------- /learn_image_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | import numpy as np 5 | import time 6 | import importlib 7 | from DSS.utils.splatterIo import readCloud, readScene, saveAsPng, writeScene 8 | from DSS.utils.trainer import FilterTrainer as Trainer 9 | from DSS.options.filter_options import FilterOptions 10 | from pytorch_points.network.operations import normalize_point_batch 11 | from pytorch_points.utils.pc_utils import load 12 | 13 | 14 | def trainImageFilter(scene, benchmark=False): 15 | expr_dir = os.path.join(opt.output, opt.name) 16 | if not os.path.isdir(expr_dir): 17 | os.makedirs(expr_dir) 18 | 19 | trainer = Trainer(opt, scene) 20 | trainer.setup(opt, scene.cloud) 21 | 22 | logInterval = math.floor(1+sum(opt.steps)//20) 23 | renderForwardTime = 0.0 24 | lossTime = 0.0 25 | optimizerStep = 0.0 26 | 27 | with torch.autograd.detect_anomaly(): 28 | with open(os.path.join(expr_dir, "loss.csv"), 'w') as loss_log: 29 | for c in range(opt.cycles): 30 | # creat new reference 31 | tb = c*sum(opt.steps)+opt.startingStep 32 | te = (c+1)*sum(opt.steps)+opt.startingStep 33 | t = tb 34 | 35 | with torch.no_grad(): 36 | trainer.create_reference(scene) 37 | trainer.initiate_cycle() 38 | for i, pair in enumerate(zip(trainer.groundtruths, trainer.predictions)): 39 | post, pre = pair 40 | diff = post - pre 41 | saveAsPng(pre.cpu(), os.path.join(expr_dir, 't%03d_cam%d_init.png' % (t, i))) 42 | saveAsPng(post.cpu(), os.path.join(expr_dir, 't%03d_cam%d_gt.png' % (t, i))) 43 | saveAsPng(diff.cpu(), os.path.join(expr_dir, 't%03d_cam%d_diff.png' % (t, i))) 44 | 45 | for t in range(tb, te): 46 | if t % logInterval == 0 and not benchmark: 47 | writeScene(scene, os.path.join(expr_dir, 't%03d' % t + 48 | '_values.json'), os.path.join(expr_dir, 't%03d' % t + '.ply')) 49 | 50 | trainer.optimize_parameters() 51 | if t % logInterval == 0 and not benchmark: 52 | for i, prediction in enumerate(trainer.predictions): 53 | saveAsPng(prediction.detach().cpu()[0], os.path.join(expr_dir, 't%03d_cam%d' % (t, i) + ".png")) 54 | 55 | if not benchmark: 56 | loss_str = ",".join(["%.3f" % (100*v) for v in trainer.loss_image]) 57 | reg_str = ",".join(["%.3f" % (100*v) for v in trainer.loss_reg]) 58 | entries = [trainer.modifier] + [loss_str] + [reg_str] 59 | loss_log.write(",".join(entries)+"\n") 60 | print("{:03d} {}: lr {} loss ({}) \n : reg ({})".format( 61 | t, trainer.modifier, trainer.lr, loss_str, reg_str)) 62 | 63 | trainer.finish_cycle() 64 | 65 | writeScene(scene, os.path.join(expr_dir, 'final_scene.json'), 66 | os.path.join(expr_dir, 'final_cloud.ply')) 67 | 68 | 69 | if __name__ == "__main__": 70 | opt = FilterOptions().parse() 71 | 72 | torch.manual_seed(24) 73 | torch.backends.cudnn.deterministic = True 74 | torch.backends.cudnn.benchmark = False 75 | np.random.seed(24) 76 | 77 | # Create ground truth 78 | scene = readScene(opt.source, device="cpu") 79 | if opt.cloud: 80 | points = readCloud(opt.cloud, device="cpu") 81 | points_coords, _, _ = normalize_point_batch( 82 | points[:, :3].unsqueeze(0), NCHW=False) 83 | points[:, :3] = points_coords.squeeze(0)*2 84 | scene.loadPoints(points) 85 | 86 | trainImageFilter(scene, benchmark=opt.benchmark) 87 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | easydict==1.9 5 | imageio==2.8.0 6 | matplotlib==3.2.1 7 | plotly==4.7.1 8 | plyfile==0.7.2 9 | pytest==6.2.1 10 | scikit-image==0.16.2 11 | tensorboard==1.15.0 12 | tqdm 13 | trimesh==3.6.34 14 | point_cloud_utils 15 | pymeshlab 16 | GitPython 17 | protobuf==3.20.* 18 | open3d==0.15.2 19 | -------------------------------------------------------------------------------- /scripts/evaluatePointClouds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from glob import glob 6 | import re 7 | import csv 8 | from collections import OrderedDict 9 | from pytorch_points.network.operations import normalize_point_batch, group_knn 10 | from pytorch_points.network.model_loss import nndistance 11 | from pytorch_points.utils.pc_utils import save_ply_property, load, save_ply_property 12 | 13 | 14 | def get_filenames(source, extension): 15 | # If extension is a list 16 | if source is None: 17 | return [] 18 | # Seamlessy load single file, list of files and files from directories. 19 | source_fns = [] 20 | if isinstance(source, str): 21 | if os.path.isdir(source): 22 | if not isinstance(extension, str): 23 | for fmt in extension: 24 | source_fns += get_filenames(source, fmt) 25 | else: 26 | source_fns = sorted( 27 | glob("{}/**/*{}".format(source, extension), recursive=True)) 28 | elif os.path.isfile(source): 29 | source_fns = [source] 30 | elif len(source) and isinstance(source[0], str): 31 | for s in source: 32 | source_fns.extend(get_filenames(s, extension=extension)) 33 | return source_fns 34 | 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--gt", type=str, required=True, help="directory or file name for ground truth point clouds") 38 | parser.add_argument("--pred", type=str, nargs="+", required=True, help="directorie of predictions") 39 | parser.add_argument("--name", type=str, required=True, help="name pattern if provided directory for pred and gt") 40 | FLAGS = parser.parse_args() 41 | if os.path.isdir(FLAGS.gt): 42 | GT_DIR = FLAGS.gt 43 | gt_paths = get_filenames(GT_DIR, ("ply", "pcd", "xyz")) 44 | gt_names = [os.path.basename(p)[:-4] for p in gt_paths] 45 | elif os.path.isfile(FLAGS.gt): 46 | gt_paths = [FLAGS.gt] 47 | 48 | PRED_DIR = FLAGS.pred 49 | NAME = FLAGS.name 50 | 51 | 52 | fieldnames = ["name", "CD", "hausdorff", "p2f avg", "p2f std"] + ["nuc_%d" % d for d in range(7)] 53 | print("{:60s} ".format("name"), "|".join(["{:>15s}".format(d) for d in fieldnames[1:]])) 54 | for D in PRED_DIR: 55 | avg_md_forward_value = 0 56 | avg_md_backward_value = 0 57 | avg_hd_value = 0 58 | counter = 0 59 | pred_paths = glob(os.path.join(D, "**", NAME), recursive=True) 60 | if len(pred_paths) == 1 and len(pred_paths) > 1: 61 | gt_pred_pairs = [] 62 | for p in pred_paths: 63 | name, ext = os.path.splitext(os.path.basename(p)) 64 | assert(ext in (".ply", ".xyz")) 65 | try: 66 | gt = gt_paths[gt_names.index(name)] 67 | except ValueError: 68 | pass 69 | else: 70 | gt_pred_pairs.append((gt, p)) 71 | else: 72 | gt_pred_pairs = [] 73 | for p in pred_paths: 74 | gt_pred_pairs.append((gt_paths[0], p)) 75 | 76 | # print("total inputs ", len(gt_pred_pairs)) 77 | # tag = re.search("/(\w+)/result", os.path.dirname(gt_pred_pairs[0][1])) 78 | tag = os.path.basename(os.path.dirname(gt_pred_pairs[0][1])) 79 | 80 | print("{:60s}".format(tag), end=' ') 81 | global_p2f = [] 82 | global_density = [] 83 | with open(os.path.join(os.path.dirname(gt_pred_pairs[0][1]), "evaluation.csv"), "w") as f: 84 | writer = csv.DictWriter(f, fieldnames=fieldnames, restval="-", extrasaction="ignore") 85 | writer.writeheader() 86 | for gt_path, pred_path in gt_pred_pairs: 87 | row = {} 88 | gt = load(gt_path)[:, :3] 89 | gt = gt[np.newaxis, ...] 90 | pred = load(pred_path) 91 | pred = pred[:, :3] 92 | 93 | row["name"] = os.path.basename(pred_path) 94 | pred = pred[np.newaxis, ...] 95 | 96 | pred = torch.from_numpy(pred).cuda() 97 | gt = torch.from_numpy(gt).cuda() 98 | 99 | pred_tensor, centroid, furthest_distance = normalize_point_batch(pred) 100 | gt_tensor, centroid, furthest_distance = normalize_point_batch(gt) 101 | 102 | # B, P_predict, 1 103 | cd_forward, cd_backward = nndistance(pred, gt) 104 | # cd_forward, _ = knn_point(1, gt_tensor, pred_tensor) 105 | # cd_backward, _ = knn_point(1, pred_tensor, gt_tensor) 106 | # cd_forward = cd_forward[0, :, 0] 107 | # cd_backward = cd_backward[0, :, 0] 108 | cd_forward = cd_forward.detach().cpu().numpy()[0] 109 | cd_backward = cd_backward.detach().cpu().numpy()[0] 110 | 111 | save_ply_property(pred.squeeze(0).detach().cpu().numpy(), cd_forward, pred_path[:-4]+"_cdF.ply", property_max=0.003, cmap_name="jet") 112 | save_ply_property(gt.squeeze(0).detach().cpu().numpy(), cd_backward, pred_path[:-4]+"_cdB.ply", property_max=0.003, cmap_name="jet") 113 | 114 | md_value = np.mean(cd_forward)+np.mean(cd_backward) 115 | hd_value = np.max(np.amax(cd_forward, axis=0)+np.amax(cd_backward, axis=0)) 116 | cd_backward = np.mean(cd_backward) 117 | cd_forward = np.mean(cd_forward) 118 | # row["CD_forward"] = np.mean(cd_forward) 119 | # row["CD_backwar"] = np.mean(cd_backward) 120 | row["CD"] = cd_forward+cd_backward 121 | 122 | row["hausdorff"] = hd_value 123 | avg_md_forward_value += cd_forward 124 | avg_md_backward_value += cd_backward 125 | avg_hd_value += hd_value 126 | if os.path.isfile(pred_path[:-4] + "_point2mesh_distance.xyz"): 127 | point2mesh_distance = load(pred_path[:-4] + "_point2mesh_distance.xyz") 128 | if point2mesh_distance.size == 0: 129 | continue 130 | point2mesh_distance = point2mesh_distance[:, 3] 131 | row["p2f avg"] = np.nanmean(point2mesh_distance) 132 | row["p2f std"] = np.nanstd(point2mesh_distance) 133 | global_p2f.append(point2mesh_distance) 134 | if os.path.isfile(pred_path[:-4] + "_density.xyz"): 135 | density = load(pred_path[:-4] + "_density.xyz") 136 | global_density.append(density) 137 | std = np.std(density, axis=0) 138 | for i in range(7): 139 | row["nuc_%d" % i] = std[i] 140 | writer.writerow(row) 141 | counter += 1 142 | 143 | row = OrderedDict() 144 | 145 | avg_md_forward_value /= counter 146 | avg_md_backward_value /= counter 147 | avg_hd_value /= counter 148 | # row["CD_forward"] = avg_md_forward_value 149 | # row["CD_backward"] = avg_md_backward_value 150 | row["CD"] = avg_md_forward_value+avg_md_backward_value 151 | row["hausdorff"] = avg_hd_value 152 | if global_p2f: 153 | global_p2f = np.concatenate(global_p2f, axis=0) 154 | mean_p2f = np.nanmean(global_p2f) 155 | std_p2f = np.nanstd(global_p2f) 156 | row["p2f avg"] = mean_p2f 157 | row["p2f std"] = std_p2f 158 | if global_density: 159 | global_density = np.concatenate(global_density, axis=0) 160 | nuc = np.std(global_density, axis=0) 161 | for i in range(7): 162 | row["nuc_%d" % i] = std[i] 163 | 164 | writer.writerow(row) 165 | print("|".join(["{:>15.8f}".format(d) for d in row.values()])) 166 | -------------------------------------------------------------------------------- /scripts/pcl2Mesh.py: -------------------------------------------------------------------------------- 1 | from plyfile import PlyData, PlyElement 2 | import numpy as np 3 | import math 4 | import torch 5 | import os 6 | import argparse 7 | 8 | from neural_point_splatter.geometryConstruction import triangle3dmesh, circle3dmesh, normal2RotMatrix, pcl2Mesh 9 | from neural_point_splatter.splatterIo import readCloud 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser(description='Convert a ply holding a point cloud to a ply holding a mesh with circles or radius r.') 13 | parser.add_argument('input', metavar="input", 14 | help='Input file') 15 | parser.add_argument('-o', '--output', dest='output', 16 | help='Output file') 17 | parser.add_argument('-r', '--radius', dest='radius', type=float, default="0.005", 18 | help='radius of circle to replace points with') 19 | parser.add_argument('-t', '--triangles', dest='triangles', type=int, default=5, 20 | help='number of triangles to use for approximation') 21 | parser.add_argument('--text', action='store_true', default=False, 22 | help='By default, PLY files are written in binary format. Add this flag to store as text.') 23 | args = parser.parse_args() 24 | if(args.output is None): 25 | filename, file_extension = os.path.splitext(args.input) 26 | # for the moment only ply files 27 | args.output = filename + "_mesh" + ".ply" 28 | 29 | if not os.path.exists(args.input): 30 | print("ERROR: Please provide a path to an existing input file") 31 | parser.parse_args(['-h']) 32 | exit(-1) 33 | print("Will write to: " + args.output) 34 | if args.triangles > 3: 35 | (circ3, circ3Conn) = circle3dmesh(args.triangles) 36 | else: 37 | (circ3, circ3Conn) = triangle3dmesh(args.triangles) 38 | circ3 = args.radius * circ3 39 | points = readCloud(args.input).float() 40 | pointsCount = points.size()[0] 41 | print("Found " + str(pointsCount) + " points") 42 | 43 | (verts, faces) = pcl2Mesh(points[:,0:3], points[:,3:6], points[:,6:9], circ3, circ3Conn) 44 | verts = verts.numpy() 45 | faces = faces.numpy() 46 | verts = list(map(tuple, verts)) 47 | faces = list(map(lambda row : (tuple(row[0:3]), row[3], row[4], row[5]), faces)) 48 | npVerts = np.array(verts, dtype=[ 49 | ('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 50 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 51 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 52 | npFaces = np.array(faces, dtype=[ 53 | ('vertex_indices', 'i4', (3,)), 54 | ('red', 'u1'), 55 | ('green', 'u1'), 56 | ('blue', 'u2') 57 | ]) 58 | elV = PlyElement.describe(npVerts, 'vertex') 59 | elF = PlyElement.describe(npFaces, 'face') 60 | print("Created vertices: " + str(len(verts)) + ", expected: " + str(pointsCount*(args.triangles+1))) 61 | print("Created triangles: " + str(len(faces)) + ", expected: " + str(pointsCount*args.triangles)) 62 | PlyData([elV, elF], text=args.text).write(args.output) 63 | 64 | -------------------------------------------------------------------------------- /scripts/poisson_sampling.mlx: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /scripts/poisson_sampling_pca.mlx: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /scripts/random_displacement.mlx: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /scripts/run_meshlab_filter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # usage: ./sample_mesh /mnt/external/points/data/ModelNet40 "*.ply" 3 | num_procs=1 4 | 5 | inputDir=$1 6 | name="$2" 7 | outputDir="$3" 8 | myDir=$(pwd) 9 | scriptFile="$myDir/$4" 10 | echo "input: $inputDir output: $outputDir extension: $name" 11 | 12 | cd $inputDir 13 | find . -type d -exec mkdir -p "$outputDir"/{} \; 14 | 15 | 16 | function meshlab_poisson_reconstruct () { 17 | iFile="$1" 18 | iName="$(basename $iFile)" 19 | # remove last extension 20 | iName="${iName%.*}" 21 | iDir="$(dirname $iFile)" 22 | oFile="$3/$iName".ply 23 | sFile="$2" 24 | # meshlab.meshlabserver -i $iFile -o $oFile -m vn -s $sFile 25 | # echo "meshlab.meshlabserver -i $iFile -o $oFile -s $sFile" 26 | if [ ! -f "$oFile" ]; then 27 | meshlab.meshlabserver -i $iFile -o $oFile -m vn -s $sFile 28 | # meshlabserver -i $oFile -o $oFile2 -s $sFile2 29 | fi 30 | } 31 | export -f meshlab_poisson_reconstruct 32 | 33 | echo $scriptFile 34 | find . -type f -wholename "$name" 35 | find . -type f -wholename "$name" | xargs -P $num_procs -I % bash -c 'meshlab_poisson_reconstruct "$@"' _ % $scriptFile $outputDir 36 | cd $myDir 37 | -------------------------------------------------------------------------------- /sequences.py: -------------------------------------------------------------------------------- 1 | """render a point cloud in 360 degree""" 2 | from __future__ import division, print_function 3 | import torch 4 | import os 5 | import argparse 6 | import time 7 | import numpy as np 8 | from itertools import chain 9 | from glob import glob 10 | import sys 11 | from DSS.utils.matrixConstruction import rotationMatrixY, rotationMatrixX, rotationMatrixZ, batchAffineMatrix 12 | from DSS.utils.splatterIo import saveAsPng, readScene, readCloud, getBasename, writeCameras, writeScene 13 | from DSS.core.renderer import createSplatter 14 | from DSS.core.camera import CameraSampler 15 | from DSS.options.render_options import RenderOptions 16 | 17 | 18 | def rotMatrix(axis): 19 | if axis.lower() == "x": 20 | return rotationMatrixY 21 | elif axis.lower() == "y": 22 | return rotationMatrixY 23 | else: 24 | return rotationMatrixZ 25 | 26 | 27 | if __name__ == "__main__": 28 | opt = RenderOptions().parse() 29 | points_paths = list(chain.from_iterable(glob(p) for p in opt.points)) 30 | assert(len(points_paths) > 0), "Found no point clouds with path {}".format(points_paths) 31 | points_relpaths = None 32 | if len(points_paths) > 1: 33 | points_dir = os.path.commonpath(points_paths) 34 | points_relpaths = [os.path.relpath(p, points_dir) for p in points_paths] 35 | else: 36 | points_relpaths = [os.path.basename(p) for p in points_paths] 37 | 38 | torch.manual_seed(24) 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | np.random.seed(24) 42 | 43 | scene = readScene(opt.source, device="cpu") 44 | 45 | getRotationMatrix = rotMatrix(opt.rot_axis) 46 | with torch.no_grad(): 47 | splatter = createSplatter(opt, scene=scene) 48 | 49 | for i in range(len(scene.cameras)): 50 | scene.cameras[i].width = opt.width 51 | scene.cameras[i].height = opt.height 52 | # scene.cameras[i].focalLength = opt.camFocalLength 53 | 54 | splatter.initCameras(cameras=scene.cameras) 55 | 56 | for pointPath, pointRelPath in zip(points_paths, points_relpaths): 57 | keyName = os.path.join(opt.output, pointRelPath[:-4]) 58 | print(pointPath) 59 | points = readCloud(pointPath, device="cpu") 60 | fileName = getBasename(pointPath) 61 | # find point center 62 | center = torch.mean(points[:, :3], dim=0, keepdim=True) 63 | points[:, :3] -= center 64 | scene.loadPoints(points) 65 | splatter.setCloud(scene.cloud) 66 | splatter.pointPosition.data.copy_(center) 67 | for i, cam in enumerate(scene.cameras): 68 | # compute object rotation 69 | cnt = 0 70 | for ang in range(0, 360, 3): 71 | rot = getRotationMatrix(torch.tensor(ang*np.pi/180).to(device=splatter.pointRotation.device)) 72 | splatter.pointRotation.data.copy_(rot.unsqueeze(0)) 73 | splatter.m2w = batchAffineMatrix(splatter.pointRotation, splatter.pointPosition, splatter.pointScale) 74 | 75 | # set camera to look at the center 76 | splatter.setCamera(i) 77 | result = splatter.render() 78 | if result is None: 79 | continue 80 | result = result.detach()[0] 81 | 82 | if splatter.shading == "albedo": 83 | cmax = 1 84 | saveAsPng(result.cpu(), keyName + '_cam%02d_%03d.png' % (i, cnt), cmin=0, cmax=cmax) 85 | else: 86 | saveAsPng(result.cpu(), keyName + '_cam%02d_%03d.png' % (i, cnt), cmin=0) 87 | 88 | cnt += 1 89 | 90 | print(pointRelPath) 91 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from typing import List 3 | import os 4 | 5 | import torch 6 | from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension, CppExtension 7 | 8 | extra_compile_args = {"cxx": ["-std=c++14"]} 9 | define_macros = [] 10 | 11 | force_cuda = os.getenv("FORCE_CUDA", "0") == "1" 12 | if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda: 13 | extension = CUDAExtension 14 | # sources += source_cuda 15 | define_macros += [("WITH_CUDA", None)] 16 | nvcc_args = [ 17 | "-DCUDA_HAS_FP16=1", 18 | "-D__CUDA_NO_HALF_OPERATORS__", 19 | "-D__CUDA_NO_HALF_CONVERSIONS__", 20 | "-D__CUDA_NO_HALF2_OPERATORS__", 21 | ] 22 | nvcc_flags_env = os.getenv("NVCC_FLAGS", "") 23 | if nvcc_flags_env != "": 24 | nvcc_args.extend(nvcc_flags_env.split(" ")) 25 | 26 | # It's better if pytorch can do this by default .. 27 | CC = os.environ.get("CC", None) 28 | if CC is not None: 29 | CC_arg = "-ccbin={}".format(CC) 30 | if CC_arg not in nvcc_args: 31 | if any(arg.startswith("-ccbin") for arg in nvcc_args): 32 | raise ValueError("Inconsistent ccbins") 33 | nvcc_args.append(CC_arg) 34 | 35 | extra_compile_args["nvcc"] = nvcc_args 36 | else: 37 | print('Cuda is not available!') 38 | 39 | ext_modules = [ 40 | CppExtension('DSS._C', [ 41 | 'DSS/csrc/ext.cpp', 42 | 'DSS/csrc/rasterize_points_cpu.cpp', 43 | ]) 44 | ] 45 | include_dirs = torch.utils.cpp_extension.include_paths() 46 | ext_modules += [ 47 | CUDAExtension('DSS._C', [ 48 | 'DSS/csrc/ext.cpp', 49 | 'DSS/csrc/rasterize_points.cu', 50 | 'DSS/csrc/rasterize_points_backward.cu', 51 | 'DSS/csrc/rasterize_points_cpu.cpp', 52 | ], 53 | include_dirs=['DSS/csrc'], 54 | define_macros=define_macros, 55 | extra_compile_args=extra_compile_args 56 | ) 57 | ] 58 | 59 | INSTALL_REQUIREMENTS = ['numpy', 'torch', 'plyfile', 'pytorch3d', 'imageio', 'frnn'] 60 | 61 | setup( 62 | name='DSS', 63 | description='Differentiable Surface Splatter', 64 | author='Yifan Wang, Lixin Xue and Felice Serena', 65 | packages=find_packages(exclude=('tests')), 66 | license='MIT License', 67 | version='2.0', 68 | install_requires=INSTALL_REQUIREMENTS, 69 | ext_modules=ext_modules, 70 | cmdclass={'build_ext': BuildExtension} 71 | ) 72 | -------------------------------------------------------------------------------- /train_mvr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import git 5 | import os 6 | import logging 7 | import config 8 | import torch 9 | import torch.optim as optim 10 | from DSS.utils import tolerating_collate 11 | from DSS.misc.checkpoints import CheckpointIO 12 | from DSS.utils.sampler import WeightedSubsetRandomSampler 13 | from DSS import logger_py, set_deterministic_ 14 | 15 | set_deterministic_() 16 | 17 | 18 | # Arguments 19 | parser = argparse.ArgumentParser( 20 | description='Train implicit representations without 3D supervision.' 21 | ) 22 | parser.add_argument('--config', type=str, 23 | default="configs/donut_dss_complete.yml", help='Path to config file.') 24 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 25 | parser.add_argument('--exit-after', type=int, default=600, 26 | help='Checkpoint and exit after specified number of ' 27 | 'seconds with exit code 2.') 28 | 29 | args = parser.parse_args() 30 | cfg = config.load_config(args.config, 'configs/default.yaml') 31 | 32 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 33 | device = torch.device("cuda" if is_cuda else "cpu") 34 | 35 | # Shorthands 36 | out_dir = os.path.join(cfg['training']['out_dir'], cfg['name']) 37 | backup_every = cfg['training']['backup_every'] 38 | exit_after = args.exit_after 39 | lr = cfg['training']['learning_rate'] 40 | batch_size = cfg['training']['batch_size'] 41 | batch_size_val = cfg['training']['batch_size_val'] 42 | n_workers = cfg['training']['n_workers'] 43 | model_selection_metric = cfg['training']['model_selection_metric'] 44 | if cfg['training']['model_selection_mode'] == 'maximize': 45 | model_selection_sign = 1 46 | elif cfg['training']['model_selection_mode'] == 'minimize': 47 | model_selection_sign = -1 48 | else: 49 | raise ValueError('model_selection_mode must be ' 50 | 'either maximize or minimize.') 51 | 52 | 53 | # Output directory 54 | if not os.path.exists(out_dir): 55 | os.makedirs(out_dir) 56 | 57 | # Begin logging also to the log file 58 | fileHandler = logging.FileHandler(os.path.join(out_dir, cfg.training.logfile)) 59 | fileHandler.setLevel(logging.DEBUG) 60 | logger_py.addHandler(fileHandler) 61 | 62 | repo = git.Repo(search_parent_directories=False) 63 | sha = repo.head.object.hexsha 64 | logger_py.debug('Git commit: %s' % sha) 65 | 66 | # Data 67 | train_dataset = config.create_dataset(cfg.data, mode='train') 68 | val_dataset = config.create_dataset(cfg.data, mode='val') 69 | val_loader = torch.utils.data.DataLoader( 70 | val_dataset, batch_size=batch_size_val, num_workers=int(n_workers // 2), 71 | shuffle=False, collate_fn=tolerating_collate, 72 | ) 73 | # data_viz = next(iter(val_loader)) 74 | model = config.create_model( 75 | cfg, camera_model=train_dataset.get_cameras(), device=device) 76 | 77 | # Create rendering objects from loaded data 78 | cameras = train_dataset.get_cameras() 79 | lights = train_dataset.get_lights() 80 | 81 | 82 | # Optimizer 83 | param_groups = [] 84 | if cfg.model.model_kwargs.learn_normals: 85 | param_groups.append( 86 | {"params": [model.normals], "lr": 0.01, 'betas': (0.5, 0.9)}) 87 | if cfg.model.model_kwargs.learn_points: 88 | param_groups.append( 89 | {"params": [model.points], "lr": 0.01, 'betas': (0.5, 0.9)}) 90 | if cfg.model.model_kwargs.learn_colors: 91 | param_groups.append( 92 | {"params": [model.colors], "lr": 1.0, 'betas': (0.5, 0.9)}) 93 | 94 | # optimizer = optim.SGD(param_groups, lr=lr) 95 | optimizer = optim.Adam(param_groups, lr=0.01, betas=(0.5, 0.9)) 96 | 97 | # Loads checkpoints 98 | checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer) 99 | try: 100 | load_dict = checkpoint_io.load(cfg.training.resume_from) 101 | except FileExistsError: 102 | load_dict = dict() 103 | 104 | epoch_it = load_dict.get('epoch_it', -1) 105 | it = load_dict.get('it', -1) 106 | 107 | # Save config to log directory 108 | config.save_config(os.path.join(out_dir, 'config.yaml'), cfg) 109 | 110 | # Update Metrics from loaded 111 | model_selection_metric = cfg['training']['model_selection_metric'] 112 | metric_val_best = load_dict.get( 113 | 'loss_val_best', -model_selection_sign * np.inf) 114 | 115 | if metric_val_best == np.inf or metric_val_best == -np.inf: 116 | metric_val_best = -model_selection_sign * np.inf 117 | 118 | logger_py.info('Current best validation metric (%s): %.8f' 119 | % (model_selection_metric, metric_val_best)) 120 | 121 | # Shorthands 122 | print_every = cfg['training']['print_every'] 123 | checkpoint_every = cfg['training']['checkpoint_every'] 124 | validate_every = cfg['training']['validate_every'] 125 | visualize_every = cfg['training']['visualize_every'] 126 | debug_every = cfg['training']['debug_every'] 127 | 128 | scheduler = optim.lr_scheduler.MultiStepLR( 129 | optimizer, cfg['training']['scheduler_milestones'], 130 | gamma=cfg['training']['scheduler_gamma'], last_epoch=epoch_it) 131 | 132 | # Set mesh extraction to low resolution for fast visuliation 133 | # during training 134 | cfg['generation']['resolution'] = 64 135 | cfg['generation']['img_size'] = tuple(x // 4 for x in train_dataset.resolution) 136 | generator = config.create_generator(cfg, model, device=device) 137 | trainer = config.create_trainer( 138 | cfg, model, optimizer, scheduler, generator, None, val_loader, device=device) 139 | 140 | # Print model 141 | nparameters = sum(p.numel() for p in model.parameters()) 142 | logger_py.info('Total number of parameters: %d' % nparameters) 143 | 144 | 145 | # Start training loop 146 | t0 = time.time() 147 | t0b = time.time() 148 | sample_weights = np.ones(len(train_dataset)).astype('float32') 149 | 150 | while True: 151 | epoch_it += 1 152 | train_sampler = WeightedSubsetRandomSampler( 153 | list(range(len(train_dataset))), sample_weights) 154 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, 155 | num_workers=n_workers, drop_last=True, 156 | collate_fn=tolerating_collate) 157 | trainer.train_loader = train_loader 158 | for batch in train_loader: 159 | it += 1 160 | 161 | loss = trainer.train_step(batch, cameras=cameras, lights=lights, it=it) 162 | 163 | # Visualize output 164 | if it > 0 and visualize_every > 0 and (it % visualize_every) == 0: 165 | logger_py.info('Visualizing') 166 | trainer.visualize(batch, it=it, vis_type='image', 167 | cameras=cameras, lights=lights) 168 | trainer.visualize( 169 | batch, it=it, vis_type='pointcloud', cameras=cameras, lights=lights) 170 | 171 | # Print output 172 | if print_every > 0 and (it % print_every) == 0: 173 | logger_py.info('[Epoch %02d] it=%03d, loss=%.4f, time=%.4f' 174 | % (epoch_it, it, loss, time.time() - t0b)) 175 | t0b = time.time() 176 | 177 | # Debug visualization 178 | if it > 0 and debug_every > 0 and (it % debug_every) == 0: 179 | logger_py.info('Visualizing gradients') 180 | trainer.debug(batch, cameras=cameras, lights=lights, it=it, 181 | mesh_gt=train_dataset.get_meshes()) 182 | 183 | # Save checkpoint 184 | if it > 0 and (checkpoint_every > 0 and (it % checkpoint_every) == 0): 185 | logger_py.info('Saving checkpoint') 186 | print('Saving checkpoint') 187 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 188 | loss_val_best=metric_val_best) 189 | 190 | # Backup if necessary 191 | if it > 0 and (backup_every > 0 and (it % backup_every) == 0): 192 | logger_py.info('Backup checkpoint') 193 | checkpoint_io.save('model_%d.pt' % it, epoch_it=epoch_it, it=it, 194 | loss_val_best=metric_val_best) 195 | 196 | # Run validation and adjust sampling rate 197 | if it > 0 and validate_every > 0 and (it % validate_every) == 0: 198 | if 'chamfer' in model_selection_metric: 199 | eval_dict = trainer.evaluate_3d( 200 | val_loader, it, cameras=cameras, lights=lights) 201 | else: 202 | eval_dict = trainer.evaluate_2d( 203 | val_loader, cameras=cameras, lights=lights) 204 | metric_val = eval_dict[model_selection_metric] 205 | 206 | logger_py.info('Validation metric (%s): %.4g' % 207 | (model_selection_metric, metric_val)) 208 | 209 | if model_selection_sign * (metric_val - metric_val_best) > 0: 210 | metric_val_best = metric_val 211 | logger_py.info('New best model (loss %.4g)' % metric_val_best) 212 | checkpoint_io.backup_model_best('model_best.pt') 213 | checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it, 214 | loss_val_best=metric_val_best) 215 | # save point cloud 216 | pointcloud = trainer.generator.generate_pointclouds( 217 | {}, with_colors=False, with_normals=True)[0] 218 | pointcloud.export(os.path.join(trainer.val_dir, 'best.ply')) 219 | 220 | # Exit if necessary 221 | if exit_after > 0 and (time.time() - t0) >= exit_after: 222 | logger_py.info('Time limit reached. Exiting.') 223 | checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, 224 | loss_val_best=metric_val_best) 225 | for t in trainer._threads: 226 | t.join() 227 | exit(3) 228 | 229 | # Make scheduler step after full epoch 230 | trainer.update_learning_rate(it) 231 | -------------------------------------------------------------------------------- /trained_models/download_data.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | curl -L https://polybox.ethz.ch/index.php/s/sKvxkn8pKLFFhBd/download --output render_PCA_resnet.zip 3 | unzip render_PCA_resnet.zip --------------------------------------------------------------------------------