├── .flake8 ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── INSTRUCTIONS_PIX3D.md ├── INSTRUCTIONS_SHAPENET.md ├── LICENSE ├── README.md ├── configs ├── pix3d │ ├── MESH-RCNN-FPN.yaml │ ├── meshrcnn_R50_FPN.yaml │ ├── pixel2mesh_R50_FPN.yaml │ ├── sphereinit_R50_FPN.yaml │ └── voxelrcnn_R50_FPN.yaml └── shapenet │ ├── pixel2mesh_R50.yaml │ ├── sphereinit_R50.yaml │ └── voxmesh_R50.yaml ├── datasets ├── pix3d │ └── download_pix3d.sh └── shapenet │ └── download_shapenet.sh ├── demo ├── README.md └── demo.py ├── infra └── linter.sh ├── meshrcnn ├── config │ ├── __init__.py │ └── config.py ├── data │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── builtin.py │ │ └── pix3d.py │ └── meshrcnn_transforms.py ├── evaluation │ ├── __init__.py │ └── pix3d_evaluation.py ├── modeling │ ├── __init__.py │ └── roi_heads │ │ ├── __init__.py │ │ ├── mask_head.py │ │ ├── mesh_head.py │ │ ├── roi_heads.py │ │ ├── voxel_head.py │ │ └── z_head.py ├── structures │ ├── __init__.py │ ├── mask.py │ ├── mesh.py │ └── voxel.py └── utils │ ├── VOCap.py │ ├── __init__.py │ ├── metrics.py │ ├── model_zoo.py │ ├── projtransform.py │ ├── shape.py │ └── vis.py ├── setup.cfg ├── setup.py ├── shapenet ├── config │ ├── __init__.py │ └── config.py ├── data │ ├── __init__.py │ ├── build_data_loader.py │ ├── builtin.py │ ├── mesh_vox.py │ └── utils.py ├── evaluation │ ├── __init__.py │ └── eval.py ├── modeling │ ├── __init__.py │ ├── backbone.py │ ├── heads │ │ ├── __init__.py │ │ ├── mesh_head.py │ │ ├── mesh_loss.py │ │ └── voxel_head.py │ └── mesh_arch.py ├── solver │ ├── __init__.py │ ├── build.py │ └── lr_schedule.py └── utils │ ├── __init__.py │ ├── binvox_torch.py │ ├── checkpoint.py │ ├── coords.py │ ├── defaults.py │ ├── model_zoo.py │ ├── timing.py │ └── vis.py └── tools ├── convert_cocomodel_for_init.py ├── preprocess_shapenet.py ├── train_net.py └── train_net_shapenet.py /.flake8: -------------------------------------------------------------------------------- 1 | # This is an example .flake8 config, used when developing *Black* itself. 2 | # Keep in sync with setup.cfg which is used for source packages. 3 | 4 | [flake8] 5 | ignore = E203, E266, E501, W503, E221 6 | max-line-length = 100 7 | max-complexity = 18 8 | select = B,C,E,F,W,T4,B9 9 | exclude = build,__init__.py 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output 3 | output_demo 4 | 5 | 6 | *.jpg 7 | *.png 8 | *.txt 9 | 10 | # compilation and distribution 11 | __pycache__ 12 | _ext 13 | *.pyc 14 | *.so 15 | meshrcnn.egg-info/ 16 | build/ 17 | dist/ 18 | 19 | # pytorch/python/numpy formats 20 | *.pth 21 | *.pkl 22 | *.npy 23 | 24 | # ipython/jupyter notebooks 25 | *.ipynb 26 | **/.ipynb_checkpoints/ 27 | 28 | # Editor temporaries 29 | *.swn 30 | *.swo 31 | *.swp 32 | *~ 33 | 34 | # Pycharm editor settings 35 | .idea 36 | 37 | # project dirs 38 | /models 39 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to meshrcnn 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | While this codebase focuses on reproducing our paper results, we believe it could 7 | be extended and improved in many ways. Thus, feel free to suggest improvements 8 | or extensions. 9 | 10 | ## Pull Requests 11 | We actively welcome your pull requests. 12 | 13 | 1. Fork the repo and create your branch from `main`. 14 | 2. If you've changed APIs, update the documentation. 15 | 4. Train and report performance of all models (and baselines if applicable). 16 | 5. Make sure your code lints `./infra/linter.sh`. 17 | 6. If a PR contains multiple orthogonal changes, split it to several PRs. 18 | 7. If you haven't already, complete the Contributor License Agreement ("CLA"). 19 | 20 | ## Contributor License Agreement ("CLA") 21 | In order to accept your pull request, we need you to submit a CLA. You only need 22 | to do this once to work on any of Facebook's open source projects. 23 | 24 | Complete your CLA here: 25 | 26 | ## License 27 | By contributing to meshrcnn, you agree that your contributions will be licensed 28 | under the LICENSE file in the root directory of this source tree. 29 | -------------------------------------------------------------------------------- /INSTRUCTIONS_PIX3D.md: -------------------------------------------------------------------------------- 1 | # Experiments on Pix3D 2 | 3 | ## Download Pix3D and splits 4 | 5 | Run 6 | 7 | ``` 8 | datasets/pix3d/download_pix3d.sh 9 | ``` 10 | 11 | to download [Pix3D][pix3d] and the `S1` & `S2` splits to `./datasets/pix3d/` 12 | 13 | ## Training 14 | 15 | ``` 16 | python tools/train_net.py --num-gpus 8 \ 17 | --config-file configs/pix3d/meshrcnn_R50_FPN.yaml 18 | ``` 19 | 20 | *Note* that the above config is tuned for 8-gpu distributed training. 21 | Deviation from the provided training recipe means that other hyper parameters have to be tuned accordingly. 22 | 23 | ## Testing and Evaluation 24 | 25 | ``` 26 | python tools/train_net.py \ 27 | --config-file configs/pix3d/meshrcnn_R50_FPN.yaml \ 28 | --eval-only MODEL.WEIGHTS /path/to/checkpoint_file 29 | ``` 30 | 31 | If you wish to evaluate the provided pretrained models (see below for a model zoo), simply do `MODEL.WEIGHTS meshrcnn://meshrcnn_R50.pth`. *Note* that by default, the config files use the `S1` split.To change between `S1` and `S2`, specify the split in the `DATASETS` section in the config file. 32 | 33 | ## Models 34 | 35 | We provide a model zoo for models trained on Pix3D `S1` & `S2` splits (see paper for more details). 36 | 37 | | | Mesh R-CNN | Pixel2Mesh | SphereInit | 38 | |------|:-------------------------:|:----------------------------:|:----------------------------:| 39 | | `S1` | [meshrcnn_R50.pth][m1] | [pixel2mesh_R50.pth][pm1] | [sphereinit_R50.pth][sp1] | 40 | | `S2` | [meshrcnn_S2_R50.pth][m2] | [pixel2mesh_S2_R50.pth][pm2] | [sphereinit_S2_R50.pth][sp2] | 41 | 42 | [pix3d]: http://pix3d.csail.mit.edu/data/pix3d.zip 43 | [m1]: https://dl.fbaipublicfiles.com/meshrcnn/pix3d/meshrcnn_R50.pth 44 | [m2]: https://dl.fbaipublicfiles.com/meshrcnn/pix3d/meshrcnn_S2_R50.pth 45 | [pm1]: https://dl.fbaipublicfiles.com/meshrcnn/pix3d/pixel2mesh_R50.pth 46 | [pm2]: https://dl.fbaipublicfiles.com/meshrcnn/pix3d/pixel2mesh_S2_R50.pth 47 | [sp1]: https://dl.fbaipublicfiles.com/meshrcnn/pix3d/sphereinit_R50.pth 48 | [sp2]: https://dl.fbaipublicfiles.com/meshrcnn/pix3d/sphereinit_S2_R50.pth 49 | -------------------------------------------------------------------------------- /INSTRUCTIONS_SHAPENET.md: -------------------------------------------------------------------------------- 1 | # Experiments on ShapeNet 2 | 3 | ## Data 4 | 5 | We use [ShapeNet][shapenet] data and their renderings, as provided by [R2N2][r2n2]. 6 | 7 | Run 8 | 9 | ``` 10 | datasets/shapenet/download_shapenet.sh 11 | ``` 12 | 13 | to download [R2N2][r2n2], and the train/val/test splits. 14 | You also need the original ShapeNet Core v1 & binvox dataset, which require [registration][shapenet_login] before downloading. 15 | 16 | ## Preprocessing 17 | 18 | ``` 19 | python tools/preprocess_shapenet.py \ 20 | --shapenet_dir /path/to/ShapeNetCore.v1 \ 21 | --shapenet_binvox_dir /path/to/ShapeNetCore.v1.binvox \ 22 | --output_dir ./datasets/shapenet/ShapeNetV1processed \ 23 | --zip_output 24 | ``` 25 | 26 | The above command preprocesses the ShapeNet dataset to reduce the data loading time. 27 | The preprocessed data will be saved in `./datasets/shapenet` and will be zipped. 28 | The zipped output is useful when training in clusters. 29 | 30 | ## Training 31 | 32 | ``` 33 | python tools/train_net_shapenet.py --num-gpus 8 \ 34 | --config-file configs/shapenet/voxmesh_R50.yaml 35 | ``` 36 | 37 | When `--copy_data`, the preprocessed zipped data from above will be copied to a local `/tmp` directory. 38 | This is particularly useful when training on remote clusters, as it reduces the io time during training 39 | 40 | ## Testing and Evaluation 41 | 42 | ``` 43 | python tools/train_net_shapenet.py --eval-only --num-gpus 1 \ 44 | --config-file configs/shapenet/voxmesh_R50.yaml \ 45 | MODEL.CHECKPOINT shapenet://voxmesh_R50.pth 46 | ``` 47 | 48 | The output of the evaluation produces the results as shown in Table 2 of our paper. 49 | To evaluate under the [Pixel2Mesh][p2m] protocol, as in Table 1 of our paper, add `--eval-p2m`. 50 | 51 | ## Models 52 | 53 | | Mesh R-CNN | Pixel2Mesh | SphereInit | 54 | |-----------------------------:|:----------------------------:|:----------------------------:| 55 | | [voxmesh_R50.pth][voxm] | [pixel2mesh_R50.pth][pm] | [sphereinit_R50.pth][sp] | 56 | 57 | Note that we release only the *light* and *pretty* for both our and the baseline models. 58 | 59 | ## Performance 60 | 61 | ### Scale-normalized Protocol 62 | 63 | Performance of our [model][voxm] on ShapeNet `test` set under the scale-normalized evaluation protocol (as in Table 2 of our paper). 64 | 65 | | category | #instances | chamfer | normal | F1(0.1) | F1(0.3) | F1(0.5) | 66 | |:------------:|:-------------|:----------|:---------|:----------|:----------|:----------| 67 | | bench | 8712 | 0.120899 | 0.657536 | 42.4005 | 86.0036 | 95.128 | 68 | | chair | 32520 | 0.183693 | 0.712362 | 31.6906 | 79.8275 | 92.0139 | 69 | | lamp | 11122 | 0.413965 | 0.672992 | 30.5048 | 70.3449 | 84.5068 | 70 | | speaker | 7752 | 0.253796 | 0.730829 | 24.8335 | 74.6606 | 88.237 | 71 | | firearm | 11386 | 0.168323 | 0.621439 | 47.2251 | 85.271 | 93.8171 | 72 | | table | 40796 | 0.148357 | 0.75642 | 42.249 | 86.2039 | 94.1623 | 73 | | watercraft | 9298 | 0.224168 | 0.642812 | 30.0589 | 75.5332 | 89.9764 | 74 | | plane | 19416 | 0.187465 | 0.684285 | 39.009 | 80.998 | 92.1069 | 75 | | cabinet | 7541 | 0.111294 | 0.75122 | 34.8227 | 86.9346 | 95.371 | 76 | | car | 35981 | 0.107605 | 0.647857 | 29.6397 | 85.7925 | 96.2938 | 77 | | monitor | 5256 | 0.218032 | 0.779365 | 27.2531 | 77.2979 | 90.904 | 78 | | couch | 15226 | 0.144279 | 0.72302 | 27.5734 | 81.684 | 94.3294 | 79 | | cellphone | 5045 | 0.121504 | 0.850437 | 42.9168 | 88.9888 | 96.1367 | 80 | | | | | | | | | 81 | | total | 210051 | 0.184875 | 0.710044 | 34.629 | 81.5031 | 92.5372 | 82 | | | | | | | | | 83 | | per-instance | 210051 | 0.171189 | 0.70275 | 34.9372 | 82.4107 | 93.1323 | 84 | 85 | 86 | ### Pixel2Mesh Protocol 87 | 88 | Performance of our [model][voxm] on ShapeNet `test` set under the pixel2mesh evaluation protocol (as in Table 1 of our paper). To evaluate under this protocol, add the `--eval-p2m` flag. 89 | 90 | | category | #instances | chamfer | normal | F1(0.0001) | F1(0.0002) | 91 | |:------------:|:-------------|:------------|:---------|:-------------|:-------------| 92 | | bench | 8712 | 0.000295252 | 0.657508 | 73.4681 | 84.4999 | 93 | | chair | 32520 | 0.000400415 | 0.712348 | 66.5227 | 79.3634 | 94 | | lamp | 11122 | 0.000788915 | 0.673057 | 60.1057 | 71.7711 | 95 | | speaker | 7752 | 0.000582152 | 0.730797 | 59.9974 | 73.8792 | 96 | | firearm | 11386 | 0.000357016 | 0.621438 | 75.9761 | 85.5111 | 97 | | table | 40796 | 0.000342991 | 0.756442 | 76.0776 | 85.4878 | 98 | | watercraft | 9298 | 0.000449061 | 0.642791 | 62.808 | 76.5464 | 99 | | plane | 19416 | 0.000313141 | 0.684333 | 75.8104 | 85.3897 | 100 | | cabinet | 7541 | 0.000293613 | 0.751306 | 72.6302 | 84.7327 | 101 | | car | 35981 | 0.000240585 | 0.647896 | 71.8118 | 85.5155 | 102 | | monitor | 5256 | 0.000470965 | 0.779397 | 64.2917 | 77.8422 | 103 | | couch | 15226 | 0.000355369 | 0.723013 | 64.1388 | 79.327 | 104 | | cellphone | 5045 | 0.000280397 | 0.850456 | 77.3011 | 87.8698 | 105 | | | | | | | | 106 | | total | 210051 | 0.000397682 | 0.71006 | 69.303 | 81.3643 | 107 | | | | | | | | 108 | | per-instance | 210051 | 0.000368317 | 0.702768 | 70.4479 | 82.3373 | 109 | 110 | 111 | [shapenet]: http://shapenet.cs.stanford.edu/ 112 | [shapenet_login]: https://www.shapenet.org/login/ 113 | [r2n2_data]: http://cvgl.stanford.edu/data2/ShapeNetRendering.tgz 114 | [r2n2]: http://3d-r2n2.stanford.edu/ 115 | [p2m]: https://github.com/nywang16/Pixel2Mesh 116 | [voxm]: https://dl.fbaipublicfiles.com/meshrcnn/shapenet/voxmesh_R50.pth 117 | [pm]: https://dl.fbaipublicfiles.com/meshrcnn/shapenet/pixel2mesh_R50.pth 118 | [sp]: https://dl.fbaipublicfiles.com/meshrcnn/shapenet/sphereinit_R50.pth 119 | 120 | 125 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For meshrcnn software 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mesh R-CNN 2 | 3 | Code for the paper 4 | 5 | **[Mesh R-CNN][1]** 6 | [Georgia Gkioxari][gg], Jitendra Malik, [Justin Johnson][jj] 7 | ICCV 2019 8 | 9 |
10 | 11 |
12 | 13 |   14 | 15 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eQLZrNYRZMo9zdnGGccE0hFswGiinO-Z?usp=sharing) 16 | 17 | (thanks to [Alberto Tono][at]!) 18 | 19 | ## Installation Requirements 20 | - [Detectron2][d2] 21 | - [PyTorch3D][py3d] 22 | 23 | The implementation of Mesh R-CNN is based on [Detectron2][d2] and [PyTorch3D][py3d]. 24 | You will first need to install those in order to be able to run Mesh R-CNN. 25 | 26 | To install 27 | ``` 28 | git clone https://github.com/facebookresearch/meshrcnn.git 29 | cd meshrcnn && pip install -e . 30 | ``` 31 | 32 | ## Demo 33 | 34 | Run Mesh R-CNN on an input image 35 | 36 | ``` 37 | python demo/demo.py \ 38 | --config-file configs/pix3d/meshrcnn_R50_FPN.yaml \ 39 | --input /path/to/image \ 40 | --output output_demo \ 41 | --onlyhighest MODEL.WEIGHTS meshrcnn://meshrcnn_R50.pth 42 | ``` 43 | 44 | See [demo.py](demo/demo.py) for more details. 45 | 46 | ## Running Experiments 47 | 48 | ### Pix3D 49 | See [INSTRUCTIONS_PIX3D.md](INSTRUCTIONS_PIX3D.md) for more instructions. 50 | 51 | ### ShapeNet 52 | See [INSTRUCTIONS_SHAPENET.md](INSTRUCTIONS_SHAPENET.md) for more instructions. 53 | 54 | ## License 55 | The Mesh R-CNN codebase is released under [BSD-3-Clause License](LICENSE) 56 | 57 | [1]: https://arxiv.org/abs/1906.02739 58 | [gg]: https://github.com/gkioxari 59 | [jj]: https://github.com/jcjohnson 60 | [d2]: https://github.com/facebookresearch/detectron2 61 | [py3d]: https://github.com/facebookresearch/pytorch3d 62 | [at]: https://github.com/albertotono 63 | -------------------------------------------------------------------------------- /configs/pix3d/MESH-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | POOLER_SAMPLING_RATIO: 2 29 | POOLER_TYPE: "ROIAlign" 30 | ROI_MASK_HEAD: 31 | NAME: "MaskRCNNConvUpsampleHead" 32 | NUM_CONV: 4 33 | POOLER_RESOLUTION: 14 34 | POOLER_SAMPLING_RATIO: 2 35 | POOLER_TYPE: "ROIAlign" 36 | ROI_VOXEL_HEAD: 37 | NAME: "MaskRCNNConvUpsampleHead" 38 | NUM_CONV: 4 39 | POOLER_RESOLUTION: 14 40 | POOLER_SAMPLING_RATIO: 2 41 | POOLER_TYPE: "ROIAlign" 42 | ROI_MESH_HEAD: 43 | NAME: "MeshRCNNGraphConvHead" 44 | POOLER_RESOLUTION: 14 45 | POOLER_SAMPLING_RATIO: 2 46 | POOLER_TYPE: "ROIAlign" 47 | DATASETS: 48 | TRAIN: ("pix3d_train",) 49 | TEST: ("pix3d_test",) 50 | SOLVER: 51 | IMS_PER_BATCH: 16 52 | BASE_LR: 0.02 53 | STEPS: (60000, 80000) 54 | MAX_ITER: 90000 55 | VERSION: 2 56 | -------------------------------------------------------------------------------- /configs/pix3d/meshrcnn_R50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "MESH-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "meshrcnn://coco_init_0719.pth" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | MASK_ON: True 5 | VOXEL_ON: True 6 | MESH_ON: True 7 | ZPRED_ON: True 8 | RESNETS: 9 | DEPTH: 50 10 | RPN: 11 | IOU_THRESHOLDS: [0.2, 0.5, 0.7] 12 | IOU_LABELS: [-1, 0, -1, 1] 13 | SMOOTH_L1_BETA: 0.111 14 | ROI_HEADS: 15 | NAME: "MeshRCNNROIHeads" 16 | BATCH_SIZE_PER_IMAGE: 64 17 | NUM_CLASSES: 9 # Number of foreground classes 18 | IOU_THRESHOLDS: [0.2, 0.5] 19 | IOU_LABELS: [-1, 0, 1] 20 | ROI_BOX_HEAD: 21 | SMOOTH_L1_BETA: 1.0 22 | ROI_Z_HEAD: 23 | NAME: "FastRCNNFCHead" 24 | Z_REG_WEIGHT: 1.0 25 | SMOOTH_L1_BETA: 1.0 26 | ROI_MASK_HEAD: 27 | NAME: "MaskRCNNConvUpsampleHead" 28 | POOLER_RESOLUTION: 14 29 | POOLER_SAMPLING_RATIO: 2 30 | NUM_CONV: 4 31 | ROI_VOXEL_HEAD: 32 | NAME: "VoxelRCNNConvUpsampleHead" 33 | POOLER_RESOLUTION: 12 34 | POOLER_SAMPLING_RATIO: 2 35 | NUM_CONV: 4 36 | NUM_DEPTH: 24 37 | CLS_AGNOSTIC_VOXEL: True 38 | LOSS_WEIGHT: 3.0 39 | CUBIFY_THRESH: 0.2 40 | ROI_MESH_HEAD: 41 | NAME: "MeshRCNNGraphConvHead" 42 | POOLER_RESOLUTION: 14 43 | POOLER_SAMPLING_RATIO: 2 44 | NUM_STAGES: 3 45 | NUM_GRAPH_CONVS: 3 46 | GRAPH_CONV_DIM: 128 47 | GRAPH_CONV_INIT: "normal" 48 | GT_COORD_THRESH: 5.0 49 | CHAMFER_LOSS_WEIGHT: 1.0 50 | NORMALS_LOSS_WEIGHT: 0.1 51 | EDGE_LOSS_WEIGHT: 1.0 52 | DATASETS: 53 | TRAIN: ("pix3d_s1_train",) 54 | TEST: ("pix3d_s1_test",) 55 | SOLVER: 56 | BASE_LR: 0.02 57 | WEIGHT_DECAY: 0.0001 58 | STEPS: (8000, 10000) 59 | MAX_ITER: 11000 60 | WARMUP_ITERS: 1000 61 | WARMUP_FACTOR: 0.1 62 | -------------------------------------------------------------------------------- /configs/pix3d/pixel2mesh_R50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "MESH-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "meshrcnn://coco_init_0719.pth" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | MASK_ON: True 5 | VOXEL_ON: False 6 | MESH_ON: True 7 | ZPRED_ON: True 8 | RESNETS: 9 | DEPTH: 50 10 | RPN: 11 | IOU_THRESHOLDS: [0.2, 0.5, 0.7] 12 | IOU_LABELS: [-1, 0, -1, 1] 13 | SMOOTH_L1_BETA: 0.111 14 | ROI_HEADS: 15 | NAME: "MeshRCNNROIHeads" 16 | BATCH_SIZE_PER_IMAGE: 64 17 | NUM_CLASSES: 9 # Number of foreground classes 18 | IOU_THRESHOLDS: [0.2, 0.5] 19 | IOU_LABELS: [-1, 0, 1] 20 | ROI_BOX_HEAD: 21 | SMOOTH_L1_BETA: 1.0 22 | ROI_Z_HEAD: 23 | NAME: "FastRCNNFCHead" 24 | Z_REG_WEIGHT: 1.0 25 | SMOOTH_L1_BETA: 1.0 26 | ROI_MASK_HEAD: 27 | NAME: "MaskRCNNConvUpsampleHead" 28 | POOLER_RESOLUTION: 14 29 | POOLER_SAMPLING_RATIO: 2 30 | NUM_CONV: 4 31 | ROI_MESH_HEAD: 32 | NAME: "MeshRCNNGraphConvSubdHead" 33 | POOLER_RESOLUTION: 14 34 | POOLER_SAMPLING_RATIO: 2 35 | NUM_STAGES: 3 36 | NUM_GRAPH_CONVS: 3 37 | GRAPH_CONV_DIM: 128 38 | GRAPH_CONV_INIT: "normal" 39 | GT_COORD_THRESH: 5.0 40 | CHAMFER_LOSS_WEIGHT: 1.0 41 | NORMALS_LOSS_WEIGHT: 0.1 42 | EDGE_LOSS_WEIGHT: 1.0 43 | ICO_SPHERE_LEVEL: 2 44 | DATASETS: 45 | TRAIN: ("pix3d_s1_train",) 46 | TEST: ("pix3d_s1_test",) 47 | SOLVER: 48 | BASE_LR: 0.02 49 | WEIGHT_DECAY: 0.0001 50 | STEPS: (8000, 10000) 51 | MAX_ITER: 11000 52 | WARMUP_ITERS: 1000 53 | WARMUP_FACTOR: 0.1 54 | -------------------------------------------------------------------------------- /configs/pix3d/sphereinit_R50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "MESH-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "meshrcnn://coco_init_0719.pth" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | MASK_ON: True 5 | VOXEL_ON: False 6 | MESH_ON: True 7 | ZPRED_ON: True 8 | RESNETS: 9 | DEPTH: 50 10 | RPN: 11 | IOU_THRESHOLDS: [0.2, 0.5, 0.7] 12 | IOU_LABELS: [-1, 0, -1, 1] 13 | SMOOTH_L1_BETA: 0.111 14 | ROI_HEADS: 15 | NAME: "MeshRCNNROIHeads" 16 | BATCH_SIZE_PER_IMAGE: 64 17 | NUM_CLASSES: 9 # Number of foreground classes 18 | IOU_THRESHOLDS: [0.2, 0.5] 19 | IOU_LABELS: [-1, 0, 1] 20 | ROI_BOX_HEAD: 21 | SMOOTH_L1_BETA: 1.0 22 | ROI_Z_HEAD: 23 | NAME: "FastRCNNFCHead" 24 | Z_REG_WEIGHT: 1.0 25 | SMOOTH_L1_BETA: 1.0 26 | ROI_MASK_HEAD: 27 | NAME: "MaskRCNNConvUpsampleHead" 28 | POOLER_RESOLUTION: 14 29 | POOLER_SAMPLING_RATIO: 2 30 | NUM_CONV: 4 31 | ROI_MESH_HEAD: 32 | NAME: "MeshRCNNGraphConvHead" 33 | POOLER_RESOLUTION: 14 34 | POOLER_SAMPLING_RATIO: 2 35 | NUM_STAGES: 3 36 | NUM_GRAPH_CONVS: 3 37 | GRAPH_CONV_DIM: 128 38 | GRAPH_CONV_INIT: "normal" 39 | GT_COORD_THRESH: 5.0 40 | CHAMFER_LOSS_WEIGHT: 1.0 41 | NORMALS_LOSS_WEIGHT: 0.1 42 | EDGE_LOSS_WEIGHT: 1.0 43 | ICO_SPHERE_LEVEL: 4 44 | DATASETS: 45 | TRAIN: ("pix3d_s1_train",) 46 | TEST: ("pix3d_s1_test",) 47 | SOLVER: 48 | BASE_LR: 0.02 49 | WEIGHT_DECAY: 0.0001 50 | STEPS: (8000, 10000) 51 | MAX_ITER: 11000 52 | WARMUP_ITERS: 1000 53 | WARMUP_FACTOR: 0.1 54 | -------------------------------------------------------------------------------- /configs/pix3d/voxelrcnn_R50_FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "MESH-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "meshrcnn://coco_init_0719.pth" # "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | MASK_ON: True 5 | VOXEL_ON: True 6 | ZPRED_ON: True 7 | RESNETS: 8 | DEPTH: 50 9 | RPN: 10 | IOU_THRESHOLDS: [0.2, 0.5, 0.7] 11 | IOU_LABELS: [-1, 0, -1, 1] 12 | SMOOTH_L1_BETA: 0.111 13 | ROI_HEADS: 14 | NAME: "MeshRCNNROIHeads" 15 | BATCH_SIZE_PER_IMAGE: 64 16 | NUM_CLASSES: 9 # Number of foreground classes 17 | IOU_THRESHOLDS: [0.2, 0.5] 18 | IOU_LABELS: [-1, 0, 1] 19 | ROI_BOX_HEAD: 20 | SMOOTH_L1_BETA: 1.0 21 | ROI_Z_HEAD: 22 | NAME: "FastRCNNFCHead" 23 | Z_REG_WEIGHT: 1.0 24 | SMOOTH_L1_BETA: 1.0 25 | ROI_MASK_HEAD: 26 | NAME: "MaskRCNNConvUpsampleHead" 27 | POOLER_RESOLUTION: 14 28 | POOLER_SAMPLING_RATIO: 2 29 | NUM_CONV: 4 30 | ROI_VOXEL_HEAD: 31 | NAME: "VoxelRCNNConvUpsampleHead" 32 | POOLER_RESOLUTION: 12 33 | POOLER_SAMPLING_RATIO: 2 34 | NUM_CONV: 4 35 | NUM_DEPTH: 24 36 | CLS_AGNOSTIC_VOXEL: True 37 | LOSS_WEIGHT: 3.0 38 | CUBIFY_THRESH: 0.2 39 | DATASETS: 40 | TRAIN: ("pix3d_s1_train",) 41 | TEST: ("pix3d_s1_test",) 42 | SOLVER: 43 | BASE_LR: 0.02 44 | WEIGHT_DECAY: 0.0001 45 | STEPS: (8000, 10000) 46 | MAX_ITER: 11000 47 | WARMUP_ITERS: 1000 48 | WARMUP_FACTOR: 0.1 49 | -------------------------------------------------------------------------------- /configs/shapenet/pixel2mesh_R50.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: "resnet50" 3 | MESH_ON: True 4 | CHECKPOINT: "" 5 | MESH_HEAD: 6 | NAME: "Pixel2MeshHead" 7 | NUM_STAGES: 3 8 | NUM_GRAPH_CONVS: 3 9 | GRAPH_CONV_DIM: 128 10 | GRAPH_CONV_INIT: "normal" 11 | GT_NUM_SAMPLES: 5000 12 | PRED_NUM_SAMPLES: 5000 13 | CHAMFER_LOSS_WEIGHT: 1.0 14 | NORMALS_LOSS_WEIGHT: 0.0 15 | EDGE_LOSS_WEIGHT: 0.2 16 | ICO_SPHERE_LEVEL: 2 17 | DATASETS: 18 | NAME: "shapenet" 19 | SOLVER: 20 | BATCH_SIZE: 64 # 32 21 | BATCH_SIZE_EVAL: 8 22 | NUM_EPOCHS: 25 23 | BASE_LR: 0.0001 24 | OPTIMIZER: "adam" 25 | LR_SCHEDULER_NAME: "constant" 26 | LOGGING_PERIOD: 50 27 | -------------------------------------------------------------------------------- /configs/shapenet/sphereinit_R50.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: "resnet50" 3 | MESH_ON: True 4 | CHECKPOINT: "" 5 | MESH_HEAD: 6 | NAME: "SphereInitHead" 7 | NUM_STAGES: 3 8 | NUM_GRAPH_CONVS: 3 9 | GRAPH_CONV_DIM: 128 10 | GRAPH_CONV_INIT: "normal" 11 | GT_NUM_SAMPLES: 5000 12 | PRED_NUM_SAMPLES: 5000 13 | CHAMFER_LOSS_WEIGHT: 1.0 14 | NORMALS_LOSS_WEIGHT: 0.0 15 | EDGE_LOSS_WEIGHT: 0.2 16 | ICO_SPHERE_LEVEL: 4 17 | DATASETS: 18 | NAME: "shapenet" 19 | SOLVER: 20 | BATCH_SIZE: 64 # 32 21 | BATCH_SIZE_EVAL: 8 22 | NUM_EPOCHS: 25 23 | BASE_LR: 0.0001 24 | OPTIMIZER: "adam" 25 | LR_SCHEDULER_NAME: "constant" 26 | LOGGING_PERIOD: 50 27 | -------------------------------------------------------------------------------- /configs/shapenet/voxmesh_R50.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: "resnet50" 3 | VOXEL_ON: True 4 | MESH_ON: True 5 | CHECKPOINT: "" 6 | VOXEL_HEAD: 7 | NUM_CONV: 4 8 | CONV_DIM: 256 9 | VOXEL_SIZE: 48 10 | LOSS_WEIGHT: 1.0 11 | CUBIFY_THRESH: 0.2 12 | VOXEL_ONLY_ITERS: 100 13 | MESH_HEAD: 14 | NAME: "VoxMeshHead" 15 | NUM_STAGES: 3 16 | NUM_GRAPH_CONVS: 3 17 | GRAPH_CONV_DIM: 128 18 | GRAPH_CONV_INIT: "normal" 19 | GT_NUM_SAMPLES: 5000 20 | PRED_NUM_SAMPLES: 5000 21 | CHAMFER_LOSS_WEIGHT: 1.0 22 | NORMALS_LOSS_WEIGHT: 0.0 23 | EDGE_LOSS_WEIGHT: 0.2 24 | DATASETS: 25 | NAME: "shapenet" 26 | SOLVER: 27 | BATCH_SIZE: 64 # 32 28 | BATCH_SIZE_EVAL: 8 29 | NUM_EPOCHS: 25 30 | BASE_LR: 0.0001 31 | OPTIMIZER: "adam" 32 | LR_SCHEDULER_NAME: "constant" 33 | LOGGING_PERIOD: 50 34 | -------------------------------------------------------------------------------- /datasets/pix3d/download_pix3d.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | 5 | # Download the Pix3D dataset and split files 6 | 7 | cd "${0%/*}" 8 | 9 | wget http://pix3d.csail.mit.edu/data/pix3d.zip 10 | unzip -qq pix3d.zip 11 | 12 | BASE=https://dl.fbaipublicfiles.com/meshrcnn 13 | 14 | wget $BASE/pix3d/pix3d_s1_train.json 15 | wget $BASE/pix3d/pix3d_s1_test.json 16 | wget $BASE/pix3d/pix3d_s2_train.json 17 | wget $BASE/pix3d/pix3d_s2_test.json 18 | -------------------------------------------------------------------------------- /datasets/shapenet/download_shapenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | 5 | # Download R2N2 and associated splits 6 | # User needs to register and download ShapeNetCore.v1 and ShapeNetCore.v1.binvox 7 | 8 | cd "${0%/*}" 9 | 10 | # download r2n2 renderings 11 | wget http://cvgl.stanford.edu/data2/ShapeNetRendering.tgz 12 | tar -xvzf ShapeNetRendering.tgz 13 | 14 | # downloand splits 15 | BASE=https://dl.fbaipublicfiles.com/meshrcnn 16 | wget $BASE/shapenet/pix2mesh_splits_val05.json 17 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | Run Mesh R-CNN on an input image 2 | 3 | ``` 4 | python demo/demo.py \ 5 | --config-file configs/pix3d/meshrcnn_R50_FPN.yaml \ 6 | --input /path/to/image \ 7 | --output output_demo \ 8 | --onlyhighest MODEL.WEIGHTS meshrcnn://meshrcnn_R50.pth 9 | ``` 10 | 11 | The `--onlyhighest` flag will return the highest scoring object prediction. If you remove this flag, all predictions will be returned. 12 | 13 | Here are some notes to clarify and guide you how to use the outputs from Mesh R-CNN. 14 | 15 | ### What does Mesh R-CNN output? 16 | The Mesh R-CNN demo will detect the objects in the image from the Pix3D vocabulary of objects, along with their 2D bounding boxes, 2D instance masks and 3D meshes. For each detected object, Mesh R-CNN returns the 3D shape of an object in the camera coordinate system confined in a 3D box which respects the aspect ratio of the object detected in the image. If you provide the _focal length_ `f` of the camera and the actual depth location `t_z` of the object's center , i.e. how far the center of the object is from the image plane in the Z axis, then Mesh R-CNN will pixel align the predicted 3D object shape with the image *and* the prediction would correspond to the true metric size of the object - its actual scale in the real world!. 17 | 18 | ### Metric scale 19 | While most images nowadays have access to their focal_length `f` from the image metadata, knowing `t_z` is difficult. We could of course supervise for `t_z` but Pix3D does not contain useful object metric depth. In the Pix3D annotations, the tuple `(f, t_z)` provided does not correspond to the actual camera metadata nor metric depth of the object but is computed subsequently at annotation time by their annotation process and annotation tool and thus is somewhat adhoc. This is the reason we don't tackle the problem of estimating `t_z` (this problem is also called the scene layout prediction problem). 20 | 21 | ### I don't care about metric scale. I just want to pixel align via rendering. 22 | However if you don't care about metric scale and you only care about pixel aligning the object to the image, that is possible with our demo! **The demo runs with a default focal_length `f=20`** (this is the blender focal length assuming 32mm sensor width and is *not* the true focal_length of the image! We make it up!). The demo also places the object at some arbitrary `t_z > 0.0`, again this is not the true metric depth of the object. Given these choices of `(f, t_z)`, the demo will output an object shape placed at `t_z`. The metric size of the predicted object from the demo will not correspond to the true size of the object in the world, but it will be a scaled version of it. Now to pixel align the predicted shape with the image, **all you need to do is render the 3D mesh with `f=20`**. _Note that the value 20 is inconsequential. You would be getting the same pixel alignment if `f` was something else, but it's important that the value of `f` you pick when running the demo is also used when rendering!_ 23 | 24 | Here is an example! When I run the demo on an image from Pix3D (1st image), it recognizes the sofa (2nd image). I get a 3D shape prediction for the sofa which after **I render with blender with focal length `f=20`** I get the final result (3rd image). 25 | 26 | ![input](https://user-images.githubusercontent.com/4369065/77708628-cda99d00-6f85-11ea-949a-5dad891005ee.jpg) 27 | ![segmentation](https://user-images.githubusercontent.com/4369065/77709133-52e18180-6f87-11ea-901a-0706c3d4e3a3.png) 28 | ![rendered_output](https://user-images.githubusercontent.com/4369065/77708647-df8b4000-6f85-11ea-8d5f-4ae62ea3bf07.png) 29 | -------------------------------------------------------------------------------- /demo/demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | import argparse 5 | import logging 6 | import multiprocessing as mp 7 | import os 8 | 9 | import cv2 10 | 11 | # required so that .register() calls are executed in module scope 12 | import meshrcnn.data # noqa 13 | import meshrcnn.modeling # noqa 14 | import meshrcnn.utils # noqa 15 | import numpy as np 16 | import torch 17 | from detectron2.config import get_cfg 18 | from detectron2.data import MetadataCatalog 19 | from detectron2.data.detection_utils import read_image 20 | from detectron2.engine.defaults import DefaultPredictor 21 | from detectron2.utils.logger import setup_logger 22 | from meshrcnn.config import get_meshrcnn_cfg_defaults 23 | from meshrcnn.evaluation import transform_meshes_to_camera_coord_system 24 | from pytorch3d.io import save_obj 25 | from pytorch3d.structures import Meshes 26 | 27 | logger = logging.getLogger("demo") 28 | 29 | 30 | class VisualizationDemo: 31 | def __init__(self, cfg, vis_highest_scoring=True, output_dir="./vis"): 32 | """ 33 | Args: 34 | cfg (CfgNode): 35 | vis_highest_scoring (bool): If set to True visualizes only 36 | the highest scoring prediction 37 | """ 38 | self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) 39 | self.colors = self.metadata.thing_colors 40 | self.cat_names = self.metadata.thing_classes 41 | 42 | self.cpu_device = torch.device("cpu") 43 | self.vis_highest_scoring = vis_highest_scoring 44 | self.predictor = DefaultPredictor(cfg) 45 | 46 | os.makedirs(output_dir, exist_ok=True) 47 | self.output_dir = output_dir 48 | 49 | def run_on_image(self, image, focal_length=10.0): 50 | """ 51 | Args: 52 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 53 | This is the format used by OpenCV. 54 | focal_length (float): the focal_length of the image 55 | 56 | Returns: 57 | predictions (dict): the output of the model. 58 | """ 59 | predictions = self.predictor(image) 60 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 61 | image = image[:, :, ::-1] 62 | 63 | # camera matrix 64 | imsize = [image.shape[0], image.shape[1]] 65 | # focal <- focal * image_width / 32 66 | focal_length = image.shape[1] / 32 * focal_length 67 | K = [focal_length, image.shape[1] / 2, image.shape[0] / 2] 68 | 69 | if "instances" in predictions: 70 | instances = predictions["instances"].to(self.cpu_device) 71 | scores = instances.scores 72 | boxes = instances.pred_boxes 73 | labels = instances.pred_classes 74 | masks = instances.pred_masks 75 | meshes = Meshes( 76 | verts=[mesh[0] for mesh in instances.pred_meshes], 77 | faces=[mesh[1] for mesh in instances.pred_meshes], 78 | ) 79 | pred_dz = instances.pred_dz[:, 0] * ( 80 | boxes.tensor[:, 3] - boxes.tensor[:, 1] 81 | ) 82 | tc = pred_dz.abs().max() + 1.0 83 | zranges = torch.stack( 84 | [ 85 | torch.stack( 86 | [ 87 | tc - tc * pred_dz[i] / 2.0 / focal_length, 88 | tc + tc * pred_dz[i] / 2.0 / focal_length, 89 | ] 90 | ) 91 | for i in range(len(meshes)) 92 | ], 93 | dim=0, 94 | ) 95 | 96 | Ks = torch.tensor(K).to(self.cpu_device).view(1, 3).expand(len(meshes), 3) 97 | meshes = transform_meshes_to_camera_coord_system( 98 | meshes, boxes.tensor, zranges, Ks, imsize 99 | ) 100 | 101 | if self.vis_highest_scoring: 102 | det_ids = [scores.argmax().item()] 103 | else: 104 | det_ids = range(len(scores)) 105 | 106 | for det_id in det_ids: 107 | self.visualize_prediction( 108 | det_id, 109 | image, 110 | boxes.tensor[det_id], 111 | labels[det_id], 112 | scores[det_id], 113 | masks[det_id], 114 | meshes[det_id], 115 | ) 116 | 117 | return predictions 118 | 119 | def visualize_prediction( 120 | self, det_id, image, box, label, score, mask, mesh, alpha=0.6, dpi=200 121 | ): 122 | mask_color = np.array(self.colors[label], dtype=np.float32) 123 | cat_name = self.cat_names[label] 124 | thickness = max([int(np.ceil(0.001 * image.shape[0])), 1]) 125 | box_color = (0, 255, 0) # '#00ff00', green 126 | text_color = (218, 227, 218) # gray 127 | 128 | composite = image.copy().astype(np.float32) 129 | 130 | # overlay mask 131 | idx = mask.nonzero() 132 | composite[idx[:, 0], idx[:, 1], :] *= 1.0 - alpha 133 | composite[idx[:, 0], idx[:, 1], :] += alpha * mask_color 134 | 135 | # overlay box 136 | (x0, y0, x1, y1) = (int(x + 0.5) for x in box) 137 | composite = cv2.rectangle( 138 | composite, (x0, y0), (x1, y1), color=box_color, thickness=thickness 139 | ) 140 | composite = composite.astype(np.uint8) 141 | 142 | # overlay text 143 | font_scale = 0.001 * image.shape[0] 144 | font_thickness = thickness 145 | font = cv2.FONT_HERSHEY_TRIPLEX 146 | text = "%s %.3f" % (cat_name, score) 147 | ((text_w, text_h), _) = cv2.getTextSize(text, font, font_scale, font_thickness) 148 | # Place text background. 149 | if x0 + text_w > composite.shape[1]: 150 | x0 = composite.shape[1] - text_w 151 | if y0 - int(1.2 * text_h) < 0: 152 | y0 = int(1.2 * text_h) 153 | back_topleft = x0, y0 - int(1.3 * text_h) 154 | back_bottomright = x0 + text_w, y0 155 | cv2.rectangle(composite, back_topleft, back_bottomright, box_color, -1) 156 | # Show text 157 | text_bottomleft = x0, y0 - int(0.2 * text_h) 158 | cv2.putText( 159 | composite, 160 | text, 161 | text_bottomleft, 162 | font, 163 | font_scale, 164 | text_color, 165 | thickness=font_thickness, 166 | lineType=cv2.LINE_AA, 167 | ) 168 | 169 | save_file = os.path.join( 170 | self.output_dir, "%d_mask_%s_%.3f.png" % (det_id, cat_name, score) 171 | ) 172 | cv2.imwrite(save_file, composite[:, :, ::-1]) 173 | 174 | save_file = os.path.join( 175 | self.output_dir, "%d_mesh_%s_%.3f.obj" % (det_id, cat_name, score) 176 | ) 177 | verts, faces = mesh.get_mesh_verts_faces(0) 178 | save_obj(save_file, verts, faces) 179 | 180 | 181 | def setup_cfg(args): 182 | cfg = get_cfg() 183 | get_meshrcnn_cfg_defaults(cfg) 184 | cfg.merge_from_file(args.config_file) 185 | cfg.merge_from_list(args.opts) 186 | cfg.freeze() 187 | return cfg 188 | 189 | 190 | def get_parser(): 191 | parser = argparse.ArgumentParser(description="MeshRCNN Demo") 192 | parser.add_argument( 193 | "--config-file", 194 | default="configs/pix3d/meshrcnn_R50_FPN.yaml", 195 | metavar="FILE", 196 | help="path to config file", 197 | ) 198 | parser.add_argument("--input", help="A path to an input image") 199 | parser.add_argument("--output", help="A directory to save output visualizations") 200 | parser.add_argument( 201 | "--focal-length", type=float, default=20.0, help="Focal length for the image" 202 | ) 203 | parser.add_argument( 204 | "--onlyhighest", 205 | action="store_true", 206 | help="will return only the highest scoring detection", 207 | ) 208 | 209 | parser.add_argument( 210 | "opts", 211 | help="Modify model config options using the command-line", 212 | default=None, 213 | nargs=argparse.REMAINDER, 214 | ) 215 | return parser 216 | 217 | 218 | def main() -> None: 219 | mp.set_start_method("spawn", force=True) 220 | args = get_parser().parse_args() 221 | logger = setup_logger(name="demo") 222 | logger.info("Arguments: " + str(args)) 223 | 224 | cfg = setup_cfg(args) 225 | 226 | im_name = args.input.split("/")[-1].split(".")[0] 227 | 228 | demo = VisualizationDemo( 229 | cfg, 230 | vis_highest_scoring=args.onlyhighest, 231 | output_dir=os.path.join(args.output, im_name), 232 | ) 233 | 234 | # use PIL, to be consistent with evaluation 235 | img = read_image(args.input, format="BGR") 236 | predictions = demo.run_on_image(img, focal_length=args.focal_length) 237 | logger.info("Predictions saved in %s" % (os.path.join(args.output, im_name))) 238 | 239 | 240 | if __name__ == "__main__": 241 | main() # pragma: no cover 242 | -------------------------------------------------------------------------------- /infra/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | # Run this script "./infra/linter.sh" before you commit. 4 | echo "Running isort..." 5 | isort -y -sp . 6 | 7 | echo "Running black..." 8 | black -l 100 . 9 | 10 | echo "Running flake..." 11 | flake8 . 12 | 13 | command -v arc > /dev/null && { 14 | echo "Running arc lint ..." 15 | arc lint 16 | } 17 | -------------------------------------------------------------------------------- /meshrcnn/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .config import get_meshrcnn_cfg_defaults 3 | -------------------------------------------------------------------------------- /meshrcnn/config/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def get_meshrcnn_cfg_defaults(cfg): 8 | """ 9 | Customize the detectron2 cfg to include some new keys and default values 10 | for Mesh R-CNN 11 | """ 12 | 13 | cfg.MODEL.VOXEL_ON = False 14 | cfg.MODEL.MESH_ON = False 15 | cfg.MODEL.ZPRED_ON = False 16 | cfg.MODEL.VIS_MINIBATCH = False # visualize minibatches 17 | 18 | # aspect ratio grouping has no difference in performance 19 | # but might reduce memory by a little bit 20 | cfg.DATALOADER.ASPECT_RATIO_GROUPING = False 21 | 22 | # ------------------------------------------------------------------------ # 23 | # Z Predict Head 24 | # ------------------------------------------------------------------------ # 25 | cfg.MODEL.ROI_Z_HEAD = CN() 26 | cfg.MODEL.ROI_Z_HEAD.NAME = "FastRCNNFCHead" 27 | cfg.MODEL.ROI_Z_HEAD.NUM_FC = 2 28 | cfg.MODEL.ROI_Z_HEAD.FC_DIM = 1024 29 | cfg.MODEL.ROI_Z_HEAD.POOLER_RESOLUTION = 7 30 | cfg.MODEL.ROI_Z_HEAD.POOLER_SAMPLING_RATIO = 2 31 | # Type of pooling operation applied to the incoming feature map for each RoI 32 | cfg.MODEL.ROI_Z_HEAD.POOLER_TYPE = "ROIAlign" 33 | # Whether to use class agnostic for z regression 34 | cfg.MODEL.ROI_Z_HEAD.CLS_AGNOSTIC_Z_REG = False 35 | # Default weight on (dz) for normalizing z regression targets 36 | cfg.MODEL.ROI_Z_HEAD.Z_REG_WEIGHT = 5.0 37 | # The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1. 38 | cfg.MODEL.ROI_Z_HEAD.SMOOTH_L1_BETA = 0.0 39 | 40 | # ------------------------------------------------------------------------ # 41 | # Voxel Head 42 | # ------------------------------------------------------------------------ # 43 | cfg.MODEL.ROI_VOXEL_HEAD = CN() 44 | cfg.MODEL.ROI_VOXEL_HEAD.NAME = "VoxelRCNNConvUpsampleHead" 45 | cfg.MODEL.ROI_VOXEL_HEAD.POOLER_RESOLUTION = 14 46 | cfg.MODEL.ROI_VOXEL_HEAD.POOLER_SAMPLING_RATIO = 0 47 | # Type of pooling operation applied to the incoming feature map for each RoI 48 | cfg.MODEL.ROI_VOXEL_HEAD.POOLER_TYPE = "ROIAlign" 49 | # Whether to use class agnostic for voxel prediction 50 | cfg.MODEL.ROI_VOXEL_HEAD.CLS_AGNOSTIC_VOXEL = False 51 | # The number of convs in the voxel head and the number of channels 52 | cfg.MODEL.ROI_VOXEL_HEAD.NUM_CONV = 0 53 | cfg.MODEL.ROI_VOXEL_HEAD.CONV_DIM = 256 54 | # Normalization method for the convolution layers. Options: "" (no norm), "GN" 55 | cfg.MODEL.ROI_VOXEL_HEAD.NORM = "" 56 | # The number of depth channels for the predicted voxels 57 | cfg.MODEL.ROI_VOXEL_HEAD.NUM_DEPTH = 28 58 | cfg.MODEL.ROI_VOXEL_HEAD.LOSS_WEIGHT = 1.0 59 | cfg.MODEL.ROI_VOXEL_HEAD.CUBIFY_THRESH = 0.0 60 | 61 | # ------------------------------------------------------------------------ # 62 | # Mesh Head 63 | # ------------------------------------------------------------------------ # 64 | cfg.MODEL.ROI_MESH_HEAD = CN() 65 | cfg.MODEL.ROI_MESH_HEAD.NAME = "MeshRCNNGraphConvHead" 66 | cfg.MODEL.ROI_MESH_HEAD.POOLER_RESOLUTION = 14 67 | cfg.MODEL.ROI_MESH_HEAD.POOLER_SAMPLING_RATIO = 0 68 | # Type of pooling operation applied to the incoming feature map for each RoI 69 | cfg.MODEL.ROI_MESH_HEAD.POOLER_TYPE = "ROIAlign" 70 | # Numer of stages 71 | cfg.MODEL.ROI_MESH_HEAD.NUM_STAGES = 1 72 | cfg.MODEL.ROI_MESH_HEAD.NUM_GRAPH_CONVS = 1 # per stage 73 | cfg.MODEL.ROI_MESH_HEAD.GRAPH_CONV_DIM = 256 74 | cfg.MODEL.ROI_MESH_HEAD.GRAPH_CONV_INIT = "normal" 75 | # Mesh sampling 76 | cfg.MODEL.ROI_MESH_HEAD.GT_NUM_SAMPLES = 5000 77 | cfg.MODEL.ROI_MESH_HEAD.PRED_NUM_SAMPLES = 5000 78 | # loss weights 79 | cfg.MODEL.ROI_MESH_HEAD.CHAMFER_LOSS_WEIGHT = 1.0 80 | cfg.MODEL.ROI_MESH_HEAD.NORMALS_LOSS_WEIGHT = 1.0 81 | cfg.MODEL.ROI_MESH_HEAD.EDGE_LOSS_WEIGHT = 1.0 82 | # coord thresh 83 | cfg.MODEL.ROI_MESH_HEAD.GT_COORD_THRESH = 0.0 84 | # Init ico_sphere level (only for when voxel_on is false) 85 | cfg.MODEL.ROI_MESH_HEAD.ICO_SPHERE_LEVEL = -1 86 | 87 | return cfg 88 | -------------------------------------------------------------------------------- /meshrcnn/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # ensure the builtin datasets are registered 3 | from . import datasets # isort:skip 4 | 5 | from .meshrcnn_transforms import MeshRCNNMapper 6 | -------------------------------------------------------------------------------- /meshrcnn/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .pix3d import load_pix3d_json # isort:skip 3 | from . import builtin # ensure the builtin datasets are registered 4 | -------------------------------------------------------------------------------- /meshrcnn/data/datasets/builtin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file registers pre-defined datasets at hard-coded paths, and their metadata. 4 | 5 | We hard-code metadata for common datasets. This will enable: 6 | 1. Consistency check when loading the datasets 7 | 2. Use models on these standard datasets directly and run demos, 8 | without having to download the dataset annotations 9 | 10 | We hard-code some paths to the dataset that's assumed to 11 | exist in "./datasets/". 12 | """ 13 | 14 | import os 15 | 16 | from detectron2.data import DatasetCatalog, MetadataCatalog 17 | 18 | from meshrcnn.data.datasets import load_pix3d_json 19 | 20 | 21 | def get_pix3d_metadata(): 22 | meta = [ 23 | {"name": "bed", "color": [255, 255, 25], "id": 1}, # noqa 24 | {"name": "bookcase", "color": [230, 25, 75], "id": 2}, # noqa 25 | {"name": "chair", "color": [250, 190, 190], "id": 3}, # noqa 26 | {"name": "desk", "color": [60, 180, 75], "id": 4}, # noqa 27 | {"name": "misc", "color": [230, 190, 255], "id": 5}, # noqa 28 | {"name": "sofa", "color": [0, 130, 200], "id": 6}, # noqa 29 | {"name": "table", "color": [245, 130, 48], "id": 7}, # noqa 30 | {"name": "tool", "color": [70, 240, 240], "id": 8}, # noqa 31 | {"name": "wardrobe", "color": [210, 245, 60], "id": 9}, # noqa 32 | ] 33 | return meta 34 | 35 | 36 | SPLITS = { 37 | "pix3d_s1_train": ("pix3d", "pix3d/pix3d_s1_train.json"), 38 | "pix3d_s1_test": ("pix3d", "pix3d/pix3d_s1_test.json"), 39 | "pix3d_s2_train": ("pix3d", "pix3d/pix3d_s2_train.json"), 40 | "pix3d_s2_test": ("pix3d", "pix3d/pix3d_s2_test.json"), 41 | } 42 | 43 | 44 | def register_pix3d(dataset_name, json_file, image_root, root="datasets"): 45 | DatasetCatalog.register( 46 | dataset_name, lambda: load_pix3d_json(json_file, image_root, dataset_name) 47 | ) 48 | things_ids = [k["id"] for k in get_pix3d_metadata()] 49 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(things_ids)} 50 | thing_classes = [k["name"] for k in get_pix3d_metadata()] 51 | thing_colors = [k["color"] for k in get_pix3d_metadata()] 52 | json_file = os.path.join(root, json_file) 53 | image_root = os.path.join(root, image_root) 54 | metadata = { 55 | "thing_classes": thing_classes, 56 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 57 | "thing_colors": thing_colors, 58 | } 59 | MetadataCatalog.get(dataset_name).set( 60 | json_file=json_file, image_root=image_root, evaluator_type="pix3d", **metadata 61 | ) 62 | 63 | 64 | for key, (data_root, anno_file) in SPLITS.items(): 65 | register_pix3d(key, anno_file, data_root) 66 | -------------------------------------------------------------------------------- /meshrcnn/data/datasets/pix3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import contextlib 3 | import io 4 | import logging 5 | import os 6 | 7 | from detectron2.data import MetadataCatalog 8 | from detectron2.structures import BoxMode 9 | from detectron2.utils.file_io import PathManager 10 | 11 | """ 12 | This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format". 13 | """ 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | __all__ = ["load_pix3d_json"] 18 | 19 | 20 | def load_pix3d_json(json_file, image_root, dataset_name=None): 21 | """ 22 | Load a json file with Pix3D's instances annotation format. 23 | 24 | Args: 25 | json_file (str): full path to the json file in COCO instances annotation format. 26 | image_root (str): the directory where the images in this json file exists. 27 | dataset_name (str): the name of the dataset (e.g., coco_2017_train). 28 | If provided, this function will also put "thing_classes" into 29 | the metadata associated with this dataset. 30 | 31 | Returns: 32 | list[dict]: a list of dicts in "Detectron2 Dataset" format. (See DATASETS.md) 33 | 34 | Notes: 35 | 1. This function does not read the image files. 36 | The results do not have the "image" field. 37 | """ 38 | from pycocotools.coco import COCO 39 | 40 | json_file = PathManager.get_local_path(json_file) 41 | with contextlib.redirect_stdout(io.StringIO()): 42 | coco_api = COCO(json_file) 43 | 44 | id_map = None 45 | assert dataset_name is not None 46 | meta = MetadataCatalog.get(dataset_name) 47 | cat_ids = sorted(coco_api.getCatIds()) 48 | cats = coco_api.loadCats(cat_ids) 49 | # The categories in a custom json file may not be sorted. 50 | thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])] 51 | meta.thing_classes = thing_classes 52 | 53 | # In COCO, certain category ids are artificially removed, 54 | # and by convention they are always ignored. 55 | # We deal with COCO's id issue and translate 56 | # the category ids to contiguous ids in [0, 80). 57 | 58 | # It works by looking at the "categories" field in the json, therefore 59 | # if users' own json also have incontiguous ids, we'll 60 | # apply this mapping as well but print a warning. 61 | if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)): 62 | logger.warning( 63 | """ 64 | Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you. 65 | """ 66 | ) 67 | id_map = {v: i for i, v in enumerate(cat_ids)} 68 | meta.thing_dataset_id_to_contiguous_id = id_map 69 | 70 | # sort indices for reproducible results 71 | img_ids = sorted(list(coco_api.imgs.keys())) 72 | 73 | # imgs is a list of dicts, each looks something like: 74 | # {'license': 4, 75 | # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', 76 | # 'file_name': 'COCO_val2014_000000001268.jpg', 77 | # 'height': 427, 78 | # 'width': 640, 79 | # 'date_captured': '2013-11-17 05:57:24', 80 | # 'id': 1268} 81 | imgs = coco_api.loadImgs(img_ids) 82 | # anns is a list[list[dict]], where each dict is an annotation 83 | # record for an object. The inner list enumerates the objects in an image 84 | # and the outer list enumerates over images. Example of anns[0]: 85 | # [{'segmentation': [[192.81, 86 | # 247.09, 87 | # ... 88 | # 219.03, 89 | # 249.06]], 90 | # 'area': 1035.749, 91 | # 'iscrowd': 0, 92 | # 'image_id': 1268, 93 | # 'bbox': [192.81, 224.8, 74.73, 33.43], 94 | # 'category_id': 16, 95 | # 'id': 42986}, 96 | # ...] 97 | anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] 98 | imgs_anns = list(zip(imgs, anns)) 99 | 100 | logger.info( 101 | "Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file) 102 | ) 103 | 104 | dataset_dicts = [] 105 | 106 | for img_dict, anno_dict_list in imgs_anns: 107 | # examples with imgfiles = {img/table/1749.jpg, img/table/0045.png} 108 | # have a mismatch between images and masks. Thus, ignore 109 | if img_dict["file_name"] in ["img/table/1749.jpg", "img/table/0045.png"]: 110 | continue 111 | 112 | record = {} 113 | record["file_name"] = os.path.join(image_root, img_dict["file_name"]) 114 | record["height"] = img_dict["height"] 115 | record["width"] = img_dict["width"] 116 | image_id = record["image_id"] = img_dict["id"] 117 | 118 | objs = [] 119 | for anno in anno_dict_list: 120 | # Check that the image_id in this annotation is the same as 121 | # the image_id we're looking at. 122 | # This fails only when the data parsing logic or the annotation file is buggy. 123 | 124 | assert anno["image_id"] == image_id 125 | assert anno.get("ignore", 0) == 0 126 | 127 | obj = { 128 | field: anno[field] 129 | for field in ["iscrowd", "bbox", "category_id"] 130 | if field in anno 131 | } 132 | 133 | segm = anno.get("segmentation", None) 134 | if segm: # string 135 | obj["segmentation"] = os.path.join(image_root, segm) 136 | 137 | voxel = anno.get("voxel", None) 138 | if voxel: 139 | obj["voxel"] = os.path.join(image_root, voxel) 140 | 141 | mesh = anno.get("model", None) 142 | if mesh: 143 | obj["mesh"] = mesh 144 | 145 | # camera 146 | obj["K"] = anno["K"] 147 | obj["R"] = anno["rot_mat"] 148 | obj["t"] = anno["trans_mat"] 149 | 150 | obj["bbox_mode"] = BoxMode.XYWH_ABS 151 | if id_map: 152 | obj["category_id"] = id_map[obj["category_id"]] 153 | objs.append(obj) 154 | record["annotations"] = objs 155 | dataset_dicts.append(record) 156 | 157 | return dataset_dicts 158 | 159 | 160 | def main() -> None: 161 | global logger 162 | """ 163 | Test the Pix3D json dataset loader. 164 | 165 | Usage: 166 | python -m meshrcnn.data.datasets.pix3d \ 167 | path/to/json path/to/image_root dataset_name 168 | 169 | "dataset_name" can be "coco", "coco_person", or other 170 | pre-registered ones 171 | """ 172 | import sys 173 | 174 | import cv2 175 | import detectron2.data.datasets # noqa # add pre-defined metadata 176 | from detectron2.utils.logger import setup_logger 177 | from meshrcnn.utils.vis import draw_pix3d_dict 178 | 179 | logger = setup_logger(name=__name__) 180 | meta = MetadataCatalog.get(sys.argv[3]) 181 | 182 | dicts = load_pix3d_json(sys.argv[1], sys.argv[2], sys.argv[3]) 183 | logger.info("Done loading {} samples.".format(len(dicts))) 184 | 185 | dirname = "pix3d-data-vis" 186 | os.makedirs(dirname, exist_ok=True) 187 | for d in dicts: 188 | vis = draw_pix3d_dict(d, meta.thing_classes + ["0"]) 189 | fpath = os.path.join(dirname, os.path.basename(d["file_name"])) 190 | cv2.imwrite(fpath, vis) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() # pragma: no cover 195 | -------------------------------------------------------------------------------- /meshrcnn/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .pix3d_evaluation import Pix3DEvaluator, transform_meshes_to_camera_coord_system 3 | -------------------------------------------------------------------------------- /meshrcnn/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .roi_heads import * # noqa 3 | -------------------------------------------------------------------------------- /meshrcnn/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .roi_heads import MeshRCNNROIHeads # noqa 3 | -------------------------------------------------------------------------------- /meshrcnn/modeling/roi_heads/mask_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | from detectron2.layers import cat 4 | from detectron2.utils.events import get_event_storage 5 | 6 | from meshrcnn.structures.mask import batch_crop_masks_within_box 7 | from torch.nn import functional as F 8 | 9 | 10 | def mask_rcnn_loss(pred_mask_logits, instances): 11 | """ 12 | Compute the mask prediction loss defined in the Mask R-CNN paper. 13 | 14 | Args: 15 | pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask) 16 | for class-specific or class-agnostic, where B is the total number of predicted masks 17 | in all images, C is the number of foreground classes, and Hmask, Wmask are the height 18 | and width of the mask predictions. The values are logits. 19 | instances (list[Instances]): A list of N Instances, where N is the number of images 20 | in the batch. These instances are in 1:1 21 | correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask, 22 | ...) associated with each instance are stored in fields. 23 | 24 | Returns: 25 | mask_loss (Tensor): A scalar tensor containing the loss. 26 | and groundtruth masks for visualization 27 | """ 28 | cls_agnostic_mask = pred_mask_logits.size(1) == 1 29 | total_num_masks = pred_mask_logits.size(0) 30 | mask_side_len = pred_mask_logits.size(2) 31 | assert pred_mask_logits.size(2) == pred_mask_logits.size( 32 | 3 33 | ), "Mask prediction must be square!" 34 | 35 | gt_classes = [] 36 | gt_masks = [] 37 | for instances_per_image in instances: 38 | if len(instances_per_image) == 0: 39 | continue 40 | if not cls_agnostic_mask: 41 | gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) 42 | gt_classes.append(gt_classes_per_image) 43 | 44 | gt_masks_per_image = batch_crop_masks_within_box( 45 | instances_per_image.gt_masks, 46 | instances_per_image.proposal_boxes.tensor, 47 | mask_side_len, 48 | ).to(device=pred_mask_logits.device) 49 | gt_masks.append(gt_masks_per_image) 50 | 51 | if len(gt_masks) == 0: 52 | return pred_mask_logits.sum() * 0, gt_masks 53 | 54 | gt_masks = cat(gt_masks, dim=0) 55 | assert gt_masks.numel() > 0, gt_masks.shape 56 | 57 | if cls_agnostic_mask: 58 | pred_mask_logits = pred_mask_logits[:, 0] 59 | else: 60 | indices = torch.arange(total_num_masks) 61 | gt_classes = cat(gt_classes, dim=0) 62 | pred_mask_logits = pred_mask_logits[indices, gt_classes] 63 | 64 | # Log the training accuracy (using gt classes and 0.5 threshold) 65 | # Note that here we allow gt_masks to be float as well 66 | # (depend on the implementation of rasterize()) 67 | mask_accurate = (pred_mask_logits > 0.5) == (gt_masks > 0.5) 68 | mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel() 69 | get_event_storage().put_scalar("mask_rcnn/accuracy", mask_accuracy) 70 | 71 | mask_loss = F.binary_cross_entropy_with_logits( 72 | pred_mask_logits, gt_masks.to(dtype=torch.float32), reduction="mean" 73 | ) 74 | return mask_loss, gt_masks 75 | -------------------------------------------------------------------------------- /meshrcnn/modeling/roi_heads/mesh_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from collections import OrderedDict 3 | 4 | import fvcore.nn.weight_init as weight_init 5 | import torch 6 | from detectron2.layers import cat, ShapeSpec 7 | from detectron2.utils.registry import Registry 8 | 9 | from meshrcnn.structures.mesh import batch_crop_meshes_within_box, MeshInstances 10 | from pytorch3d.loss import chamfer_distance, mesh_edge_loss 11 | from pytorch3d.ops import ( 12 | GraphConv, 13 | sample_points_from_meshes, 14 | SubdivideMeshes, 15 | vert_align, 16 | ) 17 | from pytorch3d.structures import Meshes 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | ROI_MESH_HEAD_REGISTRY = Registry("ROI_MESH_HEAD") 22 | 23 | 24 | def mesh_rcnn_loss( 25 | pred_meshes, 26 | instances, 27 | loss_weights=None, 28 | gt_num_samples=5000, 29 | pred_num_samples=5000, 30 | gt_coord_thresh=None, 31 | ): 32 | """ 33 | Compute the mesh prediction loss defined in the Mesh R-CNN paper. 34 | 35 | Args: 36 | pred_meshes (list of Meshes): A list of K Meshes. Each entry contains B meshes, 37 | where B is the total number of predicted meshes in all images. 38 | K is the number of refinements 39 | instances (list[Instances]): A list of N Instances, where N is the number of images 40 | in the batch. These instances are in 1:1 correspondence with the pred_meshes. 41 | The ground-truth labels (class, box, mask, ...) associated with each instance 42 | are stored in fields. 43 | loss_weights (dict): Contains the weights for the different losses, e.g. 44 | loss_weights = {'champfer': 1.0, 'normals': 0.0, 'edge': 0.2} 45 | gt_num_samples (int): The number of points to sample from gt meshes 46 | pred_num_samples (int): The number of points to sample from predicted meshes 47 | gt_coord_thresh (float): A threshold value over which the batch is ignored 48 | Returns: 49 | mesh_loss (Tensor): A scalar tensor containing the loss. 50 | """ 51 | if not isinstance(pred_meshes, list): 52 | raise ValueError("Expecting a list of Meshes") 53 | 54 | gt_verts, gt_faces = [], [] 55 | for instances_per_image in instances: 56 | if len(instances_per_image) == 0: 57 | continue 58 | 59 | gt_K = instances_per_image.gt_K 60 | gt_mesh_per_image = batch_crop_meshes_within_box( 61 | instances_per_image.gt_meshes, 62 | instances_per_image.proposal_boxes.tensor, 63 | gt_K, 64 | ).to(device=pred_meshes[0].device) 65 | gt_verts.extend(gt_mesh_per_image.verts_list()) 66 | gt_faces.extend(gt_mesh_per_image.faces_list()) 67 | 68 | if len(gt_verts) == 0: 69 | return None, None 70 | 71 | gt_meshes = Meshes(verts=gt_verts, faces=gt_faces) 72 | gt_valid = gt_meshes.valid 73 | gt_sampled_verts, gt_sampled_normals = sample_points_from_meshes( 74 | gt_meshes, num_samples=gt_num_samples, return_normals=True 75 | ) 76 | 77 | all_loss_chamfer = [] 78 | all_loss_normals = [] 79 | all_loss_edge = [] 80 | for pred_mesh in pred_meshes: 81 | pred_sampled_verts, pred_sampled_normals = sample_points_from_meshes( 82 | pred_mesh, num_samples=pred_num_samples, return_normals=True 83 | ) 84 | wts = (pred_mesh.valid * gt_valid).to(dtype=torch.float32) 85 | # chamfer loss 86 | loss_chamfer, loss_normals = chamfer_distance( 87 | pred_sampled_verts, 88 | gt_sampled_verts, 89 | x_normals=pred_sampled_normals, 90 | y_normals=gt_sampled_normals, 91 | weights=wts, 92 | ) 93 | 94 | # chamfer loss 95 | loss_chamfer = loss_chamfer * loss_weights["chamfer"] 96 | all_loss_chamfer.append(loss_chamfer) 97 | # normal loss 98 | loss_normals = loss_normals * loss_weights["normals"] 99 | all_loss_normals.append(loss_normals) 100 | # mesh edge regularization 101 | loss_edge = mesh_edge_loss(pred_mesh) 102 | loss_edge = loss_edge * loss_weights["edge"] 103 | all_loss_edge.append(loss_edge) 104 | 105 | loss_chamfer = sum(all_loss_chamfer) 106 | loss_normals = sum(all_loss_normals) 107 | loss_edge = sum(all_loss_edge) 108 | 109 | # if the rois are bad, the target verts can be arbitrarily large 110 | # causing exploding gradients. If this is the case, ignore the batch 111 | if gt_coord_thresh and gt_sampled_verts.abs().max() > gt_coord_thresh: 112 | loss_chamfer = loss_chamfer * 0.0 113 | loss_normals = loss_normals * 0.0 114 | loss_edge = loss_edge * 0.0 115 | 116 | return loss_chamfer, loss_normals, loss_edge, gt_meshes 117 | 118 | 119 | def mesh_rcnn_inference(pred_meshes, pred_instances): 120 | """ 121 | Return the predicted mesh for each predicted instance 122 | 123 | Args: 124 | pred_meshes (Meshes): A class of Meshes containing B meshes, where B is 125 | the total number of predictions in all images. 126 | pred_instances (list[Instances]): A list of N Instances, where N is the number of images 127 | in the batch. Each Instances must have field "pred_classes". 128 | 129 | Returns: 130 | None. pred_instances will contain an extra "pred_meshes" field storing the meshes 131 | """ 132 | num_boxes_per_image = [len(i) for i in pred_instances] 133 | pred_meshes = pred_meshes.split(num_boxes_per_image) 134 | 135 | for pred_mesh, instances in zip(pred_meshes, pred_instances): 136 | # NOTE do not save the Meshes object; pickle dumps become inefficient 137 | if pred_mesh.isempty(): 138 | continue 139 | verts_list = pred_mesh.verts_list() 140 | faces_list = pred_mesh.faces_list() 141 | instances.pred_meshes = MeshInstances( 142 | [(v, f) for (v, f) in zip(verts_list, faces_list)] 143 | ) 144 | 145 | 146 | class MeshRefinementStage(nn.Module): 147 | def __init__( 148 | self, img_feat_dim, vert_feat_dim, hidden_dim, stage_depth, gconv_init="normal" 149 | ): 150 | """ 151 | Args: 152 | img_feat_dim: Dimension of features we will get from vert_align 153 | vert_feat_dim: Dimension of vert_feats we will receive from the 154 | previous stage; can be 0 155 | hidden_dim: Output dimension for graph-conv layers 156 | stage_depth: Number of graph-conv layers to use 157 | gconv_init: How to initialize graph-conv layers 158 | """ 159 | super(MeshRefinementStage, self).__init__() 160 | 161 | # fc layer to reduce feature dimension 162 | self.bottleneck = nn.Linear(img_feat_dim, hidden_dim) 163 | 164 | # deform layer 165 | self.verts_offset = nn.Linear(hidden_dim + 3, 3) 166 | 167 | # graph convs 168 | self.gconvs = nn.ModuleList() 169 | for i in range(stage_depth): 170 | if i == 0: 171 | input_dim = hidden_dim + vert_feat_dim + 3 172 | else: 173 | input_dim = hidden_dim + 3 174 | gconv = GraphConv(input_dim, hidden_dim, init=gconv_init, directed=False) 175 | self.gconvs.append(gconv) 176 | 177 | # initialization 178 | nn.init.normal_(self.bottleneck.weight, mean=0.0, std=0.01) 179 | nn.init.constant_(self.bottleneck.bias, 0) 180 | 181 | nn.init.zeros_(self.verts_offset.weight) 182 | nn.init.constant_(self.verts_offset.bias, 0) 183 | 184 | def forward(self, x, mesh, vert_feats=None): 185 | img_feats = vert_align(x, mesh, return_packed=True, padding_mode="border") 186 | # 256 -> hidden_dim 187 | img_feats = F.relu(self.bottleneck(img_feats)) 188 | if vert_feats is None: 189 | # hidden_dim + 3 190 | vert_feats = torch.cat((img_feats, mesh.verts_packed()), dim=1) 191 | else: 192 | # hidden_dim * 2 + 3 193 | vert_feats = torch.cat((vert_feats, img_feats, mesh.verts_packed()), dim=1) 194 | for graph_conv in self.gconvs: 195 | vert_feats_nopos = F.relu(graph_conv(vert_feats, mesh.edges_packed())) 196 | vert_feats = torch.cat((vert_feats_nopos, mesh.verts_packed()), dim=1) 197 | 198 | # refine 199 | deform = torch.tanh(self.verts_offset(vert_feats)) 200 | mesh = mesh.offset_verts(deform) 201 | return mesh, vert_feats_nopos 202 | 203 | 204 | @ROI_MESH_HEAD_REGISTRY.register() 205 | class MeshRCNNGraphConvHead(nn.Module): 206 | """ 207 | A mesh head with vert align, graph conv layers and refine layers. 208 | """ 209 | 210 | def __init__(self, cfg, input_shape: ShapeSpec): 211 | super(MeshRCNNGraphConvHead, self).__init__() 212 | 213 | # fmt: off 214 | num_stages = cfg.MODEL.ROI_MESH_HEAD.NUM_STAGES 215 | num_graph_convs = cfg.MODEL.ROI_MESH_HEAD.NUM_GRAPH_CONVS # per stage 216 | graph_conv_dim = cfg.MODEL.ROI_MESH_HEAD.GRAPH_CONV_DIM 217 | graph_conv_init = cfg.MODEL.ROI_MESH_HEAD.GRAPH_CONV_INIT 218 | input_channels = input_shape.channels 219 | # fmt: on 220 | 221 | self.stages = nn.ModuleList() 222 | for i in range(num_stages): 223 | vert_feat_dim = 0 if i == 0 else graph_conv_dim 224 | stage = MeshRefinementStage( 225 | input_channels, 226 | vert_feat_dim, 227 | graph_conv_dim, 228 | num_graph_convs, 229 | gconv_init=graph_conv_init, 230 | ) 231 | self.stages.append(stage) 232 | 233 | def forward(self, x, mesh): 234 | if x.numel() == 0 or mesh.isempty(): 235 | return [Meshes(verts=[], faces=[])] 236 | 237 | meshes = [] 238 | vert_feats = None 239 | for stage in self.stages: 240 | mesh, vert_feats = stage(x, mesh, vert_feats=vert_feats) 241 | meshes.append(mesh) 242 | return meshes 243 | 244 | 245 | @ROI_MESH_HEAD_REGISTRY.register() 246 | class MeshRCNNGraphConvSubdHead(nn.Module): 247 | """ 248 | A mesh head with vert align, graph conv layers and refine and subdivide layers. 249 | """ 250 | 251 | def __init__(self, cfg, input_shape: ShapeSpec): 252 | super(MeshRCNNGraphConvSubdHead, self).__init__() 253 | 254 | # fmt: off 255 | self.num_stages = cfg.MODEL.ROI_MESH_HEAD.NUM_STAGES 256 | num_graph_convs = cfg.MODEL.ROI_MESH_HEAD.NUM_GRAPH_CONVS # per stage 257 | graph_conv_dim = cfg.MODEL.ROI_MESH_HEAD.GRAPH_CONV_DIM 258 | graph_conv_init = cfg.MODEL.ROI_MESH_HEAD.GRAPH_CONV_INIT 259 | input_channels = input_shape.channels 260 | # fmt: on 261 | 262 | self.stages = nn.ModuleList() 263 | for i in range(self.num_stages): 264 | vert_feat_dim = 0 if i == 0 else graph_conv_dim 265 | stage = MeshRefinementStage( 266 | input_channels, 267 | vert_feat_dim, 268 | graph_conv_dim, 269 | num_graph_convs, 270 | gconv_init=graph_conv_init, 271 | ) 272 | self.stages.append(stage) 273 | 274 | def forward(self, x, mesh): 275 | if x.numel() == 0 or mesh.isempty(): 276 | return [Meshes(verts=[], faces=[])] 277 | 278 | meshes = [] 279 | vert_feats = None 280 | for i, stage in enumerate(self.stages): 281 | mesh, vert_feats = stage(x, mesh, vert_feats=vert_feats) 282 | meshes.append(mesh) 283 | if i < self.num_stages - 1: 284 | subdivide = SubdivideMeshes() 285 | mesh, vert_feats = subdivide(mesh, feats=vert_feats) 286 | return meshes 287 | 288 | 289 | def build_mesh_head(cfg, input_shape): 290 | name = cfg.MODEL.ROI_MESH_HEAD.NAME 291 | return ROI_MESH_HEAD_REGISTRY.get(name)(cfg, input_shape) 292 | -------------------------------------------------------------------------------- /meshrcnn/modeling/roi_heads/voxel_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from detectron2.layers import cat, Conv2d, ConvTranspose2d, get_norm 5 | from detectron2.utils.events import get_event_storage 6 | from detectron2.utils.registry import Registry 7 | 8 | from meshrcnn.structures.voxel import batch_crop_voxels_within_box 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | ROI_VOXEL_HEAD_REGISTRY = Registry("ROI_VOXEL_HEAD") 13 | 14 | 15 | def voxel_rcnn_loss(pred_voxel_logits, instances, loss_weight=1.0): 16 | """ 17 | Compute the voxel prediction loss defined in the Mesh R-CNN paper. 18 | 19 | Args: 20 | pred_voxel_logits (Tensor): A tensor of shape (B, C, D, H, W) or (B, 1, D, H, W) 21 | for class-specific or class-agnostic, where B is the total number of predicted voxels 22 | in all images, C is the number of foreground classes, and D, H, W are the depth, 23 | height and width of the voxel predictions. The values are logits. 24 | instances (list[Instances]): A list of N Instances, where N is the number of images 25 | in the batch. These instances are in 1:1 26 | correspondence with the pred_voxel_logits. The ground-truth labels (class, box, mask, 27 | ...) associated with each instance are stored in fields. 28 | loss_weight (float): A float to multiply the loss with. 29 | 30 | Returns: 31 | voxel_loss (Tensor): A scalar tensor containing the loss. 32 | """ 33 | cls_agnostic_voxel = pred_voxel_logits.size(1) == 1 34 | total_num_voxels = pred_voxel_logits.size(0) 35 | voxel_side_len = pred_voxel_logits.size(2) 36 | assert pred_voxel_logits.size(2) == pred_voxel_logits.size( 37 | 3 38 | ), "Voxel prediction must be square!" 39 | assert pred_voxel_logits.size(2) == pred_voxel_logits.size( 40 | 4 41 | ), "Voxel prediction must be square!" 42 | 43 | gt_classes = [] 44 | gt_voxel_logits = [] 45 | for instances_per_image in instances: 46 | if len(instances_per_image) == 0: 47 | continue 48 | if not cls_agnostic_voxel: 49 | gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) 50 | gt_classes.append(gt_classes_per_image) 51 | 52 | gt_voxels = instances_per_image.gt_voxels 53 | gt_K = instances_per_image.gt_K 54 | gt_voxel_logits_per_image = batch_crop_voxels_within_box( 55 | gt_voxels, instances_per_image.proposal_boxes.tensor, gt_K, voxel_side_len 56 | ).to(device=pred_voxel_logits.device) 57 | gt_voxel_logits.append(gt_voxel_logits_per_image) 58 | 59 | if len(gt_voxel_logits) == 0: 60 | return pred_voxel_logits.sum() * 0, gt_voxel_logits 61 | 62 | gt_voxel_logits = cat(gt_voxel_logits, dim=0) 63 | assert gt_voxel_logits.numel() > 0, gt_voxel_logits.shape 64 | 65 | if cls_agnostic_voxel: 66 | pred_voxel_logits = pred_voxel_logits[:, 0] 67 | else: 68 | indices = torch.arange(total_num_voxels) 69 | gt_classes = cat(gt_classes, dim=0) 70 | pred_voxel_logits = pred_voxel_logits[indices, gt_classes] 71 | 72 | # Log the training accuracy (using gt classes and 0.5 threshold) 73 | # Note that here we allow gt_voxel_logits to be float as well 74 | # (depend on the implementation of rasterize()) 75 | voxel_accurate = (pred_voxel_logits > 0.5) == (gt_voxel_logits > 0.5) 76 | voxel_accuracy = voxel_accurate.nonzero().size(0) / voxel_accurate.numel() 77 | get_event_storage().put_scalar("voxel_rcnn/accuracy", voxel_accuracy) 78 | 79 | voxel_loss = F.binary_cross_entropy_with_logits( 80 | pred_voxel_logits, gt_voxel_logits, reduction="mean" 81 | ) 82 | voxel_loss = voxel_loss * loss_weight 83 | return voxel_loss, gt_voxel_logits 84 | 85 | 86 | def voxel_rcnn_inference(pred_voxel_logits, pred_instances): 87 | """ 88 | Convert pred_voxel_logits to estimated foreground probability voxels while also 89 | extracting only the voxels for the predicted classes in pred_instances. For each 90 | predicted box, the voxel of the same class is attached to the instance by adding a 91 | new "pred_voxels" field to pred_instances. 92 | 93 | Args: 94 | pred_voxel_logits (Tensor): A tensor of shape (B, C, D, H, W) or (B, 1, D, H, W) 95 | for class-specific or class-agnostic, where B is the total number of predicted voxels 96 | in all images, C is the number of foreground classes, and D, H, W are the depth, height 97 | and width of the voxel predictions. The values are logits. 98 | pred_instances (list[Instances]): A list of N Instances, where N is the number of images 99 | in the batch. Each Instances must have field "pred_classes". 100 | 101 | Returns: 102 | None. pred_instances will contain an extra "pred_voxels" field storing a voxel of size (D, H, 103 | W) for predicted class. Note that the voxels are returned as a soft (non-quantized) 104 | voxels the resolution predicted by the network; post-processing steps are left 105 | to the caller. 106 | """ 107 | cls_agnostic_voxel = pred_voxel_logits.size(1) == 1 108 | 109 | if cls_agnostic_voxel: 110 | voxel_probs_pred = pred_voxel_logits.sigmoid() 111 | else: 112 | # Select voxels corresponding to the predicted classes 113 | num_voxels = pred_voxel_logits.shape[0] 114 | class_pred = cat([i.pred_classes for i in pred_instances]) 115 | indices = torch.arange(num_voxels, device=class_pred.device) 116 | voxel_probs_pred = pred_voxel_logits[indices, class_pred][:, None].sigmoid() 117 | # voxel_probs_pred.shape: (B, 1, D, H, W) 118 | 119 | num_boxes_per_image = [len(i) for i in pred_instances] 120 | voxel_probs_pred = voxel_probs_pred.split(num_boxes_per_image, dim=0) 121 | 122 | for prob, instances in zip(voxel_probs_pred, pred_instances): 123 | instances.pred_voxels = prob # (1, D, H, W) 124 | 125 | 126 | @ROI_VOXEL_HEAD_REGISTRY.register() 127 | class VoxelRCNNConvUpsampleHead(nn.Module): 128 | """ 129 | A voxel head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). 130 | """ 131 | 132 | def __init__(self, cfg, input_shape): 133 | super(VoxelRCNNConvUpsampleHead, self).__init__() 134 | 135 | # fmt: off 136 | num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES 137 | conv_dims = cfg.MODEL.ROI_VOXEL_HEAD.CONV_DIM 138 | self.norm = cfg.MODEL.ROI_VOXEL_HEAD.NORM 139 | num_conv = cfg.MODEL.ROI_VOXEL_HEAD.NUM_CONV 140 | input_channels = input_shape.channels 141 | cls_agnostic_voxel = cfg.MODEL.ROI_VOXEL_HEAD.CLS_AGNOSTIC_VOXEL 142 | # fmt: on 143 | 144 | self.conv_norm_relus = [] 145 | self.num_depth = cfg.MODEL.ROI_VOXEL_HEAD.NUM_DEPTH 146 | self.num_classes = 1 if cls_agnostic_voxel else num_classes 147 | 148 | for k in range(num_conv): 149 | conv = Conv2d( 150 | input_channels if k == 0 else conv_dims, 151 | conv_dims, 152 | kernel_size=3, 153 | stride=1, 154 | padding=1, 155 | bias=not self.norm, 156 | norm=get_norm(self.norm, conv_dims), 157 | activation=F.relu, 158 | ) 159 | self.add_module("voxel_fcn{}".format(k + 1), conv) 160 | self.conv_norm_relus.append(conv) 161 | 162 | self.deconv = ConvTranspose2d( 163 | conv_dims if num_conv > 0 else input_channels, 164 | conv_dims, 165 | kernel_size=2, 166 | stride=2, 167 | padding=0, 168 | ) 169 | 170 | self.predictor = Conv2d( 171 | conv_dims, 172 | self.num_classes * self.num_depth, 173 | kernel_size=1, 174 | stride=1, 175 | padding=0, 176 | ) 177 | 178 | for layer in self.conv_norm_relus + [self.deconv]: 179 | weight_init.c2_msra_fill(layer) 180 | # use normal distribution initialization for voxel prediction layer 181 | nn.init.normal_(self.predictor.weight, std=0.001) 182 | if self.predictor.bias is not None: 183 | nn.init.constant_(self.predictor.bias, 0) 184 | 185 | def forward(self, x): 186 | for layer in self.conv_norm_relus: 187 | x = layer(x) 188 | x = F.relu(self.deconv(x)) 189 | x = self.predictor(x) 190 | # reshape from (N, CD, H, W) to (N, C, D, H, W) 191 | x = x.reshape(x.size(0), self.num_classes, self.num_depth, x.size(2), x.size(3)) 192 | return x 193 | 194 | 195 | def build_voxel_head(cfg, input_shape): 196 | name = cfg.MODEL.ROI_VOXEL_HEAD.NAME 197 | return ROI_VOXEL_HEAD_REGISTRY.get(name)(cfg, input_shape) 198 | -------------------------------------------------------------------------------- /meshrcnn/modeling/roi_heads/z_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | 4 | import fvcore.nn.weight_init as weight_init 5 | import numpy as np 6 | import torch 7 | from detectron2.layers import cat, ShapeSpec 8 | from detectron2.utils.registry import Registry 9 | from fvcore.nn import smooth_l1_loss 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | ROI_Z_HEAD_REGISTRY = Registry("ROI_Z_HEAD") 14 | 15 | 16 | @ROI_Z_HEAD_REGISTRY.register() 17 | class FastRCNNFCHead(nn.Module): 18 | """ 19 | A head with several fc layers (each followed by relu). 20 | """ 21 | 22 | def __init__(self, cfg, input_shape: ShapeSpec): 23 | """ 24 | The following attributes are parsed from config: 25 | num_fc: the number of fc layers 26 | fc_dim: the dimension of the fc layers 27 | """ 28 | super().__init__() 29 | 30 | # fmt: off 31 | num_fc = cfg.MODEL.ROI_Z_HEAD.NUM_FC 32 | fc_dim = cfg.MODEL.ROI_Z_HEAD.FC_DIM 33 | cls_agnostic = cfg.MODEL.ROI_Z_HEAD.CLS_AGNOSTIC_Z_REG 34 | num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES 35 | # fmt: on 36 | 37 | self._output_size = ( 38 | input_shape.channels, 39 | input_shape.height, 40 | input_shape.width, 41 | ) 42 | 43 | self.fcs = [] 44 | for k in range(num_fc): 45 | fc = nn.Linear(np.prod(self._output_size), fc_dim) 46 | self.add_module("z_fc{}".format(k + 1), fc) 47 | self.fcs.append(fc) 48 | self._output_size = fc_dim 49 | 50 | num_z_reg_classes = 1 if cls_agnostic else num_classes 51 | self.z_pred = nn.Linear(fc_dim, num_z_reg_classes) 52 | 53 | for layer in self.fcs: 54 | weight_init.c2_xavier_fill(layer) 55 | 56 | nn.init.normal_(self.z_pred.weight, std=0.001) 57 | nn.init.constant_(self.z_pred.bias, 0) 58 | 59 | def forward(self, x): 60 | x = x.view(x.shape[0], np.prod(x.shape[1:])) 61 | for layer in self.fcs: 62 | x = F.relu(layer(x)) 63 | x = self.z_pred(x) 64 | return x 65 | 66 | @property 67 | def output_size(self): 68 | return self._output_size 69 | 70 | 71 | def z_rcnn_loss(z_pred, instances, src_boxes, loss_weight=1.0, smooth_l1_beta=0.0): 72 | """ 73 | Compute the z_pred loss. 74 | 75 | Args: 76 | z_pred (Tensor): A tensor of shape (B, C) or (B, 1) for class-specific or class-agnostic, 77 | where B is the total number of foreground regions in all images, C is the number of foreground classes, 78 | instances (list[Instances]): A list of N Instances, where N is the number of images 79 | in the batch. The ground-truth labels (class, box, mask, 80 | ...) associated with each instance are stored in fields. 81 | 82 | Returns: 83 | loss (Tensor): A scalar tensor containing the loss. 84 | """ 85 | cls_agnostic_z = z_pred.size(1) == 1 86 | total_num = z_pred.size(0) 87 | 88 | gt_classes = [] 89 | gt_dz = [] 90 | for instances_per_image in instances: 91 | if len(instances_per_image) == 0: 92 | continue 93 | if not cls_agnostic_z: 94 | gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) 95 | gt_classes.append(gt_classes_per_image) 96 | 97 | gt_dz.append(instances_per_image.gt_dz) 98 | 99 | if len(gt_dz) == 0: 100 | return z_pred.sum() * 0 101 | 102 | gt_dz = cat(gt_dz, dim=0) 103 | assert gt_dz.numel() > 0 104 | src_heights = src_boxes[:, 3] - src_boxes[:, 1] 105 | dz = torch.log(gt_dz / src_heights) 106 | 107 | if cls_agnostic_z: 108 | z_pred = z_pred[:, 0] 109 | else: 110 | indices = torch.arange(total_num) 111 | gt_classes = cat(gt_classes, dim=0) 112 | z_pred = z_pred[indices, gt_classes] 113 | 114 | loss_z_reg = smooth_l1_loss(z_pred, dz, smooth_l1_beta, reduction="sum") 115 | loss_z_reg = loss_weight * loss_z_reg / gt_classes.numel() 116 | return loss_z_reg 117 | 118 | 119 | def z_rcnn_inference(z_pred, pred_instances): 120 | cls_agnostic = z_pred.size(1) == 1 121 | 122 | if not cls_agnostic: 123 | num_z = z_pred.shape[0] 124 | class_pred = cat([i.pred_classes for i in pred_instances]) 125 | indices = torch.arange(num_z, device=class_pred.device) 126 | z_pred = z_pred[indices, class_pred][:, None] 127 | 128 | z_pred = torch.clamp(z_pred, max=math.log(1000.0 / 16)) 129 | z_pred = torch.exp(z_pred) 130 | 131 | # The multiplication with the heights of the boxes will happen at eval time 132 | # See appendix for more. 133 | 134 | num_boxes_per_image = [len(i) for i in pred_instances] 135 | z_pred = z_pred.split(num_boxes_per_image, dim=0) 136 | 137 | for z_reg, instances in zip(z_pred, pred_instances): 138 | instances.pred_dz = z_reg 139 | 140 | 141 | def build_z_head(cfg, input_shape): 142 | name = cfg.MODEL.ROI_Z_HEAD.NAME 143 | return ROI_Z_HEAD_REGISTRY.get(name)(cfg, input_shape) 144 | -------------------------------------------------------------------------------- /meshrcnn/structures/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .mesh import MeshInstances 3 | from .voxel import VoxelInstances 4 | -------------------------------------------------------------------------------- /meshrcnn/structures/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def crop_mask_within_box(mask, box, mask_size): 7 | """ 8 | Crop the mask content in the given box. 9 | The cropped mask is resized to (mask_size, mask_size). 10 | 11 | This function is used when generating training targets for mask head in Mask R-CNN. 12 | Given original ground-truth masks for an image, new ground-truth mask 13 | training targets in the size of `mask_size x mask_size` 14 | must be provided for each predicted box. This function will be called to 15 | produce such targets. 16 | 17 | Args: 18 | mask (Tensor): A tensor mask image. 19 | box: 4 elements 20 | mask_size (int): 21 | 22 | Returns: 23 | Tensor: ByteTensor of shape (mask_size, mask_size) 24 | """ 25 | # 1. Crop mask 26 | roi = box.clone().int() 27 | cropped_mask = mask[roi[1] : roi[3], roi[0] : roi[2]] 28 | 29 | # 2. Resize mask 30 | cropped_mask = cropped_mask.unsqueeze(0).unsqueeze(0) 31 | cropped_mask = F.interpolate( 32 | cropped_mask, size=(mask_size, mask_size), mode="bilinear" 33 | ) 34 | cropped_mask = cropped_mask.squeeze(0).squeeze(0) 35 | 36 | # 3. Binarize 37 | cropped_mask = (cropped_mask > 0).float() 38 | 39 | return cropped_mask 40 | 41 | 42 | def batch_crop_masks_within_box(masks, boxes, mask_side_len): 43 | """ 44 | Batched version of :func:`crop_mask_within_box`. 45 | 46 | Args: 47 | masks (Masks): store N masks for an image in 2D array format. 48 | boxes (Tensor): store N boxes corresponding to the masks. 49 | mask_side_len (int): the size of the mask. 50 | 51 | Returns: 52 | Tensor: A byte tensor of shape (N, mask_side_len, mask_side_len), where 53 | N is the number of predicted boxes for this image. 54 | """ 55 | device = boxes.device 56 | # Put boxes on the CPU, as the representation for masks is not efficient 57 | # GPU-wise (possibly several small tensors for representing a single instance mask) 58 | boxes = boxes.to(torch.device("cpu")) 59 | masks = masks.to(torch.device("cpu")) 60 | 61 | results = [ 62 | crop_mask_within_box(mask, box, mask_side_len) 63 | for mask, box in zip(masks, boxes) 64 | ] 65 | 66 | if len(results) == 0: 67 | return torch.empty(0, dtype=torch.float32, device=device) 68 | return torch.stack(results, dim=0).to(device=device) 69 | -------------------------------------------------------------------------------- /meshrcnn/structures/mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import meshrcnn.utils.shape as shape_utils 3 | import torch 4 | from pytorch3d.structures import Meshes 5 | 6 | 7 | def batch_crop_meshes_within_box(meshes, boxes, Ks): 8 | """ 9 | Batched version of :func:`crop_mesh_within_box`. 10 | 11 | Args: 12 | mesh (MeshInstances): store N meshes for an image 13 | boxes (Tensor): store N boxes corresponding to the meshes. 14 | Ks (Tensor): store N camera matrices 15 | 16 | Returns: 17 | Meshes: A Meshes structure of N meshes where N is the number of 18 | predicted boxes for this image. 19 | """ 20 | device = boxes.device 21 | im_sizes = Ks[:, 1:] * 2.0 22 | verts = torch.stack([mesh[0] for mesh in meshes], dim=0) 23 | zranges = torch.stack([verts[:, :, 2].min(1)[0], verts[:, :, 2].max(1)[0]], dim=1) 24 | cub3D = shape_utils.box2D_to_cuboid3D(zranges, Ks, boxes.clone(), im_sizes) 25 | txz, tyz = shape_utils.cuboid3D_to_unitbox3D(cub3D) 26 | x, y, z = verts.split(1, dim=2) 27 | xz = torch.cat([x, z], dim=2) 28 | yz = torch.cat([y, z], dim=2) 29 | pxz = txz(xz) 30 | pyz = tyz(yz) 31 | new_verts = torch.stack([pxz[:, :, 0], pyz[:, :, 0], pxz[:, :, 1]], dim=2) 32 | 33 | # align to image 34 | new_verts[:, :, 0] = -new_verts[:, :, 0] 35 | new_verts[:, :, 1] = -new_verts[:, :, 1] 36 | 37 | verts_list = [new_verts[i] for i in range(boxes.shape[0])] 38 | faces_list = [mesh[1] for mesh in meshes] 39 | 40 | return Meshes(verts=verts_list, faces=faces_list).to(device=device) 41 | 42 | 43 | class MeshInstances: 44 | """ 45 | Class to hold meshes of varying topology to interface with Instances 46 | """ 47 | 48 | def __init__(self, meshes): 49 | assert isinstance(meshes, list) 50 | assert torch.is_tensor(meshes[0][0]) 51 | assert torch.is_tensor(meshes[0][1]) 52 | self.data = meshes 53 | 54 | def to(self, device): 55 | to_meshes = [(mesh[0].to(device), mesh[1].to(device)) for mesh in self] 56 | return MeshInstances(to_meshes) 57 | 58 | def __getitem__(self, item): 59 | if isinstance(item, (int, slice)): 60 | selected_data = [self.data[item]] 61 | else: 62 | # advanced indexing on a single dimension 63 | selected_data = [] 64 | if isinstance(item, torch.Tensor) and ( 65 | item.dtype == torch.uint8 or item.dtype == torch.bool 66 | ): 67 | item = item.nonzero() 68 | item = item.squeeze(1) if item.numel() > 0 else item 69 | item = item.tolist() 70 | for i in item: 71 | selected_data.append(self.data[i]) 72 | return MeshInstances(selected_data) 73 | 74 | def __iter__(self): 75 | return iter(self.data) 76 | 77 | def __len__(self): 78 | return len(self.data) 79 | 80 | def __repr__(self): 81 | s = self.__class__.__name__ + "(" 82 | s += "num_instances={}) ".format(len(self)) 83 | return s 84 | -------------------------------------------------------------------------------- /meshrcnn/structures/voxel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import meshrcnn.utils.shape as shape_utils 3 | import torch 4 | 5 | 6 | def batch_crop_voxels_within_box(voxels, boxes, Ks, voxel_side_len): 7 | """ 8 | Batched version of :func:`crop_voxel_within_box`. 9 | 10 | Args: 11 | voxels (VoxelInstances): store N voxels for an image 12 | boxes (Tensor): store N boxes corresponding to the masks. 13 | Ks (Tensor): store N camera matrices 14 | voxel_side_len (int): the size of the voxel. 15 | 16 | Returns: 17 | Tensor: A byte tensor of shape (N, voxel_side_len, voxel_side_len, voxel_side_len), 18 | where N is the number of predicted boxes for this image. 19 | """ 20 | device = boxes.device 21 | im_sizes = Ks[:, 1:] * 2.0 22 | voxels_tensor = torch.stack(voxels.data, 0) 23 | zranges = torch.stack( 24 | [voxels_tensor[:, :, 2].min(1)[0], voxels_tensor[:, :, 2].max(1)[0]], dim=1 25 | ) 26 | cub3D = shape_utils.box2D_to_cuboid3D(zranges, Ks, boxes.clone(), im_sizes) 27 | txz, tyz = shape_utils.cuboid3D_to_unitbox3D(cub3D) 28 | x, y, z = voxels_tensor.split(1, dim=2) 29 | xz = torch.cat([x, z], dim=2) 30 | yz = torch.cat([y, z], dim=2) 31 | pxz = txz(xz) 32 | pyz = tyz(yz) 33 | cropped_verts = torch.stack([pxz[:, :, 0], pyz[:, :, 0], pxz[:, :, 1]], dim=2) 34 | results = [ 35 | verts2voxel(cropped_vert, [voxel_side_len] * 3).permute(2, 0, 1) 36 | for cropped_vert in cropped_verts 37 | ] 38 | 39 | if len(results) == 0: 40 | return torch.empty(0, dtype=torch.float32, device=device) 41 | return torch.stack(results, dim=0).to(device=device) 42 | 43 | 44 | class VoxelInstances: 45 | """ 46 | Class to hold voxels of varying dimensions to interface with Instances 47 | """ 48 | 49 | def __init__(self, voxels): 50 | assert isinstance(voxels, list) 51 | assert torch.is_tensor(voxels[0]) 52 | self.data = voxels 53 | 54 | def to(self, device): 55 | to_voxels = [voxel.to(device) for voxel in self] 56 | return VoxelInstances(to_voxels) 57 | 58 | def __getitem__(self, item): 59 | if isinstance(item, (int, slice)): 60 | selected_data = [self.data[item]] 61 | else: 62 | # advanced indexing on a single dimension 63 | selected_data = [] 64 | if isinstance(item, torch.Tensor) and ( 65 | item.dtype == torch.uint8 or item.dtype == torch.bool 66 | ): 67 | item = item.nonzero() 68 | item = item.squeeze(1) if item.numel() > 0 else item 69 | item = item.tolist() 70 | for i in item: 71 | selected_data.append(self.data[i]) 72 | return VoxelInstances(selected_data) 73 | 74 | def __iter__(self): 75 | return iter(self.data) 76 | 77 | def __len__(self): 78 | return len(self.data) 79 | 80 | def __repr__(self): 81 | s = self.__class__.__name__ + "(" 82 | s += "num_instances={}) ".format(len(self)) 83 | return s 84 | 85 | 86 | def downsample(vox_in, n, use_max=True): 87 | """ 88 | Downsample a 3-d tensor n times 89 | Inputs: 90 | - vox_in (Tensor): HxWxD tensor 91 | - n (int): number of times to downsample each dimension 92 | - use_max (bool): use maximum value when downsampling. If set to False 93 | the mean value is used. 94 | Output: 95 | - vox_out (Tensor): (H/n)x(W/n)x(D/n) tensor 96 | """ 97 | dimy = vox_in.size(0) // n 98 | dimx = vox_in.size(1) // n 99 | dimz = vox_in.size(2) // n 100 | vox_out = torch.zeros((dimy, dimx, dimz)) 101 | for x in range(dimx): 102 | for y in range(dimy): 103 | for z in range(dimz): 104 | subx = x * n 105 | suby = y * n 106 | subz = z * n 107 | subvox = vox_in[suby : suby + n, subx : subx + n, subz : subz + n] 108 | if use_max: 109 | vox_out[y, x, z] = torch.max(subvox) 110 | else: 111 | vox_out[y, x, z] = torch.mean(subvox) 112 | return vox_out 113 | 114 | 115 | def verts2voxel(verts, voxel_size): 116 | def valid_coords(x, y, z, vx_size): 117 | Hv, Wv, Zv = vx_size 118 | indx = (x >= 0) * (x < Wv) 119 | indy = (y >= 0) * (y < Hv) 120 | indz = (z >= 0) * (z < Zv) 121 | return indx * indy * indz 122 | 123 | Hv, Wv, Zv = voxel_size 124 | # create original voxel of size VxVxV 125 | orig_voxel = torch.zeros((Hv, Wv, Zv), dtype=torch.float32) 126 | 127 | x = (verts[:, 0] + 1) * (Wv - 1) / 2 128 | x = x.long() 129 | y = (verts[:, 1] + 1) * (Hv - 1) / 2 130 | y = y.long() 131 | z = (verts[:, 2] + 1) * (Zv - 1) / 2 132 | z = z.long() 133 | 134 | keep = valid_coords(x, y, z, voxel_size) 135 | x = x[keep] 136 | y = y[keep] 137 | z = z[keep] 138 | 139 | orig_voxel[y, x, z] = 1.0 140 | 141 | # align with image coordinate system 142 | flip_idx = torch.tensor(list(range(Hv)[::-1])) 143 | orig_voxel = orig_voxel.index_select(0, flip_idx) 144 | flip_idx = torch.tensor(list(range(Wv)[::-1])) 145 | orig_voxel = orig_voxel.index_select(1, flip_idx) 146 | return orig_voxel 147 | 148 | 149 | def normalize_verts(verts): 150 | # centering and normalization 151 | min, _ = torch.min(verts, 0) 152 | min_x, min_y, min_z = min 153 | max, _ = torch.max(verts, 0) 154 | max_x, max_y, max_z = max 155 | x_ctr = (min_x + max_x) / 2.0 156 | y_ctr = (min_y + max_y) / 2.0 157 | z_ctr = (min_z + max_z) / 2.0 158 | x_scale = 2.0 / (max_x - min_x) 159 | y_scale = 2.0 / (max_y - min_y) 160 | z_scale = 2.0 / (max_z - min_z) 161 | verts[:, 0] = (verts[:, 0] - x_ctr) * x_scale 162 | verts[:, 1] = (verts[:, 1] - y_ctr) * y_scale 163 | verts[:, 2] = (verts[:, 2] - z_ctr) * z_scale 164 | return verts 165 | -------------------------------------------------------------------------------- /meshrcnn/utils/VOCap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | 5 | def compute_ap(scores, labels, npos, device=None): 6 | if device is None: 7 | device = scores.device 8 | 9 | if len(scores) == 0: 10 | return 0.0 11 | tp = labels == 1 12 | fp = labels == 0 13 | sc = scores 14 | assert tp.size() == sc.size() 15 | assert tp.size() == fp.size() 16 | sc, ind = torch.sort(sc, descending=True) 17 | tp = tp[ind].to(dtype=torch.float32) 18 | fp = fp[ind].to(dtype=torch.float32) 19 | tp = torch.cumsum(tp, dim=0) 20 | fp = torch.cumsum(fp, dim=0) 21 | 22 | # # Compute precision/recall 23 | rec = tp / npos 24 | prec = tp / (fp + tp) 25 | ap = xVOCap(rec, prec, device) 26 | 27 | return ap 28 | 29 | 30 | def xVOCap(rec, prec, device): 31 | z = rec.new_zeros((1)) 32 | o = rec.new_ones((1)) 33 | mrec = torch.cat((z, rec, o)) 34 | mpre = torch.cat((z, prec, z)) 35 | 36 | for i in range(len(mpre) - 2, -1, -1): 37 | mpre[i] = max(mpre[i], mpre[i + 1]) 38 | 39 | I = (mrec[1:] != mrec[0:-1]).nonzero()[:, 0] + 1 40 | ap = 0 41 | for i in I: 42 | ap = ap + (mrec[i] - mrec[i - 1]) * mpre[i] 43 | return ap 44 | -------------------------------------------------------------------------------- /meshrcnn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from . import model_zoo # registers pathhandlers 3 | -------------------------------------------------------------------------------- /meshrcnn/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | from collections import defaultdict 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from pytorch3d.ops import knn_gather, knn_points, sample_points_from_meshes 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @torch.no_grad() 13 | def compare_meshes( 14 | pred_meshes, 15 | gt_meshes, 16 | num_samples=10000, 17 | scale="gt-10", 18 | thresholds=None, 19 | reduce=True, 20 | eps=1e-8, 21 | ): 22 | """ 23 | Compute evaluation metrics to compare meshes. We currently compute the 24 | following metrics: 25 | 26 | - L2 Chamfer distance 27 | - Normal consistency 28 | - Absolute normal consistency 29 | - Precision at various thresholds 30 | - Recall at various thresholds 31 | - F1 score at various thresholds 32 | 33 | Inputs: 34 | - pred_meshes (Meshes): Contains N predicted meshes 35 | - gt_meshes (Meshes): Contains 1 or N ground-truth meshes. If gt_meshes 36 | contains 1 mesh, it is replicated N times. 37 | - num_samples: The number of samples to take on the surface of each mesh. 38 | This can be one of the following: 39 | - (int): Take that many uniform samples from the surface of the mesh 40 | - 'verts': Use the vertex positions as samples for each mesh 41 | - A tuple of length 2: To use different sampling strategies for the 42 | predicted and ground-truth meshes (respectively). 43 | - scale: How to scale the predicted and ground-truth meshes before comparing. 44 | This can be one of the following: 45 | - (float): Multiply the vertex positions of both meshes by this value 46 | - A tuple of two floats: Multiply the vertex positions of the predicted 47 | and ground-truth meshes by these two different values 48 | - A string of the form 'gt-[SCALE]', where [SCALE] is a float literal. 49 | In this case, each (predicted, ground-truth) pair is scaled differently, 50 | so that bounding box of the (rescaled) ground-truth mesh has longest 51 | edge length [SCALE]. 52 | - thresholds: The distance thresholds to use when computing precision, recall, 53 | and F1 scores. 54 | - reduce: If True, then return the average of each metric over the batch; 55 | otherwise return the value of each metric between each predicted and 56 | ground-truth mesh. 57 | - eps: Small constant for numeric stability when computing F1 scores. 58 | 59 | Returns: 60 | - metrics: A dictionary mapping metric names to their values. If reduce is 61 | True then the values are the average value of the metric over the batch; 62 | otherwise the values are Tensors of shape (N,). 63 | """ 64 | if thresholds is None: 65 | thresholds = [0.1, 0.2, 0.3, 0.4, 0.5] 66 | 67 | pred_meshes, gt_meshes = _scale_meshes(pred_meshes, gt_meshes, scale) 68 | 69 | if isinstance(num_samples, tuple): 70 | num_samples_pred, num_samples_gt = num_samples 71 | else: 72 | num_samples_pred = num_samples_gt = num_samples 73 | 74 | pred_points, pred_normals = _sample_meshes(pred_meshes, num_samples_pred) 75 | gt_points, gt_normals = _sample_meshes(gt_meshes, num_samples_gt) 76 | if pred_points is None: 77 | logger.info("WARNING: Sampling predictions failed during eval") 78 | return None 79 | elif gt_points is None: 80 | logger.info("WARNING: Sampling GT failed during eval") 81 | return None 82 | 83 | if len(gt_meshes) == 1: 84 | # (1, S, 3) to (N, S, 3) 85 | gt_points = gt_points.expand(len(pred_meshes), -1, -1) 86 | gt_normals = gt_normals.expand(len(pred_meshes), -1, -1) 87 | 88 | if torch.is_tensor(pred_points) and torch.is_tensor(gt_points): 89 | # We can compute all metrics at once in this case 90 | metrics = _compute_sampling_metrics( 91 | pred_points, pred_normals, gt_points, gt_normals, thresholds, eps 92 | ) 93 | else: 94 | # Slow path when taking vert samples from non-equisized meshes; we need 95 | # to iterate over the batch 96 | metrics = defaultdict(list) 97 | for cur_points_pred, cur_points_gt in zip(pred_points, gt_points): 98 | cur_metrics = _compute_sampling_metrics( 99 | cur_points_pred[None], None, cur_points_gt[None], None, thresholds, eps 100 | ) 101 | for k, v in cur_metrics.items(): 102 | metrics[k].append(v.item()) 103 | metrics = {k: torch.tensor(vs) for k, vs in metrics.items()} 104 | 105 | if reduce: 106 | # Average each metric over the batch 107 | metrics = {k: v.mean().item() for k, v in metrics.items()} 108 | 109 | return metrics 110 | 111 | 112 | def _scale_meshes(pred_meshes, gt_meshes, scale): 113 | if isinstance(scale, float): 114 | # Assume scale is a single scalar to use for both preds and GT 115 | pred_scale = gt_scale = scale 116 | elif isinstance(scale, tuple): 117 | # Rescale preds and GT with different scalars 118 | pred_scale, gt_scale = scale 119 | elif scale.startswith("gt-"): 120 | # Rescale both preds and GT so that the largest edge length of each GT 121 | # mesh is target 122 | target = float(scale[3:]) 123 | bbox = gt_meshes.get_bounding_boxes() # (N, 3, 2) 124 | long_edge = (bbox[:, :, 1] - bbox[:, :, 0]).max(dim=1)[0] # (N,) 125 | scale = target / long_edge 126 | if scale.numel() == 1: 127 | scale = scale.expand(len(pred_meshes)) 128 | pred_scale, gt_scale = scale, scale 129 | else: 130 | raise ValueError("Invalid scale: %r" % scale) 131 | pred_meshes = pred_meshes.scale_verts(pred_scale) 132 | gt_meshes = gt_meshes.scale_verts(gt_scale) 133 | return pred_meshes, gt_meshes 134 | 135 | 136 | def _sample_meshes(meshes, num_samples): 137 | """ 138 | Helper to either sample points uniformly from the surface of a mesh 139 | (with normals), or take the verts of the mesh as samples. 140 | 141 | Inputs: 142 | - meshes: A MeshList 143 | - num_samples: An integer, or the string 'verts' 144 | 145 | Outputs: 146 | - verts: Either a Tensor of shape (N, S, 3) if we take the same number of 147 | samples from each mesh; otherwise a list of length N, whose ith element 148 | is a Tensor of shape (S_i, 3) 149 | - normals: Either a Tensor of shape (N, S, 3) or None if we take verts 150 | as samples. 151 | """ 152 | if num_samples == "verts": 153 | normals = None 154 | if meshes.equisized: 155 | verts = meshes.verts_batch 156 | else: 157 | verts = meshes.verts_list 158 | else: 159 | verts, normals = sample_points_from_meshes( 160 | meshes, num_samples, return_normals=True 161 | ) 162 | return verts, normals 163 | 164 | 165 | def _compute_sampling_metrics( 166 | pred_points, pred_normals, gt_points, gt_normals, thresholds, eps 167 | ): 168 | """ 169 | Compute metrics that are based on sampling points and normals: 170 | 171 | - L2 Chamfer distance 172 | - Precision at various thresholds 173 | - Recall at various thresholds 174 | - F1 score at various thresholds 175 | - Normal consistency (if normals are provided) 176 | - Absolute normal consistency (if normals are provided) 177 | 178 | Inputs: 179 | - pred_points: Tensor of shape (N, S, 3) giving coordinates of sampled points 180 | for each predicted mesh 181 | - pred_normals: Tensor of shape (N, S, 3) giving normals of points sampled 182 | from the predicted mesh, or None if such normals are not available 183 | - gt_points: Tensor of shape (N, S, 3) giving coordinates of sampled points 184 | for each ground-truth mesh 185 | - gt_normals: Tensor of shape (N, S, 3) giving normals of points sampled from 186 | the ground-truth verts, or None of such normals are not available 187 | - thresholds: Distance thresholds to use for precision / recall / F1 188 | - eps: epsilon value to handle numerically unstable F1 computation 189 | 190 | Returns: 191 | - metrics: A dictionary where keys are metric names and values are Tensors of 192 | shape (N,) giving the value of the metric for the batch 193 | """ 194 | metrics = {} 195 | lengths_pred = torch.full( 196 | (pred_points.shape[0],), 197 | pred_points.shape[1], 198 | dtype=torch.int64, 199 | device=pred_points.device, 200 | ) 201 | lengths_gt = torch.full( 202 | (gt_points.shape[0],), 203 | gt_points.shape[1], 204 | dtype=torch.int64, 205 | device=gt_points.device, 206 | ) 207 | 208 | # For each predicted point, find its neareast-neighbor GT point 209 | knn_pred = knn_points( 210 | pred_points, gt_points, lengths1=lengths_pred, lengths2=lengths_gt, K=1 211 | ) 212 | # Compute L1 and L2 distances between each pred point and its nearest GT 213 | pred_to_gt_dists2 = knn_pred.dists[..., 0] # (N, S) 214 | pred_to_gt_dists = pred_to_gt_dists2.sqrt() # (N, S) 215 | if gt_normals is not None: 216 | pred_normals_near = knn_gather(gt_normals, knn_pred.idx, lengths_gt)[ 217 | ..., 0, : 218 | ] # (N, S, 3) 219 | else: 220 | pred_normals_near = None 221 | 222 | # For each GT point, find its nearest-neighbor predicted point 223 | knn_gt = knn_points( 224 | gt_points, pred_points, lengths1=lengths_gt, lengths2=lengths_pred, K=1 225 | ) 226 | # Compute L1 and L2 dists between each GT point and its nearest pred point 227 | gt_to_pred_dists2 = knn_gt.dists[..., 0] # (N, S) 228 | gt_to_pred_dists = gt_to_pred_dists2.sqrt() # (N, S) 229 | 230 | if pred_normals is not None: 231 | gt_normals_near = knn_gather(pred_normals, knn_gt.idx, lengths_pred)[ 232 | ..., 0, : 233 | ] # (N, S, 3) 234 | else: 235 | gt_normals_near = None 236 | 237 | # Compute L2 chamfer distances 238 | chamfer_l2 = pred_to_gt_dists2.mean(dim=1) + gt_to_pred_dists2.mean(dim=1) 239 | metrics["Chamfer-L2"] = chamfer_l2 240 | 241 | # Compute normal consistency and absolute normal consistance only if 242 | # we actually got normals for both meshes 243 | if pred_normals is not None and gt_normals is not None: 244 | pred_to_gt_cos = F.cosine_similarity(pred_normals, pred_normals_near, dim=2) 245 | gt_to_pred_cos = F.cosine_similarity(gt_normals, gt_normals_near, dim=2) 246 | 247 | pred_to_gt_cos_sim = pred_to_gt_cos.mean(dim=1) 248 | pred_to_gt_abs_cos_sim = pred_to_gt_cos.abs().mean(dim=1) 249 | gt_to_pred_cos_sim = gt_to_pred_cos.mean(dim=1) 250 | gt_to_pred_abs_cos_sim = gt_to_pred_cos.abs().mean(dim=1) 251 | normal_dist = 0.5 * (pred_to_gt_cos_sim + gt_to_pred_cos_sim) 252 | abs_normal_dist = 0.5 * (pred_to_gt_abs_cos_sim + gt_to_pred_abs_cos_sim) 253 | metrics["NormalConsistency"] = normal_dist 254 | metrics["AbsNormalConsistency"] = abs_normal_dist 255 | 256 | # Compute precision, recall, and F1 based on L2 distances 257 | for t in thresholds: 258 | precision = 100.0 * (pred_to_gt_dists < t).float().mean(dim=1) 259 | recall = 100.0 * (gt_to_pred_dists < t).float().mean(dim=1) 260 | f1 = (2.0 * precision * recall) / (precision + recall + eps) 261 | metrics["Precision@%f" % t] = precision 262 | metrics["Recall@%f" % t] = recall 263 | metrics["F1@%f" % t] = f1 264 | 265 | # Move all metrics to CPU 266 | metrics = {k: v.cpu() for k, v in metrics.items()} 267 | return metrics 268 | -------------------------------------------------------------------------------- /meshrcnn/utils/model_zoo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from detectron2.utils.file_io import PathHandler, PathManager 4 | 5 | __all__ = ["MeshRCNNHandler"] 6 | 7 | 8 | class MeshRCNNHandler(PathHandler): 9 | """ 10 | Resolve anything that's in Mesh R-CNN's model zoo. 11 | """ 12 | 13 | PREFIX = "meshrcnn://" 14 | MESHRCNN_PREFIX = "https://dl.fbaipublicfiles.com/meshrcnn/pix3d/" 15 | 16 | def _get_supported_prefixes(self): 17 | return [self.PREFIX] 18 | 19 | def _get_local_path(self, path): 20 | name = path[len(self.PREFIX) :] 21 | return PathManager.get_local_path(self.MESHRCNN_PREFIX + name) 22 | 23 | def _open(self, path, mode="r", **kwargs): 24 | return PathManager.open(self._get_local_path(path), mode, **kwargs) 25 | 26 | 27 | PathManager.register_handler(MeshRCNNHandler()) 28 | -------------------------------------------------------------------------------- /meshrcnn/utils/projtransform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | 5 | class ProjectiveTransform: 6 | """ 7 | Projective Transformation in PyTorch: 8 | Follows a similar design to skimage.ProjectiveTransform 9 | https://github.com/scikit-image/scikit-image/blob/master/skimage/transform/_geometric.py#L494 10 | The implementation assumes batched representations, 11 | so every tensor is assumed to be of shape batch x dim1 x dim2 x etc. 12 | """ 13 | 14 | def __init__(self, matrix=None): 15 | if matrix is None: 16 | # default to an identity transform 17 | matrix = torch.eye(3).view(1, 3, 3) 18 | if matrix.ndim != 3 and matrix.shape[-1] != 3 and matrix.shape[-2] != 3: 19 | raise ValueError("Shape of transformation matrix should be Bx3x3") 20 | self.params = matrix 21 | 22 | @property 23 | def _inv_matrix(self): 24 | return torch.inverse(self.params) 25 | 26 | def _apply_mat(self, coords, matrix): 27 | """ 28 | Applies matrix transformation 29 | Input: 30 | coords: FloatTensor of shape BxNx2 31 | matrix: FloatTensor of shape Bx3x3 32 | Returns: 33 | new_coords: FloatTensor of shape BxNx2 34 | """ 35 | if coords.shape[0] != matrix.shape[0]: 36 | raise ValueError("Mismatch in the batch dimension") 37 | if coords.ndim != 3 or coords.shape[-1] != 2: 38 | raise ValueError("Input tensors should be of shape BxNx2") 39 | 40 | # append 1s, shape: BxNx2 -> BxNx3 41 | src = torch.cat( 42 | [ 43 | coords, 44 | torch.ones( 45 | (coords.shape[0], coords.shape[1], 1), 46 | device=coords.device, 47 | dtype=torch.float32, 48 | ), 49 | ], 50 | dim=2, 51 | ) 52 | dst = torch.bmm(matrix, src.transpose(1, 2)).transpose(1, 2) 53 | # rescale to homogeneous coordinates 54 | dst[:, :, 0] /= dst[:, :, 2] 55 | dst[:, :, 1] /= dst[:, :, 2] 56 | 57 | return dst[:, :, :2] 58 | 59 | def __call__(self, coords): 60 | """Apply forward transformation. 61 | Input: 62 | coords: FloatTensor of shape BxNx2 63 | Output: 64 | coords: FloateTensor of shape BxNx2 65 | """ 66 | return self._apply_mat(coords, self.params) 67 | 68 | def inverse(self, coords): 69 | """Apply inverse transformation. 70 | Input: 71 | coords: FloatTensor of shape BxNx2 72 | Output: 73 | coords: FloatTensor of shape BxNx2 74 | """ 75 | return self._apply_mat(coords, self._inv_matrix) 76 | 77 | def estimate(self, src, dst, method="svd"): 78 | """ 79 | Estimates the matrix to transform src to dst. 80 | Input: 81 | src: FloatTensor of shape BxNx2 82 | dst: FloatTensor of shape BxNx2 83 | method: Specifies the method to solve the linear system 84 | """ 85 | if src.shape != dst.shape: 86 | raise ValueError("src and dst tensors but be of same shape") 87 | if src.ndim != 3 or src.shape[-1] != 2: 88 | raise ValueError("Input should be of shape BxNx2") 89 | device = src.device 90 | batch = src.shape[0] 91 | 92 | # Center and normalize image points for better numerical stability. 93 | try: 94 | src_matrix, src = _center_and_normalize_points(src) 95 | dst_matrix, dst = _center_and_normalize_points(dst) 96 | except ZeroDivisionError: 97 | self.params = torch.zeros((batch, 3, 3), device=device) 98 | return False 99 | 100 | xs = src[:, :, 0] 101 | ys = src[:, :, 1] 102 | xd = dst[:, :, 0] 103 | yd = dst[:, :, 1] 104 | rows = src.shape[1] 105 | 106 | # params: a0, a1, a2, b0, b1, b2, c0, c1, (c3=1) 107 | A = torch.zeros((batch, rows * 2, 9), device=device, dtype=torch.float32) 108 | A[:, :rows, 0] = xs 109 | A[:, :rows, 1] = ys 110 | A[:, :rows, 2] = 1 111 | A[:, :rows, 6] = -xd * xs 112 | A[:, :rows, 7] = -xd * ys 113 | A[:, rows:, 3] = xs 114 | A[:, rows:, 4] = ys 115 | A[:, rows:, 5] = 1 116 | A[:, rows:, 6] = -yd * xs 117 | A[:, rows:, 7] = -yd * ys 118 | A[:, :rows, 8] = xd 119 | A[:, rows:, 8] = yd 120 | 121 | if method == "svd": 122 | A = A.cpu() # faster computation in cpu 123 | # Solve for the nullspace of the constraint matrix. 124 | _, _, V = torch.svd(A, some=False) 125 | V = V.transpose(1, 2) 126 | 127 | H = torch.ones((batch, 9), device=device, dtype=torch.float32) 128 | H[:, :-1] = -V[:, -1, :-1] / V[:, -1, -1].view(-1, 1) 129 | H = H.reshape(batch, 3, 3) 130 | # H[:, 2, 2] = 1.0 131 | elif method == "least_sqr": 132 | A = A.cpu() # faster computation in cpu 133 | # Least square solution 134 | x, _ = torch.solve(-A[:, :, -1].view(-1, 1), A[:, :, :-1]) 135 | H = torch.cat([-x, torch.ones((1, 1), dtype=x.dtype, device=device)]) 136 | H = H.reshape(3, 3) 137 | elif method == "inv": 138 | # x = inv(A'A)*A'*b 139 | invAtA = torch.inverse(torch.mm(A[:, :-1].t(), A[:, :-1])) 140 | Atb = torch.mm(A[:, :-1].t(), -A[:, -1].view(-1, 1)) 141 | x = torch.mm(invAtA, Atb) 142 | H = torch.cat([-x, torch.ones((1, 1), dtype=x.dtype, device=device)]) 143 | H = H.reshape(3, 3) 144 | else: 145 | raise ValueError("method {} undefined".format(method)) 146 | 147 | # De-center and de-normalize 148 | self.params = torch.bmm(torch.bmm(torch.inverse(dst_matrix), H), src_matrix) 149 | return True 150 | 151 | 152 | def _center_and_normalize_points(points): 153 | """Center and normalize points. 154 | The points are transformed in a two-step procedure that is expressed 155 | as a transformation matrix. The matrix of the resulting points is usually 156 | better conditioned than the matrix of the original points. 157 | Center the points, such that the new coordinate system has its 158 | origin at the centroid of the image points. 159 | Normalize the points, such that the mean distance from the points 160 | to the origin of the coordinate system is sqrt(2). 161 | Inputs: 162 | points: FloatTensor of shape BxNx2 of the coordinates of the image points. 163 | Outputs: 164 | matrix: FloatTensor of shape Bx3x3 of the transformation matrix to obtain 165 | the new points. 166 | new_points: FloatTensor of shape BxNx2 of the transformed image points. 167 | 168 | References 169 | ---------- 170 | .. [1] Hartley, Richard I. "In defense of the eight-point algorithm." 171 | Pattern Analysis and Machine Intelligence, IEEE Transactions on 19.6 172 | (1997): 580-593. 173 | """ 174 | device = points.device 175 | centroid = torch.mean(points, 1, keepdim=True) 176 | 177 | rms = torch.sqrt( 178 | torch.sum((points - centroid) ** 2.0, dim=(1, 2)) / points.shape[1] 179 | ) 180 | 181 | norm_factor = torch.sqrt(torch.tensor([2.0], device=device)) / rms 182 | 183 | matrix = torch.zeros((points.shape[0], 3, 3), dtype=torch.float32, device=device) 184 | matrix[:, 0, 0] = norm_factor 185 | matrix[:, 0, 2] = -norm_factor * centroid[:, 0, 0] 186 | matrix[:, 1, 1] = norm_factor 187 | matrix[:, 1, 2] = -norm_factor * centroid[:, 0, 1] 188 | matrix[:, 2, 2] = 1.0 189 | 190 | # matrix = torch.tensor( 191 | # [ 192 | # [norm_factor, 0.0, -norm_factor * centroid[0]], 193 | # [0.0, norm_factor, -norm_factor * centroid[1]], 194 | # [0.0, 0.0, 1.0], 195 | # ], device=device, dtype=torch.float32) 196 | 197 | pointsh = torch.cat( 198 | [ 199 | points, 200 | torch.ones( 201 | (points.shape[0], points.shape[1], 1), 202 | device=device, 203 | dtype=torch.float32, 204 | ), 205 | ], 206 | dim=2, 207 | ) 208 | 209 | new_pointsh = torch.bmm(matrix, pointsh.transpose(1, 2)).transpose(1, 2) 210 | 211 | new_points = new_pointsh[:, :, :2] 212 | new_points[:, :, 0] /= new_pointsh[:, :, 2] 213 | new_points[:, :, 1] /= new_pointsh[:, :, 2] 214 | 215 | return matrix, new_points 216 | -------------------------------------------------------------------------------- /meshrcnn/utils/shape.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | import torch 4 | from detectron2.utils.file_io import PathManager 5 | 6 | from meshrcnn.utils.projtransform import ProjectiveTransform 7 | from scipy import io as sio 8 | 9 | 10 | def cuboid3D_to_unitbox3D(cub3D): 11 | device = cub3D.device 12 | dst = torch.tensor( 13 | [[-1.0, -1.0], [1.0, -1.0], [-1.0, 1.0], [1.0, 1.0]], 14 | device=device, 15 | dtype=torch.float32, 16 | ) 17 | dst = dst.view(1, 4, 2).expand(cub3D.shape[0], -1, -1) 18 | # for (x,z) plane 19 | txz = ProjectiveTransform() 20 | src = torch.stack( 21 | [ 22 | torch.stack([cub3D[:, 0, 0], cub3D[:, 4, 0]], dim=1), 23 | torch.stack([cub3D[:, 0, 1], cub3D[:, 4, 0]], dim=1), 24 | torch.stack([cub3D[:, 2, 0], cub3D[:, 4, 1]], dim=1), 25 | torch.stack([cub3D[:, 2, 1], cub3D[:, 4, 1]], dim=1), 26 | ], 27 | dim=1, 28 | ) 29 | if not txz.estimate(src, dst): 30 | raise ValueError("Estimate failed") 31 | # for (y,z) plane 32 | tyz = ProjectiveTransform() 33 | src = torch.stack( 34 | [ 35 | torch.stack([cub3D[:, 1, 0], cub3D[:, 4, 0]], dim=1), 36 | torch.stack([cub3D[:, 1, 1], cub3D[:, 4, 0]], dim=1), 37 | torch.stack([cub3D[:, 3, 0], cub3D[:, 4, 1]], dim=1), 38 | torch.stack([cub3D[:, 3, 1], cub3D[:, 4, 1]], dim=1), 39 | ], 40 | dim=1, 41 | ) 42 | if not tyz.estimate(src, dst): 43 | raise ValueError("Estimate failed") 44 | return txz, tyz 45 | 46 | 47 | def box2D_to_cuboid3D(zranges, Ks, boxes, im_sizes): 48 | device = boxes.device 49 | if boxes.shape[0] != Ks.shape[0] != zranges.shape[0]: 50 | raise ValueError("Ks, boxes and zranges must have the same batch dimension") 51 | if zranges.shape[1] != 2: 52 | raise ValueError("zrange must have two entries per example") 53 | w, h = im_sizes.t() 54 | sx, px, py = Ks.t() 55 | sy = sx 56 | x1, y1, x2, y2 = boxes.t() 57 | # transform 2d box from image coordinates to world coordinates 58 | x1 = w - 1 - x1 - px 59 | y1 = h - 1 - y1 - py 60 | x2 = w - 1 - x2 - px 61 | y2 = h - 1 - y2 - py 62 | 63 | cub3D = torch.zeros((boxes.shape[0], 5, 2), device=device, dtype=torch.float32) 64 | for i in range(2): 65 | z = zranges[:, i] 66 | x3D_min = x2 * z / sx 67 | x3D_max = x1 * z / sx 68 | y3D_min = y2 * z / sy 69 | y3D_max = y1 * z / sy 70 | cub3D[:, i * 2 + 0, 0] = x3D_min 71 | cub3D[:, i * 2 + 0, 1] = x3D_max 72 | cub3D[:, i * 2 + 1, 0] = y3D_min 73 | cub3D[:, i * 2 + 1, 1] = y3D_max 74 | cub3D[:, 4, 0] = zranges[:, 0] 75 | cub3D[:, 4, 1] = zranges[:, 1] 76 | return cub3D 77 | 78 | 79 | def transform_verts(verts, R, t): 80 | """ 81 | Transforms verts with rotation R and translation t 82 | Inputs: 83 | - verts (tensor): of shape (N, 3) 84 | - R (tensor): of shape (3, 3) or None 85 | - t (tensor): of shape (3,) or None 86 | Outputs: 87 | - rotated_verts (tensor): of shape (N, 3) 88 | """ 89 | rot_verts = verts.clone().t() 90 | if R is not None: 91 | assert R.dim() == 2 92 | assert R.size(0) == 3 and R.size(1) == 3 93 | rot_verts = torch.mm(R, rot_verts) 94 | if t is not None: 95 | assert t.dim() == 1 96 | assert t.size(0) == 3 97 | rot_verts = rot_verts + t.unsqueeze(1) 98 | rot_verts = rot_verts.t() 99 | return rot_verts 100 | 101 | 102 | def read_voxel(voxelfile): 103 | """ 104 | Reads voxel and transforms it in the form of verts 105 | """ 106 | with PathManager.open(voxelfile, "rb") as f: 107 | voxel = sio.loadmat(f)["voxel"] 108 | voxel = np.rot90(voxel, k=3, axes=(1, 2)) 109 | verts = np.argwhere(voxel > 0).astype(np.float32, copy=False) 110 | 111 | # centering and normalization 112 | min_x = np.min(verts[:, 0]) 113 | max_x = np.max(verts[:, 0]) 114 | min_y = np.min(verts[:, 1]) 115 | max_y = np.max(verts[:, 1]) 116 | min_z = np.min(verts[:, 2]) 117 | max_z = np.max(verts[:, 2]) 118 | verts[:, 0] = verts[:, 0] - (max_x + min_x) / 2 119 | verts[:, 1] = verts[:, 1] - (max_y + min_y) / 2 120 | verts[:, 2] = verts[:, 2] - (max_z + min_z) / 2 121 | scale = np.sqrt(np.max(np.sum(verts**2, axis=1))) * 2 122 | verts /= scale 123 | verts = torch.tensor(verts, dtype=torch.float32) 124 | 125 | return verts 126 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | known_standard_library=numpy,setuptools 5 | known_myself=meshrcnn 6 | known_third_party=fvcore,detectron2,pytorch3d,torch,pycocotools,yacs,termcolor,scipy,simplejson,matplotlib 7 | no_lines_before=STDLIB,THIRDPARTY 8 | sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER 9 | default_section=FIRSTPARTY 10 | 11 | [mypy] 12 | python_version=3.6 13 | ignore_missing_imports = True 14 | warn_unused_configs = True 15 | disallow_untyped_defs = True 16 | check_untyped_defs = True 17 | warn_unused_ignores = True 18 | warn_redundant_casts = True 19 | show_column_numbers = True 20 | follow_imports = silent 21 | allow_redefinition = True 22 | ; Require all functions to be annotated 23 | disallow_incomplete_defs = True 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="meshrcnn", 7 | version="1.0", 8 | author="FAIR", 9 | url="https://github.com/facebookresearch/meshrcnn", 10 | description="Code for Mesh R-CNN", 11 | packages=find_packages(exclude=("configs", "tests")), 12 | install_requires=["torchvision>=0.4", "fvcore", "detectron2", "pytorch3d"], 13 | ) 14 | -------------------------------------------------------------------------------- /shapenet/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .config import get_shapenet_cfg 3 | -------------------------------------------------------------------------------- /shapenet/config/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from fvcore.common.config import CfgNode as CN 5 | 6 | 7 | # ----------------------------------------------------------------------------- 8 | # Config definition 9 | # ----------------------------------------------------------------------------- 10 | def get_shapenet_cfg(): 11 | cfg = CN() 12 | cfg.MODEL = CN() 13 | cfg.MODEL.BACKBONE = "resnet50" 14 | cfg.MODEL.VOXEL_ON = False 15 | cfg.MODEL.MESH_ON = False 16 | 17 | # ------------------------------------------------------------------------ # 18 | # Checkpoint 19 | # ------------------------------------------------------------------------ # 20 | cfg.MODEL.CHECKPOINT = "" # path to checkpoint 21 | 22 | # ------------------------------------------------------------------------ # 23 | # Voxel Head 24 | # ------------------------------------------------------------------------ # 25 | cfg.MODEL.VOXEL_HEAD = CN() 26 | # The number of convs in the voxel head and the number of channels 27 | cfg.MODEL.VOXEL_HEAD.NUM_CONV = 0 28 | cfg.MODEL.VOXEL_HEAD.CONV_DIM = 256 29 | # Normalization method for the convolution layers. Options: "" (no norm), "GN" 30 | cfg.MODEL.VOXEL_HEAD.NORM = "" 31 | # The number of depth channels for the predicted voxels 32 | cfg.MODEL.VOXEL_HEAD.VOXEL_SIZE = 28 33 | cfg.MODEL.VOXEL_HEAD.LOSS_WEIGHT = 1.0 34 | cfg.MODEL.VOXEL_HEAD.CUBIFY_THRESH = 0.0 35 | # voxel only iterations 36 | cfg.MODEL.VOXEL_HEAD.VOXEL_ONLY_ITERS = 100 37 | 38 | # ------------------------------------------------------------------------ # 39 | # Mesh Head 40 | # ------------------------------------------------------------------------ # 41 | cfg.MODEL.MESH_HEAD = CN() 42 | cfg.MODEL.MESH_HEAD.NAME = "VoxMeshHead" 43 | # Numer of stages 44 | cfg.MODEL.MESH_HEAD.NUM_STAGES = 1 45 | cfg.MODEL.MESH_HEAD.NUM_GRAPH_CONVS = 1 # per stage 46 | cfg.MODEL.MESH_HEAD.GRAPH_CONV_DIM = 256 47 | cfg.MODEL.MESH_HEAD.GRAPH_CONV_INIT = "normal" 48 | # Mesh sampling 49 | cfg.MODEL.MESH_HEAD.GT_NUM_SAMPLES = 5000 50 | cfg.MODEL.MESH_HEAD.PRED_NUM_SAMPLES = 5000 51 | # loss weights 52 | cfg.MODEL.MESH_HEAD.CHAMFER_LOSS_WEIGHT = 1.0 53 | cfg.MODEL.MESH_HEAD.NORMALS_LOSS_WEIGHT = 1.0 54 | cfg.MODEL.MESH_HEAD.EDGE_LOSS_WEIGHT = 1.0 55 | # Init ico_sphere level (only for when voxel_on is false) 56 | cfg.MODEL.MESH_HEAD.ICO_SPHERE_LEVEL = -1 57 | 58 | # ------------------------------------------------------------------------ # 59 | # Solver 60 | # ------------------------------------------------------------------------ # 61 | cfg.SOLVER = CN() 62 | cfg.SOLVER.LR_SCHEDULER_NAME = "constant" # {'constant', 'cosine'} 63 | cfg.SOLVER.BATCH_SIZE = 32 64 | cfg.SOLVER.BATCH_SIZE_EVAL = 8 65 | cfg.SOLVER.NUM_EPOCHS = 25 66 | cfg.SOLVER.BASE_LR = 0.0001 67 | cfg.SOLVER.OPTIMIZER = "adam" # {'sgd', 'adam'} 68 | cfg.SOLVER.MOMENTUM = 0.9 69 | cfg.SOLVER.WARMUP_ITERS = 500 70 | cfg.SOLVER.WARMUP_FACTOR = 0.1 71 | cfg.SOLVER.CHECKPOINT_PERIOD = 24949 # in iters 72 | cfg.SOLVER.LOGGING_PERIOD = 50 # in iters 73 | # stable training 74 | cfg.SOLVER.SKIP_LOSS_THRESH = 50.0 75 | cfg.SOLVER.LOSS_SKIP_GAMMA = 0.9 76 | 77 | # ------------------------------------------------------------------------ # 78 | # Datasets 79 | # ------------------------------------------------------------------------ # 80 | cfg.DATASETS = CN() 81 | cfg.DATASETS.NAME = "shapenet" 82 | 83 | # ------------------------------------------------------------------------ # 84 | # Misc options 85 | # ------------------------------------------------------------------------ # 86 | # Directory where output files are written 87 | cfg.OUTPUT_DIR = "./output" 88 | 89 | return cfg 90 | -------------------------------------------------------------------------------- /shapenet/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .build_data_loader import build_data_loader 3 | from .builtin import register_shapenet 4 | -------------------------------------------------------------------------------- /shapenet/data/build_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import json 3 | import logging 4 | 5 | import torch 6 | from detectron2.utils import comm 7 | from fvcore.common.file_io import PathManager 8 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Subset 9 | from torch.utils.data.distributed import DistributedSampler 10 | 11 | from .mesh_vox import MeshVoxDataset 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def _identity(x): 17 | return x 18 | 19 | 20 | def build_data_loader( 21 | cfg, 22 | dataset, 23 | split_name, 24 | num_workers=4, 25 | multigpu=False, 26 | shuffle=True, 27 | num_samples=None, 28 | ): 29 | batch_size = cfg.SOLVER.BATCH_SIZE 30 | return_mesh, sample_online, return_id_str = False, False, False 31 | if split_name in ["train_eval", "val"]: 32 | batch_size = cfg.SOLVER.BATCH_SIZE_EVAL 33 | return_mesh = True 34 | sample_online = True 35 | elif split_name == "test": 36 | batch_size = cfg.SOLVER.BATCH_SIZE_EVAL 37 | return_mesh = True 38 | return_id_str = True 39 | 40 | splits_file = cfg.DATASETS.SPLITS_FILE 41 | 42 | with PathManager.open(splits_file, "r") as f: 43 | splits = json.load(f) 44 | if split_name is not None: 45 | if split_name in ["train", "train_eval"]: 46 | split = splits["train"] 47 | else: 48 | split = splits[split_name] 49 | 50 | num_gpus = 1 51 | if multigpu: 52 | num_gpus = comm.get_world_size() 53 | assert batch_size % num_gpus == 0, "num_gpus must divide batch size" 54 | batch_size //= num_gpus 55 | 56 | logger.info('Building dataset for split "%s"' % split_name) 57 | if dataset == "MeshVox": 58 | dset = MeshVoxDataset( 59 | cfg.DATASETS.DATA_DIR, 60 | split=split, 61 | num_samples=cfg.MODEL.MESH_HEAD.GT_NUM_SAMPLES, 62 | voxel_size=cfg.MODEL.VOXEL_HEAD.VOXEL_SIZE, 63 | return_mesh=return_mesh, 64 | sample_online=sample_online, 65 | return_id_str=return_id_str, 66 | ) 67 | collate_fn = MeshVoxDataset.collate_fn 68 | else: 69 | raise ValueError("Dataset %s not registered" % dataset) 70 | 71 | loader_kwargs = { 72 | "batch_size": batch_size, 73 | "collate_fn": collate_fn, 74 | "num_workers": num_workers, 75 | } 76 | 77 | if hasattr(dset, "postprocess"): 78 | postprocess_fn = dset.postprocess 79 | else: 80 | postprocess_fn = _identity 81 | 82 | # In this case we want to subsample num_samples elements from the underlying 83 | # dataset. We can wrap the dataset in a Subset dataset, so we need to tell 84 | # it which indices of the underlying dataset to use. Either take the first 85 | # or a random subset depending on whether shuffling was requested. 86 | if num_samples is not None: 87 | if shuffle: 88 | idx = torch.randperm(len(dset))[:num_samples] 89 | else: 90 | idx = torch.arange(min(num_samples, len(dset))) 91 | dset = Subset(dset, idx) 92 | 93 | # Right now we only do evaluation with a single GPU on the main process, 94 | # so only use a DistributedSampler for the training set. 95 | # TODO: Change this once we do evaluation on multiple GPUs 96 | if multigpu: 97 | assert shuffle, "Cannot sample sequentially with distributed training" 98 | sampler = DistributedSampler(dset) 99 | else: 100 | if shuffle: 101 | sampler = RandomSampler(dset) 102 | else: 103 | sampler = SequentialSampler(dset) 104 | loader_kwargs["sampler"] = sampler 105 | loader = DataLoader(dset, **loader_kwargs) 106 | 107 | # WARNING this is really gross! We want to access the underlying 108 | # dataset.postprocess method so we can run it on the main Python process, 109 | # but the dataset might be wrapped in a Subset instance, or may not even 110 | # define a postprocess method at all. To get around this we monkeypatch 111 | # the DataLoader object with the postprocess function we want; this will 112 | # be a bound method of the underlying Dataset, or an identity function. 113 | # Maybe it would be cleaner to subclass DataLoader for this? 114 | loader.postprocess = postprocess_fn 115 | 116 | return loader 117 | -------------------------------------------------------------------------------- /shapenet/data/builtin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | This file registers pre-defined datasets at hard-coded paths 4 | """ 5 | 6 | import os 7 | 8 | # each dataset contains name : (data_dir, splits_file) 9 | _PREDEFINED_SPLITS_SHAPENET = { 10 | "shapenet": ("shapenet/ShapeNetV1processed", "shapenet/pix2mesh_splits_val05.json") 11 | } 12 | 13 | 14 | def register_shapenet(dataset_name, root="datasets"): 15 | if dataset_name not in _PREDEFINED_SPLITS_SHAPENET.keys(): 16 | raise ValueError("%s not registered" % dataset_name) 17 | data_dir = _PREDEFINED_SPLITS_SHAPENET[dataset_name][0] 18 | splits_file = _PREDEFINED_SPLITS_SHAPENET[dataset_name][1] 19 | return os.path.join(root, data_dir), os.path.join(root, splits_file) 20 | -------------------------------------------------------------------------------- /shapenet/data/mesh_vox.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import json 3 | import logging 4 | import os 5 | 6 | import torch 7 | 8 | import torchvision.transforms as T 9 | from fvcore.common.file_io import PathManager 10 | from PIL import Image 11 | from pytorch3d.ops import sample_points_from_meshes 12 | from pytorch3d.structures import Meshes 13 | from shapenet.data.utils import imagenet_preprocess 14 | from shapenet.utils.coords import project_verts, SHAPENET_MAX_ZMAX, SHAPENET_MIN_ZMIN 15 | from torch.utils.data import Dataset 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class MeshVoxDataset(Dataset): 21 | def __init__( 22 | self, 23 | data_dir, 24 | normalize_images=True, 25 | split=None, 26 | return_mesh=False, 27 | voxel_size=32, 28 | num_samples=5000, 29 | sample_online=False, 30 | in_memory=False, 31 | return_id_str=False, 32 | ): 33 | super(MeshVoxDataset, self).__init__() 34 | if not return_mesh and sample_online: 35 | raise ValueError("Cannot sample online without returning mesh") 36 | self.data_dir = data_dir 37 | self.return_mesh = return_mesh 38 | self.voxel_size = voxel_size 39 | self.num_samples = num_samples 40 | self.sample_online = sample_online 41 | self.return_id_str = return_id_str 42 | 43 | self.synset_ids = [] 44 | self.model_ids = [] 45 | self.image_ids = [] 46 | self.mid_to_samples = {} 47 | 48 | transform = [T.ToTensor()] 49 | if normalize_images: 50 | transform.append(imagenet_preprocess()) 51 | self.transform = T.Compose(transform) 52 | 53 | summary_json = os.path.join(data_dir, "summary.json") 54 | with PathManager.open(summary_json, "r") as f: 55 | summary = json.load(f) 56 | for sid in summary: 57 | logger.info("Starting synset %s" % sid) 58 | allowed_mids = None 59 | if split is not None: 60 | if sid not in split: 61 | logger.info("Skipping synset %s" % sid) 62 | continue 63 | elif isinstance(split[sid], list): 64 | allowed_mids = set(split[sid]) 65 | elif isinstance(split, dict): 66 | allowed_mids = set(split[sid].keys()) 67 | for mid, num_imgs in summary[sid].items(): 68 | if allowed_mids is not None and mid not in allowed_mids: 69 | continue 70 | allowed_iids = None 71 | if split is not None and isinstance(split[sid], dict): 72 | allowed_iids = set(split[sid][mid]) 73 | if not sample_online and in_memory: 74 | samples_path = os.path.join(data_dir, sid, mid, "samples.pt") 75 | with PathManager.open(samples_path, "rb") as f: 76 | samples = torch.load(f) 77 | self.mid_to_samples[mid] = samples 78 | for iid in range(num_imgs): 79 | if allowed_iids is None or iid in allowed_iids: 80 | self.synset_ids.append(sid) 81 | self.model_ids.append(mid) 82 | self.image_ids.append(iid) 83 | 84 | def __len__(self): 85 | return len(self.synset_ids) 86 | 87 | def __getitem__(self, idx): 88 | sid = self.synset_ids[idx] 89 | mid = self.model_ids[idx] 90 | iid = self.image_ids[idx] 91 | 92 | # Always read metadata for this model; TODO cache in __init__? 93 | metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt") 94 | with PathManager.open(metadata_path, "rb") as f: 95 | metadata = torch.load(f) 96 | K = metadata["intrinsic"] 97 | RT = metadata["extrinsics"][iid] 98 | img_path = metadata["image_list"][iid] 99 | img_path = os.path.join(self.data_dir, sid, mid, "images", img_path) 100 | 101 | # Load the image 102 | with PathManager.open(img_path, "rb") as f: 103 | img = Image.open(f).convert("RGB") 104 | img = self.transform(img) 105 | 106 | # Maybe read mesh 107 | verts, faces = None, None 108 | if self.return_mesh: 109 | mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt") 110 | with PathManager.open(mesh_path, "rb") as f: 111 | mesh_data = torch.load(f) 112 | verts, faces = mesh_data["verts"], mesh_data["faces"] 113 | verts = project_verts(verts, RT) 114 | 115 | # Maybe use cached samples 116 | points, normals = None, None 117 | if not self.sample_online: 118 | samples = self.mid_to_samples.get(mid, None) 119 | if samples is None: 120 | # They were not cached in memory, so read off disk 121 | samples_path = os.path.join(self.data_dir, sid, mid, "samples.pt") 122 | with PathManager.open(samples_path, "rb") as f: 123 | samples = torch.load(f) 124 | points = samples["points_sampled"] 125 | normals = samples["normals_sampled"] 126 | idx = torch.randperm(points.shape[0])[: self.num_samples] 127 | points, normals = points[idx], normals[idx] 128 | points = project_verts(points, RT) 129 | normals = normals.mm(RT[:3, :3].t()) # Only rotate, don't translate 130 | 131 | voxels, P = None, None 132 | if self.voxel_size > 0: 133 | # Use precomputed voxels if we have them, otherwise return voxel_coords 134 | # and we will compute voxels in postprocess 135 | voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid) 136 | voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file) 137 | if PathManager.isfile(voxel_file): 138 | with PathManager.open(voxel_file, "rb") as f: 139 | voxels = torch.load(f) 140 | else: 141 | voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt") 142 | with PathManager.open(voxel_path, "rb") as f: 143 | voxel_data = torch.load(f) 144 | voxels = voxel_data["voxel_coords"] 145 | P = K.mm(RT) 146 | 147 | id_str = "%s-%s-%02d" % (sid, mid, iid) 148 | return img, verts, faces, points, normals, voxels, P, id_str 149 | 150 | def _voxelize(self, voxel_coords, P): 151 | V = self.voxel_size 152 | device = voxel_coords.device 153 | voxel_coords = project_verts(voxel_coords, P) 154 | 155 | # In the original coordinate system, the object fits in a unit sphere 156 | # centered at the origin. Thus after transforming by RT, it will fit 157 | # in a unit sphere centered at T = RT[:, 3] = (0, 0, RT[2, 3]). We need 158 | # to figure out what the range of z will be after being further 159 | # transformed by K. We can work this out explicitly. 160 | # z0 = RT[2, 3].item() 161 | # zp, zm = z0 - 0.5, z0 + 0.5 162 | # k22, k23 = K[2, 2].item(), K[2, 3].item() 163 | # k32, k33 = K[3, 2].item(), K[3, 3].item() 164 | # zmin = (zm * k22 + k23) / (zm * k32 + k33) 165 | # zmax = (zp * k22 + k23) / (zp * k32 + k33) 166 | 167 | # Using the actual zmin and zmax of the model is bad because we need them 168 | # to perform the inverse transform, which transform voxels back into world 169 | # space for refinement or evaluation. Instead we use an empirical min and 170 | # max over the dataset; that way it is consistent for all images. 171 | zmin = SHAPENET_MIN_ZMIN 172 | zmax = SHAPENET_MAX_ZMAX 173 | 174 | # Once we know zmin and zmax, we need to adjust the z coordinates so the 175 | # range [zmin, zmax] instead runs from [-1, 1] 176 | m = 2.0 / (zmax - zmin) 177 | b = -2.0 * zmin / (zmax - zmin) - 1 178 | voxel_coords[:, 2].mul_(m).add_(b) 179 | voxel_coords[:, 1].mul_(-1) # Flip y 180 | 181 | # Now voxels are in [-1, 1]^3; map to [0, V-1)^3 182 | voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0) 183 | voxel_coords = voxel_coords.round().to(torch.int64) 184 | valid = (0 <= voxel_coords) * (voxel_coords < V) 185 | valid = valid[:, 0] * valid[:, 1] * valid[:, 2] 186 | x, y, z = voxel_coords.unbind(dim=1) 187 | x, y, z = x[valid], y[valid], z[valid] 188 | voxels = torch.zeros(V, V, V, dtype=torch.int64, device=device) 189 | voxels[z, y, x] = 1 190 | 191 | return voxels 192 | 193 | @staticmethod 194 | def collate_fn(batch): 195 | imgs, verts, faces, points, normals, voxels, Ps, id_strs = zip(*batch) 196 | imgs = torch.stack(imgs, dim=0) 197 | if verts[0] is not None and faces[0] is not None: 198 | # TODO(gkioxari) Meshes should accept tuples 199 | meshes = Meshes(verts=list(verts), faces=list(faces)) 200 | else: 201 | meshes = None 202 | if points[0] is not None and normals[0] is not None: 203 | points = torch.stack(points, dim=0) 204 | normals = torch.stack(normals, dim=0) 205 | else: 206 | points, normals = None, None 207 | if voxels[0] is None: 208 | voxels = None 209 | Ps = None 210 | elif voxels[0].dim() == 2: 211 | # They are voxel coords 212 | Ps = torch.stack(Ps, dim=0) 213 | elif voxels[0].dim() == 3: 214 | # They are actual voxels 215 | voxels = torch.stack(voxels, dim=0) 216 | return imgs, meshes, points, normals, voxels, Ps, id_strs 217 | 218 | def postprocess(self, batch, device=None): 219 | if device is None: 220 | device = torch.device("cuda") 221 | imgs, meshes, points, normals, voxels, Ps, id_strs = batch 222 | imgs = imgs.to(device) 223 | if meshes is not None: 224 | meshes = meshes.to(device) 225 | if points is not None and normals is not None: 226 | points = points.to(device) 227 | normals = normals.to(device) 228 | else: 229 | points, normals = sample_points_from_meshes( 230 | meshes, num_samples=self.num_samples, return_normals=True 231 | ) 232 | if voxels is not None: 233 | if torch.is_tensor(voxels): 234 | # We used cached voxels on disk, just cast and return 235 | voxels = voxels.to(device) 236 | else: 237 | # We got a list of voxel_coords, and need to compute voxels on-the-fly 238 | voxel_coords = voxels 239 | Ps = Ps.to(device) 240 | voxels = [] 241 | for i, cur_voxel_coords in enumerate(voxel_coords): 242 | cur_voxel_coords = cur_voxel_coords.to(device) 243 | cur_voxels = self._voxelize(cur_voxel_coords, Ps[i]) 244 | voxels.append(cur_voxels) 245 | voxels = torch.stack(voxels, dim=0) 246 | 247 | if self.return_id_str: 248 | return imgs, meshes, points, normals, voxels, id_strs 249 | else: 250 | return imgs, meshes, points, normals, voxels 251 | -------------------------------------------------------------------------------- /shapenet/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torchvision.transforms as T 3 | 4 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 5 | IMAGENET_STD = [0.229, 0.224, 0.225] 6 | 7 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] 8 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] 9 | 10 | 11 | def imagenet_preprocess(): 12 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 13 | 14 | 15 | def rescale(x): 16 | lo, hi = x.min(), x.max() 17 | return x.sub(lo).div(hi - lo) 18 | 19 | 20 | def imagenet_deprocess(rescale_image=True): 21 | transforms = [ 22 | T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD), 23 | T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]), 24 | ] 25 | if rescale_image: 26 | transforms.append(rescale) 27 | return T.Compose(transforms) 28 | 29 | 30 | def image_to_numpy(img): 31 | return img.detach().cpu().mul(255).byte().numpy().transpose(1, 2, 0) 32 | -------------------------------------------------------------------------------- /shapenet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .eval import evaluate_split, evaluate_test, evaluate_test_p2m 3 | -------------------------------------------------------------------------------- /shapenet/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | from collections import defaultdict 4 | 5 | import detectron2.utils.comm as comm 6 | import numpy as np 7 | 8 | import shapenet.utils.vis as vis_utils 9 | import torch 10 | from detectron2.evaluation import inference_context 11 | 12 | from meshrcnn.utils.metrics import compare_meshes 13 | from shapenet.data.utils import image_to_numpy, imagenet_deprocess 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | @torch.no_grad() 19 | def evaluate_test(model, data_loader, vis_preds=False): 20 | """ 21 | This function evaluates the model on the dataset defined by data_loader. 22 | The metrics reported are described in Table 2 of our paper. 23 | """ 24 | # Note that all eval runs on main process 25 | assert comm.is_main_process() 26 | deprocess = imagenet_deprocess(rescale_image=False) 27 | device = torch.device("cuda:0") 28 | # evaluation 29 | class_names = { 30 | "02828884": "bench", 31 | "03001627": "chair", 32 | "03636649": "lamp", 33 | "03691459": "speaker", 34 | "04090263": "firearm", 35 | "04379243": "table", 36 | "04530566": "watercraft", 37 | "02691156": "plane", 38 | "02933112": "cabinet", 39 | "02958343": "car", 40 | "03211117": "monitor", 41 | "04256520": "couch", 42 | "04401088": "cellphone", 43 | } 44 | 45 | num_instances = {i: 0 for i in class_names} 46 | chamfer = {i: 0 for i in class_names} 47 | normal = {i: 0 for i in class_names} 48 | f1_01 = {i: 0 for i in class_names} 49 | f1_03 = {i: 0 for i in class_names} 50 | f1_05 = {i: 0 for i in class_names} 51 | 52 | num_batch_evaluated = 0 53 | for batch in data_loader: 54 | batch = data_loader.postprocess(batch, device) 55 | imgs, meshes_gt, _, _, _, id_strs = batch 56 | sids = [id_str.split("-")[0] for id_str in id_strs] 57 | for sid in sids: 58 | num_instances[sid] += 1 59 | 60 | with inference_context(model): 61 | voxel_scores, meshes_pred = model(imgs) 62 | cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt, reduce=False) 63 | cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh().cpu() 64 | cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh().cpu() 65 | 66 | for i, sid in enumerate(sids): 67 | chamfer[sid] += cur_metrics["Chamfer-L2"][i].item() 68 | normal[sid] += cur_metrics["AbsNormalConsistency"][i].item() 69 | f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item() 70 | f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item() 71 | f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item() 72 | 73 | if vis_preds: 74 | img = image_to_numpy(deprocess(imgs[i])) 75 | vis_utils.visualize_prediction( 76 | id_strs[i], img, meshes_pred[-1][i], "/tmp/output" 77 | ) 78 | 79 | num_batch_evaluated += 1 80 | logger.info( 81 | "Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader)) 82 | ) 83 | 84 | vis_utils.print_instances_class_histogram( 85 | num_instances, 86 | class_names, 87 | { 88 | "chamfer": chamfer, 89 | "normal": normal, 90 | "f1_01": f1_01, 91 | "f1_03": f1_03, 92 | "f1_05": f1_05, 93 | }, 94 | ) 95 | 96 | 97 | @torch.no_grad() 98 | def evaluate_test_p2m(model, data_loader): 99 | """ 100 | This function evaluates the model on the dataset defined by data_loader. 101 | The metrics reported are described in Table 1 of our paper, following previous 102 | reported approaches (like Pixel2Mesh - p2m), where meshes are 103 | rescaled by a factor of 0.57. See the paper for more details. 104 | """ 105 | assert comm.is_main_process() 106 | device = torch.device("cuda:0") 107 | # evaluation 108 | class_names = { 109 | "02828884": "bench", 110 | "03001627": "chair", 111 | "03636649": "lamp", 112 | "03691459": "speaker", 113 | "04090263": "firearm", 114 | "04379243": "table", 115 | "04530566": "watercraft", 116 | "02691156": "plane", 117 | "02933112": "cabinet", 118 | "02958343": "car", 119 | "03211117": "monitor", 120 | "04256520": "couch", 121 | "04401088": "cellphone", 122 | } 123 | 124 | num_instances = {i: 0 for i in class_names} 125 | chamfer = {i: 0 for i in class_names} 126 | normal = {i: 0 for i in class_names} 127 | f1_1e_4 = {i: 0 for i in class_names} 128 | f1_2e_4 = {i: 0 for i in class_names} 129 | 130 | num_batch_evaluated = 0 131 | for batch in data_loader: 132 | batch = data_loader.postprocess(batch, device) 133 | imgs, meshes_gt, _, _, _, id_strs = batch 134 | sids = [id_str.split("-")[0] for id_str in id_strs] 135 | for sid in sids: 136 | num_instances[sid] += 1 137 | 138 | with inference_context(model): 139 | voxel_scores, meshes_pred = model(imgs) 140 | # NOTE that for the F1 thresholds we take the square root of 1e-4 & 2e-4 141 | # as `compare_meshes` returns the euclidean distance (L2) of two pointclouds. 142 | # In Pixel2Mesh, the squared L2 (L2^2) is computed instead. 143 | # i.e. (L2^2 < τ) <=> (L2 < sqrt(τ)) 144 | cur_metrics = compare_meshes( 145 | meshes_pred[-1], 146 | meshes_gt, 147 | scale=0.57, 148 | thresholds=[0.01, 0.014142], 149 | reduce=False, 150 | ) 151 | cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh().cpu() 152 | cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh().cpu() 153 | 154 | for i, sid in enumerate(sids): 155 | chamfer[sid] += cur_metrics["Chamfer-L2"][i].item() 156 | normal[sid] += cur_metrics["AbsNormalConsistency"][i].item() 157 | f1_1e_4[sid] += cur_metrics["F1@%f" % 0.01][i].item() 158 | f1_2e_4[sid] += cur_metrics["F1@%f" % 0.014142][i].item() 159 | 160 | num_batch_evaluated += 1 161 | logger.info( 162 | "Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader)) 163 | ) 164 | 165 | vis_utils.print_instances_class_histogram_p2m( 166 | num_instances, 167 | class_names, 168 | {"chamfer": chamfer, "normal": normal, "f1_1e_4": f1_1e_4, "f1_2e_4": f1_2e_4}, 169 | ) 170 | 171 | 172 | @torch.no_grad() 173 | def evaluate_split( 174 | model, 175 | loader, 176 | max_predictions=-1, 177 | num_predictions_keep=10, 178 | prefix="", 179 | store_predictions=False, 180 | ): 181 | """ 182 | This function is used to report validation performance during training. 183 | """ 184 | # Note that all eval runs on main process 185 | assert comm.is_main_process() 186 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 187 | model = model.module 188 | 189 | device = torch.device("cuda:0") 190 | num_predictions = 0 191 | num_predictions_kept = 0 192 | predictions = defaultdict(list) 193 | metrics = defaultdict(list) 194 | deprocess = imagenet_deprocess(rescale_image=False) 195 | for batch in loader: 196 | batch = loader.postprocess(batch, device) 197 | imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch 198 | voxel_scores, meshes_pred = model(imgs) 199 | 200 | # Only compute metrics for the final predicted meshes, not intermediates 201 | cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt) 202 | if cur_metrics is None: 203 | continue 204 | for k, v in cur_metrics.items(): 205 | metrics[k].append(v) 206 | 207 | # Store input images and predicted meshes 208 | if store_predictions: 209 | N = imgs.shape[0] 210 | for i in range(N): 211 | if num_predictions_kept >= num_predictions_keep: 212 | break 213 | num_predictions_kept += 1 214 | 215 | img = image_to_numpy(deprocess(imgs[i])) 216 | predictions["%simg_input" % prefix].append(img) 217 | for level, cur_meshes_pred in enumerate(meshes_pred): 218 | verts, faces = cur_meshes_pred.get_mesh(i) 219 | verts_key = "%sverts_pred_%d" % (prefix, level) 220 | faces_key = "%sfaces_pred_%d" % (prefix, level) 221 | predictions[verts_key].append(verts.cpu().numpy()) 222 | predictions[faces_key].append(faces.cpu().numpy()) 223 | 224 | num_predictions += len(meshes_gt) 225 | logger.info("Evaluated %d predictions so far" % num_predictions) 226 | if 0 < max_predictions <= num_predictions: 227 | break 228 | 229 | # Average numeric metrics, and concatenate images 230 | metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()} 231 | if store_predictions: 232 | img_key = "%simg_input" % prefix 233 | predictions[img_key] = np.stack(predictions[img_key], axis=0) 234 | 235 | return metrics, predictions 236 | -------------------------------------------------------------------------------- /shapenet/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .heads import MeshLoss 3 | from .mesh_arch import build_model 4 | -------------------------------------------------------------------------------- /shapenet/modeling/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch.nn as nn 3 | 4 | import torchvision 5 | 6 | 7 | class ResNetBackbone(nn.Module): 8 | def __init__(self, net): 9 | super(ResNetBackbone, self).__init__() 10 | self.stem = nn.Sequential(net.conv1, net.bn1, net.relu, net.maxpool) 11 | self.stage1 = net.layer1 12 | self.stage2 = net.layer2 13 | self.stage3 = net.layer3 14 | self.stage4 = net.layer4 15 | 16 | def forward(self, imgs): 17 | feats = self.stem(imgs) 18 | conv1 = self.stage1(feats) # 18, 34: 64 19 | conv2 = self.stage2(conv1) 20 | conv3 = self.stage3(conv2) 21 | conv4 = self.stage4(conv3) 22 | 23 | return [conv1, conv2, conv3, conv4] 24 | 25 | 26 | _FEAT_DIMS = { 27 | "resnet18": (64, 128, 256, 512), 28 | "resnet34": (64, 128, 256, 512), 29 | "resnet50": (256, 512, 1024, 2048), 30 | "resnet101": (256, 512, 1024, 2048), 31 | "resnet152": (256, 512, 1024, 2048), 32 | } 33 | 34 | 35 | def build_backbone(name, pretrained=True): 36 | resnets = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] 37 | if name in resnets and name in _FEAT_DIMS: 38 | cnn = getattr(torchvision.models, name)(pretrained=pretrained) 39 | backbone = ResNetBackbone(cnn) 40 | feat_dims = _FEAT_DIMS[name] 41 | return backbone, feat_dims 42 | else: 43 | raise ValueError('Unrecognized backbone type "%s"' % name) 44 | -------------------------------------------------------------------------------- /shapenet/modeling/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .mesh_head import MeshRefinementHead 3 | from .mesh_loss import MeshLoss 4 | from .voxel_head import VoxelHead 5 | -------------------------------------------------------------------------------- /shapenet/modeling/heads/mesh_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | from pytorch3d.ops import GraphConv, SubdivideMeshes, vert_align 5 | 6 | from shapenet.utils.coords import project_verts 7 | from torch.nn import functional as F 8 | 9 | 10 | class MeshRefinementHead(nn.Module): 11 | def __init__(self, cfg): 12 | super(MeshRefinementHead, self).__init__() 13 | 14 | # fmt: off 15 | input_channels = cfg.MODEL.MESH_HEAD.COMPUTED_INPUT_CHANNELS 16 | self.num_stages = cfg.MODEL.MESH_HEAD.NUM_STAGES 17 | hidden_dim = cfg.MODEL.MESH_HEAD.GRAPH_CONV_DIM 18 | stage_depth = cfg.MODEL.MESH_HEAD.NUM_GRAPH_CONVS 19 | graph_conv_init = cfg.MODEL.MESH_HEAD.GRAPH_CONV_INIT 20 | # fmt: on 21 | 22 | self.stages = nn.ModuleList() 23 | for i in range(self.num_stages): 24 | vert_feat_dim = 0 if i == 0 else hidden_dim 25 | stage = MeshRefinementStage( 26 | input_channels, 27 | vert_feat_dim, 28 | hidden_dim, 29 | stage_depth, 30 | gconv_init=graph_conv_init, 31 | ) 32 | self.stages.append(stage) 33 | 34 | def forward(self, img_feats, meshes, P=None, subdivide=False): 35 | """ 36 | Args: 37 | img_feats (tensor): Tensor of shape (N, C, H, W) giving image features, 38 | or a list of such tensors. 39 | meshes (Meshes): Meshes class of N meshes 40 | P (tensor): Tensor of shape (N, 4, 4) giving projection matrix to be applied 41 | to vertex positions before vert-align. If None, don't project verts. 42 | subdivide (bool): Flag whether to subdivice the mesh after refinement 43 | 44 | Returns: 45 | output_meshes (list of Meshes): A list with S Meshes, where S is the 46 | number of refinement stages 47 | """ 48 | output_meshes = [] 49 | vert_feats = None 50 | for i, stage in enumerate(self.stages): 51 | meshes, vert_feats = stage(img_feats, meshes, vert_feats, P) 52 | output_meshes.append(meshes) 53 | if subdivide and i < self.num_stages - 1: 54 | subdivide = SubdivideMeshes() 55 | meshes, vert_feats = subdivide(meshes, feats=vert_feats) 56 | return output_meshes 57 | 58 | 59 | class MeshRefinementStage(nn.Module): 60 | def __init__( 61 | self, img_feat_dim, vert_feat_dim, hidden_dim, stage_depth, gconv_init="normal" 62 | ): 63 | """ 64 | Args: 65 | img_feat_dim (int): Dimension of features we will get from vert_align 66 | vert_feat_dim (int): Dimension of vert_feats we will receive from the 67 | previous stage; can be 0 68 | hidden_dim (int): Output dimension for graph-conv layers 69 | stage_depth (int): Number of graph-conv layers to use 70 | gconv_init (int): Specifies weight initialization for graph-conv layers 71 | """ 72 | super(MeshRefinementStage, self).__init__() 73 | 74 | self.bottleneck = nn.Linear(img_feat_dim, hidden_dim) 75 | 76 | self.vert_offset = nn.Linear(hidden_dim + 3, 3) 77 | 78 | self.gconvs = nn.ModuleList() 79 | for i in range(stage_depth): 80 | if i == 0: 81 | input_dim = hidden_dim + vert_feat_dim + 3 82 | else: 83 | input_dim = hidden_dim + 3 84 | gconv = GraphConv(input_dim, hidden_dim, init=gconv_init, directed=False) 85 | self.gconvs.append(gconv) 86 | 87 | # initialization for bottleneck and vert_offset 88 | nn.init.normal_(self.bottleneck.weight, mean=0.0, std=0.01) 89 | nn.init.constant_(self.bottleneck.bias, 0) 90 | 91 | nn.init.zeros_(self.vert_offset.weight) 92 | nn.init.constant_(self.vert_offset.bias, 0) 93 | 94 | def forward(self, img_feats, meshes, vert_feats=None, P=None): 95 | """ 96 | Args: 97 | img_feats (tensor): Features from the backbone 98 | meshes (Meshes): Initial meshes which will get refined 99 | vert_feats (tensor): Features from the previous refinement stage 100 | P (tensor): Tensor of shape (N, 4, 4) giving projection matrix to be applied 101 | to vertex positions before vert-align. If None, don't project verts. 102 | """ 103 | # Project verts if we are making predictions in world space 104 | verts_padded_to_packed_idx = meshes.verts_padded_to_packed_idx() 105 | 106 | if P is not None: 107 | vert_pos_padded = project_verts(meshes.verts_padded(), P) 108 | vert_pos_packed = _padded_to_packed( 109 | vert_pos_padded, verts_padded_to_packed_idx 110 | ) 111 | else: 112 | vert_pos_padded = meshes.verts_padded() 113 | vert_pos_packed = meshes.verts_packed() 114 | 115 | # flip y coordinate 116 | device, dtype = vert_pos_padded.device, vert_pos_padded.dtype 117 | factor = torch.tensor([1, -1, 1], device=device, dtype=dtype).view(1, 1, 3) 118 | vert_pos_padded = vert_pos_padded * factor 119 | # Get features from the image 120 | vert_align_feats = vert_align(img_feats, vert_pos_padded) 121 | vert_align_feats = _padded_to_packed( 122 | vert_align_feats, verts_padded_to_packed_idx 123 | ) 124 | vert_align_feats = F.relu(self.bottleneck(vert_align_feats)) 125 | 126 | # Prepare features for first graph conv layer 127 | first_layer_feats = [vert_align_feats, vert_pos_packed] 128 | if vert_feats is not None: 129 | first_layer_feats.append(vert_feats) 130 | vert_feats = torch.cat(first_layer_feats, dim=1) 131 | 132 | # Run graph conv layers 133 | for gconv in self.gconvs: 134 | vert_feats_nopos = F.relu(gconv(vert_feats, meshes.edges_packed())) 135 | vert_feats = torch.cat([vert_feats_nopos, vert_pos_packed], dim=1) 136 | 137 | # Predict a new mesh by offsetting verts 138 | vert_offsets = torch.tanh(self.vert_offset(vert_feats)) 139 | meshes_out = meshes.offset_verts(vert_offsets) 140 | 141 | return meshes_out, vert_feats_nopos 142 | 143 | 144 | def _padded_to_packed(x, idx): 145 | """ 146 | Convert features from padded to packed. 147 | 148 | Args: 149 | x: (N, V, D) 150 | idx: LongTensor of shape (VV,) 151 | 152 | Returns: 153 | feats_packed: (VV, D) 154 | """ 155 | 156 | D = x.shape[-1] 157 | idx = idx.view(-1, 1).expand(-1, D) 158 | x_packed = x.view(-1, D).gather(0, idx) 159 | return x_packed 160 | -------------------------------------------------------------------------------- /shapenet/modeling/heads/mesh_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from pytorch3d.loss import chamfer_distance, mesh_edge_loss 8 | from pytorch3d.ops import sample_points_from_meshes 9 | from pytorch3d.structures import Meshes 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MeshLoss(nn.Module): 15 | def __init__( 16 | self, 17 | chamfer_weight=1.0, 18 | normal_weight=0.0, 19 | edge_weight=0.1, 20 | voxel_weight=0.0, 21 | gt_num_samples=5000, 22 | pred_num_samples=5000, 23 | ): 24 | super(MeshLoss, self).__init__() 25 | self.chamfer_weight = chamfer_weight 26 | self.normal_weight = normal_weight 27 | self.edge_weight = edge_weight 28 | self.gt_num_samples = gt_num_samples 29 | self.pred_num_samples = pred_num_samples 30 | self.voxel_weight = voxel_weight 31 | 32 | self.skip_mesh_loss = False 33 | if chamfer_weight == 0.0 and normal_weight == 0.0 and edge_weight == 0.0: 34 | self.skip_mesh_loss = True 35 | 36 | def forward(self, voxel_scores, meshes_pred, voxels_gt, meshes_gt): 37 | """ 38 | Args: 39 | meshes_pred: Meshes 40 | meshes_gt: Either Meshes, or a tuple (points_gt, normals_gt) 41 | 42 | Returns: 43 | loss (float): Torch scalar giving the total loss, or None if an error occured and 44 | we should skip this loss. TODO use an exception instead? 45 | losses (dict): A dictionary mapping loss names to Torch scalars giving their 46 | (unweighted) values. 47 | """ 48 | # Sample from meshes_gt if we haven't already 49 | if isinstance(meshes_gt, tuple): 50 | points_gt, normals_gt = meshes_gt 51 | else: 52 | points_gt, normals_gt = sample_points_from_meshes( 53 | meshes_gt, num_samples=self.gt_num_samples, return_normals=True 54 | ) 55 | 56 | total_loss = torch.tensor(0.0).to(points_gt) 57 | losses = {} 58 | 59 | if voxel_scores is not None and voxels_gt is not None and self.voxel_weight > 0: 60 | voxels_gt = voxels_gt.float() 61 | voxel_loss = F.binary_cross_entropy_with_logits(voxel_scores, voxels_gt) 62 | total_loss = total_loss + self.voxel_weight * voxel_loss 63 | losses["voxel"] = voxel_loss 64 | 65 | if isinstance(meshes_pred, Meshes): 66 | meshes_pred = [meshes_pred] 67 | elif meshes_pred is None: 68 | meshes_pred = [] 69 | 70 | # Now assume meshes_pred is a list 71 | if not self.skip_mesh_loss: 72 | for i, cur_meshes_pred in enumerate(meshes_pred): 73 | cur_out = self._mesh_loss(cur_meshes_pred, points_gt, normals_gt) 74 | cur_loss, cur_losses = cur_out 75 | if total_loss is None or cur_loss is None: 76 | total_loss = None 77 | else: 78 | total_loss = total_loss + cur_loss / len(meshes_pred) 79 | for k, v in cur_losses.items(): 80 | losses["%s_%d" % (k, i)] = v 81 | 82 | return total_loss, losses 83 | 84 | def _mesh_loss(self, meshes_pred, points_gt, normals_gt): 85 | """ 86 | Args: 87 | meshes_pred: Meshes containing N meshes 88 | points_gt: Tensor of shape NxPx3 89 | normals_gt: Tensor of shape NxPx3 90 | 91 | Returns: 92 | total_loss (float): The sum of all losses specific to meshes 93 | losses (dict): All (unweighted) mesh losses in a dictionary 94 | """ 95 | zero = torch.tensor(0.0).to(meshes_pred.verts_list()[0]) 96 | losses = {"chamfer": zero, "normal": zero, "edge": zero} 97 | points_pred, normals_pred = sample_points_from_meshes( 98 | meshes_pred, num_samples=self.pred_num_samples, return_normals=True 99 | ) 100 | 101 | total_loss = torch.tensor(0.0).to(points_pred) 102 | if points_pred is None or points_gt is None: 103 | # Sampling failed, so return None 104 | total_loss = None 105 | which = "predictions" if points_pred is None else "GT" 106 | logger.info("WARNING: Sampling %s failed" % (which)) 107 | return total_loss, losses 108 | 109 | losses = {} 110 | cham_loss, normal_loss = chamfer_distance( 111 | points_pred, points_gt, x_normals=normals_pred, y_normals=normals_gt 112 | ) 113 | 114 | total_loss = total_loss + self.chamfer_weight * cham_loss 115 | total_loss = total_loss + self.normal_weight * normal_loss 116 | losses["chamfer"] = cham_loss 117 | losses["normal"] = normal_loss 118 | 119 | edge_loss = mesh_edge_loss(meshes_pred) 120 | total_loss = total_loss + self.edge_weight * edge_loss 121 | losses["edge"] = edge_loss 122 | 123 | return total_loss, losses 124 | -------------------------------------------------------------------------------- /shapenet/modeling/heads/voxel_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from detectron2.layers import Conv2d, ConvTranspose2d, get_norm 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class VoxelHead(nn.Module): 10 | def __init__(self, cfg): 11 | super(VoxelHead, self).__init__() 12 | 13 | # fmt: off 14 | self.voxel_size = cfg.MODEL.VOXEL_HEAD.VOXEL_SIZE 15 | conv_dims = cfg.MODEL.VOXEL_HEAD.CONV_DIM 16 | num_conv = cfg.MODEL.VOXEL_HEAD.NUM_CONV 17 | input_channels = cfg.MODEL.VOXEL_HEAD.COMPUTED_INPUT_CHANNELS 18 | self.norm = cfg.MODEL.VOXEL_HEAD.NORM 19 | # fmt: on 20 | 21 | assert self.voxel_size % 2 == 0 22 | 23 | self.conv_norm_relus = [] 24 | prev_dim = input_channels 25 | for k in range(num_conv): 26 | conv = Conv2d( 27 | prev_dim, 28 | conv_dims, 29 | kernel_size=3, 30 | stride=1, 31 | padding=1, 32 | bias=not self.norm, 33 | norm=get_norm(self.norm, conv_dims), 34 | activation=F.relu, 35 | ) 36 | self.add_module("voxel_fcn{}".format(k + 1), conv) 37 | self.conv_norm_relus.append(conv) 38 | prev_dim = conv_dims 39 | 40 | self.deconv = ConvTranspose2d( 41 | conv_dims if num_conv > 0 else input_channels, 42 | conv_dims, 43 | kernel_size=2, 44 | stride=2, 45 | padding=0, 46 | ) 47 | 48 | self.predictor = Conv2d( 49 | conv_dims, self.voxel_size, kernel_size=1, stride=1, padding=0 50 | ) 51 | 52 | for layer in self.conv_norm_relus + [self.deconv]: 53 | weight_init.c2_msra_fill(layer) 54 | # use normal distribution initialization for voxel prediction layer 55 | nn.init.normal_(self.predictor.weight, std=0.001) 56 | if self.predictor.bias is not None: 57 | nn.init.constant_(self.predictor.bias, 0) 58 | 59 | def forward(self, x): 60 | V = self.voxel_size 61 | x = F.interpolate(x, size=V // 2, mode="bilinear", align_corners=False) 62 | for layer in self.conv_norm_relus: 63 | x = layer(x) 64 | x = F.relu(self.deconv(x)) 65 | x = self.predictor(x) 66 | return x 67 | -------------------------------------------------------------------------------- /shapenet/modeling/mesh_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | from detectron2.utils.registry import Registry 5 | from pytorch3d.ops import cubify 6 | from pytorch3d.structures import Meshes 7 | from pytorch3d.utils import ico_sphere 8 | 9 | from shapenet.modeling.backbone import build_backbone 10 | from shapenet.modeling.heads import MeshRefinementHead, VoxelHead 11 | from shapenet.utils.coords import get_blender_intrinsic_matrix, voxel_to_world 12 | 13 | MESH_ARCH_REGISTRY = Registry("MESH_ARCH") 14 | 15 | 16 | @MESH_ARCH_REGISTRY.register() 17 | class VoxMeshHead(nn.Module): 18 | def __init__(self, cfg): 19 | super(VoxMeshHead, self).__init__() 20 | 21 | # fmt: off 22 | backbone = cfg.MODEL.BACKBONE 23 | self.cubify_threshold = cfg.MODEL.VOXEL_HEAD.CUBIFY_THRESH 24 | self.voxel_size = cfg.MODEL.VOXEL_HEAD.VOXEL_SIZE 25 | # fmt: on 26 | 27 | self.register_buffer("K", get_blender_intrinsic_matrix()) 28 | # backbone 29 | self.backbone, feat_dims = build_backbone(backbone) 30 | # voxel head 31 | cfg.MODEL.VOXEL_HEAD.COMPUTED_INPUT_CHANNELS = feat_dims[-1] 32 | self.voxel_head = VoxelHead(cfg) 33 | # mesh head 34 | cfg.MODEL.MESH_HEAD.COMPUTED_INPUT_CHANNELS = sum(feat_dims) 35 | self.mesh_head = MeshRefinementHead(cfg) 36 | 37 | def _get_projection_matrix(self, N, device): 38 | return self.K[None].repeat(N, 1, 1).to(device).detach() 39 | 40 | def _dummy_mesh(self, N, device): 41 | verts_batch = torch.randn(N, 4, 3, device=device) 42 | faces = [[0, 1, 2], [0, 2, 3], [0, 3, 1], [1, 3, 2]] 43 | faces = torch.tensor(faces, dtype=torch.int64) 44 | faces_batch = faces.view(1, 4, 3).expand(N, 4, 3).to(device) 45 | return Meshes(verts=verts_batch, faces=faces_batch) 46 | 47 | def cubify(self, voxel_scores): 48 | V = self.voxel_size 49 | N = voxel_scores.shape[0] 50 | voxel_probs = voxel_scores.sigmoid() 51 | active_voxels = voxel_probs > self.cubify_threshold 52 | voxels_per_mesh = (active_voxels.view(N, -1).sum(dim=1)).tolist() 53 | start = V // 4 54 | stop = start + V // 2 55 | for i in range(N): 56 | if voxels_per_mesh[i] == 0: 57 | voxel_probs[i, start:stop, start:stop, start:stop] = 1 58 | meshes = cubify(voxel_probs, self.cubify_threshold) 59 | 60 | meshes = self._add_dummies(meshes) 61 | meshes = voxel_to_world(meshes) 62 | return meshes 63 | 64 | def _add_dummies(self, meshes): 65 | N = len(meshes) 66 | dummies = self._dummy_mesh(N, meshes.device) 67 | verts_list = meshes.verts_list() 68 | faces_list = meshes.faces_list() 69 | for i in range(N): 70 | if faces_list[i].shape[0] == 0: 71 | # print('Adding dummmy mesh at index ', i) 72 | vv, ff = dummies.get_mesh(i) 73 | verts_list[i] = vv 74 | faces_list[i] = ff 75 | return Meshes(verts=verts_list, faces=faces_list) 76 | 77 | def forward(self, imgs, voxel_only=False): 78 | N = imgs.shape[0] 79 | device = imgs.device 80 | 81 | img_feats = self.backbone(imgs) 82 | voxel_scores = self.voxel_head(img_feats[-1]) 83 | P = self._get_projection_matrix(N, device) 84 | 85 | if voxel_only: 86 | dummy_meshes = self._dummy_mesh(N, device) 87 | dummy_refined = self.mesh_head(img_feats, dummy_meshes, P) 88 | return voxel_scores, dummy_refined 89 | 90 | cubified_meshes = self.cubify(voxel_scores) 91 | refined_meshes = self.mesh_head(img_feats, cubified_meshes, P) 92 | return voxel_scores, refined_meshes 93 | 94 | 95 | @MESH_ARCH_REGISTRY.register() 96 | class SphereInitHead(nn.Module): 97 | def __init__(self, cfg): 98 | super(SphereInitHead, self).__init__() 99 | 100 | # fmt: off 101 | backbone = cfg.MODEL.BACKBONE 102 | self.ico_sphere_level = cfg.MODEL.MESH_HEAD.ICO_SPHERE_LEVEL 103 | # fmt: on 104 | 105 | self.register_buffer("K", get_blender_intrinsic_matrix()) 106 | # backbone 107 | self.backbone, feat_dims = build_backbone(backbone) 108 | # mesh head 109 | cfg.MODEL.MESH_HEAD.COMPUTED_INPUT_CHANNELS = sum(feat_dims) 110 | self.mesh_head = MeshRefinementHead(cfg) 111 | 112 | def _get_projection_matrix(self, N, device): 113 | return self.K[None].repeat(N, 1, 1).to(device).detach() 114 | 115 | def forward(self, imgs): 116 | N = imgs.shape[0] 117 | device = imgs.device 118 | 119 | img_feats = self.backbone(imgs) 120 | P = self._get_projection_matrix(N, device) 121 | 122 | init_meshes = ico_sphere(self.ico_sphere_level, device).extend(N) 123 | refined_meshes = self.mesh_head(img_feats, init_meshes, P) 124 | return None, refined_meshes 125 | 126 | 127 | @MESH_ARCH_REGISTRY.register() 128 | class Pixel2MeshHead(nn.Module): 129 | def __init__(self, cfg): 130 | super(Pixel2MeshHead, self).__init__() 131 | 132 | # fmt: off 133 | backbone = cfg.MODEL.BACKBONE 134 | self.ico_sphere_level = cfg.MODEL.MESH_HEAD.ICO_SPHERE_LEVEL 135 | # fmt: on 136 | 137 | self.register_buffer("K", get_blender_intrinsic_matrix()) 138 | # backbone 139 | self.backbone, feat_dims = build_backbone(backbone) 140 | # mesh head 141 | cfg.MODEL.MESH_HEAD.COMPUTED_INPUT_CHANNELS = sum(feat_dims) 142 | self.mesh_head = MeshRefinementHead(cfg) 143 | 144 | def _get_projection_matrix(self, N, device): 145 | return self.K[None].repeat(N, 1, 1).to(device).detach() 146 | 147 | def forward(self, imgs): 148 | N = imgs.shape[0] 149 | device = imgs.device 150 | 151 | img_feats = self.backbone(imgs) 152 | P = self._get_projection_matrix(N, device) 153 | 154 | init_meshes = ico_sphere(self.ico_sphere_level, device).extend(N) 155 | refined_meshes = self.mesh_head(img_feats, init_meshes, P, subdivide=True) 156 | return None, refined_meshes 157 | 158 | 159 | def build_model(cfg): 160 | name = cfg.MODEL.MESH_HEAD.NAME 161 | return MESH_ARCH_REGISTRY.get(name)(cfg) 162 | -------------------------------------------------------------------------------- /shapenet/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .build import build_lr_scheduler, build_optimizer 3 | -------------------------------------------------------------------------------- /shapenet/solver/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | from .lr_schedule import ConstantLR, WarmupCosineLR 5 | 6 | 7 | def build_lr_scheduler(cfg, optimizer): 8 | name = cfg.SOLVER.LR_SCHEDULER_NAME 9 | if name == "constant": 10 | return ConstantLR(optimizer) 11 | elif name == "cosine": 12 | return WarmupCosineLR( 13 | optimizer, 14 | total_iters=cfg.SOLVER.COMPUTED_MAX_ITERS, 15 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 16 | warmpup_factor=cfg.SOLVER.WARMUP_FACTOR, 17 | ) 18 | 19 | 20 | def build_optimizer(cfg, model): 21 | # TODO add weight decay? 22 | name = cfg.SOLVER.OPTIMIZER 23 | lr = cfg.SOLVER.BASE_LR 24 | momentum = cfg.SOLVER.MOMENTUM 25 | if name == "sgd": 26 | return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) 27 | elif name == "adam": 28 | return torch.optim.Adam(model.parameters(), lr=lr) 29 | -------------------------------------------------------------------------------- /shapenet/solver/lr_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | 4 | import torch 5 | 6 | 7 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | total_iters, 12 | warmup_iters=500, 13 | warmup_factor=0.1, 14 | eta_min=0.0, 15 | last_epoch=-1, 16 | warmup_method="cosine", 17 | ): 18 | self.total_iters = total_iters 19 | self.warmup_iters = warmup_iters 20 | self.warmup_factor = warmup_factor 21 | assert warmup_method in ["linear", "cosine"] 22 | self.warmup_method = warmup_method 23 | self.eta_min = eta_min 24 | super().__init__(optimizer, last_epoch) 25 | 26 | def get_lr(self): 27 | if self.last_epoch < self.warmup_iters: 28 | if self.warmup_method == "linear": 29 | alpha = self.last_epoch / self.warmup_iters 30 | lr_factor = self.warmup_factor * (1 - alpha) + alpha 31 | elif self.warmup_method == "cosine": 32 | t = 1.0 + self.last_epoch / self.warmup_iters 33 | cos_factor = (1.0 + math.cos(math.pi * t)) / 2.0 34 | lr_factor = self.warmup_factor + (1.0 - self.warmup_factor) * cos_factor 35 | else: 36 | raise ValueError("Unsupported warmup method") 37 | return [lr_factor * base_lr for base_lr in self.base_lrs] 38 | 39 | num_decay_iters = self.total_iters - self.warmup_iters 40 | t = (self.last_epoch - self.warmup_iters) / num_decay_iters 41 | cos_factor = (1.0 + math.cos(math.pi * t)) / 2.0 42 | lrs = [] 43 | for base_lr in self.base_lrs: 44 | lr = self.eta_min + (base_lr - self.eta_min) * cos_factor 45 | lrs.append(lr) 46 | return lrs 47 | 48 | 49 | class ConstantLR(torch.optim.lr_scheduler._LRScheduler): 50 | def __init__(self, optimizer, last_epoch=-1): 51 | super().__init__(optimizer, last_epoch) 52 | 53 | def get_lr(self): 54 | return [base_lr for base_lr in self.base_lrs] 55 | -------------------------------------------------------------------------------- /shapenet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from . import model_zoo # registers pathhandlers 3 | from .checkpoint import Checkpoint, clean_state_dict 4 | from .defaults import * 5 | from .timing import Timer 6 | -------------------------------------------------------------------------------- /shapenet/utils/binvox_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | 5 | def read_binvox_coords(f, integer_division=True, dtype=torch.float32): 6 | """ 7 | Read a binvox file and return the indices of all nonzero voxels. 8 | 9 | This matches the behavior of binvox_rw.read_as_coord_array 10 | (https://github.com/dimatura/binvox-rw-py/blob/public/binvox_rw.py#L153) 11 | but this implementation uses torch rather than numpy, and is more efficient 12 | due to improved vectorization. 13 | 14 | I think that binvox_rw.read_as_coord_array actually has a bug; when converting 15 | linear indices into three-dimensional indices, they use floating-point 16 | division instead of integer division. We can reproduce their incorrect 17 | implementation by passing integer_division=False. 18 | 19 | Args: 20 | f (str): A file pointer to the binvox file to read 21 | integer_division (bool): If False, then match the buggy implementation from binvox_rw 22 | dtype: Datatype of the output tensor. Use float64 to match binvox_rw 23 | 24 | Returns: 25 | coords (tensor): A tensor of shape (N, 3) where N is the number of nonzero voxels, 26 | and coords[i] = (x, y, z) gives the index of the ith nonzero voxel. If the 27 | voxel grid has shape (V, V, V) then we have 0 <= x, y, z < V. 28 | """ 29 | size, translation, scale = _read_binvox_header(f) 30 | storage = torch.ByteStorage.from_buffer(f.read()) 31 | data = torch.tensor([], dtype=torch.uint8) 32 | data.set_(source=storage) 33 | vals, counts = data[::2], data[1::2] 34 | idxs = _compute_idxs_v2(vals, counts) 35 | if not integer_division: 36 | idxs = idxs.to(dtype) 37 | x_idxs = idxs / (size * size) 38 | zy_idxs = idxs % (size * size) 39 | z_idxs = zy_idxs / size 40 | y_idxs = zy_idxs % size 41 | coords = torch.stack([x_idxs, y_idxs, z_idxs], dim=1) 42 | return coords.to(dtype) 43 | 44 | 45 | def _compute_idxs_v1(vals, counts): 46 | """Naive version of index computation with loops""" 47 | idxs = [] 48 | cur = 0 49 | for i in range(vals.shape[0]): 50 | val, count = vals[i].item(), counts[i].item() 51 | if val == 1: 52 | idxs.append(torch.arange(cur, cur + count)) 53 | cur += count 54 | idxs = torch.cat(idxs, dim=0) 55 | return idxs 56 | 57 | 58 | def _compute_idxs_v2(vals, counts): 59 | """Fast vectorized version of index computation""" 60 | # Consider an example where: 61 | # vals = [0, 1, 0, 1, 1] 62 | # counts = [2, 3, 3, 2, 1] 63 | # 64 | # These values of counts and vals mean that the dense binary grid is: 65 | # [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1] 66 | # 67 | # So the nonzero indices we want to return are: 68 | # [2, 3, 4, 8, 9, 10] 69 | 70 | # After the cumsum we will have: 71 | # end_idxs = [2, 5, 8, 10, 11] 72 | end_idxs = counts.cumsum(dim=0) 73 | 74 | # After masking and computing start_idx we have: 75 | # end_idxs = [5, 10, 11] 76 | # counts = [3, 2, 1] 77 | # start_idxs = [2, 8, 10] 78 | mask = vals == 1 79 | end_idxs = end_idxs[mask] 80 | counts = counts[mask].to(end_idxs) 81 | start_idxs = end_idxs - counts 82 | 83 | # We initialize delta as: 84 | # [2, 1, 1, 1, 1, 1] 85 | delta = torch.ones(counts.sum().item(), dtype=torch.int64) 86 | delta[0] = start_idxs[0] 87 | 88 | # We compute pos = [3, 5], val = [3, 0]; then delta is 89 | # [2, 1, 1, 4, 1, 1] 90 | pos = counts.cumsum(dim=0)[:-1] 91 | val = start_idxs[1:] - end_idxs[:-1] 92 | delta[pos] += val 93 | 94 | # A final cumsum gives the idx we want: [2, 3, 4, 8, 9, 10] 95 | idxs = delta.cumsum(dim=0) 96 | return idxs 97 | 98 | 99 | def _read_binvox_header(f): 100 | # First line of the header should be "#binvox 1" 101 | line = f.readline().strip() 102 | if line != b"#binvox 1": 103 | raise ValueError("Invalid header (line 1)") 104 | 105 | # Second line of the header should be "dim [int] [int] [int]" 106 | # and all three int should be the same 107 | line = f.readline().strip() 108 | if not line.startswith(b"dim "): 109 | raise ValueError("Invalid header (line 2)") 110 | dims = line.split(b" ") 111 | try: 112 | dims = [int(d) for d in dims[1:]] 113 | except ValueError: 114 | raise ValueError("Invalid header (line 2)") 115 | if len(dims) != 3 or dims[0] != dims[1] or dims[0] != dims[2]: 116 | raise ValueError("Invalid header (line 2)") 117 | size = dims[0] 118 | 119 | # Third line of the header should be "translate [float] [float] [float]" 120 | line = f.readline().strip() 121 | if not line.startswith(b"translate "): 122 | raise ValueError("Invalid header (line 3)") 123 | translation = line.split(b" ") 124 | if len(translation) != 4: 125 | raise ValueError("Invalid header (line 3)") 126 | try: 127 | translation = tuple(float(t) for t in translation[1:]) 128 | except ValueError: 129 | raise ValueError("Invalid header (line 3)") 130 | 131 | # Fourth line of the header should be "scale [float]" 132 | line = f.readline().strip() 133 | if not line.startswith(b"scale "): 134 | raise ValueError("Invalid header (line 4)") 135 | line = line.split(b" ") 136 | if not len(line) == 2: 137 | raise ValueError("Invalid header (line 4)") 138 | scale = float(line[1]) 139 | 140 | # Fifth line of the header should be "data" 141 | line = f.readline().strip() 142 | if not line == b"data": 143 | raise ValueError("Invalid header (line 5)") 144 | 145 | return size, translation, scale 146 | -------------------------------------------------------------------------------- /shapenet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | import os 4 | 5 | import torch 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class Checkpoint: 11 | # These keys are saved in all checkpoints 12 | KEYS_TO_SAVE = [ 13 | "t", 14 | "epoch", 15 | "metrics", 16 | "metrics_ts", 17 | "data", 18 | "early_stop_metric", 19 | "with_model_path", 20 | "no_model_path", 21 | "restarts", 22 | ] 23 | 24 | # These keys are saved for "big" checkpoints that include the model state 25 | STATE_KEYS = ["latest_states", "latest_states_ts", "best_states", "best_states_ts"] 26 | 27 | def __init__( 28 | self, output_path="checkpoint.pt", early_stop_metric=None, overwrite=False 29 | ): 30 | output_dir, filename = os.path.split(output_path) 31 | filename, ext = os.path.splitext(filename) 32 | self.with_model_path = "%s_with_model%s" % (filename, ext) 33 | self.with_model_path = os.path.join(output_dir, self.with_model_path) 34 | self.no_model_path = "%s_no_model%s" % (filename, ext) 35 | self.no_model_path = os.path.join(output_dir, self.no_model_path) 36 | 37 | self.t = 0 38 | self.epoch = 0 39 | 40 | # Metrics change over time, data doesn't 41 | self.metrics = {} 42 | self.metrics_ts = {} 43 | self.data = {} 44 | 45 | self.latest_states = {} 46 | self.latest_states_ts = {} 47 | self.best_states = {} 48 | self.best_states_ts = {} 49 | self.early_stop_metric = early_stop_metric 50 | 51 | self.restarts = [] 52 | if os.path.isfile(self.with_model_path) and not overwrite: 53 | logger.info('Loading checkpoint from "%s"' % self.with_model_path) 54 | self.from_dict(torch.load(self.with_model_path)) 55 | self.restarts.append(self.t) 56 | 57 | def step(self): 58 | self.t += 1 59 | 60 | def step_epoch(self): 61 | self.epoch += 1 62 | 63 | def store_data(self, k, v): 64 | self.data[k] = v 65 | 66 | def store_metric(self, **kwargs): 67 | for k, v in kwargs.items(): 68 | if k not in self.metrics: 69 | self.metrics[k] = [] 70 | self.metrics_ts[k] = [] 71 | self.metrics[k].append(v) 72 | self.metrics_ts[k].append(self.t) 73 | 74 | def store_state(self, name, state, best=None): 75 | self.latest_states[name] = state 76 | self.latest_states_ts[name] = self.t 77 | 78 | if best is None: 79 | k = self.early_stop_metric 80 | if k not in self.metrics: 81 | best = True 82 | else: 83 | max_v = max(self.metrics[k]) 84 | last_v = self.metrics[k][-1] 85 | last_t = self.metrics_ts[k][-1] 86 | if self.t == last_t and last_v == max_v: 87 | best = True 88 | else: 89 | best = False 90 | 91 | if best is None: 92 | raise ValueError("Cannot determine whether current state is best") 93 | 94 | if best: 95 | logger.info('Storing new best state for "%s"' % name) 96 | self.best_states[name] = state 97 | self.best_states_ts[name] = state 98 | 99 | def to_dict(self, include_states=False): 100 | keys = [k for k in self.KEYS_TO_SAVE] 101 | if include_states: 102 | keys += self.STATE_KEYS 103 | d = {k: getattr(self, k) for k in keys} 104 | return d 105 | 106 | def from_dict(self, d): 107 | for k in d.keys(): 108 | setattr(self, k, d[k]) 109 | 110 | def save(self): 111 | logger.info('Saving checkpoint (with model) to "%s"' % self.with_model_path) 112 | torch.save(self.to_dict(include_states=True), self.with_model_path) 113 | 114 | logger.info('Saving checkpoint (without model) to "%s"' % self.no_model_path) 115 | torch.save(self.to_dict(include_states=False), self.no_model_path) 116 | 117 | 118 | def clean_state_dict(state_dict): 119 | # Ugly hack to clean up the state dict in case we forgot to unpack the 120 | # underlying model from DistributedDataParallel when training 121 | out = {} 122 | for k, v in state_dict.items(): 123 | while k.startswith("module."): 124 | k = k[7:] 125 | out[k] = v 126 | return out 127 | -------------------------------------------------------------------------------- /shapenet/utils/coords.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """Utilities for working with different 3D coordinate systems""" 3 | 4 | import copy 5 | import math 6 | 7 | import torch 8 | from pytorch3d.structures import Meshes 9 | 10 | SHAPENET_MIN_ZMIN = 0.67 11 | SHAPENET_MAX_ZMAX = 0.92 12 | 13 | SHAPENET_AVG_ZMIN = 0.77 14 | SHAPENET_AVG_ZMAX = 0.90 15 | 16 | 17 | def get_blender_intrinsic_matrix(N=None): 18 | """ 19 | This is the (default) matrix that blender uses to map from camera coordinates 20 | to normalized device coordinates. We can extract it from Blender like this: 21 | 22 | import bpy 23 | camera = bpy.data.objects['Camera'] 24 | render = bpy.context.scene.render 25 | K = camera.calc_matrix_camera( 26 | render.resolution_x, 27 | render.resolution_y, 28 | render.pixel_aspect_x, 29 | render.pixel_aspect_y) 30 | """ 31 | K = [ 32 | [2.1875, 0.0, 0.0, 0.0], 33 | [0.0, 2.1875, 0.0, 0.0], 34 | [0.0, 0.0, -1.002002, -0.2002002], 35 | [0.0, 0.0, -1.0, 0.0], 36 | ] 37 | K = torch.tensor(K) 38 | if N is not None: 39 | K = K.view(1, 4, 4).expand(N, 4, 4) 40 | return K 41 | 42 | 43 | def blender_ndc_to_world(verts): 44 | """ 45 | Inverse operation to projecting by the Blender intrinsic operation above. 46 | In other words, the following should hold: 47 | 48 | K = get_blender_intrinsic_matrix() 49 | verts == blender_ndc_to_world(project_verts(verts, K)) 50 | """ 51 | xx, yy, zz = verts.unbind(dim=1) 52 | a1, a2, a3 = 2.1875, 2.1875, -1.002002 53 | b1, b2 = -0.2002002, -1.0 54 | z = b1 / (b2 * zz - a3) 55 | y = (b2 / a2) * (z * yy) 56 | x = (b2 / a1) * (z * xx) 57 | out = torch.stack([x, y, z], dim=1) 58 | return out 59 | 60 | 61 | def voxel_to_world(meshes): 62 | """ 63 | When predicting voxels, we operate in a [-1, 1]^3 coordinate space where the 64 | intrinsic matrix has already been applied, the y-axis has been flipped to 65 | to align with the image plane, and the z-axis has been rescaled so the min/max 66 | z values in the dataset correspond to -1 / 1. This function undoes these 67 | transformations, and projects a Meshes from voxel-space into world space. 68 | 69 | TODO: This projection logic is tightly coupled to the MeshVox Dataset; 70 | they should maybe both be refactored? 71 | 72 | Input: 73 | - meshes: Meshes in voxel coordinate system 74 | 75 | Output: 76 | - meshes: Meshes in world coordinate system 77 | """ 78 | verts = meshes.verts_packed() 79 | x, y, z = verts.unbind(dim=1) 80 | 81 | zmin, zmax = SHAPENET_MIN_ZMIN, SHAPENET_MAX_ZMAX 82 | m = 2.0 / (zmax - zmin) 83 | b = -2.0 * zmin / (zmax - zmin) - 1 84 | 85 | y = -y 86 | z = (z - b) / m 87 | verts = torch.stack([x, y, z], dim=1) 88 | verts = blender_ndc_to_world(verts) 89 | 90 | verts_list = list(verts.split(meshes.num_verts_per_mesh().tolist(), dim=0)) 91 | faces_list = copy.deepcopy(meshes.faces_list()) 92 | meshes_world = Meshes(verts=verts_list, faces=faces_list) 93 | 94 | return meshes_world 95 | 96 | 97 | def compute_extrinsic_matrix(azimuth, elevation, distance): 98 | """ 99 | Compute 4x4 extrinsic matrix that converts from homogenous world coordinates 100 | to homogenous camera coordinates. We assume that the camera is looking at the 101 | origin. 102 | 103 | Inputs: 104 | - azimuth: Rotation about the z-axis, in degrees 105 | - elevation: Rotation above the xy-plane, in degrees 106 | - distance: Distance from the origin 107 | 108 | Returns: 109 | - FloatTensor of shape (4, 4) 110 | """ 111 | azimuth, elevation, distance = (float(azimuth), float(elevation), float(distance)) 112 | az_rad = -math.pi * azimuth / 180.0 113 | el_rad = -math.pi * elevation / 180.0 114 | sa = math.sin(az_rad) 115 | ca = math.cos(az_rad) 116 | se = math.sin(el_rad) 117 | ce = math.cos(el_rad) 118 | R_world2obj = torch.tensor( 119 | [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]] 120 | ) 121 | R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) 122 | R_world2cam = R_obj2cam.mm(R_world2obj) 123 | cam_location = torch.tensor([[distance, 0, 0]]).t() 124 | T_world2cam = -R_obj2cam.mm(cam_location) 125 | RT = torch.cat([R_world2cam, T_world2cam], dim=1) 126 | RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])]) 127 | 128 | # For some reason I cannot fathom, when Blender loads a .obj file it rotates 129 | # the model 90 degrees about the x axis. To compensate for this quirk we roll 130 | # that rotation into the extrinsic matrix here 131 | rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) 132 | RT = RT.mm(rot.to(RT)) 133 | 134 | return RT 135 | 136 | 137 | def rotate_verts(RT, verts): 138 | """ 139 | Inputs: 140 | - RT: (N, 4, 4) array of extrinsic matrices 141 | - verts: (N, V, 3) array of vertex positions 142 | """ 143 | singleton = False 144 | if RT.dim() == 2: 145 | assert verts.dim() == 2 146 | RT, verts = RT[None], verts[None] 147 | singleton = True 148 | 149 | if isinstance(verts, list): 150 | verts_rot = [] 151 | for i, v in enumerate(verts): 152 | verts_rot.append(rotate_verts(RT[i], v)) 153 | return verts_rot 154 | 155 | R = RT[:, :3, :3] 156 | verts_rot = verts.bmm(R.transpose(1, 2)) 157 | if singleton: 158 | verts_rot = verts_rot[0] 159 | return verts_rot 160 | 161 | 162 | def project_verts(verts, P, eps=1e-1): 163 | """ 164 | Project verticies using a 4x4 transformation matrix 165 | 166 | Inputs: 167 | - verts: FloatTensor of shape (N, V, 3) giving a batch of vertex positions. 168 | - P: FloatTensor of shape (N, 4, 4) giving projection matrices 169 | 170 | Outputs: 171 | - verts_out: FloatTensor of shape (N, V, 3) giving vertex positions (x, y, z) 172 | where verts_out[i] is the result of transforming verts[i] by P[i]. 173 | """ 174 | # Handle unbatched inputs 175 | singleton = False 176 | if verts.dim() == 2: 177 | assert P.dim() == 2 178 | singleton = True 179 | verts, P = verts[None], P[None] 180 | 181 | N, V = verts.shape[0], verts.shape[1] 182 | dtype, device = verts.dtype, verts.device 183 | 184 | # Add an extra row of ones to the world-space coordinates of verts before 185 | # multiplying by the projection matrix. We could avoid this allocation by 186 | # instead multiplying by a 4x3 submatrix of the projectio matrix, then 187 | # adding the remaining 4x1 vector. Not sure whether there will be much 188 | # performance difference between the two. 189 | ones = torch.ones(N, V, 1, dtype=dtype, device=device) 190 | verts_hom = torch.cat([verts, ones], dim=2) 191 | verts_cam_hom = torch.bmm(verts_hom, P.transpose(1, 2)) 192 | 193 | # Avoid division by zero by clamping the absolute value 194 | w = verts_cam_hom[:, :, 3:] 195 | w_sign = w.sign() 196 | w_sign[w == 0] = 1 197 | w = w_sign * w.abs().clamp(min=eps) 198 | 199 | verts_proj = verts_cam_hom[:, :, :3] / w 200 | 201 | if singleton: 202 | return verts_proj[0] 203 | return verts_proj 204 | -------------------------------------------------------------------------------- /shapenet/utils/defaults.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import argparse 4 | import os 5 | 6 | __all__ = ["default_argument_parser"] 7 | 8 | 9 | def default_argument_parser(): 10 | """ 11 | Create a parser. 12 | 13 | Returns: 14 | argparse.ArgumentParser: 15 | """ 16 | parser = argparse.ArgumentParser(description="ShapeNet Training") 17 | parser.add_argument( 18 | "--config-file", default="", metavar="FILE", help="path to config file" 19 | ) 20 | parser.add_argument( 21 | "--resume", 22 | action="store_true", 23 | help="whether to attempt to resume from the checkpoint directory", 24 | ) 25 | parser.add_argument( 26 | "--eval-only", action="store_true", help="perform evaluation only" 27 | ) 28 | parser.add_argument( 29 | "--eval-p2m", action="store_true", help="pix2mesh evaluation mode" 30 | ) 31 | parser.add_argument( 32 | "--no-color", action="store_true", help="disable colorful logging" 33 | ) 34 | parser.add_argument( 35 | "--num-gpus", type=int, default=1, help="number of gpus per machine" 36 | ) 37 | parser.add_argument("--num-machines", type=int, default=1) 38 | parser.add_argument( 39 | "--machine-rank", 40 | type=int, 41 | default=0, 42 | help="the rank of this machine (unique per machine)", 43 | ) 44 | port = 2**15 + 2**14 + hash(os.getuid()) % 2**14 45 | parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port)) 46 | parser.add_argument( 47 | "--data-dir", 48 | default="./datasets/shapenet/ShapeNetV1processed.zip", 49 | help="Path to the ShapeNet zipped data from preprocessing - used ONLY when copying data", 50 | ) 51 | parser.add_argument("--tmp-dir", default="/tmp") 52 | parser.add_argument("--copy-data", action="store_true", help="copy data") 53 | parser.add_argument( 54 | "--torch-home", 55 | default="$XDG_CACHE_HOME/torch", 56 | help="Path to torchvision model zoo", 57 | ) 58 | parser.add_argument( 59 | "opts", 60 | help="Modify config options using the command-line", 61 | default=None, 62 | nargs=argparse.REMAINDER, 63 | ) 64 | return parser.parse_args() 65 | -------------------------------------------------------------------------------- /shapenet/utils/model_zoo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from fvcore.common.file_io import PathHandler, PathManager 4 | 5 | __all__ = ["ShapenetHandler"] 6 | 7 | 8 | class ShapenetHandler(PathHandler): 9 | """ 10 | Resolve anything that's in Mesh R-CNN's model zoo. 11 | """ 12 | 13 | PREFIX = "shapenet://" 14 | SHAPENET_PREFIX = "https://dl.fbaipublicfiles.com/meshrcnn/shapenet/" 15 | 16 | def _get_supported_prefixes(self): 17 | return [self.PREFIX] 18 | 19 | def _get_local_path(self, path): 20 | name = path[len(self.PREFIX) :] 21 | return PathManager.get_local_path(self.SHAPENET_PREFIX + name) 22 | 23 | def _open(self, path, mode="r", **kwargs): 24 | return PathManager.open(self._get_local_path(path), mode, **kwargs) 25 | 26 | 27 | PathManager.register_handler(ShapenetHandler()) 28 | -------------------------------------------------------------------------------- /shapenet/utils/timing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """Utilities for timing GPU operations in PyTorch.""" 3 | 4 | import logging 5 | import time 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import torch 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def time_backward(f, x, key=None, timing=None): 15 | """ 16 | Utility function for timing the backward pass. Suppose we have the operation 17 | y = f(x) and we want to know how long the backward pass will take. We can 18 | then write: 19 | 20 | y = time_backward(f, x, 'f') 21 | 22 | This will set up backward hooks in the graph that start a Timer once grad_y 23 | has been computed, and stop the Timer when grad_x has been computed. 24 | """ 25 | if callable(f): 26 | y = f(x) 27 | else: 28 | y = f 29 | timer = Timer(key=key, timing=timing) 30 | 31 | def y_hook(_grad_y): 32 | timer.start() 33 | 34 | def x_hook(_grad_x): 35 | timer.stop() 36 | 37 | if y.requires_grad and x.requires_grad: 38 | y.register_hook(y_hook) 39 | x.register_hook(x_hook) 40 | return y 41 | 42 | 43 | def timeit(f, x, key=None, timing=None): 44 | """ 45 | Utility function that times both the forward and backward pass of y = f(x). 46 | """ 47 | f_key = "%s-forward" % key 48 | b_key = "%s-backward" % key 49 | with Timer(f_key, timing): 50 | y = time_backward(f, x, b_key, timing) 51 | return y 52 | 53 | 54 | class Timer: 55 | """ 56 | A context manager for timing nested chunks of code, like this: 57 | 58 | with Timer('my_loop'): 59 | out = 0 60 | for x in range(100): 61 | with Timer('my_op'): 62 | out += f(x) 63 | 64 | If you set Timer.timing = True then this will print mean and std dev timing 65 | for both my_loop and my_op. 66 | """ 67 | 68 | _indent_level = 0 69 | timing = False 70 | _times = defaultdict(list) 71 | 72 | @classmethod 73 | def _adjust_indent(cls, val): 74 | cls._indent_level += val 75 | 76 | @classmethod 77 | def _record_time(cls, key, val): 78 | cls._times[key].append(val) 79 | 80 | @classmethod 81 | def get_stats(cls, key): 82 | times = cls._times[key] 83 | return np.mean(times), np.std(times) 84 | 85 | @classmethod 86 | def reset(cls): 87 | cls._times = defaultdict(list) 88 | 89 | def __init__(self, key=None, timing=None): 90 | self._key = key 91 | self._local_timing = timing 92 | 93 | def _should_time(self): 94 | if self._local_timing is not None: 95 | return self._local_timing 96 | return self.timing 97 | 98 | def start(self): 99 | if self._should_time(): 100 | self._adjust_indent(1) 101 | torch.cuda.synchronize() 102 | self._t0 = time.time() 103 | 104 | def stop(self): 105 | if self._should_time(): 106 | torch.cuda.synchronize() 107 | self._t1 = time.time() 108 | duration_ms = (self._t1 - self._t0) * 1000.0 109 | key = self._key 110 | space = " " * self._indent_level 111 | if key is not None: 112 | self._record_time(key, duration_ms) 113 | mean, std = self.get_stats(key) 114 | msg = "[timeit]%s%s: %.4f ms (mean=%.4f ms, std=%.4f ms)" % ( 115 | space, 116 | key, 117 | duration_ms, 118 | mean, 119 | std, 120 | ) 121 | else: 122 | msg = "[timeit]%s%.4f" % (space, duration_ms) 123 | logger.info(msg) 124 | self._adjust_indent(-1) 125 | 126 | def tick(self): 127 | self.stop() 128 | self.start() 129 | 130 | def __enter__(self): 131 | self.start() 132 | return self 133 | 134 | def __exit__(self, exc_type, value, traceback): 135 | self.stop() 136 | -------------------------------------------------------------------------------- /shapenet/utils/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import itertools 3 | import os 4 | 5 | from pytorch3d.io import save_obj 6 | 7 | from tabulate import tabulate 8 | from termcolor import colored 9 | 10 | try: 11 | import cv2 # noqa 12 | except ImportError: 13 | # If opencv is not available, everything else should still run 14 | pass 15 | 16 | 17 | def print_instances_class_histogram(num_instances, class_names, results): 18 | """ 19 | Args: 20 | num_instances (list): list of dataset dicts. 21 | """ 22 | num_classes = len(class_names) 23 | N_COLS = 7 24 | data = list( 25 | itertools.chain( 26 | *[ 27 | [ 28 | class_names[id], 29 | v, 30 | results["chamfer"][id] / v, 31 | results["normal"][id] / v, 32 | results["f1_01"][id] / v, 33 | results["f1_03"][id] / v, 34 | results["f1_05"][id] / v, 35 | ] 36 | for id, v in num_instances.items() 37 | ] 38 | ) 39 | ) 40 | total_num_instances = sum(data[1::7]) 41 | mean_chamfer = sum(data[2::7]) / num_classes 42 | mean_normal = sum(data[3::7]) / num_classes 43 | mean_f1_01 = sum(data[4::7]) / num_classes 44 | mean_f1_03 = sum(data[5::7]) / num_classes 45 | mean_f1_05 = sum(data[6::7]) / num_classes 46 | data.extend([None] * (N_COLS - (len(data) % N_COLS))) 47 | data.extend( 48 | [ 49 | "total", 50 | total_num_instances, 51 | mean_chamfer, 52 | mean_normal, 53 | mean_f1_01, 54 | mean_f1_03, 55 | mean_f1_05, 56 | ] 57 | ) 58 | data.extend([None] * (N_COLS - (len(data) % N_COLS))) 59 | data.extend( 60 | [ 61 | "per-instance", 62 | total_num_instances, 63 | sum(results["chamfer"].values()) / total_num_instances, 64 | sum(results["normal"].values()) / total_num_instances, 65 | sum(results["f1_01"].values()) / total_num_instances, 66 | sum(results["f1_03"].values()) / total_num_instances, 67 | sum(results["f1_05"].values()) / total_num_instances, 68 | ] 69 | ) 70 | data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) 71 | table = tabulate( 72 | data, 73 | headers=[ 74 | "category", 75 | "#instances", 76 | "chamfer", 77 | "normal", 78 | "F1(0.1)", 79 | "F1(0.3)", 80 | "F1(0.5)", 81 | ] 82 | * (N_COLS // 2), 83 | tablefmt="pipe", 84 | numalign="left", 85 | stralign="center", 86 | ) 87 | print( 88 | "Distribution of testing instances among all {} categories:\n".format( 89 | num_classes 90 | ) 91 | + colored(table, "cyan") 92 | ) 93 | 94 | 95 | def print_instances_class_histogram_p2m(num_instances, class_names, results): 96 | """ 97 | Args: 98 | num_instances (list): list of dataset dicts. 99 | """ 100 | num_classes = len(class_names) 101 | N_COLS = 6 102 | data = list( 103 | itertools.chain( 104 | *[ 105 | [ 106 | class_names[id], 107 | v, 108 | results["chamfer"][id] / v, 109 | results["normal"][id] / v, 110 | results["f1_1e_4"][id] / v, 111 | results["f1_2e_4"][id] / v, 112 | ] 113 | for id, v in num_instances.items() 114 | ] 115 | ) 116 | ) 117 | total_num_instances = sum(data[1::6]) 118 | mean_chamfer = sum(data[2::6]) / num_classes 119 | mean_normal = sum(data[3::6]) / num_classes 120 | mean_f1_1e_4 = sum(data[4::6]) / num_classes 121 | mean_f1_2e_4 = sum(data[5::6]) / num_classes 122 | data.extend([None] * (N_COLS - (len(data) % N_COLS))) 123 | data.extend( 124 | [ 125 | "total", 126 | total_num_instances, 127 | mean_chamfer, 128 | mean_normal, 129 | mean_f1_1e_4, 130 | mean_f1_2e_4, 131 | ] 132 | ) 133 | data.extend([None] * (N_COLS - (len(data) % N_COLS))) 134 | data.extend( 135 | [ 136 | "per-instance", 137 | total_num_instances, 138 | sum(results["chamfer"].values()) / total_num_instances, 139 | sum(results["normal"].values()) / total_num_instances, 140 | sum(results["f1_1e_4"].values()) / total_num_instances, 141 | sum(results["f1_2e_4"].values()) / total_num_instances, 142 | ] 143 | ) 144 | data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) 145 | table = tabulate( 146 | data, 147 | headers=[ 148 | "category", 149 | "#instances", 150 | "chamfer", 151 | "normal", 152 | "F1(0.0001)", 153 | "F1(0.0002)", 154 | ] 155 | * (N_COLS // 2), 156 | tablefmt="pipe", 157 | numalign="left", 158 | stralign="center", 159 | ) 160 | print( 161 | "Distribution of testing instances among all {} categories:\n".format( 162 | num_classes 163 | ) 164 | + colored(table, "cyan") 165 | ) 166 | 167 | 168 | def visualize_prediction(image_id, img, mesh, output_dir): 169 | # create vis_dir 170 | output_dir = os.path.join(output_dir, "results_shapenet") 171 | os.makedirs(output_dir, exist_ok=True) 172 | 173 | save_img = os.path.join(output_dir, "%s.png" % (image_id)) 174 | cv2.imwrite(save_img, img[:, :, ::-1]) 175 | 176 | save_mesh = os.path.join(output_dir, "%s.obj" % (image_id)) 177 | verts, faces = mesh.get_mesh_verts_faces(0) 178 | save_obj(save_mesh, verts, faces) 179 | -------------------------------------------------------------------------------- /tools/convert_cocomodel_for_init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | r""" 4 | Convert coco model for init. Remove class specific heads, optimizer and scheduler 5 | so that this model can be used for pre-training 6 | """ 7 | 8 | import argparse 9 | 10 | import torch 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description="Convert models for init") 15 | parser.add_argument( 16 | "--model-file", 17 | default="", 18 | dest="modelfile", 19 | metavar="FILE", 20 | help="path to model", 21 | type=str, 22 | ) 23 | parser.add_argument( 24 | "--output-file", 25 | default="", 26 | dest="outputfile", 27 | metavar="FILE", 28 | help="path to model", 29 | type=str, 30 | ) 31 | 32 | args = parser.parse_args() 33 | 34 | model = torch.load(args.modelfile) 35 | # pop the optimizer 36 | model.pop("optimizer") 37 | # pop the scheduler 38 | model.pop("scheduler") 39 | # pop the iteration 40 | model.pop("iteration") 41 | # pop the class specific weights from the coco pretrained model 42 | heads = [ 43 | "roi_heads.box_predictor.cls_score.weight", 44 | "roi_heads.box_predictor.cls_score.bias", 45 | "roi_heads.box_predictor.bbox_pred.weight", 46 | "roi_heads.box_predictor.bbox_pred.bias", 47 | "roi_heads.mask_head.predictor.weight", 48 | "roi_heads.mask_head.predictor.bias", 49 | ] 50 | for head in heads: 51 | model["model"].pop(head) 52 | torch.save(model, args.outputfile) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | from collections import OrderedDict 5 | 6 | import detectron2.utils.comm as comm 7 | 8 | # required so that .register() calls are executed in module scope 9 | import meshrcnn.modeling # noqa 10 | from detectron2.checkpoint import DetectionCheckpointer 11 | from detectron2.config import get_cfg 12 | from detectron2.data import ( 13 | build_detection_test_loader, 14 | build_detection_train_loader, 15 | MetadataCatalog, 16 | ) 17 | from detectron2.engine import ( 18 | default_argument_parser, 19 | default_setup, 20 | DefaultTrainer, 21 | launch, 22 | ) 23 | from detectron2.evaluation import inference_on_dataset 24 | from detectron2.utils.logger import setup_logger 25 | from meshrcnn.config import get_meshrcnn_cfg_defaults 26 | from meshrcnn.data import MeshRCNNMapper 27 | from meshrcnn.evaluation import Pix3DEvaluator 28 | 29 | 30 | class Trainer(DefaultTrainer): 31 | @classmethod 32 | def build_evaluator(cls, cfg, dataset_name): 33 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 34 | if evaluator_type == "pix3d": 35 | return Pix3DEvaluator(dataset_name, cfg, True) 36 | else: 37 | raise ValueError("The evaluator type is wrong") 38 | 39 | @classmethod 40 | def build_test_loader(cls, cfg, dataset_name): 41 | return build_detection_test_loader( 42 | cfg, 43 | dataset_name, 44 | mapper=MeshRCNNMapper(cfg, False, dataset_names=(dataset_name,)), 45 | ) 46 | 47 | @classmethod 48 | def build_train_loader(cls, cfg): 49 | dataset_names = cfg.DATASETS.TRAIN 50 | return build_detection_train_loader( 51 | cfg, mapper=MeshRCNNMapper(cfg, True, dataset_names=dataset_names) 52 | ) 53 | 54 | @classmethod 55 | def test(cls, cfg, model): 56 | """ 57 | Args: 58 | cfg (CfgNode): 59 | model (nn.Module): 60 | 61 | Returns: 62 | dict: a dict of result metrics 63 | """ 64 | results = OrderedDict() 65 | for dataset_name in cfg.DATASETS.TEST: 66 | data_loader = cls.build_test_loader(cfg, dataset_name) 67 | evaluator = cls.build_evaluator(cfg, dataset_name) 68 | results_i = inference_on_dataset(model, data_loader, evaluator) 69 | results[dataset_name] = results_i 70 | if comm.is_main_process(): 71 | assert isinstance( 72 | results_i, dict 73 | ), "Evaluator must return a dict on the main process. Got {} instead.".format( 74 | results_i 75 | ) 76 | return results 77 | 78 | 79 | def setup(args): 80 | cfg = get_cfg() 81 | get_meshrcnn_cfg_defaults(cfg) 82 | cfg.merge_from_file(args.config_file) 83 | cfg.merge_from_list(args.opts) 84 | cfg.freeze() 85 | default_setup(cfg, args) 86 | # Setup logger for "meshrcnn" module 87 | setup_logger( 88 | output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="meshrcnn" 89 | ) 90 | return cfg 91 | 92 | 93 | def main(args): 94 | cfg = setup(args) 95 | 96 | if args.eval_only: 97 | model = Trainer.build_model(cfg) 98 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 99 | cfg.MODEL.WEIGHTS, resume=args.resume 100 | ) 101 | res = Trainer.test(cfg, model) 102 | return res 103 | 104 | trainer = Trainer(cfg) 105 | trainer.resume_or_load(resume=args.resume) 106 | return trainer.train() 107 | 108 | 109 | def invoke_main() -> None: 110 | args = default_argument_parser().parse_args() 111 | print("Command Line Args:", args) 112 | launch( 113 | main, 114 | args.num_gpus, 115 | num_machines=args.num_machines, 116 | machine_rank=args.machine_rank, 117 | dist_url=args.dist_url, 118 | args=(args,), 119 | ) 120 | 121 | 122 | if __name__ == "__main__": 123 | invoke_main() # pragma: no cover 124 | --------------------------------------------------------------------------------