├── .gitignore ├── LICENSE ├── README.md ├── apps ├── __init__.py ├── eval.py └── train.py ├── environment.yaml └── lib ├── __init__.py ├── data ├── BOP_BP_YCBV.py └── __init__.py ├── data_utils ├── __init__.py ├── aug_util.py └── sample_frustum_util.py ├── debug_pyrender_util.py ├── eval_Rt_time_util.py ├── geometry.py ├── loss_util.py ├── mesh_util.py ├── model ├── BasePIFuNet.py ├── HGFilters.py ├── HGPIFuNet.py ├── RayDistanceNormalizer.py ├── SurfaceClassifier.py └── __init__.py ├── net_util.py ├── options.py ├── rigid_fit ├── calculate_rmsd.py ├── ransac.py └── ransac_kabsch.py ├── sdf.py └── sym_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial 4.0 International 2 | 3 | Creative Commons Corporation ("Creative Commons") is not a law firm and 4 | does not provide legal services or legal advice. Distribution of 5 | Creative Commons public licenses does not create a lawyer-client or 6 | other relationship. Creative Commons makes its licenses and related 7 | information available on an "as-is" basis. Creative Commons gives no 8 | warranties regarding its licenses, any material licensed under their 9 | terms and conditions, or any related information. Creative Commons 10 | disclaims all liability for damages resulting from their use to the 11 | fullest extent possible. 12 | 13 | Using Creative Commons Public Licenses 14 | 15 | Creative Commons public licenses provide a standard set of terms and 16 | conditions that creators and other rights holders may use to share 17 | original works of authorship and other material subject to copyright and 18 | certain other rights specified in the public license below. The 19 | following considerations are for informational purposes only, are not 20 | exhaustive, and do not form part of our licenses. 21 | 22 | - Considerations for licensors: Our public licenses are intended for 23 | use by those authorized to give the public permission to use 24 | material in ways otherwise restricted by copyright and certain other 25 | rights. Our licenses are irrevocable. Licensors should read and 26 | understand the terms and conditions of the license they choose 27 | before applying it. Licensors should also secure all rights 28 | necessary before applying our licenses so that the public can reuse 29 | the material as expected. Licensors should clearly mark any material 30 | not subject to the license. This includes other CC-licensed 31 | material, or material used under an exception or limitation to 32 | copyright. More considerations for licensors : 33 | wiki.creativecommons.org/Considerations_for_licensors 34 | 35 | - Considerations for the public: By using one of our public licenses, 36 | a licensor grants the public permission to use the licensed material 37 | under specified terms and conditions. If the licensor's permission 38 | is not necessary for any reason–for example, because of any 39 | applicable exception or limitation to copyright–then that use is not 40 | regulated by the license. Our licenses grant only permissions under 41 | copyright and certain other rights that a licensor has authority to 42 | grant. Use of the licensed material may still be restricted for 43 | other reasons, including because others have copyright or other 44 | rights in the material. A licensor may make special requests, such 45 | as asking that all changes be marked or described. Although not 46 | required by our licenses, you are encouraged to respect those 47 | requests where reasonable. More considerations for the public : 48 | wiki.creativecommons.org/Considerations_for_licensees 49 | 50 | Creative Commons Attribution-NonCommercial 4.0 International Public 51 | License 52 | 53 | By exercising the Licensed Rights (defined below), You accept and agree 54 | to be bound by the terms and conditions of this Creative Commons 55 | Attribution-NonCommercial 4.0 International Public License ("Public 56 | License"). To the extent this Public License may be interpreted as a 57 | contract, You are granted the Licensed Rights in consideration of Your 58 | acceptance of these terms and conditions, and the Licensor grants You 59 | such rights in consideration of benefits the Licensor receives from 60 | making the Licensed Material available under these terms and conditions. 61 | 62 | - Section 1 – Definitions. 63 | 64 | - a. Adapted Material means material subject to Copyright and 65 | Similar Rights that is derived from or based upon the Licensed 66 | Material and in which the Licensed Material is translated, 67 | altered, arranged, transformed, or otherwise modified in a 68 | manner requiring permission under the Copyright and Similar 69 | Rights held by the Licensor. For purposes of this Public 70 | License, where the Licensed Material is a musical work, 71 | performance, or sound recording, Adapted Material is always 72 | produced where the Licensed Material is synched in timed 73 | relation with a moving image. 74 | - b. Adapter's License means the license You apply to Your 75 | Copyright and Similar Rights in Your contributions to Adapted 76 | Material in accordance with the terms and conditions of this 77 | Public License. 78 | - c. Copyright and Similar Rights means copyright and/or similar 79 | rights closely related to copyright including, without 80 | limitation, performance, broadcast, sound recording, and Sui 81 | Generis Database Rights, without regard to how the rights are 82 | labeled or categorized. For purposes of this Public License, the 83 | rights specified in Section 2(b)(1)-(2) are not Copyright and 84 | Similar Rights. 85 | - d. Effective Technological Measures means those measures that, 86 | in the absence of proper authority, may not be circumvented 87 | under laws fulfilling obligations under Article 11 of the WIPO 88 | Copyright Treaty adopted on December 20, 1996, and/or similar 89 | international agreements. 90 | - e. Exceptions and Limitations means fair use, fair dealing, 91 | and/or any other exception or limitation to Copyright and 92 | Similar Rights that applies to Your use of the Licensed 93 | Material. 94 | - f. Licensed Material means the artistic or literary work, 95 | database, or other material to which the Licensor applied this 96 | Public License. 97 | - g. Licensed Rights means the rights granted to You subject to 98 | the terms and conditions of this Public License, which are 99 | limited to all Copyright and Similar Rights that apply to Your 100 | use of the Licensed Material and that the Licensor has authority 101 | to license. 102 | - h. Licensor means the individual(s) or entity(ies) granting 103 | rights under this Public License. 104 | - i. NonCommercial means not primarily intended for or directed 105 | towards commercial advantage or monetary compensation. For 106 | purposes of this Public License, the exchange of the Licensed 107 | Material for other material subject to Copyright and Similar 108 | Rights by digital file-sharing or similar means is NonCommercial 109 | provided there is no payment of monetary compensation in 110 | connection with the exchange. 111 | - j. Share means to provide material to the public by any means or 112 | process that requires permission under the Licensed Rights, such 113 | as reproduction, public display, public performance, 114 | distribution, dissemination, communication, or importation, and 115 | to make material available to the public including in ways that 116 | members of the public may access the material from a place and 117 | at a time individually chosen by them. 118 | - k. Sui Generis Database Rights means rights other than copyright 119 | resulting from Directive 96/9/EC of the European Parliament and 120 | of the Council of 11 March 1996 on the legal protection of 121 | databases, as amended and/or succeeded, as well as other 122 | essentially equivalent rights anywhere in the world. 123 | - l. You means the individual or entity exercising the Licensed 124 | Rights under this Public License. Your has a corresponding 125 | meaning. 126 | 127 | - Section 2 – Scope. 128 | 129 | - a. License grant. 130 | - 1. Subject to the terms and conditions of this Public 131 | License, the Licensor hereby grants You a worldwide, 132 | royalty-free, non-sublicensable, non-exclusive, irrevocable 133 | license to exercise the Licensed Rights in the Licensed 134 | Material to: 135 | - A. reproduce and Share the Licensed Material, in whole 136 | or in part, for NonCommercial purposes only; and 137 | - B. produce, reproduce, and Share Adapted Material for 138 | NonCommercial purposes only. 139 | - 2. Exceptions and Limitations. For the avoidance of doubt, 140 | where Exceptions and Limitations apply to Your use, this 141 | Public License does not apply, and You do not need to comply 142 | with its terms and conditions. 143 | - 3. Term. The term of this Public License is specified in 144 | Section 6(a). 145 | - 4. Media and formats; technical modifications allowed. The 146 | Licensor authorizes You to exercise the Licensed Rights in 147 | all media and formats whether now known or hereafter 148 | created, and to make technical modifications necessary to do 149 | so. The Licensor waives and/or agrees not to assert any 150 | right or authority to forbid You from making technical 151 | modifications necessary to exercise the Licensed Rights, 152 | including technical modifications necessary to circumvent 153 | Effective Technological Measures. For purposes of this 154 | Public License, simply making modifications authorized by 155 | this Section 2(a)(4) never produces Adapted Material. 156 | - 5. Downstream recipients. 157 | - A. Offer from the Licensor – Licensed Material. Every 158 | recipient of the Licensed Material automatically 159 | receives an offer from the Licensor to exercise the 160 | Licensed Rights under the terms and conditions of this 161 | Public License. 162 | - B. No downstream restrictions. You may not offer or 163 | impose any additional or different terms or conditions 164 | on, or apply any Effective Technological Measures to, 165 | the Licensed Material if doing so restricts exercise of 166 | the Licensed Rights by any recipient of the Licensed 167 | Material. 168 | - 6. No endorsement. Nothing in this Public License 169 | constitutes or may be construed as permission to assert or 170 | imply that You are, or that Your use of the Licensed 171 | Material is, connected with, or sponsored, endorsed, or 172 | granted official status by, the Licensor or others 173 | designated to receive attribution as provided in Section 174 | 3(a)(1)(A)(i). 175 | - b. Other rights. 176 | - 1. Moral rights, such as the right of integrity, are not 177 | licensed under this Public License, nor are publicity, 178 | privacy, and/or other similar personality rights; however, 179 | to the extent possible, the Licensor waives and/or agrees 180 | not to assert any such rights held by the Licensor to the 181 | limited extent necessary to allow You to exercise the 182 | Licensed Rights, but not otherwise. 183 | - 2. Patent and trademark rights are not licensed under this 184 | Public License. 185 | - 3. To the extent possible, the Licensor waives any right to 186 | collect royalties from You for the exercise of the Licensed 187 | Rights, whether directly or through a collecting society 188 | under any voluntary or waivable statutory or compulsory 189 | licensing scheme. In all other cases the Licensor expressly 190 | reserves any right to collect such royalties, including when 191 | the Licensed Material is used other than for NonCommercial 192 | purposes. 193 | 194 | - Section 3 – License Conditions. 195 | 196 | Your exercise of the Licensed Rights is expressly made subject to 197 | the following conditions. 198 | 199 | - a. Attribution. 200 | - 1. If You Share the Licensed Material (including in modified 201 | form), You must: 202 | - A. retain the following if it is supplied by the 203 | Licensor with the Licensed Material: 204 | - i. identification of the creator(s) of the Licensed 205 | Material and any others designated to receive 206 | attribution, in any reasonable manner requested by 207 | the Licensor (including by pseudonym if designated); 208 | - ii. a copyright notice; 209 | - iii. a notice that refers to this Public License; 210 | - iv. a notice that refers to the disclaimer of 211 | warranties; 212 | - v. a URI or hyperlink to the Licensed Material to 213 | the extent reasonably practicable; 214 | - B. indicate if You modified the Licensed Material and 215 | retain an indication of any previous modifications; and 216 | - C. indicate the Licensed Material is licensed under this 217 | Public License, and include the text of, or the URI or 218 | hyperlink to, this Public License. 219 | - 2. You may satisfy the conditions in Section 3(a)(1) in any 220 | reasonable manner based on the medium, means, and context in 221 | which You Share the Licensed Material. For example, it may 222 | be reasonable to satisfy the conditions by providing a URI 223 | or hyperlink to a resource that includes the required 224 | information. 225 | - 3. If requested by the Licensor, You must remove any of the 226 | information required by Section 3(a)(1)(A) to the extent 227 | reasonably practicable. 228 | - 4. If You Share Adapted Material You produce, the Adapter's 229 | License You apply must not prevent recipients of the Adapted 230 | Material from complying with this Public License. 231 | 232 | - Section 4 – Sui Generis Database Rights. 233 | 234 | Where the Licensed Rights include Sui Generis Database Rights that 235 | apply to Your use of the Licensed Material: 236 | 237 | - a. for the avoidance of doubt, Section 2(a)(1) grants You the 238 | right to extract, reuse, reproduce, and Share all or a 239 | substantial portion of the contents of the database for 240 | NonCommercial purposes only; 241 | - b. if You include all or a substantial portion of the database 242 | contents in a database in which You have Sui Generis Database 243 | Rights, then the database in which You have Sui Generis Database 244 | Rights (but not its individual contents) is Adapted Material; 245 | and 246 | - c. You must comply with the conditions in Section 3(a) if You 247 | Share all or a substantial portion of the contents of the 248 | database. 249 | 250 | For the avoidance of doubt, this Section 4 supplements and does not 251 | replace Your obligations under this Public License where the 252 | Licensed Rights include other Copyright and Similar Rights. 253 | 254 | - Section 5 – Disclaimer of Warranties and Limitation of Liability. 255 | 256 | - a. Unless otherwise separately undertaken by the Licensor, to 257 | the extent possible, the Licensor offers the Licensed Material 258 | as-is and as-available, and makes no representations or 259 | warranties of any kind concerning the Licensed Material, whether 260 | express, implied, statutory, or other. This includes, without 261 | limitation, warranties of title, merchantability, fitness for a 262 | particular purpose, non-infringement, absence of latent or other 263 | defects, accuracy, or the presence or absence of errors, whether 264 | or not known or discoverable. Where disclaimers of warranties 265 | are not allowed in full or in part, this disclaimer may not 266 | apply to You. 267 | - b. To the extent possible, in no event will the Licensor be 268 | liable to You on any legal theory (including, without 269 | limitation, negligence) or otherwise for any direct, special, 270 | indirect, incidental, consequential, punitive, exemplary, or 271 | other losses, costs, expenses, or damages arising out of this 272 | Public License or use of the Licensed Material, even if the 273 | Licensor has been advised of the possibility of such losses, 274 | costs, expenses, or damages. Where a limitation of liability is 275 | not allowed in full or in part, this limitation may not apply to 276 | You. 277 | - c. The disclaimer of warranties and limitation of liability 278 | provided above shall be interpreted in a manner that, to the 279 | extent possible, most closely approximates an absolute 280 | disclaimer and waiver of all liability. 281 | 282 | - Section 6 – Term and Termination. 283 | 284 | - a. This Public License applies for the term of the Copyright and 285 | Similar Rights licensed here. However, if You fail to comply 286 | with this Public License, then Your rights under this Public 287 | License terminate automatically. 288 | - b. Where Your right to use the Licensed Material has terminated 289 | under Section 6(a), it reinstates: 290 | 291 | - 1. automatically as of the date the violation is cured, 292 | provided it is cured within 30 days of Your discovery of the 293 | violation; or 294 | - 2. upon express reinstatement by the Licensor. 295 | 296 | For the avoidance of doubt, this Section 6(b) does not affect 297 | any right the Licensor may have to seek remedies for Your 298 | violations of this Public License. 299 | 300 | - c. For the avoidance of doubt, the Licensor may also offer the 301 | Licensed Material under separate terms or conditions or stop 302 | distributing the Licensed Material at any time; however, doing 303 | so will not terminate this Public License. 304 | - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 305 | License. 306 | 307 | - Section 7 – Other Terms and Conditions. 308 | 309 | - a. The Licensor shall not be bound by any additional or 310 | different terms or conditions communicated by You unless 311 | expressly agreed. 312 | - b. Any arrangements, understandings, or agreements regarding the 313 | Licensed Material not stated herein are separate from and 314 | independent of the terms and conditions of this Public License. 315 | 316 | - Section 8 – Interpretation. 317 | 318 | - a. For the avoidance of doubt, this Public License does not, and 319 | shall not be interpreted to, reduce, limit, restrict, or impose 320 | conditions on any use of the Licensed Material that could 321 | lawfully be made without permission under this Public License. 322 | - b. To the extent possible, if any provision of this Public 323 | License is deemed unenforceable, it shall be automatically 324 | reformed to the minimum extent necessary to make it enforceable. 325 | If the provision cannot be reformed, it shall be severed from 326 | this Public License without affecting the enforceability of the 327 | remaining terms and conditions. 328 | - c. No term or condition of this Public License will be waived 329 | and no failure to comply consented to unless expressly agreed to 330 | by the Licensor. 331 | - d. Nothing in this Public License constitutes or may be 332 | interpreted as a limitation upon, or waiver of, any privileges 333 | and immunities that apply to the Licensor or You, including from 334 | the legal processes of any jurisdiction or authority. 335 | 336 | Creative Commons is not a party to its public licenses. Notwithstanding, 337 | Creative Commons may elect to apply one of its public licenses to 338 | material it publishes and in those instances will be considered the 339 | "Licensor." The text of the Creative Commons public licenses is 340 | dedicated to the public domain under the CC0 Public Domain Dedication. 341 | Except for the limited purpose of indicating that material is shared 342 | under a Creative Commons public license or as otherwise permitted by the 343 | Creative Commons policies published at creativecommons.org/policies, 344 | Creative Commons does not authorize the use of the trademark "Creative 345 | Commons" or any other trademark or logo of Creative Commons without its 346 | prior written consent including, without limitation, in connection with 347 | any unauthorized modifications to any of its public licenses or any 348 | other arrangements, understandings, or agreements concerning use of 349 | licensed material. For the avoidance of doubt, this paragraph does not 350 | form part of the public licenses. 351 | 352 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Correspondence Field for Object Pose Estimation 2 | 3 | This repository provides the source code and trained models of the 6D object pose estimation method presented in: 4 | 5 | [Lin Huang](https://linhuang17.github.io/), [Tomas Hodan](http://www.hodan.xyz), [Lingni Ma](https://www.linkedin.com/in/lingnima/), [Linguang Zhang](https://lg-zhang.github.io/), [Luan Tran](https://www.linkedin.com/in/luan-tran-3185009b/), [Christopher Twigg](https://chris.twi.gg/), [Po-Chen Wu](http://media.ee.ntu.edu.tw/personal/pcwu/), [Junsong Yuan](https://cse.buffalo.edu/~jsyuan/), [Cem Keskin](https://www.linkedin.com/in/cem-keskin-23692a15/), [Robert Wang](http://people.csail.mit.edu/rywang/)
6 | **Neural Correspondence Field for Object Pose Estimation**
7 | European Conference on Computer Vision (ECCV) 2022
8 | [Paper](https://arxiv.org/pdf/2208.00113.pdf) | [Webpage](https://linhuang17.github.io/NCF/) | [Bibtex](https://linhuang17.github.io/NCF/resources/huang2022ncf.txt) 9 | 10 | Contents: [Setup](#setup) | [Usage](#usage) | [Pre-trained models](#pre-trained-models) 11 | 12 | 13 | ## 1. Setup 14 | 15 | ### 1.1 Cloning the repository 16 | 17 | Download the code: 18 | ``` 19 | git clone https://github.com/LinHuang17/NCF-code.git 20 | cd NCF-code 21 | ``` 22 | 23 | ### 1.2 Python environment and dependencies 24 | 25 | Create and activate conda environment with dependencies: 26 | ``` 27 | conda env create -f environment.yaml 28 | conda activate ncf 29 | ``` 30 | 31 | ### 1.3 BOP datasets 32 | 33 | For experiments on existing [BOP datasets](https://bop.felk.cvut.cz/datasets/), please follow the instructions on the [website](https://bop.felk.cvut.cz/datasets/) to download the base archives, 3D object models, the training images, and the test images. 34 | 35 | For YCB-V, you are expected to have files: `ycbv_base.zip`, `ycbv_models.zip`, `ycbv_train_pbr.zip`, `ycbv_train_real.zip` (used for training models with real images), and `ycbv_test_bop19.zip`. Then, unpack them into folder ``. 36 | 37 | ## 2. Usage 38 | 39 | ### 2.1 Inference with a pre-trained model 40 | 41 | To evaluate on an object (e.g., cracker box) from YCB-V: 42 | 43 | First, download and unpack the [pre-trained models](#pre-trained-models) into folder ``. 44 | 45 | Then, run the following command with the cracker box's pre-trained model: 46 | ``` 47 | export CUDA_VISIBLE_DEVICES=0 48 | python -m apps.eval --exp_id ncf_ycbv_run2_eval --work_base_path --model_dir --ds_ycbv_dir --obj_id 2 --bbx_size 380 --eval_perf True --load_netG_checkpoint_path --num_in_batch 10000 49 | ``` 50 | 51 | where `work_base_path` is the path to the results (e.g., the estimated pose csv file as `ncf-obj2_ycbv-Rt-time.csv`), `model_dir` is the path to the YCB-V 3D object models, `ds_ycbv_dir` is the path to the YCB-V dataset, and `load_netG_checkpoint_path` is the path to the cracker box's pre-trained model. 52 | 53 | 54 | ### 2.2 Training your own model 55 | 56 | To train on an object (e.g., cracker box) from YCB-V: 57 | 58 | Run the following command: 59 | 60 | ``` 61 | export CUDA_VISIBLE_DEVICES=0 62 | python -m apps.train --exp_id ncf_ycbv_run2_train --work_base_path --model_dir --ds_ycbv_dir --obj_id 2 --bbx_size 380 --num_in_batch 10000 63 | ``` 64 | 65 | where `work_base_path` is the path to the results (e.g., the estimated pose csv file as `ncf-obj2_ycbv-Rt-time.csv`), `model_dir` is the path to the YCB-V 3D object models, and `ds_ycbv_dir` is the path to the YCB-V dataset. 66 | 67 | 68 | ## 3. Pre-trained models 69 | 70 | - [YCB-V](https://drive.google.com/file/d/19rcvuIC7Ilu0MHPgLxmbxeUkOgBHR2be/view?usp=sharing) -------------------------------------------------------------------------------- /apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinHuang17/NCF-code/8e429efb320b136786dc5438ea7d78c231ffab16/apps/__init__.py -------------------------------------------------------------------------------- /apps/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | eval. for rigid obj. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 9 | ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | 11 | import cv2 12 | import json 13 | import time 14 | import random 15 | import torch 16 | import numpy as np 17 | from tqdm import tqdm 18 | import torch.nn as nn 19 | from torch.utils.data import DataLoader, ConcatDataset 20 | 21 | from lib.data import * 22 | from lib.model import * 23 | from lib.net_util import * 24 | from lib.eval_Rt_time_util import * 25 | from lib.options import BaseOptions 26 | 27 | 28 | # get options 29 | opt = BaseOptions().parse() 30 | 31 | 32 | def evaluate(opt): 33 | 34 | # seed 35 | if opt.deterministic: 36 | seed = opt.seed 37 | print("Set manual random Seed: ", seed) 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) # cpu 41 | torch.cuda.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | else: 46 | torch.backends.cudnn.benchmark = True 47 | print("cuDNN benchmarking enabled") 48 | 49 | # set path 50 | work_path = os.path.join(opt.work_base_path, f"{opt.exp_id}") 51 | os.makedirs(work_path, exist_ok=True) 52 | checkpoints_path = os.path.join(work_path, "checkpoints") 53 | os.makedirs(checkpoints_path, exist_ok=True) 54 | results_path = os.path.join(work_path, "results") 55 | os.makedirs(results_path, exist_ok=True) 56 | tb_dir = os.path.join(work_path, "tb") 57 | os.makedirs(tb_dir, exist_ok=True) 58 | tb_runs_dir = os.path.join(tb_dir, "runs") 59 | os.makedirs(tb_runs_dir, exist_ok=True) 60 | debug_dir = os.path.join(work_path, "debug") 61 | os.makedirs(debug_dir, exist_ok=True) 62 | 63 | # set gpu environment 64 | devices_ids = opt.GPU_ID 65 | num_GPU = len(devices_ids) 66 | torch.cuda.set_device(devices_ids[0]) 67 | 68 | # dataset 69 | test_dataset_list = [] 70 | test_data_ids = [opt.eval_data] 71 | for data_id in test_data_ids: 72 | if data_id == 'lm_bop_cha': 73 | test_dataset_list.append(BOP_BP_LM(opt, phase='test')) 74 | if data_id == 'lmo_bop_cha': 75 | test_dataset_list.append(BOP_BP_LMO(opt, phase='test')) 76 | if data_id == 'ycbv_bop_cha': 77 | test_dataset_list.append(BOP_BP_YCBV(opt, phase='test')) 78 | projection_mode = test_dataset_list[0].projection_mode 79 | test_dataset = ConcatDataset(test_dataset_list) 80 | # create test data loader 81 | # NOTE: batch size should be 1 and use all the points for evaluation 82 | test_data_loader = DataLoader(test_dataset, 83 | batch_size=1, shuffle=False, 84 | num_workers=opt.num_threads, pin_memory=(opt.num_threads == 0)) 85 | # persistent_workers=(opt.num_threads > 0)) 86 | # num_workers=opt.num_threads, pin_memory=opt.pin_memory) 87 | print('test data size: ', len(test_dataset)) 88 | 89 | # define model, multi-gpu, checkpoint 90 | sdf_criterion = None 91 | xyz_criterion = None 92 | netG = HGPIFuNet(opt, projection_mode, 93 | sdf_loss_term=sdf_criterion, 94 | xyz_loss_term=xyz_criterion) 95 | print('Using Network: ', netG.name) 96 | 97 | def set_eval(): 98 | netG.eval() 99 | 100 | # load checkpoints 101 | if opt.continue_train or opt.eval_perf: 102 | print('Loading for net G ...', opt.load_netG_checkpoint_path) 103 | netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=torch.device('cpu'))) 104 | 105 | # Data Parallel 106 | # if num_GPU > 1: 107 | netG = torch.nn.DataParallel(netG, device_ids=devices_ids, output_device=devices_ids[0]) 108 | # netG = torch.nn.parallel.DistributedDataParallel(netG, device_ids=devices_ids, output_device=devices_ids[0]) 109 | print(f'Data Paralleling on GPU: {devices_ids}') 110 | netG.cuda() 111 | 112 | os.makedirs(checkpoints_path, exist_ok=True) 113 | os.makedirs(results_path, exist_ok=True) 114 | os.makedirs('%s/%s' % (checkpoints_path, opt.name), exist_ok=True) 115 | os.makedirs('%s/%s' % (results_path, opt.name), exist_ok=True) 116 | opt_log = os.path.join(results_path, opt.name, 'opt.txt') 117 | with open(opt_log, 'w') as outfile: 118 | outfile.write(json.dumps(vars(opt), indent=2)) 119 | 120 | # evaluation 121 | with torch.no_grad(): 122 | set_eval() 123 | obj_id = [opt.obj_id][0] 124 | print('eval. for obj. pose and time (test) ...') 125 | save_csv_path = os.path.join(results_path, opt.name, f'ncf-obj{obj_id}_{opt.dataset}-Rt-time.csv') 126 | eval_Rt_time(opt, netG.module, test_data_loader, save_csv_path) 127 | 128 | if __name__ == '__main__': 129 | evaluate(opt) 130 | -------------------------------------------------------------------------------- /apps/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | train & eval. for rigid obj. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 9 | ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | 11 | import cv2 12 | import json 13 | import time 14 | import random 15 | import torch 16 | import numpy as np 17 | from tqdm import tqdm 18 | import torch.nn as nn 19 | from torch.utils.data import DataLoader, ConcatDataset 20 | 21 | from lib.data import * 22 | from lib.model import * 23 | from lib.net_util import * 24 | from lib.sym_util import * 25 | from lib.loss_util import * 26 | from lib.eval_Rt_time_util import * 27 | from lib.options import BaseOptions 28 | 29 | from lib.debug_pyrender_util import * 30 | from torch.utils.tensorboard import SummaryWriter 31 | 32 | # get options 33 | opt = BaseOptions().parse() 34 | 35 | class meter(): 36 | 37 | def __init__(self, opt): 38 | 39 | self.opt = opt 40 | 41 | self.load_time = AverageMeter() 42 | self.forward_time = AverageMeter() 43 | 44 | self.sdf_loss_meter = AverageMeter() 45 | if self.opt.use_xyz: 46 | self.xyz_loss_meter = AverageMeter() 47 | self.total_loss_meter = AverageMeter() 48 | 49 | def update_time(self, time, end, state): 50 | 51 | if state == 'forward': 52 | self.forward_time.update(time - end) 53 | elif state == 'load': 54 | self.load_time.update(time - end) 55 | 56 | def update_total_loss(self, total_loss, size): 57 | 58 | self.total_loss_meter.update(total_loss.item(), size) 59 | 60 | def update_loss(self, loss_dict, size): 61 | 62 | self.sdf_loss_meter.update(loss_dict['sdf_loss'].mean().item(), size) 63 | if self.opt.use_xyz: 64 | self.xyz_loss_meter.update(loss_dict['xyz_loss'].mean().item(), size) 65 | 66 | def set_dataset_train_mode(dataset, mode=True): 67 | for dataset_idx in range(len(dataset.datasets)): 68 | dataset.datasets[dataset_idx].is_train = mode 69 | 70 | def train(opt): 71 | 72 | # seed 73 | if opt.deterministic: 74 | seed = opt.seed 75 | print("Set manual random Seed: ", seed) 76 | random.seed(seed) 77 | np.random.seed(seed) 78 | torch.manual_seed(seed) # cpu 79 | torch.cuda.manual_seed(seed) 80 | torch.cuda.manual_seed_all(seed) 81 | torch.backends.cudnn.deterministic = True 82 | torch.backends.cudnn.benchmark = False 83 | else: 84 | torch.backends.cudnn.benchmark = True 85 | print("cuDNN benchmarking enabled") 86 | 87 | # set path 88 | work_path = os.path.join(opt.work_base_path, f"{opt.exp_id}") 89 | os.makedirs(work_path, exist_ok=True) 90 | checkpoints_path = os.path.join(work_path, "checkpoints") 91 | os.makedirs(checkpoints_path, exist_ok=True) 92 | results_path = os.path.join(work_path, "results") 93 | os.makedirs(results_path, exist_ok=True) 94 | tb_dir = os.path.join(work_path, "tb") 95 | os.makedirs(tb_dir, exist_ok=True) 96 | tb_runs_dir = os.path.join(tb_dir, "runs") 97 | os.makedirs(tb_runs_dir, exist_ok=True) 98 | debug_dir = os.path.join(work_path, "debug") 99 | os.makedirs(debug_dir, exist_ok=True) 100 | 101 | writer = SummaryWriter(os.path.join(tb_runs_dir, f'{opt.exp_id}')) 102 | writer.add_text('Info', 'ncf for obj. Rt est. in frustum space using pred. sdf & xyz') 103 | 104 | # set gpu environment 105 | devices_ids = opt.GPU_ID 106 | num_GPU = len(devices_ids) 107 | torch.cuda.set_device(devices_ids[0]) 108 | 109 | # dataset 110 | train_dataset_list = [] 111 | train_data_ids = [opt.train_data] + [opt.more_train_data] 112 | for data_id in train_data_ids: 113 | if data_id == 'lm': 114 | train_dataset_list.append(BOP_BP_LM(opt, phase='train')) 115 | if data_id == 'ycbv': 116 | train_dataset_list.append(BOP_BP_YCBV(opt, phase='train')) 117 | if data_id == 'ycbv_real': 118 | train_dataset_list.append(BOP_BP_YCBV_real(opt, phase='train')) 119 | projection_mode = train_dataset_list[0].projection_mode 120 | train_dataset = ConcatDataset(train_dataset_list) 121 | # create train data loader 122 | train_data_loader = DataLoader(train_dataset, 123 | batch_size=opt.batch_size, shuffle=not opt.serial_batches, 124 | num_workers=opt.num_threads, pin_memory=(opt.num_threads == 0)) 125 | # persistent_workers=(opt.num_threads > 0)) 126 | # num_workers=opt.num_threads, pin_memory=opt.pin_memory) 127 | print('train data size: ', len(train_dataset)) 128 | 129 | test_dataset_list = [] 130 | test_data_ids = [opt.eval_data] 131 | for data_id in test_data_ids: 132 | if data_id == 'lm_bop_cha': 133 | test_dataset_list.append(BOP_BP_LM(opt, phase='test')) 134 | if data_id == 'lmo_bop_cha': 135 | test_dataset_list.append(BOP_BP_LMO(opt, phase='test')) 136 | if data_id == 'ycbv_bop_cha': 137 | test_dataset_list.append(BOP_BP_YCBV(opt, phase='test')) 138 | test_dataset = ConcatDataset(test_dataset_list) 139 | # create test data loader 140 | # NOTE: batch size should be 1 and use all the points for evaluation 141 | test_data_loader = DataLoader(test_dataset, 142 | batch_size=1, shuffle=False, 143 | num_workers=opt.num_threads, pin_memory=(opt.num_threads == 0)) 144 | # persistent_workers=(opt.num_threads > 0)) 145 | # num_workers=opt.num_threads, pin_memory=opt.pin_memory) 146 | print('test data size: ', len(test_dataset)) 147 | 148 | # pre-define pool of symmetric poses 149 | sym_pool=[] 150 | obj_id = [opt.obj_id][0] 151 | # load obj. param. 152 | obj_params = get_obj_params(opt.model_dir, [opt.train_data][0]) 153 | # Load meta info about the models (including symmetries). 154 | models_info = load_json(obj_params['models_info_path'], keys_to_int=True) 155 | sym_poses = get_symmetry_transformations(models_info[obj_id], opt.max_sym_disc_step) 156 | for sym_pose in sym_poses: 157 | Rt = np.concatenate([sym_pose['R'], sym_pose['t'].reshape(3,1)], axis=1) 158 | Rt = np.concatenate([Rt, np.array([0, 0, 0, 1]).reshape(1, 4)], axis=0) 159 | sym_pool.append(torch.Tensor(Rt)) 160 | 161 | # define model, multi-gpu, checkpoint 162 | if opt.loss_type == 'mse': 163 | sdf_criterion = torch.nn.MSELoss() 164 | elif opt.loss_type == 'l1': 165 | sdf_criterion = torch.nn.L1Loss() 166 | elif opt.loss_type == 'huber': 167 | sdf_criterion = torch.nn.SmoothL1Loss() 168 | xyz_criterion = None 169 | if opt.use_xyz: 170 | if (len(sym_pool) > 1): 171 | xyz_criterion = XYZLoss_sym(use_xyz_mask=True, sym_pool=sym_pool) 172 | else: 173 | xyz_criterion = XYZLoss(use_xyz_mask=True) 174 | netG = HGPIFuNet(opt, projection_mode, 175 | sdf_loss_term=sdf_criterion, 176 | xyz_loss_term=xyz_criterion) 177 | print('Using Network: ', netG.name) 178 | 179 | def set_train(): 180 | netG.train() 181 | 182 | def set_eval(): 183 | netG.eval() 184 | 185 | # load checkpoints 186 | if opt.continue_train or opt.eval_perf: 187 | print('Loading for net G ...', opt.load_netG_checkpoint_path) 188 | netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=torch.device('cpu'))) 189 | 190 | # Data Parallel 191 | # if num_GPU > 1: 192 | netG = torch.nn.DataParallel(netG, device_ids=devices_ids, output_device=devices_ids[0]) 193 | # netG = torch.nn.parallel.DistributedDataParallel(netG, device_ids=devices_ids, output_device=devices_ids[0]) 194 | print(f'Data Paralleling on GPU: {devices_ids}') 195 | netG.cuda() 196 | 197 | os.makedirs(checkpoints_path, exist_ok=True) 198 | os.makedirs(results_path, exist_ok=True) 199 | os.makedirs('%s/%s' % (checkpoints_path, opt.name), exist_ok=True) 200 | os.makedirs('%s/%s' % (results_path, opt.name), exist_ok=True) 201 | opt_log = os.path.join(results_path, opt.name, 'opt.txt') 202 | with open(opt_log, 'w') as outfile: 203 | outfile.write(json.dumps(vars(opt), indent=2)) 204 | 205 | # optimizer 206 | lr = opt.learning_rate 207 | if opt.optimizer == 'rms': 208 | optimizerG = torch.optim.RMSprop(netG.module.parameters(), lr=lr, momentum=0, weight_decay=0) 209 | print(f'Using optimizer: rms') 210 | # optimizerG = torch.optim.RMSprop(netG.parameters(), lr=lr, momentum=0, weight_decay=0) 211 | # optimizerG = torch.optim.RMSprop(netG.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) 212 | elif opt.optimizer == 'adam': 213 | optimizerG = torch.optim.Adam(netG.module.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08) 214 | print(f'Using optimizer: adam') 215 | # optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08) 216 | # load optimizer 217 | if opt.continue_train and opt.load_optG_checkpoint_path is not None: 218 | print('Loading for opt G ...', opt.load_optG_checkpoint_path) 219 | optimizerG.load_state_dict(torch.load(opt.load_optG_checkpoint_path)) 220 | 221 | # training 222 | tb_train_idx = 0 223 | start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0) 224 | for epoch in range(start_epoch, opt.num_epoch): 225 | # log lr 226 | writer.add_scalar('train/learning_rate', lr, epoch) 227 | 228 | # meter, time, train mode 229 | train_meter = meter(opt) 230 | epoch_start_time = time.time() 231 | set_train() 232 | # torch.cuda.synchronize() 233 | iter_data_time = time.time() 234 | for train_idx, train_data in enumerate(train_data_loader): 235 | tb_train_idx += 1 236 | # measure elapsed data loading time in batch 237 | iter_start_time = time.time() 238 | train_meter.update_time(iter_start_time, iter_data_time, 'load') 239 | 240 | # retrieve the data 241 | # shape (B, 3, 480, 640) 242 | image_tensor = train_data['img'].cuda() 243 | # shape (B, 4, 4) 244 | calib_tensor = train_data['calib'].cuda() 245 | # shape (B, 3, 5000) 246 | sample_tensor = train_data['samples'].cuda() 247 | batch = image_tensor.size(0) 248 | 249 | # shape (B, 1, 5000) 250 | label_tensor = train_data['labels'].cuda() 251 | if opt.use_xyz: 252 | # shape (B, 1, 5000) 253 | xyz_tensor = train_data['xyzs'].cuda() 254 | xyz_mask_tensor = train_data['xyz_mask'].cuda() 255 | transforms = torch.zeros([batch,2,3]).cuda() 256 | transforms[:, 0,0] = 1 / (opt.img_size[0] // 2) 257 | transforms[:, 1,1] = 1 / (opt.img_size[1] // 2) 258 | transforms[:, 0,2] = -1 259 | transforms[:, 1,2] = -1 260 | if opt.use_xyz: 261 | results, loss_dict, xyzs, uvz = netG(image_tensor, sample_tensor, calib_tensor, 262 | labels=label_tensor, transforms=transforms, 263 | gt_xyzs=xyz_tensor, gt_xyz_mask=xyz_mask_tensor) 264 | else: 265 | results, loss_dict, uvz = netG(image_tensor, sample_tensor, calib_tensor, 266 | labels=label_tensor, transforms=transforms) 267 | 268 | optimizerG.zero_grad() 269 | # for param in netG.module.parameters(): 270 | # for param in netG.parameters(): 271 | # param.grad = None 272 | loss_dict['total_loss'].mean().backward() 273 | # error.backward() 274 | optimizerG.step() 275 | 276 | # measure elapsed forward time in batch 277 | # torch.cuda.synchronize() 278 | iter_net_time = time.time() 279 | train_meter.update_time(iter_net_time, iter_start_time, 'forward') 280 | eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - ( 281 | iter_net_time - epoch_start_time) 282 | 283 | # update loss 284 | train_meter.update_loss(loss_dict, batch) 285 | # update total loss 286 | train_meter.update_total_loss(loss_dict['total_loss'].mean(), batch) 287 | 288 | writer.add_scalar('train/total_loss_per_batch', train_meter.total_loss_meter.val, tb_train_idx) 289 | writer.add_scalar('train/sdf_loss_per_batch', train_meter.sdf_loss_meter.val, tb_train_idx) 290 | if opt.use_xyz: 291 | writer.add_scalar('train/xyz_loss_per_batch', train_meter.xyz_loss_meter.val, tb_train_idx) 292 | if train_idx % opt.freq_plot == 0: 293 | print('Name: {0} | Epoch: {1} | {2}/{3} | Loss: {4:.06f} | LR: {5:.06f} | dataT: {6:.05f} | netT: {7:.05f} | ETA: {8:02d}:{9:02d}'.format( 294 | opt.name, epoch, train_idx, len(train_data_loader), loss_dict['total_loss'].mean().item(), lr, 295 | iter_start_time - iter_data_time, iter_net_time - iter_start_time, int(eta // 60), int(eta - 60 * (eta // 60)))) 296 | 297 | if train_idx % opt.freq_debug == 0: 298 | with torch.no_grad(): 299 | # debug for rgb, mask, rendering of object model 300 | # shape (4, 3, 480, 640) 301 | name = train_data['name'][0] 302 | model_mesh = train_data_loader.dataset.datasets[0].model_mesh_dict[name].copy(include_cache=False) 303 | img = (np.transpose(image_tensor[0].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5) 304 | save_debug_path = os.path.join(debug_dir, f'train_sample{train_idx}_epoch{epoch}_debug.jpeg') 305 | viz_debug_data(img, model_mesh, 306 | train_data['extrinsic'][0].detach().numpy(), train_data['aug_intrinsic'][0].detach().numpy(), 307 | save_debug_path) 308 | 309 | # debug for query projection during forward 310 | # shape (4, 3, 5000), (4, 1, 5000) 311 | inv_trans = torch.zeros([1,2,3]) 312 | inv_trans[:, 0,0] = (opt.img_size[0] // 2) 313 | inv_trans[:, 1,1] = (opt.img_size[1] // 2) 314 | inv_trans[:, 0,2] = (opt.img_size[0] // 2) 315 | inv_trans[:, 1,2] = (opt.img_size[1] // 2) 316 | scale = inv_trans[:, :2, :2] 317 | shift = inv_trans[:, :2, 2:3] 318 | uv = torch.baddbmm(shift, scale, uvz[0].detach().cpu()[:2, :].unsqueeze(0)) 319 | query_res = {'img': image_tensor[0].detach().cpu(), 'samples': uv.squeeze(0), 'labels': label_tensor[0].detach().cpu()} 320 | save_in_query_path = os.path.join(debug_dir, f'train_sample{train_idx}_epoch{epoch}_in_query.jpeg') 321 | save_out_query_path = os.path.join(debug_dir, f'train_sample{train_idx}_epoch{epoch}_out_query.jpeg') 322 | viz_debug_query_forward(opt.out_type, query_res, save_in_query_path, save_out_query_path) 323 | 324 | # debug for prediction & gt ply for query & its label 325 | save_gt_path = os.path.join(debug_dir, f'train_sample{train_idx}_epoch{epoch}_sdf_gt.ply') 326 | save_sdf_path = os.path.join(debug_dir, f'train_sample{train_idx}_epoch{epoch}_sdf_est.ply') 327 | r = results[0].cpu() 328 | points = sample_tensor[0].transpose(0, 1).cpu() 329 | if opt.out_type[-3:] == 'sdf': 330 | save_samples_truncted_sdf(save_gt_path, points.detach().numpy(), label_tensor[0].transpose(0, 1).cpu().detach().numpy(), thres=opt.norm_clamp_dist) 331 | save_samples_truncted_sdf(save_sdf_path, points.detach().numpy(), r.detach().numpy(), thres=opt.norm_clamp_dist) 332 | if opt.use_xyz: 333 | norm_xyz_factor = train_data['norm_xyz_factor'][0].item() 334 | pred_xyzs = (xyzs[0].transpose(0, 1).cpu()) * norm_xyz_factor 335 | save_sdf_xyz_path = os.path.join(debug_dir, f'train_sample{train_idx}_epoch{epoch}_xyz_est.ply') 336 | save_samples_truncted_sdf(save_sdf_xyz_path, pred_xyzs.detach().numpy(), r.detach().numpy(), thres=opt.norm_clamp_dist) 337 | 338 | iter_data_time = time.time() 339 | 340 | writer.add_scalars('train/time_per_epoch', {'forward_per_batch': train_meter.forward_time.avg, 'dataload_per_batch': train_meter.load_time.avg}, epoch) 341 | writer.add_scalar('train/total_loss_per_epoch', train_meter.total_loss_meter.avg, epoch) 342 | writer.add_scalar('train/sdf_loss_per_epoch', train_meter.sdf_loss_meter.avg, epoch) 343 | if opt.use_xyz: 344 | writer.add_scalar('train/xyz_loss_per_epoch', train_meter.xyz_loss_meter.avg, epoch) 345 | # update learning rate 346 | lr = adjust_learning_rate(optimizerG, epoch, lr, opt.schedule, opt.gamma) 347 | # save checkpoints 348 | torch.save(netG.module.state_dict(), '%s/%s/netG_epoch_%d' % (checkpoints_path, opt.name, epoch)) 349 | torch.save(optimizerG.state_dict(), '%s/%s/optG_epoch_%d' % (checkpoints_path, opt.name, epoch)) 350 | 351 | #### test 352 | with torch.no_grad(): 353 | set_eval() 354 | obj_id = [opt.obj_id][0] 355 | if epoch > 0 and epoch % opt.freq_eval_all == 0 and opt.use_xyz and opt.gen_obj_pose: 356 | print('eval. for obj. pose and time (test) ...') 357 | save_csv_path = os.path.join(results_path, opt.name, f'ncf-obj{obj_id}_{opt.dataset}-Rt-time.csv') 358 | eval_Rt_time(opt, netG.module, test_data_loader, save_csv_path) 359 | 360 | writer.close() 361 | 362 | if __name__ == '__main__': 363 | train(opt) 364 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: ncf 2 | channels: 3 | - nodefaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: 7 | - python=3.7 8 | - cudatoolkit=10.1 9 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 10 | - torchvision=0.5.0=py37_cu101 11 | - numpy 12 | - tqdm 13 | - pyembree 14 | - shapely 15 | - xxhash 16 | - trimesh 17 | - eigenpy 18 | - rtree 19 | - scikit-image==0.16.2 20 | - matplotlib 21 | - scipy 22 | - imageio 23 | - cython 24 | - pip 25 | - pip: 26 | - pypng 27 | - pysdf 28 | - pyrender 29 | - tensorboard 30 | - transforms3d 31 | - opencv-python 32 | - Pillow==8.2.0 33 | - git+https://github.com/hassony2/chumpy.git 34 | - git+https://github.com/mmatl/pyopengl.git 35 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinHuang17/NCF-code/8e429efb320b136786dc5438ea7d78c231ffab16/lib/__init__.py -------------------------------------------------------------------------------- /lib/data/BOP_BP_YCBV.py: -------------------------------------------------------------------------------- 1 | """ 2 | dataset class of bop ycbv 3 | """ 4 | 5 | import os 6 | import sys 7 | import pdb 8 | import random 9 | import logging 10 | import inspect 11 | 12 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 13 | parentdir = os.path.dirname(currentdir) 14 | sys.path.insert(0, parentdir) 15 | 16 | import PIL 17 | import json 18 | import torch 19 | import pickle 20 | import numpy as np 21 | from PIL import Image, ImageOps 22 | from PIL.ImageFilter import GaussianBlur 23 | 24 | from torch.utils.data import Dataset 25 | from torch.utils.data import DataLoader 26 | import torchvision.transforms as transforms 27 | 28 | from data_utils.aug_util import AugmentOp, augment_image 29 | from data_utils.sample_frustum_util import load_trimesh, wks_sampling_sdf_xyz_calc, wks_sampling_eff_csdf_xyz_calc, xyz_mask_calc 30 | 31 | from options import BaseOptions 32 | from debug_pyrender_util import * 33 | 34 | # log = logging.getLogger('trimesh') 35 | # log.setLevel(40) 36 | 37 | class BOP_BP_YCBV(Dataset): 38 | @staticmethod 39 | def modify_commandline_options(parser, is_train): 40 | return parser 41 | 42 | def __init__(self, opt, phase='train'): 43 | self.opt = opt 44 | # path & state setup 45 | self.phase = phase 46 | self.is_train = (self.phase == 'train') 47 | 48 | # 3D->2D projection: 'orthogonal' or 'perspective' 49 | self.projection_mode = 'perspective' 50 | 51 | # ABBox or Sphere in cam. c.s. 52 | B_SHIFT = self.opt.bbx_shift 53 | Bx_SIZE = self.opt.bbx_size // 2 54 | By_SIZE = self.opt.bbx_size // 2 55 | Bz_SIZE = self.opt.bbx_size // 2 56 | self.B_MIN = np.array([-Bx_SIZE, -By_SIZE, -Bz_SIZE]) 57 | self.B_MAX = np.array([Bx_SIZE, By_SIZE, Bz_SIZE]) 58 | # wks box in cam. c.s. 59 | self.CAM_Bz_SHIFT = self.opt.wks_z_shift 60 | Cam_Bx_SIZE = self.opt.wks_size[0] // 2 61 | Cam_By_SIZE = self.opt.wks_size[1] // 2 62 | Cam_Bz_SIZE = self.opt.wks_size[2] // 2 63 | self.CAM_B_MIN = np.array([-Cam_Bx_SIZE, -Cam_By_SIZE, -Cam_Bz_SIZE+self.CAM_Bz_SHIFT]) 64 | self.CAM_B_MAX = np.array([Cam_Bx_SIZE, Cam_By_SIZE, Cam_Bz_SIZE+self.CAM_Bz_SHIFT]) 65 | # test wks box in cam. c.s. 66 | self.TEST_CAM_Bz_SHIFT = self.opt.test_wks_z_shift 67 | Test_Cam_Bx_SIZE = self.opt.test_wks_size[0] // 2 68 | Test_Cam_By_SIZE = self.opt.test_wks_size[1] // 2 69 | Test_Cam_Bz_SIZE = self.opt.test_wks_size[2] // 2 70 | self.TEST_CAM_B_MIN = np.array([-Test_Cam_Bx_SIZE, -Test_Cam_By_SIZE, -Test_Cam_Bz_SIZE+self.TEST_CAM_Bz_SHIFT]) 71 | self.TEST_CAM_B_MAX = np.array([Test_Cam_Bx_SIZE, Test_Cam_By_SIZE, Test_Cam_Bz_SIZE+self.TEST_CAM_Bz_SHIFT]) 72 | 73 | # PIL to tensor 74 | self.to_tensor = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 77 | ]) 78 | 79 | self.aug_ops = [ 80 | AugmentOp('blur', 0.4, [1, self.opt.aug_blur]), 81 | AugmentOp('sharpness', 0.3, [0.0, self.opt.aug_sha]), 82 | AugmentOp('contrast', 0.3, [0.2, self.opt.aug_con]), 83 | AugmentOp('brightness', 0.5, [0.1, self.opt.aug_bri]), 84 | AugmentOp('color', 0.3, [0.0, self.opt.aug_col]), 85 | ] 86 | 87 | # ycbv train 88 | # self.obj_id = self.opt.obj_id 89 | self.obj_id_list = [self.opt.obj_id] 90 | self.model_dir = self.opt.model_dir 91 | self.ds_root_dir = self.opt.ds_ycbv_dir 92 | if self.phase == 'train': 93 | self.ds_dir = os.path.join(self.ds_root_dir, 'train_pbr') 94 | start = 0 95 | end = 50 96 | self.visib_fract_thresh = self.opt.visib_fract_thresh 97 | elif self.phase == 'test': 98 | self.ds_dir = os.path.join(self.ds_root_dir, 'test') 99 | start = 48 100 | end = 60 101 | self.visib_fract_thresh = 0.0 102 | self.all_gt_info = [] 103 | for folder_id in range(start, end): 104 | self.scene_gt_dict = {} 105 | self.scene_gt_info_dict = {} 106 | self.scene_camera_dict = {} 107 | with open(os.path.join(self.ds_dir, f'{int(folder_id):06d}/scene_gt_info.json')) as f: 108 | self.scene_gt_info_dict = json.load(f) 109 | 110 | with open(os.path.join(self.ds_dir, f'{int(folder_id):06d}/scene_camera.json')) as f: 111 | self.scene_camera_dict = json.load(f) 112 | 113 | with open(os.path.join(self.ds_dir, f'{int(folder_id):06d}/scene_gt.json')) as f: 114 | self.scene_gt_dict = json.load(f) 115 | 116 | # for data_idx in range(len(self.scene_gt_info_dict)): 117 | for data_id in self.scene_gt_info_dict.keys(): 118 | len_item = len(self.scene_gt_info_dict[str(data_id)]) 119 | for obj_idx in range(len_item): 120 | # self.all_gt_info[str(obj_id)] = [] 121 | if self.scene_gt_dict[str(data_id)][obj_idx]['obj_id'] in self.obj_id_list: 122 | if self.scene_gt_info_dict[str(data_id)][obj_idx]['visib_fract'] > self.visib_fract_thresh: 123 | single_annot = {} 124 | single_annot['folder_id'] = folder_id 125 | single_annot['frame_id'] = int(data_id) 126 | single_annot['cam_R_m2c'] = self.scene_gt_dict[str(data_id)][obj_idx]['cam_R_m2c'] 127 | single_annot['cam_t_m2c'] = self.scene_gt_dict[str(data_id)][obj_idx]['cam_t_m2c'] 128 | single_annot['obj_id'] = self.scene_gt_dict[str(data_id)][obj_idx]['obj_id'] 129 | single_annot['cam_K'] = self.scene_camera_dict[str(data_id)]['cam_K'] 130 | # self.all_gt_info[str(obj_id)].append(single_annot) 131 | self.all_gt_info.append(single_annot) 132 | self.model_mesh_dict = load_trimesh(self.model_dir, self.opt.model_unit) 133 | 134 | def __len__(self): 135 | 136 | return len(self.all_gt_info) 137 | 138 | def get_img_cam(self, frame_id): 139 | 140 | data_gt_info = self.all_gt_info[frame_id] 141 | folder_id = data_gt_info['folder_id'] 142 | frame_id = data_gt_info['frame_id'] 143 | rgb_parent_path = os.path.join(self.ds_dir, f'{int(folder_id):06d}', 'rgb') 144 | if self.phase == 'train': 145 | rgb_path = os.path.join(rgb_parent_path, f'{int(frame_id):06d}.jpg') 146 | elif self.phase == 'test': 147 | rgb_path = os.path.join(rgb_parent_path, f'{int(frame_id):06d}.png') 148 | 149 | # shape (H, W, C)/(480, 640, 3) 150 | render = Image.open(rgb_path).convert('RGB') 151 | w, h = render.size 152 | 153 | # original camera intrinsic 154 | K = np.array(data_gt_info['cam_K']).reshape(3, 3) 155 | camera = dict(K=K.astype(np.float32), aug_K=np.copy(K.astype(np.float32)), resolution=(w, h)) 156 | 157 | objects = [] 158 | # annotation for every object in the scene 159 | # Rotation matrix from model to cam 160 | R_m2c = np.array(data_gt_info['cam_R_m2c']).reshape(3, 3) 161 | # translation vector from model to cam 162 | # unit: mm -> cm 163 | t_m2c = np.array(data_gt_info['cam_t_m2c']) 164 | # Rigid Transform class from model to cam/model c.s. 6D pose in cam c.s./extrinsic 165 | RT_m2c = np.concatenate([R_m2c, t_m2c.reshape(3,1)], axis=1) 166 | # model to cam: Rigid Transform homo. matrix 167 | RT_m2c = np.concatenate([RT_m2c, np.array([0, 0, 0, 1]).reshape(1, 4)], axis=0) 168 | obj_id = data_gt_info['obj_id'] 169 | name = f'obj_{int(obj_id):06d}' 170 | obj = dict(label=name, name=name, RT_m2c=RT_m2c.astype(np.float32)) 171 | objects.append(obj) 172 | 173 | # object name 174 | objname = objects[0]['name'] 175 | 176 | # color aug. 177 | if self.is_train and self.opt.use_aug: 178 | render = augment_image(render, self.aug_ops) 179 | 180 | aug_intrinsic = camera['aug_K'] 181 | aug_intrinsic = np.concatenate([aug_intrinsic, np.array([0, 0, 0]).reshape(3, 1)], 1) 182 | aug_intrinsic = np.concatenate([aug_intrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0) 183 | extrinsic = objects[0]['RT_m2c'] 184 | calib = torch.Tensor(np.matmul(aug_intrinsic, extrinsic)).float() 185 | extrinsic = torch.Tensor(extrinsic).float() 186 | aug_intrinsic = torch.Tensor(aug_intrinsic).float() 187 | 188 | render = self.to_tensor(render) 189 | 190 | # shape (C, H, W), ... 191 | return {'img': render, 'calib': aug_intrinsic, 'extrinsic': extrinsic, 'aug_intrinsic': aug_intrinsic, 'folder_id': folder_id, 'frame_id': frame_id, 'obj_id': obj_id, 'name': objname} 192 | 193 | def get_item(self, index): 194 | 195 | res = { 196 | 'b_min': self.CAM_B_MIN, 197 | 'b_max': self.CAM_B_MAX, 198 | 'test_b_min': self.TEST_CAM_B_MIN, 199 | 'test_b_max': self.TEST_CAM_B_MAX, 200 | } 201 | 202 | render_data = self.get_img_cam(index) 203 | res.update(render_data) 204 | if self.is_train: 205 | if self.opt.out_type[:3] == 'eff': 206 | # efficient conventional-SDF calculation 207 | sample_data = wks_sampling_eff_csdf_xyz_calc(self.opt, 208 | # bouding box 209 | self.B_MAX, self.B_MIN, 210 | # wks 211 | self.CAM_B_MAX, self.CAM_B_MIN, 212 | # model mesh 213 | self.model_mesh_dict[res['name']].copy(include_cache=False), 214 | # camera param. & bouding volume 215 | res['extrinsic'].clone(), res['calib'].clone(), bounding='sphere') 216 | else: 217 | # Ray-SDF or conventional-SDF calculation 218 | sample_data = wks_sampling_sdf_xyz_calc(self.opt, 219 | # bouding box 220 | self.B_MAX, self.B_MIN, 221 | # wks 222 | self.CAM_B_MAX, self.CAM_B_MIN, 223 | # model mesh 224 | self.model_mesh_dict[res['name']].copy(include_cache=False), 225 | # camera param. & bouding volume 226 | res['extrinsic'].clone(), res['calib'].clone(), bounding='sphere') 227 | if self.opt.use_xyz: 228 | xyz_mask = xyz_mask_calc(sdfs=sample_data['labels'].clone(), xyz_range=self.opt.norm_clamp_dist) 229 | res.update(xyz_mask) 230 | res.update(sample_data) 231 | else: 232 | norm_xyz_factor = self.opt.bbx_size / 2 233 | res['norm_xyz_factor'] = torch.tensor(norm_xyz_factor) 234 | 235 | return res 236 | 237 | def __getitem__(self, index): 238 | return self.get_item(index) 239 | 240 | 241 | if __name__ == '__main__': 242 | 243 | phase = 'train' 244 | opt = BaseOptions().parse() 245 | debug_path = f'/data1/lin/ncf_results/data/ycbv_{opt.out_type}_obj{opt.obj_id}_{phase}' 246 | os.makedirs(debug_path, exist_ok=True) 247 | dataset = BOP_BP_YCBV(opt, phase=phase) 248 | print(f'len. of dataset {len(dataset)}') 249 | 250 | num_debug = 10 251 | for idx in range(0, len(dataset), len(dataset) // num_debug): 252 | print(f'Debugging for sample: {idx}') 253 | res = dataset[idx] 254 | 255 | # debug for rgb, mask, rendering of object model 256 | # img = np.uint8((np.transpose(res['img'].numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0) 257 | model_mesh = dataset.model_mesh_dict[res['name']].copy(include_cache=False) 258 | img = (np.transpose(res['img'].numpy(), (1, 2, 0)) * 0.5 + 0.5) 259 | save_debug_path = os.path.join(debug_path, f'data_sample{idx}_debug_bop_ycbv.jpeg') 260 | viz_debug_data(img, model_mesh, 261 | res['extrinsic'].numpy(), res['aug_intrinsic'].numpy(), 262 | save_debug_path) 263 | 264 | # debug for sampled points with labels: same for each sample 265 | save_sdf_path = os.path.join(debug_path, f'data_sample{idx}_clamp{opt.norm_clamp_dist}_sdf.ply') 266 | save_sdf_xyz_path = os.path.join(debug_path, f'data_sample{idx}_clamp{opt.norm_clamp_dist}_xyz.ply') 267 | save_samples_truncted_sdf(save_sdf_path, res['samples'].numpy().T, res['labels'].numpy().T, thres=opt.norm_clamp_dist) 268 | save_samples_truncted_sdf(save_sdf_xyz_path, res['xyzs'].numpy().T, res['labels'].numpy().T, thres=opt.norm_clamp_dist) 269 | 270 | # debug for query projection 271 | save_in_query_path = os.path.join(debug_path, f'data_sample{idx}_in_query.jpeg') 272 | save_out_query_path = os.path.join(debug_path, f'data_sample{idx}_out_query.jpeg') 273 | viz_debug_query(opt.out_type, res, save_in_query_path, save_out_query_path) 274 | -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .BOP_BP_YCBV import BOP_BP_YCBV 2 | -------------------------------------------------------------------------------- /lib/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinHuang17/NCF-code/8e429efb320b136786dc5438ea7d78c231ffab16/lib/data_utils/__init__.py -------------------------------------------------------------------------------- /lib/data_utils/aug_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Image augmentation operations.""" 4 | 5 | 6 | import random 7 | from typing import Iterable, Tuple, NamedTuple 8 | 9 | import numpy as np 10 | from PIL import Image, ImageEnhance, ImageFilter 11 | from PIL.Image import Image as ImageType 12 | 13 | 14 | class AugmentOp(NamedTuple): 15 | """Parameters of an augmentation operation. 16 | 17 | name: Name of the augmentation operation. 18 | prob: Probability with which the operation is applied. 19 | param_range: A range from which to sample the value of the key parameter 20 | (each augmentation operation is assumed to have one key parameter). 21 | """ 22 | 23 | name: str 24 | prob: float 25 | param_range: Tuple[float, float] 26 | 27 | 28 | def _augment_pil_filter( 29 | im: ImageType, 30 | fn: ImageFilter.MultibandFilter, 31 | prob: float, 32 | param_range: Tuple[float, float], 33 | ) -> ImageType: 34 | """Generic function for augmentations based on PIL's filter function. 35 | 36 | Args: 37 | im: An input image. 38 | fn: A filtering function to apply to the image. 39 | prob: Probability with which the function is applied. 40 | param_range: A range from which the value of the key parameter is sampled. 41 | Returns: 42 | A potentially augmented image. 43 | """ 44 | 45 | if random.random() <= prob: 46 | im = im.filter(fn(random.randint(*map(int, param_range)))) # pyre-ignore 47 | return im 48 | 49 | 50 | def _augment_pil_enhance( 51 | im: ImageType, 52 | fn: ImageEnhance._Enhance, 53 | prob: float, 54 | param_range: Tuple[float, float], 55 | ) -> ImageType: 56 | """Generic function for augmentations based on PIL's enhance function. 57 | 58 | Args: 59 | im: An input image. 60 | fn: A filtering function to apply to the image. 61 | prob: Probability with which the function is applied. 62 | param_range: A range from which the value of the key parameter is sampled. 63 | Returns: 64 | A potentially augmented image. 65 | """ 66 | 67 | if random.random() <= prob: 68 | im = fn(im).enhance(factor=random.uniform(*param_range)) # pyre-ignore 69 | return im 70 | 71 | 72 | def blur(im, prob=0.5, param_range=(1, 3)): 73 | return _augment_pil_filter(im, ImageFilter.GaussianBlur, prob, param_range) 74 | 75 | 76 | def sharpness(im, prob=0.5, param_range=(0.0, 50.0)): 77 | return _augment_pil_enhance(im, ImageEnhance.Sharpness, prob, param_range) 78 | 79 | 80 | def contrast(im, prob=0.5, param_range=(0.2, 50.0)): 81 | return _augment_pil_enhance(im, ImageEnhance.Contrast, prob, param_range) 82 | 83 | 84 | def brightness(im, prob=0.5, param_range=(0.1, 6.0)): 85 | return _augment_pil_enhance(im, ImageEnhance.Brightness, prob, param_range) 86 | 87 | 88 | def color(im, prob=0.5, param_range=(0.0, 20.0)): 89 | return _augment_pil_enhance(im, ImageEnhance.Color, prob, param_range) 90 | 91 | 92 | # def augment_image(im: np.ndarray, augment_ops: Iterable[AugmentOp]) -> np.ndarray: 93 | # """Applies a list of augmentations to an image. 94 | 95 | # Args: 96 | # im: An input image. 97 | # augment_ops: A list of augmentations to apply. 98 | # Returns: 99 | # A potentially augmented image. 100 | # """ 101 | 102 | # im_pil = Image.fromarray(im) 103 | # for op in augment_ops: 104 | # im_pil = globals()[op.name](im_pil, op.prob, op.param_range) 105 | # return np.array(im_pil) 106 | 107 | def augment_image(im: ImageType, augment_ops: Iterable[AugmentOp]) -> ImageType: 108 | """Applies a list of augmentations to an image. 109 | 110 | Args: 111 | im: An input image. 112 | augment_ops: A list of augmentations to apply. 113 | Returns: 114 | A potentially augmented image. 115 | """ 116 | 117 | # im_pil = Image.fromarray(im) 118 | im_pil = im 119 | for op in augment_ops: 120 | im_pil = globals()[op.name](im_pil, op.prob, op.param_range) 121 | return im_pil 122 | -------------------------------------------------------------------------------- /lib/data_utils/sample_frustum_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | utils for trimesh loading 3 | sampling surface, bouding volume, camera space 4 | SDF calc. via trimesh or pysdf 5 | xyz_correspondence & its mask gen. 6 | """ 7 | import os 8 | import pdb 9 | import random 10 | import logging 11 | 12 | import torch 13 | import numpy as np 14 | 15 | import trimesh 16 | from trimesh.ray import ray_pyembree 17 | 18 | from pysdf import SDF 19 | 20 | log = logging.getLogger('trimesh') 21 | log.setLevel(40) 22 | 23 | 24 | def load_trimesh(model_dir, model_unit): 25 | files = os.listdir(model_dir) 26 | mesh_dict = {} 27 | for idx, filename in enumerate(files): 28 | if filename[-4:] == '.ply': 29 | # load mesh in model space 30 | model_mesh = trimesh.load(os.path.join(model_dir, filename), process=False) 31 | # m -> mm unit if orig. ycbv 32 | if model_unit == 'meter': 33 | # m -> mm unit 34 | model_mesh.vertices = model_mesh.vertices * 1000 35 | key = filename[:-4] 36 | mesh_dict[key] = model_mesh 37 | 38 | return mesh_dict 39 | 40 | 41 | def sampling_in_ball(num_points, dimension, radius=1): 42 | 43 | # First generate random directions by normalizing the length of a 44 | # vector of random-normal values (these distribute evenly on ball). 45 | random_directions = np.random.normal(size=(dimension,num_points)) 46 | random_directions /= np.linalg.norm(random_directions, axis=0) 47 | 48 | # Second generate a random radius with probability proportional to 49 | # the surface area of a ball with a given radius. 50 | random_radii = np.random.random(num_points) ** (1/dimension) 51 | 52 | # Return the list of random (direction & length) points. 53 | return radius * (random_directions * random_radii).T 54 | 55 | 56 | def out_of_plane_mask_calc(cam_pts, calib, img_size): 57 | # deal with out-of-plane cases 58 | c2i_rot = calib[:3, :3] 59 | c2i_trans = calib[:3, 3:4] 60 | img_sample_pts = torch.addmm(c2i_trans, c2i_rot, torch.Tensor(cam_pts.T).float()) 61 | img_sample_uvs = img_sample_pts[:2, :] / img_sample_pts[2:3, :] 62 | 63 | # normalize to [-1,1] 64 | transforms = torch.zeros([2,3]) 65 | transforms[0,0] = 1 / (img_size[0] // 2) 66 | transforms[1,1] = 1 / (img_size[1] // 2) 67 | transforms[0,2] = -1 68 | transforms[1,2] = -1 69 | scale = transforms[:2, :2] 70 | shift = transforms[:2, 2:3] 71 | img_sample_norm_uvs = torch.addmm(shift, scale, img_sample_uvs) 72 | in_img = (img_sample_norm_uvs[0,:] >= -1.0) & (img_sample_norm_uvs[0,:] <= 1.0) & (img_sample_norm_uvs[1,:] >= -1.0) & (img_sample_norm_uvs[1,:] <= 1.0) 73 | not_in_img = torch.logical_not(in_img).numpy() 74 | 75 | return not_in_img 76 | 77 | 78 | # ray-SDF or conventional-SDF 79 | def wks_sampling_sdf_xyz_calc(opt, bmax, bmin, cam_bmax, cam_bmin, model_mesh, extrinsic, calib, bounding): 80 | # if not self.is_train: 81 | # random.seed(1991) 82 | # np.random.seed(1991) 83 | # torch.manual_seed(1991) 84 | 85 | # extrinsic to transform from model to cam. space 86 | m2c_rot = extrinsic.numpy()[:3, :3] 87 | m2c_trans = extrinsic.numpy()[:3, 3:4] 88 | # (N, 3) 89 | cam_vert_pts = (m2c_rot.dot(model_mesh.vertices.T) + m2c_trans.reshape((3, 1))).T 90 | # load mesh in cam. space 91 | cam_mesh = trimesh.Trimesh(vertices=cam_vert_pts, faces=model_mesh.faces, process=False) 92 | # (1) sampling in surface with gaussian noise 93 | surf_ratio = float(opt.sample_ratio) / 8 94 | surface_points_cam, _ = trimesh.sample.sample_surface(cam_mesh, int(surf_ratio * opt.num_sample_inout)) 95 | # with gaussian noise 96 | sigma = opt.sigma_ratio * opt.clamp_dist 97 | noisy_surface_points_cam = surface_points_cam + np.random.normal(scale=sigma, size=surface_points_cam.shape) 98 | 99 | # (2) sampling in tight sphere: add random points within image space 100 | # 16:1=1250/16:0.5=625 in tight sphere 101 | bd_length = bmax - bmin 102 | zero_rot = np.identity(3) 103 | wks_ratio = opt.sample_ratio // 4 104 | if bounding == 'abb': 105 | bounding_points_model = np.random.rand(opt.num_sample_inout // wks_ratio, 3) * bd_length + bmin 106 | elif bounding == 'sphere': 107 | radius = bd_length.max() / 2 108 | bounding_points_model = sampling_in_ball(opt.num_sample_inout // wks_ratio, 3, radius=radius) 109 | # (N, 3) 110 | bounding_points_trans = (zero_rot.dot(bounding_points_model.T) + m2c_trans.reshape((3, 1))).T 111 | 112 | # (3) sampling in 3D frustum inside the 3D genearl workspace in front of the camera 113 | # 16:1=1250/16:0.5=625 in 3D workspace 114 | wks_sample_flag = True 115 | frustum_points_trans_list = [] 116 | wks_length = cam_bmax - cam_bmin 117 | while wks_sample_flag: 118 | # (N, 3) 119 | wks_points_trans = np.random.rand((opt.num_sample_inout // wks_ratio) * 10, 3) * wks_length + cam_bmin 120 | # filter out pts not in camera frustum 121 | # (N,) 122 | wks_not_in_img = out_of_plane_mask_calc(wks_points_trans, calib, opt.img_size) 123 | # (N,) 124 | wks_in_img = np.logical_not(wks_not_in_img) 125 | frustum_points_trans_list = frustum_points_trans_list + wks_points_trans[wks_in_img].tolist() 126 | if len(frustum_points_trans_list) >= (opt.num_sample_inout // wks_ratio): 127 | wks_sample_flag = False 128 | frustum_points_trans = np.array(frustum_points_trans_list[:(opt.num_sample_inout // wks_ratio)]) 129 | 130 | # (N, 3): combine all 21250 points 131 | sample_points_cam = np.concatenate([noisy_surface_points_cam, bounding_points_trans, frustum_points_trans], 0) 132 | np.random.shuffle(sample_points_cam) 133 | 134 | inside = cam_mesh.contains(sample_points_cam) 135 | inside_points = sample_points_cam[inside] 136 | outside_points = sample_points_cam[np.logical_not(inside)] 137 | 138 | nin = inside_points.shape[0] 139 | inside_points = inside_points[ 140 | :opt.num_sample_inout // 2] if nin > opt.num_sample_inout // 2 else inside_points 141 | outside_points = outside_points[ 142 | :opt.num_sample_inout // 2] if nin > opt.num_sample_inout // 2 else outside_points[ 143 | :(opt.num_sample_inout - nin)] 144 | # (N, 3) 145 | cam_sample_pts = np.concatenate([inside_points, outside_points], 0) 146 | 147 | # trimesh-based ray-SDF 148 | if opt.out_type == 'rsdf': 149 | # (N, 1) 150 | labels = np.concatenate([np.ones((1, inside_points.shape[0])), np.zeros((1, outside_points.shape[0]))], 1).T 151 | 152 | ray_mesh_emb = ray_pyembree.RayMeshIntersector(cam_mesh, scale_to_box=False) 153 | cam_sample_pt_sdf = np.zeros(cam_sample_pts.shape[0]) 154 | ray_origins = np.zeros_like(cam_sample_pts) 155 | delta_vect = (cam_sample_pts - ray_origins) 156 | norm_delta = np.expand_dims(np.linalg.norm(delta_vect, axis=1), axis=1) 157 | unit_ray_dir = delta_vect / norm_delta 158 | 159 | # intersect = ray_mesh_emb.intersects_any(ray_origins, unit_ray_dir) 160 | _, hit_index_ray, hit_locations = ray_mesh_emb.intersects_id(ray_origins, unit_ray_dir, multiple_hits=True, return_locations=True) 161 | # intersect mask 162 | hit_unique_idx_ray = np.unique(hit_index_ray) 163 | hit_ray_mask = np.zeros(cam_sample_pts.shape[0], dtype=bool) 164 | hit_ray_mask[hit_unique_idx_ray] = True 165 | for idx, pt in enumerate(cam_sample_pts): 166 | if hit_ray_mask[idx]: 167 | min_df = np.inf 168 | hit_idx_list = (np.where(np.array(hit_index_ray) == idx)[0]).tolist() 169 | for hit_idx in hit_idx_list: 170 | cur_df = np.linalg.norm((hit_locations[hit_idx] - pt)) 171 | if cur_df < min_df: 172 | min_df = cur_df 173 | if labels[idx]: 174 | cam_sample_pt_sdf[idx] = -min_df 175 | else: 176 | cam_sample_pt_sdf[idx] = min_df 177 | else: 178 | cam_sample_pt_sdf[idx] = 100 * opt.clamp_dist 179 | # pysdf-based conventional-SDF 180 | if opt.out_type == 'csdf': 181 | sdf_calc_func = SDF(cam_mesh.vertices, cam_mesh.faces) 182 | cam_sample_pt_sdf = (-1) * sdf_calc_func(cam_sample_pts) 183 | 184 | # shape (N, 1) 185 | sdfs = np.expand_dims(cam_sample_pt_sdf, axis=1) 186 | 187 | # deal with out-of-plane cases 188 | not_in_img = out_of_plane_mask_calc(cam_sample_pts, calib, opt.img_size) 189 | sdfs[not_in_img] = 100 * opt.clamp_dist 190 | 191 | norm_sdfs = sdfs / (opt.clamp_dist / opt.norm_clamp_dist) 192 | 193 | # obtain for xyz of correspondence in model space 194 | inverse_ext = torch.inverse(extrinsic) 195 | c2m_rot = inverse_ext[:3, :3] 196 | c2m_trans = inverse_ext[:3, 3:4] 197 | # (3, N) 198 | model_sample_pts = torch.addmm(c2m_trans, c2m_rot, torch.Tensor(cam_sample_pts.T).float()).float() 199 | norm_xyz_factor = opt.bbx_size / 2 200 | norm_model_sample_pts = model_sample_pts / norm_xyz_factor 201 | # (3, N) 202 | cam_sample_pts = torch.Tensor(cam_sample_pts.T).float() 203 | # (1, N) 204 | norm_sdfs = torch.Tensor(norm_sdfs.T).float() 205 | 206 | del model_mesh 207 | del cam_mesh 208 | 209 | return { 210 | 'samples': cam_sample_pts, 211 | 'labels': norm_sdfs, 212 | 'xyzs': norm_model_sample_pts, 213 | 'norm_xyz_factor': torch.tensor(norm_xyz_factor) 214 | } 215 | 216 | 217 | # efficient conventional-SDF 218 | def wks_sampling_eff_csdf_xyz_calc(opt, bmax, bmin, cam_bmax, cam_bmin, model_mesh, extrinsic, calib, bounding): 219 | # if not self.is_train: 220 | # random.seed(1991) 221 | # np.random.seed(1991) 222 | # torch.manual_seed(1991) 223 | 224 | # extrinsic to transform from model to cam. space 225 | m2c_rot = extrinsic.numpy()[:3, :3] 226 | m2c_trans = extrinsic.numpy()[:3, 3:4] 227 | # (N, 3) 228 | cam_vert_pts = (m2c_rot.dot(model_mesh.vertices.T) + m2c_trans.reshape((3, 1))).T 229 | # load mesh in cam. space 230 | cam_mesh = trimesh.Trimesh(vertices=cam_vert_pts, faces=model_mesh.faces, process=False) 231 | # (1) sampling in surface with gaussian noise 232 | surf_ratio = float(opt.sample_ratio) / 8 233 | surface_points_cam, _ = trimesh.sample.sample_surface(cam_mesh, int(surf_ratio * opt.num_sample_inout)) 234 | # with gaussian noise 235 | sigma = opt.sigma_ratio * opt.clamp_dist 236 | noisy_surface_points_cam = surface_points_cam + np.random.normal(scale=sigma, size=surface_points_cam.shape) 237 | 238 | # (2) sampling in tight sphere: add random points within image space 239 | # 16:1=1250/16:0.5=625 in tight sphere 240 | bd_length = bmax - bmin 241 | zero_rot = np.identity(3) 242 | wks_ratio = opt.sample_ratio // 4 243 | if bounding == 'abb': 244 | bounding_points_model = np.random.rand(opt.num_sample_inout // wks_ratio, 3) * bd_length + bmin 245 | elif bounding == 'sphere': 246 | radius = bd_length.max() / 2 247 | bounding_points_model = sampling_in_ball(opt.num_sample_inout // wks_ratio, 3, radius=radius) 248 | # (N, 3) 249 | bounding_points_trans = (zero_rot.dot(bounding_points_model.T) + m2c_trans.reshape((3, 1))).T 250 | 251 | # (3) sampling in 3D frustum inside the 3D genearl workspace in front of the camera 252 | # 16:1=1250/16:0.5=625 in 3D workspace 253 | wks_sample_flag = True 254 | frustum_points_trans_list = [] 255 | wks_length = cam_bmax - cam_bmin 256 | while wks_sample_flag: 257 | # (N, 3) 258 | wks_points_trans = np.random.rand((opt.num_sample_inout // wks_ratio) * 10, 3) * wks_length + cam_bmin 259 | # filter out pts not in camera frustum 260 | # (N,) 261 | wks_not_in_img = out_of_plane_mask_calc(wks_points_trans, calib, opt.img_size) 262 | # (N,) 263 | wks_in_img = np.logical_not(wks_not_in_img) 264 | frustum_points_trans_list = frustum_points_trans_list + wks_points_trans[wks_in_img].tolist() 265 | if len(frustum_points_trans_list) >= (opt.num_sample_inout // wks_ratio): 266 | wks_sample_flag = False 267 | frustum_points_trans = np.array(frustum_points_trans_list[:(opt.num_sample_inout // wks_ratio)]) 268 | 269 | # (N, 3): combine all 21250 points 270 | sample_points_cam = np.concatenate([noisy_surface_points_cam, bounding_points_trans, frustum_points_trans], 0) 271 | np.random.shuffle(sample_points_cam) 272 | 273 | # pysdf-based conventional-SDF 274 | sdf_calc_func = SDF(cam_mesh.vertices, cam_mesh.faces) 275 | sample_points_cam_sdf = (-1) * sdf_calc_func(sample_points_cam) 276 | 277 | inside = (sample_points_cam_sdf < 0) 278 | inside_points = sample_points_cam[inside] 279 | outside_points = sample_points_cam[np.logical_not(inside)] 280 | inside_points_sdf = sample_points_cam_sdf[inside] 281 | outside_points_sdf = sample_points_cam_sdf[np.logical_not(inside)] 282 | 283 | nin = inside_points.shape[0] 284 | inside_points = inside_points[ 285 | :opt.num_sample_inout // 2] if nin > opt.num_sample_inout // 2 else inside_points 286 | outside_points = outside_points[ 287 | :opt.num_sample_inout // 2] if nin > opt.num_sample_inout // 2 else outside_points[ 288 | :(opt.num_sample_inout - nin)] 289 | inside_points_sdf = inside_points_sdf[ 290 | :opt.num_sample_inout // 2] if nin > opt.num_sample_inout // 2 else inside_points_sdf 291 | outside_points_sdf = outside_points_sdf[ 292 | :opt.num_sample_inout // 2] if nin > opt.num_sample_inout // 2 else outside_points_sdf[ 293 | :(opt.num_sample_inout - nin)] 294 | # (N, 3) 295 | cam_sample_pts = np.concatenate([inside_points, outside_points], 0) 296 | cam_sample_pt_sdf = np.concatenate([inside_points_sdf, outside_points_sdf], 0) 297 | 298 | # shape (N, 1) 299 | sdfs = np.expand_dims(cam_sample_pt_sdf, axis=1) 300 | 301 | # deal with out-of-plane cases 302 | not_in_img = out_of_plane_mask_calc(cam_sample_pts, calib, opt.img_size) 303 | sdfs[not_in_img] = 100 * opt.clamp_dist 304 | 305 | norm_sdfs = sdfs / (opt.clamp_dist / opt.norm_clamp_dist) 306 | 307 | # obtain for xyz of correspondence in model space 308 | inverse_ext = torch.inverse(extrinsic) 309 | c2m_rot = inverse_ext[:3, :3] 310 | c2m_trans = inverse_ext[:3, 3:4] 311 | # (3, N) 312 | model_sample_pts = torch.addmm(c2m_trans, c2m_rot, torch.Tensor(cam_sample_pts.T).float()).float() 313 | norm_xyz_factor = opt.bbx_size / 2 314 | norm_model_sample_pts = model_sample_pts / norm_xyz_factor 315 | # (3, N) 316 | cam_sample_pts = torch.Tensor(cam_sample_pts.T).float() 317 | # (1, N) 318 | norm_sdfs = torch.Tensor(norm_sdfs.T).float() 319 | 320 | del model_mesh 321 | del cam_mesh 322 | 323 | return { 324 | 'samples': cam_sample_pts, 325 | 'labels': norm_sdfs, 326 | 'xyzs': norm_model_sample_pts, 327 | 'norm_xyz_factor': torch.tensor(norm_xyz_factor) 328 | } 329 | 330 | 331 | def xyz_mask_calc(sdfs, xyz_range): 332 | 333 | # shape (1, num_sample_inout) 334 | return {'xyz_mask': (abs(sdfs) < xyz_range).float()} -------------------------------------------------------------------------------- /lib/debug_pyrender_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | """ 5 | debugging utils 6 | """ 7 | 8 | import os 9 | import sys 10 | import pdb 11 | import code 12 | import json 13 | import random 14 | import pickle 15 | import warnings 16 | import datetime 17 | import subprocess 18 | 19 | import PIL 20 | import cv2 21 | import torch 22 | import numpy as np 23 | from matplotlib import pyplot as plt 24 | 25 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 26 | # import pyrender 27 | 28 | import trimesh 29 | # import transforms3d as t3d 30 | 31 | 32 | # """ 33 | # pyrender-based rendering 34 | # """ 35 | # class Renderer(object): 36 | # """ 37 | # Render mesh using PyRender for visualization. 38 | # in m unit by default 39 | # """ 40 | # def __init__(self, alight_color, dlight_color, dlight_int=2.0, bg='black', im_width=640, im_height=480): 41 | 42 | # self.im_width = im_width 43 | # self.im_height = im_height 44 | # # light initialization 45 | # self.alight_color = alight_color 46 | # self.dlight_int = dlight_int 47 | # self.dlight_color = dlight_color 48 | # # blending coe for bg 49 | # if bg == 'white': 50 | # self.bg_color = [1.0, 1.0, 1.0] 51 | # elif bg == 'black': 52 | # self.bg_color = [0.0, 0.0, 0.0] 53 | 54 | # # render creation 55 | # self.renderer = pyrender.OffscreenRenderer(self.im_width, self.im_height) 56 | # # renderer_flags = pyrender.constants.RenderFlags.DEPTH_ONLY 57 | # # renderer_flags = pyrender.constants.RenderFlags.FLAT 58 | # # renderer_flags = pyrender.constants.RenderFlags.RGBA 59 | 60 | # # light creation 61 | # self.direc_light = pyrender.DirectionalLight(color=self.dlight_color, intensity=self.dlight_int) 62 | 63 | # def render(self, cam_intr, cam_pose, tri_mesh): 64 | 65 | # # scene creation 66 | # self.scene = pyrender.Scene(ambient_light=self.alight_color, bg_color=self.bg_color) 67 | 68 | # # camera 69 | # K = np.copy(cam_intr) 70 | # fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] 71 | # # fx, fy, cx, cy = K[0][0], K[1][1], K[0][2], K[1][2] 72 | # camera = pyrender.IntrinsicsCamera(fx=fx, fy=fy, cx=cx, cy=cy) 73 | 74 | # # Object->Camera to Camera->Object. 75 | # # camera_pose = np.linalg.inv(camera_pose) 76 | 77 | # # OpenCV to OpenGL coordinate system. 78 | # camera_pose = self.opencv_to_opengl_transformation(cam_pose) 79 | 80 | # # create mesh node 81 | # tri_mesh.vertices *= 0.001 # To meters. 82 | # mesh = pyrender.Mesh.from_trimesh(tri_mesh) 83 | 84 | # # add mesh 85 | # self.scene.add(mesh) 86 | # # add direc_light 87 | # self.scene.add(self.direc_light, pose=camera_pose) 88 | # # Create a camera node and add 89 | # camera_node = pyrender.Node(camera=camera, matrix=camera_pose) 90 | # self.scene.add_node(camera_node) 91 | 92 | # # render 93 | # color, _ = self.renderer.render(self.scene) 94 | # # color, _ = renderer.render(scene, flags=renderer_flags) 95 | # self.scene.remove_node(camera_node) 96 | # color = np.uint8(color) 97 | 98 | # return color 99 | 100 | # def opencv_to_opengl_transformation(self, trans): 101 | # """Converts a transformation from OpenCV to OpenGL coordinate system. 102 | 103 | # :param trans: A 4x4 transformation matrix. 104 | # """ 105 | # yz_flip = np.eye(4, dtype=np.float64) 106 | # yz_flip[1, 1], yz_flip[2, 2] = -1, -1 107 | # trans = trans.dot(yz_flip) 108 | # return trans 109 | 110 | """ 111 | save point cloud 112 | """ 113 | def save_samples_rgb(fname, points, rgb): 114 | ''' 115 | Save the visualization of sampling to a ply file. 116 | Red points represent positive predictions. 117 | Green points represent negative predictions. 118 | :param fname: File name to save 119 | :param points: [N, 3] array of points 120 | :param rgb: [N, 3] array of rgb values in the range [0~1] 121 | :return: 122 | ''' 123 | to_save = np.concatenate([points, rgb * 255], axis=-1) 124 | return np.savetxt(fname, 125 | to_save, 126 | fmt='%.6f %.6f %.6f %d %d %d', 127 | comments='', 128 | header=( 129 | 'ply\nformat ascii 1.0\nelement vertex {:d}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header').format( 130 | points.shape[0]) 131 | ) 132 | 133 | def save_samples_truncted_prob(fname, points, prob): 134 | ''' 135 | Save the visualization of sampling to a ply file. 136 | Red points represent positive predictions. 137 | Green points represent negative predictions. 138 | :param fname: File name to save 139 | :param points: [N, 3] array of points 140 | :param prob: [N, 1] array of predictions in the range [0~1] 141 | :return: 142 | ''' 143 | r = (prob > 0.5).reshape([-1, 1]) * 255 144 | g = (prob < 0.5).reshape([-1, 1]) * 255 145 | b = np.zeros(r.shape) 146 | 147 | to_save = np.concatenate([points, r, g, b], axis=-1) 148 | return np.savetxt(fname, 149 | to_save, 150 | fmt='%.6f %.6f %.6f %d %d %d', 151 | comments='', 152 | header=( 153 | 'ply\nformat ascii 1.0\nelement vertex {:d}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header').format( 154 | points.shape[0]) 155 | ) 156 | 157 | def save_samples_truncted_sdf(fname, points, sdf, thres): 158 | ''' 159 | Save the visualization of sampling to a ply file. 160 | Red points represent positive predictions. 161 | Green points represent negative predictions. 162 | :param fname: File name to save 163 | :param points: [N, 3] array of points 164 | :param sdf: [N, 1] array of predictions in the range [0~1] 165 | :return: 166 | ''' 167 | r = (sdf <= -thres).reshape([-1, 1]) * 255 168 | g = (sdf >= thres).reshape([-1, 1]) * 255 169 | b = (abs(sdf) < thres).reshape([-1, 1]) * 255 170 | # b = np.zeros(r.shape) 171 | # pdb.set_trace() 172 | to_save = np.concatenate([points, r, g, b], axis=-1) 173 | return np.savetxt(fname, 174 | to_save, 175 | fmt='%.6f %.6f %.6f %d %d %d', 176 | comments='', 177 | header=( 178 | 'ply\nformat ascii 1.0\nelement vertex {:d}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header').format( 179 | points.shape[0]) 180 | ) 181 | 182 | """ 183 | save mesh 184 | """ 185 | def save_obj_mesh(mesh_path, verts, faces): 186 | file = open(mesh_path, 'w') 187 | 188 | for v in verts: 189 | file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) 190 | for f in faces: 191 | f_plus = f + 1 192 | file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1])) 193 | file.close() 194 | 195 | 196 | def save_obj_mesh_with_color(mesh_path, verts, faces, colors): 197 | file = open(mesh_path, 'w') 198 | 199 | for idx, v in enumerate(verts): 200 | c = colors[idx] 201 | file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % (v[0], v[1], v[2], c[0], c[1], c[2])) 202 | for f in faces: 203 | f_plus = f + 1 204 | file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1])) 205 | file.close() 206 | 207 | 208 | def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs): 209 | file = open(mesh_path, 'w') 210 | 211 | for idx, v in enumerate(verts): 212 | vt = uvs[idx] 213 | file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) 214 | file.write('vt %.4f %.4f\n' % (vt[0], vt[1])) 215 | 216 | for f in faces: 217 | f_plus = f + 1 218 | file.write('f %d/%d %d/%d %d/%d\n' % (f_plus[0], f_plus[0], 219 | f_plus[2], f_plus[2], 220 | f_plus[1], f_plus[1])) 221 | file.close() 222 | 223 | 224 | """ 225 | viz img, mask, rendering 226 | """ 227 | def viz_debug_data(img, model_mesh, extrinsic, aug_intrinsic, save_debug_path): 228 | 229 | fig = plt.figure(figsize=(3, 3)) 230 | ax = fig.add_subplot(1,1,1) 231 | ax.imshow(img) 232 | plt.axis('off') 233 | 234 | plt.tight_layout() 235 | plt.savefig(save_debug_path, dpi=100) 236 | 237 | 238 | """ 239 | viz query projection for debugging 240 | """ 241 | def viz_debug_query(out_type, res, save_in_query_path, save_out_query_path): 242 | 243 | # from RGB order to opencv BGR order 244 | img = np.uint8((np.transpose(res['img'].numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0) 245 | img_cp = np.copy(img) 246 | rot = res['calib'][:3, :3] 247 | trans = res['calib'][:3, 3:4] 248 | 249 | # draw points inside 250 | # pts = torch.addmm(trans, rot, sample_data['samples']) # [3, N] 251 | if out_type[-3:] == 'sdf': 252 | pts = torch.addmm(trans, rot, res['samples'][:, res['labels'][0] < 0]) # [3, N] 253 | uv = pts[:2, :] / pts[2:3, :] 254 | uvz = torch.cat([uv, pts[2:3, :]], 0) 255 | # draw projected queries 256 | img = np.ascontiguousarray(img, dtype=np.uint8) 257 | for pt in torch.transpose(uvz, 0, 1): 258 | img = cv2.circle(img, (int(pt[0]), int(pt[1])), 2, (0,0,255), -1) 259 | cv2.imwrite(save_in_query_path, img) 260 | 261 | # draw points outside 262 | if out_type[-3:] == 'sdf': 263 | pts = torch.addmm(trans, rot, res['samples'][:, res['labels'][0] > 0]) # [3, N] 264 | uv = pts[:2, :] / pts[2:3, :] 265 | uvz = torch.cat([uv, pts[2:3, :]], 0) 266 | # draw projected queries 267 | img_cp = np.ascontiguousarray(img_cp, dtype=np.uint8) 268 | for pt in torch.transpose(uvz, 0, 1): 269 | img_cp = cv2.circle(img_cp, (int(pt[0]), int(pt[1])), 2, (0,255,0), -1) 270 | cv2.imwrite(save_out_query_path, img_cp) 271 | 272 | def viz_debug_query_forward(out_type, res, save_in_query_path, save_out_query_path): 273 | 274 | # from RGB order to opencv BGR order 275 | img = np.uint8((np.transpose(res['img'].numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0) 276 | img_cp = np.copy(img) 277 | 278 | # draw points inside 279 | if out_type[-3:] == 'sdf': 280 | uv = (res['samples'][:, res['labels'][0] < 0]) # [2, N] 281 | # draw projected queries 282 | img = np.ascontiguousarray(img, dtype=np.uint8) 283 | for pt in torch.transpose(uv, 0, 1): 284 | img = cv2.circle(img, (int(pt[0]), int(pt[1])), 2, (0,0,255), -1) 285 | cv2.imwrite(save_in_query_path, img) 286 | 287 | # draw points outside 288 | if out_type[-3:] == 'sdf': 289 | uv = (res['samples'][:, res['labels'][0] > 0]) # [2, N] 290 | # draw projected queries 291 | img_cp = np.ascontiguousarray(img_cp, dtype=np.uint8) 292 | for pt in torch.transpose(uv, 0, 1): 293 | img_cp = cv2.circle(img_cp, (int(pt[0]), int(pt[1])), 2, (0,255,0), -1) 294 | cv2.imwrite(save_out_query_path, img_cp) 295 | 296 | """ 297 | Meter for recording 298 | """ 299 | class AverageMeter(object): 300 | """ 301 | refer to https://github.com/bearpaw/pytorch-pose 302 | Computes and stores the average and current value 303 | """ 304 | def __init__(self): 305 | self.reset() 306 | 307 | def reset(self): 308 | self.val = 0 309 | self.avg = 0 310 | self.sum = 0 311 | self.count = 0 312 | 313 | def update(self, val, n=1): 314 | self.val = val 315 | self.sum += val * n 316 | self.count += n 317 | self.avg = self.sum / self.count 318 | -------------------------------------------------------------------------------- /lib/eval_Rt_time_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | comprehensive evaluation for: 3 | SDF, predicted corresopndence, 6D pose 4 | """ 5 | 6 | import os 7 | import json 8 | import time 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import numpy as np 13 | from PIL import Image 14 | 15 | import trimesh 16 | 17 | from .geometry import * 18 | 19 | from lib.rigid_fit.ransac import RansacEstimator 20 | from lib.rigid_fit.ransac_kabsch import Procrustes 21 | 22 | from .sdf import create_grid, eval_sdf_xyz_grid_frustum 23 | 24 | 25 | def save_bop_results(path, results, version='bop19'): 26 | """Saves 6D object pose estimates to a file. 27 | :param path: Path to the output file. 28 | :param results: Dictionary with pose estimates. 29 | :param version: Version of the results. 30 | """ 31 | # See docs/bop_challenge_2019.md for details. 32 | if version == 'bop19': 33 | lines = ['scene_id,im_id,obj_id,score,R,t,time'] 34 | for res in results: 35 | if 'time' in res: 36 | run_time = res['time'] 37 | else: 38 | run_time = -1 39 | 40 | lines.append('{scene_id},{im_id},{obj_id},{score},{R},{t},{time}'.format( 41 | scene_id=res['scene_id'], 42 | im_id=res['im_id'], 43 | obj_id=res['obj_id'], 44 | score=res['score'], 45 | R=' '.join(map(str, res['R'].flatten().tolist())), 46 | t=' '.join(map(str, res['t'].flatten().tolist())), 47 | time=run_time)) 48 | 49 | with open(path, 'w') as f: 50 | f.write('\n'.join(lines)) 51 | 52 | else: 53 | raise ValueError('Unknown version of BOP results.') 54 | 55 | 56 | def out_of_plane_mask_calc(cam_pts, calib, img_size): 57 | # deal with out-of-plane cases 58 | c2i_rot = calib[:3, :3] 59 | c2i_trans = calib[:3, 3:4] 60 | img_sample_pts = torch.addmm(c2i_trans, c2i_rot, torch.Tensor(cam_pts.T).float()) 61 | img_sample_uvs = img_sample_pts[:2, :] / img_sample_pts[2:3, :] 62 | 63 | # normalize to [-1,1] 64 | transforms = torch.zeros([2,3]) 65 | transforms[0,0] = 1 / (img_size[0] // 2) 66 | transforms[1,1] = 1 / (img_size[1] // 2) 67 | transforms[0,2] = -1 68 | transforms[1,2] = -1 69 | scale = transforms[:2, :2] 70 | shift = transforms[:2, 2:3] 71 | img_sample_norm_uvs = torch.addmm(shift, scale, img_sample_uvs) 72 | in_img = (img_sample_norm_uvs[0,:] >= -1.0) & (img_sample_norm_uvs[0,:] <= 1.0) & (img_sample_norm_uvs[1,:] >= -1.0) & (img_sample_norm_uvs[1,:] <= 1.0) 73 | not_in_img = torch.logical_not(in_img).numpy() 74 | 75 | return not_in_img 76 | 77 | 78 | """ 79 | generate 6D rigid pose based on SDF & Corresopndence 80 | calculate eval. time 81 | """ 82 | def eval_Rt_time(opt, net, test_data_loader, save_csv_path): 83 | 84 | with torch.no_grad(): 85 | preds = [] 86 | # for test_idx, test_data in enumerate(test_data_loader): 87 | for test_idx, test_data in enumerate(tqdm(test_data_loader)): 88 | 89 | # retrieve the data 90 | # resolution = opt.resolution 91 | resolution_X = int(opt.test_wks_size[0] / opt.step_size) 92 | resolution_Y = int(opt.test_wks_size[1] / opt.step_size) 93 | resolution_Z = int(opt.test_wks_size[2] / opt.step_size) 94 | image_tensor = test_data['img'].cuda() 95 | calib_tensor = test_data['calib'].cuda() 96 | norm_xyz_factor = test_data['norm_xyz_factor'][0].item() 97 | 98 | # get all 3D queries 99 | # create a grid by resolution 100 | # and transforming matrix for grid coordinates to real world xyz 101 | b_min = np.array(test_data['test_b_min'][0]) 102 | b_max = np.array(test_data['test_b_max'][0]) 103 | coords, mat = create_grid(resolution_X, resolution_Y, resolution_Z, b_min, b_max, transform=None) 104 | # (M=KxKxK, 3) 105 | coords = coords.reshape([3, -1]).T 106 | # (M,) 107 | coords_not_in_img = out_of_plane_mask_calc(coords, test_data['calib'][0], opt.img_size) 108 | # (M,) 109 | coords_in_img = np.logical_not(coords_not_in_img) 110 | # (3, N) 111 | coords_in_frustum = coords[coords_in_img].T 112 | 113 | # transform for proj. 114 | transforms = torch.zeros([1,2,3]).cuda() 115 | transforms[:, 0,0] = 1 / (opt.img_size[0] // 2) 116 | transforms[:, 1,1] = 1 / (opt.img_size[1] // 2) 117 | transforms[:, 0,2] = -1 118 | transforms[:, 1,2] = -1 119 | 120 | # create ransac 121 | ransac = RansacEstimator( 122 | min_samples=opt.min_samples, 123 | residual_threshold=(opt.res_thresh)**2, 124 | max_trials=opt.max_trials, 125 | ) 126 | 127 | eval_start_time = time.time() 128 | # get 2D feat. maps 129 | net.filter(image_tensor) 130 | # Then we define the lambda function for cell evaluation 131 | def eval_func(points): 132 | points = np.expand_dims(points, axis=0) 133 | # points = np.repeat(points, net.num_views, axis=0) 134 | samples = torch.from_numpy(points).cuda().float() 135 | 136 | transforms = torch.zeros([1,2,3]).cuda() 137 | transforms[:, 0,0] = 1 / (opt.img_size[0] // 2) 138 | transforms[:, 1,1] = 1 / (opt.img_size[1] // 2) 139 | transforms[:, 0,2] = -1 140 | transforms[:, 1,2] = -1 141 | net.query(samples, calib_tensor, transforms=transforms) 142 | # shape (B, 1, N) -> (N) 143 | eval_sdfs = net.preds[0][0] 144 | # shape (B, 3, N) -> (3, N) 145 | eval_xyzs = net.xyzs[0] 146 | return eval_sdfs.detach().cpu().numpy(), eval_xyzs.detach().cpu().numpy() 147 | # (N), (3, N), all the predicted dfs and xyzs 148 | pred_sdfs, pred_xyzs = eval_sdf_xyz_grid_frustum(coords_in_frustum, eval_func, num_samples=opt.num_in_batch) 149 | # norm_xyz_factor = max(opt.bbx_size) / 2 150 | pred_xyzs = pred_xyzs * norm_xyz_factor 151 | # get sdf & xyz within clamping distance 152 | pos_anchor_mask = (abs(pred_sdfs) < opt.norm_clamp_dist) 153 | est_cam_pts = coords_in_frustum[:, pos_anchor_mask] 154 | est_model_pts = pred_xyzs[:, pos_anchor_mask] 155 | # mask_sdfs = pred_sdfs[pos_anchor_mask] 156 | 157 | # estimate 6D pose with RANSAC-based kabsch or procruste 158 | ret = ransac.fit(Procrustes(), [est_model_pts.T, est_cam_pts.T]) 159 | eval_end_time = time.time() 160 | eval_time = eval_end_time - eval_start_time 161 | 162 | # est. RT 163 | RT_m2c_est = ret["best_params"] 164 | R_m2c_est = RT_m2c_est[:3, :3] 165 | t_m2c_est = RT_m2c_est[:3, 3:4] 166 | 167 | scene_id = int(test_data['folder_id'][0]) 168 | im_id = int(test_data['frame_id'][0]) 169 | obj_id = int(test_data['obj_id'][0]) 170 | pred = dict(scene_id=scene_id, 171 | im_id=im_id, 172 | obj_id=obj_id, 173 | score=1, 174 | R=np.array(R_m2c_est).reshape(3, 3), 175 | t=np.array(t_m2c_est), 176 | time=eval_time) 177 | preds.append(pred) 178 | save_bop_results(save_csv_path, preds) 179 | -------------------------------------------------------------------------------- /lib/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def index(feat, uv): 5 | ''' 6 | 7 | :param feat: [B, C, H, W] image features 8 | :param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1] 9 | :return: [B, C, N] image features at the uv coordinates 10 | ''' 11 | uv = uv.transpose(1, 2) # [B, N, 2] 12 | uv = uv.unsqueeze(2) # [B, N, 1, 2] 13 | # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample 14 | # for old versions, simply remove the aligned_corners argument. 15 | samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1] 16 | return samples[:, :, :, 0] # [B, C, N] 17 | 18 | 19 | def orthogonal(points, calibrations, transforms=None): 20 | ''' 21 | Compute the orthogonal projections of 3D points into the image plane by given projection matrix 22 | :param points: [B, 3, N] Tensor of 3D points 23 | :param calibrations: [B, 4, 4] Tensor of projection matrix 24 | :param transforms: [B, 2, 3] Tensor of image transform matrix 25 | :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane 26 | ''' 27 | rot = calibrations[:, :3, :3] 28 | trans = calibrations[:, :3, 3:4] 29 | pts = torch.baddbmm(trans, rot, points) # [B, 3, N] 30 | if transforms is not None: 31 | scale = transforms[:2, :2] 32 | shift = transforms[:2, 2:3] 33 | pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) 34 | return pts 35 | 36 | 37 | def perspective(points, calibrations, transforms=None): 38 | ''' 39 | Compute the perspective projections of 3D points into the image plane by given projection matrix 40 | :param points: [Bx3xN] Tensor of 3D points 41 | :param calibrations: [Bx4x4] Tensor of projection matrix 42 | :param transforms: [Bx2x3] Tensor of image transform matrix 43 | :return: uv: [Bx2xN] Tensor of uv coordinates in the image plane 44 | ''' 45 | rot = calibrations[:, :3, :3] 46 | trans = calibrations[:, :3, 3:4] 47 | homo = torch.baddbmm(trans, rot, points) # [B, 3, N] 48 | uv = homo[:, :2, :] / homo[:, 2:3, :] 49 | if transforms is not None: 50 | scale = transforms[:, :2, :2] 51 | shift = transforms[:, :2, 2:3] 52 | # scale = transforms[:2, :2] 53 | # shift = transforms[:2, 2:3] 54 | uv = torch.baddbmm(shift, scale, uv) 55 | 56 | uvz = torch.cat([uv, homo[:, 2:3, :]], 1) 57 | return uvz 58 | -------------------------------------------------------------------------------- /lib/loss_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2020-present Zerong Zheng. All Rights Reserved. 3 | 4 | import os 5 | import json 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import pdb 11 | 12 | 13 | """ 14 | XYZ loss w/o symmetry 15 | """ 16 | class XYZLoss(nn.Module): 17 | def __init__(self, use_xyz_mask=True): 18 | super(XYZLoss, self).__init__() 19 | # self.criterion = nn.MSELoss(reduction='mean') 20 | self.criterion = nn.SmoothL1Loss() 21 | self.use_xyz_mask = use_xyz_mask 22 | 23 | def forward(self, output, target, xyz_mask): 24 | ''' 25 | should be consistent for tensor shape 26 | (B, 3, N)/(B, 3, N)/(B, 1, N) or 27 | (B, N, 3)/(B, N, 3)/(B, N, 1) 28 | ''' 29 | if self.use_xyz_mask: 30 | loss = self.criterion( 31 | output.mul(xyz_mask), 32 | target.mul(xyz_mask) 33 | ) 34 | else: 35 | # loss += 0.5 * self.criterion(xyz_pred, xyz_gt) 36 | loss = self.criterion(output, target) 37 | 38 | return loss 39 | 40 | """ 41 | XYZ loss with symmetry 42 | """ 43 | class XYZLoss_sym(nn.Module): 44 | def __init__(self, use_xyz_mask=True, sym_pool=None): 45 | super(XYZLoss_sym, self).__init__() 46 | # self.criterion = nn.MSELoss(reduction='mean') 47 | self.criterion = nn.SmoothL1Loss(reduction='none') 48 | self.use_xyz_mask = use_xyz_mask 49 | 50 | self.sym_pool = sym_pool 51 | 52 | def forward(self, output, target, xyz_mask): 53 | ''' 54 | should be consistent for tensor shape 55 | (B, 3, N)/(B, 3, N)/(B, 1, N) or 56 | (B, N, 3)/(B, N, 3)/(B, N, 1) 57 | ''' 58 | output = output.permute(0,2,1) 59 | target = target.permute(0,2,1) 60 | xyz_mask = xyz_mask.permute(0,2,1) 61 | if (len(self.sym_pool) > 1): 62 | for sym_id, transform in enumerate(self.sym_pool): 63 | # repeat: (3, 3) -> (B, 3, 3) 64 | rot = transform[:3, :3].cuda().repeat((target.size(0),1,1)) 65 | # repeat: (3, 1) -> (B, 3, 1) 66 | trans = transform[:3, 3:4].cuda().repeat((target.size(0),1,1)) 67 | # (B, 3, 3) * (B, 3, N) + (B, 3, 1) -> (B, 3, N) -> (B, N, 3) 68 | sym_target = torch.baddbmm(trans, rot, target.permute(0,2,1)).permute(0,2,1) 69 | if self.use_xyz_mask: 70 | # (B, N, 3) 71 | loss_xyz_temp = self.criterion(output.mul(xyz_mask), sym_target.mul(xyz_mask)) 72 | else: 73 | # loss += 0.5 * self.criterion(xyz_pred, xyz_gt) 74 | # (B, N, 3) 75 | loss_xyz_temp = self.criterion(output, sym_target) 76 | # (B, N) 77 | loss_xyz_temp = torch.sum(loss_xyz_temp, dim=2) / 3 78 | # (B) 79 | loss_sum = torch.sum(loss_xyz_temp, dim=1) 80 | if(sym_id > 0): 81 | # (M, B) 82 | loss_sums = torch.cat((loss_sums, loss_sum.unsqueeze(0)), dim=0) 83 | # (M, B, N) 84 | loss_xyzs = torch.cat((loss_xyzs, loss_xyz_temp.unsqueeze(0)), dim=0) 85 | else: 86 | loss_sums = loss_sum.unsqueeze(0) 87 | loss_xyzs = loss_xyz_temp.unsqueeze(0) 88 | # (1, B) 89 | min_values = torch.min(loss_sums, dim=0, keepdim=True)[0] 90 | # (M, B) 91 | loss_switch = torch.eq(loss_sums, min_values).type(output.dtype) 92 | # (M, B, 1) * (M, B, N) -> (M, B, N) 93 | loss_xyz = loss_switch.unsqueeze(2) * loss_xyzs 94 | # (B, N) 95 | loss_xyz = torch.sum(loss_xyz, dim=0) 96 | else: 97 | if self.use_xyz_mask: 98 | # (B, N, 3) 99 | loss_xyz = self.criterion(output.mul(xyz_mask), target.mul(xyz_mask)) 100 | else: 101 | # (B, N, 3) 102 | loss_xyz = self.criterion(output, target) 103 | # (B, N) 104 | loss_xyz = torch.sum(loss_xyz, dim=2) / 3 105 | loss = loss_xyz 106 | loss = torch.mean(loss) 107 | 108 | return loss 109 | 110 | class XYZLoss_old(nn.Module): 111 | def __init__(self, use_xyz_mask=True): 112 | super(XYZLoss_orig, self).__init__() 113 | # self.criterion = nn.MSELoss(reduction='mean') 114 | self.criterion = nn.SmoothL1Loss() 115 | self.use_xyz_mask = use_xyz_mask 116 | 117 | def forward(self, output, target, use_xyz_mask): 118 | batch_size = output.size(0) 119 | num_queries = output.size(1) 120 | xyzs_pred = output.reshape((batch_size, num_queries, -1)).split(1, 1) 121 | xyzs_gt = target.reshape((batch_size, num_queries, -1)).split(1, 1) 122 | loss = 0 123 | 124 | for idx in range(num_queries): 125 | xyz_pred = xyzs_pred[idx].squeeze() 126 | xyz_gt = xyzs_gt[idx].squeeze() 127 | if self.use_xyz_mask: 128 | # loss += 0.5 * self.criterion( 129 | loss += self.criterion( 130 | xyz_pred.mul(use_xyz_mask[:, idx]), 131 | xyz_gt.mul(use_xyz_mask[:, idx]) 132 | ) 133 | else: 134 | # loss += 0.5 * self.criterion(xyz_pred, xyz_gt) 135 | loss +=self.criterion(xyz_pred, xyz_gt) 136 | 137 | return loss / num_queries 138 | 139 | 140 | class LipschitzLoss(nn.Module): 141 | def __init__(self, k, reduction=None): 142 | super(LipschitzLoss, self).__init__() 143 | self.relu = nn.ReLU() 144 | self.k = k 145 | self.reduction = reduction 146 | 147 | def forward(self, x1, x2, y1, y2): 148 | l = self.relu(torch.norm(y1-y2, dim=-1) / (torch.norm(x1-x2, dim=-1)+1e-3) - self.k) 149 | # l = torch.clamp(l, 0.0, 5.0) # avoid 150 | if self.reduction is None or self.reduction == "mean": 151 | return torch.mean(l) 152 | else: 153 | return torch.sum(l) 154 | 155 | 156 | class HuberFunc(nn.Module): 157 | def __init__(self, reduction=None): 158 | super(HuberFunc, self).__init__() 159 | self.reduction = reduction 160 | 161 | def forward(self, x, delta): 162 | n = torch.abs(x) 163 | cond = n < delta 164 | l = torch.where(cond, 0.5 * n ** 2, n*delta - 0.5 * delta**2) 165 | if self.reduction is None or self.reduction == "mean": 166 | return torch.mean(l) 167 | else: 168 | return torch.sum(l) 169 | 170 | 171 | class SoftL1Loss(nn.Module): 172 | def __init__(self, reduction=None): 173 | super(SoftL1Loss, self).__init__() 174 | self.reduction = reduction 175 | 176 | def forward(self, input, target, eps=0.0, lamb=0.0): 177 | ret = torch.abs(input - target) - eps 178 | ret = torch.clamp(ret, min=0.0, max=100.0) 179 | ret = ret * (1 + lamb * torch.sign(target) * torch.sign(target-input)) 180 | if self.reduction is None or self.reduction == "mean": 181 | return torch.mean(ret) 182 | else: 183 | return torch.sum(ret) 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | 189 | criterion1 = XYZLoss() 190 | criterion2 = XYZLoss_orig() 191 | aa = torch.rand((2,5000,3)) 192 | bb = torch.rand((2,5000,3)) 193 | # bb = aa.clone() 194 | mask = torch.rand((2,5000,1)) > 0.3 195 | pdb.set_trace() 196 | 197 | loss1 = criterion1(aa,bb,mask) 198 | loss2 = criterion1(aa.permute(0,2,1),bb.permute(0,2,1),mask.permute(0,2,1)) 199 | loss3 = criterion2(aa,bb,mask) 200 | 201 | print('debug') 202 | -------------------------------------------------------------------------------- /lib/mesh_util.py: -------------------------------------------------------------------------------- 1 | from skimage import measure 2 | import numpy as np 3 | import torch 4 | from .sdf import create_grid, eval_grid_octree, eval_grid 5 | from skimage import measure 6 | 7 | 8 | def reconstruction(opt, net, calib_tensor, 9 | resolution, b_min, b_max, thresh=0.5, 10 | use_octree=False, num_samples=10000, transform=None): 11 | ''' 12 | Reconstruct meshes from sdf predicted by the network. 13 | :param net: a BasePixImpNet object. call image filter beforehead. 14 | :param cuda: cuda device 15 | :param calib_tensor: calibration tensor 16 | :param resolution: resolution of the grid cell 17 | :param b_min: bounding box corner [x_min, y_min, z_min] 18 | :param b_max: bounding box corner [x_max, y_max, z_max] 19 | :param use_octree: whether to use octree acceleration 20 | :param num_samples: how many points to query each gpu iteration 21 | :return: marching cubes results. 22 | ''' 23 | # First we create a grid by resolution 24 | # and transforming matrix for grid coordinates to real world xyz 25 | coords, mat = create_grid(resolution, resolution, resolution, 26 | b_min, b_max, transform=transform) 27 | 28 | # Then we define the lambda function for cell evaluation 29 | def eval_func(points): 30 | points = np.expand_dims(points, axis=0) 31 | points = np.repeat(points, net.num_views, axis=0) 32 | samples = torch.from_numpy(points).cuda().float() 33 | 34 | transforms = torch.zeros([1,2,3]).cuda() 35 | transforms[:, 0,0] = 1 / (opt.img_size[0] // 2) 36 | transforms[:, 1,1] = 1 / (opt.img_size[1] // 2) 37 | transforms[:, 0,2] = -1 38 | transforms[:, 1,2] = -1 39 | net.query(samples, calib_tensor, transforms=transforms) 40 | pred = net.get_preds()[0][0] 41 | return pred.detach().cpu().numpy() 42 | 43 | # Then we evaluate the grid 44 | if use_octree: 45 | sdf = eval_grid_octree(coords, eval_func, num_samples=num_samples) 46 | else: 47 | sdf = eval_grid(coords, eval_func, num_samples=num_samples) 48 | 49 | # Finally we do marching cubes 50 | try: 51 | verts, faces, normals, values = measure.marching_cubes_lewiner(sdf, thresh) 52 | # transform verts into world coordinate system 53 | verts = np.matmul(mat[:3, :3], verts.T) + mat[:3, 3:4] 54 | verts = verts.T 55 | return verts, faces, normals, values 56 | except: 57 | print('error cannot marching cubes') 58 | return -1 59 | 60 | -------------------------------------------------------------------------------- /lib/model/BasePIFuNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..geometry import index, orthogonal, perspective 6 | 7 | class BasePIFuNet(nn.Module): 8 | def __init__(self, 9 | projection_mode='perspective', 10 | sdf_loss_term=nn.L1Loss(), 11 | xyz_loss_term=nn.SmoothL1Loss(), 12 | ): 13 | """ 14 | :param projection_mode: 15 | Either orthogonal or perspective. 16 | It will call the corresponding function for projection. 17 | :param error_term: 18 | nn Loss between the predicted [B, Res, N] and the label [B, Res, N] 19 | """ 20 | super(BasePIFuNet, self).__init__() 21 | self.name = 'base' 22 | 23 | self.sdf_loss_term = sdf_loss_term 24 | self.xyz_loss_term = xyz_loss_term 25 | 26 | self.index = index 27 | self.projection = orthogonal if projection_mode == 'orthogonal' else perspective 28 | 29 | self.preds = None 30 | self.labels = None 31 | 32 | def forward(self, points, images, calibs, transforms=None): 33 | ''' 34 | :param points: [B, 3, N] world space coordinates of points 35 | :param images: [B, C, H, W] input images 36 | :param calibs: [B, 3, 4] calibration matrices for each image 37 | :param transforms: Optional [B, 2, 3] image space coordinate transforms 38 | :return: [B, Res, N] predictions for each point 39 | ''' 40 | self.filter(images) 41 | self.query(points, calibs, transforms) 42 | return self.get_preds() 43 | 44 | def filter(self, images): 45 | ''' 46 | Filter the input images 47 | store all intermediate features. 48 | :param images: [B, C, H, W] input images 49 | ''' 50 | None 51 | 52 | def query(self, points, calibs, transforms=None, labels=None): 53 | ''' 54 | Given 3D points, query the network predictions for each point. 55 | Image features should be pre-computed before this call. 56 | store all intermediate features. 57 | query() function may behave differently during training/testing. 58 | :param points: [B, 3, N] world space coordinates of points 59 | :param calibs: [B, 3, 4] calibration matrices for each image 60 | :param transforms: Optional [B, 2, 3] image space coordinate transforms 61 | :param labels: Optional [B, Res, N] gt labeling 62 | :return: [B, Res, N] predictions for each point 63 | ''' 64 | None 65 | 66 | def get_preds(self): 67 | ''' 68 | Get the predictions from the last query 69 | :return: [B, Res, N] network prediction for the last query 70 | ''' 71 | return self.preds 72 | 73 | def get_loss(self): 74 | ''' 75 | Get the network loss from the last query 76 | :return: loss term 77 | ''' 78 | return self.sdf_loss_term(self.preds, self.labels) 79 | -------------------------------------------------------------------------------- /lib/model/HGFilters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..net_util import * 5 | 6 | 7 | class HourGlass(nn.Module): 8 | def __init__(self, num_modules, depth, num_features, norm='batch'): 9 | super(HourGlass, self).__init__() 10 | self.num_modules = num_modules 11 | self.depth = depth 12 | self.features = num_features 13 | self.norm = norm 14 | 15 | self._generate_network(self.depth) 16 | 17 | def _generate_network(self, level): 18 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 19 | 20 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 21 | 22 | if level > 1: 23 | self._generate_network(level - 1) 24 | else: 25 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 26 | 27 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features, norm=self.norm)) 28 | 29 | def _forward(self, level, inp): 30 | # Upper branch 31 | up1 = inp 32 | up1 = self._modules['b1_' + str(level)](up1) 33 | 34 | # Lower branch 35 | low1 = F.avg_pool2d(inp, 2, stride=2) 36 | low1 = self._modules['b2_' + str(level)](low1) 37 | 38 | if level > 1: 39 | low2 = self._forward(level - 1, low1) 40 | else: 41 | low2 = low1 42 | low2 = self._modules['b2_plus_' + str(level)](low2) 43 | 44 | low3 = low2 45 | low3 = self._modules['b3_' + str(level)](low3) 46 | 47 | # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample 48 | # if the pretrained model behaves weirdly, switch with the commented line. 49 | # NOTE: I also found that "bicubic" works better. 50 | up2 = F.interpolate(low3, scale_factor=2, mode='bicubic', align_corners=True) 51 | # up2 = F.interpolate(low3, scale_factor=2, mode='nearest) 52 | 53 | return up1 + up2 54 | 55 | def forward(self, x): 56 | return self._forward(self.depth, x) 57 | 58 | 59 | class HGFilter(nn.Module): 60 | def __init__(self, opt): 61 | super(HGFilter, self).__init__() 62 | self.num_modules = opt.num_stack 63 | 64 | self.opt = opt 65 | 66 | # Base part 67 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 68 | 69 | if self.opt.norm == 'batch': 70 | self.bn1 = nn.BatchNorm2d(64) 71 | elif self.opt.norm == 'group': 72 | self.bn1 = nn.GroupNorm(32, 64) 73 | 74 | if self.opt.hg_down == 'conv64': 75 | self.conv2 = ConvBlock(64, 64, self.opt.norm) 76 | self.down_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 77 | elif self.opt.hg_down == 'conv128': 78 | self.conv2 = ConvBlock(64, 128, self.opt.norm) 79 | self.down_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1) 80 | elif self.opt.hg_down == 'ave_pool': 81 | self.conv2 = ConvBlock(64, 128, self.opt.norm) 82 | else: 83 | raise NameError('Unknown Fan Filter setting!') 84 | 85 | self.conv3 = ConvBlock(128, 128, self.opt.norm) 86 | self.conv4 = ConvBlock(128, 256, self.opt.norm) 87 | 88 | # Stacking part 89 | for hg_module in range(self.num_modules): 90 | self.add_module('m' + str(hg_module), HourGlass(1, opt.num_hourglass, 256, self.opt.norm)) 91 | 92 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256, self.opt.norm)) 93 | self.add_module('conv_last' + str(hg_module), 94 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 95 | if self.opt.norm == 'batch': 96 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 97 | elif self.opt.norm == 'group': 98 | self.add_module('bn_end' + str(hg_module), nn.GroupNorm(32, 256)) 99 | 100 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 101 | opt.hourglass_dim, kernel_size=1, stride=1, padding=0)) 102 | 103 | if hg_module < self.num_modules - 1: 104 | self.add_module( 105 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 106 | self.add_module('al' + str(hg_module), nn.Conv2d(opt.hourglass_dim, 107 | 256, kernel_size=1, stride=1, padding=0)) 108 | 109 | def forward(self, x): 110 | x = F.relu(self.bn1(self.conv1(x)), True) 111 | tmpx = x 112 | if self.opt.hg_down == 'ave_pool': 113 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 114 | elif self.opt.hg_down in ['conv64', 'conv128']: 115 | x = self.conv2(x) 116 | x = self.down_conv2(x) 117 | else: 118 | raise NameError('Unknown Fan Filter setting!') 119 | 120 | normx = x 121 | 122 | x = self.conv3(x) 123 | x = self.conv4(x) 124 | 125 | previous = x 126 | 127 | outputs = [] 128 | for i in range(self.num_modules): 129 | hg = self._modules['m' + str(i)](previous) 130 | 131 | ll = hg 132 | ll = self._modules['top_m_' + str(i)](ll) 133 | 134 | ll = F.relu(self._modules['bn_end' + str(i)] 135 | (self._modules['conv_last' + str(i)](ll)), True) 136 | 137 | # Predict heatmaps 138 | tmp_out = self._modules['l' + str(i)](ll) 139 | outputs.append(tmp_out) 140 | 141 | if i < self.num_modules - 1: 142 | ll = self._modules['bl' + str(i)](ll) 143 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 144 | previous = previous + ll + tmp_out_ 145 | 146 | return outputs, tmpx.detach(), normx 147 | -------------------------------------------------------------------------------- /lib/model/HGPIFuNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .BasePIFuNet import BasePIFuNet 5 | from .SurfaceClassifier import SurfaceClassifier 6 | from .RayDistanceNormalizer import RayDistanceNormalizer 7 | from .HGFilters import * 8 | from ..net_util import init_net 9 | 10 | 11 | class HGPIFuNet(BasePIFuNet): 12 | ''' 13 | HG PIFu network uses Hourglass stacks as the image filter. 14 | It does the following: 15 | 1. Compute image feature stacks and store it in self.im_feat_list 16 | self.im_feat_list[-1] is the last stack (output stack) 17 | 2. Calculate calibration 18 | 3. If training, it index on every intermediate stacks, 19 | If testing, it index on the last stack. 20 | 4. Classification. 21 | 5. During training, error is calculated on all stacks. 22 | ''' 23 | 24 | def __init__(self, 25 | opt, 26 | projection_mode='perspective', 27 | sdf_loss_term=nn.L1Loss(), 28 | xyz_loss_term=nn.SmoothL1Loss(), 29 | ): 30 | super(HGPIFuNet, self).__init__( 31 | projection_mode=projection_mode, 32 | sdf_loss_term=sdf_loss_term, 33 | xyz_loss_term=xyz_loss_term) 34 | 35 | self.name = 'hgpifu' 36 | 37 | self.opt = opt 38 | self.num_views = self.opt.num_views 39 | 40 | self.image_filter = HGFilter(opt) 41 | 42 | self.last_op = None 43 | if self.opt.out_type[-3:] == 'sdf': 44 | if self.opt.use_tanh: 45 | self.last_op = nn.Tanh() 46 | if self.opt.use_xyz: 47 | mlp_dim = self.opt.mlp_dim_xyz 48 | else: 49 | mlp_dim = self.opt.mlp_dim 50 | 51 | self.surface_classifier = SurfaceClassifier( 52 | filter_channels=mlp_dim, 53 | num_views=self.opt.num_views, 54 | no_residual=self.opt.no_residual, 55 | last_op=self.last_op) 56 | 57 | self.normalizer = RayDistanceNormalizer(opt) 58 | 59 | # This is a list of [B x Feat_i x H x W] features 60 | self.im_feat_list = [] 61 | self.tmpx = None 62 | self.normx = None 63 | 64 | self.intermediate_preds_list = [] 65 | 66 | # init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 67 | # gain (float) -- scaling factor for normal, xavier and orthogonal. 68 | # init_net(self) 69 | init_net(self, init_type=self.opt.init_type, init_gain=self.opt.init_gain) 70 | 71 | def filter(self, images): 72 | ''' 73 | Filter the input images 74 | store all intermediate features. 75 | :param images: [B, C, H, W] input images 76 | ''' 77 | self.im_feat_list, self.tmpx, self.normx = self.image_filter(images) 78 | # If it is not in training, only produce the last im_feat 79 | if not self.training: 80 | self.im_feat_list = [self.im_feat_list[-1]] 81 | 82 | def query(self, points, calibs, transforms=None, labels=None): 83 | ''' 84 | Given 3D points, query the network predictions for each point. 85 | Image features should be pre-computed before this call. 86 | store all intermediate features. 87 | query() function may behave differently during training/testing. 88 | :param points: [B, 3, N] world space coordinates of points 89 | :param calibs: [B, 3, 4] calibration matrices for each image 90 | :param transforms: Optional [B, 2, 3] image space coordinate transforms 91 | :param labels: Optional [B, Res, N] gt labeling 92 | :return: [B, Res, N] predictions for each point 93 | ''' 94 | if labels is not None: 95 | self.labels = labels 96 | 97 | self.uvz = self.projection(points, calibs, transforms) 98 | uv = self.uvz[:, :2, :] 99 | z = self.uvz[:, 2:3, :] 100 | 101 | # debug for query during forward 102 | # pdb.set_trace() 103 | # debug_dir = '/mnt/data0/lin/results/hopifu/debug/forward_query/' 104 | # res = {'img': images.cpu()[0], 'samples': xyz.cpu()[0], 'labels': labels.cpu()[0]} 105 | # viz_debug_query_forward(res, 1, debug_dir) 106 | 107 | in_img = (uv[:, 0] >= -1.0) & (uv[:, 0] <= 1.0) & (uv[:, 1] >= -1.0) & (uv[:, 1] <= 1.0) 108 | 109 | # self.z_feat = self.normalizer(z, calibs=calibs) 110 | self.dist_ray_feat = self.normalizer(points, uv, transforms=transforms, calibs=calibs) 111 | 112 | if self.opt.skip_hourglass: 113 | tmpx_local_feature = self.index(self.tmpx, uv) 114 | 115 | self.intermediate_preds_list = [] 116 | 117 | for im_feat in self.im_feat_list: 118 | # [B, Feat_i + z, N] 119 | point_local_feat_list = [self.index(im_feat, uv), self.dist_ray_feat] 120 | 121 | if self.opt.skip_hourglass: 122 | point_local_feat_list.append(tmpx_local_feature) 123 | 124 | point_local_feat = torch.cat(point_local_feat_list, 1) 125 | 126 | # out of image plane is always set to 0 for occupancy or 1000 for sdf 127 | # pred (B, 1, 5000)/(B, 4, 5000) 128 | # in_img (B, N), not_in_img (B, 1, N) 129 | # ((in_img == False).nonzero(as_tuple=True)) 130 | pred = in_img[:,None].float() * self.surface_classifier(point_local_feat) 131 | if self.opt.out_type[-3:] == 'sdf': 132 | norm_factor = (self.opt.clamp_dist / self.opt.norm_clamp_dist) 133 | not_in_img = (torch.logical_not(in_img).float() * (100 * self.opt.clamp_dist / norm_factor)).unsqueeze(1) 134 | if self.opt.use_xyz: 135 | added_zeros = torch.zeros((pred.shape[0], 3, pred.shape[2])).cuda() 136 | pred = pred + torch.cat((not_in_img, added_zeros), dim=1) 137 | else: 138 | pred = pred + not_in_img 139 | self.intermediate_preds_list.append(pred) 140 | 141 | # shape (B, 1, 5000) 142 | if self.opt.use_xyz: 143 | self.preds = self.intermediate_preds_list[-1][:,0,:].unsqueeze(1) 144 | # shape (B, 3, 5000) 145 | self.xyzs = self.intermediate_preds_list[-1][:,1:,:] 146 | else: 147 | self.preds = self.intermediate_preds_list[-1] 148 | 149 | def get_im_feat(self): 150 | ''' 151 | Get the image filter 152 | :return: [B, C_feat, H, W] image feature after filtering 153 | ''' 154 | return self.im_feat_list[-1] 155 | 156 | # def get_error(self): 157 | # ''' 158 | # Hourglass has its own intermediate supervision scheme 159 | # ''' 160 | # error = 0 161 | # for preds in self.intermediate_preds_list: 162 | # error += self.error_term(preds, self.labels) 163 | # error /= len(self.intermediate_preds_list) 164 | 165 | # return error 166 | def get_loss(self): 167 | ''' 168 | Hourglass has its own intermediate supervision scheme 169 | ''' 170 | loss_dict = {} 171 | loss_dict['sdf_loss'] = 0. 172 | if self.opt.use_xyz: 173 | loss_dict['xyz_loss'] = 0. 174 | loss_dict['total_loss'] = 0. 175 | for preds in self.intermediate_preds_list: 176 | 177 | if self.opt.out_type[-3:] == 'sdf': 178 | pred_sdf = torch.clamp(preds[:,0,:].unsqueeze(1), -self.opt.norm_clamp_dist, self.opt.norm_clamp_dist) 179 | gt_sdf = torch.clamp(self.labels, -self.opt.norm_clamp_dist, self.opt.norm_clamp_dist) 180 | loss_dict['sdf_loss'] += self.sdf_loss_term(pred_sdf, gt_sdf) 181 | 182 | if self.opt.use_xyz: 183 | loss_dict['xyz_loss'] += self.xyz_loss_term(preds[:,1:,:], self.norm_gt_xyzs, self.gt_xyz_mask) 184 | 185 | loss_dict['sdf_loss'] /= len(self.intermediate_preds_list) 186 | loss_dict['total_loss'] += loss_dict['sdf_loss'] 187 | if self.opt.use_xyz: 188 | loss_dict['xyz_loss'] /= len(self.intermediate_preds_list) 189 | loss_dict['total_loss'] += self.opt.xyz_lambda * loss_dict['xyz_loss'] 190 | 191 | return loss_dict 192 | 193 | def forward(self, images, points, calibs, labels=None, transforms=None, gt_xyzs=None, gt_xyz_mask=None, pairwise_dist=None, pairwise_pt_idxs=None): 194 | # pdb.set_trace() 195 | if self.opt.use_xyz: 196 | # norm_xyz_factor = max(self.opt.bbx_size) / 2 197 | # self.norm_points_model = xyzs / norm_xyz_factor 198 | self.norm_gt_xyzs = gt_xyzs 199 | self.gt_xyz_mask = gt_xyz_mask 200 | 201 | # Get image feature 202 | self.filter(images) 203 | 204 | # Phase 2: point query 205 | self.query(points=points, calibs=calibs, transforms=transforms, labels=labels) 206 | 207 | # get the prediction 208 | res = self.get_preds() 209 | 210 | # get the error 211 | loss_dict = self.get_loss() 212 | 213 | if self.opt.use_xyz: 214 | return res, loss_dict, self.xyzs, self.uvz 215 | else: 216 | return res, loss_dict, self.uvz 217 | -------------------------------------------------------------------------------- /lib/model/RayDistanceNormalizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class RayDistanceNormalizer(nn.Module): 6 | def __init__(self, opt): 7 | super(RayDistanceNormalizer, self).__init__() 8 | 9 | self.opt = opt 10 | self.norm_method = self.opt.rdist_norm 11 | 12 | if self.norm_method == 'uvf': 13 | self.half_w = (self.opt.img_size[0] // 2) 14 | self.half_h = (self.opt.img_size[1] // 2) 15 | if self.norm_method == 'minmax': 16 | CAM_Bz_SHIFT = self.opt.wks_z_shift 17 | Bx_SIZE = self.opt.wks_size[0] / 2 18 | By_SIZE = self.opt.wks_size[1] / 2 19 | Bz_SIZE = self.opt.wks_size[2] / 2 20 | self.rdist_min = -Bz_SIZE + CAM_Bz_SHIFT 21 | self.rdist_max = torch.norm(torch.tensor([Bx_SIZE, By_SIZE, Bz_SIZE + CAM_Bz_SHIFT], dtype=torch.float)).item() 22 | 23 | def forward(self, queries, norm_uv=None, transforms=None, calibs=None): 24 | ''' 25 | Normalize dist_ray_feature 26 | :param dist_ray_feature: [B, 1, N] query distance along the ray normalized by projected uv distance along the ray 27 | :return: 28 | ''' 29 | batch_size = queries.shape[0] 30 | pt_size = queries.shape[2] 31 | # (B, 1, N) = (B, 3, N) 32 | abs_dist_ray = torch.norm(queries, dim=1).unsqueeze(1) 33 | 34 | if self.norm_method == 'uvf': 35 | # (B, 2, 3) 36 | inv_trans = torch.zeros_like(transforms) 37 | inv_trans[:, 0,0] = self.half_w 38 | inv_trans[:, 1,1] = self.half_h 39 | # inv_trans[:, 0,2] = self.half_w 40 | # inv_trans[:, 1,2] = self.half_h 41 | inv_trans[:, 0,2] = 0 42 | inv_trans[:, 1,2] = 0 43 | scale = inv_trans[:, :2, :2] 44 | shift = inv_trans[:, :2, 2:3] 45 | # (B, 2, N) 46 | uv = torch.baddbmm(shift, scale, norm_uv) 47 | # (B) 48 | ave_focal = (calibs[:, 0,0] + calibs[:, 1,1]) / 2 49 | # (B, 1, N) 50 | ave_focal = ave_focal.unsqueeze(1).expand(batch_size, pt_size).unsqueeze(1) 51 | # (B, 3, N) 52 | proj_uvf = torch.cat((uv, ave_focal), dim=1) 53 | # (B, 1, N) 54 | proj_dist_ray = torch.norm(proj_uvf, dim=1).unsqueeze(1) 55 | 56 | return abs_dist_ray / proj_dist_ray 57 | 58 | elif self.norm_method == 'minmax': 59 | 60 | return (abs_dist_ray - self.rdist_min) / (self.rdist_max - self.rdist_min) 61 | 62 | -------------------------------------------------------------------------------- /lib/model/SurfaceClassifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SurfaceClassifier(nn.Module): 7 | def __init__(self, filter_channels, num_views=1, no_residual=True, last_op=None): 8 | super(SurfaceClassifier, self).__init__() 9 | 10 | self.filters = [] 11 | self.num_views = num_views 12 | self.no_residual = no_residual 13 | filter_channels = filter_channels 14 | self.last_op = last_op 15 | 16 | if self.no_residual: 17 | for l in range(0, len(filter_channels) - 1): 18 | self.filters.append(nn.Conv1d( 19 | filter_channels[l], 20 | filter_channels[l + 1], 21 | 1)) 22 | self.add_module("conv%d" % l, self.filters[l]) 23 | else: 24 | for l in range(0, len(filter_channels) - 1): 25 | if 0 != l: 26 | self.filters.append( 27 | nn.Conv1d( 28 | filter_channels[l] + filter_channels[0], 29 | filter_channels[l + 1], 30 | 1)) 31 | else: 32 | self.filters.append(nn.Conv1d( 33 | filter_channels[l], 34 | filter_channels[l + 1], 35 | 1)) 36 | 37 | self.add_module("conv%d" % l, self.filters[l]) 38 | 39 | def forward(self, feature): 40 | ''' 41 | 42 | :param feature: list of [BxC_inxHxW] tensors of image features 43 | :param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane 44 | :return: [BxC_outxN] tensor of features extracted at the coordinates 45 | ''' 46 | 47 | y = feature 48 | tmpy = feature 49 | for i, f in enumerate(self.filters): 50 | if self.no_residual: 51 | y = self._modules['conv' + str(i)](y) 52 | else: 53 | y = self._modules['conv' + str(i)]( 54 | y if i == 0 55 | else torch.cat([y, tmpy], 1) 56 | ) 57 | if i != len(self.filters) - 1: 58 | y = F.leaky_relu(y) 59 | 60 | if self.num_views > 1 and i == len(self.filters) // 2: 61 | y = y.view( 62 | -1, self.num_views, y.shape[1], y.shape[2] 63 | ).mean(dim=1) 64 | tmpy = feature.view( 65 | -1, self.num_views, feature.shape[1], feature.shape[2] 66 | ).mean(dim=1) 67 | 68 | if self.last_op: 69 | # y = self.last_op(y) 70 | y[:,0,:] = self.last_op(y[:,0,:]) 71 | 72 | return y 73 | -------------------------------------------------------------------------------- /lib/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .BasePIFuNet import BasePIFuNet 2 | from .HGPIFuNet import HGPIFuNet 3 | -------------------------------------------------------------------------------- /lib/net_util.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import torch 4 | import functools 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import torch.nn as nn 9 | from torch.nn import init 10 | import torch.nn.functional as F 11 | 12 | 13 | def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): 14 | """Sets the learning rate to the initial LR decayed by schedule""" 15 | if epoch in schedule: 16 | lr *= gamma 17 | for param_group in optimizer.param_groups: 18 | param_group['lr'] = lr 19 | return lr 20 | 21 | def init_weights(net, init_type='normal', init_gain=0.02): 22 | """Initialize network weights. 23 | 24 | Parameters: 25 | net (network) -- network to be initialized 26 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 27 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 28 | 29 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 30 | work better for some applications. Feel free to try yourself. 31 | """ 32 | 33 | def init_func(m): # define the initialization function 34 | classname = m.__class__.__name__ 35 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 36 | if init_type == 'normal': 37 | init.normal_(m.weight.data, 0.0, init_gain) 38 | elif init_type == 'xavier': 39 | init.xavier_normal_(m.weight.data, gain=init_gain) 40 | elif init_type == 'kaiming': 41 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 42 | elif init_type == 'orthogonal': 43 | init.orthogonal_(m.weight.data, gain=init_gain) 44 | else: 45 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 46 | if hasattr(m, 'bias') and m.bias is not None: 47 | init.constant_(m.bias.data, 0.0) 48 | elif classname.find( 49 | 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 50 | init.normal_(m.weight.data, 1.0, init_gain) 51 | init.constant_(m.bias.data, 0.0) 52 | 53 | print('initialize network with %s' % init_type) 54 | net.apply(init_func) # apply the initialization function 55 | 56 | 57 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 58 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 59 | Parameters: 60 | net (network) -- the network to be initialized 61 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 62 | gain (float) -- scaling factor for normal, xavier and orthogonal. 63 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 64 | 65 | Return an initialized network. 66 | """ 67 | if len(gpu_ids) > 0: 68 | assert (torch.cuda.is_available()) 69 | net.to(gpu_ids[0]) 70 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 71 | init_weights(net, init_type, init_gain=init_gain) 72 | return net 73 | 74 | 75 | def imageSpaceRotation(xy, rot): 76 | ''' 77 | args: 78 | xy: (B, 2, N) input 79 | rot: (B, 2) x,y axis rotation angles 80 | 81 | rotation center will be always image center (other rotation center can be represented by additional z translation) 82 | ''' 83 | disp = rot.unsqueeze(2).sin().expand_as(xy) 84 | return (disp * xy).sum(dim=1) 85 | 86 | 87 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 88 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 89 | 90 | Arguments: 91 | netD (network) -- discriminator network 92 | real_data (tensor array) -- real images 93 | fake_data (tensor array) -- generated images from the generator 94 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 95 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 96 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 97 | lambda_gp (float) -- weight for this loss 98 | 99 | Returns the gradient penalty loss 100 | """ 101 | if lambda_gp > 0.0: 102 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 103 | interpolatesv = real_data 104 | elif type == 'fake': 105 | interpolatesv = fake_data 106 | elif type == 'mixed': 107 | alpha = torch.rand(real_data.shape[0], 1) 108 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view( 109 | *real_data.shape) 110 | alpha = alpha.to(device) 111 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 112 | else: 113 | raise NotImplementedError('{} not implemented'.format(type)) 114 | interpolatesv.requires_grad_(True) 115 | disc_interpolates = netD(interpolatesv) 116 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 117 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 118 | create_graph=True, retain_graph=True, only_inputs=True) 119 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 120 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 121 | return gradient_penalty, gradients 122 | else: 123 | return 0.0, None 124 | 125 | def get_norm_layer(norm_type='instance'): 126 | """Return a normalization layer 127 | Parameters: 128 | norm_type (str) -- the name of the normalization layer: batch | instance | none 129 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 130 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 131 | """ 132 | if norm_type == 'batch': 133 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 134 | elif norm_type == 'instance': 135 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 136 | elif norm_type == 'group': 137 | norm_layer = functools.partial(nn.GroupNorm, 32) 138 | elif norm_type == 'none': 139 | norm_layer = None 140 | else: 141 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 142 | return norm_layer 143 | 144 | class Flatten(nn.Module): 145 | def forward(self, input): 146 | return input.view(input.size(0), -1) 147 | 148 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 149 | "3x3 convolution with padding" 150 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 151 | stride=strd, padding=padding, bias=bias) 152 | 153 | class ConvBlock(nn.Module): 154 | def __init__(self, in_planes, out_planes, norm='batch'): 155 | super(ConvBlock, self).__init__() 156 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 157 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 158 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 159 | 160 | if norm == 'batch': 161 | self.bn1 = nn.BatchNorm2d(in_planes) 162 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 163 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 164 | self.bn4 = nn.BatchNorm2d(in_planes) 165 | elif norm == 'group': 166 | self.bn1 = nn.GroupNorm(32, in_planes) 167 | self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) 168 | self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) 169 | self.bn4 = nn.GroupNorm(32, in_planes) 170 | 171 | if in_planes != out_planes: 172 | self.downsample = nn.Sequential( 173 | self.bn4, 174 | nn.ReLU(True), 175 | nn.Conv2d(in_planes, out_planes, 176 | kernel_size=1, stride=1, bias=False), 177 | ) 178 | else: 179 | self.downsample = None 180 | 181 | def forward(self, x): 182 | residual = x 183 | 184 | out1 = self.bn1(x) 185 | out1 = F.relu(out1, True) 186 | out1 = self.conv1(out1) 187 | 188 | out2 = self.bn2(out1) 189 | out2 = F.relu(out2, True) 190 | out2 = self.conv2(out2) 191 | 192 | out3 = self.bn3(out2) 193 | out3 = F.relu(out3, True) 194 | out3 = self.conv3(out3) 195 | 196 | out3 = torch.cat((out1, out2, out3), 1) 197 | 198 | if self.downsample is not None: 199 | residual = self.downsample(residual) 200 | 201 | out3 += residual 202 | 203 | return out3 204 | -------------------------------------------------------------------------------- /lib/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | class BaseOptions(): 6 | def __init__(self): 7 | self.initialized = False 8 | 9 | def initialize(self, parser): 10 | # Experiment launch: Logistic/Datasets related 11 | g_logistic = parser.add_argument_group('Logistic') 12 | g_logistic.add_argument('--exp_id', type=str, default='ncf_ycbv_run2',help='') 13 | g_logistic.add_argument('--work_base_path', type=str, default='/data1/lin/ncf_results/runs',help='') 14 | 15 | g_logistic.add_argument('--dataset', type=str, default='ycbv',help='lm | lmo | ycbv') 16 | g_logistic.add_argument('--train_data', type=str, default='ycbv', help='lm | ycbv | ycbv_real') 17 | g_logistic.add_argument('--more_train_data', type=str, default='none', help='ycbv_real') 18 | g_logistic.add_argument('--eval_data', type=str, default='ycbv_bop_cha', help='lm_bop_cha | lmo_bop_cha | ycbv_bop_cha') 19 | g_logistic.add_argument('--model_dir', type=str, default='/data2/lin/bop_datasets/ycbv/models', help='') 20 | g_logistic.add_argument('--ds_lm_dir', type=str, default='/data2/lin/bop_datasets/lm', help='') 21 | g_logistic.add_argument('--ds_lmo_dir', type=str, default='/data2/lin/bop_datasets/lmo', help='') 22 | g_logistic.add_argument('--ds_ycbv_dir', type=str, default='/data2/lin/bop_datasets/ycbv', help='') 23 | 24 | g_logistic.add_argument('--visib_fract_thresh', type=float, default=0.3, help='0.05 | 0.1 | 0.15 | 0.3') 25 | g_logistic.add_argument('--model_unit', type=str, default='mm', help='meter | mm') 26 | 27 | g_logistic.add_argument('--obj_id', default=2, type=int, help='ids for object') 28 | g_logistic.add_argument('--wks_size', type=int, default=[1600, 1600, 2000], help='size of workspace/mm') 29 | g_logistic.add_argument('--wks_z_shift', type=int, default=1010, help='shift of workspace/mm') 30 | g_logistic.add_argument('--test_wks_size', type=int, default=[1200, 1200, 930], help='size of test workspace/mm') 31 | g_logistic.add_argument('--test_wks_z_shift', type=int, default=925, help='shift of test workspace/mm') 32 | g_logistic.add_argument('--max_sym_disc_step', type=float, default=0.01, help='') 33 | g_logistic.add_argument('--sample_ratio', type=int, default=20, help='20 | 24 | 16 | 32 for surf') 34 | g_logistic.add_argument('--bbx_size', type=int, default=380, help='size of object bounding box/mm') 35 | g_logistic.add_argument('--bbx_shift', type=int, default=0, help='shift of object bounding box/mm') 36 | g_logistic.add_argument('--use_remap', type=bool, default=True, help='') 37 | g_logistic.add_argument('--rdist_norm', type=str, default='uvf', help='normlization method for ray distance, uvf|minmax') 38 | 39 | g_logistic.add_argument('--img_size', type=int, default=[640,480], help='image shape') 40 | g_logistic.add_argument('--num_views', type=int, default=1, help='How many views to use for multiview network.') 41 | 42 | g_logistic.add_argument('--GPU_ID', default=[0], type=int, help='# of GPUs') 43 | g_logistic.add_argument('--deterministic', type=bool, default=False, help='') 44 | g_logistic.add_argument('--seed', type=int, default=0) 45 | 46 | g_logistic.add_argument('--continue_train', type=bool, default=False, help='continue training: load model') 47 | g_logistic.add_argument('--resume_epoch', type=int, default=0, help='epoch resuming the training') 48 | g_logistic.add_argument('--eval_perf', type=bool, default=False, help='evaluation: load model') 49 | g_logistic.add_argument('--eval_epoch', type=int, default=0, help='epoch for eval.') 50 | 51 | g_logistic.add_argument('--load_netG_checkpoint_path', type=str, default=None, help='path to save checkpoints') 52 | g_logistic.add_argument('--load_optG_checkpoint_path', type=str, default=None, help='path to save checkpoints') 53 | g_logistic.add_argument('--name', type=str, default='example', 54 | help='name of the experiment. It decides where to store/load samples and models') 55 | 56 | # Sampling related 57 | g_sample = parser.add_argument_group('Sampling') 58 | g_sample.add_argument('--sigma_ratio', type=float, default=0.5, help='perturbation ratio of standard deviation for positions: 0.5 | 0.75') 59 | 60 | g_sample.add_argument('--num_sample_inout', type=int, default=5000, help='# of sampling points: 5000') 61 | 62 | # Rigid pose related 63 | g_rigid = parser.add_argument_group('Rigid') 64 | g_rigid.add_argument('--min_samples', type=int, default=3, help='min. #samples for ransac') 65 | g_rigid.add_argument('--res_thresh', type=float, default=20, help='residual threshold for selecting inliers') 66 | g_rigid.add_argument('--max_trials', type=int, default=200, help='max. #iterations') 67 | 68 | # Pre. & Aug. related 69 | g_aug = parser.add_argument_group('aug') 70 | # appearance 71 | g_aug.add_argument('--use_aug', type=bool, default=True, help='') 72 | g_aug.add_argument('--aug_blur', type=int, default=3, help='augmentation blur') 73 | g_aug.add_argument('--aug_sha', type=float, default=50.0, help='augmentation sharpness') 74 | g_aug.add_argument('--aug_con', type=float, default=50.0, help='augmentation contrast') 75 | g_aug.add_argument('--aug_bri', type=float, default=6.0, help='augmentation brightness') 76 | g_aug.add_argument('--aug_col', type=float, default=20.0, help='augmentation color') 77 | 78 | # Training related 79 | g_train = parser.add_argument_group('Training') 80 | g_train.add_argument('--batch_size', type=int, default=4, help='input batch size') 81 | 82 | g_train.add_argument('--num_threads', default=1, type=int, help='# sthreads for loading data') 83 | g_train.add_argument('--serial_batches', action='store_true', 84 | help='if true, takes images in order to make batches, otherwise takes them randomly') 85 | # g_train.add_argument('--pin_memory', type=bool, default=True, help='pin_memory') 86 | 87 | g_train.add_argument('--out_type', type=str, default='rsdf', help='rsdf | csdf | eff_csdf') 88 | g_train.add_argument('--loss_type', type=str, default='l1', help='mse | l1 | huber') 89 | g_train.add_argument('--clamp_dist', type=float, default=5.0, help='') 90 | g_train.add_argument('--norm_clamp_dist', type=float, default=0.1, help='') 91 | g_train.add_argument('--use_xyz', type=bool, default=True, help='') 92 | g_train.add_argument('--xyz_lambda', type=float, default=1.0, help='') 93 | 94 | g_train.add_argument('--init_type', type=str, default='normal', help='normal | xavier | kaiming | orthogonal') 95 | g_train.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal') 96 | g_train.add_argument('--optimizer', choices=["adam", "rms"], default="rms") 97 | g_train.add_argument('--learning_rate', type=float, default=1e-4, help='') # 1e-3 98 | g_train.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 99 | g_train.add_argument('--schedule', type=int, nargs='+', default=[500, 1000, 1500], 100 | help='Decrease learning rate at these epochs.') 101 | g_train.add_argument('--num_epoch', type=int, default=2000, help='num epoch to train') 102 | 103 | g_train.add_argument('--freq_plot', type=int, default=7000, help='freqency of the error plot') 104 | g_train.add_argument('--freq_debug', type=int, default=7000, help='frequence of the visualization') 105 | 106 | # Model related 107 | g_model = parser.add_argument_group('Model') 108 | # General 109 | g_model.add_argument('--norm', type=str, default='group', 110 | help='instance normalization or batch normalization or group normalization') 111 | # hg filter specify 112 | g_model.add_argument('--num_stack', type=int, default=4, help='# of stacked layer of hourglass') 113 | g_model.add_argument('--num_hourglass', type=int, default=2, help='# of hourglass') 114 | g_model.add_argument('--skip_hourglass', action='store_true', help='skip connection in hourglass') 115 | g_model.add_argument('--hg_down', type=str, default='ave_pool', help='ave pool || conv64 || conv128') 116 | g_model.add_argument('--hourglass_dim', type=int, default='256', help='256 | 512') 117 | 118 | # Classification General 119 | g_model.add_argument('--mlp_dim', nargs='+', default=[257, 1024, 512, 256, 128, 1], type=int, 120 | help='# of dimensions of mlp') 121 | g_model.add_argument('--mlp_dim_xyz', nargs='+', default=[257, 1024, 512, 256, 128, 4], 122 | type=int, help='# of dimensions of mlp') 123 | 124 | g_model.add_argument('--use_tanh', type=bool, default=True, 125 | help='using tanh after last conv of image_filter network') 126 | 127 | g_model.add_argument('--no_residual', action='store_true', help='no skip connection in mlp') 128 | 129 | # Eval. related 130 | g_eval = parser.add_argument_group('Evaluation') 131 | g_eval.add_argument('--step_size', type=int, default=10, help='step size (mm) of grid') 132 | g_eval.add_argument('--num_in_batch', type=int, default=1500000, help='number of each batch for eval.') 133 | g_eval.add_argument('--thresh', type=float, default=0.0, help='0.0999 | 0.0 | -0.0999') 134 | 135 | g_eval.add_argument('--freq_eval_all', type=int, default=20, help='freqency of the eval. for all') 136 | g_eval.add_argument('--gen_obj_pose', type=bool, default=True, help='') 137 | 138 | # special tasks 139 | self.initialized = True 140 | return parser 141 | 142 | def gather_options(self): 143 | # initialize parser with basic options 144 | if not self.initialized: 145 | parser = argparse.ArgumentParser( 146 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 147 | parser = self.initialize(parser) 148 | 149 | self.parser = parser 150 | 151 | return parser.parse_args() 152 | 153 | def print_options(self, opt): 154 | message = '' 155 | message += '----------------- Options ---------------\n' 156 | for k, v in sorted(vars(opt).items()): 157 | comment = '' 158 | default = self.parser.get_default(k) 159 | if v != default: 160 | comment = '\t[default: %s]' % str(default) 161 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 162 | message += '----------------- End -------------------' 163 | print(message) 164 | 165 | def parse(self): 166 | opt = self.gather_options() 167 | return opt 168 | -------------------------------------------------------------------------------- /lib/rigid_fit/ransac.py: -------------------------------------------------------------------------------- 1 | """A simple RANSAC class implementation. 2 | References: 3 | [1] : https://github.com/scikit-image/scikit-image/blob/master/skimage/measure/fit.py 4 | [2] : https://github.com/scikit-learn/scikit-learn/blob/e5698bde9/sklearn/linear_model/_ransac.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class RansacEstimator: 11 | """Random Sample Consensus. 12 | """ 13 | def __init__(self, min_samples=None, residual_threshold=None, max_trials=100): 14 | """Constructor. 15 | 16 | Args: 17 | min_samples: The minimal number of samples needed to fit the model 18 | to the data. If `None`, we assume a linear model in which case 19 | the minimum number is one more than the feature dimension. 20 | residual_threshold: The maximum allowed residual for a sample to 21 | be classified as an inlier. If `None`, the threshold is chosen 22 | to be the median absolute deviation of the target variable. 23 | max_trials: The maximum number of trials to run RANSAC for. By 24 | default, this value is 100. 25 | """ 26 | self.min_samples = min_samples 27 | self.residual_threshold = residual_threshold 28 | self.max_trials = max_trials 29 | 30 | def fit(self, model, data): 31 | """Robustely fit a model to the data. 32 | 33 | Args: 34 | model: a class object that implements `estimate` and 35 | `residuals` methods. 36 | data: the data to fit the model to. Can be a list of 37 | data pairs, such as `X` and `y` in the case of 38 | regression. 39 | 40 | Returns: 41 | A dictionary containing: 42 | best_model: the model with the largest consensus set 43 | and lowest residual error. 44 | inliers: a boolean mask indicating the inlier subset 45 | of the data for the best model. 46 | """ 47 | best_model = None 48 | best_inliers = None 49 | best_num_inliers = 0 50 | best_residual_sum = np.inf 51 | 52 | if not isinstance(data, (tuple, list)): 53 | data = [data] 54 | num_data, num_feats = data[0].shape 55 | 56 | for trial in range(self.max_trials): 57 | # randomly select subset 58 | rand_subset_idxs = np.random.choice( 59 | np.arange(num_data), size=self.min_samples, replace=False) 60 | rand_subset = [d[rand_subset_idxs] for d in data] 61 | 62 | # estimate with model 63 | model.estimate(*rand_subset) 64 | 65 | # compute residuals 66 | residuals = model.residuals(*data) 67 | # residuals_sum = residuals.sum() 68 | inliers = residuals <= self.residual_threshold 69 | num_inliers = np.sum(inliers) 70 | 71 | # decide if better 72 | # if (best_num_inliers < num_inliers) or (best_residual_sum > residuals_sum): 73 | if (best_num_inliers < num_inliers): 74 | best_num_inliers = num_inliers 75 | # best_residual_sum = residuals_sum 76 | best_inliers = inliers 77 | 78 | # refit model using all inliers for this set 79 | if best_num_inliers == 0: 80 | data_inliers = data 81 | else: 82 | data_inliers = [d[best_inliers] for d in data] 83 | model.estimate(*data_inliers) 84 | 85 | ret = { 86 | "best_params": model.params, 87 | "best_inliers": best_inliers, 88 | } 89 | return ret -------------------------------------------------------------------------------- /lib/rigid_fit/ransac_kabsch.py: -------------------------------------------------------------------------------- 1 | """Estimate a rigid transform between 2 point clouds. 2 | """ 3 | 4 | import numpy as np 5 | from .ransac import RansacEstimator 6 | 7 | import pdb 8 | 9 | def gen_data(N=100, frac=0.1): 10 | # create a random rigid transform 11 | transform = np.eye(4) 12 | # transform[:3, :3] = RotationMatrix.random() 13 | transform[:3, :3] = np.array([-0.52573111, 0.85065081, 0.0, 0.84825128, 0.52424812, -0.07505775, -0.06384793, -0.03946019, -0.99717919]).reshape(3,3) 14 | transform[:3, 3] = 2 * np.random.randn(3) + 1 15 | 16 | # create a random source point cloud 17 | src_pc = 5 * np.random.randn(N, 3) + 2 18 | dst_pc = Procrustes.transform_xyz(src_pc, transform) 19 | 20 | # corrupt 21 | rand_corrupt = np.random.choice(np.arange(len(src_pc)), replace=False, size=int(frac*N)) 22 | dst_pc[rand_corrupt] += np.random.uniform(-10, 10, (int(frac*N), 3)) 23 | 24 | return src_pc, dst_pc, transform, rand_corrupt 25 | 26 | 27 | def transform_from_rotm_tr(rotm, tr): 28 | transform = np.eye(4) 29 | transform[:3, :3] = rotm 30 | transform[:3, 3] = tr 31 | return transform 32 | 33 | class Procrustes: 34 | """Determines the best rigid transform [1] between two point clouds. 35 | 36 | References: 37 | [1]: https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem 38 | """ 39 | def __init__(self, transform=None): 40 | self._transform = transform 41 | 42 | def __call__(self, xyz): 43 | return Procrustes.transform_xyz(xyz, self._transform) 44 | 45 | @staticmethod 46 | def transform_xyz(xyz, transform): 47 | """Applies a rigid transform to an (N, 3) point cloud. 48 | """ 49 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1))]) # homogenize 3D pointcloud 50 | xyz_t_h = (transform @ xyz_h.T).T # apply transform 51 | return xyz_t_h[:, :3] 52 | 53 | # def estimate(self, X, Y): 54 | # # find centroids 55 | # X_c = np.mean(X, axis=0) 56 | # Y_c = np.mean(Y, axis=0) 57 | 58 | # # shift 59 | # X_s = X - X_c 60 | # Y_s = Y - Y_c 61 | 62 | # # compute SVD of covariance matrix 63 | # cov = Y_s.T @ X_s 64 | # u, _, vt = np.linalg.svd(cov) 65 | 66 | # # determine rotation 67 | # rot = u @ vt 68 | # if np.linalg.det(rot) < 0.: 69 | # vt[2, :] *= -1 70 | # rot = u @ vt 71 | 72 | # # determine optimal translation 73 | # trans = Y_c - rot @ X_c 74 | 75 | # self._transform = transform_from_rotm_tr(rot, trans) 76 | 77 | def estimate(self, X, Y): 78 | # find centroids 79 | X_c = np.mean(X, axis=0) 80 | Y_c = np.mean(Y, axis=0) 81 | 82 | # shift 83 | X_s = X - X_c 84 | Y_s = Y - Y_c 85 | 86 | # Computation of the covariance matrix 87 | C = np.dot(np.transpose(Y_s), X_s) 88 | 89 | # Computation of the optimal rotation matrix 90 | # This can be done using singular value decomposition (SVD) 91 | # Getting the sign of the det(V)*(W) to decide 92 | # whether we need to correct our rotation matrix to ensure a 93 | # right-handed coordinate system. 94 | # And finally calculating the optimal rotation matrix U 95 | # see http://en.wikipedia.org/wiki/Kabsch_algorithm 96 | V, S, W = np.linalg.svd(C) 97 | d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0 98 | 99 | if d: 100 | S[-1] = -S[-1] 101 | V[:, -1] = -V[:, -1] 102 | 103 | # Create Rotation matrix U 104 | rot = np.dot(V, W) 105 | 106 | # determine optimal translation 107 | trans = Y_c - rot @ X_c 108 | 109 | self._transform = transform_from_rotm_tr(rot, trans) 110 | 111 | def residuals(self, X, Y): 112 | """L2 distance between point correspondences. 113 | """ 114 | Y_est = self(X) 115 | sum_sq = np.sum((Y_est - Y)**2, axis=1) 116 | return sum_sq 117 | 118 | @property 119 | def params(self): 120 | return self._transform 121 | 122 | 123 | if __name__ == "__main__": 124 | src_pc, dst_pc, transform_true, rand_corrupt = gen_data(frac=0.2) 125 | 126 | # estimate without ransac, i.e. using all 127 | # point correspondences 128 | naive_model = Procrustes() 129 | naive_model.estimate(src_pc, dst_pc) 130 | transform_naive = naive_model.params 131 | mse_naive = np.sqrt(naive_model.residuals(src_pc, dst_pc).mean()) 132 | print("mse naive: {}".format(mse_naive)) 133 | 134 | 135 | # estimate with RANSAC 136 | ransac = RansacEstimator( 137 | min_samples=3, 138 | # 5, 10, 20 139 | residual_threshold=(10)**2, 140 | max_trials=100, 141 | ) 142 | ret = ransac.fit(Procrustes(), [src_pc, dst_pc]) 143 | transform_ransac = ret["best_params"] 144 | 145 | 146 | inliers_ransac = ret["best_inliers"] 147 | mse_ransac = np.sqrt(Procrustes(transform_ransac).residuals(src_pc, dst_pc).mean()) 148 | print("mse ransac all: {}".format(mse_ransac)) 149 | mse_ransac_inliers = np.sqrt( 150 | Procrustes(transform_ransac).residuals(src_pc[inliers_ransac], dst_pc[inliers_ransac]).mean()) 151 | print("mse ransac inliers: {}".format(mse_ransac_inliers)) 152 | pdb.set_trace() -------------------------------------------------------------------------------- /lib/sdf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def create_grid(resX, resY, resZ, b_min=np.array([0, 0, 0]), b_max=np.array([1, 1, 1]), transform=None): 5 | ''' 6 | Create a dense grid of given resolution and bounding box 7 | :param resX: resolution along X axis 8 | :param resY: resolution along Y axis 9 | :param resZ: resolution along Z axis 10 | :param b_min: vec3 (x_min, y_min, z_min) bounding box corner 11 | :param b_max: vec3 (x_max, y_max, z_max) bounding box corner 12 | :return: [3, resX, resY, resZ] coordinates of the grid, and transform matrix from mesh index 13 | ''' 14 | coords = np.mgrid[:resX, :resY, :resZ] 15 | coords = coords.reshape(3, -1) 16 | coords_matrix = np.eye(4) 17 | length = b_max - b_min 18 | coords_matrix[0, 0] = length[0] / resX 19 | coords_matrix[1, 1] = length[1] / resY 20 | coords_matrix[2, 2] = length[2] / resZ 21 | coords_matrix[0:3, 3] = b_min 22 | coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4] 23 | if transform is not None: 24 | coords = np.matmul(transform[:3, :3], coords) + transform[:3, 3:4] 25 | coords_matrix = np.matmul(transform, coords_matrix) 26 | coords = coords.reshape(3, resX, resY, resZ) 27 | return coords, coords_matrix 28 | 29 | 30 | def batch_eval(points, eval_func, num_samples=512 * 512 * 512): 31 | num_pts = points.shape[1] 32 | sdf = np.zeros(num_pts) 33 | 34 | num_batches = num_pts // num_samples 35 | for i in range(num_batches): 36 | sdf[i * num_samples:i * num_samples + num_samples] = eval_func( 37 | points[:, i * num_samples:i * num_samples + num_samples]) 38 | if num_pts % num_samples: 39 | sdf[num_batches * num_samples:] = eval_func(points[:, num_batches * num_samples:]) 40 | 41 | return sdf 42 | 43 | 44 | def eval_grid(coords, eval_func, num_samples=512 * 512 * 512): 45 | resolution = coords.shape[1:4] 46 | coords = coords.reshape([3, -1]) 47 | sdf = batch_eval(coords, eval_func, num_samples=num_samples) 48 | return sdf.reshape(resolution) 49 | 50 | 51 | def eval_grid_octree(coords, eval_func, 52 | init_resolution=64, threshold=0.01, 53 | num_samples=512 * 512 * 512): 54 | resolution = coords.shape[1:4] 55 | 56 | sdf = np.zeros(resolution) 57 | 58 | dirty = np.ones(resolution, dtype=np.bool) 59 | grid_mask = np.zeros(resolution, dtype=np.bool) 60 | 61 | reso = resolution[0] // init_resolution 62 | 63 | while reso > 0: 64 | # subdivide the grid 65 | grid_mask[0:resolution[0]:reso, 0:resolution[1]:reso, 0:resolution[2]:reso] = True 66 | # test samples in this iteration 67 | test_mask = np.logical_and(grid_mask, dirty) 68 | #print('step size:', reso, 'test sample size:', test_mask.sum()) 69 | points = coords[:, test_mask] 70 | 71 | sdf[test_mask] = batch_eval(points, eval_func, num_samples=num_samples) 72 | dirty[test_mask] = False 73 | 74 | # do interpolation 75 | if reso <= 1: 76 | break 77 | for x in range(0, resolution[0] - reso, reso): 78 | for y in range(0, resolution[1] - reso, reso): 79 | for z in range(0, resolution[2] - reso, reso): 80 | # if center marked, return 81 | if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]: 82 | continue 83 | v0 = sdf[x, y, z] 84 | v1 = sdf[x, y, z + reso] 85 | v2 = sdf[x, y + reso, z] 86 | v3 = sdf[x, y + reso, z + reso] 87 | v4 = sdf[x + reso, y, z] 88 | v5 = sdf[x + reso, y, z + reso] 89 | v6 = sdf[x + reso, y + reso, z] 90 | v7 = sdf[x + reso, y + reso, z + reso] 91 | v = np.array([v0, v1, v2, v3, v4, v5, v6, v7]) 92 | v_min = v.min() 93 | v_max = v.max() 94 | # this cell is all the same 95 | if (v_max - v_min) < threshold: 96 | sdf[x:x + reso, y:y + reso, z:z + reso] = (v_max + v_min) / 2 97 | dirty[x:x + reso, y:y + reso, z:z + reso] = False 98 | reso //= 2 99 | 100 | return sdf.reshape(resolution) 101 | 102 | 103 | 104 | 105 | """ 106 | for hopifu-only 107 | """ 108 | import pdb 109 | def batch_eval_sdf_xyz(points, eval_func, num_samples=512 * 512 * 512): 110 | num_pts = points.shape[1] 111 | sdf = np.zeros(num_pts) 112 | xyz = np.zeros((3, num_pts)) 113 | 114 | # pdb.set_trace() 115 | num_batches = num_pts // num_samples 116 | for i in range(num_batches): 117 | sdf[i * num_samples:i * num_samples + num_samples], xyz[:, i * num_samples:i * num_samples + num_samples] = eval_func( 118 | points[:, i * num_samples:i * num_samples + num_samples]) 119 | if num_pts % num_samples: 120 | sdf[num_batches * num_samples:], xyz[:, num_batches * num_samples:] = eval_func(points[:, num_batches * num_samples:]) 121 | 122 | return sdf, xyz 123 | 124 | def eval_sdf_xyz_grid(coords, eval_func, num_samples=512 * 512 * 512): 125 | resolution = coords.shape[1:4] 126 | coords = coords.reshape([3, -1]) 127 | sdf, xyz = batch_eval_sdf_xyz(coords, eval_func, num_samples=num_samples) 128 | return sdf, xyz 129 | 130 | def eval_sdf_xyz_grid_frustum(coords, eval_func, num_samples=512 * 512 * 512): 131 | # coords = coords.reshape([3, -1]) 132 | sdf, xyz = batch_eval_sdf_xyz(coords, eval_func, num_samples=num_samples) 133 | return sdf, xyz 134 | -------------------------------------------------------------------------------- /lib/sym_util.py: -------------------------------------------------------------------------------- 1 | # Author: Tomas Hodan (hodantom@cmp.felk.cvut.cz) 2 | # Center for Machine Perception, Czech Technical University in Prague 3 | 4 | """Parameters of the BOP datasets.""" 5 | 6 | import os 7 | import pdb 8 | import json 9 | 10 | import math 11 | import numpy as np 12 | 13 | def load_json(path, keys_to_int=False): 14 | """Loads content of a JSON file. 15 | 16 | :param path: Path to the JSON file. 17 | :return: Content of the loaded JSON file. 18 | """ 19 | # Keys to integers. 20 | def convert_keys_to_int(x): 21 | return {int(k) if k.lstrip('-').isdigit() else k: v for k, v in x.items()} 22 | 23 | with open(path, 'r') as f: 24 | if keys_to_int: 25 | content = json.load(f, object_hook=lambda x: convert_keys_to_int(x)) 26 | else: 27 | content = json.load(f) 28 | 29 | return content 30 | 31 | def get_obj_params(models_path, dataset_name): 32 | """Returns parameters of object models for the specified dataset. 33 | 34 | :param models_path: Path to a folder with models. 35 | :param dataset_name: Name of the dataset for which to return the parameters. 36 | :return: Dictionary with object model parameters for the specified dataset. 37 | """ 38 | # Object ID's. 39 | obj_ids = { 40 | 'lm': list(range(1, 16)), 41 | 'lmo': [1, 5, 6, 8, 9, 10, 11, 12], 42 | 'tudl': list(range(1, 4)), 43 | 'tyol': list(range(1, 22)), 44 | 'ruapc': list(range(1, 15)), 45 | 'icmi': list(range(1, 7)), 46 | 'icbin': list(range(1, 3)), 47 | 'itodd': list(range(1, 29)), 48 | 'hbs': [1, 3, 4, 8, 9, 10, 12, 15, 17, 18, 19, 22, 23, 29, 32, 33], 49 | 'hb': list(range(1, 34)), # Full HB dataset. 50 | 'ycbv': list(range(1, 22)), 51 | 'hope': list(range(1, 29)), 52 | }[dataset_name] 53 | 54 | # ID's of objects with ambiguous views evaluated using the ADI pose error 55 | # function (the others are evaluated using ADD). See Hodan et al. (ECCVW'16). 56 | symmetric_obj_ids = { 57 | 'lm': [3, 7, 10, 11], 58 | 'lmo': [10, 11], 59 | 'tudl': [], 60 | 'tyol': [3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 15, 16, 17, 18, 19, 21], 61 | 'ruapc': [8, 9, 12, 13], 62 | 'icmi': [1, 2, 6], 63 | 'icbin': [1], 64 | 'itodd': [2, 3, 4, 5, 7, 8, 9, 11, 12, 14, 17, 18, 19, 23, 24, 25, 27, 28], 65 | 'hbs': [10, 12, 18, 29], 66 | 'hb': [6, 10, 11, 12, 13, 14, 18, 24, 29], 67 | 'ycbv': [1, 13, 14, 16, 18, 19, 20, 21], 68 | 'hope': None, # Not defined yet. 69 | }[dataset_name] 70 | 71 | # Both versions of the HB dataset share the same directory. 72 | if dataset_name == 'hbs': 73 | dataset_name = 'hb' 74 | 75 | p = { 76 | # ID's of all objects included in the dataset. 77 | 'obj_ids': obj_ids, 78 | 79 | # ID's of objects with symmetries. 80 | 'symmetric_obj_ids': symmetric_obj_ids, 81 | 82 | # Path to a file with meta information about the object models. 83 | 'models_info_path': os.path.join(models_path, 'models_info.json') 84 | } 85 | 86 | return p 87 | 88 | def unit_vector(data, axis=None, out=None): 89 | """Return ndarray normalized by length, i.e. Euclidean norm, along axis. 90 | 91 | >>> v0 = numpy.random.random(3) 92 | >>> v1 = unit_vector(v0) 93 | >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0)) 94 | True 95 | >>> v0 = numpy.random.rand(5, 4, 3) 96 | >>> v1 = unit_vector(v0, axis=-1) 97 | >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2) 98 | >>> numpy.allclose(v1, v2) 99 | True 100 | >>> v1 = unit_vector(v0, axis=1) 101 | >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1) 102 | >>> numpy.allclose(v1, v2) 103 | True 104 | >>> v1 = numpy.empty((5, 4, 3)) 105 | >>> unit_vector(v0, axis=1, out=v1) 106 | >>> numpy.allclose(v1, v2) 107 | True 108 | >>> list(unit_vector([])) 109 | [] 110 | >>> list(unit_vector([1])) 111 | [1.0] 112 | 113 | """ 114 | if out is None: 115 | data = np.array(data, dtype=np.float64, copy=True) 116 | if data.ndim == 1: 117 | data /= math.sqrt(np.dot(data, data)) 118 | return data 119 | else: 120 | if out is not data: 121 | out[:] = np.array(data, copy=False) 122 | data = out 123 | length = np.atleast_1d(np.sum(data * data, axis)) 124 | np.sqrt(length, length) 125 | if axis is not None: 126 | length = np.expand_dims(length, axis) 127 | data /= length 128 | if out is None: 129 | return data 130 | 131 | def rotation_matrix(angle, direction, point=None): 132 | """Return matrix to rotate about axis defined by point and direction. 133 | 134 | >>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0]) 135 | >>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1]) 136 | True 137 | >>> angle = (random.random() - 0.5) * (2*math.pi) 138 | >>> direc = numpy.random.random(3) - 0.5 139 | >>> point = numpy.random.random(3) - 0.5 140 | >>> R0 = rotation_matrix(angle, direc, point) 141 | >>> R1 = rotation_matrix(angle-2*math.pi, direc, point) 142 | >>> is_same_transform(R0, R1) 143 | True 144 | >>> R0 = rotation_matrix(angle, direc, point) 145 | >>> R1 = rotation_matrix(-angle, -direc, point) 146 | >>> is_same_transform(R0, R1) 147 | True 148 | >>> I = numpy.identity(4, numpy.float64) 149 | >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc)) 150 | True 151 | >>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2, 152 | ... direc, point))) 153 | True 154 | 155 | """ 156 | sina = math.sin(angle) 157 | cosa = math.cos(angle) 158 | direction = unit_vector(direction[:3]) 159 | # rotation matrix around unit vector 160 | R = np.diag([cosa, cosa, cosa]) 161 | R += np.outer(direction, direction) * (1.0 - cosa) 162 | direction *= sina 163 | R += np.array([[0.0, -direction[2], direction[1]], 164 | [direction[2], 0.0, -direction[0]], 165 | [-direction[1], direction[0], 0.0]]) 166 | M = np.identity(4) 167 | M[:3, :3] = R 168 | if point is not None: 169 | # rotation not around origin 170 | point = np.array(point[:3], dtype=np.float64, copy=False) 171 | M[:3, 3] = point - np.dot(R, point) 172 | return M 173 | 174 | def get_symmetry_transformations(model_info, max_sym_disc_step): 175 | """Returns a set of symmetry transformations for an object model. 176 | 177 | :param model_info: See files models_info.json provided with the datasets. 178 | :param max_sym_disc_step: The maximum fraction of the object diameter which 179 | the vertex that is the furthest from the axis of continuous rotational 180 | symmetry travels between consecutive discretized rotations. 181 | :return: The set of symmetry transformations. 182 | """ 183 | # Discrete symmetries. 184 | trans_disc = [{'R': np.eye(3), 't': np.array([[0, 0, 0]]).T}] # Identity. 185 | if 'symmetries_discrete' in model_info: 186 | for sym in model_info['symmetries_discrete']: 187 | sym_4x4 = np.reshape(sym, (4, 4)) 188 | R = sym_4x4[:3, :3] 189 | t = sym_4x4[:3, 3].reshape((3, 1)) 190 | trans_disc.append({'R': R, 't': t}) 191 | 192 | # Discretized continuous symmetries. 193 | trans_cont = [] 194 | if 'symmetries_continuous' in model_info: 195 | for sym in model_info['symmetries_continuous']: 196 | axis = np.array(sym['axis']) 197 | offset = np.array(sym['offset']).reshape((3, 1)) 198 | 199 | # (PI * diam.) / (max_sym_disc_step * diam.) = discrete_steps_count 200 | discrete_steps_count = int(np.ceil(np.pi / max_sym_disc_step)) 201 | 202 | # Discrete step in radians. 203 | discrete_step = 2.0 * np.pi / discrete_steps_count 204 | 205 | for i in range(1, discrete_steps_count): 206 | R = rotation_matrix(i * discrete_step, axis)[:3, :3] 207 | t = -R.dot(offset) + offset 208 | trans_cont.append({'R': R, 't': t}) 209 | 210 | # Combine the discrete and the discretized continuous symmetries. 211 | trans = [] 212 | for tran_disc in trans_disc: 213 | if len(trans_cont): 214 | for tran_cont in trans_cont: 215 | R = tran_cont['R'].dot(tran_disc['R']) 216 | t = tran_cont['R'].dot(tran_disc['t']) + tran_cont['t'] 217 | trans.append({'R': R, 't': t}) 218 | else: 219 | trans.append(tran_disc) 220 | 221 | return trans 222 | 223 | 224 | if __name__ == '__main__': 225 | # PARAMETERS. 226 | ################################################################################ 227 | p = { 228 | # See dataset_params.py for options. 229 | 'dataset': 'ycbv', 230 | 231 | # See misc.get_symmetry_transformations(). 232 | 'max_sym_disc_step': 0.01, 233 | 234 | # Folder containing the BOP datasets. 235 | 'models_path': '/mnt/data0/lin/bop_datasets/ycbv/models', 236 | 237 | } 238 | ################################################################################ 239 | 240 | # Load dataset parameters. 241 | obj_params = get_obj_params(p['models_path'], p['dataset']) 242 | 243 | # Load meta info about the models (including symmetries). 244 | models_info = load_json(obj_params['models_info_path'], keys_to_int=True) 245 | 246 | # for obj_id in obj_params['obj_ids']: 247 | import torch 248 | sym_pool=[] 249 | obj_id = 13 250 | sym_poses = get_symmetry_transformations(models_info[obj_id], p['max_sym_disc_step']) 251 | for sym_pose in sym_poses: 252 | Rt = np.concatenate([sym_pose['R'], sym_pose['t'].reshape(3,1)], axis=1) 253 | Rt = np.concatenate([Rt, np.array([0, 0, 0, 1]).reshape(1, 4)], axis=0) 254 | sym_pool.append(torch.Tensor(Rt)) 255 | pdb.set_trace() 256 | --------------------------------------------------------------------------------