├── LICENSE ├── README.md ├── output ├── 000008.png ├── 200.ckpt ├── scene.gif └── scene_rotate.gif ├── scripts ├── detectron2 │ ├── LICENSE │ ├── configs │ │ └── Base-RCNN-FPN.yaml │ └── projects │ │ └── PointRend │ │ ├── configs │ │ ├── InstanceSegmentation │ │ │ ├── Base-PointRend-RCNN-FPN.yaml │ │ │ ├── pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml │ │ │ ├── pointrend_rcnn_R_50_FPN_1x_coco.yaml │ │ │ ├── pointrend_rcnn_R_50_FPN_3x_coco.yaml │ │ │ └── pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml │ │ └── SemanticSegmentation │ │ │ ├── Base-PointRend-Semantic-FPN.yaml │ │ │ └── pointrend_semantic_R_101_FPN_1x_cityscapes.yaml │ │ └── point_rend │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── coarse_mask_head.cpython-38.pyc │ │ ├── color_augmentation.cpython-38.pyc │ │ ├── config.cpython-37.pyc │ │ ├── config.cpython-38.pyc │ │ ├── point_features.cpython-38.pyc │ │ ├── point_head.cpython-38.pyc │ │ ├── roi_heads.cpython-38.pyc │ │ └── semantic_seg.cpython-38.pyc │ │ ├── coarse_mask_head.py │ │ ├── color_augmentation.py │ │ ├── config.py │ │ ├── point_features.py │ │ ├── point_head.py │ │ ├── roi_heads.py │ │ └── semantic_seg.py └── preproc.py └── src ├── __pycache__ ├── encoder.cpython-37.pyc ├── kitti.cpython-37.pyc ├── kitti_util.cpython-37.pyc ├── loss.cpython-37.pyc ├── models.cpython-37.pyc ├── nerf.cpython-37.pyc └── renderer.cpython-37.pyc ├── kitti.py ├── kitti_util.py ├── loss.py ├── models.py ├── renderer.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoRF (unofficial) 2 | This is unofficial implementation of "AutoRF: Learning 3D Object Radiance Fields from Single View Observations", which performs implicit neural reconstruction, manipulation and scene composition for 3D object. In this repo, we use KITTI dataset. 3 | 4 | drawing 5 | drawing 6 | drawing 7 | 8 | 9 |
10 | Dependencies (click to expand) 11 | 12 | ## Dependencies 13 | - pytorch==1.10.1 14 | - matplotlib 15 | - numpy 16 | - imageio 17 |
18 | 19 | ## Quick Start 20 | 21 | Download KITTI data and here we only use image data 22 | ```plain 23 | └── DATA_DIR 24 | ├── training <-- training data 25 | | ├── image_2 26 | | ├── label_2 27 | | ├── calib 28 | ``` 29 | Run the preprocess scripts, which produce instance mask using pretrained PointRend model. 30 | ``` 31 | python scripts/preproc.py 32 | ``` 33 | After this, you will have a certain directory which contains the image, mask and 3D anotation of each instance. 34 | ```plain 35 | └── DATA_DIR 36 | ├── training 37 | | ├── nerf 38 | | ├── 0000008_01_patch.png 39 | | ├── 0000008_01_mask.png 40 | | ├── 0000008_01_label.png 41 | ``` 42 | 43 | Run the following sciprts to train a nerf model 44 | 45 | ``` 46 | python src/train.py 47 | ``` 48 | 49 | After training for serveral iterations (enough is ok), you can find the checkpoint file in the ``output'' folder, and then you can perform scene rendering by running 50 | 51 | ``` 52 | python src/train.py --demo 53 | ``` 54 | 55 | 56 | ## Notice ### 57 | 58 | You can adjust the manipulaion function (in kitti.py) by your self, here I only provide the camera pushing/pulling and instance rotation. 59 | 60 | 61 | -------------------------------------------------------------------------------- /output/000008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/output/000008.png -------------------------------------------------------------------------------- /output/200.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/output/200.ckpt -------------------------------------------------------------------------------- /output/scene.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/output/scene.gif -------------------------------------------------------------------------------- /output/scene_rotate.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/output/scene_rotate.gif -------------------------------------------------------------------------------- /scripts/detectron2/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /scripts/detectron2/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../../../configs/Base-RCNN-FPN.yaml" 2 | MODEL: 3 | ROI_HEADS: 4 | NAME: "PointRendROIHeads" 5 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 6 | ROI_BOX_HEAD: 7 | TRAIN_ON_PRED_BOXES: True 8 | ROI_MASK_HEAD: 9 | NAME: "CoarseMaskHead" 10 | FC_DIM: 1024 11 | NUM_FC: 2 12 | OUTPUT_SIDE_RESOLUTION: 7 13 | IN_FEATURES: ["p2"] 14 | POINT_HEAD_ON: True 15 | POINT_HEAD: 16 | FC_DIM: 256 17 | NUM_FC: 3 18 | IN_FEATURES: ["p2"] 19 | INPUT: 20 | # PointRend for instance segmenation does not work with "polygon" mask_format. 21 | MASK_FORMAT: "bitmask" 22 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl 4 | MASK_ON: true 5 | RESNETS: 6 | DEPTH: 50 7 | ROI_HEADS: 8 | NUM_CLASSES: 8 9 | POINT_HEAD: 10 | NUM_CLASSES: 8 11 | DATASETS: 12 | TEST: ("cityscapes_fine_instance_seg_val",) 13 | TRAIN: ("cityscapes_fine_instance_seg_train",) 14 | SOLVER: 15 | BASE_LR: 0.01 16 | IMS_PER_BATCH: 8 17 | MAX_ITER: 24000 18 | STEPS: (18000,) 19 | INPUT: 20 | MAX_SIZE_TEST: 2048 21 | MAX_SIZE_TRAIN: 2048 22 | MIN_SIZE_TEST: 1024 23 | MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960, 992, 1024) 24 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_1x_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl 4 | MASK_ON: true 5 | RESNETS: 6 | DEPTH: 50 7 | # To add COCO AP evaluation against the higher-quality LVIS annotations. 8 | # DATASETS: 9 | # TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") 10 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl 4 | MASK_ON: true 5 | RESNETS: 6 | DEPTH: 50 7 | SOLVER: 8 | STEPS: (210000, 250000) 9 | MAX_ITER: 270000 10 | # To add COCO AP evaluation against the higher-quality LVIS annotations. 11 | # DATASETS: 12 | # TEST: ("coco_2017_val", "lvis_v0.5_val_cocofied") 13 | 14 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | MASK_ON: True 4 | WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" 5 | PIXEL_STD: [57.375, 57.120, 58.395] 6 | RESNETS: 7 | STRIDE_IN_1X1: False # this is a C2 model 8 | NUM_GROUPS: 32 9 | WIDTH_PER_GROUP: 8 10 | DEPTH: 101 11 | SOLVER: 12 | STEPS: (210000, 250000) 13 | MAX_ITER: 270000 -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/SemanticSegmentation/Base-PointRend-Semantic-FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../../../../configs/Base-RCNN-FPN.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: "SemanticSegmentor" 4 | BACKBONE: 5 | FREEZE_AT: 0 6 | SEM_SEG_HEAD: 7 | NAME: "PointRendSemSegHead" 8 | POINT_HEAD: 9 | NUM_CLASSES: 54 10 | FC_DIM: 256 11 | NUM_FC: 3 12 | IN_FEATURES: ["p2"] 13 | TRAIN_NUM_POINTS: 1024 14 | SUBDIVISION_STEPS: 2 15 | SUBDIVISION_NUM_POINTS: 8192 16 | COARSE_SEM_SEG_HEAD_NAME: "SemSegFPNHead" 17 | COARSE_PRED_EACH_LAYER: False 18 | DATASETS: 19 | TRAIN: ("coco_2017_train_panoptic_stuffonly",) 20 | TEST: ("coco_2017_val_panoptic_stuffonly",) 21 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/configs/SemanticSegmentation/pointrend_semantic_R_101_FPN_1x_cityscapes.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-Semantic-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-101.pkl 4 | RESNETS: 5 | DEPTH: 101 6 | SEM_SEG_HEAD: 7 | NUM_CLASSES: 19 8 | POINT_HEAD: 9 | NUM_CLASSES: 19 10 | TRAIN_NUM_POINTS: 2048 11 | SUBDIVISION_NUM_POINTS: 8192 12 | DATASETS: 13 | TRAIN: ("cityscapes_fine_sem_seg_train",) 14 | TEST: ("cityscapes_fine_sem_seg_val",) 15 | SOLVER: 16 | BASE_LR: 0.01 17 | STEPS: (40000, 55000) 18 | MAX_ITER: 65000 19 | IMS_PER_BATCH: 32 20 | INPUT: 21 | MIN_SIZE_TRAIN: (512, 768, 1024, 1280, 1536, 1792, 2048) 22 | MIN_SIZE_TRAIN_SAMPLING: "choice" 23 | MIN_SIZE_TEST: 1024 24 | MAX_SIZE_TRAIN: 4096 25 | MAX_SIZE_TEST: 2048 26 | CROP: 27 | ENABLED: True 28 | TYPE: "absolute" 29 | SIZE: (512, 1024) 30 | SINGLE_CATEGORY_MAX_AREA: 0.75 31 | COLOR_AUG_SSD: True 32 | DATALOADER: 33 | NUM_WORKERS: 10 34 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .config import add_pointrend_config 3 | from .coarse_mask_head import CoarseMaskHead 4 | from .roi_heads import PointRendROIHeads 5 | from .semantic_seg import PointRendSemSegHead 6 | from .color_augmentation import ColorAugSSDTransform 7 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/coarse_mask_head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/coarse_mask_head.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/color_augmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/color_augmentation.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/point_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/point_features.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/point_head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/point_head.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/roi_heads.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/roi_heads.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/__pycache__/semantic_seg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/scripts/detectron2/projects/PointRend/point_rend/__pycache__/semantic_seg.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/coarse_mask_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from detectron2.layers import Conv2d, ShapeSpec 8 | from detectron2.modeling import ROI_MASK_HEAD_REGISTRY 9 | 10 | 11 | @ROI_MASK_HEAD_REGISTRY.register() 12 | class CoarseMaskHead(nn.Module): 13 | """ 14 | A mask head with fully connected layers. Given pooled features it first reduces channels and 15 | spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously 16 | to the standard box head. 17 | """ 18 | 19 | def __init__(self, cfg, input_shape: ShapeSpec): 20 | """ 21 | The following attributes are parsed from config: 22 | conv_dim: the output dimension of the conv layers 23 | fc_dim: the feature dimenstion of the FC layers 24 | num_fc: the number of FC layers 25 | output_side_resolution: side resolution of the output square mask prediction 26 | """ 27 | super(CoarseMaskHead, self).__init__() 28 | 29 | # fmt: off 30 | self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES 31 | conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM 32 | self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM 33 | num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC 34 | self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION 35 | self.input_channels = input_shape.channels 36 | self.input_h = input_shape.height 37 | self.input_w = input_shape.width 38 | # fmt: on 39 | 40 | self.conv_layers = [] 41 | if self.input_channels > conv_dim: 42 | self.reduce_channel_dim_conv = Conv2d( 43 | self.input_channels, 44 | conv_dim, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=True, 49 | activation=F.relu, 50 | ) 51 | self.conv_layers.append(self.reduce_channel_dim_conv) 52 | 53 | self.reduce_spatial_dim_conv = Conv2d( 54 | conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu 55 | ) 56 | self.conv_layers.append(self.reduce_spatial_dim_conv) 57 | 58 | input_dim = conv_dim * self.input_h * self.input_w 59 | input_dim //= 4 60 | 61 | self.fcs = [] 62 | for k in range(num_fc): 63 | fc = nn.Linear(input_dim, self.fc_dim) 64 | self.add_module("coarse_mask_fc{}".format(k + 1), fc) 65 | self.fcs.append(fc) 66 | input_dim = self.fc_dim 67 | 68 | output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution 69 | 70 | self.prediction = nn.Linear(self.fc_dim, output_dim) 71 | # use normal distribution initialization for mask prediction layer 72 | nn.init.normal_(self.prediction.weight, std=0.001) 73 | nn.init.constant_(self.prediction.bias, 0) 74 | 75 | for layer in self.conv_layers: 76 | weight_init.c2_msra_fill(layer) 77 | for layer in self.fcs: 78 | weight_init.c2_xavier_fill(layer) 79 | 80 | def forward(self, x): 81 | # unlike BaseMaskRCNNHead, this head only outputs intermediate 82 | # features, because the features will be used later by PointHead. 83 | N = x.shape[0] 84 | x = x.view(N, self.input_channels, self.input_h, self.input_w) 85 | for layer in self.conv_layers: 86 | x = layer(x) 87 | x = torch.flatten(x, start_dim=1) 88 | for layer in self.fcs: 89 | x = F.relu(layer(x)) 90 | return self.prediction(x).view( 91 | N, self.num_classes, self.output_side_resolution, self.output_side_resolution 92 | ) 93 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/color_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | import random 4 | import cv2 5 | from fvcore.transforms.transform import Transform 6 | 7 | 8 | class ColorAugSSDTransform(Transform): 9 | """ 10 | A color related data augmentation used in Single Shot Multibox Detector (SSD). 11 | 12 | Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, 13 | Scott Reed, Cheng-Yang Fu, Alexander C. Berg. 14 | SSD: Single Shot MultiBox Detector. ECCV 2016. 15 | 16 | Implementation based on: 17 | 18 | https://github.com/weiliu89/caffe/blob 19 | /4817bf8b4200b35ada8ed0dc378dceaf38c539e4 20 | /src/caffe/util/im_transforms.cpp 21 | 22 | https://github.com/chainer/chainercv/blob 23 | /7159616642e0be7c5b3ef380b848e16b7e99355b/chainercv 24 | /links/model/ssd/transforms.py 25 | """ 26 | 27 | def __init__( 28 | self, 29 | img_format, 30 | brightness_delta=32, 31 | contrast_low=0.5, 32 | contrast_high=1.5, 33 | saturation_low=0.5, 34 | saturation_high=1.5, 35 | hue_delta=18, 36 | ): 37 | super().__init__() 38 | assert img_format in ["BGR", "RGB"] 39 | self.is_rgb = img_format == "RGB" 40 | del img_format 41 | self._set_attributes(locals()) 42 | 43 | def apply_coords(self, coords): 44 | return coords 45 | 46 | def apply_segmentation(self, segmentation): 47 | return segmentation 48 | 49 | def apply_image(self, img, interp=None): 50 | if self.is_rgb: 51 | img = img[:, :, [2, 1, 0]] 52 | img = self.brightness(img) 53 | if random.randrange(2): 54 | img = self.contrast(img) 55 | img = self.saturation(img) 56 | img = self.hue(img) 57 | else: 58 | img = self.saturation(img) 59 | img = self.hue(img) 60 | img = self.contrast(img) 61 | if self.is_rgb: 62 | img = img[:, :, [2, 1, 0]] 63 | return img 64 | 65 | def convert(self, img, alpha=1, beta=0): 66 | img = img.astype(np.float32) * alpha + beta 67 | img = np.clip(img, 0, 255) 68 | return img.astype(np.uint8) 69 | 70 | def brightness(self, img): 71 | if random.randrange(2): 72 | return self.convert( 73 | img, beta=random.uniform(-self.brightness_delta, self.brightness_delta) 74 | ) 75 | return img 76 | 77 | def contrast(self, img): 78 | if random.randrange(2): 79 | return self.convert(img, alpha=random.uniform(self.contrast_low, self.contrast_high)) 80 | return img 81 | 82 | def saturation(self, img): 83 | if random.randrange(2): 84 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 85 | img[:, :, 1] = self.convert( 86 | img[:, :, 1], alpha=random.uniform(self.saturation_low, self.saturation_high) 87 | ) 88 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 89 | return img 90 | 91 | def hue(self, img): 92 | if random.randrange(2): 93 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 94 | img[:, :, 0] = ( 95 | img[:, :, 0].astype(int) + random.randint(-self.hue_delta, self.hue_delta) 96 | ) % 180 97 | return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 98 | return img 99 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_pointrend_config(cfg): 8 | """ 9 | Add config for PointRend. 10 | """ 11 | # We retry random cropping until no single category in semantic segmentation GT occupies more 12 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 13 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 14 | # Color augmentatition from SSD paper for semantic segmentation model during training. 15 | cfg.INPUT.COLOR_AUG_SSD = False 16 | 17 | # Names of the input feature maps to be used by a coarse mask head. 18 | cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",) 19 | cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024 20 | cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 2 21 | # The side size of a coarse mask head prediction. 22 | cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7 23 | # True if point head is used. 24 | cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False 25 | 26 | cfg.MODEL.POINT_HEAD = CN() 27 | cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead" 28 | cfg.MODEL.POINT_HEAD.NUM_CLASSES = 80 29 | # Names of the input feature maps to be used by a mask point head. 30 | cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",) 31 | # Number of points sampled during training for a mask point head. 32 | cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14 33 | # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the 34 | # original paper. 35 | cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3 36 | # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in 37 | # the original paper. 38 | cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75 39 | # Number of subdivision steps during inference. 40 | cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5 41 | # Maximum number of points selected at each subdivision step (N). 42 | cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28 43 | cfg.MODEL.POINT_HEAD.FC_DIM = 256 44 | cfg.MODEL.POINT_HEAD.NUM_FC = 3 45 | cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False 46 | # If True, then coarse prediction features are used as inout for each layer in PointRend's MLP. 47 | cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True 48 | cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead" 49 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/point_features.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from detectron2.layers import cat 6 | from detectron2.structures import Boxes 7 | 8 | 9 | """ 10 | Shape shorthand in this module: 11 | 12 | N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the 13 | number of images for semantic segmenation. 14 | R: number of ROIs, combined over all images, in the minibatch 15 | P: number of points 16 | """ 17 | 18 | 19 | def point_sample(input, point_coords, **kwargs): 20 | """ 21 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 22 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 23 | [0, 1] x [0, 1] square. 24 | 25 | Args: 26 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 27 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 28 | [0, 1] x [0, 1] normalized point coordinates. 29 | 30 | Returns: 31 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 32 | features for points in `point_coords`. The features are obtained via bilinear 33 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 34 | """ 35 | add_dim = False 36 | if point_coords.dim() == 3: 37 | add_dim = True 38 | point_coords = point_coords.unsqueeze(2) 39 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 40 | if add_dim: 41 | output = output.squeeze(3) 42 | return output 43 | 44 | 45 | def generate_regular_grid_point_coords(R, side_size, device): 46 | """ 47 | Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. 48 | 49 | Args: 50 | R (int): The number of grids to sample, one for each region. 51 | side_size (int): The side size of the regular grid. 52 | device (torch.device): Desired device of returned tensor. 53 | 54 | Returns: 55 | (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates 56 | for the regular grids. 57 | """ 58 | aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device) 59 | r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False) 60 | return r.view(1, -1, 2).expand(R, -1, -1) 61 | 62 | 63 | def get_uncertain_point_coords_with_randomness( 64 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio 65 | ): 66 | """ 67 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 68 | are calculated for each point using 'uncertainty_func' function that takes point's logit 69 | prediction as input. 70 | See PointRend paper for details. 71 | 72 | Args: 73 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 74 | class-specific or class-agnostic prediction. 75 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 76 | contains logit predictions for P points and returns their uncertainties as a Tensor of 77 | shape (N, 1, P). 78 | num_points (int): The number of points P to sample. 79 | oversample_ratio (int): Oversampling parameter. 80 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 81 | 82 | Returns: 83 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 84 | sampled points. 85 | """ 86 | assert oversample_ratio >= 1 87 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 88 | num_boxes = coarse_logits.shape[0] 89 | num_sampled = int(num_points * oversample_ratio) 90 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 91 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 92 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points. 93 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 94 | # to incorrect results. 95 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between 96 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. 97 | # However, if we calculate uncertainties for the coarse predictions first, 98 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. 99 | point_uncertainties = uncertainty_func(point_logits) 100 | num_uncertain_points = int(importance_sample_ratio * num_points) 101 | num_random_points = num_points - num_uncertain_points 102 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 103 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 104 | idx += shift[:, None] 105 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 106 | num_boxes, num_uncertain_points, 2 107 | ) 108 | if num_random_points > 0: 109 | point_coords = cat( 110 | [ 111 | point_coords, 112 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 113 | ], 114 | dim=1, 115 | ) 116 | return point_coords 117 | 118 | 119 | def get_uncertain_point_coords_on_grid(uncertainty_map, num_points): 120 | """ 121 | Find `num_points` most uncertain points from `uncertainty_map` grid. 122 | 123 | Args: 124 | uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty 125 | values for a set of points on a regular H x W grid. 126 | num_points (int): The number of points P to select. 127 | 128 | Returns: 129 | point_indices (Tensor): A tensor of shape (N, P) that contains indices from 130 | [0, H x W) of the most uncertain points. 131 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized 132 | coordinates of the most uncertain points from the H x W grid. 133 | """ 134 | R, _, H, W = uncertainty_map.shape 135 | h_step = 1.0 / float(H) 136 | w_step = 1.0 / float(W) 137 | 138 | num_points = min(H * W, num_points) 139 | point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1] 140 | point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device) 141 | point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step 142 | point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step 143 | return point_indices, point_coords 144 | 145 | 146 | def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): 147 | """ 148 | Get features from feature maps in `features_list` that correspond to specific point coordinates 149 | inside each bounding box from `boxes`. 150 | 151 | Args: 152 | features_list (list[Tensor]): A list of feature map tensors to get features from. 153 | feature_scales (list[float]): A list of scales for tensors in `features_list`. 154 | boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all 155 | together. 156 | point_coords (Tensor): A tensor of shape (R, P, 2) that contains 157 | [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. 158 | 159 | Returns: 160 | point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled 161 | from all features maps in feature_list for P sampled points for all R boxes in `boxes`. 162 | point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level 163 | coordinates of P points. 164 | """ 165 | cat_boxes = Boxes.cat(boxes) 166 | num_boxes = [len(b) for b in boxes] 167 | 168 | point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) 169 | split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) 170 | 171 | point_features = [] 172 | for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): 173 | point_features_per_image = [] 174 | for idx_feature, feature_map in enumerate(features_list): 175 | h, w = feature_map.shape[-2:] 176 | scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature] 177 | point_coords_scaled = point_coords_wrt_image_per_image / scale 178 | point_features_per_image.append( 179 | point_sample( 180 | feature_map[idx_img].unsqueeze(0), 181 | point_coords_scaled.unsqueeze(0), 182 | align_corners=False, 183 | ) 184 | .squeeze(0) 185 | .transpose(1, 0) 186 | ) 187 | point_features.append(cat(point_features_per_image, dim=1)) 188 | 189 | return cat(point_features, dim=0), point_coords_wrt_image 190 | 191 | 192 | def get_point_coords_wrt_image(boxes_coords, point_coords): 193 | """ 194 | Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates. 195 | 196 | Args: 197 | boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. 198 | coordinates. 199 | point_coords (Tensor): A tensor of shape (R, P, 2) that contains 200 | [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. 201 | 202 | Returns: 203 | point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains 204 | image-normalized coordinates of P sampled points. 205 | """ 206 | with torch.no_grad(): 207 | point_coords_wrt_image = point_coords.clone() 208 | point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * ( 209 | boxes_coords[:, None, 2] - boxes_coords[:, None, 0] 210 | ) 211 | point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * ( 212 | boxes_coords[:, None, 3] - boxes_coords[:, None, 1] 213 | ) 214 | point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0] 215 | point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1] 216 | return point_coords_wrt_image 217 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/point_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from detectron2.layers import ShapeSpec, cat 8 | from detectron2.structures import BitMasks 9 | from detectron2.utils.events import get_event_storage 10 | from detectron2.utils.registry import Registry 11 | 12 | from .point_features import point_sample 13 | 14 | POINT_HEAD_REGISTRY = Registry("POINT_HEAD") 15 | POINT_HEAD_REGISTRY.__doc__ = """ 16 | Registry for point heads, which makes prediction for a given set of per-point features. 17 | 18 | The registered object will be called with `obj(cfg, input_shape)`. 19 | """ 20 | 21 | 22 | def roi_mask_point_loss(mask_logits, instances, points_coord): 23 | """ 24 | Compute the point-based loss for instance segmentation mask predictions. 25 | 26 | Args: 27 | mask_logits (Tensor): A tensor of shape (R, C, P) or (R, 1, P) for class-specific or 28 | class-agnostic, where R is the total number of predicted masks in all images, C is the 29 | number of foreground classes, and P is the number of points sampled for each mask. 30 | The values are logits. 31 | instances (list[Instances]): A list of N Instances, where N is the number of images 32 | in the batch. These instances are in 1:1 correspondence with the `mask_logits`. So, i_th 33 | elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. 34 | The ground-truth labels (class, box, mask, ...) associated with each instance are stored 35 | in fields. 36 | points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of 37 | predicted masks and P is the number of points for each mask. The coordinates are in 38 | the image pixel coordinate space, i.e. [0, H] x [0, W]. 39 | Returns: 40 | point_loss (Tensor): A scalar tensor containing the loss. 41 | """ 42 | with torch.no_grad(): 43 | cls_agnostic_mask = mask_logits.size(1) == 1 44 | total_num_masks = mask_logits.size(0) 45 | 46 | gt_classes = [] 47 | gt_mask_logits = [] 48 | idx = 0 49 | for instances_per_image in instances: 50 | if len(instances_per_image) == 0: 51 | continue 52 | assert isinstance( 53 | instances_per_image.gt_masks, BitMasks 54 | ), "Point head works with GT in 'bitmask' format. Set INPUT.MASK_FORMAT to 'bitmask'." 55 | 56 | if not cls_agnostic_mask: 57 | gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) 58 | gt_classes.append(gt_classes_per_image) 59 | 60 | gt_bit_masks = instances_per_image.gt_masks.tensor 61 | h, w = instances_per_image.gt_masks.image_size 62 | scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) 63 | points_coord_grid_sample_format = ( 64 | points_coord[idx : idx + len(instances_per_image)] / scale 65 | ) 66 | idx += len(instances_per_image) 67 | gt_mask_logits.append( 68 | point_sample( 69 | gt_bit_masks.to(torch.float32).unsqueeze(1), 70 | points_coord_grid_sample_format, 71 | align_corners=False, 72 | ).squeeze(1) 73 | ) 74 | 75 | if len(gt_mask_logits) == 0: 76 | return mask_logits.sum() * 0 77 | 78 | gt_mask_logits = cat(gt_mask_logits) 79 | assert gt_mask_logits.numel() > 0, gt_mask_logits.shape 80 | 81 | if cls_agnostic_mask: 82 | mask_logits = mask_logits[:, 0] 83 | else: 84 | indices = torch.arange(total_num_masks) 85 | gt_classes = cat(gt_classes, dim=0) 86 | mask_logits = mask_logits[indices, gt_classes] 87 | 88 | # Log the training accuracy (using gt classes and 0.0 threshold for the logits) 89 | mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8) 90 | mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel() 91 | get_event_storage().put_scalar("point_rend/accuracy", mask_accuracy) 92 | 93 | point_loss = F.binary_cross_entropy_with_logits( 94 | mask_logits, gt_mask_logits.to(dtype=torch.float32), reduction="mean" 95 | ) 96 | return point_loss 97 | 98 | 99 | @POINT_HEAD_REGISTRY.register() 100 | class StandardPointHead(nn.Module): 101 | """ 102 | A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head 103 | takes both fine-grained and coarse prediction features as its input. 104 | """ 105 | 106 | def __init__(self, cfg, input_shape: ShapeSpec): 107 | """ 108 | The following attributes are parsed from config: 109 | fc_dim: the output dimension of each FC layers 110 | num_fc: the number of FC layers 111 | coarse_pred_each_layer: if True, coarse prediction features are concatenated to each 112 | layer's input 113 | """ 114 | super(StandardPointHead, self).__init__() 115 | # fmt: off 116 | num_classes = cfg.MODEL.POINT_HEAD.NUM_CLASSES 117 | fc_dim = cfg.MODEL.POINT_HEAD.FC_DIM 118 | num_fc = cfg.MODEL.POINT_HEAD.NUM_FC 119 | cls_agnostic_mask = cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK 120 | self.coarse_pred_each_layer = cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER 121 | input_channels = input_shape.channels 122 | # fmt: on 123 | 124 | fc_dim_in = input_channels + num_classes 125 | self.fc_layers = [] 126 | for k in range(num_fc): 127 | fc = nn.Conv1d(fc_dim_in, fc_dim, kernel_size=1, stride=1, padding=0, bias=True) 128 | self.add_module("fc{}".format(k + 1), fc) 129 | self.fc_layers.append(fc) 130 | fc_dim_in = fc_dim 131 | fc_dim_in += num_classes if self.coarse_pred_each_layer else 0 132 | 133 | num_mask_classes = 1 if cls_agnostic_mask else num_classes 134 | self.predictor = nn.Conv1d(fc_dim_in, num_mask_classes, kernel_size=1, stride=1, padding=0) 135 | 136 | for layer in self.fc_layers: 137 | weight_init.c2_msra_fill(layer) 138 | # use normal distribution initialization for mask prediction layer 139 | nn.init.normal_(self.predictor.weight, std=0.001) 140 | if self.predictor.bias is not None: 141 | nn.init.constant_(self.predictor.bias, 0) 142 | 143 | def forward(self, fine_grained_features, coarse_features): 144 | x = torch.cat((fine_grained_features, coarse_features), dim=1) 145 | for layer in self.fc_layers: 146 | x = F.relu(layer(x)) 147 | if self.coarse_pred_each_layer: 148 | x = cat((x, coarse_features), dim=1) 149 | return self.predictor(x) 150 | 151 | 152 | def build_point_head(cfg, input_channels): 153 | """ 154 | Build a point head defined by `cfg.MODEL.POINT_HEAD.NAME`. 155 | """ 156 | head_name = cfg.MODEL.POINT_HEAD.NAME 157 | return POINT_HEAD_REGISTRY.get(head_name)(cfg, input_channels) 158 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/roi_heads.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import numpy as np 4 | import torch 5 | 6 | from detectron2.layers import ShapeSpec, cat, interpolate 7 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 8 | from detectron2.modeling.roi_heads.mask_head import ( 9 | build_mask_head, 10 | mask_rcnn_inference, 11 | mask_rcnn_loss, 12 | ) 13 | from detectron2.modeling.roi_heads.roi_heads import select_foreground_proposals 14 | 15 | from .point_features import ( 16 | generate_regular_grid_point_coords, 17 | get_uncertain_point_coords_on_grid, 18 | get_uncertain_point_coords_with_randomness, 19 | point_sample, 20 | point_sample_fine_grained_features, 21 | ) 22 | from .point_head import build_point_head, roi_mask_point_loss 23 | 24 | 25 | def calculate_uncertainty(logits, classes): 26 | """ 27 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 28 | foreground class in `classes`. 29 | 30 | Args: 31 | logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or 32 | class-agnostic, where R is the total number of predicted masks in all images and C is 33 | the number of foreground classes. The values are logits. 34 | classes (list): A list of length R that contains either predicted of ground truth class 35 | for eash predicted mask. 36 | 37 | Returns: 38 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 39 | the most uncertain locations having the highest uncertainty score. 40 | """ 41 | if logits.shape[1] == 1: 42 | gt_class_logits = logits.clone() 43 | else: 44 | gt_class_logits = logits[ 45 | torch.arange(logits.shape[0], device=logits.device), classes 46 | ].unsqueeze(1) 47 | return -(torch.abs(gt_class_logits)) 48 | 49 | 50 | @ROI_HEADS_REGISTRY.register() 51 | class PointRendROIHeads(StandardROIHeads): 52 | """ 53 | The RoI heads class for PointRend instance segmentation models. 54 | 55 | In this class we redefine the mask head of `StandardROIHeads` leaving all other heads intact. 56 | To avoid namespace conflict with other heads we use names starting from `mask_` for all 57 | variables that correspond to the mask head in the class's namespace. 58 | """ 59 | 60 | def __init__(self, cfg, input_shape): 61 | # TODO use explicit args style 62 | super().__init__(cfg, input_shape) 63 | self._init_mask_head(cfg, input_shape) 64 | 65 | def _init_mask_head(self, cfg, input_shape): 66 | # fmt: off 67 | self.mask_on = cfg.MODEL.MASK_ON 68 | if not self.mask_on: 69 | return 70 | self.mask_coarse_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES 71 | self.mask_coarse_side_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION 72 | self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} 73 | # fmt: on 74 | 75 | in_channels = np.sum([input_shape[f].channels for f in self.mask_coarse_in_features]) 76 | self.mask_coarse_head = build_mask_head( 77 | cfg, 78 | ShapeSpec( 79 | channels=in_channels, 80 | width=self.mask_coarse_side_size, 81 | height=self.mask_coarse_side_size, 82 | ), 83 | ) 84 | self._init_point_head(cfg, input_shape) 85 | 86 | def _init_point_head(self, cfg, input_shape): 87 | # fmt: off 88 | self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON 89 | if not self.mask_point_on: 90 | return 91 | assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES 92 | self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES 93 | self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS 94 | self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO 95 | self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO 96 | # next two parameters are use in the adaptive subdivions inference procedure 97 | self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS 98 | self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS 99 | # fmt: on 100 | 101 | in_channels = np.sum([input_shape[f].channels for f in self.mask_point_in_features]) 102 | self.mask_point_head = build_point_head( 103 | cfg, ShapeSpec(channels=in_channels, width=1, height=1) 104 | ) 105 | 106 | def _forward_mask(self, features, instances): 107 | """ 108 | Forward logic of the mask prediction branch. 109 | 110 | Args: 111 | features (dict[str, Tensor]): #level input features for mask prediction 112 | instances (list[Instances]): the per-image instances to train/predict masks. 113 | In training, they can be the proposals. 114 | In inference, they can be the predicted boxes. 115 | 116 | Returns: 117 | In training, a dict of losses. 118 | In inference, update `instances` with new fields "pred_masks" and return it. 119 | """ 120 | if not self.mask_on: 121 | return {} if self.training else instances 122 | 123 | if self.training: 124 | proposals, _ = select_foreground_proposals(instances, self.num_classes) 125 | proposal_boxes = [x.proposal_boxes for x in proposals] 126 | mask_coarse_logits = self._forward_mask_coarse(features, proposal_boxes) 127 | 128 | losses = {"loss_mask": mask_rcnn_loss(mask_coarse_logits, proposals)} 129 | losses.update(self._forward_mask_point(features, mask_coarse_logits, proposals)) 130 | return losses 131 | else: 132 | pred_boxes = [x.pred_boxes for x in instances] 133 | mask_coarse_logits = self._forward_mask_coarse(features, pred_boxes) 134 | 135 | mask_logits = self._forward_mask_point(features, mask_coarse_logits, instances) 136 | mask_rcnn_inference(mask_logits, instances) 137 | return instances 138 | 139 | def _forward_mask_coarse(self, features, boxes): 140 | """ 141 | Forward logic of the coarse mask head. 142 | """ 143 | point_coords = generate_regular_grid_point_coords( 144 | np.sum(len(x) for x in boxes), self.mask_coarse_side_size, boxes[0].device 145 | ) 146 | mask_coarse_features_list = [features[k] for k in self.mask_coarse_in_features] 147 | features_scales = [self._feature_scales[k] for k in self.mask_coarse_in_features] 148 | # For regular grids of points, this function is equivalent to `len(features_list)' calls 149 | # of `ROIAlign` (with `SAMPLING_RATIO=2`), and concat the results. 150 | mask_features, _ = point_sample_fine_grained_features( 151 | mask_coarse_features_list, features_scales, boxes, point_coords 152 | ) 153 | return self.mask_coarse_head(mask_features) 154 | 155 | def _forward_mask_point(self, features, mask_coarse_logits, instances): 156 | """ 157 | Forward logic of the mask point head. 158 | """ 159 | if not self.mask_point_on: 160 | return {} if self.training else mask_coarse_logits 161 | 162 | mask_features_list = [features[k] for k in self.mask_point_in_features] 163 | features_scales = [self._feature_scales[k] for k in self.mask_point_in_features] 164 | 165 | if self.training: 166 | proposal_boxes = [x.proposal_boxes for x in instances] 167 | gt_classes = cat([x.gt_classes for x in instances]) 168 | with torch.no_grad(): 169 | point_coords = get_uncertain_point_coords_with_randomness( 170 | mask_coarse_logits, 171 | lambda logits: calculate_uncertainty(logits, gt_classes), 172 | self.mask_point_train_num_points, 173 | self.mask_point_oversample_ratio, 174 | self.mask_point_importance_sample_ratio, 175 | ) 176 | 177 | fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features( 178 | mask_features_list, features_scales, proposal_boxes, point_coords 179 | ) 180 | coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False) 181 | point_logits = self.mask_point_head(fine_grained_features, coarse_features) 182 | return { 183 | "loss_mask_point": roi_mask_point_loss( 184 | point_logits, instances, point_coords_wrt_image 185 | ) 186 | } 187 | else: 188 | pred_boxes = [x.pred_boxes for x in instances] 189 | pred_classes = cat([x.pred_classes for x in instances]) 190 | # The subdivision code will fail with the empty list of boxes 191 | if len(pred_classes) == 0: 192 | return mask_coarse_logits 193 | 194 | mask_logits = mask_coarse_logits.clone() 195 | for subdivions_step in range(self.mask_point_subdivision_steps): 196 | mask_logits = interpolate( 197 | mask_logits, scale_factor=2, mode="bilinear", align_corners=False 198 | ) 199 | # If `mask_point_subdivision_num_points` is larger or equal to the 200 | # resolution of the next step, then we can skip this step 201 | H, W = mask_logits.shape[-2:] 202 | if ( 203 | self.mask_point_subdivision_num_points >= 4 * H * W 204 | and subdivions_step < self.mask_point_subdivision_steps - 1 205 | ): 206 | continue 207 | uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) 208 | point_indices, point_coords = get_uncertain_point_coords_on_grid( 209 | uncertainty_map, self.mask_point_subdivision_num_points 210 | ) 211 | fine_grained_features, _ = point_sample_fine_grained_features( 212 | mask_features_list, features_scales, pred_boxes, point_coords 213 | ) 214 | coarse_features = point_sample( 215 | mask_coarse_logits, point_coords, align_corners=False 216 | ) 217 | point_logits = self.mask_point_head(fine_grained_features, coarse_features) 218 | 219 | # put mask point predictions to the right places on the upsampled grid. 220 | R, C, H, W = mask_logits.shape 221 | point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) 222 | mask_logits = ( 223 | mask_logits.reshape(R, C, H * W) 224 | .scatter_(2, point_indices, point_logits) 225 | .view(R, C, H, W) 226 | ) 227 | return mask_logits 228 | -------------------------------------------------------------------------------- /scripts/detectron2/projects/PointRend/point_rend/semantic_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | from typing import Dict 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from detectron2.layers import ShapeSpec, cat 9 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 10 | 11 | from .point_features import ( 12 | get_uncertain_point_coords_on_grid, 13 | get_uncertain_point_coords_with_randomness, 14 | point_sample, 15 | ) 16 | from .point_head import build_point_head 17 | 18 | 19 | def calculate_uncertainty(sem_seg_logits): 20 | """ 21 | For each location of the prediction `sem_seg_logits` we estimate uncerainty as the 22 | difference between top first and top second predicted logits. 23 | 24 | Args: 25 | mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and 26 | C is the number of foreground classes. The values are logits. 27 | 28 | Returns: 29 | scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with 30 | the most uncertain locations having the highest uncertainty score. 31 | """ 32 | top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0] 33 | return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) 34 | 35 | 36 | @SEM_SEG_HEADS_REGISTRY.register() 37 | class PointRendSemSegHead(nn.Module): 38 | """ 39 | A semantic segmentation head that combines a head set in `POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME` 40 | and a point head set in `MODEL.POINT_HEAD.NAME`. 41 | """ 42 | 43 | def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): 44 | super().__init__() 45 | 46 | self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE 47 | 48 | self.coarse_sem_seg_head = SEM_SEG_HEADS_REGISTRY.get( 49 | cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME 50 | )(cfg, input_shape) 51 | self._init_point_head(cfg, input_shape) 52 | 53 | def _init_point_head(self, cfg, input_shape: Dict[str, ShapeSpec]): 54 | # fmt: off 55 | assert cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES 56 | feature_channels = {k: v.channels for k, v in input_shape.items()} 57 | self.in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES 58 | self.train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS 59 | self.oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO 60 | self.importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO 61 | self.subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS 62 | self.subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS 63 | # fmt: on 64 | 65 | in_channels = np.sum([feature_channels[f] for f in self.in_features]) 66 | self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) 67 | 68 | def forward(self, features, targets=None): 69 | coarse_sem_seg_logits = self.coarse_sem_seg_head.layers(features) 70 | 71 | if self.training: 72 | losses = self.coarse_sem_seg_head.losses(coarse_sem_seg_logits, targets) 73 | 74 | with torch.no_grad(): 75 | point_coords = get_uncertain_point_coords_with_randomness( 76 | coarse_sem_seg_logits, 77 | calculate_uncertainty, 78 | self.train_num_points, 79 | self.oversample_ratio, 80 | self.importance_sample_ratio, 81 | ) 82 | coarse_features = point_sample(coarse_sem_seg_logits, point_coords, align_corners=False) 83 | 84 | fine_grained_features = cat( 85 | [ 86 | point_sample(features[in_feature], point_coords, align_corners=False) 87 | for in_feature in self.in_features 88 | ], 89 | dim=1, 90 | ) 91 | point_logits = self.point_head(fine_grained_features, coarse_features) 92 | point_targets = ( 93 | point_sample( 94 | targets.unsqueeze(1).to(torch.float), 95 | point_coords, 96 | mode="nearest", 97 | align_corners=False, 98 | ) 99 | .squeeze(1) 100 | .to(torch.long) 101 | ) 102 | losses["loss_sem_seg_point"] = F.cross_entropy( 103 | point_logits, point_targets, reduction="mean", ignore_index=self.ignore_value 104 | ) 105 | return None, losses 106 | else: 107 | sem_seg_logits = coarse_sem_seg_logits.clone() 108 | for _ in range(self.subdivision_steps): 109 | sem_seg_logits = F.interpolate( 110 | sem_seg_logits, scale_factor=2, mode="bilinear", align_corners=False 111 | ) 112 | uncertainty_map = calculate_uncertainty(sem_seg_logits) 113 | point_indices, point_coords = get_uncertain_point_coords_on_grid( 114 | uncertainty_map, self.subdivision_num_points 115 | ) 116 | fine_grained_features = cat( 117 | [ 118 | point_sample(features[in_feature], point_coords, align_corners=False) 119 | for in_feature in self.in_features 120 | ] 121 | ) 122 | coarse_features = point_sample( 123 | coarse_sem_seg_logits, point_coords, align_corners=False 124 | ) 125 | point_logits = self.point_head(fine_grained_features, coarse_features) 126 | 127 | # put sem seg point predictions to the right places on the upsampled grid. 128 | N, C, H, W = sem_seg_logits.shape 129 | point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) 130 | sem_seg_logits = ( 131 | sem_seg_logits.reshape(N, C, H * W) 132 | .scatter_(2, point_indices, point_logits) 133 | .view(N, C, H, W) 134 | ) 135 | return sem_seg_logits, {} 136 | -------------------------------------------------------------------------------- /scripts/preproc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert( 5 | 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")) 6 | ) 7 | 8 | import torch 9 | 10 | from detectron2.engine import DefaultPredictor 11 | from detectron2.config import get_cfg 12 | from detectron2.data import MetadataCatalog 13 | from detectron2 import structures 14 | from detectron2.projects import point_rend 15 | coco_metadata = MetadataCatalog.get("coco_2017_val") 16 | 17 | import numpy as np 18 | import cv2 19 | 20 | 21 | import kitti_util 22 | 23 | cfg = get_cfg() 24 | # Add PointRend-specific config 25 | point_rend.add_pointrend_config(cfg) 26 | # Load a config from file 27 | cfg.merge_from_file("scripts/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml") 28 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model 29 | # Use a model from PointRend model zoo: https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend#pretrained-models 30 | cfg.MODEL.WEIGHTS = "detectron2://PointRend/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco/164955410/model_final_edd263.pkl" 31 | predictor = DefaultPredictor(cfg) 32 | 33 | class Prepare(torch.utils.data.Dataset): 34 | 35 | def __init__(self, ): 36 | super().__init__() 37 | 38 | self.ids = range( 39 | len(os.listdir( 40 | '/data0/billyhe/KITTI/training/label_2')) 41 | ) 42 | 43 | def __len__(self): 44 | return len(self.ids) 45 | 46 | def __getitem__(self, idx): 47 | id = self.ids[idx] 48 | 49 | objs = kitti_util.read_label('/data0/billyhe/KITTI/training/label_2/%06d.txt' % id) 50 | img = cv2.imread('/data0/billyhe/KITTI/training/image_2/%06d.png' % id) 51 | 52 | insts = predictor(img)["instances"] 53 | insts = insts[insts.pred_classes == 2] # 2 for ca 54 | ious = structures.pairwise_iou( 55 | structures.Boxes(torch.Tensor([obj.box2d for obj in objs])).to(insts.pred_boxes.device), 56 | insts.pred_boxes 57 | ) 58 | 59 | if ious.numel() == 0: 60 | return 1 61 | 62 | for i, obj in enumerate(objs): 63 | 64 | if obj.type == 'DontCare': 65 | continue 66 | if obj.t[2] > 50: 67 | continue 68 | if obj.ymax - obj.ymin < 64: 69 | continue 70 | iou, j = torch.max(ious[i]), torch.argmax(ious[i]) 71 | if iou<.8: 72 | continue 73 | rgb_gt = img[int(obj.ymin):int(obj.ymax), int(obj.xmin):int(obj.xmax), :] 74 | msk_gt = insts.pred_masks[j][int(obj.ymin):int(obj.ymax), int(obj.xmin):int(obj.xmax)] 75 | 76 | cv2.imwrite('/data0/billyhe/KITTI/training/nerf/%06d_%02d_patch.png' % (id, i), rgb_gt) 77 | cv2.imwrite('/data0/billyhe/KITTI/training/nerf/%06d_%02d_mask.png' % (id, i), np.stack([msk_gt.cpu()*255]*3, -1)) 78 | anno = [obj.xmin, obj.xmax, obj.ymin, obj.ymax] + list(obj.t) + list(obj.dim) + [obj.ry] 79 | anno = [str(x) for x in anno] 80 | with open('/data0/billyhe/KITTI/training/nerf/%06d_%02d_label.txt' % (id, i), 'w') as f: 81 | f.writelines(' '.join(anno)) 82 | 83 | 84 | return 1 85 | 86 | 87 | if __name__ == "__main__": 88 | 89 | 90 | loader = torch.utils.data.DataLoader( 91 | Prepare(), 92 | batch_size=1, 93 | shuffle=False, 94 | num_workers=0 95 | ) 96 | 97 | for _ in loader: 98 | pass -------------------------------------------------------------------------------- /src/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/src/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/kitti.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/src/__pycache__/kitti.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/kitti_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/src/__pycache__/kitti_util.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/src/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/src/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/nerf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/src/__pycache__/nerf.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/renderer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skyhehe123/AutoRF-pytorch/0be4e13a2543b25c4b91d76359bbe7d2fb0c5eea/src/__pycache__/renderer.cpython-37.pyc -------------------------------------------------------------------------------- /src/kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import cv2 5 | 6 | import torchvision.transforms as T 7 | 8 | import numpy as np 9 | 10 | import kitti_util 11 | 12 | img_transform = T.Compose([T.Resize((128, 128)), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 13 | 14 | def manipulate(objs, txyz): 15 | objs[:, 0] += txyz[0] 16 | objs[:, 1] += txyz[1] 17 | objs[:, 2] += txyz[2] 18 | 19 | # objs[:, :3] = kitti_util.rotate_yaw(objs[:, :3], np.pi/15) 20 | # objs[:, 6] += np.pi/15 21 | 22 | # corners = np.stack(get_corners(obj) for obj in objs) 23 | 24 | # kitti_util.visualize_offscreen(np.zeros([1,3]), corners, save_path='boxes.png') 25 | return objs 26 | 27 | class KITTI(torch.utils.data.Dataset): 28 | 29 | def __init__(self, ): 30 | super().__init__() 31 | 32 | self.filelist = [f[:-10] for f in os.listdir("/data0/billyhe/KITTI/training/nerf") if "label" in f ] 33 | self.filelist.sort() 34 | 35 | self.cam_pos = torch.eye(4)[None, :, :] 36 | self.cam_pos[:, 2, 2] = -1 37 | self.cam_pos[:, 1, 1] = -1 38 | 39 | def __getitem__(self, idx): 40 | # idx=2 41 | id = self.filelist[idx] 42 | 43 | img = cv2.imread('/data0/billyhe/KITTI/training/nerf/%s_patch.png' % id) 44 | msk = cv2.imread('/data0/billyhe/KITTI/training/nerf/%s_mask.png' % id) 45 | 46 | with open('/data0/billyhe/KITTI/training/nerf/%s_label.txt' % id , 'r') as f: 47 | obj = f.readlines()[0].split() 48 | 49 | sid = id[:6] 50 | 51 | 52 | calib = kitti_util.Calibration('/data0/billyhe/KITTI/training/calib/%s.txt' % sid) 53 | imshape = cv2.imread('/data0/billyhe/KITTI/training/image_2/%s.png' % sid).shape 54 | 55 | render_rays = kitti_util.gen_rays( 56 | self.cam_pos, imshape[1], imshape[0], 57 | torch.tensor([calib.f_u, calib.f_v]), 0, np.inf, 58 | torch.tensor([calib.c_u, calib.c_v]) 59 | )[0].numpy() 60 | 61 | xmin, xmax, ymin, ymax, tx, ty, tz, dx, dy, dz, ry = [float(a) for a in obj] 62 | 63 | cam_rays = render_rays[int(ymin):int(ymax), int(xmin):int(xmax), :].reshape(-1, 8) 64 | 65 | objs = np.array([tx, ty, tz, dx, dy, dz, ry]).reshape(1, 7) 66 | 67 | 68 | ray_o = kitti_util.world2object(np.zeros((len(cam_rays), 3)), objs) 69 | ray_d = kitti_util.world2object(cam_rays[:, 3:6], objs, use_dir=True) 70 | 71 | z_in, z_out, intersect = kitti_util.ray_box_intersection(ray_o, ray_d) 72 | 73 | bounds = np.ones((*ray_o.shape[:-1], 2)) * -1 74 | bounds [intersect, 0] = z_in 75 | bounds [intersect, 1] = z_out 76 | 77 | cam_rays = np.concatenate([ray_o, ray_d, bounds], -1) 78 | 79 | 80 | return img, msk, cam_rays 81 | 82 | def __len__(self): 83 | return len(self.filelist) 84 | 85 | def __getviews__(self, idx, 86 | ry_list = [0, np.pi/2, np.pi, 1.75*np.pi], 87 | txyz=[0., 1.75, 12]): 88 | 89 | id = self.filelist[idx] 90 | 91 | img = cv2.imread('/data0/billyhe/KITTI/training/nerf/%s_patch.png' % id) 92 | 93 | with open('/data0/billyhe/KITTI/training/nerf/%s_label.txt' % id , 'r') as f: 94 | obj = f.readlines()[0].split() 95 | 96 | sid = id[:6] 97 | 98 | calib = kitti_util.Calibration('/data0/billyhe/KITTI/training/calib/%s.txt' % sid) 99 | canvas = cv2.imread('/data0/billyhe/KITTI/training/image_2/%s.png' % sid) 100 | 101 | render_rays = kitti_util.gen_rays( 102 | self.cam_pos, canvas.shape[1], canvas.shape[0], 103 | torch.tensor([calib.f_u, calib.f_v]), 0, np.inf, 104 | torch.tensor([calib.c_u, calib.c_v]) 105 | )[0].numpy() 106 | 107 | 108 | test_data = list() 109 | out_shape = list() 110 | for ry in ry_list: 111 | _,_,_,_,_,_,_, l, h, w, _ = [float(a) for a in obj] 112 | xmin, ymin, xmax, ymax = box3d_to_image_roi(txyz + [l, h, w, ry], calib.P, canvas.shape) 113 | 114 | cam_rays = render_rays[int(ymin):int(ymax), int(xmin):int(xmax), :].reshape(-1, 8) 115 | 116 | objs = np.array(txyz + [l, h, w, ry]).reshape(1, 7) 117 | 118 | ray_o = kitti_util.world2object(np.zeros((len(cam_rays), 3)), objs) 119 | ray_d = kitti_util.world2object(cam_rays[:, 3:6], objs, use_dir=True) 120 | 121 | z_in, z_out, intersect = kitti_util.ray_box_intersection(ray_o, ray_d) 122 | 123 | bounds = np.ones((*ray_o.shape[:-1], 2)) * -1 124 | bounds [intersect, 0] = z_in 125 | bounds [intersect, 1] = z_out 126 | 127 | cam_rays = np.concatenate([ray_o, ray_d, bounds], -1) 128 | 129 | out_shape.append( [int(ymax)-int(ymin), int(xmax)-int(xmin) ]) 130 | 131 | test_data.append( collate_lambda_test(img, cam_rays) ) 132 | 133 | return img, test_data, out_shape 134 | 135 | def __getscene__(self, sid, manipulation=None): 136 | calib = kitti_util.Calibration('/data0/billyhe/KITTI/training/calib/%06d.txt' % sid) 137 | canvas = cv2.imread('/data0/billyhe/KITTI/training/image_2/%06d.png' % sid) 138 | 139 | render_rays = kitti_util.gen_rays( 140 | self.cam_pos, canvas.shape[1], canvas.shape[0], 141 | torch.tensor([calib.f_u, calib.f_v]), 0, np.inf, 142 | torch.tensor([calib.c_u, calib.c_v]) 143 | )[0].flatten(0,1).numpy() 144 | 145 | objs = kitti_util.read_label('/data0/billyhe/KITTI/training/label_2/%06d.txt' % sid) 146 | 147 | objs_pose = np.array([obj.t for obj in objs if obj.type == 'Car']).reshape(-1, 3) 148 | objs_dim = np.array([obj.dim for obj in objs if obj.type == 'Car']).reshape(-1, 3) 149 | objs_yaw = np.array([obj.ry for obj in objs if obj.type == 'Car']).reshape(-1, 1) 150 | # objs_box = np.stack([obj.box2d for obj in objs if obj.type == 'Car']).reshape(-1, 4) 151 | 152 | objs = np.concatenate([objs_pose, objs_dim, objs_yaw], -1) 153 | 154 | ##################### 155 | rois = list() 156 | for obj in objs: 157 | 158 | xmin, ymin, xmax, ymax = box3d_to_image_roi(obj, calib.P, canvas.shape) 159 | 160 | roi = canvas[int(ymin):int(ymax), int(xmin):int(xmax), :] 161 | roi = T.ToTensor()(roi) 162 | roi = img_transform(roi) 163 | rois.append(roi) 164 | 165 | rois = torch.stack(rois) 166 | 167 | # manipulate 3d boxes 168 | if manipulation is not None: 169 | objs = manipulate(objs, manipulation) 170 | 171 | # get rays from 3d boxes 172 | ray_o = kitti_util.world2object(np.zeros((len(render_rays), 3)), objs) 173 | ray_d = kitti_util.world2object(render_rays[:, 3:6], objs, use_dir=True) 174 | 175 | z_in, z_out, intersect = kitti_util.ray_box_intersection(ray_o, ray_d) 176 | 177 | bounds = np.ones((*ray_o.shape[:-1], 2)) * -1 178 | bounds [intersect, 0] = z_in 179 | bounds [intersect, 1] = z_out 180 | 181 | scene_render_rays = np.concatenate([ray_o, ray_d, bounds], -1) 182 | _, nb, nc = scene_render_rays.shape 183 | scene_render_rays = scene_render_rays.reshape(canvas.shape[0], canvas.shape[1], nb, nc) 184 | 185 | return canvas, \ 186 | torch.FloatTensor(scene_render_rays), \ 187 | rois, \ 188 | torch.from_numpy( np.any(intersect, 1) ),\ 189 | torch.FloatTensor(objs) 190 | 191 | 192 | 193 | 194 | 195 | 196 | def collate_lambda_train(batch, ray_batch_size=1024): 197 | imgs = list() 198 | msks = list() 199 | rays = list() 200 | rgbs = list() 201 | 202 | for el in batch: 203 | im, msk, cam_rays = el 204 | im = T.ToTensor()(im) 205 | msk = T.ToTensor()(msk) 206 | cam_rays = torch.FloatTensor(cam_rays) 207 | 208 | _, H, W = im.shape 209 | 210 | pix_inds = torch.randint(0, H * W, (ray_batch_size,)) 211 | 212 | rgb_gt = im.permute(1,2,0).flatten(0,1)[pix_inds,...] 213 | msk_gt = msk.permute(1,2,0).flatten(0,1)[pix_inds,...] 214 | ray = cam_rays.view(-1, cam_rays.shape[-1])[pix_inds] 215 | 216 | imgs.append( 217 | img_transform(im) 218 | ) 219 | msks.append(msk_gt) 220 | rays.append(ray) 221 | rgbs.append(rgb_gt) 222 | 223 | imgs = torch.stack(imgs) 224 | rgbs = torch.stack(rgbs, 1) 225 | msks = torch.stack(msks, 1) 226 | rays = torch.stack(rays, 1) 227 | 228 | return imgs, rays, rgbs, msks 229 | 230 | 231 | 232 | def collate_lambda_test(im, cam_rays, ray_batch_size=1024): 233 | imgs = list() 234 | rays = list() 235 | 236 | im = T.ToTensor()(im) 237 | cam_rays = torch.FloatTensor(cam_rays) 238 | 239 | N = cam_rays.shape[0] 240 | 241 | for i in range(N// ray_batch_size + 1): 242 | 243 | pix_inds = np.arange(i*ray_batch_size, i*ray_batch_size + ray_batch_size) 244 | 245 | if i == N // ray_batch_size: 246 | pix_inds = np.clip(pix_inds, 0, N-1) 247 | 248 | ray = cam_rays[pix_inds] 249 | rays.append(ray) 250 | 251 | imgs = img_transform(im).unsqueeze(0) 252 | rays = torch.stack(rays) 253 | 254 | return imgs, rays 255 | 256 | 257 | def get_corners(obj): 258 | if isinstance(obj, list): 259 | tx, ty, tz, l, h, w, ry = obj 260 | else: 261 | tx, ty, tz, l, h, w, ry = obj.tolist() 262 | 263 | # 3d bounding box corners 264 | x_corners = [l/2,l/2,-l/2,-l/2,l/2,l/2,-l/2,-l/2] 265 | y_corners = [0,0,0,0,-h,-h,-h,-h] 266 | z_corners = [w/2,-w/2,-w/2,w/2,w/2,-w/2,-w/2,w/2] 267 | 268 | R = kitti_util.roty(ry) 269 | # rotate and translate 3d bounding box 270 | corners_3d = np.dot(R, np.vstack([x_corners,y_corners,z_corners])) 271 | #print corners_3d.shape 272 | corners_3d[0,:] = corners_3d[0,:] + tx 273 | corners_3d[1,:] = corners_3d[1,:] + ty 274 | corners_3d[2,:] = corners_3d[2,:] + tz 275 | return np.transpose(corners_3d) 276 | 277 | 278 | def box3d_to_image_roi(obj, P, imshape=None): 279 | corners_3d = get_corners(obj) 280 | 281 | # project the 3d bounding box into the image plane 282 | corners_2d = kitti_util.project_to_image(corners_3d, P) 283 | xmin, ymin = np.min(corners_2d, axis=0) 284 | xmax, ymax = np.max(corners_2d, axis=0) 285 | 286 | if imshape is not None: 287 | xmin = np.clip(xmin, 0, imshape[1]) 288 | xmax = np.clip(xmax, 0, imshape[1]) 289 | ymin = np.clip(ymin, 0, imshape[0]) 290 | ymax = np.clip(ymax, 0, imshape[0]) 291 | 292 | return xmin, ymin, xmax, ymax 293 | 294 | if __name__ == "__main__": 295 | ds = KITTI() 296 | ds.__getscene__(8) 297 | -------------------------------------------------------------------------------- /src/kitti_util.py: -------------------------------------------------------------------------------- 1 | from dotmap import DotMap 2 | from matplotlib import use 3 | import numpy as np 4 | import cv2 5 | import os 6 | 7 | import torch 8 | 9 | class Object3d(object): 10 | ''' 3d object label ''' 11 | def __init__(self, label_file_line): 12 | data = label_file_line.split(' ') 13 | data[1:] = [float(x) for x in data[1:]] 14 | 15 | # extract label, truncation, occlusion 16 | self.type = data[0] # 'Car', 'Pedestrian', ... 17 | self.truncation = data[1] # truncated pixel ratio [0..1] 18 | self.occlusion = int(data[2]) # 0=visible, 1=partly occluded, 2=fully occluded, 3=unknown 19 | self.alpha = data[3] # object observation angle [-pi..pi] 20 | 21 | # extract 2d bounding box in 0-based coordinates 22 | self.xmin = data[4] # left 23 | self.ymin = data[5] # top 24 | self.xmax = data[6] # right 25 | self.ymax = data[7] # bottom 26 | self.box2d = np.array([self.xmin,self.ymin,self.xmax,self.ymax]) 27 | 28 | # extract 3d bounding box information 29 | self.h = data[8] # box height 30 | self.w = data[9] # box width 31 | self.l = data[10] # box length (in meters) 32 | self.t = (data[11],data[12],data[13]) # location (x,y,z) in camera coord. 33 | self.dim = (self.l, self.h, self.w) 34 | self.ry = data[14] # yaw angle (around Y-axis in camera coordinates) [-pi..pi] 35 | 36 | def print_object(self): 37 | print('Type, truncation, occlusion, alpha: %s, %d, %d, %f' % \ 38 | (self.type, self.truncation, self.occlusion, self.alpha)) 39 | print('2d bbox (x0,y0,x1,y1): %f, %f, %f, %f' % \ 40 | (self.xmin, self.ymin, self.xmax, self.ymax)) 41 | print('3d bbox h,w,l: %f, %f, %f' % \ 42 | (self.h, self.w, self.l)) 43 | print('3d bbox location, ry: (%f, %f, %f), %f' % \ 44 | (self.t[0],self.t[1],self.t[2],self.ry)) 45 | 46 | 47 | 48 | class Calibration(object): 49 | ''' Calibration matrices and utils 50 | 3d XYZ in