├── .gitignore ├── LICENSE ├── README.md ├── configs ├── interactron.yaml ├── interactron_random.yaml ├── multi_frame_baseline.yaml └── single_frame_baseline.yaml ├── data_collection ├── collect_ithor_tree_data.py └── data_collection_utils.py ├── datasets ├── __init__.py ├── interactive_dataset.py └── sequence_dataset.py ├── engine ├── __init__.py ├── direct_supervision_trainer.py ├── interactive_evaluator.py ├── interactron_random_trainer.py ├── interactron_trainer.py └── random_policy_evaluator.py ├── evaluate.py ├── images └── teaser-wide.png ├── models ├── __init__.py ├── components.py ├── detr.py ├── detr_models │ ├── __init__.py │ ├── backbone.py │ ├── detr.py │ ├── matcher.py │ ├── position_encoding.py │ ├── segmentation.py │ ├── transformer.py │ └── util │ │ ├── __init__.py │ │ ├── box_ops.py │ │ ├── misc.py │ │ └── transforms.py ├── detr_multiframe.py ├── five_frame_baseline.py ├── gpt.py ├── interactron.py ├── interactron_random.py ├── learned_loss.py ├── mlp_detector.py ├── new_transformer.py ├── single_frame_baseline.py └── transformer.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── config_utils.py ├── constants.py ├── detection_utils.py ├── logging_utils.py ├── meta_utils.py ├── model_utils.py ├── storage_utils.py ├── time_utils.py ├── transform_utis.py └── viz_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interactron: Embodied Adaptive Object Detection 2 | 3 | By Klemen Kotar and Roozbeh Mottagh 4 | 5 | ![teaser](images/teaser-wide.png) 6 | 7 | Interactron is a model for interactive, embodied object detection. 8 | It is the official codebase for the paper 9 | [Interactron: Embodied Adaptive Object Detection](https://arxiv.org/abs/2202.00660). 10 | Traditionally object detectors are trained on a fixed training set and frozen at evaluation. 11 | This project explores methods of dynamically adpating object detection models to their test 12 | time environments using MAML style meta learning and interactive exploration. 13 | 14 | 15 | ## Setup 16 | 17 | - Clone the repository with `git clone https://github.com/allenai/interactron.git && cd interactron`. 18 | 19 | - Install the necessary packages. If you are using pip then simply run `pip install -r requirements.txt`. 20 | 21 | - If running on GPUs, we strongly recommend installing PyTorch with conda. 22 | 23 | - Download the [pretrained weights](https://interactron.s3.us-east-2.amazonaws.com/pretrained_weights.tar.gz) and 24 | [data](https://interactron.s3.us-east-2.amazonaws.com/data.tar.gz) to the `interactron` directory. Untar with 25 | ```bash 26 | tar -xzf pretrained_weights.tar.gz 27 | tar -xzf data.tar.gz 28 | ``` 29 | 30 | ## Results 31 | 32 | Bellow is a summary of the results of the various models. 33 | 34 | | Model | Policy | Adaptive | AP | AP_50 | 35 | |------------------|----------|----------|-------|-------| 36 | | DETR | No Move | No | 0.256 | 0.448 | 37 | | Multi-Frame | Random | No | 0.288 | 0.517 | 38 | | Interactron-Rand | Random | Yes | 0.313 | 0.551 | 39 | | Interactron | Learned | Yes | 0.328 | 0.575 | 40 | 41 | For more detaile results please see the full paper 42 | [Interactron: Embodied Adaptive Object Detection](https://arxiv.org/abs/2202.00660). 43 | 44 | ## Evaluation 45 | 46 | Evaluation of the Interactron model can be performed by running ``python evaluate.py --config=configs/interactron.yaml``. 47 | The code will automatically take over any available GPUs. Running the evaluation on a CPU could 48 | take several minutes. The evaluator will output visualizations and results in a folder called 49 | `evaluation_results/`. To evaluate other models, select one of the other config files in `configs/`. 50 | 51 | 52 | ## Training 53 | 54 | Training of the Interactron model can be performed by running ``python train.py --config=configs/interactron.yaml``. 55 | The code will automatically take over any available GPUs. To train using the default configuration, 56 | at least 12GB of VRAM is necessary. Training takes roughly five days on a high performance machine using a RTX 3090 GPU. 57 | The trainer will output results in a folder called 58 | `training_results/`. To train other models, select one of the other config files in `configs/`. 59 | 60 | 61 | ## Citation 62 | ``` 63 | @inproceedings{kotar2022interactron, 64 | title={Interactron: Embodied Adaptive Object Detection}, 65 | author={Klemen Kotar and Roozbeh Mottaghi}, 66 | booktitle={CVPR}, 67 | year={2022}, 68 | } 69 | ``` 70 | 71 | Parts of the codebase were derived from other repositories and modified (like the DETR model code) 72 | and have a crediting comment on the first line of the file. 73 | -------------------------------------------------------------------------------- /configs/interactron.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: "interactron" 3 | WEIGHTS: "pretrained_weights/detr-dc5-backbone.pth" 4 | NUM_CLASSES: 1235 5 | BACKBONE: "resnet50" 6 | SET_COST_CLASS: 1.0 7 | SET_COST_BBOX: 5.0 8 | SET_COST_GIOU: 2.0 9 | TEST_RESOLUTION: 300 10 | PREDICT_ACTIONS: True 11 | NUM_LAYERS: 4 12 | NUM_HEADS: 8 13 | EMBEDDING_DIM: 512 14 | BLOCK_SIZE: 2060 15 | IMG_FEATURE_SIZE: 256 16 | OUTPUT_SIZE: 512 17 | BOX_EMB_SIZE: 256 18 | EMBEDDING_PDROP: 0.1 19 | RESIDUAL_PDROP: 0.1 20 | ATTENTION_PDROP: 0.1 21 | ADAPTIVE_LR: 1e-3 22 | DATASET: 23 | TRAIN: 24 | TYPE: "sequence" 25 | MODE: "train" 26 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_train.json" 27 | IMAGE_ROOT: "data/interactron/train" 28 | TEST: 29 | TYPE: "sequence" 30 | MODE: "test" 31 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_test.json" 32 | IMAGE_ROOT: "data/interactron/test" 33 | TRAINER: 34 | TYPE: "interactron" 35 | BATCH_SIZE: 16 36 | NUM_WORKERS: 16 37 | MAX_EPOCHS: 2000 38 | SAVE_WINDOW: 500 39 | DETECTOR_LR: 1e-5 40 | SUPERVISOR_LR: 1e-4 41 | BETA1: 0.9 42 | BETA2: 0.95 43 | MOMENTUM: 0.9 44 | GRAD_NORM_CLIP: 1.0 45 | WEIGHT_DECAY: 0.1 46 | OPTIM_TYPE: "Adam" 47 | LR_DECAY: False 48 | WARMUP_TOKENS: 0 49 | FINAL_TOKENS: 0.8e7 50 | OUTPUT_DIRECTORY: "training_results/interactron" 51 | EVALUATOR: 52 | TYPE: "interactive_evaluator" 53 | BATCH_SIZE: 1 54 | NUM_WORKERS: 1 55 | OUTPUT_DIRECTORY: "evaluation_results/interactron" 56 | CHECKPOINT: "pretrained_weights/interactron.pt" 57 | -------------------------------------------------------------------------------- /configs/interactron_random.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: "interactron_random" 3 | WEIGHTS: "pretrained_weights/detr-dc5-backbone.pth" 4 | NUM_CLASSES: 1235 5 | BACKBONE: "resnet50" 6 | SET_COST_CLASS: 1.0 7 | SET_COST_BBOX: 5.0 8 | SET_COST_GIOU: 2.0 9 | TEST_RESOLUTION: 300 10 | PREDICT_ACTIONS: False 11 | NUM_LAYERS: 4 12 | NUM_HEADS: 8 13 | EMBEDDING_DIM: 512 14 | BLOCK_SIZE: 2060 15 | IMG_FEATURE_SIZE: 256 16 | OUTPUT_SIZE: 512 17 | BOX_EMB_SIZE: 256 18 | EMBEDDING_PDROP: 0.1 19 | RESIDUAL_PDROP: 0.1 20 | ATTENTION_PDROP: 0.1 21 | ADAPTIVE_LR: 1e-3 22 | DATASET: 23 | TRAIN: 24 | TYPE: "sequence" 25 | MODE: "train" 26 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_train.json" 27 | IMAGE_ROOT: "data/interactron/train" 28 | TEST: 29 | TYPE: "sequence" 30 | MODE: "test" 31 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_test.json" 32 | IMAGE_ROOT: "data/interactron/test" 33 | TRAINER: 34 | TYPE: "interactron_random" 35 | BATCH_SIZE: 16 36 | NUM_WORKERS: 16 37 | MAX_EPOCHS: 2000 38 | SAVE_WINDOW: 500 39 | DETECTOR_LR: 1e-5 40 | SUPERVISOR_LR: 1e-4 41 | BETA1: 0.9 42 | BETA2: 0.95 43 | MOMENTUM: 0.9 44 | GRAD_NORM_CLIP: 1.0 45 | WEIGHT_DECAY: 0.1 46 | OPTIM_TYPE: "Adam" 47 | LR_DECAY: False 48 | WARMUP_TOKENS: 0 49 | FINAL_TOKENS: 1.0e7 50 | OUTPUT_DIRECTORY: "training_results/interactron_random" 51 | EVALUATOR: 52 | TYPE: "random_policy_evaluator" 53 | BATCH_SIZE: 1 54 | NUM_WORKERS: 1 55 | OUTPUT_DIRECTORY: "evaluation_results/interactron_random" 56 | CHECKPOINT: "pretrained_weights/interactron_random.pt" 57 | -------------------------------------------------------------------------------- /configs/multi_frame_baseline.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: "detr_multiframe" 3 | WEIGHTS: "pretrained_weights/detr-dc5-backbone.pth" 4 | NUM_CLASSES: 1235 5 | BACKBONE: "resnet50" 6 | SET_COST_CLASS: 1.0 7 | SET_COST_BBOX: 5.0 8 | SET_COST_GIOU: 2.0 9 | TEST_RESOLUTION: 300 10 | PREDICT_ACTIONS: False 11 | NUM_LAYERS: 4 12 | NUM_HEADS: 8 13 | EMBEDDING_DIM: 512 14 | BLOCK_SIZE: 2060 15 | IMG_FEATURE_SIZE: 256 16 | OUTPUT_SIZE: 512 17 | BOX_EMB_SIZE: 256 18 | EMBEDDING_PDROP: 0.1 19 | RESIDUAL_PDROP: 0.1 20 | ATTENTION_PDROP: 0.1 21 | DATASET: 22 | TRAIN: 23 | TYPE: "sequence" 24 | MODE: "train" 25 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_train.json" 26 | IMAGE_ROOT: "data/interactron/train" 27 | TEST: 28 | TYPE: "sequence" 29 | MODE: "test" 30 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_test.json" 31 | IMAGE_ROOT: "data/interactron/test" 32 | TRAINER: 33 | TYPE: "direct_supervision" 34 | BATCH_SIZE: 16 35 | NUM_WORKERS: 16 36 | MAX_EPOCHS: 2000 37 | SAVE_WINDOW: 500 38 | LEARNING_RATE: 1e-5 39 | BETA1: 0.9 40 | BETA2: 0.95 41 | MOMENTUM: 0.9 42 | GRAD_NORM_CLIP: 1.0 43 | WEIGHT_DECAY: 0.1 44 | OPTIM_TYPE: "Adam" 45 | LR_DECAY: Flase 46 | WARMUP_TOKENS: 0 47 | FINAL_TOKENS: 0 48 | OUTPUT_DIRECTORY: "training_results/detr_multiframe" 49 | EVALUATOR: 50 | TYPE: "random_policy_evaluator" 51 | BATCH_SIZE: 1 52 | NUM_WORKERS: 1 53 | OUTPUT_DIRECTORY: "evaluation_results/multi_frame_baseline" 54 | CHECKPOINT: "pretrained_weights/detr_multiframe.pt" 55 | -------------------------------------------------------------------------------- /configs/single_frame_baseline.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: "detr" 3 | WEIGHTS: "pretrained_weights/detr-dc5.pth" 4 | NUM_CLASSES: 1235 5 | FROZEN_WEIGHTS: "pretrained_weights/detr-dc5.pth" 6 | BACKBONE: "resnet50" 7 | SET_COST_CLASS: 1.0 8 | SET_COST_BBOX: 5.0 9 | SET_COST_GIOU: 2.0 10 | TEST_RESOLUTION: 300 11 | DATASET: 12 | TRAIN: 13 | TYPE: "sequence" 14 | MODE: "train" 15 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_train.json" 16 | IMAGE_ROOT: "data/interactron/train" 17 | TEST: 18 | TYPE: "sequence" 19 | MODE: "test" 20 | ANNOTATION_ROOT: "data/interactron/annotations/interactron_v1_test.json" 21 | IMAGE_ROOT: "data/interactron/test" 22 | EVALUATOR: 23 | TYPE: "random_policy_evaluator" 24 | BATCH_SIZE: 1 25 | NUM_WORKERS: 1 26 | OUTPUT_DIRECTORY: "evaluation_results/single_frame_baseline" 27 | CHECKPOINT: "pretrained_weights/single_frame_baseline.pt" 28 | 29 | -------------------------------------------------------------------------------- /data_collection/collect_ithor_tree_data.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import random 3 | import json 4 | import sys 5 | import cv2 6 | import tqdm 7 | import os 8 | 9 | from ai2thor.controller import Controller 10 | from data_collection_utils import ( 11 | pos_to_id, 12 | teleport_to, 13 | take_step, 14 | find_shortes_terminal_path 15 | ) 16 | 17 | # define scenes 18 | kitchens = [f"FloorPlan{i}" for i in range(1, 31)] 19 | living_rooms = [f"FloorPlan{200 + i}" for i in range(1, 31)] 20 | bedrooms = [f"FloorPlan{300 + i}" for i in range(1, 31)] 21 | bathrooms = [f"FloorPlan{400 + i}" for i in range(1, 31)] 22 | 23 | train_scenes = kitchens[:20] + living_rooms[:20] + bedrooms[:20] + bathrooms[:20] 24 | val_scenes = kitchens[20:25] + living_rooms[20:25] + bedrooms[20:25] + bathrooms[20:25] 25 | test_scenes = kitchens[25:] + living_rooms[25:] + bedrooms[25:] + bathrooms[25:] 26 | 27 | # define dataset collection parameters 28 | TRAIN = False if sys.argv[1] == "test" else True 29 | NUM_ANCHORS = 1000 if TRAIN else 100 30 | NUM_STEPS = 4 31 | ROT_ANGLE = 30 32 | ACTIONS = ["MoveAhead", "MoveBack", "RotateLeft", "RotateRight"] 33 | SCENES = train_scenes + val_scenes if TRAIN else test_scenes 34 | IMG_ROOT = '../data/interactron/train' if TRAIN else '../data/interactron/test' 35 | ANN_PATH = '../data/interactron/annotations/interactron_v1_train.json' if TRAIN \ 36 | else '../data/interactron/annotations/interactron_v1_test.json' 37 | CTRL = Controller( 38 | rotateStepDegrees=ROT_ANGLE, 39 | renderDepthImage=True, 40 | renderInstanceSegmentation=True, 41 | height=300, 42 | width=300, 43 | gridSize=0.25, 44 | snapToGrid=False, 45 | ) 46 | 47 | 48 | def rollout_rec(root_state, state_table, d=0): 49 | # if we reached the end of the rollout return empty dict of next steps 50 | if d >= NUM_STEPS: 51 | return {} 52 | # otherwise generate the data for the steps we can take from this state 53 | if pos_to_id(root_state) in state_table and len(state_table[pos_to_id(root_state)]['actions']) > 0: 54 | steps = state_table[pos_to_id(root_state)]['actions'] 55 | else: 56 | steps = {} 57 | for action in ACTIONS: 58 | new_state = take_step(CTRL, root_state, action) 59 | steps[action] = pos_to_id(new_state) 60 | if pos_to_id(new_state) not in state_table: 61 | state_table[pos_to_id(new_state)] = new_state 62 | state_table[pos_to_id(new_state)]["actions"] = {} 63 | for state_name in steps.values(): 64 | state = state_table[state_name] 65 | next_steps = rollout_rec(state, state_table, d=d+1) 66 | if len(state_table[pos_to_id(state)]["actions"]) == 0: 67 | state_table[pos_to_id(state)]["actions"] = next_steps 68 | return steps 69 | 70 | 71 | def collect_dataset(): 72 | if NUM_ANCHORS % len(SCENES) != 0: 73 | warnings.warn("The number of anchors specified (%d) is not integer divisible by the number" 74 | "of scenes (%d). To maintain dataset balance the number of anchors will" 75 | "be reduced to (%d)" % (NUM_ANCHORS, len(SCENES), NUM_ANCHORS // len(SCENES))) 76 | samples_per_scene = NUM_ANCHORS // len(SCENES) 77 | annotations = { 78 | "data": [], 79 | "metadata": { 80 | "actions": ACTIONS, 81 | "max_steps": NUM_STEPS, 82 | "rotation_angle": ROT_ANGLE, 83 | "scenes": SCENES 84 | } 85 | } 86 | for scene in tqdm.tqdm(SCENES): 87 | CTRL.reset(scene=scene) 88 | rotations = [{"x": 0.0, "y": float(theta), "z": 0.0} for theta in range(0, 360, ROT_ANGLE)] 89 | horizons = [0] 90 | standing = [True] 91 | for i in range(samples_per_scene): 92 | # try generating data until you have a complete validated tree 93 | validated_root = False 94 | while not validated_root: 95 | # randomize scene 96 | CTRL.reset(scene=scene) 97 | # find a valid root 98 | num_valid_objects = 0 99 | while num_valid_objects < 3: 100 | # select random starting rotation, horizon and standing state 101 | p = random.choice(CTRL.step(action="GetReachablePositions").metadata["actionReturn"]) 102 | r = random.choice(rotations) 103 | h = random.choice(horizons) 104 | s = random.choice(standing) 105 | root = teleport_to(CTRL, {"pos": p, "rot": r, "hor": h, "stand": s}) 106 | num_valid_objects = len(root["detections"]) 107 | # generate data from this root 108 | root_id = pos_to_id(root) 109 | state_table = {root_id: root} 110 | state_table[root_id]["actions"] = {} 111 | state_table[root_id]["actions"] = rollout_rec(root, state_table) 112 | # check to see that all paths in this tree are at least as long as out max depth 113 | validated_root = find_shortes_terminal_path(root_id, state_table, max_depth=NUM_STEPS) >= NUM_STEPS 114 | # save data 115 | scene_name = "{}_{:05d}".format(scene, i) 116 | os.makedirs("{}/{}".format(IMG_ROOT, scene_name), exist_ok=True) 117 | for state, values in state_table.items(): 118 | cv2.imwrite("{}/{}/{}.jpg".format(IMG_ROOT, scene_name, pos_to_id(values)), values["img"]) 119 | # reformat state table into dataset format 120 | light_state_table = {} 121 | for name, fields in state_table.items(): 122 | light_state_table[name] = { 123 | "pos": fields["pos"], 124 | "rot": fields["rot"], 125 | "hor": fields["hor"], 126 | "stand": fields["stand"], 127 | "detections": fields["detections"], 128 | "actions": fields["actions"] 129 | } 130 | annotations["data"].append({ 131 | "scene_name": scene_name, 132 | "state_table": light_state_table, 133 | "root": root_id 134 | }) 135 | # save annotations 136 | with open(ANN_PATH, 'w') as f: 137 | json.dump(annotations, f) 138 | 139 | # close env 140 | CTRL.stop() 141 | 142 | 143 | if __name__ == '__main__': 144 | collect_dataset() 145 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/interactron/e94a1c3c7bd442708f4a3a6d8bc5f586f597a02d/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/interactive_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | import random 5 | from PIL import Image 6 | 7 | from utils.constants import ACTIONS 8 | 9 | 10 | class InteractiveDaatset(Dataset): 11 | """Interactive Dataset.""" 12 | 13 | def __init__(self, img_root, annotations_path, mode="train", transform=None): 14 | """ 15 | Args: 16 | root_dir (string): Directory with the train and test images and annotations 17 | test: Flag to indicate if the train or test set is used 18 | """ 19 | assert mode in ["train", "test"], "Only train and test modes supported" 20 | self.mode = mode 21 | with open(annotations_path) as f: 22 | self.annotations = json.load(f) 23 | # remove trailing slash if present 24 | self.img_dir = img_root if img_root[-1] != "/" else img_root[:-1] 25 | self.transform = transform 26 | # interactive 27 | self.idx = -1 28 | self.actions = [] 29 | 30 | def reset(self): 31 | 32 | self.idx += 1 33 | if self.idx >= len(self.annotations["data"]): 34 | self.idx = 0 35 | self.actions = [] 36 | scene = self.annotations["data"][self.idx] 37 | 38 | state_name = scene["root"] 39 | state = scene["state_table"][state_name] 40 | actions = self.actions 41 | frames = [] 42 | masks = [] 43 | object_ids = [] 44 | category_ids = [] 45 | bounding_boxes = [] 46 | initial_img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 47 | for i in range(len(self.actions)+1): 48 | # load image 49 | img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 50 | frame = Image.open(img_path) 51 | # get img dimensions 52 | imgw, imgh = frame.size 53 | masks.append(torch.zeros((imgw, imgh), dtype=torch.long)) 54 | img_object_ids = [] 55 | img_class_ids = [] 56 | img_bounding_boxes = [] 57 | for k, v in state["detections"].items(): 58 | img_object_ids.append(hash(k.encode())) 59 | img_class_ids.append(v["category_id"]+1) 60 | w, h, cw, ch = v["bbox"] 61 | img_bounding_boxes.append([w, h, w+cw, h+ch]) 62 | if len(img_bounding_boxes) != 0: 63 | boxes = torch.tensor(img_bounding_boxes, dtype=torch.float) 64 | targets = { 65 | "boxes": boxes, 66 | "labels": torch.tensor(img_class_ids, dtype=torch.long), 67 | "areas": (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]), 68 | "iscrowd": torch.zeros((len(img_object_ids))).bool() 69 | } 70 | else: 71 | targets = None 72 | if self.transform: 73 | frame, targets = self.transform(frame, targets) 74 | frames.append(frame) 75 | bounding_boxes.append(targets["boxes"] if targets is not None else torch.zeros(0, 4)) 76 | object_ids.append(img_object_ids) 77 | category_ids.append(targets["labels"] if targets is not None else torch.zeros(0).long()) 78 | if i < len(actions): 79 | state_name = state["actions"][actions[i]] 80 | state = scene["state_table"][state_name] 81 | 82 | sample = { 83 | 'frames': torch.stack(frames, dim=0).unsqueeze(0), 84 | "masks": torch.stack(masks, dim=0).unsqueeze(0), 85 | "actions": torch.tensor([ACTIONS.index(a) for a in actions], dtype=torch.long).unsqueeze(0), 86 | "category_ids": [category_ids], 87 | "boxes": [bounding_boxes], 88 | "episode_ids": self.idx, 89 | "initial_image_path": [initial_img_path] 90 | } 91 | 92 | return sample 93 | 94 | def step(self, action): 95 | 96 | self.actions.append(ACTIONS[action]) 97 | scene = self.annotations["data"][self.idx] 98 | 99 | state_name = scene["root"] 100 | state = scene["state_table"][state_name] 101 | actions = self.actions 102 | frames = [] 103 | masks = [] 104 | object_ids = [] 105 | category_ids = [] 106 | bounding_boxes = [] 107 | initial_img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 108 | for i in range(len(self.actions)+1): 109 | # load image 110 | img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 111 | frame = Image.open(img_path) 112 | # get img dimensions 113 | imgw, imgh = frame.size 114 | masks.append(torch.zeros((imgw, imgh), dtype=torch.long)) 115 | img_object_ids = [] 116 | img_class_ids = [] 117 | img_bounding_boxes = [] 118 | for k, v in state["detections"].items(): 119 | img_object_ids.append(hash(k.encode())) 120 | img_class_ids.append(v["category_id"]+1) 121 | w, h, cw, ch = v["bbox"] 122 | img_bounding_boxes.append([w, h, w+cw, h+ch]) 123 | if len(img_bounding_boxes) != 0: 124 | boxes = torch.tensor(img_bounding_boxes, dtype=torch.float) 125 | targets = { 126 | "boxes": boxes, 127 | "labels": torch.tensor(img_class_ids, dtype=torch.long), 128 | "areas": (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]), 129 | "iscrowd": torch.zeros((len(img_object_ids))).bool() 130 | } 131 | else: 132 | targets = None 133 | if self.transform: 134 | frame, targets = self.transform(frame, targets) 135 | frames.append(frame) 136 | bounding_boxes.append(targets["boxes"] if targets is not None else torch.zeros(0, 4)) 137 | object_ids.append(img_object_ids) 138 | category_ids.append(targets["labels"] if targets is not None else torch.zeros(0).long()) 139 | if i < len(actions): 140 | state_name = state["actions"][actions[i]] 141 | state = scene["state_table"][state_name] 142 | 143 | sample = { 144 | 'frames': torch.stack(frames, dim=0).unsqueeze(0), 145 | "masks": torch.stack(masks, dim=0).unsqueeze(0), 146 | "actions": torch.tensor([ACTIONS.index(a) for a in actions], dtype=torch.long).unsqueeze(0), 147 | "object_ids": object_ids, 148 | "category_ids": [category_ids], 149 | "boxes": [bounding_boxes], 150 | "episode_ids": self.idx, 151 | "initial_image_path": [initial_img_path] 152 | } 153 | 154 | return sample 155 | 156 | def __len__(self): 157 | return len(self.annotations["data"]) 158 | 159 | def __getitem__(self, idx): 160 | if torch.is_tensor(idx): 161 | idx = idx.tolist() 162 | 163 | scene = self.annotations["data"][idx] 164 | 165 | state_name = scene["root"] 166 | state = scene["state_table"][state_name] 167 | actions = [random.choice(self.annotations["metadata"]["actions"]) for _ in range(5)] 168 | frames = [] 169 | masks = [] 170 | object_ids = [] 171 | category_ids = [] 172 | bounding_boxes = [] 173 | initial_img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 174 | for i in range(5): 175 | # load image 176 | img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 177 | frame = Image.open(img_path) 178 | # get img dimensions 179 | imgw, imgh = frame.size 180 | masks.append(torch.zeros((imgw, imgh), dtype=torch.long)) 181 | img_object_ids = [] 182 | img_class_ids = [] 183 | img_bounding_boxes = [] 184 | for k, v in state["detections"].items(): 185 | img_object_ids.append(hash(k.encode())) 186 | img_class_ids.append(v["category_id"]) 187 | w, h, cw, ch = v["bbox"] 188 | img_bounding_boxes.append([w, h, w+cw, h+ch]) 189 | # apply transforms to image 190 | if len(img_bounding_boxes) != 0: 191 | boxes = torch.tensor(img_bounding_boxes, dtype=torch.float) 192 | targets = { 193 | "boxes": boxes, 194 | "labels": torch.tensor(img_class_ids, dtype=torch.long), 195 | "areas": (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]), 196 | "iscrowd": torch.zeros((len(img_object_ids))).bool() 197 | } 198 | else: 199 | targets = None 200 | if self.transform: 201 | frame, targets = self.transform(frame, targets) 202 | frames.append(frame) 203 | bounding_boxes.append(targets["boxes"] if targets is not None else torch.zeros(0, 4)) 204 | object_ids.append(img_object_ids) 205 | category_ids.append(targets["labels"] if targets is not None else torch.zeros(0).long()) 206 | if i < 4: 207 | if self.mode == "test": 208 | state_name = state["actions"][actions[i]] 209 | else: 210 | state_name = random.choice(list(scene["state_table"])) 211 | state = scene["state_table"][state_name] 212 | 213 | sample = { 214 | 'frames': frames, 215 | "masks": masks, 216 | "actions": [ACTIONS.index(a) for a in actions], 217 | "object_ids": object_ids, 218 | "category_ids": category_ids, 219 | "boxes": bounding_boxes, 220 | "episode_ids": idx, 221 | "initial_image_path": initial_img_path 222 | } 223 | 224 | return sample 225 | 226 | -------------------------------------------------------------------------------- /datasets/sequence_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | import random 5 | from PIL import Image 6 | 7 | from utils.constants import ACTIONS 8 | 9 | 10 | class SequenceDataset(Dataset): 11 | """Sequence Rollout Dataset.""" 12 | 13 | def __init__(self, img_root, annotations_path, mode="train", transform=None): 14 | """ 15 | Args: 16 | root_dir (string): Directory with the train and test images and annotations 17 | test: Flag to indicate if the train or test set is used 18 | """ 19 | assert mode in ["train", "test"], "Only train and test modes supported" 20 | self.mode = mode 21 | with open(annotations_path) as f: 22 | self.annotations = json.load(f) 23 | # remove trailing slash if present 24 | self.img_dir = img_root if img_root[-1] != "/" else img_root[:-1] 25 | self.transform = transform 26 | 27 | def __len__(self): 28 | return len(self.annotations["data"]) 29 | 30 | def __getitem__(self, idx, actions=None): 31 | if torch.is_tensor(idx): 32 | idx = idx.tolist() 33 | 34 | scene = self.annotations["data"][idx] 35 | 36 | # seed the random generator 37 | if self.mode == "test" and actions is None: 38 | actions = ['RotateLeft', 'MoveAhead', 'RotateLeft', 'MoveBack', 'RotateRight'] 39 | 40 | state_name = scene["root"] 41 | state = scene["state_table"][state_name] 42 | if actions is None: 43 | actions = [random.choice(self.annotations["metadata"]["actions"]) for _ in range(5)] 44 | frames = [] 45 | masks = [] 46 | object_ids = [] 47 | category_ids = [] 48 | bounding_boxes = [] 49 | initial_img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 50 | for i in range(5): 51 | # load image 52 | img_path = "{}/{}/{}.jpg".format(self.img_dir, scene["scene_name"], state_name) 53 | frame = Image.open(img_path) 54 | # get img dimensions 55 | imgw, imgh = frame.size 56 | masks.append(torch.zeros((imgw, imgh), dtype=torch.long)) 57 | img_object_ids = [] 58 | img_class_ids = [] 59 | img_bounding_boxes = [] 60 | for k, v in state["detections"].items(): 61 | img_object_ids.append(hash(k.encode())) 62 | img_class_ids.append(v["category_id"]+1) 63 | w, h, cw, ch = v["bbox"] 64 | img_bounding_boxes.append([w, h, w+cw, h+ch]) 65 | # apply transforms to image 66 | if len(img_bounding_boxes) != 0: 67 | boxes = torch.tensor(img_bounding_boxes, dtype=torch.float) 68 | targets = { 69 | "boxes": boxes, 70 | "labels": torch.tensor(img_class_ids, dtype=torch.long), 71 | "areas": (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]), 72 | "iscrowd": torch.zeros((len(img_object_ids))).bool() 73 | } 74 | else: 75 | targets = None 76 | if self.transform: 77 | frame, targets = self.transform(frame, targets) 78 | frames.append(frame) 79 | bounding_boxes.append(targets["boxes"] if targets is not None else torch.zeros(0, 4)) 80 | object_ids.append(img_object_ids) 81 | category_ids.append(targets["labels"] if targets is not None else torch.zeros(0).long()) 82 | if i < 4: 83 | state_name = state["actions"][actions[i]] 84 | state = scene["state_table"][state_name] 85 | 86 | sample = { 87 | 'frames': frames, 88 | "masks": masks, 89 | "actions": [ACTIONS.index(a) for a in actions], 90 | "object_ids": object_ids, 91 | "category_ids": category_ids, 92 | "boxes": bounding_boxes, 93 | "episode_ids": idx, 94 | "initial_image_path": initial_img_path 95 | } 96 | 97 | return sample 98 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/interactron/e94a1c3c7bd442708f4a3a6d8bc5f586f597a02d/engine/__init__.py -------------------------------------------------------------------------------- /engine/direct_supervision_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Direct Supervision Random Training Loop 3 | The model is trained on random sequences of data. 4 | """ 5 | 6 | import math 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import os 11 | from datetime import datetime 12 | 13 | import torch 14 | from torch.utils.data.dataloader import DataLoader 15 | 16 | from datasets.sequence_dataset import SequenceDataset 17 | from utils.transform_utis import transform, train_transform 18 | from utils.logging_utils import TBLogger 19 | from utils.storage_utils import collate_fn 20 | 21 | 22 | class DirectSupervisionTrainer: 23 | 24 | def __init__(self, model, config, evaluator=None): 25 | self.model = model 26 | self.config = config 27 | self.evaluator = evaluator 28 | 29 | # set up logging and saving 30 | self.out_dir = os.path.join(self.config.TRAINER.OUTPUT_DIRECTORY, datetime.now().strftime("%m-%d-%Y:%H:%M:%S")) 31 | self.logger = TBLogger(os.path.join(self.out_dir, "logs")) 32 | self.model.set_logger(self.logger) 33 | self.checkpoint_path = os.path.join(self.out_dir, "detector.pt") 34 | self.saved_checkpoints = None 35 | 36 | self.train_dataset = SequenceDataset(config.DATASET.TRAIN.IMAGE_ROOT, config.DATASET.TRAIN.ANNOTATION_ROOT, 37 | config.DATASET.TRAIN.MODE, transform=train_transform) 38 | self.test_dataset = SequenceDataset(config.DATASET.TEST.IMAGE_ROOT, config.DATASET.TEST.ANNOTATION_ROOT, 39 | config.DATASET.TEST.MODE, transform=transform) 40 | 41 | # take over whatever gpus are on the system 42 | self.device = 'cpu' 43 | if torch.cuda.is_available(): 44 | self.device = torch.cuda.current_device() 45 | self.model = torch.nn.DataParallel(self.model).to(self.device) 46 | 47 | def record_checkpoint(self, w=1.0): 48 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 49 | raw_parameters = raw_model.state_dict() 50 | if self.saved_checkpoints is None: 51 | print("New Save", w) 52 | self.saved_checkpoints = {k: w * v for k, v in raw_parameters.items()} 53 | else: 54 | print("Add on save", w) 55 | for param_name, weight in raw_parameters.items(): 56 | self.saved_checkpoints[param_name] += w * weight 57 | 58 | def save_checkpoint(self): 59 | if self.saved_checkpoints is None: 60 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 61 | raw_parameters = raw_model.state_dict() 62 | else: 63 | raw_parameters = self.saved_checkpoints 64 | torch.save({"model": raw_parameters}, self.checkpoint_path) 65 | 66 | def train(self): 67 | model, config = self.model, self.config.TRAINER 68 | raw_model = model.module if hasattr(self.model, "module") else model 69 | optimizer = torch.optim.Adam(raw_model.get_optimizer_groups(config), lr=config.LEARNING_RATE) 70 | 71 | def run_epoch(split): 72 | is_train = split == 'train' 73 | loader = DataLoader(self.train_dataset if is_train else self.test_dataset, shuffle=is_train, 74 | pin_memory=True, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, 75 | collate_fn=collate_fn) 76 | 77 | loss_list = [] 78 | pbar = tqdm(enumerate(loader), total=len(loader)) 79 | for it, data in pbar: 80 | 81 | # place data on the correct device 82 | data["frames"] = data["frames"].to(self.device) 83 | data["masks"] = data["masks"].to(self.device) 84 | data["category_ids"] = [[j.to(self.device) for j in i] for i in data["category_ids"]] 85 | data["boxes"] = [[j.to(self.device) for j in i] for i in data["boxes"]] 86 | 87 | # forward the model 88 | predictions, losses = model(data) 89 | loss = losses["loss_detector_ce"] + 5 * losses["loss_detector_bbox"] \ 90 | + 2 * losses["loss_detector_giou"] 91 | 92 | # log the losses 93 | for name, loss_comp in losses.items(): 94 | self.logger.add_value("{}/{}".format("Train" if is_train else "Test", name), loss_comp.mean()) 95 | self.logger.add_value("{}/Total Loss".format("Train" if is_train else "Test"), loss.mean()) 96 | loss_list.append(loss.item()) 97 | 98 | if is_train: 99 | 100 | # backprop and update the parameters 101 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRAD_NORM_CLIP) 102 | optimizer.step() 103 | optimizer.zero_grad() 104 | 105 | # decay the learning rate based on our progress 106 | if config.LR_DECAY: 107 | self.tokens += data["frames"].shape[0] 108 | if self.tokens < config.WARMUP_TOKENS: 109 | # linear warmup 110 | lr_mult = float(self.tokens) / float(max(1, config.WARMUP_TOKENS)) 111 | else: 112 | # cosine learning rate decay 113 | progress = float(self.tokens - config.WARMUP_TOKENS) / \ 114 | float(max(1, config.FINAL_TOKENS - config.WARMUP_TOKENS)) 115 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 116 | lr = config.LEARNING_RATE * lr_mult 117 | for param_group in optimizer.param_groups: 118 | param_group['lr'] = lr 119 | else: 120 | lr = config.LEARNING_RATE 121 | 122 | # report progress 123 | pbar.set_description( 124 | f"epoch {epoch} iter {it}: train loss {float(np.mean(loss_list)):.5f}. lr {lr:e}" 125 | ) 126 | 127 | if not is_train: 128 | test_loss = float(np.mean(loss_list)) 129 | return test_loss 130 | 131 | def run_evaluation(): 132 | test_loss = run_epoch('test') 133 | mAP_50, mAP, tps, fps, fns = self.evaluator.evaluate(save_results=False) 134 | self.logger.add_value("Test/TP", tps) 135 | self.logger.add_value("Test/FP", fps) 136 | self.logger.add_value("Test/FN", fns) 137 | self.logger.add_value("Test/mAP_50", mAP_50) 138 | self.logger.add_value("Test/mAP", mAP) 139 | model.zero_grad() 140 | return mAP 141 | 142 | self.tokens = 0 # counter used for learning rate decay 143 | run_evaluation() 144 | self.logger.log_values() 145 | for epoch in range(1, config.MAX_EPOCHS): 146 | run_epoch('train') 147 | if epoch % 1 == 0 and self.test_dataset is not None and self.evaluator is not None: 148 | run_evaluation() 149 | self.logger.log_values() 150 | 151 | if self.test_dataset is not None and config.MAX_EPOCHS - epoch <= config.SAVE_WINDOW: 152 | self.record_checkpoint(w=1/config.SAVE_WINDOW) 153 | self.save_checkpoint() 154 | -------------------------------------------------------------------------------- /engine/interactive_evaluator.py: -------------------------------------------------------------------------------- 1 | import torchvision.ops 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | from datetime import datetime 6 | from PIL import ImageDraw, ImageFont 7 | import torch 8 | 9 | from utils.constants import THOR_CLASS_IDS, tlvis_classes 10 | from utils.detection_utils import match_predictions_to_detections 11 | from utils.transform_utis import transform, inv_transform 12 | from models.detr_models.util.box_ops import box_cxcywh_to_xyxy 13 | from datasets.interactive_dataset import InteractiveDaatset 14 | 15 | 16 | class InteractiveEvaluator: 17 | 18 | def __init__(self, model, config, load_checkpoint=False): 19 | self.model = model 20 | if load_checkpoint: 21 | self.model.load_state_dict( 22 | torch.load(config.EVALUATOR.CHECKPOINT, map_location=torch.device('cpu'))['model'], strict=False) 23 | self.test_dataset = InteractiveDaatset(config.DATASET.TEST.IMAGE_ROOT, config.DATASET.TEST.ANNOTATION_ROOT, 24 | config.DATASET.TEST.MODE, transform=transform) 25 | self.config = config 26 | 27 | # take over whatever gpus are on the system 28 | self.device = 'cpu' 29 | if torch.cuda.is_available(): 30 | self.device = torch.cuda.current_device() 31 | self.model.to(self.device) 32 | 33 | self.out_dir = config.EVALUATOR.OUTPUT_DIRECTORY + "/" + datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + "/" 34 | 35 | def evaluate(self, save_results=False): 36 | 37 | # prepare data folder if we are saving 38 | if save_results: 39 | os.makedirs(self.out_dir + "images/", exist_ok=True) 40 | 41 | model, config = self.model, self.config.EVALUATOR 42 | loader = self.test_dataset 43 | 44 | detections = [] 45 | for i in range(len(self.test_dataset)): 46 | 47 | model.eval() 48 | data = loader.reset() 49 | 50 | # place data on the correct device 51 | data["frames"] = data["frames"].to(self.device) 52 | data["masks"] = data["masks"].to(self.device) 53 | data["category_ids"] = [[j.to(self.device) for j in i] for i in data["category_ids"]] 54 | data["boxes"] = [[j.to(self.device) for j in i] for i in data["boxes"]] 55 | 56 | for i in range(4): 57 | action = model.get_next_action(data) 58 | data = loader.step(action) 59 | # place data on the correct device 60 | data["frames"] = data["frames"].to(self.device) 61 | data["masks"] = data["masks"].to(self.device) 62 | data["category_ids"] = [[j.to(self.device) for j in i] for i in data["category_ids"]] 63 | data["boxes"] = [[j.to(self.device) for j in i] for i in data["boxes"]] 64 | 65 | # forward the model 66 | predictions = model.predict(data) 67 | 68 | with torch.no_grad(): 69 | for b in range(predictions["pred_boxes"].shape[0]): 70 | img_detections = [] 71 | # get predictions and labels for this image 72 | pred_boxes = box_cxcywh_to_xyxy(predictions["pred_boxes"][b][0]) 73 | pred_scores, pred_cats = predictions["pred_logits"][b][0].softmax(dim=-1).max(dim=-1) 74 | gt_boxes = box_cxcywh_to_xyxy(data["boxes"][b][0]) 75 | gt_cats = data["category_ids"][b][0] 76 | # remove background predictions 77 | non_background_idx = pred_cats != 1235 78 | pred_boxes = pred_boxes[non_background_idx] 79 | pred_cats = pred_cats[non_background_idx] 80 | pred_scores = pred_scores[non_background_idx] 81 | # perform nms 82 | pruned_idxs = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.5) 83 | pred_cats = pred_cats[pruned_idxs] 84 | pred_boxes = pred_boxes[pruned_idxs] 85 | pred_scores = pred_scores[pruned_idxs] 86 | # get sets of categories of predictions and labels 87 | pred_cat_set = set([int(c) for c in pred_cats]) 88 | gt_cat_set = set([int(c) for c in gt_cats]) 89 | pred_only_cat_set = set(THOR_CLASS_IDS).intersection(pred_cat_set - gt_cat_set) 90 | # add each prediction to the list of detections 91 | for cat in gt_cat_set: 92 | if torch.any(pred_cats == cat): 93 | cat_pred_boxes = pred_boxes[pred_cats == cat] 94 | cat_pred_scores = pred_scores[pred_cats == cat] 95 | cat_gt_boxes = gt_boxes[gt_cats == cat] 96 | cat_ious = torchvision.ops.box_iou(cat_pred_boxes, cat_gt_boxes) 97 | cat_best_ious, cat_best_match_idx = match_predictions_to_detections(cat_ious) 98 | for i in range(cat_ious.shape[0]): 99 | if torch.any(cat_best_match_idx == i): 100 | img_detections.append({ 101 | "iou": cat_ious[i].max().item(), 102 | "category_match": True, 103 | "type": "tp", 104 | "pred_cat": cat, 105 | "pred_score": cat_pred_scores[i].item(), 106 | "box": [coord.item() for coord in cat_pred_boxes[i]], 107 | "area": ((cat_pred_boxes[i][2] - cat_pred_boxes[i][0]) * 108 | (cat_pred_boxes[i][3] - cat_pred_boxes[i][1])).item(), 109 | "img": data["initial_image_path"][b] 110 | }) 111 | else: 112 | img_detections.append({ 113 | "iou": cat_ious[i].max().item(), 114 | "category_match": True, 115 | "type": "fp", 116 | "pred_cat": cat, 117 | "pred_score": cat_pred_scores[i].item(), 118 | "box": [coord.item() for coord in cat_pred_boxes[i]], 119 | "area": ((cat_pred_boxes[i][2] - cat_pred_boxes[i][0]) * 120 | (cat_pred_boxes[i][3] - cat_pred_boxes[i][1])).item(), 121 | "img": data["initial_image_path"][b] 122 | }) 123 | for j in range(cat_ious.shape[1]): 124 | if cat_best_ious[j] == 0.0: 125 | img_detections.append({ 126 | "iou": 0.0, 127 | "category_match": False, 128 | "type": "fn", 129 | "pred_cat": cat, 130 | "pred_score": 0.0, 131 | "box": [coord.item() for coord in cat_gt_boxes[j]], 132 | "area": ((cat_gt_boxes[j][2] - cat_gt_boxes[j][0]) * 133 | (cat_gt_boxes[j][3] - cat_gt_boxes[j][1])).item(), 134 | "img": data["initial_image_path"][b] 135 | }) 136 | else: 137 | cat_gt_boxes = gt_boxes[gt_cats == cat] 138 | for j in range(cat_gt_boxes.shape[0]): 139 | img_detections.append({ 140 | "iou": 0.0, 141 | "category_match": False, 142 | "type": "fn", 143 | "pred_cat": cat, 144 | "pred_score": 0.0, 145 | "box": [coord.item() for coord in cat_gt_boxes[j]], 146 | "area": ((cat_gt_boxes[j][2] - cat_gt_boxes[j][0]) * 147 | (cat_gt_boxes[j][3] - cat_gt_boxes[j][1])).item(), 148 | "img": data["initial_image_path"][b] 149 | }) 150 | for cat in pred_only_cat_set: 151 | cat_pred_boxes = pred_boxes[pred_cats == cat] 152 | cat_pred_scores = pred_scores[pred_cats == cat] 153 | for i in range(cat_pred_scores.shape[0]): 154 | img_detections.append({ 155 | "iou": 0.0, 156 | "category_match": False, 157 | "type": "fp", 158 | "pred_cat": cat, 159 | "pred_score": cat_pred_scores[i].item(), 160 | "box": [coord.item() for coord in cat_pred_boxes[i]], 161 | "area": ((cat_pred_boxes[i][2] - cat_pred_boxes[i][0]) * 162 | (cat_pred_boxes[i][3] - cat_pred_boxes[i][1])).item(), 163 | "img": data["initial_image_path"][b], 164 | }) 165 | detections = detections + img_detections 166 | if save_results: 167 | img = inv_transform(data["frames"][b][0].detach().cpu()).resize((1200, 1200)) 168 | font = ImageFont.load_default() 169 | draw = ImageDraw.Draw(img) 170 | for det in img_detections: 171 | color = None 172 | if det["type"] == "tp": 173 | if det["iou"] >= 0.5: 174 | color = "blue" 175 | else: 176 | color = "black" 177 | if det["type"] == "fn": 178 | continue 179 | if det["type"] == "fp" and det["pred_score"] > 0.5: 180 | continue 181 | if color is not None: 182 | draw.rectangle([1200 * c for c in det["box"]], outline=color, width=2) 183 | text = tlvis_classes[det["pred_cat"]] 184 | x, y = 1200 * det["box"][0], 1200 * (det["box"][1] - 0.02) 185 | w, h = font.getsize(text) 186 | draw.rectangle((x, y, x + w, y + h), fill=color) 187 | draw.text((x, y), text, fill="white", font=font) 188 | img_root = self.out_dir + "images/" 189 | img.save(img_root + img_detections[0]["img"].split("/")[-1]) 190 | 191 | tps = [x for x in detections if x["type"] == "tp"] 192 | fps = [x for x in detections if x["type"] == "fp"] 193 | fns = [x for x in detections if x["type"] == "fn"] 194 | 195 | ap_50 = self.compute_ap(detections, nsamples=100, iou_thresholds=[0.5]) 196 | ap_75 = self.compute_ap(detections, nsamples=100, iou_thresholds=[0.75]) 197 | ap = self.compute_ap(detections, nsamples=100, iou_thresholds=list(np.arange(0.5, 1.0, 0.05))) 198 | ap_small = self.compute_ap(detections, nsamples=100, iou_thresholds=list(np.arange(0.5, 1.0, 0.05)), 199 | min_area=0.0, max_area=32**2/300**2) 200 | ap_medium = self.compute_ap(detections, nsamples=100, iou_thresholds=list(np.arange(0.5, 1.0, 0.05)), 201 | min_area=32**2/300**2, max_area=96**2/300**2) 202 | ap_large = self.compute_ap(detections, nsamples=100, iou_thresholds=list(np.arange(0.5, 1.0, 0.05)), 203 | min_area=96**2/300**2, max_area=1.0) 204 | 205 | if not save_results: 206 | return ap_50, ap, len(tps), len(fps), len(fns) 207 | 208 | print("AP_50:", ap_50, "AP_75", ap_75, "AP", ap, "AP_small", ap_small, "AP_medium", ap_medium, "AP_large", ap_large) 209 | 210 | @staticmethod 211 | def compute_ap(detections, nsamples=100, iou_thresholds=[0.5], min_area=0.0, max_area=1.0): 212 | aps = [] 213 | detections = [d for d in detections if min_area < d["area"] < max_area] 214 | 215 | # compute ap for every iou threshold specified 216 | for iou_thresh in iou_thresholds: 217 | tps = [d for d in detections if d["type"] == "tp"] 218 | fps = [d for d in detections if d["type"] == "fp"] 219 | fns = [d for d in detections if d["type"] == "fn"] 220 | p = [] 221 | r = [] 222 | 223 | # move all detections with an iou under the threshold from the tp set to the fp set 224 | i = 0 225 | while i < len(tps): 226 | if tps[i]["iou"] < iou_thresh: 227 | fps.append(tps.pop(i)) 228 | else: 229 | i += 1 230 | 231 | # compute PR curve for various confidence levels 232 | for conf_thresh in np.arange(0.0, 1.0, 1.0 / nsamples): 233 | # remove all prediction with a confidence bellow the threshold 234 | i = 0 235 | while i < len(tps): 236 | if tps[i]["pred_score"] < conf_thresh: 237 | tps.pop(i) 238 | else: 239 | i += 1 240 | i = 0 241 | while i < len(fps): 242 | if fps[i]["pred_score"] < conf_thresh: 243 | fps.pop(i) 244 | else: 245 | i += 1 246 | 247 | # compute p and r values for current confidence threshold 248 | p.append(0 if len(tps) == 0 else len(tps) / (len(tps) + len(fps))) 249 | r.append(0 if len(tps) == 0 else len(tps) / (len(tps) + len(fns))) 250 | 251 | # compute AP using 11 Point Interpolation of PR Curve 252 | p = [0.0] + p 253 | r = [r[0] + 0.000001] + r 254 | interpolation_samples = [] 255 | r_idx = 0 256 | for r_cutoff in np.arange(1.0, -0.0001, -0.01): 257 | while r_idx < len(r)-1 and r[r_idx] > r_cutoff: 258 | r_idx += 1 259 | interpolation_samples.append(max(p[:r_idx+1])) 260 | aps.append(np.mean(interpolation_samples)) 261 | 262 | return np.mean(aps) 263 | 264 | @staticmethod 265 | def compute_pr(detections, nsamples=100, iou_thresh=0.5, min_area=0.0, max_area=1.0): 266 | p = [] 267 | r = [] 268 | detections = [d for d in detections if min_area < d["area"] < max_area] 269 | tps = [d for d in detections if d["type"] == "tp"] 270 | fps = [d for d in detections if d["type"] == "fp"] 271 | fns = [d for d in detections if d["type"] == "fn"] 272 | i = 0 273 | while i < len(tps): 274 | if tps[i]["iou"] < iou_thresh: 275 | fps.append(tps.pop(i)) 276 | else: 277 | i += 1 278 | for conf_thresh in np.arange(0.0, 1.0, 1.0/nsamples): 279 | i = 0 280 | while i < len(tps): 281 | if tps[i]["pred_score"] < conf_thresh: 282 | tps.pop(i) 283 | else: 284 | i += 1 285 | i = 0 286 | while i < len(fps): 287 | if fps[i]["pred_score"] < conf_thresh: 288 | fps.pop(i) 289 | else: 290 | i += 1 291 | p.append(0 if len(tps) == 0 else len(tps) / (len(tps) + len(fps))) 292 | r.append(0 if len(tps) == 0 else len(tps) / (len(tps) + len(fns))) 293 | 294 | return p, r 295 | 296 | -------------------------------------------------------------------------------- /engine/interactron_random_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interactron Random Training Loop 3 | The interactorn model is trained on random sequences of data. 4 | """ 5 | 6 | import math 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import os 11 | from datetime import datetime 12 | 13 | import torch 14 | from torch.utils.data.dataloader import DataLoader 15 | 16 | from datasets.sequence_dataset import SequenceDataset 17 | from utils.transform_utis import transform, train_transform 18 | from utils.logging_utils import TBLogger 19 | from utils.storage_utils import collate_fn 20 | 21 | 22 | class InteractronRandomTrainer: 23 | 24 | def __init__(self, model, config, evaluator=None): 25 | self.model = model 26 | self.config = config 27 | self.evaluator = evaluator 28 | 29 | # set up logging and saving 30 | self.out_dir = os.path.join(self.config.TRAINER.OUTPUT_DIRECTORY, datetime.now().strftime("%m-%d-%Y:%H:%M:%S")) 31 | # os.makedirs(self.out_dir, exist_ok=True) 32 | self.logger = TBLogger(os.path.join(self.out_dir, "logs")) 33 | self.model.set_logger(self.logger) 34 | self.checkpoint_path = os.path.join(self.out_dir, "detector.pt") 35 | self.saved_checkpoints = None 36 | 37 | self.train_dataset = SequenceDataset(config.DATASET.TRAIN.IMAGE_ROOT, config.DATASET.TRAIN.ANNOTATION_ROOT, 38 | config.DATASET.TRAIN.MODE, transform=train_transform) 39 | self.test_dataset = SequenceDataset(config.DATASET.TEST.IMAGE_ROOT, config.DATASET.TEST.ANNOTATION_ROOT, 40 | config.DATASET.TEST.MODE, transform=transform) 41 | 42 | # take over whatever gpus are on the system 43 | self.device = 'cpu' 44 | if torch.cuda.is_available(): 45 | self.device = torch.cuda.current_device() 46 | self.model = torch.nn.DataParallel(self.model).to(self.device) 47 | 48 | def record_checkpoint(self, w=1.0): 49 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 50 | raw_parameters = raw_model.state_dict() 51 | if self.saved_checkpoints is None: 52 | print("New Save", w) 53 | self.saved_checkpoints = {k: w * v for k, v in raw_parameters.items()} 54 | else: 55 | print("Add on save", w) 56 | for param_name, weight in raw_parameters.items(): 57 | self.saved_checkpoints[param_name] += w * weight 58 | 59 | def save_checkpoint(self): 60 | if self.saved_checkpoints is None: 61 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 62 | raw_parameters = raw_model.state_dict() 63 | else: 64 | raw_parameters = self.saved_checkpoints 65 | torch.save({"model": raw_parameters}, self.checkpoint_path) 66 | 67 | def train(self): 68 | model, config = self.model, self.config.TRAINER 69 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 70 | detector_optimizer = torch.optim.Adam(raw_model.detector.parameters(), lr=1e-5) 71 | supervisor_optimizer = torch.optim.Adam(raw_model.fusion.parameters(), lr=1e-4) 72 | model.train() 73 | 74 | def run_epoch(split): 75 | is_train = split == 'train' 76 | loader = DataLoader(self.train_dataset if is_train else self.test_dataset, shuffle=is_train, 77 | pin_memory=True, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, 78 | collate_fn=collate_fn) 79 | 80 | loss_list = [] 81 | pbar = tqdm(enumerate(loader), total=len(loader)) 82 | for it, data in pbar: 83 | 84 | # place data on the correct device 85 | data["frames"] = data["frames"].to(self.device) 86 | data["masks"] = data["masks"].to(self.device) 87 | data["category_ids"] = [[j.to(self.device) for j in i] for i in data["category_ids"]] 88 | data["boxes"] = [[j.to(self.device) for j in i] for i in data["boxes"]] 89 | 90 | # forward the model 91 | predictions, losses = model(data, train=is_train) 92 | detector_loss = losses["loss_detector_ce"] + 5*losses["loss_detector_giou"] + \ 93 | 2*losses["loss_detector_bbox"] 94 | supervisor_loss = losses["loss_supervisor_ce"] +5*losses["loss_supervisor_giou"] + \ 95 | 2*losses["loss_supervisor_bbox"] 96 | 97 | # log the losses 98 | for name, loss_comp in losses.items(): 99 | self.logger.add_value("{}/{}".format("Train" if is_train else "Test", name), loss_comp.mean()) 100 | total_loss = detector_loss + supervisor_loss 101 | self.logger.add_value("{}/Total Loss".format("Train" if is_train else "Test"), total_loss.mean()) 102 | loss_list.append(total_loss.item()) 103 | 104 | if is_train: 105 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRAD_NORM_CLIP) 106 | detector_optimizer.step() 107 | supervisor_optimizer.step() 108 | detector_optimizer.zero_grad() 109 | supervisor_optimizer.zero_grad() 110 | 111 | # decay the learning rate based on our progress 112 | if config.LR_DECAY: 113 | self.tokens += data["frames"].shape[0] * data["frames"].shape[1] 114 | if self.tokens < config.WARMUP_TOKENS: 115 | # linear warmup 116 | lr_mult = float(self.tokens) / float(max(1, config.WARMUP_TOKENS)) 117 | else: 118 | # cosine learning rate decay 119 | progress = float(self.tokens - config.WARMUP_TOKENS) / \ 120 | float(max(1, config.FINAL_TOKENS - config.WARMUP_TOKENS)) 121 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 122 | lr = config.SUPERVISOR_LR * lr_mult 123 | for param_group in supervisor_optimizer.param_groups: 124 | param_group['lr'] = lr 125 | else: 126 | lr = config.SUPERVISOR_LR 127 | 128 | # report progress 129 | pbar.set_description( 130 | f"epoch {epoch} iter {it}: train loss {float(np.mean(loss_list)):.5f}. lr {lr:e}" 131 | ) 132 | 133 | if not is_train: 134 | test_loss = float(np.mean(loss_list)) 135 | return test_loss 136 | 137 | def run_evaluation(): 138 | test_loss = run_epoch('test') 139 | mAP_50, mAP, tps, fps, fns = self.evaluator.evaluate(save_results=False) 140 | self.logger.add_value("Test/TP", tps) 141 | self.logger.add_value("Test/FP", fps) 142 | self.logger.add_value("Test/FN", fns) 143 | self.logger.add_value("Test/mAP_50", mAP_50) 144 | self.logger.add_value("Test/mAP", mAP) 145 | detector_optimizer.zero_grad() 146 | supervisor_optimizer.zero_grad() 147 | return mAP 148 | 149 | self.tokens = 0 # counter used for learning rate decay 150 | run_evaluation() 151 | self.logger.log_values() 152 | for epoch in range(1, config.MAX_EPOCHS): 153 | run_epoch('train') 154 | if epoch % 1 == 0 and self.test_dataset is not None and self.evaluator is not None: 155 | run_evaluation() 156 | self.logger.log_values() 157 | 158 | if self.test_dataset is not None and config.MAX_EPOCHS - epoch <= config.SAVE_WINDOW: 159 | self.record_checkpoint(w=1 / config.SAVE_WINDOW) 160 | self.save_checkpoint() 161 | 162 | -------------------------------------------------------------------------------- /engine/interactron_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interactron Training Loop 3 | The interactorn model is trained on random sequences of data and supervised to pursue good paths 4 | """ 5 | 6 | import math 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import os 11 | from datetime import datetime 12 | 13 | import torch 14 | from torch.utils.data.dataloader import DataLoader 15 | 16 | from datasets.sequence_dataset import SequenceDataset 17 | from utils.transform_utis import transform, train_transform 18 | from utils.logging_utils import TBLogger 19 | from utils.storage_utils import collate_fn 20 | 21 | 22 | class InteractronTrainer: 23 | 24 | def __init__(self, model, config, evaluator=None): 25 | self.model = model 26 | self.config = config 27 | self.evaluator = evaluator 28 | 29 | # set up logging and saving 30 | self.out_dir = os.path.join(self.config.TRAINER.OUTPUT_DIRECTORY, datetime.now().strftime("%m-%d-%Y:%H:%M:%S")) 31 | os.makedirs(self.out_dir, exist_ok=True) 32 | self.logger = TBLogger(os.path.join(self.out_dir, "logs")) 33 | self.model.set_logger(self.logger) 34 | self.checkpoint_path = os.path.join(self.out_dir, "detector.pt") 35 | self.saved_checkpoints = None 36 | 37 | self.train_dataset = SequenceDataset(config.DATASET.TRAIN.IMAGE_ROOT, config.DATASET.TRAIN.ANNOTATION_ROOT, 38 | config.DATASET.TRAIN.MODE, transform=train_transform) 39 | self.test_dataset = SequenceDataset(config.DATASET.TEST.IMAGE_ROOT, config.DATASET.TEST.ANNOTATION_ROOT, 40 | config.DATASET.TEST.MODE, transform=transform) 41 | 42 | # take over whatever gpus are on the system 43 | self.device = 'cpu' 44 | if torch.cuda.is_available(): 45 | self.device = torch.cuda.current_device() 46 | self.model = torch.nn.DataParallel(self.model).to(self.device) 47 | 48 | def record_checkpoint(self, w=1.0): 49 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 50 | raw_parameters = raw_model.state_dict() 51 | if self.saved_checkpoints is None: 52 | print("New Save", w) 53 | self.saved_checkpoints = {k: w * v for k, v in raw_parameters.items()} 54 | else: 55 | print("Add on save", w) 56 | for param_name, weight in raw_parameters.items(): 57 | self.saved_checkpoints[param_name] += w * weight 58 | 59 | def save_checkpoint(self): 60 | if self.saved_checkpoints is None: 61 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 62 | raw_parameters = raw_model.state_dict() 63 | else: 64 | raw_parameters = self.saved_checkpoints 65 | torch.save({"model": raw_parameters}, self.checkpoint_path) 66 | 67 | def train(self): 68 | 69 | model, config = self.model, self.config.TRAINER 70 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 71 | detector_optimizer = torch.optim.Adam(raw_model.detector.parameters(), lr=config.DETECTOR_LR) 72 | supervisor_optimizer = torch.optim.Adam(raw_model.fusion.parameters(), lr=config.SUPERVISOR_LR) 73 | model.train() 74 | 75 | def run_epoch(split): 76 | is_train = split == 'train' 77 | 78 | loader = DataLoader(self.train_dataset if is_train else self.test_dataset, shuffle=is_train, 79 | pin_memory=True, batch_size=config.BATCH_SIZE, num_workers=config.NUM_WORKERS, 80 | collate_fn=collate_fn) 81 | 82 | loss_list = [] 83 | pbar = tqdm(enumerate(loader), total=len(loader)) 84 | for it, data in pbar: 85 | 86 | # place data on the correct device 87 | data["frames"] = data["frames"].to(self.device) 88 | data["masks"] = data["masks"].to(self.device) 89 | data["category_ids"] = [[j.to(self.device) for j in i] for i in data["category_ids"]] 90 | data["boxes"] = [[j.to(self.device) for j in i] for i in data["boxes"]] 91 | 92 | # forward the model 93 | predictions, losses = model(data) 94 | detector_loss = losses["loss_detector_ce"] + 5 * losses["loss_detector_giou"] + \ 95 | 2 * losses["loss_detector_bbox"] 96 | supervisor_loss = losses["loss_supervisor_ce"] + 5 * losses["loss_supervisor_giou"] + \ 97 | 2 * losses["loss_supervisor_bbox"] 98 | 99 | # log the losses 100 | for name, loss_comp in losses.items(): 101 | self.logger.add_value("{}/{}".format("Train" if is_train else "Test", name), loss_comp.mean()) 102 | total_loss = detector_loss + supervisor_loss 103 | self.logger.add_value("{}/Total Loss".format("Train" if is_train else "Test"), total_loss.mean()) 104 | loss_list.append(total_loss.item()) 105 | 106 | if is_train: 107 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRAD_NORM_CLIP) 108 | detector_optimizer.step() 109 | supervisor_optimizer.step() 110 | detector_optimizer.zero_grad() 111 | supervisor_optimizer.zero_grad() 112 | 113 | # decay the learning rate based on our progress 114 | if config.LR_DECAY: 115 | self.tokens += data["frames"].shape[0] * data["frames"].shape[1] 116 | if self.tokens < config.WARMUP_TOKENS: 117 | # linear warmup 118 | lr_mult = float(self.tokens) / float(max(1, config.WARMUP_TOKENS)) 119 | else: 120 | # cosine learning rate decay 121 | progress = float(self.tokens - config.WARMUP_TOKENS) / \ 122 | float(max(1, config.FINAL_TOKENS - config.WARMUP_TOKENS)) 123 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 124 | lr = config.SUPERVISOR_LR * lr_mult 125 | for param_group in supervisor_optimizer.param_groups: 126 | param_group['lr'] = lr 127 | else: 128 | lr = config.SUPERVISOR_LR 129 | self.logger.add_value("{}/LR".format("Train" if is_train else "Test"), lr) 130 | 131 | # report progress 132 | pbar.set_description( 133 | f"epoch {epoch} iter {it}: train loss {float(np.mean(loss_list)):.5f}. lr {lr:e}" 134 | ) 135 | 136 | if not is_train: 137 | test_loss = float(np.mean(loss_list)) 138 | return test_loss 139 | 140 | def run_evaluation(): 141 | test_loss = run_epoch('test') 142 | mAP_50, mAP, tps, fps, fns = self.evaluator.evaluate(save_results=False) 143 | self.logger.add_value("Test/TP", tps) 144 | self.logger.add_value("Test/FP", fps) 145 | self.logger.add_value("Test/FN", fns) 146 | self.logger.add_value("Test/mAP_50", mAP_50) 147 | self.logger.add_value("Test/mAP", mAP) 148 | detector_optimizer.zero_grad() 149 | supervisor_optimizer.zero_grad() 150 | return mAP 151 | 152 | self.tokens = 0 # counter used for learning rate decay 153 | run_evaluation() 154 | self.logger.log_values() 155 | for epoch in range(1, config.MAX_EPOCHS): 156 | run_epoch('train') 157 | if epoch % 1 == 0 and self.test_dataset is not None and self.evaluator is not None: 158 | run_evaluation() 159 | self.logger.log_values() 160 | 161 | if self.test_dataset is not None and config.MAX_EPOCHS - epoch <= config.SAVE_WINDOW: 162 | self.record_checkpoint(w=1 / config.SAVE_WINDOW) 163 | self.save_checkpoint() 164 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from utils.config_utils import ( 2 | get_config, 3 | get_args, 4 | build_model, 5 | build_evaluator 6 | ) 7 | 8 | 9 | def evaluate(): 10 | args = get_args() 11 | cfg = get_config(args.config_file) 12 | model = build_model(cfg.MODEL) 13 | evaluator = build_evaluator(model, cfg, load_checkpoint=True) 14 | evaluator.evaluate(save_results=True) 15 | 16 | 17 | if __name__ == "__main__": 18 | evaluate() 19 | -------------------------------------------------------------------------------- /images/teaser-wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/interactron/e94a1c3c7bd442708f4a3a6d8bc5f586f597a02d/images/teaser-wide.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/interactron/e94a1c3c7bd442708f4a3a6d8bc5f586f597a02d/models/__init__.py -------------------------------------------------------------------------------- /models/components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LinearBlock(nn.Module): 6 | def __init__(self, in_dim, out_dim, bias=False): 7 | super().__init__() 8 | self.model = nn.Sequential( 9 | nn.Linear(in_features=in_dim, out_features=out_dim, bias=bias), 10 | # nn.LayerNorm(out_dim), 11 | nn.GELU(), 12 | ) 13 | 14 | def forward(self, x): 15 | og_shape = x.shape 16 | x = self.model(x.view(-1, og_shape[-1])) 17 | return x.view(*og_shape[:-1], -1) 18 | 19 | -------------------------------------------------------------------------------- /models/detr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.detr_models.detr import build 5 | from models.detr_models.util.misc import NestedTensor 6 | 7 | 8 | class detr(nn.Module): 9 | 10 | def __init__( 11 | self, 12 | config, 13 | ): 14 | super().__init__() 15 | self.model, self.criterion, self.postprocessor = build(config) 16 | self.model.load_state_dict(torch.load(config.WEIGHTS, map_location=torch.device('cpu'))['model']) 17 | self.logger = None 18 | self.mode = 'train' 19 | 20 | def predict(self, data): 21 | # reformat img and mask data 22 | b, s, c, w, h = data["frames"].shape 23 | img = data["frames"].view(b*s, c, w, h) 24 | mask = data["masks"].view(b*s, w, h) 25 | # reformat labels 26 | labels = [] 27 | for i in range(b): 28 | for j in range(s): 29 | labels.append({ 30 | "labels": data["category_ids"][i][j], 31 | "boxes": data["boxes"][i][j] 32 | }) 33 | # get predictions and losses 34 | out = self.model(NestedTensor(img, mask)) 35 | # loss = self.criterion(out, labels) 36 | # clean up predictions 37 | for key, val in out.items(): 38 | out[key] = val.view(b, s, *val.shape[1:]) 39 | 40 | return out 41 | 42 | def forward(self, data): 43 | # reformat img and mask data 44 | b, s, c, w, h = data["frames"].shape 45 | img = data["frames"].view(b*s, c, w, h) 46 | mask = data["masks"].view(b*s, w, h) 47 | # reformat labels 48 | labels = [] 49 | for i in range(b): 50 | for j in range(s): 51 | labels.append({ 52 | "labels": data["category_ids"][i][j], 53 | "boxes": data["boxes"][i][j] 54 | }) 55 | # get predictions and losses 56 | out = self.model(NestedTensor(img, mask)) 57 | losses = self.criterion(out, labels) 58 | loss = losses["loss_ce"] + 5 * losses["loss_bbox"] + 2 * losses["loss_giou"] 59 | loss.backward() 60 | # clean up predictions 61 | for key, val in out.items(): 62 | out[key] = val.view(b, s, *val.shape[1:]) 63 | 64 | return out, {k.replace("loss", "loss_detector"): v for k, v in losses.items()} 65 | 66 | def eval(self): 67 | return self.train(False) 68 | 69 | def train(self, mode=True): 70 | self.mode = 'train' if mode else 'test' 71 | # only train proposal generator of detector 72 | # self.model.backbone.eval() 73 | self.model.train(mode) 74 | return self 75 | 76 | def get_optimizer_groups(self, train_config): 77 | optim_groups = [{ 78 | "params": list(self.model.parameters()), "weight_decay": 0.0 79 | }] 80 | return optim_groups 81 | 82 | def set_logger(self, logger): 83 | assert self.logger is None, "This model already has a logger!" 84 | self.logger = logger 85 | -------------------------------------------------------------------------------- /models/detr_models/__init__.py: -------------------------------------------------------------------------------- 1 | # This code was copied from https://github.com/facebookresearch/detr/blob/main/models/detr.py 2 | from .detr import build 3 | 4 | 5 | def build_model(args): 6 | return build(args) 7 | -------------------------------------------------------------------------------- /models/detr_models/backbone.py: -------------------------------------------------------------------------------- 1 | # This code was copied from https://github.com/facebookresearch/detr 2 | """ 3 | Backbone modules. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from models.detr_models.util.misc import NestedTensor, is_main_process 15 | 16 | from .position_encoding import build_position_encoding 17 | 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 23 | without which any other models than torchvision.models.resnet[18,34,50,101] 24 | produce nans. 25 | """ 26 | 27 | def __init__(self, n): 28 | super(FrozenBatchNorm2d, self).__init__() 29 | self.register_buffer("weight", torch.ones(n)) 30 | self.register_buffer("bias", torch.zeros(n)) 31 | self.register_buffer("running_mean", torch.zeros(n)) 32 | self.register_buffer("running_var", torch.ones(n)) 33 | 34 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 35 | missing_keys, unexpected_keys, error_msgs): 36 | num_batches_tracked_key = prefix + 'num_batches_tracked' 37 | if num_batches_tracked_key in state_dict: 38 | del state_dict[num_batches_tracked_key] 39 | 40 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 41 | state_dict, prefix, local_metadata, strict, 42 | missing_keys, unexpected_keys, error_msgs) 43 | 44 | def forward(self, x): 45 | # move reshapes to the beginning 46 | # to make it fuser-friendly 47 | w = self.weight.reshape(1, -1, 1, 1) 48 | b = self.bias.reshape(1, -1, 1, 1) 49 | rv = self.running_var.reshape(1, -1, 1, 1) 50 | rm = self.running_mean.reshape(1, -1, 1, 1) 51 | eps = 1e-5 52 | scale = w * (rv + eps).rsqrt() 53 | bias = b - rm * scale 54 | return x * scale + bias 55 | 56 | 57 | class BackboneBase(nn.Module): 58 | 59 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 60 | super().__init__() 61 | for name, parameter in backbone.named_parameters(): 62 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 63 | parameter.requires_grad_(False) 64 | if return_interm_layers: 65 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 66 | else: 67 | return_layers = {'layer4': "0"} 68 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 69 | self.num_channels = num_channels 70 | 71 | def forward(self, tensor_list: NestedTensor): 72 | xs = self.body(tensor_list.tensors) 73 | out: Dict[str, NestedTensor] = {} 74 | for name, x in xs.items(): 75 | m = tensor_list.mask 76 | assert m is not None 77 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 78 | out[name] = NestedTensor(x, mask) 79 | return out 80 | 81 | 82 | class Backbone(BackboneBase): 83 | """ResNet backbone with frozen BatchNorm.""" 84 | def __init__(self, name: str, 85 | train_backbone: bool, 86 | return_interm_layers: bool, 87 | dilation: bool): 88 | backbone = getattr(torchvision.models, name)( 89 | replace_stride_with_dilation=[False, False, dilation], 90 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 91 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 92 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 93 | 94 | 95 | class Joiner(nn.Sequential): 96 | def __init__(self, backbone, position_embedding): 97 | super().__init__(backbone, position_embedding) 98 | 99 | def forward(self, tensor_list: NestedTensor): 100 | xs = self[0](tensor_list) 101 | out: List[NestedTensor] = [] 102 | pos = [] 103 | for name, x in xs.items(): 104 | out.append(x) 105 | # position encoding 106 | pos.append(self[1](x).to(x.tensors.dtype)) 107 | 108 | return out, pos 109 | 110 | 111 | def build_backbone(args): 112 | position_embedding = build_position_encoding(args) 113 | train_backbone = True 114 | return_interm_layers = False 115 | backbone = Backbone('resnet50', train_backbone, return_interm_layers, True) 116 | model = Joiner(backbone, position_embedding) 117 | model.num_channels = backbone.num_channels 118 | return model -------------------------------------------------------------------------------- /models/detr_models/matcher.py: -------------------------------------------------------------------------------- 1 | # This code was copied from https://github.com/facebookresearch/detr 2 | """ 3 | Modules to compute the matching cost and solve the corresponding LSAP. 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | 9 | from models.detr_models.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 10 | 11 | 12 | class HungarianMatcher(nn.Module): 13 | """This class computes an assignment between the targets and the predictions of the network 14 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 15 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 16 | while the others are un-matched (and thus treated as non-objects). 17 | """ 18 | 19 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 20 | """Creates the matcher 21 | Params: 22 | cost_class: This is the relative weight of the classification error in the matching cost 23 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 24 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 25 | """ 26 | super().__init__() 27 | self.cost_class = cost_class 28 | self.cost_bbox = cost_bbox 29 | self.cost_giou = cost_giou 30 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 31 | 32 | @torch.no_grad() 33 | def forward(self, outputs, targets): 34 | """ Performs the matching 35 | Params: 36 | outputs: This is a dict that contains at least these entries: 37 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 38 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 39 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 40 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 41 | objects in the target) containing the class labels 42 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 43 | Returns: 44 | A list of size batch_size, containing tuples of (index_i, index_j) where: 45 | - index_i is the indices of the selected predictions (in order) 46 | - index_j is the indices of the corresponding selected targets (in order) 47 | For each batch element, it holds: 48 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 49 | """ 50 | bs, num_queries = outputs["pred_logits"].shape[:2] 51 | 52 | # We flatten to compute the cost matrices in a batch 53 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 54 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 55 | 56 | # Also concat the target labels and boxes 57 | tgt_ids = torch.cat([v["labels"] for v in targets]) 58 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 59 | 60 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 61 | # but approximate it in 1 - proba[target class]. 62 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 63 | cost_class = -out_prob[:, tgt_ids] 64 | 65 | # Compute the L1 cost between boxes 66 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 67 | 68 | # Compute the giou cost betwen boxes 69 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 70 | 71 | # Final cost matrix 72 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 73 | C = C.view(bs, num_queries, -1).cpu() 74 | 75 | sizes = [len(v["boxes"]) for v in targets] 76 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 77 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 78 | 79 | 80 | def build_matcher(args): 81 | return HungarianMatcher(cost_class=args.SET_COST_CLASS, cost_bbox=args.SET_COST_BBOX, cost_giou=args.SET_COST_GIOU) -------------------------------------------------------------------------------- /models/detr_models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # This code was copied from https://github.com/facebookresearch/detr 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from models.detr_models.util.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors 30 | mask = tensor_list.mask 31 | assert mask is not None 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 48 | return pos 49 | 50 | 51 | class PositionEmbeddingLearned(nn.Module): 52 | """ 53 | Absolute pos embedding, learned. 54 | """ 55 | def __init__(self, num_pos_feats=256): 56 | super().__init__() 57 | self.row_embed = nn.Embedding(50, num_pos_feats) 58 | self.col_embed = nn.Embedding(50, num_pos_feats) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | nn.init.uniform_(self.row_embed.weight) 63 | nn.init.uniform_(self.col_embed.weight) 64 | 65 | def forward(self, tensor_list: NestedTensor): 66 | x = tensor_list.tensors 67 | h, w = x.shape[-2:] 68 | i = torch.arange(w, device=x.device) 69 | j = torch.arange(h, device=x.device) 70 | x_emb = self.col_embed(i) 71 | y_emb = self.row_embed(j) 72 | pos = torch.cat([ 73 | x_emb.unsqueeze(0).repeat(h, 1, 1), 74 | y_emb.unsqueeze(1).repeat(1, w, 1), 75 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 76 | return pos 77 | 78 | 79 | def build_position_encoding(args): 80 | N_steps = 256 // 2 81 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 82 | return position_embedding 83 | -------------------------------------------------------------------------------- /models/detr_models/transformer.py: -------------------------------------------------------------------------------- 1 | # This code was copied from https://github.com/facebookresearch/detr 2 | """ 3 | DETR Transformer class. 4 | Copy-paste from torch.nn.Transformer with modifications: 5 | * positional encodings are passed in MHattention 6 | * extra LN at the end of encoder is removed 7 | * decoder returns a stack of activations from all decoding layers 8 | """ 9 | import copy 10 | from typing import Optional, List 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn, Tensor 15 | 16 | 17 | class Transformer(nn.Module): 18 | 19 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 20 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 21 | activation="relu", normalize_before=False, 22 | return_intermediate_dec=False): 23 | super().__init__() 24 | 25 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 26 | dropout, activation, normalize_before) 27 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 28 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 29 | 30 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 31 | dropout, activation, normalize_before) 32 | decoder_norm = nn.LayerNorm(d_model) 33 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 34 | return_intermediate=return_intermediate_dec) 35 | 36 | self._reset_parameters() 37 | 38 | self.d_model = d_model 39 | self.nhead = nhead 40 | 41 | def _reset_parameters(self): 42 | for p in self.parameters(): 43 | if p.dim() > 1: 44 | nn.init.xavier_uniform_(p) 45 | 46 | def forward(self, src, mask, query_embed, pos_embed): 47 | # flatten NxCxHxW to HWxNxC 48 | bs, c, h, w = src.shape 49 | src = src.flatten(2).permute(2, 0, 1) 50 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 51 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 52 | mask = mask.flatten(1) 53 | 54 | tgt = torch.zeros_like(query_embed) 55 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 56 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 57 | pos=pos_embed, query_pos=query_embed) 58 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 59 | 60 | 61 | class TransformerEncoder(nn.Module): 62 | 63 | def __init__(self, encoder_layer, num_layers, norm=None): 64 | super().__init__() 65 | self.layers = _get_clones(encoder_layer, num_layers) 66 | self.num_layers = num_layers 67 | self.norm = norm 68 | 69 | def forward(self, src, 70 | mask: Optional[Tensor] = None, 71 | src_key_padding_mask: Optional[Tensor] = None, 72 | pos: Optional[Tensor] = None): 73 | output = src 74 | 75 | for layer in self.layers: 76 | output = layer(output, src_mask=mask, 77 | src_key_padding_mask=src_key_padding_mask, pos=pos) 78 | 79 | if self.norm is not None: 80 | output = self.norm(output) 81 | 82 | return output 83 | 84 | 85 | class TransformerDecoder(nn.Module): 86 | 87 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 88 | super().__init__() 89 | self.layers = _get_clones(decoder_layer, num_layers) 90 | self.num_layers = num_layers 91 | self.norm = norm 92 | self.return_intermediate = return_intermediate 93 | 94 | def forward(self, tgt, memory, 95 | tgt_mask: Optional[Tensor] = None, 96 | memory_mask: Optional[Tensor] = None, 97 | tgt_key_padding_mask: Optional[Tensor] = None, 98 | memory_key_padding_mask: Optional[Tensor] = None, 99 | pos: Optional[Tensor] = None, 100 | query_pos: Optional[Tensor] = None): 101 | output = tgt 102 | 103 | intermediate = [] 104 | 105 | for layer in self.layers: 106 | output = layer(output, memory, tgt_mask=tgt_mask, 107 | memory_mask=memory_mask, 108 | tgt_key_padding_mask=tgt_key_padding_mask, 109 | memory_key_padding_mask=memory_key_padding_mask, 110 | pos=pos, query_pos=query_pos) 111 | if self.return_intermediate: 112 | intermediate.append(self.norm(output)) 113 | 114 | if self.norm is not None: 115 | output = self.norm(output) 116 | if self.return_intermediate: 117 | intermediate.pop() 118 | intermediate.append(output) 119 | 120 | if self.return_intermediate: 121 | return torch.stack(intermediate) 122 | 123 | return output.unsqueeze(0) 124 | 125 | 126 | class TransformerEncoderLayer(nn.Module): 127 | 128 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 129 | activation="relu", normalize_before=False): 130 | super().__init__() 131 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 132 | # Implementation of Feedforward model 133 | self.linear1 = nn.Linear(d_model, dim_feedforward) 134 | self.dropout = nn.Dropout(dropout) 135 | self.linear2 = nn.Linear(dim_feedforward, d_model) 136 | 137 | self.norm1 = nn.LayerNorm(d_model) 138 | self.norm2 = nn.LayerNorm(d_model) 139 | self.dropout1 = nn.Dropout(dropout) 140 | self.dropout2 = nn.Dropout(dropout) 141 | 142 | self.activation = _get_activation_fn(activation) 143 | self.normalize_before = normalize_before 144 | 145 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 146 | return tensor if pos is None else tensor + pos 147 | 148 | def forward_post(self, 149 | src, 150 | src_mask: Optional[Tensor] = None, 151 | src_key_padding_mask: Optional[Tensor] = None, 152 | pos: Optional[Tensor] = None): 153 | q = k = self.with_pos_embed(src, pos) 154 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 155 | key_padding_mask=src_key_padding_mask)[0] 156 | src = src + self.dropout1(src2) 157 | src = self.norm1(src) 158 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 159 | src = src + self.dropout2(src2) 160 | src = self.norm2(src) 161 | return src 162 | 163 | def forward_pre(self, src, 164 | src_mask: Optional[Tensor] = None, 165 | src_key_padding_mask: Optional[Tensor] = None, 166 | pos: Optional[Tensor] = None): 167 | src2 = self.norm1(src) 168 | q = k = self.with_pos_embed(src2, pos) 169 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 170 | key_padding_mask=src_key_padding_mask)[0] 171 | src = src + self.dropout1(src2) 172 | src2 = self.norm2(src) 173 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 174 | src = src + self.dropout2(src2) 175 | return src 176 | 177 | def forward(self, src, 178 | src_mask: Optional[Tensor] = None, 179 | src_key_padding_mask: Optional[Tensor] = None, 180 | pos: Optional[Tensor] = None): 181 | if self.normalize_before: 182 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 183 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 184 | 185 | 186 | class TransformerDecoderLayer(nn.Module): 187 | 188 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 189 | activation="relu", normalize_before=False): 190 | super().__init__() 191 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 192 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 193 | # Implementation of Feedforward model 194 | self.linear1 = nn.Linear(d_model, dim_feedforward) 195 | self.dropout = nn.Dropout(dropout) 196 | self.linear2 = nn.Linear(dim_feedforward, d_model) 197 | 198 | self.norm1 = nn.LayerNorm(d_model) 199 | self.norm2 = nn.LayerNorm(d_model) 200 | self.norm3 = nn.LayerNorm(d_model) 201 | self.dropout1 = nn.Dropout(dropout) 202 | self.dropout2 = nn.Dropout(dropout) 203 | self.dropout3 = nn.Dropout(dropout) 204 | 205 | self.activation = _get_activation_fn(activation) 206 | self.normalize_before = normalize_before 207 | 208 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 209 | return tensor if pos is None else tensor + pos 210 | 211 | def forward_post(self, tgt, memory, 212 | tgt_mask: Optional[Tensor] = None, 213 | memory_mask: Optional[Tensor] = None, 214 | tgt_key_padding_mask: Optional[Tensor] = None, 215 | memory_key_padding_mask: Optional[Tensor] = None, 216 | pos: Optional[Tensor] = None, 217 | query_pos: Optional[Tensor] = None): 218 | q = k = self.with_pos_embed(tgt, query_pos) 219 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 220 | key_padding_mask=tgt_key_padding_mask)[0] 221 | tgt = tgt + self.dropout1(tgt2) 222 | tgt = self.norm1(tgt) 223 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 224 | key=self.with_pos_embed(memory, pos), 225 | value=memory, attn_mask=memory_mask, 226 | key_padding_mask=memory_key_padding_mask)[0] 227 | tgt = tgt + self.dropout2(tgt2) 228 | tgt = self.norm2(tgt) 229 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 230 | tgt = tgt + self.dropout3(tgt2) 231 | tgt = self.norm3(tgt) 232 | return tgt 233 | 234 | def forward_pre(self, tgt, memory, 235 | tgt_mask: Optional[Tensor] = None, 236 | memory_mask: Optional[Tensor] = None, 237 | tgt_key_padding_mask: Optional[Tensor] = None, 238 | memory_key_padding_mask: Optional[Tensor] = None, 239 | pos: Optional[Tensor] = None, 240 | query_pos: Optional[Tensor] = None): 241 | tgt2 = self.norm1(tgt) 242 | q = k = self.with_pos_embed(tgt2, query_pos) 243 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 244 | key_padding_mask=tgt_key_padding_mask)[0] 245 | tgt = tgt + self.dropout1(tgt2) 246 | tgt2 = self.norm2(tgt) 247 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 248 | key=self.with_pos_embed(memory, pos), 249 | value=memory, attn_mask=memory_mask, 250 | key_padding_mask=memory_key_padding_mask)[0] 251 | tgt = tgt + self.dropout2(tgt2) 252 | tgt2 = self.norm3(tgt) 253 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 254 | tgt = tgt + self.dropout3(tgt2) 255 | return tgt 256 | 257 | def forward(self, tgt, memory, 258 | tgt_mask: Optional[Tensor] = None, 259 | memory_mask: Optional[Tensor] = None, 260 | tgt_key_padding_mask: Optional[Tensor] = None, 261 | memory_key_padding_mask: Optional[Tensor] = None, 262 | pos: Optional[Tensor] = None, 263 | query_pos: Optional[Tensor] = None): 264 | if self.normalize_before: 265 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 266 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 267 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 268 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 269 | 270 | 271 | def _get_clones(module, N): 272 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 273 | 274 | 275 | def build_transformer(args): 276 | return Transformer( 277 | d_model=256, 278 | dropout=0.1, 279 | nhead=8, 280 | dim_feedforward=2048, 281 | num_encoder_layers=6, 282 | num_decoder_layers=6, 283 | normalize_before=False, 284 | return_intermediate_dec=True, 285 | ) 286 | 287 | 288 | def _get_activation_fn(activation): 289 | """Return an activation function given a string""" 290 | if activation == "relu": 291 | return F.relu 292 | if activation == "gelu": 293 | return F.gelu 294 | if activation == "glu": 295 | return F.glu 296 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 297 | -------------------------------------------------------------------------------- /models/detr_models/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/interactron/e94a1c3c7bd442708f4a3a6d8bc5f586f597a02d/models/detr_models/util/__init__.py -------------------------------------------------------------------------------- /models/detr_models/util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | 8 | def box_cxcywh_to_xyxy(x): 9 | x_c, y_c, w, h = x.unbind(-1) 10 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 11 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 12 | return torch.stack(b, dim=-1) 13 | 14 | 15 | def box_xyxy_to_cxcywh(x): 16 | x0, y0, x1, y1 = x.unbind(-1) 17 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 18 | (x1 - x0), (y1 - y0)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | # modified from torchvision to also return the union 23 | def box_iou(boxes1, boxes2): 24 | area1 = box_area(boxes1) 25 | area2 = box_area(boxes2) 26 | 27 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 28 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 29 | 30 | wh = (rb - lt).clamp(min=0) # [N,M,2] 31 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 32 | 33 | union = area1[:, None] + area2 - inter 34 | 35 | iou = inter / union 36 | return iou, union 37 | 38 | 39 | def generalized_box_iou(boxes1, boxes2): 40 | """ 41 | Generalized IoU from https://giou.stanford.edu/ 42 | The boxes should be in [x0, y0, x1, y1] format 43 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 44 | and M = len(boxes2) 45 | """ 46 | # degenerate boxes gives inf / nan results 47 | # so do an early check 48 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 49 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 50 | iou, union = box_iou(boxes1, boxes2) 51 | 52 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 53 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 54 | 55 | wh = (rb - lt).clamp(min=0) # [N,M,2] 56 | area = wh[:, :, 0] * wh[:, :, 1] 57 | 58 | return iou - (area - union) / area 59 | 60 | 61 | def masks_to_boxes(masks): 62 | """Compute the bounding boxes around the provided masks 63 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 64 | Returns a [N, 4] tensors, with the boxes in xyxy format 65 | """ 66 | if masks.numel() == 0: 67 | return torch.zeros((0, 4), device=masks.device) 68 | 69 | h, w = masks.shape[-2:] 70 | 71 | y = torch.arange(0, h, dtype=torch.float) 72 | x = torch.arange(0, w, dtype=torch.float) 73 | y, x = torch.meshgrid(y, x) 74 | 75 | x_mask = (masks * x.unsqueeze(0)) 76 | x_max = x_mask.flatten(1).max(-1)[0] 77 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 78 | 79 | y_mask = (masks * y.unsqueeze(0)) 80 | y_max = y_mask.flatten(1).max(-1)[0] 81 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 82 | 83 | return torch.stack([x_min, y_min, x_max, y_max], 1) -------------------------------------------------------------------------------- /models/detr_models/util/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Transforms and data augmentation for both image + bbox. 4 | """ 5 | import random 6 | 7 | import PIL 8 | import torch 9 | import torchvision.transforms as T 10 | import torchvision.transforms.functional as F 11 | 12 | from models.detr_models.util.box_ops import box_xyxy_to_cxcywh 13 | from models.detr_models.util.misc import interpolate 14 | 15 | 16 | def crop(image, target, region): 17 | cropped_image = F.crop(image, *region) 18 | 19 | if target is not None: 20 | target = target.copy() 21 | i, j, h, w = region 22 | 23 | # should we do something wrt the original size? 24 | target["size"] = torch.tensor([h, w]) 25 | 26 | fields = ["labels", "area", "iscrowd"] 27 | 28 | if "boxes" in target: 29 | boxes = target["boxes"] 30 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 31 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 32 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 33 | cropped_boxes = cropped_boxes.clamp(min=0) 34 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 35 | target["boxes"] = cropped_boxes.reshape(-1, 4) 36 | target["area"] = area 37 | fields.append("boxes") 38 | 39 | if "masks" in target: 40 | # FIXME should we update the area here if there are no boxes? 41 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 42 | fields.append("masks") 43 | 44 | # remove elements for which the boxes or masks that have zero area 45 | if "boxes" in target or "masks" in target: 46 | # favor boxes selection when defining which elements to keep 47 | # this is compatible with previous implementation 48 | if "boxes" in target: 49 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 50 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 51 | else: 52 | keep = target['masks'].flatten(1).any(1) 53 | 54 | for field in fields: 55 | target[field] = target[field][keep] 56 | 57 | return cropped_image, target 58 | 59 | 60 | def hflip(image, target): 61 | flipped_image = F.hflip(image) 62 | 63 | w, h = image.size 64 | 65 | if target is not None: 66 | target = target.copy() 67 | if "boxes" in target: 68 | boxes = target["boxes"] 69 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 70 | target["boxes"] = boxes 71 | 72 | if "masks" in target: 73 | target['masks'] = target['masks'].flip(-1) 74 | 75 | return flipped_image, target 76 | 77 | 78 | def resize(image, target, size, max_size=None): 79 | # size can be min_size (scalar) or (w, h) tuple 80 | 81 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 82 | w, h = image_size 83 | if max_size is not None: 84 | min_original_size = float(min((w, h))) 85 | max_original_size = float(max((w, h))) 86 | if max_original_size / min_original_size * size > max_size: 87 | size = int(round(max_size * min_original_size / max_original_size)) 88 | 89 | if (w <= h and w == size) or (h <= w and h == size): 90 | return (h, w) 91 | 92 | if w < h: 93 | ow = size 94 | oh = int(size * h / w) 95 | else: 96 | oh = size 97 | ow = int(size * w / h) 98 | 99 | return (oh, ow) 100 | 101 | def get_size(image_size, size, max_size=None): 102 | if isinstance(size, (list, tuple)): 103 | return size[::-1] 104 | else: 105 | return get_size_with_aspect_ratio(image_size, size, max_size) 106 | 107 | size = get_size(image.size, size, max_size) 108 | rescaled_image = F.resize(image, size) 109 | 110 | if target is None: 111 | return rescaled_image, None 112 | 113 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 114 | ratio_width, ratio_height = ratios 115 | 116 | target = target.copy() 117 | if "boxes" in target: 118 | boxes = target["boxes"] 119 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 120 | target["boxes"] = scaled_boxes 121 | 122 | if "area" in target: 123 | area = target["area"] 124 | scaled_area = area * (ratio_width * ratio_height) 125 | target["area"] = scaled_area 126 | 127 | h, w = size 128 | target["size"] = torch.tensor([h, w]) 129 | 130 | if "masks" in target: 131 | target['masks'] = interpolate( 132 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 133 | 134 | return rescaled_image, target 135 | 136 | 137 | def pad(image, target, padding): 138 | # assumes that we only pad on the bottom right corners 139 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 140 | if target is None: 141 | return padded_image, None 142 | target = target.copy() 143 | # should we do something wrt the original size? 144 | target["size"] = torch.tensor(padded_image.size[::-1]) 145 | if "masks" in target: 146 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 147 | return padded_image, target 148 | 149 | 150 | class RandomCrop(object): 151 | def __init__(self, size): 152 | self.size = size 153 | 154 | def __call__(self, img, target): 155 | region = T.RandomCrop.get_params(img, self.size) 156 | return crop(img, target, region) 157 | 158 | 159 | class RandomSizeCrop(object): 160 | def __init__(self, min_size: int, max_size: int): 161 | self.min_size = min_size 162 | self.max_size = max_size 163 | 164 | def __call__(self, img: PIL.Image.Image, target: dict): 165 | w = random.randint(self.min_size, min(img.width, self.max_size)) 166 | h = random.randint(self.min_size, min(img.height, self.max_size)) 167 | region = T.RandomCrop.get_params(img, [h, w]) 168 | return crop(img, target, region) 169 | 170 | 171 | class CenterCrop(object): 172 | def __init__(self, size): 173 | self.size = size 174 | 175 | def __call__(self, img, target): 176 | image_width, image_height = img.size 177 | crop_height, crop_width = self.size 178 | crop_top = int(round((image_height - crop_height) / 2.)) 179 | crop_left = int(round((image_width - crop_width) / 2.)) 180 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 181 | 182 | 183 | class RandomHorizontalFlip(object): 184 | def __init__(self, p=0.5): 185 | self.p = p 186 | 187 | def __call__(self, img, target): 188 | if random.random() < self.p: 189 | return hflip(img, target) 190 | return img, target 191 | 192 | 193 | class RandomResize(object): 194 | def __init__(self, sizes, max_size=None): 195 | assert isinstance(sizes, (list, tuple)) 196 | self.sizes = sizes 197 | self.max_size = max_size 198 | 199 | def __call__(self, img, target=None): 200 | size = random.choice(self.sizes) 201 | return resize(img, target, size, self.max_size) 202 | 203 | 204 | class RandomPad(object): 205 | def __init__(self, max_pad): 206 | self.max_pad = max_pad 207 | 208 | def __call__(self, img, target): 209 | pad_x = random.randint(0, self.max_pad) 210 | pad_y = random.randint(0, self.max_pad) 211 | return pad(img, target, (pad_x, pad_y)) 212 | 213 | 214 | class RandomSelect(object): 215 | """ 216 | Randomly selects between transforms1 and transforms2, 217 | with probability p for transforms1 and (1 - p) for transforms2 218 | """ 219 | def __init__(self, transforms1, transforms2, p=0.5): 220 | self.transforms1 = transforms1 221 | self.transforms2 = transforms2 222 | self.p = p 223 | 224 | def __call__(self, img, target): 225 | if random.random() < self.p: 226 | return self.transforms1(img, target) 227 | return self.transforms2(img, target) 228 | 229 | 230 | class ToTensor(object): 231 | def __call__(self, img, target): 232 | return F.to_tensor(img), target 233 | 234 | 235 | class RandomErasing(object): 236 | 237 | def __init__(self, *args, **kwargs): 238 | self.eraser = T.RandomErasing(*args, **kwargs) 239 | 240 | def __call__(self, img, target): 241 | return self.eraser(img), target 242 | 243 | 244 | class Normalize(object): 245 | def __init__(self, mean, std): 246 | self.mean = mean 247 | self.std = std 248 | 249 | def __call__(self, image, target=None): 250 | image = F.normalize(image, mean=self.mean, std=self.std) 251 | if target is None: 252 | return image, None 253 | target = target.copy() 254 | h, w = image.shape[-2:] 255 | if "boxes" in target: 256 | boxes = target["boxes"] 257 | boxes = box_xyxy_to_cxcywh(boxes) 258 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 259 | target["boxes"] = boxes 260 | return image, target 261 | 262 | 263 | class Compose(object): 264 | def __init__(self, transforms): 265 | self.transforms = transforms 266 | 267 | def __call__(self, image, target): 268 | for t in self.transforms: 269 | image, target = t(image, target) 270 | return image, target 271 | 272 | def __repr__(self): 273 | format_string = self.__class__.__name__ + "(" 274 | for t in self.transforms: 275 | format_string += "\n" 276 | format_string += " {0}".format(t) 277 | format_string += "\n)" 278 | return format_string 279 | -------------------------------------------------------------------------------- /models/detr_multiframe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.detr_models.detr import build 5 | from models.detr_models.util.misc import NestedTensor 6 | from models.transformer import Transformer 7 | 8 | 9 | class detr_multiframe(nn.Module): 10 | 11 | def __init__( 12 | self, 13 | config, 14 | ): 15 | super().__init__() 16 | # build DETR detector 17 | self.detector, self.criterion, self.postprocessor = build(config) 18 | self.detector.load_state_dict(torch.load(config.WEIGHTS, map_location=torch.device('cpu'))['model']) 19 | # build fusion transformer 20 | self.fusion = Transformer(config) 21 | self.logger = None 22 | self.mode = 'train' 23 | 24 | def predict(self, data): 25 | # reformat img and mask data 26 | b, s, c, w, h = data["frames"].shape 27 | img = data["frames"].view(b*s, c, w, h) 28 | mask = data["masks"].view(b*s, w, h) 29 | # reformat labels 30 | labels = [] 31 | for i in range(b): 32 | for j in range(s): 33 | labels.append({ 34 | "labels": data["category_ids"][i][j], 35 | "boxes": data["boxes"][i][j] 36 | }) 37 | # get predictions and losses 38 | detr_out = self.detector(NestedTensor(img, mask)) 39 | # unfold images back into batch and sequences 40 | # for key in detr_out: 41 | # detr_out[key] = detr_out[key].view(b, s, *detr_out[key].shape[1:]) 42 | detr_out["embedded_memory_features"] = detr_out["embedded_memory_features"].unsqueeze(0) 43 | detr_out["box_features"] = detr_out["box_features"].unsqueeze(0) 44 | detr_out["pred_logits"] = detr_out["pred_logits"].unsqueeze(0) 45 | detr_out["pred_boxes"] = detr_out["pred_boxes"].unsqueeze(0) 46 | out = self.fusion(detr_out) 47 | 48 | predictions = { 49 | "pred_boxes": out["pred_boxes"].view(b, s, *out["pred_boxes"].shape[1:]), 50 | "pred_logits": out["pred_logits"].view(b, s, *out["pred_logits"].shape[1:]) 51 | } 52 | 53 | return predictions 54 | 55 | def forward(self, data): 56 | # reformat img and mask data 57 | b, s, c, w, h = data["frames"].shape 58 | img = data["frames"].view(b, s, c, w, h) 59 | mask = data["masks"].view(b, s, w, h) 60 | # reformat labels 61 | labels = [] 62 | for i in range(b): 63 | labels.append([]) 64 | for j in range(s): 65 | labels[i].append({ 66 | "labels": data["category_ids"][i][j], 67 | "boxes": data["boxes"][i][j] 68 | }) 69 | 70 | losses = [] 71 | out_logits_list = [] 72 | out_boxes_list = [] 73 | 74 | for task in range(b): 75 | # get predictions and losses 76 | # with torch.no_grad(): 77 | detr_out = self.detector(NestedTensor(img[task], mask[task])) 78 | # unfold images back into batch and sequences 79 | # for key in detr_out: 80 | # detr_out[key] = detr_out[key].view(b, s, *detr_out[key].shape[1:]) 81 | detr_out["embedded_memory_features"] = detr_out["embedded_memory_features"].unsqueeze(0) 82 | detr_out["box_features"] = detr_out["box_features"].unsqueeze(0) 83 | detr_out["pred_logits"] = detr_out["pred_logits"].unsqueeze(0) 84 | detr_out["pred_boxes"] = detr_out["pred_boxes"].unsqueeze(0) 85 | out = self.fusion(detr_out) 86 | # out["pred_boxes"] = detr_out["pred_boxes"] 87 | # del out['actions'] 88 | # for key in out: 89 | # out[key] = out[key].reshape(b * s, *out[key].shape[2:]) 90 | # for key in detr_out: 91 | # detr_out[key] = detr_out[key].reshape(b * s, *detr_out[key].shape[2:]) 92 | 93 | loss = self.criterion(out, labels[task], background_c=0.1) 94 | total_loss = loss["loss_ce"] + 5 * loss["loss_giou"] + 2 * loss["loss_bbox"] 95 | total_loss.backward() 96 | losses.append({k: v.detach() for k, v in loss.items()}) 97 | # clean up predictions 98 | # for key, val in out.items(): 99 | # out[key] = val.view(b, s, *val.shape[1:]) 100 | 101 | out_logits_list.append(out["pred_logits"][0:1]) 102 | out_boxes_list.append(out["pred_boxes"][0:1]) 103 | 104 | predictions = {"pred_logits": torch.stack(out_logits_list, dim=0), "pred_boxes": torch.stack(out_boxes_list, dim=0)} 105 | losses = {k.replace("loss", "loss_detector"): 106 | torch.mean(torch.stack([x[k] for x in losses])) 107 | for k, v in losses[0].items()} 108 | 109 | return predictions, losses 110 | 111 | def eval(self): 112 | return self.train(False) 113 | 114 | def train(self, mode=True): 115 | self.mode = 'train' if mode else 'test' 116 | self.detector.train(False) 117 | self.detector.transformer.decoder.train(mode) 118 | self.fusion.train(mode) 119 | return self 120 | 121 | def get_optimizer_groups(self, train_config): 122 | optim_groups = [ 123 | {"params": list(self.detector.parameters()), "weight_decay": 0.0}, 124 | {"params": list(self.fusion.parameters()), "weight_decay": 0.0}, 125 | ] 126 | return optim_groups 127 | 128 | def set_logger(self, logger): 129 | assert self.logger is None, "This model already has a logger!" 130 | self.logger = logger 131 | 132 | -------------------------------------------------------------------------------- /models/five_frame_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import matplotlib.pyplot as plt 6 | import cv2 7 | 8 | from models.detectron2_detector import Detectron2Detector 9 | from models.detr import DETRDetector 10 | from models.gpt import GPT 11 | from models.components import LinearBlock 12 | from utils.constants import tlvis_classes 13 | from utils.model_utils import merge_batch_seq, unmerge_batch_seq 14 | from utils.detection_utils import iou 15 | from utils.time_utils import Timer 16 | from utils.viz_utils import draw_box 17 | 18 | 19 | class LinearBlock(nn.Module): 20 | def __init__(self, in_dim, out_dim, bias=False): 21 | super().__init__() 22 | self.model = nn.Sequential( 23 | nn.Linear(in_features=in_dim, out_features=out_dim, bias=bias), 24 | nn.LayerNorm(out_dim), 25 | nn.GELU(), 26 | ) 27 | 28 | def forward(self, x): 29 | og_shape = x.shape 30 | x = self.model(x.view(-1, og_shape[-1])) 31 | return x.view(*og_shape[:-1], -1) 32 | 33 | 34 | class FiveFrameBaselineModel(nn.Module): 35 | 36 | def __init__(self, cfg): 37 | super().__init__() 38 | self.detector = DETRDetector(config=cfg) 39 | self.model = GPT(cfg.TRANSFORMER) 40 | self.proposal_encoder = LinearBlock(2264, cfg.TRANSFORMER.EMBEDDING_DIM, bias=False) 41 | self.img_feature_encoder = LinearBlock(2048, cfg.TRANSFORMER.EMBEDDING_DIM, bias=False) 42 | self.box_decoder = nn.Linear(in_features=1024, out_features=4, bias=False) 43 | self.category_decoder = nn.Linear(in_features=1024, out_features=1236, bias=False) 44 | self.cfg = cfg 45 | self.is_train = True 46 | self.timer = Timer() 47 | self.logger = None 48 | self.mode = 'train' 49 | 50 | def forward(self, images, labels): 51 | predictions = self.detector(images) 52 | predictions.nms(k=50) 53 | 54 | seq = self.fold_sequence(predictions) 55 | pred_embs = self.model(seq)[:, :250] 56 | pred_embs = F.gelu(pred_embs) 57 | predictions.set_logits(self.category_decoder(pred_embs), flat=True) 58 | # anchor_boxes = predictions.get_boxes() 59 | # anchor_boxes = anchor_boxes.view(anchor_boxes.shape[0], -1, anchor_boxes.shape[-1]) 60 | # predictions.set_boxes(anchor_boxes + self.box_decoder(pred_embs), flat=True) 61 | predictions.nms(k=50) 62 | 63 | with torch.no_grad(): 64 | labels.match_labels(predictions) 65 | 66 | # compute losses 67 | bounding_box_loss, category_loss = self.compute_losses(predictions, labels) 68 | losses = { 69 | "category_prediction_loss": category_loss, 70 | "bounding_box_loss": bounding_box_loss 71 | } 72 | 73 | return predictions, losses 74 | 75 | def compute_losses(self, predictions, labels): 76 | gt_cats = labels.get_matched_categories(flat=True) 77 | gt_boxes = labels.get_matched_boxes(flat=True) 78 | pred_logits = predictions.get_logits(flat=True) 79 | pred_boxes = predictions.get_boxes(flat=True) 80 | mask = gt_cats.view(-1) != self.cfg.DETECTOR.NUM_CLASSES 81 | # debugging 82 | if self.logger: 83 | self.logger.add_value( 84 | "{}/Number of matched ground truths".format(self.mode.capitalize()), 85 | torch.count_nonzero(mask) 86 | ) 87 | self.logger.add_value( 88 | "{}/Number of positive detections".format(self.mode.capitalize()), 89 | torch.count_nonzero(predictions.get_logits().argmax(-1) != self.cfg.DETECTOR.NUM_CLASSES) 90 | ) 91 | bbox_loss = F.mse_loss(pred_boxes.view(-1, 4)[mask], gt_boxes.view(-1, 4)[mask]) 92 | category_loss = F.cross_entropy(pred_logits.view(-1, pred_logits.shape[-1]), gt_cats.view(-1)) 93 | return bbox_loss, category_loss 94 | 95 | def configure_optimizer(self, train_config): 96 | optim_groups = self.model.get_optimizer_groups(train_config) + self.detector.get_optimizer_groups(train_config) 97 | optim_groups.append({ 98 | "params": list(self.proposal_encoder.parameters()), "weight_decay": train_config.WEIGHT_DECAY 99 | }) 100 | optim_groups.append({ 101 | "params": list(self.img_feature_encoder.parameters()), "weight_decay": train_config.WEIGHT_DECAY 102 | }) 103 | optim_groups.append({ 104 | "params": list(self.box_decoder.parameters()), "weight_decay": train_config.WEIGHT_DECAY 105 | }) 106 | optim_groups.append({ 107 | "params": list(self.category_decoder.parameters()), "weight_decay": train_config.WEIGHT_DECAY 108 | }) 109 | assert train_config.OPTIM_TYPE in ["Adam", "AdamW", "SGD"], \ 110 | "Invalid optimizer type {}. Please select Adam, AdamW or SGD" 111 | if train_config.OPTIM_TYPE == "AdamW": 112 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.LEARNING_RATE, 113 | betas=(train_config.BETA1, train_config.BETA2)) 114 | elif train_config.OPTIM_TYPE == "Adam": 115 | optimizer = torch.optim.Adam(optim_groups, lr=train_config.LEARNING_RATE, 116 | betas=(train_config.BETA1, train_config.BETA2)) 117 | else: 118 | optimizer = torch.optim.SGD(optim_groups, lr=train_config.LEARNING_RATE, momentum=train_config.MOMENTUM) 119 | return optimizer 120 | 121 | def eval(self): 122 | return self.train(False) 123 | 124 | def train(self, mode=True): 125 | self.mode = 'train' if mode else 'test' 126 | self.model.train(mode) 127 | self.proposal_encoder.train(mode) 128 | self.img_feature_encoder.train(mode) 129 | self.category_decoder.train(mode) 130 | self.box_decoder.train(mode) 131 | self.detector.train(mode) 132 | self.is_train = mode 133 | return self 134 | 135 | def match_proposals_to_labels(self, proposals, bounding_boxes, categories): 136 | labels = torch.ones(proposals.shape[0], proposals.shape[1], device=categories.device, dtype=torch.long) 137 | labels *= self.cfg.DETECTOR.NUM_CLASSES 138 | boxes = torch.zeros_like(proposals) 139 | for n in range(proposals.shape[0]): 140 | ious = torchvision.ops.box_iou(proposals[n], bounding_boxes[n]) 141 | max_ious, max_iou_idxs = ious.max(dim=1) 142 | best_iou_categories = categories[n][max_iou_idxs] 143 | match_mask = max_ious > 0.5 144 | labels[n][match_mask] = best_iou_categories[match_mask] 145 | best_iou_boxes = bounding_boxes[n][max_iou_idxs] 146 | boxes[n][match_mask] = best_iou_boxes[match_mask].float() 147 | return labels, boxes 148 | 149 | def prune_predictions(self, logits, boxes, box_features, backbone_boxes, k=50): 150 | pruned_logits = torch.zeros(logits.shape[0], k, logits.shape[2], device=logits.device) 151 | pruned_logits[:, :, -1] = 1.0 152 | pruned_boxes = torch.zeros(boxes.shape[0], k, boxes.shape[2], device=boxes.device) 153 | pruned_backbone_boxes = torch.zeros_like(pruned_boxes) 154 | pruned_box_features = torch.zeros(box_features.shape[0], k, box_features.shape[2], device=box_features.device) 155 | for n in range(logits.shape[0]): 156 | cats = logits[n, :, :-1].argmax(dim=-1) 157 | scores, _ = torch.max(F.softmax(logits[n], dim=-1)[:, :-1], dim=-1) 158 | pruned_indexes = torchvision.ops.batched_nms(boxes[n], scores, cats, iou_threshold=0.5)[:k] 159 | t = pruned_indexes.shape[0] 160 | pruned_logits[n][:t] = logits[n][pruned_indexes] 161 | pruned_boxes[n][:t] = boxes[n][pruned_indexes] 162 | pruned_box_features[n][:t] = box_features[n][pruned_indexes] 163 | pruned_backbone_boxes[n][:t] = backbone_boxes[n][pruned_indexes] 164 | return pruned_logits, pruned_boxes, pruned_box_features, pruned_backbone_boxes 165 | 166 | def fold_sequence(self, predictions): 167 | img_features = predictions.get_image_features() 168 | box_features = predictions.get_box_features() 169 | boxes = predictions.get_boxes() 170 | logits = predictions.get_logits() 171 | detections = torch.cat((box_features, boxes, logits), dim=-1) 172 | b, t = img_features.shape[:2] 173 | img_features = img_features.permute(0, 1, 3, 4, 2) 174 | seq_img_features = self.img_feature_encoder(img_features.reshape(b, -1, img_features.shape[-1])) 175 | det_image_features = self.proposal_encoder(detections.reshape(b, -1, detections.shape[-1])) 176 | return torch.cat((det_image_features, seq_img_features), dim=1) 177 | 178 | def set_logger(self, logger): 179 | assert self.logger is None, "This model already has a logger!" 180 | self.logger = logger 181 | self.detector.set_logger(logger) 182 | -------------------------------------------------------------------------------- /models/gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPT model 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | 13 | class CausalSelfAttention(nn.Module): 14 | """ 15 | A vanilla multi-head masked self-attention layer with a projection at the end. 16 | It is possible to use torch.nn.MultiheadAttention here but I am including an 17 | explicit implementation here to show that there is nothing too scary here. 18 | """ 19 | 20 | def __init__(self, config): 21 | super().__init__() 22 | assert config.EMBEDDING_DIM % config.NUM_HEADS == 0 23 | # key, query, value projections for all heads 24 | self.key = nn.Linear(config.EMBEDDING_DIM, config.EMBEDDING_DIM) 25 | self.query = nn.Linear(config.EMBEDDING_DIM, config.EMBEDDING_DIM) 26 | self.value = nn.Linear(config.EMBEDDING_DIM, config.EMBEDDING_DIM) 27 | # regularization 28 | self.attn_drop = nn.Dropout(config.ATTENTION_PDROP) 29 | self.resid_drop = nn.Dropout(config.RESIDUAL_PDROP) 30 | # output projection 31 | self.proj = nn.Linear(config.EMBEDDING_DIM, config.EMBEDDING_DIM) 32 | # causal mask to ensure that attention is only applied to the left in the input sequence 33 | # self.register_buffer("mask", torch.tril(torch.ones(config.BLOCK_SIZE, config.BLOCK_SIZE)) 34 | # .view(1, 1, config.BLOCK_SIZE, config.BLOCK_SIZE)) 35 | self.register_buffer("mask", torch.ones(config.BLOCK_SIZE, config.BLOCK_SIZE) 36 | .view(1, 1, config.BLOCK_SIZE, config.BLOCK_SIZE)) 37 | self.NUM_HEADS = config.NUM_HEADS 38 | 39 | def forward(self, x, layer_past=None): 40 | B, T, C = x.size() 41 | 42 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 43 | k = self.key(x).view(B, T, self.NUM_HEADS, C // self.NUM_HEADS).transpose(1, 2) # (B, nh, T, hs) 44 | q = self.query(x).view(B, T, self.NUM_HEADS, C // self.NUM_HEADS).transpose(1, 2) # (B, nh, T, hs) 45 | v = self.value(x).view(B, T, self.NUM_HEADS, C // self.NUM_HEADS).transpose(1, 2) # (B, nh, T, hs) 46 | 47 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 48 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 49 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 50 | att = F.softmax(att, dim=-1) 51 | att = self.attn_drop(att) 52 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 53 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 54 | 55 | # output projection 56 | y = self.resid_drop(self.proj(y)) 57 | return y 58 | 59 | 60 | class Block(nn.Module): 61 | """ an unassuming Transformer block """ 62 | 63 | def __init__(self, config): 64 | super().__init__() 65 | self.ln1 = nn.LayerNorm(config.EMBEDDING_DIM) 66 | self.ln2 = nn.LayerNorm(config.EMBEDDING_DIM) 67 | self.attn = CausalSelfAttention(config) 68 | self.mlp = nn.Sequential( 69 | nn.Linear(config.EMBEDDING_DIM, 4 * config.EMBEDDING_DIM), 70 | nn.GELU(), 71 | nn.Linear(4 * config.EMBEDDING_DIM, config.EMBEDDING_DIM), 72 | nn.Dropout(config.RESIDUAL_PDROP), 73 | ) 74 | 75 | def forward(self, x): 76 | x = x + self.attn(self.ln1(x)) 77 | x = x + self.mlp(self.ln2(x)) 78 | return x 79 | 80 | 81 | class GPT(nn.Module): 82 | """ the full GPT language model, with a context size of block_size """ 83 | 84 | def __init__(self, config): 85 | super().__init__() 86 | 87 | # # input embedding stem 88 | # self.tok_emb = nn.Embedding(config.vocab_size, config.EMBEDDING_DIM) 89 | self.pos_emb = nn.Parameter(torch.zeros(1, 255, config.EMBEDDING_DIM)) 90 | 91 | # self.seq_pos_embed = nn.Parameter(torch.zeros(1, 2060, config.EMBEDDING_DIM), requires_grad=False) 92 | self.seq_pos_embed = nn.Parameter(torch.zeros(1, 2060, config.EMBEDDING_DIM), requires_grad=True) 93 | self.embed_dim = config.EMBEDDING_DIM 94 | self.img_len = 19*19 95 | 96 | self.drop = nn.Dropout(config.EMBEDDING_PDROP) 97 | # transformer 98 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.NUM_LAYERS)]) 99 | # decoder head 100 | self.ln_f = nn.LayerNorm(config.EMBEDDING_DIM) 101 | self.head = nn.Linear(config.EMBEDDING_DIM, config.OUTPUT_SIZE, bias=False) 102 | 103 | self.block_size = config.BLOCK_SIZE 104 | self.apply(self._init_weights) 105 | 106 | def get_block_size(self): 107 | return self.block_size 108 | 109 | def init_pos_emb(self): 110 | img_sin_embed = get_2d_sincos_pos_embed(self.embed_dim // 2, int(self.img_len**.5)) 111 | img_pos_embed = torch.zeros((1, self.img_len, self.embed_dim)) 112 | img_pos_embed[:, :, :self.embed_dim // 2] = torch.from_numpy(img_sin_embed).float() 113 | 114 | seq_sin_embed = get_1d_sincos_pos_embed(self.embed_dim // 2, 11) 115 | seq_pos_embed = torch.zeros((1, 11, self.embed_dim)) 116 | seq_pos_embed[:, :, self.embed_dim // 2:] = torch.from_numpy(seq_sin_embed).float() 117 | 118 | pred_sin_embed = get_1d_sincos_pos_embed(self.embed_dim // 2, 50) 119 | pred_pos_embed = torch.zeros((1, 50, self.embed_dim)) 120 | pred_pos_embed[:, :, self.embed_dim // 2:] = torch.from_numpy(pred_sin_embed).float() + 0.2 121 | 122 | action_sin_embed = get_1d_sincos_pos_embed(self.embed_dim // 2, 5) 123 | action_pos_embed = torch.zeros((1, 5, self.embed_dim)) 124 | action_pos_embed[:, :, :self.embed_dim // 2] = torch.from_numpy(action_sin_embed).float() 125 | 126 | pos_emb = torch.zeros((1, 2060, self.embed_dim)) 127 | for i in range(5): 128 | pos_emb[:,(self.img_len+50)*i:(self.img_len+50)*i+self.img_len] = img_pos_embed + seq_pos_embed[:,i*2,:] 129 | pos_emb[:,(self.img_len+50)*i+self.img_len:(self.img_len+50)*(i+1)] = pred_pos_embed + seq_pos_embed[ 130 | :,i*2+1,:] 131 | pos_emb[:, 2055:, :] = action_pos_embed[:, :, :] + seq_pos_embed[:, -1, :] 132 | 133 | self.seq_pos_embed.data.copy_(pos_emb) 134 | 135 | def _init_weights(self, module): 136 | if isinstance(module, (nn.Linear, nn.Embedding)): 137 | module.weight.data.normal_(mean=0.0, std=0.02) 138 | if isinstance(module, nn.Linear) and module.bias is not None: 139 | module.bias.data.zero_() 140 | elif isinstance(module, nn.LayerNorm): 141 | module.bias.data.zero_() 142 | module.weight.data.fill_(1.0) 143 | 144 | def get_optimizer_groups(self, train_config): 145 | """ 146 | This long function is unfortunately doing something very simple and is being very defensive: 147 | We are separating out all parameters of the model into two buckets: those that will experience 148 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 149 | We are then returning the PyTorch optimizer object. 150 | """ 151 | 152 | # separate out all parameters to those that will and won't experience regularizing weight decay 153 | decay = set() 154 | no_decay = set() 155 | whitelist_weight_modules = (torch.nn.Linear, ) 156 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 157 | for mn, m in self.named_modules(): 158 | for pn, p in m.named_parameters(): 159 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 160 | 161 | if pn.endswith('bias'): 162 | # all biases will not be decayed 163 | no_decay.add(fpn) 164 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 165 | # weights of whitelist modules will be weight decayed 166 | decay.add(fpn) 167 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 168 | # weights of blacklist modules will NOT be weight decayed 169 | no_decay.add(fpn) 170 | 171 | # special case the position embedding parameter in the root GPT module as not decayed 172 | no_decay.add('pos_emb') 173 | 174 | # validate that we considered every parameter 175 | param_dict = {pn: p for pn, p in self.named_parameters()} 176 | inter_params = decay & no_decay 177 | union_params = decay | no_decay 178 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 179 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 180 | % (str(param_dict.keys() - union_params), ) 181 | 182 | # create the pytorch optimizer object 183 | optim_groups = [ 184 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.WEIGHT_DECAY}, 185 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 186 | ] 187 | return optim_groups 188 | 189 | def forward(self, seq): 190 | b, t = seq.shape[:2] 191 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 192 | 193 | # forward the GPT model 194 | position_embeddings = self.seq_pos_embed[:, :t, :] # each position maps to a (learnable) vector 195 | x = self.drop(seq + position_embeddings) 196 | x = self.blocks(x) 197 | x = self.ln_f(x) 198 | logits = self.head(x) 199 | 200 | return logits 201 | 202 | 203 | # Positional embeddings 204 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 205 | """ 206 | grid_size: int of the grid height and width 207 | return: 208 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 209 | """ 210 | grid_h = np.arange(grid_size, dtype=np.float32) 211 | grid_w = np.arange(grid_size, dtype=np.float32) 212 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 213 | grid = np.stack(grid, axis=0) 214 | 215 | grid = grid.reshape([2, 1, grid_size, grid_size]) 216 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 217 | if cls_token: 218 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 219 | return pos_embed 220 | 221 | 222 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 223 | assert embed_dim % 2 == 0 224 | 225 | # use half of dimensions to encode grid_h 226 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 227 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 228 | 229 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 230 | return emb 231 | 232 | 233 | def get_1d_sincos_pos_embed(embed_dim, n): 234 | grid = np.arange(n) 235 | return get_1d_sincos_pos_embed_from_grid(embed_dim, grid) 236 | 237 | 238 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 239 | """ 240 | embed_dim: output dimension for each position 241 | pos: a list of positions to be encoded: size (M,) 242 | out: (M, D) 243 | """ 244 | assert embed_dim % 2 == 0 245 | omega = np.arange(embed_dim // 2, dtype=np.float) 246 | omega /= embed_dim / 2. 247 | omega = 1. / 10000**omega # (D/2,) 248 | 249 | pos = pos.reshape(-1) # (M,) 250 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 251 | 252 | emb_sin = np.sin(out) # (M, D/2) 253 | emb_cos = np.cos(out) # (M, D/2) 254 | 255 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 256 | return emb 257 | -------------------------------------------------------------------------------- /models/interactron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | 6 | from models.detr_models.detr import build 7 | from models.detr_models.util.misc import NestedTensor 8 | from models.transformer import Transformer 9 | from utils.meta_utils import get_parameters, clone_parameters, sgd_step, set_parameters, detach_parameters, \ 10 | detach_gradients 11 | from utils.storage_utils import PathStorage 12 | 13 | 14 | class interactron(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | config, 19 | ): 20 | super().__init__() 21 | # build DETR detector 22 | self.detector, self.criterion, self.postprocessor = build(config) 23 | self.detector.load_state_dict(torch.load(config.WEIGHTS, map_location=torch.device('cpu'))['model']) 24 | # build fusion transformer 25 | self.fusion = Transformer(config) 26 | self.logger = None 27 | self.mode = 'train' 28 | self.path_storage = {} 29 | self.config = config 30 | 31 | def predict(self, data): 32 | 33 | # reformat img and mask data 34 | b, s, c, w, h = data["frames"].shape 35 | img = data["frames"].view(s, c, w, h) 36 | mask = data["masks"].view(s, w, h) 37 | 38 | theta = get_parameters(self.detector) 39 | theta_task = detach_parameters(clone_parameters(theta)) 40 | 41 | # get supervisor grads 42 | set_parameters(self.detector, theta_task) 43 | pre_adaptive_out = self.detector(NestedTensor(img, mask)) 44 | pre_adaptive_out["embedded_memory_features"] = pre_adaptive_out["embedded_memory_features"].unsqueeze(0) 45 | pre_adaptive_out["box_features"] = pre_adaptive_out["box_features"].unsqueeze(0) 46 | pre_adaptive_out["pred_logits"] = pre_adaptive_out["pred_logits"].unsqueeze(0) 47 | pre_adaptive_out["pred_boxes"] = pre_adaptive_out["pred_boxes"].unsqueeze(0) 48 | 49 | fusion_out = self.fusion(pre_adaptive_out) 50 | learned_loss = torch.norm(fusion_out["loss"]) 51 | detector_grad = torch.autograd.grad(learned_loss, theta_task, create_graph=True, retain_graph=True, 52 | allow_unused=True) 53 | fast_weights = sgd_step(theta_task, detector_grad, self.config.ADAPTIVE_LR) 54 | set_parameters(self.detector, fast_weights) 55 | post_adaptive_out = self.detector(NestedTensor(img[0:1], mask[0:1])) 56 | 57 | set_parameters(self.detector, theta) 58 | 59 | return {k: v.unsqueeze(0) for k, v in post_adaptive_out.items()} 60 | 61 | def forward(self, data, train=True): 62 | 63 | # reformat img and mask data 64 | b, s, c, w, h = data["frames"].shape 65 | img = data["frames"].view(b, s, c, w, h) 66 | mask = data["masks"].view(b, s, w, h) 67 | # reformat labels 68 | labels = [] 69 | for i in range(b): 70 | labels.append([]) 71 | for j in range(s): 72 | labels[i].append({ 73 | "labels": data["category_ids"][i][j], 74 | "boxes": data["boxes"][i][j] 75 | }) 76 | 77 | detector_losses = [] 78 | supervisor_losses = [] 79 | out_logits_list = [] 80 | out_boxes_list = [] 81 | 82 | theta = get_parameters(self.detector) 83 | 84 | for task in range(b): 85 | 86 | theta_task = clone_parameters(theta) 87 | 88 | # get supervisor grads 89 | detached_theta_task = detach_parameters(theta_task) 90 | set_parameters(self.detector, detached_theta_task) 91 | pre_adaptive_out = self.detector(NestedTensor(img[task], mask[task])) 92 | pre_adaptive_out["embedded_memory_features"] = pre_adaptive_out["embedded_memory_features"].unsqueeze(0) 93 | pre_adaptive_out["box_features"] = pre_adaptive_out["box_features"].unsqueeze(0) 94 | pre_adaptive_out["pred_logits"] = pre_adaptive_out["pred_logits"].unsqueeze(0) 95 | pre_adaptive_out["pred_boxes"] = pre_adaptive_out["pred_boxes"].unsqueeze(0) 96 | 97 | fusion_out = self.fusion(pre_adaptive_out) 98 | learned_loss = torch.norm(fusion_out["loss"]) 99 | detector_grad = torch.autograd.grad(learned_loss, detached_theta_task, create_graph=True, retain_graph=True, 100 | allow_unused=True) 101 | fast_weights = sgd_step(detached_theta_task, detector_grad, self.config.ADAPTIVE_LR) 102 | set_parameters(self.detector, fast_weights) 103 | post_adaptive_out = self.detector(NestedTensor(img[task], mask[task])) 104 | 105 | # lowest loss policy experiment 106 | first_frame_out = {k: v[[0]] for k, v in post_adaptive_out.items()} 107 | gt_loss = self.criterion(first_frame_out, [labels[task][0]], background_c=0.1) 108 | gt_loss = gt_loss["loss_ce"] + 5 * gt_loss["loss_giou"] + 2 * gt_loss["loss_bbox"] 109 | iip = data["initial_image_path"][task] 110 | rew = torch.mean(gt_loss).item() 111 | if iip not in self.path_storage: 112 | self.path_storage[iip] = PathStorage() 113 | self.path_storage[iip].add_path(data["actions"][task][:4], rew) 114 | best_path = torch.tensor(self.path_storage[iip].get_label(data["actions"][task][:4]), 115 | dtype=torch.long, device=gt_loss.device) 116 | 117 | supervisor_loss = self.criterion(post_adaptive_out, labels[task], background_c=0.1) 118 | supervisor_loss["loss_path"] = F.cross_entropy(fusion_out["actions"].view(4, 4), best_path) 119 | supervisor_loss["policy_reward"] = gt_loss 120 | supervisor_losses.append({k: v.detach() for k, v in supervisor_loss.items()}) 121 | supervisor_loss = supervisor_loss["loss_ce"] + 5 * supervisor_loss["loss_giou"] + \ 122 | 2 * supervisor_loss["loss_bbox"] + supervisor_loss["loss_path"] 123 | supervisor_loss.backward() 124 | 125 | # get detector grads 126 | fast_weights = sgd_step(theta_task, detach_gradients(detector_grad), self.config.ADAPTIVE_LR) 127 | set_parameters(self.detector, fast_weights) 128 | 129 | ridx = random.randint(0, 4) 130 | post_adaptive_out = self.detector(NestedTensor(img[task][ridx:ridx+1], mask[task][ridx:ridx+1])) 131 | detector_loss = self.criterion(post_adaptive_out, labels[task][ridx:ridx+1], background_c=0.1) 132 | detector_losses.append({k: v.detach() for k, v in detector_loss.items()}) 133 | detector_loss = detector_loss["loss_ce"] + 5 * detector_loss["loss_giou"] + 2 * detector_loss["loss_bbox"] 134 | detector_loss.backward() 135 | 136 | out_logits_list.append(post_adaptive_out["pred_logits"]) 137 | out_boxes_list.append(post_adaptive_out["pred_boxes"]) 138 | 139 | set_parameters(self.detector, theta) 140 | 141 | predictions = {"pred_logits": torch.stack(out_logits_list, dim=0), "pred_boxes": torch.stack(out_boxes_list, dim=0)} 142 | mean_detector_losses = {k.replace("loss", "loss_detector"): 143 | torch.mean(torch.stack([x[k] for x in detector_losses])) 144 | for k, v in detector_losses[0].items()} 145 | mean_supervisor_losses = {k.replace("loss", "loss_supervisor"): 146 | torch.mean(torch.stack([x[k] for x in supervisor_losses])) 147 | for k, v in supervisor_losses[0].items()} 148 | losses = mean_detector_losses 149 | losses.update(mean_supervisor_losses) 150 | 151 | return predictions, losses 152 | 153 | def eval(self): 154 | return self.train(False) 155 | 156 | def train(self, mode=True): 157 | self.mode = 'train' if mode else 'test' 158 | # only train proposal generator of detector 159 | self.detector.train(mode) 160 | self.fusion.train(mode) 161 | return self 162 | 163 | def get_optimizer_groups(self, train_config): 164 | optim_groups = [ 165 | {"params": list(self.decoder.parameters()), "weight_decay": 0.0}, 166 | {"params": list(self.detector.parameters()), "weight_decay": 0.0}, 167 | ] 168 | return optim_groups 169 | 170 | def set_logger(self, logger): 171 | assert self.logger is None, "This model already has a logger!" 172 | self.logger = logger 173 | 174 | def get_next_action(self, data): 175 | # reformat img and mask data 176 | b, s, c, w, h = data["frames"].shape 177 | img = data["frames"].view(b*s, c, w, h) 178 | mask = data["masks"].view(b*s, w, h) 179 | # reformat labels 180 | labels = [] 181 | for i in range(b): 182 | labels.append([]) 183 | for j in range(s): 184 | labels[i].append({ 185 | "labels": data["category_ids"][i][j], 186 | "boxes": data["boxes"][i][j] 187 | }) 188 | 189 | pre_adaptive_out = self.detector(NestedTensor(img, mask)) 190 | pre_adaptive_out["embedded_memory_features"] = pre_adaptive_out["embedded_memory_features"].unsqueeze(0) 191 | pre_adaptive_out["box_features"] = pre_adaptive_out["box_features"].unsqueeze(0) 192 | pre_adaptive_out["pred_logits"] = pre_adaptive_out["pred_logits"].unsqueeze(0) 193 | pre_adaptive_out["pred_boxes"] = pre_adaptive_out["pred_boxes"].unsqueeze(0) 194 | 195 | fusion_out = self.fusion(pre_adaptive_out) 196 | 197 | return fusion_out['actions'][s-1].argmax(dim=-1).item() 198 | 199 | -------------------------------------------------------------------------------- /models/interactron_random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.detr_models.detr import build 5 | from models.detr_models.util.misc import NestedTensor 6 | from models.new_transformer import Transformer 7 | from utils.meta_utils import get_parameters, clone_parameters, sgd_step, set_parameters, detach_parameters, \ 8 | detach_gradients 9 | 10 | 11 | class interactron_random(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | config, 16 | ): 17 | super().__init__() 18 | # build DETR detector 19 | self.detector, self.criterion, self.postprocessor = build(config) 20 | self.detector.load_state_dict(torch.load(config.WEIGHTS, map_location=torch.device('cpu'))['model']) 21 | # build fusion transformer 22 | self.fusion = Transformer(config) 23 | self.logger = None 24 | self.mode = 'train' 25 | self.config = config 26 | 27 | def predict(self, data): 28 | 29 | # reformat img and mask data 30 | b, s, c, w, h = data["frames"].shape 31 | img = data["frames"].view(s, c, w, h) 32 | mask = data["masks"].view(s, w, h) 33 | 34 | theta = get_parameters(self.detector) 35 | theta_task = detach_parameters(clone_parameters(theta)) 36 | 37 | # get supervisor grads 38 | set_parameters(self.detector, theta_task) 39 | pre_adaptive_out = self.detector(NestedTensor(img, mask)) 40 | pre_adaptive_out["embedded_memory_features"] = pre_adaptive_out["embedded_memory_features"].unsqueeze(0) 41 | pre_adaptive_out["box_features"] = pre_adaptive_out["box_features"].unsqueeze(0) 42 | pre_adaptive_out["pred_logits"] = pre_adaptive_out["pred_logits"].unsqueeze(0) 43 | pre_adaptive_out["pred_boxes"] = pre_adaptive_out["pred_boxes"].unsqueeze(0) 44 | 45 | fusion_out = self.fusion(pre_adaptive_out) 46 | learned_loss = torch.norm(fusion_out["loss"]) 47 | detector_grad = torch.autograd.grad(learned_loss, theta_task, create_graph=True, retain_graph=True, 48 | allow_unused=True) 49 | fast_weights = sgd_step(theta_task, detector_grad, self.config.ADAPTIVE_LR) 50 | set_parameters(self.detector, fast_weights) 51 | post_adaptive_out = self.detector(NestedTensor(img[0:1], mask[0:1])) 52 | 53 | set_parameters(self.detector, theta) 54 | 55 | return {k: v.unsqueeze(0) for k, v in post_adaptive_out.items()} 56 | 57 | def forward(self, data, train=True): 58 | # reformat img and mask data 59 | b, s, c, w, h = data["frames"].shape 60 | img = data["frames"].view(b, s, c, w, h) 61 | mask = data["masks"].view(b, s, w, h) 62 | # reformat labels 63 | labels = [] 64 | for i in range(b): 65 | labels.append([]) 66 | for j in range(s): 67 | labels[i].append({ 68 | "labels": data["category_ids"][i][j], 69 | "boxes": data["boxes"][i][j] 70 | }) 71 | 72 | detector_losses = [] 73 | supervisor_losses = [] 74 | out_logits_list = [] 75 | out_boxes_list = [] 76 | 77 | theta = get_parameters(self.detector) 78 | 79 | for task in range(b): 80 | 81 | theta_task = clone_parameters(theta) 82 | 83 | # get supervisor grads 84 | detached_theta_task = detach_parameters(theta_task) 85 | set_parameters(self.detector, detached_theta_task) 86 | pre_adaptive_out = self.detector(NestedTensor(img[task], mask[task])) 87 | pre_adaptive_out["embedded_memory_features"] = pre_adaptive_out["embedded_memory_features"].unsqueeze(0) 88 | pre_adaptive_out["box_features"] = pre_adaptive_out["box_features"].unsqueeze(0) 89 | pre_adaptive_out["pred_logits"] = pre_adaptive_out["pred_logits"].unsqueeze(0) 90 | pre_adaptive_out["pred_boxes"] = pre_adaptive_out["pred_boxes"].unsqueeze(0) 91 | 92 | fusion_out = self.fusion(pre_adaptive_out) 93 | learned_loss = torch.norm(fusion_out["loss"]) 94 | detector_grad = torch.autograd.grad(learned_loss, detached_theta_task, create_graph=True, retain_graph=True, 95 | allow_unused=True) 96 | fast_weights = sgd_step(detached_theta_task, detector_grad, self.config.ADAPTIVE_LR) 97 | set_parameters(self.detector, fast_weights) 98 | 99 | post_adaptive_out = self.detector(NestedTensor(img[task], mask[task])) 100 | supervisor_loss = self.criterion(post_adaptive_out, labels[task], background_c=0.1) 101 | supervisor_losses.append({k: v.detach() for k, v in supervisor_loss.items()}) 102 | supervisor_loss = supervisor_loss["loss_ce"] + 5 * supervisor_loss["loss_giou"] + 2 * supervisor_loss["loss_bbox"] 103 | supervisor_loss.backward() 104 | 105 | # get detector grads 106 | fast_weights = sgd_step(theta_task, detach_gradients(detector_grad), self.config.ADAPTIVE_LR) 107 | set_parameters(self.detector, fast_weights) 108 | 109 | import random 110 | ridx = random.randint(0, 4) 111 | # ridx = 0 112 | post_adaptive_out = self.detector(NestedTensor(img[task][ridx:ridx+1], mask[task][ridx:ridx+1])) 113 | detector_loss = self.criterion(post_adaptive_out, labels[task][ridx:ridx+1], background_c=0.1) 114 | detector_losses.append({k: v.detach() for k, v in detector_loss.items()}) 115 | detector_loss = detector_loss["loss_ce"] + 5 * detector_loss["loss_giou"] + 2 * detector_loss["loss_bbox"] 116 | detector_loss.backward() 117 | 118 | out_logits_list.append(post_adaptive_out["pred_logits"]) 119 | out_boxes_list.append(post_adaptive_out["pred_boxes"]) 120 | 121 | set_parameters(self.detector, theta) 122 | 123 | predictions = {"pred_logits": torch.stack(out_logits_list, dim=0), "pred_boxes": torch.stack(out_boxes_list, dim=0)} 124 | mean_detector_losses = {k.replace("loss", "loss_detector"): 125 | torch.mean(torch.stack([x[k] for x in detector_losses])) 126 | for k, v in detector_losses[0].items()} 127 | mean_supervisor_losses = {k.replace("loss", "loss_supervisor"): 128 | torch.mean(torch.stack([x[k] for x in supervisor_losses])) 129 | for k, v in supervisor_losses[0].items()} 130 | losses = mean_detector_losses 131 | losses.update(mean_supervisor_losses) 132 | return predictions, losses 133 | 134 | def eval(self): 135 | return self.train(False) 136 | 137 | def train(self, mode=True): 138 | self.mode = 'train' if mode else 'test' 139 | # only train proposal generator of detector 140 | self.detector.train(mode) 141 | self.fusion.train(mode) 142 | return self 143 | 144 | def get_optimizer_groups(self, train_config): 145 | optim_groups = [ 146 | {"params": list(self.decoder.parameters()), "weight_decay": 0.0}, 147 | {"params": list(self.detector.parameters()), "weight_decay": 0.0}, 148 | ] 149 | return optim_groups 150 | 151 | def set_logger(self, logger): 152 | assert self.logger is None, "This model already has a logger!" 153 | self.logger = logger 154 | 155 | -------------------------------------------------------------------------------- /models/learned_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import matplotlib.pyplot as plt 6 | import cv2 7 | 8 | from models.detectron2_detector import Detectron2Detector 9 | from models.gpt import GPT 10 | from models.components import LinearBlock 11 | from utils.constants import tlvis_classes 12 | from utils.model_utils import merge_batch_seq, unmerge_batch_seq 13 | from utils.detection_utils import iou 14 | from utils.time_utils import Timer 15 | from utils.viz_utils import draw_box 16 | 17 | 18 | class LinearBlock(nn.Module): 19 | def __init__(self, in_dim, out_dim, bias=False): 20 | super().__init__() 21 | self.model = nn.Sequential( 22 | nn.Linear(in_features=in_dim, out_features=out_dim, bias=bias), 23 | nn.LayerNorm(out_dim), 24 | nn.GELU(), 25 | ) 26 | 27 | def forward(self, x): 28 | og_shape = x.shape 29 | x = self.model(x.view(-1, og_shape[-1])) 30 | return x.view(*og_shape[:-1], -1) 31 | 32 | 33 | class LearnedLossModel(nn.Module): 34 | 35 | def __init__(self, cfg): 36 | super().__init__() 37 | self.model = GPT(cfg.TRANSFORMER) 38 | self.proposal_encoder = LinearBlock(2264, cfg.TRANSFORMER.EMBEDDING_DIM, bias=False) 39 | self.img_feature_encoder = LinearBlock(2048, cfg.TRANSFORMER.EMBEDDING_DIM, bias=False) 40 | self.box_decoder = nn.Linear(in_features=1024, out_features=4, bias=False) 41 | self.category_decoder = nn.Linear(in_features=1024, out_features=1236, bias=False) 42 | self.cfg = cfg 43 | self.is_train = True 44 | self.timer = Timer() 45 | self.logger = None 46 | self.mode = 'train' 47 | if cfg.TRANSFORMER.PREDICT_ACTIONS: 48 | self.policy_tokens = nn.Parameter(1, 5, cfg.TRANSFORMER.EMBEDDING_DIM) 49 | 50 | def forward(self, predictions, images): 51 | 52 | seq = self.fold_sequence(predictions) 53 | if self.cfg.TRANSFORMER.PREDICT_ACTIONS: 54 | seq = torch.cat((seq, self.policy_tokens), dim=1) 55 | out = self.model(seq) 56 | pred_embs = out[:, :250] 57 | learned_loss = torch.norm(pred_embs, p=2) 58 | 59 | if self.cfg.TRANSFORMER.PREDICT_ACTIONS: 60 | action_predictions = out[:, -5:] 61 | return learned_loss, action_predictions 62 | 63 | return learned_loss 64 | 65 | def configure_optimizer(self, train_config): 66 | optim_groups = self.model.get_optimizer_groups(train_config) 67 | assert train_config.OPTIM_TYPE in ["Adam", "AdamW", "SGD"], \ 68 | "Invalid optimizer type {}. Please select Adam, AdamW or SGD" 69 | if train_config.OPTIM_TYPE == "AdamW": 70 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.LEARNING_RATE, 71 | betas=(train_config.BETA1, train_config.BETA2)) 72 | elif train_config.OPTIM_TYPE == "Adam": 73 | optimizer = torch.optim.Adam(optim_groups, lr=train_config.LEARNING_RATE, 74 | betas=(train_config.BETA1, train_config.BETA2)) 75 | else: 76 | optimizer = torch.optim.SGD(optim_groups, lr=train_config.LEARNING_RATE, momentum=train_config.MOMENTUM) 77 | return optimizer 78 | 79 | def eval(self): 80 | return self.train(False) 81 | 82 | def train(self, mode=True): 83 | self.mode = 'train' if mode else 'test' 84 | self.model.train(mode) 85 | self.is_train = mode 86 | return self 87 | 88 | def fold_sequence(self, predictions): 89 | img_features = predictions.get_image_features() 90 | box_features = predictions.get_box_features() 91 | boxes = predictions.get_boxes() 92 | logits = predictions.get_logits() 93 | detections = torch.cat((box_features, boxes, logits), dim=-1) 94 | b, t = img_features.shape[:2] 95 | img_features = img_features.permute(0, 1, 3, 4, 2) 96 | seq_img_features = self.img_feature_encoder(img_features.reshape(b, -1, img_features.shape[-1])) 97 | det_image_features = self.proposal_encoder(detections.reshape(b, -1, detections.shape[-1])) 98 | return torch.cat((det_image_features, seq_img_features), dim=1) 99 | 100 | def set_logger(self, logger): 101 | assert self.logger is None, "This model already has a logger!" 102 | self.logger = logger 103 | -------------------------------------------------------------------------------- /models/mlp_detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import matplotlib.pyplot as plt 6 | import cv2 7 | 8 | from models.detectron2_detector import Detectron2Detector 9 | from models.gpt import GPT 10 | from models.components import LinearBlock 11 | from utils.constants import tlvis_classes 12 | from utils.model_utils import merge_batch_seq, unmerge_batch_seq 13 | from utils.detection_utils import iou 14 | from utils.time_utils import Timer 15 | from utils.viz_utils import draw_box 16 | from utils.detection_utils import Prediction 17 | 18 | 19 | class LinearBlock(nn.Module): 20 | def __init__(self, in_dim, out_dim, bias=False): 21 | super().__init__() 22 | self.model = nn.Sequential( 23 | nn.Linear(in_features=in_dim, out_features=out_dim, bias=bias), 24 | # nn.LayerNorm(out_dim), 25 | nn.GELU(), 26 | ) 27 | 28 | def forward(self, x): 29 | og_shape = x.shape 30 | x = self.model(x.view(-1, og_shape[-1])) 31 | return x.view(*og_shape[:-1], -1) 32 | 33 | 34 | class MLPDetector(nn.Module): 35 | 36 | def __init__(self, cfg): 37 | super().__init__() 38 | self.model = nn.Sequential( 39 | LinearBlock(2264, 1024, bias=False), 40 | LinearBlock(1024, 1024, bias=False), 41 | nn.Linear(1024, 1236, bias=False), 42 | ) 43 | self.preprocessor = Detectron2Detector(config=cfg) 44 | self.preprocessor.eval() 45 | self.cfg = cfg 46 | self.is_train = True 47 | self.timer = Timer() 48 | self.logger = None 49 | self.mode = 'train' 50 | 51 | def preprocess(self, images): 52 | predictions = self.preprocessor(images) 53 | predictions.nms(k=50) 54 | return predictions 55 | 56 | def forward(self, predictions, labels, use_predictions_as_labels=False): 57 | 58 | image_features = predictions.get_image_features(flat=True).detach() 59 | box_features = predictions.get_box_features(flat=True).detach() 60 | logits = predictions.get_logits(flat=True).detach() 61 | boxes = predictions.get_boxes(flat=True).detach() 62 | new_logits = self.model(torch.cat((box_features, logits, boxes), dim=-1)) 63 | refined_predictions = Prediction( 64 | batch_size=predictions.batch_size, 65 | seq_len=predictions.seq_len, 66 | device=predictions.device, 67 | logits=new_logits.unsqueeze(0), 68 | boxes=boxes.unsqueeze(0), 69 | box_features=box_features.unsqueeze(0), 70 | image_features=image_features.unsqueeze(0) 71 | ) 72 | 73 | if use_predictions_as_labels: 74 | labels = predictions.make_labels_from_predictions(c=0.8) 75 | 76 | with torch.no_grad(): 77 | labels.match_labels(refined_predictions) 78 | 79 | # compute losses 80 | bounding_box_loss, category_loss = self.compute_losses(refined_predictions, labels) 81 | losses = { 82 | "category_prediction_loss": category_loss, 83 | "bounding_box_loss": 0.0 * bounding_box_loss # torch.tensor([0.0], device=category_loss.device) 84 | } 85 | 86 | return refined_predictions, losses 87 | 88 | def compute_losses(self, predictions, labels): 89 | gt_cats = labels.get_matched_categories(flat=True) 90 | gt_boxes = labels.get_matched_boxes(flat=True) 91 | pred_logits = predictions.get_logits(flat=True) 92 | pred_boxes = predictions.get_boxes(flat=True) 93 | box_mask = gt_cats.view(-1) != self.cfg.DETECTOR.NUM_CLASSES 94 | # cat_mask = box_mask.detach().clone() 95 | # _, pred_rankings_mask = torch.topk( 96 | # pred_logits.softmax(dim=-1)[:, :, -1].view(-1), int(len(cat_mask)*0.04)) 97 | # cat_mask[pred_rankings_mask] = 1.0 98 | # debugging 99 | if self.logger: 100 | self.logger.add_value( 101 | "{}/Number of matched ground truths".format(self.mode.capitalize()), 102 | torch.count_nonzero(box_mask) 103 | ) 104 | self.logger.add_value( 105 | "{}/Number of positive detections".format(self.mode.capitalize()), 106 | torch.count_nonzero(predictions.get_logits().argmax(-1) != self.cfg.DETECTOR.NUM_CLASSES) 107 | ) 108 | bbox_loss = F.l1_loss(pred_boxes.view(-1, 4)[box_mask], gt_boxes.view(-1, 4)[box_mask]) 109 | # nan guard 110 | if bbox_loss.isnan(): 111 | bbox_loss = torch.zeros_like(bbox_loss) 112 | weights = torch.ones(self.cfg.DETECTOR.NUM_CLASSES + 1, device=gt_cats.device) 113 | weights[-1] = 1 114 | category_loss = F.cross_entropy( 115 | pred_logits.view(-1, pred_logits.shape[-1]), # [cat_mask], 116 | gt_cats.view(-1), # [cat_mask], 117 | weight=weights 118 | ) 119 | return bbox_loss, category_loss 120 | 121 | def configure_optimizer(self, train_config): 122 | optim_groups = [{"params": list(self.model.parameters()), "weight_decay": train_config.WEIGHT_DECAY}] 123 | assert train_config.OPTIM_TYPE in ["Adam", "AdamW", "SGD"], \ 124 | "Invalid optimizer type {}. Please select Adam, AdamW or SGD" 125 | if train_config.OPTIM_TYPE == "AdamW": 126 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.LEARNING_RATE, 127 | betas=(train_config.BETA1, train_config.BETA2)) 128 | elif train_config.OPTIM_TYPE == "Adam": 129 | optimizer = torch.optim.Adam(optim_groups, lr=train_config.LEARNING_RATE, 130 | betas=(train_config.BETA1, train_config.BETA2)) 131 | else: 132 | optimizer = torch.optim.SGD(optim_groups, lr=train_config.LEARNING_RATE, momentum=train_config.MOMENTUM) 133 | return optimizer 134 | 135 | def eval(self): 136 | return self.train(False) 137 | 138 | def train(self, mode=True): 139 | self.mode = 'train' if mode else 'test' 140 | self.preprocessor.train(False) 141 | self.model.train(mode) 142 | self.is_train = mode 143 | return self 144 | 145 | def set_logger(self, logger): 146 | assert self.logger is None, "This model already has a logger!" 147 | self.logger = logger 148 | -------------------------------------------------------------------------------- /models/new_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | 6 | from models.detr_models.detr import MLP 7 | from models.detr_models.transformer import TransformerDecoderLayer, TransformerDecoder 8 | 9 | 10 | class Transformer(nn.Module): 11 | 12 | def __init__(self, config): 13 | super().__init__() 14 | self.img_feature_embedding = nn.Linear(config.IMG_FEATURE_SIZE, config.EMBEDDING_DIM) 15 | self.prediction_embedding = nn.Linear(config.BOX_EMB_SIZE + config.NUM_CLASSES + 5, config.EMBEDDING_DIM) 16 | self.box_decoder = MLP(config.OUTPUT_SIZE, 512, 4, 3) 17 | self.logit_decoder = nn.Linear(config.OUTPUT_SIZE, config.NUM_CLASSES + 1) 18 | self.loss_decoder = MLP(config.OUTPUT_SIZE, 512, 1, 3) 19 | self.action_decoder = MLP(config.OUTPUT_SIZE, 512, 4, 3) 20 | self.action_tokens = nn.Parameter(nn.init.kaiming_uniform_(torch.empty(1, 5, config.EMBEDDING_DIM), 21 | a=math.sqrt(5))) 22 | # build transformer 23 | decoder_layer = TransformerDecoderLayer(config.EMBEDDING_DIM, config.NUM_HEADS, 2048, 0.1, "relu", False) 24 | decoder_norm = nn.LayerNorm(config.EMBEDDING_DIM) 25 | self.transformer = TransformerDecoder(decoder_layer, config.NUM_LAYERS, decoder_norm, return_intermediate=False) 26 | 27 | self.embed_dim = config.EMBEDDING_DIM 28 | self.img_len = 19 * 19 29 | 30 | self.pos_embed = nn.Parameter(torch.zeros(1, 1805, config.EMBEDDING_DIM), requires_grad=False) 31 | self.query_embed = nn.Parameter(torch.zeros(1, 255, config.EMBEDDING_DIM), requires_grad=True) 32 | self.init_pos_emb() 33 | 34 | def forward(self, x): 35 | # fold data into sequence 36 | img_feature_embedding = self.img_feature_embedding(x["embedded_memory_features"].permute(0, 1, 3, 4, 2)) 37 | preds = torch.cat((x["box_features"], x["pred_logits"], x["pred_boxes"]), dim=-1) 38 | prediction_embeddings = self.prediction_embedding(preds) 39 | b, s, p, n = prediction_embeddings.shape 40 | # create padded sequences 41 | memory = torch.zeros((b, 5 * 19 * 19, n), device=prediction_embeddings.device) 42 | memory[:, :(s * 19 * 19)] = img_feature_embedding.reshape(b, -1, n) 43 | tgt = torch.zeros((b, 255, n), device=prediction_embeddings.device) 44 | tgt[:, :(s * 50)] = prediction_embeddings.reshape(b, -1, n) 45 | tgt[:, 250:255] = self.action_tokens.repeat(b, 1, 1).reshape(b, -1, n) 46 | mask = torch.zeros((b, 5 * 19 * 19), dtype=torch.bool, device=x["box_features"].device) 47 | # pass sequence through model 48 | y = self.transformer(tgt.permute(1, 0, 2), memory.permute(1, 0, 2), memory_key_padding_mask=mask, 49 | pos=self.pos_embed.permute(1, 0, 2), query_pos=self.query_embed.permute(1, 0, 2)) 50 | # unfold data 51 | y_preds = y[:, :-5].reshape(b, s, p, -1) 52 | boxes = self.box_decoder(y_preds).sigmoid() 53 | logits = self.logit_decoder(y_preds) 54 | loss = self.loss_decoder(y_preds) 55 | actions = self.action_decoder(y[:, -5:-1].reshape(b, 4, -1)) 56 | 57 | return {"seq": y_preds.squeeze(), "pred_boxes": boxes.squeeze(), "pred_logits": logits.squeeze(), 58 | "loss": loss, "actions": actions.squeeze()} 59 | 60 | def init_pos_emb(self): 61 | img_sin_embed = get_2d_sincos_pos_embed(self.embed_dim // 2, int(self.img_len**.5)) 62 | img_pos_embed = torch.zeros((1, self.img_len, self.embed_dim)) 63 | img_pos_embed[:, :, :self.embed_dim // 2] = torch.from_numpy(img_sin_embed).float() 64 | 65 | seq_sin_embed = get_1d_sincos_pos_embed(self.embed_dim // 2, 5) 66 | seq_pos_embed = torch.zeros((1, 5, self.embed_dim)) 67 | seq_pos_embed[:, :, self.embed_dim // 2:] = torch.from_numpy(seq_sin_embed).float() 68 | 69 | pos_emb = torch.zeros((1, 1805, self.embed_dim)) 70 | for i in range(5): 71 | pos_emb[:, self.img_len*i:self.img_len*(i+1)] = img_pos_embed + seq_pos_embed[:, i] 72 | 73 | self.pos_embed.data.copy_(pos_emb) 74 | 75 | 76 | # Positional embeddings 77 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 78 | """ 79 | grid_size: int of the grid height and width 80 | return: 81 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 82 | """ 83 | grid_h = np.arange(grid_size, dtype=np.float32) 84 | grid_w = np.arange(grid_size, dtype=np.float32) 85 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 86 | grid = np.stack(grid, axis=0) 87 | 88 | grid = grid.reshape([2, 1, grid_size, grid_size]) 89 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 90 | if cls_token: 91 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 92 | return pos_embed 93 | 94 | 95 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 96 | assert embed_dim % 2 == 0 97 | 98 | # use half of dimensions to encode grid_h 99 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 100 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 101 | 102 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 103 | return emb 104 | 105 | 106 | def get_1d_sincos_pos_embed(embed_dim, n): 107 | grid = np.arange(n) 108 | return get_1d_sincos_pos_embed_from_grid(embed_dim, grid) 109 | 110 | 111 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 112 | """ 113 | embed_dim: output dimension for each position 114 | pos: a list of positions to be encoded: size (M,) 115 | out: (M, D) 116 | """ 117 | assert embed_dim % 2 == 0 118 | omega = np.arange(embed_dim // 2, dtype=np.float) 119 | omega /= embed_dim / 2. 120 | omega = 1. / 10000**omega # (D/2,) 121 | 122 | pos = pos.reshape(-1) # (M,) 123 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 124 | 125 | emb_sin = np.sin(out) # (M, D/2) 126 | emb_cos = np.cos(out) # (M, D/2) 127 | 128 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 129 | return emb 130 | -------------------------------------------------------------------------------- /models/single_frame_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.detr import detr 6 | 7 | 8 | class SingleFrameBaselineModel(nn.Module): 9 | 10 | def __init__(self, cfg): 11 | super().__init__() 12 | self.detector = detr(config=cfg) 13 | self.cfg = cfg 14 | self.logger = None 15 | self.mode = 'train' 16 | 17 | def forward(self, data): 18 | predictions, losses = self.detector(data) 19 | predictions.nms(k=50) 20 | 21 | return predictions, losses 22 | 23 | def configure_optimizer(self, train_config): 24 | # optim_groups = self.detector.get_optimizer_groups(train_config) 25 | optim_groups = [{"params": list(self.model.parameters()), "weight_decay": train_config.WEIGHT_DECAY}] 26 | assert train_config.OPTIM_TYPE in ["Adam", "AdamW", "SGD"], \ 27 | "Invalid optimizer type {}. Please select Adam, AdamW or SGD" 28 | if train_config.OPTIM_TYPE == "AdamW": 29 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.LEARNING_RATE, 30 | betas=(train_config.BETA1, train_config.BETA2)) 31 | elif train_config.OPTIM_TYPE == "Adam": 32 | optimizer = torch.optim.Adam(optim_groups, lr=train_config.LEARNING_RATE, 33 | betas=(train_config.BETA1, train_config.BETA2)) 34 | else: 35 | optimizer = torch.optim.SGD(optim_groups, lr=train_config.LEARNING_RATE, momentum=train_config.MOMENTUM) 36 | return optimizer 37 | 38 | def set_logger(self, logger): 39 | assert self.logger is None, "This model already has a logger!" 40 | self.logger = logger 41 | self.detector.set_logger(logger) 42 | 43 | def train(self, mode=True): 44 | self.mode = 'train' if mode else 'test' 45 | self.detector.train(mode=False) 46 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from models.gpt import GPT 6 | from models.detr_models.detr import MLP 7 | 8 | 9 | class MLP2(nn.Module): 10 | 11 | def __init__(self, in_dim, emb_dim, out_dim): 12 | super().__init__() 13 | self.model = nn.Sequential( 14 | nn.Linear(in_dim, emb_dim), 15 | nn.LayerNorm(emb_dim), 16 | nn.ReLU(), 17 | nn.Linear(emb_dim, emb_dim), 18 | nn.LayerNorm(emb_dim), 19 | nn.ReLU(), 20 | nn.Linear(emb_dim, emb_dim), 21 | nn.LayerNorm(emb_dim), 22 | nn.ReLU(), 23 | nn.Linear(emb_dim, emb_dim), 24 | nn.LayerNorm(emb_dim), 25 | nn.ReLU(), 26 | nn.Linear(emb_dim, out_dim) 27 | ) 28 | 29 | def forward(self, x): 30 | return self.model(x) 31 | 32 | 33 | class Transformer(nn.Module): 34 | 35 | def __init__(self, config): 36 | super().__init__() 37 | self.img_feature_embedding = nn.Linear(config.IMG_FEATURE_SIZE, config.EMBEDDING_DIM) 38 | self.prediction_embedding = nn.Linear(config.BOX_EMB_SIZE + config.NUM_CLASSES + 5, config.EMBEDDING_DIM) 39 | self.model = GPT(config) 40 | self.box_decoder = MLP(config.OUTPUT_SIZE, 256, 4, 3) 41 | self.logit_decoder = nn.Linear(config.OUTPUT_SIZE, config.NUM_CLASSES + 1) 42 | self.loss_decoder = MLP(config.OUTPUT_SIZE, 512, 1, 3) 43 | self.action_decoder = MLP(config.OUTPUT_SIZE, 512, 4, 3) 44 | self.action_tokens = nn.Parameter(nn.init.kaiming_uniform_(torch.empty(1, 5, config.EMBEDDING_DIM), 45 | a=math.sqrt(5))) 46 | 47 | def forward(self, x): 48 | # fold data into sequence 49 | img_feature_embedding = self.img_feature_embedding(x["embedded_memory_features"].permute(0, 1, 3, 4, 2)) 50 | preds = torch.cat((x["box_features"], x["pred_logits"], x["pred_boxes"]), dim=-1) 51 | prediction_embeddings = self.prediction_embedding(preds) 52 | b, s, p, n = prediction_embeddings.shape 53 | n_preds = prediction_embeddings.shape[1] * prediction_embeddings.shape[2] 54 | seq = torch.cat((img_feature_embedding.reshape(b, -1, n), 55 | prediction_embeddings.reshape(b, -1, n), 56 | self.action_tokens.repeat(b, 1, 1).reshape(b, -1, n)), dim=1) 57 | y = self.model(seq) 58 | # unfold data 59 | y_preds = y[:, -(n_preds + 5):-5].reshape(b, s, p, -1) 60 | boxes = self.box_decoder(y_preds).sigmoid() 61 | logits = self.logit_decoder(y_preds) 62 | loss = self.loss_decoder(y_preds) 63 | actions = self.action_decoder(y[:, -5:-1].reshape(b, 4, -1)) 64 | 65 | return {"seq": y_preds.squeeze(), "pred_boxes": boxes.squeeze(), "pred_logits": logits.squeeze(), 66 | "loss": loss, "actions": actions.squeeze()} 67 | 68 | def get_optimizer_groups(self, train_config): 69 | # separate out all parameters to those that will and won't experience regularizing weight decay 70 | decay = set() 71 | no_decay = set() 72 | whitelist_weight_modules = (torch.nn.Linear, ) 73 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 74 | for mn, m in self.named_modules(): 75 | for pn, p in m.named_parameters(): 76 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 77 | 78 | if pn.endswith('bias'): 79 | # all biases will not be decayed 80 | no_decay.add(fpn) 81 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 82 | # weights of whitelist modules will be weight decayed 83 | decay.add(fpn) 84 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 85 | # weights of blacklist modules will NOT be weight decayed 86 | no_decay.add(fpn) 87 | 88 | # special case the position embedding parameter in the root GPT module as not decayed 89 | no_decay.add('model.pos_emb') 90 | no_decay.add('action_tokens') 91 | no_decay.add('model.seq_pos_embed') 92 | 93 | # validate that we considered every parameter 94 | param_dict = {pn: p for pn, p in self.named_parameters()} 95 | inter_params = decay & no_decay 96 | union_params = decay | no_decay 97 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 98 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 99 | % (str(param_dict.keys() - union_params), ) 100 | 101 | # create the pytorch optimizer object 102 | optim_groups = [ 103 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.WEIGHT_DECAY}, 104 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 105 | ] 106 | return optim_groups 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ai2thor>=3.3.4 2 | torch==1.9.0 3 | torchvision==0.10.0 4 | matplotlib==3.3.4 5 | tqdm==4.62.2 6 | pyyaml==5.4.1 7 | packaging==21.3 8 | scipy==1.8.0 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy 4 | from utils.config_utils import ( 5 | get_config, 6 | get_args, 7 | build_model, 8 | build_trainer, 9 | build_evaluator, 10 | ) 11 | 12 | 13 | def train(): 14 | # torch.use_deterministic_algorithms(True) 15 | random.seed(42) 16 | torch.manual_seed(42) 17 | torch.cuda.manual_seed(42) 18 | numpy.random.seed(42) 19 | args = get_args() 20 | cfg = get_config(args.config_file) 21 | model = build_model(cfg.MODEL) 22 | evaluator = build_evaluator(model, cfg) 23 | trainer = build_trainer(model, cfg, evaluator=evaluator) 24 | trainer.train() 25 | 26 | 27 | if __name__ == "__main__": 28 | train() 29 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/interactron/e94a1c3c7bd442708f4a3a6d8bc5f586f597a02d/utils/__init__.py -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import os 4 | 5 | 6 | ACTIONS = ["MoveAhead", "MoveBack", "RotateLeft", "RotateRight"] 7 | 8 | 9 | class Config: 10 | 11 | def __init__(self, **entries): 12 | objectefied_entires = {} 13 | for entrie, value in entries.items(): 14 | if type(value) is dict: 15 | objectefied_entires[entrie] = Config(**value) 16 | else: 17 | try: 18 | value = float(value) 19 | if value.is_integer(): 20 | value = int(value) 21 | except: 22 | pass 23 | objectefied_entires[entrie] = value 24 | self.__dict__.update(objectefied_entires) 25 | 26 | def dictionarize(self): 27 | fields = {} 28 | for k, v in self.__dict__.items(): 29 | if isinstance(v, Config): 30 | fields[k] = v.dictionarize() 31 | else: 32 | fields[k] = v 33 | return fields 34 | 35 | 36 | def get_config(cfg): 37 | assert os.path.exists(cfg), "File {} does not exist".format(cfg) 38 | with open(cfg) as f: 39 | args = yaml.safe_load(f) 40 | return Config(**args) 41 | 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser(description='Train Interactron Model') 45 | parser.add_argument('--config_file', type=str, required=True, 46 | help='path to the configuration file for this training run') 47 | parser.add_argument('--devices', type=list, default='cpu', help='sum the integers (default: find the max)') 48 | 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | def build_model(args): 54 | arg_check(args.TYPE, ["detr", "detr_multiframe", "interactron_random", "interactron", "single_frame_baseline", 55 | "five_frame_baseline", "adaptive"], "model") 56 | if args.TYPE == "single_frame_baseline": 57 | from models.single_frame_baseline import SingleFrameBaselineModel 58 | model = SingleFrameBaselineModel(args) 59 | elif args.TYPE == "five_frame_baseline": 60 | from models.five_frame_baseline import FiveFrameBaselineModel 61 | model = FiveFrameBaselineModel(args) 62 | elif args.TYPE == "adaptive": 63 | from models.adaptive import AdaptiveModel 64 | model = AdaptiveModel(args) 65 | elif args.TYPE == "detr": 66 | from models.detr import detr 67 | model = detr(args) 68 | elif args.TYPE == "detr_multiframe": 69 | from models.detr_multiframe import detr_multiframe 70 | model = detr_multiframe(args) 71 | elif args.TYPE == "interactron_random": 72 | from models.interactron_random import interactron_random 73 | model = interactron_random(args) 74 | elif args.TYPE == "interactron": 75 | from models.interactron import interactron 76 | model = interactron(args) 77 | return model 78 | 79 | 80 | def build_trainer(model, args, evaluator=None): 81 | arg_check(args.TRAINER.TYPE, ["direct_supervision", "adaptive", "adaptive_interactive", "interactron_random", 82 | "interactron"], "supervisor") 83 | if args.TRAINER.TYPE == "direct_supervision": 84 | from engine.direct_supervision_trainer import DirectSupervisionTrainer 85 | trainer = DirectSupervisionTrainer(model, args, evaluator=evaluator) 86 | elif args.TRAINER.TYPE == "interactron_random": 87 | from engine.interactron_random_trainer import InteractronRandomTrainer 88 | trainer = InteractronRandomTrainer(model, args, evaluator=evaluator) 89 | elif args.TRAINER.TYPE == "interactron": 90 | from engine.interactron_trainer import InteractronTrainer 91 | trainer = InteractronTrainer(model, args, evaluator=evaluator) 92 | elif args.TRAINER.TYPE == "adaptive": 93 | from engine.adaptive_trainer import AdaptiveTrainer 94 | trainer = AdaptiveTrainer(model, args, evaluator=evaluator) 95 | elif args.TRAINER.TYPE == "adaptive_interactive": 96 | from engine.adaptive_interactive_trainer import AdaptiveInteractiveTrainer 97 | trainer = AdaptiveInteractiveTrainer(model, args, evaluator=evaluator) 98 | return trainer 99 | 100 | 101 | def build_evaluator(model, args, load_checkpoint=False): 102 | arg_check(args.EVALUATOR.TYPE, ["random_policy_evaluator", "interactive_evaluator", "every_path_evaluator"], 103 | "evaluator") 104 | if args.EVALUATOR.TYPE == "random_policy_evaluator": 105 | from engine.random_policy_evaluator import RandomPolicyEvaluator 106 | evaluator = RandomPolicyEvaluator(model, args, load_checkpoint=load_checkpoint) 107 | elif args.EVALUATOR.TYPE == "interactive_evaluator": 108 | from engine.interactive_evaluator import InteractiveEvaluator 109 | evaluator = InteractiveEvaluator(model, args, load_checkpoint=load_checkpoint) 110 | elif args.EVALUATOR.TYPE == "every_path_evaluator": 111 | from engine.every_path_evaluator import EveryPathEvaluator 112 | evaluator = EveryPathEvaluator(model, args, load_checkpoint=load_checkpoint) 113 | return evaluator 114 | 115 | 116 | def arg_check(arg, list, argname): 117 | assert arg in list, "{} is not a valid {}. Please select one from {}".format(arg, argname, list) 118 | 119 | 120 | def iou(b1, b2): 121 | a1 = (b1[2] - b1[0]) * (b1[3] - b1[1]) 122 | a2 = (b2[2] - b2[0]) * (b2[3] - b2[1]) 123 | i = max(min(b1[2], b2[2]) - max(b1[0], b2[0]), 0) * max(min(b1[3], b2[3]) - max(b1[1], b2[1]), 0) 124 | u = a1 + a2 - i 125 | return i / u 126 | 127 | 128 | def compute_AP(precision, recall): 129 | p = precision 130 | r = recall 131 | return sum([r[0] * p[0]] + [(r[i]-r[i-1]) * ((p[i] + p[i-1])/2) for i in range(1, len(p))]) 132 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | import numpy as np 4 | 5 | 6 | class TBLogger: 7 | 8 | def __init__(self, log_dir): 9 | self.writer = SummaryWriter(log_dir=log_dir) 10 | self.scalar_buffer = {} 11 | self.img_buffer = {} 12 | self.iter_counter = 0 13 | 14 | def add_value(self, name, value): 15 | assert any([isinstance(value, t) for t in [int, float, np.ndarray, torch.Tensor]]), \ 16 | "Invalid type {}. Only int, float, np.ndarray and torch.Tensor are accepted".format(type(value)) 17 | if isinstance(value, torch.Tensor): 18 | assert len(value.shape) == 0, \ 19 | "Got tensor of shape {}. Only single value tensors are valid.".format(value.shape) 20 | value = value.item() 21 | if name in self.scalar_buffer: 22 | self.scalar_buffer[name].append(value) 23 | else: 24 | self.scalar_buffer[name] = [value] 25 | 26 | def add_image(self, name, img): 27 | assert any([isinstance(img, t) for t in [torch.Tensor]]), \ 28 | "Invalid type {}. Only torch.Tensor are accepted".format(type(img)) 29 | self.img_buffer[name] = img 30 | 31 | def log_values(self): 32 | # log all scalars 33 | for name, values in self.scalar_buffer.items(): 34 | self.writer.add_scalar(name, np.mean(values), self.iter_counter) 35 | # log all images 36 | for name, value in self.img_buffer.items(): 37 | self.writer.add_image(name, value, self.iter_counter, dataformats='HWC') 38 | # clear out buffers 39 | self.scalar_buffer = {} 40 | self.img_buffer = {} 41 | self.iter_counter += 1 42 | -------------------------------------------------------------------------------- /utils/meta_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | 4 | 5 | def get_parameters(model): 6 | # get children form model! 7 | children = list(model.children()) 8 | flatt_children = [] 9 | if children == []: 10 | # if model has no children; model is last child! :O 11 | params = [] 12 | for name, param in model._parameters.items(): 13 | if param is not None and param.requires_grad: 14 | params.append(param) 15 | return params 16 | # return [child for child in model._parameters.values() if child.requires_grad] 17 | else: 18 | # look for children from children... to the last child! 19 | for child in children: 20 | try: 21 | flatt_children.extend(get_parameters(child)) 22 | except TypeError: 23 | flatt_children.append(get_parameters(child)) 24 | return tuple(flatt_children) 25 | 26 | 27 | # def detach_parameters(model): 28 | # # get children form model! 29 | # children = list(model.children()) 30 | # flatt_children = [] 31 | # if children == []: 32 | # # if model has no children; model is last child! :O 33 | # for name, param in model._parameters.items(): 34 | # if param is not None and param.requires_grad: 35 | # model._parameters[name] = model._parameters[name].detach() 36 | # model._parameters[name].requires_grad = True 37 | # # return [child for child in model._parameters.values() if child.requires_grad] 38 | # else: 39 | # # look for children from children... to the last child! 40 | # for child in children: 41 | # try: 42 | # flatt_children.extend(detach_parameters(child)) 43 | # except TypeError: 44 | # flatt_children.append(detach_parameters(child)) 45 | # return flatt_children 46 | 47 | 48 | def detach_parameters(params): 49 | detached_params = [] 50 | for p in params: 51 | dp = p.clone().detach() 52 | dp.requires_grad = True 53 | detached_params.append(dp) 54 | return tuple(detached_params) 55 | 56 | 57 | def detach_gradients(params): 58 | detached_params = [] 59 | for p in params: 60 | if p is None: 61 | detached_params.append(None) 62 | else: 63 | dp = p.clone().detach() 64 | detached_params.append(dp) 65 | return tuple(detached_params) 66 | 67 | 68 | def set_parameters(model, params): 69 | # get children form model! 70 | if not isinstance(params, collections.abc.Iterator): 71 | params = iter(params) 72 | children = list(model.children()) 73 | flatt_children = [] 74 | if children == []: 75 | # if model has no children; model is last child! :O 76 | for name, param in model._parameters.items(): 77 | if param is not None and param.requires_grad: 78 | model._parameters[name] = next(params) 79 | else: 80 | # look for children from children... to the last child! 81 | for child in children: 82 | try: 83 | flatt_children.extend(set_parameters(child, params)) 84 | except TypeError: 85 | flatt_children.append(set_parameters(child, params)) 86 | return flatt_children 87 | 88 | 89 | # def clone_parameters(model, params): 90 | # # get children form model! 91 | # children = list(model.children()) 92 | # flatt_children = [] 93 | # if children == []: 94 | # # if model has no children; model is last child! :O 95 | # for name, param in model._parameters.items(): 96 | # if param is not None and param.requires_grad: 97 | # model._parameters[name] = next(params).clone() 98 | # else: 99 | # # look for children from children... to the last child! 100 | # for child in children: 101 | # try: 102 | # flatt_children.extend(clone_parameters(child, params)) 103 | # except TypeError: 104 | # flatt_children.append(clone_parameters(child, params)) 105 | # return flatt_children 106 | 107 | def clone_parameters(params): 108 | cloned_parameters = [] 109 | for param in params: 110 | cloned_parameters.append(param.clone()) 111 | return tuple(cloned_parameters) 112 | 113 | 114 | # def sgd_step(model, grads, lr=0.001): 115 | # # get children form model! 116 | # children = list(model.children()) 117 | # flatt_children = [] 118 | # if children == []: 119 | # # if model has no children; model is last child! :O 120 | # for name, param in model._parameters.items(): 121 | # if param is not None and param.requires_grad: 122 | # grad = next(grads) 123 | # if grad is not None: 124 | # model._parameters[name] = model._parameters[name] - lr * grad 125 | # else: 126 | # # look for children from children... to the last child! 127 | # for child in children: 128 | # try: 129 | # flatt_children.extend(sgd_step(child, grads, lr=lr)) 130 | # except TypeError: 131 | # flatt_children.append(sgd_step(child, grads, lr=lr)) 132 | # return flatt_children 133 | 134 | 135 | def sgd_step(params, grads, lr, clip=0.01): 136 | updated_params = [] 137 | for p, g in zip(params, grads): 138 | if g is None: 139 | updated_params.append(p) 140 | else: 141 | updated_params.append(p - torch.clip(lr * g, min=-clip, max=clip)) 142 | return tuple(updated_params) 143 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | def merge_batch_seq(x): 2 | return x.view(x.shape[0] * x.shape[1], *x.shape[2:]) 3 | 4 | 5 | def unmerge_batch_seq(x, b, s): 6 | return x.view(b, s, *x.shape[1:]) 7 | -------------------------------------------------------------------------------- /utils/storage_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Node: 5 | def __init__(self, cost=None, action=None): 6 | self.cost = cost 7 | self.action = action 8 | self.edges = [] 9 | 10 | def get_edges(self): 11 | return {e.value: e.b for e in self.edges} 12 | 13 | def add_edge(self, e): 14 | if e not in [x.value for x in self.edges]: 15 | self.edges.append(e) 16 | 17 | 18 | class Edge: 19 | def __init__(self, a, b, x): 20 | self.value = x 21 | self.a = a 22 | self.b = b 23 | 24 | 25 | class PathStorage: 26 | 27 | def __init__(self): 28 | self.root = Node(float('inf')) 29 | 30 | def add_path(self, path, ifga): 31 | curr = self.root 32 | for a in path: 33 | a = a.item() 34 | if a is None: 35 | print("wow") 36 | if ifga < curr.cost: 37 | curr.cost = ifga 38 | curr.action = a 39 | if a not in curr.get_edges(): 40 | curr.add_edge(Edge(curr, Node(float('inf')), a)) 41 | curr = curr.get_edges()[a] 42 | 43 | def get_label(self, path): 44 | actions = [] 45 | curr = self.root 46 | for a in path: 47 | a = a.item() 48 | actions.append(curr.action) 49 | curr = curr.get_edges()[a] 50 | return actions 51 | 52 | 53 | def collate_fn(batch): 54 | collated_batch = { 55 | 'frames': torch.stack([torch.stack(b['frames']) for b in batch]), 56 | "masks": torch.stack([torch.stack(b['masks']) for b in batch]), 57 | "actions": torch.stack([torch.tensor(b['actions'], dtype=torch.long) for b in batch]), 58 | "object_ids": [[torch.tensor(inst, dtype=torch.long) for inst in b['object_ids']] for b in batch], 59 | "category_ids": [[inst for inst in b['category_ids']] for b in batch], 60 | "boxes": [[inst for inst in b['boxes']] for b in batch], 61 | "episode_ids": torch.stack([torch.tensor(b['episode_ids'], dtype=torch.long) for b in batch]), 62 | "initial_image_path": [b['initial_image_path'] for b in batch], 63 | } 64 | return collated_batch 65 | -------------------------------------------------------------------------------- /utils/time_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer: 5 | 6 | def __init__(self): 7 | self.start_time = time.time() 8 | 9 | def tick(self, msg=""): 10 | end_time = time.time() 11 | print("Tick:{}:{}".format(msg, end_time-self.start_time)) 12 | self.start_time = time.time() 13 | -------------------------------------------------------------------------------- /utils/transform_utis.py: -------------------------------------------------------------------------------- 1 | import models.detr_models.util.transforms as T 2 | import torchvision.transforms as TV 3 | 4 | 5 | transform = T.Compose([ 6 | T.RandomResize([300], max_size=300), 7 | T.Compose([ 8 | T.ToTensor(), 9 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 10 | ]) 11 | ]) 12 | 13 | train_transform = T.Compose([ 14 | T.RandomHorizontalFlip(), 15 | T.RandomResize([400, 500, 600]), 16 | T.RandomSizeCrop(300, 300), 17 | T.RandomResize([300], max_size=300), 18 | T.Compose([ 19 | T.ToTensor(), 20 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 21 | ]) 22 | ]) 23 | 24 | 25 | inv_transform = TV.Compose([ 26 | TV.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), 27 | TV.Normalize([-0.485, -0.456, -0.406], [1., 1., 1.,]), 28 | TV.ToPILImage() 29 | ]) 30 | -------------------------------------------------------------------------------- /utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def draw_pr_curve(points, mAP, path): 8 | plt.plot([x["recall"] for x in points], [x["precision"] for x in points]) 9 | plt.xlabel("Recall") 10 | plt.ylabel("Precision") 11 | plt.title("Precision Recall Curve for IOU=0.5, mAP={:.4f}".format(mAP)) 12 | plt.savefig(path) 13 | 14 | 15 | def draw_prediction_distribuion(tp, fp, path): 16 | plt.style.use('seaborn-deep') 17 | x = [p["confidence"] for p in tp] 18 | y = [p["confidence"] for p in fp] 19 | bins = np.linspace(0.0, 1.0, num=10) 20 | plt.hist([x, y], bins, label=['True Positives', 'False Positives']) 21 | plt.legend(loc='upper right') 22 | plt.title("Distribution of True Positive and False Positive Distribution Confidence") 23 | plt.xlabel("Confidence") 24 | plt.ylabel("Number of Predictions") 25 | plt.savefig(path) 26 | 27 | 28 | def draw_preds_and_labels(images, preds, labels): 29 | img = images.get_images()[0, 0].detach().cpu().numpy() 30 | preds = preds.get_boxes()[0, 0] 31 | matched_labels = labels.get_matched_boxes()[0, 0] 32 | gt_labels = labels.get_boxes()[0, 0] 33 | for i in range(matched_labels.shape[0]): 34 | img = draw_box(img, matched_labels[i], (255, 0, 0), 2) 35 | for i in range(gt_labels.shape[0]): 36 | # img = draw_box(img, preds[i], (0, 0, 0), 1) 37 | img = draw_box(img, gt_labels[i], (0, 255, 0), 1) 38 | return torch.tensor(img) 39 | 40 | 41 | def draw_box(img, box, color, thickness, label=False): 42 | b = box 43 | img = cv2.rectangle(img, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), color, thickness) 44 | if label: 45 | img = cv2.putText(img, label, (int(b[0]), int(b[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.4, 46 | color, 1) 47 | return img 48 | 49 | --------------------------------------------------------------------------------