├── data └── gargoyle.ply ├── README.md ├── confs └── ngpull.conf ├── LICENSE ├── .github └── workflows │ └── jekyll-gh-pages.yml ├── .gitignore ├── models ├── utils.py ├── dataset.py └── fields.py └── run.py /data/gargoyle.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CuiRuikai/NumGrad-Pull/HEAD/data/gargoyle.ply -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NumGrad-Pull: Numerical Gradient Guided Tri-plane Representation for Surface Reconstruction from Point Clouds 2 | 3 | 4 | --- 5 | 6 | This is the official repository for the paper *"NumGrad-Pull: Numerical Gradient Guided Tri-plane Representation for Surface Reconstruction from Point Clouds"*. 7 | 8 | ## Usage: 9 | ```python 10 | python run.py --gpu 0 --conf confs/ngpull.conf --dataname gargoyle --dir gargoyle 11 | ``` 12 | You can find the generated mesh and the log in `./outs`. 13 | -------------------------------------------------------------------------------- /confs/ngpull.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./outs/ 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = data/ 11 | np_data_name = carnew1w_norm.npz 12 | } 13 | 14 | train { 15 | learning_rate = 0.001 16 | maxiter = 50000 17 | warm_up_end = 1000 18 | eval_num_points = 100000 19 | 20 | 21 | batch_size = 5000 22 | 23 | save_freq = 5000 24 | val_freq = 2500 25 | report_freq = 1000 26 | 27 | igr_weight = 0.1 28 | mask_weight = 0.0 29 | 30 | lr_net = 0.001 31 | lr_tri = 0.05 32 | grad_eps = 1e-2 33 | resolution = 48 34 | c2f_scale = [3000, 8000, 12000] 35 | } 36 | 37 | model { 38 | sdf_network { 39 | d_out = 1 40 | d_in = 32 41 | d_hidden = 128 42 | n_layers = 3 43 | skip_in = [] 44 | multires = 0 45 | bias = 0.5 46 | scale = 1.0 47 | geometric_init = True 48 | weight_norm = True 49 | } 50 | triplane { 51 | init_type = geo_init 52 | } 53 | } 54 | 55 | 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 CuiRuikai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/jekyll-gh-pages.yml: -------------------------------------------------------------------------------- 1 | # Sample workflow for building and deploying a Jekyll site to GitHub Pages 2 | name: Deploy Jekyll with GitHub Pages dependencies preinstalled 3 | 4 | on: 5 | # Runs on pushes targeting the default branch 6 | push: 7 | branches: ["main"] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 13 | permissions: 14 | contents: read 15 | pages: write 16 | id-token: write 17 | 18 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 19 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 20 | concurrency: 21 | group: "pages" 22 | cancel-in-progress: false 23 | 24 | jobs: 25 | # Build job 26 | build: 27 | runs-on: ubuntu-latest 28 | steps: 29 | - name: Checkout 30 | uses: actions/checkout@v4 31 | - name: Setup Pages 32 | uses: actions/configure-pages@v5 33 | - name: Build with Jekyll 34 | uses: actions/jekyll-build-pages@v1 35 | with: 36 | source: ./ 37 | destination: ./_site 38 | - name: Upload artifact 39 | uses: actions/upload-pages-artifact@v3 40 | 41 | # Deployment job 42 | deploy: 43 | environment: 44 | name: github-pages 45 | url: ${{ steps.deployment.outputs.page_url }} 46 | runs-on: ubuntu-latest 47 | needs: build 48 | steps: 49 | - name: Deploy to GitHub Pages 50 | id: deployment 51 | uses: actions/deploy-pages@v4 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # my folders 2 | outs/ 3 | *.npz 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # UV 102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | #uv.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | 174 | # PyPI configuration file 175 | .pypirc -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO, name="main"): 8 | """Get root logger and add a keyword filter to it. 9 | The logger will be initialized if it has not been initialized. By default a 10 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 11 | also be added. The name of the root logger is the top-level package name, 12 | e.g., "mmdet3d". 13 | Args: 14 | log_file (str, optional): File path of log. Defaults to None. 15 | log_level (int, optional): The level of logger. 16 | Defaults to logging.INFO. 17 | name (str, optional): The name of the root logger, also used as a 18 | filter keyword. Defaults to 'mmdet3d'. 19 | Returns: 20 | :obj:`logging.Logger`: The obtained logger 21 | """ 22 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 23 | # add a logging filter 24 | logging_filter = logging.Filter(name) 25 | logging_filter.filter = lambda record: record.find(name) != -1 26 | 27 | return logger 28 | 29 | 30 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="w"): 31 | """Initialize and get a logger by name. 32 | If the logger has not been initialized, this method will initialize the 33 | logger by adding one or two handlers, otherwise the initialized logger will 34 | be directly returned. During initialization, a StreamHandler will always be 35 | added. If `log_file` is specified and the process rank is 0, a FileHandler 36 | will also be added. 37 | Args: 38 | name (str): Logger name. 39 | log_file (str | None): The log filename. If specified, a FileHandler 40 | will be added to the logger. 41 | log_level (int): The logger level. Note that only the process of 42 | rank 0 is affected, and other processes will set the level to 43 | "Error" thus be silent most of the time. 44 | file_mode (str): The file mode used in opening log file. 45 | Defaults to 'w'. 46 | Returns: 47 | logging.Logger: The expected logger. 48 | """ 49 | logger = logging.getLogger(name) 50 | if name in logger_initialized: 51 | return logger 52 | # handle hierarchical names 53 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 54 | # initialization since it is a child of "a". 55 | for logger_name in logger_initialized: 56 | if name.startswith(logger_name): 57 | return logger 58 | 59 | # handle duplicate logs to the console 60 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 61 | # to the root logger. As logger.propagate is True by default, this root 62 | # level handler causes logging messages from rank>0 processes to 63 | # unexpectedly show up on the console, creating much unwanted clutter. 64 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 65 | # at the ERROR level. 66 | for handler in logger.root.handlers: 67 | if type(handler) is logging.StreamHandler: 68 | handler.setLevel(logging.ERROR) 69 | 70 | stream_handler = logging.StreamHandler() 71 | handlers = [stream_handler] 72 | 73 | if dist.is_available() and dist.is_initialized(): 74 | rank = dist.get_rank() 75 | else: 76 | rank = 0 77 | 78 | # only rank 0 will add a FileHandler 79 | if rank == 0 and log_file is not None: 80 | # Here, the default behaviour of the official logger is 'a'. Thus, we 81 | # provide an interface to change the file mode to the default 82 | # behaviour. 83 | file_handler = logging.FileHandler(log_file, file_mode) 84 | handlers.append(file_handler) 85 | 86 | formatter = logging.Formatter( 87 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 88 | ) 89 | for handler in handlers: 90 | handler.setFormatter(formatter) 91 | handler.setLevel(log_level) 92 | logger.addHandler(handler) 93 | 94 | if rank == 0: 95 | logger.setLevel(log_level) 96 | else: 97 | logger.setLevel(logging.ERROR) 98 | 99 | logger_initialized[name] = True 100 | 101 | return logger 102 | 103 | 104 | def print_log(msg, logger=None, level=logging.INFO): 105 | """Print a log message. 106 | Args: 107 | msg (str): The message to be logged. 108 | logger (logging.Logger | str | None): The logger to be used. 109 | Some special loggers are: 110 | - "silent": no message will be printed. 111 | - other str: the logger obtained with `get_root_logger(logger)`. 112 | - None: The `print()` method will be used to print log messages. 113 | level (int): Logging level. Only available when `logger` is a Logger 114 | object or "root". 115 | """ 116 | if logger is None: 117 | print(msg) 118 | elif isinstance(logger, logging.Logger): 119 | logger.log(level, msg) 120 | elif logger == "silent": 121 | pass 122 | elif isinstance(logger, str): 123 | _logger = get_logger(logger) 124 | _logger.log(level, msg) 125 | else: 126 | raise TypeError( 127 | "logger should be either a logging.Logger object, str, " 128 | f'"silent" or None, but got {type(logger)}' 129 | ) 130 | -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | from scipy.spatial import cKDTree 6 | import trimesh 7 | 8 | 9 | def search_nearest_point(point_batch, point_gt): 10 | num_point_batch, num_point_gt = point_batch.shape[0], point_gt.shape[0] 11 | point_batch = point_batch.unsqueeze(1).repeat(1, num_point_gt, 1) 12 | point_gt = point_gt.unsqueeze(0).repeat(num_point_batch, 1, 1) 13 | 14 | distances = torch.sqrt(torch.sum((point_batch - point_gt) ** 2, axis=-1) + 1e-12) 15 | dis_idx = torch.argmin(distances, axis=1).detach().cpu().numpy() 16 | 17 | return dis_idx 18 | 19 | 20 | def process_data(data_dir, dataname): 21 | if os.path.exists(os.path.join(data_dir, dataname) + ".ply"): 22 | pointcloud = trimesh.load(os.path.join(data_dir, dataname) + ".ply").vertices 23 | pointcloud = np.asarray(pointcloud) 24 | elif os.path.exists(os.path.join(data_dir, dataname) + ".xyz"): 25 | pointcloud = np.load(os.path.join(data_dir, dataname)) + ".xyz" 26 | else: 27 | print("Only support .xyz or .ply data. Please make adjust your data.") 28 | exit() 29 | shape_scale = np.max( 30 | [ 31 | np.max(pointcloud[:, 0]) - np.min(pointcloud[:, 0]), 32 | np.max(pointcloud[:, 1]) - np.min(pointcloud[:, 1]), 33 | np.max(pointcloud[:, 2]) - np.min(pointcloud[:, 2]), 34 | ] 35 | ) 36 | shape_center = [ 37 | (np.max(pointcloud[:, 0]) + np.min(pointcloud[:, 0])) / 2, 38 | (np.max(pointcloud[:, 1]) + np.min(pointcloud[:, 1])) / 2, 39 | (np.max(pointcloud[:, 2]) + np.min(pointcloud[:, 2])) / 2, 40 | ] 41 | pointcloud = pointcloud - shape_center 42 | pointcloud = pointcloud / shape_scale 43 | 44 | POINT_NUM = pointcloud.shape[0] // 60 45 | POINT_NUM_GT = pointcloud.shape[0] // 60 * 60 46 | QUERY_EACH = 1000000 // POINT_NUM_GT 47 | 48 | point_idx = np.random.choice(pointcloud.shape[0], POINT_NUM_GT, replace=False) 49 | pointcloud = pointcloud[point_idx, :] 50 | # print(np.max(pointcloud[:,0]),np.max(pointcloud[:,1]),np.max(pointcloud[:,2]),np.min(pointcloud[:,0]),np.min(pointcloud[:,1]),np.min(pointcloud[:,2])) 51 | ptree = cKDTree(pointcloud) 52 | sigmas = [] 53 | for p in np.array_split(pointcloud, 100, axis=0): 54 | d = ptree.query(p, 51) 55 | sigmas.append(d[0][:, -1]) 56 | 57 | sigmas = np.concatenate(sigmas) 58 | sample = [] 59 | sample_near = [] 60 | 61 | for i in range(QUERY_EACH): 62 | scale = 0.25 * np.sqrt(POINT_NUM_GT / 20000) 63 | tt = pointcloud + scale * np.expand_dims(sigmas, -1) * np.random.normal( 64 | 0.0, 1.0, size=pointcloud.shape 65 | ) 66 | sample.append(tt) 67 | tt = tt.reshape(-1, POINT_NUM, 3) 68 | 69 | sample_near_tmp = [] 70 | for j in range(tt.shape[0]): 71 | nearest_idx = search_nearest_point( 72 | torch.tensor(tt[j]).float().cuda(), 73 | torch.tensor(pointcloud).float().cuda(), 74 | ) 75 | nearest_points = pointcloud[nearest_idx] 76 | nearest_points = np.asarray(nearest_points).reshape(-1, 3) 77 | sample_near_tmp.append(nearest_points) 78 | sample_near_tmp = np.asarray(sample_near_tmp) 79 | sample_near_tmp = sample_near_tmp.reshape(-1, 3) 80 | sample_near.append(sample_near_tmp) 81 | 82 | # sample points uniformly in the unit box 83 | for i in range(QUERY_EACH // 8): 84 | tt = np.random.rand(POINT_NUM_GT, 3) * 2 - 1 85 | sample.append(tt) 86 | tt = tt.reshape(-1, POINT_NUM, 3) 87 | 88 | sample_near_tmp = [] 89 | for j in range(tt.shape[0]): 90 | nearest_idx = search_nearest_point( 91 | torch.tensor(tt[j]).float().cuda(), 92 | torch.tensor(pointcloud).float().cuda(), 93 | ) 94 | nearest_points = pointcloud[nearest_idx] 95 | nearest_points = np.asarray(nearest_points).reshape(-1, 3) 96 | sample_near_tmp.append(nearest_points) 97 | sample_near_tmp = np.asarray(sample_near_tmp) 98 | sample_near_tmp = sample_near_tmp.reshape(-1, 3) 99 | sample_near.append(sample_near_tmp) 100 | 101 | sample = np.asarray(sample) 102 | sample_near = np.asarray(sample_near) 103 | 104 | np.savez( 105 | os.path.join(data_dir, dataname) + ".npz", 106 | sample=sample, 107 | point=pointcloud, 108 | sample_near=sample_near, 109 | ) 110 | 111 | 112 | class DatasetNP: 113 | def __init__(self, conf, dataname): 114 | super(DatasetNP, self).__init__() 115 | self.device = torch.device("cuda") 116 | self.conf = conf 117 | 118 | self.data_dir = conf.get_string("data_dir") 119 | self.np_data_name = dataname + ".npz" 120 | 121 | if os.path.exists(os.path.join(self.data_dir, self.np_data_name)): 122 | print("Data existing. Loading data...") 123 | else: 124 | print("Data not found. Processing data...") 125 | process_data(self.data_dir, dataname) 126 | load_data = np.load(os.path.join(self.data_dir, self.np_data_name)) 127 | 128 | self.point = np.asarray(load_data["sample_near"]).reshape(-1, 3) 129 | self.sample = np.asarray(load_data["sample"]).reshape(-1, 3) 130 | self.point_gt = np.asarray(load_data["point"]).reshape(-1, 3) 131 | self.sample_points_num = self.sample.shape[0] - 1 132 | 133 | self.object_bbox_min = ( 134 | np.array( 135 | [ 136 | np.min(self.point[:, 0]), 137 | np.min(self.point[:, 1]), 138 | np.min(self.point[:, 2]), 139 | ] 140 | ) 141 | - 0.05 142 | ) 143 | self.object_bbox_max = ( 144 | np.array( 145 | [ 146 | np.max(self.point[:, 0]), 147 | np.max(self.point[:, 1]), 148 | np.max(self.point[:, 2]), 149 | ] 150 | ) 151 | + 0.05 152 | ) 153 | print("Data bounding box:", self.object_bbox_min, self.object_bbox_max) 154 | 155 | self.point = torch.from_numpy(self.point).to(self.device).float() 156 | self.sample = torch.from_numpy(self.sample).to(self.device).float() 157 | self.point_gt = torch.from_numpy(self.point_gt).to(self.device).float() 158 | 159 | print("NP Load data: End") 160 | 161 | def np_train_data(self, batch_size): 162 | index_coarse = np.random.choice(10, 1) 163 | index_fine = np.random.choice( 164 | self.sample_points_num // 10, batch_size, replace=False 165 | ) 166 | index = index_fine * 10 + index_coarse 167 | points = self.point[index] 168 | sample = self.sample[index] 169 | return points, sample, self.point_gt 170 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import trimesh 6 | 7 | 8 | class Triplane(nn.Module): 9 | def __init__( 10 | self, 11 | reso=256, 12 | channel=32, 13 | init_type="geo_init", 14 | ): 15 | super().__init__() 16 | if init_type == "geo_init": 17 | sdf_proxy = nn.Sequential( 18 | nn.Linear(3, channel), 19 | nn.Softplus(beta=100), 20 | nn.Linear(channel, channel), 21 | ) 22 | torch.nn.init.constant_(sdf_proxy[0].bias, 0.0) 23 | torch.nn.init.normal_( 24 | sdf_proxy[0].weight, 0.0, np.sqrt(2) / np.sqrt(channel) 25 | ) 26 | torch.nn.init.constant_(sdf_proxy[2].bias, 0.0) 27 | torch.nn.init.normal_( 28 | sdf_proxy[2].weight, 0.0, np.sqrt(2) / np.sqrt(channel) 29 | ) 30 | 31 | ini_sdf = torch.zeros([3, channel, reso, reso]) 32 | X = torch.linspace(-1.0, 1.0, reso) 33 | (U, V) = torch.meshgrid(X, X, indexing="ij") 34 | Z = torch.zeros(reso, reso) 35 | inputx = torch.stack([Z, U, V], -1).reshape(-1, 3) 36 | inputy = torch.stack([U, Z, V], -1).reshape(-1, 3) 37 | inputz = torch.stack([U, V, Z], -1).reshape(-1, 3) 38 | ini_sdf[0] = sdf_proxy(inputx).permute(1, 0).reshape(channel, reso, reso) 39 | ini_sdf[1] = sdf_proxy(inputy).permute(1, 0).reshape(channel, reso, reso) 40 | ini_sdf[2] = sdf_proxy(inputz).permute(1, 0).reshape(channel, reso, reso) 41 | self.triplane = torch.nn.Parameter(ini_sdf / 3, requires_grad=True) 42 | elif init_type == "rand_init": 43 | self.triplane = torch.nn.Parameter( 44 | torch.randn([3, channel, reso, reso]) * 0.001, requires_grad=True 45 | ) 46 | else: 47 | raise ValueError("Unknown init_type") 48 | 49 | self.R = reso 50 | self.C = channel 51 | self.register_buffer( 52 | "plane_axes", 53 | torch.tensor( 54 | [ 55 | [[0, 1, 0], [1, 0, 0], [0, 0, 1]], 56 | [[0, 0, 1], [1, 0, 0], [0, 1, 0]], 57 | [[0, 1, 0], [0, 0, 1], [1, 0, 0]], 58 | ], 59 | dtype=torch.float32, 60 | ), 61 | ) 62 | 63 | def project_onto_planes(self, xyz): 64 | M, _ = xyz.shape 65 | xyz = xyz.unsqueeze(0).expand(3, -1, -1).reshape(3, M, 3) 66 | inv_planes = torch.linalg.inv(self.plane_axes).reshape(3, 3, 3) 67 | projections = torch.bmm(xyz, inv_planes) 68 | return projections[..., :2] # [3, M, 2] 69 | 70 | def forward(self, xyz): 71 | # pts: [M,3] 72 | M, _ = xyz.shape 73 | projected_coordinates = self.project_onto_planes(xyz).unsqueeze(1) 74 | feats = F.grid_sample( 75 | self.triplane, # [3,C,R,R] 76 | projected_coordinates.float(), # [3,1,M,2] 77 | mode="bilinear", 78 | padding_mode="zeros", 79 | align_corners=True, 80 | ) # [3,C,1,M] 81 | feats = feats.permute(0, 3, 2, 1).reshape(3, M, self.C).sum(0) 82 | return feats # [M,C] 83 | 84 | def update_resolution(self, new_reso): 85 | new_tri = F.interpolate( 86 | self.triplane.data, 87 | size=(new_reso, new_reso), 88 | mode="bilinear", 89 | align_corners=True, 90 | ) 91 | self.R = new_reso 92 | self.triplane = torch.nn.Parameter(new_tri, requires_grad=True) 93 | 94 | 95 | class NGPullNetwork(nn.Module): 96 | def __init__( 97 | self, 98 | d_in, 99 | d_out, 100 | d_hidden, 101 | n_layers, 102 | skip_in=(4,), 103 | multires=0, 104 | bias=0.5, 105 | scale=1, 106 | geometric_init=True, 107 | weight_norm=True, 108 | inside_outside=False, 109 | ): 110 | super(NGPullNetwork, self).__init__() 111 | 112 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 113 | 114 | self.embed_fn_fine = None 115 | 116 | self.num_layers = len(dims) 117 | self.skip_in = skip_in 118 | self.scale = scale 119 | 120 | for l in range(0, self.num_layers - 1): 121 | if l + 1 in self.skip_in: 122 | out_dim = dims[l + 1] - dims[0] 123 | else: 124 | out_dim = dims[l + 1] 125 | 126 | lin = nn.Linear(dims[l], out_dim) 127 | 128 | if geometric_init: 129 | if l == self.num_layers - 2: 130 | if not inside_outside: 131 | torch.nn.init.normal_( 132 | lin.weight, 133 | mean=np.sqrt(np.pi) / np.sqrt(dims[l]), 134 | std=0.0001, 135 | ) 136 | torch.nn.init.constant_(lin.bias, -bias) 137 | else: 138 | torch.nn.init.normal_( 139 | lin.weight, 140 | mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), 141 | std=0.0001, 142 | ) 143 | torch.nn.init.constant_(lin.bias, bias) 144 | elif multires > 0 and l == 0: 145 | torch.nn.init.constant_(lin.bias, 0.0) 146 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 147 | torch.nn.init.normal_( 148 | lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim) 149 | ) 150 | elif multires > 0 and l in self.skip_in: 151 | torch.nn.init.constant_(lin.bias, 0.0) 152 | torch.nn.init.normal_( 153 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim) 154 | ) 155 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) 156 | else: 157 | torch.nn.init.constant_(lin.bias, 0.0) 158 | torch.nn.init.normal_( 159 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim) 160 | ) 161 | 162 | if weight_norm: 163 | lin = nn.utils.weight_norm(lin) 164 | setattr(self, "lin" + str(l), lin) 165 | 166 | self.activation = nn.ReLU() 167 | 168 | def forward(self, inputs): 169 | inputs = inputs * self.scale 170 | if self.embed_fn_fine is not None: 171 | inputs = self.embed_fn_fine(inputs) 172 | 173 | x = inputs 174 | for l in range(0, self.num_layers - 1): 175 | lin = getattr(self, "lin" + str(l)) 176 | if l in self.skip_in: 177 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 178 | 179 | x = lin(x) 180 | if l < self.num_layers - 2: 181 | x = self.activation(x) 182 | 183 | return x / self.scale 184 | 185 | def sdf(self, x): 186 | return self.forward(x) 187 | 188 | 189 | def as_mesh(scene_or_mesh): 190 | """ 191 | Convert a possible scene to a mesh. 192 | 193 | If conversion occurs, the returned mesh has only vertex and face data. 194 | Suggested by https://github.com/mikedh/trimesh/issues/507 195 | """ 196 | if isinstance(scene_or_mesh, trimesh.Scene): 197 | if len(scene_or_mesh.geometry) == 0: 198 | mesh = None # empty scene 199 | else: 200 | # we lose texture information here 201 | mesh = trimesh.util.concatenate( 202 | tuple( 203 | trimesh.Trimesh(vertices=g.vertices, faces=g.faces) 204 | for g in scene_or_mesh.geometry.values() 205 | ) 206 | ) 207 | else: 208 | print("is_mesh") 209 | assert isinstance(scene_or_mesh, trimesh.Trimesh) 210 | mesh = scene_or_mesh 211 | return mesh 212 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import time 5 | import math 6 | import argparse 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import trimesh 11 | import mcubes 12 | from shutil import copyfile 13 | from tqdm import tqdm 14 | from pyhocon import ConfigFactory 15 | from models.utils import get_root_logger, print_log 16 | from models.dataset import DatasetNP 17 | from models.fields import NGPullNetwork, Triplane 18 | 19 | import warnings 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | def create_optimizer(net, triplane, lr_net=1e-3, lr_tri=1e-2): 25 | params_to_train = [] 26 | if net is not None: 27 | params_to_train += [{"name": "net", "params": net.parameters(), "lr": lr_net}] 28 | if triplane is not None: 29 | params_to_train += [ 30 | {"name": "tri", "params": triplane.parameters(), "lr": lr_tri} 31 | ] 32 | return torch.optim.Adam(params_to_train) 33 | 34 | 35 | class Runner: 36 | def __init__(self, args, conf_path, mode="train"): 37 | self.device = torch.device("cuda") 38 | 39 | # Configuration 40 | self.conf_path = conf_path 41 | f = open(self.conf_path) 42 | conf_text = f.read() 43 | f.close() 44 | 45 | self.conf = ConfigFactory.parse_string(conf_text) 46 | self.conf["dataset.np_data_name"] = self.conf["dataset.np_data_name"] 47 | self.base_exp_dir = self.conf["general.base_exp_dir"] + args.dir 48 | os.makedirs(self.base_exp_dir, exist_ok=True) 49 | 50 | self.dataset_np = DatasetNP(self.conf["dataset"], args.dataname) 51 | self.dataname = args.dataname 52 | self.iter_step = 0 53 | 54 | # Training parameters 55 | self.maxiter = self.conf.get_int("train.maxiter") 56 | self.save_freq = self.conf.get_int("train.save_freq") 57 | self.report_freq = self.conf.get_int("train.report_freq") 58 | self.val_freq = self.conf.get_int("train.val_freq") 59 | self.batch_size = self.conf.get_int("train.batch_size") 60 | self.learning_rate = self.conf.get_float("train.learning_rate") 61 | self.warm_up_end = self.conf.get_float("train.warm_up_end", default=0.0) 62 | self.eval_num_points = self.conf.get_int("train.eval_num_points") 63 | 64 | self.lr_net = self.conf.get_float("train.lr_net") 65 | self.lr_tri = self.conf.get_float("train.lr_tri") 66 | 67 | self.eps = self.conf.get_float("train.grad_eps") 68 | self.resolution = self.conf.get_int("train.resolution") 69 | self.c2f_scale = self.conf.get_list("train.c2f_scale") 70 | 71 | self.mode = mode 72 | 73 | # Networks 74 | self.sdf_network = NGPullNetwork(**self.conf["model.sdf_network"]).to( 75 | self.device 76 | ) 77 | self.triplane = Triplane( 78 | reso=self.resolution // (2 ** len(self.c2f_scale)), 79 | channel=self.conf["model.sdf_network.d_in"], 80 | init_type=self.conf.get_string("model.triplane.init_type"), 81 | ).to(self.device) 82 | self.optimizer = create_optimizer( 83 | self.sdf_network, self.triplane, self.lr_net, self.lr_tri 84 | ) 85 | 86 | # Backup codes and configs for debug 87 | if self.mode[:5] == "train": 88 | self.file_backup() 89 | 90 | def train(self): 91 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 92 | log_file = os.path.join(os.path.join(self.base_exp_dir), f"{timestamp}.log") 93 | logger = get_root_logger(log_file=log_file, name="outs") 94 | self.logger = logger 95 | batch_size = self.batch_size 96 | 97 | res_step = self.maxiter - self.iter_step 98 | 99 | eps_tensor = ( 100 | torch.cat( 101 | [ 102 | torch.as_tensor([[self.eps, 0.0, 0.0]]), 103 | torch.as_tensor([[-self.eps, 0.0, 0.0]]), 104 | torch.as_tensor([[0.0, self.eps, 0.0]]), 105 | torch.as_tensor([[0.0, -self.eps, 0.0]]), 106 | torch.as_tensor([[0.0, 0.0, self.eps]]), 107 | torch.as_tensor([[0.0, 0.0, -self.eps]]), 108 | ] 109 | ) 110 | .unsqueeze(1) 111 | .to(self.device) 112 | ) 113 | 114 | for iter_i in tqdm(range(res_step)): 115 | if iter_i in self.c2f_scale: 116 | new_reso = int( 117 | self.resolution 118 | / (2 ** (len(self.c2f_scale) - self.c2f_scale.index(iter_i) - 1)) 119 | ) 120 | print_log("Update resolution to {}".format(new_reso), logger=logger) 121 | self.triplane.update_resolution(new_reso) 122 | self.optimizer = create_optimizer(self.sdf_network, self.triplane) 123 | torch.cuda.empty_cache() 124 | 125 | self.update_learning_rate_np(iter_i) 126 | 127 | points, samples, point_gt = self.dataset_np.np_train_data(batch_size) 128 | samples_all = torch.cat( 129 | [(samples.unsqueeze(0) + eps_tensor).reshape(-1, 3), samples], dim=0 130 | ) 131 | sdfs_all = self.sdf_network(self.triplane(samples_all)).reshape( 132 | 7, -1, 1 133 | ) # [7, N, 1] 134 | gradients_sample = torch.cat( 135 | [ 136 | 0.5 * (sdfs_all[0, :] - sdfs_all[1, :]) / self.eps, 137 | 0.5 * (sdfs_all[2, :] - sdfs_all[3, :]) / self.eps, 138 | 0.5 * (sdfs_all[4, :] - sdfs_all[5, :]) / self.eps, 139 | ], 140 | dim=-1, 141 | ) 142 | 143 | sdf_sample = sdfs_all[-1, :] # 5000x1 144 | grad_norm = F.normalize(gradients_sample, dim=1) # 5000x3 145 | sample_moved = samples - grad_norm * sdf_sample # 5000x3 146 | 147 | loss = torch.linalg.norm((points - sample_moved), ord=2, dim=-1).mean() 148 | 149 | self.optimizer.zero_grad() 150 | loss.backward() 151 | self.optimizer.step() 152 | 153 | self.iter_step += 1 154 | if self.iter_step % self.report_freq == 0: 155 | print_log( 156 | "iter:{:8>d} cd_l1 = {} lr={}".format( 157 | self.iter_step, loss, self.optimizer.param_groups[0]["lr"] 158 | ), 159 | logger=logger, 160 | ) 161 | 162 | if self.iter_step % self.val_freq == 0 and self.iter_step != 0: 163 | self.validate_mesh( 164 | resolution=256, 165 | threshold=args.mcubes_threshold, 166 | point_gt=point_gt, 167 | iter_step=self.iter_step, 168 | logger=logger, 169 | ) 170 | 171 | if self.iter_step % self.save_freq == 0 and self.iter_step != 0: 172 | self.save_checkpoint() 173 | 174 | def validate_mesh( 175 | self, resolution=64, threshold=0.0, point_gt=None, iter_step=0, logger=None 176 | ): 177 | bound_min = torch.tensor(self.dataset_np.object_bbox_min, dtype=torch.float32) 178 | bound_max = torch.tensor(self.dataset_np.object_bbox_max, dtype=torch.float32) 179 | os.makedirs(os.path.join(self.base_exp_dir, "outputs"), exist_ok=True) 180 | mesh = self.extract_geometry( 181 | bound_min, 182 | bound_max, 183 | resolution=resolution, 184 | threshold=threshold, 185 | query_func=lambda pts: -self.sdf_network.sdf(self.triplane(pts)), 186 | ) 187 | 188 | mesh.export( 189 | os.path.join( 190 | self.base_exp_dir, 191 | "outputs", 192 | "{:0>8d}_{}.ply".format(self.iter_step, str(threshold)), 193 | ) 194 | ) 195 | 196 | def update_learning_rate_np(self, iter_step): 197 | warn_up = self.warm_up_end 198 | max_iter = self.maxiter 199 | init_lr = self.learning_rate 200 | lr = ( 201 | (iter_step / warn_up) 202 | if iter_step < warn_up 203 | else 0.5 204 | * (math.cos((iter_step - warn_up) / (max_iter - warn_up) * math.pi) + 1) 205 | ) 206 | lr = lr * init_lr 207 | for g in self.optimizer.param_groups: 208 | g["lr"] = lr 209 | 210 | def extract_fields(self, bound_min, bound_max, resolution, query_func): 211 | N = 32 212 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 213 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 214 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 215 | 216 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 217 | with torch.no_grad(): 218 | for xi, xs in enumerate(X): 219 | for yi, ys in enumerate(Y): 220 | for zi, zs in enumerate(Z): 221 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 222 | pts = torch.cat( 223 | [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], 224 | dim=-1, 225 | ) 226 | val = ( 227 | query_func(pts) 228 | .reshape(len(xs), len(ys), len(zs)) 229 | .detach() 230 | .cpu() 231 | .numpy() 232 | ) 233 | u[ 234 | xi * N : xi * N + len(xs), 235 | yi * N : yi * N + len(ys), 236 | zi * N : zi * N + len(zs), 237 | ] = val 238 | return u 239 | 240 | def extract_geometry(self, bound_min, bound_max, resolution, threshold, query_func): 241 | print("Creating mesh with threshold: {}".format(threshold)) 242 | u = self.extract_fields(bound_min, bound_max, resolution, query_func) 243 | vertices, triangles = mcubes.marching_cubes(u, threshold) 244 | b_max_np = bound_max.detach().cpu().numpy() 245 | b_min_np = bound_min.detach().cpu().numpy() 246 | 247 | vertices = ( 248 | vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] 249 | + b_min_np[None, :] 250 | ) 251 | mesh = trimesh.Trimesh(vertices, triangles) 252 | 253 | return mesh 254 | 255 | def file_backup(self): 256 | dir_lis = self.conf["general.recording"] 257 | os.makedirs(os.path.join(self.base_exp_dir, "recording"), exist_ok=True) 258 | for dir_name in dir_lis: 259 | cur_dir = os.path.join(self.base_exp_dir, "recording", dir_name) 260 | os.makedirs(cur_dir, exist_ok=True) 261 | files = os.listdir(dir_name) 262 | for f_name in files: 263 | if f_name[-3:] == ".py": 264 | copyfile( 265 | os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name) 266 | ) 267 | 268 | copyfile( 269 | self.conf_path, os.path.join(self.base_exp_dir, "recording", "config.conf") 270 | ) 271 | 272 | def load_checkpoint(self, checkpoint_name): 273 | checkpoint = torch.load( 274 | os.path.join(self.base_exp_dir, "checkpoints", checkpoint_name), 275 | map_location=self.device, 276 | ) 277 | print(os.path.join(self.base_exp_dir, "checkpoints", checkpoint_name)) 278 | self.sdf_network.load_state_dict(checkpoint["sdf_network_fine"]) 279 | self.triplane.load_state_dict(checkpoint["triplane"]) 280 | 281 | self.iter_step = checkpoint["iter_step"] 282 | 283 | def save_checkpoint(self): 284 | checkpoint = { 285 | "sdf_network_fine": self.sdf_network.state_dict(), 286 | "triplane": self.triplane.state_dict(), 287 | "iter_step": self.iter_step, 288 | } 289 | os.makedirs(os.path.join(self.base_exp_dir, "checkpoints"), exist_ok=True) 290 | torch.save( 291 | checkpoint, 292 | os.path.join( 293 | self.base_exp_dir, 294 | "checkpoints", 295 | "ckpt_{:0>6d}.pth".format(self.iter_step), 296 | ), 297 | ) 298 | 299 | 300 | if __name__ == "__main__": 301 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 302 | parser = argparse.ArgumentParser() 303 | parser.add_argument("--conf", type=str, default="./confs/np_srb.conf") 304 | parser.add_argument("--mode", type=str, default="train") 305 | parser.add_argument("--mcubes_threshold", type=float, default=0.0) 306 | parser.add_argument("--gpu", type=int, default=0) 307 | parser.add_argument("--dir", type=str, default="gargoyle") 308 | parser.add_argument("--dataname", type=str, default="gargoyle") 309 | args = parser.parse_args() 310 | 311 | torch.cuda.set_device(args.gpu) 312 | runner = Runner(args, args.conf, args.mode) 313 | 314 | if args.mode == "train": 315 | runner.train() 316 | elif args.mode == "validate_mesh": 317 | threshs = [ 318 | -0.001, 319 | -0.0025, 320 | -0.005, 321 | -0.01, 322 | -0.02, 323 | 0.0, 324 | 0.001, 325 | 0.0025, 326 | 0.005, 327 | 0.01, 328 | 0.02, 329 | ] 330 | for thresh in threshs: 331 | runner.validate_mesh(resolution=256, threshold=thresh) 332 | --------------------------------------------------------------------------------